egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from egogym.wrappers.episode_monitor import EpisodeMonitor
|
|
2
|
+
from egogym.wrappers.unprivileged_molmo import UnprivilegedMolmo
|
|
3
|
+
from egogym.wrappers.unprivileged_gemini import UnprivilegedGemini
|
|
4
|
+
from egogym.wrappers.unprivileged_moondream import UnprivilegedMoondream
|
|
5
|
+
from egogym.wrappers.unprivileged_chatgpt import UnprivilegedChatGPT
|
|
6
|
+
|
|
7
|
+
VLM_WRAPPERS = {
|
|
8
|
+
"molmo": UnprivilegedMolmo,
|
|
9
|
+
"gemini": UnprivilegedGemini,
|
|
10
|
+
"moondream": UnprivilegedMoondream,
|
|
11
|
+
"chatgpt": UnprivilegedChatGPT,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_vlm_wrapper(vlm_name: str):
|
|
16
|
+
if vlm_name not in VLM_WRAPPERS:
|
|
17
|
+
available = ", ".join(f"'{k}'" for k in VLM_WRAPPERS.keys())
|
|
18
|
+
raise ValueError(f"Unknown VLM option: '{vlm_name}'. Choose from {available}.")
|
|
19
|
+
return VLM_WRAPPERS[vlm_name]
|
|
20
|
+
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import csv
|
|
3
|
+
import os
|
|
4
|
+
import cv2
|
|
5
|
+
|
|
6
|
+
class EpisodeMonitor:
|
|
7
|
+
|
|
8
|
+
def __init__(self, env, logs_dir=None, columns=None, record=False, render_freq=0, num_envs=1):
|
|
9
|
+
self.env = env
|
|
10
|
+
self.logs_dir = logs_dir
|
|
11
|
+
self.record = record
|
|
12
|
+
self.mujoco_step_counter = 0
|
|
13
|
+
self.columns = columns or [
|
|
14
|
+
"episode", "max_reward", "object_name", "texture_name",
|
|
15
|
+
"grasped_bodies", "steps", "initial_robot_pose",
|
|
16
|
+
"initial_object_pose", "is_grasping",
|
|
17
|
+
"gripper_current_position", "grasping_object"
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
self.is_vector = num_envs > 1
|
|
21
|
+
self.render_freq = render_freq
|
|
22
|
+
|
|
23
|
+
if self.is_vector:
|
|
24
|
+
self.num_envs = num_envs
|
|
25
|
+
self.global_episode_counter = 0
|
|
26
|
+
self.episode_idx = np.zeros(self.num_envs, dtype=int)
|
|
27
|
+
self.max_reward = np.zeros(self.num_envs)
|
|
28
|
+
self.steps = np.zeros(self.num_envs, dtype=int)
|
|
29
|
+
self.episode_info = [{} for _ in range(self.num_envs)]
|
|
30
|
+
if self.record:
|
|
31
|
+
self.video_writers = [None] * self.num_envs
|
|
32
|
+
self.video_frames = [[] for _ in range(self.num_envs)]
|
|
33
|
+
else:
|
|
34
|
+
self.episode_idx = 0
|
|
35
|
+
self.max_reward = 0
|
|
36
|
+
self.steps = 0
|
|
37
|
+
self.episode_info = {}
|
|
38
|
+
if self.record:
|
|
39
|
+
self.video_writer = None
|
|
40
|
+
self.video_frames = []
|
|
41
|
+
|
|
42
|
+
self.csv_writer = None
|
|
43
|
+
self.csv_file = None
|
|
44
|
+
|
|
45
|
+
os.makedirs(self.logs_dir, exist_ok=True)
|
|
46
|
+
csv_path = os.path.join(self.logs_dir, "log.csv")
|
|
47
|
+
self.csv_file = open(csv_path, 'w', newline='')
|
|
48
|
+
self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=self.columns, delimiter='\t')
|
|
49
|
+
self.csv_writer.writeheader()
|
|
50
|
+
|
|
51
|
+
if self.record:
|
|
52
|
+
self.video_dir = os.path.join(self.logs_dir, "videos")
|
|
53
|
+
os.makedirs(self.video_dir, exist_ok=True)
|
|
54
|
+
|
|
55
|
+
self._callbacks_initialized = False
|
|
56
|
+
self._callbacks_paused = False
|
|
57
|
+
self._callbacks_paused = False
|
|
58
|
+
|
|
59
|
+
def _setup_callbacks(self):
|
|
60
|
+
if self._callbacks_initialized or self.render_freq <= 0:
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
if self.is_vector:
|
|
64
|
+
# AsyncVectorEnv doesn't provide direct access to individual envs
|
|
65
|
+
# Callbacks cannot be set up for vectorized environments
|
|
66
|
+
pass
|
|
67
|
+
else:
|
|
68
|
+
def callback():
|
|
69
|
+
self._capture_frame_single()
|
|
70
|
+
actual_env = self.env.unwrapped
|
|
71
|
+
actual_env._render_callback = callback
|
|
72
|
+
actual_env._render_freq = self.render_freq
|
|
73
|
+
actual_env._should_render = actual_env.render_mode == 'human'
|
|
74
|
+
|
|
75
|
+
self._callbacks_initialized = True
|
|
76
|
+
|
|
77
|
+
def _capture_frame_single(self):
|
|
78
|
+
if self.render_freq <= 0 or self._callbacks_paused:
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
actual_env = self.env.unwrapped
|
|
83
|
+
|
|
84
|
+
if actual_env._should_render:
|
|
85
|
+
actual_env.render()
|
|
86
|
+
|
|
87
|
+
if self.record:
|
|
88
|
+
frame_ego = actual_env.robot.get_camera_view(actual_env.robot.camera_names[0])
|
|
89
|
+
frame_exo = actual_env.robot.get_camera_view(actual_env.robot.camera_names[1])
|
|
90
|
+
frame = np.concatenate((frame_ego, frame_exo), axis=1)
|
|
91
|
+
frame = self._process_frame(frame)
|
|
92
|
+
self.video_frames.append(frame)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
def _capture_frame_vectorized(self, env_idx):
|
|
97
|
+
if self.render_freq <= 0 or self._callbacks_paused:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
env = self.env.envs[env_idx]
|
|
102
|
+
actual_env = env.unwrapped
|
|
103
|
+
|
|
104
|
+
if actual_env._should_render:
|
|
105
|
+
actual_env.render()
|
|
106
|
+
|
|
107
|
+
if self.record:
|
|
108
|
+
frame_ego = actual_env.robot.get_camera_view(actual_env.robot.camera_names[0])
|
|
109
|
+
frame_exo = actual_env.robot.get_camera_view(actual_env.robot.camera_names[1])
|
|
110
|
+
frame = np.concatenate((frame_ego, frame_exo), axis=1)
|
|
111
|
+
frame = self._process_frame(frame)
|
|
112
|
+
self.video_frames[env_idx].append(frame)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
def _process_frame(self, frame):
|
|
117
|
+
if isinstance(frame, np.ndarray):
|
|
118
|
+
if frame.dtype != np.uint8:
|
|
119
|
+
frame = (frame * 255).astype(np.uint8) if frame.max() <= 1.0 else frame.astype(np.uint8)
|
|
120
|
+
if frame.shape[-1] == 3:
|
|
121
|
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
|
122
|
+
return frame
|
|
123
|
+
|
|
124
|
+
def reset(self, **kwargs):
|
|
125
|
+
self._callbacks_paused = True
|
|
126
|
+
|
|
127
|
+
if self.is_vector:
|
|
128
|
+
options = kwargs.get('options', {})
|
|
129
|
+
reset_mask = options.get('reset_mask', np.ones(self.num_envs, dtype=bool))
|
|
130
|
+
|
|
131
|
+
obs, info = self.env.reset(**kwargs)
|
|
132
|
+
self._setup_callbacks()
|
|
133
|
+
|
|
134
|
+
for i in np.where(reset_mask)[0]:
|
|
135
|
+
if self.steps[i] == 0 and self.record:
|
|
136
|
+
self.video_frames[i] = []
|
|
137
|
+
self.max_reward[i] = 0
|
|
138
|
+
self.steps[i] = 0
|
|
139
|
+
self.episode_info[i] = {}
|
|
140
|
+
|
|
141
|
+
self._callbacks_paused = False
|
|
142
|
+
else:
|
|
143
|
+
if self.steps > 0:
|
|
144
|
+
if self.csv_writer:
|
|
145
|
+
self._log_episode()
|
|
146
|
+
if self.record:
|
|
147
|
+
self._save_video()
|
|
148
|
+
self.episode_idx += 1
|
|
149
|
+
obs, info = self.env.reset(**kwargs)
|
|
150
|
+
self._setup_callbacks()
|
|
151
|
+
|
|
152
|
+
self.max_reward = 0
|
|
153
|
+
self.steps = 0
|
|
154
|
+
self.episode_info = {}
|
|
155
|
+
if self.record:
|
|
156
|
+
self.video_frames = []
|
|
157
|
+
self._callbacks_paused = False
|
|
158
|
+
return obs, info
|
|
159
|
+
|
|
160
|
+
def step(self, action):
|
|
161
|
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
162
|
+
|
|
163
|
+
if self.render_freq == 0:
|
|
164
|
+
if self.is_vector:
|
|
165
|
+
# AsyncVectorEnv doesn't provide direct access to individual envs
|
|
166
|
+
# Recording from observations if available
|
|
167
|
+
if self.record and 'rgb_ego' in obs and 'rgb_exo' in obs:
|
|
168
|
+
for i in range(self.num_envs):
|
|
169
|
+
frame_ego = obs['rgb_ego'][i]
|
|
170
|
+
frame_exo = obs['rgb_exo'][i]
|
|
171
|
+
frame = np.concatenate((frame_ego, frame_exo), axis=1)
|
|
172
|
+
frame = self._process_frame(frame)
|
|
173
|
+
self.video_frames[i].append(frame)
|
|
174
|
+
else:
|
|
175
|
+
actual_env = self.env.unwrapped
|
|
176
|
+
if actual_env._should_render:
|
|
177
|
+
actual_env.render()
|
|
178
|
+
if self.record and 'rgb_ego' in obs and 'rgb_exo' in obs:
|
|
179
|
+
frame_ego = obs['rgb_ego']
|
|
180
|
+
frame_exo = obs['rgb_exo']
|
|
181
|
+
frame = np.concatenate((frame_ego, frame_exo), axis=1)
|
|
182
|
+
frame = self._process_frame(frame)
|
|
183
|
+
self.video_frames.append(frame)
|
|
184
|
+
|
|
185
|
+
if self.is_vector:
|
|
186
|
+
self.max_reward = np.maximum(self.max_reward, reward)
|
|
187
|
+
self.steps += 1
|
|
188
|
+
for i in range(self.num_envs):
|
|
189
|
+
if isinstance(info, dict):
|
|
190
|
+
env_info = {}
|
|
191
|
+
for k, v in info.items():
|
|
192
|
+
if k.startswith('_'):
|
|
193
|
+
continue
|
|
194
|
+
if isinstance(v, (list, np.ndarray)) and len(v) > i:
|
|
195
|
+
env_info[k] = v[i]
|
|
196
|
+
else:
|
|
197
|
+
env_info[k] = v
|
|
198
|
+
else:
|
|
199
|
+
env_info = info[i] if isinstance(info, (list, tuple)) else {}
|
|
200
|
+
self.episode_info[i].update(env_info)
|
|
201
|
+
|
|
202
|
+
dones = np.logical_or(terminated, truncated)
|
|
203
|
+
if isinstance(dones, bool):
|
|
204
|
+
dones = np.array([dones] * self.num_envs)
|
|
205
|
+
else:
|
|
206
|
+
self.max_reward = max(self.max_reward, reward)
|
|
207
|
+
self.steps += 1
|
|
208
|
+
self.episode_info.update(info)
|
|
209
|
+
info['episode_stats'] = {'episode': self.episode_idx, 'max_reward': self.max_reward, 'steps': self.steps}
|
|
210
|
+
|
|
211
|
+
return obs, reward, terminated, truncated, info
|
|
212
|
+
|
|
213
|
+
def _log_episode(self, env_idx=None):
|
|
214
|
+
if env_idx is not None:
|
|
215
|
+
log_data = {"episode": self.global_episode_counter, "max_reward": float(self.max_reward[env_idx]), "steps": int(self.steps[env_idx]), **self.episode_info[env_idx]}
|
|
216
|
+
self.global_episode_counter += 1
|
|
217
|
+
else:
|
|
218
|
+
log_data = {"episode": self.episode_idx, "max_reward": self.max_reward, "steps": self.steps, **self.episode_info}
|
|
219
|
+
|
|
220
|
+
filtered_data = {}
|
|
221
|
+
for col in self.columns:
|
|
222
|
+
value = log_data.get(col, "")
|
|
223
|
+
if isinstance(value, np.ndarray):
|
|
224
|
+
value = value.flatten().tolist()
|
|
225
|
+
filtered_data[col] = value
|
|
226
|
+
|
|
227
|
+
self.csv_writer.writerow(filtered_data)
|
|
228
|
+
self.csv_file.flush()
|
|
229
|
+
|
|
230
|
+
def log_episodes(self, env_indices):
|
|
231
|
+
if self.is_vector:
|
|
232
|
+
for i in env_indices:
|
|
233
|
+
if self.csv_writer:
|
|
234
|
+
self._log_episode(i)
|
|
235
|
+
if self.record:
|
|
236
|
+
self._save_video(i)
|
|
237
|
+
self.video_frames[i] = []
|
|
238
|
+
self.episode_idx[i] += 1
|
|
239
|
+
|
|
240
|
+
def _save_video(self, env_idx=None):
|
|
241
|
+
if env_idx is not None:
|
|
242
|
+
frames = self.video_frames[env_idx]
|
|
243
|
+
episode_num = int(self.episode_idx[env_idx])
|
|
244
|
+
video_filename = f"env_{env_idx}_episode_{episode_num}.mp4"
|
|
245
|
+
else:
|
|
246
|
+
frames = self.video_frames
|
|
247
|
+
episode_num = self.episode_idx
|
|
248
|
+
video_filename = f"episode_{episode_num}.mp4"
|
|
249
|
+
|
|
250
|
+
if not frames:
|
|
251
|
+
return
|
|
252
|
+
|
|
253
|
+
video_path = os.path.join(self.video_dir, video_filename)
|
|
254
|
+
height, width = frames[0].shape[:2]
|
|
255
|
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
256
|
+
out = cv2.VideoWriter(video_path, fourcc, 30, (width, height))
|
|
257
|
+
|
|
258
|
+
for frame in frames:
|
|
259
|
+
out.write(frame)
|
|
260
|
+
|
|
261
|
+
out.release()
|
|
262
|
+
|
|
263
|
+
def close(self):
|
|
264
|
+
if self.is_vector:
|
|
265
|
+
for i in range(self.num_envs):
|
|
266
|
+
if self.steps[i] > 0:
|
|
267
|
+
if self.csv_writer:
|
|
268
|
+
self._log_episode(i)
|
|
269
|
+
if self.record:
|
|
270
|
+
self._save_video(i)
|
|
271
|
+
else:
|
|
272
|
+
if self.steps > 0:
|
|
273
|
+
if self.csv_writer:
|
|
274
|
+
self._log_episode()
|
|
275
|
+
if self.record:
|
|
276
|
+
self._save_video()
|
|
277
|
+
if self.csv_file:
|
|
278
|
+
self.csv_file.close()
|
|
279
|
+
self.env.close()
|
|
280
|
+
|
|
281
|
+
def __getattr__(self, name):
|
|
282
|
+
return getattr(self.env, name)
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
import time
|
|
4
|
+
import re
|
|
5
|
+
import io
|
|
6
|
+
import json
|
|
7
|
+
import base64
|
|
8
|
+
from PIL import Image
|
|
9
|
+
from egogym.utils import pixel_to_world
|
|
10
|
+
|
|
11
|
+
MODEL = "gpt-4o"
|
|
12
|
+
PROMPT = """Get all points matching the following object: {object_name}. The label returned should be an identifying name for the object detected.
|
|
13
|
+
The answer should follow the json format: [{{"point": [y, x], "label": "{object_name}"}}, ...]. The points are in [y, x] format normalized to 0-1000."""
|
|
14
|
+
|
|
15
|
+
class UnprivilegedChatGPT:
|
|
16
|
+
|
|
17
|
+
def __init__(self, env, api_key=None):
|
|
18
|
+
try:
|
|
19
|
+
from openai import OpenAI
|
|
20
|
+
except ImportError:
|
|
21
|
+
raise ImportError("openai package is not installed. Please install it with `pip install openai`.")
|
|
22
|
+
|
|
23
|
+
self.env = env
|
|
24
|
+
self.unprivileged_T_world_object = None
|
|
25
|
+
|
|
26
|
+
if api_key is None:
|
|
27
|
+
api_key = os.environ.get("OPENAI_API_KEY")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
self.client = OpenAI(api_key=api_key)
|
|
31
|
+
|
|
32
|
+
def _call_api_with_retry(self, messages, max_retries=3, base_delay=2):
|
|
33
|
+
"""Call OpenAI API with retry logic for transient errors."""
|
|
34
|
+
for attempt in range(max_retries):
|
|
35
|
+
try:
|
|
36
|
+
return self.client.chat.completions.create(
|
|
37
|
+
model=MODEL,
|
|
38
|
+
messages=messages,
|
|
39
|
+
temperature=0.5,
|
|
40
|
+
max_tokens=1024
|
|
41
|
+
)
|
|
42
|
+
except Exception as e:
|
|
43
|
+
error_code = getattr(e, 'status_code', None)
|
|
44
|
+
if error_code in [503, 429] and attempt < max_retries - 1:
|
|
45
|
+
delay = base_delay * (2 ** attempt)
|
|
46
|
+
print(f"API error {error_code}: {str(e)}. Retrying in {delay}s... (attempt {attempt + 1}/{max_retries})")
|
|
47
|
+
time.sleep(delay)
|
|
48
|
+
continue
|
|
49
|
+
else:
|
|
50
|
+
raise
|
|
51
|
+
|
|
52
|
+
def _parse_json_point_yx(self, text):
|
|
53
|
+
"""Parse JSON format with point coordinates in [y, x] format normalized to 0-1000."""
|
|
54
|
+
try:
|
|
55
|
+
json_match = re.search(r'```(?:json)?\s*(\[.*?\])\s*```', text, re.DOTALL)
|
|
56
|
+
if json_match:
|
|
57
|
+
json_str = json_match.group(1)
|
|
58
|
+
else:
|
|
59
|
+
json_match = re.search(r'(\[.*?\])', text, re.DOTALL)
|
|
60
|
+
if json_match:
|
|
61
|
+
json_str = json_match.group(1)
|
|
62
|
+
else:
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
data = json.loads(json_str)
|
|
66
|
+
if not isinstance(data, list) or len(data) == 0:
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
point_data = data[0]
|
|
70
|
+
if 'point' in point_data:
|
|
71
|
+
point = point_data['point']
|
|
72
|
+
if isinstance(point, list) and len(point) >= 2:
|
|
73
|
+
return (float(point[0]), float(point[1]))
|
|
74
|
+
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
except (json.JSONDecodeError, KeyError, ValueError, IndexError) as e:
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
def _infer_point_sync(self, rgb: np.ndarray, object_name: str) -> np.ndarray:
|
|
81
|
+
"""Use OpenAI GPT-4o API to locate object in image and return normalized coordinates."""
|
|
82
|
+
if rgb.dtype == np.float32 or rgb.dtype == np.float64:
|
|
83
|
+
rgb = (rgb * 255).astype(np.uint8)
|
|
84
|
+
image = Image.fromarray(rgb)
|
|
85
|
+
img_width, img_height = image.size
|
|
86
|
+
|
|
87
|
+
prompt = PROMPT.format(object_name=object_name)
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
buf = io.BytesIO()
|
|
91
|
+
image.save(buf, format='PNG')
|
|
92
|
+
image_bytes = buf.getvalue()
|
|
93
|
+
base64_image = base64.b64encode(image_bytes).decode('utf-8')
|
|
94
|
+
|
|
95
|
+
messages = [
|
|
96
|
+
{
|
|
97
|
+
"role": "user",
|
|
98
|
+
"content": [
|
|
99
|
+
{
|
|
100
|
+
"type": "image_url",
|
|
101
|
+
"image_url": {
|
|
102
|
+
"url": f"data:image/png;base64,{base64_image}"
|
|
103
|
+
}
|
|
104
|
+
},
|
|
105
|
+
{
|
|
106
|
+
"type": "text",
|
|
107
|
+
"text": prompt
|
|
108
|
+
}
|
|
109
|
+
]
|
|
110
|
+
}
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
response = self._call_api_with_retry(messages)
|
|
114
|
+
generated_text = response.choices[0].message.content
|
|
115
|
+
|
|
116
|
+
json_point = self._parse_json_point_yx(generated_text)
|
|
117
|
+
if json_point:
|
|
118
|
+
y_pos, x_pos = json_point
|
|
119
|
+
x_norm = float(x_pos) / 1000.0
|
|
120
|
+
y_norm = float(y_pos) / 1000.0
|
|
121
|
+
|
|
122
|
+
x_norm = max(0.0, min(1.0, x_norm))
|
|
123
|
+
y_norm = max(0.0, min(1.0, y_norm))
|
|
124
|
+
|
|
125
|
+
return np.array([x_norm, y_norm], dtype=np.float32)
|
|
126
|
+
|
|
127
|
+
return np.array([0.5, 0.5], dtype=np.float32)
|
|
128
|
+
|
|
129
|
+
except Exception as e:
|
|
130
|
+
return np.array([0.5, 0.5], dtype=np.float32)
|
|
131
|
+
|
|
132
|
+
def reset(self, **kwargs):
|
|
133
|
+
obs, info = self.env.reset(**kwargs)
|
|
134
|
+
|
|
135
|
+
robot = self.env.unwrapped.robot
|
|
136
|
+
robot.rgb_renderers[robot.camera_names[0]].enable_depth_rendering()
|
|
137
|
+
robot.rgb_renderers[robot.camera_names[0]].update_scene(robot.data, robot.camera_names[0])
|
|
138
|
+
depth = robot.rgb_renderers[robot.camera_names[0]].render()
|
|
139
|
+
robot.rgb_renderers[robot.camera_names[0]].disable_depth_rendering()
|
|
140
|
+
|
|
141
|
+
point_norm = self._infer_point_sync(obs["rgb_ego"], obs["object_name"])
|
|
142
|
+
x_norm, y_norm = point_norm
|
|
143
|
+
x = int(x_norm * self.env.unwrapped.render_width)
|
|
144
|
+
y = int(y_norm * self.env.unwrapped.render_height)
|
|
145
|
+
depth_value = depth[y, x] + 0.03
|
|
146
|
+
self.unprivileged_T_world_object = pixel_to_world(x, y, depth_value, self.env.unwrapped.model, self.env.unwrapped.data, robot.camera_names[0], self.env.unwrapped.render_width, self.env.unwrapped.render_height)
|
|
147
|
+
|
|
148
|
+
return obs, info
|
|
149
|
+
|
|
150
|
+
def step(self, action):
|
|
151
|
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
152
|
+
|
|
153
|
+
object_pose = obs["object_pose"].reshape(4,4)
|
|
154
|
+
object_pose[:3, 3] = self.unprivileged_T_world_object
|
|
155
|
+
obs["object_pose"] = object_pose.flatten()
|
|
156
|
+
|
|
157
|
+
return obs, reward, terminated, truncated, info
|
|
158
|
+
|
|
159
|
+
def close(self):
|
|
160
|
+
self.env.close()
|
|
161
|
+
|
|
162
|
+
def __getattr__(self, name):
|
|
163
|
+
return getattr(self.env, name)
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
import time
|
|
4
|
+
import re
|
|
5
|
+
import io
|
|
6
|
+
import json
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from egogym.utils import pixel_to_world
|
|
9
|
+
|
|
10
|
+
MODEL = "gemini-robotics-er-1.5-preview"
|
|
11
|
+
PROMPT = """Get all points matching the following object: {object_name}. The label returned should be an identifying name for the object detected.
|
|
12
|
+
The answer should follow the json format: [{{"point": [y, x], "label": "{object_name}"}}, ...]. The points are in [y, x] format normalized to 0-1000."""
|
|
13
|
+
|
|
14
|
+
class UnprivilegedGemini:
|
|
15
|
+
|
|
16
|
+
def __init__(self, env, api_key=None):
|
|
17
|
+
try:
|
|
18
|
+
import google.genai
|
|
19
|
+
from google.genai import errors, types
|
|
20
|
+
except ImportError:
|
|
21
|
+
raise ImportError("google-genai package is not installed. Please install it with `pip install google-genai`.")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
self.env = env
|
|
25
|
+
self.unprivileged_T_world_object = None
|
|
26
|
+
|
|
27
|
+
if api_key is None:
|
|
28
|
+
api_key = os.getenv("GEMINI_API_KEY", "AIzaSyBQmaK2VE3y8qxOGOEliQsveTDpfK2LyLc")
|
|
29
|
+
|
|
30
|
+
self.client = genai.Client(api_key=api_key)
|
|
31
|
+
|
|
32
|
+
def _call_api_with_retry(self, model, contents, config, max_retries=3, base_delay=2):
|
|
33
|
+
"""Call Gemini API with retry logic for transient errors."""
|
|
34
|
+
for attempt in range(max_retries):
|
|
35
|
+
try:
|
|
36
|
+
return self.client.models.generate_content(
|
|
37
|
+
model=model,
|
|
38
|
+
contents=contents,
|
|
39
|
+
config=config
|
|
40
|
+
)
|
|
41
|
+
except errors.ServerError as e:
|
|
42
|
+
if e.code in [503, 429] and attempt < max_retries - 1:
|
|
43
|
+
delay = base_delay * (2 ** attempt)
|
|
44
|
+
error_msg = e.message if e.message else "Service unavailable"
|
|
45
|
+
print(f"API error {e.code}: {error_msg}. Retrying in {delay}s... (attempt {attempt + 1}/{max_retries})")
|
|
46
|
+
time.sleep(delay)
|
|
47
|
+
continue
|
|
48
|
+
else:
|
|
49
|
+
raise
|
|
50
|
+
except Exception as e:
|
|
51
|
+
raise
|
|
52
|
+
|
|
53
|
+
def _parse_json_point_yx(self, text):
|
|
54
|
+
"""Parse JSON format with point coordinates in [y, x] format normalized to 0-1000."""
|
|
55
|
+
try:
|
|
56
|
+
json_match = re.search(r'```(?:json)?\s*(\[.*?\])\s*```', text, re.DOTALL)
|
|
57
|
+
if json_match:
|
|
58
|
+
json_str = json_match.group(1)
|
|
59
|
+
else:
|
|
60
|
+
json_match = re.search(r'(\[.*?\])', text, re.DOTALL)
|
|
61
|
+
if json_match:
|
|
62
|
+
json_str = json_match.group(1)
|
|
63
|
+
else:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
data = json.loads(json_str)
|
|
67
|
+
if not isinstance(data, list) or len(data) == 0:
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
point_data = data[0]
|
|
71
|
+
if 'point' in point_data:
|
|
72
|
+
point = point_data['point']
|
|
73
|
+
if isinstance(point, list) and len(point) >= 2:
|
|
74
|
+
return (float(point[0]), float(point[1]))
|
|
75
|
+
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
except (json.JSONDecodeError, KeyError, ValueError, IndexError) as e:
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
def _infer_point_sync(self, rgb: np.ndarray, object_name: str) -> np.ndarray:
|
|
82
|
+
"""Use Gemini API to locate object in image and return normalized coordinates."""
|
|
83
|
+
if rgb.dtype == np.float32 or rgb.dtype == np.float64:
|
|
84
|
+
rgb = (rgb * 255).astype(np.uint8)
|
|
85
|
+
image = Image.fromarray(rgb)
|
|
86
|
+
img_width, img_height = image.size
|
|
87
|
+
|
|
88
|
+
prompt = PROMPT.format(object_name=object_name)
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
buf = io.BytesIO()
|
|
92
|
+
image.save(buf, format='PNG')
|
|
93
|
+
image_bytes = buf.getvalue()
|
|
94
|
+
|
|
95
|
+
contents = [
|
|
96
|
+
types.Part.from_bytes(
|
|
97
|
+
data=image_bytes,
|
|
98
|
+
mime_type='image/png',
|
|
99
|
+
),
|
|
100
|
+
prompt
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
config = types.GenerateContentConfig(
|
|
104
|
+
temperature=0.5,
|
|
105
|
+
thinking_config=types.ThinkingConfig(thinking_budget=0)
|
|
106
|
+
)
|
|
107
|
+
response = self._call_api_with_retry(MODEL, contents, config)
|
|
108
|
+
generated_text = response.text
|
|
109
|
+
|
|
110
|
+
json_point = self._parse_json_point_yx(generated_text)
|
|
111
|
+
if json_point:
|
|
112
|
+
y_pos, x_pos = json_point
|
|
113
|
+
x_norm = float(x_pos) / 1000.0
|
|
114
|
+
y_norm = float(y_pos) / 1000.0
|
|
115
|
+
|
|
116
|
+
x_norm = max(0.0, min(1.0, x_norm))
|
|
117
|
+
y_norm = max(0.0, min(1.0, y_norm))
|
|
118
|
+
|
|
119
|
+
return np.array([x_norm, y_norm], dtype=np.float32)
|
|
120
|
+
|
|
121
|
+
return np.array([0.5, 0.5], dtype=np.float32)
|
|
122
|
+
|
|
123
|
+
except Exception as e:
|
|
124
|
+
return np.array([0.5, 0.5], dtype=np.float32)
|
|
125
|
+
|
|
126
|
+
def reset(self, **kwargs):
|
|
127
|
+
obs, info = self.env.reset(**kwargs)
|
|
128
|
+
|
|
129
|
+
robot = self.env.unwrapped.robot
|
|
130
|
+
robot.rgb_renderers[robot.camera_names[0]].enable_depth_rendering()
|
|
131
|
+
robot.rgb_renderers[robot.camera_names[0]].update_scene(robot.data, robot.camera_names[0])
|
|
132
|
+
depth = robot.rgb_renderers[robot.camera_names[0]].render()
|
|
133
|
+
robot.rgb_renderers[robot.camera_names[0]].disable_depth_rendering()
|
|
134
|
+
|
|
135
|
+
point_norm = self._infer_point_sync(obs["rgb_ego"], obs["object_name"])
|
|
136
|
+
x_norm, y_norm = point_norm
|
|
137
|
+
x = int(x_norm * self.env.unwrapped.render_width)
|
|
138
|
+
y = int(y_norm * self.env.unwrapped.render_height)
|
|
139
|
+
depth_value = depth[y, x] + 0.03
|
|
140
|
+
self.unprivileged_T_world_object = pixel_to_world(x, y, depth_value, self.env.unwrapped.model, self.env.unwrapped.data, robot.camera_names[0], self.env.unwrapped.render_width, self.env.unwrapped.render_height)
|
|
141
|
+
|
|
142
|
+
return obs, info
|
|
143
|
+
|
|
144
|
+
def step(self, action):
|
|
145
|
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
146
|
+
|
|
147
|
+
object_pose = obs["object_pose"].reshape(4,4)
|
|
148
|
+
object_pose[:3, 3] = self.unprivileged_T_world_object
|
|
149
|
+
obs["object_pose"] = object_pose.flatten()
|
|
150
|
+
|
|
151
|
+
return obs, reward, terminated, truncated, info
|
|
152
|
+
|
|
153
|
+
def close(self):
|
|
154
|
+
self.env.close()
|
|
155
|
+
|
|
156
|
+
def __getattr__(self, name):
|
|
157
|
+
return getattr(self.env, name)
|