module emote.models.model
Classes
class DynamicModel(nn.Module):
Wrapper class for model. DynamicModel class functions as a wrapper for models including ensembles. It also provides data manipulations that are common when using dynamics models with observations and actions (e.g., predicting delta observations, input normalization).
Methods
def __init__(self, *model, learned_rewards, obs_process_fn, no_delta_list) -> None
Arguments:
model(nn.Module)
: the model to wrap.learned_rewards(bool)
: if True, the wrapper considers the last output of the model to correspond to reward predictions.obs_process_fn(Optional[nn.Module])
: if provided, observations will be passed through this function before being given to the model.no_delta_list(Optional[list[int]])
: if provided, represents a list of dimensions over which the model predicts the actual observation and not just a delta.
def forward(self, x) -> tuple[torch.Tensor, ...]
Computes the output of the dynamics model.
Arguments:
x(torch.Tensor)
: input
Returns:
- (tuple of tensors): predicted tensors
def loss(self, obs, next_obs, action, reward) -> tuple[torch.Tensor, dict[str, any]]
Computes the model loss over a batch of transitions.
Arguments:
obs(torch.Tensor)
: current observationsnext_obs(torch.Tensor)
: next observationsaction(torch.Tensor)
: actionsreward(torch.Tensor)
: rewards
Returns:
- (tensor and optional dict): the loss tensor and optional info
def sample(
self,
action,
observation,
rng
) -> tuple[torch.Tensor, Optional[torch.Tensor]]
Samples a simulated transition from the dynamics model. The function first normalizes the inputs to the model, and then denormalize the model output as the final output.
Arguments:
action(torch.Tensor)
: the action at.observation(torch.Tensor)
: the observation/state st.rng(torch.Generator)
: a random number generator.
Returns:
- predicted observation and rewards.
def get_model_input(self, obs, action) -> torch.Tensor
The function prepares the input to the neural network model by concatenating observations and actions. In case, obs_process_fn is given, the observations are processed by the function prior to the concatenation.
Arguments:
obs(torch.Tensor)
: observation tensoraction(torch.Tensor)
: action tensor
Returns:
- the concatenation of obs and actions
def process_batch(
self,
obs,
next_obs,
action,
reward
) -> tuple[torch.Tensor, torch.Tensor]
The function processes the given batch, normalizes inputs and targets, and prepares them for the training.
Arguments:
obs(torch.Tensor)
: the observations tensornext_obs(torch.Tensor)
: the next observation tensoraction(torch.Tensor)
: the actions tensorreward(torch.Tensor)
: the rewards tensor
Returns:
- (tuple[torch.Tensor, torch.Tensor]): the training input and target tensors
def save(self, save_dir) -> None
Saving the model.
Arguments:
save_dir(str)
: the directory to save the model
def load(self, load_dir) -> None
Loading the model.
Arguments:
load_dir(str)
: the directory to load the model
class DeterministicModel(nn.Module):
Methods
def __init__(self, in_size, out_size, device, hidden_size, num_hidden_layers) -> None
def forward(self, x) -> torch.Tensor
def loss(self, model_in, target) -> tuple[torch.Tensor, dict[str, any]]
def sample(self, model_input, rng) -> torch.Tensor
Samples next observation, reward and terminal from the model.
Arguments:
model_input(torch.Tensor)
: the observation and action.rng(torch.Generator)
: a random number generator.
Returns:
- predicted observation, rewards, terminal indicator and model state dictionary.
class Normalizer:
Class that keeps a running mean and variance and normalizes data accordingly.
Methods
def __init__(self) -> None
def update_stats(self, data) -> None
Updates the stored statistics using the given data.
Arguments:
data(torch.Tensor)
: The data used to compute the statistics.
def normalize(self, val, update_state) -> torch.Tensor
Normalizes the value according to the stored statistics.
Arguments:
val(torch.Tensor)
: The value to normalize.update_state(bool)
: Update state?
Returns:
- The normalized value.
def denormalize(self, val) -> torch.Tensor
De-normalizes the value according to the stored statistics.
Arguments:
val(torch.Tensor)
: The value to de-normalize.
Returns:
- The de-normalized value.