module emote.algorithms.action_symmetry

Classes

class ActionSymmetryDiscriminatorLoss(LossCallback):

This loss is used to train a discriminator for adversarial training.

Methods

def __init__(
    self,
    discriminator,
    right_action_map_fn,
    left_action_map_fn,
    grad_loss_weight,
    optimizer,
    lr_schedule,
    max_grad_norm,
    data_group,
    name
) -> None

Arguments:

  • discriminator(Discriminator)
  • right_action_map_fn(Callable[[Tensor], Tensor])
  • left_action_map_fn(Callable[[Tensor], Tensor])
  • grad_loss_weight(float)
  • optimizer(torch.optim.Optimizer)
  • lr_schedule(torch.optim.lr_scheduler._LRScheduler)
  • max_grad_norm(float)
  • data_group(str)
  • name(str)
def loss(self, actions) -> Tensor

Computing the loss to train a discriminator to classify right-side from left-side action values.

Arguments:

  • actions

class ActionSymmetryAMPReward(LoggingMixin, Callback):

Adversarial rewarding with AMP.

Methods

def __init__(
    self,
    discriminator,
    right_action_map_fn,
    left_action_map_fn,
    confusion_reward_weight,
    data_group
) -> None

Arguments:

  • discriminator(Discriminator)
  • right_action_map_fn(Callable[[Tensor], Tensor])
  • left_action_map_fn(Callable[[Tensor], Tensor])
  • confusion_reward_weight(float)
  • data_group(str)
def begin_batch(self, actions, rewards) -> None

Updating the reward by adding the weighted AMP reward

Arguments:

  • actions(Tensor): batch of actions
  • rewards(Tensor): task reward