cua-agent 0.3.1__py3-none-any.whl ā 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +21 -12
- agent/__main__.py +21 -0
- agent/adapters/__init__.py +9 -0
- agent/adapters/huggingfacelocal_adapter.py +229 -0
- agent/agent.py +594 -0
- agent/callbacks/__init__.py +19 -0
- agent/callbacks/base.py +153 -0
- agent/callbacks/budget_manager.py +44 -0
- agent/callbacks/image_retention.py +139 -0
- agent/callbacks/logging.py +247 -0
- agent/callbacks/pii_anonymization.py +259 -0
- agent/callbacks/telemetry.py +210 -0
- agent/callbacks/trajectory_saver.py +305 -0
- agent/cli.py +297 -0
- agent/computer_handler.py +107 -0
- agent/decorators.py +90 -0
- agent/loops/__init__.py +11 -0
- agent/loops/anthropic.py +728 -0
- agent/loops/omniparser.py +339 -0
- agent/loops/openai.py +95 -0
- agent/loops/uitars.py +688 -0
- agent/responses.py +207 -0
- agent/telemetry.py +135 -14
- agent/types.py +79 -0
- agent/ui/__init__.py +7 -1
- agent/ui/__main__.py +2 -13
- agent/ui/gradio/__init__.py +6 -19
- agent/ui/gradio/app.py +94 -1313
- agent/ui/gradio/ui_components.py +721 -0
- cua_agent-0.4.0.dist-info/METADATA +424 -0
- cua_agent-0.4.0.dist-info/RECORD +33 -0
- {cua_agent-0.3.1.dist-info ā cua_agent-0.4.0.dist-info}/WHEEL +1 -1
- agent/core/__init__.py +0 -27
- agent/core/agent.py +0 -210
- agent/core/base.py +0 -217
- agent/core/callbacks.py +0 -200
- agent/core/experiment.py +0 -249
- agent/core/factory.py +0 -122
- agent/core/messages.py +0 -332
- agent/core/provider_config.py +0 -21
- agent/core/telemetry.py +0 -142
- agent/core/tools/__init__.py +0 -21
- agent/core/tools/base.py +0 -74
- agent/core/tools/bash.py +0 -52
- agent/core/tools/collection.py +0 -46
- agent/core/tools/computer.py +0 -113
- agent/core/tools/edit.py +0 -67
- agent/core/tools/manager.py +0 -56
- agent/core/tools.py +0 -32
- agent/core/types.py +0 -88
- agent/core/visualization.py +0 -197
- agent/providers/__init__.py +0 -4
- agent/providers/anthropic/__init__.py +0 -6
- agent/providers/anthropic/api/client.py +0 -360
- agent/providers/anthropic/api/logging.py +0 -150
- agent/providers/anthropic/api_handler.py +0 -140
- agent/providers/anthropic/callbacks/__init__.py +0 -5
- agent/providers/anthropic/callbacks/manager.py +0 -65
- agent/providers/anthropic/loop.py +0 -568
- agent/providers/anthropic/prompts.py +0 -23
- agent/providers/anthropic/response_handler.py +0 -226
- agent/providers/anthropic/tools/__init__.py +0 -33
- agent/providers/anthropic/tools/base.py +0 -88
- agent/providers/anthropic/tools/bash.py +0 -66
- agent/providers/anthropic/tools/collection.py +0 -34
- agent/providers/anthropic/tools/computer.py +0 -396
- agent/providers/anthropic/tools/edit.py +0 -326
- agent/providers/anthropic/tools/manager.py +0 -54
- agent/providers/anthropic/tools/run.py +0 -42
- agent/providers/anthropic/types.py +0 -16
- agent/providers/anthropic/utils.py +0 -367
- agent/providers/omni/__init__.py +0 -8
- agent/providers/omni/api_handler.py +0 -42
- agent/providers/omni/clients/anthropic.py +0 -103
- agent/providers/omni/clients/base.py +0 -35
- agent/providers/omni/clients/oaicompat.py +0 -195
- agent/providers/omni/clients/ollama.py +0 -122
- agent/providers/omni/clients/openai.py +0 -155
- agent/providers/omni/clients/utils.py +0 -25
- agent/providers/omni/image_utils.py +0 -34
- agent/providers/omni/loop.py +0 -990
- agent/providers/omni/parser.py +0 -307
- agent/providers/omni/prompts.py +0 -64
- agent/providers/omni/tools/__init__.py +0 -30
- agent/providers/omni/tools/base.py +0 -29
- agent/providers/omni/tools/bash.py +0 -74
- agent/providers/omni/tools/computer.py +0 -179
- agent/providers/omni/tools/manager.py +0 -61
- agent/providers/omni/utils.py +0 -236
- agent/providers/openai/__init__.py +0 -6
- agent/providers/openai/api_handler.py +0 -456
- agent/providers/openai/loop.py +0 -472
- agent/providers/openai/response_handler.py +0 -205
- agent/providers/openai/tools/__init__.py +0 -15
- agent/providers/openai/tools/base.py +0 -79
- agent/providers/openai/tools/computer.py +0 -326
- agent/providers/openai/tools/manager.py +0 -106
- agent/providers/openai/types.py +0 -36
- agent/providers/openai/utils.py +0 -98
- agent/providers/uitars/__init__.py +0 -1
- agent/providers/uitars/clients/base.py +0 -35
- agent/providers/uitars/clients/mlxvlm.py +0 -263
- agent/providers/uitars/clients/oaicompat.py +0 -214
- agent/providers/uitars/loop.py +0 -660
- agent/providers/uitars/prompts.py +0 -63
- agent/providers/uitars/tools/__init__.py +0 -1
- agent/providers/uitars/tools/computer.py +0 -283
- agent/providers/uitars/tools/manager.py +0 -60
- agent/providers/uitars/utils.py +0 -264
- cua_agent-0.3.1.dist-info/METADATA +0 -295
- cua_agent-0.3.1.dist-info/RECORD +0 -87
- {cua_agent-0.3.1.dist-info ā cua_agent-0.4.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Trajectory saving callback handler for ComputerAgent.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import json
|
|
7
|
+
import uuid
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
import base64
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import List, Dict, Any, Optional, Union, override
|
|
12
|
+
from PIL import Image, ImageDraw
|
|
13
|
+
import io
|
|
14
|
+
from .base import AsyncCallbackHandler
|
|
15
|
+
|
|
16
|
+
def sanitize_image_urls(data: Any) -> Any:
|
|
17
|
+
"""
|
|
18
|
+
Recursively search for 'image_url' keys and set their values to '[omitted]'.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
data: Any data structure (dict, list, or primitive type)
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
|
|
25
|
+
"""
|
|
26
|
+
if isinstance(data, dict):
|
|
27
|
+
# Create a copy of the dictionary
|
|
28
|
+
sanitized = {}
|
|
29
|
+
for key, value in data.items():
|
|
30
|
+
if key == "image_url":
|
|
31
|
+
sanitized[key] = "[omitted]"
|
|
32
|
+
else:
|
|
33
|
+
# Recursively sanitize the value
|
|
34
|
+
sanitized[key] = sanitize_image_urls(value)
|
|
35
|
+
return sanitized
|
|
36
|
+
|
|
37
|
+
elif isinstance(data, list):
|
|
38
|
+
# Recursively sanitize each item in the list
|
|
39
|
+
return [sanitize_image_urls(item) for item in data]
|
|
40
|
+
|
|
41
|
+
else:
|
|
42
|
+
# For primitive types (str, int, bool, None, etc.), return as-is
|
|
43
|
+
return data
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TrajectorySaverCallback(AsyncCallbackHandler):
|
|
47
|
+
"""
|
|
48
|
+
Callback handler that saves agent trajectories to disk.
|
|
49
|
+
|
|
50
|
+
Saves each run as a separate trajectory with unique ID, and each turn
|
|
51
|
+
within the trajectory gets its own folder with screenshots and responses.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, trajectory_dir: str):
|
|
55
|
+
"""
|
|
56
|
+
Initialize trajectory saver.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
trajectory_dir: Base directory to save trajectories
|
|
60
|
+
"""
|
|
61
|
+
self.trajectory_dir = Path(trajectory_dir)
|
|
62
|
+
self.trajectory_id: Optional[str] = None
|
|
63
|
+
self.current_turn: int = 0
|
|
64
|
+
self.current_artifact: int = 0
|
|
65
|
+
self.model: Optional[str] = None
|
|
66
|
+
self.total_usage: Dict[str, Any] = {}
|
|
67
|
+
|
|
68
|
+
# Ensure trajectory directory exists
|
|
69
|
+
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
|
|
70
|
+
|
|
71
|
+
def _get_turn_dir(self) -> Path:
|
|
72
|
+
"""Get the directory for the current turn."""
|
|
73
|
+
if not self.trajectory_id:
|
|
74
|
+
raise ValueError("Trajectory not initialized - call _on_run_start first")
|
|
75
|
+
|
|
76
|
+
# format: trajectory_id/turn_000
|
|
77
|
+
turn_dir = self.trajectory_dir / self.trajectory_id / f"turn_{self.current_turn:03d}"
|
|
78
|
+
turn_dir.mkdir(parents=True, exist_ok=True)
|
|
79
|
+
return turn_dir
|
|
80
|
+
|
|
81
|
+
def _save_artifact(self, name: str, artifact: Union[str, bytes, Dict[str, Any]]) -> None:
|
|
82
|
+
"""Save an artifact to the current turn directory."""
|
|
83
|
+
turn_dir = self._get_turn_dir()
|
|
84
|
+
if isinstance(artifact, bytes):
|
|
85
|
+
# format: turn_000/0000_name.png
|
|
86
|
+
artifact_filename = f"{self.current_artifact:04d}_{name}"
|
|
87
|
+
artifact_path = turn_dir / f"{artifact_filename}.png"
|
|
88
|
+
with open(artifact_path, "wb") as f:
|
|
89
|
+
f.write(artifact)
|
|
90
|
+
else:
|
|
91
|
+
# format: turn_000/0000_name.json
|
|
92
|
+
artifact_filename = f"{self.current_artifact:04d}_{name}"
|
|
93
|
+
artifact_path = turn_dir / f"{artifact_filename}.json"
|
|
94
|
+
with open(artifact_path, "w") as f:
|
|
95
|
+
json.dump(sanitize_image_urls(artifact), f, indent=2)
|
|
96
|
+
self.current_artifact += 1
|
|
97
|
+
|
|
98
|
+
def _update_usage(self, usage: Dict[str, Any]) -> None:
|
|
99
|
+
"""Update total usage statistics."""
|
|
100
|
+
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
|
101
|
+
for key, value in source.items():
|
|
102
|
+
if isinstance(value, dict):
|
|
103
|
+
if key not in target:
|
|
104
|
+
target[key] = {}
|
|
105
|
+
add_dicts(target[key], value)
|
|
106
|
+
else:
|
|
107
|
+
if key not in target:
|
|
108
|
+
target[key] = 0
|
|
109
|
+
target[key] += value
|
|
110
|
+
add_dicts(self.total_usage, usage)
|
|
111
|
+
|
|
112
|
+
@override
|
|
113
|
+
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
|
114
|
+
"""Initialize trajectory tracking for a new run."""
|
|
115
|
+
model = kwargs.get("model", "unknown")
|
|
116
|
+
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
|
|
117
|
+
if "+" in model:
|
|
118
|
+
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
|
|
119
|
+
|
|
120
|
+
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
|
|
121
|
+
now = datetime.now()
|
|
122
|
+
self.trajectory_id = f"{now.strftime('%Y-%m-%d')}_{model_name_short}_{now.strftime('%H%M%S')}_{str(uuid.uuid4())[:4]}"
|
|
123
|
+
self.current_turn = 0
|
|
124
|
+
self.current_artifact = 0
|
|
125
|
+
self.model = model
|
|
126
|
+
self.total_usage = {}
|
|
127
|
+
|
|
128
|
+
# Create trajectory directory
|
|
129
|
+
trajectory_path = self.trajectory_dir / self.trajectory_id
|
|
130
|
+
trajectory_path.mkdir(parents=True, exist_ok=True)
|
|
131
|
+
|
|
132
|
+
# Save trajectory metadata
|
|
133
|
+
metadata = {
|
|
134
|
+
"trajectory_id": self.trajectory_id,
|
|
135
|
+
"created_at": str(uuid.uuid1().time),
|
|
136
|
+
"status": "running",
|
|
137
|
+
"kwargs": kwargs,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
with open(trajectory_path / "metadata.json", "w") as f:
|
|
141
|
+
json.dump(metadata, f, indent=2)
|
|
142
|
+
|
|
143
|
+
@override
|
|
144
|
+
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
|
145
|
+
"""Finalize run tracking by updating metadata with completion status, usage, and new items."""
|
|
146
|
+
if not self.trajectory_id:
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
# Update metadata with completion status, total usage, and new items
|
|
150
|
+
trajectory_path = self.trajectory_dir / self.trajectory_id
|
|
151
|
+
metadata_path = trajectory_path / "metadata.json"
|
|
152
|
+
|
|
153
|
+
# Read existing metadata
|
|
154
|
+
if metadata_path.exists():
|
|
155
|
+
with open(metadata_path, "r") as f:
|
|
156
|
+
metadata = json.load(f)
|
|
157
|
+
else:
|
|
158
|
+
metadata = {}
|
|
159
|
+
|
|
160
|
+
# Update metadata with completion info
|
|
161
|
+
metadata.update({
|
|
162
|
+
"status": "completed",
|
|
163
|
+
"completed_at": str(uuid.uuid1().time),
|
|
164
|
+
"total_usage": self.total_usage,
|
|
165
|
+
"new_items": sanitize_image_urls(new_items),
|
|
166
|
+
"total_turns": self.current_turn
|
|
167
|
+
})
|
|
168
|
+
|
|
169
|
+
# Save updated metadata
|
|
170
|
+
with open(metadata_path, "w") as f:
|
|
171
|
+
json.dump(metadata, f, indent=2)
|
|
172
|
+
|
|
173
|
+
@override
|
|
174
|
+
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
|
175
|
+
if not self.trajectory_id:
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
self._save_artifact("api_start", { "kwargs": kwargs })
|
|
179
|
+
|
|
180
|
+
@override
|
|
181
|
+
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
|
182
|
+
"""Save API call result."""
|
|
183
|
+
if not self.trajectory_id:
|
|
184
|
+
return
|
|
185
|
+
|
|
186
|
+
self._save_artifact("api_result", { "kwargs": kwargs, "result": result })
|
|
187
|
+
|
|
188
|
+
@override
|
|
189
|
+
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
|
190
|
+
"""Save a screenshot."""
|
|
191
|
+
if isinstance(screenshot, str):
|
|
192
|
+
screenshot = base64.b64decode(screenshot)
|
|
193
|
+
self._save_artifact(name, screenshot)
|
|
194
|
+
|
|
195
|
+
@override
|
|
196
|
+
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
|
197
|
+
"""Called when usage information is received."""
|
|
198
|
+
self._update_usage(usage)
|
|
199
|
+
|
|
200
|
+
@override
|
|
201
|
+
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
|
202
|
+
"""Save responses to the current turn directory and update usage statistics."""
|
|
203
|
+
if not self.trajectory_id:
|
|
204
|
+
return
|
|
205
|
+
|
|
206
|
+
# Save responses
|
|
207
|
+
turn_dir = self._get_turn_dir()
|
|
208
|
+
response_data = {
|
|
209
|
+
"timestamp": str(uuid.uuid1().time),
|
|
210
|
+
"model": self.model,
|
|
211
|
+
"kwargs": kwargs,
|
|
212
|
+
"response": responses
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
self._save_artifact("agent_response", response_data)
|
|
216
|
+
|
|
217
|
+
# Increment turn counter
|
|
218
|
+
self.current_turn += 1
|
|
219
|
+
|
|
220
|
+
def _draw_crosshair_on_image(self, image_bytes: bytes, x: int, y: int) -> bytes:
|
|
221
|
+
"""
|
|
222
|
+
Draw a red dot and crosshair at the specified coordinates on the image.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
image_bytes: The original image as bytes
|
|
226
|
+
x: X coordinate for the crosshair
|
|
227
|
+
y: Y coordinate for the crosshair
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Modified image as bytes with red dot and crosshair
|
|
231
|
+
"""
|
|
232
|
+
# Open the image
|
|
233
|
+
image = Image.open(io.BytesIO(image_bytes))
|
|
234
|
+
draw = ImageDraw.Draw(image)
|
|
235
|
+
|
|
236
|
+
# Draw crosshair lines (red, 2px thick)
|
|
237
|
+
crosshair_size = 20
|
|
238
|
+
line_width = 2
|
|
239
|
+
color = "red"
|
|
240
|
+
|
|
241
|
+
# Horizontal line
|
|
242
|
+
draw.line([(x - crosshair_size, y), (x + crosshair_size, y)], fill=color, width=line_width)
|
|
243
|
+
# Vertical line
|
|
244
|
+
draw.line([(x, y - crosshair_size), (x, y + crosshair_size)], fill=color, width=line_width)
|
|
245
|
+
|
|
246
|
+
# Draw center dot (filled circle)
|
|
247
|
+
dot_radius = 3
|
|
248
|
+
draw.ellipse([(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color)
|
|
249
|
+
|
|
250
|
+
# Convert back to bytes
|
|
251
|
+
output = io.BytesIO()
|
|
252
|
+
image.save(output, format='PNG')
|
|
253
|
+
return output.getvalue()
|
|
254
|
+
|
|
255
|
+
@override
|
|
256
|
+
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
|
257
|
+
"""
|
|
258
|
+
Called when a computer call has completed.
|
|
259
|
+
Saves screenshots and computer call output.
|
|
260
|
+
"""
|
|
261
|
+
if not self.trajectory_id:
|
|
262
|
+
return
|
|
263
|
+
|
|
264
|
+
self._save_artifact("computer_call_result", { "item": item, "result": result })
|
|
265
|
+
|
|
266
|
+
# Check if action has x/y coordinates and there's a screenshot in the result
|
|
267
|
+
action = item.get("action", {})
|
|
268
|
+
if "x" in action and "y" in action:
|
|
269
|
+
# Look for screenshot in the result
|
|
270
|
+
for result_item in result:
|
|
271
|
+
if (result_item.get("type") == "computer_call_output" and
|
|
272
|
+
result_item.get("output", {}).get("type") == "input_image"):
|
|
273
|
+
|
|
274
|
+
image_url = result_item["output"]["image_url"]
|
|
275
|
+
|
|
276
|
+
# Extract base64 image data
|
|
277
|
+
if image_url.startswith("data:image/"):
|
|
278
|
+
# Format: data:image/png;base64,<base64_data>
|
|
279
|
+
base64_data = image_url.split(",", 1)[1]
|
|
280
|
+
else:
|
|
281
|
+
# Assume it's just base64 data
|
|
282
|
+
base64_data = image_url
|
|
283
|
+
|
|
284
|
+
try:
|
|
285
|
+
# Decode the image
|
|
286
|
+
image_bytes = base64.b64decode(base64_data)
|
|
287
|
+
|
|
288
|
+
# Draw crosshair at the action coordinates
|
|
289
|
+
annotated_image = self._draw_crosshair_on_image(
|
|
290
|
+
image_bytes,
|
|
291
|
+
int(action["x"]),
|
|
292
|
+
int(action["y"])
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# Save as screenshot_action
|
|
296
|
+
self._save_artifact("screenshot_action", annotated_image)
|
|
297
|
+
|
|
298
|
+
except Exception as e:
|
|
299
|
+
# If annotation fails, just log and continue
|
|
300
|
+
print(f"Failed to annotate screenshot: {e}")
|
|
301
|
+
|
|
302
|
+
break # Only process the first screenshot found
|
|
303
|
+
|
|
304
|
+
# Increment turn counter
|
|
305
|
+
self.current_turn += 1
|
agent/cli.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CLI chat interface for agent - Computer Use Agent
|
|
3
|
+
|
|
4
|
+
Usage:
|
|
5
|
+
python -m agent.cli <model_string>
|
|
6
|
+
|
|
7
|
+
Examples:
|
|
8
|
+
python -m agent.cli openai/computer-use-preview
|
|
9
|
+
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
|
|
10
|
+
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import asyncio
|
|
15
|
+
import argparse
|
|
16
|
+
import os
|
|
17
|
+
import sys
|
|
18
|
+
import json
|
|
19
|
+
from typing import List, Dict, Any
|
|
20
|
+
import dotenv
|
|
21
|
+
from yaspin import yaspin
|
|
22
|
+
except ImportError:
|
|
23
|
+
if __name__ == "__main__":
|
|
24
|
+
raise ImportError(
|
|
25
|
+
"CLI dependencies not found. "
|
|
26
|
+
"Please install with: pip install \"cua-agent[cli]\""
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
# Load environment variables
|
|
30
|
+
dotenv.load_dotenv()
|
|
31
|
+
|
|
32
|
+
# Color codes for terminal output
|
|
33
|
+
class Colors:
|
|
34
|
+
RESET = '\033[0m'
|
|
35
|
+
BOLD = '\033[1m'
|
|
36
|
+
DIM = '\033[2m'
|
|
37
|
+
|
|
38
|
+
# Text colors
|
|
39
|
+
RED = '\033[31m'
|
|
40
|
+
GREEN = '\033[32m'
|
|
41
|
+
YELLOW = '\033[33m'
|
|
42
|
+
BLUE = '\033[34m'
|
|
43
|
+
MAGENTA = '\033[35m'
|
|
44
|
+
CYAN = '\033[36m'
|
|
45
|
+
WHITE = '\033[37m'
|
|
46
|
+
GRAY = '\033[90m'
|
|
47
|
+
|
|
48
|
+
# Background colors
|
|
49
|
+
BG_RED = '\033[41m'
|
|
50
|
+
BG_GREEN = '\033[42m'
|
|
51
|
+
BG_YELLOW = '\033[43m'
|
|
52
|
+
BG_BLUE = '\033[44m'
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = False, end: str = "\n"):
|
|
56
|
+
"""Print colored text to terminal."""
|
|
57
|
+
prefix = ""
|
|
58
|
+
if bold:
|
|
59
|
+
prefix += Colors.BOLD
|
|
60
|
+
if dim:
|
|
61
|
+
prefix += Colors.DIM
|
|
62
|
+
if color:
|
|
63
|
+
prefix += color
|
|
64
|
+
|
|
65
|
+
print(f"{prefix}{text}{Colors.RESET}", end=end)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def print_action(action_type: str, details: Dict[str, Any]):
|
|
69
|
+
"""Print computer action with nice formatting."""
|
|
70
|
+
# Format action details
|
|
71
|
+
args_str = ""
|
|
72
|
+
if action_type == "click" and "x" in details and "y" in details:
|
|
73
|
+
args_str = f"({details['x']}, {details['y']})"
|
|
74
|
+
elif action_type == "type" and "text" in details:
|
|
75
|
+
text = details["text"]
|
|
76
|
+
if len(text) > 50:
|
|
77
|
+
text = text[:47] + "..."
|
|
78
|
+
args_str = f'"{text}"'
|
|
79
|
+
elif action_type == "key" and "key" in details:
|
|
80
|
+
args_str = f"'{details['key']}'"
|
|
81
|
+
elif action_type == "scroll" and "x" in details and "y" in details:
|
|
82
|
+
args_str = f"({details['x']}, {details['y']})"
|
|
83
|
+
|
|
84
|
+
print_colored(f"š ļø {action_type}{args_str}", dim=True)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def print_welcome(model: str, agent_loop: str, container_name: str):
|
|
88
|
+
"""Print welcome message."""
|
|
89
|
+
print_colored(f"Connected to {container_name} ({model}, {agent_loop})")
|
|
90
|
+
print_colored("Type 'exit' to quit.", dim=True)
|
|
91
|
+
|
|
92
|
+
async def ainput(prompt: str = ""):
|
|
93
|
+
return await asyncio.to_thread(input, prompt)
|
|
94
|
+
|
|
95
|
+
async def chat_loop(agent, model: str, container_name: str):
|
|
96
|
+
"""Main chat loop with the agent."""
|
|
97
|
+
print_welcome(model, agent.agent_loop.__name__, container_name)
|
|
98
|
+
|
|
99
|
+
history = []
|
|
100
|
+
|
|
101
|
+
while True:
|
|
102
|
+
# Get user input with prompt
|
|
103
|
+
print_colored("> ", end="")
|
|
104
|
+
user_input = await ainput()
|
|
105
|
+
|
|
106
|
+
if user_input.lower() in ['exit', 'quit', 'q']:
|
|
107
|
+
print_colored("\nš Goodbye!")
|
|
108
|
+
break
|
|
109
|
+
|
|
110
|
+
if not user_input:
|
|
111
|
+
continue
|
|
112
|
+
|
|
113
|
+
# Add user message to history
|
|
114
|
+
history.append({"role": "user", "content": user_input})
|
|
115
|
+
|
|
116
|
+
# Stream responses from the agent with spinner
|
|
117
|
+
with yaspin(text="Thinking...", spinner="line", attrs=["dark"]) as spinner:
|
|
118
|
+
spinner.hide()
|
|
119
|
+
|
|
120
|
+
async for result in agent.run(history):
|
|
121
|
+
# Add agent responses to history
|
|
122
|
+
history.extend(result.get("output", []))
|
|
123
|
+
|
|
124
|
+
# Process and display the output
|
|
125
|
+
for item in result.get("output", []):
|
|
126
|
+
if item.get("type") == "message":
|
|
127
|
+
# Display agent text response
|
|
128
|
+
content = item.get("content", [])
|
|
129
|
+
for content_part in content:
|
|
130
|
+
if content_part.get("text"):
|
|
131
|
+
text = content_part.get("text", "").strip()
|
|
132
|
+
if text:
|
|
133
|
+
spinner.hide()
|
|
134
|
+
print_colored(text)
|
|
135
|
+
|
|
136
|
+
elif item.get("type") == "computer_call":
|
|
137
|
+
# Display computer action
|
|
138
|
+
action = item.get("action", {})
|
|
139
|
+
action_type = action.get("type", "")
|
|
140
|
+
if action_type:
|
|
141
|
+
spinner.hide()
|
|
142
|
+
print_action(action_type, action)
|
|
143
|
+
spinner.text = f"Performing {action_type}..."
|
|
144
|
+
spinner.show()
|
|
145
|
+
|
|
146
|
+
elif item.get("type") == "function_call":
|
|
147
|
+
# Display function call
|
|
148
|
+
function_name = item.get("name", "")
|
|
149
|
+
spinner.hide()
|
|
150
|
+
print_colored(f"š§ Calling function: {function_name}", dim=True)
|
|
151
|
+
spinner.text = f"Calling {function_name}..."
|
|
152
|
+
spinner.show()
|
|
153
|
+
|
|
154
|
+
elif item.get("type") == "function_call_output":
|
|
155
|
+
# Display function output (dimmed)
|
|
156
|
+
output = item.get("output", "")
|
|
157
|
+
if output and len(output.strip()) > 0:
|
|
158
|
+
spinner.hide()
|
|
159
|
+
print_colored(f"š¤ {output}", dim=True)
|
|
160
|
+
|
|
161
|
+
spinner.hide()
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
async def main():
|
|
165
|
+
"""Main CLI function."""
|
|
166
|
+
parser = argparse.ArgumentParser(
|
|
167
|
+
description="CUA Agent CLI - Interactive computer use assistant",
|
|
168
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
169
|
+
epilog="""
|
|
170
|
+
Examples:
|
|
171
|
+
python -m agent.cli openai/computer-use-preview
|
|
172
|
+
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
|
|
173
|
+
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
|
|
174
|
+
python -m agent.cli huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B
|
|
175
|
+
"""
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
parser.add_argument(
|
|
179
|
+
"model",
|
|
180
|
+
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
parser.add_argument(
|
|
184
|
+
"--images",
|
|
185
|
+
type=int,
|
|
186
|
+
default=3,
|
|
187
|
+
help="Number of recent images to keep in context (default: 3)"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
parser.add_argument(
|
|
191
|
+
"--trajectory",
|
|
192
|
+
action="store_true",
|
|
193
|
+
help="Save trajectory for debugging"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
parser.add_argument(
|
|
197
|
+
"--budget",
|
|
198
|
+
type=float,
|
|
199
|
+
help="Maximum budget for the session (in dollars)"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
parser.add_argument(
|
|
203
|
+
"--verbose",
|
|
204
|
+
action="store_true",
|
|
205
|
+
help="Enable verbose logging"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
args = parser.parse_args()
|
|
209
|
+
|
|
210
|
+
# Check for required environment variables
|
|
211
|
+
container_name = os.getenv("CUA_CONTAINER_NAME")
|
|
212
|
+
cua_api_key = os.getenv("CUA_API_KEY")
|
|
213
|
+
|
|
214
|
+
# Prompt for missing environment variables
|
|
215
|
+
if not container_name:
|
|
216
|
+
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
|
|
217
|
+
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
|
|
218
|
+
container_name = input("Enter your CUA container name: ").strip()
|
|
219
|
+
if not container_name:
|
|
220
|
+
print_colored("ā Container name is required.")
|
|
221
|
+
sys.exit(1)
|
|
222
|
+
|
|
223
|
+
if not cua_api_key:
|
|
224
|
+
print_colored("CUA_API_KEY not set.", dim=True)
|
|
225
|
+
cua_api_key = input("Enter your CUA API key: ").strip()
|
|
226
|
+
if not cua_api_key:
|
|
227
|
+
print_colored("ā API key is required.")
|
|
228
|
+
sys.exit(1)
|
|
229
|
+
|
|
230
|
+
# Check for provider-specific API keys based on model
|
|
231
|
+
provider_api_keys = {
|
|
232
|
+
"openai/": "OPENAI_API_KEY",
|
|
233
|
+
"anthropic/": "ANTHROPIC_API_KEY",
|
|
234
|
+
"omniparser+": "OPENAI_API_KEY",
|
|
235
|
+
"omniparser+": "ANTHROPIC_API_KEY",
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
# Find matching provider and check for API key
|
|
239
|
+
for prefix, env_var in provider_api_keys.items():
|
|
240
|
+
if args.model.startswith(prefix):
|
|
241
|
+
if not os.getenv(env_var):
|
|
242
|
+
print_colored(f"{env_var} not set.", dim=True)
|
|
243
|
+
api_key = input(f"Enter your {env_var.replace('_', ' ').title()}: ").strip()
|
|
244
|
+
if not api_key:
|
|
245
|
+
print_colored(f"ā {env_var.replace('_', ' ').title()} is required.")
|
|
246
|
+
sys.exit(1)
|
|
247
|
+
# Set the environment variable for the session
|
|
248
|
+
os.environ[env_var] = api_key
|
|
249
|
+
break
|
|
250
|
+
|
|
251
|
+
# Import here to avoid import errors if dependencies are missing
|
|
252
|
+
try:
|
|
253
|
+
from agent import ComputerAgent
|
|
254
|
+
from computer import Computer
|
|
255
|
+
except ImportError as e:
|
|
256
|
+
print_colored(f"ā Import error: {e}", Colors.RED, bold=True)
|
|
257
|
+
print_colored("Make sure agent and computer libraries are installed.", Colors.YELLOW)
|
|
258
|
+
sys.exit(1)
|
|
259
|
+
|
|
260
|
+
# Create computer instance
|
|
261
|
+
async with Computer(
|
|
262
|
+
os_type="linux",
|
|
263
|
+
provider_type="cloud",
|
|
264
|
+
name=container_name,
|
|
265
|
+
api_key=cua_api_key
|
|
266
|
+
) as computer:
|
|
267
|
+
|
|
268
|
+
# Create agent
|
|
269
|
+
agent_kwargs = {
|
|
270
|
+
"model": args.model,
|
|
271
|
+
"tools": [computer],
|
|
272
|
+
"only_n_most_recent_images": args.images,
|
|
273
|
+
"verbosity": 20 if args.verbose else 30, # DEBUG vs WARNING
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
if args.trajectory:
|
|
277
|
+
agent_kwargs["trajectory_dir"] = "trajectories"
|
|
278
|
+
|
|
279
|
+
if args.budget:
|
|
280
|
+
agent_kwargs["max_trajectory_budget"] = {
|
|
281
|
+
"max_budget": args.budget,
|
|
282
|
+
"raise_error": True,
|
|
283
|
+
"reset_after_each_run": False
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
agent = ComputerAgent(**agent_kwargs)
|
|
287
|
+
|
|
288
|
+
# Start chat loop
|
|
289
|
+
await chat_loop(agent, args.model, container_name)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
if __name__ == "__main__":
|
|
294
|
+
try:
|
|
295
|
+
asyncio.run(main())
|
|
296
|
+
except (KeyboardInterrupt, EOFError) as _:
|
|
297
|
+
print_colored("\n\nš Goodbye!")
|