module emote.models.model

Classes

class DynamicModel(nn.Module):

Wrapper class for model. DynamicModel class functions as a wrapper for models including ensembles. It also provides data manipulations that are common when using dynamics models with observations and actions (e.g., predicting delta observations, input normalization).

Methods

def __init__(self, *model, learned_rewards, obs_process_fn, no_delta_list) -> None

Arguments:

  • model(nn.Module): the model to wrap.
  • learned_rewards(bool): if True, the wrapper considers the last output of the model to correspond to reward predictions.
  • obs_process_fn(Optional[nn.Module]): if provided, observations will be passed through this function before being given to the model.
  • no_delta_list(Optional[list[int]]): if provided, represents a list of dimensions over which the model predicts the actual observation and not just a delta.
def forward(self, x) -> tuple[torch.Tensor, ...]

Computes the output of the dynamics model.

Arguments:

  • x(torch.Tensor): input

Returns:

  • (tuple of tensors): predicted tensors
def loss(self, obs, next_obs, action, reward) -> tuple[torch.Tensor, dict[str, any]]

Computes the model loss over a batch of transitions.

Arguments:

  • obs(torch.Tensor): current observations
  • next_obs(torch.Tensor): next observations
  • action(torch.Tensor): actions
  • reward(torch.Tensor): rewards

Returns:

  • (tensor and optional dict): the loss tensor and optional info
def sample(
    self,
    action,
    observation,
    rng
) -> tuple[torch.Tensor, Optional[torch.Tensor]]

Samples a simulated transition from the dynamics model. The function first normalizes the inputs to the model, and then denormalize the model output as the final output.

Arguments:

  • action(torch.Tensor): the action at.
  • observation(torch.Tensor): the observation/state st.
  • rng(torch.Generator): a random number generator.

Returns:

  • predicted observation and rewards.
def get_model_input(self, obs, action) -> torch.Tensor

The function prepares the input to the neural network model by concatenating observations and actions. In case, obs_process_fn is given, the observations are processed by the function prior to the concatenation.

Arguments:

  • obs(torch.Tensor): observation tensor
  • action(torch.Tensor): action tensor

Returns:

  • the concatenation of obs and actions
def process_batch(
    self,
    obs,
    next_obs,
    action,
    reward
) -> tuple[torch.Tensor, torch.Tensor]

The function processes the given batch, normalizes inputs and targets, and prepares them for the training.

Arguments:

  • obs(torch.Tensor): the observations tensor
  • next_obs(torch.Tensor): the next observation tensor
  • action(torch.Tensor): the actions tensor
  • reward(torch.Tensor): the rewards tensor

Returns:

  • (tuple[torch.Tensor, torch.Tensor]): the training input and target tensors
def save(self, save_dir) -> None

Saving the model.

Arguments:

  • save_dir(str): the directory to save the model
def load(self, load_dir) -> None

Loading the model.

Arguments:

  • load_dir(str): the directory to load the model

class DeterministicModel(nn.Module):

Methods

def __init__(self, in_size, out_size, device, hidden_size, num_hidden_layers) -> None
def forward(self, x) -> torch.Tensor
def loss(self, model_in, target) -> tuple[torch.Tensor, dict[str, any]]
def sample(self, model_input, rng) -> torch.Tensor

Samples next observation, reward and terminal from the model.

Arguments:

  • model_input(torch.Tensor): the observation and action.
  • rng(torch.Generator): a random number generator.

Returns:

  • predicted observation, rewards, terminal indicator and model state dictionary.

class Normalizer:

Class that keeps a running mean and variance and normalizes data accordingly.

Methods

def __init__(self) -> None
def update_stats(self, data) -> None

Updates the stored statistics using the given data.

Arguments:

  • data(torch.Tensor): The data used to compute the statistics.
def normalize(self, val, update_state) -> torch.Tensor

Normalizes the value according to the stored statistics.

Arguments:

  • val(torch.Tensor): The value to normalize.
  • update_state(bool): Update state?

Returns:

  • The normalized value.
def denormalize(self, val) -> torch.Tensor

De-normalizes the value according to the stored statistics.

Arguments:

  • val(torch.Tensor): The value to de-normalize.

Returns:

  • The de-normalized value.