package emote.memory

This module contains all the major building blocks for our memory implementation. The memory was developed in the same time period as DeepMind's Reverb <https://www.deepmind.com/open-source/reverb>_, and shares naming with it, which in turn is borrowing from databases. What is not alike Reverb is that we do not have the RateSamplers (but it can be added). We also do not share data between ArrayTables.

The goal of the memory is to provide a unified interface for all types of machine learning tasks. This is achieved by focusing on configuration and pluggability over code-driven functionality.

Currently, there are three main points of customization:

  • Shape and type of data
  • Insertion, sampling, and eviction
  • Data transformation and generation

High-level parts

ArrayTable

A table is a datastructure containing a specific type of data that shares the same high-level structure.

Columns and Virtual Columns

A column is a storage for a specific type of data where each item is the same shape and type. A virtual column is like a column, but it references another column and does data synthesization or modification w.r.t that. For example, dones and masks are synthetic data based only on indices.

Adaptors

Adaptors are another approach to virtual column but are more suited for transforming the whole batch, such as scaling for reshaping specific datas. Since this step occurs when the data has already been converted to tensors, the full power of Tensorflow is available here and gradients will be correctly tracked.

Strategies, Samplers and Ejectors

Strategies are based on the delegate pattern, where we can inject implementation details through objects instead of using inheritance. Strategies define the API for sampling and ejection from memories, and are queried from the table upon sampling and insertion.

Samplers and Ejectors track the data (but do not own it!). They are used by the table for sampling and ejection based on the policy they implement. Currently we have Fifo and Uniform samplers and ejectors, but one could have prioritized samplers/ejectors, etc.

Proxy Wrappers

Wrappers live around the memory proxy and extend functionality. This is a great point for data conversion, validation, and logging.

Classes

class MemoryTable(Protocol):

Fields

  • adaptors: List[Adaptor]

Methods

def sample(self, count, sequence_length) -> SampleResult

Sample COUNT traces from the memory, each consisting of SEQUENCE_LENGTH frames.

The data is transposed in a SoA fashion (since this is both easier to store and easier to consume).

Arguments:

  • count(int)
  • sequence_length(int)
def size(self) -> int

Query the number of elements currently in the memory.

def full(self) -> bool

Query whether the memory is filled.

def add_sequence(self, identity, sequence) -> None

Add a fully terminated sequence to the memory.

Arguments:

  • identity(int)
  • sequence
def store(self, path, version) -> bool

Persist the whole table and all metadata into the designated name.

Arguments:

  • path(str)
  • version(TableSerializationVersion)
def restore(self, path, override_version) -> bool

Restore the data table from the provided path. This also clears the data stores.

Arguments:

  • path(str)
  • override_version(TableSerializationVersion | None)

class MemoryTableProxy:

The sequence builder wraps a sequence-based memory to build full episodes from [identity, observation] data.

Not thread safe.

Methods

def __init__(
    self,
    memory_table,
    minimum_length_threshold,
    use_terminal
,
    *name
) -> None

Arguments:

  • memory_table(MemoryTable)
  • minimum_length_threshold(Optional[int])
  • use_terminal(bool)
  • name(str)
def name(self) -> None
def size(self) -> None
def resize(self, new_size) -> None
def store(self, path) -> None
def is_initial(self, identity) -> None

Returns true if identity is not already used in a partial sequence. Does not validate if the identity is associated with a complete episode.

Arguments:

  • identity(int)
def add(self, observations, responses) -> None
def timers(self) -> None

class MemoryLoader:

Methods

def __init__(
    self,
    memory_table,
    rollout_count,
    rollout_length,
    size_key,
    data_group
) -> None
def is_ready(self) -> None

True if the data loader has enough data to start providing data.

class MemoryExporterProxyWrapper(LoggingMixin, MemoryTableProxyWrapper):

Export the memory at regular intervals.

Methods

def __init__(
    self,
    memory,
    target_memory_name,
    inf_steps_per_memory_export,
    experiment_root_path,
    min_time_per_export
) -> None

Arguments:

  • memory(MemoryTableProxy | MemoryTableProxyWrapper)
  • target_memory_name
  • inf_steps_per_memory_export
  • experiment_root_path(str)
  • min_time_per_export(int) (default: 600)
def add(self, observations, responses) -> None

First add the new batch to the memory.

Arguments:

  • observations(Dict[AgentId, DictObservation])
  • responses(Dict[AgentId, DictResponse])

class MemoryImporterCallback(Callback):

Load and validate a previously exported memory.

Methods

def __init__(
    self,
    memory_table,
    target_memory_name,
    experiment_load_dir,
    load_fname_override
) -> None

Arguments:

  • memory_table(MemoryTable)
  • target_memory_name(str)
  • experiment_load_dir(str)
  • load_fname_override
def restore_state(self) -> None

class LoggingProxyWrapper(LoggingMixin, MemoryTableProxyWrapper):

Methods

def __init__(self, inner, writer, log_interval) -> None
def state_dict(self) -> dict[str, Any]
def load_state_dict(
    self,
    state_dict,
    load_network,
    load_optimizer,
    load_hparams
) -> None
def add(self, observations, responses) -> None
def report(self, metrics, metrics_lists) -> None
def get_report(
    self,
    keys
) -> Tuple[dict[str, int | float | list[float]], dict[str, list[float]]]

class MemoryWarmup(Callback):

A blocker to ensure memory has data. This ensures the memory has enough data when training starts, as the memory will panic otherwise. This is useful if you use an async data generator.

If you do not use an async data generator this can deadlock your training loop and prevent progress.

Methods

def __init__(self, loader, exporter, shutdown_signal) -> None

Arguments:

  • loader(MemoryLoader)
  • exporter(Optional[OnnxExporter])
  • shutdown_signal(Optional[Callable[[], bool]])
def begin_training(self) -> None

class JointMemoryLoader:

A memory loader capable of loading data from multiple MemoryLoaders.

Methods

def __init__(self, loaders, size_key) -> None

Arguments:

  • loaders(list[MemoryLoader])
  • size_key(str) (default: batch_size)
def is_ready(self) -> None