Unverified Commit 73f26f6b authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Puts all environment-related scripts in inference mode (#215)

# Description

When having agent-environment interaction, it is important to ensure
that PyTorch constructs no computational graph. Otherwise, it will keep
allocating memory for gradients and result in out-of-memory error.

This MR wraps all scripts with `torch.inference_mode()` to prevent this
issue.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
parent 34f8b443
...@@ -53,13 +53,12 @@ def main(): ...@@ -53,13 +53,12 @@ def main():
env.reset() env.reset()
# simulate environment # simulate environment
while simulation_app.is_running(): while simulation_app.is_running():
# sample actions from -1 to 1 # run everything in inference mode
actions = 2 * torch.rand((env.num_envs, env.action_space.shape[0]), device=env.device) - 1 with torch.inference_mode():
# apply actions # sample actions from -1 to 1
_, _, _, _ = env.step(actions) actions = 2 * torch.rand((env.num_envs, env.action_space.shape[0]), device=env.device) - 1
# check if simulator is stopped # apply actions
if env.unwrapped.sim.is_stopped(): _, _, _, _ = env.step(actions)
break
# close the simulator # close the simulator
env.close() env.close()
......
...@@ -258,27 +258,29 @@ def main(): ...@@ -258,27 +258,29 @@ def main():
# run state machine # run state machine
for _ in range(10000): for _ in range(10000):
# step environment # run everything in inference mode
dones = env.step(actions)[-2] with torch.inference_mode():
# step environment
# observations dones = env.step(actions)[-2]
ee_pose = env.robot.data.ee_state_w[:, :7].clone()
object_pose = env.object.data.root_state_w[:, :7].clone() # observations
des_object_pose = env.object_des_pose_w.clone() ee_pose = env.robot.data.ee_state_w[:, :7].clone()
# transform from world to base frames object_pose = env.object.data.root_state_w[:, :7].clone()
ee_pose[:, :3] -= env.robot.data.root_pos_w des_object_pose = env.object_des_pose_w.clone()
object_pose[:, :3] -= env.robot.data.root_pos_w # transform from world to base frames
des_object_pose[:, :3] -= env.robot.data.root_pos_w ee_pose[:, :3] -= env.robot.data.root_pos_w
# advance state machine object_pose[:, :3] -= env.robot.data.root_pos_w
with Timer("state machine"): des_object_pose[:, :3] -= env.robot.data.root_pos_w
sm_actions = pick_sm.compute(ee_pose, object_pose, des_object_pose) # advance state machine
with Timer("state machine"):
# set actions for IK with positions sm_actions = pick_sm.compute(ee_pose, object_pose, des_object_pose)
actions[:, :3] = sm_actions[:, :3]
actions[:, -1] = sm_actions[:, -1] # set actions for IK with positions
# reset state machine actions[:, :3] = sm_actions[:, :3]
if dones.any(): actions[:, -1] = sm_actions[:, -1]
pick_sm.reset_idx(dones.nonzero(as_tuple=False).squeeze(-1)) # reset state machine
if dones.any():
pick_sm.reset_idx(dones.nonzero(as_tuple=False).squeeze(-1))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -103,17 +103,16 @@ def main(): ...@@ -103,17 +103,16 @@ def main():
# simulate environment # simulate environment
while simulation_app.is_running(): while simulation_app.is_running():
# get keyboard command # run everything in inference mode
delta_pose, gripper_command = teleop_interface.advance() with torch.inference_mode():
# convert to torch # get keyboard command
delta_pose = torch.tensor(delta_pose, dtype=torch.float, device=env.device).repeat(env.num_envs, 1) delta_pose, gripper_command = teleop_interface.advance()
# pre-process actions # convert to torch
actions = pre_process_actions(delta_pose, gripper_command) delta_pose = torch.tensor(delta_pose, dtype=torch.float, device=env.device).repeat(env.num_envs, 1)
# apply actions # pre-process actions
_, _, _, _ = env.step(actions) actions = pre_process_actions(delta_pose, gripper_command)
# check if simulator is stopped # apply actions
if env.unwrapped.sim.is_stopped(): _, _, _, _ = env.step(actions)
break
# close the simulator # close the simulator
env.close() env.close()
......
...@@ -52,13 +52,12 @@ def main(): ...@@ -52,13 +52,12 @@ def main():
env.reset() env.reset()
# simulate environment # simulate environment
while simulation_app.is_running(): while simulation_app.is_running():
# compute zero actions # run everything in inference mode
actions = torch.zeros((env.num_envs, env.action_space.shape[0]), device=env.device) with torch.inference_mode():
# apply actions # compute zero actions
_, _, _, _ = env.step(actions) actions = torch.zeros((env.num_envs, env.action_space.shape[0]), device=env.device)
# check if simulator is stopped # apply actions
if env.unwrapped.sim.is_stopped(): _, _, _, _ = env.step(actions)
break
# close the simulator # close the simulator
env.close() env.close()
......
...@@ -40,6 +40,7 @@ simulation_app = app_launcher.app ...@@ -40,6 +40,7 @@ simulation_app = app_launcher.app
import gym import gym
import math import math
import os import os
import torch
import traceback import traceback
import carb import carb
...@@ -124,21 +125,21 @@ def main(): ...@@ -124,21 +125,21 @@ def main():
# attempt to have complete control over environment stepping. However, this removes other # attempt to have complete control over environment stepping. However, this removes other
# operations such as masking that is used for multi-agent learning by RL-Games. # operations such as masking that is used for multi-agent learning by RL-Games.
while simulation_app.is_running(): while simulation_app.is_running():
# convert obs to agent format # run everything in inference mode
obs = agent.obs_to_torch(obs) with torch.inference_mode():
# agent stepping # convert obs to agent format
actions = agent.get_action(obs, is_deterministic) obs = agent.obs_to_torch(obs)
# env stepping # agent stepping
obs, _, dones, _ = env.step(actions) actions = agent.get_action(obs, is_deterministic)
# check if simulator is stopped # env stepping
if env.unwrapped.sim.is_stopped(): obs, _, dones, _ = env.step(actions)
break
# perform operations for terminated episodes # perform operations for terminated episodes
if len(dones) > 0: if len(dones) > 0:
# reset rnn state for terminated episodes # reset rnn state for terminated episodes
if agent.is_rnn and agent.states is not None: if agent.is_rnn and agent.states is not None:
for s in agent.states: for s in agent.states:
s[:, dones, :] = 0.0 s[:, dones, :] = 0.0
# close the simulator # close the simulator
env.close() env.close()
......
...@@ -116,8 +116,8 @@ def main(): ...@@ -116,8 +116,8 @@ def main():
teleop_interface.reset() teleop_interface.reset()
collector_interface.reset() collector_interface.reset()
# simulate environment # simulate environment -- run everything in inference mode
with contextlib.suppress(KeyboardInterrupt): with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode():
while not collector_interface.is_stopped(): while not collector_interface.is_stopped():
# get keyboard command # get keyboard command
delta_pose, gripper_command = teleop_interface.advance() delta_pose, gripper_command = teleop_interface.advance()
......
...@@ -70,16 +70,15 @@ def main(): ...@@ -70,16 +70,15 @@ def main():
obs = obs_dict["policy"] obs = obs_dict["policy"]
# simulate environment # simulate environment
while simulation_app.is_running(): while simulation_app.is_running():
# compute actions # run everything in inference mode
actions = policy(obs) with torch.inference_mode():
actions = torch.from_numpy(actions).to(device=device).view(1, env.action_space.shape[0]) # compute actions
# apply actions actions = policy(obs)
obs_dict, _, _, _ = env.step(actions) actions = torch.from_numpy(actions).to(device=device).view(1, env.action_space.shape[0])
# check if simulator is stopped # apply actions
if env.unwrapped.sim.is_stopped(): obs_dict, _, _, _ = env.step(actions)
break # robomimic only cares about policy observations
# robomimic only cares about policy observations obs = obs_dict["policy"]
obs = obs_dict["policy"]
# close the simulator # close the simulator
env.close() env.close()
......
...@@ -34,6 +34,7 @@ simulation_app = app_launcher.app ...@@ -34,6 +34,7 @@ simulation_app = app_launcher.app
import gym import gym
import os import os
import torch
import traceback import traceback
import carb import carb
...@@ -93,15 +94,12 @@ def main(): ...@@ -93,15 +94,12 @@ def main():
obs, _ = env.get_observations() obs, _ = env.get_observations()
# simulate environment # simulate environment
while simulation_app.is_running(): while simulation_app.is_running():
# agent stepping # run everything in inference mode
actions = policy(obs) with torch.inference_mode():
# env stepping # agent stepping
obs, _, _, _ = env.step(actions) actions = policy(obs)
# env rendering # env stepping
env.render() obs, _, _, _ = env.step(actions)
# check if simulator is stopped
if env.unwrapped.sim.is_stopped():
break
# close the simulator # close the simulator
env.close() env.close()
......
...@@ -34,6 +34,7 @@ simulation_app = app_launcher.app ...@@ -34,6 +34,7 @@ simulation_app = app_launcher.app
import gym import gym
import torch
import traceback import traceback
import carb import carb
...@@ -82,13 +83,12 @@ def main(): ...@@ -82,13 +83,12 @@ def main():
obs = env.reset() obs = env.reset()
# simulate environment # simulate environment
while simulation_app.is_running(): while simulation_app.is_running():
# agent stepping # run everything in inference mode
actions, _ = agent.predict(obs, deterministic=True) with torch.inference_mode():
# env stepping # agent stepping
obs, _, _, _ = env.step(actions) actions, _ = agent.predict(obs, deterministic=True)
# check if simulator is stopped # env stepping
if env.unwrapped.sim.is_stopped(): obs, _, _, _ = env.step(actions)
break
# close the simulator # close the simulator
env.close() env.close()
......
...@@ -39,6 +39,7 @@ simulation_app = app_launcher.app ...@@ -39,6 +39,7 @@ simulation_app = app_launcher.app
import gym import gym
import torch
import traceback import traceback
import carb import carb
...@@ -137,13 +138,12 @@ def main(): ...@@ -137,13 +138,12 @@ def main():
obs, _ = env.reset() obs, _ = env.reset()
# simulate environment # simulate environment
while simulation_app.is_running(): while simulation_app.is_running():
# agent stepping # run everything in inference mode
actions = agent.act(obs, timestep=0, timesteps=0)[0] with torch.inference_mode():
# env stepping # agent stepping
obs, _, _, _, _ = env.step(actions) actions = agent.act(obs, timestep=0, timesteps=0)[0]
# check if simulator is stopped # env stepping
if env.sim.is_stopped(): obs, _, _, _, _ = env.step(actions)
break
# close the simulator # close the simulator
env.close() env.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment