cua-agent 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (112) hide show
  1. agent/__init__.py +21 -12
  2. agent/__main__.py +21 -0
  3. agent/adapters/__init__.py +9 -0
  4. agent/adapters/huggingfacelocal_adapter.py +229 -0
  5. agent/agent.py +594 -0
  6. agent/callbacks/__init__.py +19 -0
  7. agent/callbacks/base.py +153 -0
  8. agent/callbacks/budget_manager.py +44 -0
  9. agent/callbacks/image_retention.py +139 -0
  10. agent/callbacks/logging.py +247 -0
  11. agent/callbacks/pii_anonymization.py +259 -0
  12. agent/callbacks/telemetry.py +210 -0
  13. agent/callbacks/trajectory_saver.py +305 -0
  14. agent/cli.py +297 -0
  15. agent/computer_handler.py +107 -0
  16. agent/decorators.py +90 -0
  17. agent/loops/__init__.py +11 -0
  18. agent/loops/anthropic.py +728 -0
  19. agent/loops/omniparser.py +339 -0
  20. agent/loops/openai.py +95 -0
  21. agent/loops/uitars.py +688 -0
  22. agent/responses.py +207 -0
  23. agent/telemetry.py +135 -14
  24. agent/types.py +79 -0
  25. agent/ui/__init__.py +7 -1
  26. agent/ui/__main__.py +2 -13
  27. agent/ui/gradio/__init__.py +6 -19
  28. agent/ui/gradio/app.py +94 -1313
  29. agent/ui/gradio/ui_components.py +721 -0
  30. cua_agent-0.4.0.dist-info/METADATA +424 -0
  31. cua_agent-0.4.0.dist-info/RECORD +33 -0
  32. {cua_agent-0.3.1.dist-info → cua_agent-0.4.0.dist-info}/WHEEL +1 -1
  33. agent/core/__init__.py +0 -27
  34. agent/core/agent.py +0 -210
  35. agent/core/base.py +0 -217
  36. agent/core/callbacks.py +0 -200
  37. agent/core/experiment.py +0 -249
  38. agent/core/factory.py +0 -122
  39. agent/core/messages.py +0 -332
  40. agent/core/provider_config.py +0 -21
  41. agent/core/telemetry.py +0 -142
  42. agent/core/tools/__init__.py +0 -21
  43. agent/core/tools/base.py +0 -74
  44. agent/core/tools/bash.py +0 -52
  45. agent/core/tools/collection.py +0 -46
  46. agent/core/tools/computer.py +0 -113
  47. agent/core/tools/edit.py +0 -67
  48. agent/core/tools/manager.py +0 -56
  49. agent/core/tools.py +0 -32
  50. agent/core/types.py +0 -88
  51. agent/core/visualization.py +0 -197
  52. agent/providers/__init__.py +0 -4
  53. agent/providers/anthropic/__init__.py +0 -6
  54. agent/providers/anthropic/api/client.py +0 -360
  55. agent/providers/anthropic/api/logging.py +0 -150
  56. agent/providers/anthropic/api_handler.py +0 -140
  57. agent/providers/anthropic/callbacks/__init__.py +0 -5
  58. agent/providers/anthropic/callbacks/manager.py +0 -65
  59. agent/providers/anthropic/loop.py +0 -568
  60. agent/providers/anthropic/prompts.py +0 -23
  61. agent/providers/anthropic/response_handler.py +0 -226
  62. agent/providers/anthropic/tools/__init__.py +0 -33
  63. agent/providers/anthropic/tools/base.py +0 -88
  64. agent/providers/anthropic/tools/bash.py +0 -66
  65. agent/providers/anthropic/tools/collection.py +0 -34
  66. agent/providers/anthropic/tools/computer.py +0 -396
  67. agent/providers/anthropic/tools/edit.py +0 -326
  68. agent/providers/anthropic/tools/manager.py +0 -54
  69. agent/providers/anthropic/tools/run.py +0 -42
  70. agent/providers/anthropic/types.py +0 -16
  71. agent/providers/anthropic/utils.py +0 -367
  72. agent/providers/omni/__init__.py +0 -8
  73. agent/providers/omni/api_handler.py +0 -42
  74. agent/providers/omni/clients/anthropic.py +0 -103
  75. agent/providers/omni/clients/base.py +0 -35
  76. agent/providers/omni/clients/oaicompat.py +0 -195
  77. agent/providers/omni/clients/ollama.py +0 -122
  78. agent/providers/omni/clients/openai.py +0 -155
  79. agent/providers/omni/clients/utils.py +0 -25
  80. agent/providers/omni/image_utils.py +0 -34
  81. agent/providers/omni/loop.py +0 -990
  82. agent/providers/omni/parser.py +0 -307
  83. agent/providers/omni/prompts.py +0 -64
  84. agent/providers/omni/tools/__init__.py +0 -30
  85. agent/providers/omni/tools/base.py +0 -29
  86. agent/providers/omni/tools/bash.py +0 -74
  87. agent/providers/omni/tools/computer.py +0 -179
  88. agent/providers/omni/tools/manager.py +0 -61
  89. agent/providers/omni/utils.py +0 -236
  90. agent/providers/openai/__init__.py +0 -6
  91. agent/providers/openai/api_handler.py +0 -456
  92. agent/providers/openai/loop.py +0 -472
  93. agent/providers/openai/response_handler.py +0 -205
  94. agent/providers/openai/tools/__init__.py +0 -15
  95. agent/providers/openai/tools/base.py +0 -79
  96. agent/providers/openai/tools/computer.py +0 -326
  97. agent/providers/openai/tools/manager.py +0 -106
  98. agent/providers/openai/types.py +0 -36
  99. agent/providers/openai/utils.py +0 -98
  100. agent/providers/uitars/__init__.py +0 -1
  101. agent/providers/uitars/clients/base.py +0 -35
  102. agent/providers/uitars/clients/mlxvlm.py +0 -263
  103. agent/providers/uitars/clients/oaicompat.py +0 -214
  104. agent/providers/uitars/loop.py +0 -660
  105. agent/providers/uitars/prompts.py +0 -63
  106. agent/providers/uitars/tools/__init__.py +0 -1
  107. agent/providers/uitars/tools/computer.py +0 -283
  108. agent/providers/uitars/tools/manager.py +0 -60
  109. agent/providers/uitars/utils.py +0 -264
  110. cua_agent-0.3.1.dist-info/METADATA +0 -295
  111. cua_agent-0.3.1.dist-info/RECORD +0 -87
  112. {cua_agent-0.3.1.dist-info → cua_agent-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -1,197 +0,0 @@
1
- """Core visualization utilities for agents."""
2
-
3
- import logging
4
- import base64
5
- from typing import Dict, Tuple
6
- from PIL import Image, ImageDraw
7
- from io import BytesIO
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
13
- """Visualize a click action by drawing a circle on the screenshot.
14
-
15
- Args:
16
- x: X coordinate of the click
17
- y: Y coordinate of the click
18
- img_base64: Base64-encoded screenshot
19
-
20
- Returns:
21
- PIL Image with visualization
22
- """
23
- try:
24
- # Decode the base64 image
25
- image_data = base64.b64decode(img_base64)
26
- img = Image.open(BytesIO(image_data))
27
-
28
- # Create a copy to draw on
29
- draw_img = img.copy()
30
- draw = ImageDraw.Draw(draw_img)
31
-
32
- # Draw a circle at the click location
33
- radius = 15
34
- draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline="red", width=3)
35
-
36
- # Draw crosshairs
37
- line_length = 20
38
- draw.line([(x - line_length, y), (x + line_length, y)], fill="red", width=3)
39
- draw.line([(x, y - line_length), (x, y + line_length)], fill="red", width=3)
40
-
41
- return draw_img
42
- except Exception as e:
43
- logger.error(f"Error visualizing click: {str(e)}")
44
- # Return a blank image as fallback
45
- return Image.new("RGB", (800, 600), "white")
46
-
47
-
48
- def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
49
- """Visualize a scroll action by drawing arrows on the screenshot.
50
-
51
- Args:
52
- direction: Direction of scroll ('up' or 'down')
53
- clicks: Number of scroll clicks
54
- img_base64: Base64-encoded screenshot
55
-
56
- Returns:
57
- PIL Image with visualization
58
- """
59
- try:
60
- # Decode the base64 image
61
- image_data = base64.b64decode(img_base64)
62
- img = Image.open(BytesIO(image_data))
63
-
64
- # Create a copy to draw on
65
- draw_img = img.copy()
66
- draw = ImageDraw.Draw(draw_img)
67
-
68
- # Calculate parameters for visualization
69
- width, height = img.size
70
- center_x = width // 2
71
-
72
- # Draw arrows to indicate scrolling
73
- arrow_length = min(100, height // 4)
74
- arrow_width = 30
75
- num_arrows = min(clicks, 3) # Don't draw too many arrows
76
-
77
- # Calculate starting position
78
- if direction == "down":
79
- start_y = height // 3
80
- arrow_dir = 1 # Down
81
- else:
82
- start_y = height * 2 // 3
83
- arrow_dir = -1 # Up
84
-
85
- # Draw the arrows
86
- for i in range(num_arrows):
87
- y_pos = start_y + (i * arrow_length * arrow_dir * 0.7)
88
- arrow_top = (center_x, y_pos)
89
- arrow_bottom = (center_x, y_pos + arrow_length * arrow_dir)
90
-
91
- # Draw the main line
92
- draw.line([arrow_top, arrow_bottom], fill="red", width=5)
93
-
94
- # Draw the arrowhead
95
- arrowhead_size = 20
96
- if direction == "down":
97
- draw.line(
98
- [
99
- (center_x - arrow_width // 2, arrow_bottom[1] - arrowhead_size),
100
- arrow_bottom,
101
- (center_x + arrow_width // 2, arrow_bottom[1] - arrowhead_size),
102
- ],
103
- fill="red",
104
- width=5,
105
- )
106
- else:
107
- draw.line(
108
- [
109
- (center_x - arrow_width // 2, arrow_bottom[1] + arrowhead_size),
110
- arrow_bottom,
111
- (center_x + arrow_width // 2, arrow_bottom[1] + arrowhead_size),
112
- ],
113
- fill="red",
114
- width=5,
115
- )
116
-
117
- return draw_img
118
- except Exception as e:
119
- logger.error(f"Error visualizing scroll: {str(e)}")
120
- # Return a blank image as fallback
121
- return Image.new("RGB", (800, 600), "white")
122
-
123
-
124
- def calculate_element_center(bbox: Dict[str, float], width: int, height: int) -> Tuple[int, int]:
125
- """Calculate the center point of a UI element.
126
-
127
- Args:
128
- bbox: Bounding box dictionary with x1, y1, x2, y2 coordinates (0-1 normalized)
129
- width: Screen width in pixels
130
- height: Screen height in pixels
131
-
132
- Returns:
133
- (x, y) tuple with pixel coordinates
134
- """
135
- center_x = int((bbox["x1"] + bbox["x2"]) / 2 * width)
136
- center_y = int((bbox["y1"] + bbox["y2"]) / 2 * height)
137
- return center_x, center_y
138
-
139
-
140
- class VisualizationHelper:
141
- """Helper class for visualizing agent actions."""
142
-
143
- def __init__(self, agent):
144
- """Initialize visualization helper.
145
-
146
- Args:
147
- agent: Reference to the agent that will use this helper
148
- """
149
- self.agent = agent
150
-
151
- def visualize_action(self, x: int, y: int, img_base64: str) -> None:
152
- """Visualize a click action by drawing on the screenshot."""
153
- if (
154
- not self.agent.save_trajectory
155
- or not hasattr(self.agent, "experiment_manager")
156
- or not self.agent.experiment_manager
157
- ):
158
- return
159
-
160
- try:
161
- # Use the visualization utility
162
- img = visualize_click(x, y, img_base64)
163
-
164
- # Save the visualization
165
- self.agent.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
166
- except Exception as e:
167
- logger.error(f"Error visualizing action: {str(e)}")
168
-
169
- def visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
170
- """Visualize a scroll action by drawing arrows on the screenshot."""
171
- if (
172
- not self.agent.save_trajectory
173
- or not hasattr(self.agent, "experiment_manager")
174
- or not self.agent.experiment_manager
175
- ):
176
- return
177
-
178
- try:
179
- # Use the visualization utility
180
- img = visualize_scroll(direction, clicks, img_base64)
181
-
182
- # Save the visualization
183
- self.agent.experiment_manager.save_action_visualization(
184
- img, "scroll", f"{direction}_{clicks}"
185
- )
186
- except Exception as e:
187
- logger.error(f"Error visualizing scroll: {str(e)}")
188
-
189
- def save_action_visualization(
190
- self, img: Image.Image, action_name: str, details: str = ""
191
- ) -> str:
192
- """Save a visualization of an action."""
193
- if hasattr(self.agent, "experiment_manager") and self.agent.experiment_manager:
194
- return self.agent.experiment_manager.save_action_visualization(
195
- img, action_name, details
196
- )
197
- return ""
@@ -1,4 +0,0 @@
1
- """Provider implementations for different AI services."""
2
-
3
- # Import specific providers only when needed to avoid circular imports
4
- __all__ = [] # Let each provider module handle its own exports
@@ -1,6 +0,0 @@
1
- """Anthropic provider implementation."""
2
-
3
- from .loop import AnthropicLoop
4
- from .types import LLMProvider
5
-
6
- __all__ = ["AnthropicLoop", "LLMProvider"]
@@ -1,360 +0,0 @@
1
- from typing import Any, List, Dict, cast
2
- import httpx
3
- import asyncio
4
- from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
5
- from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaToolUnionParam
6
- from ..types import LLMProvider
7
- from .logging import log_api_interaction
8
- import random
9
- import logging
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- class APIConnectionError(Exception):
15
- """Error raised when there are connection issues with the API."""
16
-
17
- pass
18
-
19
-
20
- class BaseAnthropicClient:
21
- """Base class for Anthropic API clients."""
22
-
23
- MAX_RETRIES = 10
24
- INITIAL_RETRY_DELAY = 1.0
25
- MAX_RETRY_DELAY = 60.0
26
- JITTER_FACTOR = 0.1
27
-
28
- async def create_message(
29
- self,
30
- *,
31
- messages: list[BetaMessageParam],
32
- system: list[Any],
33
- tools: list[BetaToolUnionParam],
34
- max_tokens: int,
35
- betas: list[str],
36
- ) -> BetaMessage:
37
- """Create a message using the Anthropic API."""
38
- raise NotImplementedError
39
-
40
- async def _make_api_call_with_retries(self, api_call):
41
- """Make an API call with exponential backoff retry logic.
42
-
43
- Args:
44
- api_call: Async function that makes the actual API call
45
-
46
- Returns:
47
- API response
48
-
49
- Raises:
50
- APIConnectionError: If all retries fail
51
- """
52
- retry_count = 0
53
- last_error = None
54
-
55
- while retry_count < self.MAX_RETRIES:
56
- try:
57
- return await api_call()
58
- except Exception as e:
59
- last_error = e
60
- retry_count += 1
61
-
62
- if retry_count == self.MAX_RETRIES:
63
- break
64
-
65
- # Calculate delay with exponential backoff and jitter
66
- delay = min(
67
- self.INITIAL_RETRY_DELAY * (2 ** (retry_count - 1)), self.MAX_RETRY_DELAY
68
- )
69
- # Add jitter to avoid thundering herd
70
- jitter = delay * self.JITTER_FACTOR * (2 * random.random() - 1)
71
- final_delay = delay + jitter
72
-
73
- logger.info(
74
- f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) "
75
- f"in {final_delay:.2f} seconds after error: {str(e)}"
76
- )
77
- await asyncio.sleep(final_delay)
78
-
79
- raise APIConnectionError(
80
- f"Failed after {self.MAX_RETRIES} retries. " f"Last error: {str(last_error)}"
81
- )
82
-
83
- async def run_interleaved(
84
- self, messages: List[Dict[str, Any]], system: str, max_tokens: int = 4096
85
- ) -> Any:
86
- """Run the Anthropic API with the Claude model, supports interleaved tool calling.
87
-
88
- Args:
89
- messages: List of message objects
90
- system: System prompt
91
- max_tokens: Maximum tokens to generate
92
-
93
- Returns:
94
- API response
95
- """
96
- # Add the tool_result check/fix logic here
97
- fixed_messages = self._fix_missing_tool_results(messages)
98
-
99
- # Get model name from concrete implementation if available
100
- model_name = getattr(self, "model", "unknown model")
101
- logger.info(f"Running Anthropic API call with model {model_name}")
102
-
103
- retry_count = 0
104
-
105
- while retry_count < self.MAX_RETRIES:
106
- try:
107
- # Call the Anthropic API through create_message which is implemented by subclasses
108
- # Convert system str to the list format expected by create_message
109
- system_list = [system]
110
-
111
- # Convert message format if needed - concrete implementations may do further conversion
112
- response = await self.create_message(
113
- messages=cast(list[BetaMessageParam], fixed_messages),
114
- system=system_list,
115
- tools=[], # Tools are included in the messages
116
- max_tokens=max_tokens,
117
- betas=["tools-2023-12-13"],
118
- )
119
- logger.info(f"Anthropic API call successful")
120
- return response
121
- except Exception as e:
122
- retry_count += 1
123
- wait_time = self.INITIAL_RETRY_DELAY * (
124
- 2 ** (retry_count - 1)
125
- ) # Exponential backoff
126
- logger.info(
127
- f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) in {wait_time:.2f} seconds after error: {str(e)}"
128
- )
129
- await asyncio.sleep(wait_time)
130
-
131
- # If we get here, all retries failed
132
- raise RuntimeError(f"Failed to call Anthropic API after {self.MAX_RETRIES} attempts")
133
-
134
- def _fix_missing_tool_results(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
135
- """Check for and fix any missing tool_result blocks after tool_use blocks.
136
-
137
- Args:
138
- messages: List of message objects
139
-
140
- Returns:
141
- Fixed messages with proper tool_result blocks
142
- """
143
- fixed_messages = []
144
- pending_tool_uses = {} # Map of tool_use IDs to their details
145
-
146
- for i, message in enumerate(messages):
147
- # Track any tool_use blocks in this message
148
- if message.get("role") == "assistant" and "content" in message:
149
- content = message.get("content", [])
150
- for block in content:
151
- if isinstance(block, dict) and block.get("type") == "tool_use":
152
- tool_id = block.get("id")
153
- if tool_id:
154
- pending_tool_uses[tool_id] = {
155
- "name": block.get("name", ""),
156
- "input": block.get("input", {}),
157
- }
158
-
159
- # Check if this message handles any pending tool_use blocks
160
- if message.get("role") == "user" and "content" in message:
161
- # Check for tool_result blocks in this message
162
- content = message.get("content", [])
163
- for block in content:
164
- if isinstance(block, dict) and block.get("type") == "tool_result":
165
- tool_id = block.get("tool_use_id")
166
- if tool_id in pending_tool_uses:
167
- # This tool_result handles a pending tool_use
168
- pending_tool_uses.pop(tool_id)
169
-
170
- # Add the message to our fixed list
171
- fixed_messages.append(message)
172
-
173
- # If this is an assistant message with tool_use blocks and there are
174
- # pending tool uses that need to be resolved before the next assistant message
175
- if (
176
- i + 1 < len(messages)
177
- and message.get("role") == "assistant"
178
- and messages[i + 1].get("role") == "assistant"
179
- and pending_tool_uses
180
- ):
181
-
182
- # We need to insert a user message with tool_results for all pending tool_uses
183
- tool_results = []
184
- for tool_id, tool_info in pending_tool_uses.items():
185
- tool_results.append(
186
- {
187
- "type": "tool_result",
188
- "tool_use_id": tool_id,
189
- "content": {
190
- "type": "error",
191
- "message": "Tool execution was skipped or failed",
192
- },
193
- }
194
- )
195
-
196
- # Insert a synthetic user message with the tool results
197
- if tool_results:
198
- fixed_messages.append({"role": "user", "content": tool_results})
199
-
200
- # Clear pending tools since we've added results for them
201
- pending_tool_uses = {}
202
-
203
- # Check if there are any remaining pending tool_uses at the end of the conversation
204
- if pending_tool_uses and fixed_messages and fixed_messages[-1].get("role") == "assistant":
205
- # Add a final user message with tool results for any pending tool_uses
206
- tool_results = []
207
- for tool_id, tool_info in pending_tool_uses.items():
208
- tool_results.append(
209
- {
210
- "type": "tool_result",
211
- "tool_use_id": tool_id,
212
- "content": {
213
- "type": "error",
214
- "message": "Tool execution was skipped or failed",
215
- },
216
- }
217
- )
218
-
219
- if tool_results:
220
- fixed_messages.append({"role": "user", "content": tool_results})
221
-
222
- return fixed_messages
223
-
224
-
225
- class AnthropicDirectClient(BaseAnthropicClient):
226
- """Direct Anthropic API client implementation."""
227
-
228
- def __init__(self, api_key: str, model: str):
229
- self.model = model
230
- self.client = Anthropic(api_key=api_key, http_client=self._create_http_client())
231
-
232
- def _create_http_client(self) -> httpx.Client:
233
- """Create an HTTP client with appropriate settings."""
234
- return httpx.Client(
235
- verify=True,
236
- timeout=httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0),
237
- transport=httpx.HTTPTransport(
238
- retries=3,
239
- verify=True,
240
- limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
241
- ),
242
- )
243
-
244
- async def create_message(
245
- self,
246
- *,
247
- messages: list[BetaMessageParam],
248
- system: list[Any],
249
- tools: list[BetaToolUnionParam],
250
- max_tokens: int,
251
- betas: list[str],
252
- ) -> BetaMessage:
253
- """Create a message using the direct Anthropic API with retry logic."""
254
-
255
- async def api_call():
256
- response = self.client.beta.messages.with_raw_response.create(
257
- max_tokens=max_tokens,
258
- messages=messages,
259
- model=self.model,
260
- system=system,
261
- tools=tools,
262
- betas=betas,
263
- )
264
- log_api_interaction(response.http_response.request, response.http_response, None)
265
- return response.parse()
266
-
267
- try:
268
- return await self._make_api_call_with_retries(api_call)
269
- except Exception as e:
270
- log_api_interaction(None, None, e)
271
- raise
272
-
273
-
274
- class AnthropicVertexClient(BaseAnthropicClient):
275
- """Google Cloud Vertex AI implementation of Anthropic client."""
276
-
277
- def __init__(self, model: str):
278
- self.model = model
279
- self.client = AnthropicVertex()
280
-
281
- async def create_message(
282
- self,
283
- *,
284
- messages: list[BetaMessageParam],
285
- system: list[Any],
286
- tools: list[BetaToolUnionParam],
287
- max_tokens: int,
288
- betas: list[str],
289
- ) -> BetaMessage:
290
- """Create a message using Vertex AI with retry logic."""
291
-
292
- async def api_call():
293
- response = self.client.beta.messages.with_raw_response.create(
294
- max_tokens=max_tokens,
295
- messages=messages,
296
- model=self.model,
297
- system=system,
298
- tools=tools,
299
- betas=betas,
300
- )
301
- log_api_interaction(response.http_response.request, response.http_response, None)
302
- return response.parse()
303
-
304
- try:
305
- return await self._make_api_call_with_retries(api_call)
306
- except Exception as e:
307
- log_api_interaction(None, None, e)
308
- raise
309
-
310
-
311
- class AnthropicBedrockClient(BaseAnthropicClient):
312
- """AWS Bedrock implementation of Anthropic client."""
313
-
314
- def __init__(self, model: str):
315
- self.model = model
316
- self.client = AnthropicBedrock()
317
-
318
- async def create_message(
319
- self,
320
- *,
321
- messages: list[BetaMessageParam],
322
- system: list[Any],
323
- tools: list[BetaToolUnionParam],
324
- max_tokens: int,
325
- betas: list[str],
326
- ) -> BetaMessage:
327
- """Create a message using AWS Bedrock with retry logic."""
328
-
329
- async def api_call():
330
- response = self.client.beta.messages.with_raw_response.create(
331
- max_tokens=max_tokens,
332
- messages=messages,
333
- model=self.model,
334
- system=system,
335
- tools=tools,
336
- betas=betas,
337
- )
338
- log_api_interaction(response.http_response.request, response.http_response, None)
339
- return response.parse()
340
-
341
- try:
342
- return await self._make_api_call_with_retries(api_call)
343
- except Exception as e:
344
- log_api_interaction(None, None, e)
345
- raise
346
-
347
-
348
- class AnthropicClientFactory:
349
- """Factory for creating appropriate Anthropic client implementations."""
350
-
351
- @staticmethod
352
- def create_client(provider: LLMProvider, api_key: str, model: str) -> BaseAnthropicClient:
353
- """Create an appropriate client based on the provider."""
354
- if provider == LLMProvider.ANTHROPIC:
355
- return AnthropicDirectClient(api_key, model)
356
- elif provider == LLMProvider.VERTEX:
357
- return AnthropicVertexClient(model)
358
- elif provider == LLMProvider.BEDROCK:
359
- return AnthropicBedrockClient(model)
360
- raise ValueError(f"Unsupported provider: {provider}")