cua-agent 0.1.5__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +3 -4
- agent/core/__init__.py +3 -10
- agent/core/computer_agent.py +207 -32
- agent/core/experiment.py +20 -3
- agent/core/loop.py +78 -120
- agent/core/messages.py +279 -125
- agent/core/telemetry.py +44 -32
- agent/core/types.py +35 -0
- agent/core/visualization.py +197 -0
- agent/providers/anthropic/api/client.py +142 -1
- agent/providers/anthropic/api_handler.py +140 -0
- agent/providers/anthropic/callbacks/__init__.py +5 -0
- agent/providers/anthropic/loop.py +224 -209
- agent/providers/anthropic/messages/manager.py +3 -1
- agent/providers/anthropic/response_handler.py +229 -0
- agent/providers/anthropic/tools/base.py +1 -1
- agent/providers/anthropic/tools/bash.py +0 -97
- agent/providers/anthropic/tools/collection.py +2 -2
- agent/providers/anthropic/tools/computer.py +34 -24
- agent/providers/anthropic/tools/manager.py +2 -2
- agent/providers/anthropic/utils.py +370 -0
- agent/providers/omni/__init__.py +1 -20
- agent/providers/omni/api_handler.py +42 -0
- agent/providers/omni/clients/anthropic.py +4 -0
- agent/providers/omni/image_utils.py +0 -72
- agent/providers/omni/loop.py +497 -607
- agent/providers/omni/parser.py +60 -5
- agent/providers/omni/tools/__init__.py +25 -8
- agent/providers/omni/tools/base.py +29 -0
- agent/providers/omni/tools/bash.py +43 -38
- agent/providers/omni/tools/computer.py +144 -181
- agent/providers/omni/tools/manager.py +26 -48
- agent/providers/omni/types.py +0 -4
- agent/providers/omni/utils.py +225 -144
- {cua_agent-0.1.5.dist-info → cua_agent-0.1.17.dist-info}/METADATA +6 -36
- cua_agent-0.1.17.dist-info/RECORD +63 -0
- agent/core/agent.py +0 -252
- agent/core/base_agent.py +0 -164
- agent/core/factory.py +0 -102
- agent/providers/omni/callbacks.py +0 -78
- agent/providers/omni/clients/groq.py +0 -101
- agent/providers/omni/experiment.py +0 -273
- agent/providers/omni/messages.py +0 -171
- agent/providers/omni/tool_manager.py +0 -91
- agent/providers/omni/visualization.py +0 -130
- agent/types/__init__.py +0 -26
- agent/types/base.py +0 -53
- agent/types/messages.py +0 -36
- cua_agent-0.1.5.dist-info/RECORD +0 -67
- /agent/{types → core}/tools.py +0 -0
- {cua_agent-0.1.5.dist-info → cua_agent-0.1.17.dist-info}/WHEEL +0 -0
- {cua_agent-0.1.5.dist-info → cua_agent-0.1.17.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""Core visualization utilities for agents."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import base64
|
|
5
|
+
from typing import Dict, Tuple
|
|
6
|
+
from PIL import Image, ImageDraw
|
|
7
|
+
from io import BytesIO
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
|
|
13
|
+
"""Visualize a click action by drawing a circle on the screenshot.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
x: X coordinate of the click
|
|
17
|
+
y: Y coordinate of the click
|
|
18
|
+
img_base64: Base64-encoded screenshot
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
PIL Image with visualization
|
|
22
|
+
"""
|
|
23
|
+
try:
|
|
24
|
+
# Decode the base64 image
|
|
25
|
+
image_data = base64.b64decode(img_base64)
|
|
26
|
+
img = Image.open(BytesIO(image_data))
|
|
27
|
+
|
|
28
|
+
# Create a copy to draw on
|
|
29
|
+
draw_img = img.copy()
|
|
30
|
+
draw = ImageDraw.Draw(draw_img)
|
|
31
|
+
|
|
32
|
+
# Draw a circle at the click location
|
|
33
|
+
radius = 15
|
|
34
|
+
draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline="red", width=3)
|
|
35
|
+
|
|
36
|
+
# Draw crosshairs
|
|
37
|
+
line_length = 20
|
|
38
|
+
draw.line([(x - line_length, y), (x + line_length, y)], fill="red", width=3)
|
|
39
|
+
draw.line([(x, y - line_length), (x, y + line_length)], fill="red", width=3)
|
|
40
|
+
|
|
41
|
+
return draw_img
|
|
42
|
+
except Exception as e:
|
|
43
|
+
logger.error(f"Error visualizing click: {str(e)}")
|
|
44
|
+
# Return a blank image as fallback
|
|
45
|
+
return Image.new("RGB", (800, 600), "white")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
|
|
49
|
+
"""Visualize a scroll action by drawing arrows on the screenshot.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
direction: Direction of scroll ('up' or 'down')
|
|
53
|
+
clicks: Number of scroll clicks
|
|
54
|
+
img_base64: Base64-encoded screenshot
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
PIL Image with visualization
|
|
58
|
+
"""
|
|
59
|
+
try:
|
|
60
|
+
# Decode the base64 image
|
|
61
|
+
image_data = base64.b64decode(img_base64)
|
|
62
|
+
img = Image.open(BytesIO(image_data))
|
|
63
|
+
|
|
64
|
+
# Create a copy to draw on
|
|
65
|
+
draw_img = img.copy()
|
|
66
|
+
draw = ImageDraw.Draw(draw_img)
|
|
67
|
+
|
|
68
|
+
# Calculate parameters for visualization
|
|
69
|
+
width, height = img.size
|
|
70
|
+
center_x = width // 2
|
|
71
|
+
|
|
72
|
+
# Draw arrows to indicate scrolling
|
|
73
|
+
arrow_length = min(100, height // 4)
|
|
74
|
+
arrow_width = 30
|
|
75
|
+
num_arrows = min(clicks, 3) # Don't draw too many arrows
|
|
76
|
+
|
|
77
|
+
# Calculate starting position
|
|
78
|
+
if direction == "down":
|
|
79
|
+
start_y = height // 3
|
|
80
|
+
arrow_dir = 1 # Down
|
|
81
|
+
else:
|
|
82
|
+
start_y = height * 2 // 3
|
|
83
|
+
arrow_dir = -1 # Up
|
|
84
|
+
|
|
85
|
+
# Draw the arrows
|
|
86
|
+
for i in range(num_arrows):
|
|
87
|
+
y_pos = start_y + (i * arrow_length * arrow_dir * 0.7)
|
|
88
|
+
arrow_top = (center_x, y_pos)
|
|
89
|
+
arrow_bottom = (center_x, y_pos + arrow_length * arrow_dir)
|
|
90
|
+
|
|
91
|
+
# Draw the main line
|
|
92
|
+
draw.line([arrow_top, arrow_bottom], fill="red", width=5)
|
|
93
|
+
|
|
94
|
+
# Draw the arrowhead
|
|
95
|
+
arrowhead_size = 20
|
|
96
|
+
if direction == "down":
|
|
97
|
+
draw.line(
|
|
98
|
+
[
|
|
99
|
+
(center_x - arrow_width // 2, arrow_bottom[1] - arrowhead_size),
|
|
100
|
+
arrow_bottom,
|
|
101
|
+
(center_x + arrow_width // 2, arrow_bottom[1] - arrowhead_size),
|
|
102
|
+
],
|
|
103
|
+
fill="red",
|
|
104
|
+
width=5,
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
draw.line(
|
|
108
|
+
[
|
|
109
|
+
(center_x - arrow_width // 2, arrow_bottom[1] + arrowhead_size),
|
|
110
|
+
arrow_bottom,
|
|
111
|
+
(center_x + arrow_width // 2, arrow_bottom[1] + arrowhead_size),
|
|
112
|
+
],
|
|
113
|
+
fill="red",
|
|
114
|
+
width=5,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return draw_img
|
|
118
|
+
except Exception as e:
|
|
119
|
+
logger.error(f"Error visualizing scroll: {str(e)}")
|
|
120
|
+
# Return a blank image as fallback
|
|
121
|
+
return Image.new("RGB", (800, 600), "white")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def calculate_element_center(bbox: Dict[str, float], width: int, height: int) -> Tuple[int, int]:
|
|
125
|
+
"""Calculate the center point of a UI element.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
bbox: Bounding box dictionary with x1, y1, x2, y2 coordinates (0-1 normalized)
|
|
129
|
+
width: Screen width in pixels
|
|
130
|
+
height: Screen height in pixels
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
(x, y) tuple with pixel coordinates
|
|
134
|
+
"""
|
|
135
|
+
center_x = int((bbox["x1"] + bbox["x2"]) / 2 * width)
|
|
136
|
+
center_y = int((bbox["y1"] + bbox["y2"]) / 2 * height)
|
|
137
|
+
return center_x, center_y
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class VisualizationHelper:
|
|
141
|
+
"""Helper class for visualizing agent actions."""
|
|
142
|
+
|
|
143
|
+
def __init__(self, agent):
|
|
144
|
+
"""Initialize visualization helper.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
agent: Reference to the agent that will use this helper
|
|
148
|
+
"""
|
|
149
|
+
self.agent = agent
|
|
150
|
+
|
|
151
|
+
def visualize_action(self, x: int, y: int, img_base64: str) -> None:
|
|
152
|
+
"""Visualize a click action by drawing on the screenshot."""
|
|
153
|
+
if (
|
|
154
|
+
not self.agent.save_trajectory
|
|
155
|
+
or not hasattr(self.agent, "experiment_manager")
|
|
156
|
+
or not self.agent.experiment_manager
|
|
157
|
+
):
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
# Use the visualization utility
|
|
162
|
+
img = visualize_click(x, y, img_base64)
|
|
163
|
+
|
|
164
|
+
# Save the visualization
|
|
165
|
+
self.agent.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
|
|
166
|
+
except Exception as e:
|
|
167
|
+
logger.error(f"Error visualizing action: {str(e)}")
|
|
168
|
+
|
|
169
|
+
def visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
|
|
170
|
+
"""Visualize a scroll action by drawing arrows on the screenshot."""
|
|
171
|
+
if (
|
|
172
|
+
not self.agent.save_trajectory
|
|
173
|
+
or not hasattr(self.agent, "experiment_manager")
|
|
174
|
+
or not self.agent.experiment_manager
|
|
175
|
+
):
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
# Use the visualization utility
|
|
180
|
+
img = visualize_scroll(direction, clicks, img_base64)
|
|
181
|
+
|
|
182
|
+
# Save the visualization
|
|
183
|
+
self.agent.experiment_manager.save_action_visualization(
|
|
184
|
+
img, "scroll", f"{direction}_{clicks}"
|
|
185
|
+
)
|
|
186
|
+
except Exception as e:
|
|
187
|
+
logger.error(f"Error visualizing scroll: {str(e)}")
|
|
188
|
+
|
|
189
|
+
def save_action_visualization(
|
|
190
|
+
self, img: Image.Image, action_name: str, details: str = ""
|
|
191
|
+
) -> str:
|
|
192
|
+
"""Save a visualization of an action."""
|
|
193
|
+
if hasattr(self.agent, "experiment_manager") and self.agent.experiment_manager:
|
|
194
|
+
return self.agent.experiment_manager.save_action_visualization(
|
|
195
|
+
img, action_name, details
|
|
196
|
+
)
|
|
197
|
+
return ""
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import Any, List, Dict, cast
|
|
2
2
|
import httpx
|
|
3
3
|
import asyncio
|
|
4
4
|
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
|
|
@@ -80,6 +80,147 @@ class BaseAnthropicClient:
|
|
|
80
80
|
f"Failed after {self.MAX_RETRIES} retries. " f"Last error: {str(last_error)}"
|
|
81
81
|
)
|
|
82
82
|
|
|
83
|
+
async def run_interleaved(
|
|
84
|
+
self, messages: List[Dict[str, Any]], system: str, max_tokens: int = 4096
|
|
85
|
+
) -> Any:
|
|
86
|
+
"""Run the Anthropic API with the Claude model, supports interleaved tool calling.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
messages: List of message objects
|
|
90
|
+
system: System prompt
|
|
91
|
+
max_tokens: Maximum tokens to generate
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
API response
|
|
95
|
+
"""
|
|
96
|
+
# Add the tool_result check/fix logic here
|
|
97
|
+
fixed_messages = self._fix_missing_tool_results(messages)
|
|
98
|
+
|
|
99
|
+
# Get model name from concrete implementation if available
|
|
100
|
+
model_name = getattr(self, "model", "unknown model")
|
|
101
|
+
logger.info(f"Running Anthropic API call with model {model_name}")
|
|
102
|
+
|
|
103
|
+
retry_count = 0
|
|
104
|
+
|
|
105
|
+
while retry_count < self.MAX_RETRIES:
|
|
106
|
+
try:
|
|
107
|
+
# Call the Anthropic API through create_message which is implemented by subclasses
|
|
108
|
+
# Convert system str to the list format expected by create_message
|
|
109
|
+
system_list = [system]
|
|
110
|
+
|
|
111
|
+
# Convert message format if needed - concrete implementations may do further conversion
|
|
112
|
+
response = await self.create_message(
|
|
113
|
+
messages=cast(list[BetaMessageParam], fixed_messages),
|
|
114
|
+
system=system_list,
|
|
115
|
+
tools=[], # Tools are included in the messages
|
|
116
|
+
max_tokens=max_tokens,
|
|
117
|
+
betas=["tools-2023-12-13"],
|
|
118
|
+
)
|
|
119
|
+
logger.info(f"Anthropic API call successful")
|
|
120
|
+
return response
|
|
121
|
+
except Exception as e:
|
|
122
|
+
retry_count += 1
|
|
123
|
+
wait_time = self.INITIAL_RETRY_DELAY * (
|
|
124
|
+
2 ** (retry_count - 1)
|
|
125
|
+
) # Exponential backoff
|
|
126
|
+
logger.info(
|
|
127
|
+
f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) in {wait_time:.2f} seconds after error: {str(e)}"
|
|
128
|
+
)
|
|
129
|
+
await asyncio.sleep(wait_time)
|
|
130
|
+
|
|
131
|
+
# If we get here, all retries failed
|
|
132
|
+
raise RuntimeError(f"Failed to call Anthropic API after {self.MAX_RETRIES} attempts")
|
|
133
|
+
|
|
134
|
+
def _fix_missing_tool_results(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
135
|
+
"""Check for and fix any missing tool_result blocks after tool_use blocks.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
messages: List of message objects
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Fixed messages with proper tool_result blocks
|
|
142
|
+
"""
|
|
143
|
+
fixed_messages = []
|
|
144
|
+
pending_tool_uses = {} # Map of tool_use IDs to their details
|
|
145
|
+
|
|
146
|
+
for i, message in enumerate(messages):
|
|
147
|
+
# Track any tool_use blocks in this message
|
|
148
|
+
if message.get("role") == "assistant" and "content" in message:
|
|
149
|
+
content = message.get("content", [])
|
|
150
|
+
for block in content:
|
|
151
|
+
if isinstance(block, dict) and block.get("type") == "tool_use":
|
|
152
|
+
tool_id = block.get("id")
|
|
153
|
+
if tool_id:
|
|
154
|
+
pending_tool_uses[tool_id] = {
|
|
155
|
+
"name": block.get("name", ""),
|
|
156
|
+
"input": block.get("input", {}),
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
# Check if this message handles any pending tool_use blocks
|
|
160
|
+
if message.get("role") == "user" and "content" in message:
|
|
161
|
+
# Check for tool_result blocks in this message
|
|
162
|
+
content = message.get("content", [])
|
|
163
|
+
for block in content:
|
|
164
|
+
if isinstance(block, dict) and block.get("type") == "tool_result":
|
|
165
|
+
tool_id = block.get("tool_use_id")
|
|
166
|
+
if tool_id in pending_tool_uses:
|
|
167
|
+
# This tool_result handles a pending tool_use
|
|
168
|
+
pending_tool_uses.pop(tool_id)
|
|
169
|
+
|
|
170
|
+
# Add the message to our fixed list
|
|
171
|
+
fixed_messages.append(message)
|
|
172
|
+
|
|
173
|
+
# If this is an assistant message with tool_use blocks and there are
|
|
174
|
+
# pending tool uses that need to be resolved before the next assistant message
|
|
175
|
+
if (
|
|
176
|
+
i + 1 < len(messages)
|
|
177
|
+
and message.get("role") == "assistant"
|
|
178
|
+
and messages[i + 1].get("role") == "assistant"
|
|
179
|
+
and pending_tool_uses
|
|
180
|
+
):
|
|
181
|
+
|
|
182
|
+
# We need to insert a user message with tool_results for all pending tool_uses
|
|
183
|
+
tool_results = []
|
|
184
|
+
for tool_id, tool_info in pending_tool_uses.items():
|
|
185
|
+
tool_results.append(
|
|
186
|
+
{
|
|
187
|
+
"type": "tool_result",
|
|
188
|
+
"tool_use_id": tool_id,
|
|
189
|
+
"content": {
|
|
190
|
+
"type": "error",
|
|
191
|
+
"message": "Tool execution was skipped or failed",
|
|
192
|
+
},
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Insert a synthetic user message with the tool results
|
|
197
|
+
if tool_results:
|
|
198
|
+
fixed_messages.append({"role": "user", "content": tool_results})
|
|
199
|
+
|
|
200
|
+
# Clear pending tools since we've added results for them
|
|
201
|
+
pending_tool_uses = {}
|
|
202
|
+
|
|
203
|
+
# Check if there are any remaining pending tool_uses at the end of the conversation
|
|
204
|
+
if pending_tool_uses and fixed_messages and fixed_messages[-1].get("role") == "assistant":
|
|
205
|
+
# Add a final user message with tool results for any pending tool_uses
|
|
206
|
+
tool_results = []
|
|
207
|
+
for tool_id, tool_info in pending_tool_uses.items():
|
|
208
|
+
tool_results.append(
|
|
209
|
+
{
|
|
210
|
+
"type": "tool_result",
|
|
211
|
+
"tool_use_id": tool_id,
|
|
212
|
+
"content": {
|
|
213
|
+
"type": "error",
|
|
214
|
+
"message": "Tool execution was skipped or failed",
|
|
215
|
+
},
|
|
216
|
+
}
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if tool_results:
|
|
220
|
+
fixed_messages.append({"role": "user", "content": tool_results})
|
|
221
|
+
|
|
222
|
+
return fixed_messages
|
|
223
|
+
|
|
83
224
|
|
|
84
225
|
class AnthropicDirectClient(BaseAnthropicClient):
|
|
85
226
|
"""Direct Anthropic API client implementation."""
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""API call handling for Anthropic provider."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import asyncio
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
from anthropic.types.beta import (
|
|
8
|
+
BetaMessage,
|
|
9
|
+
BetaMessageParam,
|
|
10
|
+
BetaTextBlockParam,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from .types import LLMProvider
|
|
14
|
+
from .prompts import SYSTEM_PROMPT
|
|
15
|
+
|
|
16
|
+
# Constants
|
|
17
|
+
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
|
|
18
|
+
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AnthropicAPIHandler:
|
|
24
|
+
"""Handles API calls to Anthropic's API with structured error handling and retries."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, loop):
|
|
27
|
+
"""Initialize the API handler.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
loop: Reference to the parent loop instance that provides context
|
|
31
|
+
"""
|
|
32
|
+
self.loop = loop
|
|
33
|
+
|
|
34
|
+
async def make_api_call(
|
|
35
|
+
self, messages: List[BetaMessageParam], system_prompt: str = SYSTEM_PROMPT
|
|
36
|
+
) -> BetaMessage:
|
|
37
|
+
"""Make API call to Anthropic with retry logic.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
messages: List of messages to send to the API
|
|
41
|
+
system_prompt: System prompt to use (default: SYSTEM_PROMPT)
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
API response
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
RuntimeError: If API call fails after all retries
|
|
48
|
+
"""
|
|
49
|
+
if self.loop.client is None:
|
|
50
|
+
raise RuntimeError("Client not initialized. Call initialize_client() first.")
|
|
51
|
+
if self.loop.tool_manager is None:
|
|
52
|
+
raise RuntimeError("Tool manager not initialized. Call initialize_client() first.")
|
|
53
|
+
|
|
54
|
+
last_error = None
|
|
55
|
+
|
|
56
|
+
# Add detailed debug logging to examine messages
|
|
57
|
+
logger.info(f"Sending {len(messages)} messages to Anthropic API")
|
|
58
|
+
|
|
59
|
+
# Log tool use IDs and tool result IDs for debugging
|
|
60
|
+
tool_use_ids = set()
|
|
61
|
+
tool_result_ids = set()
|
|
62
|
+
|
|
63
|
+
for i, msg in enumerate(messages):
|
|
64
|
+
logger.info(f"Message {i}: role={msg.get('role')}")
|
|
65
|
+
if isinstance(msg.get("content"), list):
|
|
66
|
+
for content_block in msg.get("content", []):
|
|
67
|
+
if isinstance(content_block, dict):
|
|
68
|
+
block_type = content_block.get("type")
|
|
69
|
+
if block_type == "tool_use" and "id" in content_block:
|
|
70
|
+
tool_id = content_block.get("id")
|
|
71
|
+
tool_use_ids.add(tool_id)
|
|
72
|
+
logger.info(f" - Found tool_use with ID: {tool_id}")
|
|
73
|
+
elif block_type == "tool_result" and "tool_use_id" in content_block:
|
|
74
|
+
result_id = content_block.get("tool_use_id")
|
|
75
|
+
tool_result_ids.add(result_id)
|
|
76
|
+
logger.info(f" - Found tool_result referencing ID: {result_id}")
|
|
77
|
+
|
|
78
|
+
# Check for mismatches
|
|
79
|
+
missing_tool_uses = tool_result_ids - tool_use_ids
|
|
80
|
+
if missing_tool_uses:
|
|
81
|
+
logger.warning(
|
|
82
|
+
f"Found tool_result IDs without matching tool_use IDs: {missing_tool_uses}"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
for attempt in range(self.loop.max_retries):
|
|
86
|
+
try:
|
|
87
|
+
# Log request
|
|
88
|
+
request_data = {
|
|
89
|
+
"messages": messages,
|
|
90
|
+
"max_tokens": self.loop.max_tokens,
|
|
91
|
+
"system": system_prompt,
|
|
92
|
+
}
|
|
93
|
+
# Let ExperimentManager handle sanitization
|
|
94
|
+
self.loop._log_api_call("request", request_data)
|
|
95
|
+
|
|
96
|
+
# Setup betas and system
|
|
97
|
+
system = BetaTextBlockParam(
|
|
98
|
+
type="text",
|
|
99
|
+
text=system_prompt,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
betas = [COMPUTER_USE_BETA_FLAG]
|
|
103
|
+
# Add prompt caching if enabled in the message manager's config
|
|
104
|
+
if self.loop.message_manager.config.enable_caching:
|
|
105
|
+
betas.append(PROMPT_CACHING_BETA_FLAG)
|
|
106
|
+
system["cache_control"] = {"type": "ephemeral"}
|
|
107
|
+
|
|
108
|
+
# Make API call
|
|
109
|
+
response = await self.loop.client.create_message(
|
|
110
|
+
messages=messages,
|
|
111
|
+
system=[system],
|
|
112
|
+
tools=self.loop.tool_manager.get_tool_params(),
|
|
113
|
+
max_tokens=self.loop.max_tokens,
|
|
114
|
+
betas=betas,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Let ExperimentManager handle sanitization
|
|
118
|
+
self.loop._log_api_call("response", request_data, response)
|
|
119
|
+
|
|
120
|
+
return response
|
|
121
|
+
except Exception as e:
|
|
122
|
+
last_error = e
|
|
123
|
+
logger.error(
|
|
124
|
+
f"Error in API call (attempt {attempt + 1}/{self.loop.max_retries}): {str(e)}"
|
|
125
|
+
)
|
|
126
|
+
self.loop._log_api_call("error", {"messages": messages}, error=e)
|
|
127
|
+
|
|
128
|
+
if attempt < self.loop.max_retries - 1:
|
|
129
|
+
await asyncio.sleep(
|
|
130
|
+
self.loop.retry_delay * (attempt + 1)
|
|
131
|
+
) # Exponential backoff
|
|
132
|
+
continue
|
|
133
|
+
|
|
134
|
+
# If we get here, all retries failed
|
|
135
|
+
error_message = f"API call failed after {self.loop.max_retries} attempts"
|
|
136
|
+
if last_error:
|
|
137
|
+
error_message += f": {str(last_error)}"
|
|
138
|
+
|
|
139
|
+
logger.error(error_message)
|
|
140
|
+
raise RuntimeError(error_message)
|