Unverified Commit 2a9198c8 authored by Kelly Guo's avatar Kelly Guo Committed by GitHub

Modifies workflow scripts to generate random seed when seed=-1 (#1048)

# Description

This change adds support to the train.py workflow scripts to support
setting seed=-1 to generate a random seed. Previously, setting seed to
-1 would cause errors to be thrown from RL libraries due to negative
seed values.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 23290191
...@@ -81,7 +81,7 @@ class DirectMARLEnv: ...@@ -81,7 +81,7 @@ class DirectMARLEnv:
# set the seed for the environment # set the seed for the environment
if self.cfg.seed is not None: if self.cfg.seed is not None:
self.seed(self.cfg.seed) self.cfg.seed = self.seed(self.cfg.seed)
else: else:
carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.") carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.")
......
...@@ -86,7 +86,7 @@ class DirectRLEnv(gym.Env): ...@@ -86,7 +86,7 @@ class DirectRLEnv(gym.Env):
# set the seed for the environment # set the seed for the environment
if self.cfg.seed is not None: if self.cfg.seed is not None:
self.seed(self.cfg.seed) self.cfg.seed = self.seed(self.cfg.seed)
else: else:
carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.") carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.")
......
...@@ -76,7 +76,7 @@ class ManagerBasedEnv: ...@@ -76,7 +76,7 @@ class ManagerBasedEnv:
# set the seed for the environment # set the seed for the environment
if self.cfg.seed is not None: if self.cfg.seed is not None:
self.seed(self.cfg.seed) self.cfg.seed = self.seed(self.cfg.seed)
else: else:
carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.") carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.")
......
...@@ -47,6 +47,7 @@ simulation_app = app_launcher.app ...@@ -47,6 +47,7 @@ simulation_app = app_launcher.app
import gymnasium as gym import gymnasium as gym
import math import math
import os import os
import random
from datetime import datetime from datetime import datetime
from rl_games.common import env_configurations, vecenv from rl_games.common import env_configurations, vecenv
...@@ -76,6 +77,10 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -76,6 +77,10 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
agent_cfg["params"]["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["params"]["seed"] agent_cfg["params"]["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["params"]["seed"]
agent_cfg["params"]["config"]["max_epochs"] = ( agent_cfg["params"]["config"]["max_epochs"] = (
args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg["params"]["config"]["max_epochs"] args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg["params"]["config"]["max_epochs"]
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import random
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -68,6 +69,9 @@ def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Name ...@@ -68,6 +69,9 @@ def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Name
""" """
# override the default configuration with CLI arguments # override the default configuration with CLI arguments
if hasattr(args_cli, "seed") and args_cli.seed is not None: if hasattr(args_cli, "seed") and args_cli.seed is not None:
# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
agent_cfg.seed = args_cli.seed agent_cfg.seed = args_cli.seed
if args_cli.resume is not None: if args_cli.resume is not None:
agent_cfg.resume = args_cli.resume agent_cfg.resume = args_cli.resume
......
...@@ -46,6 +46,7 @@ simulation_app = app_launcher.app ...@@ -46,6 +46,7 @@ simulation_app = app_launcher.app
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import os import os
import random
from datetime import datetime from datetime import datetime
from stable_baselines3 import PPO from stable_baselines3 import PPO
...@@ -71,6 +72,10 @@ from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb ...@@ -71,6 +72,10 @@ from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb
@hydra_task_config(args_cli.task, "sb3_cfg_entry_point") @hydra_task_config(args_cli.task, "sb3_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict): def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict):
"""Train with stable-baselines agent.""" """Train with stable-baselines agent."""
# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
agent_cfg["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"] agent_cfg["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]
......
...@@ -63,6 +63,7 @@ simulation_app = app_launcher.app ...@@ -63,6 +63,7 @@ simulation_app = app_launcher.app
import gymnasium as gym import gymnasium as gym
import os import os
import random
from datetime import datetime from datetime import datetime
import skrl import skrl
...@@ -119,9 +120,14 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -119,9 +120,14 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
if args_cli.ml_framework.startswith("jax"): if args_cli.ml_framework.startswith("jax"):
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy" skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
# set the environment seed # randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
# set the agent and environment seed from command line
# note: certain randomization occur in the environment initialization so we set the seed here # note: certain randomization occur in the environment initialization so we set the seed here
env_cfg.seed = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"] agent_cfg["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]
env_cfg.seed = agent_cfg["seed"]
# specify directory for logging experiments # specify directory for logging experiments
log_root_path = os.path.join("logs", "skrl", agent_cfg["agent"]["experiment"]["directory"]) log_root_path = os.path.join("logs", "skrl", agent_cfg["agent"]["experiment"]["directory"])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment