module emote.algorithms.dqn
Classes
class QTarget(LoggingMixin, Callback):
Methods
def __init__(
self
,
*q_net,
target_q_net,
gamma,
reward_scale,
target_q_tau,
data_group,
roll_length
) -> None
Compute and manage the target Q-values for Q-Learning algorithms.
Arguments:
q_net(nn.Module)
: The Q-network.target_q_net(Optional[nn.Module])
: The target Q-network. Defaults to a copy of q_net. (default: a copy of q_net)gamma(float)
: Discount factor for future rewards.reward_scale(float)
: A scaling factor for the reward values.target_q_tau(float)
: A soft update rate for target Q-network.data_group(str)
: The data group to store the computed Q-target.roll_length(int)
: The rollout length for a batch.
def begin_batch(self, next_observation, rewards, masks) -> None
def end_batch(self) -> None
class QLoss(LossCallback):
Compute the Q-Learning loss.
Methods
def __init__(
self
,
*name,
q,
opt,
lr_schedule,
max_grad_norm,
data_group,
log_per_param_weights,
log_per_param_grads
) -> None
Arguments:
name(str)
: Identifier for this loss component.q(nn.Module)
: The Q-network.opt(optim.Optimizer)
: The optimizer to use for the Q-network.lr_schedule(Optional[optim.lr_scheduler._LRScheduler])
: Learning rate scheduler.max_grad_norm(float)
: Maximum gradient norm for gradient clipping.data_group(str)
: The data group from which to pull data.log_per_param_weights(bool)
: Whether to log weights per parameter.log_per_param_grads(bool)
: Whether to log gradients per parameter.
def loss(self, observation, q_target, actions) -> None