module emote.callbacks.loss
Classes
class LossCallback(LoggingMixin, Callback):
Losses are callbacks that implement a loss function.
Methods
def __init__(
self,
lr_schedule
,
*name,
network,
optimizer,
max_grad_norm,
data_group,
log_per_param_weights,
log_per_param_grads
) -> None
Arguments:
lr_schedule(Optional[optim.lr_scheduler._LRScheduler])
name(str)
network(Optional[nn.Module])
optimizer(Optional[optim.Optimizer])
max_grad_norm(float)
data_group(str)
log_per_param_weights
log_per_param_grads
def backward(self) -> None
def log_per_param_weights_and_grads(self) -> None
def state_dict(self) -> None
def load_state_dict(
self,
state_dict,
load_weights,
load_optimizers,
load_hparams
) -> None
def loss(self) -> Tensor
The loss method needs to be overwritten to implement a loss.
Returns:
- A PyTorch tensor of shape (batch,).