package emote.callbacks

Classes

class Checkpointer(Callback):

Checkpointer writes out a checkpoint every n steps. Exactly what is written to the checkpoint is determined by the restorees supplied in the constructor.

Methods

def __init__(
    self
,
    *restorees,
    run_root,
    checkpoint_interval,
    checkpoint_index,
    storage_subdirectory
) -> None

Arguments:

  • restorees(list[Restoree]): A list of restorees that should be saved.
  • run_root(str): The root path to where the run artifacts should be stored.
  • checkpoint_interval(int): Number of backprops between checkpoints.
  • checkpoint_index(int)
  • storage_subdirectory(str): The subdirectory where the checkpoints are stored.
def begin_training(self) -> None
def end_cycle(self, bp_step, bp_samples) -> None

class CheckpointLoader(Callback):

CheckpointLoader loads a checkpoint like the one created by Checkpointer.

This is intended for resuming training given a specific checkpoint index. It also enables you to load network weights, optimizer, or other callback hyper-params independently. If you want to do something more specific, like only restore a specific network (outside a callback), it is probably easier to just do it explicitly when the network is constructed.

Methods

def __init__(
    self
,
    *restorees,
    run_root,
    checkpoint_index,
    load_weights,
    load_optimizers,
    load_hparams,
    storage_subdirectory
) -> None

Arguments:

  • restorees(list[Restoree]): A list of restorees that should be restored.
  • run_root(str): The root path to where the run artifacts should be stored.
  • checkpoint_index(int): Which checkpoint to load.
  • load_weights(bool): If True, it loads the network weights
  • load_optimizers(bool): If True, it loads the optimizer state
  • load_hparams(bool): If True, it loads other callback hyper- params
  • storage_subdirectory(str): The subdirectory where the checkpoints are stored.
def restore_state(self) -> None

class BackPropStepsTerminator(Callback):

Terminates training after a given number of backprops.

Methods

def __init__(self, bp_steps) -> None

Arguments:

  • bp_steps(int): The total number of backprops that the trainer should run for.
def end_cycle(self) -> None

class LoggingMixin:

A Mixin that accepts logging calls. Logged data is saved on this object and gets written by a Logger. This therefore doesn't care how the data is logged, it only provides a standard interface for storing the data to be handled by a Logger.

Methods

def __init__(self, *default_window_length) -> None

Arguments:

  • default_window_length(int)
def log_scalar(self, key, value) -> None

Use log_scalar to periodically log scalar data.

Arguments:

  • key(str)
  • value(float | torch.Tensor)
def log_windowed_scalar(self, key, value) -> None

Log scalars using a moving window average. By default this will use default_window_length from the constructor as the window length. It can also be overridden on a per-key basis using the format windowed[LENGTH]:foo/bar. Note that this cannot be changed between multiple invocations - whichever length is found first will be permanent.

Arguments:

  • key(str)
  • value(float | torch.Tensor | Iterable[torch.Tensor | float])
def log_image(self, key, value) -> None

Use log_image to periodically log image data.

Arguments:

  • key(str)
  • value(torch.Tensor)
def log_video(self, key, value) -> None

Use log_scalar to periodically log scalar data.

Arguments:

  • key(str)
  • value(Tuple[np.ndarray, int])
def log_histogram(self, key, value) -> None
def state_dict(self) -> None
def load_state_dict(
    self,
    state_dict,
    load_network,
    load_optimizer,
    load_hparams
) -> None

class TensorboardLogger(Callback):

Logs the provided loggable callbacks to tensorboard.

Methods

def __init__(self, loggables, writer, log_interval, log_by_samples) -> None

Arguments:

  • loggables(List[LoggingMixin])
  • writer(SummaryWriter)
  • log_interval(int)
  • log_by_samples(bool)
def begin_training(self, bp_step, bp_samples) -> None
def end_cycle(self, bp_step, bp_samples) -> None

class LossCallback(LoggingMixin, Callback):

Losses are callbacks that implement a loss function.

Methods

def __init__(
    self,
    lr_schedule
,
    *name,
    network,
    optimizer,
    max_grad_norm,
    data_group,
    log_per_param_weights,
    log_per_param_grads
) -> None

Arguments:

  • lr_schedule(Optional[optim.lr_scheduler._LRScheduler])
  • name(str)
  • network(Optional[nn.Module])
  • optimizer(Optional[optim.Optimizer])
  • max_grad_norm(float)
  • data_group(str)
  • log_per_param_weights
  • log_per_param_grads
def backward(self) -> None
def log_per_param_weights_and_grads(self) -> None
def state_dict(self) -> None
def load_state_dict(
    self,
    state_dict,
    load_weights,
    load_optimizers,
    load_hparams
) -> None
def loss(self) -> Tensor

The loss method needs to be overwritten to implement a loss.

Returns:

  • A PyTorch tensor of shape (batch,).