Unverified Commit 5e84450c authored by Mayank Mittal's avatar Mayank Mittal Committed by GitHub

Simplifies the return type for `parse_env_cfg` method (#965)

# Description

Previously, the returned config object for `parse_env_cfg` mentioned the
dictionary. However, this is hardly used or supported in our workflows
code. This MR simplifies the return type to only be an instance of
manager-based or direct environment configuration classes. Doing so,
also cleans the code a bit by removing the need of explicit
type-hinting.

## Type of change

- Breaking change (fix or feature that would cause existing
functionality to not work as expected)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
parent ac71354c
......@@ -13,11 +13,10 @@ import os
import re
import yaml
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
from omni.isaac.lab.utils import update_class_from_dict, update_dict
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
def load_cfg_from_registry(task_name: str, entry_point_key: str) -> dict | ManagerBasedRLEnvCfg:
def load_cfg_from_registry(task_name: str, entry_point_key: str) -> dict | object:
"""Load default configuration given its entry point from the gym registry.
This function loads the configuration object from the gym registry for the given task name.
......@@ -46,7 +45,8 @@ def load_cfg_from_registry(task_name: str, entry_point_key: str) -> dict | Manag
entry_point_key: The entry point key to resolve the configuration file.
Returns:
The parsed configuration object. This is either a dictionary or a class object.
The parsed configuration object. If the entry point is a YAML file, it is parsed into a dictionary.
If the entry point is a Python class, it is instantiated and returned.
Raises:
ValueError: If the entry point key is not available in the gym registry for the task.
......@@ -98,7 +98,7 @@ def load_cfg_from_registry(task_name: str, entry_point_key: str) -> dict | Manag
def parse_env_cfg(
task_name: str, device: str = "cuda:0", num_envs: int | None = None, use_fabric: bool | None = None
) -> dict | ManagerBasedRLEnvCfg:
) -> ManagerBasedRLEnvCfg | DirectRLEnvCfg:
"""Parse configuration for an environment and override based on inputs.
Args:
......@@ -110,35 +110,28 @@ def parse_env_cfg(
Defaults to None, in which case it is left unchanged.
Returns:
The parsed configuration object. This is either a dictionary or a class object.
The parsed configuration object.
Raises:
ValueError: If the task name is not provided, i.e. None.
RuntimeError: If the configuration for the task is not a class. We assume users always use a class for the
environment configuration.
"""
# check if a task name is provided
if task_name is None:
raise ValueError("Please provide a valid task name. Hint: Use --task <task_name>.")
# create a dictionary to update from
args_cfg = {"sim": {"physx": dict()}, "scene": dict()}
# load the default configuration
cfg = load_cfg_from_registry(task_name, "env_cfg_entry_point")
# simulation device
args_cfg["sim"]["device"] = device
# check that it is not a dict
# we assume users always use a class for the configuration
if isinstance(cfg, dict):
raise RuntimeError(f"Configuration for the task: '{task_name}' is not a class. Please provide a class.")
# simulation device
cfg.sim.device = device
# disable fabric to read/write through USD
if use_fabric is not None:
args_cfg["sim"]["use_fabric"] = use_fabric
cfg.sim.use_fabric = use_fabric
# number of environments
if num_envs is not None:
args_cfg["scene"]["num_envs"] = num_envs
# load the default configuration
cfg = load_cfg_from_registry(task_name, "env_cfg_entry_point")
# update the main configuration
if isinstance(cfg, dict):
cfg = update_dict(cfg, args_cfg)
else:
update_class_from_dict(cfg, args_cfg)
cfg.scene.num_envs = num_envs
return cfg
......@@ -165,12 +158,13 @@ def get_checkpoint_path(
sort_alpha: Whether to sort the runs by alphabetical order. Defaults to True.
If False, the folders in :attr:`run_dir` are sorted by the last modified time.
Returns:
The path to the model checkpoint.
Raises:
ValueError: When no runs are found in the input directory.
ValueError: When no checkpoints are found in the input directory.
Returns:
The path to the model checkpoint.
"""
# check if runs present in directory
try:
......
......@@ -21,8 +21,6 @@ import unittest
import carb
import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg
......@@ -88,9 +86,9 @@ class TestEnvironments(unittest.TestCase):
# create a new stage
omni.usd.get_context().new_stage()
# parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=device, num_envs=num_envs)
env_cfg = parse_env_cfg(task_name, device=device, num_envs=num_envs)
# create environment
env: ManagerBasedRLEnv = gym.make(task_name, cfg=env_cfg)
env = gym.make(task_name, cfg=env_cfg)
# disable control on stop
env.unwrapped.sim._app_control_on_stop_handle = None # type: ignore
......
......@@ -20,8 +20,6 @@ import unittest
import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils import parse_env_cfg
......@@ -60,7 +58,7 @@ class TestRecordVideoWrapper(unittest.TestCase):
omni.usd.get_context().new_stage()
# parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
env_cfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
# create environment
env = gym.make(task_name, cfg=env_cfg, render_mode="rgb_array")
......
......@@ -20,8 +20,6 @@ import unittest
import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.rl_games import RlGamesVecEnvWrapper
......@@ -58,7 +56,7 @@ class TestRlGamesVecEnvWrapper(unittest.TestCase):
# create a new stage
omni.usd.get_context().new_stage()
# parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
env_cfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
agent_cfg = load_cfg_from_registry(task_name, "rl_games_cfg_entry_point") # noqa: F841
# create environment
env = gym.make(task_name, cfg=env_cfg)
......
......@@ -20,8 +20,6 @@ import unittest
import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import RslRlVecEnvWrapper
......@@ -58,7 +56,7 @@ class TestRslRlVecEnvWrapper(unittest.TestCase):
# create a new stage
omni.usd.get_context().new_stage()
# parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
env_cfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
agent_cfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point") # noqa: F841
# create environment
env = gym.make(task_name, cfg=env_cfg)
......@@ -94,7 +92,7 @@ class TestRslRlVecEnvWrapper(unittest.TestCase):
# create a new stage
omni.usd.get_context().new_stage()
# parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
env_cfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
# change to finite horizon
env_cfg.is_finite_horizon = True
......
......@@ -21,8 +21,6 @@ import unittest
import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper
......@@ -59,7 +57,7 @@ class TestStableBaselines3VecEnvWrapper(unittest.TestCase):
# create a new stage
omni.usd.get_context().new_stage()
# parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
env_cfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
agent_cfg = load_cfg_from_registry(task_name, "sb3_cfg_entry_point") # noqa: F841
# create environment
env = gym.make(task_name, cfg=env_cfg)
......
......@@ -20,8 +20,6 @@ import unittest
import omni.usd
from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
import omni.isaac.lab_tasks # noqa: F401
from omni.isaac.lab_tasks.utils.parse_cfg import load_cfg_from_registry, parse_env_cfg
from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper
......@@ -58,7 +56,7 @@ class TestSKRLVecEnvWrapper(unittest.TestCase):
# create a new stage
omni.usd.get_context().new_stage()
# parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
env_cfg = parse_env_cfg(task_name, device=self.device, num_envs=self.num_envs)
agent_cfg = load_cfg_from_registry(task_name, "skrl_cfg_entry_point") # noqa: F841
# create environment
env = gym.make(task_name, cfg=env_cfg)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment