Unverified Commit 2a9198c8 authored by Kelly Guo's avatar Kelly Guo Committed by GitHub

Modifies workflow scripts to generate random seed when seed=-1 (#1048)

# Description

This change adds support to the train.py workflow scripts to support
setting seed=-1 to generate a random seed. Previously, setting seed to
-1 would cause errors to be thrown from RL libraries due to negative
seed values.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 23290191
......@@ -81,7 +81,7 @@ class DirectMARLEnv:
# set the seed for the environment
if self.cfg.seed is not None:
self.seed(self.cfg.seed)
self.cfg.seed = self.seed(self.cfg.seed)
else:
carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.")
......
......@@ -86,7 +86,7 @@ class DirectRLEnv(gym.Env):
# set the seed for the environment
if self.cfg.seed is not None:
self.seed(self.cfg.seed)
self.cfg.seed = self.seed(self.cfg.seed)
else:
carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.")
......
......@@ -76,7 +76,7 @@ class ManagerBasedEnv:
# set the seed for the environment
if self.cfg.seed is not None:
self.seed(self.cfg.seed)
self.cfg.seed = self.seed(self.cfg.seed)
else:
carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.")
......
......@@ -47,6 +47,7 @@ simulation_app = app_launcher.app
import gymnasium as gym
import math
import os
import random
from datetime import datetime
from rl_games.common import env_configurations, vecenv
......@@ -76,6 +77,10 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
agent_cfg["params"]["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["params"]["seed"]
agent_cfg["params"]["config"]["max_epochs"] = (
args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg["params"]["config"]["max_epochs"]
......
......@@ -6,6 +6,7 @@
from __future__ import annotations
import argparse
import random
from typing import TYPE_CHECKING
if TYPE_CHECKING:
......@@ -68,6 +69,9 @@ def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Name
"""
# override the default configuration with CLI arguments
if hasattr(args_cli, "seed") and args_cli.seed is not None:
# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
agent_cfg.seed = args_cli.seed
if args_cli.resume is not None:
agent_cfg.resume = args_cli.resume
......
......@@ -46,6 +46,7 @@ simulation_app = app_launcher.app
import gymnasium as gym
import numpy as np
import os
import random
from datetime import datetime
from stable_baselines3 import PPO
......@@ -71,6 +72,10 @@ from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb
@hydra_task_config(args_cli.task, "sb3_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict):
"""Train with stable-baselines agent."""
# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
# override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
agent_cfg["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]
......
......@@ -63,6 +63,7 @@ simulation_app = app_launcher.app
import gymnasium as gym
import os
import random
from datetime import datetime
import skrl
......@@ -119,9 +120,14 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
if args_cli.ml_framework.startswith("jax"):
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
# set the environment seed
# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
# set the agent and environment seed from command line
# note: certain randomization occur in the environment initialization so we set the seed here
env_cfg.seed = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]
agent_cfg["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]
env_cfg.seed = agent_cfg["seed"]
# specify directory for logging experiments
log_root_path = os.path.join("logs", "skrl", agent_cfg["agent"]["experiment"]["directory"])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment