cua-agent 0.4.14__py3-none-any.whl → 0.7.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +4 -19
- agent/__main__.py +2 -1
- agent/adapters/__init__.py +6 -0
- agent/adapters/azure_ml_adapter.py +283 -0
- agent/adapters/cua_adapter.py +161 -0
- agent/adapters/huggingfacelocal_adapter.py +67 -125
- agent/adapters/human_adapter.py +116 -114
- agent/adapters/mlxvlm_adapter.py +370 -0
- agent/adapters/models/__init__.py +41 -0
- agent/adapters/models/generic.py +78 -0
- agent/adapters/models/internvl.py +290 -0
- agent/adapters/models/opencua.py +115 -0
- agent/adapters/models/qwen2_5_vl.py +78 -0
- agent/agent.py +431 -241
- agent/callbacks/__init__.py +10 -3
- agent/callbacks/base.py +45 -31
- agent/callbacks/budget_manager.py +22 -10
- agent/callbacks/image_retention.py +54 -98
- agent/callbacks/logging.py +55 -42
- agent/callbacks/operator_validator.py +140 -0
- agent/callbacks/otel.py +291 -0
- agent/callbacks/pii_anonymization.py +19 -16
- agent/callbacks/prompt_instructions.py +47 -0
- agent/callbacks/telemetry.py +106 -69
- agent/callbacks/trajectory_saver.py +178 -70
- agent/cli.py +269 -119
- agent/computers/__init__.py +14 -9
- agent/computers/base.py +32 -19
- agent/computers/cua.py +52 -25
- agent/computers/custom.py +78 -71
- agent/decorators.py +23 -14
- agent/human_tool/__init__.py +2 -7
- agent/human_tool/__main__.py +6 -2
- agent/human_tool/server.py +48 -37
- agent/human_tool/ui.py +359 -235
- agent/integrations/hud/__init__.py +164 -74
- agent/integrations/hud/agent.py +338 -342
- agent/integrations/hud/proxy.py +297 -0
- agent/loops/__init__.py +44 -14
- agent/loops/anthropic.py +590 -492
- agent/loops/base.py +19 -15
- agent/loops/composed_grounded.py +142 -144
- agent/loops/fara/__init__.py +8 -0
- agent/loops/fara/config.py +506 -0
- agent/loops/fara/helpers.py +357 -0
- agent/loops/fara/schema.py +143 -0
- agent/loops/gelato.py +183 -0
- agent/loops/gemini.py +935 -0
- agent/loops/generic_vlm.py +601 -0
- agent/loops/glm45v.py +140 -135
- agent/loops/gta1.py +48 -51
- agent/loops/holo.py +218 -0
- agent/loops/internvl.py +180 -0
- agent/loops/moondream3.py +493 -0
- agent/loops/omniparser.py +326 -226
- agent/loops/openai.py +63 -56
- agent/loops/opencua.py +134 -0
- agent/loops/uiins.py +175 -0
- agent/loops/uitars.py +262 -212
- agent/loops/uitars2.py +951 -0
- agent/playground/__init__.py +5 -0
- agent/playground/server.py +301 -0
- agent/proxy/examples.py +196 -0
- agent/proxy/handlers.py +255 -0
- agent/responses.py +486 -339
- agent/tools/__init__.py +24 -0
- agent/tools/base.py +253 -0
- agent/tools/browser_tool.py +423 -0
- agent/types.py +20 -5
- agent/ui/__init__.py +1 -1
- agent/ui/__main__.py +1 -1
- agent/ui/gradio/app.py +25 -22
- agent/ui/gradio/ui_components.py +314 -167
- cua_agent-0.7.16.dist-info/METADATA +85 -0
- cua_agent-0.7.16.dist-info/RECORD +79 -0
- {cua_agent-0.4.14.dist-info → cua_agent-0.7.16.dist-info}/WHEEL +1 -1
- agent/integrations/hud/adapter.py +0 -121
- agent/integrations/hud/computer_handler.py +0 -187
- agent/telemetry.py +0 -142
- cua_agent-0.4.14.dist-info/METADATA +0 -436
- cua_agent-0.4.14.dist-info/RECORD +0 -50
- {cua_agent-0.4.14.dist-info → cua_agent-0.7.16.dist-info}/entry_points.txt +0 -0
agent/agent.py
CHANGED
|
@@ -3,57 +3,87 @@ ComputerAgent - Main agent class that selects and runs agent loops
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import asyncio
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
from litellm.responses.utils import Usage
|
|
9
|
-
|
|
10
|
-
from .types import Messages, AgentCapability
|
|
11
|
-
from .decorators import find_agent_config
|
|
6
|
+
import inspect
|
|
12
7
|
import json
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import (
|
|
10
|
+
Any,
|
|
11
|
+
AsyncGenerator,
|
|
12
|
+
Callable,
|
|
13
|
+
Dict,
|
|
14
|
+
List,
|
|
15
|
+
Optional,
|
|
16
|
+
Set,
|
|
17
|
+
Tuple,
|
|
18
|
+
Union,
|
|
19
|
+
cast,
|
|
20
|
+
)
|
|
21
|
+
|
|
13
22
|
import litellm
|
|
14
23
|
import litellm.utils
|
|
15
|
-
import
|
|
24
|
+
from litellm.responses.utils import Usage
|
|
25
|
+
|
|
16
26
|
from .adapters import (
|
|
27
|
+
AzureMLAdapter,
|
|
28
|
+
CUAAdapter,
|
|
17
29
|
HuggingFaceLocalAdapter,
|
|
18
30
|
HumanAdapter,
|
|
31
|
+
MLXVLMAdapter,
|
|
19
32
|
)
|
|
20
33
|
from .callbacks import (
|
|
21
|
-
ImageRetentionCallback,
|
|
22
|
-
LoggingCallback,
|
|
23
|
-
TrajectorySaverCallback,
|
|
24
34
|
BudgetManagerCallback,
|
|
35
|
+
ImageRetentionCallback,
|
|
36
|
+
LoggingCallback,
|
|
37
|
+
OperatorNormalizerCallback,
|
|
38
|
+
OtelCallback,
|
|
39
|
+
PromptInstructionsCallback,
|
|
25
40
|
TelemetryCallback,
|
|
41
|
+
TrajectorySaverCallback,
|
|
26
42
|
)
|
|
27
|
-
from .computers import
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
43
|
+
from .computers import AsyncComputerHandler, is_agent_computer, make_computer_handler
|
|
44
|
+
from .decorators import find_agent_config
|
|
45
|
+
from .responses import (
|
|
46
|
+
make_tool_error_item,
|
|
47
|
+
replace_failed_computer_calls_with_function_calls,
|
|
31
48
|
)
|
|
49
|
+
from .tools.base import BaseComputerTool, BaseTool
|
|
50
|
+
from .types import AgentCapability, IllegalArgumentError, Messages, ToolError
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def assert_callable_with(f, *args, **kwargs):
|
|
54
|
+
"""Check if function can be called with given arguments."""
|
|
55
|
+
try:
|
|
56
|
+
inspect.signature(f).bind(*args, **kwargs)
|
|
57
|
+
return True
|
|
58
|
+
except TypeError as e:
|
|
59
|
+
sig = inspect.signature(f)
|
|
60
|
+
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
|
|
61
|
+
|
|
32
62
|
|
|
33
63
|
def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
34
64
|
def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
|
|
35
65
|
if seen is None:
|
|
36
66
|
seen = set()
|
|
37
|
-
|
|
67
|
+
|
|
38
68
|
# Use model_dump() if available
|
|
39
|
-
if hasattr(o,
|
|
69
|
+
if hasattr(o, "model_dump"):
|
|
40
70
|
return o.model_dump()
|
|
41
|
-
|
|
71
|
+
|
|
42
72
|
# Check depth limit
|
|
43
73
|
if depth > max_depth:
|
|
44
74
|
return f"<max_depth_exceeded:{max_depth}>"
|
|
45
|
-
|
|
75
|
+
|
|
46
76
|
# Check for circular references using object id
|
|
47
77
|
obj_id = id(o)
|
|
48
78
|
if obj_id in seen:
|
|
49
79
|
return f"<circular_reference:{type(o).__name__}>"
|
|
50
|
-
|
|
80
|
+
|
|
51
81
|
# Handle Computer objects
|
|
52
|
-
if hasattr(o,
|
|
82
|
+
if hasattr(o, "__class__") and "computer" in o.__class__.__name__.lower():
|
|
53
83
|
return f"<computer:{o.__class__.__name__}>"
|
|
54
84
|
|
|
55
85
|
# Handle objects with __dict__
|
|
56
|
-
if hasattr(o,
|
|
86
|
+
if hasattr(o, "__dict__"):
|
|
57
87
|
seen.add(obj_id)
|
|
58
88
|
try:
|
|
59
89
|
result = {}
|
|
@@ -65,7 +95,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
|
65
95
|
return result
|
|
66
96
|
finally:
|
|
67
97
|
seen.discard(obj_id)
|
|
68
|
-
|
|
98
|
+
|
|
69
99
|
# Handle common types that might contain nested objects
|
|
70
100
|
elif isinstance(o, dict):
|
|
71
101
|
seen.add(obj_id)
|
|
@@ -77,7 +107,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
|
77
107
|
}
|
|
78
108
|
finally:
|
|
79
109
|
seen.discard(obj_id)
|
|
80
|
-
|
|
110
|
+
|
|
81
111
|
elif isinstance(o, (list, tuple, set)):
|
|
82
112
|
seen.add(obj_id)
|
|
83
113
|
try:
|
|
@@ -88,32 +118,33 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
|
88
118
|
]
|
|
89
119
|
finally:
|
|
90
120
|
seen.discard(obj_id)
|
|
91
|
-
|
|
121
|
+
|
|
92
122
|
# For basic types that json.dumps can handle
|
|
93
123
|
elif isinstance(o, (str, int, float, bool)) or o is None:
|
|
94
124
|
return o
|
|
95
|
-
|
|
125
|
+
|
|
96
126
|
# Fallback to string representation
|
|
97
127
|
else:
|
|
98
128
|
return str(o)
|
|
99
|
-
|
|
129
|
+
|
|
100
130
|
def remove_nones(obj: Any) -> Any:
|
|
101
131
|
if isinstance(obj, dict):
|
|
102
132
|
return {k: remove_nones(v) for k, v in obj.items() if v is not None}
|
|
103
133
|
elif isinstance(obj, list):
|
|
104
134
|
return [remove_nones(item) for item in obj if item is not None]
|
|
105
135
|
return obj
|
|
106
|
-
|
|
136
|
+
|
|
107
137
|
# Serialize with circular reference and depth protection
|
|
108
138
|
serialized = custom_serializer(obj)
|
|
109
|
-
|
|
139
|
+
|
|
110
140
|
# Convert to JSON string and back to ensure JSON compatibility
|
|
111
141
|
json_str = json.dumps(serialized)
|
|
112
142
|
parsed = json.loads(json_str)
|
|
113
|
-
|
|
143
|
+
|
|
114
144
|
# Final cleanup of any remaining None values
|
|
115
145
|
return remove_nones(parsed)
|
|
116
146
|
|
|
147
|
+
|
|
117
148
|
def sanitize_message(msg: Any) -> Any:
|
|
118
149
|
"""Return a copy of the message with image_url omitted for computer_call_output messages."""
|
|
119
150
|
if msg.get("type") == "computer_call_output":
|
|
@@ -124,19 +155,24 @@ def sanitize_message(msg: Any) -> Any:
|
|
|
124
155
|
return sanitized
|
|
125
156
|
return msg
|
|
126
157
|
|
|
158
|
+
|
|
127
159
|
def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
|
|
128
160
|
call_ids = []
|
|
129
161
|
for message in messages:
|
|
130
|
-
if
|
|
162
|
+
if (
|
|
163
|
+
message.get("type") == "computer_call_output"
|
|
164
|
+
or message.get("type") == "function_call_output"
|
|
165
|
+
):
|
|
131
166
|
call_ids.append(message.get("call_id"))
|
|
132
167
|
return call_ids
|
|
133
168
|
|
|
169
|
+
|
|
134
170
|
class ComputerAgent:
|
|
135
171
|
"""
|
|
136
172
|
Main agent class that automatically selects the appropriate agent loop
|
|
137
173
|
based on the model and executes tool calls.
|
|
138
174
|
"""
|
|
139
|
-
|
|
175
|
+
|
|
140
176
|
def __init__(
|
|
141
177
|
self,
|
|
142
178
|
model: str,
|
|
@@ -144,24 +180,29 @@ class ComputerAgent:
|
|
|
144
180
|
custom_loop: Optional[Callable] = None,
|
|
145
181
|
only_n_most_recent_images: Optional[int] = None,
|
|
146
182
|
callbacks: Optional[List[Any]] = None,
|
|
183
|
+
instructions: Optional[str] = None,
|
|
147
184
|
verbosity: Optional[int] = None,
|
|
148
|
-
trajectory_dir: Optional[str] = None,
|
|
185
|
+
trajectory_dir: Optional[str | Path | dict] = None,
|
|
149
186
|
max_retries: Optional[int] = 3,
|
|
150
187
|
screenshot_delay: Optional[float | int] = 0.5,
|
|
151
188
|
use_prompt_caching: Optional[bool] = False,
|
|
152
189
|
max_trajectory_budget: Optional[float | dict] = None,
|
|
153
190
|
telemetry_enabled: Optional[bool] = True,
|
|
154
|
-
|
|
191
|
+
trust_remote_code: Optional[bool] = False,
|
|
192
|
+
api_key: Optional[str] = None,
|
|
193
|
+
api_base: Optional[str] = None,
|
|
194
|
+
**additional_generation_kwargs,
|
|
155
195
|
):
|
|
156
196
|
"""
|
|
157
197
|
Initialize ComputerAgent.
|
|
158
|
-
|
|
198
|
+
|
|
159
199
|
Args:
|
|
160
|
-
model: Model name (e.g., "claude-
|
|
200
|
+
model: Model name (e.g., "claude-sonnet-4-5-20250929", "computer-use-preview", "omni+vertex_ai/gemini-pro")
|
|
161
201
|
tools: List of tools (computer objects, decorated functions, etc.)
|
|
162
202
|
custom_loop: Custom agent loop function to use instead of auto-selection
|
|
163
203
|
only_n_most_recent_images: If set, only keep the N most recent images in message history. Adds ImageRetentionCallback automatically.
|
|
164
204
|
callbacks: List of AsyncCallbackHandler instances for preprocessing/postprocessing
|
|
205
|
+
instructions: Optional system instructions to be passed to the model
|
|
165
206
|
verbosity: Logging level (logging.DEBUG, logging.INFO, etc.). If set, adds LoggingCallback automatically
|
|
166
207
|
trajectory_dir: If set, saves trajectory data (screenshots, responses) to this directory. Adds TrajectorySaverCallback automatically.
|
|
167
208
|
max_retries: Maximum number of retries for failed API calls
|
|
@@ -169,29 +210,40 @@ class ComputerAgent:
|
|
|
169
210
|
use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers.
|
|
170
211
|
max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
|
|
171
212
|
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
|
|
172
|
-
|
|
213
|
+
trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
|
|
214
|
+
api_key: Optional API key override for the model provider
|
|
215
|
+
api_base: Optional API base URL override for the model provider
|
|
216
|
+
**additional_generation_kwargs: Additional arguments passed to the model provider
|
|
173
217
|
"""
|
|
218
|
+
# If the loop is "human/human", we need to prefix a grounding model fallback
|
|
219
|
+
if model in ["human/human", "human"]:
|
|
220
|
+
model = "openai/computer-use-preview+human/human"
|
|
221
|
+
|
|
174
222
|
self.model = model
|
|
175
223
|
self.tools = tools or []
|
|
176
224
|
self.custom_loop = custom_loop
|
|
177
225
|
self.only_n_most_recent_images = only_n_most_recent_images
|
|
178
226
|
self.callbacks = callbacks or []
|
|
227
|
+
self.instructions = instructions
|
|
179
228
|
self.verbosity = verbosity
|
|
180
229
|
self.trajectory_dir = trajectory_dir
|
|
181
230
|
self.max_retries = max_retries
|
|
182
231
|
self.screenshot_delay = screenshot_delay
|
|
183
232
|
self.use_prompt_caching = use_prompt_caching
|
|
184
233
|
self.telemetry_enabled = telemetry_enabled
|
|
185
|
-
self.kwargs =
|
|
234
|
+
self.kwargs = additional_generation_kwargs
|
|
235
|
+
self.trust_remote_code = trust_remote_code
|
|
236
|
+
self.api_key = api_key
|
|
237
|
+
self.api_base = api_base
|
|
186
238
|
|
|
187
239
|
# == Add built-in callbacks ==
|
|
188
240
|
|
|
189
|
-
#
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
241
|
+
# Prepend operator normalizer callback
|
|
242
|
+
self.callbacks.insert(0, OperatorNormalizerCallback())
|
|
243
|
+
|
|
244
|
+
# Add prompt instructions callback if provided
|
|
245
|
+
if self.instructions:
|
|
246
|
+
self.callbacks.append(PromptInstructionsCallback(self.instructions))
|
|
195
247
|
|
|
196
248
|
# Add logging callback if verbosity is set
|
|
197
249
|
if self.verbosity is not None:
|
|
@@ -200,28 +252,37 @@ class ComputerAgent:
|
|
|
200
252
|
# Add image retention callback if only_n_most_recent_images is set
|
|
201
253
|
if self.only_n_most_recent_images:
|
|
202
254
|
self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
|
|
203
|
-
|
|
255
|
+
|
|
204
256
|
# Add trajectory saver callback if trajectory_dir is set
|
|
205
257
|
if self.trajectory_dir:
|
|
206
|
-
|
|
207
|
-
|
|
258
|
+
if isinstance(self.trajectory_dir, dict):
|
|
259
|
+
self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
|
|
260
|
+
elif isinstance(self.trajectory_dir, (str, Path)):
|
|
261
|
+
self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
|
|
262
|
+
|
|
208
263
|
# Add budget manager if max_trajectory_budget is set
|
|
209
264
|
if max_trajectory_budget:
|
|
210
265
|
if isinstance(max_trajectory_budget, dict):
|
|
211
266
|
self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
|
|
212
267
|
else:
|
|
213
268
|
self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
|
|
214
|
-
|
|
269
|
+
|
|
215
270
|
# == Enable local model providers w/ LiteLLM ==
|
|
216
271
|
|
|
217
272
|
# Register local model providers
|
|
218
273
|
hf_adapter = HuggingFaceLocalAdapter(
|
|
219
|
-
device="auto"
|
|
274
|
+
device="auto", trust_remote_code=self.trust_remote_code or False
|
|
220
275
|
)
|
|
221
276
|
human_adapter = HumanAdapter()
|
|
277
|
+
mlx_adapter = MLXVLMAdapter()
|
|
278
|
+
cua_adapter = CUAAdapter()
|
|
279
|
+
azure_ml_adapter = AzureMLAdapter()
|
|
222
280
|
litellm.custom_provider_map = [
|
|
223
281
|
{"provider": "huggingface-local", "custom_handler": hf_adapter},
|
|
224
|
-
{"provider": "human", "custom_handler": human_adapter}
|
|
282
|
+
{"provider": "human", "custom_handler": human_adapter},
|
|
283
|
+
{"provider": "mlx", "custom_handler": mlx_adapter},
|
|
284
|
+
{"provider": "cua", "custom_handler": cua_adapter},
|
|
285
|
+
{"provider": "azure_ml", "custom_handler": azure_ml_adapter},
|
|
225
286
|
]
|
|
226
287
|
litellm.suppress_debug_info = True
|
|
227
288
|
|
|
@@ -238,24 +299,47 @@ class ComputerAgent:
|
|
|
238
299
|
# Instantiate the agent config class
|
|
239
300
|
self.agent_loop = config_info.agent_class()
|
|
240
301
|
self.agent_config_info = config_info
|
|
241
|
-
|
|
302
|
+
|
|
303
|
+
# Add telemetry callbacks AFTER agent_loop is set so they can capture the correct agent_type
|
|
304
|
+
if self.telemetry_enabled:
|
|
305
|
+
# PostHog telemetry (product analytics)
|
|
306
|
+
if isinstance(self.telemetry_enabled, bool):
|
|
307
|
+
self.callbacks.append(TelemetryCallback(self))
|
|
308
|
+
else:
|
|
309
|
+
self.callbacks.append(TelemetryCallback(self, **self.telemetry_enabled))
|
|
310
|
+
|
|
311
|
+
# OpenTelemetry callback (operational metrics - Four Golden Signals)
|
|
312
|
+
# This is enabled alongside PostHog when telemetry_enabled is True
|
|
313
|
+
# Users can disable via CUA_TELEMETRY_DISABLED=true env var
|
|
314
|
+
self.callbacks.append(OtelCallback(self))
|
|
315
|
+
|
|
242
316
|
self.tool_schemas = []
|
|
243
317
|
self.computer_handler = None
|
|
244
|
-
|
|
318
|
+
|
|
245
319
|
async def _initialize_computers(self):
|
|
246
320
|
"""Initialize computer objects"""
|
|
247
321
|
if not self.tool_schemas:
|
|
248
322
|
# Process tools and create tool schemas
|
|
249
323
|
self.tool_schemas = self._process_tools()
|
|
250
|
-
|
|
324
|
+
|
|
251
325
|
# Find computer tool and create interface adapter
|
|
252
326
|
computer_handler = None
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
327
|
+
|
|
328
|
+
# First check if any tool is a BaseComputerTool instance
|
|
329
|
+
for tool in self.tools:
|
|
330
|
+
if isinstance(tool, BaseComputerTool):
|
|
331
|
+
computer_handler = tool
|
|
256
332
|
break
|
|
333
|
+
|
|
334
|
+
# If no BaseComputerTool found, look for traditional computer objects
|
|
335
|
+
if computer_handler is None:
|
|
336
|
+
for schema in self.tool_schemas:
|
|
337
|
+
if schema["type"] == "computer":
|
|
338
|
+
computer_handler = await make_computer_handler(schema["computer"])
|
|
339
|
+
break
|
|
340
|
+
|
|
257
341
|
self.computer_handler = computer_handler
|
|
258
|
-
|
|
342
|
+
|
|
259
343
|
def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
|
|
260
344
|
"""Process input messages and create schemas for the agent loop"""
|
|
261
345
|
if isinstance(input, str):
|
|
@@ -265,69 +349,85 @@ class ComputerAgent:
|
|
|
265
349
|
def _process_tools(self) -> List[Dict[str, Any]]:
|
|
266
350
|
"""Process tools and create schemas for the agent loop"""
|
|
267
351
|
schemas = []
|
|
268
|
-
|
|
352
|
+
|
|
269
353
|
for tool in self.tools:
|
|
270
354
|
# Check if it's a computer object (has interface attribute)
|
|
271
355
|
if is_agent_computer(tool):
|
|
272
356
|
# This is a computer tool - will be handled by agent loop
|
|
273
|
-
schemas.append({
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
357
|
+
schemas.append({"type": "computer", "computer": tool})
|
|
358
|
+
elif isinstance(tool, BaseTool):
|
|
359
|
+
# BaseTool instance - extract schema from its properties
|
|
360
|
+
function_schema = {
|
|
361
|
+
"name": tool.name,
|
|
362
|
+
"description": tool.description,
|
|
363
|
+
"parameters": tool.parameters,
|
|
364
|
+
}
|
|
365
|
+
schemas.append({"type": "function", "function": function_schema})
|
|
277
366
|
elif callable(tool):
|
|
278
367
|
# Use litellm.utils.function_to_dict to extract schema from docstring
|
|
279
368
|
try:
|
|
280
369
|
function_schema = litellm.utils.function_to_dict(tool)
|
|
281
|
-
schemas.append({
|
|
282
|
-
"type": "function",
|
|
283
|
-
"function": function_schema
|
|
284
|
-
})
|
|
370
|
+
schemas.append({"type": "function", "function": function_schema})
|
|
285
371
|
except Exception as e:
|
|
286
372
|
print(f"Warning: Could not process tool {tool}: {e}")
|
|
287
373
|
else:
|
|
288
374
|
print(f"Warning: Unknown tool type: {tool}")
|
|
289
|
-
|
|
375
|
+
|
|
290
376
|
return schemas
|
|
291
|
-
|
|
292
|
-
def _get_tool(self, name: str) -> Optional[Callable]:
|
|
377
|
+
|
|
378
|
+
def _get_tool(self, name: str) -> Optional[Union[Callable, BaseTool]]:
|
|
293
379
|
"""Get a tool by name"""
|
|
294
380
|
for tool in self.tools:
|
|
295
|
-
if
|
|
381
|
+
# Check if it's a BaseTool instance
|
|
382
|
+
if isinstance(tool, BaseTool) and tool.name == name:
|
|
296
383
|
return tool
|
|
297
|
-
|
|
384
|
+
# Check if it's a regular callable
|
|
385
|
+
elif hasattr(tool, "__name__") and tool.__name__ == name:
|
|
386
|
+
return tool
|
|
387
|
+
elif hasattr(tool, "func") and tool.func.__name__ == name:
|
|
298
388
|
return tool
|
|
299
389
|
return None
|
|
300
|
-
|
|
390
|
+
|
|
301
391
|
# ============================================================================
|
|
302
392
|
# AGENT RUN LOOP LIFECYCLE HOOKS
|
|
303
393
|
# ============================================================================
|
|
304
|
-
|
|
394
|
+
|
|
305
395
|
async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
|
306
396
|
"""Initialize run tracking by calling callbacks."""
|
|
307
397
|
for callback in self.callbacks:
|
|
308
|
-
if hasattr(callback,
|
|
398
|
+
if hasattr(callback, "on_run_start"):
|
|
309
399
|
await callback.on_run_start(kwargs, old_items)
|
|
310
|
-
|
|
311
|
-
async def _on_run_end(
|
|
400
|
+
|
|
401
|
+
async def _on_run_end(
|
|
402
|
+
self,
|
|
403
|
+
kwargs: Dict[str, Any],
|
|
404
|
+
old_items: List[Dict[str, Any]],
|
|
405
|
+
new_items: List[Dict[str, Any]],
|
|
406
|
+
) -> None:
|
|
312
407
|
"""Finalize run tracking by calling callbacks."""
|
|
313
408
|
for callback in self.callbacks:
|
|
314
|
-
if hasattr(callback,
|
|
409
|
+
if hasattr(callback, "on_run_end"):
|
|
315
410
|
await callback.on_run_end(kwargs, old_items, new_items)
|
|
316
|
-
|
|
317
|
-
async def _on_run_continue(
|
|
411
|
+
|
|
412
|
+
async def _on_run_continue(
|
|
413
|
+
self,
|
|
414
|
+
kwargs: Dict[str, Any],
|
|
415
|
+
old_items: List[Dict[str, Any]],
|
|
416
|
+
new_items: List[Dict[str, Any]],
|
|
417
|
+
) -> bool:
|
|
318
418
|
"""Check if run should continue by calling callbacks."""
|
|
319
419
|
for callback in self.callbacks:
|
|
320
|
-
if hasattr(callback,
|
|
420
|
+
if hasattr(callback, "on_run_continue"):
|
|
321
421
|
should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
|
|
322
422
|
if not should_continue:
|
|
323
423
|
return False
|
|
324
424
|
return True
|
|
325
|
-
|
|
425
|
+
|
|
326
426
|
async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
327
427
|
"""Prepare messages for the LLM call by applying callbacks."""
|
|
328
428
|
result = messages
|
|
329
429
|
for callback in self.callbacks:
|
|
330
|
-
if hasattr(callback,
|
|
430
|
+
if hasattr(callback, "on_llm_start"):
|
|
331
431
|
result = await callback.on_llm_start(result)
|
|
332
432
|
return result
|
|
333
433
|
|
|
@@ -335,81 +435,91 @@ class ComputerAgent:
|
|
|
335
435
|
"""Postprocess messages after the LLM call by applying callbacks."""
|
|
336
436
|
result = messages
|
|
337
437
|
for callback in self.callbacks:
|
|
338
|
-
if hasattr(callback,
|
|
438
|
+
if hasattr(callback, "on_llm_end"):
|
|
339
439
|
result = await callback.on_llm_end(result)
|
|
340
440
|
return result
|
|
341
441
|
|
|
342
442
|
async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
|
343
443
|
"""Called when responses are received."""
|
|
344
444
|
for callback in self.callbacks:
|
|
345
|
-
if hasattr(callback,
|
|
445
|
+
if hasattr(callback, "on_responses"):
|
|
346
446
|
await callback.on_responses(get_json(kwargs), get_json(responses))
|
|
347
|
-
|
|
447
|
+
|
|
348
448
|
async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
|
349
449
|
"""Called when a computer call is about to start."""
|
|
350
450
|
for callback in self.callbacks:
|
|
351
|
-
if hasattr(callback,
|
|
451
|
+
if hasattr(callback, "on_computer_call_start"):
|
|
352
452
|
await callback.on_computer_call_start(get_json(item))
|
|
353
|
-
|
|
354
|
-
async def _on_computer_call_end(
|
|
453
|
+
|
|
454
|
+
async def _on_computer_call_end(
|
|
455
|
+
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
|
456
|
+
) -> None:
|
|
355
457
|
"""Called when a computer call has completed."""
|
|
356
458
|
for callback in self.callbacks:
|
|
357
|
-
if hasattr(callback,
|
|
459
|
+
if hasattr(callback, "on_computer_call_end"):
|
|
358
460
|
await callback.on_computer_call_end(get_json(item), get_json(result))
|
|
359
|
-
|
|
461
|
+
|
|
360
462
|
async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
|
|
361
463
|
"""Called when a function call is about to start."""
|
|
362
464
|
for callback in self.callbacks:
|
|
363
|
-
if hasattr(callback,
|
|
465
|
+
if hasattr(callback, "on_function_call_start"):
|
|
364
466
|
await callback.on_function_call_start(get_json(item))
|
|
365
|
-
|
|
366
|
-
async def _on_function_call_end(
|
|
467
|
+
|
|
468
|
+
async def _on_function_call_end(
|
|
469
|
+
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
|
470
|
+
) -> None:
|
|
367
471
|
"""Called when a function call has completed."""
|
|
368
472
|
for callback in self.callbacks:
|
|
369
|
-
if hasattr(callback,
|
|
473
|
+
if hasattr(callback, "on_function_call_end"):
|
|
370
474
|
await callback.on_function_call_end(get_json(item), get_json(result))
|
|
371
|
-
|
|
475
|
+
|
|
372
476
|
async def _on_text(self, item: Dict[str, Any]) -> None:
|
|
373
477
|
"""Called when a text message is encountered."""
|
|
374
478
|
for callback in self.callbacks:
|
|
375
|
-
if hasattr(callback,
|
|
479
|
+
if hasattr(callback, "on_text"):
|
|
376
480
|
await callback.on_text(get_json(item))
|
|
377
|
-
|
|
481
|
+
|
|
378
482
|
async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
|
379
483
|
"""Called when an LLM API call is about to start."""
|
|
380
484
|
for callback in self.callbacks:
|
|
381
|
-
if hasattr(callback,
|
|
485
|
+
if hasattr(callback, "on_api_start"):
|
|
382
486
|
await callback.on_api_start(get_json(kwargs))
|
|
383
|
-
|
|
487
|
+
|
|
384
488
|
async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
|
385
489
|
"""Called when an LLM API call has completed."""
|
|
386
490
|
for callback in self.callbacks:
|
|
387
|
-
if hasattr(callback,
|
|
491
|
+
if hasattr(callback, "on_api_end"):
|
|
388
492
|
await callback.on_api_end(get_json(kwargs), get_json(result))
|
|
389
493
|
|
|
390
494
|
async def _on_usage(self, usage: Dict[str, Any]) -> None:
|
|
391
495
|
"""Called when usage information is received."""
|
|
392
496
|
for callback in self.callbacks:
|
|
393
|
-
if hasattr(callback,
|
|
497
|
+
if hasattr(callback, "on_usage"):
|
|
394
498
|
await callback.on_usage(get_json(usage))
|
|
395
499
|
|
|
396
500
|
async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
|
397
501
|
"""Called when a screenshot is taken."""
|
|
398
502
|
for callback in self.callbacks:
|
|
399
|
-
if hasattr(callback,
|
|
503
|
+
if hasattr(callback, "on_screenshot"):
|
|
400
504
|
await callback.on_screenshot(screenshot, name)
|
|
401
505
|
|
|
402
506
|
# ============================================================================
|
|
403
507
|
# AGENT OUTPUT PROCESSING
|
|
404
508
|
# ============================================================================
|
|
405
|
-
|
|
406
|
-
async def _handle_item(
|
|
509
|
+
|
|
510
|
+
async def _handle_item(
|
|
511
|
+
self,
|
|
512
|
+
item: Any,
|
|
513
|
+
computer: Optional[AsyncComputerHandler] = None,
|
|
514
|
+
ignore_call_ids: Optional[List[str]] = None,
|
|
515
|
+
) -> List[Dict[str, Any]]:
|
|
407
516
|
"""Handle each item; may cause a computer action + screenshot."""
|
|
408
|
-
|
|
517
|
+
call_id = item.get("call_id")
|
|
518
|
+
if ignore_call_ids and call_id and call_id in ignore_call_ids:
|
|
409
519
|
return []
|
|
410
|
-
|
|
520
|
+
|
|
411
521
|
item_type = item.get("type", None)
|
|
412
|
-
|
|
522
|
+
|
|
413
523
|
if item_type == "message":
|
|
414
524
|
await self._on_text(item)
|
|
415
525
|
# # Print messages
|
|
@@ -418,133 +528,156 @@ class ComputerAgent:
|
|
|
418
528
|
# if content_item.get("text"):
|
|
419
529
|
# print(content_item.get("text"))
|
|
420
530
|
return []
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
acknowledged_checks
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
"
|
|
472
|
-
"
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
531
|
+
|
|
532
|
+
try:
|
|
533
|
+
if item_type == "computer_call":
|
|
534
|
+
await self._on_computer_call_start(item)
|
|
535
|
+
if not computer:
|
|
536
|
+
raise ValueError("Computer handler is required for computer calls")
|
|
537
|
+
|
|
538
|
+
# Perform computer actions
|
|
539
|
+
action = item.get("action")
|
|
540
|
+
action_type = action.get("type")
|
|
541
|
+
if action_type is None:
|
|
542
|
+
print(
|
|
543
|
+
f"Action type cannot be `None`: action={action}, action_type={action_type}"
|
|
544
|
+
)
|
|
545
|
+
return []
|
|
546
|
+
|
|
547
|
+
# Extract action arguments (all fields except 'type')
|
|
548
|
+
action_args = {k: v for k, v in action.items() if k != "type"}
|
|
549
|
+
|
|
550
|
+
# print(f"{action_type}({action_args})")
|
|
551
|
+
|
|
552
|
+
# Execute the computer action
|
|
553
|
+
computer_method = getattr(computer, action_type, None)
|
|
554
|
+
if computer_method:
|
|
555
|
+
assert_callable_with(computer_method, **action_args)
|
|
556
|
+
await computer_method(**action_args)
|
|
557
|
+
else:
|
|
558
|
+
raise ToolError(f"Unknown computer action: {action_type}")
|
|
559
|
+
|
|
560
|
+
# Take screenshot after action
|
|
561
|
+
if self.screenshot_delay and self.screenshot_delay > 0:
|
|
562
|
+
await asyncio.sleep(self.screenshot_delay)
|
|
563
|
+
screenshot_base64 = await computer.screenshot()
|
|
564
|
+
await self._on_screenshot(screenshot_base64, "screenshot_after")
|
|
565
|
+
|
|
566
|
+
# Handle safety checks
|
|
567
|
+
pending_checks = item.get("pending_safety_checks", [])
|
|
568
|
+
acknowledged_checks = []
|
|
569
|
+
for check in pending_checks:
|
|
570
|
+
check_message = check.get("message", str(check))
|
|
571
|
+
acknowledged_checks.append(check)
|
|
572
|
+
# TODO: implement a callback for safety checks
|
|
573
|
+
# if acknowledge_safety_check_callback(check_message, allow_always=True):
|
|
574
|
+
# acknowledged_checks.append(check)
|
|
575
|
+
# else:
|
|
576
|
+
# raise ValueError(f"Safety check failed: {check_message}")
|
|
577
|
+
|
|
578
|
+
# Create call output
|
|
579
|
+
call_output = {
|
|
580
|
+
"type": "computer_call_output",
|
|
581
|
+
"call_id": item.get("call_id"),
|
|
582
|
+
"acknowledged_safety_checks": acknowledged_checks,
|
|
583
|
+
"output": {
|
|
584
|
+
"type": "input_image",
|
|
585
|
+
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
|
586
|
+
},
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
# # Additional URL safety checks for browser environments
|
|
590
|
+
# if await computer.get_environment() == "browser":
|
|
591
|
+
# current_url = await computer.get_current_url()
|
|
592
|
+
# call_output["output"]["current_url"] = current_url
|
|
593
|
+
# # TODO: implement a callback for URL safety checks
|
|
594
|
+
# # check_blocklisted_url(current_url)
|
|
595
|
+
|
|
596
|
+
result = [call_output]
|
|
597
|
+
await self._on_computer_call_end(item, result)
|
|
598
|
+
return result
|
|
599
|
+
|
|
600
|
+
if item_type == "function_call":
|
|
601
|
+
await self._on_function_call_start(item)
|
|
602
|
+
# Perform function call
|
|
603
|
+
function = self._get_tool(item.get("name"))
|
|
604
|
+
if not function:
|
|
605
|
+
raise ToolError(f"Function {item.get('name')} not found")
|
|
606
|
+
|
|
607
|
+
args = json.loads(item.get("arguments"))
|
|
608
|
+
|
|
609
|
+
# Handle BaseTool instances
|
|
610
|
+
if isinstance(function, BaseTool):
|
|
611
|
+
# BaseTool.call() handles its own execution
|
|
612
|
+
result = function.call(args)
|
|
613
|
+
else:
|
|
614
|
+
# Validate arguments before execution for regular callables
|
|
615
|
+
assert_callable_with(function, **args)
|
|
616
|
+
|
|
617
|
+
# Execute function - use asyncio.to_thread for non-async functions
|
|
618
|
+
if inspect.iscoroutinefunction(function):
|
|
619
|
+
result = await function(**args)
|
|
620
|
+
else:
|
|
621
|
+
result = await asyncio.to_thread(function, **args)
|
|
622
|
+
|
|
623
|
+
# Create function call output
|
|
624
|
+
call_output = {
|
|
625
|
+
"type": "function_call_output",
|
|
626
|
+
"call_id": item.get("call_id"),
|
|
627
|
+
"output": str(result),
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
result = [call_output]
|
|
631
|
+
await self._on_function_call_end(item, result)
|
|
632
|
+
return result
|
|
633
|
+
except ToolError as e:
|
|
634
|
+
return [make_tool_error_item(repr(e), call_id)]
|
|
512
635
|
|
|
513
636
|
return []
|
|
514
637
|
|
|
515
638
|
# ============================================================================
|
|
516
639
|
# MAIN AGENT LOOP
|
|
517
640
|
# ============================================================================
|
|
518
|
-
|
|
641
|
+
|
|
519
642
|
async def run(
|
|
520
643
|
self,
|
|
521
644
|
messages: Messages,
|
|
522
645
|
stream: bool = False,
|
|
523
|
-
|
|
646
|
+
api_key: Optional[str] = None,
|
|
647
|
+
api_base: Optional[str] = None,
|
|
648
|
+
**additional_generation_kwargs,
|
|
524
649
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
525
650
|
"""
|
|
526
651
|
Run the agent with the given messages using Computer protocol handler pattern.
|
|
527
|
-
|
|
652
|
+
|
|
528
653
|
Args:
|
|
529
654
|
messages: List of message dictionaries
|
|
530
655
|
stream: Whether to stream the response
|
|
531
|
-
|
|
532
|
-
|
|
656
|
+
api_key: Optional API key override for the model provider
|
|
657
|
+
api_base: Optional API base URL override for the model provider
|
|
658
|
+
**additional_generation_kwargs: Additional arguments passed to the model provider
|
|
659
|
+
|
|
533
660
|
Returns:
|
|
534
661
|
AsyncGenerator that yields response chunks
|
|
535
662
|
"""
|
|
536
663
|
if not self.agent_config_info:
|
|
537
664
|
raise ValueError("Agent configuration not found")
|
|
538
|
-
|
|
665
|
+
|
|
539
666
|
capabilities = self.get_capabilities()
|
|
540
667
|
if "step" not in capabilities:
|
|
541
|
-
raise ValueError(
|
|
668
|
+
raise ValueError(
|
|
669
|
+
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions"
|
|
670
|
+
)
|
|
542
671
|
|
|
543
672
|
await self._initialize_computers()
|
|
544
|
-
|
|
545
|
-
# Merge kwargs
|
|
546
|
-
merged_kwargs = {**self.kwargs, **
|
|
547
|
-
|
|
673
|
+
|
|
674
|
+
# Merge kwargs and thread api credentials (run overrides constructor)
|
|
675
|
+
merged_kwargs = {**self.kwargs, **additional_generation_kwargs}
|
|
676
|
+
if (api_key is not None) or (self.api_key is not None):
|
|
677
|
+
merged_kwargs["api_key"] = api_key if api_key is not None else self.api_key
|
|
678
|
+
if (api_base is not None) or (self.api_base is not None):
|
|
679
|
+
merged_kwargs["api_base"] = api_base if api_base is not None else self.api_base
|
|
680
|
+
|
|
548
681
|
old_items = self._process_input(messages)
|
|
549
682
|
new_items = []
|
|
550
683
|
|
|
@@ -554,7 +687,7 @@ class ComputerAgent:
|
|
|
554
687
|
"stream": stream,
|
|
555
688
|
"model": self.model,
|
|
556
689
|
"agent_loop": self.agent_config_info.agent_class.__name__,
|
|
557
|
-
**merged_kwargs
|
|
690
|
+
**merged_kwargs,
|
|
558
691
|
}
|
|
559
692
|
await self._on_run_start(run_kwargs, old_items)
|
|
560
693
|
|
|
@@ -569,8 +702,9 @@ class ComputerAgent:
|
|
|
569
702
|
# - PII anonymization
|
|
570
703
|
# - Image retention policy
|
|
571
704
|
combined_messages = old_items + new_items
|
|
705
|
+
combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
|
|
572
706
|
preprocessed_messages = await self._on_llm_start(combined_messages)
|
|
573
|
-
|
|
707
|
+
|
|
574
708
|
loop_kwargs = {
|
|
575
709
|
"messages": preprocessed_messages,
|
|
576
710
|
"model": self.model,
|
|
@@ -579,9 +713,39 @@ class ComputerAgent:
|
|
|
579
713
|
"computer_handler": self.computer_handler,
|
|
580
714
|
"max_retries": self.max_retries,
|
|
581
715
|
"use_prompt_caching": self.use_prompt_caching,
|
|
582
|
-
**merged_kwargs
|
|
716
|
+
**merged_kwargs,
|
|
583
717
|
}
|
|
584
718
|
|
|
719
|
+
# ---- Ollama image input guard ----
|
|
720
|
+
if isinstance(self.model, str) and (
|
|
721
|
+
"ollama/" in self.model or "ollama_chat/" in self.model
|
|
722
|
+
):
|
|
723
|
+
|
|
724
|
+
def contains_image_content(msgs):
|
|
725
|
+
for m in msgs:
|
|
726
|
+
# 1️⃣ Check regular message content
|
|
727
|
+
content = m.get("content")
|
|
728
|
+
if isinstance(content, list):
|
|
729
|
+
for item in content:
|
|
730
|
+
if isinstance(item, dict) and item.get("type") == "image_url":
|
|
731
|
+
return True
|
|
732
|
+
|
|
733
|
+
# 2️⃣ Check computer_call_output screenshots
|
|
734
|
+
if m.get("type") == "computer_call_output":
|
|
735
|
+
output = m.get("output", {})
|
|
736
|
+
if output.get("type") == "input_image" and "image_url" in output:
|
|
737
|
+
return True
|
|
738
|
+
|
|
739
|
+
return False
|
|
740
|
+
|
|
741
|
+
if contains_image_content(preprocessed_messages):
|
|
742
|
+
raise ValueError(
|
|
743
|
+
"Ollama models do not support image inputs required by ComputerAgent. "
|
|
744
|
+
"Please use a vision-capable model (e.g., OpenAI or Anthropic) "
|
|
745
|
+
"or remove computer/screenshot actions."
|
|
746
|
+
)
|
|
747
|
+
# ---------------------------------
|
|
748
|
+
|
|
585
749
|
# Run agent loop iteration
|
|
586
750
|
result = await self.agent_loop.predict_step(
|
|
587
751
|
**loop_kwargs,
|
|
@@ -591,13 +755,13 @@ class ComputerAgent:
|
|
|
591
755
|
_on_screenshot=self._on_screenshot,
|
|
592
756
|
)
|
|
593
757
|
result = get_json(result)
|
|
594
|
-
|
|
758
|
+
|
|
595
759
|
# Lifecycle hook: Postprocess messages after the LLM call
|
|
596
760
|
# Use cases:
|
|
597
761
|
# - PII deanonymization (if you want tool calls to see PII)
|
|
598
762
|
result["output"] = await self._on_llm_end(result.get("output", []))
|
|
599
763
|
await self._on_responses(loop_kwargs, result)
|
|
600
|
-
|
|
764
|
+
|
|
601
765
|
# Yield agent response
|
|
602
766
|
yield result
|
|
603
767
|
|
|
@@ -609,64 +773,90 @@ class ComputerAgent:
|
|
|
609
773
|
|
|
610
774
|
# Handle computer actions
|
|
611
775
|
for item in result.get("output"):
|
|
612
|
-
partial_items = await self._handle_item(
|
|
776
|
+
partial_items = await self._handle_item(
|
|
777
|
+
item, self.computer_handler, ignore_call_ids=output_call_ids
|
|
778
|
+
)
|
|
613
779
|
new_items += partial_items
|
|
614
780
|
|
|
615
|
-
# Yield partial response
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
781
|
+
# Yield partial response if any
|
|
782
|
+
if partial_items:
|
|
783
|
+
yield {
|
|
784
|
+
"output": partial_items,
|
|
785
|
+
"usage": Usage(
|
|
786
|
+
prompt_tokens=0,
|
|
787
|
+
completion_tokens=0,
|
|
788
|
+
total_tokens=0,
|
|
789
|
+
),
|
|
790
|
+
}
|
|
791
|
+
|
|
625
792
|
await self._on_run_end(loop_kwargs, old_items, new_items)
|
|
626
|
-
|
|
793
|
+
|
|
627
794
|
async def predict_click(
|
|
628
|
-
self,
|
|
629
|
-
instruction: str,
|
|
630
|
-
image_b64: Optional[str] = None
|
|
795
|
+
self, instruction: str, image_b64: Optional[str] = None
|
|
631
796
|
) -> Optional[Tuple[int, int]]:
|
|
632
797
|
"""
|
|
633
798
|
Predict click coordinates based on image and instruction.
|
|
634
|
-
|
|
799
|
+
|
|
635
800
|
Args:
|
|
636
801
|
instruction: Instruction for where to click
|
|
637
802
|
image_b64: Base64 encoded image (optional, will take screenshot if not provided)
|
|
638
|
-
|
|
803
|
+
|
|
639
804
|
Returns:
|
|
640
805
|
None or tuple with (x, y) coordinates
|
|
641
806
|
"""
|
|
642
807
|
if not self.agent_config_info:
|
|
643
808
|
raise ValueError("Agent configuration not found")
|
|
644
|
-
|
|
809
|
+
|
|
645
810
|
capabilities = self.get_capabilities()
|
|
646
811
|
if "click" not in capabilities:
|
|
647
|
-
raise ValueError(
|
|
648
|
-
|
|
812
|
+
raise ValueError(
|
|
813
|
+
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions"
|
|
814
|
+
)
|
|
815
|
+
if hasattr(self.agent_loop, "predict_click"):
|
|
649
816
|
if not image_b64:
|
|
650
817
|
if not self.computer_handler:
|
|
651
818
|
raise ValueError("Computer tool or image_b64 is required for predict_click")
|
|
652
819
|
image_b64 = await self.computer_handler.screenshot()
|
|
820
|
+
# Pass along api credentials if available
|
|
821
|
+
click_kwargs: Dict[str, Any] = {}
|
|
822
|
+
if self.api_key is not None:
|
|
823
|
+
click_kwargs["api_key"] = self.api_key
|
|
824
|
+
if self.api_base is not None:
|
|
825
|
+
click_kwargs["api_base"] = self.api_base
|
|
653
826
|
return await self.agent_loop.predict_click(
|
|
654
|
-
model=self.model,
|
|
655
|
-
image_b64=image_b64,
|
|
656
|
-
instruction=instruction
|
|
827
|
+
model=self.model, image_b64=image_b64, instruction=instruction, **click_kwargs
|
|
657
828
|
)
|
|
658
829
|
return None
|
|
659
|
-
|
|
830
|
+
|
|
660
831
|
def get_capabilities(self) -> List[AgentCapability]:
|
|
661
832
|
"""
|
|
662
833
|
Get list of capabilities supported by the current agent config.
|
|
663
|
-
|
|
834
|
+
|
|
664
835
|
Returns:
|
|
665
836
|
List of capability strings (e.g., ["step", "click"])
|
|
666
837
|
"""
|
|
667
838
|
if not self.agent_config_info:
|
|
668
839
|
raise ValueError("Agent configuration not found")
|
|
669
|
-
|
|
670
|
-
if hasattr(self.agent_loop,
|
|
840
|
+
|
|
841
|
+
if hasattr(self.agent_loop, "get_capabilities"):
|
|
671
842
|
return self.agent_loop.get_capabilities()
|
|
672
|
-
return ["step"] # Default capability
|
|
843
|
+
return ["step"] # Default capability
|
|
844
|
+
|
|
845
|
+
def open(self, port: Optional[int] = None):
|
|
846
|
+
"""
|
|
847
|
+
Start the playground server and open it in the browser.
|
|
848
|
+
|
|
849
|
+
This method starts a local HTTP server that exposes the /responses endpoint
|
|
850
|
+
and automatically opens the Cua playground interface in the default browser.
|
|
851
|
+
|
|
852
|
+
Args:
|
|
853
|
+
port: Port to run the server on. If None, finds an available port automatically.
|
|
854
|
+
|
|
855
|
+
Example:
|
|
856
|
+
>>> agent = ComputerAgent(model="claude-sonnet-4")
|
|
857
|
+
>>> agent.open() # Starts server and opens browser
|
|
858
|
+
"""
|
|
859
|
+
from .playground import PlaygroundServer
|
|
860
|
+
|
|
861
|
+
server = PlaygroundServer(agent_instance=self)
|
|
862
|
+
server.start(port=port, open_browser=True)
|