package emote

Emote

In order to do reinforcement learning we need to have two things: A learning protocol that specifies which losses to use, which network architectures, which optimizers, and so forth. We also need some kind of data collector that interacts with the world and stores the experiences from that in a way which makes them accessible to the learning protocol.

In Emote, data collection is done by Collectors, the protocol for the learning algorithm is built up of Callbacks, and they are tied together by a Trainer.

Classes

class Callback:

The principal modular building block of emote. Callbacks are modular pieces of code that together build up the training loop. They contain hooks that are executed at different points during training. These can consume values from other callbacks, and generate their own for others to consume. This allows a very loosely coupled flow of data between different parts of the code. The most important examples of callbacks in emote are the Losses.

The concept has been borrowed from Keras and FastAI.

Methods

def __init__(self, cycle) -> None

Arguments:

  • cycle(int | None)
def restore_state(self) -> None

Called before training starts to allow loader modules to import state.

At this point, no assumptions can be made for other modules state.

def begin_training(self) -> None

Called when training starts, both from scratch and when restoring from a checkpoint.

def begin_cycle(self) -> None

Called at the start of each cycle.

def begin_batch(self) -> None

Called at the start of each batch, immediately after data has been sampled.

def backward(self) -> None

The main batch processing should happen here.

def end_batch(self) -> None

Called when the backward pass has been completed.

def end_cycle(self) -> None

Called when a callbacks cycle is completed.

def end_training(self) -> None

Called right before shutdown, if possible.

def state_dict(self) -> Dict[str, Any]

Called by checkpointers primarily to capture state for on-disk saving.

def load_state_dict(
    self,
    state_dict,
    load_network,
    load_optimizer,
    load_hparams
) -> None

Called from checkpoint-loaders during the restore_state phase, primarily.

Arguments:

  • state_dict(Dict[str, Any])
  • load_network(bool) (default: True)
  • load_optimizer(bool) (default: True)
  • load_hparams(bool) (default: True)

class Trainer:

The Trainer class manages the main training loop in emote. It does so by invoking a bunch of callbacks in a number of different places.

Fields

  • state: StateDict

  • callbacks: List[Callback]

  • dataloader: Iterable

  • cycle_length: int

Methods

def __init__(self, callbacks, dataloader, batch_size_key) -> None

Arguments:

  • callbacks(List[Callback])
  • dataloader(Iterable)
  • batch_size_key(str) (default: batch_size)
def train(self, shutdown_signal) -> None

The main training loop. This method will wait until the memory is full enough to start sampling, and then start running cycles of backprops on batches sampled from the memory.

Arguments:

  • shutdown_signal(Callable): A function that returns True if training shut end, False otherwise.