Unverified Commit 421ece0b authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Fixes shared memory address between observation terms data (#493)

# Description

This MR does the following fixes:

* Adds a clone operator to the observation manager term computation to
prevent shared data between terms
* Fixes the flushing of data for imitation learning worklow

Fixes https://github.com/NVIDIA-Omniverse/orbit/issues/356

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./orbit.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent bca680a9
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.15.9" version = "0.15.10"
# Description # Description
title = "ORBIT framework for Robot Learning" title = "ORBIT framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.15.10 (2024-04-11)
~~~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Fixed sharing of the same memory address between returned tensors from observation terms
in the :class:`omni.isaac.orbit.managers.ObservationManager` class. Earlier, the returned
tensors could map to the same memory address, causing issues when the tensors were modified
during scaling, clipping or other operations.
0.15.9 (2024-04-04) 0.15.9 (2024-04-04)
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -168,7 +168,7 @@ class ObservationManager(ManagerBase): ...@@ -168,7 +168,7 @@ class ObservationManager(ManagerBase):
# evaluate terms: compute, add noise, clip, scale. # evaluate terms: compute, add noise, clip, scale.
for name, term_cfg in obs_terms: for name, term_cfg in obs_terms:
# compute term's value # compute term's value
obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params) obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params).clone()
# apply post-processing # apply post-processing
if term_cfg.noise: if term_cfg.noise:
obs = term_cfg.noise.func(obs, term_cfg.noise) obs = term_cfg.noise.func(obs, term_cfg.noise)
......
...@@ -10,8 +10,7 @@ from __future__ import annotations ...@@ -10,8 +10,7 @@ from __future__ import annotations
from omni.isaac.orbit.app import AppLauncher, run_tests from omni.isaac.orbit.app import AppLauncher, run_tests
# launch omniverse app # launch omniverse app
config = {"headless": True} simulation_app = AppLauncher(headless=True).app
simulation_app = AppLauncher(config).app
"""Rest everything follows.""" """Rest everything follows."""
...@@ -71,11 +70,32 @@ class non_callable_complex_function_class(ManagerTermBase): ...@@ -71,11 +70,32 @@ class non_callable_complex_function_class(ManagerTermBase):
return torch.ones(env.num_envs, 2, device=env.device) * self._cost return torch.ones(env.num_envs, 2, device=env.device) * self._cost
class MyDataClass:
def __init__(self, num_envs: int, device: str):
self.pos_w = torch.rand((num_envs, 3), device=device)
self.lin_vel_w = torch.rand((num_envs, 3), device=device)
def pos_w_data(env) -> torch.Tensor:
return env.data.pos_w
def lin_vel_w_data(env) -> torch.Tensor:
return env.data.lin_vel_w
class TestObservationManager(unittest.TestCase): class TestObservationManager(unittest.TestCase):
"""Test cases for various situations with observation manager.""" """Test cases for various situations with observation manager."""
def setUp(self) -> None: def setUp(self) -> None:
self.env = namedtuple("BaseEnv", ["num_envs", "device"])(20, "cpu") # set up the environment
self.num_envs = 20
self.device = "cuda:0"
# create dummy environment
self.env = namedtuple("BaseEnv", ["num_envs", "device", "data"])(
self.num_envs, self.device, MyDataClass(self.num_envs, self.device)
)
def test_str(self): def test_str(self):
"""Test the string representation of the observation manager.""" """Test the string representation of the observation manager."""
...@@ -203,16 +223,39 @@ class TestObservationManager(unittest.TestCase): ...@@ -203,16 +223,39 @@ class TestObservationManager(unittest.TestCase):
term_1 = ObservationTermCfg(func=grilled_chicken, scale=10) term_1 = ObservationTermCfg(func=grilled_chicken, scale=10)
term_2 = ObservationTermCfg(func=grilled_chicken_with_curry, scale=0.0, params={"hot": False}) term_2 = ObservationTermCfg(func=grilled_chicken_with_curry, scale=0.0, params={"hot": False})
term_3 = ObservationTermCfg(func=pos_w_data, scale=2.0)
term_4 = ObservationTermCfg(func=lin_vel_w_data, scale=1.5)
@configclass
class CriticCfg(ObservationGroupCfg):
term_1 = ObservationTermCfg(func=pos_w_data, scale=2.0)
term_2 = ObservationTermCfg(func=lin_vel_w_data, scale=1.5)
term_3 = ObservationTermCfg(func=pos_w_data, scale=2.0)
term_4 = ObservationTermCfg(func=lin_vel_w_data, scale=1.5)
policy: ObservationGroupCfg = PolicyCfg() policy: ObservationGroupCfg = PolicyCfg()
critic: ObservationGroupCfg = CriticCfg()
# create observation manager # create observation manager
cfg = MyObservationManagerCfg() cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env) self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager # compute observation using manager
observations = self.obs_man.compute() observations = self.obs_man.compute()
# obtain the group observations
obs_policy: torch.Tensor = observations["policy"]
obs_critic: torch.Tensor = observations["critic"]
# check the observation shape # check the observation shape
self.assertEqual((self.env.num_envs, 5), observations["policy"].shape) self.assertEqual((self.env.num_envs, 11), obs_policy.shape)
self.assertEqual((self.env.num_envs, 12), obs_critic.shape)
# make sure that the data are the same for same terms
# -- within group
torch.testing.assert_close(obs_critic[:, 0:3], obs_critic[:, 6:9])
torch.testing.assert_close(obs_critic[:, 3:6], obs_critic[:, 9:12])
# -- between groups
torch.testing.assert_close(obs_policy[:, 5:8], obs_critic[:, 0:3])
torch.testing.assert_close(obs_policy[:, 8:11], obs_critic[:, 3:6])
def test_invalid_observation_config(self): def test_invalid_observation_config(self):
"""Test the invalid observation config.""" """Test the invalid observation config."""
......
...@@ -10,8 +10,7 @@ from __future__ import annotations ...@@ -10,8 +10,7 @@ from __future__ import annotations
from omni.isaac.orbit.app import AppLauncher, run_tests from omni.isaac.orbit.app import AppLauncher, run_tests
# launch omniverse app # launch omniverse app
config = {"headless": True} simulation_app = AppLauncher(headless=True).app
simulation_app = AppLauncher(config).app
"""Rest everything follows.""" """Rest everything follows."""
......
...@@ -193,10 +193,12 @@ class RobomimicDataCollector: ...@@ -193,10 +193,12 @@ class RobomimicDataCollector:
if self._h5_file_stream is None or self._h5_data_group is None: if self._h5_file_stream is None or self._h5_data_group is None:
carb.log_error("No file stream has been opened. Please call reset before flushing data.") carb.log_error("No file stream has been opened. Please call reset before flushing data.")
return return
# iterate over each environment and add their data # iterate over each environment and add their data
for index in env_ids: for index in env_ids:
# data corresponding to demo # data corresponding to demo
env_dataset = self._dataset[f"env_{index}"] env_dataset = self._dataset[f"env_{index}"]
# create episode group based on demo count # create episode group based on demo count
h5_episode_group = self._h5_data_group.create_group(f"demo_{self._demo_count}") h5_episode_group = self._h5_data_group.create_group(f"demo_{self._demo_count}")
# store number of steps taken # store number of steps taken
...@@ -213,17 +215,23 @@ class RobomimicDataCollector: ...@@ -213,17 +215,23 @@ class RobomimicDataCollector:
h5_episode_group.create_dataset(key, data=np.array(value)) h5_episode_group.create_dataset(key, data=np.array(value))
# increment total step counts # increment total step counts
self._h5_data_group.attrs["total"] += h5_episode_group.attrs["num_samples"] self._h5_data_group.attrs["total"] += h5_episode_group.attrs["num_samples"]
# increment total demo counts # increment total demo counts
self._demo_count += 1 self._demo_count += 1
# reset buffer for environment # reset buffer for environment
self._dataset[f"env_{index}"] = dict() self._dataset[f"env_{index}"] = dict()
# dump at desired frequency # dump at desired frequency
if self._demo_count % self._flush_freq == 0: if self._demo_count % self._flush_freq == 0:
self._h5_file_stream.flush() self._h5_file_stream.flush()
print(f">>> Flushing data to disk. Collected demos: {self._demo_count} / {self._num_demos}") print(f">>> Flushing data to disk. Collected demos: {self._demo_count} / {self._num_demos}")
# if demos collected then stop
if self._demo_count >= self._num_demos: # if demos collected then stop
self.close() if self._demo_count >= self._num_demos:
print(f">>> Desired number of demonstrations collected: {self._demo_count} >= {self._num_demos}.")
self.close()
# break out of loop
break
def close(self): def close(self):
"""Stop recording and save the file at its current state.""" """Stop recording and save the file at its current state."""
...@@ -266,6 +274,8 @@ class RobomimicDataCollector: ...@@ -266,6 +274,8 @@ class RobomimicDataCollector:
if self._env_config is None: if self._env_config is None:
self._env_config = dict() self._env_config = dict()
# -- add info # -- add info
self._h5_data_group.attrs["env_args"] = json.dumps( self._h5_data_group.attrs["env_args"] = json.dumps({
{"env_name": self._env_name, "type": env_type, "env_kwargs": self._env_config} "env_name": self._env_name,
) "type": env_type,
"env_kwargs": self._env_config,
})
...@@ -140,12 +140,14 @@ def main(): ...@@ -140,12 +140,14 @@ def main():
collector_interface.add(f"obs/{key}", value) collector_interface.add(f"obs/{key}", value)
# -- actions # -- actions
collector_interface.add("actions", actions) collector_interface.add("actions", actions)
# perform action on environment # perform action on environment
obs_dict, rewards, terminated, truncated, info = env.step(actions) obs_dict, rewards, terminated, truncated, info = env.step(actions)
dones = terminated | truncated dones = terminated | truncated
# check that simulation is stopped or not # check that simulation is stopped or not
if env.unwrapped.sim.is_stopped(): if env.unwrapped.sim.is_stopped():
break break
# robomimic only cares about policy observations # robomimic only cares about policy observations
# store signals from the environment # store signals from the environment
# -- next_obs # -- next_obs
...@@ -163,6 +165,10 @@ def main(): ...@@ -163,6 +165,10 @@ def main():
reset_env_ids = dones.nonzero(as_tuple=False).squeeze(-1) reset_env_ids = dones.nonzero(as_tuple=False).squeeze(-1)
collector_interface.flush(reset_env_ids) collector_interface.flush(reset_env_ids)
# check if enough data is collected
if collector_interface.is_stopped():
break
# close the simulator # close the simulator
collector_interface.close() collector_interface.close()
env.close() env.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment