Unverified Commit e151bb93 authored by Johnson Sun's avatar Johnson Sun Committed by GitHub

Adds video recording to the play scripts in RL workflows (#763)

# Description

This MR enables users to easily record videos from trained checkpoints,
which is useful for sanity checks or creating promotional videos for
research papers.

Fixes: https://github.com/isaac-sim/IsaacLab/issues/130

## Type of change

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 4cb968ac
...@@ -64,7 +64,7 @@ Recording during training ...@@ -64,7 +64,7 @@ Recording during training
Isaac Lab supports recording video clips during training using the `gymnasium.wrappers.RecordVideo <https://gymnasium.farama.org/main/_modules/gymnasium/wrappers/record_video/>`_ class. Isaac Lab supports recording video clips during training using the `gymnasium.wrappers.RecordVideo <https://gymnasium.farama.org/main/_modules/gymnasium/wrappers/record_video/>`_ class.
This feature can be enabled by using the following command line arguments with the training script: This feature can be enabled by installing ``ffmpeg`` and using the following command line arguments with the training script:
* ``--video`` - enables video recording during training * ``--video`` - enables video recording during training
* ``--video_length`` - length of each recorded video (in steps) * ``--video_length`` - length of each recorded video (in steps)
...@@ -77,6 +77,6 @@ Example usage: ...@@ -77,6 +77,6 @@ Example usage:
.. code-block:: shell .. code-block:: shell
python source/standalone/workflows/rl_games/train.py --task=Isaac-Cartpole-v0 --headless --enable_cameras --video --video_length 100 --video_interval 500 python source/standalone/workflows/rl_games/train.py --task=Isaac-Cartpole-v0 --headless --video --video_length 100 --video_interval 500
Recorded videos will be saved in the same directory as the training checkpoints, under ``IsaacLab/logs/<rl_workflow>/<task>/<run>/videos``. Recorded videos will be saved in the same directory as the training checkpoints, under ``IsaacLab/logs/<rl_workflow>/<task>/<run>/videos/train``.
...@@ -113,7 +113,7 @@ for 200 steps, and saves it in the ``videos`` folder at a step interval of 1500 ...@@ -113,7 +113,7 @@ for 200 steps, and saves it in the ``videos`` folder at a step interval of 1500
env = gym.make(task_name, cfg=env_cfg, render_mode="rgb_array") env = gym.make(task_name, cfg=env_cfg, render_mode="rgb_array")
# wrap for video recording # wrap for video recording
video_kwargs = { video_kwargs = {
"video_folder": "videos", "video_folder": "videos/train",
"step_trigger": lambda step: step % 1500 == 0, "step_trigger": lambda step: step % 1500 == 0,
"video_length": 200, "video_length": 200,
} }
......
...@@ -186,6 +186,8 @@ from the environments into the respective libraries function argument and return ...@@ -186,6 +186,8 @@ from the environments into the respective libraries function argument and return
./isaaclab.sh -p source/standalone/workflows/sb3/train.py --task Isaac-Cartpole-v0 --headless --cpu ./isaaclab.sh -p source/standalone/workflows/sb3/train.py --task Isaac-Cartpole-v0 --headless --cpu
# run script for playing with 32 environments # run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/sb3/play.py --task Isaac-Cartpole-v0 --num_envs 32 --checkpoint /PATH/TO/model.zip ./isaaclab.sh -p source/standalone/workflows/sb3/play.py --task Isaac-Cartpole-v0 --num_envs 32 --checkpoint /PATH/TO/model.zip
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/sb3/play.py --task Isaac-Cartpole-v0 --headless --video --video_length 200
- Training an agent with - Training an agent with
`SKRL <https://skrl.readthedocs.io>`__ on ``Isaac-Reach-Franka-v0``: `SKRL <https://skrl.readthedocs.io>`__ on ``Isaac-Reach-Franka-v0``:
...@@ -202,6 +204,8 @@ from the environments into the respective libraries function argument and return ...@@ -202,6 +204,8 @@ from the environments into the respective libraries function argument and return
./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless ./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless
# run script for playing with 32 environments # run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --checkpoint /PATH/TO/model.pt ./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --checkpoint /PATH/TO/model.pt
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --headless --video --video_length 200
.. tab-item:: JAX .. tab-item:: JAX
...@@ -215,6 +219,8 @@ from the environments into the respective libraries function argument and return ...@@ -215,6 +219,8 @@ from the environments into the respective libraries function argument and return
./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless --ml_framework jax ./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless --ml_framework jax
# run script for playing with 32 environments # run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --ml_framework jax --checkpoint /PATH/TO/model.pt ./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --ml_framework jax --checkpoint /PATH/TO/model.pt
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --headless --ml_framework jax --video --video_length 200
- Training an agent with - Training an agent with
`RL-Games <https://github.com/Denys88/rl_games>`__ on ``Isaac-Ant-v0``: `RL-Games <https://github.com/Denys88/rl_games>`__ on ``Isaac-Ant-v0``:
...@@ -227,6 +233,8 @@ from the environments into the respective libraries function argument and return ...@@ -227,6 +233,8 @@ from the environments into the respective libraries function argument and return
./isaaclab.sh -p source/standalone/workflows/rl_games/train.py --task Isaac-Ant-v0 --headless ./isaaclab.sh -p source/standalone/workflows/rl_games/train.py --task Isaac-Ant-v0 --headless
# run script for playing with 32 environments # run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/rl_games/play.py --task Isaac-Ant-v0 --num_envs 32 --checkpoint /PATH/TO/model.pth ./isaaclab.sh -p source/standalone/workflows/rl_games/play.py --task Isaac-Ant-v0 --num_envs 32 --checkpoint /PATH/TO/model.pth
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/rl_games/play.py --task Isaac-Ant-v0 --headless --video --video_length 200
- Training an agent with - Training an agent with
`RSL-RL <https://github.com/leggedrobotics/rsl_rl>`__ on ``Isaac-Reach-Franka-v0``: `RSL-RL <https://github.com/leggedrobotics/rsl_rl>`__ on ``Isaac-Reach-Franka-v0``:
...@@ -239,6 +247,8 @@ from the environments into the respective libraries function argument and return ...@@ -239,6 +247,8 @@ from the environments into the respective libraries function argument and return
./isaaclab.sh -p source/standalone/workflows/rsl_rl/train.py --task Isaac-Reach-Franka-v0 --headless ./isaaclab.sh -p source/standalone/workflows/rsl_rl/train.py --task Isaac-Reach-Franka-v0 --headless
# run script for playing with 32 environments # run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/rsl_rl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --load_run run_folder_name --checkpoint model.pt ./isaaclab.sh -p source/standalone/workflows/rsl_rl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --load_run run_folder_name --checkpoint model.pt
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/rsl_rl/play.py --task Isaac-Reach-Franka-v0 --headless --video --video_length 200
All the scripts above log the training progress to `Tensorboard`_ in the ``logs`` directory in the root of All the scripts above log the training progress to `Tensorboard`_ in the ``logs`` directory in the root of
the repository. The logs directory follows the pattern ``logs/<library>/<task>/<date-time>``, where ``<library>`` the repository. The logs directory follows the pattern ``logs/<library>/<task>/<date-time>``, where ``<library>``
......
...@@ -99,9 +99,9 @@ agent's behavior during training. ...@@ -99,9 +99,9 @@ agent's behavior during training.
.. code-block:: bash .. code-block:: bash
./isaaclab.sh -p source/standalone/workflows/sb3/train.py --task Isaac-Cartpole-v0 --num_envs 64 --headless --enable_cameras --video ./isaaclab.sh -p source/standalone/workflows/sb3/train.py --task Isaac-Cartpole-v0 --num_envs 64 --headless --video
The videos are saved to the ``logs/sb3/Isaac-Cartpole-v0/<run-dir>/videos`` directory. You can open these videos The videos are saved to the ``logs/sb3/Isaac-Cartpole-v0/<run-dir>/videos/train`` directory. You can open these videos
using any video player. using any video player.
Interactive execution Interactive execution
......
...@@ -370,7 +370,7 @@ Added ...@@ -370,7 +370,7 @@ Added
* Added a new flag ``viewport`` to the :class:`IsaacEnv` class to enable/disable rendering of the viewport. * Added a new flag ``viewport`` to the :class:`IsaacEnv` class to enable/disable rendering of the viewport.
If the flag is set to ``True``, the viewport is enabled and the environment is rendered in the background. If the flag is set to ``True``, the viewport is enabled and the environment is rendered in the background.
* Updated the training scripts in the ``source/standalone/workflows`` directory to use the new flag ``viewport``. * Updated the training scripts in the ``source/standalone/workflows`` directory to use the new flag ``viewport``.
If the CLI argument ``--video`` is passed, videos are recorded in the ``videos`` directory using the If the CLI argument ``--video`` is passed, videos are recorded in the ``videos/train`` directory using the
:class:`gym.wrappers.RecordVideo` wrapper. :class:`gym.wrappers.RecordVideo` wrapper.
Changed Changed
......
...@@ -41,7 +41,7 @@ class TestRecordVideoWrapper(unittest.TestCase): ...@@ -41,7 +41,7 @@ class TestRecordVideoWrapper(unittest.TestCase):
# print all existing task names # print all existing task names
print(">>> All registered environments:", cls.registered_tasks) print(">>> All registered environments:", cls.registered_tasks)
# directory to save videos # directory to save videos
cls.videos_dir = os.path.join(os.path.dirname(__file__), "output", "videos") cls.videos_dir = os.path.join(os.path.dirname(__file__), "output", "videos", "train")
def setUp(self) -> None: def setUp(self) -> None:
# common parameters # common parameters
......
...@@ -13,6 +13,8 @@ from omni.isaac.lab.app import AppLauncher ...@@ -13,6 +13,8 @@ from omni.isaac.lab.app import AppLauncher
# add argparse arguments # add argparse arguments
parser = argparse.ArgumentParser(description="Play a checkpoint of an RL agent from RL-Games.") parser = argparse.ArgumentParser(description="Play a checkpoint of an RL agent from RL-Games.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.") parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.")
parser.add_argument( parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations." "--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
...@@ -29,6 +31,9 @@ parser.add_argument( ...@@ -29,6 +31,9 @@ parser.add_argument(
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
args_cli = parser.parse_args() args_cli = parser.parse_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
# launch omniverse app # launch omniverse app
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
...@@ -47,6 +52,7 @@ from rl_games.common.player import BasePlayer ...@@ -47,6 +52,7 @@ from rl_games.common.player import BasePlayer
from rl_games.torch_runner import Runner from rl_games.torch_runner import Runner
from omni.isaac.lab.utils.assets import retrieve_file_path from omni.isaac.lab.utils.assets import retrieve_file_path
from omni.isaac.lab.utils.dict import print_dict
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg from omni.isaac.lab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
...@@ -61,23 +67,6 @@ def main(): ...@@ -61,23 +67,6 @@ def main():
) )
agent_cfg = load_cfg_from_registry(args_cli.task, "rl_games_cfg_entry_point") agent_cfg = load_cfg_from_registry(args_cli.task, "rl_games_cfg_entry_point")
# wrap around environment for rl-games
rl_device = agent_cfg["params"]["config"]["device"]
clip_obs = agent_cfg["params"]["env"].get("clip_observations", math.inf)
clip_actions = agent_cfg["params"]["env"].get("clip_actions", math.inf)
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg)
# wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)
# register the environment to rl-games registry
# note: in agents configuration: environment name must be "rlgpu"
vecenv.register(
"IsaacRlgWrapper", lambda config_name, num_actors, **kwargs: RlGamesGpuEnv(config_name, num_actors, **kwargs)
)
env_configurations.register("rlgpu", {"vecenv_type": "IsaacRlgWrapper", "env_creator": lambda **kwargs: env})
# specify directory for logging experiments # specify directory for logging experiments
log_root_path = os.path.join("logs", "rl_games", agent_cfg["params"]["config"]["name"]) log_root_path = os.path.join("logs", "rl_games", agent_cfg["params"]["config"]["name"])
log_root_path = os.path.abspath(log_root_path) log_root_path = os.path.abspath(log_root_path)
...@@ -96,6 +85,36 @@ def main(): ...@@ -96,6 +85,36 @@ def main():
resume_path = get_checkpoint_path(log_root_path, run_dir, checkpoint_file, other_dirs=["nn"]) resume_path = get_checkpoint_path(log_root_path, run_dir, checkpoint_file, other_dirs=["nn"])
else: else:
resume_path = retrieve_file_path(args_cli.checkpoint) resume_path = retrieve_file_path(args_cli.checkpoint)
log_dir = os.path.dirname(os.path.dirname(resume_path))
# wrap around environment for rl-games
rl_device = agent_cfg["params"]["config"]["device"]
clip_obs = agent_cfg["params"]["env"].get("clip_observations", math.inf)
clip_actions = agent_cfg["params"]["env"].get("clip_actions", math.inf)
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_root_path, log_dir, "videos", "play"),
"step_trigger": lambda step: step == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
}
print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)
# register the environment to rl-games registry
# note: in agents configuration: environment name must be "rlgpu"
vecenv.register(
"IsaacRlgWrapper", lambda config_name, num_actors, **kwargs: RlGamesGpuEnv(config_name, num_actors, **kwargs)
)
env_configurations.register("rlgpu", {"vecenv_type": "IsaacRlgWrapper", "env_creator": lambda **kwargs: env})
# load previously trained model # load previously trained model
agent_cfg["params"]["load_checkpoint"] = True agent_cfg["params"]["load_checkpoint"] = True
agent_cfg["params"]["load_path"] = resume_path agent_cfg["params"]["load_path"] = resume_path
...@@ -115,6 +134,7 @@ def main(): ...@@ -115,6 +134,7 @@ def main():
obs = env.reset() obs = env.reset()
if isinstance(obs, dict): if isinstance(obs, dict):
obs = obs["obs"] obs = obs["obs"]
timestep = 0
# required: enables the flag for batched observations # required: enables the flag for batched observations
_ = agent.get_batch_size(obs, 1) _ = agent.get_batch_size(obs, 1)
# initialize RNN states if used # initialize RNN states if used
...@@ -140,6 +160,11 @@ def main(): ...@@ -140,6 +160,11 @@ def main():
if agent.is_rnn and agent.states is not None: if agent.is_rnn and agent.states is not None:
for s in agent.states: for s in agent.states:
s[:, dones, :] = 0.0 s[:, dones, :] = 0.0
if args_cli.video:
timestep += 1
# Exit the play loop after recording one video
if timestep == args_cli.video_length:
break
# close the simulator # close the simulator
env.close() env.close()
......
...@@ -113,7 +113,7 @@ def main(): ...@@ -113,7 +113,7 @@ def main():
# wrap for video recording # wrap for video recording
if args_cli.video: if args_cli.video:
video_kwargs = { video_kwargs = {
"video_folder": os.path.join(log_root_path, log_dir, "videos"), "video_folder": os.path.join(log_root_path, log_dir, "videos", "train"),
"step_trigger": lambda step: step % args_cli.video_interval == 0, "step_trigger": lambda step: step % args_cli.video_interval == 0,
"video_length": args_cli.video_length, "video_length": args_cli.video_length,
"disable_logger": True, "disable_logger": True,
......
...@@ -16,6 +16,8 @@ import cli_args # isort: skip ...@@ -16,6 +16,8 @@ import cli_args # isort: skip
# add argparse arguments # add argparse arguments
parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.") parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.") parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.")
parser.add_argument( parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations." "--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
...@@ -28,6 +30,9 @@ cli_args.add_rsl_rl_args(parser) ...@@ -28,6 +30,9 @@ cli_args.add_rsl_rl_args(parser)
# append AppLauncher cli args # append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args() args_cli = parser.parse_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
# launch omniverse app # launch omniverse app
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
...@@ -41,6 +46,8 @@ import torch ...@@ -41,6 +46,8 @@ import torch
from rsl_rl.runners import OnPolicyRunner from rsl_rl.runners import OnPolicyRunner
from omni.isaac.lab.utils.dict import print_dict
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, parse_env_cfg from omni.isaac.lab_tasks.utils import get_checkpoint_path, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import ( from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import (
...@@ -59,22 +66,33 @@ def main(): ...@@ -59,22 +66,33 @@ def main():
) )
agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli) agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli)
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg)
# wrap around environment for rsl-rl
env = RslRlVecEnvWrapper(env)
# specify directory for logging experiments # specify directory for logging experiments
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name) log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
log_root_path = os.path.abspath(log_root_path) log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}") print(f"[INFO] Loading experiment from directory: {log_root_path}")
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint) resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
print(f"[INFO]: Loading model checkpoint from: {resume_path}") log_dir = os.path.dirname(resume_path)
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_dir, "videos", "play"),
"step_trigger": lambda step: step == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
}
print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for rsl-rl
env = RslRlVecEnvWrapper(env)
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
# load previously trained model # load previously trained model
ppo_runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device) ppo_runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
ppo_runner.load(resume_path) ppo_runner.load(resume_path)
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
# obtain the trained policy for inference # obtain the trained policy for inference
policy = ppo_runner.get_inference_policy(device=env.unwrapped.device) policy = ppo_runner.get_inference_policy(device=env.unwrapped.device)
...@@ -90,6 +108,7 @@ def main(): ...@@ -90,6 +108,7 @@ def main():
# reset environment # reset environment
obs, _ = env.get_observations() obs, _ = env.get_observations()
timestep = 0
# simulate environment # simulate environment
while simulation_app.is_running(): while simulation_app.is_running():
# run everything in inference mode # run everything in inference mode
...@@ -98,6 +117,11 @@ def main(): ...@@ -98,6 +117,11 @@ def main():
actions = policy(obs) actions = policy(obs)
# env stepping # env stepping
obs, _, _, _ = env.step(actions) obs, _, _, _ = env.step(actions)
if args_cli.video:
timestep += 1
# Exit the play loop after recording one video
if timestep == args_cli.video_length:
break
# close the simulator # close the simulator
env.close() env.close()
......
...@@ -91,7 +91,7 @@ def main(): ...@@ -91,7 +91,7 @@ def main():
# wrap for video recording # wrap for video recording
if args_cli.video: if args_cli.video:
video_kwargs = { video_kwargs = {
"video_folder": os.path.join(log_dir, "videos"), "video_folder": os.path.join(log_dir, "videos", "train"),
"step_trigger": lambda step: step % args_cli.video_interval == 0, "step_trigger": lambda step: step % args_cli.video_interval == 0,
"video_length": args_cli.video_length, "video_length": args_cli.video_length,
"disable_logger": True, "disable_logger": True,
......
...@@ -13,6 +13,8 @@ from omni.isaac.lab.app import AppLauncher ...@@ -13,6 +13,8 @@ from omni.isaac.lab.app import AppLauncher
# add argparse arguments # add argparse arguments
parser = argparse.ArgumentParser(description="Play a checkpoint of an RL agent from Stable-Baselines3.") parser = argparse.ArgumentParser(description="Play a checkpoint of an RL agent from Stable-Baselines3.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.") parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.")
parser.add_argument( parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations." "--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
...@@ -29,6 +31,9 @@ parser.add_argument( ...@@ -29,6 +31,9 @@ parser.add_argument(
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
args_cli = parser.parse_args() args_cli = parser.parse_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
# launch omniverse app # launch omniverse app
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
...@@ -44,6 +49,8 @@ import torch ...@@ -44,6 +49,8 @@ import torch
from stable_baselines3 import PPO from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize from stable_baselines3.common.vec_env import VecNormalize
from omni.isaac.lab.utils.dict import print_dict
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg from omni.isaac.lab_tasks.utils.parse_cfg import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb3_cfg from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb3_cfg
...@@ -56,11 +63,37 @@ def main(): ...@@ -56,11 +63,37 @@ def main():
args_cli.task, use_gpu=not args_cli.cpu, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric args_cli.task, use_gpu=not args_cli.cpu, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
) )
agent_cfg = load_cfg_from_registry(args_cli.task, "sb3_cfg_entry_point") agent_cfg = load_cfg_from_registry(args_cli.task, "sb3_cfg_entry_point")
# directory for logging into
log_root_path = os.path.join("logs", "sb3", args_cli.task)
log_root_path = os.path.abspath(log_root_path)
# check checkpoint is valid
if args_cli.checkpoint is None:
if args_cli.use_last_checkpoint:
checkpoint = "model_.*.zip"
else:
checkpoint = "model.zip"
checkpoint_path = get_checkpoint_path(log_root_path, ".*", checkpoint)
else:
checkpoint_path = args_cli.checkpoint
log_dir = os.path.dirname(checkpoint_path)
# post-process agent configuration # post-process agent configuration
agent_cfg = process_sb3_cfg(agent_cfg) agent_cfg = process_sb3_cfg(agent_cfg)
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_dir, "videos", "play"),
"step_trigger": lambda step: step == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
}
print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for stable baselines # wrap around environment for stable baselines
env = Sb3VecEnvWrapper(env) env = Sb3VecEnvWrapper(env)
...@@ -76,24 +109,13 @@ def main(): ...@@ -76,24 +109,13 @@ def main():
clip_reward=np.inf, clip_reward=np.inf,
) )
# directory for logging into
log_root_path = os.path.join("logs", "sb3", args_cli.task)
log_root_path = os.path.abspath(log_root_path)
# check checkpoint is valid
if args_cli.checkpoint is None:
if args_cli.use_last_checkpoint:
checkpoint = "model_.*.zip"
else:
checkpoint = "model.zip"
checkpoint_path = get_checkpoint_path(log_root_path, ".*", checkpoint)
else:
checkpoint_path = args_cli.checkpoint
# create agent from stable baselines # create agent from stable baselines
print(f"Loading checkpoint from: {checkpoint_path}") print(f"Loading checkpoint from: {checkpoint_path}")
agent = PPO.load(checkpoint_path, env, print_system_info=True) agent = PPO.load(checkpoint_path, env, print_system_info=True)
# reset environment # reset environment
obs = env.reset() obs = env.reset()
timestep = 0
# simulate environment # simulate environment
while simulation_app.is_running(): while simulation_app.is_running():
# run everything in inference mode # run everything in inference mode
...@@ -102,6 +124,11 @@ def main(): ...@@ -102,6 +124,11 @@ def main():
actions, _ = agent.predict(obs, deterministic=True) actions, _ = agent.predict(obs, deterministic=True)
# env stepping # env stepping
obs, _, _, _ = env.step(actions) obs, _, _, _ = env.step(actions)
if args_cli.video:
timestep += 1
# Exit the play loop after recording one video
if timestep == args_cli.video_length:
break
# close the simulator # close the simulator
env.close() env.close()
......
...@@ -96,7 +96,7 @@ def main(): ...@@ -96,7 +96,7 @@ def main():
# wrap for video recording # wrap for video recording
if args_cli.video: if args_cli.video:
video_kwargs = { video_kwargs = {
"video_folder": os.path.join(log_dir, "videos"), "video_folder": os.path.join(log_dir, "videos", "train"),
"step_trigger": lambda step: step % args_cli.video_interval == 0, "step_trigger": lambda step: step % args_cli.video_interval == 0,
"video_length": args_cli.video_length, "video_length": args_cli.video_length,
"disable_logger": True, "disable_logger": True,
......
...@@ -19,6 +19,8 @@ from omni.isaac.lab.app import AppLauncher ...@@ -19,6 +19,8 @@ from omni.isaac.lab.app import AppLauncher
# add argparse arguments # add argparse arguments
parser = argparse.ArgumentParser(description="Play a checkpoint of an RL agent from skrl.") parser = argparse.ArgumentParser(description="Play a checkpoint of an RL agent from skrl.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.") parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.")
parser.add_argument( parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations." "--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
...@@ -38,6 +40,9 @@ parser.add_argument( ...@@ -38,6 +40,9 @@ parser.add_argument(
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
args_cli = parser.parse_args() args_cli = parser.parse_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
# launch omniverse app # launch omniverse app
app_launcher = AppLauncher(args_cli) app_launcher = AppLauncher(args_cli)
...@@ -58,6 +63,8 @@ elif args_cli.ml_framework.startswith("jax"): ...@@ -58,6 +63,8 @@ elif args_cli.ml_framework.startswith("jax"):
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model
from omni.isaac.lab.utils.dict import print_dict
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg from omni.isaac.lab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper, process_skrl_cfg from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper, process_skrl_cfg
...@@ -74,8 +81,30 @@ def main(): ...@@ -74,8 +81,30 @@ def main():
) )
experiment_cfg = load_cfg_from_registry(args_cli.task, "skrl_cfg_entry_point") experiment_cfg = load_cfg_from_registry(args_cli.task, "skrl_cfg_entry_point")
# specify directory for logging experiments (load checkpoint)
log_root_path = os.path.join("logs", "skrl", experiment_cfg["agent"]["experiment"]["directory"])
log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}")
# get checkpoint path
if args_cli.checkpoint:
resume_path = os.path.abspath(args_cli.checkpoint)
else:
resume_path = get_checkpoint_path(log_root_path, other_dirs=["checkpoints"])
log_dir = os.path.dirname(os.path.dirname(resume_path))
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_dir, "videos", "play"),
"step_trigger": lambda step: step == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
}
print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for skrl # wrap around environment for skrl
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="isaaclab")` env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="isaaclab")`
...@@ -137,25 +166,16 @@ def main(): ...@@ -137,25 +166,16 @@ def main():
device=env.device, device=env.device,
) )
# specify directory for logging experiments (load checkpoint)
log_root_path = os.path.join("logs", "skrl", experiment_cfg["agent"]["experiment"]["directory"])
log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}")
# get checkpoint path
if args_cli.checkpoint:
resume_path = os.path.abspath(args_cli.checkpoint)
else:
resume_path = get_checkpoint_path(log_root_path, other_dirs=["checkpoints"])
print(f"[INFO] Loading model checkpoint from: {resume_path}")
# initialize agent # initialize agent
agent.init() agent.init()
print(f"[INFO] Loading model checkpoint from: {resume_path}")
agent.load(resume_path) agent.load(resume_path)
# set agent to evaluation mode # set agent to evaluation mode
agent.set_running_mode("eval") agent.set_running_mode("eval")
# reset environment # reset environment
obs, _ = env.reset() obs, _ = env.reset()
timestep = 0
# simulate environment # simulate environment
while simulation_app.is_running(): while simulation_app.is_running():
# run everything in inference mode # run everything in inference mode
...@@ -164,6 +184,11 @@ def main(): ...@@ -164,6 +184,11 @@ def main():
actions = agent.act(obs, timestep=0, timesteps=0)[0] actions = agent.act(obs, timestep=0, timesteps=0)[0]
# env stepping # env stepping
obs, _, _, _, _ = env.step(actions) obs, _, _, _, _ = env.step(actions)
if args_cli.video:
timestep += 1
# Exit the play loop after recording one video
if timestep == args_cli.video_length:
break
# close the simulator # close the simulator
env.close() env.close()
......
...@@ -45,7 +45,7 @@ parser.add_argument( ...@@ -45,7 +45,7 @@ parser.add_argument(
AppLauncher.add_app_launcher_args(parser) AppLauncher.add_app_launcher_args(parser)
# parse the arguments # parse the arguments
args_cli = parser.parse_args() args_cli = parser.parse_args()
# always enable cameras to record video
if args_cli.video: if args_cli.video:
args_cli.enable_cameras = True args_cli.enable_cameras = True
...@@ -132,7 +132,7 @@ def main(): ...@@ -132,7 +132,7 @@ def main():
# wrap for video recording # wrap for video recording
if args_cli.video: if args_cli.video:
video_kwargs = { video_kwargs = {
"video_folder": os.path.join(log_dir, "videos"), "video_folder": os.path.join(log_dir, "videos", "train"),
"step_trigger": lambda step: step % args_cli.video_interval == 0, "step_trigger": lambda step: step % args_cli.video_interval == 0,
"video_length": args_cli.video_length, "video_length": args_cli.video_length,
"disable_logger": True, "disable_logger": True,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment