Commit f9eca34c authored by Mayank Mittal's avatar Mayank Mittal

Adds support for class types in configclass (#92)

# Description

This MR adds checks to the `dataclass` wrapper called `configclass` to extend its support for types.  This supports type hinting annotations such as `type`, `Type[Myclass]`, and `ClassVar[type]`.

It also adds a method called `replace` to the configclass that calls the [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace) function. This has been added for the convenience of users.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file
parent 82cbdc22
...@@ -187,7 +187,7 @@ html_show_sphinx = False ...@@ -187,7 +187,7 @@ html_show_sphinx = False
def skip_member(app, what, name, obj, skip, options): def skip_member(app, what, name, obj, skip, options):
exclusions = ["from_dict", "to_dict"] # List the names of the functions you want to skip here exclusions = ["from_dict", "to_dict", "replace"] # List the names of the functions you want to skip here
if name in exclusions: if name in exclusions:
return True return True
return None return None
......
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.7.1" version = "0.7.2"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.7.1 (2023-07-10)
0.7.2 (2023-07-24)
~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added the method :meth:`replace` to the :class:`omni.isaac.orbit.utils.configclass` decorator to allow
creating a new configuration object with values replaced from keyword arguments. This function internally
calls the `dataclasses.replace <https://docs.python.org/3/library/dataclasses.html#dataclasses.replace>`_.
Fixed
^^^^^
* Fixed the handling of class types as member values in the :meth:`omni.isaac.orbit.utils.configclass`. Earlier it was
throwing an error since class types were skipped in the if-else block.
0.7.1 (2023-07-22)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
Added Added
...@@ -11,7 +29,7 @@ Added ...@@ -11,7 +29,7 @@ Added
to the :mod:`omni.isaac.orbit.managers` module to handle termination, curriculum, and randomization respectively. to the :mod:`omni.isaac.orbit.managers` module to handle termination, curriculum, and randomization respectively.
0.7.0 (2023-07-10) 0.7.0 (2023-07-22)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
Added Added
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
from copy import deepcopy from copy import deepcopy
from dataclasses import MISSING, Field, dataclass, field from dataclasses import MISSING, Field, dataclass, field, replace
from typing import Any, Callable, ClassVar, Dict from typing import Any, Callable, ClassVar, Dict, Type
from .dict import class_to_dict, update_class_from_dict from .dict import class_to_dict, update_class_from_dict
...@@ -16,6 +16,9 @@ from .dict import class_to_dict, update_class_from_dict ...@@ -16,6 +16,9 @@ from .dict import class_to_dict, update_class_from_dict
__all__ = ["configclass"] __all__ = ["configclass"]
_CONFIGCLASS_METHODS = ["to_dict", "from_dict", "replace"]
"""List of class methods added at runtime to dataclass."""
""" """
Wrapper around dataclass. Wrapper around dataclass.
""" """
...@@ -75,6 +78,7 @@ def configclass(cls, **kwargs): ...@@ -75,6 +78,7 @@ def configclass(cls, **kwargs):
# add helper functions for dictionary conversion # add helper functions for dictionary conversion
setattr(cls, "to_dict", _class_to_dict) setattr(cls, "to_dict", _class_to_dict)
setattr(cls, "from_dict", _update_class_from_dict) setattr(cls, "from_dict", _update_class_from_dict)
setattr(cls, "replace", _replace_class_with_kwargs)
# wrap around dataclass # wrap around dataclass
cls = dataclass(cls, **kwargs) cls = dataclass(cls, **kwargs)
# return wrapped class # return wrapped class
...@@ -113,6 +117,30 @@ def _update_class_from_dict(obj, data: Dict[str, Any]) -> None: ...@@ -113,6 +117,30 @@ def _update_class_from_dict(obj, data: Dict[str, Any]) -> None:
return update_class_from_dict(obj, data, _ns="") return update_class_from_dict(obj, data, _ns="")
def _replace_class_with_kwargs(obj: object, **kwargs) -> object:
"""Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True)
class C:
x: int
y: int
c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Args:
obj (object): The object to replace.
**kwargs: The fields to replace and their new values.
Returns:
object: The new object.
"""
return replace(obj, **kwargs)
""" """
Private helper functions. Private helper functions.
""" """
...@@ -156,7 +184,7 @@ def _add_annotation_types(cls): ...@@ -156,7 +184,7 @@ def _add_annotation_types(cls):
if key.startswith("__"): if key.startswith("__"):
continue continue
# skip class functions # skip class functions
if key in ["from_dict", "to_dict"]: if key in _CONFIGCLASS_METHODS:
continue continue
# check if key is already present # check if key is already present
if key in hints: if key in hints:
...@@ -175,6 +203,10 @@ def _add_annotation_types(cls): ...@@ -175,6 +203,10 @@ def _add_annotation_types(cls):
) )
# add type annotation # add type annotation
hints[key] = type(value) hints[key] = type(value)
elif key != value.__name__:
# note: we don't want to add type annotations for nested configclass. Thus, we check if
# the name of the type matches the name of the variable.
hints[key] = Type[value]
# Note: Do not change this line. `cls.__dict__.get("__annotations__", {})` is different from # Note: Do not change this line. `cls.__dict__.get("__annotations__", {})` is different from
# `cls.__annotations__` because of inheritance. # `cls.__annotations__` because of inheritance.
...@@ -203,8 +235,12 @@ def _process_mutable_types(cls): ...@@ -203,8 +235,12 @@ def _process_mutable_types(cls):
If the function is NOT used, the following value-error is returned: If the function is NOT used, the following value-error is returned:
ValueError: mutable default <class 'list'> for field pos is not allowed: use default_factory ValueError: mutable default <class 'list'> for field pos is not allowed: use default_factory
""" """
class_members = {} # note: Need to set this up in the same order as annotations. Otherwise, it
# complains about missing positional arguments.
ann = cls.__dict__.get("__annotations__", {})
# iterate over all class members and store them in a dictionary # iterate over all class members and store them in a dictionary
class_members = {}
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
# check if base is object # check if base is object
if base is object: if base is object:
...@@ -215,29 +251,27 @@ def _process_mutable_types(cls): ...@@ -215,29 +251,27 @@ def _process_mutable_types(cls):
if key.startswith("__"): if key.startswith("__"):
continue continue
# skip class functions # skip class functions
if key in ["from_dict", "to_dict"]: if key in _CONFIGCLASS_METHODS:
continue continue
# get class member # get class member
f = getattr(base, key) f = getattr(base, key)
# store class member # store class member if it is not a type or if it is already present in annotations
if not isinstance(f, type): if not isinstance(f, type) or key in ann:
class_members[key] = f class_members[key] = f
# iterate over base class data fields # iterate over base class data fields
# in previous call, things that became a dataclass field were removed from class members # in previous call, things that became a dataclass field were removed from class members
for key, f in base.__dict__.get("__dataclass_fields__", {}).items(): for key, f in base.__dict__.get("__dataclass_fields__", {}).items():
# store class member # store class member
if not isinstance(f, type) and key not in class_members:
class_members[key] = f class_members[key] = f
# note: Need to set this up in the same order as annotations. Otherwise, it
# complains about missing positional arguments.
ann = cls.__dict__.get("__annotations__", {})
# check that all annotations are present in class members # check that all annotations are present in class members
# note: mainly for debugging purposes # note: mainly for debugging purposes
if len(class_members) != len(ann): if len(class_members) != len(ann):
raise ValueError( raise ValueError(
f"Number of annotations ({len(ann)}) does not match number of class members ({len(class_members)})." f"In class '{cls.__name__}', number of annotations ({len(ann)}) does not match number of class "
" Please check that all class members have type annotations and a default value." f"members ({len(class_members)}). Please check that all class members have type annotations and/or "
" If you don't want to specify a default value, please use the literal `dataclasses.MISSING`." "a default value. If you don't want to specify a default value, please use the literal `dataclasses.MISSING`."
) )
# iterate over annotations and add field factory for mutable types # iterate over annotations and add field factory for mutable types
for key in ann: for key in ann:
......
...@@ -8,12 +8,58 @@ import os ...@@ -8,12 +8,58 @@ import os
import unittest import unittest
from dataclasses import MISSING, asdict, field from dataclasses import MISSING, asdict, field
from functools import wraps from functools import wraps
from typing import List from typing import ClassVar, List, Type
from omni.isaac.orbit.utils.configclass import configclass from omni.isaac.orbit.utils.configclass import configclass
from omni.isaac.orbit.utils.dict import class_to_dict, update_class_from_dict from omni.isaac.orbit.utils.dict import class_to_dict, update_class_from_dict
from omni.isaac.orbit.utils.io import dump_yaml, load_yaml from omni.isaac.orbit.utils.io import dump_yaml, load_yaml
"""
Mock classes and functions.
"""
def dummy_function1() -> int:
"""Dummy function 1."""
return 1
def dummy_function2() -> int:
"""Dummy function 2."""
return 2
def dummy_wrapper(func):
"""Decorator for wrapping function."""
@wraps(func)
def wrapper():
return func() + 1
return wrapper
@dummy_wrapper
def wrapped_dummy_function3():
"""Dummy function 3."""
return 3
@dummy_wrapper
def wrapped_dummy_function4():
"""Dummy function 4."""
return 4
class DummyClass:
"""Dummy class."""
def __init__(self):
"""Initialize dummy class."""
self.a = 1
self.b = 2
""" """
Dummy configuration: Basic Dummy configuration: Basic
""" """
...@@ -109,42 +155,55 @@ class ChildDemoCfg(ParentDemoCfg): ...@@ -109,42 +155,55 @@ class ChildDemoCfg(ParentDemoCfg):
k: List[str] = ["c", "d"] k: List[str] = ["c", "d"]
e: ViewerCfg = MISSING # add new missing field e: ViewerCfg = MISSING # add new missing field
dummy_class = DummyClass
""" """
Dummy configuration: Functions Configuration with class inside.
""" """
def dummy_function1() -> int: @configclass
"""Dummy function 1.""" class DummyClassCfg:
return 1 """Dummy class configuration with class type."""
class_name_1: type = DummyClass
class_name_2: Type[DummyClass] = DummyClass
class_name_3 = DummyClass
class_name_4: ClassVar[Type[DummyClass]] = DummyClass
def dummy_function2() -> int: b: str = "dummy"
"""Dummy function 2."""
return 2
def dummy_wrapper(func): """
"""Decorator for wrapping function.""" Configuration with nested classes.
"""
@wraps(func)
def wrapper():
return func() + 1
return wrapper @configclass
class OutsideClassCfg:
"""Outermost dummy configuration."""
@configclass
class InsideClassCfg:
"""Inner dummy configuration."""
@dummy_wrapper @configclass
def wrapped_dummy_function3(): class InsideInsideClassCfg:
"""Dummy function 3.""" """Dummy configuration with class type."""
return 3
u: List[int] = [1, 2, 3]
@dummy_wrapper class_name: type = DummyClass
def wrapped_dummy_function4(): b: str = "dummy"
"""Dummy function 4."""
return 4 inside: InsideClassCfg = InsideClassCfg()
x: int = 20
"""
Dummy configuration: Functions
"""
@configclass @configclass
...@@ -284,7 +343,6 @@ class TestConfigClass(unittest.TestCase): ...@@ -284,7 +343,6 @@ class TestConfigClass(unittest.TestCase):
cfg = BasicDemoCfg() cfg = BasicDemoCfg()
cfg_dict = {"env": {"num_envs": 22, "viewer": {"eye": (2.0, 2.0, 2.0)}}} cfg_dict = {"env": {"num_envs": 22, "viewer": {"eye": (2.0, 2.0, 2.0)}}}
cfg.from_dict(cfg_dict) cfg.from_dict(cfg_dict)
print("Updated config: ", cfg.to_dict())
self.assertDictEqual(cfg.to_dict(), basic_demo_cfg_change_correct) self.assertDictEqual(cfg.to_dict(), basic_demo_cfg_change_correct)
def test_invalid_update_key(self): def test_invalid_update_key(self):
...@@ -295,7 +353,7 @@ class TestConfigClass(unittest.TestCase): ...@@ -295,7 +353,7 @@ class TestConfigClass(unittest.TestCase):
update_class_from_dict(cfg, cfg_dict) update_class_from_dict(cfg, cfg_dict)
def test_multiple_instances(self): def test_multiple_instances(self):
"""Test multiple instances of the same configclass.""" """Test multiple instances with twice instantiation."""
# create two config instances # create two config instances
cfg1 = BasicDemoCfg() cfg1 = BasicDemoCfg()
cfg2 = BasicDemoCfg() cfg2 = BasicDemoCfg()
...@@ -308,6 +366,11 @@ class TestConfigClass(unittest.TestCase): ...@@ -308,6 +366,11 @@ class TestConfigClass(unittest.TestCase):
# immutable -- variables are the same # immutable -- variables are the same
self.assertEqual(id(cfg1.robot_default_state.dof_pos), id(cfg2.robot_default_state.dof_pos)) self.assertEqual(id(cfg1.robot_default_state.dof_pos), id(cfg2.robot_default_state.dof_pos))
self.assertEqual(id(cfg1.env.num_envs), id(cfg2.env.num_envs)) self.assertEqual(id(cfg1.env.num_envs), id(cfg2.env.num_envs))
self.assertEqual(id(cfg1.device_id), id(cfg2.device_id))
# check values
self.assertDictEqual(cfg1.env.to_dict(), cfg2.env.to_dict())
self.assertDictEqual(cfg1.robot_default_state.to_dict(), cfg2.robot_default_state.to_dict())
def test_alter_values_multiple_instances(self): def test_alter_values_multiple_instances(self):
"""Test alterations in multiple instances of the same configclass.""" """Test alterations in multiple instances of the same configclass."""
...@@ -331,6 +394,48 @@ class TestConfigClass(unittest.TestCase): ...@@ -331,6 +394,48 @@ class TestConfigClass(unittest.TestCase):
# immutable -- altered variables are different ids # immutable -- altered variables are different ids
self.assertNotEqual(id(cfg1.env.num_envs), id(cfg2.env.num_envs)) self.assertNotEqual(id(cfg1.env.num_envs), id(cfg2.env.num_envs))
def test_multiple_instances_with_replace(self):
"""Test multiple instances with creation through replace function."""
# create two config instances
cfg1 = BasicDemoCfg()
cfg2 = cfg1.replace()
# check variable IDs
# mutable -- variables should be different
self.assertNotEqual(id(cfg1.env.viewer.eye), id(cfg2.env.viewer.eye))
self.assertNotEqual(id(cfg1.env.viewer.lookat), id(cfg2.env.viewer.lookat))
self.assertNotEqual(id(cfg1.robot_default_state), id(cfg2.robot_default_state))
# immutable -- variables are the same
self.assertEqual(id(cfg1.robot_default_state.dof_pos), id(cfg2.robot_default_state.dof_pos))
self.assertEqual(id(cfg1.env.num_envs), id(cfg2.env.num_envs))
self.assertEqual(id(cfg1.device_id), id(cfg2.device_id))
# check values
self.assertDictEqual(cfg1.to_dict(), cfg2.to_dict())
def test_alter_values_multiple_instances_wth_replace(self):
"""Test alterations in multiple instances through replace function."""
# create two config instances
cfg1 = BasicDemoCfg()
cfg2 = cfg1.replace(device_id=1)
# alter configurations
cfg1.env.num_envs = 22 # immutable data: int
cfg1.env.viewer.eye[0] = 1.0 # mutable data: list
cfg1.env.viewer.lookat[2] = 12.0 # mutable data: list
# check variables
# values should be different
self.assertNotEqual(cfg1.env.num_envs, cfg2.env.num_envs)
self.assertNotEqual(cfg1.env.viewer.eye, cfg2.env.viewer.eye)
self.assertNotEqual(cfg1.env.viewer.lookat, cfg2.env.viewer.lookat)
# mutable -- variables are different ids
self.assertNotEqual(id(cfg1.env.viewer.eye), id(cfg2.env.viewer.eye))
self.assertNotEqual(id(cfg1.env.viewer.lookat), id(cfg2.env.viewer.lookat))
# immutable -- altered variables are different ids
self.assertNotEqual(id(cfg1.env.num_envs), id(cfg2.env.num_envs))
self.assertNotEqual(id(cfg1.device_id), id(cfg2.device_id))
def test_configclass_type_ordering(self): def test_configclass_type_ordering(self):
"""Checks ordering of config objects when no type annotation is provided.""" """Checks ordering of config objects when no type annotation is provided."""
...@@ -346,12 +451,14 @@ class TestConfigClass(unittest.TestCase): ...@@ -346,12 +451,14 @@ class TestConfigClass(unittest.TestCase):
def test_functions_config(self): def test_functions_config(self):
"""Tests having functions as values in the configuration instance.""" """Tests having functions as values in the configuration instance."""
cfg = FunctionsDemoCfg() cfg = FunctionsDemoCfg()
# check types
self.assertEqual(cfg.__annotations__["func"], type(dummy_function1))
self.assertEqual(cfg.__annotations__["wrapped_func"], type(wrapped_dummy_function3))
self.assertEqual(cfg.__annotations__["func_in_dict"], dict)
# check calling # check calling
self.assertEqual(cfg.func(), 1) self.assertEqual(cfg.func(), 1)
self.assertEqual(cfg.wrapped_func(), 4) self.assertEqual(cfg.wrapped_func(), 4)
self.assertEqual(cfg.func_in_dict["func"](), 1) self.assertEqual(cfg.func_in_dict["func"](), 1)
# print dictionary
print(class_to_dict(cfg))
def test_dict_conversion_functions_config(self): def test_dict_conversion_functions_config(self):
"""Tests conversion of config with functions into dictionary.""" """Tests conversion of config with functions into dictionary."""
...@@ -427,6 +534,38 @@ class TestConfigClass(unittest.TestCase): ...@@ -427,6 +534,38 @@ class TestConfigClass(unittest.TestCase):
self.assertEqual(cfg.d, 3) self.assertEqual(cfg.d, 3)
self.assertEqual(cfg.j, ["c", "d"]) self.assertEqual(cfg.j, ["c", "d"])
def test_config_with_class_type(self):
"""Tests that configclass works properly with class type."""
cfg = DummyClassCfg()
# check types
self.assertEqual(cfg.__annotations__["class_name_1"], type)
self.assertEqual(cfg.__annotations__["class_name_2"], Type[DummyClass])
self.assertEqual(cfg.__annotations__["class_name_3"], Type[DummyClass])
self.assertEqual(cfg.__annotations__["class_name_4"], ClassVar[Type[DummyClass]])
# check values
self.assertEqual(cfg.class_name_1, DummyClass)
self.assertEqual(cfg.class_name_2, DummyClass)
self.assertEqual(cfg.class_name_3, DummyClass)
self.assertEqual(cfg.class_name_4, DummyClass)
self.assertEqual(cfg.b, "dummy")
def test_nested_config_class_declarations(self):
"""Tests that configclass works properly with nested class class declarations."""
cfg = OutsideClassCfg()
# check types
self.assertNotIn("InsideClassCfg", cfg.__annotations__)
self.assertNotIn("InsideClassCfg", OutsideClassCfg.__annotations__)
self.assertNotIn("InsideInsideClassCfg", OutsideClassCfg.InsideClassCfg.__annotations__)
self.assertNotIn("InsideInsideClassCfg", cfg.inside.__annotations__)
# check values
self.assertEqual(cfg.inside.class_name, DummyClass)
self.assertEqual(cfg.inside.b, "dummy")
self.assertEqual(cfg.x, 20)
def test_config_dumping(self): def test_config_dumping(self):
"""Check that config dumping works properly.""" """Check that config dumping works properly."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment