Commit f9eca34c authored by Mayank Mittal's avatar Mayank Mittal

Adds support for class types in configclass (#92)

# Description

This MR adds checks to the `dataclass` wrapper called `configclass` to extend its support for types.  This supports type hinting annotations such as `type`, `Type[Myclass]`, and `ClassVar[type]`.

It also adds a method called `replace` to the configclass that calls the [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace) function. This has been added for the convenience of users.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file
parent 82cbdc22
......@@ -187,7 +187,7 @@ html_show_sphinx = False
def skip_member(app, what, name, obj, skip, options):
exclusions = ["from_dict", "to_dict"] # List the names of the functions you want to skip here
exclusions = ["from_dict", "to_dict", "replace"] # List the names of the functions you want to skip here
if name in exclusions:
return True
return None
......
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.7.1"
version = "0.7.2"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.7.1 (2023-07-10)
0.7.2 (2023-07-24)
~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added the method :meth:`replace` to the :class:`omni.isaac.orbit.utils.configclass` decorator to allow
creating a new configuration object with values replaced from keyword arguments. This function internally
calls the `dataclasses.replace <https://docs.python.org/3/library/dataclasses.html#dataclasses.replace>`_.
Fixed
^^^^^
* Fixed the handling of class types as member values in the :meth:`omni.isaac.orbit.utils.configclass`. Earlier it was
throwing an error since class types were skipped in the if-else block.
0.7.1 (2023-07-22)
~~~~~~~~~~~~~~~~~~
Added
......@@ -11,7 +29,7 @@ Added
to the :mod:`omni.isaac.orbit.managers` module to handle termination, curriculum, and randomization respectively.
0.7.0 (2023-07-10)
0.7.0 (2023-07-22)
~~~~~~~~~~~~~~~~~~
Added
......
......@@ -7,8 +7,8 @@
from copy import deepcopy
from dataclasses import MISSING, Field, dataclass, field
from typing import Any, Callable, ClassVar, Dict
from dataclasses import MISSING, Field, dataclass, field, replace
from typing import Any, Callable, ClassVar, Dict, Type
from .dict import class_to_dict, update_class_from_dict
......@@ -16,6 +16,9 @@ from .dict import class_to_dict, update_class_from_dict
__all__ = ["configclass"]
_CONFIGCLASS_METHODS = ["to_dict", "from_dict", "replace"]
"""List of class methods added at runtime to dataclass."""
"""
Wrapper around dataclass.
"""
......@@ -75,6 +78,7 @@ def configclass(cls, **kwargs):
# add helper functions for dictionary conversion
setattr(cls, "to_dict", _class_to_dict)
setattr(cls, "from_dict", _update_class_from_dict)
setattr(cls, "replace", _replace_class_with_kwargs)
# wrap around dataclass
cls = dataclass(cls, **kwargs)
# return wrapped class
......@@ -113,6 +117,30 @@ def _update_class_from_dict(obj, data: Dict[str, Any]) -> None:
return update_class_from_dict(obj, data, _ns="")
def _replace_class_with_kwargs(obj: object, **kwargs) -> object:
"""Return a new object replacing specified fields with new values.
This is especially useful for frozen classes. Example usage:
@configclass(frozen=True)
class C:
x: int
y: int
c = C(1, 2)
c1 = c.replace(x=3)
assert c1.x == 3 and c1.y == 2
Args:
obj (object): The object to replace.
**kwargs: The fields to replace and their new values.
Returns:
object: The new object.
"""
return replace(obj, **kwargs)
"""
Private helper functions.
"""
......@@ -156,7 +184,7 @@ def _add_annotation_types(cls):
if key.startswith("__"):
continue
# skip class functions
if key in ["from_dict", "to_dict"]:
if key in _CONFIGCLASS_METHODS:
continue
# check if key is already present
if key in hints:
......@@ -175,6 +203,10 @@ def _add_annotation_types(cls):
)
# add type annotation
hints[key] = type(value)
elif key != value.__name__:
# note: we don't want to add type annotations for nested configclass. Thus, we check if
# the name of the type matches the name of the variable.
hints[key] = Type[value]
# Note: Do not change this line. `cls.__dict__.get("__annotations__", {})` is different from
# `cls.__annotations__` because of inheritance.
......@@ -203,8 +235,12 @@ def _process_mutable_types(cls):
If the function is NOT used, the following value-error is returned:
ValueError: mutable default <class 'list'> for field pos is not allowed: use default_factory
"""
class_members = {}
# note: Need to set this up in the same order as annotations. Otherwise, it
# complains about missing positional arguments.
ann = cls.__dict__.get("__annotations__", {})
# iterate over all class members and store them in a dictionary
class_members = {}
for base in reversed(cls.__mro__):
# check if base is object
if base is object:
......@@ -215,29 +251,27 @@ def _process_mutable_types(cls):
if key.startswith("__"):
continue
# skip class functions
if key in ["from_dict", "to_dict"]:
if key in _CONFIGCLASS_METHODS:
continue
# get class member
f = getattr(base, key)
# store class member
if not isinstance(f, type):
# store class member if it is not a type or if it is already present in annotations
if not isinstance(f, type) or key in ann:
class_members[key] = f
# iterate over base class data fields
# in previous call, things that became a dataclass field were removed from class members
for key, f in base.__dict__.get("__dataclass_fields__", {}).items():
# store class member
class_members[key] = f
if not isinstance(f, type) and key not in class_members:
class_members[key] = f
# note: Need to set this up in the same order as annotations. Otherwise, it
# complains about missing positional arguments.
ann = cls.__dict__.get("__annotations__", {})
# check that all annotations are present in class members
# note: mainly for debugging purposes
if len(class_members) != len(ann):
raise ValueError(
f"Number of annotations ({len(ann)}) does not match number of class members ({len(class_members)})."
" Please check that all class members have type annotations and a default value."
" If you don't want to specify a default value, please use the literal `dataclasses.MISSING`."
f"In class '{cls.__name__}', number of annotations ({len(ann)}) does not match number of class "
f"members ({len(class_members)}). Please check that all class members have type annotations and/or "
"a default value. If you don't want to specify a default value, please use the literal `dataclasses.MISSING`."
)
# iterate over annotations and add field factory for mutable types
for key in ann:
......
......@@ -8,12 +8,58 @@ import os
import unittest
from dataclasses import MISSING, asdict, field
from functools import wraps
from typing import List
from typing import ClassVar, List, Type
from omni.isaac.orbit.utils.configclass import configclass
from omni.isaac.orbit.utils.dict import class_to_dict, update_class_from_dict
from omni.isaac.orbit.utils.io import dump_yaml, load_yaml
"""
Mock classes and functions.
"""
def dummy_function1() -> int:
"""Dummy function 1."""
return 1
def dummy_function2() -> int:
"""Dummy function 2."""
return 2
def dummy_wrapper(func):
"""Decorator for wrapping function."""
@wraps(func)
def wrapper():
return func() + 1
return wrapper
@dummy_wrapper
def wrapped_dummy_function3():
"""Dummy function 3."""
return 3
@dummy_wrapper
def wrapped_dummy_function4():
"""Dummy function 4."""
return 4
class DummyClass:
"""Dummy class."""
def __init__(self):
"""Initialize dummy class."""
self.a = 1
self.b = 2
"""
Dummy configuration: Basic
"""
......@@ -109,42 +155,55 @@ class ChildDemoCfg(ParentDemoCfg):
k: List[str] = ["c", "d"]
e: ViewerCfg = MISSING # add new missing field
dummy_class = DummyClass
"""
Dummy configuration: Functions
Configuration with class inside.
"""
def dummy_function1() -> int:
"""Dummy function 1."""
return 1
@configclass
class DummyClassCfg:
"""Dummy class configuration with class type."""
class_name_1: type = DummyClass
class_name_2: Type[DummyClass] = DummyClass
class_name_3 = DummyClass
class_name_4: ClassVar[Type[DummyClass]] = DummyClass
def dummy_function2() -> int:
"""Dummy function 2."""
return 2
b: str = "dummy"
def dummy_wrapper(func):
"""Decorator for wrapping function."""
"""
Configuration with nested classes.
"""
@wraps(func)
def wrapper():
return func() + 1
return wrapper
@configclass
class OutsideClassCfg:
"""Outermost dummy configuration."""
@configclass
class InsideClassCfg:
"""Inner dummy configuration."""
@dummy_wrapper
def wrapped_dummy_function3():
"""Dummy function 3."""
return 3
@configclass
class InsideInsideClassCfg:
"""Dummy configuration with class type."""
u: List[int] = [1, 2, 3]
@dummy_wrapper
def wrapped_dummy_function4():
"""Dummy function 4."""
return 4
class_name: type = DummyClass
b: str = "dummy"
inside: InsideClassCfg = InsideClassCfg()
x: int = 20
"""
Dummy configuration: Functions
"""
@configclass
......@@ -284,7 +343,6 @@ class TestConfigClass(unittest.TestCase):
cfg = BasicDemoCfg()
cfg_dict = {"env": {"num_envs": 22, "viewer": {"eye": (2.0, 2.0, 2.0)}}}
cfg.from_dict(cfg_dict)
print("Updated config: ", cfg.to_dict())
self.assertDictEqual(cfg.to_dict(), basic_demo_cfg_change_correct)
def test_invalid_update_key(self):
......@@ -295,7 +353,7 @@ class TestConfigClass(unittest.TestCase):
update_class_from_dict(cfg, cfg_dict)
def test_multiple_instances(self):
"""Test multiple instances of the same configclass."""
"""Test multiple instances with twice instantiation."""
# create two config instances
cfg1 = BasicDemoCfg()
cfg2 = BasicDemoCfg()
......@@ -308,6 +366,11 @@ class TestConfigClass(unittest.TestCase):
# immutable -- variables are the same
self.assertEqual(id(cfg1.robot_default_state.dof_pos), id(cfg2.robot_default_state.dof_pos))
self.assertEqual(id(cfg1.env.num_envs), id(cfg2.env.num_envs))
self.assertEqual(id(cfg1.device_id), id(cfg2.device_id))
# check values
self.assertDictEqual(cfg1.env.to_dict(), cfg2.env.to_dict())
self.assertDictEqual(cfg1.robot_default_state.to_dict(), cfg2.robot_default_state.to_dict())
def test_alter_values_multiple_instances(self):
"""Test alterations in multiple instances of the same configclass."""
......@@ -331,6 +394,48 @@ class TestConfigClass(unittest.TestCase):
# immutable -- altered variables are different ids
self.assertNotEqual(id(cfg1.env.num_envs), id(cfg2.env.num_envs))
def test_multiple_instances_with_replace(self):
"""Test multiple instances with creation through replace function."""
# create two config instances
cfg1 = BasicDemoCfg()
cfg2 = cfg1.replace()
# check variable IDs
# mutable -- variables should be different
self.assertNotEqual(id(cfg1.env.viewer.eye), id(cfg2.env.viewer.eye))
self.assertNotEqual(id(cfg1.env.viewer.lookat), id(cfg2.env.viewer.lookat))
self.assertNotEqual(id(cfg1.robot_default_state), id(cfg2.robot_default_state))
# immutable -- variables are the same
self.assertEqual(id(cfg1.robot_default_state.dof_pos), id(cfg2.robot_default_state.dof_pos))
self.assertEqual(id(cfg1.env.num_envs), id(cfg2.env.num_envs))
self.assertEqual(id(cfg1.device_id), id(cfg2.device_id))
# check values
self.assertDictEqual(cfg1.to_dict(), cfg2.to_dict())
def test_alter_values_multiple_instances_wth_replace(self):
"""Test alterations in multiple instances through replace function."""
# create two config instances
cfg1 = BasicDemoCfg()
cfg2 = cfg1.replace(device_id=1)
# alter configurations
cfg1.env.num_envs = 22 # immutable data: int
cfg1.env.viewer.eye[0] = 1.0 # mutable data: list
cfg1.env.viewer.lookat[2] = 12.0 # mutable data: list
# check variables
# values should be different
self.assertNotEqual(cfg1.env.num_envs, cfg2.env.num_envs)
self.assertNotEqual(cfg1.env.viewer.eye, cfg2.env.viewer.eye)
self.assertNotEqual(cfg1.env.viewer.lookat, cfg2.env.viewer.lookat)
# mutable -- variables are different ids
self.assertNotEqual(id(cfg1.env.viewer.eye), id(cfg2.env.viewer.eye))
self.assertNotEqual(id(cfg1.env.viewer.lookat), id(cfg2.env.viewer.lookat))
# immutable -- altered variables are different ids
self.assertNotEqual(id(cfg1.env.num_envs), id(cfg2.env.num_envs))
self.assertNotEqual(id(cfg1.device_id), id(cfg2.device_id))
def test_configclass_type_ordering(self):
"""Checks ordering of config objects when no type annotation is provided."""
......@@ -346,12 +451,14 @@ class TestConfigClass(unittest.TestCase):
def test_functions_config(self):
"""Tests having functions as values in the configuration instance."""
cfg = FunctionsDemoCfg()
# check types
self.assertEqual(cfg.__annotations__["func"], type(dummy_function1))
self.assertEqual(cfg.__annotations__["wrapped_func"], type(wrapped_dummy_function3))
self.assertEqual(cfg.__annotations__["func_in_dict"], dict)
# check calling
self.assertEqual(cfg.func(), 1)
self.assertEqual(cfg.wrapped_func(), 4)
self.assertEqual(cfg.func_in_dict["func"](), 1)
# print dictionary
print(class_to_dict(cfg))
def test_dict_conversion_functions_config(self):
"""Tests conversion of config with functions into dictionary."""
......@@ -427,6 +534,38 @@ class TestConfigClass(unittest.TestCase):
self.assertEqual(cfg.d, 3)
self.assertEqual(cfg.j, ["c", "d"])
def test_config_with_class_type(self):
"""Tests that configclass works properly with class type."""
cfg = DummyClassCfg()
# check types
self.assertEqual(cfg.__annotations__["class_name_1"], type)
self.assertEqual(cfg.__annotations__["class_name_2"], Type[DummyClass])
self.assertEqual(cfg.__annotations__["class_name_3"], Type[DummyClass])
self.assertEqual(cfg.__annotations__["class_name_4"], ClassVar[Type[DummyClass]])
# check values
self.assertEqual(cfg.class_name_1, DummyClass)
self.assertEqual(cfg.class_name_2, DummyClass)
self.assertEqual(cfg.class_name_3, DummyClass)
self.assertEqual(cfg.class_name_4, DummyClass)
self.assertEqual(cfg.b, "dummy")
def test_nested_config_class_declarations(self):
"""Tests that configclass works properly with nested class class declarations."""
cfg = OutsideClassCfg()
# check types
self.assertNotIn("InsideClassCfg", cfg.__annotations__)
self.assertNotIn("InsideClassCfg", OutsideClassCfg.__annotations__)
self.assertNotIn("InsideInsideClassCfg", OutsideClassCfg.InsideClassCfg.__annotations__)
self.assertNotIn("InsideInsideClassCfg", cfg.inside.__annotations__)
# check values
self.assertEqual(cfg.inside.class_name, DummyClass)
self.assertEqual(cfg.inside.b, "dummy")
self.assertEqual(cfg.x, 20)
def test_config_dumping(self):
"""Check that config dumping works properly."""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment