cua-agent 0.4.30__py3-none-any.whl → 0.4.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/adapters/huggingfacelocal_adapter.py +15 -66
- agent/adapters/models/__init__.py +33 -0
- agent/adapters/models/generic.py +75 -0
- agent/adapters/models/internvl.py +254 -0
- agent/adapters/models/opencua.py +100 -0
- agent/adapters/models/qwen2_5_vl.py +75 -0
- agent/agent.py +5 -1
- agent/callbacks/trajectory_saver.py +2 -0
- agent/cli.py +90 -1
- agent/integrations/hud/__init__.py +19 -0
- agent/loops/__init__.py +15 -1
- agent/loops/anthropic.py +2 -3
- agent/loops/composed_grounded.py +1 -1
- agent/loops/glm45v.py +3 -2
- agent/loops/gta1.py +1 -1
- agent/loops/holo.py +216 -0
- agent/loops/internvl.py +185 -0
- agent/loops/opencua.py +142 -0
- agent/loops/uitars.py +1 -1
- {cua_agent-0.4.30.dist-info → cua_agent-0.4.32.dist-info}/METADATA +20 -4
- {cua_agent-0.4.30.dist-info → cua_agent-0.4.32.dist-info}/RECORD +23 -15
- {cua_agent-0.4.30.dist-info → cua_agent-0.4.32.dist-info}/WHEEL +0 -0
- {cua_agent-0.4.30.dist-info → cua_agent-0.4.32.dist-info}/entry_points.txt +0 -0
|
@@ -188,6 +188,8 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
|
|
188
188
|
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
|
|
189
189
|
if "+" in model:
|
|
190
190
|
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
|
|
191
|
+
# strip non-alphanumeric characters from model_name_short
|
|
192
|
+
model_name_short = ''.join(c for c in model_name_short if c.isalnum() or c == '_')
|
|
191
193
|
|
|
192
194
|
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
|
|
193
195
|
now = datetime.now()
|
agent/cli.py
CHANGED
|
@@ -18,6 +18,15 @@ try:
|
|
|
18
18
|
import json
|
|
19
19
|
from typing import List, Dict, Any
|
|
20
20
|
import dotenv
|
|
21
|
+
import base64
|
|
22
|
+
import time
|
|
23
|
+
import platform
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
try:
|
|
26
|
+
from PIL import Image, ImageDraw
|
|
27
|
+
PIL_AVAILABLE = True
|
|
28
|
+
except Exception:
|
|
29
|
+
PIL_AVAILABLE = False
|
|
21
30
|
from yaspin import yaspin
|
|
22
31
|
except ImportError:
|
|
23
32
|
if __name__ == "__main__":
|
|
@@ -248,6 +257,13 @@ Examples:
|
|
|
248
257
|
help="Initial prompt to send to the agent. Leave blank for interactive mode."
|
|
249
258
|
)
|
|
250
259
|
|
|
260
|
+
parser.add_argument(
|
|
261
|
+
"--predict-click",
|
|
262
|
+
dest="predict_click",
|
|
263
|
+
type=str,
|
|
264
|
+
help="Instruction for click prediction. If set, runs predict_click, draws crosshair on a fresh screenshot, saves and opens it."
|
|
265
|
+
)
|
|
266
|
+
|
|
251
267
|
parser.add_argument(
|
|
252
268
|
"-c", "--cache",
|
|
253
269
|
action="store_true",
|
|
@@ -331,6 +347,7 @@ Examples:
|
|
|
331
347
|
agent_kwargs = {
|
|
332
348
|
"model": args.model,
|
|
333
349
|
"tools": [computer],
|
|
350
|
+
"trust_remote_code": True, # needed for some local models (e.g., InternVL, OpenCUA)
|
|
334
351
|
"verbosity": 20 if args.verbose else 30, # DEBUG vs WARNING
|
|
335
352
|
"max_retries": args.max_retries
|
|
336
353
|
}
|
|
@@ -353,7 +370,79 @@ Examples:
|
|
|
353
370
|
|
|
354
371
|
agent = ComputerAgent(**agent_kwargs)
|
|
355
372
|
|
|
356
|
-
#
|
|
373
|
+
# If predict-click mode is requested, run once and exit
|
|
374
|
+
if args.predict_click:
|
|
375
|
+
if not PIL_AVAILABLE:
|
|
376
|
+
print_colored("❌ Pillow (PIL) is required for --predict-click visualization. Install with: pip install pillow", Colors.RED, bold=True)
|
|
377
|
+
sys.exit(1)
|
|
378
|
+
|
|
379
|
+
instruction = args.predict_click
|
|
380
|
+
print_colored(f"Predicting click for: '{instruction}'", Colors.CYAN)
|
|
381
|
+
|
|
382
|
+
# Take a fresh screenshot FIRST
|
|
383
|
+
try:
|
|
384
|
+
img_bytes = await computer.interface.screenshot()
|
|
385
|
+
except Exception as e:
|
|
386
|
+
print_colored(f"❌ Failed to take screenshot: {e}", Colors.RED, bold=True)
|
|
387
|
+
sys.exit(1)
|
|
388
|
+
|
|
389
|
+
# Encode screenshot to base64 for predict_click
|
|
390
|
+
try:
|
|
391
|
+
image_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
|
392
|
+
except Exception as e:
|
|
393
|
+
print_colored(f"❌ Failed to encode screenshot: {e}", Colors.RED, bold=True)
|
|
394
|
+
sys.exit(1)
|
|
395
|
+
|
|
396
|
+
try:
|
|
397
|
+
coords = await agent.predict_click(instruction, image_b64=image_b64)
|
|
398
|
+
except Exception as e:
|
|
399
|
+
print_colored(f"❌ predict_click failed: {e}", Colors.RED, bold=True)
|
|
400
|
+
sys.exit(1)
|
|
401
|
+
|
|
402
|
+
if not coords:
|
|
403
|
+
print_colored("⚠️ No coordinates returned.", Colors.YELLOW)
|
|
404
|
+
sys.exit(2)
|
|
405
|
+
|
|
406
|
+
x, y = coords
|
|
407
|
+
print_colored(f"✅ Predicted coordinates: ({x}, {y})", Colors.GREEN)
|
|
408
|
+
|
|
409
|
+
try:
|
|
410
|
+
from io import BytesIO
|
|
411
|
+
with Image.open(BytesIO(img_bytes)) as img:
|
|
412
|
+
img = img.convert("RGB")
|
|
413
|
+
draw = ImageDraw.Draw(img)
|
|
414
|
+
# Draw crosshair
|
|
415
|
+
size = 12
|
|
416
|
+
color = (255, 0, 0)
|
|
417
|
+
draw.line([(x - size, y), (x + size, y)], fill=color, width=3)
|
|
418
|
+
draw.line([(x, y - size), (x, y + size)], fill=color, width=3)
|
|
419
|
+
# Optional small circle
|
|
420
|
+
r = 6
|
|
421
|
+
draw.ellipse([(x - r, y - r), (x + r, y + r)], outline=color, width=2)
|
|
422
|
+
|
|
423
|
+
out_path = Path.cwd() / f"predict_click_{int(time.time())}.png"
|
|
424
|
+
img.save(out_path)
|
|
425
|
+
print_colored(f"🖼️ Saved to {out_path}")
|
|
426
|
+
|
|
427
|
+
# Open the image with default viewer
|
|
428
|
+
try:
|
|
429
|
+
system = platform.system().lower()
|
|
430
|
+
if system == "windows":
|
|
431
|
+
os.startfile(str(out_path)) # type: ignore[attr-defined]
|
|
432
|
+
elif system == "darwin":
|
|
433
|
+
os.system(f"open \"{out_path}\"")
|
|
434
|
+
else:
|
|
435
|
+
os.system(f"xdg-open \"{out_path}\"")
|
|
436
|
+
except Exception:
|
|
437
|
+
pass
|
|
438
|
+
except Exception as e:
|
|
439
|
+
print_colored(f"❌ Failed to render/save screenshot: {e}", Colors.RED, bold=True)
|
|
440
|
+
sys.exit(1)
|
|
441
|
+
|
|
442
|
+
# Done
|
|
443
|
+
sys.exit(0)
|
|
444
|
+
|
|
445
|
+
# Start chat loop (default interactive mode)
|
|
357
446
|
await chat_loop(agent, args.model, container_name, args.prompt, args.usage)
|
|
358
447
|
|
|
359
448
|
|
|
@@ -11,6 +11,7 @@ Exports:
|
|
|
11
11
|
import time
|
|
12
12
|
from typing import Any, Optional
|
|
13
13
|
|
|
14
|
+
from agent.computers import is_agent_computer
|
|
14
15
|
from datasets import load_dataset, Dataset
|
|
15
16
|
from hud.datasets import Task, run_dataset
|
|
16
17
|
from hud import trace
|
|
@@ -55,6 +56,15 @@ async def run_single_task(
|
|
|
55
56
|
sample_task = dataset[task_id] # type: ignore[index]
|
|
56
57
|
task_prompt = sample_task.get("prompt", f"Task {sample_task.get('id', 0)}") # type: ignore[attr-defined]
|
|
57
58
|
|
|
59
|
+
# Filter any existing Computer tools
|
|
60
|
+
# The eval framework will add its own Computer tool per task
|
|
61
|
+
if tools:
|
|
62
|
+
tools = [
|
|
63
|
+
tool
|
|
64
|
+
for tool in tools
|
|
65
|
+
if not is_agent_computer(tool)
|
|
66
|
+
]
|
|
67
|
+
|
|
58
68
|
with trace(name=task_prompt):
|
|
59
69
|
task = Task(**sample_task) # type: ignore[arg-type]
|
|
60
70
|
|
|
@@ -118,6 +128,15 @@ async def run_full_dataset(
|
|
|
118
128
|
dataset_name = "custom"
|
|
119
129
|
job_name = job_name or f"Evaluation {time.strftime('%H:%M %Y-%m-%d')}"
|
|
120
130
|
|
|
131
|
+
# Filter any existing Computer tools
|
|
132
|
+
# The eval framework will add its own Computer tool per task
|
|
133
|
+
if tools:
|
|
134
|
+
tools = [
|
|
135
|
+
tool
|
|
136
|
+
for tool in tools
|
|
137
|
+
if not is_agent_computer(tool)
|
|
138
|
+
]
|
|
139
|
+
|
|
121
140
|
# Execute evaluation
|
|
122
141
|
return await run_dataset(
|
|
123
142
|
name=job_name,
|
agent/loops/__init__.py
CHANGED
|
@@ -10,5 +10,19 @@ from . import omniparser
|
|
|
10
10
|
from . import gta1
|
|
11
11
|
from . import composed_grounded
|
|
12
12
|
from . import glm45v
|
|
13
|
+
from . import opencua
|
|
14
|
+
from . import internvl
|
|
15
|
+
from . import holo
|
|
13
16
|
|
|
14
|
-
__all__ = [
|
|
17
|
+
__all__ = [
|
|
18
|
+
"anthropic",
|
|
19
|
+
"openai",
|
|
20
|
+
"uitars",
|
|
21
|
+
"omniparser",
|
|
22
|
+
"gta1",
|
|
23
|
+
"composed_grounded",
|
|
24
|
+
"glm45v",
|
|
25
|
+
"opencua",
|
|
26
|
+
"internvl",
|
|
27
|
+
"holo",
|
|
28
|
+
]
|
agent/loops/anthropic.py
CHANGED
|
@@ -1577,11 +1577,10 @@ Task: Click {instruction}. Output ONLY a click action on the target element."""
|
|
|
1577
1577
|
isinstance(item.get("action"), dict)):
|
|
1578
1578
|
|
|
1579
1579
|
action = item["action"]
|
|
1580
|
-
if action.get("
|
|
1580
|
+
if action.get("x") and action.get("y"):
|
|
1581
1581
|
x = action.get("x")
|
|
1582
1582
|
y = action.get("y")
|
|
1583
|
-
|
|
1584
|
-
return (int(x), int(y))
|
|
1583
|
+
return (int(x), int(y))
|
|
1585
1584
|
|
|
1586
1585
|
return None
|
|
1587
1586
|
|
agent/loops/composed_grounded.py
CHANGED
|
@@ -126,7 +126,7 @@ def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str
|
|
|
126
126
|
|
|
127
127
|
|
|
128
128
|
@register_agent(r".*\+.*", priority=1)
|
|
129
|
-
class ComposedGroundedConfig:
|
|
129
|
+
class ComposedGroundedConfig(AsyncAgentConfig):
|
|
130
130
|
"""
|
|
131
131
|
Composed-grounded agent configuration that uses both grounding and thinking models.
|
|
132
132
|
|
agent/loops/glm45v.py
CHANGED
|
@@ -844,7 +844,7 @@ Where x,y are coordinates normalized to 0-999 range."""
|
|
|
844
844
|
api_kwargs = {
|
|
845
845
|
"model": model,
|
|
846
846
|
"messages": litellm_messages,
|
|
847
|
-
"max_tokens":
|
|
847
|
+
"max_tokens": 2056,
|
|
848
848
|
"temperature": 0.001,
|
|
849
849
|
"extra_body": {
|
|
850
850
|
"skip_special_tokens": False,
|
|
@@ -856,6 +856,7 @@ Where x,y are coordinates normalized to 0-999 range."""
|
|
|
856
856
|
|
|
857
857
|
# Extract response content
|
|
858
858
|
response_content = response.choices[0].message.content.strip()
|
|
859
|
+
print(response)
|
|
859
860
|
|
|
860
861
|
# Parse response for click coordinates
|
|
861
862
|
# Look for coordinates in the response, handling special tokens
|
|
@@ -866,7 +867,7 @@ Where x,y are coordinates normalized to 0-999 range."""
|
|
|
866
867
|
# Fallback: look for coordinates without special tokens
|
|
867
868
|
coord_pattern = r"left_click\(start_box='?\[(\d+),(\d+)\]'?\)"
|
|
868
869
|
match = re.search(coord_pattern, response_content)
|
|
869
|
-
|
|
870
|
+
|
|
870
871
|
if match:
|
|
871
872
|
x, y = int(match.group(1)), int(match.group(2))
|
|
872
873
|
|
agent/loops/gta1.py
CHANGED
agent/loops/holo.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Holo 1.5 agent loop implementation for click prediction using litellm.acompletion.
|
|
3
|
+
|
|
4
|
+
Implements the Holo1.5 grounding behavior:
|
|
5
|
+
- Prompt asks for absolute pixel coordinates in JSON: {"action":"click_absolute","x":int,"y":int}
|
|
6
|
+
- Optionally resizes the image using Qwen2-VL smart_resize parameters (via transformers AutoProcessor)
|
|
7
|
+
- If resized, maps predicted coordinates back to the original screenshot resolution
|
|
8
|
+
|
|
9
|
+
Note: We do NOT manually load the model; acompletions (via HuggingFaceLocalAdapter)
|
|
10
|
+
will handle loading based on the provided model name.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
import json
|
|
17
|
+
from io import BytesIO
|
|
18
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
19
|
+
|
|
20
|
+
import litellm
|
|
21
|
+
from PIL import Image
|
|
22
|
+
|
|
23
|
+
from ..decorators import register_agent
|
|
24
|
+
from .base import AsyncAgentConfig
|
|
25
|
+
from ..types import AgentCapability
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _strip_hf_prefix(model: str) -> str:
|
|
29
|
+
"""Strip provider prefixes like 'huggingface-local/' from model names for HF processor load."""
|
|
30
|
+
if "/" in model and model.lower().startswith("huggingface-local/"):
|
|
31
|
+
return model.split("/", 1)[1]
|
|
32
|
+
return model
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _maybe_smart_resize(image: Image.Image, model: str) -> Tuple[Image.Image, Tuple[int, int]]:
|
|
36
|
+
"""
|
|
37
|
+
Try to compute Qwen2-VL smart_resize output size using transformers AutoProcessor.
|
|
38
|
+
|
|
39
|
+
Returns (processed_image, (orig_w, orig_h)). If transformers or processor unavailable,
|
|
40
|
+
returns the original image and size without resizing.
|
|
41
|
+
"""
|
|
42
|
+
orig_w, orig_h = image.size
|
|
43
|
+
try:
|
|
44
|
+
# Import lazily to avoid hard dependency if not installed
|
|
45
|
+
from transformers import AutoProcessor # type: ignore
|
|
46
|
+
from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( # type: ignore
|
|
47
|
+
smart_resize,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
processor_name = _strip_hf_prefix(model)
|
|
51
|
+
processor = AutoProcessor.from_pretrained(processor_name)
|
|
52
|
+
image_processor = getattr(processor, "image_processor", None)
|
|
53
|
+
if image_processor is None:
|
|
54
|
+
return image, (orig_w, orig_h)
|
|
55
|
+
|
|
56
|
+
factor = getattr(image_processor, "patch_size", 14) * getattr(image_processor, "merge_size", 1)
|
|
57
|
+
min_pixels = getattr(image_processor, "min_pixels", 256 * 256)
|
|
58
|
+
max_pixels = getattr(image_processor, "max_pixels", 1536 * 1536)
|
|
59
|
+
|
|
60
|
+
resized_h, resized_w = smart_resize(
|
|
61
|
+
orig_h,
|
|
62
|
+
orig_w,
|
|
63
|
+
factor=factor,
|
|
64
|
+
min_pixels=min_pixels,
|
|
65
|
+
max_pixels=max_pixels,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if (resized_w, resized_h) == (orig_w, orig_h):
|
|
69
|
+
return image, (orig_w, orig_h)
|
|
70
|
+
|
|
71
|
+
processed = image.resize((resized_w, resized_h), resample=Image.Resampling.LANCZOS)
|
|
72
|
+
return processed, (orig_w, orig_h)
|
|
73
|
+
except Exception:
|
|
74
|
+
# If any failure (no transformers, processor load error), fall back to original
|
|
75
|
+
return image, (orig_w, orig_h)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _build_holo_prompt(instruction: str) -> str:
|
|
79
|
+
"""Construct the Holo1.5 grounding prompt."""
|
|
80
|
+
# Keep it close to the cookbook while avoiding heavy schema generation
|
|
81
|
+
schema_hint = '{"action": "click_absolute", "x": <int>, "y": <int>}'
|
|
82
|
+
return (
|
|
83
|
+
"Localize an element on the GUI image according to the provided target and output a click position. "
|
|
84
|
+
f"You must output a valid JSON following the format: {schema_hint} "
|
|
85
|
+
f"Your target is: {instruction}"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _parse_click_json(output_text: str) -> Optional[Tuple[int, int]]:
|
|
90
|
+
"""
|
|
91
|
+
Parse JSON from model output and extract x, y ints.
|
|
92
|
+
Tries to find the first JSON object substring if extra text is present.
|
|
93
|
+
"""
|
|
94
|
+
try:
|
|
95
|
+
# Fast path: direct JSON
|
|
96
|
+
data = json.loads(output_text)
|
|
97
|
+
except Exception:
|
|
98
|
+
# Try to locate a JSON object within the text
|
|
99
|
+
start = output_text.find("{")
|
|
100
|
+
end = output_text.rfind("}")
|
|
101
|
+
if start == -1 or end == -1 or end <= start:
|
|
102
|
+
return None
|
|
103
|
+
try:
|
|
104
|
+
data = json.loads(output_text[start : end + 1])
|
|
105
|
+
except Exception:
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
x = int(data.get("x"))
|
|
110
|
+
y = int(data.get("y"))
|
|
111
|
+
return x, y
|
|
112
|
+
except Exception:
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@register_agent(models=r"(?i).*(Holo1\.5|Hcompany/Holo1\.5).*")
|
|
117
|
+
class HoloConfig(AsyncAgentConfig):
|
|
118
|
+
"""Holo is a family of UI grounding models from H Company"""
|
|
119
|
+
|
|
120
|
+
async def predict_step(
|
|
121
|
+
self,
|
|
122
|
+
messages: List[Dict[str, Any]],
|
|
123
|
+
model: str,
|
|
124
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
125
|
+
max_retries: Optional[int] = None,
|
|
126
|
+
stream: bool = False,
|
|
127
|
+
computer_handler=None,
|
|
128
|
+
_on_api_start=None,
|
|
129
|
+
_on_api_end=None,
|
|
130
|
+
_on_usage=None,
|
|
131
|
+
_on_screenshot=None,
|
|
132
|
+
**kwargs,
|
|
133
|
+
) -> Dict[str, Any]:
|
|
134
|
+
# Holo models are only trained on UI localization tasks, not all-in-one agent
|
|
135
|
+
raise NotImplementedError()
|
|
136
|
+
|
|
137
|
+
async def predict_click(
|
|
138
|
+
self,
|
|
139
|
+
model: str,
|
|
140
|
+
image_b64: str,
|
|
141
|
+
instruction: str,
|
|
142
|
+
**kwargs,
|
|
143
|
+
) -> Optional[Tuple[int, int]]:
|
|
144
|
+
"""
|
|
145
|
+
Predict click coordinates using Holo1.5 via litellm.acompletion.
|
|
146
|
+
|
|
147
|
+
- Optionally smart-resizes the image using Qwen2-VL rules if transformers are available
|
|
148
|
+
- Prompts for JSON with absolute pixel coordinates
|
|
149
|
+
- Parses x,y and maps back to original screenshot size if resized
|
|
150
|
+
"""
|
|
151
|
+
try:
|
|
152
|
+
img_bytes = base64.b64decode(image_b64)
|
|
153
|
+
original_img = Image.open(BytesIO(img_bytes))
|
|
154
|
+
except Exception:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
# Optional preprocessing
|
|
158
|
+
processed_img, (orig_w, orig_h) = _maybe_smart_resize(original_img, model)
|
|
159
|
+
|
|
160
|
+
# If we resized, send the resized image; otherwise send original
|
|
161
|
+
img_to_send = processed_img
|
|
162
|
+
buf = BytesIO()
|
|
163
|
+
img_to_send.save(buf, format="PNG")
|
|
164
|
+
processed_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
|
165
|
+
|
|
166
|
+
prompt = _build_holo_prompt(instruction)
|
|
167
|
+
|
|
168
|
+
messages = [
|
|
169
|
+
{
|
|
170
|
+
"role": "user",
|
|
171
|
+
"content": [
|
|
172
|
+
{
|
|
173
|
+
"type": "image_url",
|
|
174
|
+
"image_url": {"url": f"data:image/png;base64,{processed_b64}"},
|
|
175
|
+
},
|
|
176
|
+
{"type": "text", "text": prompt},
|
|
177
|
+
],
|
|
178
|
+
}
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
api_kwargs = {
|
|
182
|
+
"model": model,
|
|
183
|
+
"messages": messages,
|
|
184
|
+
# Deterministic, small output
|
|
185
|
+
"max_tokens": kwargs.get("max_tokens", 256),
|
|
186
|
+
"temperature": kwargs.get("temperature", 0.0),
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
response = await litellm.acompletion(**api_kwargs)
|
|
190
|
+
output_text = (response.choices[0].message.content or "").strip() # type: ignore
|
|
191
|
+
|
|
192
|
+
coords = _parse_click_json(output_text)
|
|
193
|
+
if coords is None:
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
x, y = coords
|
|
197
|
+
|
|
198
|
+
# Map back to original size if we resized
|
|
199
|
+
proc_w, proc_h = img_to_send.size
|
|
200
|
+
if (proc_w, proc_h) != (orig_w, orig_h):
|
|
201
|
+
try:
|
|
202
|
+
sx = orig_w / float(proc_w)
|
|
203
|
+
sy = orig_h / float(proc_h)
|
|
204
|
+
x = int(round(x * sx))
|
|
205
|
+
y = int(round(y * sy))
|
|
206
|
+
except Exception:
|
|
207
|
+
# Fallback: clamp within original bounds
|
|
208
|
+
pass
|
|
209
|
+
|
|
210
|
+
# Clamp to original image bounds
|
|
211
|
+
x = max(0, min(orig_w - 1, x))
|
|
212
|
+
y = max(0, min(orig_h - 1, y))
|
|
213
|
+
return x, y
|
|
214
|
+
|
|
215
|
+
def get_capabilities(self) -> List[AgentCapability]:
|
|
216
|
+
return ["click"]
|
agent/loops/internvl.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InternVL agent loop implementation for click prediction using litellm.acompletion.
|
|
3
|
+
|
|
4
|
+
Implements the ScreenSpot InternVL grounding baseline behavior:
|
|
5
|
+
- Uses the exact grounding prompt format with <image> and <ref> tags
|
|
6
|
+
- Expects coordinates in 0-1000 normalized range in formats [[x1,y1,x2,y2]] or [[x,y]]
|
|
7
|
+
- Converts to pixel coordinates relative to the original screenshot size
|
|
8
|
+
|
|
9
|
+
Note: We do NOT manually load the InternVL model; acompletions (via HuggingFaceLocalAdapter)
|
|
10
|
+
will handle loading based on the provided model name.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
import math
|
|
17
|
+
import re
|
|
18
|
+
from io import BytesIO
|
|
19
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
20
|
+
|
|
21
|
+
from PIL import Image
|
|
22
|
+
import litellm
|
|
23
|
+
|
|
24
|
+
from ..decorators import register_agent
|
|
25
|
+
from .composed_grounded import ComposedGroundedConfig
|
|
26
|
+
from ..types import AgentCapability
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Regex patterns for extracting coordinates
|
|
30
|
+
# Accept optional whitespace and optional decimal fractions
|
|
31
|
+
_NUM = r"(\d+(?:\.\d+)?)"
|
|
32
|
+
_POINT_PATTERN = re.compile(r"\[\[\s*" + _NUM + r"\s*,\s*" + _NUM + r"\s*\]\]")
|
|
33
|
+
_BBOX_PATTERN = re.compile(
|
|
34
|
+
r"\[\[\s*" + _NUM + r"\s*,\s*" + _NUM + r"\s*,\s*" + _NUM + r"\s*,\s*" + _NUM + r"\s*\]\]"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _extract_first_point(text: str) -> Optional[Tuple[float, float]]:
|
|
39
|
+
"""Extract the first [[x,y]] as normalized (0-1000) floats."""
|
|
40
|
+
m = _POINT_PATTERN.search(text)
|
|
41
|
+
if not m:
|
|
42
|
+
return None
|
|
43
|
+
try:
|
|
44
|
+
x = float(m.group(1))
|
|
45
|
+
y = float(m.group(2))
|
|
46
|
+
return x, y
|
|
47
|
+
except Exception:
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _extract_last_bbox(text: str) -> Optional[Tuple[float, float, float, float]]:
|
|
52
|
+
"""Extract the last [[x1,y1,x2,y2]] as normalized (0-1000) floats."""
|
|
53
|
+
matches = list(_BBOX_PATTERN.finditer(text))
|
|
54
|
+
if not matches:
|
|
55
|
+
return None
|
|
56
|
+
m = matches[-1]
|
|
57
|
+
try:
|
|
58
|
+
x1 = float(m.group(1))
|
|
59
|
+
y1 = float(m.group(2))
|
|
60
|
+
x2 = float(m.group(3))
|
|
61
|
+
y2 = float(m.group(4))
|
|
62
|
+
return x1, y1, x2, y2
|
|
63
|
+
except Exception:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _scale_norm_to_pixels(x_norm: float, y_norm: float, width: int, height: int) -> Tuple[int, int]:
|
|
68
|
+
"""Scale 0-1000 normalized coordinates to pixel coordinates for given image size."""
|
|
69
|
+
x_px = int(math.floor((x_norm / 1000.0) * width))
|
|
70
|
+
y_px = int(math.floor((y_norm / 1000.0) * height))
|
|
71
|
+
# Clamp to image bounds just in case
|
|
72
|
+
x_px = max(0, min(width - 1, x_px))
|
|
73
|
+
y_px = max(0, min(height - 1, y_px))
|
|
74
|
+
return x_px, y_px
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@register_agent(models=r"(?i).*InternVL.*")
|
|
78
|
+
class InternVLConfig(ComposedGroundedConfig):
|
|
79
|
+
"""InternVL agent configuration reusing ComposedGroundedConfig for steps and
|
|
80
|
+
overriding predict_click to implement ScreenSpot InternVL grounding baseline."""
|
|
81
|
+
|
|
82
|
+
async def predict_step(
|
|
83
|
+
self,
|
|
84
|
+
messages: List[Dict[str, Any]],
|
|
85
|
+
model: str,
|
|
86
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
87
|
+
max_retries: Optional[int] = None,
|
|
88
|
+
stream: bool = False,
|
|
89
|
+
computer_handler=None,
|
|
90
|
+
_on_api_start=None,
|
|
91
|
+
_on_api_end=None,
|
|
92
|
+
_on_usage=None,
|
|
93
|
+
_on_screenshot=None,
|
|
94
|
+
**kwargs
|
|
95
|
+
) -> Dict[str, Any]:
|
|
96
|
+
"""Fallback to a self-composed model"""
|
|
97
|
+
return await super().predict_step(
|
|
98
|
+
messages=messages,
|
|
99
|
+
model=f"{model}+{model}",
|
|
100
|
+
tools=tools,
|
|
101
|
+
max_retries=max_retries,
|
|
102
|
+
stream=stream,
|
|
103
|
+
computer_handler=computer_handler,
|
|
104
|
+
_on_api_start=_on_api_start,
|
|
105
|
+
_on_api_end=_on_api_end,
|
|
106
|
+
_on_usage=_on_usage,
|
|
107
|
+
_on_screenshot=_on_screenshot,
|
|
108
|
+
**kwargs
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
async def predict_click(
|
|
112
|
+
self,
|
|
113
|
+
model: str,
|
|
114
|
+
image_b64: str,
|
|
115
|
+
instruction: str,
|
|
116
|
+
**kwargs
|
|
117
|
+
) -> Optional[Tuple[int, int]]:
|
|
118
|
+
"""
|
|
119
|
+
Predict click coordinates using InternVL via litellm.acompletion.
|
|
120
|
+
|
|
121
|
+
Behavior mirrors the ScreenSpot InternVL baseline:
|
|
122
|
+
- Prompt: "<image>\nPlease provide the bounding box coordinate of the UI element this user instruction describes: <ref>{instruction}</ref>. Answer in the format of [[x1, y1, x2, y2]]"
|
|
123
|
+
- Parse either [[x,y]] point or [[x1,y1,x2,y2]] bbox, using bbox center if point missing
|
|
124
|
+
- Coordinates are 0-1000 normalized; convert to pixel coordinates for the original screenshot
|
|
125
|
+
"""
|
|
126
|
+
try:
|
|
127
|
+
# Decode image dimensions to scale the normalized outputs
|
|
128
|
+
img_bytes = base64.b64decode(image_b64)
|
|
129
|
+
image = Image.open(BytesIO(img_bytes))
|
|
130
|
+
width, height = image.size
|
|
131
|
+
except Exception:
|
|
132
|
+
# If decoding fails, proceed with a safe default size to avoid crash
|
|
133
|
+
width, height = 1920, 1080
|
|
134
|
+
|
|
135
|
+
# Build grounding prompt exactly like the baseline
|
|
136
|
+
grounding_prompt = (
|
|
137
|
+
f"Please provide the bounding box coordinate of the UI element this user instruction describes: <ref>{instruction}</ref>. "
|
|
138
|
+
f"Answer in the format of [[x1, y1, x2, y2]]"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Prepare messages for LiteLLM
|
|
142
|
+
messages = [
|
|
143
|
+
{
|
|
144
|
+
"role": "user",
|
|
145
|
+
"content": [
|
|
146
|
+
{
|
|
147
|
+
"type": "image_url",
|
|
148
|
+
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
|
|
149
|
+
},
|
|
150
|
+
{"type": "text", "text": grounding_prompt},
|
|
151
|
+
],
|
|
152
|
+
}
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
# Call acompletion; HuggingFaceLocalAdapter/model handler will handle InternVL loading
|
|
156
|
+
api_kwargs = {
|
|
157
|
+
"model": model,
|
|
158
|
+
"messages": messages,
|
|
159
|
+
# Conservative generation params akin to baseline (deterministic)
|
|
160
|
+
"max_tokens": kwargs.get("max_tokens", 256),
|
|
161
|
+
"temperature": kwargs.get("temperature", 0.0),
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
response = await litellm.acompletion(**api_kwargs)
|
|
165
|
+
output_text = (response.choices[0].message.content or "").strip() # type: ignore
|
|
166
|
+
|
|
167
|
+
# print(f"InternVL output: {output_text}")
|
|
168
|
+
|
|
169
|
+
# Try to parse a point first; if absent, parse bbox and take center
|
|
170
|
+
point = _extract_first_point(output_text)
|
|
171
|
+
if point is None:
|
|
172
|
+
bbox = _extract_last_bbox(output_text)
|
|
173
|
+
if bbox is None:
|
|
174
|
+
return None
|
|
175
|
+
x1, y1, x2, y2 = bbox
|
|
176
|
+
cx = (x1 + x2) / 2.0
|
|
177
|
+
cy = (y1 + y2) / 2.0
|
|
178
|
+
point = (cx, cy)
|
|
179
|
+
|
|
180
|
+
x_norm, y_norm = point
|
|
181
|
+
x_px, y_px = _scale_norm_to_pixels(x_norm, y_norm, width, height)
|
|
182
|
+
return (x_px, y_px)
|
|
183
|
+
|
|
184
|
+
def get_capabilities(self) -> List[AgentCapability]:
|
|
185
|
+
return ["click", "step"]
|