cua-agent 0.4.22__py3-none-any.whl → 0.7.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +4 -10
- agent/__main__.py +2 -1
- agent/adapters/__init__.py +4 -0
- agent/adapters/azure_ml_adapter.py +283 -0
- agent/adapters/cua_adapter.py +161 -0
- agent/adapters/huggingfacelocal_adapter.py +67 -125
- agent/adapters/human_adapter.py +116 -114
- agent/adapters/mlxvlm_adapter.py +110 -99
- agent/adapters/models/__init__.py +41 -0
- agent/adapters/models/generic.py +78 -0
- agent/adapters/models/internvl.py +290 -0
- agent/adapters/models/opencua.py +115 -0
- agent/adapters/models/qwen2_5_vl.py +78 -0
- agent/agent.py +337 -185
- agent/callbacks/__init__.py +9 -4
- agent/callbacks/base.py +45 -31
- agent/callbacks/budget_manager.py +22 -10
- agent/callbacks/image_retention.py +54 -98
- agent/callbacks/logging.py +55 -42
- agent/callbacks/operator_validator.py +35 -33
- agent/callbacks/otel.py +291 -0
- agent/callbacks/pii_anonymization.py +19 -16
- agent/callbacks/prompt_instructions.py +47 -0
- agent/callbacks/telemetry.py +99 -61
- agent/callbacks/trajectory_saver.py +95 -69
- agent/cli.py +269 -119
- agent/computers/__init__.py +14 -9
- agent/computers/base.py +32 -19
- agent/computers/cua.py +52 -25
- agent/computers/custom.py +78 -71
- agent/decorators.py +23 -14
- agent/human_tool/__init__.py +2 -7
- agent/human_tool/__main__.py +6 -2
- agent/human_tool/server.py +48 -37
- agent/human_tool/ui.py +359 -235
- agent/integrations/hud/__init__.py +38 -99
- agent/integrations/hud/agent.py +369 -0
- agent/integrations/hud/proxy.py +166 -52
- agent/loops/__init__.py +44 -14
- agent/loops/anthropic.py +579 -492
- agent/loops/base.py +19 -15
- agent/loops/composed_grounded.py +136 -150
- agent/loops/fara/__init__.py +8 -0
- agent/loops/fara/config.py +506 -0
- agent/loops/fara/helpers.py +357 -0
- agent/loops/fara/schema.py +143 -0
- agent/loops/gelato.py +183 -0
- agent/loops/gemini.py +935 -0
- agent/loops/generic_vlm.py +601 -0
- agent/loops/glm45v.py +140 -135
- agent/loops/gta1.py +48 -51
- agent/loops/holo.py +218 -0
- agent/loops/internvl.py +180 -0
- agent/loops/moondream3.py +493 -0
- agent/loops/omniparser.py +326 -226
- agent/loops/openai.py +50 -51
- agent/loops/opencua.py +134 -0
- agent/loops/uiins.py +175 -0
- agent/loops/uitars.py +247 -206
- agent/loops/uitars2.py +951 -0
- agent/playground/__init__.py +5 -0
- agent/playground/server.py +301 -0
- agent/proxy/examples.py +61 -57
- agent/proxy/handlers.py +46 -39
- agent/responses.py +447 -347
- agent/tools/__init__.py +24 -0
- agent/tools/base.py +253 -0
- agent/tools/browser_tool.py +423 -0
- agent/types.py +11 -5
- agent/ui/__init__.py +1 -1
- agent/ui/__main__.py +1 -1
- agent/ui/gradio/app.py +25 -22
- agent/ui/gradio/ui_components.py +314 -167
- cua_agent-0.7.16.dist-info/METADATA +85 -0
- cua_agent-0.7.16.dist-info/RECORD +79 -0
- {cua_agent-0.4.22.dist-info → cua_agent-0.7.16.dist-info}/WHEEL +1 -1
- cua_agent-0.4.22.dist-info/METADATA +0 -436
- cua_agent-0.4.22.dist-info/RECORD +0 -51
- {cua_agent-0.4.22.dist-info → cua_agent-0.7.16.dist-info}/entry_points.txt +0 -0
agent/loops/gemini.py
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Gemini Computer Use agent loop
|
|
3
|
+
|
|
4
|
+
Maps internal Agent SDK message format to Google's Gemini Computer Use API and back.
|
|
5
|
+
|
|
6
|
+
Supported models:
|
|
7
|
+
- gemini-2.5-computer-use-preview-10-2025 (uses built-in ComputerUse tool)
|
|
8
|
+
- gemini-3-flash-preview (and variants) (uses custom function declarations)
|
|
9
|
+
- gemini-3-pro-preview (and variants) (uses custom function declarations)
|
|
10
|
+
|
|
11
|
+
Key features:
|
|
12
|
+
- Lazy import of google.genai
|
|
13
|
+
- Configure Computer Use tool with excluded browser-specific predefined functions (Gemini 2.5)
|
|
14
|
+
- Custom function declarations for computer use actions (Gemini 3 models)
|
|
15
|
+
- Convert Gemini function_call parts into internal computer_call actions
|
|
16
|
+
- Gemini 3-specific: thinking_level and media_resolution parameters
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import base64
|
|
22
|
+
import io
|
|
23
|
+
import uuid
|
|
24
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
25
|
+
|
|
26
|
+
from PIL import Image
|
|
27
|
+
|
|
28
|
+
from ..decorators import register_agent
|
|
29
|
+
from ..loops.base import AsyncAgentConfig
|
|
30
|
+
from ..types import AgentCapability
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _lazy_import_genai():
|
|
34
|
+
"""Import google.genai lazily to avoid hard dependency unless used."""
|
|
35
|
+
try:
|
|
36
|
+
from google import genai # type: ignore
|
|
37
|
+
from google.genai import types # type: ignore
|
|
38
|
+
|
|
39
|
+
return genai, types
|
|
40
|
+
except Exception as e: # pragma: no cover
|
|
41
|
+
raise RuntimeError(
|
|
42
|
+
"google.genai is required for the Gemini Computer Use loop. Install the Google Gemini SDK."
|
|
43
|
+
) from e
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _data_url_to_bytes(data_url: str) -> Tuple[bytes, str]:
|
|
47
|
+
"""Convert a data URL to raw bytes and mime type."""
|
|
48
|
+
if not data_url.startswith("data:"):
|
|
49
|
+
# Assume it's base64 png payload
|
|
50
|
+
try:
|
|
51
|
+
return base64.b64decode(data_url), "image/png"
|
|
52
|
+
except Exception:
|
|
53
|
+
return b"", "application/octet-stream"
|
|
54
|
+
header, b64 = data_url.split(",", 1)
|
|
55
|
+
mime = "image/png"
|
|
56
|
+
if ";" in header:
|
|
57
|
+
mime = header.split(";")[0].split(":", 1)[1] or "image/png"
|
|
58
|
+
return base64.b64decode(b64), mime
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _bytes_image_size(img_bytes: bytes) -> Tuple[int, int]:
|
|
62
|
+
try:
|
|
63
|
+
img = Image.open(io.BytesIO(img_bytes))
|
|
64
|
+
return img.size
|
|
65
|
+
except Exception:
|
|
66
|
+
return (1024, 768)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _sanitize_for_json(obj: Any) -> Any:
|
|
70
|
+
"""
|
|
71
|
+
Recursively sanitize an object for JSON serialization.
|
|
72
|
+
Handles bytes fields (like thought_signature in Gemini 3 responses).
|
|
73
|
+
"""
|
|
74
|
+
if obj is None:
|
|
75
|
+
return None
|
|
76
|
+
if isinstance(obj, bytes):
|
|
77
|
+
# Convert bytes to base64 string for JSON serialization
|
|
78
|
+
return f"<bytes:{base64.b64encode(obj).decode('ascii')}>"
|
|
79
|
+
if isinstance(obj, (str, int, float, bool)):
|
|
80
|
+
return obj
|
|
81
|
+
if isinstance(obj, dict):
|
|
82
|
+
return {k: _sanitize_for_json(v) for k, v in obj.items()}
|
|
83
|
+
if isinstance(obj, (list, tuple)):
|
|
84
|
+
return [_sanitize_for_json(item) for item in obj]
|
|
85
|
+
# Handle objects with __dict__ (like Gemini SDK response objects)
|
|
86
|
+
if hasattr(obj, "__dict__"):
|
|
87
|
+
return {k: _sanitize_for_json(v) for k, v in obj.__dict__.items() if not k.startswith("_")}
|
|
88
|
+
# Handle objects with model_dump (Pydantic models)
|
|
89
|
+
if hasattr(obj, "model_dump"):
|
|
90
|
+
return _sanitize_for_json(obj.model_dump())
|
|
91
|
+
# Fallback to string representation
|
|
92
|
+
return str(obj)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _find_last_user_text(messages: List[Dict[str, Any]]) -> List[str]:
|
|
96
|
+
texts: List[str] = []
|
|
97
|
+
for msg in reversed(messages):
|
|
98
|
+
if msg.get("type") in (None, "message") and msg.get("role") == "user":
|
|
99
|
+
content = msg.get("content")
|
|
100
|
+
if isinstance(content, str):
|
|
101
|
+
return [content]
|
|
102
|
+
elif isinstance(content, list):
|
|
103
|
+
for c in content:
|
|
104
|
+
if c.get("type") in ("input_text", "output_text") and c.get("text"):
|
|
105
|
+
texts.append(c["text"]) # newest first
|
|
106
|
+
if texts:
|
|
107
|
+
return list(reversed(texts))
|
|
108
|
+
return []
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _find_last_screenshot(messages: List[Dict[str, Any]]) -> Optional[bytes]:
|
|
112
|
+
for msg in reversed(messages):
|
|
113
|
+
if msg.get("type") == "computer_call_output":
|
|
114
|
+
out = msg.get("output", {})
|
|
115
|
+
if isinstance(out, dict) and out.get("type") in ("input_image", "computer_screenshot"):
|
|
116
|
+
image_url = out.get("image_url", "")
|
|
117
|
+
if image_url:
|
|
118
|
+
data, _ = _data_url_to_bytes(image_url)
|
|
119
|
+
return data
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _convert_messages_to_gemini_contents(
|
|
124
|
+
messages: List[Dict[str, Any]],
|
|
125
|
+
types: Any,
|
|
126
|
+
) -> Tuple[List[Any], Tuple[int, int]]:
|
|
127
|
+
"""
|
|
128
|
+
Convert internal message format to Gemini's Content format with full conversation history.
|
|
129
|
+
|
|
130
|
+
Similar to how Anthropic loop uses _convert_responses_items_to_completion_messages,
|
|
131
|
+
this converts ALL messages to Gemini's format.
|
|
132
|
+
|
|
133
|
+
Gemini requires:
|
|
134
|
+
- role: "user" or "model"
|
|
135
|
+
- parts: list of Part objects (text, image, function_call, function_response)
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Tuple of (list of Content objects, (screen_width, screen_height))
|
|
139
|
+
"""
|
|
140
|
+
contents: List[Any] = []
|
|
141
|
+
screen_w, screen_h = 1024, 768 # Default dimensions
|
|
142
|
+
|
|
143
|
+
for msg in messages:
|
|
144
|
+
msg_type = msg.get("type")
|
|
145
|
+
role = msg.get("role")
|
|
146
|
+
|
|
147
|
+
# User messages
|
|
148
|
+
if role == "user" or (msg_type in (None, "message") and role == "user"):
|
|
149
|
+
parts: List[Any] = []
|
|
150
|
+
content = msg.get("content")
|
|
151
|
+
|
|
152
|
+
if isinstance(content, str):
|
|
153
|
+
parts.append(types.Part(text=content))
|
|
154
|
+
elif isinstance(content, list):
|
|
155
|
+
for c in content:
|
|
156
|
+
if c.get("type") in ("input_text", "text") and c.get("text"):
|
|
157
|
+
parts.append(types.Part(text=c["text"]))
|
|
158
|
+
elif c.get("type") == "input_image" and c.get("image_url"):
|
|
159
|
+
img_bytes, _ = _data_url_to_bytes(c["image_url"])
|
|
160
|
+
if img_bytes:
|
|
161
|
+
w, h = _bytes_image_size(img_bytes)
|
|
162
|
+
screen_w, screen_h = w, h
|
|
163
|
+
parts.append(
|
|
164
|
+
types.Part.from_bytes(data=img_bytes, mime_type="image/png")
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if parts:
|
|
168
|
+
contents.append(types.Content(role="user", parts=parts))
|
|
169
|
+
|
|
170
|
+
# Assistant messages
|
|
171
|
+
elif role == "assistant" or (msg_type == "message" and role == "assistant"):
|
|
172
|
+
parts = []
|
|
173
|
+
content = msg.get("content")
|
|
174
|
+
|
|
175
|
+
if isinstance(content, str):
|
|
176
|
+
parts.append(types.Part(text=content))
|
|
177
|
+
elif isinstance(content, list):
|
|
178
|
+
for c in content:
|
|
179
|
+
if c.get("type") in ("output_text", "text") and c.get("text"):
|
|
180
|
+
parts.append(types.Part(text=c["text"]))
|
|
181
|
+
|
|
182
|
+
if parts:
|
|
183
|
+
contents.append(types.Content(role="model", parts=parts))
|
|
184
|
+
|
|
185
|
+
# Reasoning (treat as model output)
|
|
186
|
+
elif msg_type == "reasoning":
|
|
187
|
+
summary = msg.get("summary", [])
|
|
188
|
+
for s in summary:
|
|
189
|
+
if s.get("type") == "summary_text" and s.get("text"):
|
|
190
|
+
contents.append(
|
|
191
|
+
types.Content(
|
|
192
|
+
role="model", parts=[types.Part(text=f"[Thinking: {s['text']}]")]
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
break
|
|
196
|
+
|
|
197
|
+
# Computer call (model action) - represent as text description for context
|
|
198
|
+
elif msg_type == "computer_call":
|
|
199
|
+
action = msg.get("action", {})
|
|
200
|
+
action_type = action.get("type", "unknown")
|
|
201
|
+
action_desc = f"[Action: {action_type}"
|
|
202
|
+
for k, v in action.items():
|
|
203
|
+
if k != "type":
|
|
204
|
+
action_desc += f", {k}={v}"
|
|
205
|
+
action_desc += "]"
|
|
206
|
+
contents.append(types.Content(role="model", parts=[types.Part(text=action_desc)]))
|
|
207
|
+
|
|
208
|
+
# Computer call output (screenshot result) - this is the key part!
|
|
209
|
+
elif msg_type == "computer_call_output":
|
|
210
|
+
out = msg.get("output", {})
|
|
211
|
+
if isinstance(out, dict) and out.get("type") in ("input_image", "computer_screenshot"):
|
|
212
|
+
image_url = out.get("image_url", "")
|
|
213
|
+
if image_url and image_url != "[omitted]":
|
|
214
|
+
img_bytes, _ = _data_url_to_bytes(image_url)
|
|
215
|
+
if img_bytes:
|
|
216
|
+
w, h = _bytes_image_size(img_bytes)
|
|
217
|
+
screen_w, screen_h = w, h
|
|
218
|
+
contents.append(
|
|
219
|
+
types.Content(
|
|
220
|
+
role="user",
|
|
221
|
+
parts=[
|
|
222
|
+
types.Part.from_bytes(data=img_bytes, mime_type="image/png")
|
|
223
|
+
],
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
else:
|
|
227
|
+
# Image was omitted (by ImageRetentionCallback)
|
|
228
|
+
contents.append(
|
|
229
|
+
types.Content(
|
|
230
|
+
role="user",
|
|
231
|
+
parts=[
|
|
232
|
+
types.Part(
|
|
233
|
+
text="[Screenshot taken - image omitted for context limit]"
|
|
234
|
+
)
|
|
235
|
+
],
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Function call (model action)
|
|
240
|
+
elif msg_type == "function_call":
|
|
241
|
+
fn_name = msg.get("name", "unknown")
|
|
242
|
+
fn_args = msg.get("arguments", "{}")
|
|
243
|
+
contents.append(
|
|
244
|
+
types.Content(
|
|
245
|
+
role="model", parts=[types.Part(text=f"[Function call: {fn_name}({fn_args})]")]
|
|
246
|
+
)
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Function call output
|
|
250
|
+
elif msg_type == "function_call_output":
|
|
251
|
+
output = msg.get("output", "")
|
|
252
|
+
contents.append(
|
|
253
|
+
types.Content(role="user", parts=[types.Part(text=f"[Function result: {output}]")])
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Gemini requires alternating user/model turns - merge consecutive same-role contents
|
|
257
|
+
merged: List[Any] = []
|
|
258
|
+
for content in contents:
|
|
259
|
+
if merged and merged[-1].role == content.role:
|
|
260
|
+
# Merge parts into the previous content of same role
|
|
261
|
+
merged[-1] = types.Content(
|
|
262
|
+
role=content.role, parts=list(merged[-1].parts) + list(content.parts)
|
|
263
|
+
)
|
|
264
|
+
else:
|
|
265
|
+
merged.append(content)
|
|
266
|
+
|
|
267
|
+
# Gemini requires conversation to start with user
|
|
268
|
+
if merged and merged[0].role == "model":
|
|
269
|
+
merged.insert(0, types.Content(role="user", parts=[types.Part(text="Begin the task.")]))
|
|
270
|
+
|
|
271
|
+
# Ensure we have at least one message
|
|
272
|
+
if not merged:
|
|
273
|
+
merged = [
|
|
274
|
+
types.Content(role="user", parts=[types.Part(text="Proceed to the next action.")])
|
|
275
|
+
]
|
|
276
|
+
|
|
277
|
+
return merged, (screen_w, screen_h)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _denormalize(v: int, size: int) -> int:
|
|
281
|
+
# Gemini returns 0-999 normalized
|
|
282
|
+
try:
|
|
283
|
+
return max(0, min(size - 1, int(round(v / 1000 * size))))
|
|
284
|
+
except Exception:
|
|
285
|
+
return 0
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _is_gemini_3_model(model: str) -> bool:
|
|
289
|
+
"""Check if the model is a Gemini 3 model (Flash or Pro Preview)."""
|
|
290
|
+
return "gemini-3" in model.lower() or "gemini-2.0" in model.lower()
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def _build_custom_function_declarations(types: Any) -> List[Any]:
|
|
294
|
+
"""
|
|
295
|
+
Build custom function declarations for Gemini 3 models.
|
|
296
|
+
|
|
297
|
+
These function declarations replicate the built-in ComputerUse tool actions
|
|
298
|
+
that are available in Gemini 2.5 Computer Use Preview, but using the standard
|
|
299
|
+
function calling interface.
|
|
300
|
+
|
|
301
|
+
Note: Coordinates use 0-999 normalized range for both x and y.
|
|
302
|
+
"""
|
|
303
|
+
return [
|
|
304
|
+
types.FunctionDeclaration(
|
|
305
|
+
name="click_at",
|
|
306
|
+
description="Click at the specified x,y coordinates on the screen. Coordinates are normalized 0-999.",
|
|
307
|
+
parameters={
|
|
308
|
+
"type": "object",
|
|
309
|
+
"properties": {
|
|
310
|
+
"x": {"type": "integer", "description": "X coordinate (0-999 normalized)"},
|
|
311
|
+
"y": {"type": "integer", "description": "Y coordinate (0-999 normalized)"},
|
|
312
|
+
},
|
|
313
|
+
"required": ["x", "y"],
|
|
314
|
+
},
|
|
315
|
+
),
|
|
316
|
+
types.FunctionDeclaration(
|
|
317
|
+
name="type_text_at",
|
|
318
|
+
description="Type text at the specified x,y coordinates. First clicks at the location, then types the text.",
|
|
319
|
+
parameters={
|
|
320
|
+
"type": "object",
|
|
321
|
+
"properties": {
|
|
322
|
+
"x": {"type": "integer", "description": "X coordinate (0-999 normalized)"},
|
|
323
|
+
"y": {"type": "integer", "description": "Y coordinate (0-999 normalized)"},
|
|
324
|
+
"text": {"type": "string", "description": "Text to type"},
|
|
325
|
+
"press_enter": {
|
|
326
|
+
"type": "boolean",
|
|
327
|
+
"description": "Whether to press Enter after typing",
|
|
328
|
+
},
|
|
329
|
+
},
|
|
330
|
+
"required": ["x", "y", "text"],
|
|
331
|
+
},
|
|
332
|
+
),
|
|
333
|
+
types.FunctionDeclaration(
|
|
334
|
+
name="hover_at",
|
|
335
|
+
description="Move the mouse cursor to the specified x,y coordinates without clicking.",
|
|
336
|
+
parameters={
|
|
337
|
+
"type": "object",
|
|
338
|
+
"properties": {
|
|
339
|
+
"x": {"type": "integer", "description": "X coordinate (0-999 normalized)"},
|
|
340
|
+
"y": {"type": "integer", "description": "Y coordinate (0-999 normalized)"},
|
|
341
|
+
},
|
|
342
|
+
"required": ["x", "y"],
|
|
343
|
+
},
|
|
344
|
+
),
|
|
345
|
+
types.FunctionDeclaration(
|
|
346
|
+
name="key_combination",
|
|
347
|
+
description="Press a key combination (e.g., 'ctrl+c', 'alt+tab', 'enter').",
|
|
348
|
+
parameters={
|
|
349
|
+
"type": "object",
|
|
350
|
+
"properties": {
|
|
351
|
+
"keys": {
|
|
352
|
+
"type": "string",
|
|
353
|
+
"description": "Key combination to press (e.g., 'ctrl+c', 'enter', 'alt+tab')",
|
|
354
|
+
},
|
|
355
|
+
},
|
|
356
|
+
"required": ["keys"],
|
|
357
|
+
},
|
|
358
|
+
),
|
|
359
|
+
types.FunctionDeclaration(
|
|
360
|
+
name="scroll_at",
|
|
361
|
+
description="Scroll at the specified x,y coordinates in a given direction.",
|
|
362
|
+
parameters={
|
|
363
|
+
"type": "object",
|
|
364
|
+
"properties": {
|
|
365
|
+
"x": {"type": "integer", "description": "X coordinate (0-999 normalized)"},
|
|
366
|
+
"y": {"type": "integer", "description": "Y coordinate (0-999 normalized)"},
|
|
367
|
+
"direction": {
|
|
368
|
+
"type": "string",
|
|
369
|
+
"enum": ["up", "down", "left", "right"],
|
|
370
|
+
"description": "Direction to scroll",
|
|
371
|
+
},
|
|
372
|
+
"magnitude": {
|
|
373
|
+
"type": "integer",
|
|
374
|
+
"description": "Amount to scroll in pixels (default 800)",
|
|
375
|
+
},
|
|
376
|
+
},
|
|
377
|
+
"required": ["x", "y", "direction"],
|
|
378
|
+
},
|
|
379
|
+
),
|
|
380
|
+
types.FunctionDeclaration(
|
|
381
|
+
name="scroll_document",
|
|
382
|
+
description="Scroll the entire document/page in a given direction.",
|
|
383
|
+
parameters={
|
|
384
|
+
"type": "object",
|
|
385
|
+
"properties": {
|
|
386
|
+
"direction": {
|
|
387
|
+
"type": "string",
|
|
388
|
+
"enum": ["up", "down", "left", "right"],
|
|
389
|
+
"description": "Direction to scroll",
|
|
390
|
+
},
|
|
391
|
+
},
|
|
392
|
+
"required": ["direction"],
|
|
393
|
+
},
|
|
394
|
+
),
|
|
395
|
+
types.FunctionDeclaration(
|
|
396
|
+
name="drag_and_drop",
|
|
397
|
+
description="Drag from one coordinate to another.",
|
|
398
|
+
parameters={
|
|
399
|
+
"type": "object",
|
|
400
|
+
"properties": {
|
|
401
|
+
"x": {
|
|
402
|
+
"type": "integer",
|
|
403
|
+
"description": "Starting X coordinate (0-999 normalized)",
|
|
404
|
+
},
|
|
405
|
+
"y": {
|
|
406
|
+
"type": "integer",
|
|
407
|
+
"description": "Starting Y coordinate (0-999 normalized)",
|
|
408
|
+
},
|
|
409
|
+
"destination_x": {
|
|
410
|
+
"type": "integer",
|
|
411
|
+
"description": "Destination X coordinate (0-999 normalized)",
|
|
412
|
+
},
|
|
413
|
+
"destination_y": {
|
|
414
|
+
"type": "integer",
|
|
415
|
+
"description": "Destination Y coordinate (0-999 normalized)",
|
|
416
|
+
},
|
|
417
|
+
},
|
|
418
|
+
"required": ["x", "y", "destination_x", "destination_y"],
|
|
419
|
+
},
|
|
420
|
+
),
|
|
421
|
+
types.FunctionDeclaration(
|
|
422
|
+
name="wait_5_seconds",
|
|
423
|
+
description="Wait for 5 seconds before the next action. Use this when waiting for page loads or animations.",
|
|
424
|
+
parameters={
|
|
425
|
+
"type": "object",
|
|
426
|
+
"properties": {},
|
|
427
|
+
},
|
|
428
|
+
),
|
|
429
|
+
# # Browser-specific functions -> commented out for future support of browser exposed functions
|
|
430
|
+
# types.FunctionDeclaration(
|
|
431
|
+
# name="navigate",
|
|
432
|
+
# description="Navigate the browser to a specific URL.",
|
|
433
|
+
# parameters={
|
|
434
|
+
# "type": "object",
|
|
435
|
+
# "properties": {
|
|
436
|
+
# "url": {"type": "string", "description": "URL to navigate to"},
|
|
437
|
+
# },
|
|
438
|
+
# "required": ["url"],
|
|
439
|
+
# },
|
|
440
|
+
# ),
|
|
441
|
+
# types.FunctionDeclaration(
|
|
442
|
+
# name="open_web_browser",
|
|
443
|
+
# description="Open a web browser.",
|
|
444
|
+
# parameters={
|
|
445
|
+
# "type": "object",
|
|
446
|
+
# "properties": {},
|
|
447
|
+
# },
|
|
448
|
+
# ),
|
|
449
|
+
# types.FunctionDeclaration(
|
|
450
|
+
# name="search",
|
|
451
|
+
# description="Perform a web search with the given query.",
|
|
452
|
+
# parameters={
|
|
453
|
+
# "type": "object",
|
|
454
|
+
# "properties": {
|
|
455
|
+
# "query": {"type": "string", "description": "Search query"},
|
|
456
|
+
# },
|
|
457
|
+
# "required": ["query"],
|
|
458
|
+
# },
|
|
459
|
+
# ),
|
|
460
|
+
# types.FunctionDeclaration(
|
|
461
|
+
# name="go_back",
|
|
462
|
+
# description="Go back to the previous page in the browser.",
|
|
463
|
+
# parameters={
|
|
464
|
+
# "type": "object",
|
|
465
|
+
# "properties": {},
|
|
466
|
+
# },
|
|
467
|
+
# ),
|
|
468
|
+
# types.FunctionDeclaration(
|
|
469
|
+
# name="go_forward",
|
|
470
|
+
# description="Go forward to the next page in the browser.",
|
|
471
|
+
# parameters={
|
|
472
|
+
# "type": "object",
|
|
473
|
+
# "properties": {},
|
|
474
|
+
# },
|
|
475
|
+
# ),
|
|
476
|
+
]
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def _map_gemini_fc_to_computer_call(
|
|
480
|
+
fc: Dict[str, Any],
|
|
481
|
+
screen_w: int,
|
|
482
|
+
screen_h: int,
|
|
483
|
+
) -> Optional[Dict[str, Any]]:
|
|
484
|
+
name = fc.get("name")
|
|
485
|
+
args = fc.get("args", {}) or {}
|
|
486
|
+
|
|
487
|
+
# Gemini 3 Flash uses "web_agent_api:" prefix for browser functions
|
|
488
|
+
# Strip the prefix to normalize function names
|
|
489
|
+
if name and name.startswith("web_agent_api:"):
|
|
490
|
+
name = name[len("web_agent_api:") :]
|
|
491
|
+
|
|
492
|
+
action: Dict[str, Any] = {}
|
|
493
|
+
if name == "click_at":
|
|
494
|
+
x = _denormalize(int(args.get("x", 0)), screen_w)
|
|
495
|
+
y = _denormalize(int(args.get("y", 0)), screen_h)
|
|
496
|
+
action = {"type": "click", "x": x, "y": y, "button": "left"}
|
|
497
|
+
elif name == "type_text_at":
|
|
498
|
+
x = _denormalize(int(args.get("x", 0)), screen_w)
|
|
499
|
+
y = _denormalize(int(args.get("y", 0)), screen_h)
|
|
500
|
+
text = args.get("text", "")
|
|
501
|
+
if args.get("press_enter") == True:
|
|
502
|
+
text += "\n"
|
|
503
|
+
action = {"type": "type", "x": x, "y": y, "text": text}
|
|
504
|
+
elif name == "hover_at":
|
|
505
|
+
x = _denormalize(int(args.get("x", 0)), screen_w)
|
|
506
|
+
y = _denormalize(int(args.get("y", 0)), screen_h)
|
|
507
|
+
action = {"type": "move", "x": x, "y": y}
|
|
508
|
+
elif name == "key_combination":
|
|
509
|
+
keys = str(args.get("keys", ""))
|
|
510
|
+
action = {"type": "keypress", "keys": keys}
|
|
511
|
+
elif name == "scroll_document":
|
|
512
|
+
direction = args.get("direction", "down")
|
|
513
|
+
magnitude = 800
|
|
514
|
+
dx, dy = 0, 0
|
|
515
|
+
if direction == "down":
|
|
516
|
+
dy = magnitude
|
|
517
|
+
elif direction == "up":
|
|
518
|
+
dy = -magnitude
|
|
519
|
+
elif direction == "right":
|
|
520
|
+
dx = magnitude
|
|
521
|
+
elif direction == "left":
|
|
522
|
+
dx = -magnitude
|
|
523
|
+
action = {
|
|
524
|
+
"type": "scroll",
|
|
525
|
+
"scroll_x": dx,
|
|
526
|
+
"scroll_y": dy,
|
|
527
|
+
"x": int(screen_w / 2),
|
|
528
|
+
"y": int(screen_h / 2),
|
|
529
|
+
}
|
|
530
|
+
elif name == "scroll_at":
|
|
531
|
+
x = _denormalize(int(args.get("x", 500)), screen_w)
|
|
532
|
+
y = _denormalize(int(args.get("y", 500)), screen_h)
|
|
533
|
+
direction = args.get("direction", "down")
|
|
534
|
+
magnitude = int(args.get("magnitude", 800))
|
|
535
|
+
dx, dy = 0, 0
|
|
536
|
+
if direction == "down":
|
|
537
|
+
dy = magnitude
|
|
538
|
+
elif direction == "up":
|
|
539
|
+
dy = -magnitude
|
|
540
|
+
elif direction == "right":
|
|
541
|
+
dx = magnitude
|
|
542
|
+
elif direction == "left":
|
|
543
|
+
dx = -magnitude
|
|
544
|
+
action = {"type": "scroll", "scroll_x": dx, "scroll_y": dy, "x": x, "y": y}
|
|
545
|
+
elif name == "drag_and_drop":
|
|
546
|
+
x = _denormalize(int(args.get("x", 0)), screen_w)
|
|
547
|
+
y = _denormalize(int(args.get("y", 0)), screen_h)
|
|
548
|
+
dx = _denormalize(int(args.get("destination_x", x)), screen_w)
|
|
549
|
+
dy = _denormalize(int(args.get("destination_y", y)), screen_h)
|
|
550
|
+
action = {
|
|
551
|
+
"type": "drag",
|
|
552
|
+
"start_x": x,
|
|
553
|
+
"start_y": y,
|
|
554
|
+
"end_x": dx,
|
|
555
|
+
"end_y": dy,
|
|
556
|
+
"button": "left",
|
|
557
|
+
}
|
|
558
|
+
elif name == "wait_5_seconds":
|
|
559
|
+
action = {"type": "wait"}
|
|
560
|
+
# Browser-specific functions - use playwright_exec for browser control
|
|
561
|
+
# (Note: Gemini API does not respect exclusions, so we implement these)
|
|
562
|
+
elif name == "navigate":
|
|
563
|
+
url = args.get("url", "")
|
|
564
|
+
if url:
|
|
565
|
+
action = {"type": "playwright_exec", "command": "visit_url", "params": {"url": url}}
|
|
566
|
+
else:
|
|
567
|
+
return None
|
|
568
|
+
elif name == "open_web_browser":
|
|
569
|
+
# Open browser with blank page or google
|
|
570
|
+
action = {
|
|
571
|
+
"type": "playwright_exec",
|
|
572
|
+
"command": "visit_url",
|
|
573
|
+
"params": {"url": "https://www.google.com"},
|
|
574
|
+
}
|
|
575
|
+
elif name == "search":
|
|
576
|
+
query = args.get("query", "")
|
|
577
|
+
if query:
|
|
578
|
+
action = {
|
|
579
|
+
"type": "playwright_exec",
|
|
580
|
+
"command": "web_search",
|
|
581
|
+
"params": {"query": query},
|
|
582
|
+
}
|
|
583
|
+
else:
|
|
584
|
+
return None
|
|
585
|
+
elif name == "go_back":
|
|
586
|
+
# Browser back via Playwright's native navigation
|
|
587
|
+
action = {"type": "playwright_exec", "command": "go_back", "params": {}}
|
|
588
|
+
elif name == "go_forward":
|
|
589
|
+
# Browser forward via Playwright's native navigation
|
|
590
|
+
action = {"type": "playwright_exec", "command": "go_forward", "params": {}}
|
|
591
|
+
else:
|
|
592
|
+
# Unsupported / unknown function
|
|
593
|
+
print(f"[WARN] Unsupported Gemini function: {name}")
|
|
594
|
+
return None
|
|
595
|
+
|
|
596
|
+
return {
|
|
597
|
+
"type": "computer_call",
|
|
598
|
+
"call_id": uuid.uuid4().hex,
|
|
599
|
+
"status": "completed",
|
|
600
|
+
"action": action,
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
# Supported models:
|
|
605
|
+
# - gemini-2.5-computer-use-preview-* : Uses built-in ComputerUse tool
|
|
606
|
+
# - gemini-3-flash-preview-* : Uses custom function declarations
|
|
607
|
+
# - gemini-3-pro-preview-* : Uses custom function declarations
|
|
608
|
+
@register_agent(
|
|
609
|
+
models=r"^(gemini-2\.5-computer-use-preview.*|gemini-3-flash-preview.*|gemini-3-pro-preview.*)$"
|
|
610
|
+
)
|
|
611
|
+
class GeminiComputerUseConfig(AsyncAgentConfig):
|
|
612
|
+
async def predict_step(
|
|
613
|
+
self,
|
|
614
|
+
messages: List[Dict[str, Any]],
|
|
615
|
+
model: str,
|
|
616
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
617
|
+
max_retries: Optional[int] = None,
|
|
618
|
+
stream: bool = False,
|
|
619
|
+
computer_handler=None,
|
|
620
|
+
use_prompt_caching: Optional[bool] = False,
|
|
621
|
+
_on_api_start=None,
|
|
622
|
+
_on_api_end=None,
|
|
623
|
+
_on_usage=None,
|
|
624
|
+
_on_screenshot=None,
|
|
625
|
+
**kwargs,
|
|
626
|
+
) -> Dict[str, Any]:
|
|
627
|
+
genai, types = _lazy_import_genai()
|
|
628
|
+
import os
|
|
629
|
+
|
|
630
|
+
# Authentication follows two modes based on environment variables:
|
|
631
|
+
# 1. Google AI Studio: Set GOOGLE_API_KEY
|
|
632
|
+
# 2. Vertex AI: Set GOOGLE_CLOUD_PROJECT, GOOGLE_CLOUD_LOCATION, GOOGLE_GENAI_USE_VERTEXAI=True
|
|
633
|
+
api_key = kwargs.get("api_key", os.getenv("GOOGLE_API_KEY"))
|
|
634
|
+
|
|
635
|
+
if api_key:
|
|
636
|
+
client = genai.Client(api_key=api_key)
|
|
637
|
+
else:
|
|
638
|
+
# Vertex AI mode - requires GOOGLE_CLOUD_PROJECT, GOOGLE_CLOUD_LOCATION env vars
|
|
639
|
+
# and Application Default Credentials (ADC)
|
|
640
|
+
client = genai.Client()
|
|
641
|
+
|
|
642
|
+
# Extract Gemini 3-specific parameters
|
|
643
|
+
# thinking_level: Use types.ThinkingLevel enum values (e.g., "LOW", "HIGH", "MEDIUM", "MINIMAL")
|
|
644
|
+
# media_resolution: Use types.MediaResolution enum values (e.g., "MEDIA_RESOLUTION_LOW", "MEDIA_RESOLUTION_HIGH")
|
|
645
|
+
thinking_level = kwargs.pop("thinking_level", None)
|
|
646
|
+
media_resolution = kwargs.pop("media_resolution", None)
|
|
647
|
+
|
|
648
|
+
# Build thinking_config for Gemini 3 models if specified
|
|
649
|
+
thinking_config = None
|
|
650
|
+
if thinking_level:
|
|
651
|
+
# Accept string values and map to SDK enum
|
|
652
|
+
level_map = {
|
|
653
|
+
"minimal": types.ThinkingLevel.MINIMAL,
|
|
654
|
+
"low": types.ThinkingLevel.LOW,
|
|
655
|
+
"medium": types.ThinkingLevel.MEDIUM,
|
|
656
|
+
"high": types.ThinkingLevel.HIGH,
|
|
657
|
+
}
|
|
658
|
+
# Handle both lowercase strings and SDK enum values
|
|
659
|
+
if isinstance(thinking_level, str) and thinking_level.lower() in level_map:
|
|
660
|
+
thinking_config = types.ThinkingConfig(
|
|
661
|
+
thinking_level=level_map[thinking_level.lower()]
|
|
662
|
+
)
|
|
663
|
+
else:
|
|
664
|
+
# Assume it's already an SDK enum value
|
|
665
|
+
thinking_config = types.ThinkingConfig(thinking_level=thinking_level)
|
|
666
|
+
|
|
667
|
+
# Build media_resolution for Gemini 3 models if specified
|
|
668
|
+
resolved_media_resolution = None
|
|
669
|
+
if media_resolution:
|
|
670
|
+
resolution_map = {
|
|
671
|
+
"low": types.MediaResolution.MEDIA_RESOLUTION_LOW,
|
|
672
|
+
"medium": types.MediaResolution.MEDIA_RESOLUTION_MEDIUM,
|
|
673
|
+
"high": types.MediaResolution.MEDIA_RESOLUTION_HIGH,
|
|
674
|
+
}
|
|
675
|
+
if isinstance(media_resolution, str) and media_resolution.lower() in resolution_map:
|
|
676
|
+
resolved_media_resolution = resolution_map[media_resolution.lower()]
|
|
677
|
+
else:
|
|
678
|
+
# Assume it's already an SDK enum value
|
|
679
|
+
resolved_media_resolution = media_resolution
|
|
680
|
+
|
|
681
|
+
# Compose tools config based on model type
|
|
682
|
+
# Gemini 2.5 Computer Use Preview uses built-in ComputerUse tool
|
|
683
|
+
# Gemini 3 Flash/Pro Preview uses custom function declarations
|
|
684
|
+
is_gemini_3 = _is_gemini_3_model(model)
|
|
685
|
+
|
|
686
|
+
if is_gemini_3:
|
|
687
|
+
# Use custom function declarations for Gemini 3 models
|
|
688
|
+
custom_functions = _build_custom_function_declarations(types)
|
|
689
|
+
print(f"[DEBUG] Using custom function declarations for Gemini 3 model: {model}")
|
|
690
|
+
print(f"[DEBUG] Number of custom functions: {len(custom_functions)}")
|
|
691
|
+
|
|
692
|
+
generate_content_config = types.GenerateContentConfig(
|
|
693
|
+
tools=[
|
|
694
|
+
types.Tool(function_declarations=custom_functions),
|
|
695
|
+
],
|
|
696
|
+
thinking_config=thinking_config,
|
|
697
|
+
media_resolution=resolved_media_resolution,
|
|
698
|
+
)
|
|
699
|
+
else:
|
|
700
|
+
excluded = [
|
|
701
|
+
"open_web_browser",
|
|
702
|
+
"search",
|
|
703
|
+
"navigate",
|
|
704
|
+
"go_forward",
|
|
705
|
+
"go_back",
|
|
706
|
+
"scroll_document",
|
|
707
|
+
]
|
|
708
|
+
|
|
709
|
+
# Note: ENVIRONMENT_BROWSER biases model towards browser actions
|
|
710
|
+
# Use ENVIRONMENT_UNSPECIFIED for general desktop tasks
|
|
711
|
+
computer_environment = kwargs.pop("computer_environment", "browser")
|
|
712
|
+
env_map = {
|
|
713
|
+
"browser": types.Environment.ENVIRONMENT_BROWSER,
|
|
714
|
+
"unspecified": types.Environment.ENVIRONMENT_UNSPECIFIED,
|
|
715
|
+
}
|
|
716
|
+
resolved_environment = env_map.get(
|
|
717
|
+
computer_environment.lower(), types.Environment.ENVIRONMENT_BROWSER
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
print(f"[DEBUG] Using built-in ComputerUse tool for Gemini 2.5 model: {model}")
|
|
721
|
+
print(f"[DEBUG] Environment: {resolved_environment}")
|
|
722
|
+
print(f"[DEBUG] Excluded functions: {excluded}")
|
|
723
|
+
|
|
724
|
+
generate_content_config = types.GenerateContentConfig(
|
|
725
|
+
tools=[
|
|
726
|
+
types.Tool(
|
|
727
|
+
computer_use=types.ComputerUse(
|
|
728
|
+
environment=resolved_environment,
|
|
729
|
+
excluded_predefined_functions=excluded,
|
|
730
|
+
)
|
|
731
|
+
),
|
|
732
|
+
],
|
|
733
|
+
thinking_config=thinking_config,
|
|
734
|
+
media_resolution=resolved_media_resolution,
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
# Convert full message history to Gemini Contents format
|
|
738
|
+
contents, (screen_w, screen_h) = _convert_messages_to_gemini_contents(messages, types)
|
|
739
|
+
|
|
740
|
+
api_kwargs = {
|
|
741
|
+
"model": model,
|
|
742
|
+
"contents": contents,
|
|
743
|
+
"config": generate_content_config,
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
if _on_api_start:
|
|
747
|
+
await _on_api_start(
|
|
748
|
+
{
|
|
749
|
+
"model": api_kwargs["model"],
|
|
750
|
+
# "contents": api_kwargs["contents"], # Disabled for now
|
|
751
|
+
"config": api_kwargs["config"],
|
|
752
|
+
}
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
response = client.models.generate_content(**api_kwargs)
|
|
756
|
+
|
|
757
|
+
# Debug: print raw function calls from response
|
|
758
|
+
try:
|
|
759
|
+
for p in response.candidates[0].content.parts:
|
|
760
|
+
if hasattr(p, "function_call") and p.function_call:
|
|
761
|
+
print(
|
|
762
|
+
f"[DEBUG] Raw function_call from model: name={p.function_call.name}, args={dict(p.function_call.args or {})}"
|
|
763
|
+
)
|
|
764
|
+
except Exception as e:
|
|
765
|
+
print(f"[DEBUG] Error printing function calls: {e}")
|
|
766
|
+
|
|
767
|
+
if _on_api_end:
|
|
768
|
+
# Sanitize response to handle bytes fields (e.g., thought_signature in Gemini 3)
|
|
769
|
+
await _on_api_end(
|
|
770
|
+
{
|
|
771
|
+
"model": api_kwargs["model"],
|
|
772
|
+
# "contents": api_kwargs["contents"], # Disabled for now
|
|
773
|
+
"config": api_kwargs["config"],
|
|
774
|
+
},
|
|
775
|
+
_sanitize_for_json(response),
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
# Usage (Gemini SDK may not always provide token usage; populate when available)
|
|
779
|
+
usage: Dict[str, Any] = {}
|
|
780
|
+
try:
|
|
781
|
+
# Some SDKs expose response.usage; if available, copy
|
|
782
|
+
if getattr(response, "usage_metadata", None):
|
|
783
|
+
md = response.usage_metadata
|
|
784
|
+
usage = {
|
|
785
|
+
"prompt_tokens": getattr(md, "prompt_token_count", None) or 0,
|
|
786
|
+
"completion_tokens": getattr(md, "candidates_token_count", None) or 0,
|
|
787
|
+
"total_tokens": getattr(md, "total_token_count", None) or 0,
|
|
788
|
+
}
|
|
789
|
+
except Exception:
|
|
790
|
+
pass
|
|
791
|
+
|
|
792
|
+
if _on_usage and usage:
|
|
793
|
+
await _on_usage(usage)
|
|
794
|
+
|
|
795
|
+
# Parse output into internal items
|
|
796
|
+
output_items: List[Dict[str, Any]] = []
|
|
797
|
+
|
|
798
|
+
candidate = response.candidates[0]
|
|
799
|
+
# Text parts from the model (assistant message)
|
|
800
|
+
text_parts: List[str] = []
|
|
801
|
+
function_calls: List[Dict[str, Any]] = []
|
|
802
|
+
for p in candidate.content.parts:
|
|
803
|
+
if getattr(p, "text", None):
|
|
804
|
+
text_parts.append(p.text)
|
|
805
|
+
if getattr(p, "function_call", None):
|
|
806
|
+
# p.function_call has name and args
|
|
807
|
+
fc = {
|
|
808
|
+
"name": getattr(p.function_call, "name", None),
|
|
809
|
+
"args": dict(getattr(p.function_call, "args", {}) or {}),
|
|
810
|
+
}
|
|
811
|
+
function_calls.append(fc)
|
|
812
|
+
|
|
813
|
+
if text_parts:
|
|
814
|
+
output_items.append(
|
|
815
|
+
{
|
|
816
|
+
"type": "message",
|
|
817
|
+
"role": "assistant",
|
|
818
|
+
"content": [{"type": "output_text", "text": "\n".join(text_parts)}],
|
|
819
|
+
}
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
# Map function calls to internal computer_call actions
|
|
823
|
+
for fc in function_calls:
|
|
824
|
+
print(f"[DEBUG] Model returned function_call: {fc}")
|
|
825
|
+
item = _map_gemini_fc_to_computer_call(fc, screen_w, screen_h)
|
|
826
|
+
if item is not None:
|
|
827
|
+
output_items.append(item)
|
|
828
|
+
else:
|
|
829
|
+
print(f"[DEBUG] Function '{fc.get('name')}' not mapped (excluded or unsupported)")
|
|
830
|
+
|
|
831
|
+
return {"output": output_items, "usage": usage}
|
|
832
|
+
|
|
833
|
+
async def predict_click(
|
|
834
|
+
self,
|
|
835
|
+
model: str,
|
|
836
|
+
image_b64: str,
|
|
837
|
+
instruction: str,
|
|
838
|
+
**kwargs,
|
|
839
|
+
) -> Optional[Tuple[float, float]]:
|
|
840
|
+
"""Ask Gemini Cua to output a single click action for the given instruction.
|
|
841
|
+
|
|
842
|
+
For Gemini 2.5: Excludes all predefined tools except `click_at` and sends the screenshot.
|
|
843
|
+
For Gemini 3: Uses only the click_at function declaration.
|
|
844
|
+
Returns pixel (x, y) if a click is proposed, else None.
|
|
845
|
+
"""
|
|
846
|
+
genai, types = _lazy_import_genai()
|
|
847
|
+
import os
|
|
848
|
+
|
|
849
|
+
# Authentication: GOOGLE_API_KEY for AI Studio, or Vertex AI env vars
|
|
850
|
+
api_key = kwargs.get("api_key", os.getenv("GOOGLE_API_KEY"))
|
|
851
|
+
if api_key:
|
|
852
|
+
client = genai.Client(api_key=api_key)
|
|
853
|
+
else:
|
|
854
|
+
client = genai.Client()
|
|
855
|
+
|
|
856
|
+
# Build tools config based on model type
|
|
857
|
+
is_gemini_3 = _is_gemini_3_model(model)
|
|
858
|
+
|
|
859
|
+
if is_gemini_3:
|
|
860
|
+
# For Gemini 3 models, use only click_at function declaration
|
|
861
|
+
click_function = types.FunctionDeclaration(
|
|
862
|
+
name="click_at",
|
|
863
|
+
description="Click at the specified x,y coordinates on the screen. Coordinates are normalized 0-999.",
|
|
864
|
+
parameters={
|
|
865
|
+
"type": "object",
|
|
866
|
+
"properties": {
|
|
867
|
+
"x": {"type": "integer", "description": "X coordinate (0-999 normalized)"},
|
|
868
|
+
"y": {"type": "integer", "description": "Y coordinate (0-999 normalized)"},
|
|
869
|
+
},
|
|
870
|
+
"required": ["x", "y"],
|
|
871
|
+
},
|
|
872
|
+
)
|
|
873
|
+
config = types.GenerateContentConfig(
|
|
874
|
+
tools=[
|
|
875
|
+
types.Tool(function_declarations=[click_function]),
|
|
876
|
+
]
|
|
877
|
+
)
|
|
878
|
+
else:
|
|
879
|
+
exclude_all_but_click = [
|
|
880
|
+
"open_web_browser",
|
|
881
|
+
"search",
|
|
882
|
+
"navigate",
|
|
883
|
+
"go_forward",
|
|
884
|
+
"go_back",
|
|
885
|
+
"scroll_document",
|
|
886
|
+
]
|
|
887
|
+
|
|
888
|
+
config = types.GenerateContentConfig(
|
|
889
|
+
tools=[
|
|
890
|
+
types.Tool(
|
|
891
|
+
computer_use=types.ComputerUse(
|
|
892
|
+
environment=types.Environment.ENVIRONMENT_BROWSER,
|
|
893
|
+
excluded_predefined_functions=exclude_all_but_click,
|
|
894
|
+
)
|
|
895
|
+
)
|
|
896
|
+
]
|
|
897
|
+
)
|
|
898
|
+
|
|
899
|
+
# Prepare prompt parts
|
|
900
|
+
try:
|
|
901
|
+
img_bytes = base64.b64decode(image_b64)
|
|
902
|
+
except Exception:
|
|
903
|
+
img_bytes = b""
|
|
904
|
+
|
|
905
|
+
w, h = _bytes_image_size(img_bytes) if img_bytes else (1024, 768)
|
|
906
|
+
|
|
907
|
+
parts: List[Any] = [types.Part(text=f"Click {instruction}.")]
|
|
908
|
+
if img_bytes:
|
|
909
|
+
parts.append(types.Part.from_bytes(data=img_bytes, mime_type="image/png"))
|
|
910
|
+
|
|
911
|
+
contents = [types.Content(role="user", parts=parts)]
|
|
912
|
+
|
|
913
|
+
response = client.models.generate_content(
|
|
914
|
+
model=model,
|
|
915
|
+
contents=contents,
|
|
916
|
+
config=config,
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
# Parse first click_at
|
|
920
|
+
try:
|
|
921
|
+
candidate = response.candidates[0]
|
|
922
|
+
for p in candidate.content.parts:
|
|
923
|
+
fc = getattr(p, "function_call", None)
|
|
924
|
+
if fc and getattr(fc, "name", None) == "click_at":
|
|
925
|
+
args = dict(getattr(fc, "args", {}) or {})
|
|
926
|
+
x = _denormalize(int(args.get("x", 0)), w)
|
|
927
|
+
y = _denormalize(int(args.get("y", 0)), h)
|
|
928
|
+
return float(x), float(y)
|
|
929
|
+
except Exception:
|
|
930
|
+
return None
|
|
931
|
+
|
|
932
|
+
return None
|
|
933
|
+
|
|
934
|
+
def get_capabilities(self) -> List[AgentCapability]:
|
|
935
|
+
return ["click", "step"]
|