Unverified Commit 73f26f6b authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Puts all environment-related scripts in inference mode (#215)

# Description

When having agent-environment interaction, it is important to ensure
that PyTorch constructs no computational graph. Otherwise, it will keep
allocating memory for gradients and result in out-of-memory error.

This MR wraps all scripts with `torch.inference_mode()` to prevent this
issue.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
parent 34f8b443
......@@ -53,13 +53,12 @@ def main():
env.reset()
# simulate environment
while simulation_app.is_running():
# run everything in inference mode
with torch.inference_mode():
# sample actions from -1 to 1
actions = 2 * torch.rand((env.num_envs, env.action_space.shape[0]), device=env.device) - 1
# apply actions
_, _, _, _ = env.step(actions)
# check if simulator is stopped
if env.unwrapped.sim.is_stopped():
break
# close the simulator
env.close()
......
......@@ -258,6 +258,8 @@ def main():
# run state machine
for _ in range(10000):
# run everything in inference mode
with torch.inference_mode():
# step environment
dones = env.step(actions)[-2]
......
......@@ -103,6 +103,8 @@ def main():
# simulate environment
while simulation_app.is_running():
# run everything in inference mode
with torch.inference_mode():
# get keyboard command
delta_pose, gripper_command = teleop_interface.advance()
# convert to torch
......@@ -111,9 +113,6 @@ def main():
actions = pre_process_actions(delta_pose, gripper_command)
# apply actions
_, _, _, _ = env.step(actions)
# check if simulator is stopped
if env.unwrapped.sim.is_stopped():
break
# close the simulator
env.close()
......
......@@ -52,13 +52,12 @@ def main():
env.reset()
# simulate environment
while simulation_app.is_running():
# run everything in inference mode
with torch.inference_mode():
# compute zero actions
actions = torch.zeros((env.num_envs, env.action_space.shape[0]), device=env.device)
# apply actions
_, _, _, _ = env.step(actions)
# check if simulator is stopped
if env.unwrapped.sim.is_stopped():
break
# close the simulator
env.close()
......
......@@ -40,6 +40,7 @@ simulation_app = app_launcher.app
import gym
import math
import os
import torch
import traceback
import carb
......@@ -124,15 +125,15 @@ def main():
# attempt to have complete control over environment stepping. However, this removes other
# operations such as masking that is used for multi-agent learning by RL-Games.
while simulation_app.is_running():
# run everything in inference mode
with torch.inference_mode():
# convert obs to agent format
obs = agent.obs_to_torch(obs)
# agent stepping
actions = agent.get_action(obs, is_deterministic)
# env stepping
obs, _, dones, _ = env.step(actions)
# check if simulator is stopped
if env.unwrapped.sim.is_stopped():
break
# perform operations for terminated episodes
if len(dones) > 0:
# reset rnn state for terminated episodes
......
......@@ -116,8 +116,8 @@ def main():
teleop_interface.reset()
collector_interface.reset()
# simulate environment
with contextlib.suppress(KeyboardInterrupt):
# simulate environment -- run everything in inference mode
with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode():
while not collector_interface.is_stopped():
# get keyboard command
delta_pose, gripper_command = teleop_interface.advance()
......
......@@ -70,14 +70,13 @@ def main():
obs = obs_dict["policy"]
# simulate environment
while simulation_app.is_running():
# run everything in inference mode
with torch.inference_mode():
# compute actions
actions = policy(obs)
actions = torch.from_numpy(actions).to(device=device).view(1, env.action_space.shape[0])
# apply actions
obs_dict, _, _, _ = env.step(actions)
# check if simulator is stopped
if env.unwrapped.sim.is_stopped():
break
# robomimic only cares about policy observations
obs = obs_dict["policy"]
......
......@@ -34,6 +34,7 @@ simulation_app = app_launcher.app
import gym
import os
import torch
import traceback
import carb
......@@ -93,15 +94,12 @@ def main():
obs, _ = env.get_observations()
# simulate environment
while simulation_app.is_running():
# run everything in inference mode
with torch.inference_mode():
# agent stepping
actions = policy(obs)
# env stepping
obs, _, _, _ = env.step(actions)
# env rendering
env.render()
# check if simulator is stopped
if env.unwrapped.sim.is_stopped():
break
# close the simulator
env.close()
......
......@@ -34,6 +34,7 @@ simulation_app = app_launcher.app
import gym
import torch
import traceback
import carb
......@@ -82,13 +83,12 @@ def main():
obs = env.reset()
# simulate environment
while simulation_app.is_running():
# run everything in inference mode
with torch.inference_mode():
# agent stepping
actions, _ = agent.predict(obs, deterministic=True)
# env stepping
obs, _, _, _ = env.step(actions)
# check if simulator is stopped
if env.unwrapped.sim.is_stopped():
break
# close the simulator
env.close()
......
......@@ -39,6 +39,7 @@ simulation_app = app_launcher.app
import gym
import torch
import traceback
import carb
......@@ -137,13 +138,12 @@ def main():
obs, _ = env.reset()
# simulate environment
while simulation_app.is_running():
# run everything in inference mode
with torch.inference_mode():
# agent stepping
actions = agent.act(obs, timestep=0, timesteps=0)[0]
# env stepping
obs, _, _, _, _ = env.step(actions)
# check if simulator is stopped
if env.sim.is_stopped():
break
# close the simulator
env.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment