Unverified Commit 207566b1 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Loads actuator networks in eval() mode to prevent gradients (#1862)

# Description

Previously, the actuator networks were loaded using `torch.jit.load`,
however, they weren't set to eval mode. This meant that gradient
computation occurred in the background which is not desired. This MR
fixes this issue.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent ecf551f8
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.34.0"
version = "0.34.1"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.34.1 (2025-02-17)
~~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Ensured that the loaded torch JIT models inside actuator networks are correctly set to eval mode
to prevent any unexpected behavior during inference.
0.34.0 (2025-02-14)
~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
......
......@@ -47,7 +47,7 @@ class ActuatorNetLSTM(DCMotor):
# load the model from JIT file
file_bytes = read_file(self.cfg.network_file)
self.network = torch.jit.load(file_bytes, map_location=self._device)
self.network = torch.jit.load(file_bytes, map_location=self._device).eval()
# extract number of lstm layers and hidden dim from the shape of weights
num_layers = len(self.network.lstm.state_dict()) // 4
......@@ -126,7 +126,7 @@ class ActuatorNetMLP(DCMotor):
# load the model from JIT file
file_bytes = read_file(self.cfg.network_file)
self.network = torch.jit.load(file_bytes, map_location=self._device)
self.network = torch.jit.load(file_bytes, map_location=self._device).eval()
# create buffers for MLP history
history_length = max(self.cfg.input_idx) + 1
......@@ -175,7 +175,8 @@ class ActuatorNetMLP(DCMotor):
)
# run network inference
torques = self.network(network_input).view(self._num_envs, self.num_joints)
with torch.inference_mode():
torques = self.network(network_input).view(self._num_envs, self.num_joints)
self.computed_effort = torques.view(self._num_envs, self.num_joints) * self.cfg.torque_scale
# clip the computed effort based on the motor limits
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment