Commit e00ec819 authored by Mayank Mittal's avatar Mayank Mittal

makes _process_info_cfg safer

parent 6c317317
......@@ -394,16 +394,18 @@ class ArticulatedObject:
"""Post processing of configuration parameters."""
# default state
# -- root state
# note: we cast to tuple to avoid torch/numpy type mismatch.
default_root_state = (
self.cfg.init_state.pos
+ self.cfg.init_state.rot
+ self.cfg.init_state.lin_vel
+ self.cfg.init_state.ang_vel
tuple(self.cfg.init_state.pos)
+ tuple(self.cfg.init_state.rot)
+ tuple(self.cfg.init_state.lin_vel)
+ tuple(self.cfg.init_state.ang_vel)
)
self._default_root_states = torch.tensor(default_root_state, device=self.device).repeat(self.count, 1)
self._default_root_states = torch.tensor(default_root_state, dtype=torch.float, device=self.device)
self._default_root_states = self._default_root_states.repeat(self.count, 1)
# -- dof state
self._default_dof_pos = torch.zeros(self.count, self.num_dof, device=self.device)
self._default_dof_vel = torch.zeros(self.count, self.num_dof, device=self.device)
self._default_dof_pos = torch.zeros(self.count, self.num_dof, dtype=torch.float, device=self.device)
self._default_dof_vel = torch.zeros(self.count, self.num_dof, dtype=torch.float, device=self.device)
for index, dof_name in enumerate(self.articulations.dof_names):
# dof pos
for re_key, value in self.cfg.init_state.dof_pos.items():
......
......@@ -246,13 +246,15 @@ class RigidObject:
"""Post processing of configuration parameters."""
# default state
# -- root state
# note: we cast to tuple to avoid torch/numpy type mismatch.
default_root_state = (
self.cfg.init_state.pos
+ self.cfg.init_state.rot
+ self.cfg.init_state.lin_vel
+ self.cfg.init_state.ang_vel
tuple(self.cfg.init_state.pos)
+ tuple(self.cfg.init_state.rot)
+ tuple(self.cfg.init_state.lin_vel)
+ tuple(self.cfg.init_state.ang_vel)
)
self._default_root_states = torch.tensor(default_root_state, device=self.device).repeat(self.count, 1)
self._default_root_states = torch.tensor(default_root_state, dtype=torch.float, device=self.device)
self._default_root_states = self._default_root_states.repeat(self.count, 1)
def _create_buffers(self):
"""Create buffers for storing data."""
......
......@@ -423,16 +423,18 @@ class RobotBase:
"""Post processing of configuration parameters."""
# default state
# -- root state
# note: we cast to tuple to avoid torch/numpy type mismatch.
default_root_state = (
self.cfg.init_state.pos
+ self.cfg.init_state.rot
+ self.cfg.init_state.lin_vel
+ self.cfg.init_state.ang_vel
tuple(self.cfg.init_state.pos)
+ tuple(self.cfg.init_state.rot)
+ tuple(self.cfg.init_state.lin_vel)
+ tuple(self.cfg.init_state.ang_vel)
)
self._default_root_states = torch.tensor(default_root_state, device=self.device).repeat(self.count, 1)
self._default_root_states = torch.tensor(default_root_state, dtype=torch.float, device=self.device)
self._default_root_states = self._default_root_states.repeat(self.count, 1)
# -- dof state
self._default_dof_pos = torch.zeros(self.count, self.num_dof, device=self.device)
self._default_dof_vel = torch.zeros(self.count, self.num_dof, device=self.device)
self._default_dof_pos = torch.zeros(self.count, self.num_dof, dtype=torch.float, device=self.device)
self._default_dof_vel = torch.zeros(self.count, self.num_dof, dtype=torch.float, device=self.device)
for index, dof_name in enumerate(self.articulations.dof_names):
# dof pos
for re_key, value in self.cfg.init_state.dof_pos.items():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment