Unverified Commit 629136c1 authored by Chenyu Yang's avatar Chenyu Yang Committed by GitHub

Updates robomimic config for Franka-Lift environment (#69)

* Enlarges the hard-coded success range from 0.002 to 0.02
* Updates observations in BC configs from `object_positions` to `object_relative_tool_positions`
parent 5938748d
...@@ -143,7 +143,7 @@ ...@@ -143,7 +143,7 @@
"low_dim": [ "low_dim": [
"tool_dof_pos_scaled", "tool_dof_pos_scaled",
"tool_positions", "tool_positions",
"object_positions", "object_relative_tool_positions",
"object_desired_positions" "object_desired_positions"
], ],
"rgb": [], "rgb": [],
......
...@@ -179,7 +179,7 @@ ...@@ -179,7 +179,7 @@
"low_dim": [ "low_dim": [
"tool_dof_pos_scaled", "tool_dof_pos_scaled",
"tool_positions", "tool_positions",
"object_positions", "object_relative_tool_positions",
"object_desired_positions" "object_desired_positions"
], ],
"rgb": [], "rgb": [],
......
...@@ -183,7 +183,7 @@ class LiftEnv(IsaacEnv): ...@@ -183,7 +183,7 @@ class LiftEnv(IsaacEnv):
self.extras["time_outs"] = self.episode_length_buf >= self.max_episode_length self.extras["time_outs"] = self.episode_length_buf >= self.max_episode_length
# -- add information to extra if task completed # -- add information to extra if task completed
object_position_error = torch.norm(self.object.data.root_pos_w - self.object_des_pose_w[:, 0:3], dim=1) object_position_error = torch.norm(self.object.data.root_pos_w - self.object_des_pose_w[:, 0:3], dim=1)
self.extras["is_success"] = torch.where(object_position_error < 0.002, 1, self.reset_buf) self.extras["is_success"] = torch.where(object_position_error < 0.02, 1, self.reset_buf)
# -- update USD visualization # -- update USD visualization
if self.cfg.viewer.debug_vis and self.enable_render: if self.cfg.viewer.debug_vis and self.enable_render:
self._debug_vis() self._debug_vis()
...@@ -287,7 +287,7 @@ class LiftEnv(IsaacEnv): ...@@ -287,7 +287,7 @@ class LiftEnv(IsaacEnv):
# -- when task is successful # -- when task is successful
if self.cfg.terminations.is_success: if self.cfg.terminations.is_success:
object_position_error = torch.norm(self.object.data.root_pos_w - self.object_des_pose_w[:, 0:3], dim=1) object_position_error = torch.norm(self.object.data.root_pos_w - self.object_des_pose_w[:, 0:3], dim=1)
self.reset_buf = torch.where(object_position_error < 0.002, 1, self.reset_buf) self.reset_buf = torch.where(object_position_error < 0.02, 1, self.reset_buf)
# -- object fell off the table (table at height: 0.0 m) # -- object fell off the table (table at height: 0.0 m)
if self.cfg.terminations.object_falling: if self.cfg.terminations.object_falling:
self.reset_buf = torch.where(object_pos[:, 2] < -0.05, 1, self.reset_buf) self.reset_buf = torch.where(object_pos[:, 2] < -0.05, 1, self.reset_buf)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment