cua-agent 0.1.6__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (42) hide show
  1. agent/__init__.py +3 -2
  2. agent/core/__init__.py +0 -5
  3. agent/core/computer_agent.py +21 -28
  4. agent/core/loop.py +78 -124
  5. agent/core/messages.py +279 -125
  6. agent/core/types.py +35 -0
  7. agent/core/visualization.py +197 -0
  8. agent/providers/anthropic/api/client.py +142 -1
  9. agent/providers/anthropic/api_handler.py +140 -0
  10. agent/providers/anthropic/callbacks/__init__.py +5 -0
  11. agent/providers/anthropic/loop.py +206 -220
  12. agent/providers/anthropic/response_handler.py +229 -0
  13. agent/providers/anthropic/tools/bash.py +0 -97
  14. agent/providers/anthropic/utils.py +370 -0
  15. agent/providers/omni/__init__.py +1 -20
  16. agent/providers/omni/api_handler.py +42 -0
  17. agent/providers/omni/clients/anthropic.py +4 -0
  18. agent/providers/omni/image_utils.py +0 -72
  19. agent/providers/omni/loop.py +490 -606
  20. agent/providers/omni/parser.py +58 -4
  21. agent/providers/omni/tools/__init__.py +25 -7
  22. agent/providers/omni/tools/base.py +29 -0
  23. agent/providers/omni/tools/bash.py +43 -38
  24. agent/providers/omni/tools/computer.py +144 -182
  25. agent/providers/omni/tools/manager.py +25 -45
  26. agent/providers/omni/types.py +0 -4
  27. agent/providers/omni/utils.py +224 -145
  28. {cua_agent-0.1.6.dist-info → cua_agent-0.1.17.dist-info}/METADATA +6 -36
  29. cua_agent-0.1.17.dist-info/RECORD +63 -0
  30. agent/providers/omni/callbacks.py +0 -78
  31. agent/providers/omni/clients/groq.py +0 -101
  32. agent/providers/omni/experiment.py +0 -276
  33. agent/providers/omni/messages.py +0 -171
  34. agent/providers/omni/tool_manager.py +0 -91
  35. agent/providers/omni/visualization.py +0 -130
  36. agent/types/__init__.py +0 -23
  37. agent/types/base.py +0 -41
  38. agent/types/messages.py +0 -36
  39. cua_agent-0.1.6.dist-info/RECORD +0 -64
  40. /agent/{types → core}/tools.py +0 -0
  41. {cua_agent-0.1.6.dist-info → cua_agent-0.1.17.dist-info}/WHEEL +0 -0
  42. {cua_agent-0.1.6.dist-info → cua_agent-0.1.17.dist-info}/entry_points.txt +0 -0
@@ -1,78 +0,0 @@
1
- """Omni callback manager implementation."""
2
-
3
- import logging
4
- from typing import Any, Dict, Optional, Set
5
-
6
- from ...core.callbacks import BaseCallbackManager, ContentCallback, ToolCallback, APICallback
7
- from ...types.tools import ToolResult
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- class OmniCallbackManager(BaseCallbackManager):
12
- """Callback manager for multi-provider support."""
13
-
14
- def __init__(
15
- self,
16
- content_callback: ContentCallback,
17
- tool_callback: ToolCallback,
18
- api_callback: APICallback,
19
- ):
20
- """Initialize Omni callback manager.
21
-
22
- Args:
23
- content_callback: Callback for content updates
24
- tool_callback: Callback for tool execution results
25
- api_callback: Callback for API interactions
26
- """
27
- super().__init__(
28
- content_callback=content_callback,
29
- tool_callback=tool_callback,
30
- api_callback=api_callback
31
- )
32
- self._active_tools: Set[str] = set()
33
-
34
- def on_content(self, content: Any) -> None:
35
- """Handle content updates.
36
-
37
- Args:
38
- content: Content update data
39
- """
40
- logger.debug(f"Content update: {content}")
41
- self.content_callback(content)
42
-
43
- def on_tool_result(self, result: ToolResult, tool_id: str) -> None:
44
- """Handle tool execution results.
45
-
46
- Args:
47
- result: Tool execution result
48
- tool_id: ID of the tool
49
- """
50
- logger.debug(f"Tool result for {tool_id}: {result}")
51
- self.tool_callback(result, tool_id)
52
-
53
- def on_api_interaction(
54
- self,
55
- request: Any,
56
- response: Any,
57
- error: Optional[Exception] = None
58
- ) -> None:
59
- """Handle API interactions.
60
-
61
- Args:
62
- request: API request data
63
- response: API response data
64
- error: Optional error that occurred
65
- """
66
- if error:
67
- logger.error(f"API error: {str(error)}")
68
- else:
69
- logger.debug(f"API interaction - Request: {request}, Response: {response}")
70
- self.api_callback(request, response, error)
71
-
72
- def get_active_tools(self) -> Set[str]:
73
- """Get currently active tools.
74
-
75
- Returns:
76
- Set of active tool names
77
- """
78
- return self._active_tools.copy()
@@ -1,101 +0,0 @@
1
- """Groq client implementation."""
2
-
3
- import os
4
- import logging
5
- from typing import Dict, List, Optional, Any, Tuple
6
-
7
- from groq import Groq
8
- import re
9
- from .utils import is_image_path
10
- from .base import BaseOmniClient
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- class GroqClient(BaseOmniClient):
16
- """Client for making Groq API calls."""
17
-
18
- def __init__(
19
- self,
20
- api_key: Optional[str] = None,
21
- model: str = "deepseek-r1-distill-llama-70b",
22
- max_tokens: int = 4096,
23
- temperature: float = 0.6,
24
- ):
25
- """Initialize Groq client.
26
-
27
- Args:
28
- api_key: Groq API key (if not provided, will try to get from env)
29
- model: Model name to use
30
- max_tokens: Maximum tokens to generate
31
- temperature: Temperature for sampling
32
- """
33
- super().__init__(api_key=api_key, model=model)
34
- self.api_key = api_key or os.getenv("GROQ_API_KEY")
35
- if not self.api_key:
36
- raise ValueError("No Groq API key provided")
37
-
38
- self.max_tokens = max_tokens
39
- self.temperature = temperature
40
- self.client = Groq(api_key=self.api_key)
41
- self.model: str = model # Add explicit type annotation
42
-
43
- def run_interleaved(
44
- self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
45
- ) -> tuple[str, int]:
46
- """Run interleaved chat completion.
47
-
48
- Args:
49
- messages: List of message dicts
50
- system: System prompt
51
- max_tokens: Optional max tokens override
52
-
53
- Returns:
54
- Tuple of (response text, token usage)
55
- """
56
- # Avoid using system messages for R1
57
- final_messages = [{"role": "user", "content": system}]
58
-
59
- # Process messages
60
- if isinstance(messages, list):
61
- for item in messages:
62
- if isinstance(item, dict):
63
- # For dict items, concatenate all text content, ignoring images
64
- text_contents = []
65
- for cnt in item["content"]:
66
- if isinstance(cnt, str):
67
- if not is_image_path(cnt): # Skip image paths
68
- text_contents.append(cnt)
69
- else:
70
- text_contents.append(str(cnt))
71
-
72
- if text_contents: # Only add if there's text content
73
- message = {"role": "user", "content": " ".join(text_contents)}
74
- final_messages.append(message)
75
- else: # str
76
- message = {"role": "user", "content": item}
77
- final_messages.append(message)
78
-
79
- elif isinstance(messages, str):
80
- final_messages.append({"role": "user", "content": messages})
81
-
82
- try:
83
- completion = self.client.chat.completions.create( # type: ignore
84
- model=self.model,
85
- messages=final_messages, # type: ignore
86
- temperature=self.temperature,
87
- max_tokens=max_tokens or self.max_tokens,
88
- top_p=0.95,
89
- stream=False,
90
- )
91
-
92
- response = completion.choices[0].message.content
93
- final_answer = response.split("</think>\n")[-1] if "</think>" in response else response
94
- final_answer = final_answer.replace("<output>", "").replace("</output>", "")
95
- token_usage = completion.usage.total_tokens
96
-
97
- return final_answer, token_usage
98
-
99
- except Exception as e:
100
- logger.error(f"Error in Groq API call: {e}")
101
- raise
@@ -1,276 +0,0 @@
1
- """Experiment management for the Cua provider."""
2
-
3
- import os
4
- import logging
5
- import copy
6
- import base64
7
- from io import BytesIO
8
- from datetime import datetime
9
- from typing import Any, Dict, List, Optional
10
- from PIL import Image
11
- import json
12
- import time
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class ExperimentManager:
18
- """Manages experiment directories and logging for the agent."""
19
-
20
- def __init__(
21
- self,
22
- base_dir: Optional[str] = None,
23
- only_n_most_recent_images: Optional[int] = None,
24
- ):
25
- """Initialize the experiment manager.
26
-
27
- Args:
28
- base_dir: Base directory for saving experiment data
29
- only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
30
- """
31
- self.base_dir = base_dir
32
- self.only_n_most_recent_images = only_n_most_recent_images
33
- self.run_dir = None
34
- self.current_turn_dir = None
35
- self.turn_count = 0
36
- self.screenshot_count = 0
37
- # Track all screenshots for potential API request inclusion
38
- self.screenshot_paths = []
39
-
40
- # Set up experiment directories if base_dir is provided
41
- if self.base_dir:
42
- self.setup_experiment_dirs()
43
-
44
- def setup_experiment_dirs(self) -> None:
45
- """Setup the experiment directory structure."""
46
- if not self.base_dir:
47
- return
48
-
49
- # Create base experiments directory if it doesn't exist
50
- os.makedirs(self.base_dir, exist_ok=True)
51
-
52
- # Use the base_dir directly as the run_dir
53
- self.run_dir = self.base_dir
54
- logger.info(f"Using directory for experiment: {self.run_dir}")
55
-
56
- # Create first turn directory
57
- self.create_turn_dir()
58
-
59
- def create_turn_dir(self) -> None:
60
- """Create a new directory for the current turn."""
61
- if not self.run_dir:
62
- return
63
-
64
- self.turn_count += 1
65
- self.current_turn_dir = os.path.join(self.run_dir, f"turn_{self.turn_count:03d}")
66
- os.makedirs(self.current_turn_dir, exist_ok=True)
67
- logger.info(f"Created turn directory: {self.current_turn_dir}")
68
-
69
- def sanitize_log_data(self, data: Any) -> Any:
70
- """Sanitize data for logging by removing large base64 strings.
71
-
72
- Args:
73
- data: Data to sanitize (dict, list, or primitive)
74
-
75
- Returns:
76
- Sanitized copy of the data
77
- """
78
- if isinstance(data, dict):
79
- result = copy.deepcopy(data)
80
-
81
- # Handle nested dictionaries and lists
82
- for key, value in result.items():
83
- # Process content arrays that contain image data
84
- if key == "content" and isinstance(value, list):
85
- for i, item in enumerate(value):
86
- if isinstance(item, dict):
87
- # Handle Anthropic format
88
- if item.get("type") == "image" and isinstance(item.get("source"), dict):
89
- source = item["source"]
90
- if "data" in source and isinstance(source["data"], str):
91
- # Replace base64 data with a placeholder and length info
92
- data_len = len(source["data"])
93
- source["data"] = f"[BASE64_IMAGE_DATA_LENGTH_{data_len}]"
94
-
95
- # Handle OpenAI format
96
- elif item.get("type") == "image_url" and isinstance(
97
- item.get("image_url"), dict
98
- ):
99
- url_dict = item["image_url"]
100
- if "url" in url_dict and isinstance(url_dict["url"], str):
101
- url = url_dict["url"]
102
- if url.startswith("data:"):
103
- # Replace base64 data with placeholder
104
- data_len = len(url)
105
- url_dict["url"] = f"[BASE64_IMAGE_URL_LENGTH_{data_len}]"
106
-
107
- # Handle other nested structures recursively
108
- if isinstance(value, dict):
109
- result[key] = self.sanitize_log_data(value)
110
- elif isinstance(value, list):
111
- result[key] = [self.sanitize_log_data(item) for item in value]
112
-
113
- return result
114
- elif isinstance(data, list):
115
- return [self.sanitize_log_data(item) for item in data]
116
- else:
117
- return data
118
-
119
- def save_debug_image(self, image_data: str, filename: str) -> None:
120
- """Save a debug image to the experiment directory.
121
-
122
- Args:
123
- image_data: Base64 encoded image data
124
- filename: Filename to save the image as
125
- """
126
- # Since we no longer want to use the images/ folder, we'll skip this functionality
127
- return
128
-
129
- def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]:
130
- """Save a screenshot to the experiment directory.
131
-
132
- Args:
133
- img_base64: Base64 encoded screenshot
134
- action_type: Type of action that triggered the screenshot
135
-
136
- Returns:
137
- Optional[str]: Path to the saved screenshot, or None if saving failed
138
- """
139
- if not self.current_turn_dir:
140
- return None
141
-
142
- try:
143
- # Increment screenshot counter
144
- self.screenshot_count += 1
145
-
146
- # Create a descriptive filename
147
- timestamp = int(time.time() * 1000)
148
- action_suffix = f"_{action_type}" if action_type else ""
149
- filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
150
-
151
- # Save directly to the turn directory (no screenshots subdirectory)
152
- filepath = os.path.join(self.current_turn_dir, filename)
153
-
154
- # Save the screenshot
155
- img_data = base64.b64decode(img_base64)
156
- with open(filepath, "wb") as f:
157
- f.write(img_data)
158
-
159
- # Keep track of the file path for reference
160
- self.screenshot_paths.append(filepath)
161
-
162
- return filepath
163
- except Exception as e:
164
- logger.error(f"Error saving screenshot: {str(e)}")
165
- return None
166
-
167
- def should_save_debug_image(self) -> bool:
168
- """Determine if debug images should be saved.
169
-
170
- Returns:
171
- Boolean indicating if debug images should be saved
172
- """
173
- # We no longer need to save debug images, so always return False
174
- return False
175
-
176
- def save_action_visualization(
177
- self, img: Image.Image, action_name: str, details: str = ""
178
- ) -> str:
179
- """Save a visualization of an action.
180
-
181
- Args:
182
- img: Image to save
183
- action_name: Name of the action
184
- details: Additional details about the action
185
-
186
- Returns:
187
- Path to the saved image
188
- """
189
- if not self.current_turn_dir:
190
- return ""
191
-
192
- try:
193
- # Create a descriptive filename
194
- timestamp = int(time.time() * 1000)
195
- details_suffix = f"_{details}" if details else ""
196
- filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
197
-
198
- # Save directly to the turn directory (no visualizations subdirectory)
199
- filepath = os.path.join(self.current_turn_dir, filename)
200
-
201
- # Save the image
202
- img.save(filepath)
203
-
204
- # Keep track of the file path for cleanup
205
- self.screenshot_paths.append(filepath)
206
-
207
- return filepath
208
- except Exception as e:
209
- logger.error(f"Error saving action visualization: {str(e)}")
210
- return ""
211
-
212
- def extract_and_save_images(self, data: Any, prefix: str) -> None:
213
- """Extract and save images from response data.
214
-
215
- Args:
216
- data: Response data to extract images from
217
- prefix: Prefix for saved image filenames
218
- """
219
- # Since we no longer want to save extracted images separately,
220
- # we'll skip this functionality entirely
221
- return
222
-
223
- def log_api_call(
224
- self,
225
- call_type: str,
226
- request: Any,
227
- provider: str,
228
- model: str,
229
- response: Any = None,
230
- error: Optional[Exception] = None,
231
- ) -> None:
232
- """Log API call details to file.
233
-
234
- Args:
235
- call_type: Type of API call (e.g., 'request', 'response', 'error')
236
- request: The API request data
237
- provider: The AI provider used
238
- model: The AI model used
239
- response: Optional API response data
240
- error: Optional error information
241
- """
242
- if not self.current_turn_dir:
243
- return
244
-
245
- try:
246
- # Create a unique filename with timestamp
247
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
248
- filename = f"api_call_{timestamp}_{call_type}.json"
249
- filepath = os.path.join(self.current_turn_dir, filename)
250
-
251
- # Sanitize data to remove large base64 strings
252
- sanitized_request = self.sanitize_log_data(request)
253
- sanitized_response = self.sanitize_log_data(response) if response is not None else None
254
-
255
- # Prepare log data
256
- log_data = {
257
- "timestamp": timestamp,
258
- "provider": provider,
259
- "model": model,
260
- "type": call_type,
261
- "request": sanitized_request,
262
- }
263
-
264
- if sanitized_response is not None:
265
- log_data["response"] = sanitized_response
266
- if error is not None:
267
- log_data["error"] = str(error)
268
-
269
- # Write to file
270
- with open(filepath, "w") as f:
271
- json.dump(log_data, f, indent=2, default=str)
272
-
273
- logger.info(f"Logged API {call_type} to {filepath}")
274
-
275
- except Exception as e:
276
- logger.error(f"Error logging API call: {str(e)}")
@@ -1,171 +0,0 @@
1
- """Omni message manager implementation."""
2
-
3
- import base64
4
- from typing import Any, Dict, List, Optional
5
- from io import BytesIO
6
- from PIL import Image
7
-
8
- from ...core.messages import BaseMessageManager, ImageRetentionConfig
9
-
10
-
11
- class OmniMessageManager(BaseMessageManager):
12
- """Message manager for multi-provider support."""
13
-
14
- def __init__(self, config: Optional[ImageRetentionConfig] = None):
15
- """Initialize the message manager.
16
-
17
- Args:
18
- config: Optional configuration for image retention
19
- """
20
- super().__init__(config)
21
- self.messages: List[Dict[str, Any]] = []
22
- self.config = config
23
-
24
- def add_user_message(self, content: str, images: Optional[List[bytes]] = None) -> None:
25
- """Add a user message to the history.
26
-
27
- Args:
28
- content: Message content
29
- images: Optional list of image data
30
- """
31
- # Add images if present
32
- if images:
33
- # Initialize with proper typing for mixed content
34
- message_content: List[Dict[str, Any]] = [{"type": "text", "text": content}]
35
-
36
- # Add each image
37
- for img in images:
38
- message_content.append(
39
- {
40
- "type": "image_url",
41
- "image_url": {
42
- "url": f"data:image/png;base64,{base64.b64encode(img).decode()}"
43
- },
44
- }
45
- )
46
-
47
- message = {"role": "user", "content": message_content}
48
- else:
49
- # Simple text message
50
- message = {"role": "user", "content": content}
51
-
52
- self.messages.append(message)
53
-
54
- # Apply retention policy
55
- if self.config and self.config.num_images_to_keep:
56
- self._apply_image_retention_policy()
57
-
58
- def add_assistant_message(self, content: str) -> None:
59
- """Add an assistant message to the history.
60
-
61
- Args:
62
- content: Message content
63
- """
64
- self.messages.append({"role": "assistant", "content": content})
65
-
66
- def add_system_message(self, content: str) -> None:
67
- """Add a system message to the history.
68
-
69
- Args:
70
- content: Message content
71
- """
72
- self.messages.append({"role": "system", "content": content})
73
-
74
- def _apply_image_retention_policy(self) -> None:
75
- """Apply image retention policy to message history."""
76
- if not self.config or not self.config.num_images_to_keep:
77
- return
78
-
79
- # Count images from newest to oldest
80
- image_count = 0
81
- for message in reversed(self.messages):
82
- if message["role"] != "user":
83
- continue
84
-
85
- # Handle multimodal messages
86
- if isinstance(message["content"], list):
87
- new_content = []
88
- for item in message["content"]:
89
- if item["type"] == "text":
90
- new_content.append(item)
91
- elif item["type"] == "image_url":
92
- if image_count < self.config.num_images_to_keep:
93
- new_content.append(item)
94
- image_count += 1
95
- message["content"] = new_content
96
-
97
- def get_formatted_messages(self, provider: str) -> List[Dict[str, Any]]:
98
- """Get messages formatted for specific provider.
99
-
100
- Args:
101
- provider: Provider name to format messages for
102
-
103
- Returns:
104
- List of formatted messages
105
- """
106
- # Set the provider for message formatting
107
- self.set_provider(provider)
108
-
109
- if provider == "anthropic":
110
- return self._format_for_anthropic()
111
- elif provider == "openai":
112
- return self._format_for_openai()
113
- elif provider == "groq":
114
- return self._format_for_groq()
115
- elif provider == "qwen":
116
- return self._format_for_qwen()
117
- else:
118
- raise ValueError(f"Unsupported provider: {provider}")
119
-
120
- def _format_for_anthropic(self) -> List[Dict[str, Any]]:
121
- """Format messages for Anthropic API."""
122
- formatted = []
123
- for msg in self.messages:
124
- formatted_msg = {"role": msg["role"]}
125
-
126
- # Handle multimodal content
127
- if isinstance(msg["content"], list):
128
- formatted_msg["content"] = []
129
- for item in msg["content"]:
130
- if item["type"] == "text":
131
- formatted_msg["content"].append({"type": "text", "text": item["text"]})
132
- elif item["type"] == "image_url":
133
- formatted_msg["content"].append(
134
- {
135
- "type": "image",
136
- "source": {
137
- "type": "base64",
138
- "media_type": "image/png",
139
- "data": item["image_url"]["url"].split(",")[1],
140
- },
141
- }
142
- )
143
- else:
144
- formatted_msg["content"] = msg["content"]
145
-
146
- formatted.append(formatted_msg)
147
- return formatted
148
-
149
- def _format_for_openai(self) -> List[Dict[str, Any]]:
150
- """Format messages for OpenAI API."""
151
- # OpenAI already uses the same format
152
- return self.messages
153
-
154
- def _format_for_groq(self) -> List[Dict[str, Any]]:
155
- """Format messages for Groq API."""
156
- # Groq uses OpenAI-compatible format
157
- return self.messages
158
-
159
- def _format_for_qwen(self) -> List[Dict[str, Any]]:
160
- """Format messages for Qwen API."""
161
- formatted = []
162
- for msg in self.messages:
163
- if isinstance(msg["content"], list):
164
- # Convert multimodal content to text-only
165
- text_content = next(
166
- (item["text"] for item in msg["content"] if item["type"] == "text"), ""
167
- )
168
- formatted.append({"role": msg["role"], "content": text_content})
169
- else:
170
- formatted.append(msg)
171
- return formatted