cua-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/README.md +63 -0
- agent/__init__.py +10 -0
- agent/core/README.md +101 -0
- agent/core/__init__.py +34 -0
- agent/core/agent.py +284 -0
- agent/core/base_agent.py +164 -0
- agent/core/callbacks.py +147 -0
- agent/core/computer_agent.py +69 -0
- agent/core/experiment.py +222 -0
- agent/core/factory.py +102 -0
- agent/core/loop.py +244 -0
- agent/core/messages.py +230 -0
- agent/core/tools/__init__.py +21 -0
- agent/core/tools/base.py +74 -0
- agent/core/tools/bash.py +52 -0
- agent/core/tools/collection.py +46 -0
- agent/core/tools/computer.py +113 -0
- agent/core/tools/edit.py +67 -0
- agent/core/tools/manager.py +56 -0
- agent/providers/__init__.py +4 -0
- agent/providers/anthropic/__init__.py +6 -0
- agent/providers/anthropic/api/client.py +222 -0
- agent/providers/anthropic/api/logging.py +150 -0
- agent/providers/anthropic/callbacks/manager.py +55 -0
- agent/providers/anthropic/loop.py +521 -0
- agent/providers/anthropic/messages/manager.py +110 -0
- agent/providers/anthropic/prompts.py +20 -0
- agent/providers/anthropic/tools/__init__.py +33 -0
- agent/providers/anthropic/tools/base.py +88 -0
- agent/providers/anthropic/tools/bash.py +163 -0
- agent/providers/anthropic/tools/collection.py +34 -0
- agent/providers/anthropic/tools/computer.py +550 -0
- agent/providers/anthropic/tools/edit.py +326 -0
- agent/providers/anthropic/tools/manager.py +54 -0
- agent/providers/anthropic/tools/run.py +42 -0
- agent/providers/anthropic/types.py +16 -0
- agent/providers/omni/__init__.py +27 -0
- agent/providers/omni/callbacks.py +78 -0
- agent/providers/omni/clients/anthropic.py +99 -0
- agent/providers/omni/clients/base.py +44 -0
- agent/providers/omni/clients/groq.py +101 -0
- agent/providers/omni/clients/openai.py +159 -0
- agent/providers/omni/clients/utils.py +25 -0
- agent/providers/omni/experiment.py +273 -0
- agent/providers/omni/image_utils.py +106 -0
- agent/providers/omni/loop.py +961 -0
- agent/providers/omni/messages.py +168 -0
- agent/providers/omni/parser.py +252 -0
- agent/providers/omni/prompts.py +78 -0
- agent/providers/omni/tool_manager.py +91 -0
- agent/providers/omni/tools/__init__.py +13 -0
- agent/providers/omni/tools/bash.py +69 -0
- agent/providers/omni/tools/computer.py +216 -0
- agent/providers/omni/tools/manager.py +83 -0
- agent/providers/omni/types.py +30 -0
- agent/providers/omni/utils.py +155 -0
- agent/providers/omni/visualization.py +130 -0
- agent/types/__init__.py +26 -0
- agent/types/base.py +52 -0
- agent/types/messages.py +36 -0
- agent/types/tools.py +32 -0
- cua_agent-0.1.0.dist-info/METADATA +44 -0
- cua_agent-0.1.0.dist-info/RECORD +65 -0
- cua_agent-0.1.0.dist-info/WHEEL +4 -0
- cua_agent-0.1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""Abstract base computer tool implementation."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import io
|
|
6
|
+
import logging
|
|
7
|
+
from abc import abstractmethod
|
|
8
|
+
from typing import Any, Dict, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
from PIL import Image
|
|
11
|
+
from computer.computer import Computer
|
|
12
|
+
|
|
13
|
+
from .base import BaseTool, ToolError, ToolResult
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BaseComputerTool(BaseTool):
|
|
17
|
+
"""Base class for computer interaction tools across different providers."""
|
|
18
|
+
|
|
19
|
+
name = "computer"
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
width: Optional[int] = None
|
|
23
|
+
height: Optional[int] = None
|
|
24
|
+
display_num: Optional[int] = None
|
|
25
|
+
computer: Computer
|
|
26
|
+
|
|
27
|
+
_screenshot_delay = 1.0 # Default delay for most platforms
|
|
28
|
+
_scaling_enabled = True
|
|
29
|
+
|
|
30
|
+
def __init__(self, computer: Computer):
|
|
31
|
+
"""Initialize the ComputerTool.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
computer: Computer instance for screen interactions
|
|
35
|
+
"""
|
|
36
|
+
self.computer = computer
|
|
37
|
+
|
|
38
|
+
async def initialize_dimensions(self):
|
|
39
|
+
"""Initialize screen dimensions from the computer interface."""
|
|
40
|
+
display_size = await self.computer.interface.get_screen_size()
|
|
41
|
+
self.width = display_size["width"]
|
|
42
|
+
self.height = display_size["height"]
|
|
43
|
+
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def options(self) -> Dict[str, Any]:
|
|
47
|
+
"""Get the options for the tool.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Dictionary with tool options
|
|
51
|
+
"""
|
|
52
|
+
if self.width is None or self.height is None:
|
|
53
|
+
raise RuntimeError(
|
|
54
|
+
"Screen dimensions not initialized. Call initialize_dimensions() first."
|
|
55
|
+
)
|
|
56
|
+
return {
|
|
57
|
+
"display_width_px": self.width,
|
|
58
|
+
"display_height_px": self.height,
|
|
59
|
+
"display_number": self.display_num,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
async def resize_screenshot_if_needed(self, screenshot: bytes) -> bytes:
|
|
63
|
+
"""Resize a screenshot to match the expected dimensions.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
screenshot: Raw screenshot data
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Resized screenshot data
|
|
70
|
+
"""
|
|
71
|
+
if self.width is None or self.height is None:
|
|
72
|
+
raise ToolError("Screen dimensions not initialized")
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
img = Image.open(io.BytesIO(screenshot))
|
|
76
|
+
if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
|
|
77
|
+
img = img.convert("RGB")
|
|
78
|
+
|
|
79
|
+
# Resize if dimensions don't match
|
|
80
|
+
if img.size != (self.width, self.height):
|
|
81
|
+
self.logger.info(
|
|
82
|
+
f"Scaling image from {img.size} to {self.width}x{self.height} to match screen dimensions"
|
|
83
|
+
)
|
|
84
|
+
img = img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
|
85
|
+
|
|
86
|
+
# Save back to bytes
|
|
87
|
+
buffer = io.BytesIO()
|
|
88
|
+
img.save(buffer, format="PNG")
|
|
89
|
+
return buffer.getvalue()
|
|
90
|
+
|
|
91
|
+
return screenshot
|
|
92
|
+
except Exception as e:
|
|
93
|
+
self.logger.error(f"Error during screenshot resizing: {str(e)}")
|
|
94
|
+
raise ToolError(f"Failed to resize screenshot: {str(e)}")
|
|
95
|
+
|
|
96
|
+
async def screenshot(self) -> ToolResult:
|
|
97
|
+
"""Take a screenshot and return it as a ToolResult with base64-encoded image.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
ToolResult with the screenshot
|
|
101
|
+
"""
|
|
102
|
+
try:
|
|
103
|
+
screenshot = await self.computer.interface.screenshot()
|
|
104
|
+
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
|
105
|
+
return ToolResult(base64_image=base64.b64encode(screenshot).decode())
|
|
106
|
+
except Exception as e:
|
|
107
|
+
self.logger.error(f"Error taking screenshot: {str(e)}")
|
|
108
|
+
return ToolResult(error=f"Failed to take screenshot: {str(e)}")
|
|
109
|
+
|
|
110
|
+
@abstractmethod
|
|
111
|
+
async def __call__(self, **kwargs) -> ToolResult:
|
|
112
|
+
"""Execute the tool with the provided arguments."""
|
|
113
|
+
raise NotImplementedError
|
agent/core/tools/edit.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Abstract base edit tool implementation."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
from abc import abstractmethod
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
from computer.computer import Computer
|
|
11
|
+
|
|
12
|
+
from .base import BaseTool, ToolError, ToolResult
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseEditTool(BaseTool):
|
|
16
|
+
"""Base class for text editor tools across different providers."""
|
|
17
|
+
|
|
18
|
+
name = "edit"
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
computer: Computer
|
|
21
|
+
|
|
22
|
+
def __init__(self, computer: Computer):
|
|
23
|
+
"""Initialize the EditTool.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
computer: Computer instance, may be used for related operations
|
|
27
|
+
"""
|
|
28
|
+
self.computer = computer
|
|
29
|
+
|
|
30
|
+
async def read_file(self, path: str) -> str:
|
|
31
|
+
"""Read a file and return its contents.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
path: Path to the file to read
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
File contents as a string
|
|
38
|
+
"""
|
|
39
|
+
try:
|
|
40
|
+
path_obj = Path(path)
|
|
41
|
+
if not path_obj.exists():
|
|
42
|
+
raise ToolError(f"File does not exist: {path}")
|
|
43
|
+
return path_obj.read_text()
|
|
44
|
+
except Exception as e:
|
|
45
|
+
self.logger.error(f"Error reading file: {str(e)}")
|
|
46
|
+
raise ToolError(f"Failed to read file: {str(e)}")
|
|
47
|
+
|
|
48
|
+
async def write_file(self, path: str, content: str) -> None:
|
|
49
|
+
"""Write content to a file.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
path: Path to the file to write
|
|
53
|
+
content: Content to write to the file
|
|
54
|
+
"""
|
|
55
|
+
try:
|
|
56
|
+
path_obj = Path(path)
|
|
57
|
+
# Create parent directories if they don't exist
|
|
58
|
+
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
59
|
+
path_obj.write_text(content)
|
|
60
|
+
except Exception as e:
|
|
61
|
+
self.logger.error(f"Error writing file: {str(e)}")
|
|
62
|
+
raise ToolError(f"Failed to write file: {str(e)}")
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
async def __call__(self, **kwargs) -> ToolResult:
|
|
66
|
+
"""Execute the tool with the provided arguments."""
|
|
67
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Tool manager for initializing and running tools."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
from computer.computer import Computer
|
|
7
|
+
|
|
8
|
+
from .base import BaseTool, ToolResult
|
|
9
|
+
from .collection import ToolCollection
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseToolManager(ABC):
|
|
13
|
+
"""Base class for tool managers across different providers."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, computer: Computer):
|
|
16
|
+
"""Initialize the tool manager.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
computer: Computer instance for computer-related tools
|
|
20
|
+
"""
|
|
21
|
+
self.computer = computer
|
|
22
|
+
self.tools: ToolCollection | None = None
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def _initialize_tools(self) -> ToolCollection:
|
|
26
|
+
"""Initialize all available tools."""
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
async def initialize(self) -> None:
|
|
30
|
+
"""Initialize tool-specific requirements and create tool collection."""
|
|
31
|
+
await self._initialize_tools_specific()
|
|
32
|
+
self.tools = self._initialize_tools()
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
async def _initialize_tools_specific(self) -> None:
|
|
36
|
+
"""Initialize provider-specific tool requirements."""
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def get_tool_params(self) -> List[Dict[str, Any]]:
|
|
41
|
+
"""Get tool parameters for API calls."""
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
async def execute_tool(self, name: str, tool_input: Dict[str, Any]) -> ToolResult:
|
|
45
|
+
"""Execute a tool with the given input.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
name: Name of the tool to execute
|
|
49
|
+
tool_input: Input parameters for the tool
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Result of the tool execution
|
|
53
|
+
"""
|
|
54
|
+
if self.tools is None:
|
|
55
|
+
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
|
56
|
+
return await self.tools.run(name=name, tool_input=tool_input)
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
import httpx
|
|
3
|
+
import asyncio
|
|
4
|
+
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
|
|
5
|
+
from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaToolUnionParam
|
|
6
|
+
from ..types import APIProvider
|
|
7
|
+
from .logging import log_api_interaction
|
|
8
|
+
import random
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
class APIConnectionError(Exception):
|
|
14
|
+
"""Error raised when there are connection issues with the API."""
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
class BaseAnthropicClient:
|
|
18
|
+
"""Base class for Anthropic API clients."""
|
|
19
|
+
|
|
20
|
+
MAX_RETRIES = 10
|
|
21
|
+
INITIAL_RETRY_DELAY = 1.0
|
|
22
|
+
MAX_RETRY_DELAY = 60.0
|
|
23
|
+
JITTER_FACTOR = 0.1
|
|
24
|
+
|
|
25
|
+
async def create_message(
|
|
26
|
+
self,
|
|
27
|
+
*,
|
|
28
|
+
messages: list[BetaMessageParam],
|
|
29
|
+
system: list[Any],
|
|
30
|
+
tools: list[BetaToolUnionParam],
|
|
31
|
+
max_tokens: int,
|
|
32
|
+
betas: list[str],
|
|
33
|
+
) -> BetaMessage:
|
|
34
|
+
"""Create a message using the Anthropic API."""
|
|
35
|
+
raise NotImplementedError
|
|
36
|
+
|
|
37
|
+
async def _make_api_call_with_retries(self, api_call):
|
|
38
|
+
"""Make an API call with exponential backoff retry logic.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
api_call: Async function that makes the actual API call
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
API response
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
APIConnectionError: If all retries fail
|
|
48
|
+
"""
|
|
49
|
+
retry_count = 0
|
|
50
|
+
last_error = None
|
|
51
|
+
|
|
52
|
+
while retry_count < self.MAX_RETRIES:
|
|
53
|
+
try:
|
|
54
|
+
return await api_call()
|
|
55
|
+
except Exception as e:
|
|
56
|
+
last_error = e
|
|
57
|
+
retry_count += 1
|
|
58
|
+
|
|
59
|
+
if retry_count == self.MAX_RETRIES:
|
|
60
|
+
break
|
|
61
|
+
|
|
62
|
+
# Calculate delay with exponential backoff and jitter
|
|
63
|
+
delay = min(
|
|
64
|
+
self.INITIAL_RETRY_DELAY * (2 ** (retry_count - 1)),
|
|
65
|
+
self.MAX_RETRY_DELAY
|
|
66
|
+
)
|
|
67
|
+
# Add jitter to avoid thundering herd
|
|
68
|
+
jitter = delay * self.JITTER_FACTOR * (2 * random.random() - 1)
|
|
69
|
+
final_delay = delay + jitter
|
|
70
|
+
|
|
71
|
+
logger.info(
|
|
72
|
+
f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) "
|
|
73
|
+
f"in {final_delay:.2f} seconds after error: {str(e)}"
|
|
74
|
+
)
|
|
75
|
+
await asyncio.sleep(final_delay)
|
|
76
|
+
|
|
77
|
+
raise APIConnectionError(
|
|
78
|
+
f"Failed after {self.MAX_RETRIES} retries. "
|
|
79
|
+
f"Last error: {str(last_error)}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
class AnthropicDirectClient(BaseAnthropicClient):
|
|
83
|
+
"""Direct Anthropic API client implementation."""
|
|
84
|
+
|
|
85
|
+
def __init__(self, api_key: str, model: str):
|
|
86
|
+
self.model = model
|
|
87
|
+
self.client = Anthropic(
|
|
88
|
+
api_key=api_key,
|
|
89
|
+
http_client=self._create_http_client()
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def _create_http_client(self) -> httpx.Client:
|
|
93
|
+
"""Create an HTTP client with appropriate settings."""
|
|
94
|
+
return httpx.Client(
|
|
95
|
+
verify=True,
|
|
96
|
+
timeout=httpx.Timeout(
|
|
97
|
+
connect=30.0,
|
|
98
|
+
read=300.0,
|
|
99
|
+
write=30.0,
|
|
100
|
+
pool=30.0
|
|
101
|
+
),
|
|
102
|
+
transport=httpx.HTTPTransport(
|
|
103
|
+
retries=3,
|
|
104
|
+
verify=True,
|
|
105
|
+
limits=httpx.Limits(
|
|
106
|
+
max_keepalive_connections=5,
|
|
107
|
+
max_connections=10
|
|
108
|
+
)
|
|
109
|
+
)
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
async def create_message(
|
|
113
|
+
self,
|
|
114
|
+
*,
|
|
115
|
+
messages: list[BetaMessageParam],
|
|
116
|
+
system: list[Any],
|
|
117
|
+
tools: list[BetaToolUnionParam],
|
|
118
|
+
max_tokens: int,
|
|
119
|
+
betas: list[str],
|
|
120
|
+
) -> BetaMessage:
|
|
121
|
+
"""Create a message using the direct Anthropic API with retry logic."""
|
|
122
|
+
async def api_call():
|
|
123
|
+
response = self.client.beta.messages.with_raw_response.create(
|
|
124
|
+
max_tokens=max_tokens,
|
|
125
|
+
messages=messages,
|
|
126
|
+
model=self.model,
|
|
127
|
+
system=system,
|
|
128
|
+
tools=tools,
|
|
129
|
+
betas=betas,
|
|
130
|
+
)
|
|
131
|
+
log_api_interaction(response.http_response.request, response.http_response, None)
|
|
132
|
+
return response.parse()
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
return await self._make_api_call_with_retries(api_call)
|
|
136
|
+
except Exception as e:
|
|
137
|
+
log_api_interaction(None, None, e)
|
|
138
|
+
raise
|
|
139
|
+
|
|
140
|
+
class AnthropicVertexClient(BaseAnthropicClient):
|
|
141
|
+
"""Google Cloud Vertex AI implementation of Anthropic client."""
|
|
142
|
+
|
|
143
|
+
def __init__(self, model: str):
|
|
144
|
+
self.model = model
|
|
145
|
+
self.client = AnthropicVertex()
|
|
146
|
+
|
|
147
|
+
async def create_message(
|
|
148
|
+
self,
|
|
149
|
+
*,
|
|
150
|
+
messages: list[BetaMessageParam],
|
|
151
|
+
system: list[Any],
|
|
152
|
+
tools: list[BetaToolUnionParam],
|
|
153
|
+
max_tokens: int,
|
|
154
|
+
betas: list[str],
|
|
155
|
+
) -> BetaMessage:
|
|
156
|
+
"""Create a message using Vertex AI with retry logic."""
|
|
157
|
+
async def api_call():
|
|
158
|
+
response = self.client.beta.messages.with_raw_response.create(
|
|
159
|
+
max_tokens=max_tokens,
|
|
160
|
+
messages=messages,
|
|
161
|
+
model=self.model,
|
|
162
|
+
system=system,
|
|
163
|
+
tools=tools,
|
|
164
|
+
betas=betas,
|
|
165
|
+
)
|
|
166
|
+
log_api_interaction(response.http_response.request, response.http_response, None)
|
|
167
|
+
return response.parse()
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
return await self._make_api_call_with_retries(api_call)
|
|
171
|
+
except Exception as e:
|
|
172
|
+
log_api_interaction(None, None, e)
|
|
173
|
+
raise
|
|
174
|
+
|
|
175
|
+
class AnthropicBedrockClient(BaseAnthropicClient):
|
|
176
|
+
"""AWS Bedrock implementation of Anthropic client."""
|
|
177
|
+
|
|
178
|
+
def __init__(self, model: str):
|
|
179
|
+
self.model = model
|
|
180
|
+
self.client = AnthropicBedrock()
|
|
181
|
+
|
|
182
|
+
async def create_message(
|
|
183
|
+
self,
|
|
184
|
+
*,
|
|
185
|
+
messages: list[BetaMessageParam],
|
|
186
|
+
system: list[Any],
|
|
187
|
+
tools: list[BetaToolUnionParam],
|
|
188
|
+
max_tokens: int,
|
|
189
|
+
betas: list[str],
|
|
190
|
+
) -> BetaMessage:
|
|
191
|
+
"""Create a message using AWS Bedrock with retry logic."""
|
|
192
|
+
async def api_call():
|
|
193
|
+
response = self.client.beta.messages.with_raw_response.create(
|
|
194
|
+
max_tokens=max_tokens,
|
|
195
|
+
messages=messages,
|
|
196
|
+
model=self.model,
|
|
197
|
+
system=system,
|
|
198
|
+
tools=tools,
|
|
199
|
+
betas=betas,
|
|
200
|
+
)
|
|
201
|
+
log_api_interaction(response.http_response.request, response.http_response, None)
|
|
202
|
+
return response.parse()
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
return await self._make_api_call_with_retries(api_call)
|
|
206
|
+
except Exception as e:
|
|
207
|
+
log_api_interaction(None, None, e)
|
|
208
|
+
raise
|
|
209
|
+
|
|
210
|
+
class AnthropicClientFactory:
|
|
211
|
+
"""Factory for creating appropriate Anthropic client implementations."""
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def create_client(provider: APIProvider, api_key: str, model: str) -> BaseAnthropicClient:
|
|
215
|
+
"""Create an appropriate client based on the provider."""
|
|
216
|
+
if provider == APIProvider.ANTHROPIC:
|
|
217
|
+
return AnthropicDirectClient(api_key, model)
|
|
218
|
+
elif provider == APIProvider.VERTEX:
|
|
219
|
+
return AnthropicVertexClient(model)
|
|
220
|
+
elif provider == APIProvider.BEDROCK:
|
|
221
|
+
return AnthropicBedrockClient(model)
|
|
222
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""API logging functionality."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
import httpx
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
def _filter_base64_images(content: Any) -> Any:
|
|
13
|
+
"""Filter out base64 image data from content.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
content: Content to filter
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Filtered content with base64 data replaced by placeholder
|
|
20
|
+
"""
|
|
21
|
+
if isinstance(content, dict):
|
|
22
|
+
filtered = {}
|
|
23
|
+
for key, value in content.items():
|
|
24
|
+
if (
|
|
25
|
+
isinstance(value, dict)
|
|
26
|
+
and value.get("type") == "image"
|
|
27
|
+
and value.get("source", {}).get("type") == "base64"
|
|
28
|
+
):
|
|
29
|
+
# Replace base64 data with placeholder
|
|
30
|
+
filtered[key] = {
|
|
31
|
+
**value,
|
|
32
|
+
"source": {
|
|
33
|
+
**value["source"],
|
|
34
|
+
"data": "<base64_image_data>"
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
else:
|
|
38
|
+
filtered[key] = _filter_base64_images(value)
|
|
39
|
+
return filtered
|
|
40
|
+
elif isinstance(content, list):
|
|
41
|
+
return [_filter_base64_images(item) for item in content]
|
|
42
|
+
return content
|
|
43
|
+
|
|
44
|
+
def log_api_interaction(
|
|
45
|
+
request: httpx.Request | None,
|
|
46
|
+
response: httpx.Response | object | None,
|
|
47
|
+
error: Exception | None,
|
|
48
|
+
log_dir: Path = Path("/tmp/claude_logs")
|
|
49
|
+
) -> None:
|
|
50
|
+
"""Log API request, response, and any errors in a structured way.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
request: The HTTP request if available
|
|
54
|
+
response: The HTTP response or response object
|
|
55
|
+
error: Any error that occurred
|
|
56
|
+
log_dir: Directory to store log files
|
|
57
|
+
"""
|
|
58
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
|
59
|
+
|
|
60
|
+
# Helper function to safely decode JSON content
|
|
61
|
+
def safe_json_decode(content):
|
|
62
|
+
if not content:
|
|
63
|
+
return None
|
|
64
|
+
try:
|
|
65
|
+
if isinstance(content, bytes):
|
|
66
|
+
return json.loads(content.decode())
|
|
67
|
+
elif isinstance(content, str):
|
|
68
|
+
return json.loads(content)
|
|
69
|
+
elif isinstance(content, dict):
|
|
70
|
+
return content
|
|
71
|
+
return None
|
|
72
|
+
except json.JSONDecodeError:
|
|
73
|
+
return {"error": "Could not decode JSON", "raw": str(content)}
|
|
74
|
+
|
|
75
|
+
# Process request content
|
|
76
|
+
request_content = None
|
|
77
|
+
if request and request.content:
|
|
78
|
+
request_content = safe_json_decode(request.content)
|
|
79
|
+
request_content = _filter_base64_images(request_content)
|
|
80
|
+
|
|
81
|
+
# Process response content
|
|
82
|
+
response_content = None
|
|
83
|
+
if response:
|
|
84
|
+
if isinstance(response, httpx.Response):
|
|
85
|
+
try:
|
|
86
|
+
response_content = response.json()
|
|
87
|
+
except json.JSONDecodeError:
|
|
88
|
+
response_content = {"error": "Could not decode JSON", "raw": response.text}
|
|
89
|
+
else:
|
|
90
|
+
response_content = safe_json_decode(response)
|
|
91
|
+
response_content = _filter_base64_images(response_content)
|
|
92
|
+
|
|
93
|
+
log_entry = {
|
|
94
|
+
"timestamp": timestamp,
|
|
95
|
+
"request": {
|
|
96
|
+
"method": request.method if request else None,
|
|
97
|
+
"url": str(request.url) if request else None,
|
|
98
|
+
"headers": dict(request.headers) if request else None,
|
|
99
|
+
"content": request_content,
|
|
100
|
+
} if request else None,
|
|
101
|
+
"response": {
|
|
102
|
+
"status_code": response.status_code if isinstance(response, httpx.Response) else None,
|
|
103
|
+
"headers": dict(response.headers) if isinstance(response, httpx.Response) else None,
|
|
104
|
+
"content": response_content,
|
|
105
|
+
} if response else None,
|
|
106
|
+
"error": {
|
|
107
|
+
"type": type(error).__name__ if error else None,
|
|
108
|
+
"message": str(error) if error else None,
|
|
109
|
+
} if error else None
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
# Log to file with timestamp in filename
|
|
113
|
+
log_dir.mkdir(exist_ok=True)
|
|
114
|
+
log_file = log_dir / f"claude_api_{timestamp.replace(' ', '_').replace(':', '-')}.json"
|
|
115
|
+
|
|
116
|
+
with open(log_file, 'w') as f:
|
|
117
|
+
json.dump(log_entry, f, indent=2)
|
|
118
|
+
|
|
119
|
+
# Also log a summary to the console
|
|
120
|
+
if error:
|
|
121
|
+
logger.error(f"API Error at {timestamp}: {error}")
|
|
122
|
+
else:
|
|
123
|
+
logger.info(
|
|
124
|
+
f"API Call at {timestamp}: "
|
|
125
|
+
f"{request.method if request else 'No request'} -> "
|
|
126
|
+
f"{response.status_code if isinstance(response, httpx.Response) else 'No response'}"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Log if there are any images in the content
|
|
130
|
+
if response_content:
|
|
131
|
+
image_count = count_images(response_content)
|
|
132
|
+
if image_count > 0:
|
|
133
|
+
logger.info(f"Response contains {image_count} images")
|
|
134
|
+
|
|
135
|
+
def count_images(content: dict | list | Any) -> int:
|
|
136
|
+
"""Count the number of images in the content.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
content: Content to search for images
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Number of images found
|
|
143
|
+
"""
|
|
144
|
+
if isinstance(content, dict):
|
|
145
|
+
if content.get("type") == "image":
|
|
146
|
+
return 1
|
|
147
|
+
return sum(count_images(v) for v in content.values())
|
|
148
|
+
elif isinstance(content, list):
|
|
149
|
+
return sum(count_images(item) for item in content)
|
|
150
|
+
return 0
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Callable, Protocol
|
|
2
|
+
import httpx
|
|
3
|
+
from anthropic.types.beta import BetaContentBlockParam
|
|
4
|
+
from ..tools import ToolResult
|
|
5
|
+
|
|
6
|
+
class APICallback(Protocol):
|
|
7
|
+
"""Protocol for API callbacks."""
|
|
8
|
+
def __call__(self, request: httpx.Request | None,
|
|
9
|
+
response: httpx.Response | object | None,
|
|
10
|
+
error: Exception | None) -> None: ...
|
|
11
|
+
|
|
12
|
+
class ContentCallback(Protocol):
|
|
13
|
+
"""Protocol for content callbacks."""
|
|
14
|
+
def __call__(self, content: BetaContentBlockParam) -> None: ...
|
|
15
|
+
|
|
16
|
+
class ToolCallback(Protocol):
|
|
17
|
+
"""Protocol for tool callbacks."""
|
|
18
|
+
def __call__(self, result: ToolResult, tool_id: str) -> None: ...
|
|
19
|
+
|
|
20
|
+
class CallbackManager:
|
|
21
|
+
"""Manages various callbacks for the agent system."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
content_callback: ContentCallback,
|
|
26
|
+
tool_callback: ToolCallback,
|
|
27
|
+
api_callback: APICallback,
|
|
28
|
+
):
|
|
29
|
+
"""Initialize the callback manager.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
content_callback: Callback for content updates
|
|
33
|
+
tool_callback: Callback for tool execution results
|
|
34
|
+
api_callback: Callback for API interactions
|
|
35
|
+
"""
|
|
36
|
+
self.content_callback = content_callback
|
|
37
|
+
self.tool_callback = tool_callback
|
|
38
|
+
self.api_callback = api_callback
|
|
39
|
+
|
|
40
|
+
def on_content(self, content: BetaContentBlockParam) -> None:
|
|
41
|
+
"""Handle content updates."""
|
|
42
|
+
self.content_callback(content)
|
|
43
|
+
|
|
44
|
+
def on_tool_result(self, result: ToolResult, tool_id: str) -> None:
|
|
45
|
+
"""Handle tool execution results."""
|
|
46
|
+
self.tool_callback(result, tool_id)
|
|
47
|
+
|
|
48
|
+
def on_api_interaction(
|
|
49
|
+
self,
|
|
50
|
+
request: httpx.Request | None,
|
|
51
|
+
response: httpx.Response | object | None,
|
|
52
|
+
error: Exception | None
|
|
53
|
+
) -> None:
|
|
54
|
+
"""Handle API interactions."""
|
|
55
|
+
self.api_callback(request, response, error)
|