module emote.models.model_env
Classes
class ModelEnv:
Wraps a dynamics model into a gym-like environment.
Methods
def __init__(
self
,
*num_envs,
model,
termination_fn,
reward_fn,
generator,
input_key
) -> None
Arguments:
num_envs(int)
: the number of envs to simulate in parallel (batch_size).model(DynamicModel)
: the dynamic model to wrap.termination_fn(TermFnType)
: a function that receives observations, and returns a boolean flag indicating whether the episode should end or not.reward_fn(Optional[RewardFnType])
: a function that receives actions and observations and returns the value of the resulting reward in the environment.generator(Optional[torch.Generator])
: a torch random number generatorinput_key(str)
def reset(self, initial_obs_batch, len_rollout) -> None
Resets the model environment.
Arguments:
initial_obs_batch(torch.Tensor)
: a batch of initial observations.len_rollout(int)
: the max length of the model rollout
def step(self, actions) -> tuple[Tensor, Tensor, Tensor, dict[str, Tensor]]
Steps the model environment with the given batch of actions.
Arguments:
actions(np.ndarray)
: the actions for each "episode" to rollout. Shape must be batch_size x dim_actions. If a np.ndarray is given, it's converted to a torch.Tensor and sent to the model device.
Returns:
- (tuple | dict): contains the predicted next observation, reward, done flag. The done flag and rewards are computed using the termination_fn and reward_fn passed in the constructor. The rewards can also be predicted by the model.
def dict_step(
self,
actions
) -> tuple[dict[AgentId, DictObservation], dict[str, float]]
The function to step the Gym-like model with dict_action.
Arguments:
actions(dict[AgentId, DictResponse])
: the dict actions.
Returns:
- (tuple[dict[AgentId, DictObservation], dict[str, float]]): the predicted next dict observation, reward, and done flag.
def dict_reset(self, obs, len_rollout) -> dict[AgentId, DictObservation]
Resets the model env.
Arguments:
obs(torch.Tensor)
: the initial observations.len_rollout(int)
: the max rollout length
Returns:
- the formatted initial observation.