module emote.algorithms.genrl.wrappers
Classes
class DecoderWrapper(nn.Module):
Methods
def __init__(self, decoder, condition_fn, latent_multiplier) -> None
def forward(self, latent, observation) -> torch.Tensor
Running decoder.
Arguments:
latent(torch.Tensor)
: batch x latent_sizeobservation(torch.Tensor)
: batch x obs_size
Returns:
- the sample (batch x data_size)
def load_state_dict(self, state_dict, strict) -> None
class EncoderWrapper(nn.Module):
Methods
def __init__(self, encoder, condition_fn) -> None
def forward(self, action, observation) -> torch.Tensor
Running encoder.
Arguments:
action(torch.Tensor)
: batch x data_sizeobservation(torch.Tensor)
: batch x obs_size
Returns:
- the mean (batch x data_size)
def load_state_dict(self, state_dict, strict) -> None