cua-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/README.md +63 -0
- agent/__init__.py +10 -0
- agent/core/README.md +101 -0
- agent/core/__init__.py +34 -0
- agent/core/agent.py +284 -0
- agent/core/base_agent.py +164 -0
- agent/core/callbacks.py +147 -0
- agent/core/computer_agent.py +69 -0
- agent/core/experiment.py +222 -0
- agent/core/factory.py +102 -0
- agent/core/loop.py +244 -0
- agent/core/messages.py +230 -0
- agent/core/tools/__init__.py +21 -0
- agent/core/tools/base.py +74 -0
- agent/core/tools/bash.py +52 -0
- agent/core/tools/collection.py +46 -0
- agent/core/tools/computer.py +113 -0
- agent/core/tools/edit.py +67 -0
- agent/core/tools/manager.py +56 -0
- agent/providers/__init__.py +4 -0
- agent/providers/anthropic/__init__.py +6 -0
- agent/providers/anthropic/api/client.py +222 -0
- agent/providers/anthropic/api/logging.py +150 -0
- agent/providers/anthropic/callbacks/manager.py +55 -0
- agent/providers/anthropic/loop.py +521 -0
- agent/providers/anthropic/messages/manager.py +110 -0
- agent/providers/anthropic/prompts.py +20 -0
- agent/providers/anthropic/tools/__init__.py +33 -0
- agent/providers/anthropic/tools/base.py +88 -0
- agent/providers/anthropic/tools/bash.py +163 -0
- agent/providers/anthropic/tools/collection.py +34 -0
- agent/providers/anthropic/tools/computer.py +550 -0
- agent/providers/anthropic/tools/edit.py +326 -0
- agent/providers/anthropic/tools/manager.py +54 -0
- agent/providers/anthropic/tools/run.py +42 -0
- agent/providers/anthropic/types.py +16 -0
- agent/providers/omni/__init__.py +27 -0
- agent/providers/omni/callbacks.py +78 -0
- agent/providers/omni/clients/anthropic.py +99 -0
- agent/providers/omni/clients/base.py +44 -0
- agent/providers/omni/clients/groq.py +101 -0
- agent/providers/omni/clients/openai.py +159 -0
- agent/providers/omni/clients/utils.py +25 -0
- agent/providers/omni/experiment.py +273 -0
- agent/providers/omni/image_utils.py +106 -0
- agent/providers/omni/loop.py +961 -0
- agent/providers/omni/messages.py +168 -0
- agent/providers/omni/parser.py +252 -0
- agent/providers/omni/prompts.py +78 -0
- agent/providers/omni/tool_manager.py +91 -0
- agent/providers/omni/tools/__init__.py +13 -0
- agent/providers/omni/tools/bash.py +69 -0
- agent/providers/omni/tools/computer.py +216 -0
- agent/providers/omni/tools/manager.py +83 -0
- agent/providers/omni/types.py +30 -0
- agent/providers/omni/utils.py +155 -0
- agent/providers/omni/visualization.py +130 -0
- agent/types/__init__.py +26 -0
- agent/types/base.py +52 -0
- agent/types/messages.py +36 -0
- agent/types/tools.py +32 -0
- cua_agent-0.1.0.dist-info/METADATA +44 -0
- cua_agent-0.1.0.dist-info/RECORD +65 -0
- cua_agent-0.1.0.dist-info/WHEEL +4 -0
- cua_agent-0.1.0.dist-info/entry_points.txt +4 -0
agent/core/loop.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""Base agent loop implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import asyncio
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
import base64
|
|
11
|
+
|
|
12
|
+
from computer import Computer
|
|
13
|
+
from .experiment import ExperimentManager
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseLoop(ABC):
|
|
19
|
+
"""Base class for agent loops that handle message processing and tool execution."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
computer: Computer,
|
|
24
|
+
model: str,
|
|
25
|
+
api_key: str,
|
|
26
|
+
max_tokens: int = 4096,
|
|
27
|
+
max_retries: int = 3,
|
|
28
|
+
retry_delay: float = 1.0,
|
|
29
|
+
base_dir: Optional[str] = "trajectories",
|
|
30
|
+
save_trajectory: bool = True,
|
|
31
|
+
only_n_most_recent_images: Optional[int] = 2,
|
|
32
|
+
**kwargs,
|
|
33
|
+
):
|
|
34
|
+
"""Initialize base agent loop.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
computer: Computer instance to control
|
|
38
|
+
model: Model name to use
|
|
39
|
+
api_key: API key for provider
|
|
40
|
+
max_tokens: Maximum tokens to generate
|
|
41
|
+
max_retries: Maximum number of retries
|
|
42
|
+
retry_delay: Delay between retries in seconds
|
|
43
|
+
base_dir: Base directory for saving experiment data
|
|
44
|
+
save_trajectory: Whether to save trajectory data
|
|
45
|
+
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
|
46
|
+
**kwargs: Additional provider-specific arguments
|
|
47
|
+
"""
|
|
48
|
+
self.computer = computer
|
|
49
|
+
self.model = model
|
|
50
|
+
self.api_key = api_key
|
|
51
|
+
self.max_tokens = max_tokens
|
|
52
|
+
self.max_retries = max_retries
|
|
53
|
+
self.retry_delay = retry_delay
|
|
54
|
+
self.base_dir = base_dir
|
|
55
|
+
self.save_trajectory = save_trajectory
|
|
56
|
+
self.only_n_most_recent_images = only_n_most_recent_images
|
|
57
|
+
self._kwargs = kwargs
|
|
58
|
+
self.message_history = []
|
|
59
|
+
# self.tool_manager = BaseToolManager(computer)
|
|
60
|
+
|
|
61
|
+
# Initialize experiment manager
|
|
62
|
+
if self.save_trajectory and self.base_dir:
|
|
63
|
+
self.experiment_manager = ExperimentManager(
|
|
64
|
+
base_dir=self.base_dir,
|
|
65
|
+
only_n_most_recent_images=only_n_most_recent_images,
|
|
66
|
+
)
|
|
67
|
+
# Track directories for convenience
|
|
68
|
+
self.run_dir = self.experiment_manager.run_dir
|
|
69
|
+
self.current_turn_dir = self.experiment_manager.current_turn_dir
|
|
70
|
+
else:
|
|
71
|
+
self.experiment_manager = None
|
|
72
|
+
self.run_dir = None
|
|
73
|
+
self.current_turn_dir = None
|
|
74
|
+
|
|
75
|
+
# Initialize basic tracking
|
|
76
|
+
self.turn_count = 0
|
|
77
|
+
|
|
78
|
+
def _setup_experiment_dirs(self) -> None:
|
|
79
|
+
"""Setup the experiment directory structure."""
|
|
80
|
+
if self.experiment_manager:
|
|
81
|
+
# Use the experiment manager to set up directories
|
|
82
|
+
self.experiment_manager.setup_experiment_dirs()
|
|
83
|
+
|
|
84
|
+
# Update local tracking variables
|
|
85
|
+
self.run_dir = self.experiment_manager.run_dir
|
|
86
|
+
self.current_turn_dir = self.experiment_manager.current_turn_dir
|
|
87
|
+
|
|
88
|
+
def _create_turn_dir(self) -> None:
|
|
89
|
+
"""Create a new directory for the current turn."""
|
|
90
|
+
if self.experiment_manager:
|
|
91
|
+
# Use the experiment manager to create the turn directory
|
|
92
|
+
self.experiment_manager.create_turn_dir()
|
|
93
|
+
|
|
94
|
+
# Update local tracking variables
|
|
95
|
+
self.current_turn_dir = self.experiment_manager.current_turn_dir
|
|
96
|
+
self.turn_count = self.experiment_manager.turn_count
|
|
97
|
+
|
|
98
|
+
def _log_api_call(
|
|
99
|
+
self, call_type: str, request: Any, response: Any = None, error: Optional[Exception] = None
|
|
100
|
+
) -> None:
|
|
101
|
+
"""Log API call details to file.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
call_type: Type of API call (e.g., 'request', 'response', 'error')
|
|
105
|
+
request: The API request data
|
|
106
|
+
response: Optional API response data
|
|
107
|
+
error: Optional error information
|
|
108
|
+
"""
|
|
109
|
+
if self.experiment_manager:
|
|
110
|
+
# Use the experiment manager to log the API call
|
|
111
|
+
provider = getattr(self, "provider", "unknown")
|
|
112
|
+
provider_str = str(provider) if provider else "unknown"
|
|
113
|
+
|
|
114
|
+
self.experiment_manager.log_api_call(
|
|
115
|
+
call_type=call_type,
|
|
116
|
+
request=request,
|
|
117
|
+
provider=provider_str,
|
|
118
|
+
model=self.model,
|
|
119
|
+
response=response,
|
|
120
|
+
error=error,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def _save_screenshot(self, img_base64: str, action_type: str = "") -> None:
|
|
124
|
+
"""Save a screenshot to the experiment directory.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
img_base64: Base64 encoded screenshot
|
|
128
|
+
action_type: Type of action that triggered the screenshot
|
|
129
|
+
"""
|
|
130
|
+
if self.experiment_manager:
|
|
131
|
+
self.experiment_manager.save_screenshot(img_base64, action_type)
|
|
132
|
+
|
|
133
|
+
async def initialize(self) -> None:
|
|
134
|
+
"""Initialize both the API client and computer interface with retries."""
|
|
135
|
+
for attempt in range(self.max_retries):
|
|
136
|
+
try:
|
|
137
|
+
logger.info(
|
|
138
|
+
f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..."
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Initialize API client
|
|
142
|
+
await self.initialize_client()
|
|
143
|
+
|
|
144
|
+
# Initialize computer
|
|
145
|
+
await self.computer.initialize()
|
|
146
|
+
|
|
147
|
+
logger.info("Initialization complete.")
|
|
148
|
+
return
|
|
149
|
+
except Exception as e:
|
|
150
|
+
if attempt < self.max_retries - 1:
|
|
151
|
+
logger.warning(
|
|
152
|
+
f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..."
|
|
153
|
+
)
|
|
154
|
+
await asyncio.sleep(self.retry_delay)
|
|
155
|
+
else:
|
|
156
|
+
logger.error(
|
|
157
|
+
f"Initialization failed after {self.max_retries} attempts: {str(e)}"
|
|
158
|
+
)
|
|
159
|
+
raise RuntimeError(f"Failed to initialize: {str(e)}")
|
|
160
|
+
|
|
161
|
+
async def _get_parsed_screen_som(self) -> Dict[str, Any]:
|
|
162
|
+
"""Get parsed screen information.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Dict containing screen information
|
|
166
|
+
"""
|
|
167
|
+
try:
|
|
168
|
+
# Take screenshot
|
|
169
|
+
screenshot = await self.computer.screenshot()
|
|
170
|
+
|
|
171
|
+
# Initialize with default values
|
|
172
|
+
width, height = 1024, 768
|
|
173
|
+
base64_image = ""
|
|
174
|
+
|
|
175
|
+
# Handle different types of screenshot returns
|
|
176
|
+
if isinstance(screenshot, bytes):
|
|
177
|
+
# Raw bytes screenshot
|
|
178
|
+
base64_image = base64.b64encode(screenshot).decode("utf-8")
|
|
179
|
+
elif hasattr(screenshot, "base64_image"):
|
|
180
|
+
# Object-style screenshot with attributes
|
|
181
|
+
base64_image = screenshot.base64_image
|
|
182
|
+
if hasattr(screenshot, "width") and hasattr(screenshot, "height"):
|
|
183
|
+
width = screenshot.width
|
|
184
|
+
height = screenshot.height
|
|
185
|
+
|
|
186
|
+
# Create parsed screen data
|
|
187
|
+
parsed_screen = {
|
|
188
|
+
"width": width,
|
|
189
|
+
"height": height,
|
|
190
|
+
"parsed_content_list": [],
|
|
191
|
+
"timestamp": datetime.now().isoformat(),
|
|
192
|
+
"screenshot_base64": base64_image,
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
# Save screenshot if requested
|
|
196
|
+
if self.save_trajectory and self.experiment_manager:
|
|
197
|
+
try:
|
|
198
|
+
img_data = base64_image
|
|
199
|
+
if "," in img_data:
|
|
200
|
+
img_data = img_data.split(",")[1]
|
|
201
|
+
self._save_screenshot(img_data, action_type="state")
|
|
202
|
+
except Exception as e:
|
|
203
|
+
logger.error(f"Error saving screenshot: {str(e)}")
|
|
204
|
+
|
|
205
|
+
return parsed_screen
|
|
206
|
+
except Exception as e:
|
|
207
|
+
logger.error(f"Error taking screenshot: {str(e)}")
|
|
208
|
+
return {
|
|
209
|
+
"width": 1024,
|
|
210
|
+
"height": 768,
|
|
211
|
+
"parsed_content_list": [],
|
|
212
|
+
"timestamp": datetime.now().isoformat(),
|
|
213
|
+
"error": f"Error taking screenshot: {str(e)}",
|
|
214
|
+
"screenshot_base64": "",
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
@abstractmethod
|
|
218
|
+
async def initialize_client(self) -> None:
|
|
219
|
+
"""Initialize the API client and any provider-specific components."""
|
|
220
|
+
raise NotImplementedError
|
|
221
|
+
|
|
222
|
+
@abstractmethod
|
|
223
|
+
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
|
|
224
|
+
"""Run the agent loop with provided messages.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
messages: List of message objects
|
|
228
|
+
|
|
229
|
+
Yields:
|
|
230
|
+
Dict containing response data
|
|
231
|
+
"""
|
|
232
|
+
raise NotImplementedError
|
|
233
|
+
|
|
234
|
+
@abstractmethod
|
|
235
|
+
async def _process_screen(
|
|
236
|
+
self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]]
|
|
237
|
+
) -> None:
|
|
238
|
+
"""Process screen information and add to messages.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
parsed_screen: Dictionary containing parsed screen info
|
|
242
|
+
messages: List of messages to update
|
|
243
|
+
"""
|
|
244
|
+
raise NotImplementedError
|
agent/core/messages.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""Message handling utilities for agent."""
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any, Dict, List, Optional, Union
|
|
8
|
+
from PIL import Image
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class ImageRetentionConfig:
|
|
16
|
+
"""Configuration for image retention in messages."""
|
|
17
|
+
|
|
18
|
+
num_images_to_keep: Optional[int] = None
|
|
19
|
+
min_removal_threshold: int = 1
|
|
20
|
+
enable_caching: bool = True
|
|
21
|
+
|
|
22
|
+
def should_retain_images(self) -> bool:
|
|
23
|
+
"""Check if image retention is enabled."""
|
|
24
|
+
return self.num_images_to_keep is not None and self.num_images_to_keep > 0
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class BaseMessageManager:
|
|
28
|
+
"""Base class for message preparation and management."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, image_retention_config: Optional[ImageRetentionConfig] = None):
|
|
31
|
+
"""Initialize the message manager.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
image_retention_config: Configuration for image retention
|
|
35
|
+
"""
|
|
36
|
+
self.image_retention_config = image_retention_config or ImageRetentionConfig()
|
|
37
|
+
if self.image_retention_config.min_removal_threshold < 1:
|
|
38
|
+
raise ValueError("min_removal_threshold must be at least 1")
|
|
39
|
+
|
|
40
|
+
def prepare_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
41
|
+
"""Prepare messages by applying image retention and caching as configured.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
messages: List of messages to prepare
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Prepared messages
|
|
48
|
+
"""
|
|
49
|
+
if self.image_retention_config.should_retain_images():
|
|
50
|
+
self._filter_images(messages)
|
|
51
|
+
if self.image_retention_config.enable_caching:
|
|
52
|
+
self._inject_caching(messages)
|
|
53
|
+
return messages
|
|
54
|
+
|
|
55
|
+
def _filter_images(self, messages: List[Dict[str, Any]]) -> None:
|
|
56
|
+
"""Filter messages to retain only the specified number of most recent images.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
messages: Messages to filter
|
|
60
|
+
"""
|
|
61
|
+
# Find all tool result blocks that contain images
|
|
62
|
+
tool_results = [
|
|
63
|
+
item
|
|
64
|
+
for message in messages
|
|
65
|
+
for item in (message["content"] if isinstance(message["content"], list) else [])
|
|
66
|
+
if isinstance(item, dict) and item.get("type") == "tool_result"
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
# Count total images
|
|
70
|
+
total_images = sum(
|
|
71
|
+
1
|
|
72
|
+
for result in tool_results
|
|
73
|
+
for content in result.get("content", [])
|
|
74
|
+
if isinstance(content, dict) and content.get("type") == "image"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Calculate how many images to remove
|
|
78
|
+
images_to_remove = total_images - (self.image_retention_config.num_images_to_keep or 0)
|
|
79
|
+
images_to_remove -= images_to_remove % self.image_retention_config.min_removal_threshold
|
|
80
|
+
|
|
81
|
+
# Remove oldest images first
|
|
82
|
+
for result in tool_results:
|
|
83
|
+
if isinstance(result.get("content"), list):
|
|
84
|
+
new_content = []
|
|
85
|
+
for content in result["content"]:
|
|
86
|
+
if isinstance(content, dict) and content.get("type") == "image":
|
|
87
|
+
if images_to_remove > 0:
|
|
88
|
+
images_to_remove -= 1
|
|
89
|
+
continue
|
|
90
|
+
new_content.append(content)
|
|
91
|
+
result["content"] = new_content
|
|
92
|
+
|
|
93
|
+
def _inject_caching(self, messages: List[Dict[str, Any]]) -> None:
|
|
94
|
+
"""Inject caching control for recent message turns.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
messages: Messages to inject caching into
|
|
98
|
+
"""
|
|
99
|
+
# Default to caching last 3 turns
|
|
100
|
+
turns_to_cache = 3
|
|
101
|
+
for message in reversed(messages):
|
|
102
|
+
if message["role"] == "user" and isinstance(content := message["content"], list):
|
|
103
|
+
if turns_to_cache:
|
|
104
|
+
turns_to_cache -= 1
|
|
105
|
+
content[-1]["cache_control"] = {"type": "ephemeral"}
|
|
106
|
+
else:
|
|
107
|
+
content[-1].pop("cache_control", None)
|
|
108
|
+
break
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def create_user_message(text: str) -> Dict[str, str]:
|
|
112
|
+
"""Create a user message.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
text: The message text
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Message dictionary
|
|
119
|
+
"""
|
|
120
|
+
return {
|
|
121
|
+
"role": "user",
|
|
122
|
+
"content": text,
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def create_assistant_message(text: str) -> Dict[str, str]:
|
|
127
|
+
"""Create an assistant message.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
text: The message text
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Message dictionary
|
|
134
|
+
"""
|
|
135
|
+
return {
|
|
136
|
+
"role": "assistant",
|
|
137
|
+
"content": text,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def create_system_message(text: str) -> Dict[str, str]:
|
|
142
|
+
"""Create a system message.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
text: The message text
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Message dictionary
|
|
149
|
+
"""
|
|
150
|
+
return {
|
|
151
|
+
"role": "system",
|
|
152
|
+
"content": text,
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def create_image_message(
|
|
157
|
+
image_base64: Optional[str] = None,
|
|
158
|
+
image_path: Optional[str] = None,
|
|
159
|
+
image_obj: Optional[Image.Image] = None,
|
|
160
|
+
) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
|
|
161
|
+
"""Create a message with an image.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
image_base64: Base64 encoded image
|
|
165
|
+
image_path: Path to image file
|
|
166
|
+
image_obj: PIL Image object
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Message dictionary with content list
|
|
170
|
+
|
|
171
|
+
Raises:
|
|
172
|
+
ValueError: If no image source is provided
|
|
173
|
+
"""
|
|
174
|
+
if not any([image_base64, image_path, image_obj]):
|
|
175
|
+
raise ValueError("Must provide one of image_base64, image_path, or image_obj")
|
|
176
|
+
|
|
177
|
+
# Convert to base64 if needed
|
|
178
|
+
if image_path and not image_base64:
|
|
179
|
+
with open(image_path, "rb") as f:
|
|
180
|
+
image_bytes = f.read()
|
|
181
|
+
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
|
182
|
+
elif image_obj and not image_base64:
|
|
183
|
+
buffer = BytesIO()
|
|
184
|
+
image_obj.save(buffer, format="PNG")
|
|
185
|
+
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
186
|
+
|
|
187
|
+
return {
|
|
188
|
+
"role": "user",
|
|
189
|
+
"content": [
|
|
190
|
+
{"type": "image", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
|
191
|
+
],
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def create_screen_message(
|
|
196
|
+
parsed_screen: Dict[str, Any],
|
|
197
|
+
include_raw: bool = False,
|
|
198
|
+
) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
|
|
199
|
+
"""Create a message with screen information.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
parsed_screen: Dictionary containing parsed screen info
|
|
203
|
+
include_raw: Whether to include raw screenshot base64
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
Message dictionary with content
|
|
207
|
+
"""
|
|
208
|
+
if include_raw and "screenshot_base64" in parsed_screen:
|
|
209
|
+
# Create content list with both image and text
|
|
210
|
+
return {
|
|
211
|
+
"role": "user",
|
|
212
|
+
"content": [
|
|
213
|
+
{
|
|
214
|
+
"type": "image",
|
|
215
|
+
"image_url": {
|
|
216
|
+
"url": f"data:image/png;base64,{parsed_screen['screenshot_base64']}"
|
|
217
|
+
},
|
|
218
|
+
},
|
|
219
|
+
{
|
|
220
|
+
"type": "text",
|
|
221
|
+
"text": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
|
|
222
|
+
},
|
|
223
|
+
],
|
|
224
|
+
}
|
|
225
|
+
else:
|
|
226
|
+
# Create text-only message with screen info
|
|
227
|
+
return {
|
|
228
|
+
"role": "user",
|
|
229
|
+
"content": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
|
|
230
|
+
}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Core tools package."""
|
|
2
|
+
|
|
3
|
+
from .base import BaseTool, ToolResult, ToolError, ToolFailure, CLIResult
|
|
4
|
+
from .bash import BaseBashTool
|
|
5
|
+
from .collection import ToolCollection
|
|
6
|
+
from .computer import BaseComputerTool
|
|
7
|
+
from .edit import BaseEditTool
|
|
8
|
+
from .manager import BaseToolManager
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"BaseTool",
|
|
12
|
+
"ToolResult",
|
|
13
|
+
"ToolError",
|
|
14
|
+
"ToolFailure",
|
|
15
|
+
"CLIResult",
|
|
16
|
+
"BaseBashTool",
|
|
17
|
+
"BaseComputerTool",
|
|
18
|
+
"BaseEditTool",
|
|
19
|
+
"ToolCollection",
|
|
20
|
+
"BaseToolManager",
|
|
21
|
+
]
|
agent/core/tools/base.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Abstract base classes for tools that can be used with any provider."""
|
|
2
|
+
|
|
3
|
+
from abc import ABCMeta, abstractmethod
|
|
4
|
+
from dataclasses import dataclass, fields, replace
|
|
5
|
+
from typing import Any, Dict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseTool(metaclass=ABCMeta):
|
|
9
|
+
"""Abstract base class for provider-agnostic tools."""
|
|
10
|
+
|
|
11
|
+
name: str
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
async def __call__(self, **kwargs) -> Any:
|
|
15
|
+
"""Executes the tool with the given arguments."""
|
|
16
|
+
...
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def to_params(self) -> Dict[str, Any]:
|
|
20
|
+
"""Convert tool to provider-specific API parameters.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Dictionary with tool parameters specific to the LLM provider
|
|
24
|
+
"""
|
|
25
|
+
raise NotImplementedError
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(kw_only=True, frozen=True)
|
|
29
|
+
class ToolResult:
|
|
30
|
+
"""Represents the result of a tool execution."""
|
|
31
|
+
|
|
32
|
+
output: str | None = None
|
|
33
|
+
error: str | None = None
|
|
34
|
+
base64_image: str | None = None
|
|
35
|
+
system: str | None = None
|
|
36
|
+
content: list[dict] | None = None
|
|
37
|
+
|
|
38
|
+
def __bool__(self):
|
|
39
|
+
return any(getattr(self, field.name) for field in fields(self))
|
|
40
|
+
|
|
41
|
+
def __add__(self, other: "ToolResult"):
|
|
42
|
+
def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True):
|
|
43
|
+
if field and other_field:
|
|
44
|
+
if concatenate:
|
|
45
|
+
return field + other_field
|
|
46
|
+
raise ValueError("Cannot combine tool results")
|
|
47
|
+
return field or other_field
|
|
48
|
+
|
|
49
|
+
return ToolResult(
|
|
50
|
+
output=combine_fields(self.output, other.output),
|
|
51
|
+
error=combine_fields(self.error, other.error),
|
|
52
|
+
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
|
53
|
+
system=combine_fields(self.system, other.system),
|
|
54
|
+
content=self.content or other.content, # Use first non-None content
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def replace(self, **kwargs):
|
|
58
|
+
"""Returns a new ToolResult with the given fields replaced."""
|
|
59
|
+
return replace(self, **kwargs)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class CLIResult(ToolResult):
|
|
63
|
+
"""A ToolResult that can be rendered as a CLI output."""
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ToolFailure(ToolResult):
|
|
67
|
+
"""A ToolResult that represents a failure."""
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ToolError(Exception):
|
|
71
|
+
"""Raised when a tool encounters an error."""
|
|
72
|
+
|
|
73
|
+
def __init__(self, message):
|
|
74
|
+
self.message = message
|
agent/core/tools/bash.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Abstract base bash/shell tool implementation."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from abc import abstractmethod
|
|
6
|
+
from typing import Any, Dict, Tuple
|
|
7
|
+
|
|
8
|
+
from computer.computer import Computer
|
|
9
|
+
|
|
10
|
+
from .base import BaseTool, ToolResult
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BaseBashTool(BaseTool):
|
|
14
|
+
"""Base class for bash/shell command execution tools across different providers."""
|
|
15
|
+
|
|
16
|
+
name = "bash"
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
computer: Computer
|
|
19
|
+
|
|
20
|
+
def __init__(self, computer: Computer):
|
|
21
|
+
"""Initialize the BashTool.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
computer: Computer instance, may be used for related operations
|
|
25
|
+
"""
|
|
26
|
+
self.computer = computer
|
|
27
|
+
|
|
28
|
+
async def run_command(self, command: str) -> Tuple[int, str, str]:
|
|
29
|
+
"""Run a shell command and return exit code, stdout, and stderr.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
command: Shell command to execute
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Tuple containing (exit_code, stdout, stderr)
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
process = await asyncio.create_subprocess_shell(
|
|
39
|
+
command,
|
|
40
|
+
stdout=asyncio.subprocess.PIPE,
|
|
41
|
+
stderr=asyncio.subprocess.PIPE,
|
|
42
|
+
)
|
|
43
|
+
stdout, stderr = await process.communicate()
|
|
44
|
+
return process.returncode or 0, stdout.decode(), stderr.decode()
|
|
45
|
+
except Exception as e:
|
|
46
|
+
self.logger.error(f"Error running command: {str(e)}")
|
|
47
|
+
return 1, "", str(e)
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
async def __call__(self, **kwargs) -> ToolResult:
|
|
51
|
+
"""Execute the tool with the provided arguments."""
|
|
52
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Collection classes for managing multiple tools."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Type
|
|
4
|
+
|
|
5
|
+
from .base import (
|
|
6
|
+
BaseTool,
|
|
7
|
+
ToolError,
|
|
8
|
+
ToolFailure,
|
|
9
|
+
ToolResult,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ToolCollection:
|
|
14
|
+
"""A collection of tools that can be used with any provider."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, *tools: BaseTool):
|
|
17
|
+
self.tools = tools
|
|
18
|
+
self.tool_map = {tool.name: tool for tool in tools}
|
|
19
|
+
|
|
20
|
+
def to_params(self) -> List[Dict[str, Any]]:
|
|
21
|
+
"""Convert all tools to provider-specific parameters.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
List of dictionaries with tool parameters
|
|
25
|
+
"""
|
|
26
|
+
return [tool.to_params() for tool in self.tools]
|
|
27
|
+
|
|
28
|
+
async def run(self, *, name: str, tool_input: Dict[str, Any]) -> ToolResult:
|
|
29
|
+
"""Run a tool with the given input.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
name: Name of the tool to run
|
|
33
|
+
tool_input: Input parameters for the tool
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Result of the tool execution
|
|
37
|
+
"""
|
|
38
|
+
tool = self.tool_map.get(name)
|
|
39
|
+
if not tool:
|
|
40
|
+
return ToolFailure(error=f"Tool {name} is invalid")
|
|
41
|
+
try:
|
|
42
|
+
return await tool(**tool_input)
|
|
43
|
+
except ToolError as e:
|
|
44
|
+
return ToolFailure(error=e.message)
|
|
45
|
+
except Exception as e:
|
|
46
|
+
return ToolFailure(error=f"Unexpected error in tool {name}: {str(e)}")
|