Commit 90b11509 authored by amrmousa144's avatar amrmousa144 Committed by Mayank Mittal

Fixes setting the device from CLI in the RL training scripts (#1013)

This pull request fixes the issue where the device (`CPU` or `CUDA`) is
not set correctly when using the `--device` argument in Hydra-configured
scripts like `rsl_rl/train.py` and `skrl/train.py`. The bug caused the
scripts to always default to `cuda:0`, even when `cpu` or a specific
CUDA device (e.g., `cuda:1`) was selected.

The fix adds the following line to ensure that the selected device is
properly set in `env_cfg` before initializing the environment with
`gym.make()`:

```python
env_cfg.sim.device = args_cli.device
```

Fixes #1012

- Bug fix (non-breaking change which fixes an issue)

Before:
- skrl/train, when running the script with --device cpu, it defaults to
cuda:0.
- rsl_rl/train.py, the script freezes at `[INFO]: Starting the
simulation. This may take a few seconds. Please wait....`

After:
- Both scripts run correctly on the specified device (e.g., cpu or
cuda:1) without defaulting to cuda:0 or freezing.

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
parent 59fd1f7f
...@@ -31,9 +31,10 @@ Guidelines for modifications: ...@@ -31,9 +31,10 @@ Guidelines for modifications:
## Contributors ## Contributors
* Anton Bjørndahl Mortensen
* Alice Zhou * Alice Zhou
* Amr Mousa
* Andrej Orsula * Andrej Orsula
* Anton Bjørndahl Mortensen
* Antonio Serrano-Muñoz * Antonio Serrano-Muñoz
* Arjun Bhardwaj * Arjun Bhardwaj
* Brayden Zhang * Brayden Zhang
......
...@@ -79,6 +79,12 @@ class DirectMARLEnv: ...@@ -79,6 +79,12 @@ class DirectMARLEnv:
# initialize internal variables # initialize internal variables
self._is_closed = False self._is_closed = False
# set the seed for the environment
if self.cfg.seed is not None:
self.seed(self.cfg.seed)
else:
carb.log_warn("Seed not set for the environment. The environment creation may not be deterministic.")
# create a simulation context to control the simulator # create a simulation context to control the simulator
if SimulationContext.instance() is None: if SimulationContext.instance() is None:
self.sim: SimulationContext = SimulationContext(self.cfg.sim) self.sim: SimulationContext = SimulationContext(self.cfg.sim)
...@@ -88,6 +94,7 @@ class DirectMARLEnv: ...@@ -88,6 +94,7 @@ class DirectMARLEnv:
# print useful information # print useful information
print("[INFO]: Base environment:") print("[INFO]: Base environment:")
print(f"\tEnvironment device : {self.device}") print(f"\tEnvironment device : {self.device}")
print(f"\tEnvironment seed : {self.cfg.seed}")
print(f"\tPhysics step-size : {self.physics_dt}") print(f"\tPhysics step-size : {self.physics_dt}")
print(f"\tRendering step-size : {self.physics_dt * self.cfg.sim.render_interval}") print(f"\tRendering step-size : {self.physics_dt * self.cfg.sim.render_interval}")
print(f"\tEnvironment step-size : {self.step_dt}") print(f"\tEnvironment step-size : {self.step_dt}")
......
...@@ -41,6 +41,14 @@ class DirectMARLEnvCfg: ...@@ -41,6 +41,14 @@ class DirectMARLEnvCfg:
""" """
# general settings # general settings
seed: int | None = None
"""The seed for the random number generator. Defaults to None, in which case the seed is not set.
Note:
The seed is set at the beginning of the environment initialization. This ensures that the environment
creation is deterministic and behaves similarly across different runs.
"""
decimation: int = MISSING decimation: int = MISSING
"""Number of control action updates @ sim dt per policy dt. """Number of control action updates @ sim dt per policy dt.
......
...@@ -74,6 +74,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -74,6 +74,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
"""Train with RL-Games agent.""" """Train with RL-Games agent."""
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
agent_cfg["params"]["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["params"]["seed"] agent_cfg["params"]["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["params"]["seed"]
agent_cfg["params"]["config"]["max_epochs"] = ( agent_cfg["params"]["config"]["max_epochs"] = (
args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg["params"]["config"]["max_epochs"] args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg["params"]["config"]["max_epochs"]
......
...@@ -85,6 +85,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -85,6 +85,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# set the environment seed # set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here # note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg.seed env_cfg.seed = agent_cfg.seed
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
# specify directory for logging experiments # specify directory for logging experiments
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name) log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
......
...@@ -81,6 +81,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -81,6 +81,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# set the environment seed # set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here # note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg["seed"] env_cfg.seed = agent_cfg["seed"]
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
# directory for logging into # directory for logging into
log_dir = os.path.join("logs", "sb3", args_cli.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) log_dir = os.path.join("logs", "sb3", args_cli.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
......
...@@ -106,6 +106,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -106,6 +106,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
"""Train with skrl agent.""" """Train with skrl agent."""
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
# multi-gpu training config # multi-gpu training config
if args_cli.distributed: if args_cli.distributed:
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}" env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
...@@ -118,7 +120,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -118,7 +120,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy" skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
# set the environment seed # set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here # note: certain randomization occur in the environment initialization so we set the seed here
env_cfg.seed = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"] env_cfg.seed = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]
# specify directory for logging experiments # specify directory for logging experiments
...@@ -135,11 +137,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -135,11 +137,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# update log_dir # update log_dir
log_dir = os.path.join(log_root_path, log_dir) log_dir = os.path.join(log_root_path, log_dir)
# multi-gpu training config
if args_cli.distributed:
# update env config device
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
# dump the configuration into log-directory # dump the configuration into log-directory
dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg) dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg)
dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg) dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment