Commit 4ed3c225 authored by peterd-NV's avatar peterd-NV Committed by Kelly Guo

Fixes env.unwrapped errors in recorder/replayer scripts (#235)

Set `env` to be `env.unwrapped` during initial environment creation to
avoid needing to manually specify `env.unwrapped` multiple times later
in the scripts, which would often lead to one being missed causing an
error.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent 2101c555
......@@ -89,7 +89,7 @@ def main():
env_cfg.recorders = None
# Create environment
env = gym.make(args_cli.task, cfg=env_cfg)
env = gym.make(args_cli.task, cfg=env_cfg).unwrapped
# Set seed
torch.manual_seed(args_cli.seed)
......
......@@ -94,7 +94,7 @@ class RateLimiter:
next_wakeup_time = self.last_time + self.sleep_duration
while time.time() < next_wakeup_time:
time.sleep(self.render_period)
env.unwrapped.sim.render()
env.sim.render()
self.last_time = self.last_time + self.sleep_duration
......@@ -162,7 +162,7 @@ def main():
env_cfg.recorders.dataset_filename = output_file_name
# create environment
env = gym.make(args_cli.task, cfg=env_cfg)
env = gym.make(args_cli.task, cfg=env_cfg).unwrapped
# add teleoperation key for reset current recording instance
should_reset_recording_instance = False
......@@ -203,9 +203,7 @@ def main():
# get keyboard command
delta_pose, gripper_command = teleop_interface.advance()
# convert to torch
delta_pose = torch.tensor(delta_pose, dtype=torch.float, device=env.unwrapped.device).repeat(
env.unwrapped.num_envs, 1
)
delta_pose = torch.tensor(delta_pose, dtype=torch.float, device=env.device).repeat(env.num_envs, 1)
# compute actions based on environment
actions = pre_process_actions(delta_pose, gripper_command)
......@@ -218,7 +216,7 @@ def main():
if success_step_count >= args_cli.num_success_steps:
env.recorder_manager.record_pre_reset([0], force_export_or_skip=False)
env.recorder_manager.set_success_to_episodes(
[0], torch.tensor([[True]], dtype=torch.bool, device=env.unwrapped.device)
[0], torch.tensor([[True]], dtype=torch.bool, device=env.device)
)
env.recorder_manager.export_episodes([0])
should_reset_recording_instance = True
......@@ -226,29 +224,26 @@ def main():
success_step_count = 0
if should_reset_recording_instance:
env.unwrapped.recorder_manager.reset()
env.recorder_manager.reset()
env.reset()
should_reset_recording_instance = False
success_step_count = 0
# print out the current demo count if it has changed
if env.unwrapped.recorder_manager.exported_successful_episode_count > current_recorded_demo_count:
current_recorded_demo_count = env.unwrapped.recorder_manager.exported_successful_episode_count
if env.recorder_manager.exported_successful_episode_count > current_recorded_demo_count:
current_recorded_demo_count = env.recorder_manager.exported_successful_episode_count
print(f"Recorded {current_recorded_demo_count} successful demonstrations.")
if (
args_cli.num_demos > 0
and env.unwrapped.recorder_manager.exported_successful_episode_count >= args_cli.num_demos
):
if args_cli.num_demos > 0 and env.recorder_manager.exported_successful_episode_count >= args_cli.num_demos:
print(f"All {args_cli.num_demos} demonstrations recorded. Exiting the app.")
break
# check that simulation is stopped or not
if env.unwrapped.sim.is_stopped():
if env.sim.is_stopped():
break
if rate_limiter:
rate_limiter.sleep(env.unwrapped)
rate_limiter.sleep(env)
env.close()
......
......@@ -133,7 +133,7 @@ def main():
env_cfg.terminations = {}
# create environment from loaded config
env = gym.make(env_name, cfg=env_cfg)
env = gym.make(env_name, cfg=env_cfg).unwrapped
teleop_interface = Se3Keyboard(pos_sensitivity=0.1, rot_sensitivity=0.1)
teleop_interface.add_callback("N", play_cb)
......@@ -161,7 +161,7 @@ def main():
has_next_action = True
while has_next_action:
# initialize actions with zeros so those without next action will not move
actions = torch.zeros(env.unwrapped.action_space.shape)
actions = torch.zeros(env.action_space.shape)
has_next_action = False
for env_id in range(num_envs):
env_next_action = env_episode_data_map[env_id].get_next_action()
......@@ -177,14 +177,12 @@ def main():
replayed_episode_count += 1
print(f"{replayed_episode_count :4}: Loading #{next_episode_index} episode to env_{env_id}")
episode_data = dataset_file_handler.load_episode(
episode_names[next_episode_index], env.unwrapped.device
episode_names[next_episode_index], env.device
)
env_episode_data_map[env_id] = episode_data
# Set initial state for the new episode
initial_state = episode_data.get_initial_state()
env.unwrapped.reset_to(
initial_state, torch.tensor([env_id], device=env.unwrapped.device), is_relative=True
)
env.reset_to(initial_state, torch.tensor([env_id], device=env.device), is_relative=True)
# Get the first action for the new episode
env_next_action = env_episode_data_map[env_id].get_next_action()
has_next_action = True
......@@ -197,7 +195,7 @@ def main():
first_loop = False
else:
while is_paused:
env.unwrapped.sim.render()
env.sim.render()
continue
env.step(actions)
......@@ -208,7 +206,7 @@ def main():
f"Validating states at action-index: {env_episode_data_map[0].next_state_index - 1 :4}",
end="",
)
current_runtime_state = env.unwrapped.scene.get_state(is_relative=True)
current_runtime_state = env.scene.get_state(is_relative=True)
states_matched, comparison_log = compare_states(state_from_dataset, current_runtime_state, 0)
if states_matched:
print("\t- matched.")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment