cua-agent 0.4.22__py3-none-any.whl → 0.7.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +4 -10
- agent/__main__.py +2 -1
- agent/adapters/__init__.py +4 -0
- agent/adapters/azure_ml_adapter.py +283 -0
- agent/adapters/cua_adapter.py +161 -0
- agent/adapters/huggingfacelocal_adapter.py +67 -125
- agent/adapters/human_adapter.py +116 -114
- agent/adapters/mlxvlm_adapter.py +110 -99
- agent/adapters/models/__init__.py +41 -0
- agent/adapters/models/generic.py +78 -0
- agent/adapters/models/internvl.py +290 -0
- agent/adapters/models/opencua.py +115 -0
- agent/adapters/models/qwen2_5_vl.py +78 -0
- agent/agent.py +337 -185
- agent/callbacks/__init__.py +9 -4
- agent/callbacks/base.py +45 -31
- agent/callbacks/budget_manager.py +22 -10
- agent/callbacks/image_retention.py +54 -98
- agent/callbacks/logging.py +55 -42
- agent/callbacks/operator_validator.py +35 -33
- agent/callbacks/otel.py +291 -0
- agent/callbacks/pii_anonymization.py +19 -16
- agent/callbacks/prompt_instructions.py +47 -0
- agent/callbacks/telemetry.py +99 -61
- agent/callbacks/trajectory_saver.py +95 -69
- agent/cli.py +269 -119
- agent/computers/__init__.py +14 -9
- agent/computers/base.py +32 -19
- agent/computers/cua.py +52 -25
- agent/computers/custom.py +78 -71
- agent/decorators.py +23 -14
- agent/human_tool/__init__.py +2 -7
- agent/human_tool/__main__.py +6 -2
- agent/human_tool/server.py +48 -37
- agent/human_tool/ui.py +359 -235
- agent/integrations/hud/__init__.py +38 -99
- agent/integrations/hud/agent.py +369 -0
- agent/integrations/hud/proxy.py +166 -52
- agent/loops/__init__.py +44 -14
- agent/loops/anthropic.py +579 -492
- agent/loops/base.py +19 -15
- agent/loops/composed_grounded.py +136 -150
- agent/loops/fara/__init__.py +8 -0
- agent/loops/fara/config.py +506 -0
- agent/loops/fara/helpers.py +357 -0
- agent/loops/fara/schema.py +143 -0
- agent/loops/gelato.py +183 -0
- agent/loops/gemini.py +935 -0
- agent/loops/generic_vlm.py +601 -0
- agent/loops/glm45v.py +140 -135
- agent/loops/gta1.py +48 -51
- agent/loops/holo.py +218 -0
- agent/loops/internvl.py +180 -0
- agent/loops/moondream3.py +493 -0
- agent/loops/omniparser.py +326 -226
- agent/loops/openai.py +50 -51
- agent/loops/opencua.py +134 -0
- agent/loops/uiins.py +175 -0
- agent/loops/uitars.py +247 -206
- agent/loops/uitars2.py +951 -0
- agent/playground/__init__.py +5 -0
- agent/playground/server.py +301 -0
- agent/proxy/examples.py +61 -57
- agent/proxy/handlers.py +46 -39
- agent/responses.py +447 -347
- agent/tools/__init__.py +24 -0
- agent/tools/base.py +253 -0
- agent/tools/browser_tool.py +423 -0
- agent/types.py +11 -5
- agent/ui/__init__.py +1 -1
- agent/ui/__main__.py +1 -1
- agent/ui/gradio/app.py +25 -22
- agent/ui/gradio/ui_components.py +314 -167
- cua_agent-0.7.16.dist-info/METADATA +85 -0
- cua_agent-0.7.16.dist-info/RECORD +79 -0
- {cua_agent-0.4.22.dist-info → cua_agent-0.7.16.dist-info}/WHEEL +1 -1
- cua_agent-0.4.22.dist-info/METADATA +0 -436
- cua_agent-0.4.22.dist-info/RECORD +0 -51
- {cua_agent-0.4.22.dist-info → cua_agent-0.7.16.dist-info}/entry_points.txt +0 -0
agent/agent.py
CHANGED
|
@@ -3,75 +3,87 @@ ComputerAgent - Main agent class that selects and runs agent loops
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import asyncio
|
|
6
|
+
import inspect
|
|
7
|
+
import json
|
|
6
8
|
from pathlib import Path
|
|
7
|
-
from typing import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
9
|
+
from typing import (
|
|
10
|
+
Any,
|
|
11
|
+
AsyncGenerator,
|
|
12
|
+
Callable,
|
|
13
|
+
Dict,
|
|
14
|
+
List,
|
|
15
|
+
Optional,
|
|
16
|
+
Set,
|
|
17
|
+
Tuple,
|
|
18
|
+
Union,
|
|
19
|
+
cast,
|
|
16
20
|
)
|
|
17
|
-
|
|
18
|
-
from .decorators import find_agent_config
|
|
19
|
-
import json
|
|
21
|
+
|
|
20
22
|
import litellm
|
|
21
23
|
import litellm.utils
|
|
22
|
-
import
|
|
24
|
+
from litellm.responses.utils import Usage
|
|
25
|
+
|
|
23
26
|
from .adapters import (
|
|
27
|
+
AzureMLAdapter,
|
|
28
|
+
CUAAdapter,
|
|
24
29
|
HuggingFaceLocalAdapter,
|
|
25
30
|
HumanAdapter,
|
|
26
31
|
MLXVLMAdapter,
|
|
27
32
|
)
|
|
28
33
|
from .callbacks import (
|
|
29
|
-
ImageRetentionCallback,
|
|
30
|
-
LoggingCallback,
|
|
31
|
-
TrajectorySaverCallback,
|
|
32
34
|
BudgetManagerCallback,
|
|
35
|
+
ImageRetentionCallback,
|
|
36
|
+
LoggingCallback,
|
|
37
|
+
OperatorNormalizerCallback,
|
|
38
|
+
OtelCallback,
|
|
39
|
+
PromptInstructionsCallback,
|
|
33
40
|
TelemetryCallback,
|
|
34
|
-
|
|
41
|
+
TrajectorySaverCallback,
|
|
35
42
|
)
|
|
36
|
-
from .computers import
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
43
|
+
from .computers import AsyncComputerHandler, is_agent_computer, make_computer_handler
|
|
44
|
+
from .decorators import find_agent_config
|
|
45
|
+
from .responses import (
|
|
46
|
+
make_tool_error_item,
|
|
47
|
+
replace_failed_computer_calls_with_function_calls,
|
|
40
48
|
)
|
|
49
|
+
from .tools.base import BaseComputerTool, BaseTool
|
|
50
|
+
from .types import AgentCapability, IllegalArgumentError, Messages, ToolError
|
|
51
|
+
|
|
41
52
|
|
|
42
53
|
def assert_callable_with(f, *args, **kwargs):
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
54
|
+
"""Check if function can be called with given arguments."""
|
|
55
|
+
try:
|
|
56
|
+
inspect.signature(f).bind(*args, **kwargs)
|
|
57
|
+
return True
|
|
58
|
+
except TypeError as e:
|
|
59
|
+
sig = inspect.signature(f)
|
|
60
|
+
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
|
|
61
|
+
|
|
50
62
|
|
|
51
63
|
def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
52
64
|
def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
|
|
53
65
|
if seen is None:
|
|
54
66
|
seen = set()
|
|
55
|
-
|
|
67
|
+
|
|
56
68
|
# Use model_dump() if available
|
|
57
|
-
if hasattr(o,
|
|
69
|
+
if hasattr(o, "model_dump"):
|
|
58
70
|
return o.model_dump()
|
|
59
|
-
|
|
71
|
+
|
|
60
72
|
# Check depth limit
|
|
61
73
|
if depth > max_depth:
|
|
62
74
|
return f"<max_depth_exceeded:{max_depth}>"
|
|
63
|
-
|
|
75
|
+
|
|
64
76
|
# Check for circular references using object id
|
|
65
77
|
obj_id = id(o)
|
|
66
78
|
if obj_id in seen:
|
|
67
79
|
return f"<circular_reference:{type(o).__name__}>"
|
|
68
|
-
|
|
80
|
+
|
|
69
81
|
# Handle Computer objects
|
|
70
|
-
if hasattr(o,
|
|
82
|
+
if hasattr(o, "__class__") and "computer" in o.__class__.__name__.lower():
|
|
71
83
|
return f"<computer:{o.__class__.__name__}>"
|
|
72
84
|
|
|
73
85
|
# Handle objects with __dict__
|
|
74
|
-
if hasattr(o,
|
|
86
|
+
if hasattr(o, "__dict__"):
|
|
75
87
|
seen.add(obj_id)
|
|
76
88
|
try:
|
|
77
89
|
result = {}
|
|
@@ -83,7 +95,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
|
83
95
|
return result
|
|
84
96
|
finally:
|
|
85
97
|
seen.discard(obj_id)
|
|
86
|
-
|
|
98
|
+
|
|
87
99
|
# Handle common types that might contain nested objects
|
|
88
100
|
elif isinstance(o, dict):
|
|
89
101
|
seen.add(obj_id)
|
|
@@ -95,7 +107,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
|
95
107
|
}
|
|
96
108
|
finally:
|
|
97
109
|
seen.discard(obj_id)
|
|
98
|
-
|
|
110
|
+
|
|
99
111
|
elif isinstance(o, (list, tuple, set)):
|
|
100
112
|
seen.add(obj_id)
|
|
101
113
|
try:
|
|
@@ -106,32 +118,33 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
|
106
118
|
]
|
|
107
119
|
finally:
|
|
108
120
|
seen.discard(obj_id)
|
|
109
|
-
|
|
121
|
+
|
|
110
122
|
# For basic types that json.dumps can handle
|
|
111
123
|
elif isinstance(o, (str, int, float, bool)) or o is None:
|
|
112
124
|
return o
|
|
113
|
-
|
|
125
|
+
|
|
114
126
|
# Fallback to string representation
|
|
115
127
|
else:
|
|
116
128
|
return str(o)
|
|
117
|
-
|
|
129
|
+
|
|
118
130
|
def remove_nones(obj: Any) -> Any:
|
|
119
131
|
if isinstance(obj, dict):
|
|
120
132
|
return {k: remove_nones(v) for k, v in obj.items() if v is not None}
|
|
121
133
|
elif isinstance(obj, list):
|
|
122
134
|
return [remove_nones(item) for item in obj if item is not None]
|
|
123
135
|
return obj
|
|
124
|
-
|
|
136
|
+
|
|
125
137
|
# Serialize with circular reference and depth protection
|
|
126
138
|
serialized = custom_serializer(obj)
|
|
127
|
-
|
|
139
|
+
|
|
128
140
|
# Convert to JSON string and back to ensure JSON compatibility
|
|
129
141
|
json_str = json.dumps(serialized)
|
|
130
142
|
parsed = json.loads(json_str)
|
|
131
|
-
|
|
143
|
+
|
|
132
144
|
# Final cleanup of any remaining None values
|
|
133
145
|
return remove_nones(parsed)
|
|
134
146
|
|
|
147
|
+
|
|
135
148
|
def sanitize_message(msg: Any) -> Any:
|
|
136
149
|
"""Return a copy of the message with image_url omitted for computer_call_output messages."""
|
|
137
150
|
if msg.get("type") == "computer_call_output":
|
|
@@ -142,19 +155,24 @@ def sanitize_message(msg: Any) -> Any:
|
|
|
142
155
|
return sanitized
|
|
143
156
|
return msg
|
|
144
157
|
|
|
158
|
+
|
|
145
159
|
def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
|
|
146
160
|
call_ids = []
|
|
147
161
|
for message in messages:
|
|
148
|
-
if
|
|
162
|
+
if (
|
|
163
|
+
message.get("type") == "computer_call_output"
|
|
164
|
+
or message.get("type") == "function_call_output"
|
|
165
|
+
):
|
|
149
166
|
call_ids.append(message.get("call_id"))
|
|
150
167
|
return call_ids
|
|
151
168
|
|
|
169
|
+
|
|
152
170
|
class ComputerAgent:
|
|
153
171
|
"""
|
|
154
172
|
Main agent class that automatically selects the appropriate agent loop
|
|
155
173
|
based on the model and executes tool calls.
|
|
156
174
|
"""
|
|
157
|
-
|
|
175
|
+
|
|
158
176
|
def __init__(
|
|
159
177
|
self,
|
|
160
178
|
model: str,
|
|
@@ -162,6 +180,7 @@ class ComputerAgent:
|
|
|
162
180
|
custom_loop: Optional[Callable] = None,
|
|
163
181
|
only_n_most_recent_images: Optional[int] = None,
|
|
164
182
|
callbacks: Optional[List[Any]] = None,
|
|
183
|
+
instructions: Optional[str] = None,
|
|
165
184
|
verbosity: Optional[int] = None,
|
|
166
185
|
trajectory_dir: Optional[str | Path | dict] = None,
|
|
167
186
|
max_retries: Optional[int] = 3,
|
|
@@ -169,17 +188,21 @@ class ComputerAgent:
|
|
|
169
188
|
use_prompt_caching: Optional[bool] = False,
|
|
170
189
|
max_trajectory_budget: Optional[float | dict] = None,
|
|
171
190
|
telemetry_enabled: Optional[bool] = True,
|
|
172
|
-
|
|
191
|
+
trust_remote_code: Optional[bool] = False,
|
|
192
|
+
api_key: Optional[str] = None,
|
|
193
|
+
api_base: Optional[str] = None,
|
|
194
|
+
**additional_generation_kwargs,
|
|
173
195
|
):
|
|
174
196
|
"""
|
|
175
197
|
Initialize ComputerAgent.
|
|
176
|
-
|
|
198
|
+
|
|
177
199
|
Args:
|
|
178
|
-
model: Model name (e.g., "claude-
|
|
200
|
+
model: Model name (e.g., "claude-sonnet-4-5-20250929", "computer-use-preview", "omni+vertex_ai/gemini-pro")
|
|
179
201
|
tools: List of tools (computer objects, decorated functions, etc.)
|
|
180
202
|
custom_loop: Custom agent loop function to use instead of auto-selection
|
|
181
203
|
only_n_most_recent_images: If set, only keep the N most recent images in message history. Adds ImageRetentionCallback automatically.
|
|
182
204
|
callbacks: List of AsyncCallbackHandler instances for preprocessing/postprocessing
|
|
205
|
+
instructions: Optional system instructions to be passed to the model
|
|
183
206
|
verbosity: Logging level (logging.DEBUG, logging.INFO, etc.). If set, adds LoggingCallback automatically
|
|
184
207
|
trajectory_dir: If set, saves trajectory data (screenshots, responses) to this directory. Adds TrajectorySaverCallback automatically.
|
|
185
208
|
max_retries: Maximum number of retries for failed API calls
|
|
@@ -187,36 +210,40 @@ class ComputerAgent:
|
|
|
187
210
|
use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers.
|
|
188
211
|
max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
|
|
189
212
|
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
|
|
190
|
-
|
|
191
|
-
|
|
213
|
+
trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
|
|
214
|
+
api_key: Optional API key override for the model provider
|
|
215
|
+
api_base: Optional API base URL override for the model provider
|
|
216
|
+
**additional_generation_kwargs: Additional arguments passed to the model provider
|
|
217
|
+
"""
|
|
192
218
|
# If the loop is "human/human", we need to prefix a grounding model fallback
|
|
193
219
|
if model in ["human/human", "human"]:
|
|
194
220
|
model = "openai/computer-use-preview+human/human"
|
|
195
|
-
|
|
221
|
+
|
|
196
222
|
self.model = model
|
|
197
223
|
self.tools = tools or []
|
|
198
224
|
self.custom_loop = custom_loop
|
|
199
225
|
self.only_n_most_recent_images = only_n_most_recent_images
|
|
200
226
|
self.callbacks = callbacks or []
|
|
227
|
+
self.instructions = instructions
|
|
201
228
|
self.verbosity = verbosity
|
|
202
229
|
self.trajectory_dir = trajectory_dir
|
|
203
230
|
self.max_retries = max_retries
|
|
204
231
|
self.screenshot_delay = screenshot_delay
|
|
205
232
|
self.use_prompt_caching = use_prompt_caching
|
|
206
233
|
self.telemetry_enabled = telemetry_enabled
|
|
207
|
-
self.kwargs =
|
|
234
|
+
self.kwargs = additional_generation_kwargs
|
|
235
|
+
self.trust_remote_code = trust_remote_code
|
|
236
|
+
self.api_key = api_key
|
|
237
|
+
self.api_base = api_base
|
|
208
238
|
|
|
209
239
|
# == Add built-in callbacks ==
|
|
210
240
|
|
|
211
241
|
# Prepend operator normalizer callback
|
|
212
242
|
self.callbacks.insert(0, OperatorNormalizerCallback())
|
|
213
243
|
|
|
214
|
-
# Add
|
|
215
|
-
if self.
|
|
216
|
-
|
|
217
|
-
self.callbacks.append(TelemetryCallback(self))
|
|
218
|
-
else:
|
|
219
|
-
self.callbacks.append(TelemetryCallback(self, **self.telemetry_enabled))
|
|
244
|
+
# Add prompt instructions callback if provided
|
|
245
|
+
if self.instructions:
|
|
246
|
+
self.callbacks.append(PromptInstructionsCallback(self.instructions))
|
|
220
247
|
|
|
221
248
|
# Add logging callback if verbosity is set
|
|
222
249
|
if self.verbosity is not None:
|
|
@@ -225,33 +252,37 @@ class ComputerAgent:
|
|
|
225
252
|
# Add image retention callback if only_n_most_recent_images is set
|
|
226
253
|
if self.only_n_most_recent_images:
|
|
227
254
|
self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
|
|
228
|
-
|
|
255
|
+
|
|
229
256
|
# Add trajectory saver callback if trajectory_dir is set
|
|
230
257
|
if self.trajectory_dir:
|
|
231
258
|
if isinstance(self.trajectory_dir, dict):
|
|
232
259
|
self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
|
|
233
260
|
elif isinstance(self.trajectory_dir, (str, Path)):
|
|
234
261
|
self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
|
|
235
|
-
|
|
262
|
+
|
|
236
263
|
# Add budget manager if max_trajectory_budget is set
|
|
237
264
|
if max_trajectory_budget:
|
|
238
265
|
if isinstance(max_trajectory_budget, dict):
|
|
239
266
|
self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
|
|
240
267
|
else:
|
|
241
268
|
self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
|
|
242
|
-
|
|
269
|
+
|
|
243
270
|
# == Enable local model providers w/ LiteLLM ==
|
|
244
271
|
|
|
245
272
|
# Register local model providers
|
|
246
273
|
hf_adapter = HuggingFaceLocalAdapter(
|
|
247
|
-
device="auto"
|
|
274
|
+
device="auto", trust_remote_code=self.trust_remote_code or False
|
|
248
275
|
)
|
|
249
276
|
human_adapter = HumanAdapter()
|
|
250
277
|
mlx_adapter = MLXVLMAdapter()
|
|
278
|
+
cua_adapter = CUAAdapter()
|
|
279
|
+
azure_ml_adapter = AzureMLAdapter()
|
|
251
280
|
litellm.custom_provider_map = [
|
|
252
281
|
{"provider": "huggingface-local", "custom_handler": hf_adapter},
|
|
253
282
|
{"provider": "human", "custom_handler": human_adapter},
|
|
254
|
-
{"provider": "mlx", "custom_handler": mlx_adapter}
|
|
283
|
+
{"provider": "mlx", "custom_handler": mlx_adapter},
|
|
284
|
+
{"provider": "cua", "custom_handler": cua_adapter},
|
|
285
|
+
{"provider": "azure_ml", "custom_handler": azure_ml_adapter},
|
|
255
286
|
]
|
|
256
287
|
litellm.suppress_debug_info = True
|
|
257
288
|
|
|
@@ -268,24 +299,47 @@ class ComputerAgent:
|
|
|
268
299
|
# Instantiate the agent config class
|
|
269
300
|
self.agent_loop = config_info.agent_class()
|
|
270
301
|
self.agent_config_info = config_info
|
|
271
|
-
|
|
302
|
+
|
|
303
|
+
# Add telemetry callbacks AFTER agent_loop is set so they can capture the correct agent_type
|
|
304
|
+
if self.telemetry_enabled:
|
|
305
|
+
# PostHog telemetry (product analytics)
|
|
306
|
+
if isinstance(self.telemetry_enabled, bool):
|
|
307
|
+
self.callbacks.append(TelemetryCallback(self))
|
|
308
|
+
else:
|
|
309
|
+
self.callbacks.append(TelemetryCallback(self, **self.telemetry_enabled))
|
|
310
|
+
|
|
311
|
+
# OpenTelemetry callback (operational metrics - Four Golden Signals)
|
|
312
|
+
# This is enabled alongside PostHog when telemetry_enabled is True
|
|
313
|
+
# Users can disable via CUA_TELEMETRY_DISABLED=true env var
|
|
314
|
+
self.callbacks.append(OtelCallback(self))
|
|
315
|
+
|
|
272
316
|
self.tool_schemas = []
|
|
273
317
|
self.computer_handler = None
|
|
274
|
-
|
|
318
|
+
|
|
275
319
|
async def _initialize_computers(self):
|
|
276
320
|
"""Initialize computer objects"""
|
|
277
321
|
if not self.tool_schemas:
|
|
278
322
|
# Process tools and create tool schemas
|
|
279
323
|
self.tool_schemas = self._process_tools()
|
|
280
|
-
|
|
324
|
+
|
|
281
325
|
# Find computer tool and create interface adapter
|
|
282
326
|
computer_handler = None
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
327
|
+
|
|
328
|
+
# First check if any tool is a BaseComputerTool instance
|
|
329
|
+
for tool in self.tools:
|
|
330
|
+
if isinstance(tool, BaseComputerTool):
|
|
331
|
+
computer_handler = tool
|
|
286
332
|
break
|
|
333
|
+
|
|
334
|
+
# If no BaseComputerTool found, look for traditional computer objects
|
|
335
|
+
if computer_handler is None:
|
|
336
|
+
for schema in self.tool_schemas:
|
|
337
|
+
if schema["type"] == "computer":
|
|
338
|
+
computer_handler = await make_computer_handler(schema["computer"])
|
|
339
|
+
break
|
|
340
|
+
|
|
287
341
|
self.computer_handler = computer_handler
|
|
288
|
-
|
|
342
|
+
|
|
289
343
|
def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
|
|
290
344
|
"""Process input messages and create schemas for the agent loop"""
|
|
291
345
|
if isinstance(input, str):
|
|
@@ -295,69 +349,85 @@ class ComputerAgent:
|
|
|
295
349
|
def _process_tools(self) -> List[Dict[str, Any]]:
|
|
296
350
|
"""Process tools and create schemas for the agent loop"""
|
|
297
351
|
schemas = []
|
|
298
|
-
|
|
352
|
+
|
|
299
353
|
for tool in self.tools:
|
|
300
354
|
# Check if it's a computer object (has interface attribute)
|
|
301
355
|
if is_agent_computer(tool):
|
|
302
356
|
# This is a computer tool - will be handled by agent loop
|
|
303
|
-
schemas.append({
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
357
|
+
schemas.append({"type": "computer", "computer": tool})
|
|
358
|
+
elif isinstance(tool, BaseTool):
|
|
359
|
+
# BaseTool instance - extract schema from its properties
|
|
360
|
+
function_schema = {
|
|
361
|
+
"name": tool.name,
|
|
362
|
+
"description": tool.description,
|
|
363
|
+
"parameters": tool.parameters,
|
|
364
|
+
}
|
|
365
|
+
schemas.append({"type": "function", "function": function_schema})
|
|
307
366
|
elif callable(tool):
|
|
308
367
|
# Use litellm.utils.function_to_dict to extract schema from docstring
|
|
309
368
|
try:
|
|
310
369
|
function_schema = litellm.utils.function_to_dict(tool)
|
|
311
|
-
schemas.append({
|
|
312
|
-
"type": "function",
|
|
313
|
-
"function": function_schema
|
|
314
|
-
})
|
|
370
|
+
schemas.append({"type": "function", "function": function_schema})
|
|
315
371
|
except Exception as e:
|
|
316
372
|
print(f"Warning: Could not process tool {tool}: {e}")
|
|
317
373
|
else:
|
|
318
374
|
print(f"Warning: Unknown tool type: {tool}")
|
|
319
|
-
|
|
375
|
+
|
|
320
376
|
return schemas
|
|
321
|
-
|
|
322
|
-
def _get_tool(self, name: str) -> Optional[Callable]:
|
|
377
|
+
|
|
378
|
+
def _get_tool(self, name: str) -> Optional[Union[Callable, BaseTool]]:
|
|
323
379
|
"""Get a tool by name"""
|
|
324
380
|
for tool in self.tools:
|
|
325
|
-
if
|
|
381
|
+
# Check if it's a BaseTool instance
|
|
382
|
+
if isinstance(tool, BaseTool) and tool.name == name:
|
|
383
|
+
return tool
|
|
384
|
+
# Check if it's a regular callable
|
|
385
|
+
elif hasattr(tool, "__name__") and tool.__name__ == name:
|
|
326
386
|
return tool
|
|
327
|
-
elif hasattr(tool,
|
|
387
|
+
elif hasattr(tool, "func") and tool.func.__name__ == name:
|
|
328
388
|
return tool
|
|
329
389
|
return None
|
|
330
|
-
|
|
390
|
+
|
|
331
391
|
# ============================================================================
|
|
332
392
|
# AGENT RUN LOOP LIFECYCLE HOOKS
|
|
333
393
|
# ============================================================================
|
|
334
|
-
|
|
394
|
+
|
|
335
395
|
async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
|
336
396
|
"""Initialize run tracking by calling callbacks."""
|
|
337
397
|
for callback in self.callbacks:
|
|
338
|
-
if hasattr(callback,
|
|
398
|
+
if hasattr(callback, "on_run_start"):
|
|
339
399
|
await callback.on_run_start(kwargs, old_items)
|
|
340
|
-
|
|
341
|
-
async def _on_run_end(
|
|
400
|
+
|
|
401
|
+
async def _on_run_end(
|
|
402
|
+
self,
|
|
403
|
+
kwargs: Dict[str, Any],
|
|
404
|
+
old_items: List[Dict[str, Any]],
|
|
405
|
+
new_items: List[Dict[str, Any]],
|
|
406
|
+
) -> None:
|
|
342
407
|
"""Finalize run tracking by calling callbacks."""
|
|
343
408
|
for callback in self.callbacks:
|
|
344
|
-
if hasattr(callback,
|
|
409
|
+
if hasattr(callback, "on_run_end"):
|
|
345
410
|
await callback.on_run_end(kwargs, old_items, new_items)
|
|
346
|
-
|
|
347
|
-
async def _on_run_continue(
|
|
411
|
+
|
|
412
|
+
async def _on_run_continue(
|
|
413
|
+
self,
|
|
414
|
+
kwargs: Dict[str, Any],
|
|
415
|
+
old_items: List[Dict[str, Any]],
|
|
416
|
+
new_items: List[Dict[str, Any]],
|
|
417
|
+
) -> bool:
|
|
348
418
|
"""Check if run should continue by calling callbacks."""
|
|
349
419
|
for callback in self.callbacks:
|
|
350
|
-
if hasattr(callback,
|
|
420
|
+
if hasattr(callback, "on_run_continue"):
|
|
351
421
|
should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
|
|
352
422
|
if not should_continue:
|
|
353
423
|
return False
|
|
354
424
|
return True
|
|
355
|
-
|
|
425
|
+
|
|
356
426
|
async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
357
427
|
"""Prepare messages for the LLM call by applying callbacks."""
|
|
358
428
|
result = messages
|
|
359
429
|
for callback in self.callbacks:
|
|
360
|
-
if hasattr(callback,
|
|
430
|
+
if hasattr(callback, "on_llm_start"):
|
|
361
431
|
result = await callback.on_llm_start(result)
|
|
362
432
|
return result
|
|
363
433
|
|
|
@@ -365,82 +435,91 @@ class ComputerAgent:
|
|
|
365
435
|
"""Postprocess messages after the LLM call by applying callbacks."""
|
|
366
436
|
result = messages
|
|
367
437
|
for callback in self.callbacks:
|
|
368
|
-
if hasattr(callback,
|
|
438
|
+
if hasattr(callback, "on_llm_end"):
|
|
369
439
|
result = await callback.on_llm_end(result)
|
|
370
440
|
return result
|
|
371
441
|
|
|
372
442
|
async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
|
373
443
|
"""Called when responses are received."""
|
|
374
444
|
for callback in self.callbacks:
|
|
375
|
-
if hasattr(callback,
|
|
445
|
+
if hasattr(callback, "on_responses"):
|
|
376
446
|
await callback.on_responses(get_json(kwargs), get_json(responses))
|
|
377
|
-
|
|
447
|
+
|
|
378
448
|
async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
|
379
449
|
"""Called when a computer call is about to start."""
|
|
380
450
|
for callback in self.callbacks:
|
|
381
|
-
if hasattr(callback,
|
|
451
|
+
if hasattr(callback, "on_computer_call_start"):
|
|
382
452
|
await callback.on_computer_call_start(get_json(item))
|
|
383
|
-
|
|
384
|
-
async def _on_computer_call_end(
|
|
453
|
+
|
|
454
|
+
async def _on_computer_call_end(
|
|
455
|
+
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
|
456
|
+
) -> None:
|
|
385
457
|
"""Called when a computer call has completed."""
|
|
386
458
|
for callback in self.callbacks:
|
|
387
|
-
if hasattr(callback,
|
|
459
|
+
if hasattr(callback, "on_computer_call_end"):
|
|
388
460
|
await callback.on_computer_call_end(get_json(item), get_json(result))
|
|
389
|
-
|
|
461
|
+
|
|
390
462
|
async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
|
|
391
463
|
"""Called when a function call is about to start."""
|
|
392
464
|
for callback in self.callbacks:
|
|
393
|
-
if hasattr(callback,
|
|
465
|
+
if hasattr(callback, "on_function_call_start"):
|
|
394
466
|
await callback.on_function_call_start(get_json(item))
|
|
395
|
-
|
|
396
|
-
async def _on_function_call_end(
|
|
467
|
+
|
|
468
|
+
async def _on_function_call_end(
|
|
469
|
+
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
|
470
|
+
) -> None:
|
|
397
471
|
"""Called when a function call has completed."""
|
|
398
472
|
for callback in self.callbacks:
|
|
399
|
-
if hasattr(callback,
|
|
473
|
+
if hasattr(callback, "on_function_call_end"):
|
|
400
474
|
await callback.on_function_call_end(get_json(item), get_json(result))
|
|
401
|
-
|
|
475
|
+
|
|
402
476
|
async def _on_text(self, item: Dict[str, Any]) -> None:
|
|
403
477
|
"""Called when a text message is encountered."""
|
|
404
478
|
for callback in self.callbacks:
|
|
405
|
-
if hasattr(callback,
|
|
479
|
+
if hasattr(callback, "on_text"):
|
|
406
480
|
await callback.on_text(get_json(item))
|
|
407
|
-
|
|
481
|
+
|
|
408
482
|
async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
|
409
483
|
"""Called when an LLM API call is about to start."""
|
|
410
484
|
for callback in self.callbacks:
|
|
411
|
-
if hasattr(callback,
|
|
485
|
+
if hasattr(callback, "on_api_start"):
|
|
412
486
|
await callback.on_api_start(get_json(kwargs))
|
|
413
|
-
|
|
487
|
+
|
|
414
488
|
async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
|
415
489
|
"""Called when an LLM API call has completed."""
|
|
416
490
|
for callback in self.callbacks:
|
|
417
|
-
if hasattr(callback,
|
|
491
|
+
if hasattr(callback, "on_api_end"):
|
|
418
492
|
await callback.on_api_end(get_json(kwargs), get_json(result))
|
|
419
493
|
|
|
420
494
|
async def _on_usage(self, usage: Dict[str, Any]) -> None:
|
|
421
495
|
"""Called when usage information is received."""
|
|
422
496
|
for callback in self.callbacks:
|
|
423
|
-
if hasattr(callback,
|
|
497
|
+
if hasattr(callback, "on_usage"):
|
|
424
498
|
await callback.on_usage(get_json(usage))
|
|
425
499
|
|
|
426
500
|
async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
|
427
501
|
"""Called when a screenshot is taken."""
|
|
428
502
|
for callback in self.callbacks:
|
|
429
|
-
if hasattr(callback,
|
|
503
|
+
if hasattr(callback, "on_screenshot"):
|
|
430
504
|
await callback.on_screenshot(screenshot, name)
|
|
431
505
|
|
|
432
506
|
# ============================================================================
|
|
433
507
|
# AGENT OUTPUT PROCESSING
|
|
434
508
|
# ============================================================================
|
|
435
|
-
|
|
436
|
-
async def _handle_item(
|
|
509
|
+
|
|
510
|
+
async def _handle_item(
|
|
511
|
+
self,
|
|
512
|
+
item: Any,
|
|
513
|
+
computer: Optional[AsyncComputerHandler] = None,
|
|
514
|
+
ignore_call_ids: Optional[List[str]] = None,
|
|
515
|
+
) -> List[Dict[str, Any]]:
|
|
437
516
|
"""Handle each item; may cause a computer action + screenshot."""
|
|
438
517
|
call_id = item.get("call_id")
|
|
439
518
|
if ignore_call_ids and call_id and call_id in ignore_call_ids:
|
|
440
519
|
return []
|
|
441
|
-
|
|
520
|
+
|
|
442
521
|
item_type = item.get("type", None)
|
|
443
|
-
|
|
522
|
+
|
|
444
523
|
if item_type == "message":
|
|
445
524
|
await self._on_text(item)
|
|
446
525
|
# # Print messages
|
|
@@ -449,7 +528,7 @@ class ComputerAgent:
|
|
|
449
528
|
# if content_item.get("text"):
|
|
450
529
|
# print(content_item.get("text"))
|
|
451
530
|
return []
|
|
452
|
-
|
|
531
|
+
|
|
453
532
|
try:
|
|
454
533
|
if item_type == "computer_call":
|
|
455
534
|
await self._on_computer_call_start(item)
|
|
@@ -460,14 +539,16 @@ class ComputerAgent:
|
|
|
460
539
|
action = item.get("action")
|
|
461
540
|
action_type = action.get("type")
|
|
462
541
|
if action_type is None:
|
|
463
|
-
print(
|
|
542
|
+
print(
|
|
543
|
+
f"Action type cannot be `None`: action={action}, action_type={action_type}"
|
|
544
|
+
)
|
|
464
545
|
return []
|
|
465
|
-
|
|
546
|
+
|
|
466
547
|
# Extract action arguments (all fields except 'type')
|
|
467
548
|
action_args = {k: v for k, v in action.items() if k != "type"}
|
|
468
|
-
|
|
549
|
+
|
|
469
550
|
# print(f"{action_type}({action_args})")
|
|
470
|
-
|
|
551
|
+
|
|
471
552
|
# Execute the computer action
|
|
472
553
|
computer_method = getattr(computer, action_type, None)
|
|
473
554
|
if computer_method:
|
|
@@ -475,13 +556,13 @@ class ComputerAgent:
|
|
|
475
556
|
await computer_method(**action_args)
|
|
476
557
|
else:
|
|
477
558
|
raise ToolError(f"Unknown computer action: {action_type}")
|
|
478
|
-
|
|
559
|
+
|
|
479
560
|
# Take screenshot after action
|
|
480
561
|
if self.screenshot_delay and self.screenshot_delay > 0:
|
|
481
562
|
await asyncio.sleep(self.screenshot_delay)
|
|
482
563
|
screenshot_base64 = await computer.screenshot()
|
|
483
564
|
await self._on_screenshot(screenshot_base64, "screenshot_after")
|
|
484
|
-
|
|
565
|
+
|
|
485
566
|
# Handle safety checks
|
|
486
567
|
pending_checks = item.get("pending_safety_checks", [])
|
|
487
568
|
acknowledged_checks = []
|
|
@@ -493,7 +574,7 @@ class ComputerAgent:
|
|
|
493
574
|
# acknowledged_checks.append(check)
|
|
494
575
|
# else:
|
|
495
576
|
# raise ValueError(f"Safety check failed: {check_message}")
|
|
496
|
-
|
|
577
|
+
|
|
497
578
|
# Create call output
|
|
498
579
|
call_output = {
|
|
499
580
|
"type": "computer_call_output",
|
|
@@ -504,43 +585,48 @@ class ComputerAgent:
|
|
|
504
585
|
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
|
505
586
|
},
|
|
506
587
|
}
|
|
507
|
-
|
|
588
|
+
|
|
508
589
|
# # Additional URL safety checks for browser environments
|
|
509
590
|
# if await computer.get_environment() == "browser":
|
|
510
591
|
# current_url = await computer.get_current_url()
|
|
511
592
|
# call_output["output"]["current_url"] = current_url
|
|
512
593
|
# # TODO: implement a callback for URL safety checks
|
|
513
594
|
# # check_blocklisted_url(current_url)
|
|
514
|
-
|
|
595
|
+
|
|
515
596
|
result = [call_output]
|
|
516
597
|
await self._on_computer_call_end(item, result)
|
|
517
598
|
return result
|
|
518
|
-
|
|
599
|
+
|
|
519
600
|
if item_type == "function_call":
|
|
520
601
|
await self._on_function_call_start(item)
|
|
521
602
|
# Perform function call
|
|
522
603
|
function = self._get_tool(item.get("name"))
|
|
523
604
|
if not function:
|
|
524
|
-
raise ToolError(f"Function {item.get(
|
|
525
|
-
|
|
526
|
-
args = json.loads(item.get("arguments"))
|
|
605
|
+
raise ToolError(f"Function {item.get('name')} not found")
|
|
527
606
|
|
|
528
|
-
|
|
529
|
-
assert_callable_with(function, **args)
|
|
607
|
+
args = json.loads(item.get("arguments"))
|
|
530
608
|
|
|
531
|
-
#
|
|
532
|
-
if
|
|
533
|
-
|
|
609
|
+
# Handle BaseTool instances
|
|
610
|
+
if isinstance(function, BaseTool):
|
|
611
|
+
# BaseTool.call() handles its own execution
|
|
612
|
+
result = function.call(args)
|
|
534
613
|
else:
|
|
535
|
-
|
|
536
|
-
|
|
614
|
+
# Validate arguments before execution for regular callables
|
|
615
|
+
assert_callable_with(function, **args)
|
|
616
|
+
|
|
617
|
+
# Execute function - use asyncio.to_thread for non-async functions
|
|
618
|
+
if inspect.iscoroutinefunction(function):
|
|
619
|
+
result = await function(**args)
|
|
620
|
+
else:
|
|
621
|
+
result = await asyncio.to_thread(function, **args)
|
|
622
|
+
|
|
537
623
|
# Create function call output
|
|
538
624
|
call_output = {
|
|
539
625
|
"type": "function_call_output",
|
|
540
626
|
"call_id": item.get("call_id"),
|
|
541
627
|
"output": str(result),
|
|
542
628
|
}
|
|
543
|
-
|
|
629
|
+
|
|
544
630
|
result = [call_output]
|
|
545
631
|
await self._on_function_call_end(item, result)
|
|
546
632
|
return result
|
|
@@ -552,36 +638,46 @@ class ComputerAgent:
|
|
|
552
638
|
# ============================================================================
|
|
553
639
|
# MAIN AGENT LOOP
|
|
554
640
|
# ============================================================================
|
|
555
|
-
|
|
641
|
+
|
|
556
642
|
async def run(
|
|
557
643
|
self,
|
|
558
644
|
messages: Messages,
|
|
559
645
|
stream: bool = False,
|
|
560
|
-
|
|
646
|
+
api_key: Optional[str] = None,
|
|
647
|
+
api_base: Optional[str] = None,
|
|
648
|
+
**additional_generation_kwargs,
|
|
561
649
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
562
650
|
"""
|
|
563
651
|
Run the agent with the given messages using Computer protocol handler pattern.
|
|
564
|
-
|
|
652
|
+
|
|
565
653
|
Args:
|
|
566
654
|
messages: List of message dictionaries
|
|
567
655
|
stream: Whether to stream the response
|
|
568
|
-
|
|
569
|
-
|
|
656
|
+
api_key: Optional API key override for the model provider
|
|
657
|
+
api_base: Optional API base URL override for the model provider
|
|
658
|
+
**additional_generation_kwargs: Additional arguments passed to the model provider
|
|
659
|
+
|
|
570
660
|
Returns:
|
|
571
661
|
AsyncGenerator that yields response chunks
|
|
572
662
|
"""
|
|
573
663
|
if not self.agent_config_info:
|
|
574
664
|
raise ValueError("Agent configuration not found")
|
|
575
|
-
|
|
665
|
+
|
|
576
666
|
capabilities = self.get_capabilities()
|
|
577
667
|
if "step" not in capabilities:
|
|
578
|
-
raise ValueError(
|
|
668
|
+
raise ValueError(
|
|
669
|
+
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions"
|
|
670
|
+
)
|
|
579
671
|
|
|
580
672
|
await self._initialize_computers()
|
|
581
|
-
|
|
582
|
-
# Merge kwargs
|
|
583
|
-
merged_kwargs = {**self.kwargs, **
|
|
584
|
-
|
|
673
|
+
|
|
674
|
+
# Merge kwargs and thread api credentials (run overrides constructor)
|
|
675
|
+
merged_kwargs = {**self.kwargs, **additional_generation_kwargs}
|
|
676
|
+
if (api_key is not None) or (self.api_key is not None):
|
|
677
|
+
merged_kwargs["api_key"] = api_key if api_key is not None else self.api_key
|
|
678
|
+
if (api_base is not None) or (self.api_base is not None):
|
|
679
|
+
merged_kwargs["api_base"] = api_base if api_base is not None else self.api_base
|
|
680
|
+
|
|
585
681
|
old_items = self._process_input(messages)
|
|
586
682
|
new_items = []
|
|
587
683
|
|
|
@@ -591,7 +687,7 @@ class ComputerAgent:
|
|
|
591
687
|
"stream": stream,
|
|
592
688
|
"model": self.model,
|
|
593
689
|
"agent_loop": self.agent_config_info.agent_class.__name__,
|
|
594
|
-
**merged_kwargs
|
|
690
|
+
**merged_kwargs,
|
|
595
691
|
}
|
|
596
692
|
await self._on_run_start(run_kwargs, old_items)
|
|
597
693
|
|
|
@@ -608,7 +704,7 @@ class ComputerAgent:
|
|
|
608
704
|
combined_messages = old_items + new_items
|
|
609
705
|
combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
|
|
610
706
|
preprocessed_messages = await self._on_llm_start(combined_messages)
|
|
611
|
-
|
|
707
|
+
|
|
612
708
|
loop_kwargs = {
|
|
613
709
|
"messages": preprocessed_messages,
|
|
614
710
|
"model": self.model,
|
|
@@ -617,9 +713,39 @@ class ComputerAgent:
|
|
|
617
713
|
"computer_handler": self.computer_handler,
|
|
618
714
|
"max_retries": self.max_retries,
|
|
619
715
|
"use_prompt_caching": self.use_prompt_caching,
|
|
620
|
-
**merged_kwargs
|
|
716
|
+
**merged_kwargs,
|
|
621
717
|
}
|
|
622
718
|
|
|
719
|
+
# ---- Ollama image input guard ----
|
|
720
|
+
if isinstance(self.model, str) and (
|
|
721
|
+
"ollama/" in self.model or "ollama_chat/" in self.model
|
|
722
|
+
):
|
|
723
|
+
|
|
724
|
+
def contains_image_content(msgs):
|
|
725
|
+
for m in msgs:
|
|
726
|
+
# 1️⃣ Check regular message content
|
|
727
|
+
content = m.get("content")
|
|
728
|
+
if isinstance(content, list):
|
|
729
|
+
for item in content:
|
|
730
|
+
if isinstance(item, dict) and item.get("type") == "image_url":
|
|
731
|
+
return True
|
|
732
|
+
|
|
733
|
+
# 2️⃣ Check computer_call_output screenshots
|
|
734
|
+
if m.get("type") == "computer_call_output":
|
|
735
|
+
output = m.get("output", {})
|
|
736
|
+
if output.get("type") == "input_image" and "image_url" in output:
|
|
737
|
+
return True
|
|
738
|
+
|
|
739
|
+
return False
|
|
740
|
+
|
|
741
|
+
if contains_image_content(preprocessed_messages):
|
|
742
|
+
raise ValueError(
|
|
743
|
+
"Ollama models do not support image inputs required by ComputerAgent. "
|
|
744
|
+
"Please use a vision-capable model (e.g., OpenAI or Anthropic) "
|
|
745
|
+
"or remove computer/screenshot actions."
|
|
746
|
+
)
|
|
747
|
+
# ---------------------------------
|
|
748
|
+
|
|
623
749
|
# Run agent loop iteration
|
|
624
750
|
result = await self.agent_loop.predict_step(
|
|
625
751
|
**loop_kwargs,
|
|
@@ -629,13 +755,13 @@ class ComputerAgent:
|
|
|
629
755
|
_on_screenshot=self._on_screenshot,
|
|
630
756
|
)
|
|
631
757
|
result = get_json(result)
|
|
632
|
-
|
|
758
|
+
|
|
633
759
|
# Lifecycle hook: Postprocess messages after the LLM call
|
|
634
760
|
# Use cases:
|
|
635
761
|
# - PII deanonymization (if you want tool calls to see PII)
|
|
636
762
|
result["output"] = await self._on_llm_end(result.get("output", []))
|
|
637
763
|
await self._on_responses(loop_kwargs, result)
|
|
638
|
-
|
|
764
|
+
|
|
639
765
|
# Yield agent response
|
|
640
766
|
yield result
|
|
641
767
|
|
|
@@ -647,64 +773,90 @@ class ComputerAgent:
|
|
|
647
773
|
|
|
648
774
|
# Handle computer actions
|
|
649
775
|
for item in result.get("output"):
|
|
650
|
-
partial_items = await self._handle_item(
|
|
776
|
+
partial_items = await self._handle_item(
|
|
777
|
+
item, self.computer_handler, ignore_call_ids=output_call_ids
|
|
778
|
+
)
|
|
651
779
|
new_items += partial_items
|
|
652
780
|
|
|
653
|
-
# Yield partial response
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
781
|
+
# Yield partial response if any
|
|
782
|
+
if partial_items:
|
|
783
|
+
yield {
|
|
784
|
+
"output": partial_items,
|
|
785
|
+
"usage": Usage(
|
|
786
|
+
prompt_tokens=0,
|
|
787
|
+
completion_tokens=0,
|
|
788
|
+
total_tokens=0,
|
|
789
|
+
),
|
|
790
|
+
}
|
|
791
|
+
|
|
663
792
|
await self._on_run_end(loop_kwargs, old_items, new_items)
|
|
664
|
-
|
|
793
|
+
|
|
665
794
|
async def predict_click(
|
|
666
|
-
self,
|
|
667
|
-
instruction: str,
|
|
668
|
-
image_b64: Optional[str] = None
|
|
795
|
+
self, instruction: str, image_b64: Optional[str] = None
|
|
669
796
|
) -> Optional[Tuple[int, int]]:
|
|
670
797
|
"""
|
|
671
798
|
Predict click coordinates based on image and instruction.
|
|
672
|
-
|
|
799
|
+
|
|
673
800
|
Args:
|
|
674
801
|
instruction: Instruction for where to click
|
|
675
802
|
image_b64: Base64 encoded image (optional, will take screenshot if not provided)
|
|
676
|
-
|
|
803
|
+
|
|
677
804
|
Returns:
|
|
678
805
|
None or tuple with (x, y) coordinates
|
|
679
806
|
"""
|
|
680
807
|
if not self.agent_config_info:
|
|
681
808
|
raise ValueError("Agent configuration not found")
|
|
682
|
-
|
|
809
|
+
|
|
683
810
|
capabilities = self.get_capabilities()
|
|
684
811
|
if "click" not in capabilities:
|
|
685
|
-
raise ValueError(
|
|
686
|
-
|
|
812
|
+
raise ValueError(
|
|
813
|
+
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions"
|
|
814
|
+
)
|
|
815
|
+
if hasattr(self.agent_loop, "predict_click"):
|
|
687
816
|
if not image_b64:
|
|
688
817
|
if not self.computer_handler:
|
|
689
818
|
raise ValueError("Computer tool or image_b64 is required for predict_click")
|
|
690
819
|
image_b64 = await self.computer_handler.screenshot()
|
|
820
|
+
# Pass along api credentials if available
|
|
821
|
+
click_kwargs: Dict[str, Any] = {}
|
|
822
|
+
if self.api_key is not None:
|
|
823
|
+
click_kwargs["api_key"] = self.api_key
|
|
824
|
+
if self.api_base is not None:
|
|
825
|
+
click_kwargs["api_base"] = self.api_base
|
|
691
826
|
return await self.agent_loop.predict_click(
|
|
692
|
-
model=self.model,
|
|
693
|
-
image_b64=image_b64,
|
|
694
|
-
instruction=instruction
|
|
827
|
+
model=self.model, image_b64=image_b64, instruction=instruction, **click_kwargs
|
|
695
828
|
)
|
|
696
829
|
return None
|
|
697
|
-
|
|
830
|
+
|
|
698
831
|
def get_capabilities(self) -> List[AgentCapability]:
|
|
699
832
|
"""
|
|
700
833
|
Get list of capabilities supported by the current agent config.
|
|
701
|
-
|
|
834
|
+
|
|
702
835
|
Returns:
|
|
703
836
|
List of capability strings (e.g., ["step", "click"])
|
|
704
837
|
"""
|
|
705
838
|
if not self.agent_config_info:
|
|
706
839
|
raise ValueError("Agent configuration not found")
|
|
707
|
-
|
|
708
|
-
if hasattr(self.agent_loop,
|
|
840
|
+
|
|
841
|
+
if hasattr(self.agent_loop, "get_capabilities"):
|
|
709
842
|
return self.agent_loop.get_capabilities()
|
|
710
|
-
return ["step"] # Default capability
|
|
843
|
+
return ["step"] # Default capability
|
|
844
|
+
|
|
845
|
+
def open(self, port: Optional[int] = None):
|
|
846
|
+
"""
|
|
847
|
+
Start the playground server and open it in the browser.
|
|
848
|
+
|
|
849
|
+
This method starts a local HTTP server that exposes the /responses endpoint
|
|
850
|
+
and automatically opens the Cua playground interface in the default browser.
|
|
851
|
+
|
|
852
|
+
Args:
|
|
853
|
+
port: Port to run the server on. If None, finds an available port automatically.
|
|
854
|
+
|
|
855
|
+
Example:
|
|
856
|
+
>>> agent = ComputerAgent(model="claude-sonnet-4")
|
|
857
|
+
>>> agent.open() # Starts server and opens browser
|
|
858
|
+
"""
|
|
859
|
+
from .playground import PlaygroundServer
|
|
860
|
+
|
|
861
|
+
server = PlaygroundServer(agent_instance=self)
|
|
862
|
+
server.start(port=port, open_browser=True)
|