Unverified Commit c9f6ac57 authored by lgulich's avatar lgulich Committed by GitHub

Fixes configclass dict conversion for torch tensors (#1530)

# Description

Fix configclass dict conversion for torch tensors

Up to v1.2.0 if a configclass would contain a list/tuple of torch
tensors it would be left as is.

\#1227 changed the behavior of converting lists/tuples in a dict, which
means that currently torch tensors are converted to an empty dict,
effectively losing all contained data.

The underlying issue is that `torch.tensor.__dict__` returns an empty
dict, which was (luckily) ignored previously because we did not convert
the contents of lists.

This MR fixes this by treating torch tensors specially. I don't like
having a special case for a non-builtin class but given that
IsaacLab is heavily married with torch tensors I think it's ok in this
case.

Since currently the behavior is different between 1.2 and 1.3: can we
cherry pick this change to the 1.3 branch?

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 37e0a798
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
import collections.abc import collections.abc
import hashlib import hashlib
import json import json
import torch
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import Any from typing import Any
...@@ -40,6 +41,11 @@ def class_to_dict(obj: object) -> dict[str, Any]: ...@@ -40,6 +41,11 @@ def class_to_dict(obj: object) -> dict[str, Any]:
# convert object to dictionary # convert object to dictionary
if isinstance(obj, dict): if isinstance(obj, dict):
obj_dict = obj obj_dict = obj
elif isinstance(obj, torch.Tensor):
# We have to treat torch tensors specially because `torch.tensor.__dict__` returns an empty
# dict, which would mean that a torch.tensor would be stored as an empty dict. Instead we
# want to store it directly as the tensor.
return obj
elif hasattr(obj, "__dict__"): elif hasattr(obj, "__dict__"):
obj_dict = obj.__dict__ obj_dict = obj.__dict__
else: else:
...@@ -57,6 +63,7 @@ def class_to_dict(obj: object) -> dict[str, Any]: ...@@ -57,6 +63,7 @@ def class_to_dict(obj: object) -> dict[str, Any]:
# check if attribute is a dictionary # check if attribute is a dictionary
elif hasattr(value, "__dict__") or isinstance(value, dict): elif hasattr(value, "__dict__") or isinstance(value, dict):
data[key] = class_to_dict(value) data[key] = class_to_dict(value)
# check if attribute is a list or tuple
elif isinstance(value, (list, tuple)): elif isinstance(value, (list, tuple)):
data[key] = type(value)([class_to_dict(v) for v in value]) data[key] = type(value)([class_to_dict(v) for v in value])
else: else:
......
...@@ -19,6 +19,7 @@ simulation_app = app_launcher.app ...@@ -19,6 +19,7 @@ simulation_app = app_launcher.app
import copy import copy
import os import os
import torch
import unittest import unittest
from collections.abc import Callable from collections.abc import Callable
from dataclasses import MISSING, asdict, field from dataclasses import MISSING, asdict, field
...@@ -134,6 +135,14 @@ class BasicDemoPostInitCfg: ...@@ -134,6 +135,14 @@ class BasicDemoPostInitCfg:
self.add_variable = 3 self.add_variable = 3
@configclass
class BasicDemoTorchCfg:
"""Dummy configuration class with a torch tensor ."""
some_number: int = 0
some_tensor: torch.Tensor = torch.Tensor([1, 2, 3])
""" """
Dummy configuration to check type annotations ordering. Dummy configuration to check type annotations ordering.
""" """
...@@ -515,6 +524,12 @@ class TestConfigClass(unittest.TestCase): ...@@ -515,6 +524,12 @@ class TestConfigClass(unittest.TestCase):
self.assertDictEqual(cfg.to_dict(), basic_demo_cfg_correct) self.assertDictEqual(cfg.to_dict(), basic_demo_cfg_correct)
self.assertDictEqual(cfg.env.to_dict(), basic_demo_cfg_correct["env"]) self.assertDictEqual(cfg.env.to_dict(), basic_demo_cfg_correct["env"])
torch_cfg = BasicDemoTorchCfg()
torch_cfg_dict = torch_cfg.to_dict()
# We have to do a manual check because torch.Tensor does not work with assertDictEqual.
self.assertEqual(torch_cfg_dict["some_number"], 0)
self.assertTrue(torch.all(torch_cfg_dict["some_tensor"] == torch.tensor([1, 2, 3])))
def test_dict_conversion_order(self): def test_dict_conversion_order(self):
"""Tests that order is conserved when converting to dictionary.""" """Tests that order is conserved when converting to dictionary."""
true_outer_order = ["device_id", "env", "robot_default_state", "list_config"] true_outer_order = ["device_id", "env", "robot_default_state", "list_config"]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment