module emote.extra.onnx_exporter
Classes
class QueuedExport:
Methods
def __init__(self, metadata) -> None
def process(self, storage) -> None
def block_until_complete(self) -> None
class OnnxExporter(LoggingMixin, Callback):
Handles onnx exports of a ML policy.
Call export
whenever you want to save an onnx version of the
current model, or export_threadsafe
if you're outside the
training loop.
Parameters:
Methods
def __init__(
self,
agent_proxy,
spaces,
requires_epsilon,
directory,
interval,
prefix,
device
) -> None
Arguments:
agent_proxy(AgentProxy)
: the agent API to exportspaces(MDPSpace)
: The spaces describing the model inputs and outputsrequires_epsilon(bool)
: If true, the API should accept an input epsilon per actiondirectory(str)
: path to the directory where the files should be created. If it does not exist it will be created.interval(int | None)
: if provided, will automatically export ONNX files at this cadence.prefix(str)
: all file names will have this prefix. (default: savedmodel_)device(torch.device | None)
: if provided, will transfer the model inputs to this device before exporting.
def add_metadata(self, key, value) -> None
def end_batch(self) -> None
def end_cycle(self) -> None
def process_pending_exports(self) -> None
If you are using export_threadsafe
the main thread must call this
method regularly to make sure things are actually exported.
def export_threadsafe(self, metadata) -> StorageItem
Same as export
, but it can be called in threads other than the
main thread.
This method relies on the main thread calling process_pending_exports
from time to time.
You cannot call this method from the main thread. It will block indefinitely.
Arguments:
metadata
def export(self, metadata) -> StorageItem
Serializes a model to onnx and saves it to disk. This must only be called from the main thread. That is, the thread which has ownership over the model and that modifies it. This is usually the thread that has the training loop.
Arguments:
metadata
def delete(self, handle) -> bool
def get(self, handle) -> bool
def items(self) -> Sequence[StorageItem]
def latest(self) -> Optional[StorageItem]