cua-agent 0.3.2__py3-none-any.whl → 0.4.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +15 -51
- agent/__main__.py +21 -0
- agent/adapters/__init__.py +9 -0
- agent/adapters/huggingfacelocal_adapter.py +216 -0
- agent/agent.py +577 -0
- agent/callbacks/__init__.py +17 -0
- agent/callbacks/base.py +153 -0
- agent/callbacks/budget_manager.py +44 -0
- agent/callbacks/image_retention.py +139 -0
- agent/callbacks/logging.py +247 -0
- agent/callbacks/pii_anonymization.py +259 -0
- agent/callbacks/trajectory_saver.py +305 -0
- agent/cli.py +290 -0
- agent/computer_handler.py +107 -0
- agent/decorators.py +90 -0
- agent/loops/__init__.py +11 -0
- agent/loops/anthropic.py +728 -0
- agent/loops/omniparser.py +339 -0
- agent/loops/openai.py +95 -0
- agent/loops/uitars.py +688 -0
- agent/responses.py +207 -0
- agent/types.py +79 -0
- agent/ui/__init__.py +7 -1
- agent/ui/gradio/__init__.py +6 -19
- agent/ui/gradio/app.py +80 -1299
- agent/ui/gradio/ui_components.py +703 -0
- cua_agent-0.4.0b1.dist-info/METADATA +424 -0
- cua_agent-0.4.0b1.dist-info/RECORD +30 -0
- agent/core/__init__.py +0 -27
- agent/core/agent.py +0 -210
- agent/core/base.py +0 -217
- agent/core/callbacks.py +0 -200
- agent/core/experiment.py +0 -249
- agent/core/factory.py +0 -122
- agent/core/messages.py +0 -332
- agent/core/provider_config.py +0 -21
- agent/core/telemetry.py +0 -142
- agent/core/tools/__init__.py +0 -21
- agent/core/tools/base.py +0 -74
- agent/core/tools/bash.py +0 -52
- agent/core/tools/collection.py +0 -46
- agent/core/tools/computer.py +0 -113
- agent/core/tools/edit.py +0 -67
- agent/core/tools/manager.py +0 -56
- agent/core/tools.py +0 -32
- agent/core/types.py +0 -88
- agent/core/visualization.py +0 -197
- agent/providers/__init__.py +0 -4
- agent/providers/anthropic/__init__.py +0 -6
- agent/providers/anthropic/api/client.py +0 -360
- agent/providers/anthropic/api/logging.py +0 -150
- agent/providers/anthropic/api_handler.py +0 -140
- agent/providers/anthropic/callbacks/__init__.py +0 -5
- agent/providers/anthropic/callbacks/manager.py +0 -65
- agent/providers/anthropic/loop.py +0 -568
- agent/providers/anthropic/prompts.py +0 -23
- agent/providers/anthropic/response_handler.py +0 -226
- agent/providers/anthropic/tools/__init__.py +0 -33
- agent/providers/anthropic/tools/base.py +0 -88
- agent/providers/anthropic/tools/bash.py +0 -66
- agent/providers/anthropic/tools/collection.py +0 -34
- agent/providers/anthropic/tools/computer.py +0 -396
- agent/providers/anthropic/tools/edit.py +0 -326
- agent/providers/anthropic/tools/manager.py +0 -54
- agent/providers/anthropic/tools/run.py +0 -42
- agent/providers/anthropic/types.py +0 -16
- agent/providers/anthropic/utils.py +0 -381
- agent/providers/omni/__init__.py +0 -8
- agent/providers/omni/api_handler.py +0 -42
- agent/providers/omni/clients/anthropic.py +0 -103
- agent/providers/omni/clients/base.py +0 -35
- agent/providers/omni/clients/oaicompat.py +0 -195
- agent/providers/omni/clients/ollama.py +0 -122
- agent/providers/omni/clients/openai.py +0 -155
- agent/providers/omni/clients/utils.py +0 -25
- agent/providers/omni/image_utils.py +0 -34
- agent/providers/omni/loop.py +0 -990
- agent/providers/omni/parser.py +0 -307
- agent/providers/omni/prompts.py +0 -64
- agent/providers/omni/tools/__init__.py +0 -30
- agent/providers/omni/tools/base.py +0 -29
- agent/providers/omni/tools/bash.py +0 -74
- agent/providers/omni/tools/computer.py +0 -179
- agent/providers/omni/tools/manager.py +0 -61
- agent/providers/omni/utils.py +0 -236
- agent/providers/openai/__init__.py +0 -6
- agent/providers/openai/api_handler.py +0 -456
- agent/providers/openai/loop.py +0 -472
- agent/providers/openai/response_handler.py +0 -205
- agent/providers/openai/tools/__init__.py +0 -15
- agent/providers/openai/tools/base.py +0 -79
- agent/providers/openai/tools/computer.py +0 -326
- agent/providers/openai/tools/manager.py +0 -106
- agent/providers/openai/types.py +0 -36
- agent/providers/openai/utils.py +0 -98
- agent/providers/uitars/__init__.py +0 -1
- agent/providers/uitars/clients/base.py +0 -35
- agent/providers/uitars/clients/mlxvlm.py +0 -263
- agent/providers/uitars/clients/oaicompat.py +0 -214
- agent/providers/uitars/loop.py +0 -660
- agent/providers/uitars/prompts.py +0 -63
- agent/providers/uitars/tools/__init__.py +0 -1
- agent/providers/uitars/tools/computer.py +0 -283
- agent/providers/uitars/tools/manager.py +0 -60
- agent/providers/uitars/utils.py +0 -264
- agent/telemetry.py +0 -21
- agent/ui/__main__.py +0 -15
- cua_agent-0.3.2.dist-info/METADATA +0 -295
- cua_agent-0.3.2.dist-info/RECORD +0 -87
- {cua_agent-0.3.2.dist-info → cua_agent-0.4.0b1.dist-info}/WHEEL +0 -0
- {cua_agent-0.3.2.dist-info → cua_agent-0.4.0b1.dist-info}/entry_points.txt +0 -0
agent/agent.py
ADDED
|
@@ -0,0 +1,577 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ComputerAgent - Main agent class that selects and runs agent loops
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set
|
|
7
|
+
|
|
8
|
+
from litellm.responses.utils import Usage
|
|
9
|
+
from .types import Messages, Computer
|
|
10
|
+
from .decorators import find_agent_loop
|
|
11
|
+
from .computer_handler import OpenAIComputerHandler, acknowledge_safety_check_callback, check_blocklisted_url
|
|
12
|
+
import json
|
|
13
|
+
import litellm
|
|
14
|
+
import litellm.utils
|
|
15
|
+
import inspect
|
|
16
|
+
from .adapters import HuggingFaceLocalAdapter
|
|
17
|
+
from .callbacks import ImageRetentionCallback, LoggingCallback, TrajectorySaverCallback, BudgetManagerCallback
|
|
18
|
+
|
|
19
|
+
def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
20
|
+
def custom_serializer(o: Any, depth: int = 0, seen: Set[int] = None) -> Any:
|
|
21
|
+
if seen is None:
|
|
22
|
+
seen = set()
|
|
23
|
+
|
|
24
|
+
# Use model_dump() if available
|
|
25
|
+
if hasattr(o, 'model_dump'):
|
|
26
|
+
return o.model_dump()
|
|
27
|
+
|
|
28
|
+
# Check depth limit
|
|
29
|
+
if depth > max_depth:
|
|
30
|
+
return f"<max_depth_exceeded:{max_depth}>"
|
|
31
|
+
|
|
32
|
+
# Check for circular references using object id
|
|
33
|
+
obj_id = id(o)
|
|
34
|
+
if obj_id in seen:
|
|
35
|
+
return f"<circular_reference:{type(o).__name__}>"
|
|
36
|
+
|
|
37
|
+
# Handle Computer objects
|
|
38
|
+
if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower():
|
|
39
|
+
return f"<computer:{o.__class__.__name__}>"
|
|
40
|
+
|
|
41
|
+
# Handle objects with __dict__
|
|
42
|
+
if hasattr(o, '__dict__'):
|
|
43
|
+
seen.add(obj_id)
|
|
44
|
+
try:
|
|
45
|
+
result = {}
|
|
46
|
+
for k, v in o.__dict__.items():
|
|
47
|
+
if v is not None:
|
|
48
|
+
# Recursively serialize with updated depth and seen set
|
|
49
|
+
serialized_value = custom_serializer(v, depth + 1, seen.copy())
|
|
50
|
+
result[k] = serialized_value
|
|
51
|
+
return result
|
|
52
|
+
finally:
|
|
53
|
+
seen.discard(obj_id)
|
|
54
|
+
|
|
55
|
+
# Handle common types that might contain nested objects
|
|
56
|
+
elif isinstance(o, dict):
|
|
57
|
+
seen.add(obj_id)
|
|
58
|
+
try:
|
|
59
|
+
return {
|
|
60
|
+
k: custom_serializer(v, depth + 1, seen.copy())
|
|
61
|
+
for k, v in o.items()
|
|
62
|
+
if v is not None
|
|
63
|
+
}
|
|
64
|
+
finally:
|
|
65
|
+
seen.discard(obj_id)
|
|
66
|
+
|
|
67
|
+
elif isinstance(o, (list, tuple, set)):
|
|
68
|
+
seen.add(obj_id)
|
|
69
|
+
try:
|
|
70
|
+
return [
|
|
71
|
+
custom_serializer(item, depth + 1, seen.copy())
|
|
72
|
+
for item in o
|
|
73
|
+
if item is not None
|
|
74
|
+
]
|
|
75
|
+
finally:
|
|
76
|
+
seen.discard(obj_id)
|
|
77
|
+
|
|
78
|
+
# For basic types that json.dumps can handle
|
|
79
|
+
elif isinstance(o, (str, int, float, bool)) or o is None:
|
|
80
|
+
return o
|
|
81
|
+
|
|
82
|
+
# Fallback to string representation
|
|
83
|
+
else:
|
|
84
|
+
return str(o)
|
|
85
|
+
|
|
86
|
+
def remove_nones(obj: Any) -> Any:
|
|
87
|
+
if isinstance(obj, dict):
|
|
88
|
+
return {k: remove_nones(v) for k, v in obj.items() if v is not None}
|
|
89
|
+
elif isinstance(obj, list):
|
|
90
|
+
return [remove_nones(item) for item in obj if item is not None]
|
|
91
|
+
return obj
|
|
92
|
+
|
|
93
|
+
# Serialize with circular reference and depth protection
|
|
94
|
+
serialized = custom_serializer(obj)
|
|
95
|
+
|
|
96
|
+
# Convert to JSON string and back to ensure JSON compatibility
|
|
97
|
+
json_str = json.dumps(serialized)
|
|
98
|
+
parsed = json.loads(json_str)
|
|
99
|
+
|
|
100
|
+
# Final cleanup of any remaining None values
|
|
101
|
+
return remove_nones(parsed)
|
|
102
|
+
|
|
103
|
+
def sanitize_message(msg: Any) -> Any:
|
|
104
|
+
"""Return a copy of the message with image_url omitted for computer_call_output messages."""
|
|
105
|
+
if msg.get("type") == "computer_call_output":
|
|
106
|
+
output = msg.get("output", {})
|
|
107
|
+
if isinstance(output, dict):
|
|
108
|
+
sanitized = msg.copy()
|
|
109
|
+
sanitized["output"] = {**output, "image_url": "[omitted]"}
|
|
110
|
+
return sanitized
|
|
111
|
+
return msg
|
|
112
|
+
|
|
113
|
+
class ComputerAgent:
|
|
114
|
+
"""
|
|
115
|
+
Main agent class that automatically selects the appropriate agent loop
|
|
116
|
+
based on the model and executes tool calls.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
model: str,
|
|
122
|
+
tools: Optional[List[Any]] = None,
|
|
123
|
+
custom_loop: Optional[Callable] = None,
|
|
124
|
+
only_n_most_recent_images: Optional[int] = None,
|
|
125
|
+
callbacks: Optional[List[Any]] = None,
|
|
126
|
+
verbosity: Optional[int] = None,
|
|
127
|
+
trajectory_dir: Optional[str] = None,
|
|
128
|
+
max_retries: Optional[int] = 3,
|
|
129
|
+
screenshot_delay: Optional[float | int] = 0.5,
|
|
130
|
+
use_prompt_caching: Optional[bool] = False,
|
|
131
|
+
max_trajectory_budget: Optional[float | dict] = None,
|
|
132
|
+
**kwargs
|
|
133
|
+
):
|
|
134
|
+
"""
|
|
135
|
+
Initialize ComputerAgent.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
|
|
139
|
+
tools: List of tools (computer objects, decorated functions, etc.)
|
|
140
|
+
custom_loop: Custom agent loop function to use instead of auto-selection
|
|
141
|
+
only_n_most_recent_images: If set, only keep the N most recent images in message history. Adds ImageRetentionCallback automatically.
|
|
142
|
+
callbacks: List of AsyncCallbackHandler instances for preprocessing/postprocessing
|
|
143
|
+
verbosity: Logging level (logging.DEBUG, logging.INFO, etc.). If set, adds LoggingCallback automatically
|
|
144
|
+
trajectory_dir: If set, saves trajectory data (screenshots, responses) to this directory. Adds TrajectorySaverCallback automatically.
|
|
145
|
+
max_retries: Maximum number of retries for failed API calls
|
|
146
|
+
screenshot_delay: Delay before screenshots in seconds
|
|
147
|
+
use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers.
|
|
148
|
+
max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
|
|
149
|
+
**kwargs: Additional arguments passed to the agent loop
|
|
150
|
+
"""
|
|
151
|
+
self.model = model
|
|
152
|
+
self.tools = tools or []
|
|
153
|
+
self.custom_loop = custom_loop
|
|
154
|
+
self.only_n_most_recent_images = only_n_most_recent_images
|
|
155
|
+
self.callbacks = callbacks or []
|
|
156
|
+
self.verbosity = verbosity
|
|
157
|
+
self.trajectory_dir = trajectory_dir
|
|
158
|
+
self.max_retries = max_retries
|
|
159
|
+
self.screenshot_delay = screenshot_delay
|
|
160
|
+
self.use_prompt_caching = use_prompt_caching
|
|
161
|
+
self.kwargs = kwargs
|
|
162
|
+
|
|
163
|
+
# == Add built-in callbacks ==
|
|
164
|
+
|
|
165
|
+
# Add logging callback if verbosity is set
|
|
166
|
+
if self.verbosity is not None:
|
|
167
|
+
self.callbacks.append(LoggingCallback(level=self.verbosity))
|
|
168
|
+
|
|
169
|
+
# Add image retention callback if only_n_most_recent_images is set
|
|
170
|
+
if self.only_n_most_recent_images:
|
|
171
|
+
self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
|
|
172
|
+
|
|
173
|
+
# Add trajectory saver callback if trajectory_dir is set
|
|
174
|
+
if self.trajectory_dir:
|
|
175
|
+
self.callbacks.append(TrajectorySaverCallback(self.trajectory_dir))
|
|
176
|
+
|
|
177
|
+
# Add budget manager if max_trajectory_budget is set
|
|
178
|
+
if max_trajectory_budget:
|
|
179
|
+
if isinstance(max_trajectory_budget, dict):
|
|
180
|
+
self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
|
|
181
|
+
else:
|
|
182
|
+
self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
|
|
183
|
+
|
|
184
|
+
# == Enable local model providers w/ LiteLLM ==
|
|
185
|
+
|
|
186
|
+
# Register local model providers
|
|
187
|
+
hf_adapter = HuggingFaceLocalAdapter(
|
|
188
|
+
device="auto"
|
|
189
|
+
)
|
|
190
|
+
litellm.custom_provider_map = [
|
|
191
|
+
{"provider": "huggingface-local", "custom_handler": hf_adapter}
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
# == Initialize computer agent ==
|
|
195
|
+
|
|
196
|
+
# Find the appropriate agent loop
|
|
197
|
+
if custom_loop:
|
|
198
|
+
self.agent_loop = custom_loop
|
|
199
|
+
self.agent_loop_info = None
|
|
200
|
+
else:
|
|
201
|
+
loop_info = find_agent_loop(model)
|
|
202
|
+
if not loop_info:
|
|
203
|
+
raise ValueError(f"No agent loop found for model: {model}")
|
|
204
|
+
self.agent_loop = loop_info.func
|
|
205
|
+
self.agent_loop_info = loop_info
|
|
206
|
+
|
|
207
|
+
self.tool_schemas = []
|
|
208
|
+
self.computer_handler = None
|
|
209
|
+
|
|
210
|
+
async def _initialize_computers(self):
|
|
211
|
+
"""Initialize computer objects"""
|
|
212
|
+
if not self.tool_schemas:
|
|
213
|
+
for tool in self.tools:
|
|
214
|
+
if hasattr(tool, '_initialized') and not tool._initialized:
|
|
215
|
+
await tool.run()
|
|
216
|
+
|
|
217
|
+
# Process tools and create tool schemas
|
|
218
|
+
self.tool_schemas = self._process_tools()
|
|
219
|
+
|
|
220
|
+
# Find computer tool and create interface adapter
|
|
221
|
+
computer_handler = None
|
|
222
|
+
for schema in self.tool_schemas:
|
|
223
|
+
if schema["type"] == "computer":
|
|
224
|
+
computer_handler = OpenAIComputerHandler(schema["computer"].interface)
|
|
225
|
+
break
|
|
226
|
+
self.computer_handler = computer_handler
|
|
227
|
+
|
|
228
|
+
def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
|
|
229
|
+
"""Process input messages and create schemas for the agent loop"""
|
|
230
|
+
if isinstance(input, str):
|
|
231
|
+
return [{"role": "user", "content": input}]
|
|
232
|
+
return [get_json(msg) for msg in input]
|
|
233
|
+
|
|
234
|
+
def _process_tools(self) -> List[Dict[str, Any]]:
|
|
235
|
+
"""Process tools and create schemas for the agent loop"""
|
|
236
|
+
schemas = []
|
|
237
|
+
|
|
238
|
+
for tool in self.tools:
|
|
239
|
+
# Check if it's a computer object (has interface attribute)
|
|
240
|
+
if hasattr(tool, 'interface'):
|
|
241
|
+
# This is a computer tool - will be handled by agent loop
|
|
242
|
+
schemas.append({
|
|
243
|
+
"type": "computer",
|
|
244
|
+
"computer": tool
|
|
245
|
+
})
|
|
246
|
+
elif callable(tool):
|
|
247
|
+
# Use litellm.utils.function_to_dict to extract schema from docstring
|
|
248
|
+
try:
|
|
249
|
+
function_schema = litellm.utils.function_to_dict(tool)
|
|
250
|
+
schemas.append({
|
|
251
|
+
"type": "function",
|
|
252
|
+
"function": function_schema
|
|
253
|
+
})
|
|
254
|
+
except Exception as e:
|
|
255
|
+
print(f"Warning: Could not process tool {tool}: {e}")
|
|
256
|
+
else:
|
|
257
|
+
print(f"Warning: Unknown tool type: {tool}")
|
|
258
|
+
|
|
259
|
+
return schemas
|
|
260
|
+
|
|
261
|
+
def _get_tool(self, name: str) -> Optional[Callable]:
|
|
262
|
+
"""Get a tool by name"""
|
|
263
|
+
for tool in self.tools:
|
|
264
|
+
if hasattr(tool, '__name__') and tool.__name__ == name:
|
|
265
|
+
return tool
|
|
266
|
+
elif hasattr(tool, 'func') and tool.func.__name__ == name:
|
|
267
|
+
return tool
|
|
268
|
+
return None
|
|
269
|
+
|
|
270
|
+
# ============================================================================
|
|
271
|
+
# AGENT RUN LOOP LIFECYCLE HOOKS
|
|
272
|
+
# ============================================================================
|
|
273
|
+
|
|
274
|
+
async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
|
275
|
+
"""Initialize run tracking by calling callbacks."""
|
|
276
|
+
for callback in self.callbacks:
|
|
277
|
+
if hasattr(callback, 'on_run_start'):
|
|
278
|
+
await callback.on_run_start(kwargs, old_items)
|
|
279
|
+
|
|
280
|
+
async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
|
281
|
+
"""Finalize run tracking by calling callbacks."""
|
|
282
|
+
for callback in self.callbacks:
|
|
283
|
+
if hasattr(callback, 'on_run_end'):
|
|
284
|
+
await callback.on_run_end(kwargs, old_items, new_items)
|
|
285
|
+
|
|
286
|
+
async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
|
287
|
+
"""Check if run should continue by calling callbacks."""
|
|
288
|
+
for callback in self.callbacks:
|
|
289
|
+
if hasattr(callback, 'on_run_continue'):
|
|
290
|
+
should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
|
|
291
|
+
if not should_continue:
|
|
292
|
+
return False
|
|
293
|
+
return True
|
|
294
|
+
|
|
295
|
+
async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
296
|
+
"""Prepare messages for the LLM call by applying callbacks."""
|
|
297
|
+
result = messages
|
|
298
|
+
for callback in self.callbacks:
|
|
299
|
+
if hasattr(callback, 'on_llm_start'):
|
|
300
|
+
result = await callback.on_llm_start(result)
|
|
301
|
+
return result
|
|
302
|
+
|
|
303
|
+
async def _on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
304
|
+
"""Postprocess messages after the LLM call by applying callbacks."""
|
|
305
|
+
result = messages
|
|
306
|
+
for callback in self.callbacks:
|
|
307
|
+
if hasattr(callback, 'on_llm_end'):
|
|
308
|
+
result = await callback.on_llm_end(result)
|
|
309
|
+
return result
|
|
310
|
+
|
|
311
|
+
async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
|
312
|
+
"""Called when responses are received."""
|
|
313
|
+
for callback in self.callbacks:
|
|
314
|
+
if hasattr(callback, 'on_responses'):
|
|
315
|
+
await callback.on_responses(get_json(kwargs), get_json(responses))
|
|
316
|
+
|
|
317
|
+
async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
|
318
|
+
"""Called when a computer call is about to start."""
|
|
319
|
+
for callback in self.callbacks:
|
|
320
|
+
if hasattr(callback, 'on_computer_call_start'):
|
|
321
|
+
await callback.on_computer_call_start(get_json(item))
|
|
322
|
+
|
|
323
|
+
async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
|
324
|
+
"""Called when a computer call has completed."""
|
|
325
|
+
for callback in self.callbacks:
|
|
326
|
+
if hasattr(callback, 'on_computer_call_end'):
|
|
327
|
+
await callback.on_computer_call_end(get_json(item), get_json(result))
|
|
328
|
+
|
|
329
|
+
async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
|
|
330
|
+
"""Called when a function call is about to start."""
|
|
331
|
+
for callback in self.callbacks:
|
|
332
|
+
if hasattr(callback, 'on_function_call_start'):
|
|
333
|
+
await callback.on_function_call_start(get_json(item))
|
|
334
|
+
|
|
335
|
+
async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
|
336
|
+
"""Called when a function call has completed."""
|
|
337
|
+
for callback in self.callbacks:
|
|
338
|
+
if hasattr(callback, 'on_function_call_end'):
|
|
339
|
+
await callback.on_function_call_end(get_json(item), get_json(result))
|
|
340
|
+
|
|
341
|
+
async def _on_text(self, item: Dict[str, Any]) -> None:
|
|
342
|
+
"""Called when a text message is encountered."""
|
|
343
|
+
for callback in self.callbacks:
|
|
344
|
+
if hasattr(callback, 'on_text'):
|
|
345
|
+
await callback.on_text(get_json(item))
|
|
346
|
+
|
|
347
|
+
async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
|
348
|
+
"""Called when an LLM API call is about to start."""
|
|
349
|
+
for callback in self.callbacks:
|
|
350
|
+
if hasattr(callback, 'on_api_start'):
|
|
351
|
+
await callback.on_api_start(get_json(kwargs))
|
|
352
|
+
|
|
353
|
+
async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
|
354
|
+
"""Called when an LLM API call has completed."""
|
|
355
|
+
for callback in self.callbacks:
|
|
356
|
+
if hasattr(callback, 'on_api_end'):
|
|
357
|
+
await callback.on_api_end(get_json(kwargs), get_json(result))
|
|
358
|
+
|
|
359
|
+
async def _on_usage(self, usage: Dict[str, Any]) -> None:
|
|
360
|
+
"""Called when usage information is received."""
|
|
361
|
+
for callback in self.callbacks:
|
|
362
|
+
if hasattr(callback, 'on_usage'):
|
|
363
|
+
await callback.on_usage(get_json(usage))
|
|
364
|
+
|
|
365
|
+
async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
|
366
|
+
"""Called when a screenshot is taken."""
|
|
367
|
+
for callback in self.callbacks:
|
|
368
|
+
if hasattr(callback, 'on_screenshot'):
|
|
369
|
+
await callback.on_screenshot(screenshot, name)
|
|
370
|
+
|
|
371
|
+
# ============================================================================
|
|
372
|
+
# AGENT OUTPUT PROCESSING
|
|
373
|
+
# ============================================================================
|
|
374
|
+
|
|
375
|
+
async def _handle_item(self, item: Any, computer: Optional[Computer] = None) -> List[Dict[str, Any]]:
|
|
376
|
+
"""Handle each item; may cause a computer action + screenshot."""
|
|
377
|
+
|
|
378
|
+
item_type = item.get("type", None)
|
|
379
|
+
|
|
380
|
+
if item_type == "message":
|
|
381
|
+
await self._on_text(item)
|
|
382
|
+
# # Print messages
|
|
383
|
+
# if item.get("content"):
|
|
384
|
+
# for content_item in item.get("content"):
|
|
385
|
+
# if content_item.get("text"):
|
|
386
|
+
# print(content_item.get("text"))
|
|
387
|
+
return []
|
|
388
|
+
|
|
389
|
+
if item_type == "computer_call":
|
|
390
|
+
await self._on_computer_call_start(item)
|
|
391
|
+
if not computer:
|
|
392
|
+
raise ValueError("Computer handler is required for computer calls")
|
|
393
|
+
|
|
394
|
+
# Perform computer actions
|
|
395
|
+
action = item.get("action")
|
|
396
|
+
action_type = action.get("type")
|
|
397
|
+
|
|
398
|
+
# Extract action arguments (all fields except 'type')
|
|
399
|
+
action_args = {k: v for k, v in action.items() if k != "type"}
|
|
400
|
+
|
|
401
|
+
# print(f"{action_type}({action_args})")
|
|
402
|
+
|
|
403
|
+
# Execute the computer action
|
|
404
|
+
computer_method = getattr(computer, action_type, None)
|
|
405
|
+
if computer_method:
|
|
406
|
+
await computer_method(**action_args)
|
|
407
|
+
else:
|
|
408
|
+
print(f"Unknown computer action: {action_type}")
|
|
409
|
+
return []
|
|
410
|
+
|
|
411
|
+
# Take screenshot after action
|
|
412
|
+
if self.screenshot_delay and self.screenshot_delay > 0:
|
|
413
|
+
await asyncio.sleep(self.screenshot_delay)
|
|
414
|
+
screenshot_base64 = await computer.screenshot()
|
|
415
|
+
await self._on_screenshot(screenshot_base64, "screenshot_after")
|
|
416
|
+
|
|
417
|
+
# Handle safety checks
|
|
418
|
+
pending_checks = item.get("pending_safety_checks", [])
|
|
419
|
+
acknowledged_checks = []
|
|
420
|
+
for check in pending_checks:
|
|
421
|
+
check_message = check.get("message", str(check))
|
|
422
|
+
if acknowledge_safety_check_callback(check_message):
|
|
423
|
+
acknowledged_checks.append(check)
|
|
424
|
+
else:
|
|
425
|
+
raise ValueError(f"Safety check failed: {check_message}")
|
|
426
|
+
|
|
427
|
+
# Create call output
|
|
428
|
+
call_output = {
|
|
429
|
+
"type": "computer_call_output",
|
|
430
|
+
"call_id": item.get("call_id"),
|
|
431
|
+
"acknowledged_safety_checks": acknowledged_checks,
|
|
432
|
+
"output": {
|
|
433
|
+
"type": "input_image",
|
|
434
|
+
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
|
435
|
+
},
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
# Additional URL safety checks for browser environments
|
|
439
|
+
if await computer.get_environment() == "browser":
|
|
440
|
+
current_url = await computer.get_current_url()
|
|
441
|
+
call_output["output"]["current_url"] = current_url
|
|
442
|
+
check_blocklisted_url(current_url)
|
|
443
|
+
|
|
444
|
+
result = [call_output]
|
|
445
|
+
await self._on_computer_call_end(item, result)
|
|
446
|
+
return result
|
|
447
|
+
|
|
448
|
+
if item_type == "function_call":
|
|
449
|
+
await self._on_function_call_start(item)
|
|
450
|
+
# Perform function call
|
|
451
|
+
function = self._get_tool(item.get("name"))
|
|
452
|
+
if not function:
|
|
453
|
+
raise ValueError(f"Function {item.get("name")} not found")
|
|
454
|
+
|
|
455
|
+
args = json.loads(item.get("arguments"))
|
|
456
|
+
|
|
457
|
+
# Execute function - use asyncio.to_thread for non-async functions
|
|
458
|
+
if inspect.iscoroutinefunction(function):
|
|
459
|
+
result = await function(**args)
|
|
460
|
+
else:
|
|
461
|
+
result = await asyncio.to_thread(function, **args)
|
|
462
|
+
|
|
463
|
+
# Create function call output
|
|
464
|
+
call_output = {
|
|
465
|
+
"type": "function_call_output",
|
|
466
|
+
"call_id": item.get("call_id"),
|
|
467
|
+
"output": str(result),
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
result = [call_output]
|
|
471
|
+
await self._on_function_call_end(item, result)
|
|
472
|
+
return result
|
|
473
|
+
|
|
474
|
+
return []
|
|
475
|
+
|
|
476
|
+
# ============================================================================
|
|
477
|
+
# MAIN AGENT LOOP
|
|
478
|
+
# ============================================================================
|
|
479
|
+
|
|
480
|
+
async def run(
|
|
481
|
+
self,
|
|
482
|
+
messages: Messages,
|
|
483
|
+
stream: bool = False,
|
|
484
|
+
**kwargs
|
|
485
|
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
486
|
+
"""
|
|
487
|
+
Run the agent with the given messages using Computer protocol handler pattern.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
messages: List of message dictionaries
|
|
491
|
+
stream: Whether to stream the response
|
|
492
|
+
**kwargs: Additional arguments
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
AsyncGenerator that yields response chunks
|
|
496
|
+
"""
|
|
497
|
+
|
|
498
|
+
await self._initialize_computers()
|
|
499
|
+
|
|
500
|
+
# Merge kwargs
|
|
501
|
+
merged_kwargs = {**self.kwargs, **kwargs}
|
|
502
|
+
|
|
503
|
+
old_items = self._process_input(messages)
|
|
504
|
+
new_items = []
|
|
505
|
+
|
|
506
|
+
# Initialize run tracking
|
|
507
|
+
run_kwargs = {
|
|
508
|
+
"messages": messages,
|
|
509
|
+
"stream": stream,
|
|
510
|
+
"model": self.model,
|
|
511
|
+
"agent_loop": self.agent_loop.__name__,
|
|
512
|
+
**merged_kwargs
|
|
513
|
+
}
|
|
514
|
+
await self._on_run_start(run_kwargs, old_items)
|
|
515
|
+
|
|
516
|
+
while new_items[-1].get("role") != "assistant" if new_items else True:
|
|
517
|
+
# Lifecycle hook: Check if we should continue based on callbacks (e.g., budget manager)
|
|
518
|
+
should_continue = await self._on_run_continue(run_kwargs, old_items, new_items)
|
|
519
|
+
if not should_continue:
|
|
520
|
+
break
|
|
521
|
+
|
|
522
|
+
# Lifecycle hook: Prepare messages for the LLM call
|
|
523
|
+
# Use cases:
|
|
524
|
+
# - PII anonymization
|
|
525
|
+
# - Image retention policy
|
|
526
|
+
combined_messages = old_items + new_items
|
|
527
|
+
preprocessed_messages = await self._on_llm_start(combined_messages)
|
|
528
|
+
|
|
529
|
+
loop_kwargs = {
|
|
530
|
+
"messages": preprocessed_messages,
|
|
531
|
+
"model": self.model,
|
|
532
|
+
"tools": self.tool_schemas,
|
|
533
|
+
"stream": False,
|
|
534
|
+
"computer_handler": self.computer_handler,
|
|
535
|
+
"max_retries": self.max_retries,
|
|
536
|
+
"use_prompt_caching": self.use_prompt_caching,
|
|
537
|
+
**merged_kwargs
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
# Run agent loop iteration
|
|
541
|
+
result = await self.agent_loop(
|
|
542
|
+
**loop_kwargs,
|
|
543
|
+
_on_api_start=self._on_api_start,
|
|
544
|
+
_on_api_end=self._on_api_end,
|
|
545
|
+
_on_usage=self._on_usage,
|
|
546
|
+
_on_screenshot=self._on_screenshot,
|
|
547
|
+
)
|
|
548
|
+
result = get_json(result)
|
|
549
|
+
|
|
550
|
+
# Lifecycle hook: Postprocess messages after the LLM call
|
|
551
|
+
# Use cases:
|
|
552
|
+
# - PII deanonymization (if you want tool calls to see PII)
|
|
553
|
+
result["output"] = await self._on_llm_end(result.get("output", []))
|
|
554
|
+
await self._on_responses(loop_kwargs, result)
|
|
555
|
+
|
|
556
|
+
# Yield agent response
|
|
557
|
+
yield result
|
|
558
|
+
|
|
559
|
+
# Add agent response to new_items
|
|
560
|
+
new_items += result.get("output")
|
|
561
|
+
|
|
562
|
+
# Handle computer actions
|
|
563
|
+
for item in result.get("output"):
|
|
564
|
+
partial_items = await self._handle_item(item, self.computer_handler)
|
|
565
|
+
new_items += partial_items
|
|
566
|
+
|
|
567
|
+
# Yield partial response
|
|
568
|
+
yield {
|
|
569
|
+
"output": partial_items,
|
|
570
|
+
"usage": Usage(
|
|
571
|
+
prompt_tokens=0,
|
|
572
|
+
completion_tokens=0,
|
|
573
|
+
total_tokens=0,
|
|
574
|
+
)
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
await self._on_run_end(loop_kwargs, old_items, new_items)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Callback system for ComputerAgent preprocessing and postprocessing hooks.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .base import AsyncCallbackHandler
|
|
6
|
+
from .image_retention import ImageRetentionCallback
|
|
7
|
+
from .logging import LoggingCallback
|
|
8
|
+
from .trajectory_saver import TrajectorySaverCallback
|
|
9
|
+
from .budget_manager import BudgetManagerCallback
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"AsyncCallbackHandler",
|
|
13
|
+
"ImageRetentionCallback",
|
|
14
|
+
"LoggingCallback",
|
|
15
|
+
"TrajectorySaverCallback",
|
|
16
|
+
"BudgetManagerCallback",
|
|
17
|
+
]
|