Unverified Commit 8ddc4830 authored by Kelly Guo's avatar Kelly Guo Committed by GitHub

Adds position threshold check for state transitions (#1544)

# Description

Adds a position threshold check, resolving 3 TODO error comments, to
ensure the robot's end effector is within a specified distance from the
target position before transitioning between states in the pick and lift
state machine. Improves the precision of state transitions and helps
prevent premature actions during object manipulation. I.e, the threshold
ensures the robot is "close enough" to the target position before
proceeding, reducing the likelihood of failed grasps or incorrect
movements.

PR adapted from https://github.com/isaac-sim/IsaacLab/pull/1273/ by
@DorsaRoh.

## Type of change

<!-- As you go through the list, delete the ones that are not
applicable. -->

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./isaaclab.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

<!--
As you go through the checklist above, you can mark something as done by
putting an x character in it

For example,
- [x] I have done this task
- [ ] I have not done this task
-->

---------
Signed-off-by: 's avatarKelly Guo <kellyg@nvidia.com>
Co-authored-by: 's avatarDorsaRoh <dorsa.rohani@gmail.com>
parent f01c6f9e
[package]
# Note: Semantic Versioning is used: https://semver.org/
version = "0.27.27"
version = "0.27.28"
# Description
title = "Isaac Lab framework for Robot Learning"
......
Changelog
---------
0.27.28 (2024-12-14)
~~~~~~~~~~~~~~~~~~~~
Changed
^^^^^^^
* Added check for error below threshold in state machines to ensure the state has been reached.
0.27.27 (2024-12-13)
~~~~~~~~~~~~~~~~~~~~
......
......@@ -81,6 +81,11 @@ class PickSmWaitTime:
LIFT_OBJECT = wp.constant(1.0)
@wp.func
def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool:
return wp.length(current_pos - desired_pos) < threshold
@wp.kernel
def infer_state_machine(
dt: wp.array(dtype=float),
......@@ -92,6 +97,7 @@ def infer_state_machine(
des_ee_pose: wp.array(dtype=wp.transform),
gripper_state: wp.array(dtype=float),
offset: wp.array(dtype=wp.transform),
position_threshold: float,
):
# retrieve thread id
tid = wp.tid()
......@@ -109,21 +115,28 @@ def infer_state_machine(
elif state == PickSmState.APPROACH_ABOVE_OBJECT:
des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid])
gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.APPROACH_OBJECT
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.APPROACH_OBJECT
sm_wait_time[tid] = 0.0
elif state == PickSmState.APPROACH_OBJECT:
des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.GRASP_OBJECT
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.GRASP_OBJECT
sm_wait_time[tid] = 0.0
elif state == PickSmState.GRASP_OBJECT:
des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.CLOSE
......@@ -135,12 +148,16 @@ def infer_state_machine(
elif state == PickSmState.LIFT_OBJECT:
des_ee_pose[tid] = des_object_pose[tid]
gripper_state[tid] = GripperState.CLOSE
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.LIFT_OBJECT
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.LIFT_OBJECT
sm_wait_time[tid] = 0.0
# increment wait time
sm_wait_time[tid] = sm_wait_time[tid] + dt[tid]
......@@ -160,7 +177,7 @@ class PickAndLiftSm:
5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state.
"""
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"):
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01):
"""Initialize the state machine.
Args:
......@@ -172,6 +189,7 @@ class PickAndLiftSm:
self.dt = float(dt)
self.num_envs = num_envs
self.device = device
self.position_threshold = position_threshold
# initialize state machine
self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device)
self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device)
......@@ -201,7 +219,7 @@ class PickAndLiftSm:
self.sm_state[env_ids] = 0
self.sm_wait_time[env_ids] = 0.0
def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_pose: torch.Tensor):
def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_pose: torch.Tensor) -> torch.Tensor:
"""Compute the desired state of the robot's end-effector and the gripper."""
# convert all transformations from (w, x, y, z) to (x, y, z, w)
ee_pose = ee_pose[:, [0, 1, 2, 4, 5, 6, 3]]
......@@ -227,6 +245,7 @@ class PickAndLiftSm:
self.des_ee_pose_wp,
self.des_gripper_state_wp,
self.offset_wp,
self.position_threshold,
],
device=self.device,
)
......@@ -257,7 +276,9 @@ def main():
desired_orientation = torch.zeros((env.unwrapped.num_envs, 4), device=env.unwrapped.device)
desired_orientation[:, 1] = 1.0
# create state machine
pick_sm = PickAndLiftSm(env_cfg.sim.dt * env_cfg.decimation, env.unwrapped.num_envs, env.unwrapped.device)
pick_sm = PickAndLiftSm(
env_cfg.sim.dt * env_cfg.decimation, env.unwrapped.num_envs, env.unwrapped.device, position_threshold=0.01
)
while simulation_app.is_running():
# run everything in inference mode
......
......@@ -80,6 +80,11 @@ class PickSmWaitTime:
OPEN_GRIPPER = wp.constant(0.0)
@wp.func
def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool:
return wp.length(current_pos - desired_pos) < threshold
@wp.kernel
def infer_state_machine(
dt: wp.array(dtype=float),
......@@ -91,6 +96,7 @@ def infer_state_machine(
des_ee_pose: wp.array(dtype=wp.transform),
gripper_state: wp.array(dtype=float),
offset: wp.array(dtype=wp.transform),
position_threshold: float,
):
# retrieve thread id
tid = wp.tid()
......@@ -108,21 +114,29 @@ def infer_state_machine(
elif state == PickSmState.APPROACH_ABOVE_OBJECT:
des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid])
gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.APPROACH_OBJECT
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.APPROACH_OBJECT
sm_wait_time[tid] = 0.0
elif state == PickSmState.APPROACH_OBJECT:
des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.GRASP_OBJECT
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.GRASP_OBJECT
sm_wait_time[tid] = 0.0
elif state == PickSmState.GRASP_OBJECT:
des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.CLOSE
......@@ -134,12 +148,16 @@ def infer_state_machine(
elif state == PickSmState.LIFT_OBJECT:
des_ee_pose[tid] = des_object_pose[tid]
gripper_state[tid] = GripperState.CLOSE
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.OPEN_GRIPPER
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT:
# move to next state and reset wait time
sm_state[tid] = PickSmState.OPEN_GRIPPER
sm_wait_time[tid] = 0.0
elif state == PickSmState.OPEN_GRIPPER:
# des_ee_pose[tid] = object_pose[tid]
gripper_state[tid] = GripperState.OPEN
......@@ -167,7 +185,7 @@ class PickAndLiftSm:
5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state.
"""
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"):
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01):
"""Initialize the state machine.
Args:
......@@ -179,6 +197,7 @@ class PickAndLiftSm:
self.dt = float(dt)
self.num_envs = num_envs
self.device = device
self.position_threshold = position_threshold
# initialize state machine
self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device)
self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device)
......@@ -234,6 +253,7 @@ class PickAndLiftSm:
self.des_ee_pose_wp,
self.des_gripper_state_wp,
self.offset_wp,
self.position_threshold,
],
device=self.device,
)
......
......@@ -83,6 +83,11 @@ class OpenDrawerSmWaitTime:
RELEASE_HANDLE = wp.constant(0.2)
@wp.func
def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool:
return wp.length(current_pos - desired_pos) < threshold
@wp.kernel
def infer_state_machine(
dt: wp.array(dtype=float),
......@@ -95,6 +100,7 @@ def infer_state_machine(
handle_approach_offset: wp.array(dtype=wp.transform),
handle_grasp_offset: wp.array(dtype=wp.transform),
drawer_opening_rate: wp.array(dtype=wp.transform),
position_threshold: float,
):
# retrieve thread id
tid = wp.tid()
......@@ -112,21 +118,29 @@ def infer_state_machine(
elif state == OpenDrawerSmState.APPROACH_INFRONT_HANDLE:
des_ee_pose[tid] = wp.transform_multiply(handle_approach_offset[tid], handle_pose[tid])
gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_INFRONT_HANDLE:
# move to next state and reset wait time
sm_state[tid] = OpenDrawerSmState.APPROACH_HANDLE
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_INFRONT_HANDLE:
# move to next state and reset wait time
sm_state[tid] = OpenDrawerSmState.APPROACH_HANDLE
sm_wait_time[tid] = 0.0
elif state == OpenDrawerSmState.APPROACH_HANDLE:
des_ee_pose[tid] = handle_pose[tid]
gripper_state[tid] = GripperState.OPEN
# TODO: error between current and desired ee pose below threshold
# wait for a while
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_HANDLE:
# move to next state and reset wait time
sm_state[tid] = OpenDrawerSmState.GRASP_HANDLE
sm_wait_time[tid] = 0.0
if distance_below_threshold(
wp.transform_get_translation(ee_pose[tid]),
wp.transform_get_translation(des_ee_pose[tid]),
position_threshold,
):
# wait for a while
if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_HANDLE:
# move to next state and reset wait time
sm_state[tid] = OpenDrawerSmState.GRASP_HANDLE
sm_wait_time[tid] = 0.0
elif state == OpenDrawerSmState.GRASP_HANDLE:
des_ee_pose[tid] = wp.transform_multiply(handle_grasp_offset[tid], handle_pose[tid])
gripper_state[tid] = GripperState.CLOSE
......@@ -170,7 +184,7 @@ class OpenDrawerSm:
5. RELEASE_HANDLE: The robot releases the handle of the drawer. This is the final state.
"""
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"):
def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01):
"""Initialize the state machine.
Args:
......@@ -182,6 +196,7 @@ class OpenDrawerSm:
self.dt = float(dt)
self.num_envs = num_envs
self.device = device
self.position_threshold = position_threshold
# initialize state machine
self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device)
self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device)
......@@ -248,6 +263,7 @@ class OpenDrawerSm:
self.handle_approach_offset_wp,
self.handle_grasp_offset_wp,
self.drawer_opening_rate_wp,
self.position_threshold,
],
device=self.device,
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment