Commit 535e1803 authored by Wesley Maa's avatar Wesley Maa Committed by Kelly Guo

Make GRU-based RNNs exportable in RSL RL (#3009)

# Description

Adds correct forward pass for GRU-based recurrent policies in the
`isaaclab_rl/rsl_rl/exporter.py` script.

See [issue](https://github.com/isaac-sim/IsaacLab/issues/3008) for more
details

## Type of change
- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent aa421304
...@@ -64,10 +64,17 @@ class _TorchPolicyExporter(torch.nn.Module): ...@@ -64,10 +64,17 @@ class _TorchPolicyExporter(torch.nn.Module):
# set up recurrent network # set up recurrent network
if self.is_recurrent: if self.is_recurrent:
self.rnn.cpu() self.rnn.cpu()
self.rnn_type = type(self.rnn).__name__.lower() # 'lstm' or 'gru'
self.register_buffer("hidden_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)) self.register_buffer("hidden_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
if self.rnn_type == "lstm":
self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)) self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
self.forward = self.forward_lstm self.forward = self.forward_lstm
self.reset = self.reset_memory self.reset = self.reset_memory
elif self.rnn_type == "gru":
self.forward = self.forward_gru
self.reset = self.reset_memory
else:
raise NotImplementedError(f"Unsupported RNN type: {self.rnn_type}")
# copy normalizer if exists # copy normalizer if exists
if normalizer: if normalizer:
self.normalizer = copy.deepcopy(normalizer) self.normalizer = copy.deepcopy(normalizer)
...@@ -82,6 +89,13 @@ class _TorchPolicyExporter(torch.nn.Module): ...@@ -82,6 +89,13 @@ class _TorchPolicyExporter(torch.nn.Module):
x = x.squeeze(0) x = x.squeeze(0)
return self.actor(x) return self.actor(x)
def forward_gru(self, x):
x = self.normalizer(x)
x, h = self.rnn(x.unsqueeze(0), self.hidden_state)
self.hidden_state[:] = h
x = x.squeeze(0)
return self.actor(x)
def forward(self, x): def forward(self, x):
return self.actor(self.normalizer(x)) return self.actor(self.normalizer(x))
...@@ -91,6 +105,7 @@ class _TorchPolicyExporter(torch.nn.Module): ...@@ -91,6 +105,7 @@ class _TorchPolicyExporter(torch.nn.Module):
def reset_memory(self): def reset_memory(self):
self.hidden_state[:] = 0.0 self.hidden_state[:] = 0.0
if hasattr(self, "cell_state"):
self.cell_state[:] = 0.0 self.cell_state[:] = 0.0
def export(self, path, filename): def export(self, path, filename):
...@@ -122,7 +137,13 @@ class _OnnxPolicyExporter(torch.nn.Module): ...@@ -122,7 +137,13 @@ class _OnnxPolicyExporter(torch.nn.Module):
# set up recurrent network # set up recurrent network
if self.is_recurrent: if self.is_recurrent:
self.rnn.cpu() self.rnn.cpu()
self.rnn_type = type(self.rnn).__name__.lower() # 'lstm' or 'gru'
if self.rnn_type == "lstm":
self.forward = self.forward_lstm self.forward = self.forward_lstm
elif self.rnn_type == "gru":
self.forward = self.forward_gru
else:
raise NotImplementedError(f"Unsupported RNN type: {self.rnn_type}")
# copy normalizer if exists # copy normalizer if exists
if normalizer: if normalizer:
self.normalizer = copy.deepcopy(normalizer) self.normalizer = copy.deepcopy(normalizer)
...@@ -135,6 +156,12 @@ class _OnnxPolicyExporter(torch.nn.Module): ...@@ -135,6 +156,12 @@ class _OnnxPolicyExporter(torch.nn.Module):
x = x.squeeze(0) x = x.squeeze(0)
return self.actor(x), h, c return self.actor(x), h, c
def forward_gru(self, x_in, h_in):
x_in = self.normalizer(x_in)
x, h = self.rnn(x_in.unsqueeze(0), h_in)
x = x.squeeze(0)
return self.actor(x), h
def forward(self, x): def forward(self, x):
return self.actor(self.normalizer(x)) return self.actor(self.normalizer(x))
...@@ -144,8 +171,9 @@ class _OnnxPolicyExporter(torch.nn.Module): ...@@ -144,8 +171,9 @@ class _OnnxPolicyExporter(torch.nn.Module):
if self.is_recurrent: if self.is_recurrent:
obs = torch.zeros(1, self.rnn.input_size) obs = torch.zeros(1, self.rnn.input_size)
h_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size) h_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
if self.rnn_type == "lstm":
c_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size) c_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
actions, h_out, c_out = self(obs, h_in, c_in)
torch.onnx.export( torch.onnx.export(
self, self,
(obs, h_in, c_in), (obs, h_in, c_in),
...@@ -157,6 +185,20 @@ class _OnnxPolicyExporter(torch.nn.Module): ...@@ -157,6 +185,20 @@ class _OnnxPolicyExporter(torch.nn.Module):
output_names=["actions", "h_out", "c_out"], output_names=["actions", "h_out", "c_out"],
dynamic_axes={}, dynamic_axes={},
) )
elif self.rnn_type == "gru":
torch.onnx.export(
self,
(obs, h_in),
os.path.join(path, filename),
export_params=True,
opset_version=11,
verbose=self.verbose,
input_names=["obs", "h_in"],
output_names=["actions", "h_out"],
dynamic_axes={},
)
else:
raise NotImplementedError(f"Unsupported RNN type: {self.rnn_type}")
else: else:
obs = torch.zeros(1, self.actor[0].in_features) obs = torch.zeros(1, self.actor[0].in_features)
torch.onnx.export( torch.onnx.export(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment