module emote.proxies

Proxies are bridges between the world the agent acts in and the algorithm training loop.

Classes

class AgentProxy(Protocol):

The interface between the agent in the game and the network used during training.

Methods

def __call__(self, obserations) -> Dict[AgentId, DictResponse]

Take observations for the active agents and returns the relevant network output.

Arguments:

  • obserations(Dict[AgentId, DictObservation])
def policy(self) -> nn.Module
def input_names(self) -> tuple[str, ...]
def output_names(self) -> tuple[str, ...]

class MemoryProxy(Protocol):

The interface between the agent in the game and the memory buffer the network trains from.

Methods

def add(self, observations, responses) -> None

Store episodes in the memory buffer used for training. This is useful e.g. if the data collection is running from a checkpointed model running on another machine.

Arguments:

  • observations(Dict[AgentId, DictObservation])
  • responses(Dict[AgentId, DictResponse])

class GenericAgentProxy(AgentProxy):

Observations are dicts that contain multiple input and output keys. For example, we might have a policy that takes in both "obs" and "goal" and outputs "actions". In order to be able to properly invoke the network it is the responsibility of this proxy to collate the inputs and decollate the outputs per agent.

Methods

def __init__(
    self,
    policy,
    device,
    input_keys,
    output_keys,
    uses_logprobs,
    spaces
) -> None

Handle multi-input multi-output policy networks.

Arguments:

  • policy(nn.Module): The neural network policy that takes observations and returns actions.
  • device(torch.device): The device to run the policy on.
  • input_keys(tuple): Keys specifying what fields from the observation to pass to the policy.
  • output_keys(tuple): Keys for the fields in the output dictionary that the policy is responsible for.
  • uses_logprobs(bool) (default: True)
  • spaces(MDPSpace | None): A utility for managing observation and action spaces, for validation.
def __call__(self, observations) -> dict[AgentId, DictResponse]

Runs the policy and returns the actions.

Arguments:

  • observations(dict[AgentId, DictObservation])
def input_names(self) -> None
def output_names(self) -> None
def policy(self) -> None