Unverified Commit f7b59b31 authored by James Tigue's avatar James Tigue Committed by GitHub

Adds observation term history support to Observation Manager (#1439)

# Description

<!--
Thank you for your interest in sending a pull request. Please make sure
to check the contribution guidelines.

Link: https://isaac-sim.github.io/IsaacLab/source/refs/contributing.html
-->

This PR adds observation history by adding configuration parameters to
ObservationTerms and having the ObservationManager handling the
collection and storage of the histories via CircularBuffers.

Fixes #1208 

<!-- As a practice, it is recommended to open an issue to have
discussions on the proposed pull request.
This makes it easier for the community to keep track of what is being
developed or added, and if a given feature
is demanded by more than one party. -->

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarFangzhou Yu <156015326+fyu-bdai@users.noreply.github.com>
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent ee3f0224
[package] [package]
# Note: Semantic Versioning is used: https://semver.org/ # Note: Semantic Versioning is used: https://semver.org/
version = "0.27.29" version = "0.28.0"
# Description # Description
title = "Isaac Lab framework for Robot Learning" title = "Isaac Lab framework for Robot Learning"
......
Changelog Changelog
--------- ---------
0.28.0 (2024-12-15)
~~~~~~~~~~~~~~~~~~~
Added
^^^^^
* Added observation history computation to :class:`omni.isaac.lab.manager.observation_manager.ObservationManager`.
* Added ``history_length`` and ``flatten_history_dim`` configuration parameters to :class:`omni.isaac.lab.manager.manager_term_cfg.ObservationTermCfg`
* Added ``history_length`` and ``flatten_history_dim`` configuration parameters to :class:`omni.isaac.lab.manager.manager_term_cfg.ObservationGroupCfg`
* Added full buffer property to :class:`omni.isaac.lab.utils.buffers.circular_buffer.CircularBuffer`
0.27.29 (2024-12-15) 0.27.29 (2024-12-15)
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
......
...@@ -180,6 +180,19 @@ class ObservationTermCfg(ManagerTermBaseCfg): ...@@ -180,6 +180,19 @@ class ObservationTermCfg(ManagerTermBaseCfg):
please make sure the length of the tuple matches the dimensions of the tensor outputted from the term. please make sure the length of the tuple matches the dimensions of the tensor outputted from the term.
""" """
history_length: int = 0
"""Number of past observations to store in the observation buffers. Defaults to 0, meaning no history.
Observation history initializes to empty, but is filled with the first append after reset or initialization. Subsequent history
only adds a single entry to the history buffer. If flatten_history_dim is set to True, the source data of shape
(N, H, D, ...) where N is the batch dimension and H is the history length will be reshaped to a 2D tensor of shape
(N, H*D*...). Otherwise, the data will be returned as is.
"""
flatten_history_dim: bool = True
"""Whether or not the observation manager should flatten history-based observation terms to a 2D (N, D) tensor.
Defaults to True."""
@configclass @configclass
class ObservationGroupCfg: class ObservationGroupCfg:
...@@ -201,6 +214,22 @@ class ObservationGroupCfg: ...@@ -201,6 +214,22 @@ class ObservationGroupCfg:
Otherwise, no corruption is applied. Otherwise, no corruption is applied.
""" """
history_length: int | None = None
"""Number of past observation to store in the observation buffers for all observation terms in group.
This parameter will override :attr:`ObservationTermCfg.history_length` if set. Defaults to None. If None, each
terms history will be controlled on a per term basis. See :class:`ObservationTermCfg` for details on history_length
implementation.
"""
flatten_history_dim: bool = True
"""Flag to flatten history-based observation terms to a 2D (num_env, D) tensor for all observation terms in group.
Defaults to True.
This parameter will override all :attr:`ObservationTermCfg.flatten_history_dim` in the group if
ObservationGroupCfg.history_length is set.
"""
## ##
# Event manager # Event manager
......
...@@ -8,12 +8,14 @@ ...@@ -8,12 +8,14 @@
from __future__ import annotations from __future__ import annotations
import inspect import inspect
import numpy as np
import torch import torch
from collections.abc import Sequence from collections.abc import Sequence
from prettytable import PrettyTable from prettytable import PrettyTable
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from omni.isaac.lab.utils import modifiers from omni.isaac.lab.utils import modifiers
from omni.isaac.lab.utils.buffers import CircularBuffer
from .manager_base import ManagerBase, ManagerTermBase from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import ObservationGroupCfg, ObservationTermCfg from .manager_term_cfg import ObservationGroupCfg, ObservationTermCfg
...@@ -45,6 +47,11 @@ class ObservationManager(ManagerBase): ...@@ -45,6 +47,11 @@ class ObservationManager(ManagerBase):
concatenated. In this case, please set the :attr:`ObservationGroupCfg.concatenate_terms` attribute in the concatenated. In this case, please set the :attr:`ObservationGroupCfg.concatenate_terms` attribute in the
group configuration to False. group configuration to False.
Observations can also have history. This means a running history is updated per sim step. History can be controlled
per :class:`ObservationTermCfg` (See the :attr:`ObservationTermCfg.history_length` and
:attr:`ObservationTermCfg.flatten_history_dim`). History can also be controlled via :class:`ObservationGroupCfg`
where group configuration overwrites per term configuration if set. History follows an oldest to newest ordering.
The observation manager can be used to compute observations for all the groups or for a specific group. The The observation manager can be used to compute observations for all the groups or for a specific group. The
observations are computed by calling the registered functions for each term in the group. The functions are observations are computed by calling the registered functions for each term in the group. The functions are
called in the order of the terms in the group. The functions are expected to return a tensor with shape called in the order of the terms in the group. The functions are expected to return a tensor with shape
...@@ -174,12 +181,17 @@ class ObservationManager(ManagerBase): ...@@ -174,12 +181,17 @@ class ObservationManager(ManagerBase):
def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]: def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]:
# call all terms that are classes # call all terms that are classes
for group_cfg in self._group_obs_class_term_cfgs.values(): for group_name, group_cfg in self._group_obs_class_term_cfgs.items():
for term_cfg in group_cfg: for term_cfg in group_cfg:
term_cfg.func.reset(env_ids=env_ids) term_cfg.func.reset(env_ids=env_ids)
# reset terms with history
for term_name in self._group_obs_term_names[group_name]:
if term_name in self._group_obs_term_history_buffer[group_name]:
self._group_obs_term_history_buffer[group_name][term_name].reset(batch_ids=env_ids)
# call all modifiers that are classes # call all modifiers that are classes
for mod in self._group_obs_class_modifiers: for mod in self._group_obs_class_modifiers:
mod.reset(env_ids=env_ids) mod.reset(env_ids=env_ids)
# nothing to log here # nothing to log here
return {} return {}
...@@ -248,7 +260,7 @@ class ObservationManager(ManagerBase): ...@@ -248,7 +260,7 @@ class ObservationManager(ManagerBase):
obs_terms = zip(group_term_names, self._group_obs_term_cfgs[group_name]) obs_terms = zip(group_term_names, self._group_obs_term_cfgs[group_name])
# evaluate terms: compute, add noise, clip, scale, custom modifiers # evaluate terms: compute, add noise, clip, scale, custom modifiers
for name, term_cfg in obs_terms: for term_name, term_cfg in obs_terms:
# compute term's value # compute term's value
obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params).clone() obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params).clone()
# apply post-processing # apply post-processing
...@@ -261,8 +273,17 @@ class ObservationManager(ManagerBase): ...@@ -261,8 +273,17 @@ class ObservationManager(ManagerBase):
obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1]) obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1])
if term_cfg.scale is not None: if term_cfg.scale is not None:
obs = obs.mul_(term_cfg.scale) obs = obs.mul_(term_cfg.scale)
# add value to list # Update the history buffer if observation term has history enabled
group_obs[name] = obs if term_cfg.history_length > 0:
self._group_obs_term_history_buffer[group_name][term_name].append(obs)
if term_cfg.flatten_history_dim:
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer.reshape(
self._env.num_envs, -1
)
else:
group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer
else:
group_obs[term_name] = obs
# concatenate all observations in the group together # concatenate all observations in the group together
if self._group_obs_concatenate[group_name]: if self._group_obs_concatenate[group_name]:
...@@ -283,7 +304,7 @@ class ObservationManager(ManagerBase): ...@@ -283,7 +304,7 @@ class ObservationManager(ManagerBase):
self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict() self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict() self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict()
self._group_obs_concatenate: dict[str, bool] = dict() self._group_obs_concatenate: dict[str, bool] = dict()
self._group_obs_term_history_buffer: dict[str, dict] = dict()
# create a list to store modifiers that are classes # create a list to store modifiers that are classes
# we store it as a separate list to only call reset on them and prevent unnecessary calls # we store it as a separate list to only call reset on them and prevent unnecessary calls
self._group_obs_class_modifiers: list[modifiers.ModifierBase] = list() self._group_obs_class_modifiers: list[modifiers.ModifierBase] = list()
...@@ -309,6 +330,7 @@ class ObservationManager(ManagerBase): ...@@ -309,6 +330,7 @@ class ObservationManager(ManagerBase):
self._group_obs_term_dim[group_name] = list() self._group_obs_term_dim[group_name] = list()
self._group_obs_term_cfgs[group_name] = list() self._group_obs_term_cfgs[group_name] = list()
self._group_obs_class_term_cfgs[group_name] = list() self._group_obs_class_term_cfgs[group_name] = list()
group_entry_history_buffer: dict[str, CircularBuffer] = dict()
# read common config for the group # read common config for the group
self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms
# check if config is dict already # check if config is dict already
...@@ -319,7 +341,7 @@ class ObservationManager(ManagerBase): ...@@ -319,7 +341,7 @@ class ObservationManager(ManagerBase):
# iterate over all the terms in each group # iterate over all the terms in each group
for term_name, term_cfg in group_cfg_items: for term_name, term_cfg in group_cfg_items:
# skip non-obs settings # skip non-obs settings
if term_name in ["enable_corruption", "concatenate_terms"]: if term_name in ["enable_corruption", "concatenate_terms", "history_length", "flatten_history_dim"]:
continue continue
# check for non config # check for non config
if term_cfg is None: if term_cfg is None:
...@@ -335,12 +357,26 @@ class ObservationManager(ManagerBase): ...@@ -335,12 +357,26 @@ class ObservationManager(ManagerBase):
# check noise settings # check noise settings
if not group_cfg.enable_corruption: if not group_cfg.enable_corruption:
term_cfg.noise = None term_cfg.noise = None
# check group history params and override terms
if group_cfg.history_length is not None:
term_cfg.history_length = group_cfg.history_length
term_cfg.flatten_history_dim = group_cfg.flatten_history_dim
# add term config to list to list # add term config to list to list
self._group_obs_term_names[group_name].append(term_name) self._group_obs_term_names[group_name].append(term_name)
self._group_obs_term_cfgs[group_name].append(term_cfg) self._group_obs_term_cfgs[group_name].append(term_cfg)
# call function the first time to fill up dimensions # call function the first time to fill up dimensions
obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape) obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape)
# create history buffers and calculate history term dimensions
if term_cfg.history_length > 0:
group_entry_history_buffer[term_name] = CircularBuffer(
max_len=term_cfg.history_length, batch_size=self._env.num_envs, device=self._env.device
)
old_dims = list(obs_dims)
old_dims.insert(1, term_cfg.history_length)
obs_dims = tuple(old_dims)
if term_cfg.flatten_history_dim:
obs_dims = (obs_dims[0], np.prod(obs_dims[1:]))
self._group_obs_term_dim[group_name].append(obs_dims[1:]) self._group_obs_term_dim[group_name].append(obs_dims[1:])
# if scale is set, check if single float or tuple # if scale is set, check if single float or tuple
...@@ -411,3 +447,5 @@ class ObservationManager(ManagerBase): ...@@ -411,3 +447,5 @@ class ObservationManager(ManagerBase):
self._group_obs_class_term_cfgs[group_name].append(term_cfg) self._group_obs_class_term_cfgs[group_name].append(term_cfg)
# call reset (in-case above call to get obs dims changed the state) # call reset (in-case above call to get obs dims changed the state)
term_cfg.func.reset() term_cfg.func.reset()
# add history buffers for each group
self._group_obs_term_history_buffer[group_name] = group_entry_history_buffer
...@@ -75,6 +75,16 @@ class CircularBuffer: ...@@ -75,6 +75,16 @@ class CircularBuffer:
""" """
return torch.minimum(self._num_pushes, self._max_len) return torch.minimum(self._num_pushes, self._max_len)
@property
def buffer(self) -> torch.Tensor:
"""Complete circular buffer with most recent entry at the end and oldest entry at the beginning.
Returns:
Complete circular buffer with most recent entry at the end and oldest entry at the beginning of dimension 1. The shape is [batch_size, max_length, data.shape[1:]].
"""
buf = self._buffer.clone()
buf = torch.roll(buf, shifts=self.max_length - self._pointer - 1, dims=0)
return torch.transpose(buf, dim0=0, dim1=1)
""" """
Operations. Operations.
""" """
...@@ -89,8 +99,10 @@ class CircularBuffer: ...@@ -89,8 +99,10 @@ class CircularBuffer:
if batch_ids is None: if batch_ids is None:
batch_ids = slice(None) batch_ids = slice(None)
# reset the number of pushes for the specified batch indices # reset the number of pushes for the specified batch indices
# note: we don't need to reset the buffer since it will be overwritten. The pointer handles this.
self._num_pushes[batch_ids] = 0 self._num_pushes[batch_ids] = 0
if self._buffer is not None:
# set buffer at batch_id reset indices to 0.0 so that the buffer() getter returns the cleared circular buffer after reset.
self._buffer[:, batch_ids, :] = 0.0
def append(self, data: torch.Tensor): def append(self, data: torch.Tensor):
"""Append the data to the circular buffer. """Append the data to the circular buffer.
...@@ -106,7 +118,7 @@ class CircularBuffer: ...@@ -106,7 +118,7 @@ class CircularBuffer:
if data.shape[0] != self.batch_size: if data.shape[0] != self.batch_size:
raise ValueError(f"The input data has {data.shape[0]} environments while expecting {self.batch_size}") raise ValueError(f"The input data has {data.shape[0]} environments while expecting {self.batch_size}")
# at the fist call, initialize the buffer # at the first call, initialize the buffer size
if self._buffer is None: if self._buffer is None:
self._pointer = -1 self._pointer = -1
self._buffer = torch.empty((self.max_length, *data.shape), dtype=data.dtype, device=self._device) self._buffer = torch.empty((self.max_length, *data.shape), dtype=data.dtype, device=self._device)
...@@ -114,7 +126,12 @@ class CircularBuffer: ...@@ -114,7 +126,12 @@ class CircularBuffer:
self._pointer = (self._pointer + 1) % self.max_length self._pointer = (self._pointer + 1) % self.max_length
# add the new data to the last layer # add the new data to the last layer
self._buffer[self._pointer] = data.to(self._device) self._buffer[self._pointer] = data.to(self._device)
# increment number of number of pushes # Check for batches with zero pushes and initialize all values in batch to first append
if 0 in self._num_pushes.tolist():
fill_ids = [i for i, x in enumerate(self._num_pushes.tolist()) if x == 0]
self._num_pushes.tolist().index(0) if 0 in self._num_pushes.tolist() else None
self._buffer[:, fill_ids, :] = data.to(self._device)[fill_ids]
# increment number of number of pushes for all batches
self._num_pushes += 1 self._num_pushes += 1
def __getitem__(self, key: torch.Tensor) -> torch.Tensor: def __getitem__(self, key: torch.Tensor) -> torch.Tensor:
......
...@@ -46,9 +46,31 @@ class TestCircularBuffer(unittest.TestCase): ...@@ -46,9 +46,31 @@ class TestCircularBuffer(unittest.TestCase):
# reset the buffer # reset the buffer
self.buffer.reset() self.buffer.reset()
# check if the buffer is empty # check if the buffer has zeros entries
self.assertEqual(self.buffer.current_length.tolist(), [0, 0, 0]) self.assertEqual(self.buffer.current_length.tolist(), [0, 0, 0])
def test_reset_subset(self):
"""Test resetting a subset of batches in the circular buffer."""
data1 = torch.ones((self.batch_size, 2), device=self.device)
data2 = 2.0 * data1.clone()
data3 = 3.0 * data1.clone()
self.buffer.append(data1)
self.buffer.append(data2)
# reset the buffer
reset_batch_id = 1
self.buffer.reset(batch_ids=[reset_batch_id])
# check that correct batch is reset
self.assertEqual(self.buffer.current_length.tolist()[reset_batch_id], 0)
# Append new set of data
self.buffer.append(data3)
# check if the correct number of entries are in each batch
expected_length = [3, 3, 3]
expected_length[reset_batch_id] = 1
self.assertEqual(self.buffer.current_length.tolist(), expected_length)
# check that all entries of the recently reset and appended batch are equal
for i in range(self.max_len):
torch.testing.assert_close(self.buffer.buffer[reset_batch_id, 0], self.buffer.buffer[reset_batch_id, i])
def test_append_and_retrieve(self): def test_append_and_retrieve(self):
"""Test appending and retrieving data from the circular buffer.""" """Test appending and retrieving data from the circular buffer."""
# append some data # append some data
...@@ -121,6 +143,33 @@ class TestCircularBuffer(unittest.TestCase): ...@@ -121,6 +143,33 @@ class TestCircularBuffer(unittest.TestCase):
retrieved_data = self.buffer[torch.tensor([5, 5, 5], device=self.device)] retrieved_data = self.buffer[torch.tensor([5, 5, 5], device=self.device)]
self.assertTrue(torch.equal(retrieved_data, data1)) self.assertTrue(torch.equal(retrieved_data, data1))
def test_return_buffer_prop(self):
"""Test retrieving the whole buffer for correct size and contents.
Returning the whole buffer should have the shape [batch_size,max_len,data.shape[1:]]
"""
num_overflow = 2
for i in range(self.buffer.max_length + num_overflow):
data = torch.tensor([[i]], device=self.device).repeat(3, 2)
self.buffer.append(data)
retrieved_buffer = self.buffer.buffer
# check shape
self.assertTrue(retrieved_buffer.shape == torch.Size([self.buffer.batch_size, self.buffer.max_length, 2]))
# check that batch is first dimension
torch.testing.assert_close(retrieved_buffer[0], retrieved_buffer[1])
# check oldest
torch.testing.assert_close(
retrieved_buffer[:, 0], torch.tensor([[num_overflow]], device=self.device).repeat(3, 2)
)
# check most recent
torch.testing.assert_close(
retrieved_buffer[:, -1],
torch.tensor([[self.buffer.max_length + num_overflow - 1]], device=self.device).repeat(3, 2),
)
# check that it is returned oldest first
for idx in range(self.buffer.max_length - 1):
self.assertTrue(torch.all(torch.le(retrieved_buffer[:, idx], retrieved_buffer[:, idx + 1])))
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment