egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1 @@
1
+ from .object import Object, ArticulableObject
@@ -0,0 +1,94 @@
1
+ import mujoco
2
+ from scipy.spatial.transform import Rotation as R
3
+ import numpy as np
4
+
5
+ from egogym.utils import get_pose, include_in_scene, add_xml_to_scene
6
+
7
+
8
+ class Object:
9
+ def __init__(self, name, path=None, func=None, np_random=None):
10
+ self.name = name
11
+ self.english_name = self.name.rsplit("_", 1)[0].replace("_", " ")
12
+ self.path = path
13
+ self.last_set_pose = None
14
+ self.object_xml = None
15
+ self.handle_type = None
16
+ self.object_geoms_ids = None
17
+ self.np_random = np_random
18
+ if func is not None:
19
+ self.object_xml = func(
20
+ np_random=np_random
21
+ )
22
+ self.body_name = f"{self.name}_object"
23
+
24
+ def get_pose(self, data):
25
+ return get_pose(data, self.body_name, "body")
26
+
27
+ def get_geom_ids(self, model):
28
+ if self.object_geoms_ids is not None:
29
+ return self.object_geoms_ids
30
+ object_geoms_ids = []
31
+ for i in range(model.ngeom):
32
+ geom_bodyid = model.geom_bodyid[i]
33
+ body_name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_BODY, geom_bodyid)
34
+ if body_name is not None and f"{self.name}_object" in body_name:
35
+ object_geoms_ids.append(i)
36
+ self.object_geoms_ids = object_geoms_ids
37
+ return object_geoms_ids
38
+
39
+ def get_bottom_pose(self, data):
40
+ return get_pose(data, f"{self.name}_bottom_site", "site")
41
+
42
+ def get_center_to_bottom_z_distance(self, data):
43
+ bottom_z = self.get_bottom_pose(data)[2, 3]
44
+ center_z = self.get_pose(data)[2, 3]
45
+ return center_z - bottom_z
46
+
47
+ def set_pose(self, model, data, pose):
48
+ body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, self.body_name)
49
+ model.body_pos[body_id] = pose[0:3, 3]
50
+ model.body_quat[body_id] = R.from_matrix(pose[0:3, 0:3]).as_quat()
51
+ self.last_set_pose = pose.copy()
52
+
53
+ def add_to_scene_xml(self, scene_xml):
54
+ if self.object_xml is not None:
55
+ return add_xml_to_scene(scene_xml, self.object_xml)
56
+ elif self.path is not None:
57
+ return include_in_scene(scene_xml, self.path)
58
+ else:
59
+ raise ValueError("Either func or path must be provided")
60
+
61
+
62
+ class ArticulableObject(Object):
63
+ def __init__(self, name, path=None, func=None, np_random=None):
64
+ super().__init__(name, path, func, np_random)
65
+ self.joint_range = None
66
+ self.initial_joint = None
67
+
68
+ def get_handle_pose(self, data):
69
+ return get_pose(data, "handle", "geom")
70
+
71
+ def get_jotinpos(self, data):
72
+ return data.joint("task_joint").qpos[0]
73
+
74
+ def get_perecenttage_opened(self, model, data):
75
+ if self.joint_range is None:
76
+ self.joint_range = model.jnt_range[mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "task_joint")]
77
+ self.initial_joint = self.get_jotinpos(data)
78
+ current_task_joint = self.get_jotinpos(data)
79
+ normalized_task_joint = abs(current_task_joint - self.initial_joint) / (
80
+ abs(self.joint_range[1] - self.joint_range[0]) + 1e-8
81
+ )
82
+ return normalized_task_joint
83
+
84
+ def open(self, model, data):
85
+ if self.joint_range is None:
86
+ self.joint_range = model.jnt_range[mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, "task_joint")]
87
+ self.initial_joint = self.get_jotinpos(data)
88
+ data.joint("task_joint").qpos = (
89
+ self.joint_range[1] * self.np_random.uniform(0.5, 0.9)
90
+ if self.joint_range[0] == 0
91
+ else self.joint_range[0] * self.np_random.uniform(0.5, 0.9)
92
+ )
93
+ mujoco.mj_step(model, data, nstep=100)
94
+
egogym/egogym.py ADDED
@@ -0,0 +1,106 @@
1
+ from gymnasium.utils import seeding
2
+ import gymnasium as gym
3
+ from gymnasium.spaces import Box, Dict
4
+ import mujoco
5
+ import numpy as np
6
+ import cv2
7
+
8
+ from egogym.embodiments import ROBOT_REGISTRY
9
+
10
+ class Egogym(gym.Env):
11
+ metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30}
12
+
13
+ def __init__(self, robot="rum", action_space="delta", render_mode=None, render_size=(960,720), seed=None):
14
+ super().__init__()
15
+ self.model = None
16
+ self.data = None
17
+ self.step_idx = 0
18
+ self.render_mode = render_mode
19
+ self.np_random, _ = seeding.np_random(seed)
20
+
21
+ self.render_width = render_size[0]
22
+ self.render_height = render_size[1]
23
+
24
+ self.robot_type = robot
25
+ self.action_space_type = action_space
26
+
27
+ if self.robot_type not in ROBOT_REGISTRY:
28
+ raise ValueError(f"Unknown robot type '{self.robot_type}'. Available robots: {list(ROBOT_REGISTRY.keys())}")
29
+
30
+ robot_class = ROBOT_REGISTRY[self.robot_type]
31
+ self.robot = robot_class(action_space=self.action_space_type, np_random=self.np_random)
32
+
33
+ obs_spaces = {
34
+ "rgb_ego": Box(low=0, high=255, shape=(self.render_height, self.render_width, 3), dtype=np.uint8),
35
+ "rgb_exo": Box(low=0, high=255, shape=(self.render_height, self.render_width, 3), dtype=np.uint8),
36
+ "tcp_pose": Box(low=-np.inf, high=np.inf, shape=(16,), dtype=np.float32),
37
+ "camera_pose": Box(low=-np.inf, high=np.inf, shape=(16,), dtype=np.float32),
38
+ "grasp": Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32),
39
+ }
40
+ obs_spaces.update(self.robot.additional_obs_spaces())
41
+ self.observation_space = Dict(obs_spaces)
42
+ self.action_space = Box(low=-np.inf, high=np.inf, shape=self.robot.action_space_shape(), dtype=np.float32)
43
+ self._mujoco_step_counter = 0
44
+ self._render_freq = 0
45
+ self._render_callback = None
46
+ self._should_render = self.render_mode == 'human'
47
+
48
+ def setup_mujoco(self, scene_xml: str):
49
+ if hasattr(self, "model") and self.model is not None:
50
+ del self.model
51
+ del self.data
52
+ self.model = mujoco.MjModel.from_xml_string(scene_xml)
53
+ self.data = mujoco.MjData(self.model)
54
+ self.robot.setup_mj(self.model, self.data)
55
+ self.robot.setup_renderers(self.render_width, self.render_height)
56
+
57
+ def env_step(self, nsteps=1):
58
+ for _ in range(nsteps):
59
+ mujoco.mj_step(self.model, self.data)
60
+ self._mujoco_step_counter += 1
61
+ if self._render_freq > 0 and self._render_callback is not None and self._mujoco_step_counter >= self._render_freq:
62
+ self._render_callback()
63
+ self._mujoco_step_counter = 0
64
+
65
+ def get_base_obs(self):
66
+ obs = {
67
+ "rgb_ego": self.robot.get_camera_view(self.robot.camera_names[0]),
68
+ "rgb_exo": self.robot.get_camera_view(self.robot.camera_names[-1]),
69
+ "tcp_pose": self.robot.get_tcp_pose().reshape(16).astype(np.float32),
70
+ "camera_pose": self.robot.get_camera_pose().reshape(16).astype(np.float32),
71
+ "grasp": np.array([self.robot.get_grasp()], dtype=np.float32),
72
+ }
73
+ obs.update(self.robot.get_additional_obs())
74
+ return obs
75
+
76
+ def render(self):
77
+ rgb_ego = self.robot.get_camera_view(self.robot.camera_names[0])
78
+ rgb_exo = self.robot.get_camera_view(self.robot.camera_names[1])
79
+ rgb_stack = np.concatenate((rgb_ego, rgb_exo), axis=1)
80
+ rgb_stack = cv2.cvtColor(rgb_stack, cv2.COLOR_RGB2BGR)
81
+ cv2.imshow("Egogym", rgb_stack)
82
+ cv2.waitKey(1)
83
+
84
+ def close(self):
85
+ if self.render_mode == "human":
86
+ cv2.destroyAllWindows()
87
+
88
+ def step(self, action):
89
+ if action is None or np.isnan(action).any():
90
+ return self.get_obs(), 0.0, False, False, {}
91
+ self.robot.step(action, step_fn=self.env_step)
92
+ obs = self.get_obs()
93
+ reward = self.compute_reward()
94
+ terminated = False
95
+ truncated = False
96
+ return obs, reward, terminated, truncated, {}
97
+
98
+ def enable_sleeping_islands(self):
99
+ try:
100
+ self.model.opt.enableflags |= mujoco.mjtEnableBit.mjENBL_SLEEP
101
+ except AttributeError:
102
+ pass
103
+
104
+ def settle_env(self, nsteps=200):
105
+ if hasattr(self.robot, 'free_joint_id') and self.robot.free_joint_id is not None: #TODO: check if its floating_gripper instead
106
+ self.data.qpos[:] = np.where((np.arange(len(self.data.qpos)) >= self.robot.free_joint_id) & (np.arange(len(self.data.qpos)) < self.robot.free_joint_id + 6), self.data.qpos, 0.0)
@@ -0,0 +1,10 @@
1
+ from .robot import Robot
2
+ from .grippers import FloatingGripper, RUMGripper
3
+ from .arms import Arm, DroidArm
4
+
5
+ ROBOT_REGISTRY = {
6
+ "rum": RUMGripper,
7
+ "droid": DroidArm,
8
+ }
9
+
10
+ __all__ = ["Robot", "FloatingGripper", "RUMGripper", "Arm", "DroidArm", "ROBOT_REGISTRY"]
@@ -0,0 +1,4 @@
1
+ from .arm import Arm
2
+ from .droid import DroidArm
3
+
4
+ __all__ = ["Arm", "DroidArm"]
@@ -0,0 +1,65 @@
1
+ from egogym.embodiments.robot import Robot, requires_mj
2
+ from gymnasium.spaces import Box
3
+ import numpy as np
4
+
5
+
6
+ class Arm(Robot):
7
+
8
+ def __init__(self, control_steps=2000, grasping_steps=1000, action_space="delta", np_random=None):
9
+ super().__init__(control_steps=control_steps, grasping_steps=grasping_steps, action_space=action_space, np_random=np_random)
10
+ self.joint_names = []
11
+ self.qpos_ids = None
12
+ self.act_ids = None
13
+ self.gripper_act_id = None
14
+
15
+ def additional_obs_spaces(self):
16
+ num_joints = len(self.joint_names)
17
+ return {
18
+ "joint_positions": Box(low=-np.inf, high=np.inf, shape=(num_joints,), dtype=np.float32)
19
+ }
20
+
21
+ def action_space_shape(self):
22
+ return (len(self.joint_names) + 1,)
23
+
24
+ @requires_mj
25
+ def get_additional_obs(self):
26
+ return {
27
+ "joint_positions": self.get_joints().astype(np.float32)
28
+ }
29
+
30
+ def setup_mj(self, model, data):
31
+ super().setup_mj(model, data)
32
+ if self.joint_names:
33
+ self.qpos_ids = np.array([self.model.joint(name).qposadr[0] for name in self.joint_names])
34
+ self.act_ids = np.array([self.model.actuator(name).id for name in self.joint_names])
35
+ if hasattr(self, 'gripper_actuator_name'):
36
+ self.gripper_act_id = self.model.actuator(self.gripper_actuator_name).id
37
+
38
+ @requires_mj
39
+ def get_joints(self):
40
+ return self.data.qpos[self.qpos_ids].copy()
41
+
42
+ @requires_mj
43
+ def set_joints(self, joint_positions):
44
+ self.data.ctrl[self.act_ids] = joint_positions
45
+
46
+ @requires_mj
47
+ def step(self, action, step_fn):
48
+
49
+ if len(action.shape) == 2:
50
+ action = action[0]
51
+
52
+ num_joints = len(self.joint_names)
53
+ desired_joints = action[:num_joints]
54
+ grip_action = action[num_joints] if len(action) > num_joints else 0.0
55
+
56
+ if self.action_space == "delta":
57
+ target_joints = self.get_joints() + desired_joints
58
+ else:
59
+ target_joints = desired_joints
60
+
61
+
62
+ self.set_joints(target_joints)
63
+ self.data.ctrl[self.gripper_act_id] = 255.0 * np.clip(grip_action, 0.0, 1.0)
64
+
65
+ step_fn(self.control_steps)
@@ -0,0 +1,49 @@
1
+ from egogym.embodiments.arms.arm import Arm
2
+ from egogym.embodiments.robot import requires_mj
3
+ from egogym.utils import get_pose
4
+ import numpy as np
5
+
6
+
7
+ class DroidArm(Arm):
8
+ def __init__(self, control_steps=600, grasping_steps=100, action_space="delta", np_random=None):
9
+ super().__init__(control_steps=control_steps, grasping_steps=grasping_steps, action_space=action_space, np_random=np_random)
10
+ self.name = "droid"
11
+ self.initial_q = np.array([0, -1/5 * np.pi, 0, -4/5 * np.pi, 0, 3/5 * np.pi, 0.0])
12
+ self.joint_names = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6", "joint7"]
13
+ self.gripper_actuator_name = "fingers_actuator"
14
+
15
+ @requires_mj
16
+ def prepare(self, desired_pose, step_fn=None):
17
+
18
+ desired_pose[:3,3] += np.array([0.0, 0.6, 0.075])
19
+ if self.add_noise_init:
20
+ position_noise = self.np_random.uniform(-0.03, 0.03, size=3)
21
+ desired_pose[:3, 3] += position_noise
22
+
23
+ self.data.ctrl[self.act_ids] = self.initial_q
24
+ step_fn(500)
25
+ gripper_position = self.get_tcp_pose()[:3, 3]
26
+ link_position = get_pose(self.data, "link0", "body")[:3, 3]
27
+ delta_position = gripper_position - link_position
28
+ desired_gripper_position = desired_pose[:3, 3]
29
+ desired_link_position = desired_gripper_position - delta_position
30
+ self.model.body_pos[self.model.body("link0").id] = desired_link_position
31
+ step_fn(100)
32
+
33
+ @requires_mj
34
+ def get_grasp(self):
35
+ val = self.data.qpos[self.model.joint("right_driver_joint").qposadr[0]] / 0.8102
36
+ if val < 0.01:
37
+ val = 0.0
38
+ val = np.array([val], dtype=np.float32).reshape(1)
39
+ current = np.array([(np.clip(val, 0.0, 1.0))], dtype=np.float32).reshape(1)[0]
40
+ return current
41
+
42
+ @requires_mj
43
+ def set_grasp(self, grip_action):
44
+ desired_grasp = 255.0 * grip_action
45
+ self.data.ctrl[self.model.actuator("fingers_actuator").id] = desired_grasp.item()
46
+ if desired_grasp != self.last_grasp:
47
+ self.last_grasp = desired_grasp
48
+ return 1
49
+ return 0
@@ -0,0 +1,4 @@
1
+ from .floating_gripper import FloatingGripper
2
+ from .rum import RUMGripper
3
+
4
+ __all__ = ["FloatingGripper", "RUMGripper"]
@@ -0,0 +1,58 @@
1
+ from scipy.spatial.transform import Rotation as R
2
+ import numpy as np
3
+
4
+ from egogym.utils import get_pose
5
+ from egogym.embodiments.robot import Robot, requires_mj
6
+
7
+
8
+ class FloatingGripper(Robot):
9
+ def __init__(self, control_steps=50, grasping_steps=1, action_space="delta", np_random=None):
10
+ super().__init__(control_steps=control_steps, grasping_steps=grasping_steps, action_space=action_space, np_random=np_random)
11
+ self.last_grasp = 0.0
12
+
13
+ def action_space_shape(self):
14
+ return (17,)
15
+
16
+ def setup_mj(self, model, data):
17
+ super().setup_mj(model, data)
18
+ self.last_grasp = 0.0
19
+ self.free_joint_id = model.joint("free_joint").qposadr[0]
20
+
21
+ @requires_mj
22
+ def prepare(self, desired_pose, step_fn=None):
23
+ if self.add_noise_init:
24
+ position_noise = self.np_random.uniform(-0.03, 0.03, size=3)
25
+ desired_pose[:3, 3] += position_noise
26
+ euler_angles = R.from_matrix(desired_pose[:3, :3]).as_euler("xyz")
27
+ euler_angles[0] += self.np_random.uniform(-0.1, 0.1)
28
+ euler_angles[1] += self.np_random.uniform(-0.05, 0.05)
29
+ euler_angles[2] += self.np_random.uniform(-0.1, 0.1)
30
+ desired_pose[:3, :3] = R.from_euler("xyz", euler_angles).as_matrix()
31
+ self.last_teleop_pose = desired_pose
32
+ self.data.mocap_pos[0] = desired_pose[:3, 3]
33
+ self.data.mocap_quat[0] = np.roll(R.from_matrix(desired_pose[:3, :3]).as_quat(), 1)
34
+ step_fn(self.control_steps)
35
+
36
+ @requires_mj
37
+ def move(self, desired_pose):
38
+ self.data.mocap_pos[0] = desired_pose[:3, 3]
39
+ self.data.mocap_quat[0] = np.roll(R.from_matrix(desired_pose[:3, :3]).as_quat(), 1)
40
+
41
+ @requires_mj
42
+ def step(self, action, step_fn):
43
+ if len(action.shape) == 2:
44
+ action = action[0]
45
+
46
+ pose = action[:16].reshape(4,4)
47
+ grip_action = action[16]
48
+
49
+ if self.action_space == "delta":
50
+ goal_pose = self.get_camera_pose() @ pose
51
+ else:
52
+ goal_pose = pose
53
+
54
+ self.move(goal_pose)
55
+ updated = self.set_grasp(grip_action)
56
+ if updated:
57
+ step_fn(self.grasping_steps)
58
+ step_fn(self.control_steps)
@@ -0,0 +1,6 @@
1
+ from egogym.embodiments.grippers.floating_gripper import FloatingGripper
2
+
3
+ class RUMGripper(FloatingGripper):
4
+ def __init__(self, control_steps=200, grasping_steps=2000, action_space="delta", np_random=None):
5
+ super().__init__(control_steps, grasping_steps, action_space, np_random=np_random)
6
+ self.name = "rum"
@@ -0,0 +1,95 @@
1
+ import mujoco
2
+ from functools import wraps
3
+ import numpy as np
4
+ from egogym.utils import get_pose
5
+
6
+
7
+ def requires_mj(func):
8
+ @wraps(func)
9
+ def wrapper(self, *args, **kwargs):
10
+ if not hasattr(self, 'model') or not hasattr(self, 'data') or self.model is None or self.data is None:
11
+ raise RuntimeError(f"Must call setup_mj before using the {self.__class__.__name__}.")
12
+ return func(self, *args, **kwargs)
13
+ return wrapper
14
+
15
+
16
+ class Robot():
17
+ def __init__(self, control_steps=50, grasping_steps=1, action_space="delta", np_random=None):
18
+ self.name = None
19
+ self.model = None
20
+ self.data = None
21
+ self.control_steps = control_steps
22
+ self.grasping_steps = grasping_steps
23
+ self.action_space = action_space
24
+ self.np_random = np_random
25
+ self.camera_names = ["egocentric", "exocentric"]
26
+ self.rgb_renderers = {}
27
+ self.add_noise_init = True
28
+
29
+ def additional_obs_spaces(self):
30
+ return {}
31
+
32
+ def get_additional_obs(self):
33
+ return {}
34
+
35
+ def action_space_shape(self):
36
+ return (0,)
37
+
38
+ def setup_mj(self, model, data):
39
+ self.model = model
40
+ self.data = data
41
+
42
+ @requires_mj
43
+ def get_camera_pose(self):
44
+ return get_pose(self.data, "egocentric", "camera")
45
+
46
+
47
+ @requires_mj
48
+ def setup_renderers(self, render_width=960, render_height=720):
49
+ if len(self.rgb_renderers) > 0:
50
+ del self.rgb_renderers
51
+ self.rgb_renderers = {}
52
+ for camera_name in self.camera_names:
53
+ self.rgb_renderers[camera_name] = mujoco.Renderer(self.model, height=render_height, width=render_width)
54
+
55
+ @requires_mj
56
+ def get_camera_view(self, camera_name):
57
+ scene_opt = mujoco.MjvOption()
58
+ self.rgb_renderers[camera_name].update_scene(self.data, camera_name, scene_opt)
59
+ rgb = self.rgb_renderers[camera_name].render()
60
+ return rgb
61
+
62
+ @requires_mj
63
+ def get_tcp_pose(self):
64
+ return get_pose(self.data, "grasping_center", "site")
65
+
66
+ @requires_mj
67
+ def get_grasp(self):
68
+ return np.array(self.data.qpos[self.model.actuator("fingers_actuator").id])
69
+
70
+ @requires_mj
71
+ def get_grasped_bodies(self):
72
+ contact_points = self.data.contact
73
+ body_names = set()
74
+ for contact in contact_points:
75
+ body1_id = self.model.geom_bodyid[contact.geom1]
76
+ body1 = mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_BODY, body1_id)
77
+ body2_id = self.model.geom_bodyid[contact.geom2]
78
+ body2 = mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_BODY, body2_id)
79
+ if body1 is None or body2 is None:
80
+ continue
81
+ if "left" in body1 or "right" in body1:
82
+ body_names.add(body2)
83
+ if "left" in body2 or "right" in body2:
84
+ body_names.add(body1)
85
+ return body_names
86
+
87
+
88
+ @requires_mj
89
+ def set_grasp(self, grip_action):
90
+ desired_grasp = -255.0 * grip_action
91
+ self.data.ctrl[self.model.actuator("fingers_actuator").id] = desired_grasp.item()
92
+ if desired_grasp != self.last_grasp:
93
+ self.last_grasp = desired_grasp
94
+ return 1
95
+ return 0