cua-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/README.md +63 -0
- agent/__init__.py +10 -0
- agent/core/README.md +101 -0
- agent/core/__init__.py +34 -0
- agent/core/agent.py +284 -0
- agent/core/base_agent.py +164 -0
- agent/core/callbacks.py +147 -0
- agent/core/computer_agent.py +69 -0
- agent/core/experiment.py +222 -0
- agent/core/factory.py +102 -0
- agent/core/loop.py +244 -0
- agent/core/messages.py +230 -0
- agent/core/tools/__init__.py +21 -0
- agent/core/tools/base.py +74 -0
- agent/core/tools/bash.py +52 -0
- agent/core/tools/collection.py +46 -0
- agent/core/tools/computer.py +113 -0
- agent/core/tools/edit.py +67 -0
- agent/core/tools/manager.py +56 -0
- agent/providers/__init__.py +4 -0
- agent/providers/anthropic/__init__.py +6 -0
- agent/providers/anthropic/api/client.py +222 -0
- agent/providers/anthropic/api/logging.py +150 -0
- agent/providers/anthropic/callbacks/manager.py +55 -0
- agent/providers/anthropic/loop.py +521 -0
- agent/providers/anthropic/messages/manager.py +110 -0
- agent/providers/anthropic/prompts.py +20 -0
- agent/providers/anthropic/tools/__init__.py +33 -0
- agent/providers/anthropic/tools/base.py +88 -0
- agent/providers/anthropic/tools/bash.py +163 -0
- agent/providers/anthropic/tools/collection.py +34 -0
- agent/providers/anthropic/tools/computer.py +550 -0
- agent/providers/anthropic/tools/edit.py +326 -0
- agent/providers/anthropic/tools/manager.py +54 -0
- agent/providers/anthropic/tools/run.py +42 -0
- agent/providers/anthropic/types.py +16 -0
- agent/providers/omni/__init__.py +27 -0
- agent/providers/omni/callbacks.py +78 -0
- agent/providers/omni/clients/anthropic.py +99 -0
- agent/providers/omni/clients/base.py +44 -0
- agent/providers/omni/clients/groq.py +101 -0
- agent/providers/omni/clients/openai.py +159 -0
- agent/providers/omni/clients/utils.py +25 -0
- agent/providers/omni/experiment.py +273 -0
- agent/providers/omni/image_utils.py +106 -0
- agent/providers/omni/loop.py +961 -0
- agent/providers/omni/messages.py +168 -0
- agent/providers/omni/parser.py +252 -0
- agent/providers/omni/prompts.py +78 -0
- agent/providers/omni/tool_manager.py +91 -0
- agent/providers/omni/tools/__init__.py +13 -0
- agent/providers/omni/tools/bash.py +69 -0
- agent/providers/omni/tools/computer.py +216 -0
- agent/providers/omni/tools/manager.py +83 -0
- agent/providers/omni/types.py +30 -0
- agent/providers/omni/utils.py +155 -0
- agent/providers/omni/visualization.py +130 -0
- agent/types/__init__.py +26 -0
- agent/types/base.py +52 -0
- agent/types/messages.py +36 -0
- agent/types/tools.py +32 -0
- cua_agent-0.1.0.dist-info/METADATA +44 -0
- cua_agent-0.1.0.dist-info/RECORD +65 -0
- cua_agent-0.1.0.dist-info/WHEEL +4 -0
- cua_agent-0.1.0.dist-info/entry_points.txt +4 -0
agent/core/callbacks.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""Callback handlers for agent."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any, Dict, List, Optional, Protocol
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
class ContentCallback(Protocol):
|
|
12
|
+
"""Protocol for content callbacks."""
|
|
13
|
+
def __call__(self, content: Dict[str, Any]) -> None: ...
|
|
14
|
+
|
|
15
|
+
class ToolCallback(Protocol):
|
|
16
|
+
"""Protocol for tool callbacks."""
|
|
17
|
+
def __call__(self, result: Any, tool_id: str) -> None: ...
|
|
18
|
+
|
|
19
|
+
class APICallback(Protocol):
|
|
20
|
+
"""Protocol for API callbacks."""
|
|
21
|
+
def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ...
|
|
22
|
+
|
|
23
|
+
class BaseCallbackManager(ABC):
|
|
24
|
+
"""Base class for callback managers."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
content_callback: ContentCallback,
|
|
29
|
+
tool_callback: ToolCallback,
|
|
30
|
+
api_callback: APICallback,
|
|
31
|
+
):
|
|
32
|
+
"""Initialize the callback manager.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
content_callback: Callback for content updates
|
|
36
|
+
tool_callback: Callback for tool execution results
|
|
37
|
+
api_callback: Callback for API interactions
|
|
38
|
+
"""
|
|
39
|
+
self.content_callback = content_callback
|
|
40
|
+
self.tool_callback = tool_callback
|
|
41
|
+
self.api_callback = api_callback
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def on_content(self, content: Any) -> None:
|
|
45
|
+
"""Handle content updates."""
|
|
46
|
+
raise NotImplementedError
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def on_tool_result(self, result: Any, tool_id: str) -> None:
|
|
50
|
+
"""Handle tool execution results."""
|
|
51
|
+
raise NotImplementedError
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def on_api_interaction(
|
|
55
|
+
self,
|
|
56
|
+
request: Any,
|
|
57
|
+
response: Any,
|
|
58
|
+
error: Optional[Exception] = None
|
|
59
|
+
) -> None:
|
|
60
|
+
"""Handle API interactions."""
|
|
61
|
+
raise NotImplementedError
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class CallbackManager:
|
|
65
|
+
"""Manager for callback handlers."""
|
|
66
|
+
|
|
67
|
+
def __init__(self, handlers: Optional[List["CallbackHandler"]] = None):
|
|
68
|
+
"""Initialize with optional handlers.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
handlers: List of callback handlers
|
|
72
|
+
"""
|
|
73
|
+
self.handlers = handlers or []
|
|
74
|
+
|
|
75
|
+
def add_handler(self, handler: "CallbackHandler") -> None:
|
|
76
|
+
"""Add a callback handler.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
handler: Callback handler to add
|
|
80
|
+
"""
|
|
81
|
+
self.handlers.append(handler)
|
|
82
|
+
|
|
83
|
+
async def on_action_start(self, action: str, **kwargs) -> None:
|
|
84
|
+
"""Called when an action starts.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
action: Action name
|
|
88
|
+
**kwargs: Additional data
|
|
89
|
+
"""
|
|
90
|
+
for handler in self.handlers:
|
|
91
|
+
await handler.on_action_start(action, **kwargs)
|
|
92
|
+
|
|
93
|
+
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
|
|
94
|
+
"""Called when an action ends.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
action: Action name
|
|
98
|
+
success: Whether the action was successful
|
|
99
|
+
**kwargs: Additional data
|
|
100
|
+
"""
|
|
101
|
+
for handler in self.handlers:
|
|
102
|
+
await handler.on_action_end(action, success, **kwargs)
|
|
103
|
+
|
|
104
|
+
async def on_error(self, error: Exception, **kwargs) -> None:
|
|
105
|
+
"""Called when an error occurs.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
error: Exception that occurred
|
|
109
|
+
**kwargs: Additional data
|
|
110
|
+
"""
|
|
111
|
+
for handler in self.handlers:
|
|
112
|
+
await handler.on_error(error, **kwargs)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class CallbackHandler(ABC):
|
|
116
|
+
"""Base class for callback handlers."""
|
|
117
|
+
|
|
118
|
+
@abstractmethod
|
|
119
|
+
async def on_action_start(self, action: str, **kwargs) -> None:
|
|
120
|
+
"""Called when an action starts.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
action: Action name
|
|
124
|
+
**kwargs: Additional data
|
|
125
|
+
"""
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
@abstractmethod
|
|
129
|
+
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
|
|
130
|
+
"""Called when an action ends.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
action: Action name
|
|
134
|
+
success: Whether the action was successful
|
|
135
|
+
**kwargs: Additional data
|
|
136
|
+
"""
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
@abstractmethod
|
|
140
|
+
async def on_error(self, error: Exception, **kwargs) -> None:
|
|
141
|
+
"""Called when an error occurs.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
error: Exception that occurred
|
|
145
|
+
**kwargs: Additional data
|
|
146
|
+
"""
|
|
147
|
+
pass
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Main entry point for computer agents."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, AsyncGenerator, Dict, Optional
|
|
5
|
+
|
|
6
|
+
from computer import Computer
|
|
7
|
+
from ..types.base import Provider
|
|
8
|
+
from .factory import AgentFactory
|
|
9
|
+
|
|
10
|
+
logging.basicConfig(level=logging.INFO)
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ComputerAgent:
|
|
15
|
+
"""A computer agent that can perform automated tasks using natural language instructions."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, provider: Provider, computer: Optional[Computer] = None, **kwargs):
|
|
18
|
+
"""Initialize the ComputerAgent.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
provider: The AI provider to use (e.g., Provider.ANTHROPIC)
|
|
22
|
+
computer: Optional Computer instance. If not provided, one will be created with default settings.
|
|
23
|
+
**kwargs: Additional provider-specific arguments
|
|
24
|
+
"""
|
|
25
|
+
self.provider = provider
|
|
26
|
+
self._computer = computer
|
|
27
|
+
self._kwargs = kwargs
|
|
28
|
+
self._agent = None
|
|
29
|
+
self._initialized = False
|
|
30
|
+
self._in_context = False
|
|
31
|
+
|
|
32
|
+
# Create provider-specific agent using factory
|
|
33
|
+
self._agent = AgentFactory.create(provider=provider, computer=computer, **kwargs)
|
|
34
|
+
|
|
35
|
+
async def __aenter__(self):
|
|
36
|
+
"""Enter the async context manager."""
|
|
37
|
+
self._in_context = True
|
|
38
|
+
await self.initialize()
|
|
39
|
+
return self
|
|
40
|
+
|
|
41
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
42
|
+
"""Exit the async context manager."""
|
|
43
|
+
self._in_context = False
|
|
44
|
+
|
|
45
|
+
async def initialize(self) -> None:
|
|
46
|
+
"""Initialize the agent and its components."""
|
|
47
|
+
if not self._initialized:
|
|
48
|
+
if not self._in_context and self._computer:
|
|
49
|
+
# If not in context manager but have a computer, initialize it
|
|
50
|
+
await self._computer.run()
|
|
51
|
+
self._initialized = True
|
|
52
|
+
|
|
53
|
+
async def run(self, task: str) -> AsyncGenerator[Dict[str, Any], None]:
|
|
54
|
+
"""Run the agent with a given task."""
|
|
55
|
+
if not self._initialized:
|
|
56
|
+
await self.initialize()
|
|
57
|
+
|
|
58
|
+
if self._agent is None:
|
|
59
|
+
logger.error("Agent not initialized properly")
|
|
60
|
+
yield {"error": "Agent not initialized properly"}
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
async for result in self._agent.run(task):
|
|
64
|
+
yield result
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def computer(self) -> Optional[Computer]:
|
|
68
|
+
"""Get the underlying computer instance."""
|
|
69
|
+
return self._agent.computer if self._agent else None
|
agent/core/experiment.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""Core experiment management for agents."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
import base64
|
|
6
|
+
from io import BytesIO
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import Any, Dict, List, Optional
|
|
9
|
+
from PIL import Image
|
|
10
|
+
import json
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ExperimentManager:
|
|
16
|
+
"""Manages experiment directories and logging for the agent."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
base_dir: Optional[str] = None,
|
|
21
|
+
only_n_most_recent_images: Optional[int] = None,
|
|
22
|
+
):
|
|
23
|
+
"""Initialize the experiment manager.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
base_dir: Base directory for saving experiment data
|
|
27
|
+
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
|
28
|
+
"""
|
|
29
|
+
self.base_dir = base_dir
|
|
30
|
+
self.only_n_most_recent_images = only_n_most_recent_images
|
|
31
|
+
self.run_dir = None
|
|
32
|
+
self.current_turn_dir = None
|
|
33
|
+
self.turn_count = 0
|
|
34
|
+
self.screenshot_count = 0
|
|
35
|
+
# Track all screenshots for potential API request inclusion
|
|
36
|
+
self.screenshot_paths = []
|
|
37
|
+
|
|
38
|
+
# Set up experiment directories if base_dir is provided
|
|
39
|
+
if self.base_dir:
|
|
40
|
+
self.setup_experiment_dirs()
|
|
41
|
+
|
|
42
|
+
def setup_experiment_dirs(self) -> None:
|
|
43
|
+
"""Setup the experiment directory structure."""
|
|
44
|
+
if not self.base_dir:
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
# Create base experiments directory if it doesn't exist
|
|
48
|
+
os.makedirs(self.base_dir, exist_ok=True)
|
|
49
|
+
|
|
50
|
+
# Create timestamped run directory
|
|
51
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
52
|
+
self.run_dir = os.path.join(self.base_dir, timestamp)
|
|
53
|
+
os.makedirs(self.run_dir, exist_ok=True)
|
|
54
|
+
logger.info(f"Created run directory: {self.run_dir}")
|
|
55
|
+
|
|
56
|
+
# Create first turn directory
|
|
57
|
+
self.create_turn_dir()
|
|
58
|
+
|
|
59
|
+
def create_turn_dir(self) -> None:
|
|
60
|
+
"""Create a new directory for the current turn."""
|
|
61
|
+
if not self.run_dir:
|
|
62
|
+
logger.warning("Cannot create turn directory: run_dir not set")
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
# Increment turn counter
|
|
66
|
+
self.turn_count += 1
|
|
67
|
+
|
|
68
|
+
# Create turn directory with padded number
|
|
69
|
+
turn_name = f"turn_{self.turn_count:03d}"
|
|
70
|
+
self.current_turn_dir = os.path.join(self.run_dir, turn_name)
|
|
71
|
+
os.makedirs(self.current_turn_dir, exist_ok=True)
|
|
72
|
+
logger.info(f"Created turn directory: {self.current_turn_dir}")
|
|
73
|
+
|
|
74
|
+
def sanitize_log_data(self, data: Any) -> Any:
|
|
75
|
+
"""Sanitize log data by replacing large binary data with placeholders.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
data: Data to sanitize
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Sanitized copy of the data
|
|
82
|
+
"""
|
|
83
|
+
if isinstance(data, dict):
|
|
84
|
+
result = {}
|
|
85
|
+
for k, v in data.items():
|
|
86
|
+
result[k] = self.sanitize_log_data(v)
|
|
87
|
+
return result
|
|
88
|
+
elif isinstance(data, list):
|
|
89
|
+
return [self.sanitize_log_data(item) for item in data]
|
|
90
|
+
elif isinstance(data, str) and len(data) > 1000 and "base64" in data.lower():
|
|
91
|
+
return f"[BASE64_DATA_LENGTH_{len(data)}]"
|
|
92
|
+
else:
|
|
93
|
+
return data
|
|
94
|
+
|
|
95
|
+
def save_screenshot(self, img_base64: str, action_type: str = "") -> None:
|
|
96
|
+
"""Save a screenshot to the experiment directory.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
img_base64: Base64 encoded screenshot
|
|
100
|
+
action_type: Type of action that triggered the screenshot
|
|
101
|
+
"""
|
|
102
|
+
if not self.current_turn_dir:
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
# Increment screenshot counter
|
|
107
|
+
self.screenshot_count += 1
|
|
108
|
+
|
|
109
|
+
# Create a descriptive filename
|
|
110
|
+
timestamp = int(datetime.now().timestamp() * 1000)
|
|
111
|
+
action_suffix = f"_{action_type}" if action_type else ""
|
|
112
|
+
filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
|
|
113
|
+
|
|
114
|
+
# Save directly to the turn directory
|
|
115
|
+
filepath = os.path.join(self.current_turn_dir, filename)
|
|
116
|
+
|
|
117
|
+
# Save the screenshot
|
|
118
|
+
img_data = base64.b64decode(img_base64)
|
|
119
|
+
with open(filepath, "wb") as f:
|
|
120
|
+
f.write(img_data)
|
|
121
|
+
|
|
122
|
+
# Keep track of the file path
|
|
123
|
+
self.screenshot_paths.append(filepath)
|
|
124
|
+
|
|
125
|
+
return filepath
|
|
126
|
+
except Exception as e:
|
|
127
|
+
logger.error(f"Error saving screenshot: {str(e)}")
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
def save_action_visualization(
|
|
131
|
+
self, img: Image.Image, action_name: str, details: str = ""
|
|
132
|
+
) -> str:
|
|
133
|
+
"""Save a visualization of an action.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
img: Image to save
|
|
137
|
+
action_name: Name of the action
|
|
138
|
+
details: Additional details about the action
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Path to the saved image
|
|
142
|
+
"""
|
|
143
|
+
if not self.current_turn_dir:
|
|
144
|
+
return ""
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
# Create a descriptive filename
|
|
148
|
+
timestamp = int(datetime.now().timestamp() * 1000)
|
|
149
|
+
details_suffix = f"_{details}" if details else ""
|
|
150
|
+
filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
|
|
151
|
+
|
|
152
|
+
# Save directly to the turn directory
|
|
153
|
+
filepath = os.path.join(self.current_turn_dir, filename)
|
|
154
|
+
|
|
155
|
+
# Save the image
|
|
156
|
+
img.save(filepath)
|
|
157
|
+
|
|
158
|
+
# Keep track of the file path
|
|
159
|
+
self.screenshot_paths.append(filepath)
|
|
160
|
+
|
|
161
|
+
return filepath
|
|
162
|
+
except Exception as e:
|
|
163
|
+
logger.error(f"Error saving action visualization: {str(e)}")
|
|
164
|
+
return ""
|
|
165
|
+
|
|
166
|
+
def log_api_call(
|
|
167
|
+
self,
|
|
168
|
+
call_type: str,
|
|
169
|
+
request: Any,
|
|
170
|
+
provider: str = "unknown",
|
|
171
|
+
model: str = "unknown",
|
|
172
|
+
response: Any = None,
|
|
173
|
+
error: Optional[Exception] = None,
|
|
174
|
+
) -> None:
|
|
175
|
+
"""Log API call details to file.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
call_type: Type of API call (request, response, error)
|
|
179
|
+
request: Request data
|
|
180
|
+
provider: API provider name
|
|
181
|
+
model: Model name
|
|
182
|
+
response: Response data (for response logs)
|
|
183
|
+
error: Error information (for error logs)
|
|
184
|
+
"""
|
|
185
|
+
if not self.current_turn_dir:
|
|
186
|
+
logger.warning("Cannot log API call: current_turn_dir not set")
|
|
187
|
+
return
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
# Create a timestamp for the log file
|
|
191
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
192
|
+
|
|
193
|
+
# Create filename based on log type
|
|
194
|
+
filename = f"api_call_{timestamp}_{call_type}.json"
|
|
195
|
+
filepath = os.path.join(self.current_turn_dir, filename)
|
|
196
|
+
|
|
197
|
+
# Sanitize data before logging
|
|
198
|
+
sanitized_request = self.sanitize_log_data(request)
|
|
199
|
+
sanitized_response = self.sanitize_log_data(response) if response is not None else None
|
|
200
|
+
|
|
201
|
+
# Prepare log data
|
|
202
|
+
log_data = {
|
|
203
|
+
"timestamp": timestamp,
|
|
204
|
+
"provider": provider,
|
|
205
|
+
"model": model,
|
|
206
|
+
"type": call_type,
|
|
207
|
+
"request": sanitized_request,
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
if sanitized_response is not None:
|
|
211
|
+
log_data["response"] = sanitized_response
|
|
212
|
+
if error is not None:
|
|
213
|
+
log_data["error"] = str(error)
|
|
214
|
+
|
|
215
|
+
# Write to file
|
|
216
|
+
with open(filepath, "w") as f:
|
|
217
|
+
json.dump(log_data, f, indent=2, default=str)
|
|
218
|
+
|
|
219
|
+
logger.info(f"Logged API {call_type} to {filepath}")
|
|
220
|
+
|
|
221
|
+
except Exception as e:
|
|
222
|
+
logger.error(f"Error logging API call: {str(e)}")
|
agent/core/factory.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Factory for creating provider-specific agents."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Dict, Any, List
|
|
4
|
+
|
|
5
|
+
from computer import Computer
|
|
6
|
+
from ..types.base import Provider
|
|
7
|
+
from .base_agent import BaseComputerAgent
|
|
8
|
+
|
|
9
|
+
# Import provider-specific implementations
|
|
10
|
+
_ANTHROPIC_AVAILABLE = False
|
|
11
|
+
_OPENAI_AVAILABLE = False
|
|
12
|
+
_OLLAMA_AVAILABLE = False
|
|
13
|
+
_OMNI_AVAILABLE = False
|
|
14
|
+
|
|
15
|
+
# Try importing providers
|
|
16
|
+
try:
|
|
17
|
+
import anthropic
|
|
18
|
+
from ..providers.anthropic.agent import AnthropicComputerAgent
|
|
19
|
+
|
|
20
|
+
_ANTHROPIC_AVAILABLE = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
import openai
|
|
26
|
+
|
|
27
|
+
_OPENAI_AVAILABLE = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
from ..providers.omni.agent import OmniComputerAgent
|
|
33
|
+
|
|
34
|
+
_OMNI_AVAILABLE = True
|
|
35
|
+
except ImportError:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class AgentFactory:
|
|
40
|
+
"""Factory for creating provider-specific agent implementations."""
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def create(
|
|
44
|
+
provider: Provider, computer: Optional[Computer] = None, **kwargs: Any
|
|
45
|
+
) -> BaseComputerAgent:
|
|
46
|
+
"""Create an agent based on the specified provider.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
provider: The AI provider to use
|
|
50
|
+
computer: Optional Computer instance
|
|
51
|
+
**kwargs: Additional provider-specific arguments
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
A provider-specific agent implementation
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ImportError: If provider dependencies are not installed
|
|
58
|
+
ValueError: If provider is not supported
|
|
59
|
+
"""
|
|
60
|
+
# Create a Computer instance if none is provided
|
|
61
|
+
if computer is None:
|
|
62
|
+
computer = Computer()
|
|
63
|
+
|
|
64
|
+
if provider == Provider.ANTHROPIC:
|
|
65
|
+
if not _ANTHROPIC_AVAILABLE:
|
|
66
|
+
raise ImportError(
|
|
67
|
+
"Anthropic provider requires additional dependencies. "
|
|
68
|
+
"Install them with: pip install cua-agent[anthropic]"
|
|
69
|
+
)
|
|
70
|
+
return AnthropicComputerAgent(max_retries=3, computer=computer, **kwargs)
|
|
71
|
+
elif provider == Provider.OPENAI:
|
|
72
|
+
if not _OPENAI_AVAILABLE:
|
|
73
|
+
raise ImportError(
|
|
74
|
+
"OpenAI provider requires additional dependencies. "
|
|
75
|
+
"Install them with: pip install cua-agent[openai]"
|
|
76
|
+
)
|
|
77
|
+
raise NotImplementedError("OpenAI provider not yet implemented")
|
|
78
|
+
elif provider == Provider.OLLAMA:
|
|
79
|
+
if not _OLLAMA_AVAILABLE:
|
|
80
|
+
raise ImportError(
|
|
81
|
+
"Ollama provider requires additional dependencies. "
|
|
82
|
+
"Install them with: pip install cua-agent[ollama]"
|
|
83
|
+
)
|
|
84
|
+
# Only import ollama when actually creating an Ollama agent
|
|
85
|
+
try:
|
|
86
|
+
import ollama
|
|
87
|
+
from ..providers.ollama.agent import OllamaComputerAgent
|
|
88
|
+
|
|
89
|
+
return OllamaComputerAgent(max_retries=3, computer=computer, **kwargs)
|
|
90
|
+
except ImportError:
|
|
91
|
+
raise ImportError(
|
|
92
|
+
"Failed to import ollama package. " "Install it with: pip install ollama"
|
|
93
|
+
)
|
|
94
|
+
elif provider == Provider.OMNI:
|
|
95
|
+
if not _OMNI_AVAILABLE:
|
|
96
|
+
raise ImportError(
|
|
97
|
+
"Omni provider requires additional dependencies. "
|
|
98
|
+
"Install them with: pip install cua-agent[omni]"
|
|
99
|
+
)
|
|
100
|
+
return OmniComputerAgent(max_retries=3, computer=computer, **kwargs)
|
|
101
|
+
else:
|
|
102
|
+
raise ValueError(f"Unsupported provider: {provider}")
|