Unverified Commit db5c1c32 authored by James Tigue's avatar James Tigue Committed by GitHub

Fixes export LSTM to onnx file (#2821)

# Description

This PR fixes an issue when exporting LSTM to ONNX. The normalizer was
resetting to zero. This PR calls `eval()` during the `forward()`.

Fixes # (issue)

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)

## Screenshots
Left: After, Right: Before


![image](https://github.com/user-attachments/assets/9a8f765f-653a-4a57-b9ee-af00e8e0539c)

## Checklist

- [ ] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent bc5a3670
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.1.6" version = "0.1.7"
# Description # Description
title = "Isaac Lab RL" title = "Isaac Lab RL"
......
Changelog Changelog
--------- ---------
0.1.7 (2025-06-30)
~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Call :meth:`eval` during :meth:`forward`` RSL-RL OnnxPolicyExporter
0.1.6 (2025-06-26) 0.1.6 (2025-06-26)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
......
...@@ -140,6 +140,7 @@ class _OnnxPolicyExporter(torch.nn.Module): ...@@ -140,6 +140,7 @@ class _OnnxPolicyExporter(torch.nn.Module):
def export(self, path, filename): def export(self, path, filename):
self.to("cpu") self.to("cpu")
self.eval()
if self.is_recurrent: if self.is_recurrent:
obs = torch.zeros(1, self.rnn.input_size) obs = torch.zeros(1, self.rnn.input_size)
h_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size) h_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment