Unverified Commit d1ecc377 authored by Patrick Yin's avatar Patrick Yin Committed by GitHub

Fixes ObservationManager history buffer corrupted by external calls to...

Fixes ObservationManager history buffer corrupted by external calls to ObservationManager.compute (#2885)

# Description

When observation group has history length greater than zero, calling
`ObservationManager.compute` modifies history state by appending current
observation to history. This creates history corruption when
non-`ManagerBasedEnv` classes invoke `ObservationManager.compute`. This
PR introduces `update_history` flag (default to `False`) and only
`ManagerBasedEnv` has the privilege to run `ObservationManager.compute`
with `update_history=True`. If `update_history=False` and the history
buffer is `None`, a copy of history is returned instead of the original.

I have added test cases to verify this fix is effective.

Fixes #2884 

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarooctipus <zhengyuz@nvidia.com>
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarooctipus <zhengyuz@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 8e57a3a6
...@@ -97,6 +97,7 @@ Guidelines for modifications: ...@@ -97,6 +97,7 @@ Guidelines for modifications:
* Ori Gadot * Ori Gadot
* Oyindamola Omotuyi * Oyindamola Omotuyi
* Özhan Özen * Özhan Özen
* Patrick Yin
* Peter Du * Peter Du
* Pulkit Goyal * Pulkit Goyal
* Qian Wan * Qian Wan
......
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.40.17" version = "0.40.18"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.40.18 (2025-07-09)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added input param ``update_history`` to :meth:`~isaaclab.managers.ObservationManager.compute`
to control whether the history buffer should be updated.
* Added unit test for :class:`~isaaclab.envs.ManagerBasedEnv`.
Fixed
^^^^^
* Fixed :class:`~isaaclab.envs.ManagerBasedEnv` and :class:`~isaaclab.envs.ManagerBasedRLEnv` to not update the history
buffer on recording.
0.40.17 (2025-07-10) 0.40.17 (2025-07-10)
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
......
...@@ -305,7 +305,7 @@ class ManagerBasedEnv: ...@@ -305,7 +305,7 @@ class ManagerBasedEnv:
self.recorder_manager.record_post_reset(env_ids) self.recorder_manager.record_post_reset(env_ids)
# compute observations # compute observations
self.obs_buf = self.observation_manager.compute() self.obs_buf = self.observation_manager.compute(update_history=True)
if self.cfg.wait_for_textures and self.sim.has_rtx_sensors(): if self.cfg.wait_for_textures and self.sim.has_rtx_sensors():
while SimulationManager.assets_loading(): while SimulationManager.assets_loading():
...@@ -365,7 +365,7 @@ class ManagerBasedEnv: ...@@ -365,7 +365,7 @@ class ManagerBasedEnv:
self.recorder_manager.record_post_reset(env_ids) self.recorder_manager.record_post_reset(env_ids)
# compute observations # compute observations
self.obs_buf = self.observation_manager.compute() self.obs_buf = self.observation_manager.compute(update_history=True)
# return observations # return observations
return self.obs_buf, self.extras return self.obs_buf, self.extras
...@@ -416,7 +416,7 @@ class ManagerBasedEnv: ...@@ -416,7 +416,7 @@ class ManagerBasedEnv:
self.event_manager.apply(mode="interval", dt=self.step_dt) self.event_manager.apply(mode="interval", dt=self.step_dt)
# -- compute observations # -- compute observations
self.obs_buf = self.observation_manager.compute() self.obs_buf = self.observation_manager.compute(update_history=True)
self.recorder_manager.record_post_step() self.recorder_manager.record_post_step()
# return observations and extras # return observations and extras
......
...@@ -237,7 +237,7 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env): ...@@ -237,7 +237,7 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env):
self.event_manager.apply(mode="interval", dt=self.step_dt) self.event_manager.apply(mode="interval", dt=self.step_dt)
# -- compute observations # -- compute observations
# note: done after reset to get the correct observations for reset envs # note: done after reset to get the correct observations for reset envs
self.obs_buf = self.observation_manager.compute() self.obs_buf = self.observation_manager.compute(update_history=True)
# return observations, rewards, resets and extras # return observations, rewards, resets and extras
return self.obs_buf, self.reward_buf, self.reset_terminated, self.reset_time_outs, self.extras return self.obs_buf, self.reward_buf, self.reset_terminated, self.reset_time_outs, self.extras
......
...@@ -245,12 +245,17 @@ class ObservationManager(ManagerBase): ...@@ -245,12 +245,17 @@ class ObservationManager(ManagerBase):
# nothing to log here # nothing to log here
return {} return {}
def compute(self) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: def compute(self, update_history: bool = False) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
"""Compute the observations per group for all groups. """Compute the observations per group for all groups.
The method computes the observations for all the groups handled by the observation manager. The method computes the observations for all the groups handled by the observation manager.
Please check the :meth:`compute_group` on the processing of observations per group. Please check the :meth:`compute_group` on the processing of observations per group.
Args:
update_history: The boolean indicator without return obs should be appended to observation history.
Default to False, in which case calling compute_group does not modify history. This input is no-ops
if the group's history_length == 0.
Returns: Returns:
A dictionary with keys as the group names and values as the computed observations. A dictionary with keys as the group names and values as the computed observations.
The observations are either concatenated into a single tensor or returned as a dictionary The observations are either concatenated into a single tensor or returned as a dictionary
...@@ -260,14 +265,14 @@ class ObservationManager(ManagerBase): ...@@ -260,14 +265,14 @@ class ObservationManager(ManagerBase):
obs_buffer = dict() obs_buffer = dict()
# iterate over all the terms in each group # iterate over all the terms in each group
for group_name in self._group_obs_term_names: for group_name in self._group_obs_term_names:
obs_buffer[group_name] = self.compute_group(group_name) obs_buffer[group_name] = self.compute_group(group_name, update_history=update_history)
# otherwise return a dict with observations of all groups # otherwise return a dict with observations of all groups
# Cache the observations. # Cache the observations.
self._obs_buffer = obs_buffer self._obs_buffer = obs_buffer
return obs_buffer return obs_buffer
def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tensor]: def compute_group(self, group_name: str, update_history: bool = False) -> torch.Tensor | dict[str, torch.Tensor]:
"""Computes the observations for a given group. """Computes the observations for a given group.
The observations for a given group are computed by calling the registered functions for each The observations for a given group are computed by calling the registered functions for each
...@@ -290,6 +295,9 @@ class ObservationManager(ManagerBase): ...@@ -290,6 +295,9 @@ class ObservationManager(ManagerBase):
Args: Args:
group_name: The name of the group for which to compute the observations. Defaults to None, group_name: The name of the group for which to compute the observations. Defaults to None,
in which case observations for all the groups are computed and returned. in which case observations for all the groups are computed and returned.
update_history: The boolean indicator without return obs should be appended to observation group's history.
Default to False, in which case calling compute_group does not modify history. This input is no-ops
if the group's history_length == 0.
Returns: Returns:
Depending on the group's configuration, the tensors for individual observation terms are Depending on the group's configuration, the tensors for individual observation terms are
...@@ -330,13 +338,23 @@ class ObservationManager(ManagerBase): ...@@ -330,13 +338,23 @@ class ObservationManager(ManagerBase):
obs = obs.mul_(term_cfg.scale) obs = obs.mul_(term_cfg.scale)
# Update the history buffer if observation term has history enabled # Update the history buffer if observation term has history enabled
if term_cfg.history_length > 0: if term_cfg.history_length > 0:
self._group_obs_term_history_buffer[group_name][term_name].append(obs) circular_buffer = self._group_obs_term_history_buffer[group_name][term_name]
if term_cfg.flatten_history_dim: if update_history:
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer.reshape( circular_buffer.append(obs)
self._env.num_envs, -1 elif circular_buffer._buffer is None:
# because circular buffer only exits after the simulation steps,
# this guards history buffer from corruption by external calls before simulation start
circular_buffer = CircularBuffer(
max_len=circular_buffer.max_length,
batch_size=circular_buffer.batch_size,
device=circular_buffer.device,
) )
circular_buffer.append(obs)
if term_cfg.flatten_history_dim:
group_obs[term_name] = circular_buffer.buffer.reshape(self._env.num_envs, -1)
else: else:
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer group_obs[term_name] = circular_buffer.buffer
else: else:
group_obs[term_name] = obs group_obs[term_name] = obs
......
...@@ -23,6 +23,8 @@ import omni.usd ...@@ -23,6 +23,8 @@ import omni.usd
import pytest import pytest
from isaaclab.envs import ManagerBasedEnv, ManagerBasedEnvCfg from isaaclab.envs import ManagerBasedEnv, ManagerBasedEnvCfg
from isaaclab.managers import ObservationGroupCfg as ObsGroup
from isaaclab.managers import ObservationTermCfg as ObsTerm
from isaaclab.scene import InteractiveSceneCfg from isaaclab.scene import InteractiveSceneCfg
from isaaclab.utils import configclass from isaaclab.utils import configclass
...@@ -34,6 +36,22 @@ class EmptyManagerCfg: ...@@ -34,6 +36,22 @@ class EmptyManagerCfg:
pass pass
@configclass
class EmptyObservationWithHistoryCfg:
"""Empty observation with history specifications for the environment."""
@configclass
class EmptyObservationGroupWithHistoryCfg(ObsGroup):
"""Empty observation with history specifications for the environment."""
dummy_term: ObsTerm = ObsTerm(func=lambda env: torch.randn(env.num_envs, 1, device=env.device))
def __post_init__(self):
self.history_length = 5
empty_observation: EmptyObservationGroupWithHistoryCfg = EmptyObservationGroupWithHistoryCfg()
@configclass @configclass
class EmptySceneCfg(InteractiveSceneCfg): class EmptySceneCfg(InteractiveSceneCfg):
"""Configuration for an empty scene.""" """Configuration for an empty scene."""
...@@ -67,6 +85,32 @@ def get_empty_base_env_cfg(device: str = "cuda:0", num_envs: int = 1, env_spacin ...@@ -67,6 +85,32 @@ def get_empty_base_env_cfg(device: str = "cuda:0", num_envs: int = 1, env_spacin
return EmptyEnvCfg() return EmptyEnvCfg()
def get_empty_base_env_cfg_with_history(device: str = "cuda:0", num_envs: int = 1, env_spacing: float = 1.0):
"""Generate base environment config based on device"""
@configclass
class EmptyEnvWithHistoryCfg(ManagerBasedEnvCfg):
"""Configuration for the empty test environment."""
# Scene settings
scene: EmptySceneCfg = EmptySceneCfg(num_envs=num_envs, env_spacing=env_spacing)
# Basic settings
actions: EmptyManagerCfg = EmptyManagerCfg()
observations: EmptyObservationWithHistoryCfg = EmptyObservationWithHistoryCfg()
def __post_init__(self):
"""Post initialization."""
# step settings
self.decimation = 4 # env step every 4 sim steps: 200Hz / 4 = 50Hz
# simulation settings
self.sim.dt = 0.005 # sim step every 5ms: 200Hz
self.sim.render_interval = self.decimation # render every 4 sim steps
# pass device down from test
self.sim.device = device
return EmptyEnvWithHistoryCfg()
@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) @pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_initialization(device): def test_initialization(device):
"""Test initialization of ManagerBasedEnv.""" """Test initialization of ManagerBasedEnv."""
...@@ -90,3 +134,67 @@ def test_initialization(device): ...@@ -90,3 +134,67 @@ def test_initialization(device):
obs, ext = env.step(action=act) obs, ext = env.step(action=act)
# close the environment # close the environment
env.close() env.close()
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_observation_history_changes_only_after_step(device):
"""Test observation history of ManagerBasedEnv.
The history buffer should only change after a step is taken.
"""
# create a new stage
omni.usd.get_context().new_stage()
# create environment with history length of 5
env = ManagerBasedEnv(cfg=get_empty_base_env_cfg_with_history(device=device))
# check if history buffer is empty
for group_name in env.observation_manager._group_obs_term_names:
group_term_names = env.observation_manager._group_obs_term_names[group_name]
for term_name in group_term_names:
torch.testing.assert_close(
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].current_length,
torch.zeros((env.num_envs,), device=device, dtype=torch.int64),
)
# check if history buffer is empty after compute
env.observation_manager.compute()
for group_name in env.observation_manager._group_obs_term_names:
group_term_names = env.observation_manager._group_obs_term_names[group_name]
for term_name in group_term_names:
torch.testing.assert_close(
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].current_length,
torch.zeros((env.num_envs,), device=device, dtype=torch.int64),
)
# check if history buffer is not empty after step
act = torch.randn_like(env.action_manager.action)
env.step(act)
group_obs = dict()
for group_name in env.observation_manager._group_obs_term_names:
group_term_names = env.observation_manager._group_obs_term_names[group_name]
group_obs[group_name] = dict()
for term_name in group_term_names:
torch.testing.assert_close(
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].current_length,
torch.ones((env.num_envs,), device=device, dtype=torch.int64),
)
group_obs[group_name][term_name] = env.observation_manager._group_obs_term_history_buffer[group_name][
term_name
].buffer
# check if history buffer is not empty after compute and is the same as the buffer after step
env.observation_manager.compute()
for group_name in env.observation_manager._group_obs_term_names:
group_term_names = env.observation_manager._group_obs_term_names[group_name]
for term_name in group_term_names:
torch.testing.assert_close(
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].current_length,
torch.ones((env.num_envs,), device=device, dtype=torch.int64),
)
assert torch.allclose(
group_obs[group_name][term_name],
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].buffer,
)
# close the environment
env.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment