module emote.nn.curl

Functions

def soft_update_from_to(source_params, target_params, tau) -> None
def rand_uniform(minval, maxval, shape) -> None

Classes

class ImageAugmentor:

Methods

def __init__(
    self,
    device,
    use_fast_augment,
    use_noise_aug,
    use_per_image_mask_size,
    min_mask_relative_size,
    max_mask_relative_size
) -> None
def __call__(self, image) -> None

class CurlLoss(LossCallback):

Contrastive Unsupervised Representations for Reinforcement Learning (CURL).

paper: https://arxiv.org/abs/2004.04136

Methods

def __init__(
    self,
    encoder_model,
    target_encoder_model,
    device,
    learning_rate,
    learning_rate_start_frac,
    learning_rate_end_frac,
    learning_rate_steps,
    max_grad_norm,
    data_group,
    desired_zdim,
    tau,
    use_noise_aug,
    temperature,
    use_temperature_variant,
    use_per_image_mask_size,
    use_fast_augment,
    use_projection_layer,
    augment_anchor_and_pos,
    log_images
) -> None

Arguments:

  • encoder_model(Conv2dEncoder): (Conv2dEncoder) The image encoder that will be trained using CURL.
  • target_encoder_model(Conv2dEncoder): (Conv2dEncoder) The target image encoder.
  • device(torch.DeviceObjType): (torch.device) The device to use for computation.
  • learning_rate(float): (float)
  • learning_rate_start_frac(float): (float) The start fraction for LR schedule. (default: 1.0)
  • learning_rate_end_frac(float): (float) The end fraction for LR schedule. (default: 1.0)
  • learning_rate_steps(float): (int) The number of step to decay the LR over. (default: 1)
  • max_grad_norm(float): (float) The maximum gradient norm, use for gradient clipping. (default: 1.0)
  • data_group(str) (default: default)
  • desired_zdim(int): (int) The size of the latent. If the projection layer is not used this will default to the encoder output size. (default: 128)
  • tau(float): (float) The tau value that is used for updating the target encoder. (default: 0.005)
  • use_noise_aug(bool): (bool) Add noise during image augmentation.
  • temperature(float): (float) The value used for the temperature scaled cross-entropy calculation. (default: 0.1)
  • use_temperature_variant(bool): (bool) Use normalised temperature scaled cross-entropy variant. (default: True)
  • use_per_image_mask_size(bool): (bool) Use different mask sizes for every image in the batch.
  • use_fast_augment(bool): (bool) A gpu compatible image augmentation that uses a fixed cutout position and size per batch.
  • use_projection_layer(bool): (bool) Add an additional dense layer to the encoder that projects to zdim size. (default: True)
  • augment_anchor_and_pos(bool): (bool) Augment both the anchor and positive images. (default: True)
  • log_images(bool): (bool) Logs the augmented images. (default: True)
def parameters(self) -> None
def backward(self, observation) -> None
def end_batch(self) -> None