Unverified Commit 5e84450c authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Simplifies the return type for `parse_env_cfg` method (#965)

# Description

Previously, the returned config object for `parse_env_cfg` mentioned the
dictionary. However, this is hardly used or supported in our workflows
code. This MR simplifies the return type to only be an instance of
manager-based or direct environment configuration classes. Doing so,
also cleans the code a bit by removing the need of explicit
type-hinting.

## Type of change

- Breaking change (fix or feature that would cause existing
functionality to not work as expected)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
parent ac71354c
...@@ -13,11 +13,10 @@ import os ...@@ -13,11 +13,10 @@ import os
import re import re
import yaml import yaml
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
from omni.isaac.lab.utils import update_class_from_dict, update_dict
def load_cfg_from_registry(task_name: str, entry_point_key: str) -> dict | ManagerBasedRLEnvCfg: def load_cfg_from_registry(task_name: str, entry_point_key: str) -> dict | object:
"""Load default configuration given its entry point from the gym registry. """Load default configuration given its entry point from the gym registry.
This function loads the configuration object from the gym registry for the given task name. This function loads the configuration object from the gym registry for the given task name.
...@@ -46,7 +45,8 @@ def load_cfg_from_registry(task_name: str, entry_point_key: str) -> dict | Manag ...@@ -46,7 +45,8 @@ def load_cfg_from_registry(task_name: str, entry_point_key: str) -> dict | Manag
entry_point_key: The entry point key to resolve the configuration file. entry_point_key: The entry point key to resolve the configuration file.
Returns: Returns:
The parsed configuration object. This is either a dictionary or a class object. The parsed configuration object. If the entry point is a YAML file, it is parsed into a dictionary.
If the entry point is a Python class, it is instantiated and returned.
Raises: Raises:
ValueError: If the entry point key is not available in the gym registry for the task. ValueError: If the entry point key is not available in the gym registry for the task.
...@@ -98,7 +98,7 @@ def load_cfg_from_registry(task_name: str, entry_point_key: str) -> dict | Manag ...@@ -98,7 +98,7 @@ def load_cfg_from_registry(task_name: str, entry_point_key: str) -> dict | Manag
def parse_env_cfg( def parse_env_cfg(
task_name: str, device: str = "cuda:0", num_envs: int | None = None, use_fabric: bool | None = None task_name: str, device: str = "cuda:0", num_envs: int | None = None, use_fabric: bool | None = None
) -> dict | ManagerBasedRLEnvCfg: ) -> ManagerBasedRLEnvCfg | DirectRLEnvCfg:
"""Parse configuration for an environment and override based on inputs. """Parse configuration for an environment and override based on inputs.
Args: Args:
...@@ -110,35 +110,28 @@ def parse_env_cfg( ...@@ -110,35 +110,28 @@ def parse_env_cfg(
Defaults to None, in which case it is left unchanged. Defaults to None, in which case it is left unchanged.
Returns: Returns:
The parsed configuration object. This is either a dictionary or a class object. The parsed configuration object.
Raises: Raises:
ValueError: If the task name is not provided, i.e. None. RuntimeError: If the configuration for the task is not a class. We assume users always use a class for the
environment configuration.
""" """
# check if a task name is provided # load the default configuration
if task_name is None: cfg = load_cfg_from_registry(task_name, "env_cfg_entry_point")
raise ValueError("Please provide a valid task name. Hint: Use --task <task_name>.")
# create a dictionary to update from
args_cfg = {"sim": {"physx": dict()}, "scene": dict()}
# simulation device # check that it is not a dict
args_cfg["sim"]["device"] = device # we assume users always use a class for the configuration
if isinstance(cfg, dict):
raise RuntimeError(f"Configuration for the task: '{task_name}' is not a class. Please provide a class.")
# simulation device
cfg.sim.device = device
# disable fabric to read/write through USD # disable fabric to read/write through USD
if use_fabric is not None: if use_fabric is not None:
args_cfg["sim"]["use_fabric"] = use_fabric cfg.sim.use_fabric = use_fabric
# number of environments # number of environments
if num_envs is not None: if num_envs is not None:
args_cfg["scene"]["num_envs"] = num_envs cfg.scene.num_envs = num_envs
# load the default configuration
cfg = load_cfg_from_registry(task_name, "env_cfg_entry_point")
# update the main configuration
if isinstance(cfg, dict):
cfg = update_dict(cfg, args_cfg)
else:
update_class_from_dict(cfg, args_cfg)
return cfg return cfg
...@@ -165,12 +158,13 @@ def get_checkpoint_path( ...@@ -165,12 +158,13 @@ def get_checkpoint_path(
sort_alpha: Whether to sort the runs by alphabetical order. Defaults to True. sort_alpha: Whether to sort the runs by alphabetical order. Defaults to True.
If False, the folders in :attr:`run_dir` are sorted by the last modified time. If False, the folders in :attr:`run_dir` are sorted by the last modified time.
Returns:
The path to the model checkpoint.
Raises: Raises:
ValueError: When no runs are found in the input directory. ValueError: When no runs are found in the input directory.
ValueError: When no checkpoints are found in the input directory. ValueError: When no checkpoints are found in the input directory.
Returns:
The path to the model checkpoint.
""" """
# check if runs present in directory # check if runs present in directory
try: try:
......
...@@ -21,8 +21,6 @@ import unittest ...@@ -21,8 +21,6 @@ import unittest
import carb import carb
import omni.usd import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg
...@@ -88,9 +86,9 @@ class TestEnvironments(unittest.TestCase): ...@@ -88,9 +86,9 @@ class TestEnvironments(unittest.TestCase):
# create a new stage # create a new stage
omni.usd.get_context().new_stage() omni.usd.get_context().new_stage()
# parse configuration # parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=device, num_envs=num_envs) env_cfg = parse_env_cfg(task_name, device=device, num_envs=num_envs)
# create environment # create environment
env: ManagerBasedRLEnv = gym.make(task_name, cfg=env_cfg) env = gym.make(task_name, cfg=env_cfg)
# disable control on stop # disable control on stop
env.unwrapped.sim._app_control_on_stop_handle = None # type: ignore env.unwrapped.sim._app_control_on_stop_handle = None # type: ignore
......
...@@ -20,8 +20,6 @@ import unittest ...@@ -20,8 +20,6 @@ import unittest
import omni.usd import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import parse_env_cfg from omni.isaac.lab_tasks.utils import parse_env_cfg
...@@ -60,7 +58,7 @@ class TestRecordVideoWrapper(unittest.TestCase): ...@@ -60,7 +58,7 @@ class TestRecordVideoWrapper(unittest.TestCase):
omni.usd.get_context().new_stage() omni.usd.get_context().new_stage()
# parse configuration # parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs) env_cfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
# create environment # create environment
env = gym.make(task_name, cfg=env_cfg, render_mode="rgb_array") env = gym.make(task_name, cfg=env_cfg, render_mode="rgb_array")
......
...@@ -20,8 +20,6 @@ import unittest ...@@ -20,8 +20,6 @@ import unittest
import omni.usd import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.rl_games import RlGamesVecEnvWrapper from omni.isaac.lab_tasks.utils.wrappers.rl_games import RlGamesVecEnvWrapper
...@@ -58,7 +56,7 @@ class TestRlGamesVecEnvWrapper(unittest.TestCase): ...@@ -58,7 +56,7 @@ class TestRlGamesVecEnvWrapper(unittest.TestCase):
# create a new stage # create a new stage
omni.usd.get_context().new_stage() omni.usd.get_context().new_stage()
# parse configuration # parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs) env_cfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
agent_cfg = load_cfg_from_registry(task_name, "rl_games_cfg_entry_point") # noqa: F841 agent_cfg = load_cfg_from_registry(task_name, "rl_games_cfg_entry_point") # noqa: F841
# create environment # create environment
env = gym.make(task_name, cfg=env_cfg) env = gym.make(task_name, cfg=env_cfg)
......
...@@ -20,8 +20,6 @@ import unittest ...@@ -20,8 +20,6 @@ import unittest
import omni.usd import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import RslRlVecEnvWrapper from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import RslRlVecEnvWrapper
...@@ -58,7 +56,7 @@ class TestRslRlVecEnvWrapper(unittest.TestCase): ...@@ -58,7 +56,7 @@ class TestRslRlVecEnvWrapper(unittest.TestCase):
# create a new stage # create a new stage
omni.usd.get_context().new_stage() omni.usd.get_context().new_stage()
# parse configuration # parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs) env_cfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
agent_cfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point") # noqa: F841 agent_cfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point") # noqa: F841
# create environment # create environment
env = gym.make(task_name, cfg=env_cfg) env = gym.make(task_name, cfg=env_cfg)
...@@ -94,7 +92,7 @@ class TestRslRlVecEnvWrapper(unittest.TestCase): ...@@ -94,7 +92,7 @@ class TestRslRlVecEnvWrapper(unittest.TestCase):
# create a new stage # create a new stage
omni.usd.get_context().new_stage() omni.usd.get_context().new_stage()
# parse configuration # parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs) env_cfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
# change to finite horizon # change to finite horizon
env_cfg.is_finite_horizon = True env_cfg.is_finite_horizon = True
......
...@@ -21,8 +21,6 @@ import unittest ...@@ -21,8 +21,6 @@ import unittest
import omni.usd import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper
...@@ -59,7 +57,7 @@ class TestStableBaselines3VecEnvWrapper(unittest.TestCase): ...@@ -59,7 +57,7 @@ class TestStableBaselines3VecEnvWrapper(unittest.TestCase):
# create a new stage # create a new stage
omni.usd.get_context().new_stage() omni.usd.get_context().new_stage()
# parse configuration # parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs) env_cfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
agent_cfg = load_cfg_from_registry(task_name, "sb3_cfg_entry_point") # noqa: F841 agent_cfg = load_cfg_from_registry(task_name, "sb3_cfg_entry_point") # noqa: F841
# create environment # create environment
env = gym.make(task_name, cfg=env_cfg) env = gym.make(task_name, cfg=env_cfg)
......
...@@ -20,8 +20,6 @@ import unittest ...@@ -20,8 +20,6 @@ import unittest
import omni.usd import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
import omni.isaac.lab_tasks # noqa: F401 import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper
...@@ -58,7 +56,7 @@ class TestSKRLVecEnvWrapper(unittest.TestCase): ...@@ -58,7 +56,7 @@ class TestSKRLVecEnvWrapper(unittest.TestCase):
# create a new stage # create a new stage
omni.usd.get_context().new_stage() omni.usd.get_context().new_stage()
# parse configuration # parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs) env_cfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
agent_cfg = load_cfg_from_registry(task_name, "skrl_cfg_entry_point") # noqa: F841 agent_cfg = load_cfg_from_registry(task_name, "skrl_cfg_entry_point") # noqa: F841
# create environment # create environment
env = gym.make(task_name, cfg=env_cfg) env = gym.make(task_name, cfg=env_cfg)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment