module emote.models.callbacks

Classes

class ModelLoss(LossCallback):

Trains a dynamic model by minimizing the model loss.

Methods

def __init__(
    self
,
    *model,
    opt,
    lr_schedule,
    max_grad_norm,
    name,
    data_group,
    input_key
) -> None

Arguments:

  • model(DynamicModel): A dynamic model
  • opt(optim.Optimizer): An optimizer.
  • lr_schedule(Optional[optim.lr_scheduler._LRScheduler]): A learning rate scheduler
  • max_grad_norm(float): Clip the norm of the gradient during backprop using this value.
  • name(str): The name of the module. Used e.g. while logging.
  • data_group(str): The name of the data group from which this Loss takes its data.
  • input_key(str)
def loss(self, observation, next_observation, actions, rewards) -> None

class LossProgressCheck(LoggingMixin, BatchCallback):

Methods

def __init__(self, model, num_bp, data_group, input_key) -> None
def begin_batch(self) -> None
def end_cycle(self) -> None
def get_batch(self, observation, next_observation, actions, rewards) -> None

class BatchSampler(BatchCallback):

BatchSampler class is used to provide batches of data for the RL training callbacks. In every BP step, it samples one batch from either the gym buffer or the model buffer based on a Bernoulli probability distribution. It outputs the batch to a separate data-group which will be used by other RL training callbacks.

Arguments: dataloader (MemoryLoader): the dataloader to load data from the model buffer prob_scheduler (BPStepScheduler): the scheduler to update the prob of data samples to come from the model vs. the Gym buffer data_group (str): the data_group to receive data rl_data_group (str): the data_group to upload data for RL training generator (torch.Generator (optional)): an optional random generator

Methods

def __init__(
    self,
    dataloader,
    prob_scheduler,
    data_group,
    rl_data_group,
    generator
) -> None

Arguments:

  • dataloader(MemoryLoader)
  • prob_scheduler(BPStepScheduler)
  • data_group(str) (default: default)
  • rl_data_group(str) (default: rl_buffer)
  • generator(Optional[torch.Generator])
def begin_batch(self) -> None

Generates a batch of data either by sampling from the model buffer or by cloning the input batch

Returns:

  • the batch of data
def sample_model_batch(self) -> None

Samples a batch of data from the model buffer

Returns:

  • batch samples
def use_model_batch(self) -> None

Decides if batch should come from the model-generated buffer

Returns:

  • True if model samples should be used, False otherwise.

class ModelBasedCollector(LoggingMixin, BatchCallback):

ModelBasedCollector class is used to sample rollouts from the trained dynamic model. The rollouts are stored in a replay buffer memory.

Arguments: model_env: The Gym-like dynamic model agent: The policy used to sample actions memory: The memory to store the new synthetic samples rollout_scheduler: A scheduler used to set the rollout-length when unrolling the dynamic model num_bp_to_retain_buffer: The number of BP steps to keep samples. Samples will be over-written (first in first out) for bp steps larger than this. data_group: The data group to receive data from. This must be set to get real (Gym) samples

Methods

def __init__(
    self,
    model_env,
    agent,
    memory,
    rollout_scheduler,
    num_bp_to_retain_buffer,
    data_group,
    input_key
) -> None

Arguments:

  • model_env(ModelEnv)
  • agent(AgentProxy)
  • memory(MemoryProxy)
  • rollout_scheduler(BPStepScheduler)
  • num_bp_to_retain_buffer (default: 1000000)
  • data_group(str) (default: default)
  • input_key(str) (default: obs)
def begin_batch(self) -> None
def get_batch(self, observation) -> None
def collect_sample(self) -> None

Collect a single rollout.

def update_rollout_size(self) -> None