egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,20 @@
1
+ from egogym.wrappers.episode_monitor import EpisodeMonitor
2
+ from egogym.wrappers.unprivileged_molmo import UnprivilegedMolmo
3
+ from egogym.wrappers.unprivileged_gemini import UnprivilegedGemini
4
+ from egogym.wrappers.unprivileged_moondream import UnprivilegedMoondream
5
+ from egogym.wrappers.unprivileged_chatgpt import UnprivilegedChatGPT
6
+
7
+ VLM_WRAPPERS = {
8
+ "molmo": UnprivilegedMolmo,
9
+ "gemini": UnprivilegedGemini,
10
+ "moondream": UnprivilegedMoondream,
11
+ "chatgpt": UnprivilegedChatGPT,
12
+ }
13
+
14
+
15
+ def get_vlm_wrapper(vlm_name: str):
16
+ if vlm_name not in VLM_WRAPPERS:
17
+ available = ", ".join(f"'{k}'" for k in VLM_WRAPPERS.keys())
18
+ raise ValueError(f"Unknown VLM option: '{vlm_name}'. Choose from {available}.")
19
+ return VLM_WRAPPERS[vlm_name]
20
+
@@ -0,0 +1,282 @@
1
+ import numpy as np
2
+ import csv
3
+ import os
4
+ import cv2
5
+
6
+ class EpisodeMonitor:
7
+
8
+ def __init__(self, env, logs_dir=None, columns=None, record=False, render_freq=0, num_envs=1):
9
+ self.env = env
10
+ self.logs_dir = logs_dir
11
+ self.record = record
12
+ self.mujoco_step_counter = 0
13
+ self.columns = columns or [
14
+ "episode", "max_reward", "object_name", "texture_name",
15
+ "grasped_bodies", "steps", "initial_robot_pose",
16
+ "initial_object_pose", "is_grasping",
17
+ "gripper_current_position", "grasping_object"
18
+ ]
19
+
20
+ self.is_vector = num_envs > 1
21
+ self.render_freq = render_freq
22
+
23
+ if self.is_vector:
24
+ self.num_envs = num_envs
25
+ self.global_episode_counter = 0
26
+ self.episode_idx = np.zeros(self.num_envs, dtype=int)
27
+ self.max_reward = np.zeros(self.num_envs)
28
+ self.steps = np.zeros(self.num_envs, dtype=int)
29
+ self.episode_info = [{} for _ in range(self.num_envs)]
30
+ if self.record:
31
+ self.video_writers = [None] * self.num_envs
32
+ self.video_frames = [[] for _ in range(self.num_envs)]
33
+ else:
34
+ self.episode_idx = 0
35
+ self.max_reward = 0
36
+ self.steps = 0
37
+ self.episode_info = {}
38
+ if self.record:
39
+ self.video_writer = None
40
+ self.video_frames = []
41
+
42
+ self.csv_writer = None
43
+ self.csv_file = None
44
+
45
+ os.makedirs(self.logs_dir, exist_ok=True)
46
+ csv_path = os.path.join(self.logs_dir, "log.csv")
47
+ self.csv_file = open(csv_path, 'w', newline='')
48
+ self.csv_writer = csv.DictWriter(self.csv_file, fieldnames=self.columns, delimiter='\t')
49
+ self.csv_writer.writeheader()
50
+
51
+ if self.record:
52
+ self.video_dir = os.path.join(self.logs_dir, "videos")
53
+ os.makedirs(self.video_dir, exist_ok=True)
54
+
55
+ self._callbacks_initialized = False
56
+ self._callbacks_paused = False
57
+ self._callbacks_paused = False
58
+
59
+ def _setup_callbacks(self):
60
+ if self._callbacks_initialized or self.render_freq <= 0:
61
+ return
62
+
63
+ if self.is_vector:
64
+ # AsyncVectorEnv doesn't provide direct access to individual envs
65
+ # Callbacks cannot be set up for vectorized environments
66
+ pass
67
+ else:
68
+ def callback():
69
+ self._capture_frame_single()
70
+ actual_env = self.env.unwrapped
71
+ actual_env._render_callback = callback
72
+ actual_env._render_freq = self.render_freq
73
+ actual_env._should_render = actual_env.render_mode == 'human'
74
+
75
+ self._callbacks_initialized = True
76
+
77
+ def _capture_frame_single(self):
78
+ if self.render_freq <= 0 or self._callbacks_paused:
79
+ return
80
+
81
+ try:
82
+ actual_env = self.env.unwrapped
83
+
84
+ if actual_env._should_render:
85
+ actual_env.render()
86
+
87
+ if self.record:
88
+ frame_ego = actual_env.robot.get_camera_view(actual_env.robot.camera_names[0])
89
+ frame_exo = actual_env.robot.get_camera_view(actual_env.robot.camera_names[1])
90
+ frame = np.concatenate((frame_ego, frame_exo), axis=1)
91
+ frame = self._process_frame(frame)
92
+ self.video_frames.append(frame)
93
+ except Exception as e:
94
+ pass
95
+
96
+ def _capture_frame_vectorized(self, env_idx):
97
+ if self.render_freq <= 0 or self._callbacks_paused:
98
+ return
99
+
100
+ try:
101
+ env = self.env.envs[env_idx]
102
+ actual_env = env.unwrapped
103
+
104
+ if actual_env._should_render:
105
+ actual_env.render()
106
+
107
+ if self.record:
108
+ frame_ego = actual_env.robot.get_camera_view(actual_env.robot.camera_names[0])
109
+ frame_exo = actual_env.robot.get_camera_view(actual_env.robot.camera_names[1])
110
+ frame = np.concatenate((frame_ego, frame_exo), axis=1)
111
+ frame = self._process_frame(frame)
112
+ self.video_frames[env_idx].append(frame)
113
+ except Exception as e:
114
+ pass
115
+
116
+ def _process_frame(self, frame):
117
+ if isinstance(frame, np.ndarray):
118
+ if frame.dtype != np.uint8:
119
+ frame = (frame * 255).astype(np.uint8) if frame.max() <= 1.0 else frame.astype(np.uint8)
120
+ if frame.shape[-1] == 3:
121
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
122
+ return frame
123
+
124
+ def reset(self, **kwargs):
125
+ self._callbacks_paused = True
126
+
127
+ if self.is_vector:
128
+ options = kwargs.get('options', {})
129
+ reset_mask = options.get('reset_mask', np.ones(self.num_envs, dtype=bool))
130
+
131
+ obs, info = self.env.reset(**kwargs)
132
+ self._setup_callbacks()
133
+
134
+ for i in np.where(reset_mask)[0]:
135
+ if self.steps[i] == 0 and self.record:
136
+ self.video_frames[i] = []
137
+ self.max_reward[i] = 0
138
+ self.steps[i] = 0
139
+ self.episode_info[i] = {}
140
+
141
+ self._callbacks_paused = False
142
+ else:
143
+ if self.steps > 0:
144
+ if self.csv_writer:
145
+ self._log_episode()
146
+ if self.record:
147
+ self._save_video()
148
+ self.episode_idx += 1
149
+ obs, info = self.env.reset(**kwargs)
150
+ self._setup_callbacks()
151
+
152
+ self.max_reward = 0
153
+ self.steps = 0
154
+ self.episode_info = {}
155
+ if self.record:
156
+ self.video_frames = []
157
+ self._callbacks_paused = False
158
+ return obs, info
159
+
160
+ def step(self, action):
161
+ obs, reward, terminated, truncated, info = self.env.step(action)
162
+
163
+ if self.render_freq == 0:
164
+ if self.is_vector:
165
+ # AsyncVectorEnv doesn't provide direct access to individual envs
166
+ # Recording from observations if available
167
+ if self.record and 'rgb_ego' in obs and 'rgb_exo' in obs:
168
+ for i in range(self.num_envs):
169
+ frame_ego = obs['rgb_ego'][i]
170
+ frame_exo = obs['rgb_exo'][i]
171
+ frame = np.concatenate((frame_ego, frame_exo), axis=1)
172
+ frame = self._process_frame(frame)
173
+ self.video_frames[i].append(frame)
174
+ else:
175
+ actual_env = self.env.unwrapped
176
+ if actual_env._should_render:
177
+ actual_env.render()
178
+ if self.record and 'rgb_ego' in obs and 'rgb_exo' in obs:
179
+ frame_ego = obs['rgb_ego']
180
+ frame_exo = obs['rgb_exo']
181
+ frame = np.concatenate((frame_ego, frame_exo), axis=1)
182
+ frame = self._process_frame(frame)
183
+ self.video_frames.append(frame)
184
+
185
+ if self.is_vector:
186
+ self.max_reward = np.maximum(self.max_reward, reward)
187
+ self.steps += 1
188
+ for i in range(self.num_envs):
189
+ if isinstance(info, dict):
190
+ env_info = {}
191
+ for k, v in info.items():
192
+ if k.startswith('_'):
193
+ continue
194
+ if isinstance(v, (list, np.ndarray)) and len(v) > i:
195
+ env_info[k] = v[i]
196
+ else:
197
+ env_info[k] = v
198
+ else:
199
+ env_info = info[i] if isinstance(info, (list, tuple)) else {}
200
+ self.episode_info[i].update(env_info)
201
+
202
+ dones = np.logical_or(terminated, truncated)
203
+ if isinstance(dones, bool):
204
+ dones = np.array([dones] * self.num_envs)
205
+ else:
206
+ self.max_reward = max(self.max_reward, reward)
207
+ self.steps += 1
208
+ self.episode_info.update(info)
209
+ info['episode_stats'] = {'episode': self.episode_idx, 'max_reward': self.max_reward, 'steps': self.steps}
210
+
211
+ return obs, reward, terminated, truncated, info
212
+
213
+ def _log_episode(self, env_idx=None):
214
+ if env_idx is not None:
215
+ log_data = {"episode": self.global_episode_counter, "max_reward": float(self.max_reward[env_idx]), "steps": int(self.steps[env_idx]), **self.episode_info[env_idx]}
216
+ self.global_episode_counter += 1
217
+ else:
218
+ log_data = {"episode": self.episode_idx, "max_reward": self.max_reward, "steps": self.steps, **self.episode_info}
219
+
220
+ filtered_data = {}
221
+ for col in self.columns:
222
+ value = log_data.get(col, "")
223
+ if isinstance(value, np.ndarray):
224
+ value = value.flatten().tolist()
225
+ filtered_data[col] = value
226
+
227
+ self.csv_writer.writerow(filtered_data)
228
+ self.csv_file.flush()
229
+
230
+ def log_episodes(self, env_indices):
231
+ if self.is_vector:
232
+ for i in env_indices:
233
+ if self.csv_writer:
234
+ self._log_episode(i)
235
+ if self.record:
236
+ self._save_video(i)
237
+ self.video_frames[i] = []
238
+ self.episode_idx[i] += 1
239
+
240
+ def _save_video(self, env_idx=None):
241
+ if env_idx is not None:
242
+ frames = self.video_frames[env_idx]
243
+ episode_num = int(self.episode_idx[env_idx])
244
+ video_filename = f"env_{env_idx}_episode_{episode_num}.mp4"
245
+ else:
246
+ frames = self.video_frames
247
+ episode_num = self.episode_idx
248
+ video_filename = f"episode_{episode_num}.mp4"
249
+
250
+ if not frames:
251
+ return
252
+
253
+ video_path = os.path.join(self.video_dir, video_filename)
254
+ height, width = frames[0].shape[:2]
255
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
256
+ out = cv2.VideoWriter(video_path, fourcc, 30, (width, height))
257
+
258
+ for frame in frames:
259
+ out.write(frame)
260
+
261
+ out.release()
262
+
263
+ def close(self):
264
+ if self.is_vector:
265
+ for i in range(self.num_envs):
266
+ if self.steps[i] > 0:
267
+ if self.csv_writer:
268
+ self._log_episode(i)
269
+ if self.record:
270
+ self._save_video(i)
271
+ else:
272
+ if self.steps > 0:
273
+ if self.csv_writer:
274
+ self._log_episode()
275
+ if self.record:
276
+ self._save_video()
277
+ if self.csv_file:
278
+ self.csv_file.close()
279
+ self.env.close()
280
+
281
+ def __getattr__(self, name):
282
+ return getattr(self.env, name)
@@ -0,0 +1,163 @@
1
+ import os
2
+ import numpy as np
3
+ import time
4
+ import re
5
+ import io
6
+ import json
7
+ import base64
8
+ from PIL import Image
9
+ from egogym.utils import pixel_to_world
10
+
11
+ MODEL = "gpt-4o"
12
+ PROMPT = """Get all points matching the following object: {object_name}. The label returned should be an identifying name for the object detected.
13
+ The answer should follow the json format: [{{"point": [y, x], "label": "{object_name}"}}, ...]. The points are in [y, x] format normalized to 0-1000."""
14
+
15
+ class UnprivilegedChatGPT:
16
+
17
+ def __init__(self, env, api_key=None):
18
+ try:
19
+ from openai import OpenAI
20
+ except ImportError:
21
+ raise ImportError("openai package is not installed. Please install it with `pip install openai`.")
22
+
23
+ self.env = env
24
+ self.unprivileged_T_world_object = None
25
+
26
+ if api_key is None:
27
+ api_key = os.environ.get("OPENAI_API_KEY")
28
+
29
+
30
+ self.client = OpenAI(api_key=api_key)
31
+
32
+ def _call_api_with_retry(self, messages, max_retries=3, base_delay=2):
33
+ """Call OpenAI API with retry logic for transient errors."""
34
+ for attempt in range(max_retries):
35
+ try:
36
+ return self.client.chat.completions.create(
37
+ model=MODEL,
38
+ messages=messages,
39
+ temperature=0.5,
40
+ max_tokens=1024
41
+ )
42
+ except Exception as e:
43
+ error_code = getattr(e, 'status_code', None)
44
+ if error_code in [503, 429] and attempt < max_retries - 1:
45
+ delay = base_delay * (2 ** attempt)
46
+ print(f"API error {error_code}: {str(e)}. Retrying in {delay}s... (attempt {attempt + 1}/{max_retries})")
47
+ time.sleep(delay)
48
+ continue
49
+ else:
50
+ raise
51
+
52
+ def _parse_json_point_yx(self, text):
53
+ """Parse JSON format with point coordinates in [y, x] format normalized to 0-1000."""
54
+ try:
55
+ json_match = re.search(r'```(?:json)?\s*(\[.*?\])\s*```', text, re.DOTALL)
56
+ if json_match:
57
+ json_str = json_match.group(1)
58
+ else:
59
+ json_match = re.search(r'(\[.*?\])', text, re.DOTALL)
60
+ if json_match:
61
+ json_str = json_match.group(1)
62
+ else:
63
+ return None
64
+
65
+ data = json.loads(json_str)
66
+ if not isinstance(data, list) or len(data) == 0:
67
+ return None
68
+
69
+ point_data = data[0]
70
+ if 'point' in point_data:
71
+ point = point_data['point']
72
+ if isinstance(point, list) and len(point) >= 2:
73
+ return (float(point[0]), float(point[1]))
74
+
75
+ return None
76
+
77
+ except (json.JSONDecodeError, KeyError, ValueError, IndexError) as e:
78
+ return None
79
+
80
+ def _infer_point_sync(self, rgb: np.ndarray, object_name: str) -> np.ndarray:
81
+ """Use OpenAI GPT-4o API to locate object in image and return normalized coordinates."""
82
+ if rgb.dtype == np.float32 or rgb.dtype == np.float64:
83
+ rgb = (rgb * 255).astype(np.uint8)
84
+ image = Image.fromarray(rgb)
85
+ img_width, img_height = image.size
86
+
87
+ prompt = PROMPT.format(object_name=object_name)
88
+
89
+ try:
90
+ buf = io.BytesIO()
91
+ image.save(buf, format='PNG')
92
+ image_bytes = buf.getvalue()
93
+ base64_image = base64.b64encode(image_bytes).decode('utf-8')
94
+
95
+ messages = [
96
+ {
97
+ "role": "user",
98
+ "content": [
99
+ {
100
+ "type": "image_url",
101
+ "image_url": {
102
+ "url": f"data:image/png;base64,{base64_image}"
103
+ }
104
+ },
105
+ {
106
+ "type": "text",
107
+ "text": prompt
108
+ }
109
+ ]
110
+ }
111
+ ]
112
+
113
+ response = self._call_api_with_retry(messages)
114
+ generated_text = response.choices[0].message.content
115
+
116
+ json_point = self._parse_json_point_yx(generated_text)
117
+ if json_point:
118
+ y_pos, x_pos = json_point
119
+ x_norm = float(x_pos) / 1000.0
120
+ y_norm = float(y_pos) / 1000.0
121
+
122
+ x_norm = max(0.0, min(1.0, x_norm))
123
+ y_norm = max(0.0, min(1.0, y_norm))
124
+
125
+ return np.array([x_norm, y_norm], dtype=np.float32)
126
+
127
+ return np.array([0.5, 0.5], dtype=np.float32)
128
+
129
+ except Exception as e:
130
+ return np.array([0.5, 0.5], dtype=np.float32)
131
+
132
+ def reset(self, **kwargs):
133
+ obs, info = self.env.reset(**kwargs)
134
+
135
+ robot = self.env.unwrapped.robot
136
+ robot.rgb_renderers[robot.camera_names[0]].enable_depth_rendering()
137
+ robot.rgb_renderers[robot.camera_names[0]].update_scene(robot.data, robot.camera_names[0])
138
+ depth = robot.rgb_renderers[robot.camera_names[0]].render()
139
+ robot.rgb_renderers[robot.camera_names[0]].disable_depth_rendering()
140
+
141
+ point_norm = self._infer_point_sync(obs["rgb_ego"], obs["object_name"])
142
+ x_norm, y_norm = point_norm
143
+ x = int(x_norm * self.env.unwrapped.render_width)
144
+ y = int(y_norm * self.env.unwrapped.render_height)
145
+ depth_value = depth[y, x] + 0.03
146
+ self.unprivileged_T_world_object = pixel_to_world(x, y, depth_value, self.env.unwrapped.model, self.env.unwrapped.data, robot.camera_names[0], self.env.unwrapped.render_width, self.env.unwrapped.render_height)
147
+
148
+ return obs, info
149
+
150
+ def step(self, action):
151
+ obs, reward, terminated, truncated, info = self.env.step(action)
152
+
153
+ object_pose = obs["object_pose"].reshape(4,4)
154
+ object_pose[:3, 3] = self.unprivileged_T_world_object
155
+ obs["object_pose"] = object_pose.flatten()
156
+
157
+ return obs, reward, terminated, truncated, info
158
+
159
+ def close(self):
160
+ self.env.close()
161
+
162
+ def __getattr__(self, name):
163
+ return getattr(self.env, name)
@@ -0,0 +1,157 @@
1
+ import os
2
+ import numpy as np
3
+ import time
4
+ import re
5
+ import io
6
+ import json
7
+ from PIL import Image
8
+ from egogym.utils import pixel_to_world
9
+
10
+ MODEL = "gemini-robotics-er-1.5-preview"
11
+ PROMPT = """Get all points matching the following object: {object_name}. The label returned should be an identifying name for the object detected.
12
+ The answer should follow the json format: [{{"point": [y, x], "label": "{object_name}"}}, ...]. The points are in [y, x] format normalized to 0-1000."""
13
+
14
+ class UnprivilegedGemini:
15
+
16
+ def __init__(self, env, api_key=None):
17
+ try:
18
+ import google.genai
19
+ from google.genai import errors, types
20
+ except ImportError:
21
+ raise ImportError("google-genai package is not installed. Please install it with `pip install google-genai`.")
22
+
23
+
24
+ self.env = env
25
+ self.unprivileged_T_world_object = None
26
+
27
+ if api_key is None:
28
+ api_key = os.getenv("GEMINI_API_KEY", "AIzaSyBQmaK2VE3y8qxOGOEliQsveTDpfK2LyLc")
29
+
30
+ self.client = genai.Client(api_key=api_key)
31
+
32
+ def _call_api_with_retry(self, model, contents, config, max_retries=3, base_delay=2):
33
+ """Call Gemini API with retry logic for transient errors."""
34
+ for attempt in range(max_retries):
35
+ try:
36
+ return self.client.models.generate_content(
37
+ model=model,
38
+ contents=contents,
39
+ config=config
40
+ )
41
+ except errors.ServerError as e:
42
+ if e.code in [503, 429] and attempt < max_retries - 1:
43
+ delay = base_delay * (2 ** attempt)
44
+ error_msg = e.message if e.message else "Service unavailable"
45
+ print(f"API error {e.code}: {error_msg}. Retrying in {delay}s... (attempt {attempt + 1}/{max_retries})")
46
+ time.sleep(delay)
47
+ continue
48
+ else:
49
+ raise
50
+ except Exception as e:
51
+ raise
52
+
53
+ def _parse_json_point_yx(self, text):
54
+ """Parse JSON format with point coordinates in [y, x] format normalized to 0-1000."""
55
+ try:
56
+ json_match = re.search(r'```(?:json)?\s*(\[.*?\])\s*```', text, re.DOTALL)
57
+ if json_match:
58
+ json_str = json_match.group(1)
59
+ else:
60
+ json_match = re.search(r'(\[.*?\])', text, re.DOTALL)
61
+ if json_match:
62
+ json_str = json_match.group(1)
63
+ else:
64
+ return None
65
+
66
+ data = json.loads(json_str)
67
+ if not isinstance(data, list) or len(data) == 0:
68
+ return None
69
+
70
+ point_data = data[0]
71
+ if 'point' in point_data:
72
+ point = point_data['point']
73
+ if isinstance(point, list) and len(point) >= 2:
74
+ return (float(point[0]), float(point[1]))
75
+
76
+ return None
77
+
78
+ except (json.JSONDecodeError, KeyError, ValueError, IndexError) as e:
79
+ return None
80
+
81
+ def _infer_point_sync(self, rgb: np.ndarray, object_name: str) -> np.ndarray:
82
+ """Use Gemini API to locate object in image and return normalized coordinates."""
83
+ if rgb.dtype == np.float32 or rgb.dtype == np.float64:
84
+ rgb = (rgb * 255).astype(np.uint8)
85
+ image = Image.fromarray(rgb)
86
+ img_width, img_height = image.size
87
+
88
+ prompt = PROMPT.format(object_name=object_name)
89
+
90
+ try:
91
+ buf = io.BytesIO()
92
+ image.save(buf, format='PNG')
93
+ image_bytes = buf.getvalue()
94
+
95
+ contents = [
96
+ types.Part.from_bytes(
97
+ data=image_bytes,
98
+ mime_type='image/png',
99
+ ),
100
+ prompt
101
+ ]
102
+
103
+ config = types.GenerateContentConfig(
104
+ temperature=0.5,
105
+ thinking_config=types.ThinkingConfig(thinking_budget=0)
106
+ )
107
+ response = self._call_api_with_retry(MODEL, contents, config)
108
+ generated_text = response.text
109
+
110
+ json_point = self._parse_json_point_yx(generated_text)
111
+ if json_point:
112
+ y_pos, x_pos = json_point
113
+ x_norm = float(x_pos) / 1000.0
114
+ y_norm = float(y_pos) / 1000.0
115
+
116
+ x_norm = max(0.0, min(1.0, x_norm))
117
+ y_norm = max(0.0, min(1.0, y_norm))
118
+
119
+ return np.array([x_norm, y_norm], dtype=np.float32)
120
+
121
+ return np.array([0.5, 0.5], dtype=np.float32)
122
+
123
+ except Exception as e:
124
+ return np.array([0.5, 0.5], dtype=np.float32)
125
+
126
+ def reset(self, **kwargs):
127
+ obs, info = self.env.reset(**kwargs)
128
+
129
+ robot = self.env.unwrapped.robot
130
+ robot.rgb_renderers[robot.camera_names[0]].enable_depth_rendering()
131
+ robot.rgb_renderers[robot.camera_names[0]].update_scene(robot.data, robot.camera_names[0])
132
+ depth = robot.rgb_renderers[robot.camera_names[0]].render()
133
+ robot.rgb_renderers[robot.camera_names[0]].disable_depth_rendering()
134
+
135
+ point_norm = self._infer_point_sync(obs["rgb_ego"], obs["object_name"])
136
+ x_norm, y_norm = point_norm
137
+ x = int(x_norm * self.env.unwrapped.render_width)
138
+ y = int(y_norm * self.env.unwrapped.render_height)
139
+ depth_value = depth[y, x] + 0.03
140
+ self.unprivileged_T_world_object = pixel_to_world(x, y, depth_value, self.env.unwrapped.model, self.env.unwrapped.data, robot.camera_names[0], self.env.unwrapped.render_width, self.env.unwrapped.render_height)
141
+
142
+ return obs, info
143
+
144
+ def step(self, action):
145
+ obs, reward, terminated, truncated, info = self.env.step(action)
146
+
147
+ object_pose = obs["object_pose"].reshape(4,4)
148
+ object_pose[:3, 3] = self.unprivileged_T_world_object
149
+ obs["object_pose"] = object_pose.flatten()
150
+
151
+ return obs, reward, terminated, truncated, info
152
+
153
+ def close(self):
154
+ self.env.close()
155
+
156
+ def __getattr__(self, name):
157
+ return getattr(self.env, name)