Commit 7af7aa82 authored by Nemantor's avatar Nemantor Committed by Mayank Mittal

Fixes RSL-RL ONNX exporter for empirical normalization (#78)

The current onnx exporter does not export the empirical normalization
layer. This MR adds the empirical normalization exporting to the JIT
and ONNX exporters for RSL-RL.

- Bug fix (non-breaking change which fixes an issue)

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./orbit.sh --test` and they pass
(some did timeout)
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Co-authored-by: 's avatarMayank Mittal <mittalma@leggedrobotics.com>
parent 90f6fb10
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.6.1" version = "0.6.2"
# Description # Description
title = "ORBIT Environments" title = "ORBIT Environments"
......
Changelog Changelog
--------- ---------
0.6.2 (2024-05-31)
~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added exporting of empirical normalization layer to ONNX and JIT when exporting the model using
:meth:`omni.isaac.orbit.actuators.ActuatorNetMLP.export` method. Previously, the normalization layer
was not exported to the ONNX and JIT models. This caused the exported model to not work properly
when used for inference.
0.6.1 (2024-04-16) 0.6.1 (2024-04-16)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
......
...@@ -8,33 +8,37 @@ import os ...@@ -8,33 +8,37 @@ import os
import torch import torch
def export_policy_as_jit(actor_critic: object, path: str, filename="policy.pt"): def export_policy_as_jit(actor_critic: object, normalizer: object | None, path: str, filename="policy.pt"):
"""Export policy into a Torch JIT file. """Export policy into a Torch JIT file.
Args: Args:
actor_critic: The actor-critic torch module. actor_critic: The actor-critic torch module.
normalizer: The empirical normalizer module. If None, Identity is used.
path: The path to the saving directory. path: The path to the saving directory.
filename: The name of exported JIT file. Defaults to "policy.pt". filename: The name of exported JIT file. Defaults to "policy.pt".
Reference: Reference:
https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L180 https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L180
""" """
policy_exporter = _TorchPolicyExporter(actor_critic) policy_exporter = _TorchPolicyExporter(actor_critic, normalizer)
policy_exporter.export(path, filename) policy_exporter.export(path, filename)
def export_policy_as_onnx(actor_critic: object, path: str, filename="policy.onnx", verbose=False): def export_policy_as_onnx(
actor_critic: object, normalizer: object | None, path: str, filename="policy.onnx", verbose=False
):
"""Export policy into a Torch ONNX file. """Export policy into a Torch ONNX file.
Args: Args:
actor_critic: The actor-critic torch module. actor_critic: The actor-critic torch module.
normalizer: The empirical normalizer module. If None, Identity is used.
path: The path to the saving directory. path: The path to the saving directory.
filename: The name of exported JIT file. Defaults to "policy.pt". filename: The name of exported ONNX file. Defaults to "policy.onnx".
verbose: Whether to print the model summary. Defaults to False. verbose: Whether to print the model summary. Defaults to False.
""" """
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
policy_exporter = _OnnxPolicyExporter(actor_critic, verbose) policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose)
policy_exporter.export(path, filename) policy_exporter.export(path, filename)
...@@ -50,7 +54,7 @@ class _TorchPolicyExporter(torch.nn.Module): ...@@ -50,7 +54,7 @@ class _TorchPolicyExporter(torch.nn.Module):
https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L193 https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L193
""" """
def __init__(self, actor_critic): def __init__(self, actor_critic, normalizer=None):
super().__init__() super().__init__()
self.actor = copy.deepcopy(actor_critic.actor) self.actor = copy.deepcopy(actor_critic.actor)
self.is_recurrent = actor_critic.is_recurrent self.is_recurrent = actor_critic.is_recurrent
...@@ -61,8 +65,14 @@ class _TorchPolicyExporter(torch.nn.Module): ...@@ -61,8 +65,14 @@ class _TorchPolicyExporter(torch.nn.Module):
self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)) self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
self.forward = self.forward_lstm self.forward = self.forward_lstm
self.reset = self.reset_memory self.reset = self.reset_memory
# copy normalizer if exists
if normalizer:
self.normalizer = copy.deepcopy(normalizer)
else:
self.normalizer = torch.nn.Identity()
def forward_lstm(self, x): def forward_lstm(self, x):
x = self.normalizer(x)
x, (h, c) = self.rnn(x.unsqueeze(0), (self.hidden_state, self.cell_state)) x, (h, c) = self.rnn(x.unsqueeze(0), (self.hidden_state, self.cell_state))
self.hidden_state[:] = h self.hidden_state[:] = h
self.cell_state[:] = c self.cell_state[:] = c
...@@ -70,7 +80,7 @@ class _TorchPolicyExporter(torch.nn.Module): ...@@ -70,7 +80,7 @@ class _TorchPolicyExporter(torch.nn.Module):
return self.actor(x) return self.actor(x)
def forward(self, x): def forward(self, x):
return self.actor(x) return self.actor(self.normalizer(x))
@torch.jit.export @torch.jit.export
def reset(self): def reset(self):
...@@ -91,7 +101,7 @@ class _TorchPolicyExporter(torch.nn.Module): ...@@ -91,7 +101,7 @@ class _TorchPolicyExporter(torch.nn.Module):
class _OnnxPolicyExporter(torch.nn.Module): class _OnnxPolicyExporter(torch.nn.Module):
"""Exporter of actor-critic into ONNX file.""" """Exporter of actor-critic into ONNX file."""
def __init__(self, actor_critic, verbose=False): def __init__(self, actor_critic, normalizer=None, verbose=False):
super().__init__() super().__init__()
self.verbose = verbose self.verbose = verbose
self.actor = copy.deepcopy(actor_critic.actor) self.actor = copy.deepcopy(actor_critic.actor)
...@@ -100,14 +110,20 @@ class _OnnxPolicyExporter(torch.nn.Module): ...@@ -100,14 +110,20 @@ class _OnnxPolicyExporter(torch.nn.Module):
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn) self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
self.rnn.cpu() self.rnn.cpu()
self.forward = self.forward_lstm self.forward = self.forward_lstm
# copy normalizer if exists
if normalizer:
self.normalizer = copy.deepcopy(normalizer)
else:
self.normalizer = torch.nn.Identity()
def forward_lstm(self, x_in, h_in, c_in): def forward_lstm(self, x_in, h_in, c_in):
x_in = self.normalizer(x_in)
x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in)) x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in))
x = x.squeeze(0) x = x.squeeze(0)
return self.actor(x), h, c return self.actor(x), h, c
def forward(self, x): def forward(self, x):
return self.actor(x) return self.actor(self.normalizer(x))
def export(self, path, filename): def export(self, path, filename):
self.to("cpu") self.to("cpu")
......
...@@ -46,6 +46,7 @@ from omni.isaac.orbit_tasks.utils import get_checkpoint_path, parse_env_cfg ...@@ -46,6 +46,7 @@ from omni.isaac.orbit_tasks.utils import get_checkpoint_path, parse_env_cfg
from omni.isaac.orbit_tasks.utils.wrappers.rsl_rl import ( from omni.isaac.orbit_tasks.utils.wrappers.rsl_rl import (
RslRlOnPolicyRunnerCfg, RslRlOnPolicyRunnerCfg,
RslRlVecEnvWrapper, RslRlVecEnvWrapper,
export_policy_as_jit,
export_policy_as_onnx, export_policy_as_onnx,
) )
...@@ -78,9 +79,14 @@ def main(): ...@@ -78,9 +79,14 @@ def main():
# obtain the trained policy for inference # obtain the trained policy for inference
policy = ppo_runner.get_inference_policy(device=env.unwrapped.device) policy = ppo_runner.get_inference_policy(device=env.unwrapped.device)
# export policy to onnx # export policy to onnx/jit
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported") export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
export_policy_as_onnx(ppo_runner.alg.actor_critic, export_model_dir, filename="policy.onnx") export_policy_as_jit(
ppo_runner.alg.actor_critic, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt"
)
export_policy_as_onnx(
ppo_runner.alg.actor_critic, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx"
)
# reset environment # reset environment
obs, _ = env.get_observations() obs, _ = env.get_observations()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment