Commit 535e1803 authored by Wesley Maa's avatar Wesley Maa Committed by Kelly Guo

Make GRU-based RNNs exportable in RSL RL (#3009)

# Description

Adds correct forward pass for GRU-based recurrent policies in the
`isaaclab_rl/rsl_rl/exporter.py` script.

See [issue](https://github.com/isaac-sim/IsaacLab/issues/3008) for more
details

## Type of change
- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent aa421304
......@@ -64,10 +64,17 @@ class _TorchPolicyExporter(torch.nn.Module):
# set up recurrent network
if self.is_recurrent:
self.rnn.cpu()
self.rnn_type = type(self.rnn).__name__.lower() # 'lstm' or 'gru'
self.register_buffer("hidden_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
self.forward = self.forward_lstm
self.reset = self.reset_memory
if self.rnn_type == "lstm":
self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
self.forward = self.forward_lstm
self.reset = self.reset_memory
elif self.rnn_type == "gru":
self.forward = self.forward_gru
self.reset = self.reset_memory
else:
raise NotImplementedError(f"Unsupported RNN type: {self.rnn_type}")
# copy normalizer if exists
if normalizer:
self.normalizer = copy.deepcopy(normalizer)
......@@ -82,6 +89,13 @@ class _TorchPolicyExporter(torch.nn.Module):
x = x.squeeze(0)
return self.actor(x)
def forward_gru(self, x):
x = self.normalizer(x)
x, h = self.rnn(x.unsqueeze(0), self.hidden_state)
self.hidden_state[:] = h
x = x.squeeze(0)
return self.actor(x)
def forward(self, x):
return self.actor(self.normalizer(x))
......@@ -91,7 +105,8 @@ class _TorchPolicyExporter(torch.nn.Module):
def reset_memory(self):
self.hidden_state[:] = 0.0
self.cell_state[:] = 0.0
if hasattr(self, "cell_state"):
self.cell_state[:] = 0.0
def export(self, path, filename):
os.makedirs(path, exist_ok=True)
......@@ -122,7 +137,13 @@ class _OnnxPolicyExporter(torch.nn.Module):
# set up recurrent network
if self.is_recurrent:
self.rnn.cpu()
self.forward = self.forward_lstm
self.rnn_type = type(self.rnn).__name__.lower() # 'lstm' or 'gru'
if self.rnn_type == "lstm":
self.forward = self.forward_lstm
elif self.rnn_type == "gru":
self.forward = self.forward_gru
else:
raise NotImplementedError(f"Unsupported RNN type: {self.rnn_type}")
# copy normalizer if exists
if normalizer:
self.normalizer = copy.deepcopy(normalizer)
......@@ -135,6 +156,12 @@ class _OnnxPolicyExporter(torch.nn.Module):
x = x.squeeze(0)
return self.actor(x), h, c
def forward_gru(self, x_in, h_in):
x_in = self.normalizer(x_in)
x, h = self.rnn(x_in.unsqueeze(0), h_in)
x = x.squeeze(0)
return self.actor(x), h
def forward(self, x):
return self.actor(self.normalizer(x))
......@@ -144,19 +171,34 @@ class _OnnxPolicyExporter(torch.nn.Module):
if self.is_recurrent:
obs = torch.zeros(1, self.rnn.input_size)
h_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
c_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
actions, h_out, c_out = self(obs, h_in, c_in)
torch.onnx.export(
self,
(obs, h_in, c_in),
os.path.join(path, filename),
export_params=True,
opset_version=11,
verbose=self.verbose,
input_names=["obs", "h_in", "c_in"],
output_names=["actions", "h_out", "c_out"],
dynamic_axes={},
)
if self.rnn_type == "lstm":
c_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
torch.onnx.export(
self,
(obs, h_in, c_in),
os.path.join(path, filename),
export_params=True,
opset_version=11,
verbose=self.verbose,
input_names=["obs", "h_in", "c_in"],
output_names=["actions", "h_out", "c_out"],
dynamic_axes={},
)
elif self.rnn_type == "gru":
torch.onnx.export(
self,
(obs, h_in),
os.path.join(path, filename),
export_params=True,
opset_version=11,
verbose=self.verbose,
input_names=["obs", "h_in"],
output_names=["actions", "h_out"],
dynamic_axes={},
)
else:
raise NotImplementedError(f"Unsupported RNN type: {self.rnn_type}")
else:
obs = torch.zeros(1, self.actor[0].in_features)
torch.onnx.export(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment