egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,169 @@
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import altair as alt
5
+ from scipy.stats import beta
6
+
7
+ # Register custom font for PNG export
8
+ alt.themes.register('custom_theme', lambda: {
9
+ 'config': {
10
+ 'title': {'font': 'Produkt'},
11
+ 'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
12
+ 'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
13
+ 'mark': {'font': 'Produkt'},
14
+ 'text': {'font': 'Produkt'},
15
+ }
16
+ })
17
+ alt.themes.enable('custom_theme')
18
+
19
+ BASE_DIR = "logs"
20
+ REWARD_THRESHOLD = 0.03
21
+
22
+ def compute_success_from_csv(csv_path):
23
+ df = pd.read_csv(csv_path, sep="\t")
24
+ successes = (df["max_reward"] > REWARD_THRESHOLD).sum()
25
+ total = len(df)
26
+ return successes, total
27
+
28
+
29
+ def plot_pi0_bars():
30
+ model_folder = "pi0"
31
+ num_objects = [1, 2, 3, 4, 5]
32
+
33
+ rows = []
34
+
35
+ for n_obj in num_objects:
36
+ possible_folders = [
37
+ f"{n_obj}_objects",
38
+ f"{n_obj}_object",
39
+ f"{n_obj}-objects",
40
+ f"{n_obj}-object",
41
+ ]
42
+
43
+ csv_path = None
44
+ for folder in possible_folders:
45
+ # First try direct path
46
+ candidate = os.path.join(BASE_DIR, model_folder, folder, "log.csv")
47
+ if os.path.exists(candidate):
48
+ csv_path = candidate
49
+ break
50
+
51
+ # Try nested evaluation folder structure
52
+ folder_path = os.path.join(BASE_DIR, model_folder, folder)
53
+ if os.path.exists(folder_path) and os.path.isdir(folder_path):
54
+ for subdir in os.listdir(folder_path):
55
+ subdir_path = os.path.join(folder_path, subdir)
56
+ if os.path.isdir(subdir_path):
57
+ candidate = os.path.join(subdir_path, "log.csv")
58
+ if os.path.exists(candidate):
59
+ csv_path = candidate
60
+ break
61
+ if csv_path:
62
+ break
63
+
64
+ if csv_path is None:
65
+ print(f"Missing data: π-0.5, {n_obj} objects")
66
+ continue
67
+
68
+ s, t = compute_success_from_csv(csv_path)
69
+
70
+ # Beta posterior
71
+ a, b = 1 + s, 1 + (t - s)
72
+ mean = 100 * a / (a + b)
73
+ lo = 100 * beta.ppf(0.025, a, b)
74
+ hi = 100 * beta.ppf(0.975, a, b)
75
+
76
+ print(f"π-0.5 | {n_obj} objects: {s}/{t} = {mean:.1f}%")
77
+
78
+ rows.append({
79
+ "num_objects": n_obj,
80
+ "mean": mean,
81
+ "lo": lo,
82
+ "hi": hi,
83
+ "successes": s,
84
+ "total": t
85
+ })
86
+
87
+ df = pd.DataFrame(rows)
88
+
89
+ # Create bar chart
90
+ bars = alt.Chart(df).mark_bar(
91
+ color='#E0BE16',
92
+ width=60
93
+ ).encode(
94
+ x=alt.X(
95
+ 'num_objects:O',
96
+ title='Number of Objects',
97
+ axis=alt.Axis(
98
+ labelFontSize=20,
99
+ titleFontSize=24,
100
+ labelAngle=0,
101
+ titlePadding=15
102
+ )
103
+ ),
104
+ y=alt.Y(
105
+ 'mean:Q',
106
+ title='Success Rate (%)',
107
+ scale=alt.Scale(domain=[0, 100]),
108
+ axis=alt.Axis(
109
+ labelFontSize=20,
110
+ titleFontSize=24,
111
+ titlePadding=14,
112
+ grid=True,
113
+ gridOpacity=0.3
114
+ )
115
+ ),
116
+ tooltip=[
117
+ alt.Tooltip('num_objects:O', title='Objects'),
118
+ alt.Tooltip('mean:Q', title='Success Rate (%)', format='.1f'),
119
+ alt.Tooltip('successes:Q', title='Successes'),
120
+ alt.Tooltip('total:Q', title='Total')
121
+ ]
122
+ )
123
+
124
+ # Create error bars for 95% CI
125
+ error_bars = alt.Chart(df).mark_errorbar(
126
+ ticks=True,
127
+ thickness=2
128
+ ).encode(
129
+ x=alt.X('num_objects:O'),
130
+ y=alt.Y('lo:Q', title=''),
131
+ y2=alt.Y2('hi:Q')
132
+ )
133
+
134
+ # Add text labels on top of bars
135
+ text = alt.Chart(df).mark_text(
136
+ dy=-10,
137
+ fontSize=16,
138
+ fontWeight='bold'
139
+ ).encode(
140
+ x=alt.X('num_objects:O'),
141
+ y=alt.Y('mean:Q'),
142
+ text=alt.Text('mean:Q', format='.1f')
143
+ )
144
+
145
+ # Combine layers
146
+ chart = (bars + error_bars + text).properties(
147
+ width=500,
148
+ height=400,
149
+ title={
150
+ 'text': ' π-0.5 Success Rate by Number of Objects',
151
+ 'fontSize': 24,
152
+ 'anchor': 'start',
153
+ 'dx': 40,
154
+ 'dy': -10
155
+ },
156
+ padding={'left': 10, 'right': 10, 'top': 40, 'bottom': 40}
157
+ ).configure_view(
158
+ strokeWidth=0
159
+ )
160
+
161
+ return chart
162
+
163
+
164
+ if __name__ == "__main__":
165
+ chart = plot_pi0_bars()
166
+ chart.save("pi0_bars.html")
167
+ chart.save("pi0_bars.png", scale_factor=3)
168
+ chart.save("pi0_bars.pdf", scale_factor=3)
169
+ print("\nPlot saved to: pi0_bars.html, pi0_bars.png, and pi0_bars.pdf")
egogym/tasks/close.py ADDED
@@ -0,0 +1,84 @@
1
+ import os
2
+ import cv2
3
+ import mujoco
4
+ from gymnasium.spaces import Box
5
+ import numpy as np
6
+ from scipy.spatial.transform import Rotation as R
7
+
8
+ from egogym.egogym import Egogym
9
+ from egogym.utils import include_in_scene, position_sampler, make_objects_manager
10
+ import egogym.assets.constants as constants
11
+
12
+ class CloseTask(Egogym):
13
+
14
+ def __init__(self, robot="rum", action_space="delta", render_mode=None, render_size=(960,720), num_objs=1, seed=None, objects_set=None):
15
+ super().__init__(robot=robot, action_space=action_space, render_mode=render_mode, render_size=render_size, seed=seed)
16
+ self.num_objs = num_objs
17
+ if objects_set is not None:
18
+ self.objects_manager = make_objects_manager(objects_set, self.np_random, shuffle=False)
19
+ else:
20
+ self.objects_manager = make_objects_manager(constants.all_close_objects_set, self.np_random)
21
+ self.observation_space["handle_pose"] = Box(low=-np.inf, high=np.inf, shape=(16,), dtype=np.float32)
22
+
23
+ def make_task_scene(self):
24
+ with open(f"{os.path.dirname(__file__)}/../assets/scenes/open.xml", "r") as f:
25
+ scene_xml = f.read()
26
+
27
+ self.object = self.objects_manager.sample()
28
+ scene_xml = self.object.add_to_scene_xml(scene_xml)
29
+ scene_xml = include_in_scene(scene_xml, f"{os.path.dirname(__file__)}/../assets/embodiments/{self.robot.name}/model_open.xml")
30
+ return scene_xml
31
+
32
+ def get_obs(self):
33
+ obs = self.get_base_obs()
34
+ obs["handle_pose"] = self.object.get_handle_pose(self.data).reshape(16).astype(np.float32)
35
+ return obs
36
+
37
+ def compute_reward(self) -> float:
38
+ perecentage_opened = self.object.get_perecenttage_opened(self.model, self.data)
39
+ return max(0.1-perecentage_opened, 0.0)
40
+
41
+ def reset(self, seed=None, options=None):
42
+ super().reset(seed=seed, options=options)
43
+ scene_xml_string = self.make_task_scene()
44
+ self.grasped_bodies = set()
45
+ self.grasping_object = False
46
+
47
+ gripper_init_pose = np.eye(4)
48
+ gripper_init_pose[:3, :3] = R.from_euler("xyz", np.array([1.3, 0.000, 0.000])).as_matrix()
49
+ self.setup_mujoco(scene_xml_string)
50
+ self.env_step()
51
+ self.object.open(self.model, self.data)
52
+ self.env_step()
53
+ handle_pose = self.object.get_handle_pose(self.data)
54
+ camera_pos = handle_pose[0:3, 3].copy() + np.array([-0.1, -0.8, 0.35])
55
+ gripper_init_pose[:3, 3] = handle_pose[0:3, 3] + np.array([0.00, -0.6, 0.2])
56
+ self.model.cam_pos[mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_CAMERA, "exocentric")] = camera_pos
57
+ self.robot.prepare(gripper_init_pose, self.env_step)
58
+ self.env_step(10)
59
+ self.initial_robot_pose = self.robot.get_camera_pose()
60
+
61
+ self.enable_sleeping_islands()
62
+
63
+ observation = self.get_obs()
64
+ info = {"object_name": self.object.name}
65
+
66
+ return observation, info
67
+
68
+ def step(self, action):
69
+ obs, reward, terminated, truncated, info = super().step(action)
70
+ self.grasped_bodies.update(self.robot.get_grasped_bodies())
71
+ grasped_bodies_list = list(self.grasped_bodies)
72
+ if f"{self.object.name}_object" in grasped_bodies_list:
73
+ self.grasping_object = True
74
+ info = {
75
+ "grasped_bodies": grasped_bodies_list,
76
+ "object_name": self.object.name,
77
+ "initial_robot_pose": self.initial_robot_pose,
78
+ "initial_object_pose": self.object.last_set_pose,
79
+ "is_grasping": self.robot.get_grasp(),
80
+ "gripper_current_position": self.robot.get_tcp_pose(),
81
+ "grasping_object": self.grasping_object,
82
+ **info
83
+ }
84
+ return obs, reward, terminated, truncated, info
egogym/tasks/open.py ADDED
@@ -0,0 +1,85 @@
1
+ import os
2
+ import cv2
3
+ import mujoco
4
+ from gymnasium.spaces import Box
5
+ import numpy as np
6
+ from scipy.spatial.transform import Rotation as R
7
+
8
+ from egogym.egogym import Egogym
9
+ from egogym.utils import include_in_scene, position_sampler, make_objects_manager
10
+ import egogym.assets.constants as constants
11
+
12
+ class OpenTask(Egogym):
13
+
14
+ def __init__(self, robot="rum", action_space="delta", render_mode=None, render_size=(960,720), num_objs=1, seed=None, objects_set=None):
15
+ super().__init__(robot=robot, action_space=action_space, render_mode=render_mode, render_size=render_size, seed=seed)
16
+ self.num_objs = num_objs
17
+ if objects_set is not None:
18
+ self.objects_manager = make_objects_manager(objects_set, self.np_random, shuffle=False)
19
+ else:
20
+ self.objects_manager = make_objects_manager(constants.all_open_objects_set, self.np_random)
21
+ self.observation_space["handle_pose"] = Box(low=-np.inf, high=np.inf, shape=(16,), dtype=np.float32)
22
+
23
+ def make_task_scene(self):
24
+ with open(f"{os.path.dirname(__file__)}/../assets/scenes/open.xml", "r") as f:
25
+ scene_xml = f.read()
26
+
27
+ self.object = self.objects_manager.sample()
28
+ scene_xml = self.object.add_to_scene_xml(scene_xml)
29
+ scene_xml = include_in_scene(scene_xml, f"{os.path.dirname(__file__)}/../assets/embodiments/{self.robot.name}/model_open.xml")
30
+ return scene_xml
31
+
32
+ def get_obs(self):
33
+ obs = self.get_base_obs()
34
+ obs["handle_pose"] = self.object.get_handle_pose(self.data).reshape(16).astype(np.float32)
35
+ return obs
36
+
37
+ def compute_reward(self) -> float:
38
+ perecentage_opened = self.object.get_perecenttage_opened(self.model, self.data)
39
+ return perecentage_opened
40
+
41
+ def reset(self, seed=None, options=None):
42
+ super().reset(seed=seed, options=options)
43
+ scene_xml_string = self.make_task_scene()
44
+ self.grasped_bodies = set()
45
+ self.grasping_object = False
46
+
47
+ gripper_init_pose = np.eye(4)
48
+ gripper_init_pose[:3, :3] = R.from_euler("xyz", np.array([1.5, 0.000, 0.000])).as_matrix()
49
+ self.setup_mujoco(scene_xml_string)
50
+ mujoco.mj_step(self.model, self.data)
51
+ handle_pose = self.object.get_handle_pose(self.data)
52
+ camera_pos = handle_pose[0:3, 3].copy()
53
+ camera_pos += np.array([0.0, -1.0, 0.28])
54
+ self.model.cam_pos[mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_CAMERA, "exocentric")] = camera_pos
55
+ mujoco.mj_step(self.model, self.data)
56
+ gripper_init_pose[:3, 3] = handle_pose[:3, 3] + np.array([0.0, -0.6, 0.05])
57
+ self.robot.prepare(gripper_init_pose, self.env_step)
58
+ self.env_step()
59
+ self.initial_robot_pose = self.robot.get_camera_pose()
60
+
61
+ self.settle_env()
62
+ self.enable_sleeping_islands()
63
+
64
+ observation = self.get_obs()
65
+ info = {"object_name": self.object.name}
66
+
67
+ return observation, info
68
+
69
+ def step(self, action):
70
+ obs, reward, terminated, truncated, info = super().step(action)
71
+ self.grasped_bodies.update(self.robot.get_grasped_bodies())
72
+ grasped_bodies_list = list(self.grasped_bodies)
73
+ if f"{self.object.name}_object" in grasped_bodies_list:
74
+ self.grasping_object = True
75
+ info = {
76
+ "grasped_bodies": grasped_bodies_list,
77
+ "object_name": self.object.name,
78
+ "initial_robot_pose": self.initial_robot_pose,
79
+ "initial_object_pose": self.object.last_set_pose,
80
+ "is_grasping": self.robot.get_grasp(),
81
+ "gripper_current_position": self.robot.get_tcp_pose(),
82
+ "grasping_object": self.grasping_object,
83
+ **info
84
+ }
85
+ return obs, reward, terminated, truncated, info
egogym/tasks/pick.py ADDED
@@ -0,0 +1,121 @@
1
+ import os
2
+ from gymnasium.spaces import Box, Text
3
+ import numpy as np
4
+ from scipy.spatial.transform import Rotation as R
5
+
6
+ from egogym.egogym import Egogym
7
+ from egogym.utils import include_in_scene, position_sampler, make_objects_manager, make_textures_manager
8
+ import egogym.assets.constants as constants
9
+
10
+ class PickTask(Egogym):
11
+
12
+ def __init__(self, robot="rum", action_space="delta", render_mode=None, render_size=(960,720), num_objs=1, seed=None, objects_set=None):
13
+ super().__init__(robot=robot, action_space=action_space, render_mode=render_mode, render_size=render_size, seed=seed)
14
+ self.spread = 0.22
15
+ self.num_objs = num_objs
16
+ if objects_set is not None:
17
+ self.objects_manager = make_objects_manager(objects_set, self.np_random, shuffle=True)
18
+ else:
19
+ self.objects_manager = make_objects_manager(constants.lite_pick_objects_set, self.np_random)
20
+ self.textures_manager = make_textures_manager([f"wood/{i}.png" for i in range(10)], self.np_random)
21
+ self.observation_space["object_pose"] = Box(low=-np.inf, high=np.inf, shape=(16,), dtype=np.float32)
22
+ self.observation_space["object_name"] = Text(max_length=256)
23
+
24
+
25
+ def make_task_scene(self):
26
+ assets_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "assets"))
27
+
28
+ with open(os.path.join(assets_dir, "scenes", "pick.xml"), "r") as f:
29
+ scene_xml = f.read()
30
+
31
+ self.object = self.objects_manager.sample()
32
+ self.secondary_objects = []
33
+ used_objects = {self.object.name}
34
+ scene_xml = self.object.add_to_scene_xml(scene_xml)
35
+ while len(used_objects) < self.num_objs:
36
+ object = self.objects_manager.sample(random=True)
37
+ if object.name not in used_objects and object.name.split("_")[0] != self.object.name.split("_")[0]:
38
+ self.secondary_objects.append(object)
39
+ used_objects.add(object.name)
40
+ scene_xml = object.add_to_scene_xml(scene_xml)
41
+ self.texture = self.textures_manager.sample(random=True)
42
+ scene_xml = include_in_scene(scene_xml, os.path.join(assets_dir, "embodiments", self.robot.name, "model_pick.xml"))
43
+ scene_xml = scene_xml.replace("{ASSETS_DIR}", assets_dir)
44
+ scene_xml = scene_xml.replace("{TEXTURE_PATH}", self.texture)
45
+ scene_xml = scene_xml.replace("{TABLE_WIDTH}", str(self.spread+0.24))
46
+ scene_xml = scene_xml.replace("{TABLE_HEIGHT}", str(self.spread+0.24))
47
+ return scene_xml
48
+
49
+ def get_obs(self):
50
+ obs = self.get_base_obs()
51
+ obs["object_pose"] = self.object.get_pose(self.data).reshape(16).astype(np.float32)
52
+ obs["object_name"] = self.object.name.split("_")[0]
53
+ return obs
54
+
55
+ def compute_reward(self) -> float:
56
+ object_pose = self.object.get_pose(self.data)
57
+ lift_distance = object_pose[2][3] - self.object.last_set_pose[2][3]
58
+ return lift_distance
59
+
60
+ def reset(self, seed=None, options=None):
61
+ super().reset(seed=seed, options=options)
62
+ scene_xml_string = self.make_task_scene()
63
+ self.grasped_bodies = set()
64
+ self.grasping_object = False
65
+
66
+ gripper_init_pose = np.eye(4)
67
+ gripper_init_pose[:3, 3] = np.array([0.00, -0.68, 1.1])
68
+ gripper_init_pose[:3, :3] = R.from_euler("xyz", np.array([1.3, 0.000, 0.000])).as_matrix()
69
+ self.setup_mujoco(scene_xml_string)
70
+ initial_positions = position_sampler(self.np_random, len(self.secondary_objects) + 1, [self.spread,self.spread], self.spread/2, 100, start_position=gripper_init_pose[:2, 3], thickness=self.spread/2.5)
71
+ pose = np.eye(4)
72
+ self.initial_robot_pose = self.robot.get_camera_pose()
73
+ self.env_step(10)
74
+
75
+ for i, obj in enumerate([self.object] + self.secondary_objects):
76
+ z = obj.get_center_to_bottom_z_distance(self.data) + 0.78
77
+ pose[:2, 3] = initial_positions[i][:2]
78
+ pose[2, 3] = z
79
+ pose[:3, :3] = R.from_euler(
80
+ "x", self.np_random.uniform(0, 2 * np.pi)
81
+ ).as_matrix()
82
+ obj.set_pose(self.model, self.data, pose)
83
+
84
+ self.env_step(100)
85
+ self.settle_env()
86
+ self.env_step(100)
87
+ for obj in [self.object] + self.secondary_objects:
88
+ obj.last_set_pose = obj.get_pose(self.data).copy()
89
+
90
+ self.robot.prepare(gripper_init_pose, self.env_step)
91
+ self.env_step(100)
92
+ self.enable_sleeping_islands()
93
+
94
+ observation = self.get_obs()
95
+ info = {"object_name": self.object.name}
96
+
97
+ return observation, info
98
+
99
+ def step(self, action):
100
+ obs, reward, terminated, truncated, info = super().step(action)
101
+ self.grasped_bodies.update(self.robot.get_grasped_bodies())
102
+ grasped_bodies_list = list(self.grasped_bodies)
103
+ if f"{self.object.name}_object" in grasped_bodies_list:
104
+ self.grasping_object = True
105
+ info = {
106
+ "grasped_bodies": grasped_bodies_list,
107
+ "object_name": self.object.name,
108
+ "texture_name": self.texture,
109
+ "initial_robot_pose": self.initial_robot_pose,
110
+ "initial_object_pose": self.object.last_set_pose,
111
+ "is_grasping": self.robot.get_grasp(),
112
+ "gripper_current_position": self.robot.get_tcp_pose(),
113
+ "grasping_object": self.grasping_object,
114
+ **info
115
+ }
116
+
117
+ if np.linalg.norm(self.robot.get_camera_pose()[:3,3] - self.object.get_pose(self.data)[:3,3]) > np.linalg.norm(self.initial_robot_pose[:3,3] - self.object.last_set_pose[:3,3])*1.5:
118
+ truncated = True
119
+ terminated = True
120
+
121
+ return obs, reward, terminated, truncated, info