module emote.models.ensemble

Functions

def truncated_normal_init(m) -> None

Initializes the weights of the given module using a truncated normal distribution.

Arguments:

  • m(nn.Module)

Classes

class EnsembleLinearLayer(nn.Module):

Linear layer for ensemble models.

Methods

def __init__(self, num_members, in_size, out_size) -> None

Arguments:

  • num_members(int): the ensemble size
  • in_size(int): the input size of the model
  • out_size(int): the output size of the model
def forward(self, x) -> None

class EnsembleOfGaussian(nn.Module):

Methods

def __init__(
    self
,
    *in_size,
    out_size,
    device,
    num_layers,
    ensemble_size,
    hidden_size,
    learn_logvar_bounds,
    deterministic
) -> None
def default_forward(self, x) -> tuple[torch.Tensor, torch.Tensor]
def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]

Computes mean and logvar predictions for the given input.

Arguments:

  • x(torch.Tensor): the input to the model.

Returns:

  • (tuple of two tensors): the predicted mean and log variance of the output.
def loss(self, model_in, target) -> tuple[torch.Tensor, dict[str, any]]

Computes Gaussian NLL loss.

Arguments:

  • model_in(torch.Tensor): input tensor.
  • target(Optional[torch.Tensor]): target tensor.

Returns:

  • (a tuple of tensor and dict): a loss tensor and a dict which includes extra info.
def sample(self, model_input, rng) -> torch.Tensor

Samples next observation, reward and terminal from the model using the ensemble.

Arguments:

  • model_input(torch.Tensor): the observation and action.
  • rng(torch.Generator): a random number generator.

Returns:

  • predicted observation, rewards, terminal indicator and model state dictionary.
def save(self, save_dir) -> None

Saves the model to the given directory.

Arguments:

  • save_dir(str)
def load(self, load_dir) -> None

Loads the model from the given path.

Arguments:

  • load_dir(str)