module emote.algorithms.dqn

Classes

class QTarget(LoggingMixin, Callback):

Methods

def __init__(
    self
,
    *q_net,
    target_q_net,
    gamma,
    reward_scale,
    target_q_tau,
    data_group,
    roll_length
) -> None

Compute and manage the target Q-values for Q-Learning algorithms.

Arguments:

  • q_net(nn.Module): The Q-network.
  • target_q_net(Optional[nn.Module]): The target Q-network. Defaults to a copy of q_net. (default: a copy of q_net)
  • gamma(float): Discount factor for future rewards.
  • reward_scale(float): A scaling factor for the reward values.
  • target_q_tau(float): A soft update rate for target Q-network.
  • data_group(str): The data group to store the computed Q-target.
  • roll_length(int): The rollout length for a batch.
def begin_batch(self, next_observation, rewards, masks) -> None
def end_batch(self) -> None

class QLoss(LossCallback):

Compute the Q-Learning loss.

Methods

def __init__(
    self
,
    *name,
    q,
    opt,
    lr_schedule,
    max_grad_norm,
    data_group,
    log_per_param_weights,
    log_per_param_grads
) -> None

Arguments:

  • name(str): Identifier for this loss component.
  • q(nn.Module): The Q-network.
  • opt(optim.Optimizer): The optimizer to use for the Q-network.
  • lr_schedule(Optional[optim.lr_scheduler._LRScheduler]): Learning rate scheduler.
  • max_grad_norm(float): Maximum gradient norm for gradient clipping.
  • data_group(str): The data group from which to pull data.
  • log_per_param_weights(bool): Whether to log weights per parameter.
  • log_per_param_grads(bool): Whether to log gradients per parameter.
def loss(self, observation, q_target, actions) -> None