Unverified Commit 83d62e21 authored by James Smith's avatar James Smith Committed by GitHub

Fixes imitation learning workflow for lift environment (#451)

# Description

This PR fixes the imitation learning workflow in that
`collect_demonstrations`, `train` and `play` scripts all don't throw
exceptions. I haven't validated that the training actually generates a
successful policy, only that the loss decreased within the first few
iterations.

A follow up task might be to make sure that the chosen observation terms
can still result in a good policy, but that's outside of the scope of
updating this workflow to API changes in Orbit and Robomimic.

Fixes #387 

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Screenshot

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./orbit.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------
Signed-off-by: 's avatarJames Smith <142246516+jsmith-bdai@users.noreply.github.com>
Co-authored-by: 's avatarMayank Mittal <12863862+Mayankm96@users.noreply.github.com>
parent 40e4591f
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.15.0"
version = "0.15.1"
# Description
title = "ORBIT framework for Robot Learning"
......
Changelog
---------
0.15.1 (2024-03-19)
~~~~~~~~~~~~~~~~~~~
Fixed
^^^^^
* Fixed the imitation learning workflow example script, updating Orbit and Robomimic API calls.
* Removed the resetting of :attr:`_term_dones` in the :meth:`omni.isaac.orbit.managers.TerminationManager.reset`.
Previously, the environment cleared out all the terms. However, it impaired reading the specific term's values externally.
0.15.0 (2024-03-17)
~~~~~~~~~~~~~~~~~~~
......
......@@ -136,8 +136,6 @@ class TerminationManager(ManagerBase):
for key in self._term_dones.keys():
# store information
extras["Episode Termination/" + key] = torch.count_nonzero(self._term_dones[key][env_ids]).item()
# reset episode dones
self._term_dones[key][env_ids] = False
# reset all the reward terms
for term_cfg in self._class_term_cfgs:
term_cfg.func.reset(env_ids=env_ids)
......
......@@ -3,8 +3,8 @@
#
# SPDX-License-Identifier: BSD-3-Clause
import gymnasium as gym
import os
from . import agents, ik_abs_env_cfg, ik_rel_env_cfg, joint_pos_env_cfg
......@@ -75,6 +75,7 @@ gym.register(
"env_cfg_entry_point": ik_rel_env_cfg.FrankaCubeLiftEnvCfg,
"rsl_rl_cfg_entry_point": agents.rsl_rl_cfg.LiftCubePPORunnerCfg,
"skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml",
"robomimic_bc_cfg_entry_point": os.path.join(agents.__path__[0], "robomimic/bc.json"),
},
disable_env_checker=True,
)
......
......@@ -40,7 +40,8 @@
"hdf5_cache_mode": "all",
"hdf5_use_swmr": true,
"hdf5_normalize_obs": false,
"hdf5_filter_key": null,
"hdf5_filter_key": "train",
"hdf5_validation_filter_key": "valid",
"seq_length": 1,
"dataset_keys": [
"actions",
......@@ -141,10 +142,10 @@
"modalities": {
"obs": {
"low_dim": [
"tool_dof_pos_scaled",
"tool_positions",
"object_relative_tool_positions",
"object_desired_positions"
"joint_pos",
"joint_vel",
"object_position",
"target_object_position"
],
"rgb": [],
"depth": [],
......
......@@ -9,3 +9,4 @@ from omni.isaac.orbit.envs.mdp import * # noqa: F401, F403
from .observations import * # noqa: F401, F403
from .rewards import * # noqa: F401, F403
from .terminations import * # noqa: F401, F403
# Copyright (c) 2022-2024, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Common functions that can be used to activate certain terminations for the lift task.
The functions can be passed to the :class:`omni.isaac.orbit.managers.TerminationTermCfg` object to enable
the termination introduced by the function.
"""
from __future__ import annotations
import torch
from typing import TYPE_CHECKING
from omni.isaac.orbit.assets import RigidObject
from omni.isaac.orbit.managers import SceneEntityCfg
from omni.isaac.orbit.utils.math import combine_frame_transforms
if TYPE_CHECKING:
from omni.isaac.orbit.envs import RLTaskEnv
def object_reached_goal(
env: RLTaskEnv,
command_name: str = "object_pose",
threshold: float = 0.02,
robot_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
object_cfg: SceneEntityCfg = SceneEntityCfg("object"),
) -> torch.Tensor:
"""Termination condition for the object reaching the goal position.
Args:
env: The environment.
command_name: The name of the command that is used to control the object.
threshold: The threshold for the object to reach the goal position. Defaults to 0.02.
robot_cfg: The robot configuration. Defaults to SceneEntityCfg("robot").
object_cfg: The object configuration. Defaults to SceneEntityCfg("object").
"""
# extract the used quantities (to enable type-hinting)
robot: RigidObject = env.scene[robot_cfg.name]
object: RigidObject = env.scene[object_cfg.name]
command = env.command_manager.get_command(command_name)
# compute the desired position in the world frame
des_pos_b = command[:, :3]
des_pos_w, _ = combine_frame_transforms(robot.data.root_state_w[:, :3], robot.data.root_state_w[:, 3:7], des_pos_b)
# distance of the end-effector to the object: (num_envs,)
distance = torch.norm(des_pos_w - object.data.root_pos_w[:, :3], dim=1)
# rewarded if the object is lifted above the threshold
return distance < threshold
......@@ -39,10 +39,12 @@ import os
import torch
from omni.isaac.orbit.devices import Se3Keyboard, Se3SpaceMouse
from omni.isaac.orbit.managers import TerminationTermCfg as DoneTerm
from omni.isaac.orbit.utils.io import dump_pickle, dump_yaml
import omni.isaac.contrib_tasks # noqa: F401
import omni.isaac.orbit_tasks # noqa: F401
from omni.isaac.orbit_tasks.manipulation.lift import mdp
from omni.isaac.orbit_tasks.utils.data_collector import RobomimicDataCollector
from omni.isaac.orbit_tasks.utils.parse_cfg import parse_env_cfg
......@@ -64,21 +66,30 @@ def pre_process_actions(delta_pose: torch.Tensor, gripper_command: bool) -> torc
def main():
"""Collect demonstrations from the environment using teleop interfaces."""
assert (
args_cli.task == "Isaac-Lift-Cube-Franka-IK-Rel-v0"
), "Only 'Isaac-Lift-Cube-Franka-IK-Rel-v0' is supported currently."
# parse configuration
env_cfg = parse_env_cfg(args_cli.task, use_gpu=not args_cli.cpu, num_envs=args_cli.num_envs)
# modify configuration
env_cfg.control.control_type = "inverse_kinematics"
env_cfg.control.inverse_kinematics.command_type = "pose_rel"
env_cfg.terminations.episode_timeout = False
env_cfg.terminations.is_success = True
env_cfg.observations.return_dict_obs_in_group = True
# modify configuration such that the environment runs indefinitely
# until goal is reached
env_cfg.terminations.time_out = None
# set the resampling time range to large number to avoid resampling
env_cfg.commands.object_pose.resampling_time_range = (1.0e9, 1.0e9)
# we want to have the terms in the observations returned as a dictionary
# rather than a concatenated tensor
env_cfg.observations.policy.concatenate_terms = False
# add termination condition for reaching the goal otherwise the environment won't reset
env_cfg.terminations.object_reached_goal = DoneTerm(func=mdp.object_reached_goal)
# create environment
env = gym.make(args_cli.task, cfg=env_cfg)
# create controller
if args_cli.device.lower() == "keyboard":
teleop_interface = Se3Keyboard(pos_sensitivity=0.4, rot_sensitivity=0.8)
teleop_interface = Se3Keyboard(pos_sensitivity=0.04, rot_sensitivity=0.08)
elif args_cli.device.lower() == "spacemouse":
teleop_interface = Se3SpaceMouse(pos_sensitivity=0.05, rot_sensitivity=0.005)
else:
......@@ -105,9 +116,8 @@ def main():
)
# reset environment
obs_dict = env.reset()
# robomimic only cares about policy observations
obs = obs_dict["policy"]
obs_dict, _ = env.reset()
# reset interfaces
teleop_interface.reset()
collector_interface.reset()
......@@ -126,7 +136,7 @@ def main():
# The observations need to be recollected.
# store signals before stepping
# -- obs
for key, value in obs.items():
for key, value in obs_dict["policy"].items():
collector_interface.add(f"obs/{key}", value)
# -- actions
collector_interface.add("actions", actions)
......@@ -137,23 +147,18 @@ def main():
if env.unwrapped.sim.is_stopped():
break
# robomimic only cares about policy observations
obs = obs_dict["policy"]
# store signals from the environment
# -- next_obs
for key, value in obs.items():
collector_interface.add(f"next_obs/{key}", value.cpu().numpy())
for key, value in obs_dict["policy"].items():
collector_interface.add(f"next_obs/{key}", value)
# -- rewards
collector_interface.add("rewards", rewards)
# -- dones
collector_interface.add("dones", dones)
# -- is-success label
try:
collector_interface.add("success", info["is_success"])
except KeyError:
raise RuntimeError(
"Only goal-conditioned environment supported. No attribute named"
f" 'is_success' found in {list(info.keys())}."
)
# -- is success label
collector_interface.add("success", env.termination_manager.get_term("object_reached_goal"))
# flush data from collector for successful environments
reset_env_ids = dones.nonzero(as_tuple=False).squeeze(-1)
collector_interface.flush(reset_env_ids)
......
......@@ -49,12 +49,9 @@ def main():
"""Run a trained policy from robomimic with Orbit environment."""
# parse configuration
env_cfg = parse_env_cfg(args_cli.task, use_gpu=not args_cli.cpu, num_envs=1, use_fabric=not args_cli.disable_fabric)
# modify configuration
env_cfg.control.control_type = "inverse_kinematics"
env_cfg.control.inverse_kinematics.command_type = "pose_rel"
env_cfg.terminations.episode_timeout = False
env_cfg.terminations.is_success = True
env_cfg.observations.return_dict_obs_in_group = True
# we want to have the terms in the observations returned as a dictionary
# rather than a concatenated tensor
env_cfg.observations.policy.concatenate_terms = False
# create environment
env = gym.make(args_cli.task, cfg=env_cfg)
......@@ -65,7 +62,7 @@ def main():
policy, _ = FileUtils.policy_from_checkpoint(ckpt_path=args_cli.checkpoint, device=device, verbose=True)
# reset environment
obs_dict = env.reset()
obs_dict, _ = env.reset()
# robomimic only cares about policy observations
obs = obs_dict["policy"]
# simulate environment
......@@ -74,9 +71,9 @@ def main():
with torch.inference_mode():
# compute actions
actions = policy(obs)
actions = torch.from_numpy(actions).to(device=device).view(1, env.action_space.shape[0])
actions = torch.from_numpy(actions).to(device=device).view(1, env.action_space.shape[1])
# apply actions
obs_dict, _, _, _ = env.step(actions)
obs_dict = env.step(actions)[0]
# robomimic only cares about policy observations
obs = obs_dict["policy"]
......
......@@ -75,6 +75,9 @@ from robomimic.algo import RolloutPolicy, algo_factory
from robomimic.config import config_factory
from robomimic.utils.log_utils import DataLogger, PrintLogger
# Needed so that environment is registered
import omni.isaac.orbit_tasks # noqa: F401
def train(config, device):
"""Train a model using the algorithm."""
......@@ -139,7 +142,7 @@ def train(config, device):
print("")
# setup for a new training run
data_logger = DataLogger(log_dir, log_tb=config.experiment.logging.log_tb)
data_logger = DataLogger(log_dir, config=config, log_tb=config.experiment.logging.log_tb)
model = algo_factory(
algo_name=config.algo_name,
config=config,
......@@ -347,7 +350,9 @@ def main(args):
if args.task is not None:
# obtain the configuration entry point
cfg_entry_point_key = f"robomimic_{args.algo}_cfg_entry_point"
cfg_entry_point_file = gym.spec(args.task)._kwargs.pop(cfg_entry_point_key)
print(f"Loading configuration for task: {args.task}")
cfg_entry_point_file = gym.spec(args.task).kwargs.pop(cfg_entry_point_key)
# check if entry point exists
if cfg_entry_point_file is None:
raise ValueError(
......@@ -411,11 +416,7 @@ if __name__ == "__main__":
args = parser.parse_args()
try:
# run training
main(args)
except Exception:
raise
finally:
# close sim app
simulation_app.close()
# run training
main(args)
# close sim app
simulation_app.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment