module emote.algorithms.hlgauss

Classes

class LogitNet(nn.Module):

The QNet assumes that the input network has a num_bins property.

Methods

def __init__(self, num_bins) -> None

Arguments:

  • num_bins

class QNet(nn.Module):

The HL Gauss QNet needs to output both the q-value based on the input and to convert logits to q.

Methods

def __init__(self, logit_net, min_value, max_value) -> None

Arguments:

  • logit_net(LogitNet)
  • min_value(float)
  • max_value(float)
def forward(self) -> Tensor
def q_from_logit(self, logits) -> Tensor

class HLGaussLoss(nn.Module):

A HLGauss loss as described by Imani and White. Code from Google Deepmind's https://arxiv.org/pdf/2403.03950v1.pdf.

Methods

def __init__(self, min_value, max_value, num_bins, sigma) -> None

Arguments:

  • min_value(float): Minimal value of the range of target bins.
  • max_value(float): Maximal value of the range of target bins.
  • num_bins(int): Number of bins.
  • sigma(float): Standard deviation of the Gaussian used to convert regression targets to distributions.
def forward(self, logits, target) -> torch.Tensor
def transform_to_probs(self, target) -> torch.Tensor

class QLoss(LossCallback):

A classification loss between the action value net and the target q. The target q values are not calculated here and need to be added to the state before the loss of this module runs.

Methods

def __init__(
    self
,
    *name,
    q,
    opt,
    lr_schedule,
    max_grad_norm,
    smoothing_ratio,
    data_group,
    log_per_param_weights,
    log_per_param_grads
) -> None

Arguments:

  • name(str): The name of the module. Used e.g. while logging.
  • q(QNet): A deep neural net that outputs the discounted loss given the current observations and a given action.
  • opt(optim.Optimizer): An optimizer for q.
  • lr_schedule(Optional[optim.lr_scheduler._LRScheduler]): Learning rate schedule for the optimizer of q.
  • max_grad_norm(float): Clip the norm of the gradient during backprop using this value.
  • smoothing_ratio(float): The HL Gauss smoothing ratio is the standard deviation of the Gaussian divided by the bin size.
  • data_group(str): The name of the data group from which this Loss takes its data.
  • log_per_param_weights((bool)): If true, log each individual policy parameter that is optimized (norm and value histogram).
  • log_per_param_grads((bool)): If true, log the gradients of each individual policy parameter that is optimized (norm and histogram).
def loss(self, observation, actions, q_target) -> None