module emote.models.ensemble
Functions
def truncated_normal_init(m) -> None
Initializes the weights of the given module using a truncated normal distribution.
Arguments:
m(nn.Module)
Classes
class EnsembleLinearLayer(nn.Module):
Linear layer for ensemble models.
Methods
def __init__(self, num_members, in_size, out_size) -> None
Arguments:
num_members(int)
: the ensemble sizein_size(int)
: the input size of the modelout_size(int)
: the output size of the model
def forward(self, x) -> None
class EnsembleOfGaussian(nn.Module):
Methods
def __init__(
self
,
*in_size,
out_size,
device,
num_layers,
ensemble_size,
hidden_size,
learn_logvar_bounds,
deterministic
) -> None
def default_forward(self, x) -> tuple[torch.Tensor, torch.Tensor]
def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]
Computes mean and logvar predictions for the given input.
Arguments:
x(torch.Tensor)
: the input to the model.
Returns:
- (tuple of two tensors): the predicted mean and log variance of the output.
def loss(self, model_in, target) -> tuple[torch.Tensor, dict[str, any]]
Computes Gaussian NLL loss.
Arguments:
model_in(torch.Tensor)
: input tensor.target(Optional[torch.Tensor])
: target tensor.
Returns:
- (a tuple of tensor and dict): a loss tensor and a dict which includes extra info.
def sample(self, model_input, rng) -> torch.Tensor
Samples next observation, reward and terminal from the model using the ensemble.
Arguments:
model_input(torch.Tensor)
: the observation and action.rng(torch.Generator)
: a random number generator.
Returns:
- predicted observation, rewards, terminal indicator and model state dictionary.
def save(self, save_dir) -> None
Saves the model to the given directory.
Arguments:
save_dir(str)
def load(self, load_dir) -> None
Loads the model from the given path.
Arguments:
load_dir(str)