module emote.callbacks.checkpointing

Classes

class Restoree(Protocol):

Fields

  • name: str

Methods

def state_dict(self) -> dict[str, Any]
def load_state_dict(
    self,
    state_dict,
    load_network,
    load_optimizer,
    load_hparams
) -> None

class Checkpointer(Callback):

Checkpointer writes out a checkpoint every n steps. Exactly what is written to the checkpoint is determined by the restorees supplied in the constructor.

Methods

def __init__(
    self
,
    *restorees,
    run_root,
    checkpoint_interval,
    checkpoint_index,
    storage_subdirectory
) -> None

Arguments:

  • restorees(list[Restoree]): A list of restorees that should be saved.
  • run_root(str): The root path to where the run artifacts should be stored.
  • checkpoint_interval(int): Number of backprops between checkpoints.
  • checkpoint_index(int)
  • storage_subdirectory(str): The subdirectory where the checkpoints are stored.
def begin_training(self) -> None
def end_cycle(self, bp_step, bp_samples) -> None

class CheckpointLoader(Callback):

CheckpointLoader loads a checkpoint like the one created by Checkpointer.

This is intended for resuming training given a specific checkpoint index. It also enables you to load network weights, optimizer, or other callback hyper-params independently. If you want to do something more specific, like only restore a specific network (outside a callback), it is probably easier to just do it explicitly when the network is constructed.

Methods

def __init__(
    self
,
    *restorees,
    run_root,
    checkpoint_index,
    load_weights,
    load_optimizers,
    load_hparams,
    storage_subdirectory
) -> None

Arguments:

  • restorees(list[Restoree]): A list of restorees that should be restored.
  • run_root(str): The root path to where the run artifacts should be stored.
  • checkpoint_index(int): Which checkpoint to load.
  • load_weights(bool): If True, it loads the network weights
  • load_optimizers(bool): If True, it loads the optimizer state
  • load_hparams(bool): If True, it loads other callback hyper- params
  • storage_subdirectory(str): The subdirectory where the checkpoints are stored.
def restore_state(self) -> None

class InvalidCheckpointLocation(ValueError):