module emote.optimizers

Functions

def separate_modules_for_weight_decay(
    network,
    whitelist_weight_modules,
    blacklist_weight_modules,
    layers_to_exclude
) -> tuple[set[str], set[str]]

Separate the parameters of network into two sets: one set of parameters that will have weight decay, and one set that will not.

Arguments:

  • network(torch.nn.Module): Network whose modules we want to separate.
  • whitelist_weight_modules(tuple[Type[torch.nn.Module], ...]): Modules that should have weight decay applied to the weights.
  • blacklist_weight_modules(tuple[Type[torch.nn.Module], ...]): Modules that should not have weight decay applied to the weights.
  • layers_to_exclude(set[str] | None): Names of layers that should be excluded. Defaults to None. (default: None)

Returns:

  • Sets of modules with and without weight decay.

Classes

class ModifiedAdamW(torch.optim.AdamW):

Modifies AdamW (Adam with weight decay) to not apply weight decay on the bias and layer normalization weights, and optionally additional modules.

Methods

def __init__(
    self,
    network,
    lr,
    weight_decay,
    whitelist_weight_modules,
    blacklist_weight_modules,
    layers_to_exclude
) -> None

Arguments:

  • network(torch.nn.Module): network
  • lr(float): learning rate
  • weight_decay(float): weight decay coefficient
  • whitelist_weight_modules(tuple[Type[torch.nn.Module], ...]): params to get weight decay. Defaults to (torch.nn.Linear, ). (default: <ast.Attribute object at 0x7ffa75caa200>)
  • blacklist_weight_modules(tuple[Type[torch.nn.Module], ...]): params to not get weight decay. Defaults to (torch.nn.LayerNorm, ). (default: <ast.Attribute object at 0x7ffa75ca9ff0>)
  • layers_to_exclude(set[str] | None): set of names of additional layers to exclude, e.g. last layer of Q-network. Defaults to None. (default: None)