cua-agent 0.4.34__py3-none-any.whl → 0.4.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (61) hide show
  1. agent/__init__.py +4 -10
  2. agent/__main__.py +2 -1
  3. agent/adapters/huggingfacelocal_adapter.py +54 -61
  4. agent/adapters/human_adapter.py +116 -114
  5. agent/adapters/mlxvlm_adapter.py +110 -99
  6. agent/adapters/models/__init__.py +14 -6
  7. agent/adapters/models/generic.py +7 -4
  8. agent/adapters/models/internvl.py +66 -30
  9. agent/adapters/models/opencua.py +23 -8
  10. agent/adapters/models/qwen2_5_vl.py +7 -4
  11. agent/agent.py +184 -158
  12. agent/callbacks/__init__.py +4 -4
  13. agent/callbacks/base.py +45 -31
  14. agent/callbacks/budget_manager.py +22 -10
  15. agent/callbacks/image_retention.py +18 -13
  16. agent/callbacks/logging.py +55 -42
  17. agent/callbacks/operator_validator.py +3 -1
  18. agent/callbacks/pii_anonymization.py +19 -16
  19. agent/callbacks/telemetry.py +67 -61
  20. agent/callbacks/trajectory_saver.py +90 -70
  21. agent/cli.py +115 -110
  22. agent/computers/__init__.py +13 -8
  23. agent/computers/base.py +26 -17
  24. agent/computers/cua.py +27 -23
  25. agent/computers/custom.py +72 -69
  26. agent/decorators.py +23 -14
  27. agent/human_tool/__init__.py +2 -7
  28. agent/human_tool/__main__.py +6 -2
  29. agent/human_tool/server.py +48 -37
  30. agent/human_tool/ui.py +235 -185
  31. agent/integrations/hud/__init__.py +15 -21
  32. agent/integrations/hud/agent.py +101 -83
  33. agent/integrations/hud/proxy.py +90 -57
  34. agent/loops/__init__.py +25 -21
  35. agent/loops/anthropic.py +537 -483
  36. agent/loops/base.py +13 -14
  37. agent/loops/composed_grounded.py +135 -149
  38. agent/loops/gemini.py +31 -12
  39. agent/loops/glm45v.py +135 -133
  40. agent/loops/gta1.py +47 -50
  41. agent/loops/holo.py +4 -2
  42. agent/loops/internvl.py +6 -11
  43. agent/loops/moondream3.py +36 -12
  44. agent/loops/omniparser.py +212 -209
  45. agent/loops/openai.py +49 -50
  46. agent/loops/opencua.py +29 -41
  47. agent/loops/qwen.py +475 -0
  48. agent/loops/uitars.py +237 -202
  49. agent/proxy/examples.py +54 -50
  50. agent/proxy/handlers.py +27 -34
  51. agent/responses.py +330 -330
  52. agent/types.py +11 -5
  53. agent/ui/__init__.py +1 -1
  54. agent/ui/__main__.py +1 -1
  55. agent/ui/gradio/app.py +23 -18
  56. agent/ui/gradio/ui_components.py +310 -161
  57. {cua_agent-0.4.34.dist-info → cua_agent-0.4.35.dist-info}/METADATA +18 -10
  58. cua_agent-0.4.35.dist-info/RECORD +64 -0
  59. cua_agent-0.4.34.dist-info/RECORD +0 -63
  60. {cua_agent-0.4.34.dist-info → cua_agent-0.4.35.dist-info}/WHEEL +0 -0
  61. {cua_agent-0.4.34.dist-info → cua_agent-0.4.35.dist-info}/entry_points.txt +0 -0
agent/agent.py CHANGED
@@ -3,76 +3,83 @@ ComputerAgent - Main agent class that selects and runs agent loops
3
3
  """
4
4
 
5
5
  import asyncio
6
+ import inspect
7
+ import json
6
8
  from pathlib import Path
7
- from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple
8
-
9
- from litellm.responses.utils import Usage
10
-
11
- from .types import (
12
- Messages,
13
- AgentCapability,
14
- ToolError,
15
- IllegalArgumentError
9
+ from typing import (
10
+ Any,
11
+ AsyncGenerator,
12
+ Callable,
13
+ Dict,
14
+ List,
15
+ Optional,
16
+ Set,
17
+ Tuple,
18
+ Union,
19
+ cast,
16
20
  )
17
- from .responses import make_tool_error_item, replace_failed_computer_calls_with_function_calls
18
- from .decorators import find_agent_config
19
- import json
21
+
20
22
  import litellm
21
23
  import litellm.utils
22
- import inspect
24
+ from litellm.responses.utils import Usage
25
+
23
26
  from .adapters import (
24
27
  HuggingFaceLocalAdapter,
25
28
  HumanAdapter,
26
29
  MLXVLMAdapter,
27
30
  )
28
31
  from .callbacks import (
29
- ImageRetentionCallback,
30
- LoggingCallback,
31
- TrajectorySaverCallback,
32
32
  BudgetManagerCallback,
33
- TelemetryCallback,
33
+ ImageRetentionCallback,
34
+ LoggingCallback,
34
35
  OperatorNormalizerCallback,
35
36
  PromptInstructionsCallback,
37
+ TelemetryCallback,
38
+ TrajectorySaverCallback,
36
39
  )
37
- from .computers import (
38
- AsyncComputerHandler,
39
- is_agent_computer,
40
- make_computer_handler
40
+ from .computers import AsyncComputerHandler, is_agent_computer, make_computer_handler
41
+ from .decorators import find_agent_config
42
+ from .responses import (
43
+ make_tool_error_item,
44
+ replace_failed_computer_calls_with_function_calls,
41
45
  )
46
+ from .types import AgentCapability, IllegalArgumentError, Messages, ToolError
47
+
42
48
 
43
49
  def assert_callable_with(f, *args, **kwargs):
44
- """Check if function can be called with given arguments."""
45
- try:
46
- inspect.signature(f).bind(*args, **kwargs)
47
- return True
48
- except TypeError as e:
49
- sig = inspect.signature(f)
50
- raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
50
+ """Check if function can be called with given arguments."""
51
+ try:
52
+ inspect.signature(f).bind(*args, **kwargs)
53
+ return True
54
+ except TypeError as e:
55
+ sig = inspect.signature(f)
56
+ raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
57
+
51
58
 
52
59
  def get_json(obj: Any, max_depth: int = 10) -> Any:
53
60
  def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
54
61
  if seen is None:
55
62
  seen = set()
56
-
63
+
57
64
  # Use model_dump() if available
58
- if hasattr(o, 'model_dump'):
65
+ if hasattr(o, "model_dump"):
59
66
  return o.model_dump()
60
-
67
+
61
68
  # Check depth limit
62
69
  if depth > max_depth:
63
70
  return f"<max_depth_exceeded:{max_depth}>"
64
-
71
+
65
72
  # Check for circular references using object id
66
73
  obj_id = id(o)
67
74
  if obj_id in seen:
68
75
  return f"<circular_reference:{type(o).__name__}>"
69
-
76
+
70
77
  # Handle Computer objects
71
- if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower():
78
+ if hasattr(o, "__class__") and "computer" in o.__class__.__name__.lower():
72
79
  return f"<computer:{o.__class__.__name__}>"
73
80
 
74
81
  # Handle objects with __dict__
75
- if hasattr(o, '__dict__'):
82
+ if hasattr(o, "__dict__"):
76
83
  seen.add(obj_id)
77
84
  try:
78
85
  result = {}
@@ -84,7 +91,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
84
91
  return result
85
92
  finally:
86
93
  seen.discard(obj_id)
87
-
94
+
88
95
  # Handle common types that might contain nested objects
89
96
  elif isinstance(o, dict):
90
97
  seen.add(obj_id)
@@ -96,7 +103,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
96
103
  }
97
104
  finally:
98
105
  seen.discard(obj_id)
99
-
106
+
100
107
  elif isinstance(o, (list, tuple, set)):
101
108
  seen.add(obj_id)
102
109
  try:
@@ -107,32 +114,33 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
107
114
  ]
108
115
  finally:
109
116
  seen.discard(obj_id)
110
-
117
+
111
118
  # For basic types that json.dumps can handle
112
119
  elif isinstance(o, (str, int, float, bool)) or o is None:
113
120
  return o
114
-
121
+
115
122
  # Fallback to string representation
116
123
  else:
117
124
  return str(o)
118
-
125
+
119
126
  def remove_nones(obj: Any) -> Any:
120
127
  if isinstance(obj, dict):
121
128
  return {k: remove_nones(v) for k, v in obj.items() if v is not None}
122
129
  elif isinstance(obj, list):
123
130
  return [remove_nones(item) for item in obj if item is not None]
124
131
  return obj
125
-
132
+
126
133
  # Serialize with circular reference and depth protection
127
134
  serialized = custom_serializer(obj)
128
-
135
+
129
136
  # Convert to JSON string and back to ensure JSON compatibility
130
137
  json_str = json.dumps(serialized)
131
138
  parsed = json.loads(json_str)
132
-
139
+
133
140
  # Final cleanup of any remaining None values
134
141
  return remove_nones(parsed)
135
142
 
143
+
136
144
  def sanitize_message(msg: Any) -> Any:
137
145
  """Return a copy of the message with image_url omitted for computer_call_output messages."""
138
146
  if msg.get("type") == "computer_call_output":
@@ -143,19 +151,24 @@ def sanitize_message(msg: Any) -> Any:
143
151
  return sanitized
144
152
  return msg
145
153
 
154
+
146
155
  def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
147
156
  call_ids = []
148
157
  for message in messages:
149
- if message.get("type") == "computer_call_output" or message.get("type") == "function_call_output":
158
+ if (
159
+ message.get("type") == "computer_call_output"
160
+ or message.get("type") == "function_call_output"
161
+ ):
150
162
  call_ids.append(message.get("call_id"))
151
163
  return call_ids
152
164
 
165
+
153
166
  class ComputerAgent:
154
167
  """
155
168
  Main agent class that automatically selects the appropriate agent loop
156
169
  based on the model and executes tool calls.
157
170
  """
158
-
171
+
159
172
  def __init__(
160
173
  self,
161
174
  model: str,
@@ -172,11 +185,11 @@ class ComputerAgent:
172
185
  max_trajectory_budget: Optional[float | dict] = None,
173
186
  telemetry_enabled: Optional[bool] = True,
174
187
  trust_remote_code: Optional[bool] = False,
175
- **kwargs
188
+ **kwargs,
176
189
  ):
177
190
  """
178
191
  Initialize ComputerAgent.
179
-
192
+
180
193
  Args:
181
194
  model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
182
195
  tools: List of tools (computer objects, decorated functions, etc.)
@@ -193,11 +206,11 @@ class ComputerAgent:
193
206
  telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
194
207
  trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
195
208
  **kwargs: Additional arguments passed to the agent loop
196
- """
209
+ """
197
210
  # If the loop is "human/human", we need to prefix a grounding model fallback
198
211
  if model in ["human/human", "human"]:
199
212
  model = "openai/computer-use-preview+human/human"
200
-
213
+
201
214
  self.model = model
202
215
  self.tools = tools or []
203
216
  self.custom_loop = custom_loop
@@ -236,34 +249,33 @@ class ComputerAgent:
236
249
  # Add image retention callback if only_n_most_recent_images is set
237
250
  if self.only_n_most_recent_images:
238
251
  self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
239
-
252
+
240
253
  # Add trajectory saver callback if trajectory_dir is set
241
254
  if self.trajectory_dir:
242
255
  if isinstance(self.trajectory_dir, dict):
243
256
  self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
244
257
  elif isinstance(self.trajectory_dir, (str, Path)):
245
258
  self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
246
-
259
+
247
260
  # Add budget manager if max_trajectory_budget is set
248
261
  if max_trajectory_budget:
249
262
  if isinstance(max_trajectory_budget, dict):
250
263
  self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
251
264
  else:
252
265
  self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
253
-
266
+
254
267
  # == Enable local model providers w/ LiteLLM ==
255
268
 
256
269
  # Register local model providers
257
270
  hf_adapter = HuggingFaceLocalAdapter(
258
- device="auto",
259
- trust_remote_code=self.trust_remote_code or False
271
+ device="auto", trust_remote_code=self.trust_remote_code or False
260
272
  )
261
273
  human_adapter = HumanAdapter()
262
274
  mlx_adapter = MLXVLMAdapter()
263
275
  litellm.custom_provider_map = [
264
276
  {"provider": "huggingface-local", "custom_handler": hf_adapter},
265
277
  {"provider": "human", "custom_handler": human_adapter},
266
- {"provider": "mlx", "custom_handler": mlx_adapter}
278
+ {"provider": "mlx", "custom_handler": mlx_adapter},
267
279
  ]
268
280
  litellm.suppress_debug_info = True
269
281
 
@@ -280,16 +292,16 @@ class ComputerAgent:
280
292
  # Instantiate the agent config class
281
293
  self.agent_loop = config_info.agent_class()
282
294
  self.agent_config_info = config_info
283
-
295
+
284
296
  self.tool_schemas = []
285
297
  self.computer_handler = None
286
-
298
+
287
299
  async def _initialize_computers(self):
288
300
  """Initialize computer objects"""
289
301
  if not self.tool_schemas:
290
302
  # Process tools and create tool schemas
291
303
  self.tool_schemas = self._process_tools()
292
-
304
+
293
305
  # Find computer tool and create interface adapter
294
306
  computer_handler = None
295
307
  for schema in self.tool_schemas:
@@ -297,7 +309,7 @@ class ComputerAgent:
297
309
  computer_handler = await make_computer_handler(schema["computer"])
298
310
  break
299
311
  self.computer_handler = computer_handler
300
-
312
+
301
313
  def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
302
314
  """Process input messages and create schemas for the agent loop"""
303
315
  if isinstance(input, str):
@@ -307,69 +319,73 @@ class ComputerAgent:
307
319
  def _process_tools(self) -> List[Dict[str, Any]]:
308
320
  """Process tools and create schemas for the agent loop"""
309
321
  schemas = []
310
-
322
+
311
323
  for tool in self.tools:
312
324
  # Check if it's a computer object (has interface attribute)
313
325
  if is_agent_computer(tool):
314
326
  # This is a computer tool - will be handled by agent loop
315
- schemas.append({
316
- "type": "computer",
317
- "computer": tool
318
- })
327
+ schemas.append({"type": "computer", "computer": tool})
319
328
  elif callable(tool):
320
329
  # Use litellm.utils.function_to_dict to extract schema from docstring
321
330
  try:
322
331
  function_schema = litellm.utils.function_to_dict(tool)
323
- schemas.append({
324
- "type": "function",
325
- "function": function_schema
326
- })
332
+ schemas.append({"type": "function", "function": function_schema})
327
333
  except Exception as e:
328
334
  print(f"Warning: Could not process tool {tool}: {e}")
329
335
  else:
330
336
  print(f"Warning: Unknown tool type: {tool}")
331
-
337
+
332
338
  return schemas
333
-
339
+
334
340
  def _get_tool(self, name: str) -> Optional[Callable]:
335
341
  """Get a tool by name"""
336
342
  for tool in self.tools:
337
- if hasattr(tool, '__name__') and tool.__name__ == name:
343
+ if hasattr(tool, "__name__") and tool.__name__ == name:
338
344
  return tool
339
- elif hasattr(tool, 'func') and tool.func.__name__ == name:
345
+ elif hasattr(tool, "func") and tool.func.__name__ == name:
340
346
  return tool
341
347
  return None
342
-
348
+
343
349
  # ============================================================================
344
350
  # AGENT RUN LOOP LIFECYCLE HOOKS
345
351
  # ============================================================================
346
-
352
+
347
353
  async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
348
354
  """Initialize run tracking by calling callbacks."""
349
355
  for callback in self.callbacks:
350
- if hasattr(callback, 'on_run_start'):
356
+ if hasattr(callback, "on_run_start"):
351
357
  await callback.on_run_start(kwargs, old_items)
352
-
353
- async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
358
+
359
+ async def _on_run_end(
360
+ self,
361
+ kwargs: Dict[str, Any],
362
+ old_items: List[Dict[str, Any]],
363
+ new_items: List[Dict[str, Any]],
364
+ ) -> None:
354
365
  """Finalize run tracking by calling callbacks."""
355
366
  for callback in self.callbacks:
356
- if hasattr(callback, 'on_run_end'):
367
+ if hasattr(callback, "on_run_end"):
357
368
  await callback.on_run_end(kwargs, old_items, new_items)
358
-
359
- async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
369
+
370
+ async def _on_run_continue(
371
+ self,
372
+ kwargs: Dict[str, Any],
373
+ old_items: List[Dict[str, Any]],
374
+ new_items: List[Dict[str, Any]],
375
+ ) -> bool:
360
376
  """Check if run should continue by calling callbacks."""
361
377
  for callback in self.callbacks:
362
- if hasattr(callback, 'on_run_continue'):
378
+ if hasattr(callback, "on_run_continue"):
363
379
  should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
364
380
  if not should_continue:
365
381
  return False
366
382
  return True
367
-
383
+
368
384
  async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
369
385
  """Prepare messages for the LLM call by applying callbacks."""
370
386
  result = messages
371
387
  for callback in self.callbacks:
372
- if hasattr(callback, 'on_llm_start'):
388
+ if hasattr(callback, "on_llm_start"):
373
389
  result = await callback.on_llm_start(result)
374
390
  return result
375
391
 
@@ -377,82 +393,91 @@ class ComputerAgent:
377
393
  """Postprocess messages after the LLM call by applying callbacks."""
378
394
  result = messages
379
395
  for callback in self.callbacks:
380
- if hasattr(callback, 'on_llm_end'):
396
+ if hasattr(callback, "on_llm_end"):
381
397
  result = await callback.on_llm_end(result)
382
398
  return result
383
399
 
384
400
  async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
385
401
  """Called when responses are received."""
386
402
  for callback in self.callbacks:
387
- if hasattr(callback, 'on_responses'):
403
+ if hasattr(callback, "on_responses"):
388
404
  await callback.on_responses(get_json(kwargs), get_json(responses))
389
-
405
+
390
406
  async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
391
407
  """Called when a computer call is about to start."""
392
408
  for callback in self.callbacks:
393
- if hasattr(callback, 'on_computer_call_start'):
409
+ if hasattr(callback, "on_computer_call_start"):
394
410
  await callback.on_computer_call_start(get_json(item))
395
-
396
- async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
411
+
412
+ async def _on_computer_call_end(
413
+ self, item: Dict[str, Any], result: List[Dict[str, Any]]
414
+ ) -> None:
397
415
  """Called when a computer call has completed."""
398
416
  for callback in self.callbacks:
399
- if hasattr(callback, 'on_computer_call_end'):
417
+ if hasattr(callback, "on_computer_call_end"):
400
418
  await callback.on_computer_call_end(get_json(item), get_json(result))
401
-
419
+
402
420
  async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
403
421
  """Called when a function call is about to start."""
404
422
  for callback in self.callbacks:
405
- if hasattr(callback, 'on_function_call_start'):
423
+ if hasattr(callback, "on_function_call_start"):
406
424
  await callback.on_function_call_start(get_json(item))
407
-
408
- async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
425
+
426
+ async def _on_function_call_end(
427
+ self, item: Dict[str, Any], result: List[Dict[str, Any]]
428
+ ) -> None:
409
429
  """Called when a function call has completed."""
410
430
  for callback in self.callbacks:
411
- if hasattr(callback, 'on_function_call_end'):
431
+ if hasattr(callback, "on_function_call_end"):
412
432
  await callback.on_function_call_end(get_json(item), get_json(result))
413
-
433
+
414
434
  async def _on_text(self, item: Dict[str, Any]) -> None:
415
435
  """Called when a text message is encountered."""
416
436
  for callback in self.callbacks:
417
- if hasattr(callback, 'on_text'):
437
+ if hasattr(callback, "on_text"):
418
438
  await callback.on_text(get_json(item))
419
-
439
+
420
440
  async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
421
441
  """Called when an LLM API call is about to start."""
422
442
  for callback in self.callbacks:
423
- if hasattr(callback, 'on_api_start'):
443
+ if hasattr(callback, "on_api_start"):
424
444
  await callback.on_api_start(get_json(kwargs))
425
-
445
+
426
446
  async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
427
447
  """Called when an LLM API call has completed."""
428
448
  for callback in self.callbacks:
429
- if hasattr(callback, 'on_api_end'):
449
+ if hasattr(callback, "on_api_end"):
430
450
  await callback.on_api_end(get_json(kwargs), get_json(result))
431
451
 
432
452
  async def _on_usage(self, usage: Dict[str, Any]) -> None:
433
453
  """Called when usage information is received."""
434
454
  for callback in self.callbacks:
435
- if hasattr(callback, 'on_usage'):
455
+ if hasattr(callback, "on_usage"):
436
456
  await callback.on_usage(get_json(usage))
437
457
 
438
458
  async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
439
459
  """Called when a screenshot is taken."""
440
460
  for callback in self.callbacks:
441
- if hasattr(callback, 'on_screenshot'):
461
+ if hasattr(callback, "on_screenshot"):
442
462
  await callback.on_screenshot(screenshot, name)
443
463
 
444
464
  # ============================================================================
445
465
  # AGENT OUTPUT PROCESSING
446
466
  # ============================================================================
447
-
448
- async def _handle_item(self, item: Any, computer: Optional[AsyncComputerHandler] = None, ignore_call_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
467
+
468
+ async def _handle_item(
469
+ self,
470
+ item: Any,
471
+ computer: Optional[AsyncComputerHandler] = None,
472
+ ignore_call_ids: Optional[List[str]] = None,
473
+ ) -> List[Dict[str, Any]]:
449
474
  """Handle each item; may cause a computer action + screenshot."""
450
475
  call_id = item.get("call_id")
451
476
  if ignore_call_ids and call_id and call_id in ignore_call_ids:
452
477
  return []
453
-
478
+
454
479
  item_type = item.get("type", None)
455
-
480
+
456
481
  if item_type == "message":
457
482
  await self._on_text(item)
458
483
  # # Print messages
@@ -461,7 +486,7 @@ class ComputerAgent:
461
486
  # if content_item.get("text"):
462
487
  # print(content_item.get("text"))
463
488
  return []
464
-
489
+
465
490
  try:
466
491
  if item_type == "computer_call":
467
492
  await self._on_computer_call_start(item)
@@ -472,14 +497,16 @@ class ComputerAgent:
472
497
  action = item.get("action")
473
498
  action_type = action.get("type")
474
499
  if action_type is None:
475
- print(f"Action type cannot be `None`: action={action}, action_type={action_type}")
500
+ print(
501
+ f"Action type cannot be `None`: action={action}, action_type={action_type}"
502
+ )
476
503
  return []
477
-
504
+
478
505
  # Extract action arguments (all fields except 'type')
479
506
  action_args = {k: v for k, v in action.items() if k != "type"}
480
-
507
+
481
508
  # print(f"{action_type}({action_args})")
482
-
509
+
483
510
  # Execute the computer action
484
511
  computer_method = getattr(computer, action_type, None)
485
512
  if computer_method:
@@ -487,13 +514,13 @@ class ComputerAgent:
487
514
  await computer_method(**action_args)
488
515
  else:
489
516
  raise ToolError(f"Unknown computer action: {action_type}")
490
-
517
+
491
518
  # Take screenshot after action
492
519
  if self.screenshot_delay and self.screenshot_delay > 0:
493
520
  await asyncio.sleep(self.screenshot_delay)
494
521
  screenshot_base64 = await computer.screenshot()
495
522
  await self._on_screenshot(screenshot_base64, "screenshot_after")
496
-
523
+
497
524
  # Handle safety checks
498
525
  pending_checks = item.get("pending_safety_checks", [])
499
526
  acknowledged_checks = []
@@ -505,7 +532,7 @@ class ComputerAgent:
505
532
  # acknowledged_checks.append(check)
506
533
  # else:
507
534
  # raise ValueError(f"Safety check failed: {check_message}")
508
-
535
+
509
536
  # Create call output
510
537
  call_output = {
511
538
  "type": "computer_call_output",
@@ -516,25 +543,25 @@ class ComputerAgent:
516
543
  "image_url": f"data:image/png;base64,{screenshot_base64}",
517
544
  },
518
545
  }
519
-
546
+
520
547
  # # Additional URL safety checks for browser environments
521
548
  # if await computer.get_environment() == "browser":
522
549
  # current_url = await computer.get_current_url()
523
550
  # call_output["output"]["current_url"] = current_url
524
551
  # # TODO: implement a callback for URL safety checks
525
552
  # # check_blocklisted_url(current_url)
526
-
553
+
527
554
  result = [call_output]
528
555
  await self._on_computer_call_end(item, result)
529
556
  return result
530
-
557
+
531
558
  if item_type == "function_call":
532
559
  await self._on_function_call_start(item)
533
560
  # Perform function call
534
561
  function = self._get_tool(item.get("name"))
535
562
  if not function:
536
- raise ToolError(f"Function {item.get("name")} not found")
537
-
563
+ raise ToolError(f"Function {item.get('name')} not found")
564
+
538
565
  args = json.loads(item.get("arguments"))
539
566
 
540
567
  # Validate arguments before execution
@@ -545,14 +572,14 @@ class ComputerAgent:
545
572
  result = await function(**args)
546
573
  else:
547
574
  result = await asyncio.to_thread(function, **args)
548
-
575
+
549
576
  # Create function call output
550
577
  call_output = {
551
578
  "type": "function_call_output",
552
579
  "call_id": item.get("call_id"),
553
580
  "output": str(result),
554
581
  }
555
-
582
+
556
583
  result = [call_output]
557
584
  await self._on_function_call_end(item, result)
558
585
  return result
@@ -564,36 +591,35 @@ class ComputerAgent:
564
591
  # ============================================================================
565
592
  # MAIN AGENT LOOP
566
593
  # ============================================================================
567
-
594
+
568
595
  async def run(
569
- self,
570
- messages: Messages,
571
- stream: bool = False,
572
- **kwargs
596
+ self, messages: Messages, stream: bool = False, **kwargs
573
597
  ) -> AsyncGenerator[Dict[str, Any], None]:
574
598
  """
575
599
  Run the agent with the given messages using Computer protocol handler pattern.
576
-
600
+
577
601
  Args:
578
602
  messages: List of message dictionaries
579
603
  stream: Whether to stream the response
580
604
  **kwargs: Additional arguments
581
-
605
+
582
606
  Returns:
583
607
  AsyncGenerator that yields response chunks
584
608
  """
585
609
  if not self.agent_config_info:
586
610
  raise ValueError("Agent configuration not found")
587
-
611
+
588
612
  capabilities = self.get_capabilities()
589
613
  if "step" not in capabilities:
590
- raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions")
614
+ raise ValueError(
615
+ f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions"
616
+ )
591
617
 
592
618
  await self._initialize_computers()
593
-
619
+
594
620
  # Merge kwargs
595
621
  merged_kwargs = {**self.kwargs, **kwargs}
596
-
622
+
597
623
  old_items = self._process_input(messages)
598
624
  new_items = []
599
625
 
@@ -603,7 +629,7 @@ class ComputerAgent:
603
629
  "stream": stream,
604
630
  "model": self.model,
605
631
  "agent_loop": self.agent_config_info.agent_class.__name__,
606
- **merged_kwargs
632
+ **merged_kwargs,
607
633
  }
608
634
  await self._on_run_start(run_kwargs, old_items)
609
635
 
@@ -620,7 +646,7 @@ class ComputerAgent:
620
646
  combined_messages = old_items + new_items
621
647
  combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
622
648
  preprocessed_messages = await self._on_llm_start(combined_messages)
623
-
649
+
624
650
  loop_kwargs = {
625
651
  "messages": preprocessed_messages,
626
652
  "model": self.model,
@@ -629,7 +655,7 @@ class ComputerAgent:
629
655
  "computer_handler": self.computer_handler,
630
656
  "max_retries": self.max_retries,
631
657
  "use_prompt_caching": self.use_prompt_caching,
632
- **merged_kwargs
658
+ **merged_kwargs,
633
659
  }
634
660
 
635
661
  # Run agent loop iteration
@@ -641,13 +667,13 @@ class ComputerAgent:
641
667
  _on_screenshot=self._on_screenshot,
642
668
  )
643
669
  result = get_json(result)
644
-
670
+
645
671
  # Lifecycle hook: Postprocess messages after the LLM call
646
672
  # Use cases:
647
673
  # - PII deanonymization (if you want tool calls to see PII)
648
674
  result["output"] = await self._on_llm_end(result.get("output", []))
649
675
  await self._on_responses(loop_kwargs, result)
650
-
676
+
651
677
  # Yield agent response
652
678
  yield result
653
679
 
@@ -659,7 +685,9 @@ class ComputerAgent:
659
685
 
660
686
  # Handle computer actions
661
687
  for item in result.get("output"):
662
- partial_items = await self._handle_item(item, self.computer_handler, ignore_call_ids=output_call_ids)
688
+ partial_items = await self._handle_item(
689
+ item, self.computer_handler, ignore_call_ids=output_call_ids
690
+ )
663
691
  new_items += partial_items
664
692
 
665
693
  # Yield partial response
@@ -669,54 +697,52 @@ class ComputerAgent:
669
697
  prompt_tokens=0,
670
698
  completion_tokens=0,
671
699
  total_tokens=0,
672
- )
700
+ ),
673
701
  }
674
-
702
+
675
703
  await self._on_run_end(loop_kwargs, old_items, new_items)
676
-
704
+
677
705
  async def predict_click(
678
- self,
679
- instruction: str,
680
- image_b64: Optional[str] = None
706
+ self, instruction: str, image_b64: Optional[str] = None
681
707
  ) -> Optional[Tuple[int, int]]:
682
708
  """
683
709
  Predict click coordinates based on image and instruction.
684
-
710
+
685
711
  Args:
686
712
  instruction: Instruction for where to click
687
713
  image_b64: Base64 encoded image (optional, will take screenshot if not provided)
688
-
714
+
689
715
  Returns:
690
716
  None or tuple with (x, y) coordinates
691
717
  """
692
718
  if not self.agent_config_info:
693
719
  raise ValueError("Agent configuration not found")
694
-
720
+
695
721
  capabilities = self.get_capabilities()
696
722
  if "click" not in capabilities:
697
- raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions")
698
- if hasattr(self.agent_loop, 'predict_click'):
723
+ raise ValueError(
724
+ f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions"
725
+ )
726
+ if hasattr(self.agent_loop, "predict_click"):
699
727
  if not image_b64:
700
728
  if not self.computer_handler:
701
729
  raise ValueError("Computer tool or image_b64 is required for predict_click")
702
730
  image_b64 = await self.computer_handler.screenshot()
703
731
  return await self.agent_loop.predict_click(
704
- model=self.model,
705
- image_b64=image_b64,
706
- instruction=instruction
732
+ model=self.model, image_b64=image_b64, instruction=instruction
707
733
  )
708
734
  return None
709
-
735
+
710
736
  def get_capabilities(self) -> List[AgentCapability]:
711
737
  """
712
738
  Get list of capabilities supported by the current agent config.
713
-
739
+
714
740
  Returns:
715
741
  List of capability strings (e.g., ["step", "click"])
716
742
  """
717
743
  if not self.agent_config_info:
718
744
  raise ValueError("Agent configuration not found")
719
-
720
- if hasattr(self.agent_loop, 'get_capabilities'):
745
+
746
+ if hasattr(self.agent_loop, "get_capabilities"):
721
747
  return self.agent_loop.get_capabilities()
722
- return ["step"] # Default capability
748
+ return ["step"] # Default capability