Unverified Commit c438fe97 authored by Pascal Roth's avatar Pascal Roth Committed by GitHub

Fixes configclass behavior to support user-defined post-init call (#114)

# Description

Previously, the configclass was overwriting a user-defined configclass
which isn't desirable. This MR adds a function to allow configclasses to
support
[`__post_init__`](https://docs.python.org/3/library/dataclasses.html#post-init-processing).

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file

---------
Co-authored-by: 's avatarNikita Rudin <nrudin@nvidia.com>
Co-authored-by: 's avatarMayank Mittal <mittalma@leggedrobotics.com>
parent 16c740f5
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.8.6" version = "0.8.7"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.8.7 (2023-08-03)
~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Added support for `__post_init__ <https://docs.python.org/3/library/dataclasses.html#post-init-processing>`_ in
the :class:`omni.isaac.orbit.utils.configclass` decorator.
0.8.6 (2023-08-03) 0.8.6 (2023-08-03)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
......
...@@ -74,7 +74,11 @@ def configclass(cls, **kwargs): ...@@ -74,7 +74,11 @@ def configclass(cls, **kwargs):
# add field factory # add field factory
_process_mutable_types(cls) _process_mutable_types(cls)
# copy mutable members # copy mutable members
setattr(cls, "__post_init__", _custom_post_init) # note: we check if user defined __post_init__ function exists and augment it with our own
if hasattr(cls, "__post_init__"):
setattr(cls, "__post_init__", _combined_function(cls.__post_init__, _custom_post_init))
else:
setattr(cls, "__post_init__", _custom_post_init)
# add helper functions for dictionary conversion # add helper functions for dictionary conversion
setattr(cls, "to_dict", _class_to_dict) setattr(cls, "to_dict", _class_to_dict)
setattr(cls, "from_dict", _update_class_from_dict) setattr(cls, "from_dict", _update_class_from_dict)
...@@ -316,6 +320,25 @@ def _custom_post_init(obj): ...@@ -316,6 +320,25 @@ def _custom_post_init(obj):
setattr(obj, key, deepcopy(value)) setattr(obj, key, deepcopy(value))
def _combined_function(f1: Callable, f2: Callable) -> Callable:
"""Combine two functions into one.
Args:
f1 (Callable): The first function.
f2 (Callable): The second function.
Returns:
Callable: The combined function.
"""
def _combined(*args, **kwargs):
# call both functions
f1(*args, **kwargs)
f2(*args, **kwargs)
return _combined
""" """
Helper functions Helper functions
""" """
......
...@@ -100,6 +100,19 @@ class BasicDemoCfg: ...@@ -100,6 +100,19 @@ class BasicDemoCfg:
robot_default_state: RobotDefaultStateCfg = RobotDefaultStateCfg() robot_default_state: RobotDefaultStateCfg = RobotDefaultStateCfg()
@configclass
class BasicDemoPostInitCfg:
"""Dummy configuration class."""
device_id: int = 0
env: EnvCfg = EnvCfg()
robot_default_state: RobotDefaultStateCfg = RobotDefaultStateCfg()
def __post_init__(self):
self.device_id = 1
self.add_variable = 3
""" """
Dummy configuration to check type annotations ordering. Dummy configuration to check type annotations ordering.
""" """
...@@ -143,6 +156,7 @@ class ParentDemoCfg: ...@@ -143,6 +156,7 @@ class ParentDemoCfg:
b = 2 # type annotation missing on purpose b = 2 # type annotation missing on purpose
c: RobotDefaultStateCfg = MISSING # add new missing field c: RobotDefaultStateCfg = MISSING # add new missing field
j: List[str] = MISSING # add new missing field j: List[str] = MISSING # add new missing field
i: List[str] = MISSING # add new missing field
func: Callable = MISSING # add new missing field func: Callable = MISSING # add new missing field
...@@ -160,6 +174,10 @@ class ChildDemoCfg(ParentDemoCfg): ...@@ -160,6 +174,10 @@ class ChildDemoCfg(ParentDemoCfg):
dummy_class = DummyClass dummy_class = DummyClass
def __post_init__(self):
self.b = 3 # change value of existing field
self.i = ["a", "b"] # change value of existing field
@configclass @configclass
class ChildChildDemoCfg(ChildDemoCfg): class ChildChildDemoCfg(ChildDemoCfg):
...@@ -168,6 +186,12 @@ class ChildChildDemoCfg(ChildDemoCfg): ...@@ -168,6 +186,12 @@ class ChildChildDemoCfg(ChildDemoCfg):
func_2 = dummy_function2 func_2 = dummy_function2
d = 2 # set default value for missing field d = 2 # set default value for missing field
def __post_init__(self):
"""Post initialization function."""
super().__post_init__()
self.b = 4 # set default value for missing field
self.f = "new" # add new missing field
""" """
Configuration with class inside. Configuration with class inside.
...@@ -211,6 +235,9 @@ class OutsideClassCfg: ...@@ -211,6 +235,9 @@ class OutsideClassCfg:
inside: InsideClassCfg = InsideClassCfg() inside: InsideClassCfg = InsideClassCfg()
x: int = 20 x: int = 20
def __post_init__(self):
self.inside.b = "dummy_changed"
""" """
Dummy configuration: Functions Dummy configuration: Functions
...@@ -252,6 +279,18 @@ basic_demo_cfg_change_correct = { ...@@ -252,6 +279,18 @@ basic_demo_cfg_change_correct = {
"device_id": 0, "device_id": 0,
} }
basic_demo_post_init_cfg_correct = {
"env": {"num_envs": 56, "episode_length": 2000, "viewer": {"eye": [7.5, 7.5, 7.5], "lookat": [0.0, 0.0, 0.0]}},
"robot_default_state": {
"pos": (0.0, 0.0, 0.0),
"rot": (1.0, 0.0, 0.0, 0.0),
"dof_pos": (0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
"dof_vel": [0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
},
"device_id": 1,
"add_variable": 3,
}
""" """
Test solutions: Functions Test solutions: Functions
""" """
...@@ -356,6 +395,10 @@ class TestConfigClass(unittest.TestCase): ...@@ -356,6 +395,10 @@ class TestConfigClass(unittest.TestCase):
cfg.from_dict(cfg_dict) cfg.from_dict(cfg_dict)
self.assertDictEqual(cfg.to_dict(), basic_demo_cfg_change_correct) self.assertDictEqual(cfg.to_dict(), basic_demo_cfg_change_correct)
def test_config_update_dict_using_post_init(self):
cfg = BasicDemoPostInitCfg()
self.assertDictEqual(cfg.to_dict(), basic_demo_post_init_cfg_correct)
def test_invalid_update_key(self): def test_invalid_update_key(self):
"""Test invalid key update.""" """Test invalid key update."""
cfg = BasicDemoCfg() cfg = BasicDemoCfg()
...@@ -542,10 +585,13 @@ class TestConfigClass(unittest.TestCase): ...@@ -542,10 +585,13 @@ class TestConfigClass(unittest.TestCase):
self.assertEqual(cfg.func, dummy_function1) self.assertEqual(cfg.func, dummy_function1)
self.assertEqual(cfg.a, 20) self.assertEqual(cfg.a, 20)
self.assertEqual(cfg.b, 2)
self.assertEqual(cfg.d, 3) self.assertEqual(cfg.d, 3)
self.assertEqual(cfg.j, ["c", "d"]) self.assertEqual(cfg.j, ["c", "d"])
# check post init
self.assertEqual(cfg.b, 3)
self.assertEqual(cfg.i, ["a", "b"])
def test_config_double_inheritance(self): def test_config_double_inheritance(self):
"""Tests that inheritance works properly when inheriting twice.""" """Tests that inheritance works properly when inheriting twice."""
# check variables # check variables
...@@ -554,10 +600,14 @@ class TestConfigClass(unittest.TestCase): ...@@ -554,10 +600,14 @@ class TestConfigClass(unittest.TestCase):
self.assertEqual(cfg.func, dummy_function1) self.assertEqual(cfg.func, dummy_function1)
self.assertEqual(cfg.func_2, dummy_function2) self.assertEqual(cfg.func_2, dummy_function2)
self.assertEqual(cfg.a, 20) self.assertEqual(cfg.a, 20)
self.assertEqual(cfg.b, 2)
self.assertEqual(cfg.d, 3) self.assertEqual(cfg.d, 3)
self.assertEqual(cfg.j, ["c", "d"]) self.assertEqual(cfg.j, ["c", "d"])
# check post init
self.assertEqual(cfg.b, 4)
self.assertEqual(cfg.f, "new")
self.assertEqual(cfg.i, ["a", "b"])
def test_config_with_class_type(self): def test_config_with_class_type(self):
"""Tests that configclass works properly with class type.""" """Tests that configclass works properly with class type."""
...@@ -587,7 +637,7 @@ class TestConfigClass(unittest.TestCase): ...@@ -587,7 +637,7 @@ class TestConfigClass(unittest.TestCase):
self.assertNotIn("InsideInsideClassCfg", cfg.inside.__annotations__) self.assertNotIn("InsideInsideClassCfg", cfg.inside.__annotations__)
# check values # check values
self.assertEqual(cfg.inside.class_name, DummyClass) self.assertEqual(cfg.inside.class_name, DummyClass)
self.assertEqual(cfg.inside.b, "dummy") self.assertEqual(cfg.inside.b, "dummy_changed")
self.assertEqual(cfg.x, 20) self.assertEqual(cfg.x, 20)
def test_config_dumping(self): def test_config_dumping(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment