cua-agent 0.4.12__py3-none-any.whl → 0.4.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/adapters/__init__.py +2 -0
- agent/adapters/huggingfacelocal_adapter.py +15 -3
- agent/adapters/human_adapter.py +348 -0
- agent/agent.py +29 -21
- agent/callbacks/trajectory_saver.py +35 -26
- agent/cli.py +1 -1
- agent/computers/__init__.py +41 -0
- agent/computers/base.py +70 -0
- agent/{computer_handler.py → computers/cua.py} +26 -23
- agent/computers/custom.py +209 -0
- agent/human_tool/__init__.py +29 -0
- agent/human_tool/__main__.py +38 -0
- agent/human_tool/server.py +234 -0
- agent/human_tool/ui.py +630 -0
- agent/integrations/hud/__init__.py +77 -0
- agent/integrations/hud/adapter.py +121 -0
- agent/integrations/hud/agent.py +373 -0
- agent/integrations/hud/computer_handler.py +187 -0
- agent/loops/uitars.py +9 -1
- agent/types.py +1 -53
- agent/ui/gradio/app.py +1 -0
- agent/ui/gradio/ui_components.py +20 -9
- {cua_agent-0.4.12.dist-info → cua_agent-0.4.14.dist-info}/METADATA +9 -6
- cua_agent-0.4.14.dist-info/RECORD +50 -0
- cua_agent-0.4.12.dist-info/RECORD +0 -38
- {cua_agent-0.4.12.dist-info → cua_agent-0.4.14.dist-info}/WHEEL +0 -0
- {cua_agent-0.4.12.dist-info → cua_agent-0.4.14.dist-info}/entry_points.txt +0 -0
agent/adapters/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import functools
|
|
2
3
|
import warnings
|
|
4
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
3
5
|
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional
|
|
4
6
|
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
|
5
7
|
from litellm.llms.custom_llm import CustomLLM
|
|
@@ -28,6 +30,7 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
|
|
28
30
|
self.device = device
|
|
29
31
|
self.models = {} # Cache for loaded models
|
|
30
32
|
self.processors = {} # Cache for loaded processors
|
|
33
|
+
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
|
|
31
34
|
|
|
32
35
|
def _load_model_and_processor(self, model_name: str):
|
|
33
36
|
"""Load model and processor if not already cached.
|
|
@@ -51,7 +54,8 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
|
|
51
54
|
processor = AutoProcessor.from_pretrained(
|
|
52
55
|
model_name,
|
|
53
56
|
min_pixels=3136,
|
|
54
|
-
max_pixels=4096 * 2160
|
|
57
|
+
max_pixels=4096 * 2160,
|
|
58
|
+
device_map=self.device
|
|
55
59
|
)
|
|
56
60
|
|
|
57
61
|
# Cache them
|
|
@@ -185,7 +189,11 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
|
|
185
189
|
ModelResponse with generated text
|
|
186
190
|
"""
|
|
187
191
|
# Run _generate in thread pool to avoid blocking
|
|
188
|
-
|
|
192
|
+
loop = asyncio.get_event_loop()
|
|
193
|
+
generated_text = await loop.run_in_executor(
|
|
194
|
+
self._executor,
|
|
195
|
+
functools.partial(self._generate, **kwargs)
|
|
196
|
+
)
|
|
189
197
|
|
|
190
198
|
return await acompletion(
|
|
191
199
|
model=f"huggingface-local/{kwargs['model']}",
|
|
@@ -218,7 +226,11 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
|
|
218
226
|
AsyncIterator of GenericStreamingChunk
|
|
219
227
|
"""
|
|
220
228
|
# Run _generate in thread pool to avoid blocking
|
|
221
|
-
|
|
229
|
+
loop = asyncio.get_event_loop()
|
|
230
|
+
generated_text = await loop.run_in_executor(
|
|
231
|
+
self._executor,
|
|
232
|
+
functools.partial(self._generate, **kwargs)
|
|
233
|
+
)
|
|
222
234
|
|
|
223
235
|
generic_streaming_chunk: GenericStreamingChunk = {
|
|
224
236
|
"finish_reason": "stop",
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import asyncio
|
|
3
|
+
import requests
|
|
4
|
+
from typing import List, Dict, Any, Iterator, AsyncIterator
|
|
5
|
+
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
|
6
|
+
from litellm.llms.custom_llm import CustomLLM
|
|
7
|
+
from litellm import completion, acompletion
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class HumanAdapter(CustomLLM):
|
|
11
|
+
"""Human Adapter for human-in-the-loop completions.
|
|
12
|
+
|
|
13
|
+
This adapter sends completion requests to a human completion server
|
|
14
|
+
where humans can review and respond to AI requests.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, base_url: str | None = None, timeout: float = 300.0, **kwargs):
|
|
18
|
+
"""Initialize the human adapter.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
base_url: Base URL for the human completion server.
|
|
22
|
+
Defaults to HUMAN_BASE_URL environment variable or http://localhost:8002
|
|
23
|
+
timeout: Timeout in seconds for waiting for human response
|
|
24
|
+
**kwargs: Additional arguments
|
|
25
|
+
"""
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.base_url = base_url or os.getenv('HUMAN_BASE_URL', 'http://localhost:8002')
|
|
28
|
+
self.timeout = timeout
|
|
29
|
+
|
|
30
|
+
# Ensure base_url doesn't end with slash
|
|
31
|
+
self.base_url = self.base_url.rstrip('/')
|
|
32
|
+
|
|
33
|
+
def _queue_completion(self, messages: List[Dict[str, Any]], model: str) -> str:
|
|
34
|
+
"""Queue a completion request and return the call ID.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
messages: Messages in OpenAI format
|
|
38
|
+
model: Model name
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Call ID for tracking the request
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
Exception: If queueing fails
|
|
45
|
+
"""
|
|
46
|
+
try:
|
|
47
|
+
response = requests.post(
|
|
48
|
+
f"{self.base_url}/queue",
|
|
49
|
+
json={"messages": messages, "model": model},
|
|
50
|
+
timeout=10
|
|
51
|
+
)
|
|
52
|
+
response.raise_for_status()
|
|
53
|
+
return response.json()["id"]
|
|
54
|
+
except requests.RequestException as e:
|
|
55
|
+
raise Exception(f"Failed to queue completion request: {e}")
|
|
56
|
+
|
|
57
|
+
def _wait_for_completion(self, call_id: str) -> Dict[str, Any]:
|
|
58
|
+
"""Wait for human to complete the call.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
call_id: ID of the queued completion call
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Dict containing response and/or tool_calls
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
TimeoutError: If timeout is exceeded
|
|
68
|
+
Exception: If completion fails
|
|
69
|
+
"""
|
|
70
|
+
import time
|
|
71
|
+
|
|
72
|
+
start_time = time.time()
|
|
73
|
+
|
|
74
|
+
while True:
|
|
75
|
+
try:
|
|
76
|
+
# Check status
|
|
77
|
+
status_response = requests.get(f"{self.base_url}/status/{call_id}")
|
|
78
|
+
status_response.raise_for_status()
|
|
79
|
+
status_data = status_response.json()
|
|
80
|
+
|
|
81
|
+
if status_data["status"] == "completed":
|
|
82
|
+
result = {}
|
|
83
|
+
if "response" in status_data and status_data["response"]:
|
|
84
|
+
result["response"] = status_data["response"]
|
|
85
|
+
if "tool_calls" in status_data and status_data["tool_calls"]:
|
|
86
|
+
result["tool_calls"] = status_data["tool_calls"]
|
|
87
|
+
return result
|
|
88
|
+
elif status_data["status"] == "failed":
|
|
89
|
+
error_msg = status_data.get("error", "Unknown error")
|
|
90
|
+
raise Exception(f"Completion failed: {error_msg}")
|
|
91
|
+
|
|
92
|
+
# Check timeout
|
|
93
|
+
if time.time() - start_time > self.timeout:
|
|
94
|
+
raise TimeoutError(f"Timeout waiting for human response after {self.timeout} seconds")
|
|
95
|
+
|
|
96
|
+
# Wait before checking again
|
|
97
|
+
time.sleep(1.0)
|
|
98
|
+
|
|
99
|
+
except requests.RequestException as e:
|
|
100
|
+
if time.time() - start_time > self.timeout:
|
|
101
|
+
raise TimeoutError(f"Timeout waiting for human response: {e}")
|
|
102
|
+
# Continue trying if we haven't timed out
|
|
103
|
+
time.sleep(1.0)
|
|
104
|
+
|
|
105
|
+
async def _async_wait_for_completion(self, call_id: str) -> Dict[str, Any]:
|
|
106
|
+
"""Async version of wait_for_completion.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
call_id: ID of the queued completion call
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Dict containing response and/or tool_calls
|
|
113
|
+
|
|
114
|
+
Raises:
|
|
115
|
+
TimeoutError: If timeout is exceeded
|
|
116
|
+
Exception: If completion fails
|
|
117
|
+
"""
|
|
118
|
+
import aiohttp
|
|
119
|
+
import time
|
|
120
|
+
|
|
121
|
+
start_time = time.time()
|
|
122
|
+
|
|
123
|
+
async with aiohttp.ClientSession() as session:
|
|
124
|
+
while True:
|
|
125
|
+
try:
|
|
126
|
+
# Check status
|
|
127
|
+
async with session.get(f"{self.base_url}/status/{call_id}") as response:
|
|
128
|
+
response.raise_for_status()
|
|
129
|
+
status_data = await response.json()
|
|
130
|
+
|
|
131
|
+
if status_data["status"] == "completed":
|
|
132
|
+
result = {}
|
|
133
|
+
if "response" in status_data and status_data["response"]:
|
|
134
|
+
result["response"] = status_data["response"]
|
|
135
|
+
if "tool_calls" in status_data and status_data["tool_calls"]:
|
|
136
|
+
result["tool_calls"] = status_data["tool_calls"]
|
|
137
|
+
return result
|
|
138
|
+
elif status_data["status"] == "failed":
|
|
139
|
+
error_msg = status_data.get("error", "Unknown error")
|
|
140
|
+
raise Exception(f"Completion failed: {error_msg}")
|
|
141
|
+
|
|
142
|
+
# Check timeout
|
|
143
|
+
if time.time() - start_time > self.timeout:
|
|
144
|
+
raise TimeoutError(f"Timeout waiting for human response after {self.timeout} seconds")
|
|
145
|
+
|
|
146
|
+
# Wait before checking again
|
|
147
|
+
await asyncio.sleep(1.0)
|
|
148
|
+
|
|
149
|
+
except Exception as e:
|
|
150
|
+
if time.time() - start_time > self.timeout:
|
|
151
|
+
raise TimeoutError(f"Timeout waiting for human response: {e}")
|
|
152
|
+
# Continue trying if we haven't timed out
|
|
153
|
+
await asyncio.sleep(1.0)
|
|
154
|
+
|
|
155
|
+
def _generate_response(self, messages: List[Dict[str, Any]], model: str) -> Dict[str, Any]:
|
|
156
|
+
"""Generate a human response for the given messages.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
messages: Messages in OpenAI format
|
|
160
|
+
model: Model name
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
Dict containing response and/or tool_calls
|
|
164
|
+
"""
|
|
165
|
+
# Queue the completion request
|
|
166
|
+
call_id = self._queue_completion(messages, model)
|
|
167
|
+
|
|
168
|
+
# Wait for human response
|
|
169
|
+
response = self._wait_for_completion(call_id)
|
|
170
|
+
|
|
171
|
+
return response
|
|
172
|
+
|
|
173
|
+
async def _async_generate_response(self, messages: List[Dict[str, Any]], model: str) -> Dict[str, Any]:
|
|
174
|
+
"""Async version of _generate_response.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
messages: Messages in OpenAI format
|
|
178
|
+
model: Model name
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Dict containing response and/or tool_calls
|
|
182
|
+
"""
|
|
183
|
+
# Queue the completion request (sync operation)
|
|
184
|
+
call_id = self._queue_completion(messages, model)
|
|
185
|
+
|
|
186
|
+
# Wait for human response (async)
|
|
187
|
+
response = await self._async_wait_for_completion(call_id)
|
|
188
|
+
|
|
189
|
+
return response
|
|
190
|
+
|
|
191
|
+
def completion(self, *args, **kwargs) -> ModelResponse:
|
|
192
|
+
"""Synchronous completion method.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
ModelResponse with human-generated text or tool calls
|
|
196
|
+
"""
|
|
197
|
+
messages = kwargs.get('messages', [])
|
|
198
|
+
model = kwargs.get('model', 'human')
|
|
199
|
+
|
|
200
|
+
# Generate human response
|
|
201
|
+
human_response_data = self._generate_response(messages, model)
|
|
202
|
+
|
|
203
|
+
# Create ModelResponse with proper structure
|
|
204
|
+
from litellm.types.utils import ModelResponse, Choices, Message
|
|
205
|
+
import uuid
|
|
206
|
+
import time
|
|
207
|
+
|
|
208
|
+
# Create message content based on response type
|
|
209
|
+
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
|
|
210
|
+
# Tool calls response
|
|
211
|
+
message = Message(
|
|
212
|
+
role="assistant",
|
|
213
|
+
content=human_response_data.get("response", ""),
|
|
214
|
+
tool_calls=human_response_data["tool_calls"]
|
|
215
|
+
)
|
|
216
|
+
else:
|
|
217
|
+
# Text response
|
|
218
|
+
message = Message(
|
|
219
|
+
role="assistant",
|
|
220
|
+
content=human_response_data.get("response", "")
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
choice = Choices(
|
|
224
|
+
finish_reason="stop",
|
|
225
|
+
index=0,
|
|
226
|
+
message=message
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
result = ModelResponse(
|
|
230
|
+
id=f"human-{uuid.uuid4()}",
|
|
231
|
+
choices=[choice],
|
|
232
|
+
created=int(time.time()),
|
|
233
|
+
model=f"human/{model}",
|
|
234
|
+
object="chat.completion"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
return result
|
|
238
|
+
|
|
239
|
+
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
|
240
|
+
"""Asynchronous completion method.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
ModelResponse with human-generated text or tool calls
|
|
244
|
+
"""
|
|
245
|
+
messages = kwargs.get('messages', [])
|
|
246
|
+
model = kwargs.get('model', 'human')
|
|
247
|
+
|
|
248
|
+
# Generate human response
|
|
249
|
+
human_response_data = await self._async_generate_response(messages, model)
|
|
250
|
+
|
|
251
|
+
# Create ModelResponse with proper structure
|
|
252
|
+
from litellm.types.utils import ModelResponse, Choices, Message
|
|
253
|
+
import uuid
|
|
254
|
+
import time
|
|
255
|
+
|
|
256
|
+
# Create message content based on response type
|
|
257
|
+
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
|
|
258
|
+
# Tool calls response
|
|
259
|
+
message = Message(
|
|
260
|
+
role="assistant",
|
|
261
|
+
content=human_response_data.get("response", ""),
|
|
262
|
+
tool_calls=human_response_data["tool_calls"]
|
|
263
|
+
)
|
|
264
|
+
else:
|
|
265
|
+
# Text response
|
|
266
|
+
message = Message(
|
|
267
|
+
role="assistant",
|
|
268
|
+
content=human_response_data.get("response", "")
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
choice = Choices(
|
|
272
|
+
finish_reason="stop",
|
|
273
|
+
index=0,
|
|
274
|
+
message=message
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
result = ModelResponse(
|
|
278
|
+
id=f"human-{uuid.uuid4()}",
|
|
279
|
+
choices=[choice],
|
|
280
|
+
created=int(time.time()),
|
|
281
|
+
model=f"human/{model}",
|
|
282
|
+
object="chat.completion"
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
return result
|
|
286
|
+
|
|
287
|
+
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
|
288
|
+
"""Synchronous streaming method.
|
|
289
|
+
|
|
290
|
+
Yields:
|
|
291
|
+
Streaming chunks with human-generated text or tool calls
|
|
292
|
+
"""
|
|
293
|
+
messages = kwargs.get('messages', [])
|
|
294
|
+
model = kwargs.get('model', 'human')
|
|
295
|
+
|
|
296
|
+
# Generate human response
|
|
297
|
+
human_response_data = self._generate_response(messages, model)
|
|
298
|
+
|
|
299
|
+
import time
|
|
300
|
+
|
|
301
|
+
# Handle tool calls vs text response
|
|
302
|
+
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
|
|
303
|
+
# Stream tool calls as a single chunk
|
|
304
|
+
generic_chunk: GenericStreamingChunk = {
|
|
305
|
+
"finish_reason": "tool_calls",
|
|
306
|
+
"index": 0,
|
|
307
|
+
"is_finished": True,
|
|
308
|
+
"text": human_response_data.get("response", ""),
|
|
309
|
+
"tool_use": human_response_data["tool_calls"],
|
|
310
|
+
"usage": {"completion_tokens": 1, "prompt_tokens": 0, "total_tokens": 1},
|
|
311
|
+
}
|
|
312
|
+
yield generic_chunk
|
|
313
|
+
else:
|
|
314
|
+
# Stream text response
|
|
315
|
+
response_text = human_response_data.get("response", "")
|
|
316
|
+
generic_chunk: GenericStreamingChunk = {
|
|
317
|
+
"finish_reason": "stop",
|
|
318
|
+
"index": 0,
|
|
319
|
+
"is_finished": True,
|
|
320
|
+
"text": response_text,
|
|
321
|
+
"tool_use": None,
|
|
322
|
+
"usage": {"completion_tokens": len(response_text.split()), "prompt_tokens": 0, "total_tokens": len(response_text.split())},
|
|
323
|
+
}
|
|
324
|
+
yield generic_chunk
|
|
325
|
+
|
|
326
|
+
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
|
327
|
+
"""Asynchronous streaming method.
|
|
328
|
+
|
|
329
|
+
Yields:
|
|
330
|
+
Streaming chunks with human-generated text or tool calls
|
|
331
|
+
"""
|
|
332
|
+
messages = kwargs.get('messages', [])
|
|
333
|
+
model = kwargs.get('model', 'human')
|
|
334
|
+
|
|
335
|
+
# Generate human response
|
|
336
|
+
human_response = await self._async_generate_response(messages, model)
|
|
337
|
+
|
|
338
|
+
# Return as single streaming chunk
|
|
339
|
+
generic_streaming_chunk: GenericStreamingChunk = {
|
|
340
|
+
"finish_reason": "stop",
|
|
341
|
+
"index": 0,
|
|
342
|
+
"is_finished": True,
|
|
343
|
+
"text": human_response,
|
|
344
|
+
"tool_use": None,
|
|
345
|
+
"usage": {"completion_tokens": len(human_response.split()), "prompt_tokens": 0, "total_tokens": len(human_response.split())},
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
yield generic_streaming_chunk
|
agent/agent.py
CHANGED
|
@@ -7,14 +7,16 @@ from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Calla
|
|
|
7
7
|
|
|
8
8
|
from litellm.responses.utils import Usage
|
|
9
9
|
|
|
10
|
-
from .types import Messages,
|
|
10
|
+
from .types import Messages, AgentCapability
|
|
11
11
|
from .decorators import find_agent_config
|
|
12
|
-
from .computer_handler import OpenAIComputerHandler, acknowledge_safety_check_callback, check_blocklisted_url
|
|
13
12
|
import json
|
|
14
13
|
import litellm
|
|
15
14
|
import litellm.utils
|
|
16
15
|
import inspect
|
|
17
|
-
from .adapters import
|
|
16
|
+
from .adapters import (
|
|
17
|
+
HuggingFaceLocalAdapter,
|
|
18
|
+
HumanAdapter,
|
|
19
|
+
)
|
|
18
20
|
from .callbacks import (
|
|
19
21
|
ImageRetentionCallback,
|
|
20
22
|
LoggingCallback,
|
|
@@ -22,9 +24,14 @@ from .callbacks import (
|
|
|
22
24
|
BudgetManagerCallback,
|
|
23
25
|
TelemetryCallback,
|
|
24
26
|
)
|
|
27
|
+
from .computers import (
|
|
28
|
+
AsyncComputerHandler,
|
|
29
|
+
is_agent_computer,
|
|
30
|
+
make_computer_handler
|
|
31
|
+
)
|
|
25
32
|
|
|
26
33
|
def get_json(obj: Any, max_depth: int = 10) -> Any:
|
|
27
|
-
def custom_serializer(o: Any, depth: int = 0, seen: Set[int] = None) -> Any:
|
|
34
|
+
def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
|
|
28
35
|
if seen is None:
|
|
29
36
|
seen = set()
|
|
30
37
|
|
|
@@ -211,8 +218,10 @@ class ComputerAgent:
|
|
|
211
218
|
hf_adapter = HuggingFaceLocalAdapter(
|
|
212
219
|
device="auto"
|
|
213
220
|
)
|
|
221
|
+
human_adapter = HumanAdapter()
|
|
214
222
|
litellm.custom_provider_map = [
|
|
215
|
-
{"provider": "huggingface-local", "custom_handler": hf_adapter}
|
|
223
|
+
{"provider": "huggingface-local", "custom_handler": hf_adapter},
|
|
224
|
+
{"provider": "human", "custom_handler": human_adapter}
|
|
216
225
|
]
|
|
217
226
|
litellm.suppress_debug_info = True
|
|
218
227
|
|
|
@@ -236,10 +245,6 @@ class ComputerAgent:
|
|
|
236
245
|
async def _initialize_computers(self):
|
|
237
246
|
"""Initialize computer objects"""
|
|
238
247
|
if not self.tool_schemas:
|
|
239
|
-
for tool in self.tools:
|
|
240
|
-
if hasattr(tool, '_initialized') and not tool._initialized:
|
|
241
|
-
await tool.run()
|
|
242
|
-
|
|
243
248
|
# Process tools and create tool schemas
|
|
244
249
|
self.tool_schemas = self._process_tools()
|
|
245
250
|
|
|
@@ -247,7 +252,7 @@ class ComputerAgent:
|
|
|
247
252
|
computer_handler = None
|
|
248
253
|
for schema in self.tool_schemas:
|
|
249
254
|
if schema["type"] == "computer":
|
|
250
|
-
computer_handler =
|
|
255
|
+
computer_handler = await make_computer_handler(schema["computer"])
|
|
251
256
|
break
|
|
252
257
|
self.computer_handler = computer_handler
|
|
253
258
|
|
|
@@ -263,7 +268,7 @@ class ComputerAgent:
|
|
|
263
268
|
|
|
264
269
|
for tool in self.tools:
|
|
265
270
|
# Check if it's a computer object (has interface attribute)
|
|
266
|
-
if
|
|
271
|
+
if is_agent_computer(tool):
|
|
267
272
|
# This is a computer tool - will be handled by agent loop
|
|
268
273
|
schemas.append({
|
|
269
274
|
"type": "computer",
|
|
@@ -398,7 +403,7 @@ class ComputerAgent:
|
|
|
398
403
|
# AGENT OUTPUT PROCESSING
|
|
399
404
|
# ============================================================================
|
|
400
405
|
|
|
401
|
-
async def _handle_item(self, item: Any, computer: Optional[
|
|
406
|
+
async def _handle_item(self, item: Any, computer: Optional[AsyncComputerHandler] = None, ignore_call_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
|
402
407
|
"""Handle each item; may cause a computer action + screenshot."""
|
|
403
408
|
if ignore_call_ids and item.get("call_id") and item.get("call_id") in ignore_call_ids:
|
|
404
409
|
return []
|
|
@@ -450,10 +455,12 @@ class ComputerAgent:
|
|
|
450
455
|
acknowledged_checks = []
|
|
451
456
|
for check in pending_checks:
|
|
452
457
|
check_message = check.get("message", str(check))
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
458
|
+
acknowledged_checks.append(check)
|
|
459
|
+
# TODO: implement a callback for safety checks
|
|
460
|
+
# if acknowledge_safety_check_callback(check_message, allow_always=True):
|
|
461
|
+
# acknowledged_checks.append(check)
|
|
462
|
+
# else:
|
|
463
|
+
# raise ValueError(f"Safety check failed: {check_message}")
|
|
457
464
|
|
|
458
465
|
# Create call output
|
|
459
466
|
call_output = {
|
|
@@ -466,11 +473,12 @@ class ComputerAgent:
|
|
|
466
473
|
},
|
|
467
474
|
}
|
|
468
475
|
|
|
469
|
-
# Additional URL safety checks for browser environments
|
|
470
|
-
if await computer.get_environment() == "browser":
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
476
|
+
# # Additional URL safety checks for browser environments
|
|
477
|
+
# if await computer.get_environment() == "browser":
|
|
478
|
+
# current_url = await computer.get_current_url()
|
|
479
|
+
# call_output["output"]["current_url"] = current_url
|
|
480
|
+
# # TODO: implement a callback for URL safety checks
|
|
481
|
+
# # check_blocklisted_url(current_url)
|
|
474
482
|
|
|
475
483
|
result = [call_output]
|
|
476
484
|
await self._on_computer_call_end(item, result)
|
|
@@ -51,12 +51,14 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
|
|
51
51
|
within the trajectory gets its own folder with screenshots and responses.
|
|
52
52
|
"""
|
|
53
53
|
|
|
54
|
-
def __init__(self, trajectory_dir: str):
|
|
54
|
+
def __init__(self, trajectory_dir: str, reset_on_run: bool = True):
|
|
55
55
|
"""
|
|
56
56
|
Initialize trajectory saver.
|
|
57
57
|
|
|
58
58
|
Args:
|
|
59
59
|
trajectory_dir: Base directory to save trajectories
|
|
60
|
+
reset_on_run: If True, reset trajectory_id/turn/artifact on each run.
|
|
61
|
+
If False, continue using existing trajectory_id if set.
|
|
60
62
|
"""
|
|
61
63
|
self.trajectory_dir = Path(trajectory_dir)
|
|
62
64
|
self.trajectory_id: Optional[str] = None
|
|
@@ -64,6 +66,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
|
|
64
66
|
self.current_artifact: int = 0
|
|
65
67
|
self.model: Optional[str] = None
|
|
66
68
|
self.total_usage: Dict[str, Any] = {}
|
|
69
|
+
self.reset_on_run = reset_on_run
|
|
67
70
|
|
|
68
71
|
# Ensure trajectory directory exists
|
|
69
72
|
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -113,32 +116,38 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
|
|
|
113
116
|
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
|
114
117
|
"""Initialize trajectory tracking for a new run."""
|
|
115
118
|
model = kwargs.get("model", "unknown")
|
|
116
|
-
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
|
|
117
|
-
if "+" in model:
|
|
118
|
-
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
|
|
119
|
-
|
|
120
|
-
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
|
|
121
|
-
now = datetime.now()
|
|
122
|
-
self.trajectory_id = f"{now.strftime('%Y-%m-%d')}_{model_name_short}_{now.strftime('%H%M%S')}_{str(uuid.uuid4())[:4]}"
|
|
123
|
-
self.current_turn = 0
|
|
124
|
-
self.current_artifact = 0
|
|
125
|
-
self.model = model
|
|
126
|
-
self.total_usage = {}
|
|
127
|
-
|
|
128
|
-
# Create trajectory directory
|
|
129
|
-
trajectory_path = self.trajectory_dir / self.trajectory_id
|
|
130
|
-
trajectory_path.mkdir(parents=True, exist_ok=True)
|
|
131
|
-
|
|
132
|
-
# Save trajectory metadata
|
|
133
|
-
metadata = {
|
|
134
|
-
"trajectory_id": self.trajectory_id,
|
|
135
|
-
"created_at": str(uuid.uuid1().time),
|
|
136
|
-
"status": "running",
|
|
137
|
-
"kwargs": kwargs,
|
|
138
|
-
}
|
|
139
119
|
|
|
140
|
-
|
|
141
|
-
|
|
120
|
+
# Only reset trajectory state if reset_on_run is True or no trajectory exists
|
|
121
|
+
if self.reset_on_run or not self.trajectory_id:
|
|
122
|
+
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
|
|
123
|
+
if "+" in model:
|
|
124
|
+
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
|
|
125
|
+
|
|
126
|
+
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
|
|
127
|
+
now = datetime.now()
|
|
128
|
+
self.trajectory_id = f"{now.strftime('%Y-%m-%d')}_{model_name_short}_{now.strftime('%H%M%S')}_{str(uuid.uuid4())[:4]}"
|
|
129
|
+
self.current_turn = 0
|
|
130
|
+
self.current_artifact = 0
|
|
131
|
+
self.model = model
|
|
132
|
+
self.total_usage = {}
|
|
133
|
+
|
|
134
|
+
# Create trajectory directory
|
|
135
|
+
trajectory_path = self.trajectory_dir / self.trajectory_id
|
|
136
|
+
trajectory_path.mkdir(parents=True, exist_ok=True)
|
|
137
|
+
|
|
138
|
+
# Save trajectory metadata
|
|
139
|
+
metadata = {
|
|
140
|
+
"trajectory_id": self.trajectory_id,
|
|
141
|
+
"created_at": str(uuid.uuid1().time),
|
|
142
|
+
"status": "running",
|
|
143
|
+
"kwargs": kwargs,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
with open(trajectory_path / "metadata.json", "w") as f:
|
|
147
|
+
json.dump(metadata, f, indent=2)
|
|
148
|
+
else:
|
|
149
|
+
# Continue with existing trajectory - just update model if needed
|
|
150
|
+
self.model = model
|
|
142
151
|
|
|
143
152
|
@override
|
|
144
153
|
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
agent/cli.py
CHANGED
|
@@ -94,7 +94,7 @@ def print_action(action_type: str, details: Dict[str, Any], total_cost: float):
|
|
|
94
94
|
# Format action details
|
|
95
95
|
args_str = ""
|
|
96
96
|
if action_type == "click" and "x" in details and "y" in details:
|
|
97
|
-
args_str = f"_{details
|
|
97
|
+
args_str = f"_{details.get('button', 'left')}({details['x']}, {details['y']})"
|
|
98
98
|
elif action_type == "type" and "text" in details:
|
|
99
99
|
text = details["text"]
|
|
100
100
|
if len(text) > 50:
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Computer handler factory and interface definitions.
|
|
3
|
+
|
|
4
|
+
This module provides a factory function to create computer handlers from different
|
|
5
|
+
computer interface types, supporting both the ComputerHandler protocol and the
|
|
6
|
+
Computer library interface.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from .base import AsyncComputerHandler
|
|
10
|
+
from .cua import cuaComputerHandler
|
|
11
|
+
from .custom import CustomComputerHandler
|
|
12
|
+
from computer import Computer as cuaComputer
|
|
13
|
+
|
|
14
|
+
def is_agent_computer(computer):
|
|
15
|
+
"""Check if the given computer is a ComputerHandler or CUA Computer."""
|
|
16
|
+
return isinstance(computer, AsyncComputerHandler) or \
|
|
17
|
+
isinstance(computer, cuaComputer) or \
|
|
18
|
+
(isinstance(computer, dict)) #and "screenshot" in computer)
|
|
19
|
+
|
|
20
|
+
async def make_computer_handler(computer):
|
|
21
|
+
"""
|
|
22
|
+
Create a computer handler from a computer interface.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
computer: Either a ComputerHandler instance, Computer instance, or dict of functions
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
ComputerHandler: A computer handler instance
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ValueError: If the computer type is not supported
|
|
32
|
+
"""
|
|
33
|
+
if isinstance(computer, AsyncComputerHandler):
|
|
34
|
+
return computer
|
|
35
|
+
if isinstance(computer, cuaComputer):
|
|
36
|
+
computer_handler = cuaComputerHandler(computer)
|
|
37
|
+
await computer_handler._initialize()
|
|
38
|
+
return computer_handler
|
|
39
|
+
if isinstance(computer, dict):
|
|
40
|
+
return CustomComputerHandler(computer)
|
|
41
|
+
raise ValueError(f"Unsupported computer type: {type(computer)}")
|