Unverified Commit e151bb93 authored by Johnson Sun's avatar Johnson Sun Committed by GitHub

Adds video recording to the play scripts in RL workflows (#763)

# Description

This MR enables users to easily record videos from trained checkpoints,
which is useful for sanity checks or creating promotional videos for
research papers.

Fixes: https://github.com/isaac-sim/IsaacLab/issues/130

## Type of change

- New feature (non-breaking change which adds functionality)
- This change requires a documentation update

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 4cb968ac
......@@ -64,7 +64,7 @@ Recording during training
Isaac Lab supports recording video clips during training using the `gymnasium.wrappers.RecordVideo <https://gymnasium.farama.org/main/_modules/gymnasium/wrappers/record_video/>`_ class.
This feature can be enabled by using the following command line arguments with the training script:
This feature can be enabled by installing ``ffmpeg`` and using the following command line arguments with the training script:
* ``--video`` - enables video recording during training
* ``--video_length`` - length of each recorded video (in steps)
......@@ -77,6 +77,6 @@ Example usage:
.. code-block:: shell
python source/standalone/workflows/rl_games/train.py --task=Isaac-Cartpole-v0 --headless --enable_cameras --video --video_length 100 --video_interval 500
python source/standalone/workflows/rl_games/train.py --task=Isaac-Cartpole-v0 --headless --video --video_length 100 --video_interval 500
Recorded videos will be saved in the same directory as the training checkpoints, under ``IsaacLab/logs/<rl_workflow>/<task>/<run>/videos``.
Recorded videos will be saved in the same directory as the training checkpoints, under ``IsaacLab/logs/<rl_workflow>/<task>/<run>/videos/train``.
......@@ -113,7 +113,7 @@ for 200 steps, and saves it in the ``videos`` folder at a step interval of 1500
env = gym.make(task_name, cfg=env_cfg, render_mode="rgb_array")
# wrap for video recording
video_kwargs = {
"video_folder": "videos",
"video_folder": "videos/train",
"step_trigger": lambda step: step % 1500 == 0,
"video_length": 200,
}
......
......@@ -186,6 +186,8 @@ from the environments into the respective libraries function argument and return
./isaaclab.sh -p source/standalone/workflows/sb3/train.py --task Isaac-Cartpole-v0 --headless --cpu
# run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/sb3/play.py --task Isaac-Cartpole-v0 --num_envs 32 --checkpoint /PATH/TO/model.zip
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/sb3/play.py --task Isaac-Cartpole-v0 --headless --video --video_length 200
- Training an agent with
`SKRL <https://skrl.readthedocs.io>`__ on ``Isaac-Reach-Franka-v0``:
......@@ -202,6 +204,8 @@ from the environments into the respective libraries function argument and return
./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless
# run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --checkpoint /PATH/TO/model.pt
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --headless --video --video_length 200
.. tab-item:: JAX
......@@ -215,6 +219,8 @@ from the environments into the respective libraries function argument and return
./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless --ml_framework jax
# run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --ml_framework jax --checkpoint /PATH/TO/model.pt
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --headless --ml_framework jax --video --video_length 200
- Training an agent with
`RL-Games <https://github.com/Denys88/rl_games>`__ on ``Isaac-Ant-v0``:
......@@ -227,6 +233,8 @@ from the environments into the respective libraries function argument and return
./isaaclab.sh -p source/standalone/workflows/rl_games/train.py --task Isaac-Ant-v0 --headless
# run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/rl_games/play.py --task Isaac-Ant-v0 --num_envs 32 --checkpoint /PATH/TO/model.pth
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/rl_games/play.py --task Isaac-Ant-v0 --headless --video --video_length 200
- Training an agent with
`RSL-RL <https://github.com/leggedrobotics/rsl_rl>`__ on ``Isaac-Reach-Franka-v0``:
......@@ -239,6 +247,8 @@ from the environments into the respective libraries function argument and return
./isaaclab.sh -p source/standalone/workflows/rsl_rl/train.py --task Isaac-Reach-Franka-v0 --headless
# run script for playing with 32 environments
./isaaclab.sh -p source/standalone/workflows/rsl_rl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --load_run run_folder_name --checkpoint model.pt
# run script for recording video of a trained agent (requires installing `ffmpeg`)
./isaaclab.sh -p source/standalone/workflows/rsl_rl/play.py --task Isaac-Reach-Franka-v0 --headless --video --video_length 200
All the scripts above log the training progress to `Tensorboard`_ in the ``logs`` directory in the root of
the repository. The logs directory follows the pattern ``logs/<library>/<task>/<date-time>``, where ``<library>``
......
......@@ -99,9 +99,9 @@ agent's behavior during training.
.. code-block:: bash
./isaaclab.sh -p source/standalone/workflows/sb3/train.py --task Isaac-Cartpole-v0 --num_envs 64 --headless --enable_cameras --video
./isaaclab.sh -p source/standalone/workflows/sb3/train.py --task Isaac-Cartpole-v0 --num_envs 64 --headless --video
The videos are saved to the ``logs/sb3/Isaac-Cartpole-v0/<run-dir>/videos`` directory. You can open these videos
The videos are saved to the ``logs/sb3/Isaac-Cartpole-v0/<run-dir>/videos/train`` directory. You can open these videos
using any video player.
Interactive execution
......
......@@ -370,7 +370,7 @@ Added
* Added a new flag ``viewport`` to the :class:`IsaacEnv` class to enable/disable rendering of the viewport.
If the flag is set to ``True``, the viewport is enabled and the environment is rendered in the background.
* Updated the training scripts in the ``source/standalone/workflows`` directory to use the new flag ``viewport``.
If the CLI argument ``--video`` is passed, videos are recorded in the ``videos`` directory using the
If the CLI argument ``--video`` is passed, videos are recorded in the ``videos/train`` directory using the
:class:`gym.wrappers.RecordVideo` wrapper.
Changed
......
......@@ -41,7 +41,7 @@ class TestRecordVideoWrapper(unittest.TestCase):
# print all existing task names
print(">>> All registered environments:", cls.registered_tasks)
# directory to save videos
cls.videos_dir = os.path.join(os.path.dirname(__file__), "output", "videos")
cls.videos_dir = os.path.join(os.path.dirname(__file__), "output", "videos", "train")
def setUp(self) -> None:
# common parameters
......
......@@ -13,6 +13,8 @@ from omni.isaac.lab.app import AppLauncher
# add argparse arguments
parser = argparse.ArgumentParser(description="Play a checkpoint of an RL agent from RL-Games.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
......@@ -29,6 +31,9 @@ parser.add_argument(
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli = parser.parse_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
# launch omniverse app
app_launcher = AppLauncher(args_cli)
......@@ -47,6 +52,7 @@ from rl_games.common.player import BasePlayer
from rl_games.torch_runner import Runner
from omni.isaac.lab.utils.assets import retrieve_file_path
from omni.isaac.lab.utils.dict import print_dict
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
......@@ -61,23 +67,6 @@ def main():
)
agent_cfg = load_cfg_from_registry(args_cli.task, "rl_games_cfg_entry_point")
# wrap around environment for rl-games
rl_device = agent_cfg["params"]["config"]["device"]
clip_obs = agent_cfg["params"]["env"].get("clip_observations", math.inf)
clip_actions = agent_cfg["params"]["env"].get("clip_actions", math.inf)
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg)
# wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)
# register the environment to rl-games registry
# note: in agents configuration: environment name must be "rlgpu"
vecenv.register(
"IsaacRlgWrapper", lambda config_name, num_actors, **kwargs: RlGamesGpuEnv(config_name, num_actors, **kwargs)
)
env_configurations.register("rlgpu", {"vecenv_type": "IsaacRlgWrapper", "env_creator": lambda **kwargs: env})
# specify directory for logging experiments
log_root_path = os.path.join("logs", "rl_games", agent_cfg["params"]["config"]["name"])
log_root_path = os.path.abspath(log_root_path)
......@@ -96,6 +85,36 @@ def main():
resume_path = get_checkpoint_path(log_root_path, run_dir, checkpoint_file, other_dirs=["nn"])
else:
resume_path = retrieve_file_path(args_cli.checkpoint)
log_dir = os.path.dirname(os.path.dirname(resume_path))
# wrap around environment for rl-games
rl_device = agent_cfg["params"]["config"]["device"]
clip_obs = agent_cfg["params"]["env"].get("clip_observations", math.inf)
clip_actions = agent_cfg["params"]["env"].get("clip_actions", math.inf)
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_root_path, log_dir, "videos", "play"),
"step_trigger": lambda step: step == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
}
print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)
# register the environment to rl-games registry
# note: in agents configuration: environment name must be "rlgpu"
vecenv.register(
"IsaacRlgWrapper", lambda config_name, num_actors, **kwargs: RlGamesGpuEnv(config_name, num_actors, **kwargs)
)
env_configurations.register("rlgpu", {"vecenv_type": "IsaacRlgWrapper", "env_creator": lambda **kwargs: env})
# load previously trained model
agent_cfg["params"]["load_checkpoint"] = True
agent_cfg["params"]["load_path"] = resume_path
......@@ -115,6 +134,7 @@ def main():
obs = env.reset()
if isinstance(obs, dict):
obs = obs["obs"]
timestep = 0
# required: enables the flag for batched observations
_ = agent.get_batch_size(obs, 1)
# initialize RNN states if used
......@@ -140,6 +160,11 @@ def main():
if agent.is_rnn and agent.states is not None:
for s in agent.states:
s[:, dones, :] = 0.0
if args_cli.video:
timestep += 1
# Exit the play loop after recording one video
if timestep == args_cli.video_length:
break
# close the simulator
env.close()
......
......@@ -113,7 +113,7 @@ def main():
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_root_path, log_dir, "videos"),
"video_folder": os.path.join(log_root_path, log_dir, "videos", "train"),
"step_trigger": lambda step: step % args_cli.video_interval == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
......
......@@ -16,6 +16,8 @@ import cli_args # isort: skip
# add argparse arguments
parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
......@@ -28,6 +30,9 @@ cli_args.add_rsl_rl_args(parser)
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
# launch omniverse app
app_launcher = AppLauncher(args_cli)
......@@ -41,6 +46,8 @@ import torch
from rsl_rl.runners import OnPolicyRunner
from omni.isaac.lab.utils.dict import print_dict
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import (
......@@ -59,22 +66,33 @@ def main():
)
agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli)
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg)
# wrap around environment for rsl-rl
env = RslRlVecEnvWrapper(env)
# specify directory for logging experiments
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}")
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
log_dir = os.path.dirname(resume_path)
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_dir, "videos", "play"),
"step_trigger": lambda step: step == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
}
print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for rsl-rl
env = RslRlVecEnvWrapper(env)
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
# load previously trained model
ppo_runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
ppo_runner.load(resume_path)
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
# obtain the trained policy for inference
policy = ppo_runner.get_inference_policy(device=env.unwrapped.device)
......@@ -90,6 +108,7 @@ def main():
# reset environment
obs, _ = env.get_observations()
timestep = 0
# simulate environment
while simulation_app.is_running():
# run everything in inference mode
......@@ -98,6 +117,11 @@ def main():
actions = policy(obs)
# env stepping
obs, _, _, _ = env.step(actions)
if args_cli.video:
timestep += 1
# Exit the play loop after recording one video
if timestep == args_cli.video_length:
break
# close the simulator
env.close()
......
......@@ -91,7 +91,7 @@ def main():
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_dir, "videos"),
"video_folder": os.path.join(log_dir, "videos", "train"),
"step_trigger": lambda step: step % args_cli.video_interval == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
......
......@@ -13,6 +13,8 @@ from omni.isaac.lab.app import AppLauncher
# add argparse arguments
parser = argparse.ArgumentParser(description="Play a checkpoint of an RL agent from Stable-Baselines3.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
......@@ -29,6 +31,9 @@ parser.add_argument(
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli = parser.parse_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
# launch omniverse app
app_launcher = AppLauncher(args_cli)
......@@ -44,6 +49,8 @@ import torch
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
from omni.isaac.lab.utils.dict import print_dict
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb3_cfg
......@@ -56,11 +63,37 @@ def main():
args_cli.task, use_gpu=not args_cli.cpu, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
)
agent_cfg = load_cfg_from_registry(args_cli.task, "sb3_cfg_entry_point")
# directory for logging into
log_root_path = os.path.join("logs", "sb3", args_cli.task)
log_root_path = os.path.abspath(log_root_path)
# check checkpoint is valid
if args_cli.checkpoint is None:
if args_cli.use_last_checkpoint:
checkpoint = "model_.*.zip"
else:
checkpoint = "model.zip"
checkpoint_path = get_checkpoint_path(log_root_path, ".*", checkpoint)
else:
checkpoint_path = args_cli.checkpoint
log_dir = os.path.dirname(checkpoint_path)
# post-process agent configuration
agent_cfg = process_sb3_cfg(agent_cfg)
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg)
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_dir, "videos", "play"),
"step_trigger": lambda step: step == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
}
print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for stable baselines
env = Sb3VecEnvWrapper(env)
......@@ -76,24 +109,13 @@ def main():
clip_reward=np.inf,
)
# directory for logging into
log_root_path = os.path.join("logs", "sb3", args_cli.task)
log_root_path = os.path.abspath(log_root_path)
# check checkpoint is valid
if args_cli.checkpoint is None:
if args_cli.use_last_checkpoint:
checkpoint = "model_.*.zip"
else:
checkpoint = "model.zip"
checkpoint_path = get_checkpoint_path(log_root_path, ".*", checkpoint)
else:
checkpoint_path = args_cli.checkpoint
# create agent from stable baselines
print(f"Loading checkpoint from: {checkpoint_path}")
agent = PPO.load(checkpoint_path, env, print_system_info=True)
# reset environment
obs = env.reset()
timestep = 0
# simulate environment
while simulation_app.is_running():
# run everything in inference mode
......@@ -102,6 +124,11 @@ def main():
actions, _ = agent.predict(obs, deterministic=True)
# env stepping
obs, _, _, _ = env.step(actions)
if args_cli.video:
timestep += 1
# Exit the play loop after recording one video
if timestep == args_cli.video_length:
break
# close the simulator
env.close()
......
......@@ -96,7 +96,7 @@ def main():
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_dir, "videos"),
"video_folder": os.path.join(log_dir, "videos", "train"),
"step_trigger": lambda step: step % args_cli.video_interval == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
......
......@@ -19,6 +19,8 @@ from omni.isaac.lab.app import AppLauncher
# add argparse arguments
parser = argparse.ArgumentParser(description="Play a checkpoint of an RL agent from skrl.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--cpu", action="store_true", default=False, help="Use CPU pipeline.")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
......@@ -38,6 +40,9 @@ parser.add_argument(
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli = parser.parse_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
# launch omniverse app
app_launcher = AppLauncher(args_cli)
......@@ -58,6 +63,8 @@ elif args_cli.ml_framework.startswith("jax"):
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model
from omni.isaac.lab.utils.dict import print_dict
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper, process_skrl_cfg
......@@ -74,8 +81,30 @@ def main():
)
experiment_cfg = load_cfg_from_registry(args_cli.task, "skrl_cfg_entry_point")
# specify directory for logging experiments (load checkpoint)
log_root_path = os.path.join("logs", "skrl", experiment_cfg["agent"]["experiment"]["directory"])
log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}")
# get checkpoint path
if args_cli.checkpoint:
resume_path = os.path.abspath(args_cli.checkpoint)
else:
resume_path = get_checkpoint_path(log_root_path, other_dirs=["checkpoints"])
log_dir = os.path.dirname(os.path.dirname(resume_path))
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg)
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_dir, "videos", "play"),
"step_trigger": lambda step: step == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
}
print("[INFO] Recording videos during training.")
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)
# wrap around environment for skrl
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="isaaclab")`
......@@ -137,25 +166,16 @@ def main():
device=env.device,
)
# specify directory for logging experiments (load checkpoint)
log_root_path = os.path.join("logs", "skrl", experiment_cfg["agent"]["experiment"]["directory"])
log_root_path = os.path.abspath(log_root_path)
print(f"[INFO] Loading experiment from directory: {log_root_path}")
# get checkpoint path
if args_cli.checkpoint:
resume_path = os.path.abspath(args_cli.checkpoint)
else:
resume_path = get_checkpoint_path(log_root_path, other_dirs=["checkpoints"])
print(f"[INFO] Loading model checkpoint from: {resume_path}")
# initialize agent
agent.init()
print(f"[INFO] Loading model checkpoint from: {resume_path}")
agent.load(resume_path)
# set agent to evaluation mode
agent.set_running_mode("eval")
# reset environment
obs, _ = env.reset()
timestep = 0
# simulate environment
while simulation_app.is_running():
# run everything in inference mode
......@@ -164,6 +184,11 @@ def main():
actions = agent.act(obs, timestep=0, timesteps=0)[0]
# env stepping
obs, _, _, _, _ = env.step(actions)
if args_cli.video:
timestep += 1
# Exit the play loop after recording one video
if timestep == args_cli.video_length:
break
# close the simulator
env.close()
......
......@@ -45,7 +45,7 @@ parser.add_argument(
AppLauncher.add_app_launcher_args(parser)
# parse the arguments
args_cli = parser.parse_args()
# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True
......@@ -132,7 +132,7 @@ def main():
# wrap for video recording
if args_cli.video:
video_kwargs = {
"video_folder": os.path.join(log_dir, "videos"),
"video_folder": os.path.join(log_dir, "videos", "train"),
"step_trigger": lambda step: step % args_cli.video_interval == 0,
"video_length": args_cli.video_length,
"disable_logger": True,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment