cua-agent 0.4.22__py3-none-any.whl → 0.7.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +4 -10
- agent/__main__.py +2 -1
- agent/adapters/__init__.py +4 -0
- agent/adapters/azure_ml_adapter.py +283 -0
- agent/adapters/cua_adapter.py +161 -0
- agent/adapters/huggingfacelocal_adapter.py +67 -125
- agent/adapters/human_adapter.py +116 -114
- agent/adapters/mlxvlm_adapter.py +110 -99
- agent/adapters/models/__init__.py +41 -0
- agent/adapters/models/generic.py +78 -0
- agent/adapters/models/internvl.py +290 -0
- agent/adapters/models/opencua.py +115 -0
- agent/adapters/models/qwen2_5_vl.py +78 -0
- agent/agent.py +337 -185
- agent/callbacks/__init__.py +9 -4
- agent/callbacks/base.py +45 -31
- agent/callbacks/budget_manager.py +22 -10
- agent/callbacks/image_retention.py +54 -98
- agent/callbacks/logging.py +55 -42
- agent/callbacks/operator_validator.py +35 -33
- agent/callbacks/otel.py +291 -0
- agent/callbacks/pii_anonymization.py +19 -16
- agent/callbacks/prompt_instructions.py +47 -0
- agent/callbacks/telemetry.py +99 -61
- agent/callbacks/trajectory_saver.py +95 -69
- agent/cli.py +269 -119
- agent/computers/__init__.py +14 -9
- agent/computers/base.py +32 -19
- agent/computers/cua.py +52 -25
- agent/computers/custom.py +78 -71
- agent/decorators.py +23 -14
- agent/human_tool/__init__.py +2 -7
- agent/human_tool/__main__.py +6 -2
- agent/human_tool/server.py +48 -37
- agent/human_tool/ui.py +359 -235
- agent/integrations/hud/__init__.py +38 -99
- agent/integrations/hud/agent.py +369 -0
- agent/integrations/hud/proxy.py +166 -52
- agent/loops/__init__.py +44 -14
- agent/loops/anthropic.py +579 -492
- agent/loops/base.py +19 -15
- agent/loops/composed_grounded.py +136 -150
- agent/loops/fara/__init__.py +8 -0
- agent/loops/fara/config.py +506 -0
- agent/loops/fara/helpers.py +357 -0
- agent/loops/fara/schema.py +143 -0
- agent/loops/gelato.py +183 -0
- agent/loops/gemini.py +935 -0
- agent/loops/generic_vlm.py +601 -0
- agent/loops/glm45v.py +140 -135
- agent/loops/gta1.py +48 -51
- agent/loops/holo.py +218 -0
- agent/loops/internvl.py +180 -0
- agent/loops/moondream3.py +493 -0
- agent/loops/omniparser.py +326 -226
- agent/loops/openai.py +50 -51
- agent/loops/opencua.py +134 -0
- agent/loops/uiins.py +175 -0
- agent/loops/uitars.py +247 -206
- agent/loops/uitars2.py +951 -0
- agent/playground/__init__.py +5 -0
- agent/playground/server.py +301 -0
- agent/proxy/examples.py +61 -57
- agent/proxy/handlers.py +46 -39
- agent/responses.py +447 -347
- agent/tools/__init__.py +24 -0
- agent/tools/base.py +253 -0
- agent/tools/browser_tool.py +423 -0
- agent/types.py +11 -5
- agent/ui/__init__.py +1 -1
- agent/ui/__main__.py +1 -1
- agent/ui/gradio/app.py +25 -22
- agent/ui/gradio/ui_components.py +314 -167
- cua_agent-0.7.16.dist-info/METADATA +85 -0
- cua_agent-0.7.16.dist-info/RECORD +79 -0
- {cua_agent-0.4.22.dist-info → cua_agent-0.7.16.dist-info}/WHEEL +1 -1
- cua_agent-0.4.22.dist-info/METADATA +0 -436
- cua_agent-0.4.22.dist-info/RECORD +0 -51
- {cua_agent-0.4.22.dist-info → cua_agent-0.7.16.dist-info}/entry_points.txt +0 -0
agent/loops/gta1.py
CHANGED
|
@@ -5,75 +5,80 @@ Code: https://github.com/Yan98/GTA1
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import asyncio
|
|
8
|
+
import base64
|
|
8
9
|
import json
|
|
10
|
+
import math
|
|
9
11
|
import re
|
|
10
|
-
import base64
|
|
11
|
-
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
|
12
|
-
from io import BytesIO
|
|
13
12
|
import uuid
|
|
14
|
-
from
|
|
13
|
+
from io import BytesIO
|
|
14
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
|
15
|
+
|
|
15
16
|
import litellm
|
|
16
|
-
import
|
|
17
|
+
from PIL import Image
|
|
17
18
|
|
|
18
19
|
from ..decorators import register_agent
|
|
19
|
-
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
|
20
20
|
from ..loops.base import AsyncAgentConfig
|
|
21
|
+
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
|
21
22
|
|
|
22
|
-
SYSTEM_PROMPT =
|
|
23
|
+
SYSTEM_PROMPT = """
|
|
23
24
|
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
|
|
24
25
|
|
|
25
26
|
Output the coordinate pair exactly:
|
|
26
27
|
(x,y)
|
|
27
|
-
|
|
28
|
+
""".strip()
|
|
29
|
+
|
|
28
30
|
|
|
29
31
|
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
|
|
30
32
|
"""Extract coordinates from model output."""
|
|
31
33
|
try:
|
|
32
34
|
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
|
|
33
|
-
return tuple(map(float, matches[0]))
|
|
35
|
+
return tuple(map(float, matches[0])) # type: ignore
|
|
34
36
|
except:
|
|
35
37
|
return (0.0, 0.0)
|
|
36
38
|
|
|
37
|
-
|
|
39
|
+
|
|
40
|
+
def smart_resize(
|
|
41
|
+
height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360
|
|
42
|
+
) -> Tuple[int, int]:
|
|
38
43
|
"""Smart resize function similar to qwen_vl_utils."""
|
|
39
44
|
# Calculate the total pixels
|
|
40
45
|
total_pixels = height * width
|
|
41
|
-
|
|
46
|
+
|
|
42
47
|
# If already within bounds, return original dimensions
|
|
43
48
|
if min_pixels <= total_pixels <= max_pixels:
|
|
44
49
|
# Round to nearest factor
|
|
45
50
|
new_height = (height // factor) * factor
|
|
46
51
|
new_width = (width // factor) * factor
|
|
47
52
|
return new_height, new_width
|
|
48
|
-
|
|
53
|
+
|
|
49
54
|
# Calculate scaling factor
|
|
50
55
|
if total_pixels > max_pixels:
|
|
51
56
|
scale = (max_pixels / total_pixels) ** 0.5
|
|
52
57
|
else:
|
|
53
58
|
scale = (min_pixels / total_pixels) ** 0.5
|
|
54
|
-
|
|
59
|
+
|
|
55
60
|
# Apply scaling
|
|
56
61
|
new_height = int(height * scale)
|
|
57
62
|
new_width = int(width * scale)
|
|
58
|
-
|
|
63
|
+
|
|
59
64
|
# Round to nearest factor
|
|
60
65
|
new_height = (new_height // factor) * factor
|
|
61
66
|
new_width = (new_width // factor) * factor
|
|
62
|
-
|
|
67
|
+
|
|
63
68
|
# Ensure minimum size
|
|
64
69
|
new_height = max(new_height, factor)
|
|
65
70
|
new_width = max(new_width, factor)
|
|
66
|
-
|
|
71
|
+
|
|
67
72
|
return new_height, new_width
|
|
68
73
|
|
|
74
|
+
|
|
69
75
|
@register_agent(models=r".*GTA1.*")
|
|
70
76
|
class GTA1Config(AsyncAgentConfig):
|
|
71
77
|
"""GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
|
72
|
-
|
|
78
|
+
|
|
73
79
|
def __init__(self):
|
|
74
80
|
self.current_model = None
|
|
75
81
|
self.last_screenshot_b64 = None
|
|
76
|
-
|
|
77
82
|
|
|
78
83
|
async def predict_step(
|
|
79
84
|
self,
|
|
@@ -87,25 +92,21 @@ class GTA1Config(AsyncAgentConfig):
|
|
|
87
92
|
_on_api_end=None,
|
|
88
93
|
_on_usage=None,
|
|
89
94
|
_on_screenshot=None,
|
|
90
|
-
**kwargs
|
|
95
|
+
**kwargs,
|
|
91
96
|
) -> Dict[str, Any]:
|
|
92
97
|
raise NotImplementedError()
|
|
93
98
|
|
|
94
99
|
async def predict_click(
|
|
95
|
-
self,
|
|
96
|
-
model: str,
|
|
97
|
-
image_b64: str,
|
|
98
|
-
instruction: str,
|
|
99
|
-
**kwargs
|
|
100
|
+
self, model: str, image_b64: str, instruction: str, **kwargs
|
|
100
101
|
) -> Optional[Tuple[float, float]]:
|
|
101
102
|
"""
|
|
102
103
|
Predict click coordinates using GTA1 model via litellm.acompletion.
|
|
103
|
-
|
|
104
|
+
|
|
104
105
|
Args:
|
|
105
106
|
model: The GTA1 model name
|
|
106
107
|
image_b64: Base64 encoded image
|
|
107
108
|
instruction: Instruction for where to click
|
|
108
|
-
|
|
109
|
+
|
|
109
110
|
Returns:
|
|
110
111
|
Tuple of (x, y) coordinates or None if prediction fails
|
|
111
112
|
"""
|
|
@@ -113,66 +114,62 @@ class GTA1Config(AsyncAgentConfig):
|
|
|
113
114
|
image_data = base64.b64decode(image_b64)
|
|
114
115
|
image = Image.open(BytesIO(image_data))
|
|
115
116
|
width, height = image.width, image.height
|
|
116
|
-
|
|
117
|
+
|
|
117
118
|
# Smart resize the image (similar to qwen_vl_utils)
|
|
118
119
|
resized_height, resized_width = smart_resize(
|
|
119
|
-
height,
|
|
120
|
+
height,
|
|
121
|
+
width,
|
|
120
122
|
factor=28, # Default factor for Qwen models
|
|
121
123
|
min_pixels=3136,
|
|
122
|
-
max_pixels=4096 * 2160
|
|
124
|
+
max_pixels=4096 * 2160,
|
|
123
125
|
)
|
|
124
126
|
resized_image = image.resize((resized_width, resized_height))
|
|
125
127
|
scale_x, scale_y = width / resized_width, height / resized_height
|
|
126
|
-
|
|
128
|
+
|
|
127
129
|
# Convert resized image back to base64
|
|
128
130
|
buffered = BytesIO()
|
|
129
131
|
resized_image.save(buffered, format="PNG")
|
|
130
132
|
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
|
|
131
|
-
|
|
133
|
+
|
|
132
134
|
# Prepare system and user messages
|
|
133
135
|
system_message = {
|
|
134
136
|
"role": "system",
|
|
135
|
-
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width)
|
|
137
|
+
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width),
|
|
136
138
|
}
|
|
137
|
-
|
|
139
|
+
|
|
138
140
|
user_message = {
|
|
139
141
|
"role": "user",
|
|
140
142
|
"content": [
|
|
141
143
|
{
|
|
142
144
|
"type": "image_url",
|
|
143
|
-
"image_url": {
|
|
144
|
-
"url": f"data:image/png;base64,{resized_image_b64}"
|
|
145
|
-
}
|
|
145
|
+
"image_url": {"url": f"data:image/png;base64,{resized_image_b64}"},
|
|
146
146
|
},
|
|
147
|
-
{
|
|
148
|
-
|
|
149
|
-
"text": instruction
|
|
150
|
-
}
|
|
151
|
-
]
|
|
147
|
+
{"type": "text", "text": instruction},
|
|
148
|
+
],
|
|
152
149
|
}
|
|
153
|
-
|
|
150
|
+
|
|
154
151
|
# Prepare API call kwargs
|
|
155
152
|
api_kwargs = {
|
|
156
153
|
"model": model,
|
|
157
154
|
"messages": [system_message, user_message],
|
|
158
|
-
"max_tokens":
|
|
155
|
+
"max_tokens": 2056,
|
|
159
156
|
"temperature": 0.0,
|
|
160
|
-
**kwargs
|
|
157
|
+
**kwargs,
|
|
161
158
|
}
|
|
162
|
-
|
|
159
|
+
|
|
163
160
|
# Use liteLLM acompletion
|
|
164
161
|
response = await litellm.acompletion(**api_kwargs)
|
|
165
|
-
|
|
162
|
+
|
|
166
163
|
# Extract response text
|
|
167
|
-
output_text = response.choices[0].message.content
|
|
168
|
-
|
|
164
|
+
output_text = response.choices[0].message.content # type: ignore
|
|
165
|
+
|
|
169
166
|
# Extract and rescale coordinates
|
|
170
|
-
pred_x, pred_y = extract_coordinates(output_text)
|
|
167
|
+
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
|
|
171
168
|
pred_x *= scale_x
|
|
172
169
|
pred_y *= scale_y
|
|
173
|
-
|
|
170
|
+
|
|
174
171
|
return (math.floor(pred_x), math.floor(pred_y))
|
|
175
|
-
|
|
172
|
+
|
|
176
173
|
def get_capabilities(self) -> List[AgentCapability]:
|
|
177
174
|
"""Return the capabilities supported by this agent."""
|
|
178
175
|
return ["click"]
|
agent/loops/holo.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Holo 1.5 agent loop implementation for click prediction using litellm.acompletion.
|
|
3
|
+
|
|
4
|
+
Implements the Holo1.5 grounding behavior:
|
|
5
|
+
- Prompt asks for absolute pixel coordinates in JSON: {"action":"click_absolute","x":int,"y":int}
|
|
6
|
+
- Optionally resizes the image using Qwen2-VL smart_resize parameters (via transformers AutoProcessor)
|
|
7
|
+
- If resized, maps predicted coordinates back to the original screenshot resolution
|
|
8
|
+
|
|
9
|
+
Note: We do NOT manually load the model; acompletions (via HuggingFaceLocalAdapter)
|
|
10
|
+
will handle loading based on the provided model name.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
import json
|
|
17
|
+
from io import BytesIO
|
|
18
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
19
|
+
|
|
20
|
+
import litellm
|
|
21
|
+
from PIL import Image
|
|
22
|
+
|
|
23
|
+
from ..decorators import register_agent
|
|
24
|
+
from ..types import AgentCapability
|
|
25
|
+
from .base import AsyncAgentConfig
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _strip_hf_prefix(model: str) -> str:
|
|
29
|
+
"""Strip provider prefixes like 'huggingface-local/' from model names for HF processor load."""
|
|
30
|
+
if "/" in model and model.lower().startswith("huggingface-local/"):
|
|
31
|
+
return model.split("/", 1)[1]
|
|
32
|
+
return model
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _maybe_smart_resize(image: Image.Image, model: str) -> Tuple[Image.Image, Tuple[int, int]]:
|
|
36
|
+
"""
|
|
37
|
+
Try to compute Qwen2-VL smart_resize output size using transformers AutoProcessor.
|
|
38
|
+
|
|
39
|
+
Returns (processed_image, (orig_w, orig_h)). If transformers or processor unavailable,
|
|
40
|
+
returns the original image and size without resizing.
|
|
41
|
+
"""
|
|
42
|
+
orig_w, orig_h = image.size
|
|
43
|
+
try:
|
|
44
|
+
# Import lazily to avoid hard dependency if not installed
|
|
45
|
+
from transformers import AutoProcessor # type: ignore
|
|
46
|
+
from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( # type: ignore
|
|
47
|
+
smart_resize,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
processor_name = _strip_hf_prefix(model)
|
|
51
|
+
processor = AutoProcessor.from_pretrained(processor_name)
|
|
52
|
+
image_processor = getattr(processor, "image_processor", None)
|
|
53
|
+
if image_processor is None:
|
|
54
|
+
return image, (orig_w, orig_h)
|
|
55
|
+
|
|
56
|
+
factor = getattr(image_processor, "patch_size", 14) * getattr(
|
|
57
|
+
image_processor, "merge_size", 1
|
|
58
|
+
)
|
|
59
|
+
min_pixels = getattr(image_processor, "min_pixels", 256 * 256)
|
|
60
|
+
max_pixels = getattr(image_processor, "max_pixels", 1536 * 1536)
|
|
61
|
+
|
|
62
|
+
resized_h, resized_w = smart_resize(
|
|
63
|
+
orig_h,
|
|
64
|
+
orig_w,
|
|
65
|
+
factor=factor,
|
|
66
|
+
min_pixels=min_pixels,
|
|
67
|
+
max_pixels=max_pixels,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if (resized_w, resized_h) == (orig_w, orig_h):
|
|
71
|
+
return image, (orig_w, orig_h)
|
|
72
|
+
|
|
73
|
+
processed = image.resize((resized_w, resized_h), resample=Image.Resampling.LANCZOS)
|
|
74
|
+
return processed, (orig_w, orig_h)
|
|
75
|
+
except Exception:
|
|
76
|
+
# If any failure (no transformers, processor load error), fall back to original
|
|
77
|
+
return image, (orig_w, orig_h)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _build_holo_prompt(instruction: str) -> str:
|
|
81
|
+
"""Construct the Holo1.5 grounding prompt."""
|
|
82
|
+
# Keep it close to the cookbook while avoiding heavy schema generation
|
|
83
|
+
schema_hint = '{"action": "click_absolute", "x": <int>, "y": <int>}'
|
|
84
|
+
return (
|
|
85
|
+
"Localize an element on the GUI image according to the provided target and output a click position. "
|
|
86
|
+
f"You must output a valid JSON following the format: {schema_hint} "
|
|
87
|
+
f"Your target is: {instruction}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _parse_click_json(output_text: str) -> Optional[Tuple[int, int]]:
|
|
92
|
+
"""
|
|
93
|
+
Parse JSON from model output and extract x, y ints.
|
|
94
|
+
Tries to find the first JSON object substring if extra text is present.
|
|
95
|
+
"""
|
|
96
|
+
try:
|
|
97
|
+
# Fast path: direct JSON
|
|
98
|
+
data = json.loads(output_text)
|
|
99
|
+
except Exception:
|
|
100
|
+
# Try to locate a JSON object within the text
|
|
101
|
+
start = output_text.find("{")
|
|
102
|
+
end = output_text.rfind("}")
|
|
103
|
+
if start == -1 or end == -1 or end <= start:
|
|
104
|
+
return None
|
|
105
|
+
try:
|
|
106
|
+
data = json.loads(output_text[start : end + 1])
|
|
107
|
+
except Exception:
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
x = int(data.get("x"))
|
|
112
|
+
y = int(data.get("y"))
|
|
113
|
+
return x, y
|
|
114
|
+
except Exception:
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@register_agent(models=r"(?i).*(Holo1\.5|Hcompany/Holo1\.5).*")
|
|
119
|
+
class HoloConfig(AsyncAgentConfig):
|
|
120
|
+
"""Holo is a family of UI grounding models from H Company"""
|
|
121
|
+
|
|
122
|
+
async def predict_step(
|
|
123
|
+
self,
|
|
124
|
+
messages: List[Dict[str, Any]],
|
|
125
|
+
model: str,
|
|
126
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
127
|
+
max_retries: Optional[int] = None,
|
|
128
|
+
stream: bool = False,
|
|
129
|
+
computer_handler=None,
|
|
130
|
+
_on_api_start=None,
|
|
131
|
+
_on_api_end=None,
|
|
132
|
+
_on_usage=None,
|
|
133
|
+
_on_screenshot=None,
|
|
134
|
+
**kwargs,
|
|
135
|
+
) -> Dict[str, Any]:
|
|
136
|
+
# Holo models are only trained on UI localization tasks, not all-in-one agent
|
|
137
|
+
raise NotImplementedError()
|
|
138
|
+
|
|
139
|
+
async def predict_click(
|
|
140
|
+
self,
|
|
141
|
+
model: str,
|
|
142
|
+
image_b64: str,
|
|
143
|
+
instruction: str,
|
|
144
|
+
**kwargs,
|
|
145
|
+
) -> Optional[Tuple[int, int]]:
|
|
146
|
+
"""
|
|
147
|
+
Predict click coordinates using Holo1.5 via litellm.acompletion.
|
|
148
|
+
|
|
149
|
+
- Optionally smart-resizes the image using Qwen2-VL rules if transformers are available
|
|
150
|
+
- Prompts for JSON with absolute pixel coordinates
|
|
151
|
+
- Parses x,y and maps back to original screenshot size if resized
|
|
152
|
+
"""
|
|
153
|
+
try:
|
|
154
|
+
img_bytes = base64.b64decode(image_b64)
|
|
155
|
+
original_img = Image.open(BytesIO(img_bytes))
|
|
156
|
+
except Exception:
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
# Optional preprocessing
|
|
160
|
+
processed_img, (orig_w, orig_h) = _maybe_smart_resize(original_img, model)
|
|
161
|
+
|
|
162
|
+
# If we resized, send the resized image; otherwise send original
|
|
163
|
+
img_to_send = processed_img
|
|
164
|
+
buf = BytesIO()
|
|
165
|
+
img_to_send.save(buf, format="PNG")
|
|
166
|
+
processed_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
|
167
|
+
|
|
168
|
+
prompt = _build_holo_prompt(instruction)
|
|
169
|
+
|
|
170
|
+
messages = [
|
|
171
|
+
{
|
|
172
|
+
"role": "user",
|
|
173
|
+
"content": [
|
|
174
|
+
{
|
|
175
|
+
"type": "image_url",
|
|
176
|
+
"image_url": {"url": f"data:image/png;base64,{processed_b64}"},
|
|
177
|
+
},
|
|
178
|
+
{"type": "text", "text": prompt},
|
|
179
|
+
],
|
|
180
|
+
}
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
api_kwargs = {
|
|
184
|
+
"model": model,
|
|
185
|
+
"messages": messages,
|
|
186
|
+
# Deterministic, small output
|
|
187
|
+
"max_tokens": kwargs.get("max_tokens", 256),
|
|
188
|
+
"temperature": kwargs.get("temperature", 0.0),
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
response = await litellm.acompletion(**api_kwargs)
|
|
192
|
+
output_text = (response.choices[0].message.content or "").strip() # type: ignore
|
|
193
|
+
|
|
194
|
+
coords = _parse_click_json(output_text)
|
|
195
|
+
if coords is None:
|
|
196
|
+
return None
|
|
197
|
+
|
|
198
|
+
x, y = coords
|
|
199
|
+
|
|
200
|
+
# Map back to original size if we resized
|
|
201
|
+
proc_w, proc_h = img_to_send.size
|
|
202
|
+
if (proc_w, proc_h) != (orig_w, orig_h):
|
|
203
|
+
try:
|
|
204
|
+
sx = orig_w / float(proc_w)
|
|
205
|
+
sy = orig_h / float(proc_h)
|
|
206
|
+
x = int(round(x * sx))
|
|
207
|
+
y = int(round(y * sy))
|
|
208
|
+
except Exception:
|
|
209
|
+
# Fallback: clamp within original bounds
|
|
210
|
+
pass
|
|
211
|
+
|
|
212
|
+
# Clamp to original image bounds
|
|
213
|
+
x = max(0, min(orig_w - 1, x))
|
|
214
|
+
y = max(0, min(orig_h - 1, y))
|
|
215
|
+
return x, y
|
|
216
|
+
|
|
217
|
+
def get_capabilities(self) -> List[AgentCapability]:
|
|
218
|
+
return ["click"]
|
agent/loops/internvl.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InternVL agent loop implementation for click prediction using litellm.acompletion.
|
|
3
|
+
|
|
4
|
+
Implements the ScreenSpot InternVL grounding baseline behavior:
|
|
5
|
+
- Uses the exact grounding prompt format with <image> and <ref> tags
|
|
6
|
+
- Expects coordinates in 0-1000 normalized range in formats [[x1,y1,x2,y2]] or [[x,y]]
|
|
7
|
+
- Converts to pixel coordinates relative to the original screenshot size
|
|
8
|
+
|
|
9
|
+
Note: We do NOT manually load the InternVL model; acompletions (via HuggingFaceLocalAdapter)
|
|
10
|
+
will handle loading based on the provided model name.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
import math
|
|
17
|
+
import re
|
|
18
|
+
from io import BytesIO
|
|
19
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
20
|
+
|
|
21
|
+
import litellm
|
|
22
|
+
from PIL import Image
|
|
23
|
+
|
|
24
|
+
from ..decorators import register_agent
|
|
25
|
+
from ..types import AgentCapability
|
|
26
|
+
from .composed_grounded import ComposedGroundedConfig
|
|
27
|
+
|
|
28
|
+
# Regex patterns for extracting coordinates
|
|
29
|
+
# Accept optional whitespace and optional decimal fractions
|
|
30
|
+
_NUM = r"(\d+(?:\.\d+)?)"
|
|
31
|
+
_POINT_PATTERN = re.compile(r"\[\[\s*" + _NUM + r"\s*,\s*" + _NUM + r"\s*\]\]")
|
|
32
|
+
_BBOX_PATTERN = re.compile(
|
|
33
|
+
r"\[\[\s*" + _NUM + r"\s*,\s*" + _NUM + r"\s*,\s*" + _NUM + r"\s*,\s*" + _NUM + r"\s*\]\]"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _extract_first_point(text: str) -> Optional[Tuple[float, float]]:
|
|
38
|
+
"""Extract the first [[x,y]] as normalized (0-1000) floats."""
|
|
39
|
+
m = _POINT_PATTERN.search(text)
|
|
40
|
+
if not m:
|
|
41
|
+
return None
|
|
42
|
+
try:
|
|
43
|
+
x = float(m.group(1))
|
|
44
|
+
y = float(m.group(2))
|
|
45
|
+
return x, y
|
|
46
|
+
except Exception:
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _extract_last_bbox(text: str) -> Optional[Tuple[float, float, float, float]]:
|
|
51
|
+
"""Extract the last [[x1,y1,x2,y2]] as normalized (0-1000) floats."""
|
|
52
|
+
matches = list(_BBOX_PATTERN.finditer(text))
|
|
53
|
+
if not matches:
|
|
54
|
+
return None
|
|
55
|
+
m = matches[-1]
|
|
56
|
+
try:
|
|
57
|
+
x1 = float(m.group(1))
|
|
58
|
+
y1 = float(m.group(2))
|
|
59
|
+
x2 = float(m.group(3))
|
|
60
|
+
y2 = float(m.group(4))
|
|
61
|
+
return x1, y1, x2, y2
|
|
62
|
+
except Exception:
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _scale_norm_to_pixels(x_norm: float, y_norm: float, width: int, height: int) -> Tuple[int, int]:
|
|
67
|
+
"""Scale 0-1000 normalized coordinates to pixel coordinates for given image size."""
|
|
68
|
+
x_px = int(math.floor((x_norm / 1000.0) * width))
|
|
69
|
+
y_px = int(math.floor((y_norm / 1000.0) * height))
|
|
70
|
+
# Clamp to image bounds just in case
|
|
71
|
+
x_px = max(0, min(width - 1, x_px))
|
|
72
|
+
y_px = max(0, min(height - 1, y_px))
|
|
73
|
+
return x_px, y_px
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@register_agent(models=r"(?i).*InternVL.*")
|
|
77
|
+
class InternVLConfig(ComposedGroundedConfig):
|
|
78
|
+
"""InternVL agent configuration reusing ComposedGroundedConfig for steps and
|
|
79
|
+
overriding predict_click to implement ScreenSpot InternVL grounding baseline."""
|
|
80
|
+
|
|
81
|
+
async def predict_step(
|
|
82
|
+
self,
|
|
83
|
+
messages: List[Dict[str, Any]],
|
|
84
|
+
model: str,
|
|
85
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
86
|
+
max_retries: Optional[int] = None,
|
|
87
|
+
stream: bool = False,
|
|
88
|
+
computer_handler=None,
|
|
89
|
+
_on_api_start=None,
|
|
90
|
+
_on_api_end=None,
|
|
91
|
+
_on_usage=None,
|
|
92
|
+
_on_screenshot=None,
|
|
93
|
+
**kwargs,
|
|
94
|
+
) -> Dict[str, Any]:
|
|
95
|
+
"""Fallback to a self-composed model"""
|
|
96
|
+
return await super().predict_step(
|
|
97
|
+
messages=messages,
|
|
98
|
+
model=f"{model}+{model}",
|
|
99
|
+
tools=tools,
|
|
100
|
+
max_retries=max_retries,
|
|
101
|
+
stream=stream,
|
|
102
|
+
computer_handler=computer_handler,
|
|
103
|
+
_on_api_start=_on_api_start,
|
|
104
|
+
_on_api_end=_on_api_end,
|
|
105
|
+
_on_usage=_on_usage,
|
|
106
|
+
_on_screenshot=_on_screenshot,
|
|
107
|
+
**kwargs,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
async def predict_click(
|
|
111
|
+
self, model: str, image_b64: str, instruction: str, **kwargs
|
|
112
|
+
) -> Optional[Tuple[int, int]]:
|
|
113
|
+
"""
|
|
114
|
+
Predict click coordinates using InternVL via litellm.acompletion.
|
|
115
|
+
|
|
116
|
+
Behavior mirrors the ScreenSpot InternVL baseline:
|
|
117
|
+
- Prompt: "<image>\nPlease provide the bounding box coordinate of the UI element this user instruction describes: <ref>{instruction}</ref>. Answer in the format of [[x1, y1, x2, y2]]"
|
|
118
|
+
- Parse either [[x,y]] point or [[x1,y1,x2,y2]] bbox, using bbox center if point missing
|
|
119
|
+
- Coordinates are 0-1000 normalized; convert to pixel coordinates for the original screenshot
|
|
120
|
+
"""
|
|
121
|
+
try:
|
|
122
|
+
# Decode image dimensions to scale the normalized outputs
|
|
123
|
+
img_bytes = base64.b64decode(image_b64)
|
|
124
|
+
image = Image.open(BytesIO(img_bytes))
|
|
125
|
+
width, height = image.size
|
|
126
|
+
except Exception:
|
|
127
|
+
# If decoding fails, proceed with a safe default size to avoid crash
|
|
128
|
+
width, height = 1920, 1080
|
|
129
|
+
|
|
130
|
+
# Build grounding prompt exactly like the baseline
|
|
131
|
+
grounding_prompt = (
|
|
132
|
+
f"Please provide the bounding box coordinate of the UI element this user instruction describes: <ref>{instruction}</ref>. "
|
|
133
|
+
f"Answer in the format of [[x1, y1, x2, y2]]"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Prepare messages for LiteLLM
|
|
137
|
+
messages = [
|
|
138
|
+
{
|
|
139
|
+
"role": "user",
|
|
140
|
+
"content": [
|
|
141
|
+
{
|
|
142
|
+
"type": "image_url",
|
|
143
|
+
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
|
|
144
|
+
},
|
|
145
|
+
{"type": "text", "text": grounding_prompt},
|
|
146
|
+
],
|
|
147
|
+
}
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
# Call acompletion; HuggingFaceLocalAdapter/model handler will handle InternVL loading
|
|
151
|
+
api_kwargs = {
|
|
152
|
+
"model": model,
|
|
153
|
+
"messages": messages,
|
|
154
|
+
# Conservative generation params akin to baseline (deterministic)
|
|
155
|
+
"max_tokens": kwargs.get("max_tokens", 256),
|
|
156
|
+
"temperature": kwargs.get("temperature", 0.0),
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
response = await litellm.acompletion(**api_kwargs)
|
|
160
|
+
output_text = (response.choices[0].message.content or "").strip() # type: ignore
|
|
161
|
+
|
|
162
|
+
# print(f"InternVL output: {output_text}")
|
|
163
|
+
|
|
164
|
+
# Try to parse a point first; if absent, parse bbox and take center
|
|
165
|
+
point = _extract_first_point(output_text)
|
|
166
|
+
if point is None:
|
|
167
|
+
bbox = _extract_last_bbox(output_text)
|
|
168
|
+
if bbox is None:
|
|
169
|
+
return None
|
|
170
|
+
x1, y1, x2, y2 = bbox
|
|
171
|
+
cx = (x1 + x2) / 2.0
|
|
172
|
+
cy = (y1 + y2) / 2.0
|
|
173
|
+
point = (cx, cy)
|
|
174
|
+
|
|
175
|
+
x_norm, y_norm = point
|
|
176
|
+
x_px, y_px = _scale_norm_to_pixels(x_norm, y_norm, width, height)
|
|
177
|
+
return (x_px, y_px)
|
|
178
|
+
|
|
179
|
+
def get_capabilities(self) -> List[AgentCapability]:
|
|
180
|
+
return ["click", "step"]
|