cua-agent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (65) hide show
  1. agent/README.md +63 -0
  2. agent/__init__.py +10 -0
  3. agent/core/README.md +101 -0
  4. agent/core/__init__.py +34 -0
  5. agent/core/agent.py +284 -0
  6. agent/core/base_agent.py +164 -0
  7. agent/core/callbacks.py +147 -0
  8. agent/core/computer_agent.py +69 -0
  9. agent/core/experiment.py +222 -0
  10. agent/core/factory.py +102 -0
  11. agent/core/loop.py +244 -0
  12. agent/core/messages.py +230 -0
  13. agent/core/tools/__init__.py +21 -0
  14. agent/core/tools/base.py +74 -0
  15. agent/core/tools/bash.py +52 -0
  16. agent/core/tools/collection.py +46 -0
  17. agent/core/tools/computer.py +113 -0
  18. agent/core/tools/edit.py +67 -0
  19. agent/core/tools/manager.py +56 -0
  20. agent/providers/__init__.py +4 -0
  21. agent/providers/anthropic/__init__.py +6 -0
  22. agent/providers/anthropic/api/client.py +222 -0
  23. agent/providers/anthropic/api/logging.py +150 -0
  24. agent/providers/anthropic/callbacks/manager.py +55 -0
  25. agent/providers/anthropic/loop.py +521 -0
  26. agent/providers/anthropic/messages/manager.py +110 -0
  27. agent/providers/anthropic/prompts.py +20 -0
  28. agent/providers/anthropic/tools/__init__.py +33 -0
  29. agent/providers/anthropic/tools/base.py +88 -0
  30. agent/providers/anthropic/tools/bash.py +163 -0
  31. agent/providers/anthropic/tools/collection.py +34 -0
  32. agent/providers/anthropic/tools/computer.py +550 -0
  33. agent/providers/anthropic/tools/edit.py +326 -0
  34. agent/providers/anthropic/tools/manager.py +54 -0
  35. agent/providers/anthropic/tools/run.py +42 -0
  36. agent/providers/anthropic/types.py +16 -0
  37. agent/providers/omni/__init__.py +27 -0
  38. agent/providers/omni/callbacks.py +78 -0
  39. agent/providers/omni/clients/anthropic.py +99 -0
  40. agent/providers/omni/clients/base.py +44 -0
  41. agent/providers/omni/clients/groq.py +101 -0
  42. agent/providers/omni/clients/openai.py +159 -0
  43. agent/providers/omni/clients/utils.py +25 -0
  44. agent/providers/omni/experiment.py +273 -0
  45. agent/providers/omni/image_utils.py +106 -0
  46. agent/providers/omni/loop.py +961 -0
  47. agent/providers/omni/messages.py +168 -0
  48. agent/providers/omni/parser.py +252 -0
  49. agent/providers/omni/prompts.py +78 -0
  50. agent/providers/omni/tool_manager.py +91 -0
  51. agent/providers/omni/tools/__init__.py +13 -0
  52. agent/providers/omni/tools/bash.py +69 -0
  53. agent/providers/omni/tools/computer.py +216 -0
  54. agent/providers/omni/tools/manager.py +83 -0
  55. agent/providers/omni/types.py +30 -0
  56. agent/providers/omni/utils.py +155 -0
  57. agent/providers/omni/visualization.py +130 -0
  58. agent/types/__init__.py +26 -0
  59. agent/types/base.py +52 -0
  60. agent/types/messages.py +36 -0
  61. agent/types/tools.py +32 -0
  62. cua_agent-0.1.0.dist-info/METADATA +44 -0
  63. cua_agent-0.1.0.dist-info/RECORD +65 -0
  64. cua_agent-0.1.0.dist-info/WHEEL +4 -0
  65. cua_agent-0.1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,113 @@
1
+ """Abstract base computer tool implementation."""
2
+
3
+ import asyncio
4
+ import base64
5
+ import io
6
+ import logging
7
+ from abc import abstractmethod
8
+ from typing import Any, Dict, Optional, Tuple
9
+
10
+ from PIL import Image
11
+ from computer.computer import Computer
12
+
13
+ from .base import BaseTool, ToolError, ToolResult
14
+
15
+
16
+ class BaseComputerTool(BaseTool):
17
+ """Base class for computer interaction tools across different providers."""
18
+
19
+ name = "computer"
20
+ logger = logging.getLogger(__name__)
21
+
22
+ width: Optional[int] = None
23
+ height: Optional[int] = None
24
+ display_num: Optional[int] = None
25
+ computer: Computer
26
+
27
+ _screenshot_delay = 1.0 # Default delay for most platforms
28
+ _scaling_enabled = True
29
+
30
+ def __init__(self, computer: Computer):
31
+ """Initialize the ComputerTool.
32
+
33
+ Args:
34
+ computer: Computer instance for screen interactions
35
+ """
36
+ self.computer = computer
37
+
38
+ async def initialize_dimensions(self):
39
+ """Initialize screen dimensions from the computer interface."""
40
+ display_size = await self.computer.interface.get_screen_size()
41
+ self.width = display_size["width"]
42
+ self.height = display_size["height"]
43
+ self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
44
+
45
+ @property
46
+ def options(self) -> Dict[str, Any]:
47
+ """Get the options for the tool.
48
+
49
+ Returns:
50
+ Dictionary with tool options
51
+ """
52
+ if self.width is None or self.height is None:
53
+ raise RuntimeError(
54
+ "Screen dimensions not initialized. Call initialize_dimensions() first."
55
+ )
56
+ return {
57
+ "display_width_px": self.width,
58
+ "display_height_px": self.height,
59
+ "display_number": self.display_num,
60
+ }
61
+
62
+ async def resize_screenshot_if_needed(self, screenshot: bytes) -> bytes:
63
+ """Resize a screenshot to match the expected dimensions.
64
+
65
+ Args:
66
+ screenshot: Raw screenshot data
67
+
68
+ Returns:
69
+ Resized screenshot data
70
+ """
71
+ if self.width is None or self.height is None:
72
+ raise ToolError("Screen dimensions not initialized")
73
+
74
+ try:
75
+ img = Image.open(io.BytesIO(screenshot))
76
+ if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
77
+ img = img.convert("RGB")
78
+
79
+ # Resize if dimensions don't match
80
+ if img.size != (self.width, self.height):
81
+ self.logger.info(
82
+ f"Scaling image from {img.size} to {self.width}x{self.height} to match screen dimensions"
83
+ )
84
+ img = img.resize((self.width, self.height), Image.Resampling.LANCZOS)
85
+
86
+ # Save back to bytes
87
+ buffer = io.BytesIO()
88
+ img.save(buffer, format="PNG")
89
+ return buffer.getvalue()
90
+
91
+ return screenshot
92
+ except Exception as e:
93
+ self.logger.error(f"Error during screenshot resizing: {str(e)}")
94
+ raise ToolError(f"Failed to resize screenshot: {str(e)}")
95
+
96
+ async def screenshot(self) -> ToolResult:
97
+ """Take a screenshot and return it as a ToolResult with base64-encoded image.
98
+
99
+ Returns:
100
+ ToolResult with the screenshot
101
+ """
102
+ try:
103
+ screenshot = await self.computer.interface.screenshot()
104
+ screenshot = await self.resize_screenshot_if_needed(screenshot)
105
+ return ToolResult(base64_image=base64.b64encode(screenshot).decode())
106
+ except Exception as e:
107
+ self.logger.error(f"Error taking screenshot: {str(e)}")
108
+ return ToolResult(error=f"Failed to take screenshot: {str(e)}")
109
+
110
+ @abstractmethod
111
+ async def __call__(self, **kwargs) -> ToolResult:
112
+ """Execute the tool with the provided arguments."""
113
+ raise NotImplementedError
@@ -0,0 +1,67 @@
1
+ """Abstract base edit tool implementation."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import os
6
+ from abc import abstractmethod
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional
9
+
10
+ from computer.computer import Computer
11
+
12
+ from .base import BaseTool, ToolError, ToolResult
13
+
14
+
15
+ class BaseEditTool(BaseTool):
16
+ """Base class for text editor tools across different providers."""
17
+
18
+ name = "edit"
19
+ logger = logging.getLogger(__name__)
20
+ computer: Computer
21
+
22
+ def __init__(self, computer: Computer):
23
+ """Initialize the EditTool.
24
+
25
+ Args:
26
+ computer: Computer instance, may be used for related operations
27
+ """
28
+ self.computer = computer
29
+
30
+ async def read_file(self, path: str) -> str:
31
+ """Read a file and return its contents.
32
+
33
+ Args:
34
+ path: Path to the file to read
35
+
36
+ Returns:
37
+ File contents as a string
38
+ """
39
+ try:
40
+ path_obj = Path(path)
41
+ if not path_obj.exists():
42
+ raise ToolError(f"File does not exist: {path}")
43
+ return path_obj.read_text()
44
+ except Exception as e:
45
+ self.logger.error(f"Error reading file: {str(e)}")
46
+ raise ToolError(f"Failed to read file: {str(e)}")
47
+
48
+ async def write_file(self, path: str, content: str) -> None:
49
+ """Write content to a file.
50
+
51
+ Args:
52
+ path: Path to the file to write
53
+ content: Content to write to the file
54
+ """
55
+ try:
56
+ path_obj = Path(path)
57
+ # Create parent directories if they don't exist
58
+ path_obj.parent.mkdir(parents=True, exist_ok=True)
59
+ path_obj.write_text(content)
60
+ except Exception as e:
61
+ self.logger.error(f"Error writing file: {str(e)}")
62
+ raise ToolError(f"Failed to write file: {str(e)}")
63
+
64
+ @abstractmethod
65
+ async def __call__(self, **kwargs) -> ToolResult:
66
+ """Execute the tool with the provided arguments."""
67
+ raise NotImplementedError
@@ -0,0 +1,56 @@
1
+ """Tool manager for initializing and running tools."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Dict, List
5
+
6
+ from computer.computer import Computer
7
+
8
+ from .base import BaseTool, ToolResult
9
+ from .collection import ToolCollection
10
+
11
+
12
+ class BaseToolManager(ABC):
13
+ """Base class for tool managers across different providers."""
14
+
15
+ def __init__(self, computer: Computer):
16
+ """Initialize the tool manager.
17
+
18
+ Args:
19
+ computer: Computer instance for computer-related tools
20
+ """
21
+ self.computer = computer
22
+ self.tools: ToolCollection | None = None
23
+
24
+ @abstractmethod
25
+ def _initialize_tools(self) -> ToolCollection:
26
+ """Initialize all available tools."""
27
+ ...
28
+
29
+ async def initialize(self) -> None:
30
+ """Initialize tool-specific requirements and create tool collection."""
31
+ await self._initialize_tools_specific()
32
+ self.tools = self._initialize_tools()
33
+
34
+ @abstractmethod
35
+ async def _initialize_tools_specific(self) -> None:
36
+ """Initialize provider-specific tool requirements."""
37
+ ...
38
+
39
+ @abstractmethod
40
+ def get_tool_params(self) -> List[Dict[str, Any]]:
41
+ """Get tool parameters for API calls."""
42
+ ...
43
+
44
+ async def execute_tool(self, name: str, tool_input: Dict[str, Any]) -> ToolResult:
45
+ """Execute a tool with the given input.
46
+
47
+ Args:
48
+ name: Name of the tool to execute
49
+ tool_input: Input parameters for the tool
50
+
51
+ Returns:
52
+ Result of the tool execution
53
+ """
54
+ if self.tools is None:
55
+ raise RuntimeError("Tools not initialized. Call initialize() first.")
56
+ return await self.tools.run(name=name, tool_input=tool_input)
@@ -0,0 +1,4 @@
1
+ """Provider implementations for different AI services."""
2
+
3
+ # Import specific providers only when needed to avoid circular imports
4
+ __all__ = [] # Let each provider module handle its own exports
@@ -0,0 +1,6 @@
1
+ """Anthropic provider implementation."""
2
+
3
+ from .loop import AnthropicLoop
4
+ from .types import APIProvider
5
+
6
+ __all__ = ["AnthropicLoop", "APIProvider"]
@@ -0,0 +1,222 @@
1
+ from typing import Any
2
+ import httpx
3
+ import asyncio
4
+ from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
5
+ from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaToolUnionParam
6
+ from ..types import APIProvider
7
+ from .logging import log_api_interaction
8
+ import random
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class APIConnectionError(Exception):
14
+ """Error raised when there are connection issues with the API."""
15
+ pass
16
+
17
+ class BaseAnthropicClient:
18
+ """Base class for Anthropic API clients."""
19
+
20
+ MAX_RETRIES = 10
21
+ INITIAL_RETRY_DELAY = 1.0
22
+ MAX_RETRY_DELAY = 60.0
23
+ JITTER_FACTOR = 0.1
24
+
25
+ async def create_message(
26
+ self,
27
+ *,
28
+ messages: list[BetaMessageParam],
29
+ system: list[Any],
30
+ tools: list[BetaToolUnionParam],
31
+ max_tokens: int,
32
+ betas: list[str],
33
+ ) -> BetaMessage:
34
+ """Create a message using the Anthropic API."""
35
+ raise NotImplementedError
36
+
37
+ async def _make_api_call_with_retries(self, api_call):
38
+ """Make an API call with exponential backoff retry logic.
39
+
40
+ Args:
41
+ api_call: Async function that makes the actual API call
42
+
43
+ Returns:
44
+ API response
45
+
46
+ Raises:
47
+ APIConnectionError: If all retries fail
48
+ """
49
+ retry_count = 0
50
+ last_error = None
51
+
52
+ while retry_count < self.MAX_RETRIES:
53
+ try:
54
+ return await api_call()
55
+ except Exception as e:
56
+ last_error = e
57
+ retry_count += 1
58
+
59
+ if retry_count == self.MAX_RETRIES:
60
+ break
61
+
62
+ # Calculate delay with exponential backoff and jitter
63
+ delay = min(
64
+ self.INITIAL_RETRY_DELAY * (2 ** (retry_count - 1)),
65
+ self.MAX_RETRY_DELAY
66
+ )
67
+ # Add jitter to avoid thundering herd
68
+ jitter = delay * self.JITTER_FACTOR * (2 * random.random() - 1)
69
+ final_delay = delay + jitter
70
+
71
+ logger.info(
72
+ f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) "
73
+ f"in {final_delay:.2f} seconds after error: {str(e)}"
74
+ )
75
+ await asyncio.sleep(final_delay)
76
+
77
+ raise APIConnectionError(
78
+ f"Failed after {self.MAX_RETRIES} retries. "
79
+ f"Last error: {str(last_error)}"
80
+ )
81
+
82
+ class AnthropicDirectClient(BaseAnthropicClient):
83
+ """Direct Anthropic API client implementation."""
84
+
85
+ def __init__(self, api_key: str, model: str):
86
+ self.model = model
87
+ self.client = Anthropic(
88
+ api_key=api_key,
89
+ http_client=self._create_http_client()
90
+ )
91
+
92
+ def _create_http_client(self) -> httpx.Client:
93
+ """Create an HTTP client with appropriate settings."""
94
+ return httpx.Client(
95
+ verify=True,
96
+ timeout=httpx.Timeout(
97
+ connect=30.0,
98
+ read=300.0,
99
+ write=30.0,
100
+ pool=30.0
101
+ ),
102
+ transport=httpx.HTTPTransport(
103
+ retries=3,
104
+ verify=True,
105
+ limits=httpx.Limits(
106
+ max_keepalive_connections=5,
107
+ max_connections=10
108
+ )
109
+ )
110
+ )
111
+
112
+ async def create_message(
113
+ self,
114
+ *,
115
+ messages: list[BetaMessageParam],
116
+ system: list[Any],
117
+ tools: list[BetaToolUnionParam],
118
+ max_tokens: int,
119
+ betas: list[str],
120
+ ) -> BetaMessage:
121
+ """Create a message using the direct Anthropic API with retry logic."""
122
+ async def api_call():
123
+ response = self.client.beta.messages.with_raw_response.create(
124
+ max_tokens=max_tokens,
125
+ messages=messages,
126
+ model=self.model,
127
+ system=system,
128
+ tools=tools,
129
+ betas=betas,
130
+ )
131
+ log_api_interaction(response.http_response.request, response.http_response, None)
132
+ return response.parse()
133
+
134
+ try:
135
+ return await self._make_api_call_with_retries(api_call)
136
+ except Exception as e:
137
+ log_api_interaction(None, None, e)
138
+ raise
139
+
140
+ class AnthropicVertexClient(BaseAnthropicClient):
141
+ """Google Cloud Vertex AI implementation of Anthropic client."""
142
+
143
+ def __init__(self, model: str):
144
+ self.model = model
145
+ self.client = AnthropicVertex()
146
+
147
+ async def create_message(
148
+ self,
149
+ *,
150
+ messages: list[BetaMessageParam],
151
+ system: list[Any],
152
+ tools: list[BetaToolUnionParam],
153
+ max_tokens: int,
154
+ betas: list[str],
155
+ ) -> BetaMessage:
156
+ """Create a message using Vertex AI with retry logic."""
157
+ async def api_call():
158
+ response = self.client.beta.messages.with_raw_response.create(
159
+ max_tokens=max_tokens,
160
+ messages=messages,
161
+ model=self.model,
162
+ system=system,
163
+ tools=tools,
164
+ betas=betas,
165
+ )
166
+ log_api_interaction(response.http_response.request, response.http_response, None)
167
+ return response.parse()
168
+
169
+ try:
170
+ return await self._make_api_call_with_retries(api_call)
171
+ except Exception as e:
172
+ log_api_interaction(None, None, e)
173
+ raise
174
+
175
+ class AnthropicBedrockClient(BaseAnthropicClient):
176
+ """AWS Bedrock implementation of Anthropic client."""
177
+
178
+ def __init__(self, model: str):
179
+ self.model = model
180
+ self.client = AnthropicBedrock()
181
+
182
+ async def create_message(
183
+ self,
184
+ *,
185
+ messages: list[BetaMessageParam],
186
+ system: list[Any],
187
+ tools: list[BetaToolUnionParam],
188
+ max_tokens: int,
189
+ betas: list[str],
190
+ ) -> BetaMessage:
191
+ """Create a message using AWS Bedrock with retry logic."""
192
+ async def api_call():
193
+ response = self.client.beta.messages.with_raw_response.create(
194
+ max_tokens=max_tokens,
195
+ messages=messages,
196
+ model=self.model,
197
+ system=system,
198
+ tools=tools,
199
+ betas=betas,
200
+ )
201
+ log_api_interaction(response.http_response.request, response.http_response, None)
202
+ return response.parse()
203
+
204
+ try:
205
+ return await self._make_api_call_with_retries(api_call)
206
+ except Exception as e:
207
+ log_api_interaction(None, None, e)
208
+ raise
209
+
210
+ class AnthropicClientFactory:
211
+ """Factory for creating appropriate Anthropic client implementations."""
212
+
213
+ @staticmethod
214
+ def create_client(provider: APIProvider, api_key: str, model: str) -> BaseAnthropicClient:
215
+ """Create an appropriate client based on the provider."""
216
+ if provider == APIProvider.ANTHROPIC:
217
+ return AnthropicDirectClient(api_key, model)
218
+ elif provider == APIProvider.VERTEX:
219
+ return AnthropicVertexClient(model)
220
+ elif provider == APIProvider.BEDROCK:
221
+ return AnthropicBedrockClient(model)
222
+ raise ValueError(f"Unsupported provider: {provider}")
@@ -0,0 +1,150 @@
1
+ """API logging functionality."""
2
+
3
+ import json
4
+ import logging
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ import httpx
8
+ from typing import Any
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def _filter_base64_images(content: Any) -> Any:
13
+ """Filter out base64 image data from content.
14
+
15
+ Args:
16
+ content: Content to filter
17
+
18
+ Returns:
19
+ Filtered content with base64 data replaced by placeholder
20
+ """
21
+ if isinstance(content, dict):
22
+ filtered = {}
23
+ for key, value in content.items():
24
+ if (
25
+ isinstance(value, dict)
26
+ and value.get("type") == "image"
27
+ and value.get("source", {}).get("type") == "base64"
28
+ ):
29
+ # Replace base64 data with placeholder
30
+ filtered[key] = {
31
+ **value,
32
+ "source": {
33
+ **value["source"],
34
+ "data": "<base64_image_data>"
35
+ }
36
+ }
37
+ else:
38
+ filtered[key] = _filter_base64_images(value)
39
+ return filtered
40
+ elif isinstance(content, list):
41
+ return [_filter_base64_images(item) for item in content]
42
+ return content
43
+
44
+ def log_api_interaction(
45
+ request: httpx.Request | None,
46
+ response: httpx.Response | object | None,
47
+ error: Exception | None,
48
+ log_dir: Path = Path("/tmp/claude_logs")
49
+ ) -> None:
50
+ """Log API request, response, and any errors in a structured way.
51
+
52
+ Args:
53
+ request: The HTTP request if available
54
+ response: The HTTP response or response object
55
+ error: Any error that occurred
56
+ log_dir: Directory to store log files
57
+ """
58
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
59
+
60
+ # Helper function to safely decode JSON content
61
+ def safe_json_decode(content):
62
+ if not content:
63
+ return None
64
+ try:
65
+ if isinstance(content, bytes):
66
+ return json.loads(content.decode())
67
+ elif isinstance(content, str):
68
+ return json.loads(content)
69
+ elif isinstance(content, dict):
70
+ return content
71
+ return None
72
+ except json.JSONDecodeError:
73
+ return {"error": "Could not decode JSON", "raw": str(content)}
74
+
75
+ # Process request content
76
+ request_content = None
77
+ if request and request.content:
78
+ request_content = safe_json_decode(request.content)
79
+ request_content = _filter_base64_images(request_content)
80
+
81
+ # Process response content
82
+ response_content = None
83
+ if response:
84
+ if isinstance(response, httpx.Response):
85
+ try:
86
+ response_content = response.json()
87
+ except json.JSONDecodeError:
88
+ response_content = {"error": "Could not decode JSON", "raw": response.text}
89
+ else:
90
+ response_content = safe_json_decode(response)
91
+ response_content = _filter_base64_images(response_content)
92
+
93
+ log_entry = {
94
+ "timestamp": timestamp,
95
+ "request": {
96
+ "method": request.method if request else None,
97
+ "url": str(request.url) if request else None,
98
+ "headers": dict(request.headers) if request else None,
99
+ "content": request_content,
100
+ } if request else None,
101
+ "response": {
102
+ "status_code": response.status_code if isinstance(response, httpx.Response) else None,
103
+ "headers": dict(response.headers) if isinstance(response, httpx.Response) else None,
104
+ "content": response_content,
105
+ } if response else None,
106
+ "error": {
107
+ "type": type(error).__name__ if error else None,
108
+ "message": str(error) if error else None,
109
+ } if error else None
110
+ }
111
+
112
+ # Log to file with timestamp in filename
113
+ log_dir.mkdir(exist_ok=True)
114
+ log_file = log_dir / f"claude_api_{timestamp.replace(' ', '_').replace(':', '-')}.json"
115
+
116
+ with open(log_file, 'w') as f:
117
+ json.dump(log_entry, f, indent=2)
118
+
119
+ # Also log a summary to the console
120
+ if error:
121
+ logger.error(f"API Error at {timestamp}: {error}")
122
+ else:
123
+ logger.info(
124
+ f"API Call at {timestamp}: "
125
+ f"{request.method if request else 'No request'} -> "
126
+ f"{response.status_code if isinstance(response, httpx.Response) else 'No response'}"
127
+ )
128
+
129
+ # Log if there are any images in the content
130
+ if response_content:
131
+ image_count = count_images(response_content)
132
+ if image_count > 0:
133
+ logger.info(f"Response contains {image_count} images")
134
+
135
+ def count_images(content: dict | list | Any) -> int:
136
+ """Count the number of images in the content.
137
+
138
+ Args:
139
+ content: Content to search for images
140
+
141
+ Returns:
142
+ Number of images found
143
+ """
144
+ if isinstance(content, dict):
145
+ if content.get("type") == "image":
146
+ return 1
147
+ return sum(count_images(v) for v in content.values())
148
+ elif isinstance(content, list):
149
+ return sum(count_images(item) for item in content)
150
+ return 0
@@ -0,0 +1,55 @@
1
+ from typing import Callable, Protocol
2
+ import httpx
3
+ from anthropic.types.beta import BetaContentBlockParam
4
+ from ..tools import ToolResult
5
+
6
+ class APICallback(Protocol):
7
+ """Protocol for API callbacks."""
8
+ def __call__(self, request: httpx.Request | None,
9
+ response: httpx.Response | object | None,
10
+ error: Exception | None) -> None: ...
11
+
12
+ class ContentCallback(Protocol):
13
+ """Protocol for content callbacks."""
14
+ def __call__(self, content: BetaContentBlockParam) -> None: ...
15
+
16
+ class ToolCallback(Protocol):
17
+ """Protocol for tool callbacks."""
18
+ def __call__(self, result: ToolResult, tool_id: str) -> None: ...
19
+
20
+ class CallbackManager:
21
+ """Manages various callbacks for the agent system."""
22
+
23
+ def __init__(
24
+ self,
25
+ content_callback: ContentCallback,
26
+ tool_callback: ToolCallback,
27
+ api_callback: APICallback,
28
+ ):
29
+ """Initialize the callback manager.
30
+
31
+ Args:
32
+ content_callback: Callback for content updates
33
+ tool_callback: Callback for tool execution results
34
+ api_callback: Callback for API interactions
35
+ """
36
+ self.content_callback = content_callback
37
+ self.tool_callback = tool_callback
38
+ self.api_callback = api_callback
39
+
40
+ def on_content(self, content: BetaContentBlockParam) -> None:
41
+ """Handle content updates."""
42
+ self.content_callback(content)
43
+
44
+ def on_tool_result(self, result: ToolResult, tool_id: str) -> None:
45
+ """Handle tool execution results."""
46
+ self.tool_callback(result, tool_id)
47
+
48
+ def on_api_interaction(
49
+ self,
50
+ request: httpx.Request | None,
51
+ response: httpx.Response | object | None,
52
+ error: Exception | None
53
+ ) -> None:
54
+ """Handle API interactions."""
55
+ self.api_callback(request, response, error)