cua-agent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (65) hide show
  1. agent/README.md +63 -0
  2. agent/__init__.py +10 -0
  3. agent/core/README.md +101 -0
  4. agent/core/__init__.py +34 -0
  5. agent/core/agent.py +284 -0
  6. agent/core/base_agent.py +164 -0
  7. agent/core/callbacks.py +147 -0
  8. agent/core/computer_agent.py +69 -0
  9. agent/core/experiment.py +222 -0
  10. agent/core/factory.py +102 -0
  11. agent/core/loop.py +244 -0
  12. agent/core/messages.py +230 -0
  13. agent/core/tools/__init__.py +21 -0
  14. agent/core/tools/base.py +74 -0
  15. agent/core/tools/bash.py +52 -0
  16. agent/core/tools/collection.py +46 -0
  17. agent/core/tools/computer.py +113 -0
  18. agent/core/tools/edit.py +67 -0
  19. agent/core/tools/manager.py +56 -0
  20. agent/providers/__init__.py +4 -0
  21. agent/providers/anthropic/__init__.py +6 -0
  22. agent/providers/anthropic/api/client.py +222 -0
  23. agent/providers/anthropic/api/logging.py +150 -0
  24. agent/providers/anthropic/callbacks/manager.py +55 -0
  25. agent/providers/anthropic/loop.py +521 -0
  26. agent/providers/anthropic/messages/manager.py +110 -0
  27. agent/providers/anthropic/prompts.py +20 -0
  28. agent/providers/anthropic/tools/__init__.py +33 -0
  29. agent/providers/anthropic/tools/base.py +88 -0
  30. agent/providers/anthropic/tools/bash.py +163 -0
  31. agent/providers/anthropic/tools/collection.py +34 -0
  32. agent/providers/anthropic/tools/computer.py +550 -0
  33. agent/providers/anthropic/tools/edit.py +326 -0
  34. agent/providers/anthropic/tools/manager.py +54 -0
  35. agent/providers/anthropic/tools/run.py +42 -0
  36. agent/providers/anthropic/types.py +16 -0
  37. agent/providers/omni/__init__.py +27 -0
  38. agent/providers/omni/callbacks.py +78 -0
  39. agent/providers/omni/clients/anthropic.py +99 -0
  40. agent/providers/omni/clients/base.py +44 -0
  41. agent/providers/omni/clients/groq.py +101 -0
  42. agent/providers/omni/clients/openai.py +159 -0
  43. agent/providers/omni/clients/utils.py +25 -0
  44. agent/providers/omni/experiment.py +273 -0
  45. agent/providers/omni/image_utils.py +106 -0
  46. agent/providers/omni/loop.py +961 -0
  47. agent/providers/omni/messages.py +168 -0
  48. agent/providers/omni/parser.py +252 -0
  49. agent/providers/omni/prompts.py +78 -0
  50. agent/providers/omni/tool_manager.py +91 -0
  51. agent/providers/omni/tools/__init__.py +13 -0
  52. agent/providers/omni/tools/bash.py +69 -0
  53. agent/providers/omni/tools/computer.py +216 -0
  54. agent/providers/omni/tools/manager.py +83 -0
  55. agent/providers/omni/types.py +30 -0
  56. agent/providers/omni/utils.py +155 -0
  57. agent/providers/omni/visualization.py +130 -0
  58. agent/types/__init__.py +26 -0
  59. agent/types/base.py +52 -0
  60. agent/types/messages.py +36 -0
  61. agent/types/tools.py +32 -0
  62. cua_agent-0.1.0.dist-info/METADATA +44 -0
  63. cua_agent-0.1.0.dist-info/RECORD +65 -0
  64. cua_agent-0.1.0.dist-info/WHEEL +4 -0
  65. cua_agent-0.1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,147 @@
1
+ """Callback handlers for agent."""
2
+
3
+ import json
4
+ import logging
5
+ from abc import ABC, abstractmethod
6
+ from datetime import datetime
7
+ from typing import Any, Dict, List, Optional, Protocol
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class ContentCallback(Protocol):
12
+ """Protocol for content callbacks."""
13
+ def __call__(self, content: Dict[str, Any]) -> None: ...
14
+
15
+ class ToolCallback(Protocol):
16
+ """Protocol for tool callbacks."""
17
+ def __call__(self, result: Any, tool_id: str) -> None: ...
18
+
19
+ class APICallback(Protocol):
20
+ """Protocol for API callbacks."""
21
+ def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ...
22
+
23
+ class BaseCallbackManager(ABC):
24
+ """Base class for callback managers."""
25
+
26
+ def __init__(
27
+ self,
28
+ content_callback: ContentCallback,
29
+ tool_callback: ToolCallback,
30
+ api_callback: APICallback,
31
+ ):
32
+ """Initialize the callback manager.
33
+
34
+ Args:
35
+ content_callback: Callback for content updates
36
+ tool_callback: Callback for tool execution results
37
+ api_callback: Callback for API interactions
38
+ """
39
+ self.content_callback = content_callback
40
+ self.tool_callback = tool_callback
41
+ self.api_callback = api_callback
42
+
43
+ @abstractmethod
44
+ def on_content(self, content: Any) -> None:
45
+ """Handle content updates."""
46
+ raise NotImplementedError
47
+
48
+ @abstractmethod
49
+ def on_tool_result(self, result: Any, tool_id: str) -> None:
50
+ """Handle tool execution results."""
51
+ raise NotImplementedError
52
+
53
+ @abstractmethod
54
+ def on_api_interaction(
55
+ self,
56
+ request: Any,
57
+ response: Any,
58
+ error: Optional[Exception] = None
59
+ ) -> None:
60
+ """Handle API interactions."""
61
+ raise NotImplementedError
62
+
63
+
64
+ class CallbackManager:
65
+ """Manager for callback handlers."""
66
+
67
+ def __init__(self, handlers: Optional[List["CallbackHandler"]] = None):
68
+ """Initialize with optional handlers.
69
+
70
+ Args:
71
+ handlers: List of callback handlers
72
+ """
73
+ self.handlers = handlers or []
74
+
75
+ def add_handler(self, handler: "CallbackHandler") -> None:
76
+ """Add a callback handler.
77
+
78
+ Args:
79
+ handler: Callback handler to add
80
+ """
81
+ self.handlers.append(handler)
82
+
83
+ async def on_action_start(self, action: str, **kwargs) -> None:
84
+ """Called when an action starts.
85
+
86
+ Args:
87
+ action: Action name
88
+ **kwargs: Additional data
89
+ """
90
+ for handler in self.handlers:
91
+ await handler.on_action_start(action, **kwargs)
92
+
93
+ async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
94
+ """Called when an action ends.
95
+
96
+ Args:
97
+ action: Action name
98
+ success: Whether the action was successful
99
+ **kwargs: Additional data
100
+ """
101
+ for handler in self.handlers:
102
+ await handler.on_action_end(action, success, **kwargs)
103
+
104
+ async def on_error(self, error: Exception, **kwargs) -> None:
105
+ """Called when an error occurs.
106
+
107
+ Args:
108
+ error: Exception that occurred
109
+ **kwargs: Additional data
110
+ """
111
+ for handler in self.handlers:
112
+ await handler.on_error(error, **kwargs)
113
+
114
+
115
+ class CallbackHandler(ABC):
116
+ """Base class for callback handlers."""
117
+
118
+ @abstractmethod
119
+ async def on_action_start(self, action: str, **kwargs) -> None:
120
+ """Called when an action starts.
121
+
122
+ Args:
123
+ action: Action name
124
+ **kwargs: Additional data
125
+ """
126
+ pass
127
+
128
+ @abstractmethod
129
+ async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
130
+ """Called when an action ends.
131
+
132
+ Args:
133
+ action: Action name
134
+ success: Whether the action was successful
135
+ **kwargs: Additional data
136
+ """
137
+ pass
138
+
139
+ @abstractmethod
140
+ async def on_error(self, error: Exception, **kwargs) -> None:
141
+ """Called when an error occurs.
142
+
143
+ Args:
144
+ error: Exception that occurred
145
+ **kwargs: Additional data
146
+ """
147
+ pass
@@ -0,0 +1,69 @@
1
+ """Main entry point for computer agents."""
2
+
3
+ import logging
4
+ from typing import Any, AsyncGenerator, Dict, Optional
5
+
6
+ from computer import Computer
7
+ from ..types.base import Provider
8
+ from .factory import AgentFactory
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class ComputerAgent:
15
+ """A computer agent that can perform automated tasks using natural language instructions."""
16
+
17
+ def __init__(self, provider: Provider, computer: Optional[Computer] = None, **kwargs):
18
+ """Initialize the ComputerAgent.
19
+
20
+ Args:
21
+ provider: The AI provider to use (e.g., Provider.ANTHROPIC)
22
+ computer: Optional Computer instance. If not provided, one will be created with default settings.
23
+ **kwargs: Additional provider-specific arguments
24
+ """
25
+ self.provider = provider
26
+ self._computer = computer
27
+ self._kwargs = kwargs
28
+ self._agent = None
29
+ self._initialized = False
30
+ self._in_context = False
31
+
32
+ # Create provider-specific agent using factory
33
+ self._agent = AgentFactory.create(provider=provider, computer=computer, **kwargs)
34
+
35
+ async def __aenter__(self):
36
+ """Enter the async context manager."""
37
+ self._in_context = True
38
+ await self.initialize()
39
+ return self
40
+
41
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
42
+ """Exit the async context manager."""
43
+ self._in_context = False
44
+
45
+ async def initialize(self) -> None:
46
+ """Initialize the agent and its components."""
47
+ if not self._initialized:
48
+ if not self._in_context and self._computer:
49
+ # If not in context manager but have a computer, initialize it
50
+ await self._computer.run()
51
+ self._initialized = True
52
+
53
+ async def run(self, task: str) -> AsyncGenerator[Dict[str, Any], None]:
54
+ """Run the agent with a given task."""
55
+ if not self._initialized:
56
+ await self.initialize()
57
+
58
+ if self._agent is None:
59
+ logger.error("Agent not initialized properly")
60
+ yield {"error": "Agent not initialized properly"}
61
+ return
62
+
63
+ async for result in self._agent.run(task):
64
+ yield result
65
+
66
+ @property
67
+ def computer(self) -> Optional[Computer]:
68
+ """Get the underlying computer instance."""
69
+ return self._agent.computer if self._agent else None
@@ -0,0 +1,222 @@
1
+ """Core experiment management for agents."""
2
+
3
+ import os
4
+ import logging
5
+ import base64
6
+ from io import BytesIO
7
+ from datetime import datetime
8
+ from typing import Any, Dict, List, Optional
9
+ from PIL import Image
10
+ import json
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ExperimentManager:
16
+ """Manages experiment directories and logging for the agent."""
17
+
18
+ def __init__(
19
+ self,
20
+ base_dir: Optional[str] = None,
21
+ only_n_most_recent_images: Optional[int] = None,
22
+ ):
23
+ """Initialize the experiment manager.
24
+
25
+ Args:
26
+ base_dir: Base directory for saving experiment data
27
+ only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
28
+ """
29
+ self.base_dir = base_dir
30
+ self.only_n_most_recent_images = only_n_most_recent_images
31
+ self.run_dir = None
32
+ self.current_turn_dir = None
33
+ self.turn_count = 0
34
+ self.screenshot_count = 0
35
+ # Track all screenshots for potential API request inclusion
36
+ self.screenshot_paths = []
37
+
38
+ # Set up experiment directories if base_dir is provided
39
+ if self.base_dir:
40
+ self.setup_experiment_dirs()
41
+
42
+ def setup_experiment_dirs(self) -> None:
43
+ """Setup the experiment directory structure."""
44
+ if not self.base_dir:
45
+ return
46
+
47
+ # Create base experiments directory if it doesn't exist
48
+ os.makedirs(self.base_dir, exist_ok=True)
49
+
50
+ # Create timestamped run directory
51
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
52
+ self.run_dir = os.path.join(self.base_dir, timestamp)
53
+ os.makedirs(self.run_dir, exist_ok=True)
54
+ logger.info(f"Created run directory: {self.run_dir}")
55
+
56
+ # Create first turn directory
57
+ self.create_turn_dir()
58
+
59
+ def create_turn_dir(self) -> None:
60
+ """Create a new directory for the current turn."""
61
+ if not self.run_dir:
62
+ logger.warning("Cannot create turn directory: run_dir not set")
63
+ return
64
+
65
+ # Increment turn counter
66
+ self.turn_count += 1
67
+
68
+ # Create turn directory with padded number
69
+ turn_name = f"turn_{self.turn_count:03d}"
70
+ self.current_turn_dir = os.path.join(self.run_dir, turn_name)
71
+ os.makedirs(self.current_turn_dir, exist_ok=True)
72
+ logger.info(f"Created turn directory: {self.current_turn_dir}")
73
+
74
+ def sanitize_log_data(self, data: Any) -> Any:
75
+ """Sanitize log data by replacing large binary data with placeholders.
76
+
77
+ Args:
78
+ data: Data to sanitize
79
+
80
+ Returns:
81
+ Sanitized copy of the data
82
+ """
83
+ if isinstance(data, dict):
84
+ result = {}
85
+ for k, v in data.items():
86
+ result[k] = self.sanitize_log_data(v)
87
+ return result
88
+ elif isinstance(data, list):
89
+ return [self.sanitize_log_data(item) for item in data]
90
+ elif isinstance(data, str) and len(data) > 1000 and "base64" in data.lower():
91
+ return f"[BASE64_DATA_LENGTH_{len(data)}]"
92
+ else:
93
+ return data
94
+
95
+ def save_screenshot(self, img_base64: str, action_type: str = "") -> None:
96
+ """Save a screenshot to the experiment directory.
97
+
98
+ Args:
99
+ img_base64: Base64 encoded screenshot
100
+ action_type: Type of action that triggered the screenshot
101
+ """
102
+ if not self.current_turn_dir:
103
+ return
104
+
105
+ try:
106
+ # Increment screenshot counter
107
+ self.screenshot_count += 1
108
+
109
+ # Create a descriptive filename
110
+ timestamp = int(datetime.now().timestamp() * 1000)
111
+ action_suffix = f"_{action_type}" if action_type else ""
112
+ filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
113
+
114
+ # Save directly to the turn directory
115
+ filepath = os.path.join(self.current_turn_dir, filename)
116
+
117
+ # Save the screenshot
118
+ img_data = base64.b64decode(img_base64)
119
+ with open(filepath, "wb") as f:
120
+ f.write(img_data)
121
+
122
+ # Keep track of the file path
123
+ self.screenshot_paths.append(filepath)
124
+
125
+ return filepath
126
+ except Exception as e:
127
+ logger.error(f"Error saving screenshot: {str(e)}")
128
+ return None
129
+
130
+ def save_action_visualization(
131
+ self, img: Image.Image, action_name: str, details: str = ""
132
+ ) -> str:
133
+ """Save a visualization of an action.
134
+
135
+ Args:
136
+ img: Image to save
137
+ action_name: Name of the action
138
+ details: Additional details about the action
139
+
140
+ Returns:
141
+ Path to the saved image
142
+ """
143
+ if not self.current_turn_dir:
144
+ return ""
145
+
146
+ try:
147
+ # Create a descriptive filename
148
+ timestamp = int(datetime.now().timestamp() * 1000)
149
+ details_suffix = f"_{details}" if details else ""
150
+ filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
151
+
152
+ # Save directly to the turn directory
153
+ filepath = os.path.join(self.current_turn_dir, filename)
154
+
155
+ # Save the image
156
+ img.save(filepath)
157
+
158
+ # Keep track of the file path
159
+ self.screenshot_paths.append(filepath)
160
+
161
+ return filepath
162
+ except Exception as e:
163
+ logger.error(f"Error saving action visualization: {str(e)}")
164
+ return ""
165
+
166
+ def log_api_call(
167
+ self,
168
+ call_type: str,
169
+ request: Any,
170
+ provider: str = "unknown",
171
+ model: str = "unknown",
172
+ response: Any = None,
173
+ error: Optional[Exception] = None,
174
+ ) -> None:
175
+ """Log API call details to file.
176
+
177
+ Args:
178
+ call_type: Type of API call (request, response, error)
179
+ request: Request data
180
+ provider: API provider name
181
+ model: Model name
182
+ response: Response data (for response logs)
183
+ error: Error information (for error logs)
184
+ """
185
+ if not self.current_turn_dir:
186
+ logger.warning("Cannot log API call: current_turn_dir not set")
187
+ return
188
+
189
+ try:
190
+ # Create a timestamp for the log file
191
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
192
+
193
+ # Create filename based on log type
194
+ filename = f"api_call_{timestamp}_{call_type}.json"
195
+ filepath = os.path.join(self.current_turn_dir, filename)
196
+
197
+ # Sanitize data before logging
198
+ sanitized_request = self.sanitize_log_data(request)
199
+ sanitized_response = self.sanitize_log_data(response) if response is not None else None
200
+
201
+ # Prepare log data
202
+ log_data = {
203
+ "timestamp": timestamp,
204
+ "provider": provider,
205
+ "model": model,
206
+ "type": call_type,
207
+ "request": sanitized_request,
208
+ }
209
+
210
+ if sanitized_response is not None:
211
+ log_data["response"] = sanitized_response
212
+ if error is not None:
213
+ log_data["error"] = str(error)
214
+
215
+ # Write to file
216
+ with open(filepath, "w") as f:
217
+ json.dump(log_data, f, indent=2, default=str)
218
+
219
+ logger.info(f"Logged API {call_type} to {filepath}")
220
+
221
+ except Exception as e:
222
+ logger.error(f"Error logging API call: {str(e)}")
agent/core/factory.py ADDED
@@ -0,0 +1,102 @@
1
+ """Factory for creating provider-specific agents."""
2
+
3
+ from typing import Optional, Dict, Any, List
4
+
5
+ from computer import Computer
6
+ from ..types.base import Provider
7
+ from .base_agent import BaseComputerAgent
8
+
9
+ # Import provider-specific implementations
10
+ _ANTHROPIC_AVAILABLE = False
11
+ _OPENAI_AVAILABLE = False
12
+ _OLLAMA_AVAILABLE = False
13
+ _OMNI_AVAILABLE = False
14
+
15
+ # Try importing providers
16
+ try:
17
+ import anthropic
18
+ from ..providers.anthropic.agent import AnthropicComputerAgent
19
+
20
+ _ANTHROPIC_AVAILABLE = True
21
+ except ImportError:
22
+ pass
23
+
24
+ try:
25
+ import openai
26
+
27
+ _OPENAI_AVAILABLE = True
28
+ except ImportError:
29
+ pass
30
+
31
+ try:
32
+ from ..providers.omni.agent import OmniComputerAgent
33
+
34
+ _OMNI_AVAILABLE = True
35
+ except ImportError:
36
+ pass
37
+
38
+
39
+ class AgentFactory:
40
+ """Factory for creating provider-specific agent implementations."""
41
+
42
+ @staticmethod
43
+ def create(
44
+ provider: Provider, computer: Optional[Computer] = None, **kwargs: Any
45
+ ) -> BaseComputerAgent:
46
+ """Create an agent based on the specified provider.
47
+
48
+ Args:
49
+ provider: The AI provider to use
50
+ computer: Optional Computer instance
51
+ **kwargs: Additional provider-specific arguments
52
+
53
+ Returns:
54
+ A provider-specific agent implementation
55
+
56
+ Raises:
57
+ ImportError: If provider dependencies are not installed
58
+ ValueError: If provider is not supported
59
+ """
60
+ # Create a Computer instance if none is provided
61
+ if computer is None:
62
+ computer = Computer()
63
+
64
+ if provider == Provider.ANTHROPIC:
65
+ if not _ANTHROPIC_AVAILABLE:
66
+ raise ImportError(
67
+ "Anthropic provider requires additional dependencies. "
68
+ "Install them with: pip install cua-agent[anthropic]"
69
+ )
70
+ return AnthropicComputerAgent(max_retries=3, computer=computer, **kwargs)
71
+ elif provider == Provider.OPENAI:
72
+ if not _OPENAI_AVAILABLE:
73
+ raise ImportError(
74
+ "OpenAI provider requires additional dependencies. "
75
+ "Install them with: pip install cua-agent[openai]"
76
+ )
77
+ raise NotImplementedError("OpenAI provider not yet implemented")
78
+ elif provider == Provider.OLLAMA:
79
+ if not _OLLAMA_AVAILABLE:
80
+ raise ImportError(
81
+ "Ollama provider requires additional dependencies. "
82
+ "Install them with: pip install cua-agent[ollama]"
83
+ )
84
+ # Only import ollama when actually creating an Ollama agent
85
+ try:
86
+ import ollama
87
+ from ..providers.ollama.agent import OllamaComputerAgent
88
+
89
+ return OllamaComputerAgent(max_retries=3, computer=computer, **kwargs)
90
+ except ImportError:
91
+ raise ImportError(
92
+ "Failed to import ollama package. " "Install it with: pip install ollama"
93
+ )
94
+ elif provider == Provider.OMNI:
95
+ if not _OMNI_AVAILABLE:
96
+ raise ImportError(
97
+ "Omni provider requires additional dependencies. "
98
+ "Install them with: pip install cua-agent[omni]"
99
+ )
100
+ return OmniComputerAgent(max_retries=3, computer=computer, **kwargs)
101
+ else:
102
+ raise ValueError(f"Unsupported provider: {provider}")