module emote.models.model_env

Classes

class ModelEnv:

Wraps a dynamics model into a gym-like environment.

Methods

def __init__(
    self
,
    *num_envs,
    model,
    termination_fn,
    reward_fn,
    generator,
    input_key
) -> None

Arguments:

  • num_envs(int): the number of envs to simulate in parallel (batch_size).
  • model(DynamicModel): the dynamic model to wrap.
  • termination_fn(TermFnType): a function that receives observations, and returns a boolean flag indicating whether the episode should end or not.
  • reward_fn(Optional[RewardFnType]): a function that receives actions and observations and returns the value of the resulting reward in the environment.
  • generator(Optional[torch.Generator]): a torch random number generator
  • input_key(str)
def reset(self, initial_obs_batch, len_rollout) -> None

Resets the model environment.

Arguments:

  • initial_obs_batch(torch.Tensor): a batch of initial observations.
  • len_rollout(int): the max length of the model rollout
def step(self, actions) -> tuple[Tensor, Tensor, Tensor, dict[str, Tensor]]

Steps the model environment with the given batch of actions.

Arguments:

  • actions(np.ndarray): the actions for each "episode" to rollout. Shape must be batch_size x dim_actions. If a np.ndarray is given, it's converted to a torch.Tensor and sent to the model device.

Returns:

  • (tuple | dict): contains the predicted next observation, reward, done flag. The done flag and rewards are computed using the termination_fn and reward_fn passed in the constructor. The rewards can also be predicted by the model.
def dict_step(
    self,
    actions
) -> tuple[dict[AgentId, DictObservation], dict[str, float]]

The function to step the Gym-like model with dict_action.

Arguments:

  • actions(dict[AgentId, DictResponse]): the dict actions.

Returns:

  • (tuple[dict[AgentId, DictObservation], dict[str, float]]): the predicted next dict observation, reward, and done flag.
def dict_reset(self, obs, len_rollout) -> dict[AgentId, DictObservation]

Resets the model env.

Arguments:

  • obs(torch.Tensor): the initial observations.
  • len_rollout(int): the max rollout length

Returns:

  • the formatted initial observation.