Unverified Commit d1ecc377 authored by Patrick Yin's avatar Patrick Yin Committed by GitHub

Fixes ObservationManager history buffer corrupted by external calls to...

Fixes ObservationManager history buffer corrupted by external calls to ObservationManager.compute (#2885)

# Description

When observation group has history length greater than zero, calling
`ObservationManager.compute` modifies history state by appending current
observation to history. This creates history corruption when
non-`ManagerBasedEnv` classes invoke `ObservationManager.compute`. This
PR introduces `update_history` flag (default to `False`) and only
`ManagerBasedEnv` has the privilege to run `ObservationManager.compute`
with `update_history=True`. If `update_history=False` and the history
buffer is `None`, a copy of history is returned instead of the original.

I have added test cases to verify this fix is effective.

Fixes #2884 

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarooctipus <zhengyuz@nvidia.com>
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarooctipus <zhengyuz@nvidia.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 8e57a3a6
......@@ -97,6 +97,7 @@ Guidelines for modifications:
* Ori Gadot
* Oyindamola Omotuyi
* Özhan Özen
* Patrick Yin
* Peter Du
* Pulkit Goyal
* Qian Wan
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.40.17"
version = "0.40.18"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.40.18 (2025-07-09)
~~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added input param ``update_history`` to :meth:`~isaaclab.managers.ObservationManager.compute`
to control whether the history buffer should be updated.
* Added unit test for :class:`~isaaclab.envs.ManagerBasedEnv`.
Fixed
^^^^^
* Fixed :class:`~isaaclab.envs.ManagerBasedEnv` and :class:`~isaaclab.envs.ManagerBasedRLEnv` to not update the history
buffer on recording.
0.40.17 (2025-07-10)
~~~~~~~~~~~~~~~~~~~~
......
......@@ -305,7 +305,7 @@ class ManagerBasedEnv:
self.recorder_manager.record_post_reset(env_ids)
# compute observations
self.obs_buf = self.observation_manager.compute()
self.obs_buf = self.observation_manager.compute(update_history=True)
if self.cfg.wait_for_textures and self.sim.has_rtx_sensors():
while SimulationManager.assets_loading():
......@@ -365,7 +365,7 @@ class ManagerBasedEnv:
self.recorder_manager.record_post_reset(env_ids)
# compute observations
self.obs_buf = self.observation_manager.compute()
self.obs_buf = self.observation_manager.compute(update_history=True)
# return observations
return self.obs_buf, self.extras
......@@ -416,7 +416,7 @@ class ManagerBasedEnv:
self.event_manager.apply(mode="interval", dt=self.step_dt)
# -- compute observations
self.obs_buf = self.observation_manager.compute()
self.obs_buf = self.observation_manager.compute(update_history=True)
self.recorder_manager.record_post_step()
# return observations and extras
......
......@@ -237,7 +237,7 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env):
self.event_manager.apply(mode="interval", dt=self.step_dt)
# -- compute observations
# note: done after reset to get the correct observations for reset envs
self.obs_buf = self.observation_manager.compute()
self.obs_buf = self.observation_manager.compute(update_history=True)
# return observations, rewards, resets and extras
return self.obs_buf, self.reward_buf, self.reset_terminated, self.reset_time_outs, self.extras
......
......@@ -245,12 +245,17 @@ class ObservationManager(ManagerBase):
# nothing to log here
return {}
def compute(self) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
def compute(self, update_history: bool = False) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
"""Compute the observations per group for all groups.
The method computes the observations for all the groups handled by the observation manager.
Please check the :meth:`compute_group` on the processing of observations per group.
Args:
update_history: The boolean indicator without return obs should be appended to observation history.
Default to False, in which case calling compute_group does not modify history. This input is no-ops
if the group's history_length == 0.
Returns:
A dictionary with keys as the group names and values as the computed observations.
The observations are either concatenated into a single tensor or returned as a dictionary
......@@ -260,14 +265,14 @@ class ObservationManager(ManagerBase):
obs_buffer = dict()
# iterate over all the terms in each group
for group_name in self._group_obs_term_names:
obs_buffer[group_name] = self.compute_group(group_name)
obs_buffer[group_name] = self.compute_group(group_name, update_history=update_history)
# otherwise return a dict with observations of all groups
# Cache the observations.
self._obs_buffer = obs_buffer
return obs_buffer
def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tensor]:
def compute_group(self, group_name: str, update_history: bool = False) -> torch.Tensor | dict[str, torch.Tensor]:
"""Computes the observations for a given group.
The observations for a given group are computed by calling the registered functions for each
......@@ -290,6 +295,9 @@ class ObservationManager(ManagerBase):
Args:
group_name: The name of the group for which to compute the observations. Defaults to None,
in which case observations for all the groups are computed and returned.
update_history: The boolean indicator without return obs should be appended to observation group's history.
Default to False, in which case calling compute_group does not modify history. This input is no-ops
if the group's history_length == 0.
Returns:
Depending on the group's configuration, the tensors for individual observation terms are
......@@ -330,13 +338,23 @@ class ObservationManager(ManagerBase):
obs = obs.mul_(term_cfg.scale)
# Update the history buffer if observation term has history enabled
if term_cfg.history_length > 0:
self._group_obs_term_history_buffer[group_name][term_name].append(obs)
if term_cfg.flatten_history_dim:
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer.reshape(
self._env.num_envs, -1
circular_buffer = self._group_obs_term_history_buffer[group_name][term_name]
if update_history:
circular_buffer.append(obs)
elif circular_buffer._buffer is None:
# because circular buffer only exits after the simulation steps,
# this guards history buffer from corruption by external calls before simulation start
circular_buffer = CircularBuffer(
max_len=circular_buffer.max_length,
batch_size=circular_buffer.batch_size,
device=circular_buffer.device,
)
circular_buffer.append(obs)
if term_cfg.flatten_history_dim:
group_obs[term_name] = circular_buffer.buffer.reshape(self._env.num_envs, -1)
else:
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer
group_obs[term_name] = circular_buffer.buffer
else:
group_obs[term_name] = obs
......
......@@ -23,6 +23,8 @@ import omni.usd
import pytest
from isaaclab.envs import ManagerBasedEnv, ManagerBasedEnvCfg
from isaaclab.managers import ObservationGroupCfg as ObsGroup
from isaaclab.managers import ObservationTermCfg as ObsTerm
from isaaclab.scene import InteractiveSceneCfg
from isaaclab.utils import configclass
......@@ -34,6 +36,22 @@ class EmptyManagerCfg:
pass
@configclass
class EmptyObservationWithHistoryCfg:
"""Empty observation with history specifications for the environment."""
@configclass
class EmptyObservationGroupWithHistoryCfg(ObsGroup):
"""Empty observation with history specifications for the environment."""
dummy_term: ObsTerm = ObsTerm(func=lambda env: torch.randn(env.num_envs, 1, device=env.device))
def __post_init__(self):
self.history_length = 5
empty_observation: EmptyObservationGroupWithHistoryCfg = EmptyObservationGroupWithHistoryCfg()
@configclass
class EmptySceneCfg(InteractiveSceneCfg):
"""Configuration for an empty scene."""
......@@ -67,6 +85,32 @@ def get_empty_base_env_cfg(device: str = "cuda:0", num_envs: int = 1, env_spacin
return EmptyEnvCfg()
def get_empty_base_env_cfg_with_history(device: str = "cuda:0", num_envs: int = 1, env_spacing: float = 1.0):
"""Generate base environment config based on device"""
@configclass
class EmptyEnvWithHistoryCfg(ManagerBasedEnvCfg):
"""Configuration for the empty test environment."""
# Scene settings
scene: EmptySceneCfg = EmptySceneCfg(num_envs=num_envs, env_spacing=env_spacing)
# Basic settings
actions: EmptyManagerCfg = EmptyManagerCfg()
observations: EmptyObservationWithHistoryCfg = EmptyObservationWithHistoryCfg()
def __post_init__(self):
"""Post initialization."""
# step settings
self.decimation = 4 # env step every 4 sim steps: 200Hz / 4 = 50Hz
# simulation settings
self.sim.dt = 0.005 # sim step every 5ms: 200Hz
self.sim.render_interval = self.decimation # render every 4 sim steps
# pass device down from test
self.sim.device = device
return EmptyEnvWithHistoryCfg()
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_initialization(device):
"""Test initialization of ManagerBasedEnv."""
......@@ -90,3 +134,67 @@ def test_initialization(device):
obs, ext = env.step(action=act)
# close the environment
env.close()
@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_observation_history_changes_only_after_step(device):
"""Test observation history of ManagerBasedEnv.
The history buffer should only change after a step is taken.
"""
# create a new stage
omni.usd.get_context().new_stage()
# create environment with history length of 5
env = ManagerBasedEnv(cfg=get_empty_base_env_cfg_with_history(device=device))
# check if history buffer is empty
for group_name in env.observation_manager._group_obs_term_names:
group_term_names = env.observation_manager._group_obs_term_names[group_name]
for term_name in group_term_names:
torch.testing.assert_close(
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].current_length,
torch.zeros((env.num_envs,), device=device, dtype=torch.int64),
)
# check if history buffer is empty after compute
env.observation_manager.compute()
for group_name in env.observation_manager._group_obs_term_names:
group_term_names = env.observation_manager._group_obs_term_names[group_name]
for term_name in group_term_names:
torch.testing.assert_close(
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].current_length,
torch.zeros((env.num_envs,), device=device, dtype=torch.int64),
)
# check if history buffer is not empty after step
act = torch.randn_like(env.action_manager.action)
env.step(act)
group_obs = dict()
for group_name in env.observation_manager._group_obs_term_names:
group_term_names = env.observation_manager._group_obs_term_names[group_name]
group_obs[group_name] = dict()
for term_name in group_term_names:
torch.testing.assert_close(
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].current_length,
torch.ones((env.num_envs,), device=device, dtype=torch.int64),
)
group_obs[group_name][term_name] = env.observation_manager._group_obs_term_history_buffer[group_name][
term_name
].buffer
# check if history buffer is not empty after compute and is the same as the buffer after step
env.observation_manager.compute()
for group_name in env.observation_manager._group_obs_term_names:
group_term_names = env.observation_manager._group_obs_term_names[group_name]
for term_name in group_term_names:
torch.testing.assert_close(
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].current_length,
torch.ones((env.num_envs,), device=device, dtype=torch.int64),
)
assert torch.allclose(
group_obs[group_name][term_name],
env.observation_manager._group_obs_term_history_buffer[group_name][term_name].buffer,
)
# close the environment
env.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment