cua-agent 0.4.22__py3-none-any.whl → 0.7.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (79) hide show
  1. agent/__init__.py +4 -10
  2. agent/__main__.py +2 -1
  3. agent/adapters/__init__.py +4 -0
  4. agent/adapters/azure_ml_adapter.py +283 -0
  5. agent/adapters/cua_adapter.py +161 -0
  6. agent/adapters/huggingfacelocal_adapter.py +67 -125
  7. agent/adapters/human_adapter.py +116 -114
  8. agent/adapters/mlxvlm_adapter.py +110 -99
  9. agent/adapters/models/__init__.py +41 -0
  10. agent/adapters/models/generic.py +78 -0
  11. agent/adapters/models/internvl.py +290 -0
  12. agent/adapters/models/opencua.py +115 -0
  13. agent/adapters/models/qwen2_5_vl.py +78 -0
  14. agent/agent.py +337 -185
  15. agent/callbacks/__init__.py +9 -4
  16. agent/callbacks/base.py +45 -31
  17. agent/callbacks/budget_manager.py +22 -10
  18. agent/callbacks/image_retention.py +54 -98
  19. agent/callbacks/logging.py +55 -42
  20. agent/callbacks/operator_validator.py +35 -33
  21. agent/callbacks/otel.py +291 -0
  22. agent/callbacks/pii_anonymization.py +19 -16
  23. agent/callbacks/prompt_instructions.py +47 -0
  24. agent/callbacks/telemetry.py +99 -61
  25. agent/callbacks/trajectory_saver.py +95 -69
  26. agent/cli.py +269 -119
  27. agent/computers/__init__.py +14 -9
  28. agent/computers/base.py +32 -19
  29. agent/computers/cua.py +52 -25
  30. agent/computers/custom.py +78 -71
  31. agent/decorators.py +23 -14
  32. agent/human_tool/__init__.py +2 -7
  33. agent/human_tool/__main__.py +6 -2
  34. agent/human_tool/server.py +48 -37
  35. agent/human_tool/ui.py +359 -235
  36. agent/integrations/hud/__init__.py +38 -99
  37. agent/integrations/hud/agent.py +369 -0
  38. agent/integrations/hud/proxy.py +166 -52
  39. agent/loops/__init__.py +44 -14
  40. agent/loops/anthropic.py +579 -492
  41. agent/loops/base.py +19 -15
  42. agent/loops/composed_grounded.py +136 -150
  43. agent/loops/fara/__init__.py +8 -0
  44. agent/loops/fara/config.py +506 -0
  45. agent/loops/fara/helpers.py +357 -0
  46. agent/loops/fara/schema.py +143 -0
  47. agent/loops/gelato.py +183 -0
  48. agent/loops/gemini.py +935 -0
  49. agent/loops/generic_vlm.py +601 -0
  50. agent/loops/glm45v.py +140 -135
  51. agent/loops/gta1.py +48 -51
  52. agent/loops/holo.py +218 -0
  53. agent/loops/internvl.py +180 -0
  54. agent/loops/moondream3.py +493 -0
  55. agent/loops/omniparser.py +326 -226
  56. agent/loops/openai.py +50 -51
  57. agent/loops/opencua.py +134 -0
  58. agent/loops/uiins.py +175 -0
  59. agent/loops/uitars.py +247 -206
  60. agent/loops/uitars2.py +951 -0
  61. agent/playground/__init__.py +5 -0
  62. agent/playground/server.py +301 -0
  63. agent/proxy/examples.py +61 -57
  64. agent/proxy/handlers.py +46 -39
  65. agent/responses.py +447 -347
  66. agent/tools/__init__.py +24 -0
  67. agent/tools/base.py +253 -0
  68. agent/tools/browser_tool.py +423 -0
  69. agent/types.py +11 -5
  70. agent/ui/__init__.py +1 -1
  71. agent/ui/__main__.py +1 -1
  72. agent/ui/gradio/app.py +25 -22
  73. agent/ui/gradio/ui_components.py +314 -167
  74. cua_agent-0.7.16.dist-info/METADATA +85 -0
  75. cua_agent-0.7.16.dist-info/RECORD +79 -0
  76. {cua_agent-0.4.22.dist-info → cua_agent-0.7.16.dist-info}/WHEEL +1 -1
  77. cua_agent-0.4.22.dist-info/METADATA +0 -436
  78. cua_agent-0.4.22.dist-info/RECORD +0 -51
  79. {cua_agent-0.4.22.dist-info → cua_agent-0.7.16.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,357 @@
1
+ # Source: https://github.com/QwenLM/Qwen-Agent/blob/main/qwen_agent/llm/fncall_prompts/nous_fncall_prompt.py
2
+
3
+ import copy
4
+ import json
5
+ import os
6
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
7
+
8
+ from .schema import ContentItem, Message
9
+
10
+ FN_CALL_TEMPLATE_QWEN = """# Tools
11
+
12
+ You may call one or more functions to assist with the user query.
13
+
14
+ You are provided with function signatures within <tools></tools> XML tags:
15
+ <tools>
16
+ {tool_descs}
17
+ </tools>
18
+
19
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
20
+ <tool_call>
21
+ {{"name": <function-name>, "arguments": <args-json-object>}}
22
+ </tool_call>"""
23
+
24
+ FN_CALL_TEMPLATE = """You are a web automation agent that performs actions on websites to fulfill user requests by calling various tools.
25
+ * You should stop execution at Critical Points. A Critical Point would be encountered in tasks like 'Checkout', 'Book', 'Purchase', 'Call', 'Email', 'Order', etc where a binding transaction/agreement would require the user's permission/personal or sensitive information (name, email, credit card, address, payment information, resume, etc) in order to complete a transaction (purchase, reservation, sign-up etc), or to communicate in a way that a human would be expected to do (call, email, apply to a job, etc).
26
+ * Solve the task as far as you can up until a Critical Point:
27
+ - For example, if the task is to "call a restaurant to make a reservation", you should not actually make the call but should navigate to the restaurant's page and find the phone number.
28
+ - Similarly, if the task is to "order new size 12 running shoes" you should not actually place the order but should instead search for the right shoes that meet the criteria and add them to the cart.
29
+ - Some tasks, like answering questions, may not encounter a Critical Point at all.
30
+
31
+ You are provided with function signatures within <tools></tools> XML tags:
32
+ <tools>
33
+ {tool_descs}
34
+ </tools>
35
+
36
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
37
+ <tool_call>
38
+ {{"name": <function-name>, "arguments": <args-json-object>}}
39
+ </tool_call>"""
40
+
41
+
42
+ SPECIAL_CODE_MODE = os.getenv("SPECIAL_CODE_MODE", "false").lower() == "true"
43
+ CODE_TOOL_PATTERN = "code_interpreter"
44
+ FN_CALL_TEMPLATE_WITH_CI = """# Tools
45
+
46
+ You may call one or more functions to assist with the user query.
47
+
48
+ You are provided with function signatures within <tools></tools> XML tags:
49
+ <tools>
50
+ {tool_descs}
51
+ </tools>
52
+
53
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
54
+ <tool_call>
55
+ {{"name": <function-name>, "arguments": <args-json-object>}}
56
+ </tool_call>
57
+ For code parameters, use placeholders first, and then put the code within <code></code> XML tags, such as:
58
+ <tool_call>
59
+ {{"name": <function-name>, "arguments": {{"code": ""}}}}
60
+ <code>
61
+ Here is the code.
62
+ </code>
63
+ </tool_call>"""
64
+
65
+
66
+ class NousFnCallPrompt:
67
+ def __init__(self, template_name: str = "default"):
68
+ """Initialize NousFnCallPrompt with a specific template.
69
+
70
+ Args:
71
+ template_name: Name of the template to use. Options:
72
+ "default", "qwen", "with_ci"
73
+ """
74
+ self.template_name = template_name
75
+ self.template_map = {
76
+ "default": FN_CALL_TEMPLATE,
77
+ "qwen": FN_CALL_TEMPLATE_QWEN,
78
+ "with_ci": FN_CALL_TEMPLATE_WITH_CI,
79
+ }
80
+
81
+ if template_name not in self.template_map:
82
+ raise ValueError(
83
+ f"Unknown template_name: {template_name}. "
84
+ f"Available options: {list(self.template_map.keys())}"
85
+ )
86
+
87
+ def preprocess_fncall_messages(
88
+ self,
89
+ messages: List[Message],
90
+ functions: List[dict],
91
+ lang: Literal["en", "zh"],
92
+ parallel_function_calls: bool = True,
93
+ function_choice: Union[Literal["auto"], str] = "auto",
94
+ ) -> List[Message]:
95
+ del lang # ignored
96
+ del parallel_function_calls # ignored
97
+ if function_choice != "auto":
98
+ raise NotImplementedError
99
+
100
+ ori_messages = messages
101
+
102
+ # Change function_call responses to plaintext responses:
103
+ messages = []
104
+ for msg in copy.deepcopy(ori_messages):
105
+ role, content, reasoning_content = (
106
+ msg.role,
107
+ msg.content,
108
+ msg.reasoning_content,
109
+ )
110
+ if role in ("system", "user"):
111
+ messages.append(msg)
112
+ elif role == "assistant":
113
+ content = content or []
114
+ fn_call = msg.function_call
115
+ if fn_call:
116
+ if (not SPECIAL_CODE_MODE) or (CODE_TOOL_PATTERN not in fn_call.name):
117
+ fc = {
118
+ "name": fn_call.name,
119
+ "arguments": json.loads(fn_call.arguments),
120
+ }
121
+ fc = json.dumps(fc, ensure_ascii=False)
122
+ fc = f"<tool_call>\n{fc}\n</tool_call>"
123
+ else:
124
+ para = json.loads(fn_call.arguments)
125
+ code = para["code"]
126
+ para["code"] = ""
127
+ fc = {"name": fn_call.name, "arguments": para}
128
+ fc = json.dumps(fc, ensure_ascii=False)
129
+ fc = f"<tool_call>\n{fc}\n<code>\n{code}\n</code>\n</tool_call>"
130
+
131
+ content.append(ContentItem(text=fc))
132
+ if messages[-1].role == "assistant":
133
+ messages[-1].content.append(ContentItem(text="\n"))
134
+ messages[-1].content.extend(content)
135
+ else:
136
+ # TODO: Assuming there will only be one continuous reasoning_content here
137
+ messages.append(
138
+ Message(
139
+ role=role,
140
+ content=content,
141
+ reasoning_content=reasoning_content,
142
+ )
143
+ )
144
+ elif role == "function":
145
+ assert isinstance(content, list)
146
+ assert len(content) == 1
147
+ assert content[0].text
148
+ fc = f"<tool_response>\n{content[0].text}\n</tool_response>"
149
+ content = [ContentItem(text=fc)]
150
+ if messages[-1].role == "user":
151
+ messages[-1].content.append(ContentItem(text="\n"))
152
+ messages[-1].content.extend(content)
153
+ else:
154
+ messages.append(Message(role="user", content=content))
155
+ else:
156
+ raise TypeError
157
+
158
+ tool_descs = [{"type": "function", "function": f} for f in functions]
159
+ tool_names = [
160
+ function.get("name_for_model", function.get("name", "")) for function in functions
161
+ ]
162
+ tool_descs = "\n".join([json.dumps(f, ensure_ascii=False) for f in tool_descs])
163
+
164
+ # Select template based on configuration
165
+ if SPECIAL_CODE_MODE and any([CODE_TOOL_PATTERN in x for x in tool_names]):
166
+ selected_template = FN_CALL_TEMPLATE_WITH_CI
167
+ else:
168
+ selected_template = self.template_map[self.template_name]
169
+
170
+ tool_system = selected_template.format(tool_descs=tool_descs)
171
+ if messages[0].role == "system":
172
+ messages[0].content.append(ContentItem(text="\n\n" + tool_system))
173
+ else:
174
+ messages = [Message(role="system", content=[ContentItem(text=tool_system)])] + messages
175
+ return messages
176
+
177
+
178
+ # Mainly for removing incomplete special tokens when streaming the output
179
+ # This assumes that '<tool_call>\n{"name": "' is the special token for the NousFnCallPrompt
180
+ def remove_incomplete_special_tokens(text: str) -> str:
181
+ if text in '<tool_call>\n{"name": "':
182
+ text = ""
183
+ return text
184
+
185
+
186
+ def extract_fn(text: str):
187
+ fn_name, fn_args = "", ""
188
+ fn_name_s = '"name": "'
189
+ fn_name_e = '", "'
190
+ fn_args_s = '"arguments": '
191
+ i = text.find(fn_name_s)
192
+ k = text.find(fn_args_s)
193
+ if i > 0:
194
+ _text = text[i + len(fn_name_s) :]
195
+ j = _text.find(fn_name_e)
196
+ if j > -1:
197
+ fn_name = _text[:j]
198
+ if k > 0:
199
+ fn_args = text[k + len(fn_args_s) :]
200
+
201
+ if len(fn_args) > 5:
202
+ fn_args = fn_args[:-5]
203
+ else:
204
+ fn_args = ""
205
+ return fn_name, fn_args
206
+
207
+
208
+ def build_nous_system(functions: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
209
+ """Use original FARA NousFnCallPrompt to generate a system message embedding tool schema."""
210
+ from .schema import ContentItem as NousContentItem
211
+ from .schema import Message as NousMessage
212
+
213
+ msgs = NousFnCallPrompt().preprocess_fncall_messages(
214
+ messages=[
215
+ NousMessage(
216
+ role="system", content=[NousContentItem(text="You are a helpful assistant.")]
217
+ )
218
+ ],
219
+ functions=functions,
220
+ lang="en",
221
+ )
222
+ sys = msgs[0].model_dump()
223
+ # Convert structured content to OpenAI-style content list
224
+ content = [{"type": "text", "text": c["text"]} for c in sys.get("content", [])]
225
+ return {"role": "system", "content": content}
226
+
227
+
228
+ def parse_tool_call_from_text(text: str) -> Optional[Dict[str, Any]]:
229
+ """Extract JSON object within <tool_call>...</tool_call> from model text.
230
+
231
+ Accepts both </tool_call> and <tool_call> as closing tags for robustness.
232
+ Handles nested braces in JSON objects.
233
+ """
234
+ # Find the opening tag
235
+ start_idx = text.find("<tool_call>")
236
+ if start_idx == -1:
237
+ return None
238
+
239
+ # Find the start of JSON (first '{' after opening tag)
240
+ json_start = text.find("{", start_idx)
241
+ if json_start == -1:
242
+ return None
243
+
244
+ # Extract JSON by counting braces
245
+ brace_count = 0
246
+ json_end = json_start
247
+ for i in range(json_start, len(text)):
248
+ if text[i] == "{":
249
+ brace_count += 1
250
+ elif text[i] == "}":
251
+ brace_count -= 1
252
+ if brace_count == 0:
253
+ json_end = i + 1
254
+ break
255
+
256
+ if brace_count != 0:
257
+ return None
258
+
259
+ json_str = text[json_start:json_end]
260
+ try:
261
+ return json.loads(json_str)
262
+ except Exception:
263
+ return None
264
+
265
+
266
+ async def unnormalize_coordinate(args: Dict[str, Any], dims: Tuple[int, int]) -> Dict[str, Any]:
267
+ """Coordinates appear in 0..1000 space, scale to actual screen size using dims if provided."""
268
+ coord = args.get("coordinate")
269
+ if not coord or not isinstance(coord, (list, tuple)) or len(coord) < 2:
270
+ return args
271
+ x, y = float(coord[0]), float(coord[1])
272
+ width, height = float(dims[0]), float(dims[1])
273
+ x_abs = max(0.0, min(width, (x / 1000.0) * width))
274
+ y_abs = max(0.0, min(height, (y / 1000.0) * height))
275
+ args = {**args, "coordinate": [round(x_abs), round(y_abs)]}
276
+ return args
277
+
278
+
279
+ def convert_qwen_tool_args_to_computer_action(args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
280
+ """
281
+ Convert Qwen computer tool arguments to the Computer Calls action schema.
282
+
283
+ Qwen (example):
284
+ {"action": "left_click", "coordinate": [114, 68]}
285
+
286
+ Target (example):
287
+ {"action": "left_click", "x": 114, "y": 68}
288
+
289
+ Other mappings:
290
+ - right_click, middle_click, double_click (triple_click -> double_click)
291
+ - mouse_move -> { action: "move", x, y }
292
+ - key -> { action: "keypress", keys: [...] }
293
+ - type -> { action: "type", text }
294
+ - scroll/hscroll -> { action: "scroll", scroll_x, scroll_y, x, y }
295
+ - wait -> { action: "wait" }
296
+ - terminate/answer are not direct UI actions; return None for now
297
+ """
298
+ if not isinstance(args, dict):
299
+ return None
300
+
301
+ action = args.get("action")
302
+ if not isinstance(action, str):
303
+ return None
304
+
305
+ # Coordinates helper
306
+ coord = args.get("coordinate")
307
+ x = y = None
308
+ if isinstance(coord, (list, tuple)) and len(coord) >= 2:
309
+ try:
310
+ x = int(round(float(coord[0])))
311
+ y = int(round(float(coord[1])))
312
+ except Exception:
313
+ x = y = None
314
+
315
+ # Map actions
316
+ a = action.lower()
317
+ if a in {"left_click", "right_click", "middle_click", "double_click"}:
318
+ if x is None or y is None:
319
+ return None
320
+ return {"action": a, "x": x, "y": y}
321
+ if a == "triple_click":
322
+ # Approximate as double_click
323
+ if x is None or y is None:
324
+ return None
325
+ return {"action": "double_click", "x": x, "y": y}
326
+ if a == "mouse_move":
327
+ if x is None or y is None:
328
+ return None
329
+ return {"action": "move", "x": x, "y": y}
330
+ if a == "key":
331
+ keys = args.get("keys")
332
+ if isinstance(keys, list) and all(isinstance(k, str) for k in keys):
333
+ return {"action": "keypress", "keys": keys}
334
+ return None
335
+ if a == "type":
336
+ text = args.get("text")
337
+ if isinstance(text, str):
338
+ return {"action": "type", "text": text}
339
+ return None
340
+ if a in {"scroll", "hscroll"}:
341
+ pixels = args.get("pixels") or 0
342
+ try:
343
+ pixels_val = int(round(float(pixels)))
344
+ except Exception:
345
+ pixels_val = 0
346
+ scroll_x = pixels_val if a == "hscroll" else 0
347
+ scroll_y = pixels_val if a == "scroll" else 0
348
+ # Include cursor position if available (optional)
349
+ out: Dict[str, Any] = {"action": "scroll", "scroll_x": scroll_x, "scroll_y": scroll_y}
350
+ if x is not None and y is not None:
351
+ out.update({"x": x, "y": y})
352
+ return out
353
+ if a == "wait":
354
+ return {"action": "wait"}
355
+
356
+ # Non-UI or terminal actions: terminate/answer -> not mapped here
357
+ return None
@@ -0,0 +1,143 @@
1
+ # Source: https://github.com/QwenLM/Qwen-Agent/blob/main/qwen_agent/llm/schema.py
2
+
3
+ from typing import List, Literal, Optional, Tuple, Union
4
+
5
+ from pydantic import BaseModel, field_validator, model_validator
6
+
7
+
8
+ class BaseModelCompatibleDict(BaseModel):
9
+ def __getitem__(self, item):
10
+ return getattr(self, item)
11
+
12
+ def __setitem__(self, key, value):
13
+ setattr(self, key, value)
14
+
15
+ def model_dump(self, **kwargs):
16
+ if "exclude_none" not in kwargs:
17
+ kwargs["exclude_none"] = True
18
+ return super().model_dump(**kwargs)
19
+
20
+ def model_dump_json(self, **kwargs):
21
+ if "exclude_none" not in kwargs:
22
+ kwargs["exclude_none"] = True
23
+ return super().model_dump_json(**kwargs)
24
+
25
+ def get(self, key, default=None):
26
+ try:
27
+ return getattr(self, key)
28
+ except AttributeError:
29
+ return default
30
+
31
+ def __str__(self):
32
+ return f"{self.model_dump()}"
33
+
34
+
35
+ class FunctionCall(BaseModelCompatibleDict):
36
+ name: str
37
+ arguments: str
38
+
39
+ def __init__(self, name: str, arguments: str):
40
+ super().__init__(name=name, arguments=arguments)
41
+
42
+ def __repr__(self):
43
+ return f"FunctionCall({self.model_dump()})"
44
+
45
+
46
+ class ContentItem(BaseModelCompatibleDict):
47
+ text: Optional[str] = None
48
+ image: Optional[str] = None
49
+ file: Optional[str] = None
50
+ audio: Optional[Union[str, dict]] = None
51
+ video: Optional[Union[str, list]] = None
52
+
53
+ def __init__(
54
+ self,
55
+ text: Optional[str] = None,
56
+ image: Optional[str] = None,
57
+ file: Optional[str] = None,
58
+ audio: Optional[Union[str, dict]] = None,
59
+ video: Optional[Union[str, list]] = None,
60
+ ):
61
+ super().__init__(text=text, image=image, file=file, audio=audio, video=video)
62
+
63
+ @model_validator(mode="after")
64
+ def check_exclusivity(self):
65
+ provided_fields = 0
66
+ if self.text is not None:
67
+ provided_fields += 1
68
+ if self.image:
69
+ provided_fields += 1
70
+ if self.file:
71
+ provided_fields += 1
72
+ if self.audio:
73
+ provided_fields += 1
74
+ if self.video:
75
+ provided_fields += 1
76
+
77
+ if provided_fields != 1:
78
+ raise ValueError(
79
+ "Exactly one of 'text', 'image', 'file', 'audio', or 'video' must be provided."
80
+ )
81
+ return self
82
+
83
+ def __repr__(self):
84
+ return f"ContentItem({self.model_dump()})"
85
+
86
+ def get_type_and_value(
87
+ self,
88
+ ) -> Tuple[Literal["text", "image", "file", "audio", "video"], str]:
89
+ ((t, v),) = self.model_dump().items()
90
+ assert t in ("text", "image", "file", "audio", "video")
91
+ return t, v
92
+
93
+ @property
94
+ def type(self) -> Literal["text", "image", "file", "audio", "video"]:
95
+ t, _ = self.get_type_and_value()
96
+ return t
97
+
98
+ @property
99
+ def value(self) -> str:
100
+ _, v = self.get_type_and_value()
101
+ return v
102
+
103
+
104
+ class Message(BaseModelCompatibleDict):
105
+ role: str
106
+ content: Union[str, List[ContentItem]]
107
+ reasoning_content: Optional[Union[str, List[ContentItem]]] = None
108
+ name: Optional[str] = None
109
+ function_call: Optional[FunctionCall] = None
110
+ extra: Optional[dict] = None
111
+
112
+ def __init__(
113
+ self,
114
+ role: str,
115
+ content: Union[str, List[ContentItem]],
116
+ reasoning_content: Optional[Union[str, List[ContentItem]]] = None,
117
+ name: Optional[str] = None,
118
+ function_call: Optional[FunctionCall] = None,
119
+ extra: Optional[dict] = None,
120
+ **kwargs,
121
+ ):
122
+ if content is None:
123
+ content = ""
124
+ if reasoning_content is None:
125
+ reasoning_content = ""
126
+ super().__init__(
127
+ role=role,
128
+ content=content,
129
+ reasoning_content=reasoning_content,
130
+ name=name,
131
+ function_call=function_call,
132
+ extra=extra,
133
+ )
134
+
135
+ def __repr__(self):
136
+ return f"Message({self.model_dump()})"
137
+
138
+ @field_validator("role")
139
+ def role_checker(cls, value: str) -> str:
140
+ values = ["system", "user", "assistant", "function"]
141
+ if value not in values:
142
+ raise ValueError(f'{value} must be one of {",".join(values)}')
143
+ return value
agent/loops/gelato.py ADDED
@@ -0,0 +1,183 @@
1
+ """
2
+ Gelato agent loop implementation for click prediction using litellm.acompletion
3
+ Model: https://huggingface.co/mlfoundations/Gelato-30B-A3B
4
+ Code: https://github.com/mlfoundations/Gelato/tree/main
5
+ """
6
+
7
+ import base64
8
+ import math
9
+ import re
10
+ from io import BytesIO
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
+ import litellm
14
+ from PIL import Image
15
+
16
+ from ..decorators import register_agent
17
+ from ..loops.base import AsyncAgentConfig
18
+ from ..types import AgentCapability
19
+
20
+ SYSTEM_PROMPT = """
21
+ You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. For elements with area, return the center point.
22
+
23
+ Output the coordinate pair exactly:
24
+ (x,y)
25
+ """
26
+
27
+
28
+ def extract_coordinates(raw_string):
29
+ """
30
+ Extract the coordinates from the raw string.
31
+ Args:
32
+ raw_string: str (e.g. "(100, 200)")
33
+ Returns:
34
+ x: float (e.g. 100.0)
35
+ y: float (e.g. 200.0)
36
+ """
37
+ try:
38
+ matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
39
+ return [tuple(map(int, match)) for match in matches][0]
40
+ except:
41
+ return 0, 0
42
+
43
+
44
+ def smart_resize(
45
+ height: int,
46
+ width: int,
47
+ factor: int = 28,
48
+ min_pixels: int = 3136,
49
+ max_pixels: int = 8847360,
50
+ ) -> Tuple[int, int]:
51
+ """Smart resize function similar to qwen_vl_utils."""
52
+ # Calculate the total pixels
53
+ total_pixels = height * width
54
+
55
+ # If already within bounds, return original dimensions
56
+ if min_pixels <= total_pixels <= max_pixels:
57
+ # Round to nearest factor
58
+ new_height = (height // factor) * factor
59
+ new_width = (width // factor) * factor
60
+ return new_height, new_width
61
+
62
+ # Calculate scaling factor
63
+ if total_pixels > max_pixels:
64
+ scale = (max_pixels / total_pixels) ** 0.5
65
+ else:
66
+ scale = (min_pixels / total_pixels) ** 0.5
67
+
68
+ # Apply scaling
69
+ new_height = int(height * scale)
70
+ new_width = int(width * scale)
71
+
72
+ # Round to nearest factor
73
+ new_height = (new_height // factor) * factor
74
+ new_width = (new_width // factor) * factor
75
+
76
+ # Ensure minimum size
77
+ new_height = max(new_height, factor)
78
+ new_width = max(new_width, factor)
79
+
80
+ return new_height, new_width
81
+
82
+
83
+ @register_agent(models=r".*Gelato.*")
84
+ class GelatoConfig(AsyncAgentConfig):
85
+ """Gelato agent configuration implementing AsyncAgentConfig protocol for click prediction."""
86
+
87
+ def __init__(self):
88
+ self.current_model = None
89
+ self.last_screenshot_b64 = None
90
+
91
+ async def predict_step(
92
+ self,
93
+ messages: List[Dict[str, Any]],
94
+ model: str,
95
+ tools: Optional[List[Dict[str, Any]]] = None,
96
+ max_retries: Optional[int] = None,
97
+ stream: bool = False,
98
+ computer_handler=None,
99
+ _on_api_start=None,
100
+ _on_api_end=None,
101
+ _on_usage=None,
102
+ _on_screenshot=None,
103
+ **kwargs,
104
+ ) -> Dict[str, Any]:
105
+ raise NotImplementedError()
106
+
107
+ async def predict_click(
108
+ self, model: str, image_b64: str, instruction: str, **kwargs
109
+ ) -> Optional[Tuple[float, float]]:
110
+ """
111
+ Predict click coordinates using UI-Ins model via litellm.acompletion.
112
+
113
+ Args:
114
+ model: The UI-Ins model name
115
+ image_b64: Base64 encoded image
116
+ instruction: Instruction for where to click
117
+
118
+ Returns:
119
+ Tuple of (x, y) coordinates or None if prediction fails
120
+ """
121
+ # Decode base64 image
122
+ image_data = base64.b64decode(image_b64)
123
+ image = Image.open(BytesIO(image_data))
124
+ width, height = image.width, image.height
125
+
126
+ # Smart resize the image (similar to qwen_vl_utils)
127
+ resized_height, resized_width = smart_resize(
128
+ height,
129
+ width,
130
+ factor=28, # Default factor for Qwen models
131
+ min_pixels=3136,
132
+ max_pixels=4096 * 2160,
133
+ )
134
+ resized_image = image.resize((resized_width, resized_height))
135
+ scale_x, scale_y = width / resized_width, height / resized_height
136
+
137
+ # Convert resized image back to base64
138
+ buffered = BytesIO()
139
+ resized_image.save(buffered, format="PNG")
140
+ resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
141
+
142
+ # Prepare system and user messages
143
+ system_message = {
144
+ "role": "system",
145
+ "content": [{"type": "text", "text": SYSTEM_PROMPT.strip()}],
146
+ }
147
+
148
+ user_message = {
149
+ "role": "user",
150
+ "content": [
151
+ {
152
+ "type": "image_url",
153
+ "image_url": {"url": f"data:image/png;base64,{resized_image_b64}"},
154
+ },
155
+ {"type": "text", "text": instruction},
156
+ ],
157
+ }
158
+
159
+ # Prepare API call kwargs
160
+ api_kwargs = {
161
+ "model": model,
162
+ "messages": [system_message, user_message],
163
+ "max_tokens": 2056,
164
+ "temperature": 0.0,
165
+ **kwargs,
166
+ }
167
+
168
+ # Use liteLLM acompletion
169
+ response = await litellm.acompletion(**api_kwargs)
170
+
171
+ # Extract response text
172
+ output_text = response.choices[0].message.content # type: ignore
173
+
174
+ # Extract and rescale coordinates
175
+ pred_x, pred_y = extract_coordinates(output_text) # type: ignore
176
+ pred_x *= scale_x
177
+ pred_y *= scale_y
178
+
179
+ return (math.floor(pred_x), math.floor(pred_y))
180
+
181
+ def get_capabilities(self) -> List[AgentCapability]:
182
+ """Return the capabilities supported by this agent."""
183
+ return ["click"]