module emote.trainer

Classes

class StateDict(dict, MutableMapping[str, Any]):

Wrapped around a dict allowing usage in a weakref.

Methods

def get_handle(self) -> WeakReference['StateDict']

Retrieve a weak handle to this state dict, with no promise of ownership or lifetime.

class TrainingShutdownException(Exception):

class Trainer:

The Trainer class manages the main training loop in emote. It does so by invoking a bunch of callbacks in a number of different places.

Fields

  • state: StateDict

  • callbacks: List[Callback]

  • dataloader: Iterable

  • cycle_length: int

Methods

def __init__(self, callbacks, dataloader, batch_size_key) -> None

Arguments:

  • callbacks(List[Callback])
  • dataloader(Iterable)
  • batch_size_key(str) (default: batch_size)
def train(self, shutdown_signal) -> None

The main training loop. This method will wait until the memory is full enough to start sampling, and then start running cycles of backprops on batches sampled from the memory.

Arguments:

  • shutdown_signal(Callable): A function that returns True if training shut end, False otherwise.