egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
egogym/evaluate.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import cv2
|
|
4
|
+
import gymnasium as gym
|
|
5
|
+
import numpy as np
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
from itertools import cycle, islice
|
|
8
|
+
import multiprocessing
|
|
9
|
+
import json
|
|
10
|
+
|
|
11
|
+
import egogym
|
|
12
|
+
from egogym.wrappers import EpisodeMonitor, get_vlm_wrapper
|
|
13
|
+
import egogym.assets.constants as constants
|
|
14
|
+
|
|
15
|
+
def evaluate_policy(
|
|
16
|
+
task_name="pick",
|
|
17
|
+
policy=None,
|
|
18
|
+
robot="rum",
|
|
19
|
+
action_space="delta",
|
|
20
|
+
num_objs=1,
|
|
21
|
+
num_episodes=100,
|
|
22
|
+
max_steps=50,
|
|
23
|
+
render=False,
|
|
24
|
+
record=False,
|
|
25
|
+
render_freq=0,
|
|
26
|
+
render_size=(299, 224),
|
|
27
|
+
num_envs=1,
|
|
28
|
+
seed=42,
|
|
29
|
+
log_file=None,
|
|
30
|
+
objects_set=None,
|
|
31
|
+
reward_threshold=None,
|
|
32
|
+
use_unprivileged_vlm=None,
|
|
33
|
+
logs_dir="logs"
|
|
34
|
+
):
|
|
35
|
+
|
|
36
|
+
num_envs = max(1, min(num_envs, num_episodes))
|
|
37
|
+
if num_envs > 1 and render_freq > 0:
|
|
38
|
+
print(f"Warning: render_freq={render_freq} is not supported with multiple environments. Setting render_freq=0.")
|
|
39
|
+
render_freq = 0
|
|
40
|
+
task_name = task_name.lower()
|
|
41
|
+
if objects_set is None:
|
|
42
|
+
if task_name == "pick":
|
|
43
|
+
objects_set = constants.lite_pick_objects_set
|
|
44
|
+
elif task_name == "open":
|
|
45
|
+
objects_set = constants.all_open_objects_set
|
|
46
|
+
else:
|
|
47
|
+
objects_set_options = {
|
|
48
|
+
'all_pick': constants.all_pick_objects_set,
|
|
49
|
+
'lite_pick': constants.lite_pick_objects_set,
|
|
50
|
+
'diverse_pick': constants.diverse_pick_objects_set,
|
|
51
|
+
'full_eval': constants.full_eval_objects_set,
|
|
52
|
+
'all_open': constants.all_open_objects_set,
|
|
53
|
+
"all_close": constants.all_close_objects_set,
|
|
54
|
+
'cabinet': constants.cabinet_objects_set,
|
|
55
|
+
'drawer': constants.drawer_objects_set,
|
|
56
|
+
}
|
|
57
|
+
if objects_set in objects_set_options:
|
|
58
|
+
objects_set = objects_set_options[objects_set]
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(f"Unknown objects_set: {objects_set}")
|
|
61
|
+
|
|
62
|
+
if reward_threshold is None:
|
|
63
|
+
if task_name == "pick":
|
|
64
|
+
reward_threshold = 0.05
|
|
65
|
+
elif task_name == "open":
|
|
66
|
+
reward_threshold = 0.5
|
|
67
|
+
|
|
68
|
+
if not log_file:
|
|
69
|
+
log_file = f"evaluation_{int(time.time())}"
|
|
70
|
+
|
|
71
|
+
logs_dir = f"{logs_dir}/{log_file}"
|
|
72
|
+
os.makedirs(logs_dir, exist_ok=True)
|
|
73
|
+
|
|
74
|
+
eval_args = {
|
|
75
|
+
"task_name": task_name,
|
|
76
|
+
"robot": robot,
|
|
77
|
+
"action_space": action_space,
|
|
78
|
+
"num_objs": num_objs,
|
|
79
|
+
"num_episodes": num_episodes,
|
|
80
|
+
"max_steps": max_steps,
|
|
81
|
+
"render": render,
|
|
82
|
+
"record": record,
|
|
83
|
+
"render_freq": render_freq,
|
|
84
|
+
"render_size": render_size,
|
|
85
|
+
"num_envs": num_envs,
|
|
86
|
+
"seed": seed,
|
|
87
|
+
"log_file": log_file,
|
|
88
|
+
"objects_set": objects_set,
|
|
89
|
+
"reward_threshold": reward_threshold,
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
with open(os.path.join(logs_dir, "eval_args.json"), "w") as f:
|
|
93
|
+
json.dump(eval_args, f, indent=4)
|
|
94
|
+
|
|
95
|
+
if num_envs > 1:
|
|
96
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
97
|
+
|
|
98
|
+
expanded_objects = list(islice(cycle(objects_set), num_episodes))
|
|
99
|
+
env_objects = [expanded_objects[i::num_envs] for i in range(num_envs)]
|
|
100
|
+
|
|
101
|
+
def make_env(env_id):
|
|
102
|
+
def _init():
|
|
103
|
+
env = gym.make(
|
|
104
|
+
f"Egogym-{task_name.capitalize()}-v0",
|
|
105
|
+
robot=robot,
|
|
106
|
+
action_space=action_space,
|
|
107
|
+
render_mode='human' if render else None,
|
|
108
|
+
render_size=render_size,
|
|
109
|
+
num_objs=num_objs,
|
|
110
|
+
seed=seed + env_id,
|
|
111
|
+
objects_set=env_objects[env_id]
|
|
112
|
+
)
|
|
113
|
+
if use_unprivileged_vlm is not None:
|
|
114
|
+
wrapper_cls = get_vlm_wrapper(use_unprivileged_vlm)
|
|
115
|
+
env = wrapper_cls(env)
|
|
116
|
+
return env
|
|
117
|
+
return _init
|
|
118
|
+
|
|
119
|
+
env = gym.vector.AsyncVectorEnv(
|
|
120
|
+
[make_env(i) for i in range(num_envs)],
|
|
121
|
+
autoreset_mode=gym.vector.AutoresetMode.DISABLED
|
|
122
|
+
)
|
|
123
|
+
env = EpisodeMonitor(env, logs_dir=logs_dir, record=record, render_freq=render_freq, num_envs=num_envs)
|
|
124
|
+
obs, _ = env.reset()
|
|
125
|
+
env_episode_count = np.zeros(num_envs, dtype=int)
|
|
126
|
+
env_step_count = np.zeros(num_envs, dtype=int)
|
|
127
|
+
env_success_count = 0
|
|
128
|
+
episodes_per_env = np.array([len(env_objects[i]) for i in range(num_envs)])
|
|
129
|
+
|
|
130
|
+
pbar = tqdm(total=num_episodes, desc="Evaluating | Success Rate: 0.0%")
|
|
131
|
+
|
|
132
|
+
while np.any(env_episode_count < episodes_per_env):
|
|
133
|
+
actions = policy.get_action(obs)
|
|
134
|
+
inactive = env_episode_count >= episodes_per_env
|
|
135
|
+
actions[inactive] = None
|
|
136
|
+
|
|
137
|
+
obs, rewards, terminated, truncated, info = env.step(actions)
|
|
138
|
+
env_step_count += 1
|
|
139
|
+
|
|
140
|
+
if render:
|
|
141
|
+
ego_exo = np.concatenate([obs["rgb_ego"], obs["rgb_exo"]], axis=2)
|
|
142
|
+
combined_view = np.concatenate(ego_exo, axis=0)
|
|
143
|
+
cv2.imshow("Ego View", cv2.cvtColor(combined_view, cv2.COLOR_RGB2BGR))
|
|
144
|
+
cv2.waitKey(1)
|
|
145
|
+
|
|
146
|
+
dones = (rewards > reward_threshold) | terminated | (env_step_count >= max_steps)
|
|
147
|
+
done_indices = np.where(dones)[0]
|
|
148
|
+
|
|
149
|
+
active_done = done_indices[env_episode_count[done_indices] < episodes_per_env[done_indices]]
|
|
150
|
+
|
|
151
|
+
if len(active_done) > 0:
|
|
152
|
+
env.log_episodes(active_done)
|
|
153
|
+
|
|
154
|
+
successes = rewards[active_done] > reward_threshold
|
|
155
|
+
env_success_count += np.sum(successes)
|
|
156
|
+
|
|
157
|
+
env_episode_count[active_done] += 1
|
|
158
|
+
pbar.update(len(active_done))
|
|
159
|
+
|
|
160
|
+
total_completed = np.sum(env_episode_count)
|
|
161
|
+
success_rate = (env_success_count / total_completed) * 100 if total_completed > 0 else 0
|
|
162
|
+
pbar.set_description(f"Evaluating | Success Rate: {success_rate:.1f}%")
|
|
163
|
+
|
|
164
|
+
policy.reset(active_done.tolist())
|
|
165
|
+
|
|
166
|
+
reset_mask = np.zeros(num_envs, dtype=bool)
|
|
167
|
+
reset_mask[active_done] = True
|
|
168
|
+
obs, _ = env.reset(options={"reset_mask": reset_mask})
|
|
169
|
+
env_step_count[active_done] = 0
|
|
170
|
+
|
|
171
|
+
pbar.close()
|
|
172
|
+
success_rate = (env_success_count / num_episodes) * 100
|
|
173
|
+
return success_rate
|
|
174
|
+
|
|
175
|
+
else:
|
|
176
|
+
|
|
177
|
+
env = gym.make(
|
|
178
|
+
f"Egogym-{task_name.capitalize()}-v0",
|
|
179
|
+
robot=robot,
|
|
180
|
+
action_space=action_space,
|
|
181
|
+
render_mode='human' if render else None,
|
|
182
|
+
render_size=render_size,
|
|
183
|
+
num_objs=num_objs,
|
|
184
|
+
seed=seed,
|
|
185
|
+
objects_set=objects_set
|
|
186
|
+
)
|
|
187
|
+
if use_unprivileged_vlm is not None:
|
|
188
|
+
wrapper_cls = get_vlm_wrapper(use_unprivileged_vlm)
|
|
189
|
+
env = wrapper_cls(env)
|
|
190
|
+
env = EpisodeMonitor(env, logs_dir=logs_dir, record=record, render_freq=render_freq, num_envs=num_envs)
|
|
191
|
+
|
|
192
|
+
success_count = 0
|
|
193
|
+
pbar = tqdm(range(num_episodes), desc="Evaluating | Success Rate: 0.0%")
|
|
194
|
+
for episode_idx in pbar:
|
|
195
|
+
obs, _ = env.reset()
|
|
196
|
+
policy.reset()
|
|
197
|
+
|
|
198
|
+
success = False
|
|
199
|
+
for _ in range(max_steps):
|
|
200
|
+
action = policy.get_action(obs)
|
|
201
|
+
obs, reward, terminated, truncated, _ = env.step(action[0])
|
|
202
|
+
|
|
203
|
+
if reward > reward_threshold:
|
|
204
|
+
success = True
|
|
205
|
+
break
|
|
206
|
+
if terminated or truncated:
|
|
207
|
+
break
|
|
208
|
+
|
|
209
|
+
if success:
|
|
210
|
+
success_count += 1
|
|
211
|
+
|
|
212
|
+
success_rate = (success_count / (episode_idx + 1)) * 100
|
|
213
|
+
pbar.set_description(f"Evaluating | Success Rate: {success_rate:.1f}%")
|
|
214
|
+
|
|
215
|
+
env.close()
|
|
216
|
+
return (success_count / num_episodes) * 100 if num_envs == 1 else success_rate
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from egogym.utils import make_cabinet_xml, make_drawer_xml
|
|
2
|
+
from egogym.components import Object, ArticulableObject
|
|
3
|
+
|
|
4
|
+
class ObjectsManager:
|
|
5
|
+
def __init__(self, objects_dict, np_random, shuffle=True):
|
|
6
|
+
self.np_random = np_random
|
|
7
|
+
sorted_objects = sorted(objects_dict, key=lambda x: x[0])
|
|
8
|
+
if shuffle:
|
|
9
|
+
np_random.shuffle(sorted_objects)
|
|
10
|
+
self.objects_dict = sorted_objects
|
|
11
|
+
self.index = -1
|
|
12
|
+
|
|
13
|
+
def sample(self, random=False):
|
|
14
|
+
if random:
|
|
15
|
+
index = self.np_random.integers(len(self.objects_dict))
|
|
16
|
+
else:
|
|
17
|
+
self.index = (self.index + 1) % len(self.objects_dict)
|
|
18
|
+
index = self.index
|
|
19
|
+
if self.objects_dict[index][1] == "procedurally_generated":
|
|
20
|
+
return ArticulableObject(
|
|
21
|
+
name=self.objects_dict[index][0],
|
|
22
|
+
func=globals()[f"make_{self.objects_dict[index][0]}_xml"],
|
|
23
|
+
np_random=self.np_random,
|
|
24
|
+
)
|
|
25
|
+
else:
|
|
26
|
+
return Object(
|
|
27
|
+
name=self.objects_dict[index][0],
|
|
28
|
+
path=self.objects_dict[index][1],
|
|
29
|
+
np_random=self.np_random,
|
|
30
|
+
)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from egogym.components import Object
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TexturesManager:
|
|
7
|
+
def __init__(self, textures_list, np_random):
|
|
8
|
+
self.textures_list = textures_list
|
|
9
|
+
self.np_random = np_random
|
|
10
|
+
sorted_textures = sorted(textures_list)
|
|
11
|
+
self.np_random.shuffle(sorted_textures)
|
|
12
|
+
self.textures_list = sorted_textures
|
|
13
|
+
self.index = -1
|
|
14
|
+
|
|
15
|
+
def sample(self, random=False):
|
|
16
|
+
if random:
|
|
17
|
+
index = self.np_random.integers(len(self.textures_list))
|
|
18
|
+
else:
|
|
19
|
+
self.index = (self.index + 1) % len(self.textures_list)
|
|
20
|
+
index = self.index
|
|
21
|
+
return self.textures_list[index]
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import numpy as np
|
|
3
|
+
import websockets
|
|
4
|
+
|
|
5
|
+
class MolmoClient:
|
|
6
|
+
def __init__(self, host="localhost", port=8765):
|
|
7
|
+
self.uri = f"ws://{host}:{port}"
|
|
8
|
+
self.websocket = None
|
|
9
|
+
|
|
10
|
+
async def connect(self):
|
|
11
|
+
self.websocket = await websockets.connect(
|
|
12
|
+
self.uri,
|
|
13
|
+
max_size=50 * 1024 * 1024,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
async def disconnect(self):
|
|
17
|
+
if self.websocket:
|
|
18
|
+
await self.websocket.close()
|
|
19
|
+
|
|
20
|
+
async def infer_point(self, rgb: np.ndarray, object_name: str = None, prompt: str = None):
|
|
21
|
+
if not self.websocket:
|
|
22
|
+
raise RuntimeError("Not connected to server. Call connect() first.")
|
|
23
|
+
|
|
24
|
+
req = {
|
|
25
|
+
"action": "infer_point",
|
|
26
|
+
"rgb": rgb.tolist(),
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
if object_name:
|
|
30
|
+
req["object_name"] = object_name
|
|
31
|
+
elif prompt:
|
|
32
|
+
req["prompt"] = prompt
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError("Provide either 'object_name' or 'prompt'")
|
|
35
|
+
|
|
36
|
+
await self.websocket.send(json.dumps(req))
|
|
37
|
+
resp = json.loads(await self.websocket.recv())
|
|
38
|
+
|
|
39
|
+
if resp["status"] == "error":
|
|
40
|
+
raise RuntimeError(f"Server error: {resp['message']}")
|
|
41
|
+
|
|
42
|
+
return np.array(resp["point"], dtype=np.float32)
|
|
43
|
+
|
|
44
|
+
async def __aenter__(self):
|
|
45
|
+
await self.connect()
|
|
46
|
+
return self
|
|
47
|
+
|
|
48
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
49
|
+
await self.disconnect()
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import asyncio
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import re
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import websockets
|
|
9
|
+
from PIL import Image
|
|
10
|
+
|
|
11
|
+
logging.basicConfig(level=logging.INFO)
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def extract_molmo_points(molmo_output: str):
|
|
16
|
+
points = []
|
|
17
|
+
for match in re.finditer(
|
|
18
|
+
r'x\d*="\s*([0-9]+(?:\.[0-9]+)?)"\s+y\d*="\s*([0-9]+(?:\.[0-9]+)?)"',
|
|
19
|
+
molmo_output,
|
|
20
|
+
):
|
|
21
|
+
try:
|
|
22
|
+
p = np.array([float(match.group(1)), float(match.group(2))], dtype=np.float32)
|
|
23
|
+
except ValueError:
|
|
24
|
+
continue
|
|
25
|
+
|
|
26
|
+
if np.max(p) > 100:
|
|
27
|
+
continue
|
|
28
|
+
|
|
29
|
+
points.append(p / 100.0)
|
|
30
|
+
|
|
31
|
+
return points
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Molmo:
|
|
35
|
+
def __init__(self, model_name="allenai/Molmo-7B-D-0924"):
|
|
36
|
+
try:
|
|
37
|
+
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
|
38
|
+
except ImportError:
|
|
39
|
+
raise ImportError("transformers package is not installed. Please install it with `pip install transformers`.")
|
|
40
|
+
|
|
41
|
+
self.processor = AutoProcessor.from_pretrained(
|
|
42
|
+
model_name,
|
|
43
|
+
trust_remote_code=True,
|
|
44
|
+
torch_dtype="auto",
|
|
45
|
+
device_map="auto",
|
|
46
|
+
)
|
|
47
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
48
|
+
model_name,
|
|
49
|
+
trust_remote_code=True,
|
|
50
|
+
torch_dtype="auto",
|
|
51
|
+
device_map="auto",
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def infer(self, rgb: np.ndarray, prompt: str) -> str:
|
|
55
|
+
image = Image.fromarray(rgb)
|
|
56
|
+
inputs = self.processor.process(images=[image], text=prompt)
|
|
57
|
+
inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
|
|
58
|
+
|
|
59
|
+
with torch.inference_mode():
|
|
60
|
+
output = self.model.generate_from_batch(
|
|
61
|
+
inputs,
|
|
62
|
+
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
|
|
63
|
+
tokenizer=self.processor.tokenizer,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
gen_tokens = output[0, inputs["input_ids"].size(1):]
|
|
67
|
+
return self.processor.tokenizer.decode(gen_tokens, skip_special_tokens=True)
|
|
68
|
+
|
|
69
|
+
def infer_point(self, rgb: np.ndarray, prompt: str) -> np.ndarray:
|
|
70
|
+
text = self.infer(rgb, prompt)
|
|
71
|
+
log.info(f"Molmo output: {text}")
|
|
72
|
+
points = extract_molmo_points(text)
|
|
73
|
+
return points[0] if points else np.array([0.5, 0.5], dtype=np.float32)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class MolmoClient:
|
|
77
|
+
def __init__(self, host="localhost", port=8765):
|
|
78
|
+
self.uri = f"ws://{host}:{port}"
|
|
79
|
+
self.websocket = None
|
|
80
|
+
|
|
81
|
+
async def connect(self):
|
|
82
|
+
self.websocket = await websockets.connect(
|
|
83
|
+
self.uri,
|
|
84
|
+
max_size=50 * 1024 * 1024,
|
|
85
|
+
)
|
|
86
|
+
log.info(f"Connected to Molmo server at {self.uri}")
|
|
87
|
+
|
|
88
|
+
async def disconnect(self):
|
|
89
|
+
if self.websocket:
|
|
90
|
+
await self.websocket.close()
|
|
91
|
+
log.info("Disconnected from Molmo server")
|
|
92
|
+
|
|
93
|
+
async def infer_point(self, rgb: np.ndarray, object_name: str = None, prompt: str = None):
|
|
94
|
+
if not self.websocket:
|
|
95
|
+
raise RuntimeError("Not connected to server. Call connect() first.")
|
|
96
|
+
|
|
97
|
+
req = {
|
|
98
|
+
"action": "infer_point",
|
|
99
|
+
"rgb": rgb.tolist(),
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
if object_name:
|
|
103
|
+
req["object_name"] = object_name
|
|
104
|
+
elif prompt:
|
|
105
|
+
req["prompt"] = prompt
|
|
106
|
+
else:
|
|
107
|
+
raise ValueError("Provide either 'object_name' or 'prompt'")
|
|
108
|
+
|
|
109
|
+
await self.websocket.send(json.dumps(req))
|
|
110
|
+
resp = json.loads(await self.websocket.recv())
|
|
111
|
+
|
|
112
|
+
if resp["status"] == "error":
|
|
113
|
+
raise RuntimeError(f"Server error: {resp['message']}")
|
|
114
|
+
|
|
115
|
+
return np.array(resp["point"], dtype=np.float32)
|
|
116
|
+
|
|
117
|
+
async def __aenter__(self):
|
|
118
|
+
await self.connect()
|
|
119
|
+
return self
|
|
120
|
+
|
|
121
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
122
|
+
await self.disconnect()
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class MolmoWebSocketServer:
|
|
126
|
+
def __init__(self, host="0.0.0.0", port=8765):
|
|
127
|
+
self.host = host
|
|
128
|
+
self.port = port
|
|
129
|
+
self.molmo = Molmo()
|
|
130
|
+
|
|
131
|
+
async def handle_client(self, websocket):
|
|
132
|
+
client_id = id(websocket)
|
|
133
|
+
log.info(f"Client connected: {client_id}")
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
async for msg in websocket:
|
|
137
|
+
try:
|
|
138
|
+
req = json.loads(msg)
|
|
139
|
+
action = req.get("action")
|
|
140
|
+
|
|
141
|
+
if action != "infer_point":
|
|
142
|
+
raise ValueError("Only 'infer_point' action is supported")
|
|
143
|
+
|
|
144
|
+
rgb = np.array(req["rgb"], dtype=np.uint8)
|
|
145
|
+
|
|
146
|
+
if "object_name" in req:
|
|
147
|
+
prompt = f"Point to the center of the {req['object_name']}."
|
|
148
|
+
label = req["object_name"].replace(" ", "_")
|
|
149
|
+
elif "prompt" in req:
|
|
150
|
+
prompt = req["prompt"]
|
|
151
|
+
label = "custom_prompt"
|
|
152
|
+
else:
|
|
153
|
+
raise ValueError("Provide either 'object_name' or 'prompt'")
|
|
154
|
+
|
|
155
|
+
point = self.molmo.infer_point(rgb, prompt)
|
|
156
|
+
|
|
157
|
+
resp = {
|
|
158
|
+
"status": "ok",
|
|
159
|
+
"action": "infer_point",
|
|
160
|
+
"point": point.tolist(),
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
except Exception as e:
|
|
164
|
+
log.exception("Request error")
|
|
165
|
+
resp = {"status": "error", "message": str(e)}
|
|
166
|
+
|
|
167
|
+
await websocket.send(json.dumps(resp))
|
|
168
|
+
|
|
169
|
+
except websockets.exceptions.ConnectionClosed:
|
|
170
|
+
log.info(f"Client disconnected: {client_id}")
|
|
171
|
+
|
|
172
|
+
async def start(self):
|
|
173
|
+
log.info(f"Starting Molmo server on ws://{self.host}:{self.port}")
|
|
174
|
+
async with websockets.serve(
|
|
175
|
+
self.handle_client,
|
|
176
|
+
self.host,
|
|
177
|
+
self.port,
|
|
178
|
+
max_size=50 * 1024 * 1024,
|
|
179
|
+
):
|
|
180
|
+
await asyncio.Future()
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def main():
|
|
184
|
+
parser = argparse.ArgumentParser("Molmo Pointing Server")
|
|
185
|
+
parser.add_argument("--host", default="0.0.0.0")
|
|
186
|
+
parser.add_argument("--port", type=int, default=8765)
|
|
187
|
+
args = parser.parse_args()
|
|
188
|
+
|
|
189
|
+
server = MolmoWebSocketServer(
|
|
190
|
+
host=args.host,
|
|
191
|
+
port=args.port,
|
|
192
|
+
)
|
|
193
|
+
asyncio.run(server.start())
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
if __name__ == "__main__":
|
|
197
|
+
main()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base_policy import BasePolicy
|