Unverified Commit d63e58f9 authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Simplifies buffer validation check for CircularBuffer (#2617)

# Description

As pointed out in the reported issue, there seem to be some expensive
`tolist()` operations inside the circular buffer class, that aren't
necessary. This MR simplifies the code.

Fixes #2590

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
Co-authored-by: 's avatarKelly Guo <kellyg@nvidia.com>
parent 9d52c4b9
...@@ -121,8 +121,10 @@ class CircularBuffer: ...@@ -121,8 +121,10 @@ class CircularBuffer:
""" """
# check the batch size # check the batch size
if data.shape[0] != self.batch_size: if data.shape[0] != self.batch_size:
raise ValueError(f"The input data has {data.shape[0]} environments while expecting {self.batch_size}") raise ValueError(f"The input data has '{data.shape[0]}' batch size while expecting '{self.batch_size}'")
# move the data to the device
data = data.to(self._device)
# at the first call, initialize the buffer size # at the first call, initialize the buffer size
if self._buffer is None: if self._buffer is None:
self._pointer = -1 self._pointer = -1
...@@ -130,12 +132,11 @@ class CircularBuffer: ...@@ -130,12 +132,11 @@ class CircularBuffer:
# move the head to the next slot # move the head to the next slot
self._pointer = (self._pointer + 1) % self.max_length self._pointer = (self._pointer + 1) % self.max_length
# add the new data to the last layer # add the new data to the last layer
self._buffer[self._pointer] = data.to(self._device) self._buffer[self._pointer] = data
# Check for batches with zero pushes and initialize all values in batch to first append # Check for batches with zero pushes and initialize all values in batch to first append
if 0 in self._num_pushes.tolist(): is_first_push = self._num_pushes == 0
fill_ids = [i for i, x in enumerate(self._num_pushes.tolist()) if x == 0] if torch.any(is_first_push):
self._num_pushes.tolist().index(0) if 0 in self._num_pushes.tolist() else None self._buffer[:, is_first_push] = data[is_first_push]
self._buffer[:, fill_ids, :] = data.to(self._device)[fill_ids]
# increment number of number of pushes for all batches # increment number of number of pushes for all batches
self._num_pushes += 1 self._num_pushes += 1
......
...@@ -51,10 +51,10 @@ class DelayBuffer: ...@@ -51,10 +51,10 @@ class DelayBuffer:
# the buffer size: current data plus the history length # the buffer size: current data plus the history length
self._circular_buffer = CircularBuffer(self._history_length + 1, batch_size, device) self._circular_buffer = CircularBuffer(self._history_length + 1, batch_size, device)
# the minimum and maximum lags across all environments. # the minimum and maximum lags across all batch indices.
self._min_time_lag = 0 self._min_time_lag = 0
self._max_time_lag = 0 self._max_time_lag = 0
# the lags for each environment. # the lags for each batch index.
self._time_lags = torch.zeros(batch_size, dtype=torch.int, device=device) self._time_lags = torch.zeros(batch_size, dtype=torch.int, device=device)
""" """
......
...@@ -48,7 +48,7 @@ Usage with a class modifier: ...@@ -48,7 +48,7 @@ Usage with a class modifier:
# create a modifier configuration # create a modifier configuration
# a digital filter with a simple delay of 1 timestep # a digital filter with a simple delay of 1 timestep
cfg = modifiers.DigitalFilter(A=[0.0], B=[0.0, 1.0]) cfg = modifiers.DigitalFilterCfg(A=[0.0], B=[0.0, 1.0])
# create the modifier instance # create the modifier instance
my_modifier = modifiers.DigitalFilter(cfg, my_tensor.shape, "cuda") my_modifier = modifiers.DigitalFilter(cfg, my_tensor.shape, "cuda")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment