module emote.algorithms.genrl.wrappers

Classes

class DecoderWrapper(nn.Module):

Methods

def __init__(self, decoder, condition_fn, latent_multiplier) -> None
def forward(self, latent, observation) -> torch.Tensor

Running decoder.

Arguments:

  • latent(torch.Tensor): batch x latent_size
  • observation(torch.Tensor): batch x obs_size

Returns:

  • the sample (batch x data_size)
def load_state_dict(self, state_dict, strict) -> None

class EncoderWrapper(nn.Module):

Methods

def __init__(self, encoder, condition_fn) -> None
def forward(self, action, observation) -> torch.Tensor

Running encoder.

Arguments:

  • action(torch.Tensor): batch x data_size
  • observation(torch.Tensor): batch x obs_size

Returns:

  • the mean (batch x data_size)
def load_state_dict(self, state_dict, strict) -> None

class PolicyWrapper(nn.Module):

Methods

def __init__(self, decoder, policy) -> None
def forward(self, obs, epsilon) -> None