cua-agent 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +2 -2
- agent/adapters/huggingfacelocal_adapter.py +8 -5
- agent/agent.py +85 -15
- agent/cli.py +9 -3
- agent/computer_handler.py +3 -1
- agent/decorators.py +28 -66
- agent/loops/__init__.py +3 -1
- agent/loops/anthropic.py +200 -84
- agent/loops/base.py +76 -0
- agent/loops/composed_grounded.py +318 -0
- agent/loops/gta1.py +178 -0
- agent/loops/model_types.csv +6 -0
- agent/loops/omniparser.py +178 -84
- agent/loops/openai.py +198 -58
- agent/loops/uitars.py +305 -178
- agent/responses.py +477 -1
- agent/types.py +7 -5
- agent/ui/gradio/app.py +14 -7
- agent/ui/gradio/ui_components.py +18 -1
- {cua_agent-0.4.6.dist-info → cua_agent-0.4.8.dist-info}/METADATA +3 -3
- cua_agent-0.4.8.dist-info/RECORD +37 -0
- cua_agent-0.4.6.dist-info/RECORD +0 -33
- {cua_agent-0.4.6.dist-info → cua_agent-0.4.8.dist-info}/WHEEL +0 -0
- {cua_agent-0.4.6.dist-info → cua_agent-0.4.8.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Composed-grounded agent loop implementation that combines grounding and thinking models.
|
|
3
|
+
Uses a two-stage approach: grounding model for element detection, thinking model for reasoning.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import uuid
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import base64
|
|
10
|
+
from typing import Dict, List, Any, Optional, Tuple
|
|
11
|
+
from io import BytesIO
|
|
12
|
+
from PIL import Image
|
|
13
|
+
import litellm
|
|
14
|
+
|
|
15
|
+
from ..decorators import register_agent
|
|
16
|
+
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
|
17
|
+
from ..loops.base import AsyncAgentConfig
|
|
18
|
+
from ..responses import (
|
|
19
|
+
convert_computer_calls_xy2desc,
|
|
20
|
+
convert_responses_items_to_completion_messages,
|
|
21
|
+
convert_completion_messages_to_responses_items,
|
|
22
|
+
convert_computer_calls_desc2xy,
|
|
23
|
+
get_all_element_descriptions
|
|
24
|
+
)
|
|
25
|
+
from ..agent import find_agent_config
|
|
26
|
+
|
|
27
|
+
GROUNDED_COMPUTER_TOOL_SCHEMA = {
|
|
28
|
+
"type": "function",
|
|
29
|
+
"function": {
|
|
30
|
+
"name": "computer",
|
|
31
|
+
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
|
|
32
|
+
"parameters": {
|
|
33
|
+
"type": "object",
|
|
34
|
+
"properties": {
|
|
35
|
+
"action": {
|
|
36
|
+
"type": "string",
|
|
37
|
+
"enum": [
|
|
38
|
+
"screenshot",
|
|
39
|
+
"click",
|
|
40
|
+
"double_click",
|
|
41
|
+
"drag",
|
|
42
|
+
"type",
|
|
43
|
+
"keypress",
|
|
44
|
+
"scroll",
|
|
45
|
+
"move",
|
|
46
|
+
"wait",
|
|
47
|
+
"get_current_url",
|
|
48
|
+
"get_dimensions",
|
|
49
|
+
"get_environment"
|
|
50
|
+
],
|
|
51
|
+
"description": "The action to perform"
|
|
52
|
+
},
|
|
53
|
+
"element_description": {
|
|
54
|
+
"type": "string",
|
|
55
|
+
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
|
|
56
|
+
},
|
|
57
|
+
"start_element_description": {
|
|
58
|
+
"type": "string",
|
|
59
|
+
"description": "Description of the element to start dragging from (required for drag action)"
|
|
60
|
+
},
|
|
61
|
+
"end_element_description": {
|
|
62
|
+
"type": "string",
|
|
63
|
+
"description": "Description of the element to drag to (required for drag action)"
|
|
64
|
+
},
|
|
65
|
+
"text": {
|
|
66
|
+
"type": "string",
|
|
67
|
+
"description": "The text to type (required for type action)"
|
|
68
|
+
},
|
|
69
|
+
"keys": {
|
|
70
|
+
"type": "string",
|
|
71
|
+
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
|
|
72
|
+
},
|
|
73
|
+
"button": {
|
|
74
|
+
"type": "string",
|
|
75
|
+
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
|
76
|
+
},
|
|
77
|
+
"scroll_x": {
|
|
78
|
+
"type": "integer",
|
|
79
|
+
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
|
80
|
+
},
|
|
81
|
+
"scroll_y": {
|
|
82
|
+
"type": "integer",
|
|
83
|
+
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
|
84
|
+
},
|
|
85
|
+
},
|
|
86
|
+
"required": [
|
|
87
|
+
"action"
|
|
88
|
+
]
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
def _prepare_tools_for_grounded(tool_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
94
|
+
"""Prepare tools for grounded API format"""
|
|
95
|
+
grounded_tools = []
|
|
96
|
+
|
|
97
|
+
for schema in tool_schemas:
|
|
98
|
+
if schema["type"] == "computer":
|
|
99
|
+
grounded_tools.append(GROUNDED_COMPUTER_TOOL_SCHEMA)
|
|
100
|
+
else:
|
|
101
|
+
grounded_tools.append(schema)
|
|
102
|
+
|
|
103
|
+
return grounded_tools
|
|
104
|
+
|
|
105
|
+
def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str]:
|
|
106
|
+
"""Get the last computer call output image from messages."""
|
|
107
|
+
for message in reversed(messages):
|
|
108
|
+
if (isinstance(message, dict) and
|
|
109
|
+
message.get("type") == "computer_call_output" and
|
|
110
|
+
isinstance(message.get("output"), dict) and
|
|
111
|
+
message["output"].get("type") == "input_image"):
|
|
112
|
+
image_url = message["output"].get("image_url", "")
|
|
113
|
+
if image_url.startswith("data:image/png;base64,"):
|
|
114
|
+
return image_url.split(",", 1)[1]
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@register_agent(r".*\+.*", priority=1)
|
|
119
|
+
class ComposedGroundedConfig:
|
|
120
|
+
"""
|
|
121
|
+
Composed-grounded agent configuration that uses both grounding and thinking models.
|
|
122
|
+
|
|
123
|
+
The model parameter should be in format: "grounding_model+thinking_model"
|
|
124
|
+
e.g., "huggingface-local/HelloKKMe/GTA1-7B+gemini/gemini-1.5-pro"
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
def __init__(self):
|
|
128
|
+
self.desc2xy: Dict[str, Tuple[float, float]] = {}
|
|
129
|
+
|
|
130
|
+
async def predict_step(
|
|
131
|
+
self,
|
|
132
|
+
messages: List[Dict[str, Any]],
|
|
133
|
+
model: str,
|
|
134
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
135
|
+
max_retries: Optional[int] = None,
|
|
136
|
+
stream: bool = False,
|
|
137
|
+
computer_handler=None,
|
|
138
|
+
use_prompt_caching: Optional[bool] = False,
|
|
139
|
+
_on_api_start=None,
|
|
140
|
+
_on_api_end=None,
|
|
141
|
+
_on_usage=None,
|
|
142
|
+
_on_screenshot=None,
|
|
143
|
+
**kwargs
|
|
144
|
+
) -> Dict[str, Any]:
|
|
145
|
+
"""
|
|
146
|
+
Composed-grounded predict step implementation.
|
|
147
|
+
|
|
148
|
+
Process:
|
|
149
|
+
0. Store last computer call image, if none then take a screenshot
|
|
150
|
+
1. Convert computer calls from xy to descriptions
|
|
151
|
+
2. Convert responses items to completion messages
|
|
152
|
+
3. Call thinking model with litellm.acompletion
|
|
153
|
+
4. Convert completion messages to responses items
|
|
154
|
+
5. Get all element descriptions and populate desc2xy mapping
|
|
155
|
+
6. Convert computer calls from descriptions back to xy coordinates
|
|
156
|
+
7. Return output and usage
|
|
157
|
+
"""
|
|
158
|
+
# Parse the composed model
|
|
159
|
+
if "+" not in model:
|
|
160
|
+
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
|
|
161
|
+
grounding_model, thinking_model = model.split("+", 1)
|
|
162
|
+
|
|
163
|
+
pre_output_items = []
|
|
164
|
+
|
|
165
|
+
# Step 0: Store last computer call image, if none then take a screenshot
|
|
166
|
+
last_image_b64 = get_last_computer_call_image(messages)
|
|
167
|
+
if last_image_b64 is None:
|
|
168
|
+
# Take a screenshot
|
|
169
|
+
screenshot_b64 = await computer_handler.screenshot() # type: ignore
|
|
170
|
+
if screenshot_b64:
|
|
171
|
+
|
|
172
|
+
call_id = uuid.uuid4().hex
|
|
173
|
+
pre_output_items += [
|
|
174
|
+
{
|
|
175
|
+
"type": "message",
|
|
176
|
+
"role": "assistant",
|
|
177
|
+
"content": [
|
|
178
|
+
{
|
|
179
|
+
"type": "output_text",
|
|
180
|
+
"text": "Taking a screenshot to see the current computer screen."
|
|
181
|
+
}
|
|
182
|
+
]
|
|
183
|
+
},
|
|
184
|
+
{
|
|
185
|
+
"action": {
|
|
186
|
+
"type": "screenshot"
|
|
187
|
+
},
|
|
188
|
+
"call_id": call_id,
|
|
189
|
+
"status": "completed",
|
|
190
|
+
"type": "computer_call"
|
|
191
|
+
},
|
|
192
|
+
{
|
|
193
|
+
"type": "computer_call_output",
|
|
194
|
+
"call_id": call_id,
|
|
195
|
+
"output": {
|
|
196
|
+
"type": "input_image",
|
|
197
|
+
"image_url": f"data:image/png;base64,{screenshot_b64}"
|
|
198
|
+
}
|
|
199
|
+
},
|
|
200
|
+
]
|
|
201
|
+
last_image_b64 = screenshot_b64
|
|
202
|
+
|
|
203
|
+
# Call screenshot callback if provided
|
|
204
|
+
if _on_screenshot:
|
|
205
|
+
await _on_screenshot(screenshot_b64)
|
|
206
|
+
|
|
207
|
+
tool_schemas = _prepare_tools_for_grounded(tools) # type: ignore
|
|
208
|
+
|
|
209
|
+
# Step 1: Convert computer calls from xy to descriptions
|
|
210
|
+
input_messages = messages + pre_output_items
|
|
211
|
+
messages_with_descriptions = convert_computer_calls_xy2desc(input_messages, self.desc2xy)
|
|
212
|
+
|
|
213
|
+
# Step 2: Convert responses items to completion messages
|
|
214
|
+
completion_messages = convert_responses_items_to_completion_messages(
|
|
215
|
+
messages_with_descriptions,
|
|
216
|
+
allow_images_in_tool_results=False
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Step 3: Call thinking model with litellm.acompletion
|
|
220
|
+
api_kwargs = {
|
|
221
|
+
"model": thinking_model,
|
|
222
|
+
"messages": completion_messages,
|
|
223
|
+
"tools": tool_schemas,
|
|
224
|
+
"max_retries": max_retries,
|
|
225
|
+
"stream": stream,
|
|
226
|
+
**kwargs
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
if use_prompt_caching:
|
|
230
|
+
api_kwargs["use_prompt_caching"] = use_prompt_caching
|
|
231
|
+
|
|
232
|
+
# Call API start hook
|
|
233
|
+
if _on_api_start:
|
|
234
|
+
await _on_api_start(api_kwargs)
|
|
235
|
+
|
|
236
|
+
# Make the completion call
|
|
237
|
+
response = await litellm.acompletion(**api_kwargs)
|
|
238
|
+
|
|
239
|
+
# Call API end hook
|
|
240
|
+
if _on_api_end:
|
|
241
|
+
await _on_api_end(api_kwargs, response)
|
|
242
|
+
|
|
243
|
+
# Extract usage information
|
|
244
|
+
usage = {
|
|
245
|
+
**response.usage.model_dump(), # type: ignore
|
|
246
|
+
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
|
247
|
+
}
|
|
248
|
+
if _on_usage:
|
|
249
|
+
await _on_usage(usage)
|
|
250
|
+
|
|
251
|
+
# Step 4: Convert completion messages back to responses items format
|
|
252
|
+
response_dict = response.model_dump() # type: ignore
|
|
253
|
+
choice_messages = [choice["message"] for choice in response_dict["choices"]]
|
|
254
|
+
thinking_output_items = []
|
|
255
|
+
|
|
256
|
+
for choice_message in choice_messages:
|
|
257
|
+
thinking_output_items.extend(convert_completion_messages_to_responses_items([choice_message]))
|
|
258
|
+
|
|
259
|
+
# Step 5: Get all element descriptions and populate desc2xy mapping
|
|
260
|
+
element_descriptions = get_all_element_descriptions(thinking_output_items)
|
|
261
|
+
|
|
262
|
+
if element_descriptions and last_image_b64:
|
|
263
|
+
# Use grounding model to predict coordinates for each description
|
|
264
|
+
grounding_agent_conf = find_agent_config(grounding_model)
|
|
265
|
+
if grounding_agent_conf:
|
|
266
|
+
grounding_agent = grounding_agent_conf.agent_class()
|
|
267
|
+
|
|
268
|
+
for desc in element_descriptions:
|
|
269
|
+
coords = await grounding_agent.predict_click(
|
|
270
|
+
model=grounding_model,
|
|
271
|
+
image_b64=last_image_b64,
|
|
272
|
+
instruction=desc
|
|
273
|
+
)
|
|
274
|
+
if coords:
|
|
275
|
+
self.desc2xy[desc] = coords
|
|
276
|
+
|
|
277
|
+
# Step 6: Convert computer calls from descriptions back to xy coordinates
|
|
278
|
+
final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy)
|
|
279
|
+
|
|
280
|
+
# Step 7: Return output and usage
|
|
281
|
+
return {
|
|
282
|
+
"output": pre_output_items + final_output_items,
|
|
283
|
+
"usage": usage
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
async def predict_click(
|
|
287
|
+
self,
|
|
288
|
+
model: str,
|
|
289
|
+
image_b64: str,
|
|
290
|
+
instruction: str,
|
|
291
|
+
**kwargs
|
|
292
|
+
) -> Optional[Tuple[int, int]]:
|
|
293
|
+
"""
|
|
294
|
+
Predict click coordinates using the grounding model.
|
|
295
|
+
|
|
296
|
+
For composed models, uses only the grounding model part for click prediction.
|
|
297
|
+
"""
|
|
298
|
+
# Parse the composed model to get grounding model
|
|
299
|
+
if "+" not in model:
|
|
300
|
+
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
|
|
301
|
+
grounding_model, thinking_model = model.split("+", 1)
|
|
302
|
+
|
|
303
|
+
# Find and use the grounding agent
|
|
304
|
+
grounding_agent_conf = find_agent_config(grounding_model)
|
|
305
|
+
if grounding_agent_conf:
|
|
306
|
+
grounding_agent = grounding_agent_conf.agent_class()
|
|
307
|
+
return await grounding_agent.predict_click(
|
|
308
|
+
model=grounding_model,
|
|
309
|
+
image_b64=image_b64,
|
|
310
|
+
instruction=instruction,
|
|
311
|
+
**kwargs
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
return None
|
|
315
|
+
|
|
316
|
+
def get_capabilities(self) -> List[AgentCapability]:
|
|
317
|
+
"""Return the capabilities supported by this agent."""
|
|
318
|
+
return ["click", "step"]
|
agent/loops/gta1.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GTA1 agent loop implementation for click prediction using litellm.acompletion
|
|
3
|
+
Paper: https://arxiv.org/pdf/2507.05791
|
|
4
|
+
Code: https://github.com/Yan98/GTA1
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import re
|
|
10
|
+
import base64
|
|
11
|
+
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
|
12
|
+
from io import BytesIO
|
|
13
|
+
import uuid
|
|
14
|
+
from PIL import Image
|
|
15
|
+
import litellm
|
|
16
|
+
import math
|
|
17
|
+
|
|
18
|
+
from ..decorators import register_agent
|
|
19
|
+
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
|
20
|
+
from ..loops.base import AsyncAgentConfig
|
|
21
|
+
|
|
22
|
+
SYSTEM_PROMPT = '''
|
|
23
|
+
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
|
|
24
|
+
|
|
25
|
+
Output the coordinate pair exactly:
|
|
26
|
+
(x,y)
|
|
27
|
+
'''.strip()
|
|
28
|
+
|
|
29
|
+
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
|
|
30
|
+
"""Extract coordinates from model output."""
|
|
31
|
+
try:
|
|
32
|
+
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
|
|
33
|
+
return tuple(map(float, matches[0])) # type: ignore
|
|
34
|
+
except:
|
|
35
|
+
return (0.0, 0.0)
|
|
36
|
+
|
|
37
|
+
def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360) -> Tuple[int, int]:
|
|
38
|
+
"""Smart resize function similar to qwen_vl_utils."""
|
|
39
|
+
# Calculate the total pixels
|
|
40
|
+
total_pixels = height * width
|
|
41
|
+
|
|
42
|
+
# If already within bounds, return original dimensions
|
|
43
|
+
if min_pixels <= total_pixels <= max_pixels:
|
|
44
|
+
# Round to nearest factor
|
|
45
|
+
new_height = (height // factor) * factor
|
|
46
|
+
new_width = (width // factor) * factor
|
|
47
|
+
return new_height, new_width
|
|
48
|
+
|
|
49
|
+
# Calculate scaling factor
|
|
50
|
+
if total_pixels > max_pixels:
|
|
51
|
+
scale = (max_pixels / total_pixels) ** 0.5
|
|
52
|
+
else:
|
|
53
|
+
scale = (min_pixels / total_pixels) ** 0.5
|
|
54
|
+
|
|
55
|
+
# Apply scaling
|
|
56
|
+
new_height = int(height * scale)
|
|
57
|
+
new_width = int(width * scale)
|
|
58
|
+
|
|
59
|
+
# Round to nearest factor
|
|
60
|
+
new_height = (new_height // factor) * factor
|
|
61
|
+
new_width = (new_width // factor) * factor
|
|
62
|
+
|
|
63
|
+
# Ensure minimum size
|
|
64
|
+
new_height = max(new_height, factor)
|
|
65
|
+
new_width = max(new_width, factor)
|
|
66
|
+
|
|
67
|
+
return new_height, new_width
|
|
68
|
+
|
|
69
|
+
@register_agent(models=r".*GTA1.*")
|
|
70
|
+
class GTA1Config(AsyncAgentConfig):
|
|
71
|
+
"""GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
|
72
|
+
|
|
73
|
+
def __init__(self):
|
|
74
|
+
self.current_model = None
|
|
75
|
+
self.last_screenshot_b64 = None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
async def predict_step(
|
|
79
|
+
self,
|
|
80
|
+
messages: List[Dict[str, Any]],
|
|
81
|
+
model: str,
|
|
82
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
83
|
+
max_retries: Optional[int] = None,
|
|
84
|
+
stream: bool = False,
|
|
85
|
+
computer_handler=None,
|
|
86
|
+
_on_api_start=None,
|
|
87
|
+
_on_api_end=None,
|
|
88
|
+
_on_usage=None,
|
|
89
|
+
_on_screenshot=None,
|
|
90
|
+
**kwargs
|
|
91
|
+
) -> Dict[str, Any]:
|
|
92
|
+
raise NotImplementedError()
|
|
93
|
+
|
|
94
|
+
async def predict_click(
|
|
95
|
+
self,
|
|
96
|
+
model: str,
|
|
97
|
+
image_b64: str,
|
|
98
|
+
instruction: str,
|
|
99
|
+
**kwargs
|
|
100
|
+
) -> Optional[Tuple[float, float]]:
|
|
101
|
+
"""
|
|
102
|
+
Predict click coordinates using GTA1 model via litellm.acompletion.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
model: The GTA1 model name
|
|
106
|
+
image_b64: Base64 encoded image
|
|
107
|
+
instruction: Instruction for where to click
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Tuple of (x, y) coordinates or None if prediction fails
|
|
111
|
+
"""
|
|
112
|
+
# Decode base64 image
|
|
113
|
+
image_data = base64.b64decode(image_b64)
|
|
114
|
+
image = Image.open(BytesIO(image_data))
|
|
115
|
+
width, height = image.width, image.height
|
|
116
|
+
|
|
117
|
+
# Smart resize the image (similar to qwen_vl_utils)
|
|
118
|
+
resized_height, resized_width = smart_resize(
|
|
119
|
+
height, width,
|
|
120
|
+
factor=28, # Default factor for Qwen models
|
|
121
|
+
min_pixels=3136,
|
|
122
|
+
max_pixels=4096 * 2160
|
|
123
|
+
)
|
|
124
|
+
resized_image = image.resize((resized_width, resized_height))
|
|
125
|
+
scale_x, scale_y = width / resized_width, height / resized_height
|
|
126
|
+
|
|
127
|
+
# Convert resized image back to base64
|
|
128
|
+
buffered = BytesIO()
|
|
129
|
+
resized_image.save(buffered, format="PNG")
|
|
130
|
+
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
|
|
131
|
+
|
|
132
|
+
# Prepare system and user messages
|
|
133
|
+
system_message = {
|
|
134
|
+
"role": "system",
|
|
135
|
+
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width)
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
user_message = {
|
|
139
|
+
"role": "user",
|
|
140
|
+
"content": [
|
|
141
|
+
{
|
|
142
|
+
"type": "image_url",
|
|
143
|
+
"image_url": {
|
|
144
|
+
"url": f"data:image/png;base64,{resized_image_b64}"
|
|
145
|
+
}
|
|
146
|
+
},
|
|
147
|
+
{
|
|
148
|
+
"type": "text",
|
|
149
|
+
"text": instruction
|
|
150
|
+
}
|
|
151
|
+
]
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
# Prepare API call kwargs
|
|
155
|
+
api_kwargs = {
|
|
156
|
+
"model": model,
|
|
157
|
+
"messages": [system_message, user_message],
|
|
158
|
+
"max_tokens": 32,
|
|
159
|
+
"temperature": 0.0,
|
|
160
|
+
**kwargs
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
# Use liteLLM acompletion
|
|
164
|
+
response = await litellm.acompletion(**api_kwargs)
|
|
165
|
+
|
|
166
|
+
# Extract response text
|
|
167
|
+
output_text = response.choices[0].message.content # type: ignore
|
|
168
|
+
|
|
169
|
+
# Extract and rescale coordinates
|
|
170
|
+
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
|
|
171
|
+
pred_x *= scale_x
|
|
172
|
+
pred_y *= scale_y
|
|
173
|
+
|
|
174
|
+
return (math.floor(pred_x), math.floor(pred_y))
|
|
175
|
+
|
|
176
|
+
def get_capabilities(self) -> List[AgentCapability]:
|
|
177
|
+
"""Return the capabilities supported by this agent."""
|
|
178
|
+
return ["click"]
|