Commit 7af7aa82 authored by Nemantor's avatar Nemantor Committed by Mayank Mittal

Fixes RSL-RL ONNX exporter for empirical normalization (#78)

The current onnx exporter does not export the empirical normalization
layer. This MR adds the empirical normalization exporting to the JIT
and ONNX exporters for RSL-RL.

- Bug fix (non-breaking change which fixes an issue)

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./orbit.sh --test` and they pass
(some did timeout)
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Co-authored-by: 's avatarMayank Mittal <mittalma@leggedrobotics.com>
parent 90f6fb10
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.6.1"
version = "0.6.2"
# Description
title = "ORBIT Environments"
......
Changelog
---------
0.6.2 (2024-05-31)
~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added exporting of empirical normalization layer to ONNX and JIT when exporting the model using
:meth:`omni.isaac.orbit.actuators.ActuatorNetMLP.export` method. Previously, the normalization layer
was not exported to the ONNX and JIT models. This caused the exported model to not work properly
when used for inference.
0.6.1 (2024-04-16)
~~~~~~~~~~~~~~~~~~
......
......@@ -8,33 +8,37 @@ import os
import torch
def export_policy_as_jit(actor_critic: object, path: str, filename="policy.pt"):
def export_policy_as_jit(actor_critic: object, normalizer: object | None, path: str, filename="policy.pt"):
"""Export policy into a Torch JIT file.
Args:
actor_critic: The actor-critic torch module.
normalizer: The empirical normalizer module. If None, Identity is used.
path: The path to the saving directory.
filename: The name of exported JIT file. Defaults to "policy.pt".
Reference:
https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L180
"""
policy_exporter = _TorchPolicyExporter(actor_critic)
policy_exporter = _TorchPolicyExporter(actor_critic, normalizer)
policy_exporter.export(path, filename)
def export_policy_as_onnx(actor_critic: object, path: str, filename="policy.onnx", verbose=False):
def export_policy_as_onnx(
actor_critic: object, normalizer: object | None, path: str, filename="policy.onnx", verbose=False
):
"""Export policy into a Torch ONNX file.
Args:
actor_critic: The actor-critic torch module.
normalizer: The empirical normalizer module. If None, Identity is used.
path: The path to the saving directory.
filename: The name of exported JIT file. Defaults to "policy.pt".
filename: The name of exported ONNX file. Defaults to "policy.onnx".
verbose: Whether to print the model summary. Defaults to False.
"""
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
policy_exporter = _OnnxPolicyExporter(actor_critic, verbose)
policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose)
policy_exporter.export(path, filename)
......@@ -50,7 +54,7 @@ class _TorchPolicyExporter(torch.nn.Module):
https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L193
"""
def __init__(self, actor_critic):
def __init__(self, actor_critic, normalizer=None):
super().__init__()
self.actor = copy.deepcopy(actor_critic.actor)
self.is_recurrent = actor_critic.is_recurrent
......@@ -61,8 +65,14 @@ class _TorchPolicyExporter(torch.nn.Module):
self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
self.forward = self.forward_lstm
self.reset = self.reset_memory
# copy normalizer if exists
if normalizer:
self.normalizer = copy.deepcopy(normalizer)
else:
self.normalizer = torch.nn.Identity()
def forward_lstm(self, x):
x = self.normalizer(x)
x, (h, c) = self.rnn(x.unsqueeze(0), (self.hidden_state, self.cell_state))
self.hidden_state[:] = h
self.cell_state[:] = c
......@@ -70,7 +80,7 @@ class _TorchPolicyExporter(torch.nn.Module):
return self.actor(x)
def forward(self, x):
return self.actor(x)
return self.actor(self.normalizer(x))
@torch.jit.export
def reset(self):
......@@ -91,7 +101,7 @@ class _TorchPolicyExporter(torch.nn.Module):
class _OnnxPolicyExporter(torch.nn.Module):
"""Exporter of actor-critic into ONNX file."""
def __init__(self, actor_critic, verbose=False):
def __init__(self, actor_critic, normalizer=None, verbose=False):
super().__init__()
self.verbose = verbose
self.actor = copy.deepcopy(actor_critic.actor)
......@@ -100,14 +110,20 @@ class _OnnxPolicyExporter(torch.nn.Module):
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
self.rnn.cpu()
self.forward = self.forward_lstm
# copy normalizer if exists
if normalizer:
self.normalizer = copy.deepcopy(normalizer)
else:
self.normalizer = torch.nn.Identity()
def forward_lstm(self, x_in, h_in, c_in):
x_in = self.normalizer(x_in)
x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in))
x = x.squeeze(0)
return self.actor(x), h, c
def forward(self, x):
return self.actor(x)
return self.actor(self.normalizer(x))
def export(self, path, filename):
self.to("cpu")
......
......@@ -46,6 +46,7 @@ from omni.isaac.orbit_tasks.utils import get_checkpoint_path, parse_env_cfg
from omni.isaac.orbit_tasks.utils.wrappers.rsl_rl import (
RslRlOnPolicyRunnerCfg,
RslRlVecEnvWrapper,
export_policy_as_jit,
export_policy_as_onnx,
)
......@@ -78,9 +79,14 @@ def main():
# obtain the trained policy for inference
policy = ppo_runner.get_inference_policy(device=env.unwrapped.device)
# export policy to onnx
# export policy to onnx/jit
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
export_policy_as_onnx(ppo_runner.alg.actor_critic, export_model_dir, filename="policy.onnx")
export_policy_as_jit(
ppo_runner.alg.actor_critic, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt"
)
export_policy_as_onnx(
ppo_runner.alg.actor_critic, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx"
)
# reset environment
obs, _ = env.get_observations()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment