module emote.callback

Classes

class CallbackMeta(ABCMeta):

The CallbackMeta metaclass modifies the callbacks so that they accept data groups.

Methods

def __init__(self, cls, bases, fields) -> None

Arguments:

  • cls
  • bases
  • fields
def __call__(self) -> None
def extend(self, func) -> None
def keys_from_member(self) -> None

class Callback:

The principal modular building block of emote. Callbacks are modular pieces of code that together build up the training loop. They contain hooks that are executed at different points during training. These can consume values from other callbacks, and generate their own for others to consume. This allows a very loosely coupled flow of data between different parts of the code. The most important examples of callbacks in emote are the Losses.

The concept has been borrowed from Keras and FastAI.

Methods

def __init__(self, cycle) -> None

Arguments:

  • cycle(int | None)
def restore_state(self) -> None

Called before training starts to allow loader modules to import state.

At this point, no assumptions can be made for other modules state.

def begin_training(self) -> None

Called when training starts, both from scratch and when restoring from a checkpoint.

def begin_cycle(self) -> None

Called at the start of each cycle.

def begin_batch(self) -> None

Called at the start of each batch, immediately after data has been sampled.

def backward(self) -> None

The main batch processing should happen here.

def end_batch(self) -> None

Called when the backward pass has been completed.

def end_cycle(self) -> None

Called when a callbacks cycle is completed.

def end_training(self) -> None

Called right before shutdown, if possible.

def state_dict(self) -> Dict[str, Any]

Called by checkpointers primarily to capture state for on-disk saving.

def load_state_dict(
    self,
    state_dict,
    load_network,
    load_optimizer,
    load_hparams
) -> None

Called from checkpoint-loaders during the restore_state phase, primarily.

Arguments:

  • state_dict(Dict[str, Any])
  • load_network(bool) (default: True)
  • load_optimizer(bool) (default: True)
  • load_hparams(bool) (default: True)

class BatchCallback(Callback):

Methods

def __init__(self, cycle) -> None
def get_batch(self) -> None