cua-agent 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +21 -12
- agent/__main__.py +21 -0
- agent/adapters/__init__.py +9 -0
- agent/adapters/huggingfacelocal_adapter.py +229 -0
- agent/agent.py +594 -0
- agent/callbacks/__init__.py +19 -0
- agent/callbacks/base.py +153 -0
- agent/callbacks/budget_manager.py +44 -0
- agent/callbacks/image_retention.py +139 -0
- agent/callbacks/logging.py +247 -0
- agent/callbacks/pii_anonymization.py +259 -0
- agent/callbacks/telemetry.py +210 -0
- agent/callbacks/trajectory_saver.py +305 -0
- agent/cli.py +297 -0
- agent/computer_handler.py +107 -0
- agent/decorators.py +90 -0
- agent/loops/__init__.py +11 -0
- agent/loops/anthropic.py +728 -0
- agent/loops/omniparser.py +339 -0
- agent/loops/openai.py +95 -0
- agent/loops/uitars.py +688 -0
- agent/responses.py +207 -0
- agent/telemetry.py +135 -14
- agent/types.py +79 -0
- agent/ui/__init__.py +7 -1
- agent/ui/__main__.py +2 -13
- agent/ui/gradio/__init__.py +6 -19
- agent/ui/gradio/app.py +94 -1313
- agent/ui/gradio/ui_components.py +721 -0
- cua_agent-0.4.0.dist-info/METADATA +424 -0
- cua_agent-0.4.0.dist-info/RECORD +33 -0
- {cua_agent-0.3.1.dist-info → cua_agent-0.4.0.dist-info}/WHEEL +1 -1
- agent/core/__init__.py +0 -27
- agent/core/agent.py +0 -210
- agent/core/base.py +0 -217
- agent/core/callbacks.py +0 -200
- agent/core/experiment.py +0 -249
- agent/core/factory.py +0 -122
- agent/core/messages.py +0 -332
- agent/core/provider_config.py +0 -21
- agent/core/telemetry.py +0 -142
- agent/core/tools/__init__.py +0 -21
- agent/core/tools/base.py +0 -74
- agent/core/tools/bash.py +0 -52
- agent/core/tools/collection.py +0 -46
- agent/core/tools/computer.py +0 -113
- agent/core/tools/edit.py +0 -67
- agent/core/tools/manager.py +0 -56
- agent/core/tools.py +0 -32
- agent/core/types.py +0 -88
- agent/core/visualization.py +0 -197
- agent/providers/__init__.py +0 -4
- agent/providers/anthropic/__init__.py +0 -6
- agent/providers/anthropic/api/client.py +0 -360
- agent/providers/anthropic/api/logging.py +0 -150
- agent/providers/anthropic/api_handler.py +0 -140
- agent/providers/anthropic/callbacks/__init__.py +0 -5
- agent/providers/anthropic/callbacks/manager.py +0 -65
- agent/providers/anthropic/loop.py +0 -568
- agent/providers/anthropic/prompts.py +0 -23
- agent/providers/anthropic/response_handler.py +0 -226
- agent/providers/anthropic/tools/__init__.py +0 -33
- agent/providers/anthropic/tools/base.py +0 -88
- agent/providers/anthropic/tools/bash.py +0 -66
- agent/providers/anthropic/tools/collection.py +0 -34
- agent/providers/anthropic/tools/computer.py +0 -396
- agent/providers/anthropic/tools/edit.py +0 -326
- agent/providers/anthropic/tools/manager.py +0 -54
- agent/providers/anthropic/tools/run.py +0 -42
- agent/providers/anthropic/types.py +0 -16
- agent/providers/anthropic/utils.py +0 -367
- agent/providers/omni/__init__.py +0 -8
- agent/providers/omni/api_handler.py +0 -42
- agent/providers/omni/clients/anthropic.py +0 -103
- agent/providers/omni/clients/base.py +0 -35
- agent/providers/omni/clients/oaicompat.py +0 -195
- agent/providers/omni/clients/ollama.py +0 -122
- agent/providers/omni/clients/openai.py +0 -155
- agent/providers/omni/clients/utils.py +0 -25
- agent/providers/omni/image_utils.py +0 -34
- agent/providers/omni/loop.py +0 -990
- agent/providers/omni/parser.py +0 -307
- agent/providers/omni/prompts.py +0 -64
- agent/providers/omni/tools/__init__.py +0 -30
- agent/providers/omni/tools/base.py +0 -29
- agent/providers/omni/tools/bash.py +0 -74
- agent/providers/omni/tools/computer.py +0 -179
- agent/providers/omni/tools/manager.py +0 -61
- agent/providers/omni/utils.py +0 -236
- agent/providers/openai/__init__.py +0 -6
- agent/providers/openai/api_handler.py +0 -456
- agent/providers/openai/loop.py +0 -472
- agent/providers/openai/response_handler.py +0 -205
- agent/providers/openai/tools/__init__.py +0 -15
- agent/providers/openai/tools/base.py +0 -79
- agent/providers/openai/tools/computer.py +0 -326
- agent/providers/openai/tools/manager.py +0 -106
- agent/providers/openai/types.py +0 -36
- agent/providers/openai/utils.py +0 -98
- agent/providers/uitars/__init__.py +0 -1
- agent/providers/uitars/clients/base.py +0 -35
- agent/providers/uitars/clients/mlxvlm.py +0 -263
- agent/providers/uitars/clients/oaicompat.py +0 -214
- agent/providers/uitars/loop.py +0 -660
- agent/providers/uitars/prompts.py +0 -63
- agent/providers/uitars/tools/__init__.py +0 -1
- agent/providers/uitars/tools/computer.py +0 -283
- agent/providers/uitars/tools/manager.py +0 -60
- agent/providers/uitars/utils.py +0 -264
- cua_agent-0.3.1.dist-info/METADATA +0 -295
- cua_agent-0.3.1.dist-info/RECORD +0 -87
- {cua_agent-0.3.1.dist-info → cua_agent-0.4.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,195 +0,0 @@
|
|
|
1
|
-
"""OpenAI-compatible client implementation."""
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
import logging
|
|
5
|
-
from typing import Dict, List, Optional, Any
|
|
6
|
-
import aiohttp
|
|
7
|
-
import re
|
|
8
|
-
from .base import BaseOmniClient
|
|
9
|
-
|
|
10
|
-
logger = logging.getLogger(__name__)
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
# OpenAI-compatible client for the OmniLoop
|
|
14
|
-
class OAICompatClient(BaseOmniClient):
|
|
15
|
-
"""OpenAI-compatible API client implementation.
|
|
16
|
-
|
|
17
|
-
This client can be used with any service that implements the OpenAI API protocol, including:
|
|
18
|
-
- vLLM
|
|
19
|
-
- LM Studio
|
|
20
|
-
- LocalAI
|
|
21
|
-
- Ollama (with OpenAI compatibility)
|
|
22
|
-
- Text Generation WebUI
|
|
23
|
-
- Any other service with OpenAI API compatibility
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
def __init__(
|
|
27
|
-
self,
|
|
28
|
-
api_key: Optional[str] = None,
|
|
29
|
-
model: str = "Qwen2.5-VL-7B-Instruct",
|
|
30
|
-
provider_base_url: Optional[str] = "http://localhost:8000/v1",
|
|
31
|
-
max_tokens: int = 4096,
|
|
32
|
-
temperature: float = 0.0,
|
|
33
|
-
):
|
|
34
|
-
"""Initialize the OpenAI-compatible client.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
api_key: Not used for local endpoints, usually set to "EMPTY"
|
|
38
|
-
model: Model name to use
|
|
39
|
-
provider_base_url: API base URL. Typically in the format "http://localhost:PORT/v1"
|
|
40
|
-
Examples:
|
|
41
|
-
- vLLM: "http://localhost:8000/v1"
|
|
42
|
-
- LM Studio: "http://localhost:1234/v1"
|
|
43
|
-
- LocalAI: "http://localhost:8080/v1"
|
|
44
|
-
- Ollama: "http://localhost:11434/v1"
|
|
45
|
-
max_tokens: Maximum tokens to generate
|
|
46
|
-
temperature: Generation temperature
|
|
47
|
-
"""
|
|
48
|
-
super().__init__(api_key=api_key or "EMPTY", model=model)
|
|
49
|
-
self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
|
|
50
|
-
self.model = model
|
|
51
|
-
self.provider_base_url = (
|
|
52
|
-
provider_base_url or "http://localhost:8000/v1"
|
|
53
|
-
) # Use default if None
|
|
54
|
-
self.max_tokens = max_tokens
|
|
55
|
-
self.temperature = temperature
|
|
56
|
-
|
|
57
|
-
def _extract_base64_image(self, text: str) -> Optional[str]:
|
|
58
|
-
"""Extract base64 image data from an HTML img tag."""
|
|
59
|
-
pattern = r'data:image/[^;]+;base64,([^"]+)'
|
|
60
|
-
match = re.search(pattern, text)
|
|
61
|
-
return match.group(1) if match else None
|
|
62
|
-
|
|
63
|
-
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
64
|
-
"""Create a loggable version of messages with image data truncated."""
|
|
65
|
-
loggable_messages = []
|
|
66
|
-
for msg in messages:
|
|
67
|
-
if isinstance(msg.get("content"), list):
|
|
68
|
-
new_content = []
|
|
69
|
-
for content in msg["content"]:
|
|
70
|
-
if content.get("type") == "image":
|
|
71
|
-
new_content.append(
|
|
72
|
-
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
|
|
73
|
-
)
|
|
74
|
-
else:
|
|
75
|
-
new_content.append(content)
|
|
76
|
-
loggable_messages.append({"role": msg["role"], "content": new_content})
|
|
77
|
-
else:
|
|
78
|
-
loggable_messages.append(msg)
|
|
79
|
-
return loggable_messages
|
|
80
|
-
|
|
81
|
-
async def run_interleaved(
|
|
82
|
-
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
|
83
|
-
) -> Dict[str, Any]:
|
|
84
|
-
"""Run interleaved chat completion.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
messages: List of message dicts
|
|
88
|
-
system: System prompt
|
|
89
|
-
max_tokens: Optional max tokens override
|
|
90
|
-
|
|
91
|
-
Returns:
|
|
92
|
-
Response dict
|
|
93
|
-
"""
|
|
94
|
-
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
|
95
|
-
|
|
96
|
-
final_messages = [
|
|
97
|
-
{
|
|
98
|
-
"role": "system",
|
|
99
|
-
"content": [
|
|
100
|
-
{ "type": "text", "text": system }
|
|
101
|
-
]
|
|
102
|
-
}
|
|
103
|
-
]
|
|
104
|
-
|
|
105
|
-
# Process messages
|
|
106
|
-
for item in messages:
|
|
107
|
-
if isinstance(item, dict):
|
|
108
|
-
if isinstance(item["content"], list):
|
|
109
|
-
# Content is already in the correct format
|
|
110
|
-
final_messages.append(item)
|
|
111
|
-
else:
|
|
112
|
-
# Single string content, check for image
|
|
113
|
-
base64_img = self._extract_base64_image(item["content"])
|
|
114
|
-
if base64_img:
|
|
115
|
-
message = {
|
|
116
|
-
"role": item["role"],
|
|
117
|
-
"content": [
|
|
118
|
-
{
|
|
119
|
-
"type": "image_url",
|
|
120
|
-
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
121
|
-
}
|
|
122
|
-
],
|
|
123
|
-
}
|
|
124
|
-
else:
|
|
125
|
-
message = {
|
|
126
|
-
"role": item["role"],
|
|
127
|
-
"content": [{
|
|
128
|
-
"type": "text",
|
|
129
|
-
"text": item["content"]
|
|
130
|
-
}],
|
|
131
|
-
}
|
|
132
|
-
final_messages.append(message)
|
|
133
|
-
else:
|
|
134
|
-
# String content, check for image
|
|
135
|
-
base64_img = self._extract_base64_image(item)
|
|
136
|
-
if base64_img:
|
|
137
|
-
message = {
|
|
138
|
-
"role": "user",
|
|
139
|
-
"content": [
|
|
140
|
-
{
|
|
141
|
-
"type": "image_url",
|
|
142
|
-
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
143
|
-
}
|
|
144
|
-
],
|
|
145
|
-
}
|
|
146
|
-
else:
|
|
147
|
-
message = {"role": "user", "content": [{"type": "text", "text": item}]}
|
|
148
|
-
final_messages.append(message)
|
|
149
|
-
|
|
150
|
-
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
|
|
151
|
-
payload["max_tokens"] = max_tokens or self.max_tokens
|
|
152
|
-
|
|
153
|
-
try:
|
|
154
|
-
async with aiohttp.ClientSession() as session:
|
|
155
|
-
# Use default base URL if none provided
|
|
156
|
-
base_url = self.provider_base_url or "http://localhost:8000/v1"
|
|
157
|
-
|
|
158
|
-
# Check if the base URL already includes the chat/completions endpoint
|
|
159
|
-
|
|
160
|
-
endpoint_url = base_url
|
|
161
|
-
if not endpoint_url.endswith("/chat/completions"):
|
|
162
|
-
# If URL is RunPod format, make it OpenAI compatible
|
|
163
|
-
if endpoint_url.startswith("https://api.runpod.ai/v2/"):
|
|
164
|
-
# Extract RunPod endpoint ID
|
|
165
|
-
parts = endpoint_url.split("/")
|
|
166
|
-
if len(parts) >= 5:
|
|
167
|
-
runpod_id = parts[4]
|
|
168
|
-
endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
|
|
169
|
-
# If the URL ends with /v1, append /chat/completions
|
|
170
|
-
elif endpoint_url.endswith("/v1"):
|
|
171
|
-
endpoint_url = f"{endpoint_url}/chat/completions"
|
|
172
|
-
# If the URL doesn't end with /v1, make sure it has a proper structure
|
|
173
|
-
elif not endpoint_url.endswith("/"):
|
|
174
|
-
endpoint_url = f"{endpoint_url}/chat/completions"
|
|
175
|
-
else:
|
|
176
|
-
endpoint_url = f"{endpoint_url}chat/completions"
|
|
177
|
-
|
|
178
|
-
# Log the endpoint URL for debugging
|
|
179
|
-
logger.debug(f"Using endpoint URL: {endpoint_url}")
|
|
180
|
-
|
|
181
|
-
async with session.post(endpoint_url, headers=headers, json=payload) as response:
|
|
182
|
-
response_json = await response.json()
|
|
183
|
-
|
|
184
|
-
if response.status != 200:
|
|
185
|
-
error_msg = response_json.get("error", {}).get(
|
|
186
|
-
"message", str(response_json)
|
|
187
|
-
)
|
|
188
|
-
logger.error(f"Error in API call: {error_msg}")
|
|
189
|
-
raise Exception(f"API error: {error_msg}")
|
|
190
|
-
|
|
191
|
-
return response_json
|
|
192
|
-
|
|
193
|
-
except Exception as e:
|
|
194
|
-
logger.error(f"Error in API call: {str(e)}")
|
|
195
|
-
raise
|
|
@@ -1,122 +0,0 @@
|
|
|
1
|
-
"""Ollama API client implementation."""
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
from typing import Any, Dict, List, Optional, Tuple, cast
|
|
5
|
-
import asyncio
|
|
6
|
-
from httpx import ConnectError, ReadTimeout
|
|
7
|
-
|
|
8
|
-
from ollama import AsyncClient, Options
|
|
9
|
-
from ollama import Message
|
|
10
|
-
from .base import BaseOmniClient
|
|
11
|
-
|
|
12
|
-
logger = logging.getLogger(__name__)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class OllamaClient(BaseOmniClient):
|
|
16
|
-
"""Client for making calls to Ollama API."""
|
|
17
|
-
|
|
18
|
-
def __init__(self, api_key: str, model: str, max_retries: int = 3, retry_delay: float = 1.0):
|
|
19
|
-
"""Initialize the Ollama client.
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
api_key: Not used
|
|
23
|
-
model: Ollama model name (e.g. "gemma3:4b-it-q4_K_M")
|
|
24
|
-
max_retries: Maximum number of retries for API calls
|
|
25
|
-
retry_delay: Base delay between retries in seconds
|
|
26
|
-
"""
|
|
27
|
-
if not model:
|
|
28
|
-
raise ValueError("Model name must be provided")
|
|
29
|
-
|
|
30
|
-
self.client = AsyncClient(
|
|
31
|
-
host="http://localhost:11434",
|
|
32
|
-
)
|
|
33
|
-
self.model: str = model # Add explicit type annotation
|
|
34
|
-
self.max_retries = max_retries
|
|
35
|
-
self.retry_delay = retry_delay
|
|
36
|
-
|
|
37
|
-
def _convert_message_format(self, system: str, messages: List[Dict[str, Any]]) -> List[Any]:
|
|
38
|
-
"""Convert messages from standard format to Ollama format.
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
messages: Messages in standard format
|
|
42
|
-
|
|
43
|
-
Returns:
|
|
44
|
-
Messages in Ollama format
|
|
45
|
-
"""
|
|
46
|
-
ollama_messages = []
|
|
47
|
-
|
|
48
|
-
# Add system message
|
|
49
|
-
ollama_messages.append(
|
|
50
|
-
{
|
|
51
|
-
"role": "system",
|
|
52
|
-
"content": system,
|
|
53
|
-
}
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
for message in messages:
|
|
57
|
-
# Skip messages with empty content
|
|
58
|
-
if not message.get("content"):
|
|
59
|
-
continue
|
|
60
|
-
content = message.get("content", [{}])[0]
|
|
61
|
-
isImage = content.get("type", "") == "image_url"
|
|
62
|
-
isText = content.get("type", "") == "text"
|
|
63
|
-
if isText:
|
|
64
|
-
data = content.get("text", "")
|
|
65
|
-
ollama_messages.append({"role": message["role"], "content": data})
|
|
66
|
-
if isImage:
|
|
67
|
-
data = content.get("image_url", {}).get("url", "")
|
|
68
|
-
# remove header
|
|
69
|
-
data = data.removeprefix("data:image/png;base64,")
|
|
70
|
-
ollama_messages.append(
|
|
71
|
-
{"role": message["role"], "content": "Use this image", "images": [data]}
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
# Cast the list to the correct type expected by Ollama
|
|
75
|
-
return cast(List[Any], ollama_messages)
|
|
76
|
-
|
|
77
|
-
async def run_interleaved(
|
|
78
|
-
self, messages: List[Dict[str, Any]], system: str, max_tokens: int
|
|
79
|
-
) -> Any:
|
|
80
|
-
"""Run model with interleaved conversation format.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
messages: List of messages to process
|
|
84
|
-
system: System prompt
|
|
85
|
-
max_tokens: Not used
|
|
86
|
-
|
|
87
|
-
Returns:
|
|
88
|
-
Model response
|
|
89
|
-
"""
|
|
90
|
-
last_error = None
|
|
91
|
-
|
|
92
|
-
for attempt in range(self.max_retries):
|
|
93
|
-
try:
|
|
94
|
-
# Convert messages to Ollama format
|
|
95
|
-
ollama_messages = self._convert_message_format(system, messages)
|
|
96
|
-
|
|
97
|
-
response = await self.client.chat(
|
|
98
|
-
model=self.model,
|
|
99
|
-
options=Options(
|
|
100
|
-
temperature=0,
|
|
101
|
-
),
|
|
102
|
-
messages=ollama_messages,
|
|
103
|
-
format="json",
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
return response
|
|
107
|
-
|
|
108
|
-
except (ConnectError, ReadTimeout) as e:
|
|
109
|
-
last_error = e
|
|
110
|
-
logger.warning(
|
|
111
|
-
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
|
|
112
|
-
)
|
|
113
|
-
if attempt < self.max_retries - 1:
|
|
114
|
-
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
|
115
|
-
continue
|
|
116
|
-
|
|
117
|
-
except Exception as e:
|
|
118
|
-
logger.error(f"Unexpected error in Ollama API call: {str(e)}")
|
|
119
|
-
raise RuntimeError(f"Ollama API call failed: {str(e)}")
|
|
120
|
-
|
|
121
|
-
# If we get here, all retries failed
|
|
122
|
-
raise RuntimeError(f"Connection error after {self.max_retries} retries: {str(last_error)}")
|
|
@@ -1,155 +0,0 @@
|
|
|
1
|
-
"""OpenAI client implementation."""
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
import logging
|
|
5
|
-
from typing import Dict, List, Optional, Any
|
|
6
|
-
import aiohttp
|
|
7
|
-
import re
|
|
8
|
-
from datetime import datetime
|
|
9
|
-
from .base import BaseOmniClient
|
|
10
|
-
|
|
11
|
-
logger = logging.getLogger(__name__)
|
|
12
|
-
|
|
13
|
-
# OpenAI specific client for the OmniLoop
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class OpenAIClient(BaseOmniClient):
|
|
17
|
-
"""OpenAI vision API client implementation."""
|
|
18
|
-
|
|
19
|
-
def __init__(
|
|
20
|
-
self,
|
|
21
|
-
api_key: Optional[str] = None,
|
|
22
|
-
model: str = "gpt-4o",
|
|
23
|
-
provider_base_url: str = "https://api.openai.com/v1",
|
|
24
|
-
max_tokens: int = 4096,
|
|
25
|
-
temperature: float = 0.0,
|
|
26
|
-
):
|
|
27
|
-
"""Initialize the OpenAI client.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
api_key: OpenAI API key
|
|
31
|
-
model: Model to use
|
|
32
|
-
provider_base_url: API endpoint
|
|
33
|
-
max_tokens: Maximum tokens to generate
|
|
34
|
-
temperature: Generation temperature
|
|
35
|
-
"""
|
|
36
|
-
super().__init__(api_key=api_key, model=model)
|
|
37
|
-
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
38
|
-
if not self.api_key:
|
|
39
|
-
raise ValueError("No OpenAI API key provided")
|
|
40
|
-
|
|
41
|
-
self.model = model
|
|
42
|
-
self.provider_base_url = provider_base_url
|
|
43
|
-
self.max_tokens = max_tokens
|
|
44
|
-
self.temperature = temperature
|
|
45
|
-
|
|
46
|
-
def _extract_base64_image(self, text: str) -> Optional[str]:
|
|
47
|
-
"""Extract base64 image data from an HTML img tag."""
|
|
48
|
-
pattern = r'data:image/[^;]+;base64,([^"]+)'
|
|
49
|
-
match = re.search(pattern, text)
|
|
50
|
-
return match.group(1) if match else None
|
|
51
|
-
|
|
52
|
-
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
53
|
-
"""Create a loggable version of messages with image data truncated."""
|
|
54
|
-
loggable_messages = []
|
|
55
|
-
for msg in messages:
|
|
56
|
-
if isinstance(msg.get("content"), list):
|
|
57
|
-
new_content = []
|
|
58
|
-
for content in msg["content"]:
|
|
59
|
-
if content.get("type") == "image":
|
|
60
|
-
new_content.append(
|
|
61
|
-
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
|
|
62
|
-
)
|
|
63
|
-
else:
|
|
64
|
-
new_content.append(content)
|
|
65
|
-
loggable_messages.append({"role": msg["role"], "content": new_content})
|
|
66
|
-
else:
|
|
67
|
-
loggable_messages.append(msg)
|
|
68
|
-
return loggable_messages
|
|
69
|
-
|
|
70
|
-
async def run_interleaved(
|
|
71
|
-
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
|
72
|
-
) -> Dict[str, Any]:
|
|
73
|
-
"""Run interleaved chat completion.
|
|
74
|
-
|
|
75
|
-
Args:
|
|
76
|
-
messages: List of message dicts
|
|
77
|
-
system: System prompt
|
|
78
|
-
max_tokens: Optional max tokens override
|
|
79
|
-
|
|
80
|
-
Returns:
|
|
81
|
-
Response dict
|
|
82
|
-
"""
|
|
83
|
-
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
|
84
|
-
|
|
85
|
-
final_messages = [{"role": "system", "content": system}]
|
|
86
|
-
|
|
87
|
-
# Process messages
|
|
88
|
-
for item in messages:
|
|
89
|
-
if isinstance(item, dict):
|
|
90
|
-
if isinstance(item["content"], list):
|
|
91
|
-
# Content is already in the correct format
|
|
92
|
-
final_messages.append(item)
|
|
93
|
-
else:
|
|
94
|
-
# Single string content, check for image
|
|
95
|
-
base64_img = self._extract_base64_image(item["content"])
|
|
96
|
-
if base64_img:
|
|
97
|
-
message = {
|
|
98
|
-
"role": item["role"],
|
|
99
|
-
"content": [
|
|
100
|
-
{
|
|
101
|
-
"type": "image_url",
|
|
102
|
-
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
103
|
-
}
|
|
104
|
-
],
|
|
105
|
-
}
|
|
106
|
-
else:
|
|
107
|
-
message = {
|
|
108
|
-
"role": item["role"],
|
|
109
|
-
"content": [{"type": "text", "text": item["content"]}],
|
|
110
|
-
}
|
|
111
|
-
final_messages.append(message)
|
|
112
|
-
else:
|
|
113
|
-
# String content, check for image
|
|
114
|
-
base64_img = self._extract_base64_image(item)
|
|
115
|
-
if base64_img:
|
|
116
|
-
message = {
|
|
117
|
-
"role": "user",
|
|
118
|
-
"content": [
|
|
119
|
-
{
|
|
120
|
-
"type": "image_url",
|
|
121
|
-
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
122
|
-
}
|
|
123
|
-
],
|
|
124
|
-
}
|
|
125
|
-
else:
|
|
126
|
-
message = {"role": "user", "content": [{"type": "text", "text": item}]}
|
|
127
|
-
final_messages.append(message)
|
|
128
|
-
|
|
129
|
-
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
|
|
130
|
-
|
|
131
|
-
if "o1" in self.model or "o3-mini" in self.model:
|
|
132
|
-
payload["reasoning_effort"] = "low"
|
|
133
|
-
payload["max_completion_tokens"] = max_tokens or self.max_tokens
|
|
134
|
-
else:
|
|
135
|
-
payload["max_tokens"] = max_tokens or self.max_tokens
|
|
136
|
-
|
|
137
|
-
try:
|
|
138
|
-
async with aiohttp.ClientSession() as session:
|
|
139
|
-
async with session.post(
|
|
140
|
-
f"{self.provider_base_url}/chat/completions", headers=headers, json=payload
|
|
141
|
-
) as response:
|
|
142
|
-
response_json = await response.json()
|
|
143
|
-
|
|
144
|
-
if response.status != 200:
|
|
145
|
-
error_msg = response_json.get("error", {}).get(
|
|
146
|
-
"message", str(response_json)
|
|
147
|
-
)
|
|
148
|
-
logger.error(f"Error in OpenAI API call: {error_msg}")
|
|
149
|
-
raise Exception(f"OpenAI API error: {error_msg}")
|
|
150
|
-
|
|
151
|
-
return response_json
|
|
152
|
-
|
|
153
|
-
except Exception as e:
|
|
154
|
-
logger.error(f"Error in OpenAI API call: {str(e)}")
|
|
155
|
-
raise
|
|
@@ -1,25 +0,0 @@
|
|
|
1
|
-
import base64
|
|
2
|
-
|
|
3
|
-
def is_image_path(text: str) -> bool:
|
|
4
|
-
"""Check if a text string is an image file path.
|
|
5
|
-
|
|
6
|
-
Args:
|
|
7
|
-
text: Text string to check
|
|
8
|
-
|
|
9
|
-
Returns:
|
|
10
|
-
True if text ends with image extension, False otherwise
|
|
11
|
-
"""
|
|
12
|
-
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
|
|
13
|
-
return text.endswith(image_extensions)
|
|
14
|
-
|
|
15
|
-
def encode_image(image_path: str) -> str:
|
|
16
|
-
"""Encode image file to base64.
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
image_path: Path to image file
|
|
20
|
-
|
|
21
|
-
Returns:
|
|
22
|
-
Base64 encoded image string
|
|
23
|
-
"""
|
|
24
|
-
with open(image_path, "rb") as image_file:
|
|
25
|
-
return base64.b64encode(image_file.read()).decode("utf-8")
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
"""Image processing utilities for the Cua provider."""
|
|
2
|
-
|
|
3
|
-
import base64
|
|
4
|
-
import logging
|
|
5
|
-
import re
|
|
6
|
-
from io import BytesIO
|
|
7
|
-
from typing import Optional, Tuple
|
|
8
|
-
from PIL import Image
|
|
9
|
-
|
|
10
|
-
logger = logging.getLogger(__name__)
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def decode_base64_image(img_base64: str) -> Optional[Image.Image]:
|
|
14
|
-
"""Decode a base64 encoded image to a PIL Image.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
img_base64: Base64 encoded image, may include data URL prefix
|
|
18
|
-
|
|
19
|
-
Returns:
|
|
20
|
-
PIL Image or None if decoding fails
|
|
21
|
-
"""
|
|
22
|
-
try:
|
|
23
|
-
# Remove data URL prefix if present
|
|
24
|
-
if img_base64.startswith("data:image"):
|
|
25
|
-
img_base64 = img_base64.split(",")[1]
|
|
26
|
-
|
|
27
|
-
# Decode base64 to bytes
|
|
28
|
-
img_data = base64.b64decode(img_base64)
|
|
29
|
-
|
|
30
|
-
# Convert bytes to PIL Image
|
|
31
|
-
return Image.open(BytesIO(img_data))
|
|
32
|
-
except Exception as e:
|
|
33
|
-
logger.error(f"Error decoding base64 image: {str(e)}")
|
|
34
|
-
return None
|