cua-agent 0.4.31__py3-none-any.whl → 0.4.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/adapters/huggingfacelocal_adapter.py +15 -66
- agent/adapters/models/__init__.py +33 -0
- agent/adapters/models/generic.py +75 -0
- agent/adapters/models/internvl.py +254 -0
- agent/adapters/models/opencua.py +100 -0
- agent/adapters/models/qwen2_5_vl.py +75 -0
- agent/agent.py +5 -1
- agent/callbacks/trajectory_saver.py +2 -0
- agent/cli.py +147 -22
- agent/loops/__init__.py +19 -1
- agent/loops/anthropic.py +3 -4
- agent/loops/composed_grounded.py +1 -1
- agent/loops/gemini.py +391 -0
- agent/loops/glm45v.py +3 -2
- agent/loops/gta1.py +1 -1
- agent/loops/holo.py +216 -0
- agent/loops/internvl.py +185 -0
- agent/loops/moondream3.py +464 -0
- agent/loops/openai.py +1 -2
- agent/loops/opencua.py +142 -0
- agent/loops/uitars.py +1 -1
- {cua_agent-0.4.31.dist-info → cua_agent-0.4.33.dist-info}/METADATA +23 -4
- {cua_agent-0.4.31.dist-info → cua_agent-0.4.33.dist-info}/RECORD +25 -15
- {cua_agent-0.4.31.dist-info → cua_agent-0.4.33.dist-info}/WHEEL +0 -0
- {cua_agent-0.4.31.dist-info → cua_agent-0.4.33.dist-info}/entry_points.txt +0 -0
agent/loops/gemini.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Gemini 2.5 Computer Use agent loop
|
|
3
|
+
|
|
4
|
+
Maps internal Agent SDK message format to Google's Gemini Computer Use API and back.
|
|
5
|
+
|
|
6
|
+
Key features:
|
|
7
|
+
- Lazy import of google.genai
|
|
8
|
+
- Configure Computer Use tool with excluded browser-specific predefined functions
|
|
9
|
+
- Optional custom function declarations hook for computer-call specific functions
|
|
10
|
+
- Convert Gemini function_call parts into internal computer_call actions
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
import io
|
|
17
|
+
import uuid
|
|
18
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
19
|
+
|
|
20
|
+
from PIL import Image
|
|
21
|
+
|
|
22
|
+
from ..decorators import register_agent
|
|
23
|
+
from ..loops.base import AsyncAgentConfig
|
|
24
|
+
from ..types import AgentCapability
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _lazy_import_genai():
|
|
28
|
+
"""Import google.genai lazily to avoid hard dependency unless used."""
|
|
29
|
+
try:
|
|
30
|
+
from google import genai # type: ignore
|
|
31
|
+
from google.genai import types # type: ignore
|
|
32
|
+
return genai, types
|
|
33
|
+
except Exception as e: # pragma: no cover
|
|
34
|
+
raise RuntimeError(
|
|
35
|
+
"google.genai is required for the Gemini Computer Use loop. Install the Google Gemini SDK."
|
|
36
|
+
) from e
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _data_url_to_bytes(data_url: str) -> Tuple[bytes, str]:
|
|
40
|
+
"""Convert a data URL to raw bytes and mime type."""
|
|
41
|
+
if not data_url.startswith("data:"):
|
|
42
|
+
# Assume it's base64 png payload
|
|
43
|
+
try:
|
|
44
|
+
return base64.b64decode(data_url), "image/png"
|
|
45
|
+
except Exception:
|
|
46
|
+
return b"", "application/octet-stream"
|
|
47
|
+
header, b64 = data_url.split(",", 1)
|
|
48
|
+
mime = "image/png"
|
|
49
|
+
if ";" in header:
|
|
50
|
+
mime = header.split(";")[0].split(":", 1)[1] or "image/png"
|
|
51
|
+
return base64.b64decode(b64), mime
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _bytes_image_size(img_bytes: bytes) -> Tuple[int, int]:
|
|
55
|
+
try:
|
|
56
|
+
img = Image.open(io.BytesIO(img_bytes))
|
|
57
|
+
return img.size
|
|
58
|
+
except Exception:
|
|
59
|
+
return (1024, 768)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _find_last_user_text(messages: List[Dict[str, Any]]) -> List[str]:
|
|
63
|
+
texts: List[str] = []
|
|
64
|
+
for msg in reversed(messages):
|
|
65
|
+
if msg.get("type") in (None, "message") and msg.get("role") == "user":
|
|
66
|
+
content = msg.get("content")
|
|
67
|
+
if isinstance(content, str):
|
|
68
|
+
return [content]
|
|
69
|
+
elif isinstance(content, list):
|
|
70
|
+
for c in content:
|
|
71
|
+
if c.get("type") in ("input_text", "output_text") and c.get("text"):
|
|
72
|
+
texts.append(c["text"]) # newest first
|
|
73
|
+
if texts:
|
|
74
|
+
return list(reversed(texts))
|
|
75
|
+
return []
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _find_last_screenshot(messages: List[Dict[str, Any]]) -> Optional[bytes]:
|
|
79
|
+
for msg in reversed(messages):
|
|
80
|
+
if msg.get("type") == "computer_call_output":
|
|
81
|
+
out = msg.get("output", {})
|
|
82
|
+
if isinstance(out, dict) and out.get("type") in ("input_image", "computer_screenshot"):
|
|
83
|
+
image_url = out.get("image_url", "")
|
|
84
|
+
if image_url:
|
|
85
|
+
data, _ = _data_url_to_bytes(image_url)
|
|
86
|
+
return data
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _denormalize(v: int, size: int) -> int:
|
|
91
|
+
# Gemini returns 0-999 normalized
|
|
92
|
+
try:
|
|
93
|
+
return max(0, min(size - 1, int(round(v / 1000 * size))))
|
|
94
|
+
except Exception:
|
|
95
|
+
return 0
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _map_gemini_fc_to_computer_call(
|
|
99
|
+
fc: Dict[str, Any],
|
|
100
|
+
screen_w: int,
|
|
101
|
+
screen_h: int,
|
|
102
|
+
) -> Optional[Dict[str, Any]]:
|
|
103
|
+
name = fc.get("name")
|
|
104
|
+
args = fc.get("args", {}) or {}
|
|
105
|
+
|
|
106
|
+
action: Dict[str, Any] = {}
|
|
107
|
+
if name == "click_at":
|
|
108
|
+
x = _denormalize(int(args.get("x", 0)), screen_w)
|
|
109
|
+
y = _denormalize(int(args.get("y", 0)), screen_h)
|
|
110
|
+
action = {"type": "click", "x": x, "y": y, "button": "left"}
|
|
111
|
+
elif name == "type_text_at":
|
|
112
|
+
x = _denormalize(int(args.get("x", 0)), screen_w)
|
|
113
|
+
y = _denormalize(int(args.get("y", 0)), screen_h)
|
|
114
|
+
text = args.get("text", "")
|
|
115
|
+
if args.get("press_enter") == True:
|
|
116
|
+
text += "\n"
|
|
117
|
+
action = {"type": "type", "x": x, "y": y, "text": text}
|
|
118
|
+
elif name == "hover_at":
|
|
119
|
+
x = _denormalize(int(args.get("x", 0)), screen_w)
|
|
120
|
+
y = _denormalize(int(args.get("y", 0)), screen_h)
|
|
121
|
+
action = {"type": "move", "x": x, "y": y}
|
|
122
|
+
elif name == "key_combination":
|
|
123
|
+
keys = str(args.get("keys", ""))
|
|
124
|
+
action = {"type": "keypress", "keys": keys}
|
|
125
|
+
elif name == "scroll_document":
|
|
126
|
+
direction = args.get("direction", "down")
|
|
127
|
+
magnitude = 800
|
|
128
|
+
dx, dy = 0, 0
|
|
129
|
+
if direction == "down":
|
|
130
|
+
dy = magnitude
|
|
131
|
+
elif direction == "up":
|
|
132
|
+
dy = -magnitude
|
|
133
|
+
elif direction == "right":
|
|
134
|
+
dx = magnitude
|
|
135
|
+
elif direction == "left":
|
|
136
|
+
dx = -magnitude
|
|
137
|
+
action = {"type": "scroll", "scroll_x": dx, "scroll_y": dy, "x": int(screen_w / 2), "y": int(screen_h / 2)}
|
|
138
|
+
elif name == "scroll_at":
|
|
139
|
+
x = _denormalize(int(args.get("x", 500)), screen_w)
|
|
140
|
+
y = _denormalize(int(args.get("y", 500)), screen_h)
|
|
141
|
+
direction = args.get("direction", "down")
|
|
142
|
+
magnitude = int(args.get("magnitude", 800))
|
|
143
|
+
dx, dy = 0, 0
|
|
144
|
+
if direction == "down":
|
|
145
|
+
dy = magnitude
|
|
146
|
+
elif direction == "up":
|
|
147
|
+
dy = -magnitude
|
|
148
|
+
elif direction == "right":
|
|
149
|
+
dx = magnitude
|
|
150
|
+
elif direction == "left":
|
|
151
|
+
dx = -magnitude
|
|
152
|
+
action = {"type": "scroll", "scroll_x": dx, "scroll_y": dy, "x": x, "y": y}
|
|
153
|
+
elif name == "drag_and_drop":
|
|
154
|
+
x = _denormalize(int(args.get("x", 0)), screen_w)
|
|
155
|
+
y = _denormalize(int(args.get("y", 0)), screen_h)
|
|
156
|
+
dx = _denormalize(int(args.get("destination_x", x)), screen_w)
|
|
157
|
+
dy = _denormalize(int(args.get("destination_y", y)), screen_h)
|
|
158
|
+
action = {"type": "drag", "start_x": x, "start_y": y, "end_x": dx, "end_y": dy, "button": "left"}
|
|
159
|
+
elif name == "wait_5_seconds":
|
|
160
|
+
action = {"type": "wait"}
|
|
161
|
+
else:
|
|
162
|
+
# Unsupported / excluded browser-specific or custom function; ignore
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
return {
|
|
166
|
+
"type": "computer_call",
|
|
167
|
+
"call_id": uuid.uuid4().hex,
|
|
168
|
+
"status": "completed",
|
|
169
|
+
"action": action,
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@register_agent(models=r"^gemini-2\.5-computer-use-preview-10-2025$")
|
|
174
|
+
class GeminiComputerUseConfig(AsyncAgentConfig):
|
|
175
|
+
async def predict_step(
|
|
176
|
+
self,
|
|
177
|
+
messages: List[Dict[str, Any]],
|
|
178
|
+
model: str,
|
|
179
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
180
|
+
max_retries: Optional[int] = None,
|
|
181
|
+
stream: bool = False,
|
|
182
|
+
computer_handler=None,
|
|
183
|
+
use_prompt_caching: Optional[bool] = False,
|
|
184
|
+
_on_api_start=None,
|
|
185
|
+
_on_api_end=None,
|
|
186
|
+
_on_usage=None,
|
|
187
|
+
_on_screenshot=None,
|
|
188
|
+
**kwargs,
|
|
189
|
+
) -> Dict[str, Any]:
|
|
190
|
+
genai, types = _lazy_import_genai()
|
|
191
|
+
|
|
192
|
+
client = genai.Client()
|
|
193
|
+
|
|
194
|
+
# Build excluded predefined functions for browser-specific behavior
|
|
195
|
+
excluded = [
|
|
196
|
+
"open_web_browser",
|
|
197
|
+
"search",
|
|
198
|
+
"navigate",
|
|
199
|
+
"go_forward",
|
|
200
|
+
"go_back",
|
|
201
|
+
"scroll_document",
|
|
202
|
+
]
|
|
203
|
+
# Optional custom functions: can be extended by host code via `tools` parameter later if desired
|
|
204
|
+
CUSTOM_FUNCTION_DECLARATIONS: List[Any] = []
|
|
205
|
+
|
|
206
|
+
# Compose tools config
|
|
207
|
+
generate_content_config = types.GenerateContentConfig(
|
|
208
|
+
tools=[
|
|
209
|
+
types.Tool(
|
|
210
|
+
computer_use=types.ComputerUse(
|
|
211
|
+
environment=types.Environment.ENVIRONMENT_BROWSER,
|
|
212
|
+
excluded_predefined_functions=excluded,
|
|
213
|
+
)
|
|
214
|
+
),
|
|
215
|
+
# types.Tool(function_declarations=CUSTOM_FUNCTION_DECLARATIONS), # enable when custom functions needed
|
|
216
|
+
]
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Prepare contents: last user text + latest screenshot
|
|
220
|
+
user_texts = _find_last_user_text(messages)
|
|
221
|
+
screenshot_bytes = _find_last_screenshot(messages)
|
|
222
|
+
|
|
223
|
+
parts: List[Any] = []
|
|
224
|
+
for t in user_texts:
|
|
225
|
+
parts.append(types.Part(text=t))
|
|
226
|
+
|
|
227
|
+
screen_w, screen_h = 1024, 768
|
|
228
|
+
if screenshot_bytes:
|
|
229
|
+
screen_w, screen_h = _bytes_image_size(screenshot_bytes)
|
|
230
|
+
parts.append(types.Part.from_bytes(data=screenshot_bytes, mime_type="image/png"))
|
|
231
|
+
|
|
232
|
+
# If we don't have any content, at least pass an empty user part to prompt reasoning
|
|
233
|
+
if not parts:
|
|
234
|
+
parts = [types.Part(text="Proceed to the next action.")]
|
|
235
|
+
|
|
236
|
+
contents = [types.Content(role="user", parts=parts)]
|
|
237
|
+
|
|
238
|
+
api_kwargs = {
|
|
239
|
+
"model": model,
|
|
240
|
+
"contents": contents,
|
|
241
|
+
"config": generate_content_config,
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
if _on_api_start:
|
|
245
|
+
await _on_api_start({
|
|
246
|
+
"model": api_kwargs["model"],
|
|
247
|
+
# "contents": api_kwargs["contents"], # Disabled for now
|
|
248
|
+
"config": api_kwargs["config"],
|
|
249
|
+
})
|
|
250
|
+
|
|
251
|
+
response = client.models.generate_content(**api_kwargs)
|
|
252
|
+
|
|
253
|
+
if _on_api_end:
|
|
254
|
+
await _on_api_end({
|
|
255
|
+
"model": api_kwargs["model"],
|
|
256
|
+
# "contents": api_kwargs["contents"], # Disabled for now
|
|
257
|
+
"config": api_kwargs["config"],
|
|
258
|
+
}, response)
|
|
259
|
+
|
|
260
|
+
# Usage (Gemini SDK may not always provide token usage; populate when available)
|
|
261
|
+
usage: Dict[str, Any] = {}
|
|
262
|
+
try:
|
|
263
|
+
# Some SDKs expose response.usage; if available, copy
|
|
264
|
+
if getattr(response, "usage_metadata", None):
|
|
265
|
+
md = response.usage_metadata
|
|
266
|
+
usage = {
|
|
267
|
+
"prompt_tokens": getattr(md, "prompt_token_count", None) or 0,
|
|
268
|
+
"completion_tokens": getattr(md, "candidates_token_count", None) or 0,
|
|
269
|
+
"total_tokens": getattr(md, "total_token_count", None) or 0,
|
|
270
|
+
}
|
|
271
|
+
except Exception:
|
|
272
|
+
pass
|
|
273
|
+
|
|
274
|
+
if _on_usage and usage:
|
|
275
|
+
await _on_usage(usage)
|
|
276
|
+
|
|
277
|
+
# Parse output into internal items
|
|
278
|
+
output_items: List[Dict[str, Any]] = []
|
|
279
|
+
|
|
280
|
+
candidate = response.candidates[0]
|
|
281
|
+
# Text parts from the model (assistant message)
|
|
282
|
+
text_parts: List[str] = []
|
|
283
|
+
function_calls: List[Dict[str, Any]] = []
|
|
284
|
+
for p in candidate.content.parts:
|
|
285
|
+
if getattr(p, "text", None):
|
|
286
|
+
text_parts.append(p.text)
|
|
287
|
+
if getattr(p, "function_call", None):
|
|
288
|
+
# p.function_call has name and args
|
|
289
|
+
fc = {
|
|
290
|
+
"name": getattr(p.function_call, "name", None),
|
|
291
|
+
"args": dict(getattr(p.function_call, "args", {}) or {}),
|
|
292
|
+
}
|
|
293
|
+
function_calls.append(fc)
|
|
294
|
+
|
|
295
|
+
if text_parts:
|
|
296
|
+
output_items.append(
|
|
297
|
+
{
|
|
298
|
+
"type": "message",
|
|
299
|
+
"role": "assistant",
|
|
300
|
+
"content": [{"type": "output_text", "text": "\n".join(text_parts)}],
|
|
301
|
+
}
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Map function calls to internal computer_call actions
|
|
305
|
+
for fc in function_calls:
|
|
306
|
+
item = _map_gemini_fc_to_computer_call(fc, screen_w, screen_h)
|
|
307
|
+
if item is not None:
|
|
308
|
+
output_items.append(item)
|
|
309
|
+
|
|
310
|
+
return {"output": output_items, "usage": usage}
|
|
311
|
+
|
|
312
|
+
async def predict_click(
|
|
313
|
+
self,
|
|
314
|
+
model: str,
|
|
315
|
+
image_b64: str,
|
|
316
|
+
instruction: str,
|
|
317
|
+
**kwargs,
|
|
318
|
+
) -> Optional[Tuple[float, float]]:
|
|
319
|
+
"""Ask Gemini CUA to output a single click action for the given instruction.
|
|
320
|
+
|
|
321
|
+
Excludes all predefined tools except `click_at` and sends the screenshot.
|
|
322
|
+
Returns pixel (x, y) if a click is proposed, else None.
|
|
323
|
+
"""
|
|
324
|
+
genai, types = _lazy_import_genai()
|
|
325
|
+
|
|
326
|
+
client = genai.Client()
|
|
327
|
+
|
|
328
|
+
# Exclude all but click_at
|
|
329
|
+
exclude_all_but_click = [
|
|
330
|
+
"open_web_browser",
|
|
331
|
+
"wait_5_seconds",
|
|
332
|
+
"go_back",
|
|
333
|
+
"go_forward",
|
|
334
|
+
"search",
|
|
335
|
+
"navigate",
|
|
336
|
+
"hover_at",
|
|
337
|
+
"type_text_at",
|
|
338
|
+
"key_combination",
|
|
339
|
+
"scroll_document",
|
|
340
|
+
"scroll_at",
|
|
341
|
+
"drag_and_drop",
|
|
342
|
+
]
|
|
343
|
+
|
|
344
|
+
config = types.GenerateContentConfig(
|
|
345
|
+
tools=[
|
|
346
|
+
types.Tool(
|
|
347
|
+
computer_use=types.ComputerUse(
|
|
348
|
+
environment=types.Environment.ENVIRONMENT_BROWSER,
|
|
349
|
+
excluded_predefined_functions=exclude_all_but_click,
|
|
350
|
+
)
|
|
351
|
+
)
|
|
352
|
+
]
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Prepare prompt parts
|
|
356
|
+
try:
|
|
357
|
+
img_bytes = base64.b64decode(image_b64)
|
|
358
|
+
except Exception:
|
|
359
|
+
img_bytes = b""
|
|
360
|
+
|
|
361
|
+
w, h = _bytes_image_size(img_bytes) if img_bytes else (1024, 768)
|
|
362
|
+
|
|
363
|
+
parts: List[Any] = [types.Part(text=f"Click {instruction}.")]
|
|
364
|
+
if img_bytes:
|
|
365
|
+
parts.append(types.Part.from_bytes(data=img_bytes, mime_type="image/png"))
|
|
366
|
+
|
|
367
|
+
contents = [types.Content(role="user", parts=parts)]
|
|
368
|
+
|
|
369
|
+
response = client.models.generate_content(
|
|
370
|
+
model=model,
|
|
371
|
+
contents=contents,
|
|
372
|
+
config=config,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# Parse first click_at
|
|
376
|
+
try:
|
|
377
|
+
candidate = response.candidates[0]
|
|
378
|
+
for p in candidate.content.parts:
|
|
379
|
+
fc = getattr(p, "function_call", None)
|
|
380
|
+
if fc and getattr(fc, "name", None) == "click_at":
|
|
381
|
+
args = dict(getattr(fc, "args", {}) or {})
|
|
382
|
+
x = _denormalize(int(args.get("x", 0)), w)
|
|
383
|
+
y = _denormalize(int(args.get("y", 0)), h)
|
|
384
|
+
return float(x), float(y)
|
|
385
|
+
except Exception:
|
|
386
|
+
return None
|
|
387
|
+
|
|
388
|
+
return None
|
|
389
|
+
|
|
390
|
+
def get_capabilities(self) -> List[AgentCapability]:
|
|
391
|
+
return ["click", "step"]
|
agent/loops/glm45v.py
CHANGED
|
@@ -844,7 +844,7 @@ Where x,y are coordinates normalized to 0-999 range."""
|
|
|
844
844
|
api_kwargs = {
|
|
845
845
|
"model": model,
|
|
846
846
|
"messages": litellm_messages,
|
|
847
|
-
"max_tokens":
|
|
847
|
+
"max_tokens": 2056,
|
|
848
848
|
"temperature": 0.001,
|
|
849
849
|
"extra_body": {
|
|
850
850
|
"skip_special_tokens": False,
|
|
@@ -856,6 +856,7 @@ Where x,y are coordinates normalized to 0-999 range."""
|
|
|
856
856
|
|
|
857
857
|
# Extract response content
|
|
858
858
|
response_content = response.choices[0].message.content.strip()
|
|
859
|
+
print(response)
|
|
859
860
|
|
|
860
861
|
# Parse response for click coordinates
|
|
861
862
|
# Look for coordinates in the response, handling special tokens
|
|
@@ -866,7 +867,7 @@ Where x,y are coordinates normalized to 0-999 range."""
|
|
|
866
867
|
# Fallback: look for coordinates without special tokens
|
|
867
868
|
coord_pattern = r"left_click\(start_box='?\[(\d+),(\d+)\]'?\)"
|
|
868
869
|
match = re.search(coord_pattern, response_content)
|
|
869
|
-
|
|
870
|
+
|
|
870
871
|
if match:
|
|
871
872
|
x, y = int(match.group(1)), int(match.group(2))
|
|
872
873
|
|
agent/loops/gta1.py
CHANGED
agent/loops/holo.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Holo 1.5 agent loop implementation for click prediction using litellm.acompletion.
|
|
3
|
+
|
|
4
|
+
Implements the Holo1.5 grounding behavior:
|
|
5
|
+
- Prompt asks for absolute pixel coordinates in JSON: {"action":"click_absolute","x":int,"y":int}
|
|
6
|
+
- Optionally resizes the image using Qwen2-VL smart_resize parameters (via transformers AutoProcessor)
|
|
7
|
+
- If resized, maps predicted coordinates back to the original screenshot resolution
|
|
8
|
+
|
|
9
|
+
Note: We do NOT manually load the model; acompletions (via HuggingFaceLocalAdapter)
|
|
10
|
+
will handle loading based on the provided model name.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
import json
|
|
17
|
+
from io import BytesIO
|
|
18
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
19
|
+
|
|
20
|
+
import litellm
|
|
21
|
+
from PIL import Image
|
|
22
|
+
|
|
23
|
+
from ..decorators import register_agent
|
|
24
|
+
from .base import AsyncAgentConfig
|
|
25
|
+
from ..types import AgentCapability
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _strip_hf_prefix(model: str) -> str:
|
|
29
|
+
"""Strip provider prefixes like 'huggingface-local/' from model names for HF processor load."""
|
|
30
|
+
if "/" in model and model.lower().startswith("huggingface-local/"):
|
|
31
|
+
return model.split("/", 1)[1]
|
|
32
|
+
return model
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _maybe_smart_resize(image: Image.Image, model: str) -> Tuple[Image.Image, Tuple[int, int]]:
|
|
36
|
+
"""
|
|
37
|
+
Try to compute Qwen2-VL smart_resize output size using transformers AutoProcessor.
|
|
38
|
+
|
|
39
|
+
Returns (processed_image, (orig_w, orig_h)). If transformers or processor unavailable,
|
|
40
|
+
returns the original image and size without resizing.
|
|
41
|
+
"""
|
|
42
|
+
orig_w, orig_h = image.size
|
|
43
|
+
try:
|
|
44
|
+
# Import lazily to avoid hard dependency if not installed
|
|
45
|
+
from transformers import AutoProcessor # type: ignore
|
|
46
|
+
from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( # type: ignore
|
|
47
|
+
smart_resize,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
processor_name = _strip_hf_prefix(model)
|
|
51
|
+
processor = AutoProcessor.from_pretrained(processor_name)
|
|
52
|
+
image_processor = getattr(processor, "image_processor", None)
|
|
53
|
+
if image_processor is None:
|
|
54
|
+
return image, (orig_w, orig_h)
|
|
55
|
+
|
|
56
|
+
factor = getattr(image_processor, "patch_size", 14) * getattr(image_processor, "merge_size", 1)
|
|
57
|
+
min_pixels = getattr(image_processor, "min_pixels", 256 * 256)
|
|
58
|
+
max_pixels = getattr(image_processor, "max_pixels", 1536 * 1536)
|
|
59
|
+
|
|
60
|
+
resized_h, resized_w = smart_resize(
|
|
61
|
+
orig_h,
|
|
62
|
+
orig_w,
|
|
63
|
+
factor=factor,
|
|
64
|
+
min_pixels=min_pixels,
|
|
65
|
+
max_pixels=max_pixels,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if (resized_w, resized_h) == (orig_w, orig_h):
|
|
69
|
+
return image, (orig_w, orig_h)
|
|
70
|
+
|
|
71
|
+
processed = image.resize((resized_w, resized_h), resample=Image.Resampling.LANCZOS)
|
|
72
|
+
return processed, (orig_w, orig_h)
|
|
73
|
+
except Exception:
|
|
74
|
+
# If any failure (no transformers, processor load error), fall back to original
|
|
75
|
+
return image, (orig_w, orig_h)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _build_holo_prompt(instruction: str) -> str:
|
|
79
|
+
"""Construct the Holo1.5 grounding prompt."""
|
|
80
|
+
# Keep it close to the cookbook while avoiding heavy schema generation
|
|
81
|
+
schema_hint = '{"action": "click_absolute", "x": <int>, "y": <int>}'
|
|
82
|
+
return (
|
|
83
|
+
"Localize an element on the GUI image according to the provided target and output a click position. "
|
|
84
|
+
f"You must output a valid JSON following the format: {schema_hint} "
|
|
85
|
+
f"Your target is: {instruction}"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _parse_click_json(output_text: str) -> Optional[Tuple[int, int]]:
|
|
90
|
+
"""
|
|
91
|
+
Parse JSON from model output and extract x, y ints.
|
|
92
|
+
Tries to find the first JSON object substring if extra text is present.
|
|
93
|
+
"""
|
|
94
|
+
try:
|
|
95
|
+
# Fast path: direct JSON
|
|
96
|
+
data = json.loads(output_text)
|
|
97
|
+
except Exception:
|
|
98
|
+
# Try to locate a JSON object within the text
|
|
99
|
+
start = output_text.find("{")
|
|
100
|
+
end = output_text.rfind("}")
|
|
101
|
+
if start == -1 or end == -1 or end <= start:
|
|
102
|
+
return None
|
|
103
|
+
try:
|
|
104
|
+
data = json.loads(output_text[start : end + 1])
|
|
105
|
+
except Exception:
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
x = int(data.get("x"))
|
|
110
|
+
y = int(data.get("y"))
|
|
111
|
+
return x, y
|
|
112
|
+
except Exception:
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@register_agent(models=r"(?i).*(Holo1\.5|Hcompany/Holo1\.5).*")
|
|
117
|
+
class HoloConfig(AsyncAgentConfig):
|
|
118
|
+
"""Holo is a family of UI grounding models from H Company"""
|
|
119
|
+
|
|
120
|
+
async def predict_step(
|
|
121
|
+
self,
|
|
122
|
+
messages: List[Dict[str, Any]],
|
|
123
|
+
model: str,
|
|
124
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
125
|
+
max_retries: Optional[int] = None,
|
|
126
|
+
stream: bool = False,
|
|
127
|
+
computer_handler=None,
|
|
128
|
+
_on_api_start=None,
|
|
129
|
+
_on_api_end=None,
|
|
130
|
+
_on_usage=None,
|
|
131
|
+
_on_screenshot=None,
|
|
132
|
+
**kwargs,
|
|
133
|
+
) -> Dict[str, Any]:
|
|
134
|
+
# Holo models are only trained on UI localization tasks, not all-in-one agent
|
|
135
|
+
raise NotImplementedError()
|
|
136
|
+
|
|
137
|
+
async def predict_click(
|
|
138
|
+
self,
|
|
139
|
+
model: str,
|
|
140
|
+
image_b64: str,
|
|
141
|
+
instruction: str,
|
|
142
|
+
**kwargs,
|
|
143
|
+
) -> Optional[Tuple[int, int]]:
|
|
144
|
+
"""
|
|
145
|
+
Predict click coordinates using Holo1.5 via litellm.acompletion.
|
|
146
|
+
|
|
147
|
+
- Optionally smart-resizes the image using Qwen2-VL rules if transformers are available
|
|
148
|
+
- Prompts for JSON with absolute pixel coordinates
|
|
149
|
+
- Parses x,y and maps back to original screenshot size if resized
|
|
150
|
+
"""
|
|
151
|
+
try:
|
|
152
|
+
img_bytes = base64.b64decode(image_b64)
|
|
153
|
+
original_img = Image.open(BytesIO(img_bytes))
|
|
154
|
+
except Exception:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
# Optional preprocessing
|
|
158
|
+
processed_img, (orig_w, orig_h) = _maybe_smart_resize(original_img, model)
|
|
159
|
+
|
|
160
|
+
# If we resized, send the resized image; otherwise send original
|
|
161
|
+
img_to_send = processed_img
|
|
162
|
+
buf = BytesIO()
|
|
163
|
+
img_to_send.save(buf, format="PNG")
|
|
164
|
+
processed_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
|
165
|
+
|
|
166
|
+
prompt = _build_holo_prompt(instruction)
|
|
167
|
+
|
|
168
|
+
messages = [
|
|
169
|
+
{
|
|
170
|
+
"role": "user",
|
|
171
|
+
"content": [
|
|
172
|
+
{
|
|
173
|
+
"type": "image_url",
|
|
174
|
+
"image_url": {"url": f"data:image/png;base64,{processed_b64}"},
|
|
175
|
+
},
|
|
176
|
+
{"type": "text", "text": prompt},
|
|
177
|
+
],
|
|
178
|
+
}
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
api_kwargs = {
|
|
182
|
+
"model": model,
|
|
183
|
+
"messages": messages,
|
|
184
|
+
# Deterministic, small output
|
|
185
|
+
"max_tokens": kwargs.get("max_tokens", 256),
|
|
186
|
+
"temperature": kwargs.get("temperature", 0.0),
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
response = await litellm.acompletion(**api_kwargs)
|
|
190
|
+
output_text = (response.choices[0].message.content or "").strip() # type: ignore
|
|
191
|
+
|
|
192
|
+
coords = _parse_click_json(output_text)
|
|
193
|
+
if coords is None:
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
x, y = coords
|
|
197
|
+
|
|
198
|
+
# Map back to original size if we resized
|
|
199
|
+
proc_w, proc_h = img_to_send.size
|
|
200
|
+
if (proc_w, proc_h) != (orig_w, orig_h):
|
|
201
|
+
try:
|
|
202
|
+
sx = orig_w / float(proc_w)
|
|
203
|
+
sy = orig_h / float(proc_h)
|
|
204
|
+
x = int(round(x * sx))
|
|
205
|
+
y = int(round(y * sy))
|
|
206
|
+
except Exception:
|
|
207
|
+
# Fallback: clamp within original bounds
|
|
208
|
+
pass
|
|
209
|
+
|
|
210
|
+
# Clamp to original image bounds
|
|
211
|
+
x = max(0, min(orig_w - 1, x))
|
|
212
|
+
y = max(0, min(orig_h - 1, y))
|
|
213
|
+
return x, y
|
|
214
|
+
|
|
215
|
+
def get_capabilities(self) -> List[AgentCapability]:
|
|
216
|
+
return ["click"]
|