cua-agent 0.4.22__py3-none-any.whl → 0.7.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (79) hide show
  1. agent/__init__.py +4 -10
  2. agent/__main__.py +2 -1
  3. agent/adapters/__init__.py +4 -0
  4. agent/adapters/azure_ml_adapter.py +283 -0
  5. agent/adapters/cua_adapter.py +161 -0
  6. agent/adapters/huggingfacelocal_adapter.py +67 -125
  7. agent/adapters/human_adapter.py +116 -114
  8. agent/adapters/mlxvlm_adapter.py +110 -99
  9. agent/adapters/models/__init__.py +41 -0
  10. agent/adapters/models/generic.py +78 -0
  11. agent/adapters/models/internvl.py +290 -0
  12. agent/adapters/models/opencua.py +115 -0
  13. agent/adapters/models/qwen2_5_vl.py +78 -0
  14. agent/agent.py +337 -185
  15. agent/callbacks/__init__.py +9 -4
  16. agent/callbacks/base.py +45 -31
  17. agent/callbacks/budget_manager.py +22 -10
  18. agent/callbacks/image_retention.py +54 -98
  19. agent/callbacks/logging.py +55 -42
  20. agent/callbacks/operator_validator.py +35 -33
  21. agent/callbacks/otel.py +291 -0
  22. agent/callbacks/pii_anonymization.py +19 -16
  23. agent/callbacks/prompt_instructions.py +47 -0
  24. agent/callbacks/telemetry.py +99 -61
  25. agent/callbacks/trajectory_saver.py +95 -69
  26. agent/cli.py +269 -119
  27. agent/computers/__init__.py +14 -9
  28. agent/computers/base.py +32 -19
  29. agent/computers/cua.py +52 -25
  30. agent/computers/custom.py +78 -71
  31. agent/decorators.py +23 -14
  32. agent/human_tool/__init__.py +2 -7
  33. agent/human_tool/__main__.py +6 -2
  34. agent/human_tool/server.py +48 -37
  35. agent/human_tool/ui.py +359 -235
  36. agent/integrations/hud/__init__.py +38 -99
  37. agent/integrations/hud/agent.py +369 -0
  38. agent/integrations/hud/proxy.py +166 -52
  39. agent/loops/__init__.py +44 -14
  40. agent/loops/anthropic.py +579 -492
  41. agent/loops/base.py +19 -15
  42. agent/loops/composed_grounded.py +136 -150
  43. agent/loops/fara/__init__.py +8 -0
  44. agent/loops/fara/config.py +506 -0
  45. agent/loops/fara/helpers.py +357 -0
  46. agent/loops/fara/schema.py +143 -0
  47. agent/loops/gelato.py +183 -0
  48. agent/loops/gemini.py +935 -0
  49. agent/loops/generic_vlm.py +601 -0
  50. agent/loops/glm45v.py +140 -135
  51. agent/loops/gta1.py +48 -51
  52. agent/loops/holo.py +218 -0
  53. agent/loops/internvl.py +180 -0
  54. agent/loops/moondream3.py +493 -0
  55. agent/loops/omniparser.py +326 -226
  56. agent/loops/openai.py +50 -51
  57. agent/loops/opencua.py +134 -0
  58. agent/loops/uiins.py +175 -0
  59. agent/loops/uitars.py +247 -206
  60. agent/loops/uitars2.py +951 -0
  61. agent/playground/__init__.py +5 -0
  62. agent/playground/server.py +301 -0
  63. agent/proxy/examples.py +61 -57
  64. agent/proxy/handlers.py +46 -39
  65. agent/responses.py +447 -347
  66. agent/tools/__init__.py +24 -0
  67. agent/tools/base.py +253 -0
  68. agent/tools/browser_tool.py +423 -0
  69. agent/types.py +11 -5
  70. agent/ui/__init__.py +1 -1
  71. agent/ui/__main__.py +1 -1
  72. agent/ui/gradio/app.py +25 -22
  73. agent/ui/gradio/ui_components.py +314 -167
  74. cua_agent-0.7.16.dist-info/METADATA +85 -0
  75. cua_agent-0.7.16.dist-info/RECORD +79 -0
  76. {cua_agent-0.4.22.dist-info → cua_agent-0.7.16.dist-info}/WHEEL +1 -1
  77. cua_agent-0.4.22.dist-info/METADATA +0 -436
  78. cua_agent-0.4.22.dist-info/RECORD +0 -51
  79. {cua_agent-0.4.22.dist-info → cua_agent-0.7.16.dist-info}/entry_points.txt +0 -0
agent/agent.py CHANGED
@@ -3,75 +3,87 @@ ComputerAgent - Main agent class that selects and runs agent loops
3
3
  """
4
4
 
5
5
  import asyncio
6
+ import inspect
7
+ import json
6
8
  from pathlib import Path
7
- from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple
8
-
9
- from litellm.responses.utils import Usage
10
-
11
- from .types import (
12
- Messages,
13
- AgentCapability,
14
- ToolError,
15
- IllegalArgumentError
9
+ from typing import (
10
+ Any,
11
+ AsyncGenerator,
12
+ Callable,
13
+ Dict,
14
+ List,
15
+ Optional,
16
+ Set,
17
+ Tuple,
18
+ Union,
19
+ cast,
16
20
  )
17
- from .responses import make_tool_error_item, replace_failed_computer_calls_with_function_calls
18
- from .decorators import find_agent_config
19
- import json
21
+
20
22
  import litellm
21
23
  import litellm.utils
22
- import inspect
24
+ from litellm.responses.utils import Usage
25
+
23
26
  from .adapters import (
27
+ AzureMLAdapter,
28
+ CUAAdapter,
24
29
  HuggingFaceLocalAdapter,
25
30
  HumanAdapter,
26
31
  MLXVLMAdapter,
27
32
  )
28
33
  from .callbacks import (
29
- ImageRetentionCallback,
30
- LoggingCallback,
31
- TrajectorySaverCallback,
32
34
  BudgetManagerCallback,
35
+ ImageRetentionCallback,
36
+ LoggingCallback,
37
+ OperatorNormalizerCallback,
38
+ OtelCallback,
39
+ PromptInstructionsCallback,
33
40
  TelemetryCallback,
34
- OperatorNormalizerCallback
41
+ TrajectorySaverCallback,
35
42
  )
36
- from .computers import (
37
- AsyncComputerHandler,
38
- is_agent_computer,
39
- make_computer_handler
43
+ from .computers import AsyncComputerHandler, is_agent_computer, make_computer_handler
44
+ from .decorators import find_agent_config
45
+ from .responses import (
46
+ make_tool_error_item,
47
+ replace_failed_computer_calls_with_function_calls,
40
48
  )
49
+ from .tools.base import BaseComputerTool, BaseTool
50
+ from .types import AgentCapability, IllegalArgumentError, Messages, ToolError
51
+
41
52
 
42
53
  def assert_callable_with(f, *args, **kwargs):
43
- """Check if function can be called with given arguments."""
44
- try:
45
- inspect.signature(f).bind(*args, **kwargs)
46
- return True
47
- except TypeError as e:
48
- sig = inspect.signature(f)
49
- raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
54
+ """Check if function can be called with given arguments."""
55
+ try:
56
+ inspect.signature(f).bind(*args, **kwargs)
57
+ return True
58
+ except TypeError as e:
59
+ sig = inspect.signature(f)
60
+ raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
61
+
50
62
 
51
63
  def get_json(obj: Any, max_depth: int = 10) -> Any:
52
64
  def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
53
65
  if seen is None:
54
66
  seen = set()
55
-
67
+
56
68
  # Use model_dump() if available
57
- if hasattr(o, 'model_dump'):
69
+ if hasattr(o, "model_dump"):
58
70
  return o.model_dump()
59
-
71
+
60
72
  # Check depth limit
61
73
  if depth > max_depth:
62
74
  return f"<max_depth_exceeded:{max_depth}>"
63
-
75
+
64
76
  # Check for circular references using object id
65
77
  obj_id = id(o)
66
78
  if obj_id in seen:
67
79
  return f"<circular_reference:{type(o).__name__}>"
68
-
80
+
69
81
  # Handle Computer objects
70
- if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower():
82
+ if hasattr(o, "__class__") and "computer" in o.__class__.__name__.lower():
71
83
  return f"<computer:{o.__class__.__name__}>"
72
84
 
73
85
  # Handle objects with __dict__
74
- if hasattr(o, '__dict__'):
86
+ if hasattr(o, "__dict__"):
75
87
  seen.add(obj_id)
76
88
  try:
77
89
  result = {}
@@ -83,7 +95,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
83
95
  return result
84
96
  finally:
85
97
  seen.discard(obj_id)
86
-
98
+
87
99
  # Handle common types that might contain nested objects
88
100
  elif isinstance(o, dict):
89
101
  seen.add(obj_id)
@@ -95,7 +107,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
95
107
  }
96
108
  finally:
97
109
  seen.discard(obj_id)
98
-
110
+
99
111
  elif isinstance(o, (list, tuple, set)):
100
112
  seen.add(obj_id)
101
113
  try:
@@ -106,32 +118,33 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
106
118
  ]
107
119
  finally:
108
120
  seen.discard(obj_id)
109
-
121
+
110
122
  # For basic types that json.dumps can handle
111
123
  elif isinstance(o, (str, int, float, bool)) or o is None:
112
124
  return o
113
-
125
+
114
126
  # Fallback to string representation
115
127
  else:
116
128
  return str(o)
117
-
129
+
118
130
  def remove_nones(obj: Any) -> Any:
119
131
  if isinstance(obj, dict):
120
132
  return {k: remove_nones(v) for k, v in obj.items() if v is not None}
121
133
  elif isinstance(obj, list):
122
134
  return [remove_nones(item) for item in obj if item is not None]
123
135
  return obj
124
-
136
+
125
137
  # Serialize with circular reference and depth protection
126
138
  serialized = custom_serializer(obj)
127
-
139
+
128
140
  # Convert to JSON string and back to ensure JSON compatibility
129
141
  json_str = json.dumps(serialized)
130
142
  parsed = json.loads(json_str)
131
-
143
+
132
144
  # Final cleanup of any remaining None values
133
145
  return remove_nones(parsed)
134
146
 
147
+
135
148
  def sanitize_message(msg: Any) -> Any:
136
149
  """Return a copy of the message with image_url omitted for computer_call_output messages."""
137
150
  if msg.get("type") == "computer_call_output":
@@ -142,19 +155,24 @@ def sanitize_message(msg: Any) -> Any:
142
155
  return sanitized
143
156
  return msg
144
157
 
158
+
145
159
  def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
146
160
  call_ids = []
147
161
  for message in messages:
148
- if message.get("type") == "computer_call_output" or message.get("type") == "function_call_output":
162
+ if (
163
+ message.get("type") == "computer_call_output"
164
+ or message.get("type") == "function_call_output"
165
+ ):
149
166
  call_ids.append(message.get("call_id"))
150
167
  return call_ids
151
168
 
169
+
152
170
  class ComputerAgent:
153
171
  """
154
172
  Main agent class that automatically selects the appropriate agent loop
155
173
  based on the model and executes tool calls.
156
174
  """
157
-
175
+
158
176
  def __init__(
159
177
  self,
160
178
  model: str,
@@ -162,6 +180,7 @@ class ComputerAgent:
162
180
  custom_loop: Optional[Callable] = None,
163
181
  only_n_most_recent_images: Optional[int] = None,
164
182
  callbacks: Optional[List[Any]] = None,
183
+ instructions: Optional[str] = None,
165
184
  verbosity: Optional[int] = None,
166
185
  trajectory_dir: Optional[str | Path | dict] = None,
167
186
  max_retries: Optional[int] = 3,
@@ -169,17 +188,21 @@ class ComputerAgent:
169
188
  use_prompt_caching: Optional[bool] = False,
170
189
  max_trajectory_budget: Optional[float | dict] = None,
171
190
  telemetry_enabled: Optional[bool] = True,
172
- **kwargs
191
+ trust_remote_code: Optional[bool] = False,
192
+ api_key: Optional[str] = None,
193
+ api_base: Optional[str] = None,
194
+ **additional_generation_kwargs,
173
195
  ):
174
196
  """
175
197
  Initialize ComputerAgent.
176
-
198
+
177
199
  Args:
178
- model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
200
+ model: Model name (e.g., "claude-sonnet-4-5-20250929", "computer-use-preview", "omni+vertex_ai/gemini-pro")
179
201
  tools: List of tools (computer objects, decorated functions, etc.)
180
202
  custom_loop: Custom agent loop function to use instead of auto-selection
181
203
  only_n_most_recent_images: If set, only keep the N most recent images in message history. Adds ImageRetentionCallback automatically.
182
204
  callbacks: List of AsyncCallbackHandler instances for preprocessing/postprocessing
205
+ instructions: Optional system instructions to be passed to the model
183
206
  verbosity: Logging level (logging.DEBUG, logging.INFO, etc.). If set, adds LoggingCallback automatically
184
207
  trajectory_dir: If set, saves trajectory data (screenshots, responses) to this directory. Adds TrajectorySaverCallback automatically.
185
208
  max_retries: Maximum number of retries for failed API calls
@@ -187,36 +210,40 @@ class ComputerAgent:
187
210
  use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers.
188
211
  max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
189
212
  telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
190
- **kwargs: Additional arguments passed to the agent loop
191
- """
213
+ trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
214
+ api_key: Optional API key override for the model provider
215
+ api_base: Optional API base URL override for the model provider
216
+ **additional_generation_kwargs: Additional arguments passed to the model provider
217
+ """
192
218
  # If the loop is "human/human", we need to prefix a grounding model fallback
193
219
  if model in ["human/human", "human"]:
194
220
  model = "openai/computer-use-preview+human/human"
195
-
221
+
196
222
  self.model = model
197
223
  self.tools = tools or []
198
224
  self.custom_loop = custom_loop
199
225
  self.only_n_most_recent_images = only_n_most_recent_images
200
226
  self.callbacks = callbacks or []
227
+ self.instructions = instructions
201
228
  self.verbosity = verbosity
202
229
  self.trajectory_dir = trajectory_dir
203
230
  self.max_retries = max_retries
204
231
  self.screenshot_delay = screenshot_delay
205
232
  self.use_prompt_caching = use_prompt_caching
206
233
  self.telemetry_enabled = telemetry_enabled
207
- self.kwargs = kwargs
234
+ self.kwargs = additional_generation_kwargs
235
+ self.trust_remote_code = trust_remote_code
236
+ self.api_key = api_key
237
+ self.api_base = api_base
208
238
 
209
239
  # == Add built-in callbacks ==
210
240
 
211
241
  # Prepend operator normalizer callback
212
242
  self.callbacks.insert(0, OperatorNormalizerCallback())
213
243
 
214
- # Add telemetry callback if telemetry_enabled is set
215
- if self.telemetry_enabled:
216
- if isinstance(self.telemetry_enabled, bool):
217
- self.callbacks.append(TelemetryCallback(self))
218
- else:
219
- self.callbacks.append(TelemetryCallback(self, **self.telemetry_enabled))
244
+ # Add prompt instructions callback if provided
245
+ if self.instructions:
246
+ self.callbacks.append(PromptInstructionsCallback(self.instructions))
220
247
 
221
248
  # Add logging callback if verbosity is set
222
249
  if self.verbosity is not None:
@@ -225,33 +252,37 @@ class ComputerAgent:
225
252
  # Add image retention callback if only_n_most_recent_images is set
226
253
  if self.only_n_most_recent_images:
227
254
  self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
228
-
255
+
229
256
  # Add trajectory saver callback if trajectory_dir is set
230
257
  if self.trajectory_dir:
231
258
  if isinstance(self.trajectory_dir, dict):
232
259
  self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
233
260
  elif isinstance(self.trajectory_dir, (str, Path)):
234
261
  self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
235
-
262
+
236
263
  # Add budget manager if max_trajectory_budget is set
237
264
  if max_trajectory_budget:
238
265
  if isinstance(max_trajectory_budget, dict):
239
266
  self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
240
267
  else:
241
268
  self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
242
-
269
+
243
270
  # == Enable local model providers w/ LiteLLM ==
244
271
 
245
272
  # Register local model providers
246
273
  hf_adapter = HuggingFaceLocalAdapter(
247
- device="auto"
274
+ device="auto", trust_remote_code=self.trust_remote_code or False
248
275
  )
249
276
  human_adapter = HumanAdapter()
250
277
  mlx_adapter = MLXVLMAdapter()
278
+ cua_adapter = CUAAdapter()
279
+ azure_ml_adapter = AzureMLAdapter()
251
280
  litellm.custom_provider_map = [
252
281
  {"provider": "huggingface-local", "custom_handler": hf_adapter},
253
282
  {"provider": "human", "custom_handler": human_adapter},
254
- {"provider": "mlx", "custom_handler": mlx_adapter}
283
+ {"provider": "mlx", "custom_handler": mlx_adapter},
284
+ {"provider": "cua", "custom_handler": cua_adapter},
285
+ {"provider": "azure_ml", "custom_handler": azure_ml_adapter},
255
286
  ]
256
287
  litellm.suppress_debug_info = True
257
288
 
@@ -268,24 +299,47 @@ class ComputerAgent:
268
299
  # Instantiate the agent config class
269
300
  self.agent_loop = config_info.agent_class()
270
301
  self.agent_config_info = config_info
271
-
302
+
303
+ # Add telemetry callbacks AFTER agent_loop is set so they can capture the correct agent_type
304
+ if self.telemetry_enabled:
305
+ # PostHog telemetry (product analytics)
306
+ if isinstance(self.telemetry_enabled, bool):
307
+ self.callbacks.append(TelemetryCallback(self))
308
+ else:
309
+ self.callbacks.append(TelemetryCallback(self, **self.telemetry_enabled))
310
+
311
+ # OpenTelemetry callback (operational metrics - Four Golden Signals)
312
+ # This is enabled alongside PostHog when telemetry_enabled is True
313
+ # Users can disable via CUA_TELEMETRY_DISABLED=true env var
314
+ self.callbacks.append(OtelCallback(self))
315
+
272
316
  self.tool_schemas = []
273
317
  self.computer_handler = None
274
-
318
+
275
319
  async def _initialize_computers(self):
276
320
  """Initialize computer objects"""
277
321
  if not self.tool_schemas:
278
322
  # Process tools and create tool schemas
279
323
  self.tool_schemas = self._process_tools()
280
-
324
+
281
325
  # Find computer tool and create interface adapter
282
326
  computer_handler = None
283
- for schema in self.tool_schemas:
284
- if schema["type"] == "computer":
285
- computer_handler = await make_computer_handler(schema["computer"])
327
+
328
+ # First check if any tool is a BaseComputerTool instance
329
+ for tool in self.tools:
330
+ if isinstance(tool, BaseComputerTool):
331
+ computer_handler = tool
286
332
  break
333
+
334
+ # If no BaseComputerTool found, look for traditional computer objects
335
+ if computer_handler is None:
336
+ for schema in self.tool_schemas:
337
+ if schema["type"] == "computer":
338
+ computer_handler = await make_computer_handler(schema["computer"])
339
+ break
340
+
287
341
  self.computer_handler = computer_handler
288
-
342
+
289
343
  def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
290
344
  """Process input messages and create schemas for the agent loop"""
291
345
  if isinstance(input, str):
@@ -295,69 +349,85 @@ class ComputerAgent:
295
349
  def _process_tools(self) -> List[Dict[str, Any]]:
296
350
  """Process tools and create schemas for the agent loop"""
297
351
  schemas = []
298
-
352
+
299
353
  for tool in self.tools:
300
354
  # Check if it's a computer object (has interface attribute)
301
355
  if is_agent_computer(tool):
302
356
  # This is a computer tool - will be handled by agent loop
303
- schemas.append({
304
- "type": "computer",
305
- "computer": tool
306
- })
357
+ schemas.append({"type": "computer", "computer": tool})
358
+ elif isinstance(tool, BaseTool):
359
+ # BaseTool instance - extract schema from its properties
360
+ function_schema = {
361
+ "name": tool.name,
362
+ "description": tool.description,
363
+ "parameters": tool.parameters,
364
+ }
365
+ schemas.append({"type": "function", "function": function_schema})
307
366
  elif callable(tool):
308
367
  # Use litellm.utils.function_to_dict to extract schema from docstring
309
368
  try:
310
369
  function_schema = litellm.utils.function_to_dict(tool)
311
- schemas.append({
312
- "type": "function",
313
- "function": function_schema
314
- })
370
+ schemas.append({"type": "function", "function": function_schema})
315
371
  except Exception as e:
316
372
  print(f"Warning: Could not process tool {tool}: {e}")
317
373
  else:
318
374
  print(f"Warning: Unknown tool type: {tool}")
319
-
375
+
320
376
  return schemas
321
-
322
- def _get_tool(self, name: str) -> Optional[Callable]:
377
+
378
+ def _get_tool(self, name: str) -> Optional[Union[Callable, BaseTool]]:
323
379
  """Get a tool by name"""
324
380
  for tool in self.tools:
325
- if hasattr(tool, '__name__') and tool.__name__ == name:
381
+ # Check if it's a BaseTool instance
382
+ if isinstance(tool, BaseTool) and tool.name == name:
383
+ return tool
384
+ # Check if it's a regular callable
385
+ elif hasattr(tool, "__name__") and tool.__name__ == name:
326
386
  return tool
327
- elif hasattr(tool, 'func') and tool.func.__name__ == name:
387
+ elif hasattr(tool, "func") and tool.func.__name__ == name:
328
388
  return tool
329
389
  return None
330
-
390
+
331
391
  # ============================================================================
332
392
  # AGENT RUN LOOP LIFECYCLE HOOKS
333
393
  # ============================================================================
334
-
394
+
335
395
  async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
336
396
  """Initialize run tracking by calling callbacks."""
337
397
  for callback in self.callbacks:
338
- if hasattr(callback, 'on_run_start'):
398
+ if hasattr(callback, "on_run_start"):
339
399
  await callback.on_run_start(kwargs, old_items)
340
-
341
- async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
400
+
401
+ async def _on_run_end(
402
+ self,
403
+ kwargs: Dict[str, Any],
404
+ old_items: List[Dict[str, Any]],
405
+ new_items: List[Dict[str, Any]],
406
+ ) -> None:
342
407
  """Finalize run tracking by calling callbacks."""
343
408
  for callback in self.callbacks:
344
- if hasattr(callback, 'on_run_end'):
409
+ if hasattr(callback, "on_run_end"):
345
410
  await callback.on_run_end(kwargs, old_items, new_items)
346
-
347
- async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
411
+
412
+ async def _on_run_continue(
413
+ self,
414
+ kwargs: Dict[str, Any],
415
+ old_items: List[Dict[str, Any]],
416
+ new_items: List[Dict[str, Any]],
417
+ ) -> bool:
348
418
  """Check if run should continue by calling callbacks."""
349
419
  for callback in self.callbacks:
350
- if hasattr(callback, 'on_run_continue'):
420
+ if hasattr(callback, "on_run_continue"):
351
421
  should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
352
422
  if not should_continue:
353
423
  return False
354
424
  return True
355
-
425
+
356
426
  async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
357
427
  """Prepare messages for the LLM call by applying callbacks."""
358
428
  result = messages
359
429
  for callback in self.callbacks:
360
- if hasattr(callback, 'on_llm_start'):
430
+ if hasattr(callback, "on_llm_start"):
361
431
  result = await callback.on_llm_start(result)
362
432
  return result
363
433
 
@@ -365,82 +435,91 @@ class ComputerAgent:
365
435
  """Postprocess messages after the LLM call by applying callbacks."""
366
436
  result = messages
367
437
  for callback in self.callbacks:
368
- if hasattr(callback, 'on_llm_end'):
438
+ if hasattr(callback, "on_llm_end"):
369
439
  result = await callback.on_llm_end(result)
370
440
  return result
371
441
 
372
442
  async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
373
443
  """Called when responses are received."""
374
444
  for callback in self.callbacks:
375
- if hasattr(callback, 'on_responses'):
445
+ if hasattr(callback, "on_responses"):
376
446
  await callback.on_responses(get_json(kwargs), get_json(responses))
377
-
447
+
378
448
  async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
379
449
  """Called when a computer call is about to start."""
380
450
  for callback in self.callbacks:
381
- if hasattr(callback, 'on_computer_call_start'):
451
+ if hasattr(callback, "on_computer_call_start"):
382
452
  await callback.on_computer_call_start(get_json(item))
383
-
384
- async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
453
+
454
+ async def _on_computer_call_end(
455
+ self, item: Dict[str, Any], result: List[Dict[str, Any]]
456
+ ) -> None:
385
457
  """Called when a computer call has completed."""
386
458
  for callback in self.callbacks:
387
- if hasattr(callback, 'on_computer_call_end'):
459
+ if hasattr(callback, "on_computer_call_end"):
388
460
  await callback.on_computer_call_end(get_json(item), get_json(result))
389
-
461
+
390
462
  async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
391
463
  """Called when a function call is about to start."""
392
464
  for callback in self.callbacks:
393
- if hasattr(callback, 'on_function_call_start'):
465
+ if hasattr(callback, "on_function_call_start"):
394
466
  await callback.on_function_call_start(get_json(item))
395
-
396
- async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
467
+
468
+ async def _on_function_call_end(
469
+ self, item: Dict[str, Any], result: List[Dict[str, Any]]
470
+ ) -> None:
397
471
  """Called when a function call has completed."""
398
472
  for callback in self.callbacks:
399
- if hasattr(callback, 'on_function_call_end'):
473
+ if hasattr(callback, "on_function_call_end"):
400
474
  await callback.on_function_call_end(get_json(item), get_json(result))
401
-
475
+
402
476
  async def _on_text(self, item: Dict[str, Any]) -> None:
403
477
  """Called when a text message is encountered."""
404
478
  for callback in self.callbacks:
405
- if hasattr(callback, 'on_text'):
479
+ if hasattr(callback, "on_text"):
406
480
  await callback.on_text(get_json(item))
407
-
481
+
408
482
  async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
409
483
  """Called when an LLM API call is about to start."""
410
484
  for callback in self.callbacks:
411
- if hasattr(callback, 'on_api_start'):
485
+ if hasattr(callback, "on_api_start"):
412
486
  await callback.on_api_start(get_json(kwargs))
413
-
487
+
414
488
  async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
415
489
  """Called when an LLM API call has completed."""
416
490
  for callback in self.callbacks:
417
- if hasattr(callback, 'on_api_end'):
491
+ if hasattr(callback, "on_api_end"):
418
492
  await callback.on_api_end(get_json(kwargs), get_json(result))
419
493
 
420
494
  async def _on_usage(self, usage: Dict[str, Any]) -> None:
421
495
  """Called when usage information is received."""
422
496
  for callback in self.callbacks:
423
- if hasattr(callback, 'on_usage'):
497
+ if hasattr(callback, "on_usage"):
424
498
  await callback.on_usage(get_json(usage))
425
499
 
426
500
  async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
427
501
  """Called when a screenshot is taken."""
428
502
  for callback in self.callbacks:
429
- if hasattr(callback, 'on_screenshot'):
503
+ if hasattr(callback, "on_screenshot"):
430
504
  await callback.on_screenshot(screenshot, name)
431
505
 
432
506
  # ============================================================================
433
507
  # AGENT OUTPUT PROCESSING
434
508
  # ============================================================================
435
-
436
- async def _handle_item(self, item: Any, computer: Optional[AsyncComputerHandler] = None, ignore_call_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
509
+
510
+ async def _handle_item(
511
+ self,
512
+ item: Any,
513
+ computer: Optional[AsyncComputerHandler] = None,
514
+ ignore_call_ids: Optional[List[str]] = None,
515
+ ) -> List[Dict[str, Any]]:
437
516
  """Handle each item; may cause a computer action + screenshot."""
438
517
  call_id = item.get("call_id")
439
518
  if ignore_call_ids and call_id and call_id in ignore_call_ids:
440
519
  return []
441
-
520
+
442
521
  item_type = item.get("type", None)
443
-
522
+
444
523
  if item_type == "message":
445
524
  await self._on_text(item)
446
525
  # # Print messages
@@ -449,7 +528,7 @@ class ComputerAgent:
449
528
  # if content_item.get("text"):
450
529
  # print(content_item.get("text"))
451
530
  return []
452
-
531
+
453
532
  try:
454
533
  if item_type == "computer_call":
455
534
  await self._on_computer_call_start(item)
@@ -460,14 +539,16 @@ class ComputerAgent:
460
539
  action = item.get("action")
461
540
  action_type = action.get("type")
462
541
  if action_type is None:
463
- print(f"Action type cannot be `None`: action={action}, action_type={action_type}")
542
+ print(
543
+ f"Action type cannot be `None`: action={action}, action_type={action_type}"
544
+ )
464
545
  return []
465
-
546
+
466
547
  # Extract action arguments (all fields except 'type')
467
548
  action_args = {k: v for k, v in action.items() if k != "type"}
468
-
549
+
469
550
  # print(f"{action_type}({action_args})")
470
-
551
+
471
552
  # Execute the computer action
472
553
  computer_method = getattr(computer, action_type, None)
473
554
  if computer_method:
@@ -475,13 +556,13 @@ class ComputerAgent:
475
556
  await computer_method(**action_args)
476
557
  else:
477
558
  raise ToolError(f"Unknown computer action: {action_type}")
478
-
559
+
479
560
  # Take screenshot after action
480
561
  if self.screenshot_delay and self.screenshot_delay > 0:
481
562
  await asyncio.sleep(self.screenshot_delay)
482
563
  screenshot_base64 = await computer.screenshot()
483
564
  await self._on_screenshot(screenshot_base64, "screenshot_after")
484
-
565
+
485
566
  # Handle safety checks
486
567
  pending_checks = item.get("pending_safety_checks", [])
487
568
  acknowledged_checks = []
@@ -493,7 +574,7 @@ class ComputerAgent:
493
574
  # acknowledged_checks.append(check)
494
575
  # else:
495
576
  # raise ValueError(f"Safety check failed: {check_message}")
496
-
577
+
497
578
  # Create call output
498
579
  call_output = {
499
580
  "type": "computer_call_output",
@@ -504,43 +585,48 @@ class ComputerAgent:
504
585
  "image_url": f"data:image/png;base64,{screenshot_base64}",
505
586
  },
506
587
  }
507
-
588
+
508
589
  # # Additional URL safety checks for browser environments
509
590
  # if await computer.get_environment() == "browser":
510
591
  # current_url = await computer.get_current_url()
511
592
  # call_output["output"]["current_url"] = current_url
512
593
  # # TODO: implement a callback for URL safety checks
513
594
  # # check_blocklisted_url(current_url)
514
-
595
+
515
596
  result = [call_output]
516
597
  await self._on_computer_call_end(item, result)
517
598
  return result
518
-
599
+
519
600
  if item_type == "function_call":
520
601
  await self._on_function_call_start(item)
521
602
  # Perform function call
522
603
  function = self._get_tool(item.get("name"))
523
604
  if not function:
524
- raise ToolError(f"Function {item.get("name")} not found")
525
-
526
- args = json.loads(item.get("arguments"))
605
+ raise ToolError(f"Function {item.get('name')} not found")
527
606
 
528
- # Validate arguments before execution
529
- assert_callable_with(function, **args)
607
+ args = json.loads(item.get("arguments"))
530
608
 
531
- # Execute function - use asyncio.to_thread for non-async functions
532
- if inspect.iscoroutinefunction(function):
533
- result = await function(**args)
609
+ # Handle BaseTool instances
610
+ if isinstance(function, BaseTool):
611
+ # BaseTool.call() handles its own execution
612
+ result = function.call(args)
534
613
  else:
535
- result = await asyncio.to_thread(function, **args)
536
-
614
+ # Validate arguments before execution for regular callables
615
+ assert_callable_with(function, **args)
616
+
617
+ # Execute function - use asyncio.to_thread for non-async functions
618
+ if inspect.iscoroutinefunction(function):
619
+ result = await function(**args)
620
+ else:
621
+ result = await asyncio.to_thread(function, **args)
622
+
537
623
  # Create function call output
538
624
  call_output = {
539
625
  "type": "function_call_output",
540
626
  "call_id": item.get("call_id"),
541
627
  "output": str(result),
542
628
  }
543
-
629
+
544
630
  result = [call_output]
545
631
  await self._on_function_call_end(item, result)
546
632
  return result
@@ -552,36 +638,46 @@ class ComputerAgent:
552
638
  # ============================================================================
553
639
  # MAIN AGENT LOOP
554
640
  # ============================================================================
555
-
641
+
556
642
  async def run(
557
643
  self,
558
644
  messages: Messages,
559
645
  stream: bool = False,
560
- **kwargs
646
+ api_key: Optional[str] = None,
647
+ api_base: Optional[str] = None,
648
+ **additional_generation_kwargs,
561
649
  ) -> AsyncGenerator[Dict[str, Any], None]:
562
650
  """
563
651
  Run the agent with the given messages using Computer protocol handler pattern.
564
-
652
+
565
653
  Args:
566
654
  messages: List of message dictionaries
567
655
  stream: Whether to stream the response
568
- **kwargs: Additional arguments
569
-
656
+ api_key: Optional API key override for the model provider
657
+ api_base: Optional API base URL override for the model provider
658
+ **additional_generation_kwargs: Additional arguments passed to the model provider
659
+
570
660
  Returns:
571
661
  AsyncGenerator that yields response chunks
572
662
  """
573
663
  if not self.agent_config_info:
574
664
  raise ValueError("Agent configuration not found")
575
-
665
+
576
666
  capabilities = self.get_capabilities()
577
667
  if "step" not in capabilities:
578
- raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions")
668
+ raise ValueError(
669
+ f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions"
670
+ )
579
671
 
580
672
  await self._initialize_computers()
581
-
582
- # Merge kwargs
583
- merged_kwargs = {**self.kwargs, **kwargs}
584
-
673
+
674
+ # Merge kwargs and thread api credentials (run overrides constructor)
675
+ merged_kwargs = {**self.kwargs, **additional_generation_kwargs}
676
+ if (api_key is not None) or (self.api_key is not None):
677
+ merged_kwargs["api_key"] = api_key if api_key is not None else self.api_key
678
+ if (api_base is not None) or (self.api_base is not None):
679
+ merged_kwargs["api_base"] = api_base if api_base is not None else self.api_base
680
+
585
681
  old_items = self._process_input(messages)
586
682
  new_items = []
587
683
 
@@ -591,7 +687,7 @@ class ComputerAgent:
591
687
  "stream": stream,
592
688
  "model": self.model,
593
689
  "agent_loop": self.agent_config_info.agent_class.__name__,
594
- **merged_kwargs
690
+ **merged_kwargs,
595
691
  }
596
692
  await self._on_run_start(run_kwargs, old_items)
597
693
 
@@ -608,7 +704,7 @@ class ComputerAgent:
608
704
  combined_messages = old_items + new_items
609
705
  combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
610
706
  preprocessed_messages = await self._on_llm_start(combined_messages)
611
-
707
+
612
708
  loop_kwargs = {
613
709
  "messages": preprocessed_messages,
614
710
  "model": self.model,
@@ -617,9 +713,39 @@ class ComputerAgent:
617
713
  "computer_handler": self.computer_handler,
618
714
  "max_retries": self.max_retries,
619
715
  "use_prompt_caching": self.use_prompt_caching,
620
- **merged_kwargs
716
+ **merged_kwargs,
621
717
  }
622
718
 
719
+ # ---- Ollama image input guard ----
720
+ if isinstance(self.model, str) and (
721
+ "ollama/" in self.model or "ollama_chat/" in self.model
722
+ ):
723
+
724
+ def contains_image_content(msgs):
725
+ for m in msgs:
726
+ # 1️⃣ Check regular message content
727
+ content = m.get("content")
728
+ if isinstance(content, list):
729
+ for item in content:
730
+ if isinstance(item, dict) and item.get("type") == "image_url":
731
+ return True
732
+
733
+ # 2️⃣ Check computer_call_output screenshots
734
+ if m.get("type") == "computer_call_output":
735
+ output = m.get("output", {})
736
+ if output.get("type") == "input_image" and "image_url" in output:
737
+ return True
738
+
739
+ return False
740
+
741
+ if contains_image_content(preprocessed_messages):
742
+ raise ValueError(
743
+ "Ollama models do not support image inputs required by ComputerAgent. "
744
+ "Please use a vision-capable model (e.g., OpenAI or Anthropic) "
745
+ "or remove computer/screenshot actions."
746
+ )
747
+ # ---------------------------------
748
+
623
749
  # Run agent loop iteration
624
750
  result = await self.agent_loop.predict_step(
625
751
  **loop_kwargs,
@@ -629,13 +755,13 @@ class ComputerAgent:
629
755
  _on_screenshot=self._on_screenshot,
630
756
  )
631
757
  result = get_json(result)
632
-
758
+
633
759
  # Lifecycle hook: Postprocess messages after the LLM call
634
760
  # Use cases:
635
761
  # - PII deanonymization (if you want tool calls to see PII)
636
762
  result["output"] = await self._on_llm_end(result.get("output", []))
637
763
  await self._on_responses(loop_kwargs, result)
638
-
764
+
639
765
  # Yield agent response
640
766
  yield result
641
767
 
@@ -647,64 +773,90 @@ class ComputerAgent:
647
773
 
648
774
  # Handle computer actions
649
775
  for item in result.get("output"):
650
- partial_items = await self._handle_item(item, self.computer_handler, ignore_call_ids=output_call_ids)
776
+ partial_items = await self._handle_item(
777
+ item, self.computer_handler, ignore_call_ids=output_call_ids
778
+ )
651
779
  new_items += partial_items
652
780
 
653
- # Yield partial response
654
- yield {
655
- "output": partial_items,
656
- "usage": Usage(
657
- prompt_tokens=0,
658
- completion_tokens=0,
659
- total_tokens=0,
660
- )
661
- }
662
-
781
+ # Yield partial response if any
782
+ if partial_items:
783
+ yield {
784
+ "output": partial_items,
785
+ "usage": Usage(
786
+ prompt_tokens=0,
787
+ completion_tokens=0,
788
+ total_tokens=0,
789
+ ),
790
+ }
791
+
663
792
  await self._on_run_end(loop_kwargs, old_items, new_items)
664
-
793
+
665
794
  async def predict_click(
666
- self,
667
- instruction: str,
668
- image_b64: Optional[str] = None
795
+ self, instruction: str, image_b64: Optional[str] = None
669
796
  ) -> Optional[Tuple[int, int]]:
670
797
  """
671
798
  Predict click coordinates based on image and instruction.
672
-
799
+
673
800
  Args:
674
801
  instruction: Instruction for where to click
675
802
  image_b64: Base64 encoded image (optional, will take screenshot if not provided)
676
-
803
+
677
804
  Returns:
678
805
  None or tuple with (x, y) coordinates
679
806
  """
680
807
  if not self.agent_config_info:
681
808
  raise ValueError("Agent configuration not found")
682
-
809
+
683
810
  capabilities = self.get_capabilities()
684
811
  if "click" not in capabilities:
685
- raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions")
686
- if hasattr(self.agent_loop, 'predict_click'):
812
+ raise ValueError(
813
+ f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions"
814
+ )
815
+ if hasattr(self.agent_loop, "predict_click"):
687
816
  if not image_b64:
688
817
  if not self.computer_handler:
689
818
  raise ValueError("Computer tool or image_b64 is required for predict_click")
690
819
  image_b64 = await self.computer_handler.screenshot()
820
+ # Pass along api credentials if available
821
+ click_kwargs: Dict[str, Any] = {}
822
+ if self.api_key is not None:
823
+ click_kwargs["api_key"] = self.api_key
824
+ if self.api_base is not None:
825
+ click_kwargs["api_base"] = self.api_base
691
826
  return await self.agent_loop.predict_click(
692
- model=self.model,
693
- image_b64=image_b64,
694
- instruction=instruction
827
+ model=self.model, image_b64=image_b64, instruction=instruction, **click_kwargs
695
828
  )
696
829
  return None
697
-
830
+
698
831
  def get_capabilities(self) -> List[AgentCapability]:
699
832
  """
700
833
  Get list of capabilities supported by the current agent config.
701
-
834
+
702
835
  Returns:
703
836
  List of capability strings (e.g., ["step", "click"])
704
837
  """
705
838
  if not self.agent_config_info:
706
839
  raise ValueError("Agent configuration not found")
707
-
708
- if hasattr(self.agent_loop, 'get_capabilities'):
840
+
841
+ if hasattr(self.agent_loop, "get_capabilities"):
709
842
  return self.agent_loop.get_capabilities()
710
- return ["step"] # Default capability
843
+ return ["step"] # Default capability
844
+
845
+ def open(self, port: Optional[int] = None):
846
+ """
847
+ Start the playground server and open it in the browser.
848
+
849
+ This method starts a local HTTP server that exposes the /responses endpoint
850
+ and automatically opens the Cua playground interface in the default browser.
851
+
852
+ Args:
853
+ port: Port to run the server on. If None, finds an available port automatically.
854
+
855
+ Example:
856
+ >>> agent = ComputerAgent(model="claude-sonnet-4")
857
+ >>> agent.open() # Starts server and opens browser
858
+ """
859
+ from .playground import PlaygroundServer
860
+
861
+ server = PlaygroundServer(agent_instance=self)
862
+ server.start(port=port, open_browser=True)