Commit ce598c2b authored by ooctipus's avatar ooctipus Committed by Kelly Guo

Enables set seed for all play scripts (#2997)

This PRs enables add mannual seeding for all play scripts,
To minimize the merge conflict due to ordering and formating, this pr
built on top of #2995, and should be merge after it.

I have mannual tested all rl-framework will work

Fixes # (issue)

<!-- As a practice, it is recommended to open an issue to have
discussions on the proposed pull request.
This makes it easier for the community to keep track of what is being
developed or added, and if a given feature
is demanded by more than one party. -->

<!-- As you go through the list, delete the ones that are not
applicable. -->

- New feature (non-breaking change which adds functionality)

Please attach before and after screenshots of the change if applicable.

<!--
Example:

| Before | After |
| ------ | ----- |
| _gif/png before_ | _gif/png after_ |

To upload images to a PR -- simply drag and drop an image while in edit
mode and it should upload the image directly. You can then paste that
source into the above before/after sections.
-->

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->
parent c907fa6c
...@@ -22,6 +22,7 @@ parser.add_argument( ...@@ -22,6 +22,7 @@ parser.add_argument(
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.") parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
parser.add_argument( parser.add_argument(
"--use_pretrained_checkpoint", "--use_pretrained_checkpoint",
action="store_true", action="store_true",
...@@ -53,6 +54,7 @@ simulation_app = app_launcher.app ...@@ -53,6 +54,7 @@ simulation_app = app_launcher.app
import gymnasium as gym import gymnasium as gym
import math import math
import os import os
import random
import time import time
import torch import torch
...@@ -91,6 +93,15 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -91,6 +93,15 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
agent_cfg["params"]["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["params"]["seed"]
# set the environment seed (after multi-gpu config for updated rank from agent seed)
# note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg["params"]["seed"]
# specify directory for logging experiments # specify directory for logging experiments
log_root_path = os.path.join("logs", "rl_games", agent_cfg["params"]["config"]["name"]) log_root_path = os.path.join("logs", "rl_games", agent_cfg["params"]["config"]["name"])
log_root_path = os.path.abspath(log_root_path) log_root_path = os.path.abspath(log_root_path)
......
...@@ -24,6 +24,7 @@ parser.add_argument( ...@@ -24,6 +24,7 @@ parser.add_argument(
) )
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
parser.add_argument( parser.add_argument(
"--use_pretrained_checkpoint", "--use_pretrained_checkpoint",
action="store_true", action="store_true",
...@@ -86,6 +87,10 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -86,6 +87,10 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli) agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
# set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg.seed
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
# specify directory for logging experiments # specify directory for logging experiments
......
...@@ -23,6 +23,7 @@ parser.add_argument( ...@@ -23,6 +23,7 @@ parser.add_argument(
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.") parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
parser.add_argument( parser.add_argument(
"--use_pretrained_checkpoint", "--use_pretrained_checkpoint",
action="store_true", action="store_true",
...@@ -59,6 +60,7 @@ simulation_app = app_launcher.app ...@@ -59,6 +60,7 @@ simulation_app = app_launcher.app
import gymnasium as gym import gymnasium as gym
import os import os
import random
import time import time
import torch import torch
...@@ -90,9 +92,16 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen ...@@ -90,9 +92,16 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# grab task name for checkpoint path # grab task name for checkpoint path
task_name = args_cli.task.split(":")[-1] task_name = args_cli.task.split(":")[-1]
train_task_name = task_name.replace("-Play", "") train_task_name = task_name.replace("-Play", "")
# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
# override configurations with non-hydra CLI arguments # override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
agent_cfg["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]
# set the environment seed
# note: certain randomizations occur in the environment initialization so we set the seed here
env_cfg.seed = agent_cfg["seed"]
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
# directory for logging into # directory for logging into
......
...@@ -27,6 +27,7 @@ parser.add_argument( ...@@ -27,6 +27,7 @@ parser.add_argument(
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.") parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
parser.add_argument( parser.add_argument(
"--use_pretrained_checkpoint", "--use_pretrained_checkpoint",
action="store_true", action="store_true",
...@@ -66,6 +67,7 @@ simulation_app = app_launcher.app ...@@ -66,6 +67,7 @@ simulation_app = app_launcher.app
import gymnasium as gym import gymnasium as gym
import os import os
import random
import time import time
import torch import torch
...@@ -124,6 +126,15 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, expe ...@@ -124,6 +126,15 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, expe
if args_cli.ml_framework.startswith("jax"): if args_cli.ml_framework.startswith("jax"):
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy" skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
# set the agent and environment seed from command line
# note: certain randomization occur in the environment initialization so we set the seed here
experiment_cfg["seed"] = args_cli.seed if args_cli.seed is not None else experiment_cfg["seed"]
env_cfg.seed = experiment_cfg["seed"]
# specify directory for logging experiments (load checkpoint) # specify directory for logging experiments (load checkpoint)
log_root_path = os.path.join("logs", "skrl", experiment_cfg["agent"]["experiment"]["directory"]) log_root_path = os.path.join("logs", "skrl", experiment_cfg["agent"]["experiment"]["directory"])
log_root_path = os.path.abspath(log_root_path) log_root_path = os.path.abspath(log_root_path)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment