module emote.algorithms.amp
Functions
def gradient_loss_function(model_output, model_input) -> Tensor
Given inputs and outputs of an nn.Module, computes the sum of squared derivatives of outputs to the inputs Arguments: model_output (Tensor): the output of the nn.Module model_input (Tensor): the input to the nn.Module Returns: loss (Tensor): the sum of squared derivatives
Arguments:
model_output(Tensor)model_input(Tensor)
Classes
class DiscriminatorLoss(LossCallback):
This loss is used to train a discriminator for adversarial training.
Methods
def __init__(
    self,
    discriminator,
    imitation_state_map_fn,
    policy_state_map_fn,
    grad_loss_weight,
    optimizer,
    lr_schedule,
    max_grad_norm,
    input_key,
    name
) -> None
Arguments:
discriminator(nn.Module)imitation_state_map_fn(Callable[[Tensor], Tensor])policy_state_map_fn(Callable[[Tensor], Tensor])grad_loss_weight(float)optimizer(torch.optim.Optimizer)lr_schedule(torch.optim.lr_scheduler._LRScheduler)max_grad_norm(float)input_key(str)(default: features)name(str)(default: Discriminator)
def loss(self, imitation_batch, policy_batch) -> Tensor
Computing the loss
Arguments:
imitation_batch(dict): a batch of data from the reference animation. the discriminator is trained to classify data from this batch as positive samplespolicy_batch(dict): a batch of data from the RL buffer. the discriminator is trained to classify data from this batch as negative samples.
Returns:
- loss (Tensor): the loss tensor
 
class AMPReward(LoggingMixin, Callback):
Adversarial rewarding with AMP.
Methods
def __init__(
    self,
    discriminator,
    state_map_fn,
    style_reward_weight,
    rollout_length,
    observation_key,
    data_group
) -> None
Arguments:
discriminator(nn.Module)state_map_fn(Callable[[Tensor], Tensor])style_reward_weight(float)rollout_length(int)observation_key(str)data_group(str)
def begin_batch(self, observation, next_observation, rewards) -> None
Updating the reward by adding the weighted AMP reward
Arguments:
observation(dict[str, Tensor]): current observationnext_observation(dict[str, Tensor]): next observationrewards(Tensor): task reward