egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import altair as alt
|
|
5
|
+
from scipy.stats import beta
|
|
6
|
+
|
|
7
|
+
# Register custom font for PNG export
|
|
8
|
+
alt.themes.register('custom_theme', lambda: {
|
|
9
|
+
'config': {
|
|
10
|
+
'title': {'font': 'Produkt'},
|
|
11
|
+
'axis': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
12
|
+
'legend': {'labelFont': 'Produkt', 'titleFont': 'Produkt'},
|
|
13
|
+
'mark': {'font': 'Produkt'},
|
|
14
|
+
'text': {'font': 'Produkt'},
|
|
15
|
+
}
|
|
16
|
+
})
|
|
17
|
+
alt.themes.enable('custom_theme')
|
|
18
|
+
|
|
19
|
+
BASE_DIR = "logs"
|
|
20
|
+
REWARD_THRESHOLD = 0.03
|
|
21
|
+
|
|
22
|
+
def compute_success_from_csv(csv_path):
|
|
23
|
+
df = pd.read_csv(csv_path, sep="\t")
|
|
24
|
+
successes = (df["max_reward"] > REWARD_THRESHOLD).sum()
|
|
25
|
+
total = len(df)
|
|
26
|
+
return successes, total
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def plot_pi0_bars():
|
|
30
|
+
model_folder = "pi0"
|
|
31
|
+
num_objects = [1, 2, 3, 4, 5]
|
|
32
|
+
|
|
33
|
+
rows = []
|
|
34
|
+
|
|
35
|
+
for n_obj in num_objects:
|
|
36
|
+
possible_folders = [
|
|
37
|
+
f"{n_obj}_objects",
|
|
38
|
+
f"{n_obj}_object",
|
|
39
|
+
f"{n_obj}-objects",
|
|
40
|
+
f"{n_obj}-object",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
csv_path = None
|
|
44
|
+
for folder in possible_folders:
|
|
45
|
+
# First try direct path
|
|
46
|
+
candidate = os.path.join(BASE_DIR, model_folder, folder, "log.csv")
|
|
47
|
+
if os.path.exists(candidate):
|
|
48
|
+
csv_path = candidate
|
|
49
|
+
break
|
|
50
|
+
|
|
51
|
+
# Try nested evaluation folder structure
|
|
52
|
+
folder_path = os.path.join(BASE_DIR, model_folder, folder)
|
|
53
|
+
if os.path.exists(folder_path) and os.path.isdir(folder_path):
|
|
54
|
+
for subdir in os.listdir(folder_path):
|
|
55
|
+
subdir_path = os.path.join(folder_path, subdir)
|
|
56
|
+
if os.path.isdir(subdir_path):
|
|
57
|
+
candidate = os.path.join(subdir_path, "log.csv")
|
|
58
|
+
if os.path.exists(candidate):
|
|
59
|
+
csv_path = candidate
|
|
60
|
+
break
|
|
61
|
+
if csv_path:
|
|
62
|
+
break
|
|
63
|
+
|
|
64
|
+
if csv_path is None:
|
|
65
|
+
print(f"Missing data: π-0.5, {n_obj} objects")
|
|
66
|
+
continue
|
|
67
|
+
|
|
68
|
+
s, t = compute_success_from_csv(csv_path)
|
|
69
|
+
|
|
70
|
+
# Beta posterior
|
|
71
|
+
a, b = 1 + s, 1 + (t - s)
|
|
72
|
+
mean = 100 * a / (a + b)
|
|
73
|
+
lo = 100 * beta.ppf(0.025, a, b)
|
|
74
|
+
hi = 100 * beta.ppf(0.975, a, b)
|
|
75
|
+
|
|
76
|
+
print(f"π-0.5 | {n_obj} objects: {s}/{t} = {mean:.1f}%")
|
|
77
|
+
|
|
78
|
+
rows.append({
|
|
79
|
+
"num_objects": n_obj,
|
|
80
|
+
"mean": mean,
|
|
81
|
+
"lo": lo,
|
|
82
|
+
"hi": hi,
|
|
83
|
+
"successes": s,
|
|
84
|
+
"total": t
|
|
85
|
+
})
|
|
86
|
+
|
|
87
|
+
df = pd.DataFrame(rows)
|
|
88
|
+
|
|
89
|
+
# Create bar chart
|
|
90
|
+
bars = alt.Chart(df).mark_bar(
|
|
91
|
+
color='#E0BE16',
|
|
92
|
+
width=60
|
|
93
|
+
).encode(
|
|
94
|
+
x=alt.X(
|
|
95
|
+
'num_objects:O',
|
|
96
|
+
title='Number of Objects',
|
|
97
|
+
axis=alt.Axis(
|
|
98
|
+
labelFontSize=20,
|
|
99
|
+
titleFontSize=24,
|
|
100
|
+
labelAngle=0,
|
|
101
|
+
titlePadding=15
|
|
102
|
+
)
|
|
103
|
+
),
|
|
104
|
+
y=alt.Y(
|
|
105
|
+
'mean:Q',
|
|
106
|
+
title='Success Rate (%)',
|
|
107
|
+
scale=alt.Scale(domain=[0, 100]),
|
|
108
|
+
axis=alt.Axis(
|
|
109
|
+
labelFontSize=20,
|
|
110
|
+
titleFontSize=24,
|
|
111
|
+
titlePadding=14,
|
|
112
|
+
grid=True,
|
|
113
|
+
gridOpacity=0.3
|
|
114
|
+
)
|
|
115
|
+
),
|
|
116
|
+
tooltip=[
|
|
117
|
+
alt.Tooltip('num_objects:O', title='Objects'),
|
|
118
|
+
alt.Tooltip('mean:Q', title='Success Rate (%)', format='.1f'),
|
|
119
|
+
alt.Tooltip('successes:Q', title='Successes'),
|
|
120
|
+
alt.Tooltip('total:Q', title='Total')
|
|
121
|
+
]
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Create error bars for 95% CI
|
|
125
|
+
error_bars = alt.Chart(df).mark_errorbar(
|
|
126
|
+
ticks=True,
|
|
127
|
+
thickness=2
|
|
128
|
+
).encode(
|
|
129
|
+
x=alt.X('num_objects:O'),
|
|
130
|
+
y=alt.Y('lo:Q', title=''),
|
|
131
|
+
y2=alt.Y2('hi:Q')
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Add text labels on top of bars
|
|
135
|
+
text = alt.Chart(df).mark_text(
|
|
136
|
+
dy=-10,
|
|
137
|
+
fontSize=16,
|
|
138
|
+
fontWeight='bold'
|
|
139
|
+
).encode(
|
|
140
|
+
x=alt.X('num_objects:O'),
|
|
141
|
+
y=alt.Y('mean:Q'),
|
|
142
|
+
text=alt.Text('mean:Q', format='.1f')
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Combine layers
|
|
146
|
+
chart = (bars + error_bars + text).properties(
|
|
147
|
+
width=500,
|
|
148
|
+
height=400,
|
|
149
|
+
title={
|
|
150
|
+
'text': ' π-0.5 Success Rate by Number of Objects',
|
|
151
|
+
'fontSize': 24,
|
|
152
|
+
'anchor': 'start',
|
|
153
|
+
'dx': 40,
|
|
154
|
+
'dy': -10
|
|
155
|
+
},
|
|
156
|
+
padding={'left': 10, 'right': 10, 'top': 40, 'bottom': 40}
|
|
157
|
+
).configure_view(
|
|
158
|
+
strokeWidth=0
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
return chart
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
if __name__ == "__main__":
|
|
165
|
+
chart = plot_pi0_bars()
|
|
166
|
+
chart.save("pi0_bars.html")
|
|
167
|
+
chart.save("pi0_bars.png", scale_factor=3)
|
|
168
|
+
chart.save("pi0_bars.pdf", scale_factor=3)
|
|
169
|
+
print("\nPlot saved to: pi0_bars.html, pi0_bars.png, and pi0_bars.pdf")
|
egogym/tasks/close.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import cv2
|
|
3
|
+
import mujoco
|
|
4
|
+
from gymnasium.spaces import Box
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.spatial.transform import Rotation as R
|
|
7
|
+
|
|
8
|
+
from egogym.egogym import Egogym
|
|
9
|
+
from egogym.utils import include_in_scene, position_sampler, make_objects_manager
|
|
10
|
+
import egogym.assets.constants as constants
|
|
11
|
+
|
|
12
|
+
class CloseTask(Egogym):
|
|
13
|
+
|
|
14
|
+
def __init__(self, robot="rum", action_space="delta", render_mode=None, render_size=(960,720), num_objs=1, seed=None, objects_set=None):
|
|
15
|
+
super().__init__(robot=robot, action_space=action_space, render_mode=render_mode, render_size=render_size, seed=seed)
|
|
16
|
+
self.num_objs = num_objs
|
|
17
|
+
if objects_set is not None:
|
|
18
|
+
self.objects_manager = make_objects_manager(objects_set, self.np_random, shuffle=False)
|
|
19
|
+
else:
|
|
20
|
+
self.objects_manager = make_objects_manager(constants.all_close_objects_set, self.np_random)
|
|
21
|
+
self.observation_space["handle_pose"] = Box(low=-np.inf, high=np.inf, shape=(16,), dtype=np.float32)
|
|
22
|
+
|
|
23
|
+
def make_task_scene(self):
|
|
24
|
+
with open(f"{os.path.dirname(__file__)}/../assets/scenes/open.xml", "r") as f:
|
|
25
|
+
scene_xml = f.read()
|
|
26
|
+
|
|
27
|
+
self.object = self.objects_manager.sample()
|
|
28
|
+
scene_xml = self.object.add_to_scene_xml(scene_xml)
|
|
29
|
+
scene_xml = include_in_scene(scene_xml, f"{os.path.dirname(__file__)}/../assets/embodiments/{self.robot.name}/model_open.xml")
|
|
30
|
+
return scene_xml
|
|
31
|
+
|
|
32
|
+
def get_obs(self):
|
|
33
|
+
obs = self.get_base_obs()
|
|
34
|
+
obs["handle_pose"] = self.object.get_handle_pose(self.data).reshape(16).astype(np.float32)
|
|
35
|
+
return obs
|
|
36
|
+
|
|
37
|
+
def compute_reward(self) -> float:
|
|
38
|
+
perecentage_opened = self.object.get_perecenttage_opened(self.model, self.data)
|
|
39
|
+
return max(0.1-perecentage_opened, 0.0)
|
|
40
|
+
|
|
41
|
+
def reset(self, seed=None, options=None):
|
|
42
|
+
super().reset(seed=seed, options=options)
|
|
43
|
+
scene_xml_string = self.make_task_scene()
|
|
44
|
+
self.grasped_bodies = set()
|
|
45
|
+
self.grasping_object = False
|
|
46
|
+
|
|
47
|
+
gripper_init_pose = np.eye(4)
|
|
48
|
+
gripper_init_pose[:3, :3] = R.from_euler("xyz", np.array([1.3, 0.000, 0.000])).as_matrix()
|
|
49
|
+
self.setup_mujoco(scene_xml_string)
|
|
50
|
+
self.env_step()
|
|
51
|
+
self.object.open(self.model, self.data)
|
|
52
|
+
self.env_step()
|
|
53
|
+
handle_pose = self.object.get_handle_pose(self.data)
|
|
54
|
+
camera_pos = handle_pose[0:3, 3].copy() + np.array([-0.1, -0.8, 0.35])
|
|
55
|
+
gripper_init_pose[:3, 3] = handle_pose[0:3, 3] + np.array([0.00, -0.6, 0.2])
|
|
56
|
+
self.model.cam_pos[mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_CAMERA, "exocentric")] = camera_pos
|
|
57
|
+
self.robot.prepare(gripper_init_pose, self.env_step)
|
|
58
|
+
self.env_step(10)
|
|
59
|
+
self.initial_robot_pose = self.robot.get_camera_pose()
|
|
60
|
+
|
|
61
|
+
self.enable_sleeping_islands()
|
|
62
|
+
|
|
63
|
+
observation = self.get_obs()
|
|
64
|
+
info = {"object_name": self.object.name}
|
|
65
|
+
|
|
66
|
+
return observation, info
|
|
67
|
+
|
|
68
|
+
def step(self, action):
|
|
69
|
+
obs, reward, terminated, truncated, info = super().step(action)
|
|
70
|
+
self.grasped_bodies.update(self.robot.get_grasped_bodies())
|
|
71
|
+
grasped_bodies_list = list(self.grasped_bodies)
|
|
72
|
+
if f"{self.object.name}_object" in grasped_bodies_list:
|
|
73
|
+
self.grasping_object = True
|
|
74
|
+
info = {
|
|
75
|
+
"grasped_bodies": grasped_bodies_list,
|
|
76
|
+
"object_name": self.object.name,
|
|
77
|
+
"initial_robot_pose": self.initial_robot_pose,
|
|
78
|
+
"initial_object_pose": self.object.last_set_pose,
|
|
79
|
+
"is_grasping": self.robot.get_grasp(),
|
|
80
|
+
"gripper_current_position": self.robot.get_tcp_pose(),
|
|
81
|
+
"grasping_object": self.grasping_object,
|
|
82
|
+
**info
|
|
83
|
+
}
|
|
84
|
+
return obs, reward, terminated, truncated, info
|
egogym/tasks/open.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import cv2
|
|
3
|
+
import mujoco
|
|
4
|
+
from gymnasium.spaces import Box
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.spatial.transform import Rotation as R
|
|
7
|
+
|
|
8
|
+
from egogym.egogym import Egogym
|
|
9
|
+
from egogym.utils import include_in_scene, position_sampler, make_objects_manager
|
|
10
|
+
import egogym.assets.constants as constants
|
|
11
|
+
|
|
12
|
+
class OpenTask(Egogym):
|
|
13
|
+
|
|
14
|
+
def __init__(self, robot="rum", action_space="delta", render_mode=None, render_size=(960,720), num_objs=1, seed=None, objects_set=None):
|
|
15
|
+
super().__init__(robot=robot, action_space=action_space, render_mode=render_mode, render_size=render_size, seed=seed)
|
|
16
|
+
self.num_objs = num_objs
|
|
17
|
+
if objects_set is not None:
|
|
18
|
+
self.objects_manager = make_objects_manager(objects_set, self.np_random, shuffle=False)
|
|
19
|
+
else:
|
|
20
|
+
self.objects_manager = make_objects_manager(constants.all_open_objects_set, self.np_random)
|
|
21
|
+
self.observation_space["handle_pose"] = Box(low=-np.inf, high=np.inf, shape=(16,), dtype=np.float32)
|
|
22
|
+
|
|
23
|
+
def make_task_scene(self):
|
|
24
|
+
with open(f"{os.path.dirname(__file__)}/../assets/scenes/open.xml", "r") as f:
|
|
25
|
+
scene_xml = f.read()
|
|
26
|
+
|
|
27
|
+
self.object = self.objects_manager.sample()
|
|
28
|
+
scene_xml = self.object.add_to_scene_xml(scene_xml)
|
|
29
|
+
scene_xml = include_in_scene(scene_xml, f"{os.path.dirname(__file__)}/../assets/embodiments/{self.robot.name}/model_open.xml")
|
|
30
|
+
return scene_xml
|
|
31
|
+
|
|
32
|
+
def get_obs(self):
|
|
33
|
+
obs = self.get_base_obs()
|
|
34
|
+
obs["handle_pose"] = self.object.get_handle_pose(self.data).reshape(16).astype(np.float32)
|
|
35
|
+
return obs
|
|
36
|
+
|
|
37
|
+
def compute_reward(self) -> float:
|
|
38
|
+
perecentage_opened = self.object.get_perecenttage_opened(self.model, self.data)
|
|
39
|
+
return perecentage_opened
|
|
40
|
+
|
|
41
|
+
def reset(self, seed=None, options=None):
|
|
42
|
+
super().reset(seed=seed, options=options)
|
|
43
|
+
scene_xml_string = self.make_task_scene()
|
|
44
|
+
self.grasped_bodies = set()
|
|
45
|
+
self.grasping_object = False
|
|
46
|
+
|
|
47
|
+
gripper_init_pose = np.eye(4)
|
|
48
|
+
gripper_init_pose[:3, :3] = R.from_euler("xyz", np.array([1.5, 0.000, 0.000])).as_matrix()
|
|
49
|
+
self.setup_mujoco(scene_xml_string)
|
|
50
|
+
mujoco.mj_step(self.model, self.data)
|
|
51
|
+
handle_pose = self.object.get_handle_pose(self.data)
|
|
52
|
+
camera_pos = handle_pose[0:3, 3].copy()
|
|
53
|
+
camera_pos += np.array([0.0, -1.0, 0.28])
|
|
54
|
+
self.model.cam_pos[mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_CAMERA, "exocentric")] = camera_pos
|
|
55
|
+
mujoco.mj_step(self.model, self.data)
|
|
56
|
+
gripper_init_pose[:3, 3] = handle_pose[:3, 3] + np.array([0.0, -0.6, 0.05])
|
|
57
|
+
self.robot.prepare(gripper_init_pose, self.env_step)
|
|
58
|
+
self.env_step()
|
|
59
|
+
self.initial_robot_pose = self.robot.get_camera_pose()
|
|
60
|
+
|
|
61
|
+
self.settle_env()
|
|
62
|
+
self.enable_sleeping_islands()
|
|
63
|
+
|
|
64
|
+
observation = self.get_obs()
|
|
65
|
+
info = {"object_name": self.object.name}
|
|
66
|
+
|
|
67
|
+
return observation, info
|
|
68
|
+
|
|
69
|
+
def step(self, action):
|
|
70
|
+
obs, reward, terminated, truncated, info = super().step(action)
|
|
71
|
+
self.grasped_bodies.update(self.robot.get_grasped_bodies())
|
|
72
|
+
grasped_bodies_list = list(self.grasped_bodies)
|
|
73
|
+
if f"{self.object.name}_object" in grasped_bodies_list:
|
|
74
|
+
self.grasping_object = True
|
|
75
|
+
info = {
|
|
76
|
+
"grasped_bodies": grasped_bodies_list,
|
|
77
|
+
"object_name": self.object.name,
|
|
78
|
+
"initial_robot_pose": self.initial_robot_pose,
|
|
79
|
+
"initial_object_pose": self.object.last_set_pose,
|
|
80
|
+
"is_grasping": self.robot.get_grasp(),
|
|
81
|
+
"gripper_current_position": self.robot.get_tcp_pose(),
|
|
82
|
+
"grasping_object": self.grasping_object,
|
|
83
|
+
**info
|
|
84
|
+
}
|
|
85
|
+
return obs, reward, terminated, truncated, info
|
egogym/tasks/pick.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from gymnasium.spaces import Box, Text
|
|
3
|
+
import numpy as np
|
|
4
|
+
from scipy.spatial.transform import Rotation as R
|
|
5
|
+
|
|
6
|
+
from egogym.egogym import Egogym
|
|
7
|
+
from egogym.utils import include_in_scene, position_sampler, make_objects_manager, make_textures_manager
|
|
8
|
+
import egogym.assets.constants as constants
|
|
9
|
+
|
|
10
|
+
class PickTask(Egogym):
|
|
11
|
+
|
|
12
|
+
def __init__(self, robot="rum", action_space="delta", render_mode=None, render_size=(960,720), num_objs=1, seed=None, objects_set=None):
|
|
13
|
+
super().__init__(robot=robot, action_space=action_space, render_mode=render_mode, render_size=render_size, seed=seed)
|
|
14
|
+
self.spread = 0.22
|
|
15
|
+
self.num_objs = num_objs
|
|
16
|
+
if objects_set is not None:
|
|
17
|
+
self.objects_manager = make_objects_manager(objects_set, self.np_random, shuffle=True)
|
|
18
|
+
else:
|
|
19
|
+
self.objects_manager = make_objects_manager(constants.lite_pick_objects_set, self.np_random)
|
|
20
|
+
self.textures_manager = make_textures_manager([f"wood/{i}.png" for i in range(10)], self.np_random)
|
|
21
|
+
self.observation_space["object_pose"] = Box(low=-np.inf, high=np.inf, shape=(16,), dtype=np.float32)
|
|
22
|
+
self.observation_space["object_name"] = Text(max_length=256)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def make_task_scene(self):
|
|
26
|
+
assets_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "assets"))
|
|
27
|
+
|
|
28
|
+
with open(os.path.join(assets_dir, "scenes", "pick.xml"), "r") as f:
|
|
29
|
+
scene_xml = f.read()
|
|
30
|
+
|
|
31
|
+
self.object = self.objects_manager.sample()
|
|
32
|
+
self.secondary_objects = []
|
|
33
|
+
used_objects = {self.object.name}
|
|
34
|
+
scene_xml = self.object.add_to_scene_xml(scene_xml)
|
|
35
|
+
while len(used_objects) < self.num_objs:
|
|
36
|
+
object = self.objects_manager.sample(random=True)
|
|
37
|
+
if object.name not in used_objects and object.name.split("_")[0] != self.object.name.split("_")[0]:
|
|
38
|
+
self.secondary_objects.append(object)
|
|
39
|
+
used_objects.add(object.name)
|
|
40
|
+
scene_xml = object.add_to_scene_xml(scene_xml)
|
|
41
|
+
self.texture = self.textures_manager.sample(random=True)
|
|
42
|
+
scene_xml = include_in_scene(scene_xml, os.path.join(assets_dir, "embodiments", self.robot.name, "model_pick.xml"))
|
|
43
|
+
scene_xml = scene_xml.replace("{ASSETS_DIR}", assets_dir)
|
|
44
|
+
scene_xml = scene_xml.replace("{TEXTURE_PATH}", self.texture)
|
|
45
|
+
scene_xml = scene_xml.replace("{TABLE_WIDTH}", str(self.spread+0.24))
|
|
46
|
+
scene_xml = scene_xml.replace("{TABLE_HEIGHT}", str(self.spread+0.24))
|
|
47
|
+
return scene_xml
|
|
48
|
+
|
|
49
|
+
def get_obs(self):
|
|
50
|
+
obs = self.get_base_obs()
|
|
51
|
+
obs["object_pose"] = self.object.get_pose(self.data).reshape(16).astype(np.float32)
|
|
52
|
+
obs["object_name"] = self.object.name.split("_")[0]
|
|
53
|
+
return obs
|
|
54
|
+
|
|
55
|
+
def compute_reward(self) -> float:
|
|
56
|
+
object_pose = self.object.get_pose(self.data)
|
|
57
|
+
lift_distance = object_pose[2][3] - self.object.last_set_pose[2][3]
|
|
58
|
+
return lift_distance
|
|
59
|
+
|
|
60
|
+
def reset(self, seed=None, options=None):
|
|
61
|
+
super().reset(seed=seed, options=options)
|
|
62
|
+
scene_xml_string = self.make_task_scene()
|
|
63
|
+
self.grasped_bodies = set()
|
|
64
|
+
self.grasping_object = False
|
|
65
|
+
|
|
66
|
+
gripper_init_pose = np.eye(4)
|
|
67
|
+
gripper_init_pose[:3, 3] = np.array([0.00, -0.68, 1.1])
|
|
68
|
+
gripper_init_pose[:3, :3] = R.from_euler("xyz", np.array([1.3, 0.000, 0.000])).as_matrix()
|
|
69
|
+
self.setup_mujoco(scene_xml_string)
|
|
70
|
+
initial_positions = position_sampler(self.np_random, len(self.secondary_objects) + 1, [self.spread,self.spread], self.spread/2, 100, start_position=gripper_init_pose[:2, 3], thickness=self.spread/2.5)
|
|
71
|
+
pose = np.eye(4)
|
|
72
|
+
self.initial_robot_pose = self.robot.get_camera_pose()
|
|
73
|
+
self.env_step(10)
|
|
74
|
+
|
|
75
|
+
for i, obj in enumerate([self.object] + self.secondary_objects):
|
|
76
|
+
z = obj.get_center_to_bottom_z_distance(self.data) + 0.78
|
|
77
|
+
pose[:2, 3] = initial_positions[i][:2]
|
|
78
|
+
pose[2, 3] = z
|
|
79
|
+
pose[:3, :3] = R.from_euler(
|
|
80
|
+
"x", self.np_random.uniform(0, 2 * np.pi)
|
|
81
|
+
).as_matrix()
|
|
82
|
+
obj.set_pose(self.model, self.data, pose)
|
|
83
|
+
|
|
84
|
+
self.env_step(100)
|
|
85
|
+
self.settle_env()
|
|
86
|
+
self.env_step(100)
|
|
87
|
+
for obj in [self.object] + self.secondary_objects:
|
|
88
|
+
obj.last_set_pose = obj.get_pose(self.data).copy()
|
|
89
|
+
|
|
90
|
+
self.robot.prepare(gripper_init_pose, self.env_step)
|
|
91
|
+
self.env_step(100)
|
|
92
|
+
self.enable_sleeping_islands()
|
|
93
|
+
|
|
94
|
+
observation = self.get_obs()
|
|
95
|
+
info = {"object_name": self.object.name}
|
|
96
|
+
|
|
97
|
+
return observation, info
|
|
98
|
+
|
|
99
|
+
def step(self, action):
|
|
100
|
+
obs, reward, terminated, truncated, info = super().step(action)
|
|
101
|
+
self.grasped_bodies.update(self.robot.get_grasped_bodies())
|
|
102
|
+
grasped_bodies_list = list(self.grasped_bodies)
|
|
103
|
+
if f"{self.object.name}_object" in grasped_bodies_list:
|
|
104
|
+
self.grasping_object = True
|
|
105
|
+
info = {
|
|
106
|
+
"grasped_bodies": grasped_bodies_list,
|
|
107
|
+
"object_name": self.object.name,
|
|
108
|
+
"texture_name": self.texture,
|
|
109
|
+
"initial_robot_pose": self.initial_robot_pose,
|
|
110
|
+
"initial_object_pose": self.object.last_set_pose,
|
|
111
|
+
"is_grasping": self.robot.get_grasp(),
|
|
112
|
+
"gripper_current_position": self.robot.get_tcp_pose(),
|
|
113
|
+
"grasping_object": self.grasping_object,
|
|
114
|
+
**info
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
if np.linalg.norm(self.robot.get_camera_pose()[:3,3] - self.object.get_pose(self.data)[:3,3]) > np.linalg.norm(self.initial_robot_pose[:3,3] - self.object.last_set_pose[:3,3])*1.5:
|
|
118
|
+
truncated = True
|
|
119
|
+
terminated = True
|
|
120
|
+
|
|
121
|
+
return obs, reward, terminated, truncated, info
|