cua-agent 0.1.5__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (52) hide show
  1. agent/__init__.py +3 -4
  2. agent/core/__init__.py +3 -10
  3. agent/core/computer_agent.py +207 -32
  4. agent/core/experiment.py +20 -3
  5. agent/core/loop.py +78 -120
  6. agent/core/messages.py +279 -125
  7. agent/core/telemetry.py +44 -32
  8. agent/core/types.py +35 -0
  9. agent/core/visualization.py +197 -0
  10. agent/providers/anthropic/api/client.py +142 -1
  11. agent/providers/anthropic/api_handler.py +140 -0
  12. agent/providers/anthropic/callbacks/__init__.py +5 -0
  13. agent/providers/anthropic/loop.py +224 -209
  14. agent/providers/anthropic/messages/manager.py +3 -1
  15. agent/providers/anthropic/response_handler.py +229 -0
  16. agent/providers/anthropic/tools/base.py +1 -1
  17. agent/providers/anthropic/tools/bash.py +0 -97
  18. agent/providers/anthropic/tools/collection.py +2 -2
  19. agent/providers/anthropic/tools/computer.py +34 -24
  20. agent/providers/anthropic/tools/manager.py +2 -2
  21. agent/providers/anthropic/utils.py +370 -0
  22. agent/providers/omni/__init__.py +1 -20
  23. agent/providers/omni/api_handler.py +42 -0
  24. agent/providers/omni/clients/anthropic.py +4 -0
  25. agent/providers/omni/image_utils.py +0 -72
  26. agent/providers/omni/loop.py +497 -607
  27. agent/providers/omni/parser.py +60 -5
  28. agent/providers/omni/tools/__init__.py +25 -8
  29. agent/providers/omni/tools/base.py +29 -0
  30. agent/providers/omni/tools/bash.py +43 -38
  31. agent/providers/omni/tools/computer.py +144 -181
  32. agent/providers/omni/tools/manager.py +26 -48
  33. agent/providers/omni/types.py +0 -4
  34. agent/providers/omni/utils.py +225 -144
  35. {cua_agent-0.1.5.dist-info → cua_agent-0.1.17.dist-info}/METADATA +6 -36
  36. cua_agent-0.1.17.dist-info/RECORD +63 -0
  37. agent/core/agent.py +0 -252
  38. agent/core/base_agent.py +0 -164
  39. agent/core/factory.py +0 -102
  40. agent/providers/omni/callbacks.py +0 -78
  41. agent/providers/omni/clients/groq.py +0 -101
  42. agent/providers/omni/experiment.py +0 -273
  43. agent/providers/omni/messages.py +0 -171
  44. agent/providers/omni/tool_manager.py +0 -91
  45. agent/providers/omni/visualization.py +0 -130
  46. agent/types/__init__.py +0 -26
  47. agent/types/base.py +0 -53
  48. agent/types/messages.py +0 -36
  49. cua_agent-0.1.5.dist-info/RECORD +0 -67
  50. /agent/{types → core}/tools.py +0 -0
  51. {cua_agent-0.1.5.dist-info → cua_agent-0.1.17.dist-info}/WHEEL +0 -0
  52. {cua_agent-0.1.5.dist-info → cua_agent-0.1.17.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,197 @@
1
+ """Core visualization utilities for agents."""
2
+
3
+ import logging
4
+ import base64
5
+ from typing import Dict, Tuple
6
+ from PIL import Image, ImageDraw
7
+ from io import BytesIO
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
13
+ """Visualize a click action by drawing a circle on the screenshot.
14
+
15
+ Args:
16
+ x: X coordinate of the click
17
+ y: Y coordinate of the click
18
+ img_base64: Base64-encoded screenshot
19
+
20
+ Returns:
21
+ PIL Image with visualization
22
+ """
23
+ try:
24
+ # Decode the base64 image
25
+ image_data = base64.b64decode(img_base64)
26
+ img = Image.open(BytesIO(image_data))
27
+
28
+ # Create a copy to draw on
29
+ draw_img = img.copy()
30
+ draw = ImageDraw.Draw(draw_img)
31
+
32
+ # Draw a circle at the click location
33
+ radius = 15
34
+ draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline="red", width=3)
35
+
36
+ # Draw crosshairs
37
+ line_length = 20
38
+ draw.line([(x - line_length, y), (x + line_length, y)], fill="red", width=3)
39
+ draw.line([(x, y - line_length), (x, y + line_length)], fill="red", width=3)
40
+
41
+ return draw_img
42
+ except Exception as e:
43
+ logger.error(f"Error visualizing click: {str(e)}")
44
+ # Return a blank image as fallback
45
+ return Image.new("RGB", (800, 600), "white")
46
+
47
+
48
+ def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
49
+ """Visualize a scroll action by drawing arrows on the screenshot.
50
+
51
+ Args:
52
+ direction: Direction of scroll ('up' or 'down')
53
+ clicks: Number of scroll clicks
54
+ img_base64: Base64-encoded screenshot
55
+
56
+ Returns:
57
+ PIL Image with visualization
58
+ """
59
+ try:
60
+ # Decode the base64 image
61
+ image_data = base64.b64decode(img_base64)
62
+ img = Image.open(BytesIO(image_data))
63
+
64
+ # Create a copy to draw on
65
+ draw_img = img.copy()
66
+ draw = ImageDraw.Draw(draw_img)
67
+
68
+ # Calculate parameters for visualization
69
+ width, height = img.size
70
+ center_x = width // 2
71
+
72
+ # Draw arrows to indicate scrolling
73
+ arrow_length = min(100, height // 4)
74
+ arrow_width = 30
75
+ num_arrows = min(clicks, 3) # Don't draw too many arrows
76
+
77
+ # Calculate starting position
78
+ if direction == "down":
79
+ start_y = height // 3
80
+ arrow_dir = 1 # Down
81
+ else:
82
+ start_y = height * 2 // 3
83
+ arrow_dir = -1 # Up
84
+
85
+ # Draw the arrows
86
+ for i in range(num_arrows):
87
+ y_pos = start_y + (i * arrow_length * arrow_dir * 0.7)
88
+ arrow_top = (center_x, y_pos)
89
+ arrow_bottom = (center_x, y_pos + arrow_length * arrow_dir)
90
+
91
+ # Draw the main line
92
+ draw.line([arrow_top, arrow_bottom], fill="red", width=5)
93
+
94
+ # Draw the arrowhead
95
+ arrowhead_size = 20
96
+ if direction == "down":
97
+ draw.line(
98
+ [
99
+ (center_x - arrow_width // 2, arrow_bottom[1] - arrowhead_size),
100
+ arrow_bottom,
101
+ (center_x + arrow_width // 2, arrow_bottom[1] - arrowhead_size),
102
+ ],
103
+ fill="red",
104
+ width=5,
105
+ )
106
+ else:
107
+ draw.line(
108
+ [
109
+ (center_x - arrow_width // 2, arrow_bottom[1] + arrowhead_size),
110
+ arrow_bottom,
111
+ (center_x + arrow_width // 2, arrow_bottom[1] + arrowhead_size),
112
+ ],
113
+ fill="red",
114
+ width=5,
115
+ )
116
+
117
+ return draw_img
118
+ except Exception as e:
119
+ logger.error(f"Error visualizing scroll: {str(e)}")
120
+ # Return a blank image as fallback
121
+ return Image.new("RGB", (800, 600), "white")
122
+
123
+
124
+ def calculate_element_center(bbox: Dict[str, float], width: int, height: int) -> Tuple[int, int]:
125
+ """Calculate the center point of a UI element.
126
+
127
+ Args:
128
+ bbox: Bounding box dictionary with x1, y1, x2, y2 coordinates (0-1 normalized)
129
+ width: Screen width in pixels
130
+ height: Screen height in pixels
131
+
132
+ Returns:
133
+ (x, y) tuple with pixel coordinates
134
+ """
135
+ center_x = int((bbox["x1"] + bbox["x2"]) / 2 * width)
136
+ center_y = int((bbox["y1"] + bbox["y2"]) / 2 * height)
137
+ return center_x, center_y
138
+
139
+
140
+ class VisualizationHelper:
141
+ """Helper class for visualizing agent actions."""
142
+
143
+ def __init__(self, agent):
144
+ """Initialize visualization helper.
145
+
146
+ Args:
147
+ agent: Reference to the agent that will use this helper
148
+ """
149
+ self.agent = agent
150
+
151
+ def visualize_action(self, x: int, y: int, img_base64: str) -> None:
152
+ """Visualize a click action by drawing on the screenshot."""
153
+ if (
154
+ not self.agent.save_trajectory
155
+ or not hasattr(self.agent, "experiment_manager")
156
+ or not self.agent.experiment_manager
157
+ ):
158
+ return
159
+
160
+ try:
161
+ # Use the visualization utility
162
+ img = visualize_click(x, y, img_base64)
163
+
164
+ # Save the visualization
165
+ self.agent.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
166
+ except Exception as e:
167
+ logger.error(f"Error visualizing action: {str(e)}")
168
+
169
+ def visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
170
+ """Visualize a scroll action by drawing arrows on the screenshot."""
171
+ if (
172
+ not self.agent.save_trajectory
173
+ or not hasattr(self.agent, "experiment_manager")
174
+ or not self.agent.experiment_manager
175
+ ):
176
+ return
177
+
178
+ try:
179
+ # Use the visualization utility
180
+ img = visualize_scroll(direction, clicks, img_base64)
181
+
182
+ # Save the visualization
183
+ self.agent.experiment_manager.save_action_visualization(
184
+ img, "scroll", f"{direction}_{clicks}"
185
+ )
186
+ except Exception as e:
187
+ logger.error(f"Error visualizing scroll: {str(e)}")
188
+
189
+ def save_action_visualization(
190
+ self, img: Image.Image, action_name: str, details: str = ""
191
+ ) -> str:
192
+ """Save a visualization of an action."""
193
+ if hasattr(self.agent, "experiment_manager") and self.agent.experiment_manager:
194
+ return self.agent.experiment_manager.save_action_visualization(
195
+ img, action_name, details
196
+ )
197
+ return ""
@@ -1,4 +1,4 @@
1
- from typing import Any
1
+ from typing import Any, List, Dict, cast
2
2
  import httpx
3
3
  import asyncio
4
4
  from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
@@ -80,6 +80,147 @@ class BaseAnthropicClient:
80
80
  f"Failed after {self.MAX_RETRIES} retries. " f"Last error: {str(last_error)}"
81
81
  )
82
82
 
83
+ async def run_interleaved(
84
+ self, messages: List[Dict[str, Any]], system: str, max_tokens: int = 4096
85
+ ) -> Any:
86
+ """Run the Anthropic API with the Claude model, supports interleaved tool calling.
87
+
88
+ Args:
89
+ messages: List of message objects
90
+ system: System prompt
91
+ max_tokens: Maximum tokens to generate
92
+
93
+ Returns:
94
+ API response
95
+ """
96
+ # Add the tool_result check/fix logic here
97
+ fixed_messages = self._fix_missing_tool_results(messages)
98
+
99
+ # Get model name from concrete implementation if available
100
+ model_name = getattr(self, "model", "unknown model")
101
+ logger.info(f"Running Anthropic API call with model {model_name}")
102
+
103
+ retry_count = 0
104
+
105
+ while retry_count < self.MAX_RETRIES:
106
+ try:
107
+ # Call the Anthropic API through create_message which is implemented by subclasses
108
+ # Convert system str to the list format expected by create_message
109
+ system_list = [system]
110
+
111
+ # Convert message format if needed - concrete implementations may do further conversion
112
+ response = await self.create_message(
113
+ messages=cast(list[BetaMessageParam], fixed_messages),
114
+ system=system_list,
115
+ tools=[], # Tools are included in the messages
116
+ max_tokens=max_tokens,
117
+ betas=["tools-2023-12-13"],
118
+ )
119
+ logger.info(f"Anthropic API call successful")
120
+ return response
121
+ except Exception as e:
122
+ retry_count += 1
123
+ wait_time = self.INITIAL_RETRY_DELAY * (
124
+ 2 ** (retry_count - 1)
125
+ ) # Exponential backoff
126
+ logger.info(
127
+ f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) in {wait_time:.2f} seconds after error: {str(e)}"
128
+ )
129
+ await asyncio.sleep(wait_time)
130
+
131
+ # If we get here, all retries failed
132
+ raise RuntimeError(f"Failed to call Anthropic API after {self.MAX_RETRIES} attempts")
133
+
134
+ def _fix_missing_tool_results(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
135
+ """Check for and fix any missing tool_result blocks after tool_use blocks.
136
+
137
+ Args:
138
+ messages: List of message objects
139
+
140
+ Returns:
141
+ Fixed messages with proper tool_result blocks
142
+ """
143
+ fixed_messages = []
144
+ pending_tool_uses = {} # Map of tool_use IDs to their details
145
+
146
+ for i, message in enumerate(messages):
147
+ # Track any tool_use blocks in this message
148
+ if message.get("role") == "assistant" and "content" in message:
149
+ content = message.get("content", [])
150
+ for block in content:
151
+ if isinstance(block, dict) and block.get("type") == "tool_use":
152
+ tool_id = block.get("id")
153
+ if tool_id:
154
+ pending_tool_uses[tool_id] = {
155
+ "name": block.get("name", ""),
156
+ "input": block.get("input", {}),
157
+ }
158
+
159
+ # Check if this message handles any pending tool_use blocks
160
+ if message.get("role") == "user" and "content" in message:
161
+ # Check for tool_result blocks in this message
162
+ content = message.get("content", [])
163
+ for block in content:
164
+ if isinstance(block, dict) and block.get("type") == "tool_result":
165
+ tool_id = block.get("tool_use_id")
166
+ if tool_id in pending_tool_uses:
167
+ # This tool_result handles a pending tool_use
168
+ pending_tool_uses.pop(tool_id)
169
+
170
+ # Add the message to our fixed list
171
+ fixed_messages.append(message)
172
+
173
+ # If this is an assistant message with tool_use blocks and there are
174
+ # pending tool uses that need to be resolved before the next assistant message
175
+ if (
176
+ i + 1 < len(messages)
177
+ and message.get("role") == "assistant"
178
+ and messages[i + 1].get("role") == "assistant"
179
+ and pending_tool_uses
180
+ ):
181
+
182
+ # We need to insert a user message with tool_results for all pending tool_uses
183
+ tool_results = []
184
+ for tool_id, tool_info in pending_tool_uses.items():
185
+ tool_results.append(
186
+ {
187
+ "type": "tool_result",
188
+ "tool_use_id": tool_id,
189
+ "content": {
190
+ "type": "error",
191
+ "message": "Tool execution was skipped or failed",
192
+ },
193
+ }
194
+ )
195
+
196
+ # Insert a synthetic user message with the tool results
197
+ if tool_results:
198
+ fixed_messages.append({"role": "user", "content": tool_results})
199
+
200
+ # Clear pending tools since we've added results for them
201
+ pending_tool_uses = {}
202
+
203
+ # Check if there are any remaining pending tool_uses at the end of the conversation
204
+ if pending_tool_uses and fixed_messages and fixed_messages[-1].get("role") == "assistant":
205
+ # Add a final user message with tool results for any pending tool_uses
206
+ tool_results = []
207
+ for tool_id, tool_info in pending_tool_uses.items():
208
+ tool_results.append(
209
+ {
210
+ "type": "tool_result",
211
+ "tool_use_id": tool_id,
212
+ "content": {
213
+ "type": "error",
214
+ "message": "Tool execution was skipped or failed",
215
+ },
216
+ }
217
+ )
218
+
219
+ if tool_results:
220
+ fixed_messages.append({"role": "user", "content": tool_results})
221
+
222
+ return fixed_messages
223
+
83
224
 
84
225
  class AnthropicDirectClient(BaseAnthropicClient):
85
226
  """Direct Anthropic API client implementation."""
@@ -0,0 +1,140 @@
1
+ """API call handling for Anthropic provider."""
2
+
3
+ import logging
4
+ import asyncio
5
+ from typing import List
6
+
7
+ from anthropic.types.beta import (
8
+ BetaMessage,
9
+ BetaMessageParam,
10
+ BetaTextBlockParam,
11
+ )
12
+
13
+ from .types import LLMProvider
14
+ from .prompts import SYSTEM_PROMPT
15
+
16
+ # Constants
17
+ COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
18
+ PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class AnthropicAPIHandler:
24
+ """Handles API calls to Anthropic's API with structured error handling and retries."""
25
+
26
+ def __init__(self, loop):
27
+ """Initialize the API handler.
28
+
29
+ Args:
30
+ loop: Reference to the parent loop instance that provides context
31
+ """
32
+ self.loop = loop
33
+
34
+ async def make_api_call(
35
+ self, messages: List[BetaMessageParam], system_prompt: str = SYSTEM_PROMPT
36
+ ) -> BetaMessage:
37
+ """Make API call to Anthropic with retry logic.
38
+
39
+ Args:
40
+ messages: List of messages to send to the API
41
+ system_prompt: System prompt to use (default: SYSTEM_PROMPT)
42
+
43
+ Returns:
44
+ API response
45
+
46
+ Raises:
47
+ RuntimeError: If API call fails after all retries
48
+ """
49
+ if self.loop.client is None:
50
+ raise RuntimeError("Client not initialized. Call initialize_client() first.")
51
+ if self.loop.tool_manager is None:
52
+ raise RuntimeError("Tool manager not initialized. Call initialize_client() first.")
53
+
54
+ last_error = None
55
+
56
+ # Add detailed debug logging to examine messages
57
+ logger.info(f"Sending {len(messages)} messages to Anthropic API")
58
+
59
+ # Log tool use IDs and tool result IDs for debugging
60
+ tool_use_ids = set()
61
+ tool_result_ids = set()
62
+
63
+ for i, msg in enumerate(messages):
64
+ logger.info(f"Message {i}: role={msg.get('role')}")
65
+ if isinstance(msg.get("content"), list):
66
+ for content_block in msg.get("content", []):
67
+ if isinstance(content_block, dict):
68
+ block_type = content_block.get("type")
69
+ if block_type == "tool_use" and "id" in content_block:
70
+ tool_id = content_block.get("id")
71
+ tool_use_ids.add(tool_id)
72
+ logger.info(f" - Found tool_use with ID: {tool_id}")
73
+ elif block_type == "tool_result" and "tool_use_id" in content_block:
74
+ result_id = content_block.get("tool_use_id")
75
+ tool_result_ids.add(result_id)
76
+ logger.info(f" - Found tool_result referencing ID: {result_id}")
77
+
78
+ # Check for mismatches
79
+ missing_tool_uses = tool_result_ids - tool_use_ids
80
+ if missing_tool_uses:
81
+ logger.warning(
82
+ f"Found tool_result IDs without matching tool_use IDs: {missing_tool_uses}"
83
+ )
84
+
85
+ for attempt in range(self.loop.max_retries):
86
+ try:
87
+ # Log request
88
+ request_data = {
89
+ "messages": messages,
90
+ "max_tokens": self.loop.max_tokens,
91
+ "system": system_prompt,
92
+ }
93
+ # Let ExperimentManager handle sanitization
94
+ self.loop._log_api_call("request", request_data)
95
+
96
+ # Setup betas and system
97
+ system = BetaTextBlockParam(
98
+ type="text",
99
+ text=system_prompt,
100
+ )
101
+
102
+ betas = [COMPUTER_USE_BETA_FLAG]
103
+ # Add prompt caching if enabled in the message manager's config
104
+ if self.loop.message_manager.config.enable_caching:
105
+ betas.append(PROMPT_CACHING_BETA_FLAG)
106
+ system["cache_control"] = {"type": "ephemeral"}
107
+
108
+ # Make API call
109
+ response = await self.loop.client.create_message(
110
+ messages=messages,
111
+ system=[system],
112
+ tools=self.loop.tool_manager.get_tool_params(),
113
+ max_tokens=self.loop.max_tokens,
114
+ betas=betas,
115
+ )
116
+
117
+ # Let ExperimentManager handle sanitization
118
+ self.loop._log_api_call("response", request_data, response)
119
+
120
+ return response
121
+ except Exception as e:
122
+ last_error = e
123
+ logger.error(
124
+ f"Error in API call (attempt {attempt + 1}/{self.loop.max_retries}): {str(e)}"
125
+ )
126
+ self.loop._log_api_call("error", {"messages": messages}, error=e)
127
+
128
+ if attempt < self.loop.max_retries - 1:
129
+ await asyncio.sleep(
130
+ self.loop.retry_delay * (attempt + 1)
131
+ ) # Exponential backoff
132
+ continue
133
+
134
+ # If we get here, all retries failed
135
+ error_message = f"API call failed after {self.loop.max_retries} attempts"
136
+ if last_error:
137
+ error_message += f": {str(last_error)}"
138
+
139
+ logger.error(error_message)
140
+ raise RuntimeError(error_message)
@@ -0,0 +1,5 @@
1
+ """Anthropic callbacks package."""
2
+
3
+ from .manager import CallbackManager
4
+
5
+ __all__ = ["CallbackManager"]