Unverified Commit f7b59b31 authored by James Tigue's avatar James Tigue Committed by GitHub

Adds observation term history support to Observation Manager (#1439)

# Description

<!--
Thank you for your interest in sending a pull request. Please make sure
to check the contribution guidelines.

Link: https://isaac-sim.github.io/IsaacLab/source/refs/contributing.html
-->

This PR adds observation history by adding configuration parameters to
ObservationTerms and having the ObservationManager handling the
collection and storage of the histories via CircularBuffers.

Fixes #1208 

<!-- As a practice, it is recommended to open an issue to have
discussions on the proposed pull request.
This makes it easier for the community to keep track of what is being
developed or added, and if a given feature
is demanded by more than one party. -->

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarFangzhou Yu <156015326+fyu-bdai@users.noreply.github.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent ee3f0224
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.27.29"
version = "0.28.0"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.28.0 (2024-12-15)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added observation history computation to :class:`omni.isaac.lab.manager.observation_manager.ObservationManager`.
* Added ``history_length`` and ``flatten_history_dim`` configuration parameters to :class:`omni.isaac.lab.manager.manager_term_cfg.ObservationTermCfg`
* Added ``history_length`` and ``flatten_history_dim`` configuration parameters to :class:`omni.isaac.lab.manager.manager_term_cfg.ObservationGroupCfg`
* Added full buffer property to :class:`omni.isaac.lab.utils.buffers.circular_buffer.CircularBuffer`
0.27.29 (2024-12-15)
~~~~~~~~~~~~~~~~~~~~
......
......@@ -180,6 +180,19 @@ class ObservationTermCfg(ManagerTermBaseCfg):
please make sure the length of the tuple matches the dimensions of the tensor outputted from the term.
"""
history_length: int = 0
"""Number of past observations to store in the observation buffers. Defaults to 0, meaning no history.
Observation history initializes to empty, but is filled with the first append after reset or initialization. Subsequent history
only adds a single entry to the history buffer. If flatten_history_dim is set to True, the source data of shape
(N, H, D, ...) where N is the batch dimension and H is the history length will be reshaped to a 2D tensor of shape
(N, H*D*...). Otherwise, the data will be returned as is.
"""
flatten_history_dim: bool = True
"""Whether or not the observation manager should flatten history-based observation terms to a 2D (N, D) tensor.
Defaults to True."""
@configclass
class ObservationGroupCfg:
......@@ -201,6 +214,22 @@ class ObservationGroupCfg:
Otherwise, no corruption is applied.
"""
history_length: int | None = None
"""Number of past observation to store in the observation buffers for all observation terms in group.
This parameter will override :attr:`ObservationTermCfg.history_length` if set. Defaults to None. If None, each
terms history will be controlled on a per term basis. See :class:`ObservationTermCfg` for details on history_length
implementation.
"""
flatten_history_dim: bool = True
"""Flag to flatten history-based observation terms to a 2D (num_env, D) tensor for all observation terms in group.
Defaults to True.
This parameter will override all :attr:`ObservationTermCfg.flatten_history_dim` in the group if
ObservationGroupCfg.history_length is set.
"""
##
# Event manager
......
......@@ -8,12 +8,14 @@
from __future__ import annotations
import inspect
import numpy as np
import torch
from collections.abc import Sequence
from prettytable import PrettyTable
from typing import TYPE_CHECKING
from omni.isaac.lab.utils import modifiers
from omni.isaac.lab.utils.buffers import CircularBuffer
from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import ObservationGroupCfg, ObservationTermCfg
......@@ -45,6 +47,11 @@ class ObservationManager(ManagerBase):
concatenated. In this case, please set the :attr:`ObservationGroupCfg.concatenate_terms` attribute in the
group configuration to False.
Observations can also have history. This means a running history is updated per sim step. History can be controlled
per :class:`ObservationTermCfg` (See the :attr:`ObservationTermCfg.history_length` and
:attr:`ObservationTermCfg.flatten_history_dim`). History can also be controlled via :class:`ObservationGroupCfg`
where group configuration overwrites per term configuration if set. History follows an oldest to newest ordering.
The observation manager can be used to compute observations for all the groups or for a specific group. The
observations are computed by calling the registered functions for each term in the group. The functions are
called in the order of the terms in the group. The functions are expected to return a tensor with shape
......@@ -174,12 +181,17 @@ class ObservationManager(ManagerBase):
def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
# call all terms that are classes
for group_cfg in self._group_obs_class_term_cfgs.values():
for group_name, group_cfg in self._group_obs_class_term_cfgs.items():
for term_cfg in group_cfg:
term_cfg.func.reset(env_ids=env_ids)
# reset terms with history
for term_name in self._group_obs_term_names[group_name]:
if term_name in self._group_obs_term_history_buffer[group_name]:
self._group_obs_term_history_buffer[group_name][term_name].reset(batch_ids=env_ids)
# call all modifiers that are classes
for mod in self._group_obs_class_modifiers:
mod.reset(env_ids=env_ids)
# nothing to log here
return {}
......@@ -248,7 +260,7 @@ class ObservationManager(ManagerBase):
obs_terms = zip(group_term_names, self._group_obs_term_cfgs[group_name])
# evaluate terms: compute, add noise, clip, scale, custom modifiers
for name, term_cfg in obs_terms:
for term_name, term_cfg in obs_terms:
# compute term's value
obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params).clone()
# apply post-processing
......@@ -261,8 +273,17 @@ class ObservationManager(ManagerBase):
obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1])
if term_cfg.scale is not None:
obs = obs.mul_(term_cfg.scale)
# add value to list
group_obs[name] = obs
# Update the history buffer if observation term has history enabled
if term_cfg.history_length > 0:
self._group_obs_term_history_buffer[group_name][term_name].append(obs)
if term_cfg.flatten_history_dim:
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer.reshape(
self._env.num_envs, -1
)
else:
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer
else:
group_obs[term_name] = obs
# concatenate all observations in the group together
if self._group_obs_concatenate[group_name]:
......@@ -283,7 +304,7 @@ class ObservationManager(ManagerBase):
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_concatenate: dict[str, bool] = dict()
self._group_obs_term_history_buffer: dict[str, dict] = dict()
# create a list to store modifiers that are classes
# we store it as a separate list to only call reset on them and prevent unnecessary calls
self._group_obs_class_modifiers: list[modifiers.ModifierBase] = list()
......@@ -309,6 +330,7 @@ class ObservationManager(ManagerBase):
self._group_obs_term_dim[group_name] = list()
self._group_obs_term_cfgs[group_name] = list()
self._group_obs_class_term_cfgs[group_name] = list()
group_entry_history_buffer: dict[str, CircularBuffer] = dict()
# read common config for the group
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms
# check if config is dict already
......@@ -319,7 +341,7 @@ class ObservationManager(ManagerBase):
# iterate over all the terms in each group
for term_name, term_cfg in group_cfg_items:
# skip non-obs settings
if term_name in ["enable_corruption", "concatenate_terms"]:
if term_name in ["enable_corruption", "concatenate_terms", "history_length", "flatten_history_dim"]:
continue
# check for non config
if term_cfg is None:
......@@ -335,12 +357,26 @@ class ObservationManager(ManagerBase):
# check noise settings
if not group_cfg.enable_corruption:
term_cfg.noise = None
# check group history params and override terms
if group_cfg.history_length is not None:
term_cfg.history_length = group_cfg.history_length
term_cfg.flatten_history_dim = group_cfg.flatten_history_dim
# add term config to list to list
self._group_obs_term_names[group_name].append(term_name)
self._group_obs_term_cfgs[group_name].append(term_cfg)
# call function the first time to fill up dimensions
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape)
# create history buffers and calculate history term dimensions
if term_cfg.history_length > 0:
group_entry_history_buffer[term_name] = CircularBuffer(
max_len=term_cfg.history_length, batch_size=self._env.num_envs, device=self._env.device
)
old_dims = list(obs_dims)
old_dims.insert(1, term_cfg.history_length)
obs_dims = tuple(old_dims)
if term_cfg.flatten_history_dim:
obs_dims = (obs_dims[0], np.prod(obs_dims[1:]))
self._group_obs_term_dim[group_name].append(obs_dims[1:])
# if scale is set, check if single float or tuple
......@@ -411,3 +447,5 @@ class ObservationManager(ManagerBase):
self._group_obs_class_term_cfgs[group_name].append(term_cfg)
# call reset (in-case above call to get obs dims changed the state)
term_cfg.func.reset()
# add history buffers for each group
self._group_obs_term_history_buffer[group_name] = group_entry_history_buffer
......@@ -75,6 +75,16 @@ class CircularBuffer:
"""
return torch.minimum(self._num_pushes, self._max_len)
@property
def buffer(self) -> torch.Tensor:
"""Complete circular buffer with most recent entry at the end and oldest entry at the beginning.
Returns:
Complete circular buffer with most recent entry at the end and oldest entry at the beginning of dimension 1. The shape is [batch_size, max_length, data.shape[1:]].
"""
buf = self._buffer.clone()
buf = torch.roll(buf, shifts=self.max_length - self._pointer - 1, dims=0)
return torch.transpose(buf, dim0=0, dim1=1)
"""
Operations.
"""
......@@ -89,8 +99,10 @@ class CircularBuffer:
if batch_ids is None:
batch_ids = slice(None)
# reset the number of pushes for the specified batch indices
# note: we don't need to reset the buffer since it will be overwritten. The pointer handles this.
self._num_pushes[batch_ids] = 0
if self._buffer is not None:
# set buffer at batch_id reset indices to 0.0 so that the buffer() getter returns the cleared circular buffer after reset.
self._buffer[:, batch_ids, :] = 0.0
def append(self, data: torch.Tensor):
"""Append the data to the circular buffer.
......@@ -106,7 +118,7 @@ class CircularBuffer:
if data.shape[0] != self.batch_size:
raise ValueError(f"The input data has {data.shape[0]} environments while expecting {self.batch_size}")
# at the fist call, initialize the buffer
# at the first call, initialize the buffer size
if self._buffer is None:
self._pointer = -1
self._buffer = torch.empty((self.max_length, *data.shape), dtype=data.dtype, device=self._device)
......@@ -114,7 +126,12 @@ class CircularBuffer:
self._pointer = (self._pointer + 1) % self.max_length
# add the new data to the last layer
self._buffer[self._pointer] = data.to(self._device)
# increment number of number of pushes
# Check for batches with zero pushes and initialize all values in batch to first append
if 0 in self._num_pushes.tolist():
fill_ids = [i for i, x in enumerate(self._num_pushes.tolist()) if x == 0]
self._num_pushes.tolist().index(0) if 0 in self._num_pushes.tolist() else None
self._buffer[:, fill_ids, :] = data.to(self._device)[fill_ids]
# increment number of number of pushes for all batches
self._num_pushes += 1
def __getitem__(self, key: torch.Tensor) -> torch.Tensor:
......
......@@ -131,8 +131,51 @@ class TestObservationManager(unittest.TestCase):
self.obs_man = ObservationManager(cfg, self.env)
self.assertEqual(len(self.obs_man.active_terms["policy"]), 5)
# print the expected string
obs_man_str = str(self.obs_man)
print()
print(self.obs_man)
print(obs_man_str)
obs_man_str_split = obs_man_str.split("|")
term_1_str_index = obs_man_str_split.index(" term_1 ")
term_1_str_shape = obs_man_str_split[term_1_str_index + 1].strip()
self.assertEqual(term_1_str_shape, "(4,)")
def test_str_with_history(self):
"""Test the string representation of the observation manager with history terms."""
TERM_1_HISTORY = 5
@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""
@configclass
class SampleGroupCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""
term_1 = ObservationTermCfg(func="__main__:grilled_chicken", scale=10, history_length=TERM_1_HISTORY)
term_2 = ObservationTermCfg(func=grilled_chicken, scale=2)
term_3 = ObservationTermCfg(func=grilled_chicken_with_bbq, scale=5, params={"bbq": True})
term_4 = ObservationTermCfg(
func=grilled_chicken_with_yoghurt, scale=1.0, params={"hot": False, "bland": 2.0}
)
term_5 = ObservationTermCfg(
func=grilled_chicken_with_yoghurt_and_bbq, scale=1.0, params={"hot": False, "bland": 2.0}
)
policy: ObservationGroupCfg = SampleGroupCfg()
# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
self.assertEqual(len(self.obs_man.active_terms["policy"]), 5)
# print the expected string
obs_man_str = str(self.obs_man)
print()
print(obs_man_str)
obs_man_str_split = obs_man_str.split("|")
term_1_str_index = obs_man_str_split.index(" term_1 ")
term_1_str_shape = obs_man_str_split[term_1_str_index + 1].strip()
self.assertEqual(term_1_str_shape, "(20,)")
def test_config_equivalence(self):
"""Test the equivalence of observation manager created from different config types."""
......@@ -304,6 +347,157 @@ class TestObservationManager(unittest.TestCase):
torch.testing.assert_close(obs_policy[:, 5:8], obs_critic[:, 0:3])
torch.testing.assert_close(obs_policy[:, 8:11], obs_critic[:, 3:6])
def test_compute_with_history(self):
"""Test the observation computation with history buffers."""
HISTORY_LENGTH = 5
@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""
@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""
term_1 = ObservationTermCfg(func=grilled_chicken, history_length=HISTORY_LENGTH)
# total observation size: term_dim (4) * history_len (5) = 20
term_2 = ObservationTermCfg(func=lin_vel_w_data)
# total observation size: term_dim (3) = 3
policy: ObservationGroupCfg = PolicyCfg()
# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()
# obtain the group observations
obs_policy: torch.Tensor = observations["policy"]
# check the observation shape
self.assertEqual((self.env.num_envs, 23), obs_policy.shape)
# check the observation data
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env)
expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
print(expected_obs_data_t0, obs_policy)
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test that the history buffer holds previous data
for _ in range(HISTORY_LENGTH):
observations = self.obs_man.compute()
obs_policy = observations["policy"]
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device)
expected_obs_data_t5 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t5, obs_policy))
# test reset
self.obs_man.reset()
observations = self.obs_man.compute()
obs_policy = observations["policy"]
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test reset of specific env ids
reset_env_ids = [2, 4, 16]
self.obs_man.reset(reset_env_ids)
self.assertTrue(torch.equal(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids]))
def test_compute_with_2d_history(self):
"""Test the observation computation with history buffers for 2D observations."""
HISTORY_LENGTH = 5
@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""
@configclass
class FlattenedPolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""
term_1 = ObservationTermCfg(
func=grilled_chicken_image, params={"bland": 1.0, "channel": 1}, history_length=HISTORY_LENGTH
)
# total observation size: term_dim (128, 256) * history_len (5) = 163840
@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""
term_1 = ObservationTermCfg(
func=grilled_chicken_image,
params={"bland": 1.0, "channel": 1},
history_length=HISTORY_LENGTH,
flatten_history_dim=False,
)
# total observation size: (5, 128, 256, 1)
flat_obs_policy: ObservationGroupCfg = FlattenedPolicyCfg()
policy: ObservationGroupCfg = PolicyCfg()
# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()
# obtain the group observations
obs_policy_flat: torch.Tensor = observations["flat_obs_policy"]
obs_policy: torch.Tensor = observations["policy"]
# check the observation shapes
self.assertEqual((self.env.num_envs, 163840), obs_policy_flat.shape)
self.assertEqual((self.env.num_envs, HISTORY_LENGTH, 128, 256, 1), obs_policy.shape)
def test_compute_with_group_history(self):
"""Test the observation computation with group level history buffer configuration."""
TERM_HISTORY_LENGTH = 5
GROUP_HISTORY_LENGTH = 10
@configclass
class MyObservationManagerCfg:
"""Test config class for observation manager."""
@configclass
class PolicyCfg(ObservationGroupCfg):
"""Test config class for policy observation group."""
history_length = GROUP_HISTORY_LENGTH
# group level history length will override all terms
term_1 = ObservationTermCfg(func=grilled_chicken, history_length=TERM_HISTORY_LENGTH)
# total observation size: term_dim (4) * history_len (5) = 20
# with override total obs size: term_dim (4) * history_len (10) = 40
term_2 = ObservationTermCfg(func=lin_vel_w_data)
# total observation size: term_dim (3) = 3
# with override total obs size: term_dim (3) * history_len (10) = 30
policy: ObservationGroupCfg = PolicyCfg()
# create observation manager
cfg = MyObservationManagerCfg()
self.obs_man = ObservationManager(cfg, self.env)
# compute observation using manager
observations = self.obs_man.compute()
# obtain the group observations
obs_policy: torch.Tensor = observations["policy"]
# check the total observation shape
self.assertEqual((self.env.num_envs, 70), obs_policy.shape)
# check the observation data is initialized properly
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH)
expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test that the history buffer holds previous data
for _ in range(GROUP_HISTORY_LENGTH):
observations = self.obs_man.compute()
obs_policy = observations["policy"]
expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device)
expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH)
expected_obs_data_t10 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1)
self.assertTrue(torch.equal(expected_obs_data_t10, obs_policy))
# test reset
self.obs_man.reset()
observations = self.obs_man.compute()
obs_policy = observations["policy"]
self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy))
# test reset of specific env ids
reset_env_ids = [2, 4, 16]
self.obs_man.reset(reset_env_ids)
self.assertTrue(torch.equal(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids]))
def test_invalid_observation_config(self):
"""Test the invalid observation config."""
......
......@@ -46,9 +46,31 @@ class TestCircularBuffer(unittest.TestCase):
# reset the buffer
self.buffer.reset()
# check if the buffer is empty
# check if the buffer has zeros entries
self.assertEqual(self.buffer.current_length.tolist(), [0, 0, 0])
def test_reset_subset(self):
"""Test resetting a subset of batches in the circular buffer."""
data1 = torch.ones((self.batch_size, 2), device=self.device)
data2 = 2.0 * data1.clone()
data3 = 3.0 * data1.clone()
self.buffer.append(data1)
self.buffer.append(data2)
# reset the buffer
reset_batch_id = 1
self.buffer.reset(batch_ids=[reset_batch_id])
# check that correct batch is reset
self.assertEqual(self.buffer.current_length.tolist()[reset_batch_id], 0)
# Append new set of data
self.buffer.append(data3)
# check if the correct number of entries are in each batch
expected_length = [3, 3, 3]
expected_length[reset_batch_id] = 1
self.assertEqual(self.buffer.current_length.tolist(), expected_length)
# check that all entries of the recently reset and appended batch are equal
for i in range(self.max_len):
torch.testing.assert_close(self.buffer.buffer[reset_batch_id, 0], self.buffer.buffer[reset_batch_id, i])
def test_append_and_retrieve(self):
"""Test appending and retrieving data from the circular buffer."""
# append some data
......@@ -121,6 +143,33 @@ class TestCircularBuffer(unittest.TestCase):
retrieved_data = self.buffer[torch.tensor([5, 5, 5], device=self.device)]
self.assertTrue(torch.equal(retrieved_data, data1))
def test_return_buffer_prop(self):
"""Test retrieving the whole buffer for correct size and contents.
Returning the whole buffer should have the shape [batch_size,max_len,data.shape[1:]]
"""
num_overflow = 2
for i in range(self.buffer.max_length + num_overflow):
data = torch.tensor([[i]], device=self.device).repeat(3, 2)
self.buffer.append(data)
retrieved_buffer = self.buffer.buffer
# check shape
self.assertTrue(retrieved_buffer.shape == torch.Size([self.buffer.batch_size, self.buffer.max_length, 2]))
# check that batch is first dimension
torch.testing.assert_close(retrieved_buffer[0], retrieved_buffer[1])
# check oldest
torch.testing.assert_close(
retrieved_buffer[:, 0], torch.tensor([[num_overflow]], device=self.device).repeat(3, 2)
)
# check most recent
torch.testing.assert_close(
retrieved_buffer[:, -1],
torch.tensor([[self.buffer.max_length + num_overflow - 1]], device=self.device).repeat(3, 2),
)
# check that it is returned oldest first
for idx in range(self.buffer.max_length - 1):
self.assertTrue(torch.all(torch.le(retrieved_buffer[:, idx], retrieved_buffer[:, idx + 1])))
if __name__ == "__main__":
run_tests()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment