Unverified Commit ad4ec6e5 authored by Kelly Guo's avatar Kelly Guo Committed by GitHub

Fixes updating configclass parameter with a list of objects (#847)

# Description

When configclass dicts are nested inside lists, the list is treated as
an Iterable object and assigned directly to the outer configclass when
updating configclass data with dicts. This overwrites the configclass
object in the list with a dict object and causes undesired behavior.

This change checks for nested dictionaries inside Iterables and updates
the values inside the dictionary individually without overwiting the
full Iterable.

Fixes https://github.com/isaac-sim/IsaacLab/issues/843

## Type of change

- Bug fix (non-breaking change which fixes an issue)


## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent d906c4a0
......@@ -95,6 +95,16 @@ def update_class_from_dict(obj, data: dict[str, Any], _ns: str = "") -> None:
)
if isinstance(obj_mem, tuple):
value = tuple(value)
else:
set_obj = True
# recursively call if iterable contains dictionaries
for i in range(len(obj_mem)):
if isinstance(value[i], dict):
update_class_from_dict(obj_mem[i], value[i], _ns=key_ns)
set_obj = False
# do not set value to obj, otherwise it overwrites the cfg class with the dict
if not set_obj:
continue
elif callable(obj_mem):
# update function name
value = string_to_callable(value)
......
......@@ -298,10 +298,11 @@ Dummy configuration: Nested dictionaries
@configclass
class NestedDictCfg:
"""Dummy configuration class with nested dictionaries."""
class NestedDictAndListCfg:
"""Dummy configuration class with nested dictionaries and lists."""
dict_1: dict = {"dict_2": {"func": dummy_function1}}
list_1: list[EnvCfg] = [EnvCfg(), EnvCfg()]
"""
......@@ -341,10 +342,14 @@ basic_demo_cfg_change_with_none_correct = {
"device_id": 0,
}
basic_demo_cfg_nested_dict = {
basic_demo_cfg_nested_dict_and_list = {
"dict_1": {
"dict_2": {"func": dummy_function2},
},
"list_1": [
{"num_envs": 23, "episode_length": 3000, "viewer": {"eye": [5.0, 5.0, 5.0], "lookat": [0.0, 0.0, 0.0]}},
{"num_envs": 24, "episode_length": 2000, "viewer": {"eye": [6.0, 6.0, 6.0], "lookat": [0.0, 0.0, 0.0]}},
],
}
basic_demo_post_init_cfg_correct = {
......@@ -456,6 +461,10 @@ class TestConfigClass(unittest.TestCase):
update_class_from_dict(cfg, cfg_dict)
self.assertDictEqual(asdict(cfg), basic_demo_cfg_change_correct)
# check types are also correct
self.assertIsInstance(cfg.env.viewer, ViewerCfg)
self.assertIsInstance(cfg.env.viewer.eye, tuple)
def test_config_update_dict_with_none(self):
"""Test updating configclass using a dictionary that contains None."""
cfg = BasicDemoCfg()
......@@ -464,11 +473,23 @@ class TestConfigClass(unittest.TestCase):
self.assertDictEqual(asdict(cfg), basic_demo_cfg_change_with_none_correct)
def test_config_update_nested_dict(self):
"""Test updating configclass with sub-dictionnaries."""
cfg = NestedDictCfg()
cfg_dict = {"dict_1": {"dict_2": {"func": "__main__:dummy_function2"}}}
"""Test updating configclass with sub-dictionaries."""
cfg = NestedDictAndListCfg()
cfg_dict = {
"dict_1": {"dict_2": {"func": "__main__:dummy_function2"}},
"list_1": [
{"num_envs": 23, "episode_length": 3000, "viewer": {"eye": [5.0, 5.0, 5.0]}},
{"num_envs": 24, "viewer": {"eye": [6.0, 6.0, 6.0]}},
],
}
update_class_from_dict(cfg, cfg_dict)
self.assertDictEqual(asdict(cfg), basic_demo_cfg_nested_dict)
self.assertDictEqual(asdict(cfg), basic_demo_cfg_nested_dict_and_list)
# check types are also correct
self.assertIsInstance(cfg.list_1[0], EnvCfg)
self.assertIsInstance(cfg.list_1[1], EnvCfg)
self.assertIsInstance(cfg.list_1[0].viewer, ViewerCfg)
self.assertIsInstance(cfg.list_1[1].viewer, ViewerCfg)
def test_config_update_dict_using_internal(self):
"""Test updating configclass from a dictionary using configclass method."""
......
......@@ -74,7 +74,7 @@ def hydra_task_config(task_name: str, agent_cfg_entry_point: str) -> Callable:
# register the task to Hydra
env_cfg, agent_cfg = register_task_to_hydra(task_name, agent_cfg_entry_point)
# define thr new Hydra main function
# define the new Hydra main function
@hydra.main(config_path=None, config_name=task_name, version_base="1.3")
def hydra_main(hydra_env_cfg: DictConfig, env_cfg=env_cfg, agent_cfg=agent_cfg):
# convert to a native dictionary
......
......@@ -16,11 +16,50 @@ simulation_app = app_launcher.app
"""Rest everything follows."""
import functools
import unittest
from collections.abc import Callable
import hydra
from hydra import compose, initialize
from omegaconf import OmegaConf
from omni.isaac.lab.utils import replace_strings_with_slices
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
from omni.isaac.lab_tasks.utils.hydra import register_task_to_hydra
def hydra_task_config_test(task_name: str, agent_cfg_entry_point: str) -> Callable:
"""Copied from hydra.py hydra_task_config, since hydra.main requires a single point of entry,
which will not work with multiple tests. Here, we replace hydra.main with hydra initialize
and compose."""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# register the task to Hydra
env_cfg, agent_cfg = register_task_to_hydra(task_name, agent_cfg_entry_point)
# replace hydra.main with initialize and compose
with initialize(config_path=None, version_base="1.3"):
hydra_env_cfg = compose(config_name=task_name, overrides=sys.argv[1:])
# convert to a native dictionary
hydra_env_cfg = OmegaConf.to_container(hydra_env_cfg, resolve=True)
# replace string with slices because OmegaConf does not support slices
hydra_env_cfg = replace_strings_with_slices(hydra_env_cfg)
# update the configs with the Hydra command line arguments
env_cfg.from_dict(hydra_env_cfg["env"])
if isinstance(agent_cfg, dict):
agent_cfg = hydra_env_cfg["agent"]
else:
agent_cfg.from_dict(hydra_env_cfg["agent"])
# call the original function
func(env_cfg, agent_cfg, *args, **kwargs)
return wrapper
return decorator
class TestHydra(unittest.TestCase):
......@@ -39,7 +78,7 @@ class TestHydra(unittest.TestCase):
"agent.max_iterations=3", # test simple agent modification
]
@hydra_task_config("Isaac-Velocity-Flat-H1-v0", "rsl_rl_cfg_entry_point")
@hydra_task_config_test("Isaac-Velocity-Flat-H1-v0", "rsl_rl_cfg_entry_point")
def main(env_cfg, agent_cfg, self):
# env
self.assertEqual(env_cfg.decimation, 42)
......@@ -50,6 +89,23 @@ class TestHydra(unittest.TestCase):
self.assertEqual(agent_cfg.max_iterations, 3)
main(self)
# clean up
sys.argv = [sys.argv[0]]
hydra.core.global_hydra.GlobalHydra.instance().clear()
def test_nested_iterable_dict(self):
"""Test the hydra configuration system when dict is nested in an Iterable."""
@hydra_task_config_test("Isaac-Lift-Cube-Franka-v0", "rsl_rl_cfg_entry_point")
def main(env_cfg, agent_cfg, self):
# env
self.assertEqual(env_cfg.scene.ee_frame.target_frames[0].name, "end_effector")
self.assertEqual(env_cfg.scene.ee_frame.target_frames[0].offset.pos[2], 0.1034)
main(self)
# clean up
sys.argv = [sys.argv[0]]
hydra.core.global_hydra.GlobalHydra.instance().clear()
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment