module emote.nn.curl
Functions
def soft_update_from_to(source_params, target_params, tau) -> None
def rand_uniform(minval, maxval, shape) -> None
Classes
class ImageAugmentor:
Methods
def __init__(
self,
device,
use_fast_augment,
use_noise_aug,
use_per_image_mask_size,
min_mask_relative_size,
max_mask_relative_size
) -> None
def __call__(self, image) -> None
class CurlLoss(LossCallback):
Contrastive Unsupervised Representations for Reinforcement Learning (CURL).
paper: https://arxiv.org/abs/2004.04136
Methods
def __init__(
self,
encoder_model,
target_encoder_model,
device,
learning_rate,
learning_rate_start_frac,
learning_rate_end_frac,
learning_rate_steps,
max_grad_norm,
data_group,
desired_zdim,
tau,
use_noise_aug,
temperature,
use_temperature_variant,
use_per_image_mask_size,
use_fast_augment,
use_projection_layer,
augment_anchor_and_pos,
log_images
) -> None
Arguments:
encoder_model(Conv2dEncoder)
: (Conv2dEncoder) The image encoder that will be trained using CURL.target_encoder_model(Conv2dEncoder)
: (Conv2dEncoder) The target image encoder.device(torch.DeviceObjType)
: (torch.device) The device to use for computation.learning_rate(float)
: (float)learning_rate_start_frac(float)
: (float) The start fraction for LR schedule. (default: 1.0)learning_rate_end_frac(float)
: (float) The end fraction for LR schedule. (default: 1.0)learning_rate_steps(float)
: (int) The number of step to decay the LR over. (default: 1)max_grad_norm(float)
: (float) The maximum gradient norm, use for gradient clipping. (default: 1.0)data_group(str)
(default: default)desired_zdim(int)
: (int) The size of the latent. If the projection layer is not used this will default to the encoder output size. (default: 128)tau(float)
: (float) The tau value that is used for updating the target encoder. (default: 0.005)use_noise_aug(bool)
: (bool) Add noise during image augmentation.temperature(float)
: (float) The value used for the temperature scaled cross-entropy calculation. (default: 0.1)use_temperature_variant(bool)
: (bool) Use normalised temperature scaled cross-entropy variant. (default: True)use_per_image_mask_size(bool)
: (bool) Use different mask sizes for every image in the batch.use_fast_augment(bool)
: (bool) A gpu compatible image augmentation that uses a fixed cutout position and size per batch.use_projection_layer(bool)
: (bool) Add an additional dense layer to the encoder that projects to zdim size. (default: True)augment_anchor_and_pos(bool)
: (bool) Augment both the anchor and positive images. (default: True)log_images(bool)
: (bool) Logs the augmented images. (default: True)
def parameters(self) -> None
def backward(self, observation) -> None
def end_batch(self) -> None