cua-agent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (65) hide show
  1. agent/README.md +63 -0
  2. agent/__init__.py +10 -0
  3. agent/core/README.md +101 -0
  4. agent/core/__init__.py +34 -0
  5. agent/core/agent.py +284 -0
  6. agent/core/base_agent.py +164 -0
  7. agent/core/callbacks.py +147 -0
  8. agent/core/computer_agent.py +69 -0
  9. agent/core/experiment.py +222 -0
  10. agent/core/factory.py +102 -0
  11. agent/core/loop.py +244 -0
  12. agent/core/messages.py +230 -0
  13. agent/core/tools/__init__.py +21 -0
  14. agent/core/tools/base.py +74 -0
  15. agent/core/tools/bash.py +52 -0
  16. agent/core/tools/collection.py +46 -0
  17. agent/core/tools/computer.py +113 -0
  18. agent/core/tools/edit.py +67 -0
  19. agent/core/tools/manager.py +56 -0
  20. agent/providers/__init__.py +4 -0
  21. agent/providers/anthropic/__init__.py +6 -0
  22. agent/providers/anthropic/api/client.py +222 -0
  23. agent/providers/anthropic/api/logging.py +150 -0
  24. agent/providers/anthropic/callbacks/manager.py +55 -0
  25. agent/providers/anthropic/loop.py +521 -0
  26. agent/providers/anthropic/messages/manager.py +110 -0
  27. agent/providers/anthropic/prompts.py +20 -0
  28. agent/providers/anthropic/tools/__init__.py +33 -0
  29. agent/providers/anthropic/tools/base.py +88 -0
  30. agent/providers/anthropic/tools/bash.py +163 -0
  31. agent/providers/anthropic/tools/collection.py +34 -0
  32. agent/providers/anthropic/tools/computer.py +550 -0
  33. agent/providers/anthropic/tools/edit.py +326 -0
  34. agent/providers/anthropic/tools/manager.py +54 -0
  35. agent/providers/anthropic/tools/run.py +42 -0
  36. agent/providers/anthropic/types.py +16 -0
  37. agent/providers/omni/__init__.py +27 -0
  38. agent/providers/omni/callbacks.py +78 -0
  39. agent/providers/omni/clients/anthropic.py +99 -0
  40. agent/providers/omni/clients/base.py +44 -0
  41. agent/providers/omni/clients/groq.py +101 -0
  42. agent/providers/omni/clients/openai.py +159 -0
  43. agent/providers/omni/clients/utils.py +25 -0
  44. agent/providers/omni/experiment.py +273 -0
  45. agent/providers/omni/image_utils.py +106 -0
  46. agent/providers/omni/loop.py +961 -0
  47. agent/providers/omni/messages.py +168 -0
  48. agent/providers/omni/parser.py +252 -0
  49. agent/providers/omni/prompts.py +78 -0
  50. agent/providers/omni/tool_manager.py +91 -0
  51. agent/providers/omni/tools/__init__.py +13 -0
  52. agent/providers/omni/tools/bash.py +69 -0
  53. agent/providers/omni/tools/computer.py +216 -0
  54. agent/providers/omni/tools/manager.py +83 -0
  55. agent/providers/omni/types.py +30 -0
  56. agent/providers/omni/utils.py +155 -0
  57. agent/providers/omni/visualization.py +130 -0
  58. agent/types/__init__.py +26 -0
  59. agent/types/base.py +52 -0
  60. agent/types/messages.py +36 -0
  61. agent/types/tools.py +32 -0
  62. cua_agent-0.1.0.dist-info/METADATA +44 -0
  63. cua_agent-0.1.0.dist-info/RECORD +65 -0
  64. cua_agent-0.1.0.dist-info/WHEEL +4 -0
  65. cua_agent-0.1.0.dist-info/entry_points.txt +4 -0
agent/core/loop.py ADDED
@@ -0,0 +1,244 @@
1
+ """Base agent loop implementation."""
2
+
3
+ import logging
4
+ import asyncio
5
+ import json
6
+ import os
7
+ from abc import ABC, abstractmethod
8
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
9
+ from datetime import datetime
10
+ import base64
11
+
12
+ from computer import Computer
13
+ from .experiment import ExperimentManager
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class BaseLoop(ABC):
19
+ """Base class for agent loops that handle message processing and tool execution."""
20
+
21
+ def __init__(
22
+ self,
23
+ computer: Computer,
24
+ model: str,
25
+ api_key: str,
26
+ max_tokens: int = 4096,
27
+ max_retries: int = 3,
28
+ retry_delay: float = 1.0,
29
+ base_dir: Optional[str] = "trajectories",
30
+ save_trajectory: bool = True,
31
+ only_n_most_recent_images: Optional[int] = 2,
32
+ **kwargs,
33
+ ):
34
+ """Initialize base agent loop.
35
+
36
+ Args:
37
+ computer: Computer instance to control
38
+ model: Model name to use
39
+ api_key: API key for provider
40
+ max_tokens: Maximum tokens to generate
41
+ max_retries: Maximum number of retries
42
+ retry_delay: Delay between retries in seconds
43
+ base_dir: Base directory for saving experiment data
44
+ save_trajectory: Whether to save trajectory data
45
+ only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
46
+ **kwargs: Additional provider-specific arguments
47
+ """
48
+ self.computer = computer
49
+ self.model = model
50
+ self.api_key = api_key
51
+ self.max_tokens = max_tokens
52
+ self.max_retries = max_retries
53
+ self.retry_delay = retry_delay
54
+ self.base_dir = base_dir
55
+ self.save_trajectory = save_trajectory
56
+ self.only_n_most_recent_images = only_n_most_recent_images
57
+ self._kwargs = kwargs
58
+ self.message_history = []
59
+ # self.tool_manager = BaseToolManager(computer)
60
+
61
+ # Initialize experiment manager
62
+ if self.save_trajectory and self.base_dir:
63
+ self.experiment_manager = ExperimentManager(
64
+ base_dir=self.base_dir,
65
+ only_n_most_recent_images=only_n_most_recent_images,
66
+ )
67
+ # Track directories for convenience
68
+ self.run_dir = self.experiment_manager.run_dir
69
+ self.current_turn_dir = self.experiment_manager.current_turn_dir
70
+ else:
71
+ self.experiment_manager = None
72
+ self.run_dir = None
73
+ self.current_turn_dir = None
74
+
75
+ # Initialize basic tracking
76
+ self.turn_count = 0
77
+
78
+ def _setup_experiment_dirs(self) -> None:
79
+ """Setup the experiment directory structure."""
80
+ if self.experiment_manager:
81
+ # Use the experiment manager to set up directories
82
+ self.experiment_manager.setup_experiment_dirs()
83
+
84
+ # Update local tracking variables
85
+ self.run_dir = self.experiment_manager.run_dir
86
+ self.current_turn_dir = self.experiment_manager.current_turn_dir
87
+
88
+ def _create_turn_dir(self) -> None:
89
+ """Create a new directory for the current turn."""
90
+ if self.experiment_manager:
91
+ # Use the experiment manager to create the turn directory
92
+ self.experiment_manager.create_turn_dir()
93
+
94
+ # Update local tracking variables
95
+ self.current_turn_dir = self.experiment_manager.current_turn_dir
96
+ self.turn_count = self.experiment_manager.turn_count
97
+
98
+ def _log_api_call(
99
+ self, call_type: str, request: Any, response: Any = None, error: Optional[Exception] = None
100
+ ) -> None:
101
+ """Log API call details to file.
102
+
103
+ Args:
104
+ call_type: Type of API call (e.g., 'request', 'response', 'error')
105
+ request: The API request data
106
+ response: Optional API response data
107
+ error: Optional error information
108
+ """
109
+ if self.experiment_manager:
110
+ # Use the experiment manager to log the API call
111
+ provider = getattr(self, "provider", "unknown")
112
+ provider_str = str(provider) if provider else "unknown"
113
+
114
+ self.experiment_manager.log_api_call(
115
+ call_type=call_type,
116
+ request=request,
117
+ provider=provider_str,
118
+ model=self.model,
119
+ response=response,
120
+ error=error,
121
+ )
122
+
123
+ def _save_screenshot(self, img_base64: str, action_type: str = "") -> None:
124
+ """Save a screenshot to the experiment directory.
125
+
126
+ Args:
127
+ img_base64: Base64 encoded screenshot
128
+ action_type: Type of action that triggered the screenshot
129
+ """
130
+ if self.experiment_manager:
131
+ self.experiment_manager.save_screenshot(img_base64, action_type)
132
+
133
+ async def initialize(self) -> None:
134
+ """Initialize both the API client and computer interface with retries."""
135
+ for attempt in range(self.max_retries):
136
+ try:
137
+ logger.info(
138
+ f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..."
139
+ )
140
+
141
+ # Initialize API client
142
+ await self.initialize_client()
143
+
144
+ # Initialize computer
145
+ await self.computer.initialize()
146
+
147
+ logger.info("Initialization complete.")
148
+ return
149
+ except Exception as e:
150
+ if attempt < self.max_retries - 1:
151
+ logger.warning(
152
+ f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..."
153
+ )
154
+ await asyncio.sleep(self.retry_delay)
155
+ else:
156
+ logger.error(
157
+ f"Initialization failed after {self.max_retries} attempts: {str(e)}"
158
+ )
159
+ raise RuntimeError(f"Failed to initialize: {str(e)}")
160
+
161
+ async def _get_parsed_screen_som(self) -> Dict[str, Any]:
162
+ """Get parsed screen information.
163
+
164
+ Returns:
165
+ Dict containing screen information
166
+ """
167
+ try:
168
+ # Take screenshot
169
+ screenshot = await self.computer.screenshot()
170
+
171
+ # Initialize with default values
172
+ width, height = 1024, 768
173
+ base64_image = ""
174
+
175
+ # Handle different types of screenshot returns
176
+ if isinstance(screenshot, bytes):
177
+ # Raw bytes screenshot
178
+ base64_image = base64.b64encode(screenshot).decode("utf-8")
179
+ elif hasattr(screenshot, "base64_image"):
180
+ # Object-style screenshot with attributes
181
+ base64_image = screenshot.base64_image
182
+ if hasattr(screenshot, "width") and hasattr(screenshot, "height"):
183
+ width = screenshot.width
184
+ height = screenshot.height
185
+
186
+ # Create parsed screen data
187
+ parsed_screen = {
188
+ "width": width,
189
+ "height": height,
190
+ "parsed_content_list": [],
191
+ "timestamp": datetime.now().isoformat(),
192
+ "screenshot_base64": base64_image,
193
+ }
194
+
195
+ # Save screenshot if requested
196
+ if self.save_trajectory and self.experiment_manager:
197
+ try:
198
+ img_data = base64_image
199
+ if "," in img_data:
200
+ img_data = img_data.split(",")[1]
201
+ self._save_screenshot(img_data, action_type="state")
202
+ except Exception as e:
203
+ logger.error(f"Error saving screenshot: {str(e)}")
204
+
205
+ return parsed_screen
206
+ except Exception as e:
207
+ logger.error(f"Error taking screenshot: {str(e)}")
208
+ return {
209
+ "width": 1024,
210
+ "height": 768,
211
+ "parsed_content_list": [],
212
+ "timestamp": datetime.now().isoformat(),
213
+ "error": f"Error taking screenshot: {str(e)}",
214
+ "screenshot_base64": "",
215
+ }
216
+
217
+ @abstractmethod
218
+ async def initialize_client(self) -> None:
219
+ """Initialize the API client and any provider-specific components."""
220
+ raise NotImplementedError
221
+
222
+ @abstractmethod
223
+ async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
224
+ """Run the agent loop with provided messages.
225
+
226
+ Args:
227
+ messages: List of message objects
228
+
229
+ Yields:
230
+ Dict containing response data
231
+ """
232
+ raise NotImplementedError
233
+
234
+ @abstractmethod
235
+ async def _process_screen(
236
+ self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]]
237
+ ) -> None:
238
+ """Process screen information and add to messages.
239
+
240
+ Args:
241
+ parsed_screen: Dictionary containing parsed screen info
242
+ messages: List of messages to update
243
+ """
244
+ raise NotImplementedError
agent/core/messages.py ADDED
@@ -0,0 +1,230 @@
1
+ """Message handling utilities for agent."""
2
+
3
+ import base64
4
+ from datetime import datetime
5
+ from io import BytesIO
6
+ import logging
7
+ from typing import Any, Dict, List, Optional, Union
8
+ from PIL import Image
9
+ from dataclasses import dataclass
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @dataclass
15
+ class ImageRetentionConfig:
16
+ """Configuration for image retention in messages."""
17
+
18
+ num_images_to_keep: Optional[int] = None
19
+ min_removal_threshold: int = 1
20
+ enable_caching: bool = True
21
+
22
+ def should_retain_images(self) -> bool:
23
+ """Check if image retention is enabled."""
24
+ return self.num_images_to_keep is not None and self.num_images_to_keep > 0
25
+
26
+
27
+ class BaseMessageManager:
28
+ """Base class for message preparation and management."""
29
+
30
+ def __init__(self, image_retention_config: Optional[ImageRetentionConfig] = None):
31
+ """Initialize the message manager.
32
+
33
+ Args:
34
+ image_retention_config: Configuration for image retention
35
+ """
36
+ self.image_retention_config = image_retention_config or ImageRetentionConfig()
37
+ if self.image_retention_config.min_removal_threshold < 1:
38
+ raise ValueError("min_removal_threshold must be at least 1")
39
+
40
+ def prepare_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
41
+ """Prepare messages by applying image retention and caching as configured.
42
+
43
+ Args:
44
+ messages: List of messages to prepare
45
+
46
+ Returns:
47
+ Prepared messages
48
+ """
49
+ if self.image_retention_config.should_retain_images():
50
+ self._filter_images(messages)
51
+ if self.image_retention_config.enable_caching:
52
+ self._inject_caching(messages)
53
+ return messages
54
+
55
+ def _filter_images(self, messages: List[Dict[str, Any]]) -> None:
56
+ """Filter messages to retain only the specified number of most recent images.
57
+
58
+ Args:
59
+ messages: Messages to filter
60
+ """
61
+ # Find all tool result blocks that contain images
62
+ tool_results = [
63
+ item
64
+ for message in messages
65
+ for item in (message["content"] if isinstance(message["content"], list) else [])
66
+ if isinstance(item, dict) and item.get("type") == "tool_result"
67
+ ]
68
+
69
+ # Count total images
70
+ total_images = sum(
71
+ 1
72
+ for result in tool_results
73
+ for content in result.get("content", [])
74
+ if isinstance(content, dict) and content.get("type") == "image"
75
+ )
76
+
77
+ # Calculate how many images to remove
78
+ images_to_remove = total_images - (self.image_retention_config.num_images_to_keep or 0)
79
+ images_to_remove -= images_to_remove % self.image_retention_config.min_removal_threshold
80
+
81
+ # Remove oldest images first
82
+ for result in tool_results:
83
+ if isinstance(result.get("content"), list):
84
+ new_content = []
85
+ for content in result["content"]:
86
+ if isinstance(content, dict) and content.get("type") == "image":
87
+ if images_to_remove > 0:
88
+ images_to_remove -= 1
89
+ continue
90
+ new_content.append(content)
91
+ result["content"] = new_content
92
+
93
+ def _inject_caching(self, messages: List[Dict[str, Any]]) -> None:
94
+ """Inject caching control for recent message turns.
95
+
96
+ Args:
97
+ messages: Messages to inject caching into
98
+ """
99
+ # Default to caching last 3 turns
100
+ turns_to_cache = 3
101
+ for message in reversed(messages):
102
+ if message["role"] == "user" and isinstance(content := message["content"], list):
103
+ if turns_to_cache:
104
+ turns_to_cache -= 1
105
+ content[-1]["cache_control"] = {"type": "ephemeral"}
106
+ else:
107
+ content[-1].pop("cache_control", None)
108
+ break
109
+
110
+
111
+ def create_user_message(text: str) -> Dict[str, str]:
112
+ """Create a user message.
113
+
114
+ Args:
115
+ text: The message text
116
+
117
+ Returns:
118
+ Message dictionary
119
+ """
120
+ return {
121
+ "role": "user",
122
+ "content": text,
123
+ }
124
+
125
+
126
+ def create_assistant_message(text: str) -> Dict[str, str]:
127
+ """Create an assistant message.
128
+
129
+ Args:
130
+ text: The message text
131
+
132
+ Returns:
133
+ Message dictionary
134
+ """
135
+ return {
136
+ "role": "assistant",
137
+ "content": text,
138
+ }
139
+
140
+
141
+ def create_system_message(text: str) -> Dict[str, str]:
142
+ """Create a system message.
143
+
144
+ Args:
145
+ text: The message text
146
+
147
+ Returns:
148
+ Message dictionary
149
+ """
150
+ return {
151
+ "role": "system",
152
+ "content": text,
153
+ }
154
+
155
+
156
+ def create_image_message(
157
+ image_base64: Optional[str] = None,
158
+ image_path: Optional[str] = None,
159
+ image_obj: Optional[Image.Image] = None,
160
+ ) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
161
+ """Create a message with an image.
162
+
163
+ Args:
164
+ image_base64: Base64 encoded image
165
+ image_path: Path to image file
166
+ image_obj: PIL Image object
167
+
168
+ Returns:
169
+ Message dictionary with content list
170
+
171
+ Raises:
172
+ ValueError: If no image source is provided
173
+ """
174
+ if not any([image_base64, image_path, image_obj]):
175
+ raise ValueError("Must provide one of image_base64, image_path, or image_obj")
176
+
177
+ # Convert to base64 if needed
178
+ if image_path and not image_base64:
179
+ with open(image_path, "rb") as f:
180
+ image_bytes = f.read()
181
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
182
+ elif image_obj and not image_base64:
183
+ buffer = BytesIO()
184
+ image_obj.save(buffer, format="PNG")
185
+ image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
186
+
187
+ return {
188
+ "role": "user",
189
+ "content": [
190
+ {"type": "image", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
191
+ ],
192
+ }
193
+
194
+
195
+ def create_screen_message(
196
+ parsed_screen: Dict[str, Any],
197
+ include_raw: bool = False,
198
+ ) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
199
+ """Create a message with screen information.
200
+
201
+ Args:
202
+ parsed_screen: Dictionary containing parsed screen info
203
+ include_raw: Whether to include raw screenshot base64
204
+
205
+ Returns:
206
+ Message dictionary with content
207
+ """
208
+ if include_raw and "screenshot_base64" in parsed_screen:
209
+ # Create content list with both image and text
210
+ return {
211
+ "role": "user",
212
+ "content": [
213
+ {
214
+ "type": "image",
215
+ "image_url": {
216
+ "url": f"data:image/png;base64,{parsed_screen['screenshot_base64']}"
217
+ },
218
+ },
219
+ {
220
+ "type": "text",
221
+ "text": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
222
+ },
223
+ ],
224
+ }
225
+ else:
226
+ # Create text-only message with screen info
227
+ return {
228
+ "role": "user",
229
+ "content": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
230
+ }
@@ -0,0 +1,21 @@
1
+ """Core tools package."""
2
+
3
+ from .base import BaseTool, ToolResult, ToolError, ToolFailure, CLIResult
4
+ from .bash import BaseBashTool
5
+ from .collection import ToolCollection
6
+ from .computer import BaseComputerTool
7
+ from .edit import BaseEditTool
8
+ from .manager import BaseToolManager
9
+
10
+ __all__ = [
11
+ "BaseTool",
12
+ "ToolResult",
13
+ "ToolError",
14
+ "ToolFailure",
15
+ "CLIResult",
16
+ "BaseBashTool",
17
+ "BaseComputerTool",
18
+ "BaseEditTool",
19
+ "ToolCollection",
20
+ "BaseToolManager",
21
+ ]
@@ -0,0 +1,74 @@
1
+ """Abstract base classes for tools that can be used with any provider."""
2
+
3
+ from abc import ABCMeta, abstractmethod
4
+ from dataclasses import dataclass, fields, replace
5
+ from typing import Any, Dict
6
+
7
+
8
+ class BaseTool(metaclass=ABCMeta):
9
+ """Abstract base class for provider-agnostic tools."""
10
+
11
+ name: str
12
+
13
+ @abstractmethod
14
+ async def __call__(self, **kwargs) -> Any:
15
+ """Executes the tool with the given arguments."""
16
+ ...
17
+
18
+ @abstractmethod
19
+ def to_params(self) -> Dict[str, Any]:
20
+ """Convert tool to provider-specific API parameters.
21
+
22
+ Returns:
23
+ Dictionary with tool parameters specific to the LLM provider
24
+ """
25
+ raise NotImplementedError
26
+
27
+
28
+ @dataclass(kw_only=True, frozen=True)
29
+ class ToolResult:
30
+ """Represents the result of a tool execution."""
31
+
32
+ output: str | None = None
33
+ error: str | None = None
34
+ base64_image: str | None = None
35
+ system: str | None = None
36
+ content: list[dict] | None = None
37
+
38
+ def __bool__(self):
39
+ return any(getattr(self, field.name) for field in fields(self))
40
+
41
+ def __add__(self, other: "ToolResult"):
42
+ def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True):
43
+ if field and other_field:
44
+ if concatenate:
45
+ return field + other_field
46
+ raise ValueError("Cannot combine tool results")
47
+ return field or other_field
48
+
49
+ return ToolResult(
50
+ output=combine_fields(self.output, other.output),
51
+ error=combine_fields(self.error, other.error),
52
+ base64_image=combine_fields(self.base64_image, other.base64_image, False),
53
+ system=combine_fields(self.system, other.system),
54
+ content=self.content or other.content, # Use first non-None content
55
+ )
56
+
57
+ def replace(self, **kwargs):
58
+ """Returns a new ToolResult with the given fields replaced."""
59
+ return replace(self, **kwargs)
60
+
61
+
62
+ class CLIResult(ToolResult):
63
+ """A ToolResult that can be rendered as a CLI output."""
64
+
65
+
66
+ class ToolFailure(ToolResult):
67
+ """A ToolResult that represents a failure."""
68
+
69
+
70
+ class ToolError(Exception):
71
+ """Raised when a tool encounters an error."""
72
+
73
+ def __init__(self, message):
74
+ self.message = message
@@ -0,0 +1,52 @@
1
+ """Abstract base bash/shell tool implementation."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from abc import abstractmethod
6
+ from typing import Any, Dict, Tuple
7
+
8
+ from computer.computer import Computer
9
+
10
+ from .base import BaseTool, ToolResult
11
+
12
+
13
+ class BaseBashTool(BaseTool):
14
+ """Base class for bash/shell command execution tools across different providers."""
15
+
16
+ name = "bash"
17
+ logger = logging.getLogger(__name__)
18
+ computer: Computer
19
+
20
+ def __init__(self, computer: Computer):
21
+ """Initialize the BashTool.
22
+
23
+ Args:
24
+ computer: Computer instance, may be used for related operations
25
+ """
26
+ self.computer = computer
27
+
28
+ async def run_command(self, command: str) -> Tuple[int, str, str]:
29
+ """Run a shell command and return exit code, stdout, and stderr.
30
+
31
+ Args:
32
+ command: Shell command to execute
33
+
34
+ Returns:
35
+ Tuple containing (exit_code, stdout, stderr)
36
+ """
37
+ try:
38
+ process = await asyncio.create_subprocess_shell(
39
+ command,
40
+ stdout=asyncio.subprocess.PIPE,
41
+ stderr=asyncio.subprocess.PIPE,
42
+ )
43
+ stdout, stderr = await process.communicate()
44
+ return process.returncode or 0, stdout.decode(), stderr.decode()
45
+ except Exception as e:
46
+ self.logger.error(f"Error running command: {str(e)}")
47
+ return 1, "", str(e)
48
+
49
+ @abstractmethod
50
+ async def __call__(self, **kwargs) -> ToolResult:
51
+ """Execute the tool with the provided arguments."""
52
+ raise NotImplementedError
@@ -0,0 +1,46 @@
1
+ """Collection classes for managing multiple tools."""
2
+
3
+ from typing import Any, Dict, List, Type
4
+
5
+ from .base import (
6
+ BaseTool,
7
+ ToolError,
8
+ ToolFailure,
9
+ ToolResult,
10
+ )
11
+
12
+
13
+ class ToolCollection:
14
+ """A collection of tools that can be used with any provider."""
15
+
16
+ def __init__(self, *tools: BaseTool):
17
+ self.tools = tools
18
+ self.tool_map = {tool.name: tool for tool in tools}
19
+
20
+ def to_params(self) -> List[Dict[str, Any]]:
21
+ """Convert all tools to provider-specific parameters.
22
+
23
+ Returns:
24
+ List of dictionaries with tool parameters
25
+ """
26
+ return [tool.to_params() for tool in self.tools]
27
+
28
+ async def run(self, *, name: str, tool_input: Dict[str, Any]) -> ToolResult:
29
+ """Run a tool with the given input.
30
+
31
+ Args:
32
+ name: Name of the tool to run
33
+ tool_input: Input parameters for the tool
34
+
35
+ Returns:
36
+ Result of the tool execution
37
+ """
38
+ tool = self.tool_map.get(name)
39
+ if not tool:
40
+ return ToolFailure(error=f"Tool {name} is invalid")
41
+ try:
42
+ return await tool(**tool_input)
43
+ except ToolError as e:
44
+ return ToolFailure(error=e.message)
45
+ except Exception as e:
46
+ return ToolFailure(error=f"Unexpected error in tool {name}: {str(e)}")