package emote.callbacks
Classes
class Checkpointer(Callback):
Checkpointer writes out a checkpoint every n steps. Exactly what is written to the checkpoint is determined by the restorees supplied in the constructor.
Methods
def __init__(
self
,
*restorees,
run_root,
checkpoint_interval,
checkpoint_index,
storage_subdirectory
) -> None
Arguments:
restorees(list[Restoree])
: A list of restorees that should be saved.run_root(str)
: The root path to where the run artifacts should be stored.checkpoint_interval(int)
: Number of backprops between checkpoints.checkpoint_index(int)
storage_subdirectory(str)
: The subdirectory where the checkpoints are stored.
def begin_training(self) -> None
def end_cycle(self, bp_step, bp_samples) -> None
class CheckpointLoader(Callback):
CheckpointLoader loads a checkpoint like the one created by Checkpointer.
This is intended for resuming training given a specific checkpoint index. It also enables you to load network weights, optimizer, or other callback hyper-params independently. If you want to do something more specific, like only restore a specific network (outside a callback), it is probably easier to just do it explicitly when the network is constructed.
Methods
def __init__(
self
,
*restorees,
run_root,
checkpoint_index,
load_weights,
load_optimizers,
load_hparams,
storage_subdirectory
) -> None
Arguments:
restorees(list[Restoree])
: A list of restorees that should be restored.run_root(str)
: The root path to where the run artifacts should be stored.checkpoint_index(int)
: Which checkpoint to load.load_weights(bool)
: If True, it loads the network weightsload_optimizers(bool)
: If True, it loads the optimizer stateload_hparams(bool)
: If True, it loads other callback hyper- paramsstorage_subdirectory(str)
: The subdirectory where the checkpoints are stored.
def restore_state(self) -> None
class BackPropStepsTerminator(Callback):
Terminates training after a given number of backprops.
Methods
def __init__(self, bp_steps) -> None
Arguments:
bp_steps(int)
: The total number of backprops that the trainer should run for.
def end_cycle(self) -> None
class LoggingMixin:
A Mixin that accepts logging calls. Logged data is saved on this object and gets written by a Logger. This therefore doesn't care how the data is logged, it only provides a standard interface for storing the data to be handled by a Logger.
Methods
def __init__(self, *default_window_length) -> None
Arguments:
default_window_length(int)
def log_scalar(self, key, value) -> None
Use log_scalar to periodically log scalar data.
Arguments:
key(str)
value(float | torch.Tensor)
def log_windowed_scalar(self, key, value) -> None
Log scalars using a moving window average.
By default this will use default_window_length
from the constructor as the window
length. It can also be overridden on a per-key basis using the format
windowed[LENGTH]:foo/bar. Note that this cannot be changed between multiple invocations -
whichever length is found first will be permanent.
Arguments:
key(str)
value(float | torch.Tensor | Iterable[torch.Tensor | float])
def log_image(self, key, value) -> None
Use log_image to periodically log image data.
Arguments:
key(str)
value(torch.Tensor)
def log_video(self, key, value) -> None
Use log_scalar to periodically log scalar data.
Arguments:
key(str)
value(Tuple[np.ndarray, int])
def log_histogram(self, key, value) -> None
def state_dict(self) -> None
def load_state_dict(
self,
state_dict,
load_network,
load_optimizer,
load_hparams
) -> None
class TensorboardLogger(Callback):
Logs the provided loggable callbacks to tensorboard.
Methods
def __init__(self, loggables, writer, log_interval, log_by_samples) -> None
Arguments:
loggables(List[LoggingMixin])
writer(SummaryWriter)
log_interval(int)
log_by_samples(bool)
def begin_training(self, bp_step, bp_samples) -> None
def end_cycle(self, bp_step, bp_samples) -> None
class LossCallback(LoggingMixin, Callback):
Losses are callbacks that implement a loss function.
Methods
def __init__(
self,
lr_schedule
,
*name,
network,
optimizer,
max_grad_norm,
data_group,
log_per_param_weights,
log_per_param_grads
) -> None
Arguments:
lr_schedule(Optional[optim.lr_scheduler._LRScheduler])
name(str)
network(Optional[nn.Module])
optimizer(Optional[optim.Optimizer])
max_grad_norm(float)
data_group(str)
log_per_param_weights
log_per_param_grads
def backward(self) -> None
def log_per_param_weights_and_grads(self) -> None
def state_dict(self) -> None
def load_state_dict(
self,
state_dict,
load_weights,
load_optimizers,
load_hparams
) -> None
def loss(self) -> Tensor
The loss method needs to be overwritten to implement a loss.
Returns:
- A PyTorch tensor of shape (batch,).