cua-agent 0.1.6__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +3 -2
- agent/core/__init__.py +0 -5
- agent/core/computer_agent.py +21 -28
- agent/core/loop.py +78 -124
- agent/core/messages.py +279 -125
- agent/core/types.py +35 -0
- agent/core/visualization.py +197 -0
- agent/providers/anthropic/api/client.py +142 -1
- agent/providers/anthropic/api_handler.py +140 -0
- agent/providers/anthropic/callbacks/__init__.py +5 -0
- agent/providers/anthropic/loop.py +206 -220
- agent/providers/anthropic/response_handler.py +229 -0
- agent/providers/anthropic/tools/bash.py +0 -97
- agent/providers/anthropic/utils.py +370 -0
- agent/providers/omni/__init__.py +1 -20
- agent/providers/omni/api_handler.py +42 -0
- agent/providers/omni/clients/anthropic.py +4 -0
- agent/providers/omni/image_utils.py +0 -72
- agent/providers/omni/loop.py +490 -606
- agent/providers/omni/parser.py +58 -4
- agent/providers/omni/tools/__init__.py +25 -7
- agent/providers/omni/tools/base.py +29 -0
- agent/providers/omni/tools/bash.py +43 -38
- agent/providers/omni/tools/computer.py +144 -182
- agent/providers/omni/tools/manager.py +25 -45
- agent/providers/omni/types.py +0 -4
- agent/providers/omni/utils.py +224 -145
- {cua_agent-0.1.6.dist-info → cua_agent-0.1.17.dist-info}/METADATA +6 -36
- cua_agent-0.1.17.dist-info/RECORD +63 -0
- agent/providers/omni/callbacks.py +0 -78
- agent/providers/omni/clients/groq.py +0 -101
- agent/providers/omni/experiment.py +0 -276
- agent/providers/omni/messages.py +0 -171
- agent/providers/omni/tool_manager.py +0 -91
- agent/providers/omni/visualization.py +0 -130
- agent/types/__init__.py +0 -23
- agent/types/base.py +0 -41
- agent/types/messages.py +0 -36
- cua_agent-0.1.6.dist-info/RECORD +0 -64
- /agent/{types → core}/tools.py +0 -0
- {cua_agent-0.1.6.dist-info → cua_agent-0.1.17.dist-info}/WHEEL +0 -0
- {cua_agent-0.1.6.dist-info → cua_agent-0.1.17.dist-info}/entry_points.txt +0 -0
|
@@ -1,78 +0,0 @@
|
|
|
1
|
-
"""Omni callback manager implementation."""
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
from typing import Any, Dict, Optional, Set
|
|
5
|
-
|
|
6
|
-
from ...core.callbacks import BaseCallbackManager, ContentCallback, ToolCallback, APICallback
|
|
7
|
-
from ...types.tools import ToolResult
|
|
8
|
-
|
|
9
|
-
logger = logging.getLogger(__name__)
|
|
10
|
-
|
|
11
|
-
class OmniCallbackManager(BaseCallbackManager):
|
|
12
|
-
"""Callback manager for multi-provider support."""
|
|
13
|
-
|
|
14
|
-
def __init__(
|
|
15
|
-
self,
|
|
16
|
-
content_callback: ContentCallback,
|
|
17
|
-
tool_callback: ToolCallback,
|
|
18
|
-
api_callback: APICallback,
|
|
19
|
-
):
|
|
20
|
-
"""Initialize Omni callback manager.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
content_callback: Callback for content updates
|
|
24
|
-
tool_callback: Callback for tool execution results
|
|
25
|
-
api_callback: Callback for API interactions
|
|
26
|
-
"""
|
|
27
|
-
super().__init__(
|
|
28
|
-
content_callback=content_callback,
|
|
29
|
-
tool_callback=tool_callback,
|
|
30
|
-
api_callback=api_callback
|
|
31
|
-
)
|
|
32
|
-
self._active_tools: Set[str] = set()
|
|
33
|
-
|
|
34
|
-
def on_content(self, content: Any) -> None:
|
|
35
|
-
"""Handle content updates.
|
|
36
|
-
|
|
37
|
-
Args:
|
|
38
|
-
content: Content update data
|
|
39
|
-
"""
|
|
40
|
-
logger.debug(f"Content update: {content}")
|
|
41
|
-
self.content_callback(content)
|
|
42
|
-
|
|
43
|
-
def on_tool_result(self, result: ToolResult, tool_id: str) -> None:
|
|
44
|
-
"""Handle tool execution results.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
result: Tool execution result
|
|
48
|
-
tool_id: ID of the tool
|
|
49
|
-
"""
|
|
50
|
-
logger.debug(f"Tool result for {tool_id}: {result}")
|
|
51
|
-
self.tool_callback(result, tool_id)
|
|
52
|
-
|
|
53
|
-
def on_api_interaction(
|
|
54
|
-
self,
|
|
55
|
-
request: Any,
|
|
56
|
-
response: Any,
|
|
57
|
-
error: Optional[Exception] = None
|
|
58
|
-
) -> None:
|
|
59
|
-
"""Handle API interactions.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
request: API request data
|
|
63
|
-
response: API response data
|
|
64
|
-
error: Optional error that occurred
|
|
65
|
-
"""
|
|
66
|
-
if error:
|
|
67
|
-
logger.error(f"API error: {str(error)}")
|
|
68
|
-
else:
|
|
69
|
-
logger.debug(f"API interaction - Request: {request}, Response: {response}")
|
|
70
|
-
self.api_callback(request, response, error)
|
|
71
|
-
|
|
72
|
-
def get_active_tools(self) -> Set[str]:
|
|
73
|
-
"""Get currently active tools.
|
|
74
|
-
|
|
75
|
-
Returns:
|
|
76
|
-
Set of active tool names
|
|
77
|
-
"""
|
|
78
|
-
return self._active_tools.copy()
|
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
"""Groq client implementation."""
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
import logging
|
|
5
|
-
from typing import Dict, List, Optional, Any, Tuple
|
|
6
|
-
|
|
7
|
-
from groq import Groq
|
|
8
|
-
import re
|
|
9
|
-
from .utils import is_image_path
|
|
10
|
-
from .base import BaseOmniClient
|
|
11
|
-
|
|
12
|
-
logger = logging.getLogger(__name__)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class GroqClient(BaseOmniClient):
|
|
16
|
-
"""Client for making Groq API calls."""
|
|
17
|
-
|
|
18
|
-
def __init__(
|
|
19
|
-
self,
|
|
20
|
-
api_key: Optional[str] = None,
|
|
21
|
-
model: str = "deepseek-r1-distill-llama-70b",
|
|
22
|
-
max_tokens: int = 4096,
|
|
23
|
-
temperature: float = 0.6,
|
|
24
|
-
):
|
|
25
|
-
"""Initialize Groq client.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
api_key: Groq API key (if not provided, will try to get from env)
|
|
29
|
-
model: Model name to use
|
|
30
|
-
max_tokens: Maximum tokens to generate
|
|
31
|
-
temperature: Temperature for sampling
|
|
32
|
-
"""
|
|
33
|
-
super().__init__(api_key=api_key, model=model)
|
|
34
|
-
self.api_key = api_key or os.getenv("GROQ_API_KEY")
|
|
35
|
-
if not self.api_key:
|
|
36
|
-
raise ValueError("No Groq API key provided")
|
|
37
|
-
|
|
38
|
-
self.max_tokens = max_tokens
|
|
39
|
-
self.temperature = temperature
|
|
40
|
-
self.client = Groq(api_key=self.api_key)
|
|
41
|
-
self.model: str = model # Add explicit type annotation
|
|
42
|
-
|
|
43
|
-
def run_interleaved(
|
|
44
|
-
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
|
45
|
-
) -> tuple[str, int]:
|
|
46
|
-
"""Run interleaved chat completion.
|
|
47
|
-
|
|
48
|
-
Args:
|
|
49
|
-
messages: List of message dicts
|
|
50
|
-
system: System prompt
|
|
51
|
-
max_tokens: Optional max tokens override
|
|
52
|
-
|
|
53
|
-
Returns:
|
|
54
|
-
Tuple of (response text, token usage)
|
|
55
|
-
"""
|
|
56
|
-
# Avoid using system messages for R1
|
|
57
|
-
final_messages = [{"role": "user", "content": system}]
|
|
58
|
-
|
|
59
|
-
# Process messages
|
|
60
|
-
if isinstance(messages, list):
|
|
61
|
-
for item in messages:
|
|
62
|
-
if isinstance(item, dict):
|
|
63
|
-
# For dict items, concatenate all text content, ignoring images
|
|
64
|
-
text_contents = []
|
|
65
|
-
for cnt in item["content"]:
|
|
66
|
-
if isinstance(cnt, str):
|
|
67
|
-
if not is_image_path(cnt): # Skip image paths
|
|
68
|
-
text_contents.append(cnt)
|
|
69
|
-
else:
|
|
70
|
-
text_contents.append(str(cnt))
|
|
71
|
-
|
|
72
|
-
if text_contents: # Only add if there's text content
|
|
73
|
-
message = {"role": "user", "content": " ".join(text_contents)}
|
|
74
|
-
final_messages.append(message)
|
|
75
|
-
else: # str
|
|
76
|
-
message = {"role": "user", "content": item}
|
|
77
|
-
final_messages.append(message)
|
|
78
|
-
|
|
79
|
-
elif isinstance(messages, str):
|
|
80
|
-
final_messages.append({"role": "user", "content": messages})
|
|
81
|
-
|
|
82
|
-
try:
|
|
83
|
-
completion = self.client.chat.completions.create( # type: ignore
|
|
84
|
-
model=self.model,
|
|
85
|
-
messages=final_messages, # type: ignore
|
|
86
|
-
temperature=self.temperature,
|
|
87
|
-
max_tokens=max_tokens or self.max_tokens,
|
|
88
|
-
top_p=0.95,
|
|
89
|
-
stream=False,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
response = completion.choices[0].message.content
|
|
93
|
-
final_answer = response.split("</think>\n")[-1] if "</think>" in response else response
|
|
94
|
-
final_answer = final_answer.replace("<output>", "").replace("</output>", "")
|
|
95
|
-
token_usage = completion.usage.total_tokens
|
|
96
|
-
|
|
97
|
-
return final_answer, token_usage
|
|
98
|
-
|
|
99
|
-
except Exception as e:
|
|
100
|
-
logger.error(f"Error in Groq API call: {e}")
|
|
101
|
-
raise
|
|
@@ -1,276 +0,0 @@
|
|
|
1
|
-
"""Experiment management for the Cua provider."""
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
import logging
|
|
5
|
-
import copy
|
|
6
|
-
import base64
|
|
7
|
-
from io import BytesIO
|
|
8
|
-
from datetime import datetime
|
|
9
|
-
from typing import Any, Dict, List, Optional
|
|
10
|
-
from PIL import Image
|
|
11
|
-
import json
|
|
12
|
-
import time
|
|
13
|
-
|
|
14
|
-
logger = logging.getLogger(__name__)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class ExperimentManager:
|
|
18
|
-
"""Manages experiment directories and logging for the agent."""
|
|
19
|
-
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
base_dir: Optional[str] = None,
|
|
23
|
-
only_n_most_recent_images: Optional[int] = None,
|
|
24
|
-
):
|
|
25
|
-
"""Initialize the experiment manager.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
base_dir: Base directory for saving experiment data
|
|
29
|
-
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
|
30
|
-
"""
|
|
31
|
-
self.base_dir = base_dir
|
|
32
|
-
self.only_n_most_recent_images = only_n_most_recent_images
|
|
33
|
-
self.run_dir = None
|
|
34
|
-
self.current_turn_dir = None
|
|
35
|
-
self.turn_count = 0
|
|
36
|
-
self.screenshot_count = 0
|
|
37
|
-
# Track all screenshots for potential API request inclusion
|
|
38
|
-
self.screenshot_paths = []
|
|
39
|
-
|
|
40
|
-
# Set up experiment directories if base_dir is provided
|
|
41
|
-
if self.base_dir:
|
|
42
|
-
self.setup_experiment_dirs()
|
|
43
|
-
|
|
44
|
-
def setup_experiment_dirs(self) -> None:
|
|
45
|
-
"""Setup the experiment directory structure."""
|
|
46
|
-
if not self.base_dir:
|
|
47
|
-
return
|
|
48
|
-
|
|
49
|
-
# Create base experiments directory if it doesn't exist
|
|
50
|
-
os.makedirs(self.base_dir, exist_ok=True)
|
|
51
|
-
|
|
52
|
-
# Use the base_dir directly as the run_dir
|
|
53
|
-
self.run_dir = self.base_dir
|
|
54
|
-
logger.info(f"Using directory for experiment: {self.run_dir}")
|
|
55
|
-
|
|
56
|
-
# Create first turn directory
|
|
57
|
-
self.create_turn_dir()
|
|
58
|
-
|
|
59
|
-
def create_turn_dir(self) -> None:
|
|
60
|
-
"""Create a new directory for the current turn."""
|
|
61
|
-
if not self.run_dir:
|
|
62
|
-
return
|
|
63
|
-
|
|
64
|
-
self.turn_count += 1
|
|
65
|
-
self.current_turn_dir = os.path.join(self.run_dir, f"turn_{self.turn_count:03d}")
|
|
66
|
-
os.makedirs(self.current_turn_dir, exist_ok=True)
|
|
67
|
-
logger.info(f"Created turn directory: {self.current_turn_dir}")
|
|
68
|
-
|
|
69
|
-
def sanitize_log_data(self, data: Any) -> Any:
|
|
70
|
-
"""Sanitize data for logging by removing large base64 strings.
|
|
71
|
-
|
|
72
|
-
Args:
|
|
73
|
-
data: Data to sanitize (dict, list, or primitive)
|
|
74
|
-
|
|
75
|
-
Returns:
|
|
76
|
-
Sanitized copy of the data
|
|
77
|
-
"""
|
|
78
|
-
if isinstance(data, dict):
|
|
79
|
-
result = copy.deepcopy(data)
|
|
80
|
-
|
|
81
|
-
# Handle nested dictionaries and lists
|
|
82
|
-
for key, value in result.items():
|
|
83
|
-
# Process content arrays that contain image data
|
|
84
|
-
if key == "content" and isinstance(value, list):
|
|
85
|
-
for i, item in enumerate(value):
|
|
86
|
-
if isinstance(item, dict):
|
|
87
|
-
# Handle Anthropic format
|
|
88
|
-
if item.get("type") == "image" and isinstance(item.get("source"), dict):
|
|
89
|
-
source = item["source"]
|
|
90
|
-
if "data" in source and isinstance(source["data"], str):
|
|
91
|
-
# Replace base64 data with a placeholder and length info
|
|
92
|
-
data_len = len(source["data"])
|
|
93
|
-
source["data"] = f"[BASE64_IMAGE_DATA_LENGTH_{data_len}]"
|
|
94
|
-
|
|
95
|
-
# Handle OpenAI format
|
|
96
|
-
elif item.get("type") == "image_url" and isinstance(
|
|
97
|
-
item.get("image_url"), dict
|
|
98
|
-
):
|
|
99
|
-
url_dict = item["image_url"]
|
|
100
|
-
if "url" in url_dict and isinstance(url_dict["url"], str):
|
|
101
|
-
url = url_dict["url"]
|
|
102
|
-
if url.startswith("data:"):
|
|
103
|
-
# Replace base64 data with placeholder
|
|
104
|
-
data_len = len(url)
|
|
105
|
-
url_dict["url"] = f"[BASE64_IMAGE_URL_LENGTH_{data_len}]"
|
|
106
|
-
|
|
107
|
-
# Handle other nested structures recursively
|
|
108
|
-
if isinstance(value, dict):
|
|
109
|
-
result[key] = self.sanitize_log_data(value)
|
|
110
|
-
elif isinstance(value, list):
|
|
111
|
-
result[key] = [self.sanitize_log_data(item) for item in value]
|
|
112
|
-
|
|
113
|
-
return result
|
|
114
|
-
elif isinstance(data, list):
|
|
115
|
-
return [self.sanitize_log_data(item) for item in data]
|
|
116
|
-
else:
|
|
117
|
-
return data
|
|
118
|
-
|
|
119
|
-
def save_debug_image(self, image_data: str, filename: str) -> None:
|
|
120
|
-
"""Save a debug image to the experiment directory.
|
|
121
|
-
|
|
122
|
-
Args:
|
|
123
|
-
image_data: Base64 encoded image data
|
|
124
|
-
filename: Filename to save the image as
|
|
125
|
-
"""
|
|
126
|
-
# Since we no longer want to use the images/ folder, we'll skip this functionality
|
|
127
|
-
return
|
|
128
|
-
|
|
129
|
-
def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]:
|
|
130
|
-
"""Save a screenshot to the experiment directory.
|
|
131
|
-
|
|
132
|
-
Args:
|
|
133
|
-
img_base64: Base64 encoded screenshot
|
|
134
|
-
action_type: Type of action that triggered the screenshot
|
|
135
|
-
|
|
136
|
-
Returns:
|
|
137
|
-
Optional[str]: Path to the saved screenshot, or None if saving failed
|
|
138
|
-
"""
|
|
139
|
-
if not self.current_turn_dir:
|
|
140
|
-
return None
|
|
141
|
-
|
|
142
|
-
try:
|
|
143
|
-
# Increment screenshot counter
|
|
144
|
-
self.screenshot_count += 1
|
|
145
|
-
|
|
146
|
-
# Create a descriptive filename
|
|
147
|
-
timestamp = int(time.time() * 1000)
|
|
148
|
-
action_suffix = f"_{action_type}" if action_type else ""
|
|
149
|
-
filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
|
|
150
|
-
|
|
151
|
-
# Save directly to the turn directory (no screenshots subdirectory)
|
|
152
|
-
filepath = os.path.join(self.current_turn_dir, filename)
|
|
153
|
-
|
|
154
|
-
# Save the screenshot
|
|
155
|
-
img_data = base64.b64decode(img_base64)
|
|
156
|
-
with open(filepath, "wb") as f:
|
|
157
|
-
f.write(img_data)
|
|
158
|
-
|
|
159
|
-
# Keep track of the file path for reference
|
|
160
|
-
self.screenshot_paths.append(filepath)
|
|
161
|
-
|
|
162
|
-
return filepath
|
|
163
|
-
except Exception as e:
|
|
164
|
-
logger.error(f"Error saving screenshot: {str(e)}")
|
|
165
|
-
return None
|
|
166
|
-
|
|
167
|
-
def should_save_debug_image(self) -> bool:
|
|
168
|
-
"""Determine if debug images should be saved.
|
|
169
|
-
|
|
170
|
-
Returns:
|
|
171
|
-
Boolean indicating if debug images should be saved
|
|
172
|
-
"""
|
|
173
|
-
# We no longer need to save debug images, so always return False
|
|
174
|
-
return False
|
|
175
|
-
|
|
176
|
-
def save_action_visualization(
|
|
177
|
-
self, img: Image.Image, action_name: str, details: str = ""
|
|
178
|
-
) -> str:
|
|
179
|
-
"""Save a visualization of an action.
|
|
180
|
-
|
|
181
|
-
Args:
|
|
182
|
-
img: Image to save
|
|
183
|
-
action_name: Name of the action
|
|
184
|
-
details: Additional details about the action
|
|
185
|
-
|
|
186
|
-
Returns:
|
|
187
|
-
Path to the saved image
|
|
188
|
-
"""
|
|
189
|
-
if not self.current_turn_dir:
|
|
190
|
-
return ""
|
|
191
|
-
|
|
192
|
-
try:
|
|
193
|
-
# Create a descriptive filename
|
|
194
|
-
timestamp = int(time.time() * 1000)
|
|
195
|
-
details_suffix = f"_{details}" if details else ""
|
|
196
|
-
filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
|
|
197
|
-
|
|
198
|
-
# Save directly to the turn directory (no visualizations subdirectory)
|
|
199
|
-
filepath = os.path.join(self.current_turn_dir, filename)
|
|
200
|
-
|
|
201
|
-
# Save the image
|
|
202
|
-
img.save(filepath)
|
|
203
|
-
|
|
204
|
-
# Keep track of the file path for cleanup
|
|
205
|
-
self.screenshot_paths.append(filepath)
|
|
206
|
-
|
|
207
|
-
return filepath
|
|
208
|
-
except Exception as e:
|
|
209
|
-
logger.error(f"Error saving action visualization: {str(e)}")
|
|
210
|
-
return ""
|
|
211
|
-
|
|
212
|
-
def extract_and_save_images(self, data: Any, prefix: str) -> None:
|
|
213
|
-
"""Extract and save images from response data.
|
|
214
|
-
|
|
215
|
-
Args:
|
|
216
|
-
data: Response data to extract images from
|
|
217
|
-
prefix: Prefix for saved image filenames
|
|
218
|
-
"""
|
|
219
|
-
# Since we no longer want to save extracted images separately,
|
|
220
|
-
# we'll skip this functionality entirely
|
|
221
|
-
return
|
|
222
|
-
|
|
223
|
-
def log_api_call(
|
|
224
|
-
self,
|
|
225
|
-
call_type: str,
|
|
226
|
-
request: Any,
|
|
227
|
-
provider: str,
|
|
228
|
-
model: str,
|
|
229
|
-
response: Any = None,
|
|
230
|
-
error: Optional[Exception] = None,
|
|
231
|
-
) -> None:
|
|
232
|
-
"""Log API call details to file.
|
|
233
|
-
|
|
234
|
-
Args:
|
|
235
|
-
call_type: Type of API call (e.g., 'request', 'response', 'error')
|
|
236
|
-
request: The API request data
|
|
237
|
-
provider: The AI provider used
|
|
238
|
-
model: The AI model used
|
|
239
|
-
response: Optional API response data
|
|
240
|
-
error: Optional error information
|
|
241
|
-
"""
|
|
242
|
-
if not self.current_turn_dir:
|
|
243
|
-
return
|
|
244
|
-
|
|
245
|
-
try:
|
|
246
|
-
# Create a unique filename with timestamp
|
|
247
|
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
248
|
-
filename = f"api_call_{timestamp}_{call_type}.json"
|
|
249
|
-
filepath = os.path.join(self.current_turn_dir, filename)
|
|
250
|
-
|
|
251
|
-
# Sanitize data to remove large base64 strings
|
|
252
|
-
sanitized_request = self.sanitize_log_data(request)
|
|
253
|
-
sanitized_response = self.sanitize_log_data(response) if response is not None else None
|
|
254
|
-
|
|
255
|
-
# Prepare log data
|
|
256
|
-
log_data = {
|
|
257
|
-
"timestamp": timestamp,
|
|
258
|
-
"provider": provider,
|
|
259
|
-
"model": model,
|
|
260
|
-
"type": call_type,
|
|
261
|
-
"request": sanitized_request,
|
|
262
|
-
}
|
|
263
|
-
|
|
264
|
-
if sanitized_response is not None:
|
|
265
|
-
log_data["response"] = sanitized_response
|
|
266
|
-
if error is not None:
|
|
267
|
-
log_data["error"] = str(error)
|
|
268
|
-
|
|
269
|
-
# Write to file
|
|
270
|
-
with open(filepath, "w") as f:
|
|
271
|
-
json.dump(log_data, f, indent=2, default=str)
|
|
272
|
-
|
|
273
|
-
logger.info(f"Logged API {call_type} to {filepath}")
|
|
274
|
-
|
|
275
|
-
except Exception as e:
|
|
276
|
-
logger.error(f"Error logging API call: {str(e)}")
|
agent/providers/omni/messages.py
DELETED
|
@@ -1,171 +0,0 @@
|
|
|
1
|
-
"""Omni message manager implementation."""
|
|
2
|
-
|
|
3
|
-
import base64
|
|
4
|
-
from typing import Any, Dict, List, Optional
|
|
5
|
-
from io import BytesIO
|
|
6
|
-
from PIL import Image
|
|
7
|
-
|
|
8
|
-
from ...core.messages import BaseMessageManager, ImageRetentionConfig
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class OmniMessageManager(BaseMessageManager):
|
|
12
|
-
"""Message manager for multi-provider support."""
|
|
13
|
-
|
|
14
|
-
def __init__(self, config: Optional[ImageRetentionConfig] = None):
|
|
15
|
-
"""Initialize the message manager.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
config: Optional configuration for image retention
|
|
19
|
-
"""
|
|
20
|
-
super().__init__(config)
|
|
21
|
-
self.messages: List[Dict[str, Any]] = []
|
|
22
|
-
self.config = config
|
|
23
|
-
|
|
24
|
-
def add_user_message(self, content: str, images: Optional[List[bytes]] = None) -> None:
|
|
25
|
-
"""Add a user message to the history.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
content: Message content
|
|
29
|
-
images: Optional list of image data
|
|
30
|
-
"""
|
|
31
|
-
# Add images if present
|
|
32
|
-
if images:
|
|
33
|
-
# Initialize with proper typing for mixed content
|
|
34
|
-
message_content: List[Dict[str, Any]] = [{"type": "text", "text": content}]
|
|
35
|
-
|
|
36
|
-
# Add each image
|
|
37
|
-
for img in images:
|
|
38
|
-
message_content.append(
|
|
39
|
-
{
|
|
40
|
-
"type": "image_url",
|
|
41
|
-
"image_url": {
|
|
42
|
-
"url": f"data:image/png;base64,{base64.b64encode(img).decode()}"
|
|
43
|
-
},
|
|
44
|
-
}
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
message = {"role": "user", "content": message_content}
|
|
48
|
-
else:
|
|
49
|
-
# Simple text message
|
|
50
|
-
message = {"role": "user", "content": content}
|
|
51
|
-
|
|
52
|
-
self.messages.append(message)
|
|
53
|
-
|
|
54
|
-
# Apply retention policy
|
|
55
|
-
if self.config and self.config.num_images_to_keep:
|
|
56
|
-
self._apply_image_retention_policy()
|
|
57
|
-
|
|
58
|
-
def add_assistant_message(self, content: str) -> None:
|
|
59
|
-
"""Add an assistant message to the history.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
content: Message content
|
|
63
|
-
"""
|
|
64
|
-
self.messages.append({"role": "assistant", "content": content})
|
|
65
|
-
|
|
66
|
-
def add_system_message(self, content: str) -> None:
|
|
67
|
-
"""Add a system message to the history.
|
|
68
|
-
|
|
69
|
-
Args:
|
|
70
|
-
content: Message content
|
|
71
|
-
"""
|
|
72
|
-
self.messages.append({"role": "system", "content": content})
|
|
73
|
-
|
|
74
|
-
def _apply_image_retention_policy(self) -> None:
|
|
75
|
-
"""Apply image retention policy to message history."""
|
|
76
|
-
if not self.config or not self.config.num_images_to_keep:
|
|
77
|
-
return
|
|
78
|
-
|
|
79
|
-
# Count images from newest to oldest
|
|
80
|
-
image_count = 0
|
|
81
|
-
for message in reversed(self.messages):
|
|
82
|
-
if message["role"] != "user":
|
|
83
|
-
continue
|
|
84
|
-
|
|
85
|
-
# Handle multimodal messages
|
|
86
|
-
if isinstance(message["content"], list):
|
|
87
|
-
new_content = []
|
|
88
|
-
for item in message["content"]:
|
|
89
|
-
if item["type"] == "text":
|
|
90
|
-
new_content.append(item)
|
|
91
|
-
elif item["type"] == "image_url":
|
|
92
|
-
if image_count < self.config.num_images_to_keep:
|
|
93
|
-
new_content.append(item)
|
|
94
|
-
image_count += 1
|
|
95
|
-
message["content"] = new_content
|
|
96
|
-
|
|
97
|
-
def get_formatted_messages(self, provider: str) -> List[Dict[str, Any]]:
|
|
98
|
-
"""Get messages formatted for specific provider.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
provider: Provider name to format messages for
|
|
102
|
-
|
|
103
|
-
Returns:
|
|
104
|
-
List of formatted messages
|
|
105
|
-
"""
|
|
106
|
-
# Set the provider for message formatting
|
|
107
|
-
self.set_provider(provider)
|
|
108
|
-
|
|
109
|
-
if provider == "anthropic":
|
|
110
|
-
return self._format_for_anthropic()
|
|
111
|
-
elif provider == "openai":
|
|
112
|
-
return self._format_for_openai()
|
|
113
|
-
elif provider == "groq":
|
|
114
|
-
return self._format_for_groq()
|
|
115
|
-
elif provider == "qwen":
|
|
116
|
-
return self._format_for_qwen()
|
|
117
|
-
else:
|
|
118
|
-
raise ValueError(f"Unsupported provider: {provider}")
|
|
119
|
-
|
|
120
|
-
def _format_for_anthropic(self) -> List[Dict[str, Any]]:
|
|
121
|
-
"""Format messages for Anthropic API."""
|
|
122
|
-
formatted = []
|
|
123
|
-
for msg in self.messages:
|
|
124
|
-
formatted_msg = {"role": msg["role"]}
|
|
125
|
-
|
|
126
|
-
# Handle multimodal content
|
|
127
|
-
if isinstance(msg["content"], list):
|
|
128
|
-
formatted_msg["content"] = []
|
|
129
|
-
for item in msg["content"]:
|
|
130
|
-
if item["type"] == "text":
|
|
131
|
-
formatted_msg["content"].append({"type": "text", "text": item["text"]})
|
|
132
|
-
elif item["type"] == "image_url":
|
|
133
|
-
formatted_msg["content"].append(
|
|
134
|
-
{
|
|
135
|
-
"type": "image",
|
|
136
|
-
"source": {
|
|
137
|
-
"type": "base64",
|
|
138
|
-
"media_type": "image/png",
|
|
139
|
-
"data": item["image_url"]["url"].split(",")[1],
|
|
140
|
-
},
|
|
141
|
-
}
|
|
142
|
-
)
|
|
143
|
-
else:
|
|
144
|
-
formatted_msg["content"] = msg["content"]
|
|
145
|
-
|
|
146
|
-
formatted.append(formatted_msg)
|
|
147
|
-
return formatted
|
|
148
|
-
|
|
149
|
-
def _format_for_openai(self) -> List[Dict[str, Any]]:
|
|
150
|
-
"""Format messages for OpenAI API."""
|
|
151
|
-
# OpenAI already uses the same format
|
|
152
|
-
return self.messages
|
|
153
|
-
|
|
154
|
-
def _format_for_groq(self) -> List[Dict[str, Any]]:
|
|
155
|
-
"""Format messages for Groq API."""
|
|
156
|
-
# Groq uses OpenAI-compatible format
|
|
157
|
-
return self.messages
|
|
158
|
-
|
|
159
|
-
def _format_for_qwen(self) -> List[Dict[str, Any]]:
|
|
160
|
-
"""Format messages for Qwen API."""
|
|
161
|
-
formatted = []
|
|
162
|
-
for msg in self.messages:
|
|
163
|
-
if isinstance(msg["content"], list):
|
|
164
|
-
# Convert multimodal content to text-only
|
|
165
|
-
text_content = next(
|
|
166
|
-
(item["text"] for item in msg["content"] if item["type"] == "text"), ""
|
|
167
|
-
)
|
|
168
|
-
formatted.append({"role": msg["role"], "content": text_content})
|
|
169
|
-
else:
|
|
170
|
-
formatted.append(msg)
|
|
171
|
-
return formatted
|