module emote.algorithms.amp

Functions

def gradient_loss_function(model_output, model_input) -> Tensor

Given inputs and outputs of an nn.Module, computes the sum of squared derivatives of outputs to the inputs Arguments: model_output (Tensor): the output of the nn.Module model_input (Tensor): the input to the nn.Module Returns: loss (Tensor): the sum of squared derivatives

Arguments:

  • model_output(Tensor)
  • model_input(Tensor)

Classes

class DiscriminatorLoss(LossCallback):

This loss is used to train a discriminator for adversarial training.

Methods

def __init__(
    self,
    discriminator,
    imitation_state_map_fn,
    policy_state_map_fn,
    grad_loss_weight,
    optimizer,
    lr_schedule,
    max_grad_norm,
    input_key,
    name
) -> None

Arguments:

  • discriminator(nn.Module)
  • imitation_state_map_fn(Callable[[Tensor], Tensor])
  • policy_state_map_fn(Callable[[Tensor], Tensor])
  • grad_loss_weight(float)
  • optimizer(torch.optim.Optimizer)
  • lr_schedule(torch.optim.lr_scheduler._LRScheduler)
  • max_grad_norm(float)
  • input_key(str) (default: features)
  • name(str) (default: Discriminator)
def loss(self, imitation_batch, policy_batch) -> Tensor

Computing the loss

Arguments:

  • imitation_batch(dict): a batch of data from the reference animation. the discriminator is trained to classify data from this batch as positive samples
  • policy_batch(dict): a batch of data from the RL buffer. the discriminator is trained to classify data from this batch as negative samples.

Returns:

  • loss (Tensor): the loss tensor

class AMPReward(LoggingMixin, Callback):

Adversarial rewarding with AMP.

Methods

def __init__(
    self,
    discriminator,
    state_map_fn,
    style_reward_weight,
    rollout_length,
    observation_key,
    data_group
) -> None

Arguments:

  • discriminator(nn.Module)
  • state_map_fn(Callable[[Tensor], Tensor])
  • style_reward_weight(float)
  • rollout_length(int)
  • observation_key(str)
  • data_group(str)
def begin_batch(self, observation, next_observation, rewards) -> None

Updating the reward by adding the weighted AMP reward

Arguments:

  • observation(dict[str, Tensor]): current observation
  • next_observation(dict[str, Tensor]): next observation
  • rewards(Tensor): task reward