Unverified Commit e5584baf authored by Ziwen Zhuang's avatar Ziwen Zhuang Committed by GitHub

Fixes the issue of using Modifiers and history buffer together (#2461)

# Description

The modifier should get `obs_dims` before modified by history_buffer
initialization, since Modifiers are called before the history buffer.

Fixes #2460

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Screenshots

<img width="1594" alt="screenshots2025-05-10 14 19 57"
src="https://github.com/user-attachments/assets/6d1d65e4-3ac4-4842-8a14-5cafb5521045"
/>

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
Co-authored-by: 's avatarJames Tigue <166445701+jtigue-bdai@users.noreply.github.com>
parent 26785e42
...@@ -442,19 +442,6 @@ class ObservationManager(ManagerBase): ...@@ -442,19 +442,6 @@ class ObservationManager(ManagerBase):
# call function the first time to fill up dimensions # call function the first time to fill up dimensions
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape) obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape)
# create history buffers and calculate history term dimensions
if term_cfg.history_length > 0:
group_entry_history_buffer[term_name] = CircularBuffer(
max_len=term_cfg.history_length, batch_size=self._env.num_envs, device=self._env.device
)
old_dims = list(obs_dims)
old_dims.insert(1, term_cfg.history_length)
obs_dims = tuple(old_dims)
if term_cfg.flatten_history_dim:
obs_dims = (obs_dims[0], np.prod(obs_dims[1:]))
self._group_obs_term_dim[group_name].append(obs_dims[1:])
# if scale is set, check if single float or tuple # if scale is set, check if single float or tuple
if term_cfg.scale is not None: if term_cfg.scale is not None:
if not isinstance(term_cfg.scale, (float, int, tuple)): if not isinstance(term_cfg.scale, (float, int, tuple)):
...@@ -518,6 +505,19 @@ class ObservationManager(ManagerBase): ...@@ -518,6 +505,19 @@ class ObservationManager(ManagerBase):
f" and optional parameters: {args_with_defaults}, but received: {term_params}." f" and optional parameters: {args_with_defaults}, but received: {term_params}."
) )
# create history buffers and calculate history term dimensions
if term_cfg.history_length > 0:
group_entry_history_buffer[term_name] = CircularBuffer(
max_len=term_cfg.history_length, batch_size=self._env.num_envs, device=self._env.device
)
old_dims = list(obs_dims)
old_dims.insert(1, term_cfg.history_length)
obs_dims = tuple(old_dims)
if term_cfg.flatten_history_dim:
obs_dims = (obs_dims[0], np.prod(obs_dims[1:]))
self._group_obs_term_dim[group_name].append(obs_dims[1:])
# add term in a separate list if term is a class # add term in a separate list if term is a class
if isinstance(term_cfg.func, ManagerTermBase): if isinstance(term_cfg.func, ManagerTermBase):
self._group_obs_class_term_cfgs[group_name].append(term_cfg) self._group_obs_class_term_cfgs[group_name].append(term_cfg)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment