package emote.memory
This module contains all the major building blocks for our memory
implementation. The memory was developed in the same time period as
DeepMind's Reverb <https://www.deepmind.com/open-source/reverb>
_, and shares
naming with it, which in turn is borrowing from databases. What is not
alike Reverb is that we do not have the RateSamplers (but it can be
added). We also do not share data between ArrayTables.
The goal of the memory is to provide a unified interface for all types of machine learning tasks. This is achieved by focusing on configuration and pluggability over code-driven functionality.
Currently, there are three main points of customization:
- Shape and type of data
- Insertion, sampling, and eviction
- Data transformation and generation
High-level parts
ArrayTable
A table is a datastructure containing a specific type of data that shares the same high-level structure.
Columns and Virtual Columns
A column is a storage for a specific type of data where each item is the same shape and type. A virtual column is like a column, but it references another column and does data synthesization or modification w.r.t that. For example, dones and masks are synthetic data based only on indices.
Adaptors
Adaptors are another approach to virtual column but are more suited for transforming the whole batch, such as scaling for reshaping specific datas. Since this step occurs when the data has already been converted to tensors, the full power of Tensorflow is available here and gradients will be correctly tracked.
Strategies, Samplers and Ejectors
Strategies are based on the delegate pattern, where we can inject implementation details through objects instead of using inheritance. Strategies define the API for sampling and ejection from memories, and are queried from the table upon sampling and insertion.
Samplers and Ejectors track the data (but do not own it!). They are used by the table for sampling and ejection based on the policy they implement. Currently we have Fifo and Uniform samplers and ejectors, but one could have prioritized samplers/ejectors, etc.
Proxy Wrappers
Wrappers live around the memory proxy and extend functionality. This is a great point for data conversion, validation, and logging.
Classes
class MemoryTable(Protocol):
Fields
adaptors
:List[Adaptor]
Methods
def sample(self, count, sequence_length) -> SampleResult
Sample COUNT traces from the memory, each consisting of SEQUENCE_LENGTH frames.
The data is transposed in a SoA fashion (since this is both easier to store and easier to consume).
Arguments:
count(int)
sequence_length(int)
def size(self) -> int
Query the number of elements currently in the memory.
def full(self) -> bool
Query whether the memory is filled.
def add_sequence(self, identity, sequence) -> None
Add a fully terminated sequence to the memory.
Arguments:
identity(int)
sequence
def store(self, path, version) -> bool
Persist the whole table and all metadata into the designated name.
Arguments:
path(str)
version(TableSerializationVersion)
def restore(self, path, override_version) -> bool
Restore the data table from the provided path. This also clears the data stores.
Arguments:
path(str)
override_version(TableSerializationVersion | None)
class MemoryTableProxy:
The sequence builder wraps a sequence-based memory to build full episodes from [identity, observation] data.
Not thread safe.
Methods
def __init__(
self,
memory_table,
minimum_length_threshold,
use_terminal
,
*name
) -> None
Arguments:
memory_table(MemoryTable)
minimum_length_threshold(Optional[int])
use_terminal(bool)
name(str)
def name(self) -> None
def size(self) -> None
def resize(self, new_size) -> None
def store(self, path) -> None
def is_initial(self, identity) -> None
Returns true if identity is not already used in a partial sequence. Does not validate if the identity is associated with a complete episode.
Arguments:
identity(int)
def add(self, observations, responses) -> None
def timers(self) -> None
class MemoryLoader:
Methods
def __init__(
self,
memory_table,
rollout_count,
rollout_length,
size_key,
data_group
) -> None
def is_ready(self) -> None
True if the data loader has enough data to start providing data.
class MemoryExporterProxyWrapper(LoggingMixin, MemoryTableProxyWrapper):
Export the memory at regular intervals.
Methods
def __init__(
self,
memory,
target_memory_name,
inf_steps_per_memory_export,
experiment_root_path,
min_time_per_export
) -> None
Arguments:
memory(MemoryTableProxy | MemoryTableProxyWrapper)
target_memory_name
inf_steps_per_memory_export
experiment_root_path(str)
min_time_per_export(int)
(default: 600)
def add(self, observations, responses) -> None
First add the new batch to the memory.
Arguments:
observations(Dict[AgentId, DictObservation])
responses(Dict[AgentId, DictResponse])
class MemoryImporterCallback(Callback):
Load and validate a previously exported memory.
Methods
def __init__(
self,
memory_table,
target_memory_name,
experiment_load_dir,
load_fname_override
) -> None
Arguments:
memory_table(MemoryTable)
target_memory_name(str)
experiment_load_dir(str)
load_fname_override
def restore_state(self) -> None
class LoggingProxyWrapper(LoggingMixin, MemoryTableProxyWrapper):
Methods
def __init__(self, inner, writer, log_interval) -> None
def state_dict(self) -> dict[str, Any]
def load_state_dict(
self,
state_dict,
load_network,
load_optimizer,
load_hparams
) -> None
def add(self, observations, responses) -> None
def report(self, metrics, metrics_lists) -> None
def get_report(
self,
keys
) -> Tuple[dict[str, int | float | list[float]], dict[str, list[float]]]
class MemoryWarmup(Callback):
A blocker to ensure memory has data. This ensures the memory has enough data when training starts, as the memory will panic otherwise. This is useful if you use an async data generator.
If you do not use an async data generator this can deadlock your training loop and prevent progress.
Methods
def __init__(self, loader, exporter, shutdown_signal) -> None
Arguments:
loader(MemoryLoader)
exporter(Optional[OnnxExporter])
shutdown_signal(Optional[Callable[[], bool]])
def begin_training(self) -> None
class JointMemoryLoader:
A memory loader capable of loading data from multiple
MemoryLoader
s.
Methods
def __init__(self, loaders, size_key) -> None
Arguments:
loaders(list[MemoryLoader])
size_key(str)
(default: batch_size)
def is_ready(self) -> None