module emote.extra.onnx_exporter

Classes

class QueuedExport:

Methods

def __init__(self, metadata) -> None
def process(self, storage) -> None
def block_until_complete(self) -> None

class OnnxExporter(LoggingMixin, Callback):

Handles onnx exports of a ML policy. Call export whenever you want to save an onnx version of the current model, or export_threadsafe if you're outside the training loop.

Parameters:

Methods

def __init__(
    self,
    agent_proxy,
    spaces,
    requires_epsilon,
    directory,
    interval,
    prefix,
    device
) -> None

Arguments:

  • agent_proxy(AgentProxy): the agent API to export
  • spaces(MDPSpace): The spaces describing the model inputs and outputs
  • requires_epsilon(bool): If true, the API should accept an input epsilon per action
  • directory(str): path to the directory where the files should be created. If it does not exist it will be created.
  • interval(int | None): if provided, will automatically export ONNX files at this cadence.
  • prefix(str): all file names will have this prefix. (default: savedmodel_)
  • device(torch.device | None): if provided, will transfer the model inputs to this device before exporting.
def add_metadata(self, key, value) -> None
def end_batch(self) -> None
def end_cycle(self) -> None
def process_pending_exports(self) -> None

If you are using export_threadsafe the main thread must call this method regularly to make sure things are actually exported.

def export_threadsafe(self, metadata) -> StorageItem

Same as export, but it can be called in threads other than the main thread.

This method relies on the main thread calling process_pending_exports from time to time. You cannot call this method from the main thread. It will block indefinitely.

Arguments:

  • metadata
def export(self, metadata) -> StorageItem

Serializes a model to onnx and saves it to disk. This must only be called from the main thread. That is, the thread which has ownership over the model and that modifies it. This is usually the thread that has the training loop.

Arguments:

  • metadata
def delete(self, handle) -> bool
def get(self, handle) -> bool
def items(self) -> Sequence[StorageItem]
def latest(self) -> Optional[StorageItem]