module emote.algorithms.action_symmetry
Classes
class ActionSymmetryDiscriminatorLoss(LossCallback):
This loss is used to train a discriminator for adversarial training.
Methods
def __init__(
self,
discriminator,
right_action_map_fn,
left_action_map_fn,
grad_loss_weight,
optimizer,
lr_schedule,
max_grad_norm,
data_group,
name
) -> None
Arguments:
discriminator(Discriminator)
right_action_map_fn(Callable[[Tensor], Tensor])
left_action_map_fn(Callable[[Tensor], Tensor])
grad_loss_weight(float)
optimizer(torch.optim.Optimizer)
lr_schedule(torch.optim.lr_scheduler._LRScheduler)
max_grad_norm(float)
data_group(str)
name(str)
def loss(self, actions) -> Tensor
Computing the loss to train a discriminator to classify right-side from left-side action values.
Arguments:
actions
class ActionSymmetryAMPReward(LoggingMixin, Callback):
Adversarial rewarding with AMP.
Methods
def __init__(
self,
discriminator,
right_action_map_fn,
left_action_map_fn,
confusion_reward_weight,
data_group
) -> None
Arguments:
discriminator(Discriminator)
right_action_map_fn(Callable[[Tensor], Tensor])
left_action_map_fn(Callable[[Tensor], Tensor])
confusion_reward_weight(float)
data_group(str)
def begin_batch(self, actions, rewards) -> None
Updating the reward by adding the weighted AMP reward
Arguments:
actions(Tensor)
: batch of actionsrewards(Tensor)
: task reward