cua-agent 0.4.14__py3-none-any.whl → 0.7.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (82) hide show
  1. agent/__init__.py +4 -19
  2. agent/__main__.py +2 -1
  3. agent/adapters/__init__.py +6 -0
  4. agent/adapters/azure_ml_adapter.py +283 -0
  5. agent/adapters/cua_adapter.py +161 -0
  6. agent/adapters/huggingfacelocal_adapter.py +67 -125
  7. agent/adapters/human_adapter.py +116 -114
  8. agent/adapters/mlxvlm_adapter.py +370 -0
  9. agent/adapters/models/__init__.py +41 -0
  10. agent/adapters/models/generic.py +78 -0
  11. agent/adapters/models/internvl.py +290 -0
  12. agent/adapters/models/opencua.py +115 -0
  13. agent/adapters/models/qwen2_5_vl.py +78 -0
  14. agent/agent.py +431 -241
  15. agent/callbacks/__init__.py +10 -3
  16. agent/callbacks/base.py +45 -31
  17. agent/callbacks/budget_manager.py +22 -10
  18. agent/callbacks/image_retention.py +54 -98
  19. agent/callbacks/logging.py +55 -42
  20. agent/callbacks/operator_validator.py +140 -0
  21. agent/callbacks/otel.py +291 -0
  22. agent/callbacks/pii_anonymization.py +19 -16
  23. agent/callbacks/prompt_instructions.py +47 -0
  24. agent/callbacks/telemetry.py +106 -69
  25. agent/callbacks/trajectory_saver.py +178 -70
  26. agent/cli.py +269 -119
  27. agent/computers/__init__.py +14 -9
  28. agent/computers/base.py +32 -19
  29. agent/computers/cua.py +52 -25
  30. agent/computers/custom.py +78 -71
  31. agent/decorators.py +23 -14
  32. agent/human_tool/__init__.py +2 -7
  33. agent/human_tool/__main__.py +6 -2
  34. agent/human_tool/server.py +48 -37
  35. agent/human_tool/ui.py +359 -235
  36. agent/integrations/hud/__init__.py +164 -74
  37. agent/integrations/hud/agent.py +338 -342
  38. agent/integrations/hud/proxy.py +297 -0
  39. agent/loops/__init__.py +44 -14
  40. agent/loops/anthropic.py +590 -492
  41. agent/loops/base.py +19 -15
  42. agent/loops/composed_grounded.py +142 -144
  43. agent/loops/fara/__init__.py +8 -0
  44. agent/loops/fara/config.py +506 -0
  45. agent/loops/fara/helpers.py +357 -0
  46. agent/loops/fara/schema.py +143 -0
  47. agent/loops/gelato.py +183 -0
  48. agent/loops/gemini.py +935 -0
  49. agent/loops/generic_vlm.py +601 -0
  50. agent/loops/glm45v.py +140 -135
  51. agent/loops/gta1.py +48 -51
  52. agent/loops/holo.py +218 -0
  53. agent/loops/internvl.py +180 -0
  54. agent/loops/moondream3.py +493 -0
  55. agent/loops/omniparser.py +326 -226
  56. agent/loops/openai.py +63 -56
  57. agent/loops/opencua.py +134 -0
  58. agent/loops/uiins.py +175 -0
  59. agent/loops/uitars.py +262 -212
  60. agent/loops/uitars2.py +951 -0
  61. agent/playground/__init__.py +5 -0
  62. agent/playground/server.py +301 -0
  63. agent/proxy/examples.py +196 -0
  64. agent/proxy/handlers.py +255 -0
  65. agent/responses.py +486 -339
  66. agent/tools/__init__.py +24 -0
  67. agent/tools/base.py +253 -0
  68. agent/tools/browser_tool.py +423 -0
  69. agent/types.py +20 -5
  70. agent/ui/__init__.py +1 -1
  71. agent/ui/__main__.py +1 -1
  72. agent/ui/gradio/app.py +25 -22
  73. agent/ui/gradio/ui_components.py +314 -167
  74. cua_agent-0.7.16.dist-info/METADATA +85 -0
  75. cua_agent-0.7.16.dist-info/RECORD +79 -0
  76. {cua_agent-0.4.14.dist-info → cua_agent-0.7.16.dist-info}/WHEEL +1 -1
  77. agent/integrations/hud/adapter.py +0 -121
  78. agent/integrations/hud/computer_handler.py +0 -187
  79. agent/telemetry.py +0 -142
  80. cua_agent-0.4.14.dist-info/METADATA +0 -436
  81. cua_agent-0.4.14.dist-info/RECORD +0 -50
  82. {cua_agent-0.4.14.dist-info → cua_agent-0.7.16.dist-info}/entry_points.txt +0 -0
agent/agent.py CHANGED
@@ -3,57 +3,87 @@ ComputerAgent - Main agent class that selects and runs agent loops
3
3
  """
4
4
 
5
5
  import asyncio
6
- from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple
7
-
8
- from litellm.responses.utils import Usage
9
-
10
- from .types import Messages, AgentCapability
11
- from .decorators import find_agent_config
6
+ import inspect
12
7
  import json
8
+ from pathlib import Path
9
+ from typing import (
10
+ Any,
11
+ AsyncGenerator,
12
+ Callable,
13
+ Dict,
14
+ List,
15
+ Optional,
16
+ Set,
17
+ Tuple,
18
+ Union,
19
+ cast,
20
+ )
21
+
13
22
  import litellm
14
23
  import litellm.utils
15
- import inspect
24
+ from litellm.responses.utils import Usage
25
+
16
26
  from .adapters import (
27
+ AzureMLAdapter,
28
+ CUAAdapter,
17
29
  HuggingFaceLocalAdapter,
18
30
  HumanAdapter,
31
+ MLXVLMAdapter,
19
32
  )
20
33
  from .callbacks import (
21
- ImageRetentionCallback,
22
- LoggingCallback,
23
- TrajectorySaverCallback,
24
34
  BudgetManagerCallback,
35
+ ImageRetentionCallback,
36
+ LoggingCallback,
37
+ OperatorNormalizerCallback,
38
+ OtelCallback,
39
+ PromptInstructionsCallback,
25
40
  TelemetryCallback,
41
+ TrajectorySaverCallback,
26
42
  )
27
- from .computers import (
28
- AsyncComputerHandler,
29
- is_agent_computer,
30
- make_computer_handler
43
+ from .computers import AsyncComputerHandler, is_agent_computer, make_computer_handler
44
+ from .decorators import find_agent_config
45
+ from .responses import (
46
+ make_tool_error_item,
47
+ replace_failed_computer_calls_with_function_calls,
31
48
  )
49
+ from .tools.base import BaseComputerTool, BaseTool
50
+ from .types import AgentCapability, IllegalArgumentError, Messages, ToolError
51
+
52
+
53
+ def assert_callable_with(f, *args, **kwargs):
54
+ """Check if function can be called with given arguments."""
55
+ try:
56
+ inspect.signature(f).bind(*args, **kwargs)
57
+ return True
58
+ except TypeError as e:
59
+ sig = inspect.signature(f)
60
+ raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
61
+
32
62
 
33
63
  def get_json(obj: Any, max_depth: int = 10) -> Any:
34
64
  def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
35
65
  if seen is None:
36
66
  seen = set()
37
-
67
+
38
68
  # Use model_dump() if available
39
- if hasattr(o, 'model_dump'):
69
+ if hasattr(o, "model_dump"):
40
70
  return o.model_dump()
41
-
71
+
42
72
  # Check depth limit
43
73
  if depth > max_depth:
44
74
  return f"<max_depth_exceeded:{max_depth}>"
45
-
75
+
46
76
  # Check for circular references using object id
47
77
  obj_id = id(o)
48
78
  if obj_id in seen:
49
79
  return f"<circular_reference:{type(o).__name__}>"
50
-
80
+
51
81
  # Handle Computer objects
52
- if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower():
82
+ if hasattr(o, "__class__") and "computer" in o.__class__.__name__.lower():
53
83
  return f"<computer:{o.__class__.__name__}>"
54
84
 
55
85
  # Handle objects with __dict__
56
- if hasattr(o, '__dict__'):
86
+ if hasattr(o, "__dict__"):
57
87
  seen.add(obj_id)
58
88
  try:
59
89
  result = {}
@@ -65,7 +95,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
65
95
  return result
66
96
  finally:
67
97
  seen.discard(obj_id)
68
-
98
+
69
99
  # Handle common types that might contain nested objects
70
100
  elif isinstance(o, dict):
71
101
  seen.add(obj_id)
@@ -77,7 +107,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
77
107
  }
78
108
  finally:
79
109
  seen.discard(obj_id)
80
-
110
+
81
111
  elif isinstance(o, (list, tuple, set)):
82
112
  seen.add(obj_id)
83
113
  try:
@@ -88,32 +118,33 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
88
118
  ]
89
119
  finally:
90
120
  seen.discard(obj_id)
91
-
121
+
92
122
  # For basic types that json.dumps can handle
93
123
  elif isinstance(o, (str, int, float, bool)) or o is None:
94
124
  return o
95
-
125
+
96
126
  # Fallback to string representation
97
127
  else:
98
128
  return str(o)
99
-
129
+
100
130
  def remove_nones(obj: Any) -> Any:
101
131
  if isinstance(obj, dict):
102
132
  return {k: remove_nones(v) for k, v in obj.items() if v is not None}
103
133
  elif isinstance(obj, list):
104
134
  return [remove_nones(item) for item in obj if item is not None]
105
135
  return obj
106
-
136
+
107
137
  # Serialize with circular reference and depth protection
108
138
  serialized = custom_serializer(obj)
109
-
139
+
110
140
  # Convert to JSON string and back to ensure JSON compatibility
111
141
  json_str = json.dumps(serialized)
112
142
  parsed = json.loads(json_str)
113
-
143
+
114
144
  # Final cleanup of any remaining None values
115
145
  return remove_nones(parsed)
116
146
 
147
+
117
148
  def sanitize_message(msg: Any) -> Any:
118
149
  """Return a copy of the message with image_url omitted for computer_call_output messages."""
119
150
  if msg.get("type") == "computer_call_output":
@@ -124,19 +155,24 @@ def sanitize_message(msg: Any) -> Any:
124
155
  return sanitized
125
156
  return msg
126
157
 
158
+
127
159
  def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
128
160
  call_ids = []
129
161
  for message in messages:
130
- if message.get("type") == "computer_call_output" or message.get("type") == "function_call_output":
162
+ if (
163
+ message.get("type") == "computer_call_output"
164
+ or message.get("type") == "function_call_output"
165
+ ):
131
166
  call_ids.append(message.get("call_id"))
132
167
  return call_ids
133
168
 
169
+
134
170
  class ComputerAgent:
135
171
  """
136
172
  Main agent class that automatically selects the appropriate agent loop
137
173
  based on the model and executes tool calls.
138
174
  """
139
-
175
+
140
176
  def __init__(
141
177
  self,
142
178
  model: str,
@@ -144,24 +180,29 @@ class ComputerAgent:
144
180
  custom_loop: Optional[Callable] = None,
145
181
  only_n_most_recent_images: Optional[int] = None,
146
182
  callbacks: Optional[List[Any]] = None,
183
+ instructions: Optional[str] = None,
147
184
  verbosity: Optional[int] = None,
148
- trajectory_dir: Optional[str] = None,
185
+ trajectory_dir: Optional[str | Path | dict] = None,
149
186
  max_retries: Optional[int] = 3,
150
187
  screenshot_delay: Optional[float | int] = 0.5,
151
188
  use_prompt_caching: Optional[bool] = False,
152
189
  max_trajectory_budget: Optional[float | dict] = None,
153
190
  telemetry_enabled: Optional[bool] = True,
154
- **kwargs
191
+ trust_remote_code: Optional[bool] = False,
192
+ api_key: Optional[str] = None,
193
+ api_base: Optional[str] = None,
194
+ **additional_generation_kwargs,
155
195
  ):
156
196
  """
157
197
  Initialize ComputerAgent.
158
-
198
+
159
199
  Args:
160
- model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
200
+ model: Model name (e.g., "claude-sonnet-4-5-20250929", "computer-use-preview", "omni+vertex_ai/gemini-pro")
161
201
  tools: List of tools (computer objects, decorated functions, etc.)
162
202
  custom_loop: Custom agent loop function to use instead of auto-selection
163
203
  only_n_most_recent_images: If set, only keep the N most recent images in message history. Adds ImageRetentionCallback automatically.
164
204
  callbacks: List of AsyncCallbackHandler instances for preprocessing/postprocessing
205
+ instructions: Optional system instructions to be passed to the model
165
206
  verbosity: Logging level (logging.DEBUG, logging.INFO, etc.). If set, adds LoggingCallback automatically
166
207
  trajectory_dir: If set, saves trajectory data (screenshots, responses) to this directory. Adds TrajectorySaverCallback automatically.
167
208
  max_retries: Maximum number of retries for failed API calls
@@ -169,29 +210,40 @@ class ComputerAgent:
169
210
  use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers.
170
211
  max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
171
212
  telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
172
- **kwargs: Additional arguments passed to the agent loop
213
+ trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
214
+ api_key: Optional API key override for the model provider
215
+ api_base: Optional API base URL override for the model provider
216
+ **additional_generation_kwargs: Additional arguments passed to the model provider
173
217
  """
218
+ # If the loop is "human/human", we need to prefix a grounding model fallback
219
+ if model in ["human/human", "human"]:
220
+ model = "openai/computer-use-preview+human/human"
221
+
174
222
  self.model = model
175
223
  self.tools = tools or []
176
224
  self.custom_loop = custom_loop
177
225
  self.only_n_most_recent_images = only_n_most_recent_images
178
226
  self.callbacks = callbacks or []
227
+ self.instructions = instructions
179
228
  self.verbosity = verbosity
180
229
  self.trajectory_dir = trajectory_dir
181
230
  self.max_retries = max_retries
182
231
  self.screenshot_delay = screenshot_delay
183
232
  self.use_prompt_caching = use_prompt_caching
184
233
  self.telemetry_enabled = telemetry_enabled
185
- self.kwargs = kwargs
234
+ self.kwargs = additional_generation_kwargs
235
+ self.trust_remote_code = trust_remote_code
236
+ self.api_key = api_key
237
+ self.api_base = api_base
186
238
 
187
239
  # == Add built-in callbacks ==
188
240
 
189
- # Add telemetry callback if telemetry_enabled is set
190
- if self.telemetry_enabled:
191
- if isinstance(self.telemetry_enabled, bool):
192
- self.callbacks.append(TelemetryCallback(self))
193
- else:
194
- self.callbacks.append(TelemetryCallback(self, **self.telemetry_enabled))
241
+ # Prepend operator normalizer callback
242
+ self.callbacks.insert(0, OperatorNormalizerCallback())
243
+
244
+ # Add prompt instructions callback if provided
245
+ if self.instructions:
246
+ self.callbacks.append(PromptInstructionsCallback(self.instructions))
195
247
 
196
248
  # Add logging callback if verbosity is set
197
249
  if self.verbosity is not None:
@@ -200,28 +252,37 @@ class ComputerAgent:
200
252
  # Add image retention callback if only_n_most_recent_images is set
201
253
  if self.only_n_most_recent_images:
202
254
  self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
203
-
255
+
204
256
  # Add trajectory saver callback if trajectory_dir is set
205
257
  if self.trajectory_dir:
206
- self.callbacks.append(TrajectorySaverCallback(self.trajectory_dir))
207
-
258
+ if isinstance(self.trajectory_dir, dict):
259
+ self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
260
+ elif isinstance(self.trajectory_dir, (str, Path)):
261
+ self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
262
+
208
263
  # Add budget manager if max_trajectory_budget is set
209
264
  if max_trajectory_budget:
210
265
  if isinstance(max_trajectory_budget, dict):
211
266
  self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
212
267
  else:
213
268
  self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
214
-
269
+
215
270
  # == Enable local model providers w/ LiteLLM ==
216
271
 
217
272
  # Register local model providers
218
273
  hf_adapter = HuggingFaceLocalAdapter(
219
- device="auto"
274
+ device="auto", trust_remote_code=self.trust_remote_code or False
220
275
  )
221
276
  human_adapter = HumanAdapter()
277
+ mlx_adapter = MLXVLMAdapter()
278
+ cua_adapter = CUAAdapter()
279
+ azure_ml_adapter = AzureMLAdapter()
222
280
  litellm.custom_provider_map = [
223
281
  {"provider": "huggingface-local", "custom_handler": hf_adapter},
224
- {"provider": "human", "custom_handler": human_adapter}
282
+ {"provider": "human", "custom_handler": human_adapter},
283
+ {"provider": "mlx", "custom_handler": mlx_adapter},
284
+ {"provider": "cua", "custom_handler": cua_adapter},
285
+ {"provider": "azure_ml", "custom_handler": azure_ml_adapter},
225
286
  ]
226
287
  litellm.suppress_debug_info = True
227
288
 
@@ -238,24 +299,47 @@ class ComputerAgent:
238
299
  # Instantiate the agent config class
239
300
  self.agent_loop = config_info.agent_class()
240
301
  self.agent_config_info = config_info
241
-
302
+
303
+ # Add telemetry callbacks AFTER agent_loop is set so they can capture the correct agent_type
304
+ if self.telemetry_enabled:
305
+ # PostHog telemetry (product analytics)
306
+ if isinstance(self.telemetry_enabled, bool):
307
+ self.callbacks.append(TelemetryCallback(self))
308
+ else:
309
+ self.callbacks.append(TelemetryCallback(self, **self.telemetry_enabled))
310
+
311
+ # OpenTelemetry callback (operational metrics - Four Golden Signals)
312
+ # This is enabled alongside PostHog when telemetry_enabled is True
313
+ # Users can disable via CUA_TELEMETRY_DISABLED=true env var
314
+ self.callbacks.append(OtelCallback(self))
315
+
242
316
  self.tool_schemas = []
243
317
  self.computer_handler = None
244
-
318
+
245
319
  async def _initialize_computers(self):
246
320
  """Initialize computer objects"""
247
321
  if not self.tool_schemas:
248
322
  # Process tools and create tool schemas
249
323
  self.tool_schemas = self._process_tools()
250
-
324
+
251
325
  # Find computer tool and create interface adapter
252
326
  computer_handler = None
253
- for schema in self.tool_schemas:
254
- if schema["type"] == "computer":
255
- computer_handler = await make_computer_handler(schema["computer"])
327
+
328
+ # First check if any tool is a BaseComputerTool instance
329
+ for tool in self.tools:
330
+ if isinstance(tool, BaseComputerTool):
331
+ computer_handler = tool
256
332
  break
333
+
334
+ # If no BaseComputerTool found, look for traditional computer objects
335
+ if computer_handler is None:
336
+ for schema in self.tool_schemas:
337
+ if schema["type"] == "computer":
338
+ computer_handler = await make_computer_handler(schema["computer"])
339
+ break
340
+
257
341
  self.computer_handler = computer_handler
258
-
342
+
259
343
  def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
260
344
  """Process input messages and create schemas for the agent loop"""
261
345
  if isinstance(input, str):
@@ -265,69 +349,85 @@ class ComputerAgent:
265
349
  def _process_tools(self) -> List[Dict[str, Any]]:
266
350
  """Process tools and create schemas for the agent loop"""
267
351
  schemas = []
268
-
352
+
269
353
  for tool in self.tools:
270
354
  # Check if it's a computer object (has interface attribute)
271
355
  if is_agent_computer(tool):
272
356
  # This is a computer tool - will be handled by agent loop
273
- schemas.append({
274
- "type": "computer",
275
- "computer": tool
276
- })
357
+ schemas.append({"type": "computer", "computer": tool})
358
+ elif isinstance(tool, BaseTool):
359
+ # BaseTool instance - extract schema from its properties
360
+ function_schema = {
361
+ "name": tool.name,
362
+ "description": tool.description,
363
+ "parameters": tool.parameters,
364
+ }
365
+ schemas.append({"type": "function", "function": function_schema})
277
366
  elif callable(tool):
278
367
  # Use litellm.utils.function_to_dict to extract schema from docstring
279
368
  try:
280
369
  function_schema = litellm.utils.function_to_dict(tool)
281
- schemas.append({
282
- "type": "function",
283
- "function": function_schema
284
- })
370
+ schemas.append({"type": "function", "function": function_schema})
285
371
  except Exception as e:
286
372
  print(f"Warning: Could not process tool {tool}: {e}")
287
373
  else:
288
374
  print(f"Warning: Unknown tool type: {tool}")
289
-
375
+
290
376
  return schemas
291
-
292
- def _get_tool(self, name: str) -> Optional[Callable]:
377
+
378
+ def _get_tool(self, name: str) -> Optional[Union[Callable, BaseTool]]:
293
379
  """Get a tool by name"""
294
380
  for tool in self.tools:
295
- if hasattr(tool, '__name__') and tool.__name__ == name:
381
+ # Check if it's a BaseTool instance
382
+ if isinstance(tool, BaseTool) and tool.name == name:
296
383
  return tool
297
- elif hasattr(tool, 'func') and tool.func.__name__ == name:
384
+ # Check if it's a regular callable
385
+ elif hasattr(tool, "__name__") and tool.__name__ == name:
386
+ return tool
387
+ elif hasattr(tool, "func") and tool.func.__name__ == name:
298
388
  return tool
299
389
  return None
300
-
390
+
301
391
  # ============================================================================
302
392
  # AGENT RUN LOOP LIFECYCLE HOOKS
303
393
  # ============================================================================
304
-
394
+
305
395
  async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
306
396
  """Initialize run tracking by calling callbacks."""
307
397
  for callback in self.callbacks:
308
- if hasattr(callback, 'on_run_start'):
398
+ if hasattr(callback, "on_run_start"):
309
399
  await callback.on_run_start(kwargs, old_items)
310
-
311
- async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
400
+
401
+ async def _on_run_end(
402
+ self,
403
+ kwargs: Dict[str, Any],
404
+ old_items: List[Dict[str, Any]],
405
+ new_items: List[Dict[str, Any]],
406
+ ) -> None:
312
407
  """Finalize run tracking by calling callbacks."""
313
408
  for callback in self.callbacks:
314
- if hasattr(callback, 'on_run_end'):
409
+ if hasattr(callback, "on_run_end"):
315
410
  await callback.on_run_end(kwargs, old_items, new_items)
316
-
317
- async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
411
+
412
+ async def _on_run_continue(
413
+ self,
414
+ kwargs: Dict[str, Any],
415
+ old_items: List[Dict[str, Any]],
416
+ new_items: List[Dict[str, Any]],
417
+ ) -> bool:
318
418
  """Check if run should continue by calling callbacks."""
319
419
  for callback in self.callbacks:
320
- if hasattr(callback, 'on_run_continue'):
420
+ if hasattr(callback, "on_run_continue"):
321
421
  should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
322
422
  if not should_continue:
323
423
  return False
324
424
  return True
325
-
425
+
326
426
  async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
327
427
  """Prepare messages for the LLM call by applying callbacks."""
328
428
  result = messages
329
429
  for callback in self.callbacks:
330
- if hasattr(callback, 'on_llm_start'):
430
+ if hasattr(callback, "on_llm_start"):
331
431
  result = await callback.on_llm_start(result)
332
432
  return result
333
433
 
@@ -335,81 +435,91 @@ class ComputerAgent:
335
435
  """Postprocess messages after the LLM call by applying callbacks."""
336
436
  result = messages
337
437
  for callback in self.callbacks:
338
- if hasattr(callback, 'on_llm_end'):
438
+ if hasattr(callback, "on_llm_end"):
339
439
  result = await callback.on_llm_end(result)
340
440
  return result
341
441
 
342
442
  async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
343
443
  """Called when responses are received."""
344
444
  for callback in self.callbacks:
345
- if hasattr(callback, 'on_responses'):
445
+ if hasattr(callback, "on_responses"):
346
446
  await callback.on_responses(get_json(kwargs), get_json(responses))
347
-
447
+
348
448
  async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
349
449
  """Called when a computer call is about to start."""
350
450
  for callback in self.callbacks:
351
- if hasattr(callback, 'on_computer_call_start'):
451
+ if hasattr(callback, "on_computer_call_start"):
352
452
  await callback.on_computer_call_start(get_json(item))
353
-
354
- async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
453
+
454
+ async def _on_computer_call_end(
455
+ self, item: Dict[str, Any], result: List[Dict[str, Any]]
456
+ ) -> None:
355
457
  """Called when a computer call has completed."""
356
458
  for callback in self.callbacks:
357
- if hasattr(callback, 'on_computer_call_end'):
459
+ if hasattr(callback, "on_computer_call_end"):
358
460
  await callback.on_computer_call_end(get_json(item), get_json(result))
359
-
461
+
360
462
  async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
361
463
  """Called when a function call is about to start."""
362
464
  for callback in self.callbacks:
363
- if hasattr(callback, 'on_function_call_start'):
465
+ if hasattr(callback, "on_function_call_start"):
364
466
  await callback.on_function_call_start(get_json(item))
365
-
366
- async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
467
+
468
+ async def _on_function_call_end(
469
+ self, item: Dict[str, Any], result: List[Dict[str, Any]]
470
+ ) -> None:
367
471
  """Called when a function call has completed."""
368
472
  for callback in self.callbacks:
369
- if hasattr(callback, 'on_function_call_end'):
473
+ if hasattr(callback, "on_function_call_end"):
370
474
  await callback.on_function_call_end(get_json(item), get_json(result))
371
-
475
+
372
476
  async def _on_text(self, item: Dict[str, Any]) -> None:
373
477
  """Called when a text message is encountered."""
374
478
  for callback in self.callbacks:
375
- if hasattr(callback, 'on_text'):
479
+ if hasattr(callback, "on_text"):
376
480
  await callback.on_text(get_json(item))
377
-
481
+
378
482
  async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
379
483
  """Called when an LLM API call is about to start."""
380
484
  for callback in self.callbacks:
381
- if hasattr(callback, 'on_api_start'):
485
+ if hasattr(callback, "on_api_start"):
382
486
  await callback.on_api_start(get_json(kwargs))
383
-
487
+
384
488
  async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
385
489
  """Called when an LLM API call has completed."""
386
490
  for callback in self.callbacks:
387
- if hasattr(callback, 'on_api_end'):
491
+ if hasattr(callback, "on_api_end"):
388
492
  await callback.on_api_end(get_json(kwargs), get_json(result))
389
493
 
390
494
  async def _on_usage(self, usage: Dict[str, Any]) -> None:
391
495
  """Called when usage information is received."""
392
496
  for callback in self.callbacks:
393
- if hasattr(callback, 'on_usage'):
497
+ if hasattr(callback, "on_usage"):
394
498
  await callback.on_usage(get_json(usage))
395
499
 
396
500
  async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
397
501
  """Called when a screenshot is taken."""
398
502
  for callback in self.callbacks:
399
- if hasattr(callback, 'on_screenshot'):
503
+ if hasattr(callback, "on_screenshot"):
400
504
  await callback.on_screenshot(screenshot, name)
401
505
 
402
506
  # ============================================================================
403
507
  # AGENT OUTPUT PROCESSING
404
508
  # ============================================================================
405
-
406
- async def _handle_item(self, item: Any, computer: Optional[AsyncComputerHandler] = None, ignore_call_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
509
+
510
+ async def _handle_item(
511
+ self,
512
+ item: Any,
513
+ computer: Optional[AsyncComputerHandler] = None,
514
+ ignore_call_ids: Optional[List[str]] = None,
515
+ ) -> List[Dict[str, Any]]:
407
516
  """Handle each item; may cause a computer action + screenshot."""
408
- if ignore_call_ids and item.get("call_id") and item.get("call_id") in ignore_call_ids:
517
+ call_id = item.get("call_id")
518
+ if ignore_call_ids and call_id and call_id in ignore_call_ids:
409
519
  return []
410
-
520
+
411
521
  item_type = item.get("type", None)
412
-
522
+
413
523
  if item_type == "message":
414
524
  await self._on_text(item)
415
525
  # # Print messages
@@ -418,133 +528,156 @@ class ComputerAgent:
418
528
  # if content_item.get("text"):
419
529
  # print(content_item.get("text"))
420
530
  return []
421
-
422
- if item_type == "computer_call":
423
- await self._on_computer_call_start(item)
424
- if not computer:
425
- raise ValueError("Computer handler is required for computer calls")
426
-
427
- # Perform computer actions
428
- action = item.get("action")
429
- action_type = action.get("type")
430
- if action_type is None:
431
- print(f"Action type cannot be `None`: action={action}, action_type={action_type}")
432
- return []
433
-
434
- # Extract action arguments (all fields except 'type')
435
- action_args = {k: v for k, v in action.items() if k != "type"}
436
-
437
- # print(f"{action_type}({action_args})")
438
-
439
- # Execute the computer action
440
- computer_method = getattr(computer, action_type, None)
441
- if computer_method:
442
- await computer_method(**action_args)
443
- else:
444
- print(f"Unknown computer action: {action_type}")
445
- return []
446
-
447
- # Take screenshot after action
448
- if self.screenshot_delay and self.screenshot_delay > 0:
449
- await asyncio.sleep(self.screenshot_delay)
450
- screenshot_base64 = await computer.screenshot()
451
- await self._on_screenshot(screenshot_base64, "screenshot_after")
452
-
453
- # Handle safety checks
454
- pending_checks = item.get("pending_safety_checks", [])
455
- acknowledged_checks = []
456
- for check in pending_checks:
457
- check_message = check.get("message", str(check))
458
- acknowledged_checks.append(check)
459
- # TODO: implement a callback for safety checks
460
- # if acknowledge_safety_check_callback(check_message, allow_always=True):
461
- # acknowledged_checks.append(check)
462
- # else:
463
- # raise ValueError(f"Safety check failed: {check_message}")
464
-
465
- # Create call output
466
- call_output = {
467
- "type": "computer_call_output",
468
- "call_id": item.get("call_id"),
469
- "acknowledged_safety_checks": acknowledged_checks,
470
- "output": {
471
- "type": "input_image",
472
- "image_url": f"data:image/png;base64,{screenshot_base64}",
473
- },
474
- }
475
-
476
- # # Additional URL safety checks for browser environments
477
- # if await computer.get_environment() == "browser":
478
- # current_url = await computer.get_current_url()
479
- # call_output["output"]["current_url"] = current_url
480
- # # TODO: implement a callback for URL safety checks
481
- # # check_blocklisted_url(current_url)
482
-
483
- result = [call_output]
484
- await self._on_computer_call_end(item, result)
485
- return result
486
-
487
- if item_type == "function_call":
488
- await self._on_function_call_start(item)
489
- # Perform function call
490
- function = self._get_tool(item.get("name"))
491
- if not function:
492
- raise ValueError(f"Function {item.get("name")} not found")
493
-
494
- args = json.loads(item.get("arguments"))
495
-
496
- # Execute function - use asyncio.to_thread for non-async functions
497
- if inspect.iscoroutinefunction(function):
498
- result = await function(**args)
499
- else:
500
- result = await asyncio.to_thread(function, **args)
501
-
502
- # Create function call output
503
- call_output = {
504
- "type": "function_call_output",
505
- "call_id": item.get("call_id"),
506
- "output": str(result),
507
- }
508
-
509
- result = [call_output]
510
- await self._on_function_call_end(item, result)
511
- return result
531
+
532
+ try:
533
+ if item_type == "computer_call":
534
+ await self._on_computer_call_start(item)
535
+ if not computer:
536
+ raise ValueError("Computer handler is required for computer calls")
537
+
538
+ # Perform computer actions
539
+ action = item.get("action")
540
+ action_type = action.get("type")
541
+ if action_type is None:
542
+ print(
543
+ f"Action type cannot be `None`: action={action}, action_type={action_type}"
544
+ )
545
+ return []
546
+
547
+ # Extract action arguments (all fields except 'type')
548
+ action_args = {k: v for k, v in action.items() if k != "type"}
549
+
550
+ # print(f"{action_type}({action_args})")
551
+
552
+ # Execute the computer action
553
+ computer_method = getattr(computer, action_type, None)
554
+ if computer_method:
555
+ assert_callable_with(computer_method, **action_args)
556
+ await computer_method(**action_args)
557
+ else:
558
+ raise ToolError(f"Unknown computer action: {action_type}")
559
+
560
+ # Take screenshot after action
561
+ if self.screenshot_delay and self.screenshot_delay > 0:
562
+ await asyncio.sleep(self.screenshot_delay)
563
+ screenshot_base64 = await computer.screenshot()
564
+ await self._on_screenshot(screenshot_base64, "screenshot_after")
565
+
566
+ # Handle safety checks
567
+ pending_checks = item.get("pending_safety_checks", [])
568
+ acknowledged_checks = []
569
+ for check in pending_checks:
570
+ check_message = check.get("message", str(check))
571
+ acknowledged_checks.append(check)
572
+ # TODO: implement a callback for safety checks
573
+ # if acknowledge_safety_check_callback(check_message, allow_always=True):
574
+ # acknowledged_checks.append(check)
575
+ # else:
576
+ # raise ValueError(f"Safety check failed: {check_message}")
577
+
578
+ # Create call output
579
+ call_output = {
580
+ "type": "computer_call_output",
581
+ "call_id": item.get("call_id"),
582
+ "acknowledged_safety_checks": acknowledged_checks,
583
+ "output": {
584
+ "type": "input_image",
585
+ "image_url": f"data:image/png;base64,{screenshot_base64}",
586
+ },
587
+ }
588
+
589
+ # # Additional URL safety checks for browser environments
590
+ # if await computer.get_environment() == "browser":
591
+ # current_url = await computer.get_current_url()
592
+ # call_output["output"]["current_url"] = current_url
593
+ # # TODO: implement a callback for URL safety checks
594
+ # # check_blocklisted_url(current_url)
595
+
596
+ result = [call_output]
597
+ await self._on_computer_call_end(item, result)
598
+ return result
599
+
600
+ if item_type == "function_call":
601
+ await self._on_function_call_start(item)
602
+ # Perform function call
603
+ function = self._get_tool(item.get("name"))
604
+ if not function:
605
+ raise ToolError(f"Function {item.get('name')} not found")
606
+
607
+ args = json.loads(item.get("arguments"))
608
+
609
+ # Handle BaseTool instances
610
+ if isinstance(function, BaseTool):
611
+ # BaseTool.call() handles its own execution
612
+ result = function.call(args)
613
+ else:
614
+ # Validate arguments before execution for regular callables
615
+ assert_callable_with(function, **args)
616
+
617
+ # Execute function - use asyncio.to_thread for non-async functions
618
+ if inspect.iscoroutinefunction(function):
619
+ result = await function(**args)
620
+ else:
621
+ result = await asyncio.to_thread(function, **args)
622
+
623
+ # Create function call output
624
+ call_output = {
625
+ "type": "function_call_output",
626
+ "call_id": item.get("call_id"),
627
+ "output": str(result),
628
+ }
629
+
630
+ result = [call_output]
631
+ await self._on_function_call_end(item, result)
632
+ return result
633
+ except ToolError as e:
634
+ return [make_tool_error_item(repr(e), call_id)]
512
635
 
513
636
  return []
514
637
 
515
638
  # ============================================================================
516
639
  # MAIN AGENT LOOP
517
640
  # ============================================================================
518
-
641
+
519
642
  async def run(
520
643
  self,
521
644
  messages: Messages,
522
645
  stream: bool = False,
523
- **kwargs
646
+ api_key: Optional[str] = None,
647
+ api_base: Optional[str] = None,
648
+ **additional_generation_kwargs,
524
649
  ) -> AsyncGenerator[Dict[str, Any], None]:
525
650
  """
526
651
  Run the agent with the given messages using Computer protocol handler pattern.
527
-
652
+
528
653
  Args:
529
654
  messages: List of message dictionaries
530
655
  stream: Whether to stream the response
531
- **kwargs: Additional arguments
532
-
656
+ api_key: Optional API key override for the model provider
657
+ api_base: Optional API base URL override for the model provider
658
+ **additional_generation_kwargs: Additional arguments passed to the model provider
659
+
533
660
  Returns:
534
661
  AsyncGenerator that yields response chunks
535
662
  """
536
663
  if not self.agent_config_info:
537
664
  raise ValueError("Agent configuration not found")
538
-
665
+
539
666
  capabilities = self.get_capabilities()
540
667
  if "step" not in capabilities:
541
- raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions")
668
+ raise ValueError(
669
+ f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions"
670
+ )
542
671
 
543
672
  await self._initialize_computers()
544
-
545
- # Merge kwargs
546
- merged_kwargs = {**self.kwargs, **kwargs}
547
-
673
+
674
+ # Merge kwargs and thread api credentials (run overrides constructor)
675
+ merged_kwargs = {**self.kwargs, **additional_generation_kwargs}
676
+ if (api_key is not None) or (self.api_key is not None):
677
+ merged_kwargs["api_key"] = api_key if api_key is not None else self.api_key
678
+ if (api_base is not None) or (self.api_base is not None):
679
+ merged_kwargs["api_base"] = api_base if api_base is not None else self.api_base
680
+
548
681
  old_items = self._process_input(messages)
549
682
  new_items = []
550
683
 
@@ -554,7 +687,7 @@ class ComputerAgent:
554
687
  "stream": stream,
555
688
  "model": self.model,
556
689
  "agent_loop": self.agent_config_info.agent_class.__name__,
557
- **merged_kwargs
690
+ **merged_kwargs,
558
691
  }
559
692
  await self._on_run_start(run_kwargs, old_items)
560
693
 
@@ -569,8 +702,9 @@ class ComputerAgent:
569
702
  # - PII anonymization
570
703
  # - Image retention policy
571
704
  combined_messages = old_items + new_items
705
+ combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
572
706
  preprocessed_messages = await self._on_llm_start(combined_messages)
573
-
707
+
574
708
  loop_kwargs = {
575
709
  "messages": preprocessed_messages,
576
710
  "model": self.model,
@@ -579,9 +713,39 @@ class ComputerAgent:
579
713
  "computer_handler": self.computer_handler,
580
714
  "max_retries": self.max_retries,
581
715
  "use_prompt_caching": self.use_prompt_caching,
582
- **merged_kwargs
716
+ **merged_kwargs,
583
717
  }
584
718
 
719
+ # ---- Ollama image input guard ----
720
+ if isinstance(self.model, str) and (
721
+ "ollama/" in self.model or "ollama_chat/" in self.model
722
+ ):
723
+
724
+ def contains_image_content(msgs):
725
+ for m in msgs:
726
+ # 1️⃣ Check regular message content
727
+ content = m.get("content")
728
+ if isinstance(content, list):
729
+ for item in content:
730
+ if isinstance(item, dict) and item.get("type") == "image_url":
731
+ return True
732
+
733
+ # 2️⃣ Check computer_call_output screenshots
734
+ if m.get("type") == "computer_call_output":
735
+ output = m.get("output", {})
736
+ if output.get("type") == "input_image" and "image_url" in output:
737
+ return True
738
+
739
+ return False
740
+
741
+ if contains_image_content(preprocessed_messages):
742
+ raise ValueError(
743
+ "Ollama models do not support image inputs required by ComputerAgent. "
744
+ "Please use a vision-capable model (e.g., OpenAI or Anthropic) "
745
+ "or remove computer/screenshot actions."
746
+ )
747
+ # ---------------------------------
748
+
585
749
  # Run agent loop iteration
586
750
  result = await self.agent_loop.predict_step(
587
751
  **loop_kwargs,
@@ -591,13 +755,13 @@ class ComputerAgent:
591
755
  _on_screenshot=self._on_screenshot,
592
756
  )
593
757
  result = get_json(result)
594
-
758
+
595
759
  # Lifecycle hook: Postprocess messages after the LLM call
596
760
  # Use cases:
597
761
  # - PII deanonymization (if you want tool calls to see PII)
598
762
  result["output"] = await self._on_llm_end(result.get("output", []))
599
763
  await self._on_responses(loop_kwargs, result)
600
-
764
+
601
765
  # Yield agent response
602
766
  yield result
603
767
 
@@ -609,64 +773,90 @@ class ComputerAgent:
609
773
 
610
774
  # Handle computer actions
611
775
  for item in result.get("output"):
612
- partial_items = await self._handle_item(item, self.computer_handler, ignore_call_ids=output_call_ids)
776
+ partial_items = await self._handle_item(
777
+ item, self.computer_handler, ignore_call_ids=output_call_ids
778
+ )
613
779
  new_items += partial_items
614
780
 
615
- # Yield partial response
616
- yield {
617
- "output": partial_items,
618
- "usage": Usage(
619
- prompt_tokens=0,
620
- completion_tokens=0,
621
- total_tokens=0,
622
- )
623
- }
624
-
781
+ # Yield partial response if any
782
+ if partial_items:
783
+ yield {
784
+ "output": partial_items,
785
+ "usage": Usage(
786
+ prompt_tokens=0,
787
+ completion_tokens=0,
788
+ total_tokens=0,
789
+ ),
790
+ }
791
+
625
792
  await self._on_run_end(loop_kwargs, old_items, new_items)
626
-
793
+
627
794
  async def predict_click(
628
- self,
629
- instruction: str,
630
- image_b64: Optional[str] = None
795
+ self, instruction: str, image_b64: Optional[str] = None
631
796
  ) -> Optional[Tuple[int, int]]:
632
797
  """
633
798
  Predict click coordinates based on image and instruction.
634
-
799
+
635
800
  Args:
636
801
  instruction: Instruction for where to click
637
802
  image_b64: Base64 encoded image (optional, will take screenshot if not provided)
638
-
803
+
639
804
  Returns:
640
805
  None or tuple with (x, y) coordinates
641
806
  """
642
807
  if not self.agent_config_info:
643
808
  raise ValueError("Agent configuration not found")
644
-
809
+
645
810
  capabilities = self.get_capabilities()
646
811
  if "click" not in capabilities:
647
- raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions")
648
- if hasattr(self.agent_loop, 'predict_click'):
812
+ raise ValueError(
813
+ f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions"
814
+ )
815
+ if hasattr(self.agent_loop, "predict_click"):
649
816
  if not image_b64:
650
817
  if not self.computer_handler:
651
818
  raise ValueError("Computer tool or image_b64 is required for predict_click")
652
819
  image_b64 = await self.computer_handler.screenshot()
820
+ # Pass along api credentials if available
821
+ click_kwargs: Dict[str, Any] = {}
822
+ if self.api_key is not None:
823
+ click_kwargs["api_key"] = self.api_key
824
+ if self.api_base is not None:
825
+ click_kwargs["api_base"] = self.api_base
653
826
  return await self.agent_loop.predict_click(
654
- model=self.model,
655
- image_b64=image_b64,
656
- instruction=instruction
827
+ model=self.model, image_b64=image_b64, instruction=instruction, **click_kwargs
657
828
  )
658
829
  return None
659
-
830
+
660
831
  def get_capabilities(self) -> List[AgentCapability]:
661
832
  """
662
833
  Get list of capabilities supported by the current agent config.
663
-
834
+
664
835
  Returns:
665
836
  List of capability strings (e.g., ["step", "click"])
666
837
  """
667
838
  if not self.agent_config_info:
668
839
  raise ValueError("Agent configuration not found")
669
-
670
- if hasattr(self.agent_loop, 'get_capabilities'):
840
+
841
+ if hasattr(self.agent_loop, "get_capabilities"):
671
842
  return self.agent_loop.get_capabilities()
672
- return ["step"] # Default capability
843
+ return ["step"] # Default capability
844
+
845
+ def open(self, port: Optional[int] = None):
846
+ """
847
+ Start the playground server and open it in the browser.
848
+
849
+ This method starts a local HTTP server that exposes the /responses endpoint
850
+ and automatically opens the Cua playground interface in the default browser.
851
+
852
+ Args:
853
+ port: Port to run the server on. If None, finds an available port automatically.
854
+
855
+ Example:
856
+ >>> agent = ComputerAgent(model="claude-sonnet-4")
857
+ >>> agent.open() # Starts server and opens browser
858
+ """
859
+ from .playground import PlaygroundServer
860
+
861
+ server = PlaygroundServer(agent_instance=self)
862
+ server.start(port=port, open_browser=True)