cua-agent 0.4.34__py3-none-any.whl → 0.4.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +4 -10
- agent/__main__.py +2 -1
- agent/adapters/huggingfacelocal_adapter.py +54 -61
- agent/adapters/human_adapter.py +116 -114
- agent/adapters/mlxvlm_adapter.py +110 -99
- agent/adapters/models/__init__.py +14 -6
- agent/adapters/models/generic.py +7 -4
- agent/adapters/models/internvl.py +66 -30
- agent/adapters/models/opencua.py +23 -8
- agent/adapters/models/qwen2_5_vl.py +7 -4
- agent/agent.py +184 -158
- agent/callbacks/__init__.py +4 -4
- agent/callbacks/base.py +45 -31
- agent/callbacks/budget_manager.py +22 -10
- agent/callbacks/image_retention.py +18 -13
- agent/callbacks/logging.py +55 -42
- agent/callbacks/operator_validator.py +3 -1
- agent/callbacks/pii_anonymization.py +19 -16
- agent/callbacks/telemetry.py +67 -61
- agent/callbacks/trajectory_saver.py +90 -70
- agent/cli.py +115 -110
- agent/computers/__init__.py +13 -8
- agent/computers/base.py +32 -19
- agent/computers/cua.py +33 -25
- agent/computers/custom.py +78 -71
- agent/decorators.py +23 -14
- agent/human_tool/__init__.py +2 -7
- agent/human_tool/__main__.py +6 -2
- agent/human_tool/server.py +48 -37
- agent/human_tool/ui.py +235 -185
- agent/integrations/hud/__init__.py +15 -21
- agent/integrations/hud/agent.py +101 -83
- agent/integrations/hud/proxy.py +90 -57
- agent/loops/__init__.py +25 -21
- agent/loops/anthropic.py +537 -483
- agent/loops/base.py +13 -14
- agent/loops/composed_grounded.py +135 -149
- agent/loops/gemini.py +31 -12
- agent/loops/glm45v.py +135 -133
- agent/loops/gta1.py +47 -50
- agent/loops/holo.py +4 -2
- agent/loops/internvl.py +6 -11
- agent/loops/moondream3.py +36 -12
- agent/loops/omniparser.py +215 -210
- agent/loops/openai.py +49 -50
- agent/loops/opencua.py +29 -41
- agent/loops/qwen.py +510 -0
- agent/loops/uitars.py +237 -202
- agent/proxy/examples.py +54 -50
- agent/proxy/handlers.py +27 -34
- agent/responses.py +330 -330
- agent/types.py +11 -5
- agent/ui/__init__.py +1 -1
- agent/ui/__main__.py +1 -1
- agent/ui/gradio/app.py +23 -18
- agent/ui/gradio/ui_components.py +310 -161
- {cua_agent-0.4.34.dist-info → cua_agent-0.4.36.dist-info}/METADATA +18 -10
- cua_agent-0.4.36.dist-info/RECORD +64 -0
- cua_agent-0.4.34.dist-info/RECORD +0 -63
- {cua_agent-0.4.34.dist-info → cua_agent-0.4.36.dist-info}/WHEEL +0 -0
- {cua_agent-0.4.34.dist-info → cua_agent-0.4.36.dist-info}/entry_points.txt +0 -0
agent/agent.py
CHANGED
|
@@ -3,76 +3,83 @@ ComputerAgent - Main agent class that selects and runs agent loops
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import asyncio
|
|
6
|
+
import inspect
|
|
7
|
+
import json
|
|
6
8
|
from pathlib import Path
|
|
7
|
-
from typing import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
9
|
+
from typing import (
|
|
10
|
+
Any,
|
|
11
|
+
AsyncGenerator,
|
|
12
|
+
Callable,
|
|
13
|
+
Dict,
|
|
14
|
+
List,
|
|
15
|
+
Optional,
|
|
16
|
+
Set,
|
|
17
|
+
Tuple,
|
|
18
|
+
Union,
|
|
19
|
+
cast,
|
|
16
20
|
)
|
|
17
|
-
|
|
18
|
-
from .decorators import find_agent_config
|
|
19
|
-
import json
|
|
21
|
+
|
|
20
22
|
import litellm
|
|
21
23
|
import litellm.utils
|
|
22
|
-
import
|
|
24
|
+
from litellm.responses.utils import Usage
|
|
25
|
+
|
|
23
26
|
from .adapters import (
|
|
24
27
|
HuggingFaceLocalAdapter,
|
|
25
28
|
HumanAdapter,
|
|
26
29
|
MLXVLMAdapter,
|
|
27
30
|
)
|
|
28
31
|
from .callbacks import (
|
|
29
|
-
ImageRetentionCallback,
|
|
30
|
-
LoggingCallback,
|
|
31
|
-
TrajectorySaverCallback,
|
|
32
32
|
BudgetManagerCallback,
|
|
33
|
-
|
|
33
|
+
ImageRetentionCallback,
|
|
34
|
+
LoggingCallback,
|
|
34
35
|
OperatorNormalizerCallback,
|
|
35
36
|
PromptInstructionsCallback,
|
|
37
|
+
TelemetryCallback,
|
|
38
|
+
TrajectorySaverCallback,
|
|
36
39
|
)
|
|
37
|
-
from .computers import
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
40
|
+
from .computers import AsyncComputerHandler, is_agent_computer, make_computer_handler
|
|
41
|
+
from .decorators import find_agent_config
|
|
42
|
+
from .responses import (
|
|
43
|
+
make_tool_error_item,
|
|
44
|
+
replace_failed_computer_calls_with_function_calls,
|
|
41
45
|
)
|
|
46
|
+
from .types import AgentCapability, IllegalArgumentError, Messages, ToolError
|
|
47
|
+
|
|
42
48
|
|
|
43
49
|
def assert_callable_with(f, *args, **kwargs):
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
50
|
+
"""Check if function can be called with given arguments."""
|
|
51
|
+
try:
|
|
52
|
+
inspect.signature(f).bind(*args, **kwargs)
|
|
53
|
+
return True
|
|
54
|
+
except TypeError as e:
|
|
55
|
+
sig = inspect.signature(f)
|
|
56
|
+
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
|
|
57
|
+
|
|
51
58
|
|
|
52
59
|
def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
53
60
|
def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
|
|
54
61
|
if seen is None:
|
|
55
62
|
seen = set()
|
|
56
|
-
|
|
63
|
+
|
|
57
64
|
# Use model_dump() if available
|
|
58
|
-
if hasattr(o,
|
|
65
|
+
if hasattr(o, "model_dump"):
|
|
59
66
|
return o.model_dump()
|
|
60
|
-
|
|
67
|
+
|
|
61
68
|
# Check depth limit
|
|
62
69
|
if depth > max_depth:
|
|
63
70
|
return f"<max_depth_exceeded:{max_depth}>"
|
|
64
|
-
|
|
71
|
+
|
|
65
72
|
# Check for circular references using object id
|
|
66
73
|
obj_id = id(o)
|
|
67
74
|
if obj_id in seen:
|
|
68
75
|
return f"<circular_reference:{type(o).__name__}>"
|
|
69
|
-
|
|
76
|
+
|
|
70
77
|
# Handle Computer objects
|
|
71
|
-
if hasattr(o,
|
|
78
|
+
if hasattr(o, "__class__") and "computer" in o.__class__.__name__.lower():
|
|
72
79
|
return f"<computer:{o.__class__.__name__}>"
|
|
73
80
|
|
|
74
81
|
# Handle objects with __dict__
|
|
75
|
-
if hasattr(o,
|
|
82
|
+
if hasattr(o, "__dict__"):
|
|
76
83
|
seen.add(obj_id)
|
|
77
84
|
try:
|
|
78
85
|
result = {}
|
|
@@ -84,7 +91,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
|
84
91
|
return result
|
|
85
92
|
finally:
|
|
86
93
|
seen.discard(obj_id)
|
|
87
|
-
|
|
94
|
+
|
|
88
95
|
# Handle common types that might contain nested objects
|
|
89
96
|
elif isinstance(o, dict):
|
|
90
97
|
seen.add(obj_id)
|
|
@@ -96,7 +103,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
|
96
103
|
}
|
|
97
104
|
finally:
|
|
98
105
|
seen.discard(obj_id)
|
|
99
|
-
|
|
106
|
+
|
|
100
107
|
elif isinstance(o, (list, tuple, set)):
|
|
101
108
|
seen.add(obj_id)
|
|
102
109
|
try:
|
|
@@ -107,32 +114,33 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
|
107
114
|
]
|
|
108
115
|
finally:
|
|
109
116
|
seen.discard(obj_id)
|
|
110
|
-
|
|
117
|
+
|
|
111
118
|
# For basic types that json.dumps can handle
|
|
112
119
|
elif isinstance(o, (str, int, float, bool)) or o is None:
|
|
113
120
|
return o
|
|
114
|
-
|
|
121
|
+
|
|
115
122
|
# Fallback to string representation
|
|
116
123
|
else:
|
|
117
124
|
return str(o)
|
|
118
|
-
|
|
125
|
+
|
|
119
126
|
def remove_nones(obj: Any) -> Any:
|
|
120
127
|
if isinstance(obj, dict):
|
|
121
128
|
return {k: remove_nones(v) for k, v in obj.items() if v is not None}
|
|
122
129
|
elif isinstance(obj, list):
|
|
123
130
|
return [remove_nones(item) for item in obj if item is not None]
|
|
124
131
|
return obj
|
|
125
|
-
|
|
132
|
+
|
|
126
133
|
# Serialize with circular reference and depth protection
|
|
127
134
|
serialized = custom_serializer(obj)
|
|
128
|
-
|
|
135
|
+
|
|
129
136
|
# Convert to JSON string and back to ensure JSON compatibility
|
|
130
137
|
json_str = json.dumps(serialized)
|
|
131
138
|
parsed = json.loads(json_str)
|
|
132
|
-
|
|
139
|
+
|
|
133
140
|
# Final cleanup of any remaining None values
|
|
134
141
|
return remove_nones(parsed)
|
|
135
142
|
|
|
143
|
+
|
|
136
144
|
def sanitize_message(msg: Any) -> Any:
|
|
137
145
|
"""Return a copy of the message with image_url omitted for computer_call_output messages."""
|
|
138
146
|
if msg.get("type") == "computer_call_output":
|
|
@@ -143,19 +151,24 @@ def sanitize_message(msg: Any) -> Any:
|
|
|
143
151
|
return sanitized
|
|
144
152
|
return msg
|
|
145
153
|
|
|
154
|
+
|
|
146
155
|
def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
|
|
147
156
|
call_ids = []
|
|
148
157
|
for message in messages:
|
|
149
|
-
if
|
|
158
|
+
if (
|
|
159
|
+
message.get("type") == "computer_call_output"
|
|
160
|
+
or message.get("type") == "function_call_output"
|
|
161
|
+
):
|
|
150
162
|
call_ids.append(message.get("call_id"))
|
|
151
163
|
return call_ids
|
|
152
164
|
|
|
165
|
+
|
|
153
166
|
class ComputerAgent:
|
|
154
167
|
"""
|
|
155
168
|
Main agent class that automatically selects the appropriate agent loop
|
|
156
169
|
based on the model and executes tool calls.
|
|
157
170
|
"""
|
|
158
|
-
|
|
171
|
+
|
|
159
172
|
def __init__(
|
|
160
173
|
self,
|
|
161
174
|
model: str,
|
|
@@ -172,11 +185,11 @@ class ComputerAgent:
|
|
|
172
185
|
max_trajectory_budget: Optional[float | dict] = None,
|
|
173
186
|
telemetry_enabled: Optional[bool] = True,
|
|
174
187
|
trust_remote_code: Optional[bool] = False,
|
|
175
|
-
**kwargs
|
|
188
|
+
**kwargs,
|
|
176
189
|
):
|
|
177
190
|
"""
|
|
178
191
|
Initialize ComputerAgent.
|
|
179
|
-
|
|
192
|
+
|
|
180
193
|
Args:
|
|
181
194
|
model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
|
|
182
195
|
tools: List of tools (computer objects, decorated functions, etc.)
|
|
@@ -193,11 +206,11 @@ class ComputerAgent:
|
|
|
193
206
|
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
|
|
194
207
|
trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
|
|
195
208
|
**kwargs: Additional arguments passed to the agent loop
|
|
196
|
-
"""
|
|
209
|
+
"""
|
|
197
210
|
# If the loop is "human/human", we need to prefix a grounding model fallback
|
|
198
211
|
if model in ["human/human", "human"]:
|
|
199
212
|
model = "openai/computer-use-preview+human/human"
|
|
200
|
-
|
|
213
|
+
|
|
201
214
|
self.model = model
|
|
202
215
|
self.tools = tools or []
|
|
203
216
|
self.custom_loop = custom_loop
|
|
@@ -236,34 +249,33 @@ class ComputerAgent:
|
|
|
236
249
|
# Add image retention callback if only_n_most_recent_images is set
|
|
237
250
|
if self.only_n_most_recent_images:
|
|
238
251
|
self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
|
|
239
|
-
|
|
252
|
+
|
|
240
253
|
# Add trajectory saver callback if trajectory_dir is set
|
|
241
254
|
if self.trajectory_dir:
|
|
242
255
|
if isinstance(self.trajectory_dir, dict):
|
|
243
256
|
self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
|
|
244
257
|
elif isinstance(self.trajectory_dir, (str, Path)):
|
|
245
258
|
self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
|
|
246
|
-
|
|
259
|
+
|
|
247
260
|
# Add budget manager if max_trajectory_budget is set
|
|
248
261
|
if max_trajectory_budget:
|
|
249
262
|
if isinstance(max_trajectory_budget, dict):
|
|
250
263
|
self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
|
|
251
264
|
else:
|
|
252
265
|
self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
|
|
253
|
-
|
|
266
|
+
|
|
254
267
|
# == Enable local model providers w/ LiteLLM ==
|
|
255
268
|
|
|
256
269
|
# Register local model providers
|
|
257
270
|
hf_adapter = HuggingFaceLocalAdapter(
|
|
258
|
-
device="auto",
|
|
259
|
-
trust_remote_code=self.trust_remote_code or False
|
|
271
|
+
device="auto", trust_remote_code=self.trust_remote_code or False
|
|
260
272
|
)
|
|
261
273
|
human_adapter = HumanAdapter()
|
|
262
274
|
mlx_adapter = MLXVLMAdapter()
|
|
263
275
|
litellm.custom_provider_map = [
|
|
264
276
|
{"provider": "huggingface-local", "custom_handler": hf_adapter},
|
|
265
277
|
{"provider": "human", "custom_handler": human_adapter},
|
|
266
|
-
{"provider": "mlx", "custom_handler": mlx_adapter}
|
|
278
|
+
{"provider": "mlx", "custom_handler": mlx_adapter},
|
|
267
279
|
]
|
|
268
280
|
litellm.suppress_debug_info = True
|
|
269
281
|
|
|
@@ -280,16 +292,16 @@ class ComputerAgent:
|
|
|
280
292
|
# Instantiate the agent config class
|
|
281
293
|
self.agent_loop = config_info.agent_class()
|
|
282
294
|
self.agent_config_info = config_info
|
|
283
|
-
|
|
295
|
+
|
|
284
296
|
self.tool_schemas = []
|
|
285
297
|
self.computer_handler = None
|
|
286
|
-
|
|
298
|
+
|
|
287
299
|
async def _initialize_computers(self):
|
|
288
300
|
"""Initialize computer objects"""
|
|
289
301
|
if not self.tool_schemas:
|
|
290
302
|
# Process tools and create tool schemas
|
|
291
303
|
self.tool_schemas = self._process_tools()
|
|
292
|
-
|
|
304
|
+
|
|
293
305
|
# Find computer tool and create interface adapter
|
|
294
306
|
computer_handler = None
|
|
295
307
|
for schema in self.tool_schemas:
|
|
@@ -297,7 +309,7 @@ class ComputerAgent:
|
|
|
297
309
|
computer_handler = await make_computer_handler(schema["computer"])
|
|
298
310
|
break
|
|
299
311
|
self.computer_handler = computer_handler
|
|
300
|
-
|
|
312
|
+
|
|
301
313
|
def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
|
|
302
314
|
"""Process input messages and create schemas for the agent loop"""
|
|
303
315
|
if isinstance(input, str):
|
|
@@ -307,69 +319,73 @@ class ComputerAgent:
|
|
|
307
319
|
def _process_tools(self) -> List[Dict[str, Any]]:
|
|
308
320
|
"""Process tools and create schemas for the agent loop"""
|
|
309
321
|
schemas = []
|
|
310
|
-
|
|
322
|
+
|
|
311
323
|
for tool in self.tools:
|
|
312
324
|
# Check if it's a computer object (has interface attribute)
|
|
313
325
|
if is_agent_computer(tool):
|
|
314
326
|
# This is a computer tool - will be handled by agent loop
|
|
315
|
-
schemas.append({
|
|
316
|
-
"type": "computer",
|
|
317
|
-
"computer": tool
|
|
318
|
-
})
|
|
327
|
+
schemas.append({"type": "computer", "computer": tool})
|
|
319
328
|
elif callable(tool):
|
|
320
329
|
# Use litellm.utils.function_to_dict to extract schema from docstring
|
|
321
330
|
try:
|
|
322
331
|
function_schema = litellm.utils.function_to_dict(tool)
|
|
323
|
-
schemas.append({
|
|
324
|
-
"type": "function",
|
|
325
|
-
"function": function_schema
|
|
326
|
-
})
|
|
332
|
+
schemas.append({"type": "function", "function": function_schema})
|
|
327
333
|
except Exception as e:
|
|
328
334
|
print(f"Warning: Could not process tool {tool}: {e}")
|
|
329
335
|
else:
|
|
330
336
|
print(f"Warning: Unknown tool type: {tool}")
|
|
331
|
-
|
|
337
|
+
|
|
332
338
|
return schemas
|
|
333
|
-
|
|
339
|
+
|
|
334
340
|
def _get_tool(self, name: str) -> Optional[Callable]:
|
|
335
341
|
"""Get a tool by name"""
|
|
336
342
|
for tool in self.tools:
|
|
337
|
-
if hasattr(tool,
|
|
343
|
+
if hasattr(tool, "__name__") and tool.__name__ == name:
|
|
338
344
|
return tool
|
|
339
|
-
elif hasattr(tool,
|
|
345
|
+
elif hasattr(tool, "func") and tool.func.__name__ == name:
|
|
340
346
|
return tool
|
|
341
347
|
return None
|
|
342
|
-
|
|
348
|
+
|
|
343
349
|
# ============================================================================
|
|
344
350
|
# AGENT RUN LOOP LIFECYCLE HOOKS
|
|
345
351
|
# ============================================================================
|
|
346
|
-
|
|
352
|
+
|
|
347
353
|
async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
|
348
354
|
"""Initialize run tracking by calling callbacks."""
|
|
349
355
|
for callback in self.callbacks:
|
|
350
|
-
if hasattr(callback,
|
|
356
|
+
if hasattr(callback, "on_run_start"):
|
|
351
357
|
await callback.on_run_start(kwargs, old_items)
|
|
352
|
-
|
|
353
|
-
async def _on_run_end(
|
|
358
|
+
|
|
359
|
+
async def _on_run_end(
|
|
360
|
+
self,
|
|
361
|
+
kwargs: Dict[str, Any],
|
|
362
|
+
old_items: List[Dict[str, Any]],
|
|
363
|
+
new_items: List[Dict[str, Any]],
|
|
364
|
+
) -> None:
|
|
354
365
|
"""Finalize run tracking by calling callbacks."""
|
|
355
366
|
for callback in self.callbacks:
|
|
356
|
-
if hasattr(callback,
|
|
367
|
+
if hasattr(callback, "on_run_end"):
|
|
357
368
|
await callback.on_run_end(kwargs, old_items, new_items)
|
|
358
|
-
|
|
359
|
-
async def _on_run_continue(
|
|
369
|
+
|
|
370
|
+
async def _on_run_continue(
|
|
371
|
+
self,
|
|
372
|
+
kwargs: Dict[str, Any],
|
|
373
|
+
old_items: List[Dict[str, Any]],
|
|
374
|
+
new_items: List[Dict[str, Any]],
|
|
375
|
+
) -> bool:
|
|
360
376
|
"""Check if run should continue by calling callbacks."""
|
|
361
377
|
for callback in self.callbacks:
|
|
362
|
-
if hasattr(callback,
|
|
378
|
+
if hasattr(callback, "on_run_continue"):
|
|
363
379
|
should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
|
|
364
380
|
if not should_continue:
|
|
365
381
|
return False
|
|
366
382
|
return True
|
|
367
|
-
|
|
383
|
+
|
|
368
384
|
async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
369
385
|
"""Prepare messages for the LLM call by applying callbacks."""
|
|
370
386
|
result = messages
|
|
371
387
|
for callback in self.callbacks:
|
|
372
|
-
if hasattr(callback,
|
|
388
|
+
if hasattr(callback, "on_llm_start"):
|
|
373
389
|
result = await callback.on_llm_start(result)
|
|
374
390
|
return result
|
|
375
391
|
|
|
@@ -377,82 +393,91 @@ class ComputerAgent:
|
|
|
377
393
|
"""Postprocess messages after the LLM call by applying callbacks."""
|
|
378
394
|
result = messages
|
|
379
395
|
for callback in self.callbacks:
|
|
380
|
-
if hasattr(callback,
|
|
396
|
+
if hasattr(callback, "on_llm_end"):
|
|
381
397
|
result = await callback.on_llm_end(result)
|
|
382
398
|
return result
|
|
383
399
|
|
|
384
400
|
async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
|
385
401
|
"""Called when responses are received."""
|
|
386
402
|
for callback in self.callbacks:
|
|
387
|
-
if hasattr(callback,
|
|
403
|
+
if hasattr(callback, "on_responses"):
|
|
388
404
|
await callback.on_responses(get_json(kwargs), get_json(responses))
|
|
389
|
-
|
|
405
|
+
|
|
390
406
|
async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
|
391
407
|
"""Called when a computer call is about to start."""
|
|
392
408
|
for callback in self.callbacks:
|
|
393
|
-
if hasattr(callback,
|
|
409
|
+
if hasattr(callback, "on_computer_call_start"):
|
|
394
410
|
await callback.on_computer_call_start(get_json(item))
|
|
395
|
-
|
|
396
|
-
async def _on_computer_call_end(
|
|
411
|
+
|
|
412
|
+
async def _on_computer_call_end(
|
|
413
|
+
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
|
414
|
+
) -> None:
|
|
397
415
|
"""Called when a computer call has completed."""
|
|
398
416
|
for callback in self.callbacks:
|
|
399
|
-
if hasattr(callback,
|
|
417
|
+
if hasattr(callback, "on_computer_call_end"):
|
|
400
418
|
await callback.on_computer_call_end(get_json(item), get_json(result))
|
|
401
|
-
|
|
419
|
+
|
|
402
420
|
async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
|
|
403
421
|
"""Called when a function call is about to start."""
|
|
404
422
|
for callback in self.callbacks:
|
|
405
|
-
if hasattr(callback,
|
|
423
|
+
if hasattr(callback, "on_function_call_start"):
|
|
406
424
|
await callback.on_function_call_start(get_json(item))
|
|
407
|
-
|
|
408
|
-
async def _on_function_call_end(
|
|
425
|
+
|
|
426
|
+
async def _on_function_call_end(
|
|
427
|
+
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
|
428
|
+
) -> None:
|
|
409
429
|
"""Called when a function call has completed."""
|
|
410
430
|
for callback in self.callbacks:
|
|
411
|
-
if hasattr(callback,
|
|
431
|
+
if hasattr(callback, "on_function_call_end"):
|
|
412
432
|
await callback.on_function_call_end(get_json(item), get_json(result))
|
|
413
|
-
|
|
433
|
+
|
|
414
434
|
async def _on_text(self, item: Dict[str, Any]) -> None:
|
|
415
435
|
"""Called when a text message is encountered."""
|
|
416
436
|
for callback in self.callbacks:
|
|
417
|
-
if hasattr(callback,
|
|
437
|
+
if hasattr(callback, "on_text"):
|
|
418
438
|
await callback.on_text(get_json(item))
|
|
419
|
-
|
|
439
|
+
|
|
420
440
|
async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
|
421
441
|
"""Called when an LLM API call is about to start."""
|
|
422
442
|
for callback in self.callbacks:
|
|
423
|
-
if hasattr(callback,
|
|
443
|
+
if hasattr(callback, "on_api_start"):
|
|
424
444
|
await callback.on_api_start(get_json(kwargs))
|
|
425
|
-
|
|
445
|
+
|
|
426
446
|
async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
|
427
447
|
"""Called when an LLM API call has completed."""
|
|
428
448
|
for callback in self.callbacks:
|
|
429
|
-
if hasattr(callback,
|
|
449
|
+
if hasattr(callback, "on_api_end"):
|
|
430
450
|
await callback.on_api_end(get_json(kwargs), get_json(result))
|
|
431
451
|
|
|
432
452
|
async def _on_usage(self, usage: Dict[str, Any]) -> None:
|
|
433
453
|
"""Called when usage information is received."""
|
|
434
454
|
for callback in self.callbacks:
|
|
435
|
-
if hasattr(callback,
|
|
455
|
+
if hasattr(callback, "on_usage"):
|
|
436
456
|
await callback.on_usage(get_json(usage))
|
|
437
457
|
|
|
438
458
|
async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
|
439
459
|
"""Called when a screenshot is taken."""
|
|
440
460
|
for callback in self.callbacks:
|
|
441
|
-
if hasattr(callback,
|
|
461
|
+
if hasattr(callback, "on_screenshot"):
|
|
442
462
|
await callback.on_screenshot(screenshot, name)
|
|
443
463
|
|
|
444
464
|
# ============================================================================
|
|
445
465
|
# AGENT OUTPUT PROCESSING
|
|
446
466
|
# ============================================================================
|
|
447
|
-
|
|
448
|
-
async def _handle_item(
|
|
467
|
+
|
|
468
|
+
async def _handle_item(
|
|
469
|
+
self,
|
|
470
|
+
item: Any,
|
|
471
|
+
computer: Optional[AsyncComputerHandler] = None,
|
|
472
|
+
ignore_call_ids: Optional[List[str]] = None,
|
|
473
|
+
) -> List[Dict[str, Any]]:
|
|
449
474
|
"""Handle each item; may cause a computer action + screenshot."""
|
|
450
475
|
call_id = item.get("call_id")
|
|
451
476
|
if ignore_call_ids and call_id and call_id in ignore_call_ids:
|
|
452
477
|
return []
|
|
453
|
-
|
|
478
|
+
|
|
454
479
|
item_type = item.get("type", None)
|
|
455
|
-
|
|
480
|
+
|
|
456
481
|
if item_type == "message":
|
|
457
482
|
await self._on_text(item)
|
|
458
483
|
# # Print messages
|
|
@@ -461,7 +486,7 @@ class ComputerAgent:
|
|
|
461
486
|
# if content_item.get("text"):
|
|
462
487
|
# print(content_item.get("text"))
|
|
463
488
|
return []
|
|
464
|
-
|
|
489
|
+
|
|
465
490
|
try:
|
|
466
491
|
if item_type == "computer_call":
|
|
467
492
|
await self._on_computer_call_start(item)
|
|
@@ -472,14 +497,16 @@ class ComputerAgent:
|
|
|
472
497
|
action = item.get("action")
|
|
473
498
|
action_type = action.get("type")
|
|
474
499
|
if action_type is None:
|
|
475
|
-
print(
|
|
500
|
+
print(
|
|
501
|
+
f"Action type cannot be `None`: action={action}, action_type={action_type}"
|
|
502
|
+
)
|
|
476
503
|
return []
|
|
477
|
-
|
|
504
|
+
|
|
478
505
|
# Extract action arguments (all fields except 'type')
|
|
479
506
|
action_args = {k: v for k, v in action.items() if k != "type"}
|
|
480
|
-
|
|
507
|
+
|
|
481
508
|
# print(f"{action_type}({action_args})")
|
|
482
|
-
|
|
509
|
+
|
|
483
510
|
# Execute the computer action
|
|
484
511
|
computer_method = getattr(computer, action_type, None)
|
|
485
512
|
if computer_method:
|
|
@@ -487,13 +514,13 @@ class ComputerAgent:
|
|
|
487
514
|
await computer_method(**action_args)
|
|
488
515
|
else:
|
|
489
516
|
raise ToolError(f"Unknown computer action: {action_type}")
|
|
490
|
-
|
|
517
|
+
|
|
491
518
|
# Take screenshot after action
|
|
492
519
|
if self.screenshot_delay and self.screenshot_delay > 0:
|
|
493
520
|
await asyncio.sleep(self.screenshot_delay)
|
|
494
521
|
screenshot_base64 = await computer.screenshot()
|
|
495
522
|
await self._on_screenshot(screenshot_base64, "screenshot_after")
|
|
496
|
-
|
|
523
|
+
|
|
497
524
|
# Handle safety checks
|
|
498
525
|
pending_checks = item.get("pending_safety_checks", [])
|
|
499
526
|
acknowledged_checks = []
|
|
@@ -505,7 +532,7 @@ class ComputerAgent:
|
|
|
505
532
|
# acknowledged_checks.append(check)
|
|
506
533
|
# else:
|
|
507
534
|
# raise ValueError(f"Safety check failed: {check_message}")
|
|
508
|
-
|
|
535
|
+
|
|
509
536
|
# Create call output
|
|
510
537
|
call_output = {
|
|
511
538
|
"type": "computer_call_output",
|
|
@@ -516,25 +543,25 @@ class ComputerAgent:
|
|
|
516
543
|
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
|
517
544
|
},
|
|
518
545
|
}
|
|
519
|
-
|
|
546
|
+
|
|
520
547
|
# # Additional URL safety checks for browser environments
|
|
521
548
|
# if await computer.get_environment() == "browser":
|
|
522
549
|
# current_url = await computer.get_current_url()
|
|
523
550
|
# call_output["output"]["current_url"] = current_url
|
|
524
551
|
# # TODO: implement a callback for URL safety checks
|
|
525
552
|
# # check_blocklisted_url(current_url)
|
|
526
|
-
|
|
553
|
+
|
|
527
554
|
result = [call_output]
|
|
528
555
|
await self._on_computer_call_end(item, result)
|
|
529
556
|
return result
|
|
530
|
-
|
|
557
|
+
|
|
531
558
|
if item_type == "function_call":
|
|
532
559
|
await self._on_function_call_start(item)
|
|
533
560
|
# Perform function call
|
|
534
561
|
function = self._get_tool(item.get("name"))
|
|
535
562
|
if not function:
|
|
536
|
-
raise ToolError(f"Function {item.get(
|
|
537
|
-
|
|
563
|
+
raise ToolError(f"Function {item.get('name')} not found")
|
|
564
|
+
|
|
538
565
|
args = json.loads(item.get("arguments"))
|
|
539
566
|
|
|
540
567
|
# Validate arguments before execution
|
|
@@ -545,14 +572,14 @@ class ComputerAgent:
|
|
|
545
572
|
result = await function(**args)
|
|
546
573
|
else:
|
|
547
574
|
result = await asyncio.to_thread(function, **args)
|
|
548
|
-
|
|
575
|
+
|
|
549
576
|
# Create function call output
|
|
550
577
|
call_output = {
|
|
551
578
|
"type": "function_call_output",
|
|
552
579
|
"call_id": item.get("call_id"),
|
|
553
580
|
"output": str(result),
|
|
554
581
|
}
|
|
555
|
-
|
|
582
|
+
|
|
556
583
|
result = [call_output]
|
|
557
584
|
await self._on_function_call_end(item, result)
|
|
558
585
|
return result
|
|
@@ -564,36 +591,35 @@ class ComputerAgent:
|
|
|
564
591
|
# ============================================================================
|
|
565
592
|
# MAIN AGENT LOOP
|
|
566
593
|
# ============================================================================
|
|
567
|
-
|
|
594
|
+
|
|
568
595
|
async def run(
|
|
569
|
-
self,
|
|
570
|
-
messages: Messages,
|
|
571
|
-
stream: bool = False,
|
|
572
|
-
**kwargs
|
|
596
|
+
self, messages: Messages, stream: bool = False, **kwargs
|
|
573
597
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
574
598
|
"""
|
|
575
599
|
Run the agent with the given messages using Computer protocol handler pattern.
|
|
576
|
-
|
|
600
|
+
|
|
577
601
|
Args:
|
|
578
602
|
messages: List of message dictionaries
|
|
579
603
|
stream: Whether to stream the response
|
|
580
604
|
**kwargs: Additional arguments
|
|
581
|
-
|
|
605
|
+
|
|
582
606
|
Returns:
|
|
583
607
|
AsyncGenerator that yields response chunks
|
|
584
608
|
"""
|
|
585
609
|
if not self.agent_config_info:
|
|
586
610
|
raise ValueError("Agent configuration not found")
|
|
587
|
-
|
|
611
|
+
|
|
588
612
|
capabilities = self.get_capabilities()
|
|
589
613
|
if "step" not in capabilities:
|
|
590
|
-
raise ValueError(
|
|
614
|
+
raise ValueError(
|
|
615
|
+
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions"
|
|
616
|
+
)
|
|
591
617
|
|
|
592
618
|
await self._initialize_computers()
|
|
593
|
-
|
|
619
|
+
|
|
594
620
|
# Merge kwargs
|
|
595
621
|
merged_kwargs = {**self.kwargs, **kwargs}
|
|
596
|
-
|
|
622
|
+
|
|
597
623
|
old_items = self._process_input(messages)
|
|
598
624
|
new_items = []
|
|
599
625
|
|
|
@@ -603,7 +629,7 @@ class ComputerAgent:
|
|
|
603
629
|
"stream": stream,
|
|
604
630
|
"model": self.model,
|
|
605
631
|
"agent_loop": self.agent_config_info.agent_class.__name__,
|
|
606
|
-
**merged_kwargs
|
|
632
|
+
**merged_kwargs,
|
|
607
633
|
}
|
|
608
634
|
await self._on_run_start(run_kwargs, old_items)
|
|
609
635
|
|
|
@@ -620,7 +646,7 @@ class ComputerAgent:
|
|
|
620
646
|
combined_messages = old_items + new_items
|
|
621
647
|
combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
|
|
622
648
|
preprocessed_messages = await self._on_llm_start(combined_messages)
|
|
623
|
-
|
|
649
|
+
|
|
624
650
|
loop_kwargs = {
|
|
625
651
|
"messages": preprocessed_messages,
|
|
626
652
|
"model": self.model,
|
|
@@ -629,7 +655,7 @@ class ComputerAgent:
|
|
|
629
655
|
"computer_handler": self.computer_handler,
|
|
630
656
|
"max_retries": self.max_retries,
|
|
631
657
|
"use_prompt_caching": self.use_prompt_caching,
|
|
632
|
-
**merged_kwargs
|
|
658
|
+
**merged_kwargs,
|
|
633
659
|
}
|
|
634
660
|
|
|
635
661
|
# Run agent loop iteration
|
|
@@ -641,13 +667,13 @@ class ComputerAgent:
|
|
|
641
667
|
_on_screenshot=self._on_screenshot,
|
|
642
668
|
)
|
|
643
669
|
result = get_json(result)
|
|
644
|
-
|
|
670
|
+
|
|
645
671
|
# Lifecycle hook: Postprocess messages after the LLM call
|
|
646
672
|
# Use cases:
|
|
647
673
|
# - PII deanonymization (if you want tool calls to see PII)
|
|
648
674
|
result["output"] = await self._on_llm_end(result.get("output", []))
|
|
649
675
|
await self._on_responses(loop_kwargs, result)
|
|
650
|
-
|
|
676
|
+
|
|
651
677
|
# Yield agent response
|
|
652
678
|
yield result
|
|
653
679
|
|
|
@@ -659,7 +685,9 @@ class ComputerAgent:
|
|
|
659
685
|
|
|
660
686
|
# Handle computer actions
|
|
661
687
|
for item in result.get("output"):
|
|
662
|
-
partial_items = await self._handle_item(
|
|
688
|
+
partial_items = await self._handle_item(
|
|
689
|
+
item, self.computer_handler, ignore_call_ids=output_call_ids
|
|
690
|
+
)
|
|
663
691
|
new_items += partial_items
|
|
664
692
|
|
|
665
693
|
# Yield partial response
|
|
@@ -669,54 +697,52 @@ class ComputerAgent:
|
|
|
669
697
|
prompt_tokens=0,
|
|
670
698
|
completion_tokens=0,
|
|
671
699
|
total_tokens=0,
|
|
672
|
-
)
|
|
700
|
+
),
|
|
673
701
|
}
|
|
674
|
-
|
|
702
|
+
|
|
675
703
|
await self._on_run_end(loop_kwargs, old_items, new_items)
|
|
676
|
-
|
|
704
|
+
|
|
677
705
|
async def predict_click(
|
|
678
|
-
self,
|
|
679
|
-
instruction: str,
|
|
680
|
-
image_b64: Optional[str] = None
|
|
706
|
+
self, instruction: str, image_b64: Optional[str] = None
|
|
681
707
|
) -> Optional[Tuple[int, int]]:
|
|
682
708
|
"""
|
|
683
709
|
Predict click coordinates based on image and instruction.
|
|
684
|
-
|
|
710
|
+
|
|
685
711
|
Args:
|
|
686
712
|
instruction: Instruction for where to click
|
|
687
713
|
image_b64: Base64 encoded image (optional, will take screenshot if not provided)
|
|
688
|
-
|
|
714
|
+
|
|
689
715
|
Returns:
|
|
690
716
|
None or tuple with (x, y) coordinates
|
|
691
717
|
"""
|
|
692
718
|
if not self.agent_config_info:
|
|
693
719
|
raise ValueError("Agent configuration not found")
|
|
694
|
-
|
|
720
|
+
|
|
695
721
|
capabilities = self.get_capabilities()
|
|
696
722
|
if "click" not in capabilities:
|
|
697
|
-
raise ValueError(
|
|
698
|
-
|
|
723
|
+
raise ValueError(
|
|
724
|
+
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions"
|
|
725
|
+
)
|
|
726
|
+
if hasattr(self.agent_loop, "predict_click"):
|
|
699
727
|
if not image_b64:
|
|
700
728
|
if not self.computer_handler:
|
|
701
729
|
raise ValueError("Computer tool or image_b64 is required for predict_click")
|
|
702
730
|
image_b64 = await self.computer_handler.screenshot()
|
|
703
731
|
return await self.agent_loop.predict_click(
|
|
704
|
-
model=self.model,
|
|
705
|
-
image_b64=image_b64,
|
|
706
|
-
instruction=instruction
|
|
732
|
+
model=self.model, image_b64=image_b64, instruction=instruction
|
|
707
733
|
)
|
|
708
734
|
return None
|
|
709
|
-
|
|
735
|
+
|
|
710
736
|
def get_capabilities(self) -> List[AgentCapability]:
|
|
711
737
|
"""
|
|
712
738
|
Get list of capabilities supported by the current agent config.
|
|
713
|
-
|
|
739
|
+
|
|
714
740
|
Returns:
|
|
715
741
|
List of capability strings (e.g., ["step", "click"])
|
|
716
742
|
"""
|
|
717
743
|
if not self.agent_config_info:
|
|
718
744
|
raise ValueError("Agent configuration not found")
|
|
719
|
-
|
|
720
|
-
if hasattr(self.agent_loop,
|
|
745
|
+
|
|
746
|
+
if hasattr(self.agent_loop, "get_capabilities"):
|
|
721
747
|
return self.agent_loop.get_capabilities()
|
|
722
|
-
return ["step"] # Default capability
|
|
748
|
+
return ["step"] # Default capability
|