Unverified Commit a7dbc84e authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Simplify the if-elses inside the event manager apply method (#948)

# Description

The if-else-continue logic inside the event manager has become scary.
This MR simplifies it and keeps the term call local to the if
statements. I hope this should help with the readability.

## Type of change

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 683fb5c6
......@@ -173,12 +173,12 @@ class EventManager(ManagerBase):
# iterate over all the event terms
for index, term_cfg in enumerate(self._mode_term_cfgs[mode]):
# resample interval if needed
if mode == "interval":
# extract time left for this term
time_left = self._interval_term_time_left[index]
# update the time left for each environment
time_left -= dt
# check if the interval has passed and sample a new interval
# note: we compare with a small value to handle floating point errors
if term_cfg.is_global_time:
......@@ -186,19 +186,18 @@ class EventManager(ManagerBase):
lower, upper = term_cfg.interval_range_s
sampled_interval = torch.rand(1) * (upper - lower) + lower
self._interval_term_time_left[index][:] = sampled_interval
else:
# no need to call func to apply term
continue
# call the event term (with None for env_ids)
term_cfg.func(self._env, None, **term_cfg.params)
else:
env_ids = (time_left < 1e-6).nonzero().flatten()
if len(env_ids) > 0:
valid_env_ids = (time_left < 1e-6).nonzero().flatten()
if len(valid_env_ids) > 0:
lower, upper = term_cfg.interval_range_s
sampled_time = torch.rand(len(env_ids), device=self.device) * (upper - lower) + lower
self._interval_term_time_left[index][env_ids] = sampled_time
else:
# no need to call func to apply term
continue
# check for minimum frequency for reset
sampled_time = torch.rand(len(valid_env_ids), device=self.device) * (upper - lower) + lower
self._interval_term_time_left[index][valid_env_ids] = sampled_time
# call the event term
term_cfg.func(self._env, valid_env_ids, **term_cfg.params)
elif mode == "reset":
# obtain the minimum step count between resets
min_step_count = term_cfg.min_step_count_between_reset
......@@ -211,6 +210,9 @@ class EventManager(ManagerBase):
if min_step_count == 0:
self._reset_term_last_triggered_step_id[index][env_ids] = global_env_step_count
self._reset_term_last_triggered_once[index][env_ids] = True
# call the event term with the environment indices
term_cfg.func(self._env, env_ids, **term_cfg.params)
else:
# extract last reset step for this term
last_triggered_step = self._reset_term_last_triggered_step_id[index][env_ids]
......@@ -234,12 +236,12 @@ class EventManager(ManagerBase):
if len(valid_env_ids) > 0:
self._reset_term_last_triggered_once[index][valid_env_ids] = True
self._reset_term_last_triggered_step_id[index][valid_env_ids] = global_env_step_count
# call the event term
term_cfg.func(self._env, valid_env_ids, **term_cfg.params)
# no need to call func to apply term again
continue
# call the event term
term_cfg.func(self._env, env_ids, **term_cfg.params)
else:
# call the event term
term_cfg.func(self._env, env_ids, **term_cfg.params)
"""
Operations - Term settings.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment