Commit ace3fe3a authored by peterd-NV's avatar peterd-NV Committed by Kelly Guo

Extracts success term and add check in Robomimic play script (#358)

# Description

<!--
Thank you for your interest in sending a pull request. Please make sure
to check the contribution guidelines.

Link:
https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html
-->

This change extracts the success term from the env and uses it to check
for policy success in the Robomimic play script.

Fixes # (issue)

<!-- As a practice, it is recommended to open an issue to have
discussions on the proposed pull request.
This makes it easier for the community to keep track of what is being
developed or added, and if a given feature
is demanded by more than one party. -->

Previously the script used the termination condition as a check for
success. However, if other termination terms are added to an env that
reset the env without a success occuring, then that would cause false
positives. The script now explicitly uses the "success" termination
event to fix this.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)


## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent 2ba09f8c
......@@ -72,7 +72,7 @@ if args_cli.enable_pinocchio:
from isaaclab_tasks.utils import parse_env_cfg
def rollout(policy, env, horizon, device):
def rollout(policy, env, success_term, horizon, device):
"""Perform a single rollout of the policy in the environment.
Args:
......@@ -128,9 +128,10 @@ def rollout(policy, env, horizon, device):
traj["actions"].append(actions.tolist())
traj["next_obs"].append(obs)
if terminated:
# Check if rollout was successful
if bool(success_term.func(env, **success_term.params)[0]):
return True, traj
elif truncated:
elif terminated or truncated:
return False, traj
return False, traj
......@@ -150,6 +151,10 @@ def main():
# Disable recorder
env_cfg.recorders = None
# Extract success checking function
success_term = env_cfg.terminations.success
env_cfg.terminations.success = None
# Create environment
env = gym.make(args_cli.task, cfg=env_cfg).unwrapped
......@@ -167,7 +172,7 @@ def main():
results = []
for trial in range(args_cli.num_rollouts):
print(f"[INFO] Starting trial {trial}")
terminated, traj = rollout(policy, env, args_cli.horizon, device)
terminated, traj = rollout(policy, env, success_term, args_cli.horizon, device)
results.append(terminated)
print(f"[INFO] Trial {trial}: {terminated}\n")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment