Unverified Commit 7379dcee authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Fixes mode-based checks inside the `EventManager.apply` call (#777)

# Description

Noticed that we were checking for function arguments inside a for-loop,
which isn't necessary. Moved this check outside to make it simpler to
read the code.

Also noticed a small corner case in the event manager when reset is
called and `env_ids` is None. In that case, it would bypass the check
for min steps between reset and directly apply the term to the
environment. I am not sure if that was intentional. if so, I can revert
the behavior.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 765666d5
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.21.1"
version = "0.21.2"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.21.2 (2024-08-13)
~~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Moved event mode-based checks in the :meth:`omni.isaac.lab.managers.EventManager.apply` method outside
the loop that iterates over the event terms. This prevents unnecessary checks and improves readability.
* Fixed the logic for global and per environment interval times when using the "interval" mode inside the
event manager. Earlier, the internal lists for these times were of unequal lengths which led to wrong indexing
inside the loop that iterates over the event terms.
0.21.1 (2024-08-06)
~~~~~~~~~~~~~~~~~~~
......
......@@ -519,7 +519,7 @@ class DirectRLEnv(gym.Env):
if self.cfg.events:
if "reset" in self.event_manager.available_modes:
env_step_count = self._sim_step_counter // self.cfg.decimation
self.event_manager.apply(env_ids=env_ids, mode="reset", global_env_step_count=env_step_count)
self.event_manager.apply(mode="reset", env_ids=env_ids, global_env_step_count=env_step_count)
# reset noise models
if self.cfg.action_noise_model:
......
......@@ -337,10 +337,10 @@ class ManagerBasedEnv:
"""
# reset the internal buffers of the scene elements
self.scene.reset(env_ids)
# apply events such as randomizations for environments that need a reset
# apply events such as randomization for environments that need a reset
if "reset" in self.event_manager.available_modes:
env_step_count = self._sim_step_counter // self.cfg.decimation
self.event_manager.apply(env_ids=env_ids, mode="reset", global_env_step_count=env_step_count)
self.event_manager.apply(mode="reset", env_ids=env_ids, global_env_step_count=env_step_count)
# iterate over all managers and reset them
# this returns a dictionary of information which is stored in the extras
......
......@@ -319,7 +319,7 @@ class ManagerBasedRLEnv(ManagerBasedEnv, gym.Env):
# apply events such as randomizations for environments that need a reset
if "reset" in self.event_manager.available_modes:
env_step_count = self._sim_step_counter // self.cfg.decimation
self.event_manager.apply(env_ids=env_ids, mode="reset", global_env_step_count=env_step_count)
self.event_manager.apply(mode="reset", env_ids=env_ids, global_env_step_count=env_step_count)
# iterate over all managers and reset them
# this returns a dictionary of information which is stored in the extras
......
......@@ -128,73 +128,114 @@ class EventManager(ManagerBase):
):
"""Calls each event term in the specified mode.
Note:
For interval mode, the time step of the environment is used to determine if the event
should be applied.
This function iterates over all the event terms in the specified mode and calls the function
corresponding to the term. The function is called with the environment instance and the environment
indices to apply the event to.
For the "interval" mode, the function is called when the time interval has passed. This requires
specifying the time step of the environment.
For the "reset" mode, the function is called when the mode is "reset" and the total number of environment
steps that have happened since the last trigger of the function is equal to its configured parameter for
the number of environment steps between resets.
Args:
mode: The mode of event.
env_ids: The indices of the environments to apply the event to.
Defaults to None, in which case the event is applied to all environments.
Defaults to None, in which case the event is applied to all environments when applicable.
dt: The time step of the environment. This is only used for the "interval" mode.
Defaults to None to simplify the call for other modes.
global_env_step_count: The environment step count of the task. This is only used for the "reset" mode.
Defaults to None to simplify the call for other modes.
global_env_step_count: The total number of environment steps that have happened. This is only used
for the "reset" mode. Defaults to None to simplify the call for other modes.
Raises:
ValueError: If the mode is ``"interval"`` and the time step is not provided.
ValueError: If the mode is ``"interval"`` and the environment indices are provided. This is an undefined
behavior as the environment indices are computed based on the time left for each environment.
ValueError: If the mode is ``"reset"`` and the total number of environment steps that have happened
is not provided.
"""
# check if mode is valid
if mode not in self._mode_term_names:
carb.log_warn(f"Event mode '{mode}' is not defined. Skipping event.")
return
# check if mode is interval and dt is not provided
if mode == "interval" and dt is None:
raise ValueError(f"Event mode '{mode}' requires the time-step of the environment.")
if mode == "interval" and env_ids is not None:
raise ValueError(
f"Event mode '{mode}' does not require environment indices. This is an undefined behavior"
" as the environment indices are computed based on the time left for each environment."
)
# check if mode is reset and env step count is not provided
if mode == "reset" and global_env_step_count is None:
raise ValueError(f"Event mode '{mode}' requires the total number of environment steps to be provided.")
# iterate over all the event terms
for index, term_cfg in enumerate(self._mode_term_cfgs[mode]):
# resample interval if needed
if mode == "interval":
if dt is None:
raise ValueError(
f"Event mode '{mode}' requires the time step of the environment"
" to be passed to the event manager."
)
# extract time left for this term
time_left = self._interval_term_time_left[index]
# update the time left for each environment
time_left -= dt
# check if the interval has passed and sample a new interval
# note: we compare with a small value to handle floating point errors
if term_cfg.is_global_time:
# extract time left for this term
time_left = self._interval_mode_time_global[index]
# update the time left for each environment
time_left -= dt
# check if the interval has passed
if time_left <= 0.0:
if time_left < 1e-6:
lower, upper = term_cfg.interval_range_s
self._interval_mode_time_global[index] = torch.rand(1) * (upper - lower) + lower
sampled_interval = torch.rand(1) * (upper - lower) + lower
self._interval_term_time_left[index][:] = sampled_interval
else:
# no need to call func to sample
# no need to call func to apply term
continue
else:
# extract time left for this term
time_left = self._interval_mode_time_left[index]
# update the time left for each environment
time_left -= dt
# check if the interval has passed
env_ids = (time_left <= 0.0).nonzero().flatten()
env_ids = (time_left < 1e-6).nonzero().flatten()
if len(env_ids) > 0:
lower, upper = term_cfg.interval_range_s
time_left[env_ids] = torch.rand(len(env_ids), device=self.device) * (upper - lower) + lower
sampled_time = torch.rand(len(env_ids), device=self.device) * (upper - lower) + lower
self._interval_term_time_left[index][env_ids] = sampled_time
else:
# no need to call func to apply term
continue
# check for minimum frequency for reset
elif mode == "reset":
if global_env_step_count is None:
raise ValueError(
f"Event mode '{mode}' requires the step count of the environment"
" to be passed to the event manager."
)
# obtain the minimum step count between resets
min_step_count = term_cfg.min_step_count_between_reset
# resolve the environment indices
if env_ids is None:
env_ids = slice(None)
# We bypass the trigger mechanism if min_step_count is zero, i.e. apply term on every reset call.
# This should avoid the overhead of checking the trigger condition.
if min_step_count == 0:
self._reset_term_last_triggered_step_id[index][env_ids] = global_env_step_count
self._reset_term_last_triggered_once[index][env_ids] = True
else:
# extract last reset step for this term
last_triggered_step = self._reset_term_last_triggered_step_id[index][env_ids]
triggered_at_least_once = self._reset_term_last_triggered_once[index][env_ids]
# compute the steps since last reset
steps_since_triggered = global_env_step_count - last_triggered_step
# check if the term can be applied after the minimum step count between triggers has passed
valid_trigger = steps_since_triggered >= min_step_count
# check if the term has not been triggered yet (in that case, we trigger it at least once)
# this is usually only needed at the start of the environment
valid_trigger |= (last_triggered_step == 0) & ~triggered_at_least_once
# select the valid environment indices based on the trigger
if env_ids == slice(None):
env_ids = valid_trigger.nonzero().flatten()
else:
env_ids = env_ids[valid_trigger]
if env_ids is not None and len(env_ids) > 0:
last_reset_step = self._reset_mode_last_reset_step_count[index]
steps_since_last_reset = global_env_step_count - last_reset_step
env_ids = env_ids[steps_since_last_reset[env_ids] >= term_cfg.min_step_count_between_reset]
# reset the last reset step for each environment to the current env step count
if len(env_ids) > 0:
last_reset_step[env_ids] = global_env_step_count
self._reset_term_last_triggered_once[index][env_ids] = True
self._reset_term_last_triggered_step_id[index][env_ids] = global_env_step_count
else:
# no need to call func to sample
# no need to call func to apply term
continue
# call the event term
term_cfg.func(self._env, env_ids, **term_cfg.params)
......@@ -255,12 +296,12 @@ class EventManager(ManagerBase):
self._mode_term_names: dict[str, list[str]] = dict()
self._mode_term_cfgs: dict[str, list[EventTermCfg]] = dict()
self._mode_class_term_cfgs: dict[str, list[EventTermCfg]] = dict()
# buffer to store the time left for each environment for "interval" mode
self._interval_mode_time_left: list[torch.Tensor] = list()
# global timer for "interval" mode for global properties
self._interval_mode_time_global: list[torch.Tensor] = list()
# buffer to store the step count when reset was last performed for each environment for "reset" mode
self._reset_mode_last_reset_step_count: list[torch.Tensor] = list()
# buffer to store the time left for "interval" mode
# if interval is global, then it is a single value, otherwise it is per environment
self._interval_term_time_left: list[torch.Tensor] = list()
# buffer to store the step count when the term was last triggered for each environment for "reset" mode
self._reset_term_last_triggered_step_id: list[torch.Tensor] = list()
self._reset_term_last_triggered_once: list[torch.Tensor] = list()
# check if config is dict already
if isinstance(self.cfg, dict):
......@@ -302,6 +343,7 @@ class EventManager(ManagerBase):
self._mode_class_term_cfgs[term_cfg.mode].append(term_cfg)
# resolve the mode of the events
# -- interval mode
if term_cfg.mode == "interval":
if term_cfg.interval_range_s is None:
raise ValueError(
......@@ -312,13 +354,23 @@ class EventManager(ManagerBase):
if term_cfg.is_global_time:
lower, upper = term_cfg.interval_range_s
time_left = torch.rand(1) * (upper - lower) + lower
self._interval_mode_time_global.append(time_left)
self._interval_term_time_left.append(time_left)
else:
# sample the time left for each environment
lower, upper = term_cfg.interval_range_s
time_left = torch.rand(self.num_envs, device=self.device) * (upper - lower) + lower
self._interval_mode_time_left.append(time_left)
self._interval_term_time_left.append(time_left)
# -- reset mode
elif term_cfg.mode == "reset":
if term_cfg.min_step_count_between_reset < 0:
raise ValueError(
f"Event term '{term_name}' has mode 'reset' but 'min_step_count_between_reset' is"
f" negative: {term_cfg.min_step_count_between_reset}. Please provide a non-negative value."
)
# initialize the current step count for each environment to zero
step_count = torch.zeros(self.num_envs, device=self.device, dtype=torch.int32)
self._reset_mode_last_reset_step_count.append(step_count)
self._reset_term_last_triggered_step_id.append(step_count)
# initialize the trigger flag for each environment to zero
no_trigger = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
self._reset_term_last_triggered_once.append(no_trigger)
......@@ -191,7 +191,7 @@ class EventTermCfg(ManagerTermBaseCfg):
"""
interval_range_s: tuple[float, float] | None = None
"""The range of time in seconds at which the term is applied.
"""The range of time in seconds at which the term is applied. Defaults to None.
Based on this, the interval is sampled uniformly between the specified
range for each environment instance. The term is applied on the environment
......@@ -202,21 +202,24 @@ class EventTermCfg(ManagerTermBaseCfg):
"""
is_global_time: bool = False
""" Whether randomization should be tracked on a per-environment basis.
"""Whether randomization should be tracked on a per-environment basis. Defaults to False.
If True, the same time for the interval is tracked for all the environments instead of
tracking the time per-environment.
If True, the same interval time is used for all the environment instances.
If False, the interval time is sampled independently for each environment instance
and the term is applied when the current time hits the interval time for that instance.
Note:
This is only used if the mode is ``"interval"``.
"""
min_step_count_between_reset: int = 0
"""The minimum number of environment steps between when term is applied.
"""The number of environment steps after which the term is applied since its last application. Defaults to 0.
When mode is "reset", the term will not be applied on the next reset unless
the number of steps since the last application of the term has exceeded this.
This is useful to avoid calling this term too often and improve performance.
When the mode is "reset", the term is only applied if the number of environment steps since
its last application exceeds this quantity. This helps to avoid calling the term too often,
thereby improving performance.
If the value is zero, the term is applied on every call to the manager with the mode "reset".
Note:
This is only used if the mode is ``"reset"``.
......
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
# ignore private usage of variables warning
# pyright: reportPrivateUsage=none
"""Launch Isaac Sim Simulator first."""
from omni.isaac.lab.app import AppLauncher, run_tests
# launch omniverse app
simulation_app = AppLauncher(headless=True).app
"""Rest everything follows."""
import torch
import unittest
from collections import namedtuple
from omni.isaac.lab.managers import EventManager, EventTermCfg
from omni.isaac.lab.utils import configclass
DummyEnv = namedtuple("ManagerBasedRLEnv", ["num_envs", "dt", "device", "dummy1", "dummy2"])
"""Dummy environment for testing."""
def reset_dummy1_to_zero(env, env_ids: torch.Tensor):
env.dummy1[env_ids] = 0
def increment_dummy1_by_one(env, env_ids: torch.Tensor):
env.dummy1[env_ids] += 1
def change_dummy1_by_value(env, env_ids: torch.Tensor, value: int):
env.dummy1[env_ids] += value
def reset_dummy2_to_zero(env, env_ids: torch.Tensor):
env.dummy2[env_ids] = 0
def increment_dummy2_by_one(env, env_ids: torch.Tensor):
env.dummy2[env_ids] += 1
class TestEventManager(unittest.TestCase):
"""Test cases for various situations with event manager."""
def setUp(self) -> None:
# create values
num_envs = 32
device = "cpu"
# create dummy tensors
dummy1 = torch.zeros((num_envs, 2), device=device)
dummy2 = torch.zeros((num_envs, 10), device=device)
# create dummy environment
self.env = DummyEnv(num_envs, 0.01, device, dummy1, dummy2)
def test_str(self):
"""Test the string representation of the event manager."""
cfg = {
"term_1": EventTermCfg(func=increment_dummy1_by_one, mode="interval", interval_range_s=(0.1, 0.1)),
"term_2": EventTermCfg(func=reset_dummy1_to_zero, mode="reset"),
"term_3": EventTermCfg(func=change_dummy1_by_value, mode="custom", params={"value": 10}),
"term_4": EventTermCfg(func=change_dummy1_by_value, mode="custom", params={"value": 2}),
}
self.event_man = EventManager(cfg, self.env)
# print the expected string
print()
print(self.event_man)
def test_config_equivalence(self):
"""Test the equivalence of event manager created from different config types."""
# create from dictionary
cfg = {
"term_1": EventTermCfg(func=increment_dummy1_by_one, mode="interval", interval_range_s=(0.1, 0.1)),
"term_2": EventTermCfg(func=reset_dummy1_to_zero, mode="reset"),
"term_3": EventTermCfg(func=change_dummy1_by_value, mode="custom", params={"value": 10}),
}
event_man_from_dict = EventManager(cfg, self.env)
# create from config class
@configclass
class MyEventManagerCfg:
"""Event manager config with no type annotations."""
term_1 = EventTermCfg(func=increment_dummy1_by_one, mode="interval", interval_range_s=(0.1, 0.1))
term_2 = EventTermCfg(func=reset_dummy1_to_zero, mode="reset")
term_3 = EventTermCfg(func=change_dummy1_by_value, mode="custom", params={"value": 10})
cfg = MyEventManagerCfg()
event_man_from_cfg = EventManager(cfg, self.env)
# create from config class
@configclass
class MyEventManagerAnnotatedCfg:
"""Event manager config with type annotations."""
term_1: EventTermCfg = EventTermCfg(
func=increment_dummy1_by_one, mode="interval", interval_range_s=(0.1, 0.1)
)
term_2: EventTermCfg = EventTermCfg(func=reset_dummy1_to_zero, mode="reset")
term_3: EventTermCfg = EventTermCfg(func=change_dummy1_by_value, mode="custom", params={"value": 10})
cfg = MyEventManagerAnnotatedCfg()
event_man_from_annotated_cfg = EventManager(cfg, self.env)
# check equivalence
# parsed terms
self.assertDictEqual(event_man_from_dict.active_terms, event_man_from_annotated_cfg.active_terms)
self.assertDictEqual(event_man_from_cfg.active_terms, event_man_from_annotated_cfg.active_terms)
self.assertDictEqual(event_man_from_dict.active_terms, event_man_from_cfg.active_terms)
# parsed term configs
self.assertDictEqual(event_man_from_dict._mode_term_cfgs, event_man_from_annotated_cfg._mode_term_cfgs)
self.assertDictEqual(event_man_from_cfg._mode_term_cfgs, event_man_from_annotated_cfg._mode_term_cfgs)
self.assertDictEqual(event_man_from_dict._mode_term_cfgs, event_man_from_cfg._mode_term_cfgs)
def test_active_terms(self):
"""Test the correct reading of active terms."""
cfg = {
"term_1": EventTermCfg(func=increment_dummy1_by_one, mode="interval", interval_range_s=(0.1, 0.1)),
"term_2": EventTermCfg(func=reset_dummy1_to_zero, mode="reset"),
"term_3": EventTermCfg(func=change_dummy1_by_value, mode="custom", params={"value": 10}),
"term_4": EventTermCfg(func=change_dummy1_by_value, mode="custom", params={"value": 2}),
}
self.event_man = EventManager(cfg, self.env)
self.assertEqual(len(self.event_man.active_terms), 3)
self.assertEqual(len(self.event_man.active_terms["interval"]), 1)
self.assertEqual(len(self.event_man.active_terms["reset"]), 1)
self.assertEqual(len(self.event_man.active_terms["custom"]), 2)
def test_invalid_event_func_module(self):
"""Test the handling of invalid event function's module in string representation."""
cfg = {
"term_1": EventTermCfg(func=increment_dummy1_by_one, mode="interval", interval_range_s=(0.1, 0.1)),
"term_2": EventTermCfg(func="a:reset_dummy1_to_zero", mode="reset"),
}
with self.assertRaises(ValueError):
self.event_man = EventManager(cfg, self.env)
def test_invalid_event_config(self):
"""Test the handling of invalid event function's config parameters."""
cfg = {
"term_1": EventTermCfg(func=increment_dummy1_by_one, mode="interval", interval_range_s=(0.1, 0.1)),
"term_2": EventTermCfg(func=reset_dummy1_to_zero, mode="reset"),
"term_3": EventTermCfg(func=change_dummy1_by_value, mode="custom"),
}
with self.assertRaises(ValueError):
self.event_man = EventManager(cfg, self.env)
def test_apply_interval_mode_without_global_time(self):
"""Test the application of event terms that are in interval mode without global time.
During local time, each environment instance has its own time for the interval term.
"""
# make two intervals -- one is fixed and the other is random
term_1_interval_range_s = (10 * self.env.dt, 10 * self.env.dt)
term_2_interval_range_s = (2 * self.env.dt, 10 * self.env.dt)
cfg = {
"term_1": EventTermCfg(
func=increment_dummy1_by_one,
mode="interval",
interval_range_s=term_1_interval_range_s,
is_global_time=False,
),
"term_2": EventTermCfg(
func=increment_dummy2_by_one,
mode="interval",
interval_range_s=term_2_interval_range_s,
is_global_time=False,
),
}
self.event_man = EventManager(cfg, self.env)
# obtain the initial time left for the interval terms
term_2_interval_time = self.event_man._interval_term_time_left[1].clone()
expected_dummy2_value = torch.zeros_like(self.env.dummy2)
for count in range(50):
# apply the event terms
self.event_man.apply("interval", dt=self.env.dt)
# manually decrement the interval time for term2 since it is randomly sampled
term_2_interval_time -= self.env.dt
# check the values
# we increment the dummy1 by 1 every 10 steps. at the 9th count (aka 10th apply), the value should be 1
torch.testing.assert_close(self.env.dummy1, (count + 1) // 10 * torch.ones_like(self.env.dummy1))
# we increment the dummy2 by 1 every 2 to 10 steps based on the random interval
expected_dummy2_value += term_2_interval_time.unsqueeze(1) < 1e-6
torch.testing.assert_close(self.env.dummy2, expected_dummy2_value)
# check the time sampled at the end of the interval is valid
# -- fixed interval
if (count + 1) % 10 == 0:
term_1_interval_time_init = self.event_man._interval_term_time_left[0].clone()
expected_time_interval_init = torch.full_like(term_1_interval_time_init, term_1_interval_range_s[1])
torch.testing.assert_close(term_1_interval_time_init, expected_time_interval_init)
# -- random interval
env_ids = (term_2_interval_time < 1e-6).nonzero(as_tuple=True)[0]
if len(env_ids) > 0:
term_2_interval_time[env_ids] = self.event_man._interval_term_time_left[1][env_ids]
def test_apply_interval_mode_with_global_time(self):
"""Test the application of event terms that are in interval mode with global time.
During global time, all the environment instances share the same time for the interval term.
"""
# make two intervals -- one is fixed and the other is random
term_1_interval_range_s = (10 * self.env.dt, 10 * self.env.dt)
term_2_interval_range_s = (2 * self.env.dt, 10 * self.env.dt)
cfg = {
"term_1": EventTermCfg(
func=increment_dummy1_by_one,
mode="interval",
interval_range_s=term_1_interval_range_s,
is_global_time=True,
),
"term_2": EventTermCfg(
func=increment_dummy2_by_one,
mode="interval",
interval_range_s=term_2_interval_range_s,
is_global_time=True,
),
}
self.event_man = EventManager(cfg, self.env)
# obtain the initial time left for the interval terms
term_2_interval_time = self.event_man._interval_term_time_left[1].clone()
expected_dummy2_value = torch.zeros_like(self.env.dummy2)
for count in range(50):
# apply the event terms
self.event_man.apply("interval", dt=self.env.dt)
# manually decrement the interval time for term2 since it is randomly sampled
term_2_interval_time -= self.env.dt
# check the values
# we increment the dummy1 by 1 every 10 steps. at the 9th count (aka 10th apply), the value should be 1
torch.testing.assert_close(self.env.dummy1, (count + 1) // 10 * torch.ones_like(self.env.dummy1))
# we increment the dummy2 by 1 every 2 to 10 steps based on the random interval
expected_dummy2_value += term_2_interval_time < 1e-6
torch.testing.assert_close(self.env.dummy2, expected_dummy2_value)
# check the time sampled at the end of the interval is valid
# -- fixed interval
if (count + 1) % 10 == 0:
term_1_interval_time_init = self.event_man._interval_term_time_left[0].clone()
expected_time_interval_init = torch.full_like(term_1_interval_time_init, term_1_interval_range_s[1])
torch.testing.assert_close(term_1_interval_time_init, expected_time_interval_init)
# -- random interval
if term_2_interval_time < 1e-6:
term_2_interval_time = self.event_man._interval_term_time_left[1].clone()
def test_apply_reset_mode(self):
"""Test the application of event terms that are in reset mode."""
cfg = {
"term_1": EventTermCfg(func=increment_dummy1_by_one, mode="reset"),
"term_2": EventTermCfg(func=reset_dummy1_to_zero, mode="reset", min_step_count_between_reset=10),
}
self.event_man = EventManager(cfg, self.env)
# manually keep track of the expected values for dummy1 and trigger count
expected_dummy1_value = torch.zeros_like(self.env.dummy1)
term_2_trigger_step_id = torch.zeros((self.env.num_envs,), dtype=torch.int32, device=self.env.device)
for count in range(50):
# apply the event terms for all the env ids
if count % 3 == 0:
self.event_man.apply("reset", global_env_step_count=count)
# we increment the dummy1 by 1 every call to reset mode due to term 1
expected_dummy1_value[:] += 1
# manually update the expected value for term 2
if (count - term_2_trigger_step_id[0]) >= 10 or count == 0:
expected_dummy1_value = torch.zeros_like(self.env.dummy1)
term_2_trigger_step_id[:] = count
# check the values of trigger count
# -- term 1
expected_trigger_count = torch.full(
(self.env.num_envs,), 3 * (count // 3), dtype=torch.int32, device=self.env.device
)
torch.testing.assert_close(self.event_man._reset_term_last_triggered_step_id[0], expected_trigger_count)
# -- term 2
torch.testing.assert_close(self.event_man._reset_term_last_triggered_step_id[1], term_2_trigger_step_id)
# check the values of dummy1
torch.testing.assert_close(self.env.dummy1, expected_dummy1_value)
def test_apply_reset_mode_subset_env_ids(self):
"""Test the application of event terms that are in reset mode over a subset of environment ids."""
cfg = {
"term_1": EventTermCfg(func=increment_dummy1_by_one, mode="reset"),
"term_2": EventTermCfg(func=reset_dummy1_to_zero, mode="reset", min_step_count_between_reset=10),
}
self.event_man = EventManager(cfg, self.env)
# since we are applying the event terms over a subset of env ids, we need to keep track of the trigger count
# manually for the sake of testing
term_2_trigger_step_id = torch.zeros((self.env.num_envs,), dtype=torch.int32, device=self.env.device)
term_2_trigger_once = torch.zeros((self.env.num_envs,), dtype=torch.bool, device=self.env.device)
expected_dummy1_value = torch.zeros_like(self.env.dummy1)
for count in range(50):
# randomly select a subset of environment ids
env_ids = (torch.rand(self.env.num_envs, device=self.env.device) < 0.5).nonzero().flatten()
# apply the event terms for the selected env ids
self.event_man.apply("reset", env_ids=env_ids, global_env_step_count=count)
# modify the trigger count for term 2
trigger_ids = (count - term_2_trigger_step_id[env_ids]) >= 10
trigger_ids |= (term_2_trigger_step_id[env_ids] == 0) & ~term_2_trigger_once[env_ids]
term_2_trigger_step_id[env_ids[trigger_ids]] = count
term_2_trigger_once[env_ids[trigger_ids]] = True
# we increment the dummy1 by 1 every call to reset mode
# every 10th call, we reset the dummy1 to 0
expected_dummy1_value[env_ids] += 1 # effect of term 1
expected_dummy1_value[env_ids[trigger_ids]] = 0 # effect of term 2
# check the values of trigger count
# -- term 1
expected_trigger_count = torch.full((len(env_ids),), count, dtype=torch.int32, device=self.env.device)
torch.testing.assert_close(
self.event_man._reset_term_last_triggered_step_id[0][env_ids], expected_trigger_count
)
# -- term 2
torch.testing.assert_close(self.event_man._reset_term_last_triggered_step_id[1], term_2_trigger_step_id)
# check the values of dummy1
torch.testing.assert_close(self.env.dummy1, expected_dummy1_value)
if __name__ == "__main__":
run_tests()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment