module emote.algorithms.amp
Functions
def gradient_loss_function(model_output, model_input) -> Tensor
Given inputs and outputs of an nn.Module, computes the sum of squared derivatives of outputs to the inputs Arguments: model_output (Tensor): the output of the nn.Module model_input (Tensor): the input to the nn.Module Returns: loss (Tensor): the sum of squared derivatives
Arguments:
model_output(Tensor)
model_input(Tensor)
Classes
class DiscriminatorLoss(LossCallback):
This loss is used to train a discriminator for adversarial training.
Methods
def __init__(
self,
discriminator,
imitation_state_map_fn,
policy_state_map_fn,
grad_loss_weight,
optimizer,
lr_schedule,
max_grad_norm,
input_key,
name
) -> None
Arguments:
discriminator(nn.Module)
imitation_state_map_fn(Callable[[Tensor], Tensor])
policy_state_map_fn(Callable[[Tensor], Tensor])
grad_loss_weight(float)
optimizer(torch.optim.Optimizer)
lr_schedule(torch.optim.lr_scheduler._LRScheduler)
max_grad_norm(float)
input_key(str)
(default: features)name(str)
(default: Discriminator)
def loss(self, imitation_batch, policy_batch) -> Tensor
Computing the loss
Arguments:
imitation_batch(dict)
: a batch of data from the reference animation. the discriminator is trained to classify data from this batch as positive samplespolicy_batch(dict)
: a batch of data from the RL buffer. the discriminator is trained to classify data from this batch as negative samples.
Returns:
- loss (Tensor): the loss tensor
class AMPReward(LoggingMixin, Callback):
Adversarial rewarding with AMP.
Methods
def __init__(
self,
discriminator,
state_map_fn,
style_reward_weight,
rollout_length,
observation_key,
data_group
) -> None
Arguments:
discriminator(nn.Module)
state_map_fn(Callable[[Tensor], Tensor])
style_reward_weight(float)
rollout_length(int)
observation_key(str)
data_group(str)
def begin_batch(self, observation, next_observation, rewards) -> None
Updating the reward by adding the weighted AMP reward
Arguments:
observation(dict[str, Tensor])
: current observationnext_observation(dict[str, Tensor])
: next observationrewards(Tensor)
: task reward