Commit 4ed3c225 authored by peterd-NV's avatar peterd-NV Committed by Kelly Guo

Fixes env.unwrapped errors in recorder/replayer scripts (#235)

Set `env` to be `env.unwrapped` during initial environment creation to
avoid needing to manually specify `env.unwrapped` multiple times later
in the scripts, which would often lead to one being missed causing an
error.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent 2101c555
...@@ -89,7 +89,7 @@ def main(): ...@@ -89,7 +89,7 @@ def main():
env_cfg.recorders = None env_cfg.recorders = None
# Create environment # Create environment
env = gym.make(args_cli.task, cfg=env_cfg) env = gym.make(args_cli.task, cfg=env_cfg).unwrapped
# Set seed # Set seed
torch.manual_seed(args_cli.seed) torch.manual_seed(args_cli.seed)
......
...@@ -94,7 +94,7 @@ class RateLimiter: ...@@ -94,7 +94,7 @@ class RateLimiter:
next_wakeup_time = self.last_time + self.sleep_duration next_wakeup_time = self.last_time + self.sleep_duration
while time.time() < next_wakeup_time: while time.time() < next_wakeup_time:
time.sleep(self.render_period) time.sleep(self.render_period)
env.unwrapped.sim.render() env.sim.render()
self.last_time = self.last_time + self.sleep_duration self.last_time = self.last_time + self.sleep_duration
...@@ -162,7 +162,7 @@ def main(): ...@@ -162,7 +162,7 @@ def main():
env_cfg.recorders.dataset_filename = output_file_name env_cfg.recorders.dataset_filename = output_file_name
# create environment # create environment
env = gym.make(args_cli.task, cfg=env_cfg) env = gym.make(args_cli.task, cfg=env_cfg).unwrapped
# add teleoperation key for reset current recording instance # add teleoperation key for reset current recording instance
should_reset_recording_instance = False should_reset_recording_instance = False
...@@ -203,9 +203,7 @@ def main(): ...@@ -203,9 +203,7 @@ def main():
# get keyboard command # get keyboard command
delta_pose, gripper_command = teleop_interface.advance() delta_pose, gripper_command = teleop_interface.advance()
# convert to torch # convert to torch
delta_pose = torch.tensor(delta_pose, dtype=torch.float, device=env.unwrapped.device).repeat( delta_pose = torch.tensor(delta_pose, dtype=torch.float, device=env.device).repeat(env.num_envs, 1)
env.unwrapped.num_envs, 1
)
# compute actions based on environment # compute actions based on environment
actions = pre_process_actions(delta_pose, gripper_command) actions = pre_process_actions(delta_pose, gripper_command)
...@@ -218,7 +216,7 @@ def main(): ...@@ -218,7 +216,7 @@ def main():
if success_step_count >= args_cli.num_success_steps: if success_step_count >= args_cli.num_success_steps:
env.recorder_manager.record_pre_reset([0], force_export_or_skip=False) env.recorder_manager.record_pre_reset([0], force_export_or_skip=False)
env.recorder_manager.set_success_to_episodes( env.recorder_manager.set_success_to_episodes(
[0], torch.tensor([[True]], dtype=torch.bool, device=env.unwrapped.device) [0], torch.tensor([[True]], dtype=torch.bool, device=env.device)
) )
env.recorder_manager.export_episodes([0]) env.recorder_manager.export_episodes([0])
should_reset_recording_instance = True should_reset_recording_instance = True
...@@ -226,29 +224,26 @@ def main(): ...@@ -226,29 +224,26 @@ def main():
success_step_count = 0 success_step_count = 0
if should_reset_recording_instance: if should_reset_recording_instance:
env.unwrapped.recorder_manager.reset() env.recorder_manager.reset()
env.reset() env.reset()
should_reset_recording_instance = False should_reset_recording_instance = False
success_step_count = 0 success_step_count = 0
# print out the current demo count if it has changed # print out the current demo count if it has changed
if env.unwrapped.recorder_manager.exported_successful_episode_count > current_recorded_demo_count: if env.recorder_manager.exported_successful_episode_count > current_recorded_demo_count:
current_recorded_demo_count = env.unwrapped.recorder_manager.exported_successful_episode_count current_recorded_demo_count = env.recorder_manager.exported_successful_episode_count
print(f"Recorded {current_recorded_demo_count} successful demonstrations.") print(f"Recorded {current_recorded_demo_count} successful demonstrations.")
if ( if args_cli.num_demos > 0 and env.recorder_manager.exported_successful_episode_count >= args_cli.num_demos:
args_cli.num_demos > 0
and env.unwrapped.recorder_manager.exported_successful_episode_count >= args_cli.num_demos
):
print(f"All {args_cli.num_demos} demonstrations recorded. Exiting the app.") print(f"All {args_cli.num_demos} demonstrations recorded. Exiting the app.")
break break
# check that simulation is stopped or not # check that simulation is stopped or not
if env.unwrapped.sim.is_stopped(): if env.sim.is_stopped():
break break
if rate_limiter: if rate_limiter:
rate_limiter.sleep(env.unwrapped) rate_limiter.sleep(env)
env.close() env.close()
......
...@@ -133,7 +133,7 @@ def main(): ...@@ -133,7 +133,7 @@ def main():
env_cfg.terminations = {} env_cfg.terminations = {}
# create environment from loaded config # create environment from loaded config
env = gym.make(env_name, cfg=env_cfg) env = gym.make(env_name, cfg=env_cfg).unwrapped
teleop_interface = Se3Keyboard(pos_sensitivity=0.1, rot_sensitivity=0.1) teleop_interface = Se3Keyboard(pos_sensitivity=0.1, rot_sensitivity=0.1)
teleop_interface.add_callback("N", play_cb) teleop_interface.add_callback("N", play_cb)
...@@ -161,7 +161,7 @@ def main(): ...@@ -161,7 +161,7 @@ def main():
has_next_action = True has_next_action = True
while has_next_action: while has_next_action:
# initialize actions with zeros so those without next action will not move # initialize actions with zeros so those without next action will not move
actions = torch.zeros(env.unwrapped.action_space.shape) actions = torch.zeros(env.action_space.shape)
has_next_action = False has_next_action = False
for env_id in range(num_envs): for env_id in range(num_envs):
env_next_action = env_episode_data_map[env_id].get_next_action() env_next_action = env_episode_data_map[env_id].get_next_action()
...@@ -177,14 +177,12 @@ def main(): ...@@ -177,14 +177,12 @@ def main():
replayed_episode_count += 1 replayed_episode_count += 1
print(f"{replayed_episode_count :4}: Loading #{next_episode_index} episode to env_{env_id}") print(f"{replayed_episode_count :4}: Loading #{next_episode_index} episode to env_{env_id}")
episode_data = dataset_file_handler.load_episode( episode_data = dataset_file_handler.load_episode(
episode_names[next_episode_index], env.unwrapped.device episode_names[next_episode_index], env.device
) )
env_episode_data_map[env_id] = episode_data env_episode_data_map[env_id] = episode_data
# Set initial state for the new episode # Set initial state for the new episode
initial_state = episode_data.get_initial_state() initial_state = episode_data.get_initial_state()
env.unwrapped.reset_to( env.reset_to(initial_state, torch.tensor([env_id], device=env.device), is_relative=True)
initial_state, torch.tensor([env_id], device=env.unwrapped.device), is_relative=True
)
# Get the first action for the new episode # Get the first action for the new episode
env_next_action = env_episode_data_map[env_id].get_next_action() env_next_action = env_episode_data_map[env_id].get_next_action()
has_next_action = True has_next_action = True
...@@ -197,7 +195,7 @@ def main(): ...@@ -197,7 +195,7 @@ def main():
first_loop = False first_loop = False
else: else:
while is_paused: while is_paused:
env.unwrapped.sim.render() env.sim.render()
continue continue
env.step(actions) env.step(actions)
...@@ -208,7 +206,7 @@ def main(): ...@@ -208,7 +206,7 @@ def main():
f"Validating states at action-index: {env_episode_data_map[0].next_state_index - 1 :4}", f"Validating states at action-index: {env_episode_data_map[0].next_state_index - 1 :4}",
end="", end="",
) )
current_runtime_state = env.unwrapped.scene.get_state(is_relative=True) current_runtime_state = env.scene.get_state(is_relative=True)
states_matched, comparison_log = compare_states(state_from_dataset, current_runtime_state, 0) states_matched, comparison_log = compare_states(state_from_dataset, current_runtime_state, 0)
if states_matched: if states_matched:
print("\t- matched.") print("\t- matched.")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment