Unverified Commit 5931b15a authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Fixes ordering of terms in configclass when no type annotation is present (#76)

# Description

Previously, type annotation was always required to make the terms follow
the order in which they are defined in the configclass. If this was not
done, then the terms were getting sorted alphabetically which made it
different from the expected behavior (user-defined order).

On further inspection, turned out that in our wrappers for configclass,
we were using `dir(cls)` to parse the class members, which sorts all the
members of the class alphabetically. Changing it to `cls.__dict__` fixed
this issue since in Python 3.7 onwards, dictionaries follow the
user-defined ordering.

Since this behavior changes the way config terms are parsed, the old
configclass still exists inside the
`omni.isaac.orbit.compat.utils.configclass` module so that people can
still run policies trained with the old ordering.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- Breaking change (fix or feature that would cause existing
functionality to not work as expected)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
parent d0310cda
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.5.0"
version = "0.6.0"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.6.0 (2023-07-16)
~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added the argument :attr:`sort_keys` to the :meth:`omni.isaac.orbit.utils.io.yaml.dump_yaml` method to allow
enabling/disabling of sorting of keys in the output yaml file.
Fixed
^^^^^
* Fixed the ordering of terms in :mod:`omni.isaac.orbit.core.utils.configclass` to be consistent in the order in which
they are defined. Previously, the ordering was done alphabetically which made it inconsistent with the order in which
the parameters were defined.
Changed
^^^^^^^
* Changed the default value of the argument :attr:`sort_keys` in the :meth:`omni.isaac.orbit.utils.io.yaml.dump_yaml`
method to ``False``.
* Moved the old config classes in :mod:`omni.isaac.orbit.core.utils.configclass` to
:mod:`omni.isaac.orbit.compat.utils.configclass` so that users can still run their old code where alphabetical
ordering was used.
0.5.0 (2023-07-04)
Added
......
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES, ETH Zurich, and University of Toronto
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from .configclass import configclass
__all__ = [
# config wrapper
"configclass",
]
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES, ETH Zurich, and University of Toronto
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Wrapper around the Python 3.7 onwards `dataclasses` module."""
from copy import deepcopy
from dataclasses import Field, dataclass, field
from typing import Any, Callable, ClassVar, Dict
from .dict import class_to_dict, update_class_from_dict
# List of all methods provided by sub-module.
__all__ = ["configclass"]
"""
Wrapper around dataclass.
"""
def __dataclass_transform__():
"""Add annotations decorator for PyLance."""
return lambda a: a
@__dataclass_transform__()
def configclass(cls, **kwargs):
"""Wrapper around `dataclass` functionality to add extra checks and utilities.
As of Python3.8, the standard dataclasses have two main issues which makes them non-generic for configuration use-cases.
These include:
1. Requiring a type annotation for all its members.
2. Requiring explicit usage of :meth:`field(default_factory=...)` to reinitialize mutable variables.
This function wraps around :class:`dataclass` utility to deal with the above two issues.
Usage:
.. code-block:: python
from dataclasses import MISSING
from omni.isaac.orbit.utils.configclass import configclass
@configclass
class ViewerCfg:
eye: list = [7.5, 7.5, 7.5] # field missing on purpose
lookat: list = field(default_factory=[0.0, 0.0, 0.0])
@configclass
class EnvCfg:
num_envs: int = MISSING
episode_length: int = 2000
viewer: ViewerCfg = ViewerCfg()
# create configuration instance
env_cfg = EnvCfg(num_envs=24)
# print information
print(env_cfg.to_dict())
Reference:
https://docs.python.org/3/library/dataclasses.html#dataclasses.Field
"""
# add type annotations
_add_annotation_types(cls)
# add field factory
_process_mutable_types(cls)
# copy mutable members
setattr(cls, "__post_init__", _custom_post_init)
# add helper functions for dictionary conversion
setattr(cls, "to_dict", _class_to_dict)
setattr(cls, "from_dict", _update_class_from_dict)
# wrap around dataclass
cls = dataclass(cls, **kwargs)
# return wrapped class
return cls
"""
Dictionary <-> Class operations.
These are redefined here to add new docstrings.
"""
def _class_to_dict(obj: object) -> Dict[str, Any]:
"""Convert an object into dictionary recursively.
Returns:
Dict[str, Any]: Converted dictionary mapping.
"""
return class_to_dict(obj)
def _update_class_from_dict(obj, data: Dict[str, Any]) -> None:
"""Reads a dictionary and sets object variables recursively.
This function performs in-place update of the class member attributes.
Args:
data (Dict[str, Any]): Input (nested) dictionary to update from.
Raises:
TypeError: When input is not a dictionary.
ValueError: When dictionary has a value that does not match default config type.
KeyError: When dictionary has a key that does not exist in the default config type.
"""
return update_class_from_dict(obj, data, _ns="")
"""
Private helper functions.
"""
def _add_annotation_types(cls):
"""Add annotations to all elements in the dataclass.
By definition in Python, a field is defined as a class variable that has a type annotation.
In case type annotations are not provided, dataclass ignores those members when :func:`__dict__()` is called.
This function adds these annotations to the class variable to prevent any issues in case the user forgets to
specify the type annotation.
This makes the following a feasible operation:
@dataclass
class State:
pos = (0.0, 0.0, 0.0)
^^
If the function is NOT used, the following type-error is returned:
TypeError: 'pos' is a field but has no type annotation
"""
# Note: Do not change this line. `cls.__dict__.get("__annotations__", {})` is different from
# `cls.__annotations__` because of inheritance.
cls.__annotations__ = cls.__dict__.get("__annotations__", {})
# cls.__annotations__ = dict()
for key in dir(cls):
# skip dunder members
if key.startswith("__"):
continue
# skip class functions
if key in ["from_dict", "to_dict"]:
continue
# add type annotations for members that are not functions
var = getattr(cls, key)
if not isinstance(var, type):
if key not in cls.__annotations__:
cls.__annotations__[key] = type(var)
def _process_mutable_types(cls):
"""Initialize all mutable elements through :obj:`dataclasses.Field` to avoid unnecessary complaints.
By default, dataclass requires usage of :obj:`field(default_factory=...)` to reinitialize mutable objects every time a new
class instance is created. If a member has a mutable type and it is created without specifying the `field(default_factory=...)`,
then Python throws an error requiring the usage of `default_factory`.
Additionally, Python only explicitly checks for field specification when the type is a list, set or dict. This misses the
use-case where the type is class itself. Thus, the code silently carries a bug with it which can lead to undesirable effects.
This function deals with this issue
This makes the following a feasible operation:
@dataclass
class State:
pos: list = [0.0, 0.0, 0.0]
^^
If the function is NOT used, the following value-error is returned:
ValueError: mutable default <class 'list'> for field pos is not allowed: use default_factory
"""
def _return_f(f: Any) -> Callable[[], Any]:
"""Returns default function for creating mutable/immutable variables."""
def _wrap():
if isinstance(f, Field):
return f.default_factory
else:
return f
return _wrap
for key in dir(cls):
# skip dunder members
if key.startswith("__"):
continue
# skip class functions
if key in ["from_dict", "to_dict"]:
continue
# do not create field for class variables
if key in cls.__annotations__:
origin = getattr(cls.__annotations__[key], "__origin__", None)
if origin is ClassVar:
continue
# define explicit field for data members
f = getattr(cls, key)
# add field for mutable types
if not isinstance(f, type):
f = field(default_factory=_return_f(f))
setattr(cls, key, f)
def _custom_post_init(obj):
"""Deepcopy all elements to avoid shared memory issues for mutable objects in dataclasses initialization.
This function is called explicitly instead of as a part of :func:`_process_mutable_types()` to prevent mapping
proxy type i.e. a read only proxy for mapping objects. The error is thrown when using hierarchical data-classes
for configuration.
"""
for key in dir(obj):
# skip dunder members
if key.startswith("__"):
continue
# duplicate data members
var = getattr(obj, key)
if not callable(var):
setattr(obj, key, deepcopy(var))
......@@ -7,7 +7,7 @@
from copy import deepcopy
from dataclasses import Field, dataclass, field
from dataclasses import MISSING, Field, dataclass, field
from typing import Any, Callable, ClassVar, Dict
from .dict import class_to_dict, update_class_from_dict
......@@ -136,21 +136,50 @@ def _add_annotation_types(cls):
If the function is NOT used, the following type-error is returned:
TypeError: 'pos' is a field but has no type annotation
"""
# Note: Do not change this line. `cls.__dict__.get("__annotations__", {})` is different from `cls.__annotations__` because of inheritance.
cls.__annotations__ = cls.__dict__.get("__annotations__", {})
# cls.__annotations__ = dict()
for key in dir(cls):
# skip dunder members
if key.startswith("__"):
# get type hints
hints = {}
# iterate over class inheritance
# we add annotations from base classes first
for base in reversed(cls.__mro__):
# check if base is object
if base is object:
continue
# skip class functions
if key in ["from_dict", "to_dict"]:
continue
# add type annotations for members that are not functions
var = getattr(cls, key)
if not isinstance(var, type):
if key not in cls.__annotations__:
cls.__annotations__[key] = type(var)
# get base class annotations
ann = base.__dict__.get("__annotations__", {})
# directly add all annotations from base class
hints.update(ann)
# iterate over base class members
# Note: Do not change this to dir(base) since it orders the members alphabetically.
# This is not desirable since the order of the members is important in some cases.
for key in base.__dict__:
# skip dunder members
if key.startswith("__"):
continue
# skip class functions
if key in ["from_dict", "to_dict"]:
continue
# check if key is already present
if key in hints:
continue
# add type annotations for members that don't have explicit type annotations
# for these, we deduce the type from the default value
value = getattr(base, key)
if not isinstance(value, type):
if key not in hints:
# check if var type is not MISSING
# we cannot deduce type from MISSING!
if value is MISSING:
raise TypeError(
f"Missing type annotation for '{key}' in class '{cls.__name__}'."
" Please add a type annotation or set a default value."
)
# add type annotation
hints[key] = type(value)
# Note: Do not change this line. `cls.__dict__.get("__annotations__", {})` is different from
# `cls.__annotations__` because of inheritance.
cls.__annotations__ = cls.__dict__.get("__annotations__", {})
cls.__annotations__ = hints
def _process_mutable_types(cls):
......@@ -174,35 +203,64 @@ def _process_mutable_types(cls):
If the function is NOT used, the following value-error is returned:
ValueError: mutable default <class 'list'> for field pos is not allowed: use default_factory
"""
def _return_f(f: Any) -> Callable[[], Any]:
"""Returns default function for creating mutable/immutable variables."""
def _wrap():
if isinstance(f, Field):
return f.default_factory
else:
return f
return _wrap
for key in dir(cls):
# skip dunder members
if key.startswith("__"):
class_members = {}
# iterate over all class members and store them in a dictionary
for base in reversed(cls.__mro__):
# check if base is object
if base is object:
continue
# skip class functions
if key in ["from_dict", "to_dict"]:
continue
# do not create field for class variables
if key in cls.__annotations__:
origin = getattr(cls.__annotations__[key], "__origin__", None)
if origin is ClassVar:
# iterate over base class members
for key in base.__dict__:
# skip dunder members
if key.startswith("__"):
continue
# skip class functions
if key in ["from_dict", "to_dict"]:
continue
# define explicit field for data members
f = getattr(cls, key)
if not isinstance(f, type):
f = field(default_factory=_return_f(f))
setattr(cls, key, f)
# get class member
f = getattr(base, key)
# store class member
if not isinstance(f, type):
class_members[key] = f
# iterate over base class data fields
# in previous call, things that became a dataclass field were removed from class members
for key, f in base.__dict__.get("__dataclass_fields__", {}).items():
# store class member
class_members[key] = f
# note: Need to set this up in the same order as annotations. Otherwise, it
# complains about missing positional arguments.
ann = cls.__dict__.get("__annotations__", {})
# check that all annotations are present in class members
# note: mainly for debugging purposes
if len(class_members) != len(ann):
raise ValueError(
f"Number of annotations ({len(ann)}) does not match number of class members ({len(class_members)})."
" Please check that all class members have type annotations and a default value."
" If you don't want to specify a default value, please use the literal `dataclasses.MISSING`."
)
# iterate over annotations and add field factory for mutable types
for key in ann:
# find matching field in class
value = class_members.get(key, MISSING)
# check if key belongs to ClassVar
# in that case, we cannot use default_factory!
origin = getattr(ann[key], "__origin__", None)
if origin is ClassVar:
continue
# check if f is MISSING
# note: commented out for now since it causes issue with inheritance
# of dataclasses when parent have some positional and some keyword arguments.
# Ref: https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
# TODO: check if this is fixed in Python 3.10
# if f is MISSING:
# continue
if isinstance(value, Field):
setattr(cls, key, value)
elif not isinstance(value, type):
# create field factory for mutable types
value = field(default_factory=_return_f(value))
setattr(cls, key, value)
def _custom_post_init(obj):
......@@ -216,7 +274,38 @@ def _custom_post_init(obj):
# skip dunder members
if key.startswith("__"):
continue
# get data member
value = getattr(obj, key)
# duplicate data members
var = getattr(obj, key)
if not callable(var):
setattr(obj, key, deepcopy(var))
if not callable(value):
setattr(obj, key, deepcopy(value))
"""
Helper functions
"""
def _return_f(f: Any) -> Callable[[], Any]:
"""Returns default factory function for creating mutable/immutable variables.
This function should be used to create default factory functions for variables.
Example:
.. code-block:: python
value = field(default_factory=_return_f(value))
setattr(cls, key, value)
"""
def _wrap():
if isinstance(f, Field):
if f.default_factory is MISSING:
return deepcopy(f.default)
else:
return f.default_factory
else:
return f
return _wrap
......@@ -31,7 +31,7 @@ def load_yaml(filename: str) -> Dict:
return data
def dump_yaml(filename: str, data: Union[Dict, object]):
def dump_yaml(filename: str, data: Union[Dict, object], sort_keys: bool = False):
"""Saves data into a YAML file safely.
Note:
......@@ -40,6 +40,7 @@ def dump_yaml(filename: str, data: Union[Dict, object]):
Args:
filename (str): The path to save the file at.
data (Union[Dict, object]): The data to save either a dictionary or class object.
sort_keys (bool, optional): Whether to sort the keys in the output file. Defaults to False.
"""
# check ending
if not filename.endswith("yaml"):
......@@ -52,4 +53,4 @@ def dump_yaml(filename: str, data: Union[Dict, object]):
data = class_to_dict(data)
# save data
with open(filename, "w") as f:
yaml.dump(data, f, default_flow_style=None)
yaml.dump(data, f, default_flow_style=None, sort_keys=sort_keys)
......@@ -4,12 +4,15 @@
# SPDX-License-Identifier: BSD-3-Clause
import copy
import os
import unittest
from dataclasses import asdict, field
from dataclasses import MISSING, asdict, field
from functools import wraps
from typing import List
from omni.isaac.orbit.utils.configclass import configclass
from omni.isaac.orbit.utils.dict import class_to_dict, update_class_from_dict
from omni.isaac.orbit.utils.io import dump_yaml, load_yaml
"""
Dummy configuration: Basic
......@@ -24,7 +27,7 @@ def double(x):
@configclass
class ViewerCfg:
eye: list = [7.5, 7.5, 7.5] # field missing on purpose
lookat: list = field(default_factory=[0.0, 0.0, 0.0])
lookat: list = field(default_factory=lambda: [0.0, 0.0, 0.0])
@configclass
......@@ -51,6 +54,62 @@ class BasicDemoCfg:
robot_default_state: RobotDefaultStateCfg = RobotDefaultStateCfg()
"""
Dummy configuration to check type annotations ordering.
"""
@configclass
class TypeAnnotationOrderingDemoCfg:
"""Config class with type annotations."""
anymal: RobotDefaultStateCfg = RobotDefaultStateCfg()
unitree: RobotDefaultStateCfg = RobotDefaultStateCfg()
franka: RobotDefaultStateCfg = RobotDefaultStateCfg()
@configclass
class NonTypeAnnotationOrderingDemoCfg:
"""Config class without type annotations."""
anymal = RobotDefaultStateCfg()
unitree = RobotDefaultStateCfg()
franka = RobotDefaultStateCfg()
@configclass
class InheritedNonTypeAnnotationOrderingDemoCfg(NonTypeAnnotationOrderingDemoCfg):
"""Inherited config class without type annotations."""
pass
"""
Dummy configuration: Inheritance
"""
@configclass
class ParentDemoCfg:
"""Dummy parent configuration with missing fields."""
a: int = MISSING # add new missing field
b = 2 # type annotation missing on purpose
c: RobotDefaultStateCfg = MISSING # add new missing field
j: List[str] = MISSING # add new missing field
@configclass
class ChildDemoCfg(ParentDemoCfg):
"""Dummy child configuration with missing fields."""
c = RobotDefaultStateCfg() # set default value for missing field
d: int = MISSING # add new missing field
k: List[str] = ["c", "d"]
e: ViewerCfg = MISSING # add new missing field
"""
Dummy configuration: Functions
"""
......@@ -159,6 +218,7 @@ class TestConfigClass(unittest.TestCase):
print()
print("Using dataclass function: ", asdict(cfg))
print("Using internal function: ", cfg.to_dict())
self.assertDictEqual(asdict(cfg), cfg.to_dict())
def test_dict_conversion(self):
"""Test dictionary conversion of configclass instance."""
......@@ -271,6 +331,18 @@ class TestConfigClass(unittest.TestCase):
# immutable -- altered variables are different ids
self.assertNotEqual(id(cfg1.env.num_envs), id(cfg2.env.num_envs))
def test_configclass_type_ordering(self):
"""Checks ordering of config objects when no type annotation is provided."""
cfg_1 = TypeAnnotationOrderingDemoCfg()
cfg_2 = NonTypeAnnotationOrderingDemoCfg()
cfg_3 = InheritedNonTypeAnnotationOrderingDemoCfg()
# check ordering
self.assertEqual(list(cfg_1.__dict__.keys()), list(cfg_2.__dict__.keys()))
self.assertEqual(list(cfg_3.__dict__.keys()), list(cfg_2.__dict__.keys()))
self.assertEqual(list(cfg_1.__dict__.keys()), list(cfg_3.__dict__.keys()))
def test_functions_config(self):
"""Tests having functions as values in the configuration instance."""
cfg = FunctionsDemoCfg()
......@@ -299,6 +371,89 @@ class TestConfigClass(unittest.TestCase):
self.assertEqual(cfg.wrapped_func(), 5)
self.assertEqual(cfg.func_in_dict["func"](), 2)
def test_missing_type_in_config(self):
"""Tests missing type annotation in config.
Should complain that 'c' is missing type annotation since it cannot be inferred
from 'MISSING' value.
"""
with self.assertRaises(TypeError):
@configclass
class MissingTypeDemoCfg:
a: int = 1
b = 2
c = MISSING
def test_missing_default_value_in_config(self):
"""Tests missing default value in config.
Should complain that 'a' is missing default value since it cannot be inferred
from type annotation.
"""
with self.assertRaises(ValueError):
@configclass
class MissingTypeDemoCfg:
a: int
b = 2
def test_required_argument_for_missing_type_in_config(self):
"""Tests required positional argument for missing type annotation in config creation."""
@configclass
class MissingTypeDemoCfg:
a: int = 1
b = 2
c: int = MISSING
# should complain that 'c' is missed in positional arguments
# TODO: Uncomment this when we move to 3.10.
# with self.assertRaises(TypeError):
# cfg = MissingTypeDemoCfg(a=1)
# should not complain
cfg = MissingTypeDemoCfg(a=1, c=3)
self.assertEqual(cfg.a, 1)
self.assertEqual(cfg.b, 2)
def test_config_inheritance(self):
"""Tests that inheritance works properly."""
# check variables
cfg = ChildDemoCfg(a=20, d=3, e=ViewerCfg(), j=["c", "d"])
self.assertEqual(cfg.a, 20)
self.assertEqual(cfg.b, 2)
self.assertEqual(cfg.d, 3)
self.assertEqual(cfg.j, ["c", "d"])
def test_config_dumping(self):
"""Check that config dumping works properly."""
# file for dumping
dirname = os.path.dirname(os.path.abspath(__file__))
filename = os.path.join(dirname, "output", "configclass", "test_config.yaml")
# create config
cfg = ChildDemoCfg(a=20, d=3, e=ViewerCfg(), j=["c", "d"])
# save config
dump_yaml(filename, cfg)
# load config
cfg_loaded = load_yaml(filename)
# check dictionaries are the same
self.assertEqual(list(cfg.to_dict().keys()), list(cfg_loaded.keys()))
self.assertDictEqual(cfg.to_dict(), cfg_loaded)
# save config with sorted order won't work!
# save config
dump_yaml(filename, cfg, sort_keys=True)
# load config
cfg_loaded = load_yaml(filename)
# check dictionaries are the same
self.assertNotEqual(list(cfg.to_dict().keys()), list(cfg_loaded.keys()))
self.assertDictEqual(cfg.to_dict(), cfg_loaded)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment