cua-agent 0.3.2__py3-none-any.whl → 0.4.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (111) hide show
  1. agent/__init__.py +15 -51
  2. agent/__main__.py +21 -0
  3. agent/adapters/__init__.py +9 -0
  4. agent/adapters/huggingfacelocal_adapter.py +229 -0
  5. agent/agent.py +577 -0
  6. agent/callbacks/__init__.py +17 -0
  7. agent/callbacks/base.py +153 -0
  8. agent/callbacks/budget_manager.py +44 -0
  9. agent/callbacks/image_retention.py +139 -0
  10. agent/callbacks/logging.py +247 -0
  11. agent/callbacks/pii_anonymization.py +259 -0
  12. agent/callbacks/trajectory_saver.py +305 -0
  13. agent/cli.py +290 -0
  14. agent/computer_handler.py +107 -0
  15. agent/decorators.py +90 -0
  16. agent/loops/__init__.py +11 -0
  17. agent/loops/anthropic.py +728 -0
  18. agent/loops/omniparser.py +339 -0
  19. agent/loops/openai.py +95 -0
  20. agent/loops/uitars.py +688 -0
  21. agent/responses.py +207 -0
  22. agent/types.py +79 -0
  23. agent/ui/__init__.py +7 -1
  24. agent/ui/gradio/__init__.py +6 -19
  25. agent/ui/gradio/app.py +80 -1299
  26. agent/ui/gradio/ui_components.py +703 -0
  27. cua_agent-0.4.0b2.dist-info/METADATA +424 -0
  28. cua_agent-0.4.0b2.dist-info/RECORD +30 -0
  29. agent/core/__init__.py +0 -27
  30. agent/core/agent.py +0 -210
  31. agent/core/base.py +0 -217
  32. agent/core/callbacks.py +0 -200
  33. agent/core/experiment.py +0 -249
  34. agent/core/factory.py +0 -122
  35. agent/core/messages.py +0 -332
  36. agent/core/provider_config.py +0 -21
  37. agent/core/telemetry.py +0 -142
  38. agent/core/tools/__init__.py +0 -21
  39. agent/core/tools/base.py +0 -74
  40. agent/core/tools/bash.py +0 -52
  41. agent/core/tools/collection.py +0 -46
  42. agent/core/tools/computer.py +0 -113
  43. agent/core/tools/edit.py +0 -67
  44. agent/core/tools/manager.py +0 -56
  45. agent/core/tools.py +0 -32
  46. agent/core/types.py +0 -88
  47. agent/core/visualization.py +0 -197
  48. agent/providers/__init__.py +0 -4
  49. agent/providers/anthropic/__init__.py +0 -6
  50. agent/providers/anthropic/api/client.py +0 -360
  51. agent/providers/anthropic/api/logging.py +0 -150
  52. agent/providers/anthropic/api_handler.py +0 -140
  53. agent/providers/anthropic/callbacks/__init__.py +0 -5
  54. agent/providers/anthropic/callbacks/manager.py +0 -65
  55. agent/providers/anthropic/loop.py +0 -568
  56. agent/providers/anthropic/prompts.py +0 -23
  57. agent/providers/anthropic/response_handler.py +0 -226
  58. agent/providers/anthropic/tools/__init__.py +0 -33
  59. agent/providers/anthropic/tools/base.py +0 -88
  60. agent/providers/anthropic/tools/bash.py +0 -66
  61. agent/providers/anthropic/tools/collection.py +0 -34
  62. agent/providers/anthropic/tools/computer.py +0 -396
  63. agent/providers/anthropic/tools/edit.py +0 -326
  64. agent/providers/anthropic/tools/manager.py +0 -54
  65. agent/providers/anthropic/tools/run.py +0 -42
  66. agent/providers/anthropic/types.py +0 -16
  67. agent/providers/anthropic/utils.py +0 -381
  68. agent/providers/omni/__init__.py +0 -8
  69. agent/providers/omni/api_handler.py +0 -42
  70. agent/providers/omni/clients/anthropic.py +0 -103
  71. agent/providers/omni/clients/base.py +0 -35
  72. agent/providers/omni/clients/oaicompat.py +0 -195
  73. agent/providers/omni/clients/ollama.py +0 -122
  74. agent/providers/omni/clients/openai.py +0 -155
  75. agent/providers/omni/clients/utils.py +0 -25
  76. agent/providers/omni/image_utils.py +0 -34
  77. agent/providers/omni/loop.py +0 -990
  78. agent/providers/omni/parser.py +0 -307
  79. agent/providers/omni/prompts.py +0 -64
  80. agent/providers/omni/tools/__init__.py +0 -30
  81. agent/providers/omni/tools/base.py +0 -29
  82. agent/providers/omni/tools/bash.py +0 -74
  83. agent/providers/omni/tools/computer.py +0 -179
  84. agent/providers/omni/tools/manager.py +0 -61
  85. agent/providers/omni/utils.py +0 -236
  86. agent/providers/openai/__init__.py +0 -6
  87. agent/providers/openai/api_handler.py +0 -456
  88. agent/providers/openai/loop.py +0 -472
  89. agent/providers/openai/response_handler.py +0 -205
  90. agent/providers/openai/tools/__init__.py +0 -15
  91. agent/providers/openai/tools/base.py +0 -79
  92. agent/providers/openai/tools/computer.py +0 -326
  93. agent/providers/openai/tools/manager.py +0 -106
  94. agent/providers/openai/types.py +0 -36
  95. agent/providers/openai/utils.py +0 -98
  96. agent/providers/uitars/__init__.py +0 -1
  97. agent/providers/uitars/clients/base.py +0 -35
  98. agent/providers/uitars/clients/mlxvlm.py +0 -263
  99. agent/providers/uitars/clients/oaicompat.py +0 -214
  100. agent/providers/uitars/loop.py +0 -660
  101. agent/providers/uitars/prompts.py +0 -63
  102. agent/providers/uitars/tools/__init__.py +0 -1
  103. agent/providers/uitars/tools/computer.py +0 -283
  104. agent/providers/uitars/tools/manager.py +0 -60
  105. agent/providers/uitars/utils.py +0 -264
  106. agent/telemetry.py +0 -21
  107. agent/ui/__main__.py +0 -15
  108. cua_agent-0.3.2.dist-info/METADATA +0 -295
  109. cua_agent-0.3.2.dist-info/RECORD +0 -87
  110. {cua_agent-0.3.2.dist-info → cua_agent-0.4.0b2.dist-info}/WHEEL +0 -0
  111. {cua_agent-0.3.2.dist-info → cua_agent-0.4.0b2.dist-info}/entry_points.txt +0 -0
agent/agent.py ADDED
@@ -0,0 +1,577 @@
1
+ """
2
+ ComputerAgent - Main agent class that selects and runs agent loops
3
+ """
4
+
5
+ import asyncio
6
+ from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set
7
+
8
+ from litellm.responses.utils import Usage
9
+ from .types import Messages, Computer
10
+ from .decorators import find_agent_loop
11
+ from .computer_handler import OpenAIComputerHandler, acknowledge_safety_check_callback, check_blocklisted_url
12
+ import json
13
+ import litellm
14
+ import litellm.utils
15
+ import inspect
16
+ from .adapters import HuggingFaceLocalAdapter
17
+ from .callbacks import ImageRetentionCallback, LoggingCallback, TrajectorySaverCallback, BudgetManagerCallback
18
+
19
+ def get_json(obj: Any, max_depth: int = 10) -> Any:
20
+ def custom_serializer(o: Any, depth: int = 0, seen: Set[int] = None) -> Any:
21
+ if seen is None:
22
+ seen = set()
23
+
24
+ # Use model_dump() if available
25
+ if hasattr(o, 'model_dump'):
26
+ return o.model_dump()
27
+
28
+ # Check depth limit
29
+ if depth > max_depth:
30
+ return f"<max_depth_exceeded:{max_depth}>"
31
+
32
+ # Check for circular references using object id
33
+ obj_id = id(o)
34
+ if obj_id in seen:
35
+ return f"<circular_reference:{type(o).__name__}>"
36
+
37
+ # Handle Computer objects
38
+ if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower():
39
+ return f"<computer:{o.__class__.__name__}>"
40
+
41
+ # Handle objects with __dict__
42
+ if hasattr(o, '__dict__'):
43
+ seen.add(obj_id)
44
+ try:
45
+ result = {}
46
+ for k, v in o.__dict__.items():
47
+ if v is not None:
48
+ # Recursively serialize with updated depth and seen set
49
+ serialized_value = custom_serializer(v, depth + 1, seen.copy())
50
+ result[k] = serialized_value
51
+ return result
52
+ finally:
53
+ seen.discard(obj_id)
54
+
55
+ # Handle common types that might contain nested objects
56
+ elif isinstance(o, dict):
57
+ seen.add(obj_id)
58
+ try:
59
+ return {
60
+ k: custom_serializer(v, depth + 1, seen.copy())
61
+ for k, v in o.items()
62
+ if v is not None
63
+ }
64
+ finally:
65
+ seen.discard(obj_id)
66
+
67
+ elif isinstance(o, (list, tuple, set)):
68
+ seen.add(obj_id)
69
+ try:
70
+ return [
71
+ custom_serializer(item, depth + 1, seen.copy())
72
+ for item in o
73
+ if item is not None
74
+ ]
75
+ finally:
76
+ seen.discard(obj_id)
77
+
78
+ # For basic types that json.dumps can handle
79
+ elif isinstance(o, (str, int, float, bool)) or o is None:
80
+ return o
81
+
82
+ # Fallback to string representation
83
+ else:
84
+ return str(o)
85
+
86
+ def remove_nones(obj: Any) -> Any:
87
+ if isinstance(obj, dict):
88
+ return {k: remove_nones(v) for k, v in obj.items() if v is not None}
89
+ elif isinstance(obj, list):
90
+ return [remove_nones(item) for item in obj if item is not None]
91
+ return obj
92
+
93
+ # Serialize with circular reference and depth protection
94
+ serialized = custom_serializer(obj)
95
+
96
+ # Convert to JSON string and back to ensure JSON compatibility
97
+ json_str = json.dumps(serialized)
98
+ parsed = json.loads(json_str)
99
+
100
+ # Final cleanup of any remaining None values
101
+ return remove_nones(parsed)
102
+
103
+ def sanitize_message(msg: Any) -> Any:
104
+ """Return a copy of the message with image_url omitted for computer_call_output messages."""
105
+ if msg.get("type") == "computer_call_output":
106
+ output = msg.get("output", {})
107
+ if isinstance(output, dict):
108
+ sanitized = msg.copy()
109
+ sanitized["output"] = {**output, "image_url": "[omitted]"}
110
+ return sanitized
111
+ return msg
112
+
113
+ class ComputerAgent:
114
+ """
115
+ Main agent class that automatically selects the appropriate agent loop
116
+ based on the model and executes tool calls.
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ model: str,
122
+ tools: Optional[List[Any]] = None,
123
+ custom_loop: Optional[Callable] = None,
124
+ only_n_most_recent_images: Optional[int] = None,
125
+ callbacks: Optional[List[Any]] = None,
126
+ verbosity: Optional[int] = None,
127
+ trajectory_dir: Optional[str] = None,
128
+ max_retries: Optional[int] = 3,
129
+ screenshot_delay: Optional[float | int] = 0.5,
130
+ use_prompt_caching: Optional[bool] = False,
131
+ max_trajectory_budget: Optional[float | dict] = None,
132
+ **kwargs
133
+ ):
134
+ """
135
+ Initialize ComputerAgent.
136
+
137
+ Args:
138
+ model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
139
+ tools: List of tools (computer objects, decorated functions, etc.)
140
+ custom_loop: Custom agent loop function to use instead of auto-selection
141
+ only_n_most_recent_images: If set, only keep the N most recent images in message history. Adds ImageRetentionCallback automatically.
142
+ callbacks: List of AsyncCallbackHandler instances for preprocessing/postprocessing
143
+ verbosity: Logging level (logging.DEBUG, logging.INFO, etc.). If set, adds LoggingCallback automatically
144
+ trajectory_dir: If set, saves trajectory data (screenshots, responses) to this directory. Adds TrajectorySaverCallback automatically.
145
+ max_retries: Maximum number of retries for failed API calls
146
+ screenshot_delay: Delay before screenshots in seconds
147
+ use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers.
148
+ max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
149
+ **kwargs: Additional arguments passed to the agent loop
150
+ """
151
+ self.model = model
152
+ self.tools = tools or []
153
+ self.custom_loop = custom_loop
154
+ self.only_n_most_recent_images = only_n_most_recent_images
155
+ self.callbacks = callbacks or []
156
+ self.verbosity = verbosity
157
+ self.trajectory_dir = trajectory_dir
158
+ self.max_retries = max_retries
159
+ self.screenshot_delay = screenshot_delay
160
+ self.use_prompt_caching = use_prompt_caching
161
+ self.kwargs = kwargs
162
+
163
+ # == Add built-in callbacks ==
164
+
165
+ # Add logging callback if verbosity is set
166
+ if self.verbosity is not None:
167
+ self.callbacks.append(LoggingCallback(level=self.verbosity))
168
+
169
+ # Add image retention callback if only_n_most_recent_images is set
170
+ if self.only_n_most_recent_images:
171
+ self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
172
+
173
+ # Add trajectory saver callback if trajectory_dir is set
174
+ if self.trajectory_dir:
175
+ self.callbacks.append(TrajectorySaverCallback(self.trajectory_dir))
176
+
177
+ # Add budget manager if max_trajectory_budget is set
178
+ if max_trajectory_budget:
179
+ if isinstance(max_trajectory_budget, dict):
180
+ self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
181
+ else:
182
+ self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
183
+
184
+ # == Enable local model providers w/ LiteLLM ==
185
+
186
+ # Register local model providers
187
+ hf_adapter = HuggingFaceLocalAdapter(
188
+ device="auto"
189
+ )
190
+ litellm.custom_provider_map = [
191
+ {"provider": "huggingface-local", "custom_handler": hf_adapter}
192
+ ]
193
+
194
+ # == Initialize computer agent ==
195
+
196
+ # Find the appropriate agent loop
197
+ if custom_loop:
198
+ self.agent_loop = custom_loop
199
+ self.agent_loop_info = None
200
+ else:
201
+ loop_info = find_agent_loop(model)
202
+ if not loop_info:
203
+ raise ValueError(f"No agent loop found for model: {model}")
204
+ self.agent_loop = loop_info.func
205
+ self.agent_loop_info = loop_info
206
+
207
+ self.tool_schemas = []
208
+ self.computer_handler = None
209
+
210
+ async def _initialize_computers(self):
211
+ """Initialize computer objects"""
212
+ if not self.tool_schemas:
213
+ for tool in self.tools:
214
+ if hasattr(tool, '_initialized') and not tool._initialized:
215
+ await tool.run()
216
+
217
+ # Process tools and create tool schemas
218
+ self.tool_schemas = self._process_tools()
219
+
220
+ # Find computer tool and create interface adapter
221
+ computer_handler = None
222
+ for schema in self.tool_schemas:
223
+ if schema["type"] == "computer":
224
+ computer_handler = OpenAIComputerHandler(schema["computer"].interface)
225
+ break
226
+ self.computer_handler = computer_handler
227
+
228
+ def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
229
+ """Process input messages and create schemas for the agent loop"""
230
+ if isinstance(input, str):
231
+ return [{"role": "user", "content": input}]
232
+ return [get_json(msg) for msg in input]
233
+
234
+ def _process_tools(self) -> List[Dict[str, Any]]:
235
+ """Process tools and create schemas for the agent loop"""
236
+ schemas = []
237
+
238
+ for tool in self.tools:
239
+ # Check if it's a computer object (has interface attribute)
240
+ if hasattr(tool, 'interface'):
241
+ # This is a computer tool - will be handled by agent loop
242
+ schemas.append({
243
+ "type": "computer",
244
+ "computer": tool
245
+ })
246
+ elif callable(tool):
247
+ # Use litellm.utils.function_to_dict to extract schema from docstring
248
+ try:
249
+ function_schema = litellm.utils.function_to_dict(tool)
250
+ schemas.append({
251
+ "type": "function",
252
+ "function": function_schema
253
+ })
254
+ except Exception as e:
255
+ print(f"Warning: Could not process tool {tool}: {e}")
256
+ else:
257
+ print(f"Warning: Unknown tool type: {tool}")
258
+
259
+ return schemas
260
+
261
+ def _get_tool(self, name: str) -> Optional[Callable]:
262
+ """Get a tool by name"""
263
+ for tool in self.tools:
264
+ if hasattr(tool, '__name__') and tool.__name__ == name:
265
+ return tool
266
+ elif hasattr(tool, 'func') and tool.func.__name__ == name:
267
+ return tool
268
+ return None
269
+
270
+ # ============================================================================
271
+ # AGENT RUN LOOP LIFECYCLE HOOKS
272
+ # ============================================================================
273
+
274
+ async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
275
+ """Initialize run tracking by calling callbacks."""
276
+ for callback in self.callbacks:
277
+ if hasattr(callback, 'on_run_start'):
278
+ await callback.on_run_start(kwargs, old_items)
279
+
280
+ async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
281
+ """Finalize run tracking by calling callbacks."""
282
+ for callback in self.callbacks:
283
+ if hasattr(callback, 'on_run_end'):
284
+ await callback.on_run_end(kwargs, old_items, new_items)
285
+
286
+ async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
287
+ """Check if run should continue by calling callbacks."""
288
+ for callback in self.callbacks:
289
+ if hasattr(callback, 'on_run_continue'):
290
+ should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
291
+ if not should_continue:
292
+ return False
293
+ return True
294
+
295
+ async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
296
+ """Prepare messages for the LLM call by applying callbacks."""
297
+ result = messages
298
+ for callback in self.callbacks:
299
+ if hasattr(callback, 'on_llm_start'):
300
+ result = await callback.on_llm_start(result)
301
+ return result
302
+
303
+ async def _on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
304
+ """Postprocess messages after the LLM call by applying callbacks."""
305
+ result = messages
306
+ for callback in self.callbacks:
307
+ if hasattr(callback, 'on_llm_end'):
308
+ result = await callback.on_llm_end(result)
309
+ return result
310
+
311
+ async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
312
+ """Called when responses are received."""
313
+ for callback in self.callbacks:
314
+ if hasattr(callback, 'on_responses'):
315
+ await callback.on_responses(get_json(kwargs), get_json(responses))
316
+
317
+ async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
318
+ """Called when a computer call is about to start."""
319
+ for callback in self.callbacks:
320
+ if hasattr(callback, 'on_computer_call_start'):
321
+ await callback.on_computer_call_start(get_json(item))
322
+
323
+ async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
324
+ """Called when a computer call has completed."""
325
+ for callback in self.callbacks:
326
+ if hasattr(callback, 'on_computer_call_end'):
327
+ await callback.on_computer_call_end(get_json(item), get_json(result))
328
+
329
+ async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
330
+ """Called when a function call is about to start."""
331
+ for callback in self.callbacks:
332
+ if hasattr(callback, 'on_function_call_start'):
333
+ await callback.on_function_call_start(get_json(item))
334
+
335
+ async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
336
+ """Called when a function call has completed."""
337
+ for callback in self.callbacks:
338
+ if hasattr(callback, 'on_function_call_end'):
339
+ await callback.on_function_call_end(get_json(item), get_json(result))
340
+
341
+ async def _on_text(self, item: Dict[str, Any]) -> None:
342
+ """Called when a text message is encountered."""
343
+ for callback in self.callbacks:
344
+ if hasattr(callback, 'on_text'):
345
+ await callback.on_text(get_json(item))
346
+
347
+ async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
348
+ """Called when an LLM API call is about to start."""
349
+ for callback in self.callbacks:
350
+ if hasattr(callback, 'on_api_start'):
351
+ await callback.on_api_start(get_json(kwargs))
352
+
353
+ async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
354
+ """Called when an LLM API call has completed."""
355
+ for callback in self.callbacks:
356
+ if hasattr(callback, 'on_api_end'):
357
+ await callback.on_api_end(get_json(kwargs), get_json(result))
358
+
359
+ async def _on_usage(self, usage: Dict[str, Any]) -> None:
360
+ """Called when usage information is received."""
361
+ for callback in self.callbacks:
362
+ if hasattr(callback, 'on_usage'):
363
+ await callback.on_usage(get_json(usage))
364
+
365
+ async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
366
+ """Called when a screenshot is taken."""
367
+ for callback in self.callbacks:
368
+ if hasattr(callback, 'on_screenshot'):
369
+ await callback.on_screenshot(screenshot, name)
370
+
371
+ # ============================================================================
372
+ # AGENT OUTPUT PROCESSING
373
+ # ============================================================================
374
+
375
+ async def _handle_item(self, item: Any, computer: Optional[Computer] = None) -> List[Dict[str, Any]]:
376
+ """Handle each item; may cause a computer action + screenshot."""
377
+
378
+ item_type = item.get("type", None)
379
+
380
+ if item_type == "message":
381
+ await self._on_text(item)
382
+ # # Print messages
383
+ # if item.get("content"):
384
+ # for content_item in item.get("content"):
385
+ # if content_item.get("text"):
386
+ # print(content_item.get("text"))
387
+ return []
388
+
389
+ if item_type == "computer_call":
390
+ await self._on_computer_call_start(item)
391
+ if not computer:
392
+ raise ValueError("Computer handler is required for computer calls")
393
+
394
+ # Perform computer actions
395
+ action = item.get("action")
396
+ action_type = action.get("type")
397
+
398
+ # Extract action arguments (all fields except 'type')
399
+ action_args = {k: v for k, v in action.items() if k != "type"}
400
+
401
+ # print(f"{action_type}({action_args})")
402
+
403
+ # Execute the computer action
404
+ computer_method = getattr(computer, action_type, None)
405
+ if computer_method:
406
+ await computer_method(**action_args)
407
+ else:
408
+ print(f"Unknown computer action: {action_type}")
409
+ return []
410
+
411
+ # Take screenshot after action
412
+ if self.screenshot_delay and self.screenshot_delay > 0:
413
+ await asyncio.sleep(self.screenshot_delay)
414
+ screenshot_base64 = await computer.screenshot()
415
+ await self._on_screenshot(screenshot_base64, "screenshot_after")
416
+
417
+ # Handle safety checks
418
+ pending_checks = item.get("pending_safety_checks", [])
419
+ acknowledged_checks = []
420
+ for check in pending_checks:
421
+ check_message = check.get("message", str(check))
422
+ if acknowledge_safety_check_callback(check_message):
423
+ acknowledged_checks.append(check)
424
+ else:
425
+ raise ValueError(f"Safety check failed: {check_message}")
426
+
427
+ # Create call output
428
+ call_output = {
429
+ "type": "computer_call_output",
430
+ "call_id": item.get("call_id"),
431
+ "acknowledged_safety_checks": acknowledged_checks,
432
+ "output": {
433
+ "type": "input_image",
434
+ "image_url": f"data:image/png;base64,{screenshot_base64}",
435
+ },
436
+ }
437
+
438
+ # Additional URL safety checks for browser environments
439
+ if await computer.get_environment() == "browser":
440
+ current_url = await computer.get_current_url()
441
+ call_output["output"]["current_url"] = current_url
442
+ check_blocklisted_url(current_url)
443
+
444
+ result = [call_output]
445
+ await self._on_computer_call_end(item, result)
446
+ return result
447
+
448
+ if item_type == "function_call":
449
+ await self._on_function_call_start(item)
450
+ # Perform function call
451
+ function = self._get_tool(item.get("name"))
452
+ if not function:
453
+ raise ValueError(f"Function {item.get("name")} not found")
454
+
455
+ args = json.loads(item.get("arguments"))
456
+
457
+ # Execute function - use asyncio.to_thread for non-async functions
458
+ if inspect.iscoroutinefunction(function):
459
+ result = await function(**args)
460
+ else:
461
+ result = await asyncio.to_thread(function, **args)
462
+
463
+ # Create function call output
464
+ call_output = {
465
+ "type": "function_call_output",
466
+ "call_id": item.get("call_id"),
467
+ "output": str(result),
468
+ }
469
+
470
+ result = [call_output]
471
+ await self._on_function_call_end(item, result)
472
+ return result
473
+
474
+ return []
475
+
476
+ # ============================================================================
477
+ # MAIN AGENT LOOP
478
+ # ============================================================================
479
+
480
+ async def run(
481
+ self,
482
+ messages: Messages,
483
+ stream: bool = False,
484
+ **kwargs
485
+ ) -> AsyncGenerator[Dict[str, Any], None]:
486
+ """
487
+ Run the agent with the given messages using Computer protocol handler pattern.
488
+
489
+ Args:
490
+ messages: List of message dictionaries
491
+ stream: Whether to stream the response
492
+ **kwargs: Additional arguments
493
+
494
+ Returns:
495
+ AsyncGenerator that yields response chunks
496
+ """
497
+
498
+ await self._initialize_computers()
499
+
500
+ # Merge kwargs
501
+ merged_kwargs = {**self.kwargs, **kwargs}
502
+
503
+ old_items = self._process_input(messages)
504
+ new_items = []
505
+
506
+ # Initialize run tracking
507
+ run_kwargs = {
508
+ "messages": messages,
509
+ "stream": stream,
510
+ "model": self.model,
511
+ "agent_loop": self.agent_loop.__name__,
512
+ **merged_kwargs
513
+ }
514
+ await self._on_run_start(run_kwargs, old_items)
515
+
516
+ while new_items[-1].get("role") != "assistant" if new_items else True:
517
+ # Lifecycle hook: Check if we should continue based on callbacks (e.g., budget manager)
518
+ should_continue = await self._on_run_continue(run_kwargs, old_items, new_items)
519
+ if not should_continue:
520
+ break
521
+
522
+ # Lifecycle hook: Prepare messages for the LLM call
523
+ # Use cases:
524
+ # - PII anonymization
525
+ # - Image retention policy
526
+ combined_messages = old_items + new_items
527
+ preprocessed_messages = await self._on_llm_start(combined_messages)
528
+
529
+ loop_kwargs = {
530
+ "messages": preprocessed_messages,
531
+ "model": self.model,
532
+ "tools": self.tool_schemas,
533
+ "stream": False,
534
+ "computer_handler": self.computer_handler,
535
+ "max_retries": self.max_retries,
536
+ "use_prompt_caching": self.use_prompt_caching,
537
+ **merged_kwargs
538
+ }
539
+
540
+ # Run agent loop iteration
541
+ result = await self.agent_loop(
542
+ **loop_kwargs,
543
+ _on_api_start=self._on_api_start,
544
+ _on_api_end=self._on_api_end,
545
+ _on_usage=self._on_usage,
546
+ _on_screenshot=self._on_screenshot,
547
+ )
548
+ result = get_json(result)
549
+
550
+ # Lifecycle hook: Postprocess messages after the LLM call
551
+ # Use cases:
552
+ # - PII deanonymization (if you want tool calls to see PII)
553
+ result["output"] = await self._on_llm_end(result.get("output", []))
554
+ await self._on_responses(loop_kwargs, result)
555
+
556
+ # Yield agent response
557
+ yield result
558
+
559
+ # Add agent response to new_items
560
+ new_items += result.get("output")
561
+
562
+ # Handle computer actions
563
+ for item in result.get("output"):
564
+ partial_items = await self._handle_item(item, self.computer_handler)
565
+ new_items += partial_items
566
+
567
+ # Yield partial response
568
+ yield {
569
+ "output": partial_items,
570
+ "usage": Usage(
571
+ prompt_tokens=0,
572
+ completion_tokens=0,
573
+ total_tokens=0,
574
+ )
575
+ }
576
+
577
+ await self._on_run_end(loop_kwargs, old_items, new_items)
@@ -0,0 +1,17 @@
1
+ """
2
+ Callback system for ComputerAgent preprocessing and postprocessing hooks.
3
+ """
4
+
5
+ from .base import AsyncCallbackHandler
6
+ from .image_retention import ImageRetentionCallback
7
+ from .logging import LoggingCallback
8
+ from .trajectory_saver import TrajectorySaverCallback
9
+ from .budget_manager import BudgetManagerCallback
10
+
11
+ __all__ = [
12
+ "AsyncCallbackHandler",
13
+ "ImageRetentionCallback",
14
+ "LoggingCallback",
15
+ "TrajectorySaverCallback",
16
+ "BudgetManagerCallback",
17
+ ]