package emote
Emote
In order to do reinforcement learning we need to have two things: A learning protocol that specifies which losses to use, which network architectures, which optimizers, and so forth. We also need some kind of data collector that interacts with the world and stores the experiences from that in a way which makes them accessible to the learning protocol.
In Emote, data collection is done by Collectors, the protocol for the learning algorithm is built up of Callbacks, and they are tied together by a Trainer.
Classes
class Callback:
The principal modular building block of emote. Callbacks are modular pieces of code that together build up the training loop. They contain hooks that are executed at different points during training. These can consume values from other callbacks, and generate their own for others to consume. This allows a very loosely coupled flow of data between different parts of the code. The most important examples of callbacks in emote are the Losses.
The concept has been borrowed from Keras and FastAI.
Methods
def __init__(self, cycle) -> None
Arguments:
cycle(int | None)
def restore_state(self) -> None
Called before training starts to allow loader modules to import state.
At this point, no assumptions can be made for other modules state.
def begin_training(self) -> None
Called when training starts, both from scratch and when restoring from a checkpoint.
def begin_cycle(self) -> None
Called at the start of each cycle.
def begin_batch(self) -> None
Called at the start of each batch, immediately after data has been sampled.
def backward(self) -> None
The main batch processing should happen here.
def end_batch(self) -> None
Called when the backward pass has been completed.
def end_cycle(self) -> None
Called when a callbacks cycle is completed.
def end_training(self) -> None
Called right before shutdown, if possible.
def state_dict(self) -> Dict[str, Any]
Called by checkpointers primarily to capture state for on-disk saving.
def load_state_dict(
self,
state_dict,
load_network,
load_optimizer,
load_hparams
) -> None
Called from checkpoint-loaders during the restore_state
phase,
primarily.
Arguments:
state_dict(Dict[str, Any])
load_network(bool)
(default: True)load_optimizer(bool)
(default: True)load_hparams(bool)
(default: True)
class Trainer:
The Trainer class manages the main training loop in emote. It does so by invoking a bunch of callbacks in a number of different places.
Fields
-
state
:StateDict
-
callbacks
:List[Callback]
-
dataloader
:Iterable
-
cycle_length
:int
Methods
def __init__(self, callbacks, dataloader, batch_size_key) -> None
Arguments:
callbacks(List[Callback])
dataloader(Iterable)
batch_size_key(str)
(default: batch_size)
def train(self, shutdown_signal) -> None
The main training loop. This method will wait until the memory is full enough to start sampling, and then start running cycles of backprops on batches sampled from the memory.
Arguments:
shutdown_signal(Callable)
: A function that returns True if training shut end, False otherwise.