cua-agent 0.4.22__py3-none-any.whl → 0.7.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +4 -10
- agent/__main__.py +2 -1
- agent/adapters/__init__.py +4 -0
- agent/adapters/azure_ml_adapter.py +283 -0
- agent/adapters/cua_adapter.py +161 -0
- agent/adapters/huggingfacelocal_adapter.py +67 -125
- agent/adapters/human_adapter.py +116 -114
- agent/adapters/mlxvlm_adapter.py +110 -99
- agent/adapters/models/__init__.py +41 -0
- agent/adapters/models/generic.py +78 -0
- agent/adapters/models/internvl.py +290 -0
- agent/adapters/models/opencua.py +115 -0
- agent/adapters/models/qwen2_5_vl.py +78 -0
- agent/agent.py +337 -185
- agent/callbacks/__init__.py +9 -4
- agent/callbacks/base.py +45 -31
- agent/callbacks/budget_manager.py +22 -10
- agent/callbacks/image_retention.py +54 -98
- agent/callbacks/logging.py +55 -42
- agent/callbacks/operator_validator.py +35 -33
- agent/callbacks/otel.py +291 -0
- agent/callbacks/pii_anonymization.py +19 -16
- agent/callbacks/prompt_instructions.py +47 -0
- agent/callbacks/telemetry.py +99 -61
- agent/callbacks/trajectory_saver.py +95 -69
- agent/cli.py +269 -119
- agent/computers/__init__.py +14 -9
- agent/computers/base.py +32 -19
- agent/computers/cua.py +52 -25
- agent/computers/custom.py +78 -71
- agent/decorators.py +23 -14
- agent/human_tool/__init__.py +2 -7
- agent/human_tool/__main__.py +6 -2
- agent/human_tool/server.py +48 -37
- agent/human_tool/ui.py +359 -235
- agent/integrations/hud/__init__.py +38 -99
- agent/integrations/hud/agent.py +369 -0
- agent/integrations/hud/proxy.py +166 -52
- agent/loops/__init__.py +44 -14
- agent/loops/anthropic.py +579 -492
- agent/loops/base.py +19 -15
- agent/loops/composed_grounded.py +136 -150
- agent/loops/fara/__init__.py +8 -0
- agent/loops/fara/config.py +506 -0
- agent/loops/fara/helpers.py +357 -0
- agent/loops/fara/schema.py +143 -0
- agent/loops/gelato.py +183 -0
- agent/loops/gemini.py +935 -0
- agent/loops/generic_vlm.py +601 -0
- agent/loops/glm45v.py +140 -135
- agent/loops/gta1.py +48 -51
- agent/loops/holo.py +218 -0
- agent/loops/internvl.py +180 -0
- agent/loops/moondream3.py +493 -0
- agent/loops/omniparser.py +326 -226
- agent/loops/openai.py +50 -51
- agent/loops/opencua.py +134 -0
- agent/loops/uiins.py +175 -0
- agent/loops/uitars.py +247 -206
- agent/loops/uitars2.py +951 -0
- agent/playground/__init__.py +5 -0
- agent/playground/server.py +301 -0
- agent/proxy/examples.py +61 -57
- agent/proxy/handlers.py +46 -39
- agent/responses.py +447 -347
- agent/tools/__init__.py +24 -0
- agent/tools/base.py +253 -0
- agent/tools/browser_tool.py +423 -0
- agent/types.py +11 -5
- agent/ui/__init__.py +1 -1
- agent/ui/__main__.py +1 -1
- agent/ui/gradio/app.py +25 -22
- agent/ui/gradio/ui_components.py +314 -167
- cua_agent-0.7.16.dist-info/METADATA +85 -0
- cua_agent-0.7.16.dist-info/RECORD +79 -0
- {cua_agent-0.4.22.dist-info → cua_agent-0.7.16.dist-info}/WHEEL +1 -1
- cua_agent-0.4.22.dist-info/METADATA +0 -436
- cua_agent-0.4.22.dist-info/RECORD +0 -51
- {cua_agent-0.4.22.dist-info → cua_agent-0.7.16.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
# Source: https://github.com/QwenLM/Qwen-Agent/blob/main/qwen_agent/llm/fncall_prompts/nous_fncall_prompt.py
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
from .schema import ContentItem, Message
|
|
9
|
+
|
|
10
|
+
FN_CALL_TEMPLATE_QWEN = """# Tools
|
|
11
|
+
|
|
12
|
+
You may call one or more functions to assist with the user query.
|
|
13
|
+
|
|
14
|
+
You are provided with function signatures within <tools></tools> XML tags:
|
|
15
|
+
<tools>
|
|
16
|
+
{tool_descs}
|
|
17
|
+
</tools>
|
|
18
|
+
|
|
19
|
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
|
20
|
+
<tool_call>
|
|
21
|
+
{{"name": <function-name>, "arguments": <args-json-object>}}
|
|
22
|
+
</tool_call>"""
|
|
23
|
+
|
|
24
|
+
FN_CALL_TEMPLATE = """You are a web automation agent that performs actions on websites to fulfill user requests by calling various tools.
|
|
25
|
+
* You should stop execution at Critical Points. A Critical Point would be encountered in tasks like 'Checkout', 'Book', 'Purchase', 'Call', 'Email', 'Order', etc where a binding transaction/agreement would require the user's permission/personal or sensitive information (name, email, credit card, address, payment information, resume, etc) in order to complete a transaction (purchase, reservation, sign-up etc), or to communicate in a way that a human would be expected to do (call, email, apply to a job, etc).
|
|
26
|
+
* Solve the task as far as you can up until a Critical Point:
|
|
27
|
+
- For example, if the task is to "call a restaurant to make a reservation", you should not actually make the call but should navigate to the restaurant's page and find the phone number.
|
|
28
|
+
- Similarly, if the task is to "order new size 12 running shoes" you should not actually place the order but should instead search for the right shoes that meet the criteria and add them to the cart.
|
|
29
|
+
- Some tasks, like answering questions, may not encounter a Critical Point at all.
|
|
30
|
+
|
|
31
|
+
You are provided with function signatures within <tools></tools> XML tags:
|
|
32
|
+
<tools>
|
|
33
|
+
{tool_descs}
|
|
34
|
+
</tools>
|
|
35
|
+
|
|
36
|
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
|
37
|
+
<tool_call>
|
|
38
|
+
{{"name": <function-name>, "arguments": <args-json-object>}}
|
|
39
|
+
</tool_call>"""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
SPECIAL_CODE_MODE = os.getenv("SPECIAL_CODE_MODE", "false").lower() == "true"
|
|
43
|
+
CODE_TOOL_PATTERN = "code_interpreter"
|
|
44
|
+
FN_CALL_TEMPLATE_WITH_CI = """# Tools
|
|
45
|
+
|
|
46
|
+
You may call one or more functions to assist with the user query.
|
|
47
|
+
|
|
48
|
+
You are provided with function signatures within <tools></tools> XML tags:
|
|
49
|
+
<tools>
|
|
50
|
+
{tool_descs}
|
|
51
|
+
</tools>
|
|
52
|
+
|
|
53
|
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
|
54
|
+
<tool_call>
|
|
55
|
+
{{"name": <function-name>, "arguments": <args-json-object>}}
|
|
56
|
+
</tool_call>
|
|
57
|
+
For code parameters, use placeholders first, and then put the code within <code></code> XML tags, such as:
|
|
58
|
+
<tool_call>
|
|
59
|
+
{{"name": <function-name>, "arguments": {{"code": ""}}}}
|
|
60
|
+
<code>
|
|
61
|
+
Here is the code.
|
|
62
|
+
</code>
|
|
63
|
+
</tool_call>"""
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class NousFnCallPrompt:
|
|
67
|
+
def __init__(self, template_name: str = "default"):
|
|
68
|
+
"""Initialize NousFnCallPrompt with a specific template.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
template_name: Name of the template to use. Options:
|
|
72
|
+
"default", "qwen", "with_ci"
|
|
73
|
+
"""
|
|
74
|
+
self.template_name = template_name
|
|
75
|
+
self.template_map = {
|
|
76
|
+
"default": FN_CALL_TEMPLATE,
|
|
77
|
+
"qwen": FN_CALL_TEMPLATE_QWEN,
|
|
78
|
+
"with_ci": FN_CALL_TEMPLATE_WITH_CI,
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
if template_name not in self.template_map:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"Unknown template_name: {template_name}. "
|
|
84
|
+
f"Available options: {list(self.template_map.keys())}"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def preprocess_fncall_messages(
|
|
88
|
+
self,
|
|
89
|
+
messages: List[Message],
|
|
90
|
+
functions: List[dict],
|
|
91
|
+
lang: Literal["en", "zh"],
|
|
92
|
+
parallel_function_calls: bool = True,
|
|
93
|
+
function_choice: Union[Literal["auto"], str] = "auto",
|
|
94
|
+
) -> List[Message]:
|
|
95
|
+
del lang # ignored
|
|
96
|
+
del parallel_function_calls # ignored
|
|
97
|
+
if function_choice != "auto":
|
|
98
|
+
raise NotImplementedError
|
|
99
|
+
|
|
100
|
+
ori_messages = messages
|
|
101
|
+
|
|
102
|
+
# Change function_call responses to plaintext responses:
|
|
103
|
+
messages = []
|
|
104
|
+
for msg in copy.deepcopy(ori_messages):
|
|
105
|
+
role, content, reasoning_content = (
|
|
106
|
+
msg.role,
|
|
107
|
+
msg.content,
|
|
108
|
+
msg.reasoning_content,
|
|
109
|
+
)
|
|
110
|
+
if role in ("system", "user"):
|
|
111
|
+
messages.append(msg)
|
|
112
|
+
elif role == "assistant":
|
|
113
|
+
content = content or []
|
|
114
|
+
fn_call = msg.function_call
|
|
115
|
+
if fn_call:
|
|
116
|
+
if (not SPECIAL_CODE_MODE) or (CODE_TOOL_PATTERN not in fn_call.name):
|
|
117
|
+
fc = {
|
|
118
|
+
"name": fn_call.name,
|
|
119
|
+
"arguments": json.loads(fn_call.arguments),
|
|
120
|
+
}
|
|
121
|
+
fc = json.dumps(fc, ensure_ascii=False)
|
|
122
|
+
fc = f"<tool_call>\n{fc}\n</tool_call>"
|
|
123
|
+
else:
|
|
124
|
+
para = json.loads(fn_call.arguments)
|
|
125
|
+
code = para["code"]
|
|
126
|
+
para["code"] = ""
|
|
127
|
+
fc = {"name": fn_call.name, "arguments": para}
|
|
128
|
+
fc = json.dumps(fc, ensure_ascii=False)
|
|
129
|
+
fc = f"<tool_call>\n{fc}\n<code>\n{code}\n</code>\n</tool_call>"
|
|
130
|
+
|
|
131
|
+
content.append(ContentItem(text=fc))
|
|
132
|
+
if messages[-1].role == "assistant":
|
|
133
|
+
messages[-1].content.append(ContentItem(text="\n"))
|
|
134
|
+
messages[-1].content.extend(content)
|
|
135
|
+
else:
|
|
136
|
+
# TODO: Assuming there will only be one continuous reasoning_content here
|
|
137
|
+
messages.append(
|
|
138
|
+
Message(
|
|
139
|
+
role=role,
|
|
140
|
+
content=content,
|
|
141
|
+
reasoning_content=reasoning_content,
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
elif role == "function":
|
|
145
|
+
assert isinstance(content, list)
|
|
146
|
+
assert len(content) == 1
|
|
147
|
+
assert content[0].text
|
|
148
|
+
fc = f"<tool_response>\n{content[0].text}\n</tool_response>"
|
|
149
|
+
content = [ContentItem(text=fc)]
|
|
150
|
+
if messages[-1].role == "user":
|
|
151
|
+
messages[-1].content.append(ContentItem(text="\n"))
|
|
152
|
+
messages[-1].content.extend(content)
|
|
153
|
+
else:
|
|
154
|
+
messages.append(Message(role="user", content=content))
|
|
155
|
+
else:
|
|
156
|
+
raise TypeError
|
|
157
|
+
|
|
158
|
+
tool_descs = [{"type": "function", "function": f} for f in functions]
|
|
159
|
+
tool_names = [
|
|
160
|
+
function.get("name_for_model", function.get("name", "")) for function in functions
|
|
161
|
+
]
|
|
162
|
+
tool_descs = "\n".join([json.dumps(f, ensure_ascii=False) for f in tool_descs])
|
|
163
|
+
|
|
164
|
+
# Select template based on configuration
|
|
165
|
+
if SPECIAL_CODE_MODE and any([CODE_TOOL_PATTERN in x for x in tool_names]):
|
|
166
|
+
selected_template = FN_CALL_TEMPLATE_WITH_CI
|
|
167
|
+
else:
|
|
168
|
+
selected_template = self.template_map[self.template_name]
|
|
169
|
+
|
|
170
|
+
tool_system = selected_template.format(tool_descs=tool_descs)
|
|
171
|
+
if messages[0].role == "system":
|
|
172
|
+
messages[0].content.append(ContentItem(text="\n\n" + tool_system))
|
|
173
|
+
else:
|
|
174
|
+
messages = [Message(role="system", content=[ContentItem(text=tool_system)])] + messages
|
|
175
|
+
return messages
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# Mainly for removing incomplete special tokens when streaming the output
|
|
179
|
+
# This assumes that '<tool_call>\n{"name": "' is the special token for the NousFnCallPrompt
|
|
180
|
+
def remove_incomplete_special_tokens(text: str) -> str:
|
|
181
|
+
if text in '<tool_call>\n{"name": "':
|
|
182
|
+
text = ""
|
|
183
|
+
return text
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def extract_fn(text: str):
|
|
187
|
+
fn_name, fn_args = "", ""
|
|
188
|
+
fn_name_s = '"name": "'
|
|
189
|
+
fn_name_e = '", "'
|
|
190
|
+
fn_args_s = '"arguments": '
|
|
191
|
+
i = text.find(fn_name_s)
|
|
192
|
+
k = text.find(fn_args_s)
|
|
193
|
+
if i > 0:
|
|
194
|
+
_text = text[i + len(fn_name_s) :]
|
|
195
|
+
j = _text.find(fn_name_e)
|
|
196
|
+
if j > -1:
|
|
197
|
+
fn_name = _text[:j]
|
|
198
|
+
if k > 0:
|
|
199
|
+
fn_args = text[k + len(fn_args_s) :]
|
|
200
|
+
|
|
201
|
+
if len(fn_args) > 5:
|
|
202
|
+
fn_args = fn_args[:-5]
|
|
203
|
+
else:
|
|
204
|
+
fn_args = ""
|
|
205
|
+
return fn_name, fn_args
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def build_nous_system(functions: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
|
209
|
+
"""Use original FARA NousFnCallPrompt to generate a system message embedding tool schema."""
|
|
210
|
+
from .schema import ContentItem as NousContentItem
|
|
211
|
+
from .schema import Message as NousMessage
|
|
212
|
+
|
|
213
|
+
msgs = NousFnCallPrompt().preprocess_fncall_messages(
|
|
214
|
+
messages=[
|
|
215
|
+
NousMessage(
|
|
216
|
+
role="system", content=[NousContentItem(text="You are a helpful assistant.")]
|
|
217
|
+
)
|
|
218
|
+
],
|
|
219
|
+
functions=functions,
|
|
220
|
+
lang="en",
|
|
221
|
+
)
|
|
222
|
+
sys = msgs[0].model_dump()
|
|
223
|
+
# Convert structured content to OpenAI-style content list
|
|
224
|
+
content = [{"type": "text", "text": c["text"]} for c in sys.get("content", [])]
|
|
225
|
+
return {"role": "system", "content": content}
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def parse_tool_call_from_text(text: str) -> Optional[Dict[str, Any]]:
|
|
229
|
+
"""Extract JSON object within <tool_call>...</tool_call> from model text.
|
|
230
|
+
|
|
231
|
+
Accepts both </tool_call> and <tool_call> as closing tags for robustness.
|
|
232
|
+
Handles nested braces in JSON objects.
|
|
233
|
+
"""
|
|
234
|
+
# Find the opening tag
|
|
235
|
+
start_idx = text.find("<tool_call>")
|
|
236
|
+
if start_idx == -1:
|
|
237
|
+
return None
|
|
238
|
+
|
|
239
|
+
# Find the start of JSON (first '{' after opening tag)
|
|
240
|
+
json_start = text.find("{", start_idx)
|
|
241
|
+
if json_start == -1:
|
|
242
|
+
return None
|
|
243
|
+
|
|
244
|
+
# Extract JSON by counting braces
|
|
245
|
+
brace_count = 0
|
|
246
|
+
json_end = json_start
|
|
247
|
+
for i in range(json_start, len(text)):
|
|
248
|
+
if text[i] == "{":
|
|
249
|
+
brace_count += 1
|
|
250
|
+
elif text[i] == "}":
|
|
251
|
+
brace_count -= 1
|
|
252
|
+
if brace_count == 0:
|
|
253
|
+
json_end = i + 1
|
|
254
|
+
break
|
|
255
|
+
|
|
256
|
+
if brace_count != 0:
|
|
257
|
+
return None
|
|
258
|
+
|
|
259
|
+
json_str = text[json_start:json_end]
|
|
260
|
+
try:
|
|
261
|
+
return json.loads(json_str)
|
|
262
|
+
except Exception:
|
|
263
|
+
return None
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
async def unnormalize_coordinate(args: Dict[str, Any], dims: Tuple[int, int]) -> Dict[str, Any]:
|
|
267
|
+
"""Coordinates appear in 0..1000 space, scale to actual screen size using dims if provided."""
|
|
268
|
+
coord = args.get("coordinate")
|
|
269
|
+
if not coord or not isinstance(coord, (list, tuple)) or len(coord) < 2:
|
|
270
|
+
return args
|
|
271
|
+
x, y = float(coord[0]), float(coord[1])
|
|
272
|
+
width, height = float(dims[0]), float(dims[1])
|
|
273
|
+
x_abs = max(0.0, min(width, (x / 1000.0) * width))
|
|
274
|
+
y_abs = max(0.0, min(height, (y / 1000.0) * height))
|
|
275
|
+
args = {**args, "coordinate": [round(x_abs), round(y_abs)]}
|
|
276
|
+
return args
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def convert_qwen_tool_args_to_computer_action(args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
280
|
+
"""
|
|
281
|
+
Convert Qwen computer tool arguments to the Computer Calls action schema.
|
|
282
|
+
|
|
283
|
+
Qwen (example):
|
|
284
|
+
{"action": "left_click", "coordinate": [114, 68]}
|
|
285
|
+
|
|
286
|
+
Target (example):
|
|
287
|
+
{"action": "left_click", "x": 114, "y": 68}
|
|
288
|
+
|
|
289
|
+
Other mappings:
|
|
290
|
+
- right_click, middle_click, double_click (triple_click -> double_click)
|
|
291
|
+
- mouse_move -> { action: "move", x, y }
|
|
292
|
+
- key -> { action: "keypress", keys: [...] }
|
|
293
|
+
- type -> { action: "type", text }
|
|
294
|
+
- scroll/hscroll -> { action: "scroll", scroll_x, scroll_y, x, y }
|
|
295
|
+
- wait -> { action: "wait" }
|
|
296
|
+
- terminate/answer are not direct UI actions; return None for now
|
|
297
|
+
"""
|
|
298
|
+
if not isinstance(args, dict):
|
|
299
|
+
return None
|
|
300
|
+
|
|
301
|
+
action = args.get("action")
|
|
302
|
+
if not isinstance(action, str):
|
|
303
|
+
return None
|
|
304
|
+
|
|
305
|
+
# Coordinates helper
|
|
306
|
+
coord = args.get("coordinate")
|
|
307
|
+
x = y = None
|
|
308
|
+
if isinstance(coord, (list, tuple)) and len(coord) >= 2:
|
|
309
|
+
try:
|
|
310
|
+
x = int(round(float(coord[0])))
|
|
311
|
+
y = int(round(float(coord[1])))
|
|
312
|
+
except Exception:
|
|
313
|
+
x = y = None
|
|
314
|
+
|
|
315
|
+
# Map actions
|
|
316
|
+
a = action.lower()
|
|
317
|
+
if a in {"left_click", "right_click", "middle_click", "double_click"}:
|
|
318
|
+
if x is None or y is None:
|
|
319
|
+
return None
|
|
320
|
+
return {"action": a, "x": x, "y": y}
|
|
321
|
+
if a == "triple_click":
|
|
322
|
+
# Approximate as double_click
|
|
323
|
+
if x is None or y is None:
|
|
324
|
+
return None
|
|
325
|
+
return {"action": "double_click", "x": x, "y": y}
|
|
326
|
+
if a == "mouse_move":
|
|
327
|
+
if x is None or y is None:
|
|
328
|
+
return None
|
|
329
|
+
return {"action": "move", "x": x, "y": y}
|
|
330
|
+
if a == "key":
|
|
331
|
+
keys = args.get("keys")
|
|
332
|
+
if isinstance(keys, list) and all(isinstance(k, str) for k in keys):
|
|
333
|
+
return {"action": "keypress", "keys": keys}
|
|
334
|
+
return None
|
|
335
|
+
if a == "type":
|
|
336
|
+
text = args.get("text")
|
|
337
|
+
if isinstance(text, str):
|
|
338
|
+
return {"action": "type", "text": text}
|
|
339
|
+
return None
|
|
340
|
+
if a in {"scroll", "hscroll"}:
|
|
341
|
+
pixels = args.get("pixels") or 0
|
|
342
|
+
try:
|
|
343
|
+
pixels_val = int(round(float(pixels)))
|
|
344
|
+
except Exception:
|
|
345
|
+
pixels_val = 0
|
|
346
|
+
scroll_x = pixels_val if a == "hscroll" else 0
|
|
347
|
+
scroll_y = pixels_val if a == "scroll" else 0
|
|
348
|
+
# Include cursor position if available (optional)
|
|
349
|
+
out: Dict[str, Any] = {"action": "scroll", "scroll_x": scroll_x, "scroll_y": scroll_y}
|
|
350
|
+
if x is not None and y is not None:
|
|
351
|
+
out.update({"x": x, "y": y})
|
|
352
|
+
return out
|
|
353
|
+
if a == "wait":
|
|
354
|
+
return {"action": "wait"}
|
|
355
|
+
|
|
356
|
+
# Non-UI or terminal actions: terminate/answer -> not mapped here
|
|
357
|
+
return None
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# Source: https://github.com/QwenLM/Qwen-Agent/blob/main/qwen_agent/llm/schema.py
|
|
2
|
+
|
|
3
|
+
from typing import List, Literal, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, field_validator, model_validator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseModelCompatibleDict(BaseModel):
|
|
9
|
+
def __getitem__(self, item):
|
|
10
|
+
return getattr(self, item)
|
|
11
|
+
|
|
12
|
+
def __setitem__(self, key, value):
|
|
13
|
+
setattr(self, key, value)
|
|
14
|
+
|
|
15
|
+
def model_dump(self, **kwargs):
|
|
16
|
+
if "exclude_none" not in kwargs:
|
|
17
|
+
kwargs["exclude_none"] = True
|
|
18
|
+
return super().model_dump(**kwargs)
|
|
19
|
+
|
|
20
|
+
def model_dump_json(self, **kwargs):
|
|
21
|
+
if "exclude_none" not in kwargs:
|
|
22
|
+
kwargs["exclude_none"] = True
|
|
23
|
+
return super().model_dump_json(**kwargs)
|
|
24
|
+
|
|
25
|
+
def get(self, key, default=None):
|
|
26
|
+
try:
|
|
27
|
+
return getattr(self, key)
|
|
28
|
+
except AttributeError:
|
|
29
|
+
return default
|
|
30
|
+
|
|
31
|
+
def __str__(self):
|
|
32
|
+
return f"{self.model_dump()}"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class FunctionCall(BaseModelCompatibleDict):
|
|
36
|
+
name: str
|
|
37
|
+
arguments: str
|
|
38
|
+
|
|
39
|
+
def __init__(self, name: str, arguments: str):
|
|
40
|
+
super().__init__(name=name, arguments=arguments)
|
|
41
|
+
|
|
42
|
+
def __repr__(self):
|
|
43
|
+
return f"FunctionCall({self.model_dump()})"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ContentItem(BaseModelCompatibleDict):
|
|
47
|
+
text: Optional[str] = None
|
|
48
|
+
image: Optional[str] = None
|
|
49
|
+
file: Optional[str] = None
|
|
50
|
+
audio: Optional[Union[str, dict]] = None
|
|
51
|
+
video: Optional[Union[str, list]] = None
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
text: Optional[str] = None,
|
|
56
|
+
image: Optional[str] = None,
|
|
57
|
+
file: Optional[str] = None,
|
|
58
|
+
audio: Optional[Union[str, dict]] = None,
|
|
59
|
+
video: Optional[Union[str, list]] = None,
|
|
60
|
+
):
|
|
61
|
+
super().__init__(text=text, image=image, file=file, audio=audio, video=video)
|
|
62
|
+
|
|
63
|
+
@model_validator(mode="after")
|
|
64
|
+
def check_exclusivity(self):
|
|
65
|
+
provided_fields = 0
|
|
66
|
+
if self.text is not None:
|
|
67
|
+
provided_fields += 1
|
|
68
|
+
if self.image:
|
|
69
|
+
provided_fields += 1
|
|
70
|
+
if self.file:
|
|
71
|
+
provided_fields += 1
|
|
72
|
+
if self.audio:
|
|
73
|
+
provided_fields += 1
|
|
74
|
+
if self.video:
|
|
75
|
+
provided_fields += 1
|
|
76
|
+
|
|
77
|
+
if provided_fields != 1:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"Exactly one of 'text', 'image', 'file', 'audio', or 'video' must be provided."
|
|
80
|
+
)
|
|
81
|
+
return self
|
|
82
|
+
|
|
83
|
+
def __repr__(self):
|
|
84
|
+
return f"ContentItem({self.model_dump()})"
|
|
85
|
+
|
|
86
|
+
def get_type_and_value(
|
|
87
|
+
self,
|
|
88
|
+
) -> Tuple[Literal["text", "image", "file", "audio", "video"], str]:
|
|
89
|
+
((t, v),) = self.model_dump().items()
|
|
90
|
+
assert t in ("text", "image", "file", "audio", "video")
|
|
91
|
+
return t, v
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def type(self) -> Literal["text", "image", "file", "audio", "video"]:
|
|
95
|
+
t, _ = self.get_type_and_value()
|
|
96
|
+
return t
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def value(self) -> str:
|
|
100
|
+
_, v = self.get_type_and_value()
|
|
101
|
+
return v
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class Message(BaseModelCompatibleDict):
|
|
105
|
+
role: str
|
|
106
|
+
content: Union[str, List[ContentItem]]
|
|
107
|
+
reasoning_content: Optional[Union[str, List[ContentItem]]] = None
|
|
108
|
+
name: Optional[str] = None
|
|
109
|
+
function_call: Optional[FunctionCall] = None
|
|
110
|
+
extra: Optional[dict] = None
|
|
111
|
+
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
role: str,
|
|
115
|
+
content: Union[str, List[ContentItem]],
|
|
116
|
+
reasoning_content: Optional[Union[str, List[ContentItem]]] = None,
|
|
117
|
+
name: Optional[str] = None,
|
|
118
|
+
function_call: Optional[FunctionCall] = None,
|
|
119
|
+
extra: Optional[dict] = None,
|
|
120
|
+
**kwargs,
|
|
121
|
+
):
|
|
122
|
+
if content is None:
|
|
123
|
+
content = ""
|
|
124
|
+
if reasoning_content is None:
|
|
125
|
+
reasoning_content = ""
|
|
126
|
+
super().__init__(
|
|
127
|
+
role=role,
|
|
128
|
+
content=content,
|
|
129
|
+
reasoning_content=reasoning_content,
|
|
130
|
+
name=name,
|
|
131
|
+
function_call=function_call,
|
|
132
|
+
extra=extra,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def __repr__(self):
|
|
136
|
+
return f"Message({self.model_dump()})"
|
|
137
|
+
|
|
138
|
+
@field_validator("role")
|
|
139
|
+
def role_checker(cls, value: str) -> str:
|
|
140
|
+
values = ["system", "user", "assistant", "function"]
|
|
141
|
+
if value not in values:
|
|
142
|
+
raise ValueError(f'{value} must be one of {",".join(values)}')
|
|
143
|
+
return value
|
agent/loops/gelato.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Gelato agent loop implementation for click prediction using litellm.acompletion
|
|
3
|
+
Model: https://huggingface.co/mlfoundations/Gelato-30B-A3B
|
|
4
|
+
Code: https://github.com/mlfoundations/Gelato/tree/main
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import base64
|
|
8
|
+
import math
|
|
9
|
+
import re
|
|
10
|
+
from io import BytesIO
|
|
11
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
import litellm
|
|
14
|
+
from PIL import Image
|
|
15
|
+
|
|
16
|
+
from ..decorators import register_agent
|
|
17
|
+
from ..loops.base import AsyncAgentConfig
|
|
18
|
+
from ..types import AgentCapability
|
|
19
|
+
|
|
20
|
+
SYSTEM_PROMPT = """
|
|
21
|
+
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. For elements with area, return the center point.
|
|
22
|
+
|
|
23
|
+
Output the coordinate pair exactly:
|
|
24
|
+
(x,y)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def extract_coordinates(raw_string):
|
|
29
|
+
"""
|
|
30
|
+
Extract the coordinates from the raw string.
|
|
31
|
+
Args:
|
|
32
|
+
raw_string: str (e.g. "(100, 200)")
|
|
33
|
+
Returns:
|
|
34
|
+
x: float (e.g. 100.0)
|
|
35
|
+
y: float (e.g. 200.0)
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
|
|
39
|
+
return [tuple(map(int, match)) for match in matches][0]
|
|
40
|
+
except:
|
|
41
|
+
return 0, 0
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def smart_resize(
|
|
45
|
+
height: int,
|
|
46
|
+
width: int,
|
|
47
|
+
factor: int = 28,
|
|
48
|
+
min_pixels: int = 3136,
|
|
49
|
+
max_pixels: int = 8847360,
|
|
50
|
+
) -> Tuple[int, int]:
|
|
51
|
+
"""Smart resize function similar to qwen_vl_utils."""
|
|
52
|
+
# Calculate the total pixels
|
|
53
|
+
total_pixels = height * width
|
|
54
|
+
|
|
55
|
+
# If already within bounds, return original dimensions
|
|
56
|
+
if min_pixels <= total_pixels <= max_pixels:
|
|
57
|
+
# Round to nearest factor
|
|
58
|
+
new_height = (height // factor) * factor
|
|
59
|
+
new_width = (width // factor) * factor
|
|
60
|
+
return new_height, new_width
|
|
61
|
+
|
|
62
|
+
# Calculate scaling factor
|
|
63
|
+
if total_pixels > max_pixels:
|
|
64
|
+
scale = (max_pixels / total_pixels) ** 0.5
|
|
65
|
+
else:
|
|
66
|
+
scale = (min_pixels / total_pixels) ** 0.5
|
|
67
|
+
|
|
68
|
+
# Apply scaling
|
|
69
|
+
new_height = int(height * scale)
|
|
70
|
+
new_width = int(width * scale)
|
|
71
|
+
|
|
72
|
+
# Round to nearest factor
|
|
73
|
+
new_height = (new_height // factor) * factor
|
|
74
|
+
new_width = (new_width // factor) * factor
|
|
75
|
+
|
|
76
|
+
# Ensure minimum size
|
|
77
|
+
new_height = max(new_height, factor)
|
|
78
|
+
new_width = max(new_width, factor)
|
|
79
|
+
|
|
80
|
+
return new_height, new_width
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@register_agent(models=r".*Gelato.*")
|
|
84
|
+
class GelatoConfig(AsyncAgentConfig):
|
|
85
|
+
"""Gelato agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
|
86
|
+
|
|
87
|
+
def __init__(self):
|
|
88
|
+
self.current_model = None
|
|
89
|
+
self.last_screenshot_b64 = None
|
|
90
|
+
|
|
91
|
+
async def predict_step(
|
|
92
|
+
self,
|
|
93
|
+
messages: List[Dict[str, Any]],
|
|
94
|
+
model: str,
|
|
95
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
96
|
+
max_retries: Optional[int] = None,
|
|
97
|
+
stream: bool = False,
|
|
98
|
+
computer_handler=None,
|
|
99
|
+
_on_api_start=None,
|
|
100
|
+
_on_api_end=None,
|
|
101
|
+
_on_usage=None,
|
|
102
|
+
_on_screenshot=None,
|
|
103
|
+
**kwargs,
|
|
104
|
+
) -> Dict[str, Any]:
|
|
105
|
+
raise NotImplementedError()
|
|
106
|
+
|
|
107
|
+
async def predict_click(
|
|
108
|
+
self, model: str, image_b64: str, instruction: str, **kwargs
|
|
109
|
+
) -> Optional[Tuple[float, float]]:
|
|
110
|
+
"""
|
|
111
|
+
Predict click coordinates using UI-Ins model via litellm.acompletion.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
model: The UI-Ins model name
|
|
115
|
+
image_b64: Base64 encoded image
|
|
116
|
+
instruction: Instruction for where to click
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Tuple of (x, y) coordinates or None if prediction fails
|
|
120
|
+
"""
|
|
121
|
+
# Decode base64 image
|
|
122
|
+
image_data = base64.b64decode(image_b64)
|
|
123
|
+
image = Image.open(BytesIO(image_data))
|
|
124
|
+
width, height = image.width, image.height
|
|
125
|
+
|
|
126
|
+
# Smart resize the image (similar to qwen_vl_utils)
|
|
127
|
+
resized_height, resized_width = smart_resize(
|
|
128
|
+
height,
|
|
129
|
+
width,
|
|
130
|
+
factor=28, # Default factor for Qwen models
|
|
131
|
+
min_pixels=3136,
|
|
132
|
+
max_pixels=4096 * 2160,
|
|
133
|
+
)
|
|
134
|
+
resized_image = image.resize((resized_width, resized_height))
|
|
135
|
+
scale_x, scale_y = width / resized_width, height / resized_height
|
|
136
|
+
|
|
137
|
+
# Convert resized image back to base64
|
|
138
|
+
buffered = BytesIO()
|
|
139
|
+
resized_image.save(buffered, format="PNG")
|
|
140
|
+
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
|
|
141
|
+
|
|
142
|
+
# Prepare system and user messages
|
|
143
|
+
system_message = {
|
|
144
|
+
"role": "system",
|
|
145
|
+
"content": [{"type": "text", "text": SYSTEM_PROMPT.strip()}],
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
user_message = {
|
|
149
|
+
"role": "user",
|
|
150
|
+
"content": [
|
|
151
|
+
{
|
|
152
|
+
"type": "image_url",
|
|
153
|
+
"image_url": {"url": f"data:image/png;base64,{resized_image_b64}"},
|
|
154
|
+
},
|
|
155
|
+
{"type": "text", "text": instruction},
|
|
156
|
+
],
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
# Prepare API call kwargs
|
|
160
|
+
api_kwargs = {
|
|
161
|
+
"model": model,
|
|
162
|
+
"messages": [system_message, user_message],
|
|
163
|
+
"max_tokens": 2056,
|
|
164
|
+
"temperature": 0.0,
|
|
165
|
+
**kwargs,
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# Use liteLLM acompletion
|
|
169
|
+
response = await litellm.acompletion(**api_kwargs)
|
|
170
|
+
|
|
171
|
+
# Extract response text
|
|
172
|
+
output_text = response.choices[0].message.content # type: ignore
|
|
173
|
+
|
|
174
|
+
# Extract and rescale coordinates
|
|
175
|
+
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
|
|
176
|
+
pred_x *= scale_x
|
|
177
|
+
pred_y *= scale_y
|
|
178
|
+
|
|
179
|
+
return (math.floor(pred_x), math.floor(pred_y))
|
|
180
|
+
|
|
181
|
+
def get_capabilities(self) -> List[AgentCapability]:
|
|
182
|
+
"""Return the capabilities supported by this agent."""
|
|
183
|
+
return ["click"]
|