Unverified Commit 7ea72c40 authored by Rishi Veerapaneni's avatar Rishi Veerapaneni Committed by GitHub

Fixes MARL workflows for recording videos during training/inferencing (#1596)

# Description

Fixing bug so that using training workflow on MARL workflow populates
videos/train.
See #1595

## Type of change
- Bug fix (non-breaking change which fixes an issue)
## Screenshots

![before_and_after](https://github.com/user-attachments/assets/5b662a88-dedd-4220-a0c4-8e7d09ceb51f)
The first run was without the changes where we see videos/train empty.
The second run is after the changes with videos/train successfully
populated.

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [N/A] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [Sort of] I have added tests that prove my fix is effective or that my
feature works; I have verified that it works on train.py for skrl and
rl_games. I have not verified rsl_rl or sb3 as well have not verified
play.py on any of the four. However I have implemented the changes on
all of them as they all seem to follow the exact same structure.
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there; Unsure if this fix is worth being labelled as a
contributor, if so would be happy to be added to the contributors.md
(full name is Rishi Veerapaneni).
parent e8ea1850
...@@ -58,6 +58,7 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool = ...@@ -58,6 +58,7 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool =
self.cfg = self.env.cfg self.cfg = self.env.cfg
self.sim = self.env.sim self.sim = self.env.sim
self.scene = self.env.scene self.scene = self.env.scene
self.render_mode = self.env.render_mode
self.single_observation_space = gym.spaces.Dict() self.single_observation_space = gym.spaces.Dict()
if self._state_as_observation: if self._state_as_observation:
...@@ -126,7 +127,7 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool = ...@@ -126,7 +127,7 @@ def multi_agent_to_single_agent(env: DirectMARLEnv, state_as_observation: bool =
return obs, rewards, terminated, time_outs, extras return obs, rewards, terminated, time_outs, extras
def render(self, recompute: bool = False) -> np.ndarray | None: def render(self, recompute: bool = False) -> np.ndarray | None:
self.env.render(recompute) return self.env.render(recompute)
def close(self) -> None: def close(self) -> None:
self.env.close() self.env.close()
......
...@@ -76,5 +76,5 @@ agent: ...@@ -76,5 +76,5 @@ agent:
# https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html # https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html
trainer: trainer:
class: SequentialTrainer class: SequentialTrainer
timesteps: 1600 timesteps: 4800
environment_info: log environment_info: log
...@@ -78,5 +78,5 @@ agent: ...@@ -78,5 +78,5 @@ agent:
# https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html # https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html
trainer: trainer:
class: SequentialTrainer class: SequentialTrainer
timesteps: 1600 timesteps: 4800
environment_info: log environment_info: log
...@@ -76,5 +76,5 @@ agent: ...@@ -76,5 +76,5 @@ agent:
# https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html # https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html
trainer: trainer:
class: SequentialTrainer class: SequentialTrainer
timesteps: 1600 timesteps: 4800
environment_info: log environment_info: log
...@@ -94,6 +94,11 @@ def main(): ...@@ -94,6 +94,11 @@ def main():
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap for video recording # wrap for video recording
if args_cli.video: if args_cli.video:
video_kwargs = { video_kwargs = {
...@@ -106,10 +111,6 @@ def main(): ...@@ -106,10 +111,6 @@ def main():
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap around environment for rl-games # wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions) env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)
......
...@@ -129,6 +129,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -129,6 +129,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap for video recording # wrap for video recording
if args_cli.video: if args_cli.video:
video_kwargs = { video_kwargs = {
...@@ -141,10 +146,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -141,10 +146,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap around environment for rl-games # wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions) env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)
......
...@@ -74,6 +74,11 @@ def main(): ...@@ -74,6 +74,11 @@ def main():
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap for video recording # wrap for video recording
if args_cli.video: if args_cli.video:
video_kwargs = { video_kwargs = {
...@@ -86,10 +91,6 @@ def main(): ...@@ -86,10 +91,6 @@ def main():
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap around environment for rsl-rl # wrap around environment for rsl-rl
env = RslRlVecEnvWrapper(env) env = RslRlVecEnvWrapper(env)
......
...@@ -100,6 +100,10 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -100,6 +100,10 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# save resume path before creating a new log_dir # save resume path before creating a new log_dir
if agent_cfg.resume: if agent_cfg.resume:
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint) resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
...@@ -116,10 +120,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -116,10 +120,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap around environment for rsl-rl # wrap around environment for rsl-rl
env = RslRlVecEnvWrapper(env) env = RslRlVecEnvWrapper(env)
......
...@@ -48,6 +48,7 @@ import torch ...@@ -48,6 +48,7 @@ import torch
from stable_baselines3 import PPO from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize from stable_baselines3.common.vec_env import VecNormalize
from omni.isaac.lab.envs import DirectMARLEnv, multi_agent_to_single_agent
from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.dict import print_dict
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
...@@ -82,6 +83,11 @@ def main(): ...@@ -82,6 +83,11 @@ def main():
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap for video recording # wrap for video recording
if args_cli.video: if args_cli.video:
video_kwargs = { video_kwargs = {
......
...@@ -104,6 +104,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -104,6 +104,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap for video recording # wrap for video recording
if args_cli.video: if args_cli.video:
video_kwargs = { video_kwargs = {
...@@ -116,10 +121,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -116,10 +121,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
# wrap around environment for stable baselines # wrap around environment for stable baselines
env = Sb3VecEnvWrapper(env) env = Sb3VecEnvWrapper(env)
......
...@@ -116,6 +116,11 @@ def main(): ...@@ -116,6 +116,11 @@ def main():
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv) and algorithm in ["ppo"]:
env = multi_agent_to_single_agent(env)
# wrap for video recording # wrap for video recording
if args_cli.video: if args_cli.video:
video_kwargs = { video_kwargs = {
...@@ -128,10 +133,6 @@ def main(): ...@@ -128,10 +133,6 @@ def main():
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv) and algorithm in ["ppo"]:
env = multi_agent_to_single_agent(env)
# wrap around environment for skrl # wrap around environment for skrl
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="auto")` env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="auto")`
......
...@@ -151,6 +151,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -151,6 +151,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# create isaac environment # create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv) and algorithm in ["ppo"]:
env = multi_agent_to_single_agent(env)
# wrap for video recording # wrap for video recording
if args_cli.video: if args_cli.video:
video_kwargs = { video_kwargs = {
...@@ -163,10 +168,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -163,10 +168,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
print_dict(video_kwargs, nesting=4) print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs) env = gym.wrappers.RecordVideo(env, **video_kwargs)
# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv) and algorithm in ["ppo"]:
env = multi_agent_to_single_agent(env)
# wrap around environment for skrl # wrap around environment for skrl
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="auto")` env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="auto")`
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment