module emote.callbacks.loss

Classes

class LossCallback(LoggingMixin, Callback):

Losses are callbacks that implement a loss function.

Methods

def __init__(
    self,
    lr_schedule
,
    *name,
    network,
    optimizer,
    max_grad_norm,
    data_group,
    log_per_param_weights,
    log_per_param_grads
) -> None

Arguments:

  • lr_schedule(Optional[optim.lr_scheduler._LRScheduler])
  • name(str)
  • network(Optional[nn.Module])
  • optimizer(Optional[optim.Optimizer])
  • max_grad_norm(float)
  • data_group(str)
  • log_per_param_weights
  • log_per_param_grads
def backward(self) -> None
def log_per_param_weights_and_grads(self) -> None
def state_dict(self) -> None
def load_state_dict(
    self,
    state_dict,
    load_weights,
    load_optimizers,
    load_hparams
) -> None
def loss(self) -> Tensor

The loss method needs to be overwritten to implement a loss.

Returns:

  • A PyTorch tensor of shape (batch,).