module emote.optimizers
Functions
def separate_modules_for_weight_decay(
network,
whitelist_weight_modules,
blacklist_weight_modules,
layers_to_exclude
) -> tuple[set[str], set[str]]
Separate the parameters of network into two sets: one set of parameters that will have weight decay, and one set that will not.
Arguments:
network(torch.nn.Module)
: Network whose modules we want to separate.whitelist_weight_modules(tuple[Type[torch.nn.Module], ...])
: Modules that should have weight decay applied to the weights.blacklist_weight_modules(tuple[Type[torch.nn.Module], ...])
: Modules that should not have weight decay applied to the weights.layers_to_exclude(set[str] | None)
: Names of layers that should be excluded. Defaults to None. (default: None)
Returns:
- Sets of modules with and without weight decay.
Classes
class ModifiedAdamW(torch.optim.AdamW):
Modifies AdamW (Adam with weight decay) to not apply weight decay on the bias and layer normalization weights, and optionally additional modules.
Methods
def __init__(
self,
network,
lr,
weight_decay,
whitelist_weight_modules,
blacklist_weight_modules,
layers_to_exclude
) -> None
Arguments:
network(torch.nn.Module)
: networklr(float)
: learning rateweight_decay(float)
: weight decay coefficientwhitelist_weight_modules(tuple[Type[torch.nn.Module], ...])
: params to get weight decay. Defaults to (torch.nn.Linear, ). (default: <ast.Attribute object at 0x7ffa75caa200>)blacklist_weight_modules(tuple[Type[torch.nn.Module], ...])
: params to not get weight decay. Defaults to (torch.nn.LayerNorm, ). (default: <ast.Attribute object at 0x7ffa75ca9ff0>)layers_to_exclude(set[str] | None)
: set of names of additional layers to exclude, e.g. last layer of Q-network. Defaults to None. (default: None)