module emote.trainer
Classes
class StateDict(dict, MutableMapping[str, Any]):
Wrapped around a dict allowing usage in a weakref.
Methods
def get_handle(self) -> WeakReference['StateDict']
Retrieve a weak handle to this state dict, with no promise of ownership or lifetime.
class TrainingShutdownException(Exception):
class Trainer:
The Trainer class manages the main training loop in emote. It does so by invoking a bunch of callbacks in a number of different places.
Fields
-
state
:StateDict
-
callbacks
:List[Callback]
-
dataloader
:Iterable
-
cycle_length
:int
Methods
def __init__(self, callbacks, dataloader, batch_size_key) -> None
Arguments:
callbacks(List[Callback])
dataloader(Iterable)
batch_size_key(str)
(default: batch_size)
def train(self, shutdown_signal) -> None
The main training loop. This method will wait until the memory is full enough to start sampling, and then start running cycles of backprops on batches sampled from the memory.
Arguments:
shutdown_signal(Callable)
: A function that returns True if training shut end, False otherwise.