cua-agent 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +21 -12
- agent/__main__.py +21 -0
- agent/adapters/__init__.py +9 -0
- agent/adapters/huggingfacelocal_adapter.py +229 -0
- agent/agent.py +594 -0
- agent/callbacks/__init__.py +19 -0
- agent/callbacks/base.py +153 -0
- agent/callbacks/budget_manager.py +44 -0
- agent/callbacks/image_retention.py +139 -0
- agent/callbacks/logging.py +247 -0
- agent/callbacks/pii_anonymization.py +259 -0
- agent/callbacks/telemetry.py +210 -0
- agent/callbacks/trajectory_saver.py +305 -0
- agent/cli.py +297 -0
- agent/computer_handler.py +107 -0
- agent/decorators.py +90 -0
- agent/loops/__init__.py +11 -0
- agent/loops/anthropic.py +728 -0
- agent/loops/omniparser.py +339 -0
- agent/loops/openai.py +95 -0
- agent/loops/uitars.py +688 -0
- agent/responses.py +207 -0
- agent/telemetry.py +135 -14
- agent/types.py +79 -0
- agent/ui/__init__.py +7 -1
- agent/ui/__main__.py +2 -13
- agent/ui/gradio/__init__.py +6 -19
- agent/ui/gradio/app.py +94 -1313
- agent/ui/gradio/ui_components.py +721 -0
- cua_agent-0.4.0.dist-info/METADATA +424 -0
- cua_agent-0.4.0.dist-info/RECORD +33 -0
- {cua_agent-0.3.1.dist-info → cua_agent-0.4.0.dist-info}/WHEEL +1 -1
- agent/core/__init__.py +0 -27
- agent/core/agent.py +0 -210
- agent/core/base.py +0 -217
- agent/core/callbacks.py +0 -200
- agent/core/experiment.py +0 -249
- agent/core/factory.py +0 -122
- agent/core/messages.py +0 -332
- agent/core/provider_config.py +0 -21
- agent/core/telemetry.py +0 -142
- agent/core/tools/__init__.py +0 -21
- agent/core/tools/base.py +0 -74
- agent/core/tools/bash.py +0 -52
- agent/core/tools/collection.py +0 -46
- agent/core/tools/computer.py +0 -113
- agent/core/tools/edit.py +0 -67
- agent/core/tools/manager.py +0 -56
- agent/core/tools.py +0 -32
- agent/core/types.py +0 -88
- agent/core/visualization.py +0 -197
- agent/providers/__init__.py +0 -4
- agent/providers/anthropic/__init__.py +0 -6
- agent/providers/anthropic/api/client.py +0 -360
- agent/providers/anthropic/api/logging.py +0 -150
- agent/providers/anthropic/api_handler.py +0 -140
- agent/providers/anthropic/callbacks/__init__.py +0 -5
- agent/providers/anthropic/callbacks/manager.py +0 -65
- agent/providers/anthropic/loop.py +0 -568
- agent/providers/anthropic/prompts.py +0 -23
- agent/providers/anthropic/response_handler.py +0 -226
- agent/providers/anthropic/tools/__init__.py +0 -33
- agent/providers/anthropic/tools/base.py +0 -88
- agent/providers/anthropic/tools/bash.py +0 -66
- agent/providers/anthropic/tools/collection.py +0 -34
- agent/providers/anthropic/tools/computer.py +0 -396
- agent/providers/anthropic/tools/edit.py +0 -326
- agent/providers/anthropic/tools/manager.py +0 -54
- agent/providers/anthropic/tools/run.py +0 -42
- agent/providers/anthropic/types.py +0 -16
- agent/providers/anthropic/utils.py +0 -367
- agent/providers/omni/__init__.py +0 -8
- agent/providers/omni/api_handler.py +0 -42
- agent/providers/omni/clients/anthropic.py +0 -103
- agent/providers/omni/clients/base.py +0 -35
- agent/providers/omni/clients/oaicompat.py +0 -195
- agent/providers/omni/clients/ollama.py +0 -122
- agent/providers/omni/clients/openai.py +0 -155
- agent/providers/omni/clients/utils.py +0 -25
- agent/providers/omni/image_utils.py +0 -34
- agent/providers/omni/loop.py +0 -990
- agent/providers/omni/parser.py +0 -307
- agent/providers/omni/prompts.py +0 -64
- agent/providers/omni/tools/__init__.py +0 -30
- agent/providers/omni/tools/base.py +0 -29
- agent/providers/omni/tools/bash.py +0 -74
- agent/providers/omni/tools/computer.py +0 -179
- agent/providers/omni/tools/manager.py +0 -61
- agent/providers/omni/utils.py +0 -236
- agent/providers/openai/__init__.py +0 -6
- agent/providers/openai/api_handler.py +0 -456
- agent/providers/openai/loop.py +0 -472
- agent/providers/openai/response_handler.py +0 -205
- agent/providers/openai/tools/__init__.py +0 -15
- agent/providers/openai/tools/base.py +0 -79
- agent/providers/openai/tools/computer.py +0 -326
- agent/providers/openai/tools/manager.py +0 -106
- agent/providers/openai/types.py +0 -36
- agent/providers/openai/utils.py +0 -98
- agent/providers/uitars/__init__.py +0 -1
- agent/providers/uitars/clients/base.py +0 -35
- agent/providers/uitars/clients/mlxvlm.py +0 -263
- agent/providers/uitars/clients/oaicompat.py +0 -214
- agent/providers/uitars/loop.py +0 -660
- agent/providers/uitars/prompts.py +0 -63
- agent/providers/uitars/tools/__init__.py +0 -1
- agent/providers/uitars/tools/computer.py +0 -283
- agent/providers/uitars/tools/manager.py +0 -60
- agent/providers/uitars/utils.py +0 -264
- cua_agent-0.3.1.dist-info/METADATA +0 -295
- cua_agent-0.3.1.dist-info/RECORD +0 -87
- {cua_agent-0.3.1.dist-info → cua_agent-0.4.0.dist-info}/entry_points.txt +0 -0
agent/core/visualization.py
DELETED
|
@@ -1,197 +0,0 @@
|
|
|
1
|
-
"""Core visualization utilities for agents."""
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
import base64
|
|
5
|
-
from typing import Dict, Tuple
|
|
6
|
-
from PIL import Image, ImageDraw
|
|
7
|
-
from io import BytesIO
|
|
8
|
-
|
|
9
|
-
logger = logging.getLogger(__name__)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
|
|
13
|
-
"""Visualize a click action by drawing a circle on the screenshot.
|
|
14
|
-
|
|
15
|
-
Args:
|
|
16
|
-
x: X coordinate of the click
|
|
17
|
-
y: Y coordinate of the click
|
|
18
|
-
img_base64: Base64-encoded screenshot
|
|
19
|
-
|
|
20
|
-
Returns:
|
|
21
|
-
PIL Image with visualization
|
|
22
|
-
"""
|
|
23
|
-
try:
|
|
24
|
-
# Decode the base64 image
|
|
25
|
-
image_data = base64.b64decode(img_base64)
|
|
26
|
-
img = Image.open(BytesIO(image_data))
|
|
27
|
-
|
|
28
|
-
# Create a copy to draw on
|
|
29
|
-
draw_img = img.copy()
|
|
30
|
-
draw = ImageDraw.Draw(draw_img)
|
|
31
|
-
|
|
32
|
-
# Draw a circle at the click location
|
|
33
|
-
radius = 15
|
|
34
|
-
draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline="red", width=3)
|
|
35
|
-
|
|
36
|
-
# Draw crosshairs
|
|
37
|
-
line_length = 20
|
|
38
|
-
draw.line([(x - line_length, y), (x + line_length, y)], fill="red", width=3)
|
|
39
|
-
draw.line([(x, y - line_length), (x, y + line_length)], fill="red", width=3)
|
|
40
|
-
|
|
41
|
-
return draw_img
|
|
42
|
-
except Exception as e:
|
|
43
|
-
logger.error(f"Error visualizing click: {str(e)}")
|
|
44
|
-
# Return a blank image as fallback
|
|
45
|
-
return Image.new("RGB", (800, 600), "white")
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
|
|
49
|
-
"""Visualize a scroll action by drawing arrows on the screenshot.
|
|
50
|
-
|
|
51
|
-
Args:
|
|
52
|
-
direction: Direction of scroll ('up' or 'down')
|
|
53
|
-
clicks: Number of scroll clicks
|
|
54
|
-
img_base64: Base64-encoded screenshot
|
|
55
|
-
|
|
56
|
-
Returns:
|
|
57
|
-
PIL Image with visualization
|
|
58
|
-
"""
|
|
59
|
-
try:
|
|
60
|
-
# Decode the base64 image
|
|
61
|
-
image_data = base64.b64decode(img_base64)
|
|
62
|
-
img = Image.open(BytesIO(image_data))
|
|
63
|
-
|
|
64
|
-
# Create a copy to draw on
|
|
65
|
-
draw_img = img.copy()
|
|
66
|
-
draw = ImageDraw.Draw(draw_img)
|
|
67
|
-
|
|
68
|
-
# Calculate parameters for visualization
|
|
69
|
-
width, height = img.size
|
|
70
|
-
center_x = width // 2
|
|
71
|
-
|
|
72
|
-
# Draw arrows to indicate scrolling
|
|
73
|
-
arrow_length = min(100, height // 4)
|
|
74
|
-
arrow_width = 30
|
|
75
|
-
num_arrows = min(clicks, 3) # Don't draw too many arrows
|
|
76
|
-
|
|
77
|
-
# Calculate starting position
|
|
78
|
-
if direction == "down":
|
|
79
|
-
start_y = height // 3
|
|
80
|
-
arrow_dir = 1 # Down
|
|
81
|
-
else:
|
|
82
|
-
start_y = height * 2 // 3
|
|
83
|
-
arrow_dir = -1 # Up
|
|
84
|
-
|
|
85
|
-
# Draw the arrows
|
|
86
|
-
for i in range(num_arrows):
|
|
87
|
-
y_pos = start_y + (i * arrow_length * arrow_dir * 0.7)
|
|
88
|
-
arrow_top = (center_x, y_pos)
|
|
89
|
-
arrow_bottom = (center_x, y_pos + arrow_length * arrow_dir)
|
|
90
|
-
|
|
91
|
-
# Draw the main line
|
|
92
|
-
draw.line([arrow_top, arrow_bottom], fill="red", width=5)
|
|
93
|
-
|
|
94
|
-
# Draw the arrowhead
|
|
95
|
-
arrowhead_size = 20
|
|
96
|
-
if direction == "down":
|
|
97
|
-
draw.line(
|
|
98
|
-
[
|
|
99
|
-
(center_x - arrow_width // 2, arrow_bottom[1] - arrowhead_size),
|
|
100
|
-
arrow_bottom,
|
|
101
|
-
(center_x + arrow_width // 2, arrow_bottom[1] - arrowhead_size),
|
|
102
|
-
],
|
|
103
|
-
fill="red",
|
|
104
|
-
width=5,
|
|
105
|
-
)
|
|
106
|
-
else:
|
|
107
|
-
draw.line(
|
|
108
|
-
[
|
|
109
|
-
(center_x - arrow_width // 2, arrow_bottom[1] + arrowhead_size),
|
|
110
|
-
arrow_bottom,
|
|
111
|
-
(center_x + arrow_width // 2, arrow_bottom[1] + arrowhead_size),
|
|
112
|
-
],
|
|
113
|
-
fill="red",
|
|
114
|
-
width=5,
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
return draw_img
|
|
118
|
-
except Exception as e:
|
|
119
|
-
logger.error(f"Error visualizing scroll: {str(e)}")
|
|
120
|
-
# Return a blank image as fallback
|
|
121
|
-
return Image.new("RGB", (800, 600), "white")
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def calculate_element_center(bbox: Dict[str, float], width: int, height: int) -> Tuple[int, int]:
|
|
125
|
-
"""Calculate the center point of a UI element.
|
|
126
|
-
|
|
127
|
-
Args:
|
|
128
|
-
bbox: Bounding box dictionary with x1, y1, x2, y2 coordinates (0-1 normalized)
|
|
129
|
-
width: Screen width in pixels
|
|
130
|
-
height: Screen height in pixels
|
|
131
|
-
|
|
132
|
-
Returns:
|
|
133
|
-
(x, y) tuple with pixel coordinates
|
|
134
|
-
"""
|
|
135
|
-
center_x = int((bbox["x1"] + bbox["x2"]) / 2 * width)
|
|
136
|
-
center_y = int((bbox["y1"] + bbox["y2"]) / 2 * height)
|
|
137
|
-
return center_x, center_y
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
class VisualizationHelper:
|
|
141
|
-
"""Helper class for visualizing agent actions."""
|
|
142
|
-
|
|
143
|
-
def __init__(self, agent):
|
|
144
|
-
"""Initialize visualization helper.
|
|
145
|
-
|
|
146
|
-
Args:
|
|
147
|
-
agent: Reference to the agent that will use this helper
|
|
148
|
-
"""
|
|
149
|
-
self.agent = agent
|
|
150
|
-
|
|
151
|
-
def visualize_action(self, x: int, y: int, img_base64: str) -> None:
|
|
152
|
-
"""Visualize a click action by drawing on the screenshot."""
|
|
153
|
-
if (
|
|
154
|
-
not self.agent.save_trajectory
|
|
155
|
-
or not hasattr(self.agent, "experiment_manager")
|
|
156
|
-
or not self.agent.experiment_manager
|
|
157
|
-
):
|
|
158
|
-
return
|
|
159
|
-
|
|
160
|
-
try:
|
|
161
|
-
# Use the visualization utility
|
|
162
|
-
img = visualize_click(x, y, img_base64)
|
|
163
|
-
|
|
164
|
-
# Save the visualization
|
|
165
|
-
self.agent.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
|
|
166
|
-
except Exception as e:
|
|
167
|
-
logger.error(f"Error visualizing action: {str(e)}")
|
|
168
|
-
|
|
169
|
-
def visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
|
|
170
|
-
"""Visualize a scroll action by drawing arrows on the screenshot."""
|
|
171
|
-
if (
|
|
172
|
-
not self.agent.save_trajectory
|
|
173
|
-
or not hasattr(self.agent, "experiment_manager")
|
|
174
|
-
or not self.agent.experiment_manager
|
|
175
|
-
):
|
|
176
|
-
return
|
|
177
|
-
|
|
178
|
-
try:
|
|
179
|
-
# Use the visualization utility
|
|
180
|
-
img = visualize_scroll(direction, clicks, img_base64)
|
|
181
|
-
|
|
182
|
-
# Save the visualization
|
|
183
|
-
self.agent.experiment_manager.save_action_visualization(
|
|
184
|
-
img, "scroll", f"{direction}_{clicks}"
|
|
185
|
-
)
|
|
186
|
-
except Exception as e:
|
|
187
|
-
logger.error(f"Error visualizing scroll: {str(e)}")
|
|
188
|
-
|
|
189
|
-
def save_action_visualization(
|
|
190
|
-
self, img: Image.Image, action_name: str, details: str = ""
|
|
191
|
-
) -> str:
|
|
192
|
-
"""Save a visualization of an action."""
|
|
193
|
-
if hasattr(self.agent, "experiment_manager") and self.agent.experiment_manager:
|
|
194
|
-
return self.agent.experiment_manager.save_action_visualization(
|
|
195
|
-
img, action_name, details
|
|
196
|
-
)
|
|
197
|
-
return ""
|
agent/providers/__init__.py
DELETED
|
@@ -1,360 +0,0 @@
|
|
|
1
|
-
from typing import Any, List, Dict, cast
|
|
2
|
-
import httpx
|
|
3
|
-
import asyncio
|
|
4
|
-
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
|
|
5
|
-
from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaToolUnionParam
|
|
6
|
-
from ..types import LLMProvider
|
|
7
|
-
from .logging import log_api_interaction
|
|
8
|
-
import random
|
|
9
|
-
import logging
|
|
10
|
-
|
|
11
|
-
logger = logging.getLogger(__name__)
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class APIConnectionError(Exception):
|
|
15
|
-
"""Error raised when there are connection issues with the API."""
|
|
16
|
-
|
|
17
|
-
pass
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class BaseAnthropicClient:
|
|
21
|
-
"""Base class for Anthropic API clients."""
|
|
22
|
-
|
|
23
|
-
MAX_RETRIES = 10
|
|
24
|
-
INITIAL_RETRY_DELAY = 1.0
|
|
25
|
-
MAX_RETRY_DELAY = 60.0
|
|
26
|
-
JITTER_FACTOR = 0.1
|
|
27
|
-
|
|
28
|
-
async def create_message(
|
|
29
|
-
self,
|
|
30
|
-
*,
|
|
31
|
-
messages: list[BetaMessageParam],
|
|
32
|
-
system: list[Any],
|
|
33
|
-
tools: list[BetaToolUnionParam],
|
|
34
|
-
max_tokens: int,
|
|
35
|
-
betas: list[str],
|
|
36
|
-
) -> BetaMessage:
|
|
37
|
-
"""Create a message using the Anthropic API."""
|
|
38
|
-
raise NotImplementedError
|
|
39
|
-
|
|
40
|
-
async def _make_api_call_with_retries(self, api_call):
|
|
41
|
-
"""Make an API call with exponential backoff retry logic.
|
|
42
|
-
|
|
43
|
-
Args:
|
|
44
|
-
api_call: Async function that makes the actual API call
|
|
45
|
-
|
|
46
|
-
Returns:
|
|
47
|
-
API response
|
|
48
|
-
|
|
49
|
-
Raises:
|
|
50
|
-
APIConnectionError: If all retries fail
|
|
51
|
-
"""
|
|
52
|
-
retry_count = 0
|
|
53
|
-
last_error = None
|
|
54
|
-
|
|
55
|
-
while retry_count < self.MAX_RETRIES:
|
|
56
|
-
try:
|
|
57
|
-
return await api_call()
|
|
58
|
-
except Exception as e:
|
|
59
|
-
last_error = e
|
|
60
|
-
retry_count += 1
|
|
61
|
-
|
|
62
|
-
if retry_count == self.MAX_RETRIES:
|
|
63
|
-
break
|
|
64
|
-
|
|
65
|
-
# Calculate delay with exponential backoff and jitter
|
|
66
|
-
delay = min(
|
|
67
|
-
self.INITIAL_RETRY_DELAY * (2 ** (retry_count - 1)), self.MAX_RETRY_DELAY
|
|
68
|
-
)
|
|
69
|
-
# Add jitter to avoid thundering herd
|
|
70
|
-
jitter = delay * self.JITTER_FACTOR * (2 * random.random() - 1)
|
|
71
|
-
final_delay = delay + jitter
|
|
72
|
-
|
|
73
|
-
logger.info(
|
|
74
|
-
f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) "
|
|
75
|
-
f"in {final_delay:.2f} seconds after error: {str(e)}"
|
|
76
|
-
)
|
|
77
|
-
await asyncio.sleep(final_delay)
|
|
78
|
-
|
|
79
|
-
raise APIConnectionError(
|
|
80
|
-
f"Failed after {self.MAX_RETRIES} retries. " f"Last error: {str(last_error)}"
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
async def run_interleaved(
|
|
84
|
-
self, messages: List[Dict[str, Any]], system: str, max_tokens: int = 4096
|
|
85
|
-
) -> Any:
|
|
86
|
-
"""Run the Anthropic API with the Claude model, supports interleaved tool calling.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
messages: List of message objects
|
|
90
|
-
system: System prompt
|
|
91
|
-
max_tokens: Maximum tokens to generate
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
API response
|
|
95
|
-
"""
|
|
96
|
-
# Add the tool_result check/fix logic here
|
|
97
|
-
fixed_messages = self._fix_missing_tool_results(messages)
|
|
98
|
-
|
|
99
|
-
# Get model name from concrete implementation if available
|
|
100
|
-
model_name = getattr(self, "model", "unknown model")
|
|
101
|
-
logger.info(f"Running Anthropic API call with model {model_name}")
|
|
102
|
-
|
|
103
|
-
retry_count = 0
|
|
104
|
-
|
|
105
|
-
while retry_count < self.MAX_RETRIES:
|
|
106
|
-
try:
|
|
107
|
-
# Call the Anthropic API through create_message which is implemented by subclasses
|
|
108
|
-
# Convert system str to the list format expected by create_message
|
|
109
|
-
system_list = [system]
|
|
110
|
-
|
|
111
|
-
# Convert message format if needed - concrete implementations may do further conversion
|
|
112
|
-
response = await self.create_message(
|
|
113
|
-
messages=cast(list[BetaMessageParam], fixed_messages),
|
|
114
|
-
system=system_list,
|
|
115
|
-
tools=[], # Tools are included in the messages
|
|
116
|
-
max_tokens=max_tokens,
|
|
117
|
-
betas=["tools-2023-12-13"],
|
|
118
|
-
)
|
|
119
|
-
logger.info(f"Anthropic API call successful")
|
|
120
|
-
return response
|
|
121
|
-
except Exception as e:
|
|
122
|
-
retry_count += 1
|
|
123
|
-
wait_time = self.INITIAL_RETRY_DELAY * (
|
|
124
|
-
2 ** (retry_count - 1)
|
|
125
|
-
) # Exponential backoff
|
|
126
|
-
logger.info(
|
|
127
|
-
f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) in {wait_time:.2f} seconds after error: {str(e)}"
|
|
128
|
-
)
|
|
129
|
-
await asyncio.sleep(wait_time)
|
|
130
|
-
|
|
131
|
-
# If we get here, all retries failed
|
|
132
|
-
raise RuntimeError(f"Failed to call Anthropic API after {self.MAX_RETRIES} attempts")
|
|
133
|
-
|
|
134
|
-
def _fix_missing_tool_results(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
135
|
-
"""Check for and fix any missing tool_result blocks after tool_use blocks.
|
|
136
|
-
|
|
137
|
-
Args:
|
|
138
|
-
messages: List of message objects
|
|
139
|
-
|
|
140
|
-
Returns:
|
|
141
|
-
Fixed messages with proper tool_result blocks
|
|
142
|
-
"""
|
|
143
|
-
fixed_messages = []
|
|
144
|
-
pending_tool_uses = {} # Map of tool_use IDs to their details
|
|
145
|
-
|
|
146
|
-
for i, message in enumerate(messages):
|
|
147
|
-
# Track any tool_use blocks in this message
|
|
148
|
-
if message.get("role") == "assistant" and "content" in message:
|
|
149
|
-
content = message.get("content", [])
|
|
150
|
-
for block in content:
|
|
151
|
-
if isinstance(block, dict) and block.get("type") == "tool_use":
|
|
152
|
-
tool_id = block.get("id")
|
|
153
|
-
if tool_id:
|
|
154
|
-
pending_tool_uses[tool_id] = {
|
|
155
|
-
"name": block.get("name", ""),
|
|
156
|
-
"input": block.get("input", {}),
|
|
157
|
-
}
|
|
158
|
-
|
|
159
|
-
# Check if this message handles any pending tool_use blocks
|
|
160
|
-
if message.get("role") == "user" and "content" in message:
|
|
161
|
-
# Check for tool_result blocks in this message
|
|
162
|
-
content = message.get("content", [])
|
|
163
|
-
for block in content:
|
|
164
|
-
if isinstance(block, dict) and block.get("type") == "tool_result":
|
|
165
|
-
tool_id = block.get("tool_use_id")
|
|
166
|
-
if tool_id in pending_tool_uses:
|
|
167
|
-
# This tool_result handles a pending tool_use
|
|
168
|
-
pending_tool_uses.pop(tool_id)
|
|
169
|
-
|
|
170
|
-
# Add the message to our fixed list
|
|
171
|
-
fixed_messages.append(message)
|
|
172
|
-
|
|
173
|
-
# If this is an assistant message with tool_use blocks and there are
|
|
174
|
-
# pending tool uses that need to be resolved before the next assistant message
|
|
175
|
-
if (
|
|
176
|
-
i + 1 < len(messages)
|
|
177
|
-
and message.get("role") == "assistant"
|
|
178
|
-
and messages[i + 1].get("role") == "assistant"
|
|
179
|
-
and pending_tool_uses
|
|
180
|
-
):
|
|
181
|
-
|
|
182
|
-
# We need to insert a user message with tool_results for all pending tool_uses
|
|
183
|
-
tool_results = []
|
|
184
|
-
for tool_id, tool_info in pending_tool_uses.items():
|
|
185
|
-
tool_results.append(
|
|
186
|
-
{
|
|
187
|
-
"type": "tool_result",
|
|
188
|
-
"tool_use_id": tool_id,
|
|
189
|
-
"content": {
|
|
190
|
-
"type": "error",
|
|
191
|
-
"message": "Tool execution was skipped or failed",
|
|
192
|
-
},
|
|
193
|
-
}
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
# Insert a synthetic user message with the tool results
|
|
197
|
-
if tool_results:
|
|
198
|
-
fixed_messages.append({"role": "user", "content": tool_results})
|
|
199
|
-
|
|
200
|
-
# Clear pending tools since we've added results for them
|
|
201
|
-
pending_tool_uses = {}
|
|
202
|
-
|
|
203
|
-
# Check if there are any remaining pending tool_uses at the end of the conversation
|
|
204
|
-
if pending_tool_uses and fixed_messages and fixed_messages[-1].get("role") == "assistant":
|
|
205
|
-
# Add a final user message with tool results for any pending tool_uses
|
|
206
|
-
tool_results = []
|
|
207
|
-
for tool_id, tool_info in pending_tool_uses.items():
|
|
208
|
-
tool_results.append(
|
|
209
|
-
{
|
|
210
|
-
"type": "tool_result",
|
|
211
|
-
"tool_use_id": tool_id,
|
|
212
|
-
"content": {
|
|
213
|
-
"type": "error",
|
|
214
|
-
"message": "Tool execution was skipped or failed",
|
|
215
|
-
},
|
|
216
|
-
}
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
if tool_results:
|
|
220
|
-
fixed_messages.append({"role": "user", "content": tool_results})
|
|
221
|
-
|
|
222
|
-
return fixed_messages
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
class AnthropicDirectClient(BaseAnthropicClient):
|
|
226
|
-
"""Direct Anthropic API client implementation."""
|
|
227
|
-
|
|
228
|
-
def __init__(self, api_key: str, model: str):
|
|
229
|
-
self.model = model
|
|
230
|
-
self.client = Anthropic(api_key=api_key, http_client=self._create_http_client())
|
|
231
|
-
|
|
232
|
-
def _create_http_client(self) -> httpx.Client:
|
|
233
|
-
"""Create an HTTP client with appropriate settings."""
|
|
234
|
-
return httpx.Client(
|
|
235
|
-
verify=True,
|
|
236
|
-
timeout=httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0),
|
|
237
|
-
transport=httpx.HTTPTransport(
|
|
238
|
-
retries=3,
|
|
239
|
-
verify=True,
|
|
240
|
-
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
|
241
|
-
),
|
|
242
|
-
)
|
|
243
|
-
|
|
244
|
-
async def create_message(
|
|
245
|
-
self,
|
|
246
|
-
*,
|
|
247
|
-
messages: list[BetaMessageParam],
|
|
248
|
-
system: list[Any],
|
|
249
|
-
tools: list[BetaToolUnionParam],
|
|
250
|
-
max_tokens: int,
|
|
251
|
-
betas: list[str],
|
|
252
|
-
) -> BetaMessage:
|
|
253
|
-
"""Create a message using the direct Anthropic API with retry logic."""
|
|
254
|
-
|
|
255
|
-
async def api_call():
|
|
256
|
-
response = self.client.beta.messages.with_raw_response.create(
|
|
257
|
-
max_tokens=max_tokens,
|
|
258
|
-
messages=messages,
|
|
259
|
-
model=self.model,
|
|
260
|
-
system=system,
|
|
261
|
-
tools=tools,
|
|
262
|
-
betas=betas,
|
|
263
|
-
)
|
|
264
|
-
log_api_interaction(response.http_response.request, response.http_response, None)
|
|
265
|
-
return response.parse()
|
|
266
|
-
|
|
267
|
-
try:
|
|
268
|
-
return await self._make_api_call_with_retries(api_call)
|
|
269
|
-
except Exception as e:
|
|
270
|
-
log_api_interaction(None, None, e)
|
|
271
|
-
raise
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
class AnthropicVertexClient(BaseAnthropicClient):
|
|
275
|
-
"""Google Cloud Vertex AI implementation of Anthropic client."""
|
|
276
|
-
|
|
277
|
-
def __init__(self, model: str):
|
|
278
|
-
self.model = model
|
|
279
|
-
self.client = AnthropicVertex()
|
|
280
|
-
|
|
281
|
-
async def create_message(
|
|
282
|
-
self,
|
|
283
|
-
*,
|
|
284
|
-
messages: list[BetaMessageParam],
|
|
285
|
-
system: list[Any],
|
|
286
|
-
tools: list[BetaToolUnionParam],
|
|
287
|
-
max_tokens: int,
|
|
288
|
-
betas: list[str],
|
|
289
|
-
) -> BetaMessage:
|
|
290
|
-
"""Create a message using Vertex AI with retry logic."""
|
|
291
|
-
|
|
292
|
-
async def api_call():
|
|
293
|
-
response = self.client.beta.messages.with_raw_response.create(
|
|
294
|
-
max_tokens=max_tokens,
|
|
295
|
-
messages=messages,
|
|
296
|
-
model=self.model,
|
|
297
|
-
system=system,
|
|
298
|
-
tools=tools,
|
|
299
|
-
betas=betas,
|
|
300
|
-
)
|
|
301
|
-
log_api_interaction(response.http_response.request, response.http_response, None)
|
|
302
|
-
return response.parse()
|
|
303
|
-
|
|
304
|
-
try:
|
|
305
|
-
return await self._make_api_call_with_retries(api_call)
|
|
306
|
-
except Exception as e:
|
|
307
|
-
log_api_interaction(None, None, e)
|
|
308
|
-
raise
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
class AnthropicBedrockClient(BaseAnthropicClient):
|
|
312
|
-
"""AWS Bedrock implementation of Anthropic client."""
|
|
313
|
-
|
|
314
|
-
def __init__(self, model: str):
|
|
315
|
-
self.model = model
|
|
316
|
-
self.client = AnthropicBedrock()
|
|
317
|
-
|
|
318
|
-
async def create_message(
|
|
319
|
-
self,
|
|
320
|
-
*,
|
|
321
|
-
messages: list[BetaMessageParam],
|
|
322
|
-
system: list[Any],
|
|
323
|
-
tools: list[BetaToolUnionParam],
|
|
324
|
-
max_tokens: int,
|
|
325
|
-
betas: list[str],
|
|
326
|
-
) -> BetaMessage:
|
|
327
|
-
"""Create a message using AWS Bedrock with retry logic."""
|
|
328
|
-
|
|
329
|
-
async def api_call():
|
|
330
|
-
response = self.client.beta.messages.with_raw_response.create(
|
|
331
|
-
max_tokens=max_tokens,
|
|
332
|
-
messages=messages,
|
|
333
|
-
model=self.model,
|
|
334
|
-
system=system,
|
|
335
|
-
tools=tools,
|
|
336
|
-
betas=betas,
|
|
337
|
-
)
|
|
338
|
-
log_api_interaction(response.http_response.request, response.http_response, None)
|
|
339
|
-
return response.parse()
|
|
340
|
-
|
|
341
|
-
try:
|
|
342
|
-
return await self._make_api_call_with_retries(api_call)
|
|
343
|
-
except Exception as e:
|
|
344
|
-
log_api_interaction(None, None, e)
|
|
345
|
-
raise
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
class AnthropicClientFactory:
|
|
349
|
-
"""Factory for creating appropriate Anthropic client implementations."""
|
|
350
|
-
|
|
351
|
-
@staticmethod
|
|
352
|
-
def create_client(provider: LLMProvider, api_key: str, model: str) -> BaseAnthropicClient:
|
|
353
|
-
"""Create an appropriate client based on the provider."""
|
|
354
|
-
if provider == LLMProvider.ANTHROPIC:
|
|
355
|
-
return AnthropicDirectClient(api_key, model)
|
|
356
|
-
elif provider == LLMProvider.VERTEX:
|
|
357
|
-
return AnthropicVertexClient(model)
|
|
358
|
-
elif provider == LLMProvider.BEDROCK:
|
|
359
|
-
return AnthropicBedrockClient(model)
|
|
360
|
-
raise ValueError(f"Unsupported provider: {provider}")
|