module emote.algorithms.dqn
Classes
class QTarget(LoggingMixin, Callback):
Methods
def __init__(
self
,
*q_net,
target_q_net,
gamma,
reward_scale,
target_q_tau,
data_group,
roll_length
) -> None
Compute and manage the target Q-values for Q-Learning algorithms.
Arguments:
q_net(nn.Module): The Q-network.target_q_net(Optional[nn.Module]): The target Q-network. Defaults to a copy of q_net. (default: a copy of q_net)gamma(float): Discount factor for future rewards.reward_scale(float): A scaling factor for the reward values.target_q_tau(float): A soft update rate for target Q-network.data_group(str): The data group to store the computed Q-target.roll_length(int): The rollout length for a batch.
def begin_batch(self, next_observation, rewards, masks) -> None
def end_batch(self) -> None
class QLoss(LossCallback):
Compute the Q-Learning loss.
Methods
def __init__(
self
,
*name,
q,
opt,
lr_schedule,
max_grad_norm,
data_group,
log_per_param_weights,
log_per_param_grads
) -> None
Arguments:
name(str): Identifier for this loss component.q(nn.Module): The Q-network.opt(optim.Optimizer): The optimizer to use for the Q-network.lr_schedule(Optional[optim.lr_scheduler._LRScheduler]): Learning rate scheduler.max_grad_norm(float): Maximum gradient norm for gradient clipping.data_group(str): The data group from which to pull data.log_per_param_weights(bool): Whether to log weights per parameter.log_per_param_grads(bool): Whether to log gradients per parameter.
def loss(self, observation, q_target, actions) -> None