Unverified Commit aec72bdc authored by Sixiang Chen's avatar Sixiang Chen Committed by GitHub

Fix: observation space Dict for non-concatenated groups only keeps last term (#3134)

# Description

This PR fixes a bug in the observation space construction for
non-concatenated groups in `ManagerBasedRLEnv._configure_gym_env_spaces`
method. Previously, only the last term in each group was included in the
Dict, causing loss of observation information. Now, all terms are
correctly added to the group Dict.

Fixes #3133 

<!-- As a practice, it is recommended to open an issue to have
discussions on the proposed pull request.
This makes it easier for the community to keep track of what is being
developed or added, and if a given feature
is demanded by more than one party. -->

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)


<!--
Example:

| Before | After |
| ------ | ----- |
| _gif/png before_ | _gif/png after_ |

To upload images to a PR -- simply drag and drop an image while in edit
mode and it should upload the image directly. You can then paste that
source into the above before/after sections.
-->
parent 65d6087f
...@@ -117,6 +117,7 @@ Guidelines for modifications: ...@@ -117,6 +117,7 @@ Guidelines for modifications:
* Shafeef Omar * Shafeef Omar
* Shaoshu Su * Shaoshu Su
* Shaurya Dewan * Shaurya Dewan
* Sixiang Chen
* Shundo Kishi * Shundo Kishi
* Stefan Van de Mosselaer * Stefan Van de Mosselaer
* Stephan Pleines * Stephan Pleines
......
...@@ -334,12 +334,12 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env): ...@@ -334,12 +334,12 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env):
self.single_observation_space[group_name] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=group_dim) self.single_observation_space[group_name] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=group_dim)
else: else:
group_term_cfgs = self.observation_manager._group_obs_term_cfgs[group_name] group_term_cfgs = self.observation_manager._group_obs_term_cfgs[group_name]
term_dict = {}
for term_name, term_dim, term_cfg in zip(group_term_names, group_dim, group_term_cfgs): for term_name, term_dim, term_cfg in zip(group_term_names, group_dim, group_term_cfgs):
low = -np.inf if term_cfg.clip is None else term_cfg.clip[0] low = -np.inf if term_cfg.clip is None else term_cfg.clip[0]
high = np.inf if term_cfg.clip is None else term_cfg.clip[1] high = np.inf if term_cfg.clip is None else term_cfg.clip[1]
self.single_observation_space[group_name] = gym.spaces.Dict( term_dict[term_name] = gym.spaces.Box(low=low, high=high, shape=term_dim)
{term_name: gym.spaces.Box(low=low, high=high, shape=term_dim)} self.single_observation_space[group_name] = gym.spaces.Dict(term_dict)
)
# action space (unbounded since we don't impose any limits) # action space (unbounded since we don't impose any limits)
action_dim = sum(self.action_manager.action_term_dim) action_dim = sum(self.action_manager.action_term_dim)
self.single_action_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(action_dim,)) self.single_action_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(action_dim,))
......
...@@ -12,6 +12,7 @@ simulation_app = AppLauncher(headless=True, enable_cameras=True).app ...@@ -12,6 +12,7 @@ simulation_app = AppLauncher(headless=True, enable_cameras=True).app
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import torch
import omni.usd import omni.usd
import pytest import pytest
...@@ -55,3 +56,86 @@ def test_obs_space_follows_clip_contraint(env_cfg_cls, device): ...@@ -55,3 +56,86 @@ def test_obs_space_follows_clip_contraint(env_cfg_cls, device):
assert np.all(term_space.high == high) assert np.all(term_space.high == high)
env.close() env.close()
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_non_concatenated_obs_groups_contain_all_terms(device):
"""Test that non-concatenated observation groups contain all defined terms (issue #3133).
Before the fix, only the last term in each non-concatenated group would be present
in the observation space Dict. This test ensures all terms are correctly included.
"""
from isaaclab_tasks.manager_based.manipulation.stack.config.franka.stack_joint_pos_env_cfg import (
FrankaCubeStackEnvCfg,
)
# new USD stage
omni.usd.get_context().new_stage()
# configure the stack env - it has multiple non-concatenated observation groups
env_cfg = FrankaCubeStackEnvCfg()
env_cfg.scene.num_envs = 2 # keep num_envs small for testing
env_cfg.sim.device = device
env = ManagerBasedRLEnv(cfg=env_cfg)
# Verify that observation space is properly structured
assert isinstance(env.observation_space, gym.spaces.Dict), "Top-level observation space should be Dict"
# Test 'policy' group - should have 9 terms (not just the last one due to the bug)
assert "policy" in env.observation_space.spaces, "Policy group missing from observation space"
policy_space = env.observation_space.spaces["policy"]
assert isinstance(policy_space, gym.spaces.Dict), "Policy group should be Dict space"
expected_policy_terms = [
"actions",
"joint_pos",
"joint_vel",
"object",
"cube_positions",
"cube_orientations",
"eef_pos",
"eef_quat",
"gripper_pos",
]
# This is the key test - before the fix, only "gripper_pos" (last term) would be present
assert len(policy_space.spaces) == len(expected_policy_terms), (
f"Policy group should have {len(expected_policy_terms)} terms, got {len(policy_space.spaces)}:"
f" {list(policy_space.spaces.keys())}"
)
for term_name in expected_policy_terms:
assert term_name in policy_space.spaces, f"Term '{term_name}' missing from policy group"
assert isinstance(policy_space.spaces[term_name], gym.spaces.Box), f"Term '{term_name}' should be Box space"
# Test 'subtask_terms' group - should have 3 terms (not just the last one)
assert "subtask_terms" in env.observation_space.spaces, "Subtask_terms group missing from observation space"
subtask_space = env.observation_space.spaces["subtask_terms"]
assert isinstance(subtask_space, gym.spaces.Dict), "Subtask_terms group should be Dict space"
expected_subtask_terms = ["grasp_1", "stack_1", "grasp_2"]
# Before the fix, only "grasp_2" (last term) would be present
assert len(subtask_space.spaces) == len(expected_subtask_terms), (
f"Subtask_terms group should have {len(expected_subtask_terms)} terms, got {len(subtask_space.spaces)}:"
f" {list(subtask_space.spaces.keys())}"
)
for term_name in expected_subtask_terms:
assert term_name in subtask_space.spaces, f"Term '{term_name}' missing from subtask_terms group"
assert isinstance(subtask_space.spaces[term_name], gym.spaces.Box), f"Term '{term_name}' should be Box space"
# Test that we can get observations and they match the space structure
env.reset()
action = torch.tensor(env.action_space.sample(), device=env.device)
obs, reward, terminated, truncated, info = env.step(action)
# Verify all terms are present in actual observations
for term_name in expected_policy_terms:
assert term_name in obs["policy"], f"Term '{term_name}' missing from policy observation"
for term_name in expected_subtask_terms:
assert term_name in obs["subtask_terms"], f"Term '{term_name}' missing from subtask_terms observation"
env.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment