cua-agent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cua-agent might be problematic. Click here for more details.

Files changed (65) hide show
  1. agent/README.md +63 -0
  2. agent/__init__.py +10 -0
  3. agent/core/README.md +101 -0
  4. agent/core/__init__.py +34 -0
  5. agent/core/agent.py +284 -0
  6. agent/core/base_agent.py +164 -0
  7. agent/core/callbacks.py +147 -0
  8. agent/core/computer_agent.py +69 -0
  9. agent/core/experiment.py +222 -0
  10. agent/core/factory.py +102 -0
  11. agent/core/loop.py +244 -0
  12. agent/core/messages.py +230 -0
  13. agent/core/tools/__init__.py +21 -0
  14. agent/core/tools/base.py +74 -0
  15. agent/core/tools/bash.py +52 -0
  16. agent/core/tools/collection.py +46 -0
  17. agent/core/tools/computer.py +113 -0
  18. agent/core/tools/edit.py +67 -0
  19. agent/core/tools/manager.py +56 -0
  20. agent/providers/__init__.py +4 -0
  21. agent/providers/anthropic/__init__.py +6 -0
  22. agent/providers/anthropic/api/client.py +222 -0
  23. agent/providers/anthropic/api/logging.py +150 -0
  24. agent/providers/anthropic/callbacks/manager.py +55 -0
  25. agent/providers/anthropic/loop.py +521 -0
  26. agent/providers/anthropic/messages/manager.py +110 -0
  27. agent/providers/anthropic/prompts.py +20 -0
  28. agent/providers/anthropic/tools/__init__.py +33 -0
  29. agent/providers/anthropic/tools/base.py +88 -0
  30. agent/providers/anthropic/tools/bash.py +163 -0
  31. agent/providers/anthropic/tools/collection.py +34 -0
  32. agent/providers/anthropic/tools/computer.py +550 -0
  33. agent/providers/anthropic/tools/edit.py +326 -0
  34. agent/providers/anthropic/tools/manager.py +54 -0
  35. agent/providers/anthropic/tools/run.py +42 -0
  36. agent/providers/anthropic/types.py +16 -0
  37. agent/providers/omni/__init__.py +27 -0
  38. agent/providers/omni/callbacks.py +78 -0
  39. agent/providers/omni/clients/anthropic.py +99 -0
  40. agent/providers/omni/clients/base.py +44 -0
  41. agent/providers/omni/clients/groq.py +101 -0
  42. agent/providers/omni/clients/openai.py +159 -0
  43. agent/providers/omni/clients/utils.py +25 -0
  44. agent/providers/omni/experiment.py +273 -0
  45. agent/providers/omni/image_utils.py +106 -0
  46. agent/providers/omni/loop.py +961 -0
  47. agent/providers/omni/messages.py +168 -0
  48. agent/providers/omni/parser.py +252 -0
  49. agent/providers/omni/prompts.py +78 -0
  50. agent/providers/omni/tool_manager.py +91 -0
  51. agent/providers/omni/tools/__init__.py +13 -0
  52. agent/providers/omni/tools/bash.py +69 -0
  53. agent/providers/omni/tools/computer.py +216 -0
  54. agent/providers/omni/tools/manager.py +83 -0
  55. agent/providers/omni/types.py +30 -0
  56. agent/providers/omni/utils.py +155 -0
  57. agent/providers/omni/visualization.py +130 -0
  58. agent/types/__init__.py +26 -0
  59. agent/types/base.py +52 -0
  60. agent/types/messages.py +36 -0
  61. agent/types/tools.py +32 -0
  62. cua_agent-0.1.0.dist-info/METADATA +44 -0
  63. cua_agent-0.1.0.dist-info/RECORD +65 -0
  64. cua_agent-0.1.0.dist-info/WHEEL +4 -0
  65. cua_agent-0.1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,101 @@
1
+ """Groq client implementation."""
2
+
3
+ import os
4
+ import logging
5
+ from typing import Dict, List, Optional, Any, Tuple
6
+
7
+ from groq import Groq
8
+ import re
9
+ from .utils import is_image_path
10
+ from .base import BaseOmniClient
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class GroqClient(BaseOmniClient):
16
+ """Client for making Groq API calls."""
17
+
18
+ def __init__(
19
+ self,
20
+ api_key: Optional[str] = None,
21
+ model: str = "deepseek-r1-distill-llama-70b",
22
+ max_tokens: int = 4096,
23
+ temperature: float = 0.6,
24
+ ):
25
+ """Initialize Groq client.
26
+
27
+ Args:
28
+ api_key: Groq API key (if not provided, will try to get from env)
29
+ model: Model name to use
30
+ max_tokens: Maximum tokens to generate
31
+ temperature: Temperature for sampling
32
+ """
33
+ super().__init__(api_key=api_key, model=model)
34
+ self.api_key = api_key or os.getenv("GROQ_API_KEY")
35
+ if not self.api_key:
36
+ raise ValueError("No Groq API key provided")
37
+
38
+ self.max_tokens = max_tokens
39
+ self.temperature = temperature
40
+ self.client = Groq(api_key=self.api_key)
41
+ self.model: str = model # Add explicit type annotation
42
+
43
+ def run_interleaved(
44
+ self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
45
+ ) -> tuple[str, int]:
46
+ """Run interleaved chat completion.
47
+
48
+ Args:
49
+ messages: List of message dicts
50
+ system: System prompt
51
+ max_tokens: Optional max tokens override
52
+
53
+ Returns:
54
+ Tuple of (response text, token usage)
55
+ """
56
+ # Avoid using system messages for R1
57
+ final_messages = [{"role": "user", "content": system}]
58
+
59
+ # Process messages
60
+ if isinstance(messages, list):
61
+ for item in messages:
62
+ if isinstance(item, dict):
63
+ # For dict items, concatenate all text content, ignoring images
64
+ text_contents = []
65
+ for cnt in item["content"]:
66
+ if isinstance(cnt, str):
67
+ if not is_image_path(cnt): # Skip image paths
68
+ text_contents.append(cnt)
69
+ else:
70
+ text_contents.append(str(cnt))
71
+
72
+ if text_contents: # Only add if there's text content
73
+ message = {"role": "user", "content": " ".join(text_contents)}
74
+ final_messages.append(message)
75
+ else: # str
76
+ message = {"role": "user", "content": item}
77
+ final_messages.append(message)
78
+
79
+ elif isinstance(messages, str):
80
+ final_messages.append({"role": "user", "content": messages})
81
+
82
+ try:
83
+ completion = self.client.chat.completions.create( # type: ignore
84
+ model=self.model,
85
+ messages=final_messages, # type: ignore
86
+ temperature=self.temperature,
87
+ max_tokens=max_tokens or self.max_tokens,
88
+ top_p=0.95,
89
+ stream=False,
90
+ )
91
+
92
+ response = completion.choices[0].message.content
93
+ final_answer = response.split("</think>\n")[-1] if "</think>" in response else response
94
+ final_answer = final_answer.replace("<output>", "").replace("</output>", "")
95
+ token_usage = completion.usage.total_tokens
96
+
97
+ return final_answer, token_usage
98
+
99
+ except Exception as e:
100
+ logger.error(f"Error in Groq API call: {e}")
101
+ raise
@@ -0,0 +1,159 @@
1
+ """OpenAI client implementation."""
2
+
3
+ import os
4
+ import logging
5
+ from typing import Dict, List, Optional, Any
6
+ import aiohttp
7
+ import base64
8
+ import re
9
+ import json
10
+ import ssl
11
+ import certifi
12
+ from datetime import datetime
13
+ from .base import BaseOmniClient
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # OpenAI specific client for the OmniLoop
18
+
19
+
20
+ class OpenAIClient(BaseOmniClient):
21
+ """OpenAI vision API client implementation."""
22
+
23
+ def __init__(
24
+ self,
25
+ api_key: Optional[str] = None,
26
+ model: str = "gpt-4o",
27
+ provider_base_url: str = "https://api.openai.com/v1",
28
+ max_tokens: int = 4096,
29
+ temperature: float = 0.0,
30
+ ):
31
+ """Initialize the OpenAI client.
32
+
33
+ Args:
34
+ api_key: OpenAI API key
35
+ model: Model to use
36
+ provider_base_url: API endpoint
37
+ max_tokens: Maximum tokens to generate
38
+ temperature: Generation temperature
39
+ """
40
+ super().__init__(api_key=api_key, model=model)
41
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY")
42
+ if not self.api_key:
43
+ raise ValueError("No OpenAI API key provided")
44
+
45
+ self.model = model
46
+ self.provider_base_url = provider_base_url
47
+ self.max_tokens = max_tokens
48
+ self.temperature = temperature
49
+
50
+ def _extract_base64_image(self, text: str) -> Optional[str]:
51
+ """Extract base64 image data from an HTML img tag."""
52
+ pattern = r'data:image/[^;]+;base64,([^"]+)'
53
+ match = re.search(pattern, text)
54
+ return match.group(1) if match else None
55
+
56
+ def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
57
+ """Create a loggable version of messages with image data truncated."""
58
+ loggable_messages = []
59
+ for msg in messages:
60
+ if isinstance(msg.get("content"), list):
61
+ new_content = []
62
+ for content in msg["content"]:
63
+ if content.get("type") == "image":
64
+ new_content.append(
65
+ {"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
66
+ )
67
+ else:
68
+ new_content.append(content)
69
+ loggable_messages.append({"role": msg["role"], "content": new_content})
70
+ else:
71
+ loggable_messages.append(msg)
72
+ return loggable_messages
73
+
74
+ async def run_interleaved(
75
+ self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
76
+ ) -> Dict[str, Any]:
77
+ """Run interleaved chat completion.
78
+
79
+ Args:
80
+ messages: List of message dicts
81
+ system: System prompt
82
+ max_tokens: Optional max tokens override
83
+
84
+ Returns:
85
+ Response dict
86
+ """
87
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
88
+
89
+ final_messages = [{"role": "system", "content": system}]
90
+
91
+ # Process messages
92
+ for item in messages:
93
+ if isinstance(item, dict):
94
+ if isinstance(item["content"], list):
95
+ # Content is already in the correct format
96
+ final_messages.append(item)
97
+ else:
98
+ # Single string content, check for image
99
+ base64_img = self._extract_base64_image(item["content"])
100
+ if base64_img:
101
+ message = {
102
+ "role": item["role"],
103
+ "content": [
104
+ {
105
+ "type": "image_url",
106
+ "image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
107
+ }
108
+ ],
109
+ }
110
+ else:
111
+ message = {
112
+ "role": item["role"],
113
+ "content": [{"type": "text", "text": item["content"]}],
114
+ }
115
+ final_messages.append(message)
116
+ else:
117
+ # String content, check for image
118
+ base64_img = self._extract_base64_image(item)
119
+ if base64_img:
120
+ message = {
121
+ "role": "user",
122
+ "content": [
123
+ {
124
+ "type": "image_url",
125
+ "image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
126
+ }
127
+ ],
128
+ }
129
+ else:
130
+ message = {"role": "user", "content": [{"type": "text", "text": item}]}
131
+ final_messages.append(message)
132
+
133
+ payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
134
+
135
+ if "o1" in self.model or "o3-mini" in self.model:
136
+ payload["reasoning_effort"] = "low"
137
+ payload["max_completion_tokens"] = max_tokens or self.max_tokens
138
+ else:
139
+ payload["max_tokens"] = max_tokens or self.max_tokens
140
+
141
+ try:
142
+ async with aiohttp.ClientSession() as session:
143
+ async with session.post(
144
+ f"{self.provider_base_url}/chat/completions", headers=headers, json=payload
145
+ ) as response:
146
+ response_json = await response.json()
147
+
148
+ if response.status != 200:
149
+ error_msg = response_json.get("error", {}).get(
150
+ "message", str(response_json)
151
+ )
152
+ logger.error(f"Error in OpenAI API call: {error_msg}")
153
+ raise Exception(f"OpenAI API error: {error_msg}")
154
+
155
+ return response_json
156
+
157
+ except Exception as e:
158
+ logger.error(f"Error in OpenAI API call: {str(e)}")
159
+ raise
@@ -0,0 +1,25 @@
1
+ import base64
2
+
3
+ def is_image_path(text: str) -> bool:
4
+ """Check if a text string is an image file path.
5
+
6
+ Args:
7
+ text: Text string to check
8
+
9
+ Returns:
10
+ True if text ends with image extension, False otherwise
11
+ """
12
+ image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
13
+ return text.endswith(image_extensions)
14
+
15
+ def encode_image(image_path: str) -> str:
16
+ """Encode image file to base64.
17
+
18
+ Args:
19
+ image_path: Path to image file
20
+
21
+ Returns:
22
+ Base64 encoded image string
23
+ """
24
+ with open(image_path, "rb") as image_file:
25
+ return base64.b64encode(image_file.read()).decode("utf-8")
@@ -0,0 +1,273 @@
1
+ """Experiment management for the Cua provider."""
2
+
3
+ import os
4
+ import logging
5
+ import copy
6
+ import base64
7
+ from io import BytesIO
8
+ from datetime import datetime
9
+ from typing import Any, Dict, List, Optional
10
+ from PIL import Image
11
+ import json
12
+ import time
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ExperimentManager:
18
+ """Manages experiment directories and logging for the agent."""
19
+
20
+ def __init__(
21
+ self,
22
+ base_dir: Optional[str] = None,
23
+ only_n_most_recent_images: Optional[int] = None,
24
+ ):
25
+ """Initialize the experiment manager.
26
+
27
+ Args:
28
+ base_dir: Base directory for saving experiment data
29
+ only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
30
+ """
31
+ self.base_dir = base_dir
32
+ self.only_n_most_recent_images = only_n_most_recent_images
33
+ self.run_dir = None
34
+ self.current_turn_dir = None
35
+ self.turn_count = 0
36
+ self.screenshot_count = 0
37
+ # Track all screenshots for potential API request inclusion
38
+ self.screenshot_paths = []
39
+
40
+ # Set up experiment directories if base_dir is provided
41
+ if self.base_dir:
42
+ self.setup_experiment_dirs()
43
+
44
+ def setup_experiment_dirs(self) -> None:
45
+ """Setup the experiment directory structure."""
46
+ if not self.base_dir:
47
+ return
48
+
49
+ # Create base experiments directory if it doesn't exist
50
+ os.makedirs(self.base_dir, exist_ok=True)
51
+
52
+ # Use the base_dir directly as the run_dir
53
+ self.run_dir = self.base_dir
54
+ logger.info(f"Using directory for experiment: {self.run_dir}")
55
+
56
+ # Create first turn directory
57
+ self.create_turn_dir()
58
+
59
+ def create_turn_dir(self) -> None:
60
+ """Create a new directory for the current turn."""
61
+ if not self.run_dir:
62
+ return
63
+
64
+ self.turn_count += 1
65
+ self.current_turn_dir = os.path.join(self.run_dir, f"turn_{self.turn_count:03d}")
66
+ os.makedirs(self.current_turn_dir, exist_ok=True)
67
+ logger.info(f"Created turn directory: {self.current_turn_dir}")
68
+
69
+ def sanitize_log_data(self, data: Any) -> Any:
70
+ """Sanitize data for logging by removing large base64 strings.
71
+
72
+ Args:
73
+ data: Data to sanitize (dict, list, or primitive)
74
+
75
+ Returns:
76
+ Sanitized copy of the data
77
+ """
78
+ if isinstance(data, dict):
79
+ result = copy.deepcopy(data)
80
+
81
+ # Handle nested dictionaries and lists
82
+ for key, value in result.items():
83
+ # Process content arrays that contain image data
84
+ if key == "content" and isinstance(value, list):
85
+ for i, item in enumerate(value):
86
+ if isinstance(item, dict):
87
+ # Handle Anthropic format
88
+ if item.get("type") == "image" and isinstance(item.get("source"), dict):
89
+ source = item["source"]
90
+ if "data" in source and isinstance(source["data"], str):
91
+ # Replace base64 data with a placeholder and length info
92
+ data_len = len(source["data"])
93
+ source["data"] = f"[BASE64_IMAGE_DATA_LENGTH_{data_len}]"
94
+
95
+ # Handle OpenAI format
96
+ elif item.get("type") == "image_url" and isinstance(
97
+ item.get("image_url"), dict
98
+ ):
99
+ url_dict = item["image_url"]
100
+ if "url" in url_dict and isinstance(url_dict["url"], str):
101
+ url = url_dict["url"]
102
+ if url.startswith("data:"):
103
+ # Replace base64 data with placeholder
104
+ data_len = len(url)
105
+ url_dict["url"] = f"[BASE64_IMAGE_URL_LENGTH_{data_len}]"
106
+
107
+ # Handle other nested structures recursively
108
+ if isinstance(value, dict):
109
+ result[key] = self.sanitize_log_data(value)
110
+ elif isinstance(value, list):
111
+ result[key] = [self.sanitize_log_data(item) for item in value]
112
+
113
+ return result
114
+ elif isinstance(data, list):
115
+ return [self.sanitize_log_data(item) for item in data]
116
+ else:
117
+ return data
118
+
119
+ def save_debug_image(self, image_data: str, filename: str) -> None:
120
+ """Save a debug image to the experiment directory.
121
+
122
+ Args:
123
+ image_data: Base64 encoded image data
124
+ filename: Filename to save the image as
125
+ """
126
+ # Since we no longer want to use the images/ folder, we'll skip this functionality
127
+ return
128
+
129
+ def save_screenshot(self, img_base64: str, action_type: str = "") -> None:
130
+ """Save a screenshot to the experiment directory.
131
+
132
+ Args:
133
+ img_base64: Base64 encoded screenshot
134
+ action_type: Type of action that triggered the screenshot
135
+ """
136
+ if not self.current_turn_dir:
137
+ return
138
+
139
+ try:
140
+ # Increment screenshot counter
141
+ self.screenshot_count += 1
142
+
143
+ # Create a descriptive filename
144
+ timestamp = int(time.time() * 1000)
145
+ action_suffix = f"_{action_type}" if action_type else ""
146
+ filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
147
+
148
+ # Save directly to the turn directory (no screenshots subdirectory)
149
+ filepath = os.path.join(self.current_turn_dir, filename)
150
+
151
+ # Save the screenshot
152
+ img_data = base64.b64decode(img_base64)
153
+ with open(filepath, "wb") as f:
154
+ f.write(img_data)
155
+
156
+ # Keep track of the file path for reference
157
+ self.screenshot_paths.append(filepath)
158
+
159
+ return filepath
160
+ except Exception as e:
161
+ logger.error(f"Error saving screenshot: {str(e)}")
162
+ return None
163
+
164
+ def should_save_debug_image(self) -> bool:
165
+ """Determine if debug images should be saved.
166
+
167
+ Returns:
168
+ Boolean indicating if debug images should be saved
169
+ """
170
+ # We no longer need to save debug images, so always return False
171
+ return False
172
+
173
+ def save_action_visualization(
174
+ self, img: Image.Image, action_name: str, details: str = ""
175
+ ) -> str:
176
+ """Save a visualization of an action.
177
+
178
+ Args:
179
+ img: Image to save
180
+ action_name: Name of the action
181
+ details: Additional details about the action
182
+
183
+ Returns:
184
+ Path to the saved image
185
+ """
186
+ if not self.current_turn_dir:
187
+ return ""
188
+
189
+ try:
190
+ # Create a descriptive filename
191
+ timestamp = int(time.time() * 1000)
192
+ details_suffix = f"_{details}" if details else ""
193
+ filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
194
+
195
+ # Save directly to the turn directory (no visualizations subdirectory)
196
+ filepath = os.path.join(self.current_turn_dir, filename)
197
+
198
+ # Save the image
199
+ img.save(filepath)
200
+
201
+ # Keep track of the file path for cleanup
202
+ self.screenshot_paths.append(filepath)
203
+
204
+ return filepath
205
+ except Exception as e:
206
+ logger.error(f"Error saving action visualization: {str(e)}")
207
+ return ""
208
+
209
+ def extract_and_save_images(self, data: Any, prefix: str) -> None:
210
+ """Extract and save images from response data.
211
+
212
+ Args:
213
+ data: Response data to extract images from
214
+ prefix: Prefix for saved image filenames
215
+ """
216
+ # Since we no longer want to save extracted images separately,
217
+ # we'll skip this functionality entirely
218
+ return
219
+
220
+ def log_api_call(
221
+ self,
222
+ call_type: str,
223
+ request: Any,
224
+ provider: str,
225
+ model: str,
226
+ response: Any = None,
227
+ error: Optional[Exception] = None,
228
+ ) -> None:
229
+ """Log API call details to file.
230
+
231
+ Args:
232
+ call_type: Type of API call (e.g., 'request', 'response', 'error')
233
+ request: The API request data
234
+ provider: The AI provider used
235
+ model: The AI model used
236
+ response: Optional API response data
237
+ error: Optional error information
238
+ """
239
+ if not self.current_turn_dir:
240
+ return
241
+
242
+ try:
243
+ # Create a unique filename with timestamp
244
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
245
+ filename = f"api_call_{timestamp}_{call_type}.json"
246
+ filepath = os.path.join(self.current_turn_dir, filename)
247
+
248
+ # Sanitize data to remove large base64 strings
249
+ sanitized_request = self.sanitize_log_data(request)
250
+ sanitized_response = self.sanitize_log_data(response) if response is not None else None
251
+
252
+ # Prepare log data
253
+ log_data = {
254
+ "timestamp": timestamp,
255
+ "provider": provider,
256
+ "model": model,
257
+ "type": call_type,
258
+ "request": sanitized_request,
259
+ }
260
+
261
+ if sanitized_response is not None:
262
+ log_data["response"] = sanitized_response
263
+ if error is not None:
264
+ log_data["error"] = str(error)
265
+
266
+ # Write to file
267
+ with open(filepath, "w") as f:
268
+ json.dump(log_data, f, indent=2, default=str)
269
+
270
+ logger.info(f"Logged API {call_type} to {filepath}")
271
+
272
+ except Exception as e:
273
+ logger.error(f"Error logging API call: {str(e)}")
@@ -0,0 +1,106 @@
1
+ """Image processing utilities for the Cua provider."""
2
+
3
+ import base64
4
+ import logging
5
+ import re
6
+ from io import BytesIO
7
+ from typing import Optional, Tuple
8
+ from PIL import Image
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def decode_base64_image(img_base64: str) -> Optional[Image.Image]:
14
+ """Decode a base64 encoded image to a PIL Image.
15
+
16
+ Args:
17
+ img_base64: Base64 encoded image, may include data URL prefix
18
+
19
+ Returns:
20
+ PIL Image or None if decoding fails
21
+ """
22
+ try:
23
+ # Remove data URL prefix if present
24
+ if img_base64.startswith("data:image"):
25
+ img_base64 = img_base64.split(",")[1]
26
+
27
+ # Decode base64 to bytes
28
+ img_data = base64.b64decode(img_base64)
29
+
30
+ # Convert bytes to PIL Image
31
+ return Image.open(BytesIO(img_data))
32
+ except Exception as e:
33
+ logger.error(f"Error decoding base64 image: {str(e)}")
34
+ return None
35
+
36
+
37
+ def encode_image_base64(img: Image.Image, format: str = "PNG") -> str:
38
+ """Encode a PIL Image to base64.
39
+
40
+ Args:
41
+ img: PIL Image to encode
42
+ format: Image format (PNG, JPEG, etc.)
43
+
44
+ Returns:
45
+ Base64 encoded image string
46
+ """
47
+ try:
48
+ buffered = BytesIO()
49
+ img.save(buffered, format=format)
50
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
51
+ except Exception as e:
52
+ logger.error(f"Error encoding image to base64: {str(e)}")
53
+ return ""
54
+
55
+
56
+ def clean_base64_data(img_base64: str) -> str:
57
+ """Clean base64 image data by removing data URL prefix.
58
+
59
+ Args:
60
+ img_base64: Base64 encoded image, may include data URL prefix
61
+
62
+ Returns:
63
+ Clean base64 string without prefix
64
+ """
65
+ if img_base64.startswith("data:image"):
66
+ return img_base64.split(",")[1]
67
+ return img_base64
68
+
69
+
70
+ def extract_base64_from_text(text: str) -> Optional[str]:
71
+ """Extract base64 image data from a text string.
72
+
73
+ Args:
74
+ text: Text potentially containing base64 image data
75
+
76
+ Returns:
77
+ Base64 string or None if not found
78
+ """
79
+ # Look for data URL pattern
80
+ data_url_pattern = r"data:image/[^;]+;base64,([a-zA-Z0-9+/=]+)"
81
+ match = re.search(data_url_pattern, text)
82
+ if match:
83
+ return match.group(1)
84
+
85
+ # Look for plain base64 pattern (basic heuristic)
86
+ base64_pattern = r"([a-zA-Z0-9+/=]{100,})"
87
+ match = re.search(base64_pattern, text)
88
+ if match:
89
+ return match.group(1)
90
+
91
+ return None
92
+
93
+
94
+ def get_image_dimensions(img_base64: str) -> Tuple[int, int]:
95
+ """Get the dimensions of a base64 encoded image.
96
+
97
+ Args:
98
+ img_base64: Base64 encoded image
99
+
100
+ Returns:
101
+ Tuple of (width, height) or (0, 0) if decoding fails
102
+ """
103
+ img = decode_base64_image(img_base64)
104
+ if img:
105
+ return img.size
106
+ return (0, 0)