cua-agent 0.3.2__py3-none-any.whl → 0.4.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +15 -51
- agent/__main__.py +21 -0
- agent/adapters/__init__.py +9 -0
- agent/adapters/huggingfacelocal_adapter.py +216 -0
- agent/agent.py +577 -0
- agent/callbacks/__init__.py +17 -0
- agent/callbacks/base.py +153 -0
- agent/callbacks/budget_manager.py +44 -0
- agent/callbacks/image_retention.py +139 -0
- agent/callbacks/logging.py +247 -0
- agent/callbacks/pii_anonymization.py +259 -0
- agent/callbacks/trajectory_saver.py +305 -0
- agent/cli.py +290 -0
- agent/computer_handler.py +107 -0
- agent/decorators.py +90 -0
- agent/loops/__init__.py +11 -0
- agent/loops/anthropic.py +728 -0
- agent/loops/omniparser.py +339 -0
- agent/loops/openai.py +95 -0
- agent/loops/uitars.py +688 -0
- agent/responses.py +207 -0
- agent/types.py +79 -0
- agent/ui/__init__.py +7 -1
- agent/ui/gradio/__init__.py +6 -19
- agent/ui/gradio/app.py +80 -1299
- agent/ui/gradio/ui_components.py +703 -0
- cua_agent-0.4.0b1.dist-info/METADATA +424 -0
- cua_agent-0.4.0b1.dist-info/RECORD +30 -0
- agent/core/__init__.py +0 -27
- agent/core/agent.py +0 -210
- agent/core/base.py +0 -217
- agent/core/callbacks.py +0 -200
- agent/core/experiment.py +0 -249
- agent/core/factory.py +0 -122
- agent/core/messages.py +0 -332
- agent/core/provider_config.py +0 -21
- agent/core/telemetry.py +0 -142
- agent/core/tools/__init__.py +0 -21
- agent/core/tools/base.py +0 -74
- agent/core/tools/bash.py +0 -52
- agent/core/tools/collection.py +0 -46
- agent/core/tools/computer.py +0 -113
- agent/core/tools/edit.py +0 -67
- agent/core/tools/manager.py +0 -56
- agent/core/tools.py +0 -32
- agent/core/types.py +0 -88
- agent/core/visualization.py +0 -197
- agent/providers/__init__.py +0 -4
- agent/providers/anthropic/__init__.py +0 -6
- agent/providers/anthropic/api/client.py +0 -360
- agent/providers/anthropic/api/logging.py +0 -150
- agent/providers/anthropic/api_handler.py +0 -140
- agent/providers/anthropic/callbacks/__init__.py +0 -5
- agent/providers/anthropic/callbacks/manager.py +0 -65
- agent/providers/anthropic/loop.py +0 -568
- agent/providers/anthropic/prompts.py +0 -23
- agent/providers/anthropic/response_handler.py +0 -226
- agent/providers/anthropic/tools/__init__.py +0 -33
- agent/providers/anthropic/tools/base.py +0 -88
- agent/providers/anthropic/tools/bash.py +0 -66
- agent/providers/anthropic/tools/collection.py +0 -34
- agent/providers/anthropic/tools/computer.py +0 -396
- agent/providers/anthropic/tools/edit.py +0 -326
- agent/providers/anthropic/tools/manager.py +0 -54
- agent/providers/anthropic/tools/run.py +0 -42
- agent/providers/anthropic/types.py +0 -16
- agent/providers/anthropic/utils.py +0 -381
- agent/providers/omni/__init__.py +0 -8
- agent/providers/omni/api_handler.py +0 -42
- agent/providers/omni/clients/anthropic.py +0 -103
- agent/providers/omni/clients/base.py +0 -35
- agent/providers/omni/clients/oaicompat.py +0 -195
- agent/providers/omni/clients/ollama.py +0 -122
- agent/providers/omni/clients/openai.py +0 -155
- agent/providers/omni/clients/utils.py +0 -25
- agent/providers/omni/image_utils.py +0 -34
- agent/providers/omni/loop.py +0 -990
- agent/providers/omni/parser.py +0 -307
- agent/providers/omni/prompts.py +0 -64
- agent/providers/omni/tools/__init__.py +0 -30
- agent/providers/omni/tools/base.py +0 -29
- agent/providers/omni/tools/bash.py +0 -74
- agent/providers/omni/tools/computer.py +0 -179
- agent/providers/omni/tools/manager.py +0 -61
- agent/providers/omni/utils.py +0 -236
- agent/providers/openai/__init__.py +0 -6
- agent/providers/openai/api_handler.py +0 -456
- agent/providers/openai/loop.py +0 -472
- agent/providers/openai/response_handler.py +0 -205
- agent/providers/openai/tools/__init__.py +0 -15
- agent/providers/openai/tools/base.py +0 -79
- agent/providers/openai/tools/computer.py +0 -326
- agent/providers/openai/tools/manager.py +0 -106
- agent/providers/openai/types.py +0 -36
- agent/providers/openai/utils.py +0 -98
- agent/providers/uitars/__init__.py +0 -1
- agent/providers/uitars/clients/base.py +0 -35
- agent/providers/uitars/clients/mlxvlm.py +0 -263
- agent/providers/uitars/clients/oaicompat.py +0 -214
- agent/providers/uitars/loop.py +0 -660
- agent/providers/uitars/prompts.py +0 -63
- agent/providers/uitars/tools/__init__.py +0 -1
- agent/providers/uitars/tools/computer.py +0 -283
- agent/providers/uitars/tools/manager.py +0 -60
- agent/providers/uitars/utils.py +0 -264
- agent/telemetry.py +0 -21
- agent/ui/__main__.py +0 -15
- cua_agent-0.3.2.dist-info/METADATA +0 -295
- cua_agent-0.3.2.dist-info/RECORD +0 -87
- {cua_agent-0.3.2.dist-info → cua_agent-0.4.0b1.dist-info}/WHEEL +0 -0
- {cua_agent-0.3.2.dist-info → cua_agent-0.4.0b1.dist-info}/entry_points.txt +0 -0
|
@@ -1,263 +0,0 @@
|
|
|
1
|
-
"""MLX LVM client implementation."""
|
|
2
|
-
|
|
3
|
-
import io
|
|
4
|
-
import logging
|
|
5
|
-
import base64
|
|
6
|
-
import tempfile
|
|
7
|
-
import os
|
|
8
|
-
import re
|
|
9
|
-
import math
|
|
10
|
-
from typing import Dict, List, Optional, Any, cast, Tuple
|
|
11
|
-
from PIL import Image
|
|
12
|
-
|
|
13
|
-
from .base import BaseUITarsClient
|
|
14
|
-
import mlx.core as mx
|
|
15
|
-
from mlx_vlm import load, generate
|
|
16
|
-
from mlx_vlm.prompt_utils import apply_chat_template
|
|
17
|
-
from mlx_vlm.utils import load_config
|
|
18
|
-
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
19
|
-
|
|
20
|
-
logger = logging.getLogger(__name__)
|
|
21
|
-
|
|
22
|
-
# Constants for smart_resize
|
|
23
|
-
IMAGE_FACTOR = 28
|
|
24
|
-
MIN_PIXELS = 100 * 28 * 28
|
|
25
|
-
MAX_PIXELS = 16384 * 28 * 28
|
|
26
|
-
MAX_RATIO = 200
|
|
27
|
-
|
|
28
|
-
def round_by_factor(number: float, factor: int) -> int:
|
|
29
|
-
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
|
30
|
-
return round(number / factor) * factor
|
|
31
|
-
|
|
32
|
-
def ceil_by_factor(number: float, factor: int) -> int:
|
|
33
|
-
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
|
34
|
-
return math.ceil(number / factor) * factor
|
|
35
|
-
|
|
36
|
-
def floor_by_factor(number: float, factor: int) -> int:
|
|
37
|
-
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
|
38
|
-
return math.floor(number / factor) * factor
|
|
39
|
-
|
|
40
|
-
def smart_resize(
|
|
41
|
-
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
|
42
|
-
) -> tuple[int, int]:
|
|
43
|
-
"""
|
|
44
|
-
Rescales the image so that the following conditions are met:
|
|
45
|
-
|
|
46
|
-
1. Both dimensions (height and width) are divisible by 'factor'.
|
|
47
|
-
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
|
48
|
-
3. The aspect ratio of the image is maintained as closely as possible.
|
|
49
|
-
"""
|
|
50
|
-
if max(height, width) / min(height, width) > MAX_RATIO:
|
|
51
|
-
raise ValueError(
|
|
52
|
-
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
|
53
|
-
)
|
|
54
|
-
h_bar = max(factor, round_by_factor(height, factor))
|
|
55
|
-
w_bar = max(factor, round_by_factor(width, factor))
|
|
56
|
-
if h_bar * w_bar > max_pixels:
|
|
57
|
-
beta = math.sqrt((height * width) / max_pixels)
|
|
58
|
-
h_bar = floor_by_factor(height / beta, factor)
|
|
59
|
-
w_bar = floor_by_factor(width / beta, factor)
|
|
60
|
-
elif h_bar * w_bar < min_pixels:
|
|
61
|
-
beta = math.sqrt(min_pixels / (height * width))
|
|
62
|
-
h_bar = ceil_by_factor(height * beta, factor)
|
|
63
|
-
w_bar = ceil_by_factor(width * beta, factor)
|
|
64
|
-
return h_bar, w_bar
|
|
65
|
-
|
|
66
|
-
class MLXVLMUITarsClient(BaseUITarsClient):
|
|
67
|
-
"""MLX LVM client implementation class."""
|
|
68
|
-
|
|
69
|
-
def __init__(
|
|
70
|
-
self,
|
|
71
|
-
model: str = "mlx-community/UI-TARS-1.5-7B-4bit"
|
|
72
|
-
):
|
|
73
|
-
"""Initialize MLX LVM client.
|
|
74
|
-
|
|
75
|
-
Args:
|
|
76
|
-
model: Model name or path (defaults to mlx-community/UI-TARS-1.5-7B-4bit)
|
|
77
|
-
"""
|
|
78
|
-
# Load model and processor
|
|
79
|
-
model_obj, processor = load(
|
|
80
|
-
model,
|
|
81
|
-
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
|
|
82
|
-
)
|
|
83
|
-
self.config = load_config(model)
|
|
84
|
-
self.model = model_obj
|
|
85
|
-
self.processor = processor
|
|
86
|
-
self.model_name = model
|
|
87
|
-
|
|
88
|
-
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
|
|
89
|
-
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
|
|
90
|
-
|
|
91
|
-
Args:
|
|
92
|
-
text: Text containing box tokens
|
|
93
|
-
original_size: Original image size (width, height)
|
|
94
|
-
model_size: Model processed image size (width, height)
|
|
95
|
-
|
|
96
|
-
Returns:
|
|
97
|
-
Text with processed coordinates
|
|
98
|
-
"""
|
|
99
|
-
# Find all box tokens
|
|
100
|
-
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
|
|
101
|
-
|
|
102
|
-
def process_coords(match):
|
|
103
|
-
model_x, model_y = int(match.group(1)), int(match.group(2))
|
|
104
|
-
# Scale coordinates from model space to original image space
|
|
105
|
-
# Both original_size and model_size are in (width, height) format
|
|
106
|
-
new_x = int(model_x * original_size[0] / model_size[0]) # Width
|
|
107
|
-
new_y = int(model_y * original_size[1] / model_size[1]) # Height
|
|
108
|
-
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
|
|
109
|
-
|
|
110
|
-
return re.sub(box_pattern, process_coords, text)
|
|
111
|
-
|
|
112
|
-
async def run_interleaved(
|
|
113
|
-
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
|
114
|
-
) -> Dict[str, Any]:
|
|
115
|
-
"""Run interleaved chat completion.
|
|
116
|
-
|
|
117
|
-
Args:
|
|
118
|
-
messages: List of message dicts
|
|
119
|
-
system: System prompt
|
|
120
|
-
max_tokens: Optional max tokens override
|
|
121
|
-
|
|
122
|
-
Returns:
|
|
123
|
-
Response dict
|
|
124
|
-
"""
|
|
125
|
-
# Ensure the system message is included
|
|
126
|
-
if not any(msg.get("role") == "system" for msg in messages):
|
|
127
|
-
messages = [{"role": "system", "content": system}] + messages
|
|
128
|
-
|
|
129
|
-
# Create a deep copy of messages to avoid modifying the original
|
|
130
|
-
processed_messages = messages.copy()
|
|
131
|
-
|
|
132
|
-
# Extract images and process messages
|
|
133
|
-
images = []
|
|
134
|
-
original_sizes = {} # Track original sizes of images for coordinate mapping
|
|
135
|
-
model_sizes = {} # Track model processed sizes
|
|
136
|
-
image_index = 0
|
|
137
|
-
|
|
138
|
-
for msg_idx, msg in enumerate(messages):
|
|
139
|
-
content = msg.get("content", [])
|
|
140
|
-
if not isinstance(content, list):
|
|
141
|
-
continue
|
|
142
|
-
|
|
143
|
-
# Create a copy of the content list to modify
|
|
144
|
-
processed_content = []
|
|
145
|
-
|
|
146
|
-
for item_idx, item in enumerate(content):
|
|
147
|
-
if item.get("type") == "image_url":
|
|
148
|
-
image_url = item.get("image_url", {}).get("url", "")
|
|
149
|
-
pil_image = None
|
|
150
|
-
|
|
151
|
-
if image_url.startswith("data:image/"):
|
|
152
|
-
# Extract base64 data
|
|
153
|
-
base64_data = image_url.split(',')[1]
|
|
154
|
-
# Convert base64 to PIL Image
|
|
155
|
-
image_data = base64.b64decode(base64_data)
|
|
156
|
-
pil_image = Image.open(io.BytesIO(image_data))
|
|
157
|
-
else:
|
|
158
|
-
# Handle file path or URL
|
|
159
|
-
pil_image = Image.open(image_url)
|
|
160
|
-
|
|
161
|
-
# Store original image size for coordinate mapping
|
|
162
|
-
original_size = pil_image.size
|
|
163
|
-
original_sizes[image_index] = original_size
|
|
164
|
-
|
|
165
|
-
# Use smart_resize to determine model size
|
|
166
|
-
# Note: smart_resize expects (height, width) but PIL gives (width, height)
|
|
167
|
-
height, width = original_size[1], original_size[0]
|
|
168
|
-
new_height, new_width = smart_resize(height, width)
|
|
169
|
-
# Store model size in (width, height) format for consistent coordinate processing
|
|
170
|
-
model_sizes[image_index] = (new_width, new_height)
|
|
171
|
-
|
|
172
|
-
# Resize the image using the calculated dimensions from smart_resize
|
|
173
|
-
resized_image = pil_image.resize((new_width, new_height))
|
|
174
|
-
images.append(resized_image)
|
|
175
|
-
image_index += 1
|
|
176
|
-
|
|
177
|
-
# Copy items to processed content list
|
|
178
|
-
processed_content.append(item.copy())
|
|
179
|
-
|
|
180
|
-
# Update the processed message content
|
|
181
|
-
processed_messages[msg_idx] = msg.copy()
|
|
182
|
-
processed_messages[msg_idx]["content"] = processed_content
|
|
183
|
-
|
|
184
|
-
logger.info(f"resized {len(images)} from {original_sizes[0]} to {model_sizes[0]}")
|
|
185
|
-
|
|
186
|
-
# Process user text input with box coordinates after image processing
|
|
187
|
-
# Swap original_size and model_size arguments for inverse transformation
|
|
188
|
-
for msg_idx, msg in enumerate(processed_messages):
|
|
189
|
-
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
|
|
190
|
-
if "<|box_start|>" in msg.get("content") and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
|
191
|
-
orig_size = original_sizes[0]
|
|
192
|
-
model_size = model_sizes[0]
|
|
193
|
-
# Swap arguments to perform inverse transformation for user input
|
|
194
|
-
processed_messages[msg_idx]["content"] = self._process_coordinates(msg["content"], model_size, orig_size)
|
|
195
|
-
|
|
196
|
-
try:
|
|
197
|
-
# Format prompt according to model requirements using the processor directly
|
|
198
|
-
prompt = self.processor.apply_chat_template(
|
|
199
|
-
processed_messages,
|
|
200
|
-
tokenize=False,
|
|
201
|
-
add_generation_prompt=True
|
|
202
|
-
)
|
|
203
|
-
tokenizer = cast(PreTrainedTokenizer, self.processor)
|
|
204
|
-
|
|
205
|
-
print("generating response...")
|
|
206
|
-
|
|
207
|
-
# Generate response
|
|
208
|
-
text_content, usage = generate(
|
|
209
|
-
self.model,
|
|
210
|
-
tokenizer,
|
|
211
|
-
str(prompt),
|
|
212
|
-
images,
|
|
213
|
-
verbose=False,
|
|
214
|
-
max_tokens=max_tokens
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
from pprint import pprint
|
|
218
|
-
print("DEBUG - AGENT GENERATION --------")
|
|
219
|
-
pprint(text_content)
|
|
220
|
-
print("DEBUG - AGENT GENERATION --------")
|
|
221
|
-
except Exception as e:
|
|
222
|
-
logger.error(f"Error generating response: {str(e)}")
|
|
223
|
-
return {
|
|
224
|
-
"choices": [
|
|
225
|
-
{
|
|
226
|
-
"message": {
|
|
227
|
-
"role": "assistant",
|
|
228
|
-
"content": f"Error generating response: {str(e)}"
|
|
229
|
-
},
|
|
230
|
-
"finish_reason": "error"
|
|
231
|
-
}
|
|
232
|
-
],
|
|
233
|
-
"model": self.model_name,
|
|
234
|
-
"error": str(e)
|
|
235
|
-
}
|
|
236
|
-
|
|
237
|
-
# Process coordinates in the response back to original image space
|
|
238
|
-
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
|
239
|
-
# Get original image size and model size (using the first image)
|
|
240
|
-
orig_size = original_sizes[0]
|
|
241
|
-
model_size = model_sizes[0]
|
|
242
|
-
|
|
243
|
-
# Check if output contains box tokens that need processing
|
|
244
|
-
if "<|box_start|>" in text_content:
|
|
245
|
-
# Process coordinates from model space back to original image space
|
|
246
|
-
text_content = self._process_coordinates(text_content, orig_size, model_size)
|
|
247
|
-
|
|
248
|
-
# Format response to match OpenAI format
|
|
249
|
-
response = {
|
|
250
|
-
"choices": [
|
|
251
|
-
{
|
|
252
|
-
"message": {
|
|
253
|
-
"role": "assistant",
|
|
254
|
-
"content": text_content
|
|
255
|
-
},
|
|
256
|
-
"finish_reason": "stop"
|
|
257
|
-
}
|
|
258
|
-
],
|
|
259
|
-
"model": self.model_name,
|
|
260
|
-
"usage": usage
|
|
261
|
-
}
|
|
262
|
-
|
|
263
|
-
return response
|
|
@@ -1,214 +0,0 @@
|
|
|
1
|
-
"""OpenAI-compatible client implementation."""
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
import logging
|
|
5
|
-
from typing import Dict, List, Optional, Any
|
|
6
|
-
import aiohttp
|
|
7
|
-
import re
|
|
8
|
-
from .base import BaseUITarsClient
|
|
9
|
-
import asyncio
|
|
10
|
-
|
|
11
|
-
logger = logging.getLogger(__name__)
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
# OpenAI-compatible client for the UI_Tars
|
|
15
|
-
class OAICompatClient(BaseUITarsClient):
|
|
16
|
-
"""OpenAI-compatible API client implementation.
|
|
17
|
-
|
|
18
|
-
This client can be used with any service that implements the OpenAI API protocol, including:
|
|
19
|
-
- Huggingface Text Generation Interface endpoints
|
|
20
|
-
- vLLM
|
|
21
|
-
- LM Studio
|
|
22
|
-
- LocalAI
|
|
23
|
-
- Ollama (with OpenAI compatibility)
|
|
24
|
-
- Text Generation WebUI
|
|
25
|
-
- Any other service with OpenAI API compatibility
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
def __init__(
|
|
29
|
-
self,
|
|
30
|
-
api_key: Optional[str] = None,
|
|
31
|
-
model: str = "Qwen2.5-VL-7B-Instruct",
|
|
32
|
-
provider_base_url: Optional[str] = "http://localhost:8000/v1",
|
|
33
|
-
max_tokens: int = 4096,
|
|
34
|
-
temperature: float = 0.0,
|
|
35
|
-
):
|
|
36
|
-
"""Initialize the OpenAI-compatible client.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
api_key: Not used for local endpoints, usually set to "EMPTY"
|
|
40
|
-
model: Model name to use
|
|
41
|
-
provider_base_url: API base URL. Typically in the format "http://localhost:PORT/v1"
|
|
42
|
-
Examples:
|
|
43
|
-
- vLLM: "http://localhost:8000/v1"
|
|
44
|
-
- LM Studio: "http://localhost:1234/v1"
|
|
45
|
-
- LocalAI: "http://localhost:8080/v1"
|
|
46
|
-
- Ollama: "http://localhost:11434/v1"
|
|
47
|
-
max_tokens: Maximum tokens to generate
|
|
48
|
-
temperature: Generation temperature
|
|
49
|
-
"""
|
|
50
|
-
super().__init__(api_key=api_key or "EMPTY", model=model)
|
|
51
|
-
self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
|
|
52
|
-
self.model = model
|
|
53
|
-
self.provider_base_url = (
|
|
54
|
-
provider_base_url or "http://localhost:8000/v1"
|
|
55
|
-
) # Use default if None
|
|
56
|
-
self.max_tokens = max_tokens
|
|
57
|
-
self.temperature = temperature
|
|
58
|
-
|
|
59
|
-
def _extract_base64_image(self, text: str) -> Optional[str]:
|
|
60
|
-
"""Extract base64 image data from an HTML img tag."""
|
|
61
|
-
pattern = r'data:image/[^;]+;base64,([^"]+)'
|
|
62
|
-
match = re.search(pattern, text)
|
|
63
|
-
return match.group(1) if match else None
|
|
64
|
-
|
|
65
|
-
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
66
|
-
"""Create a loggable version of messages with image data truncated."""
|
|
67
|
-
loggable_messages = []
|
|
68
|
-
for msg in messages:
|
|
69
|
-
if isinstance(msg.get("content"), list):
|
|
70
|
-
new_content = []
|
|
71
|
-
for content in msg["content"]:
|
|
72
|
-
if content.get("type") == "image":
|
|
73
|
-
new_content.append(
|
|
74
|
-
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
|
|
75
|
-
)
|
|
76
|
-
else:
|
|
77
|
-
new_content.append(content)
|
|
78
|
-
loggable_messages.append({"role": msg["role"], "content": new_content})
|
|
79
|
-
else:
|
|
80
|
-
loggable_messages.append(msg)
|
|
81
|
-
return loggable_messages
|
|
82
|
-
|
|
83
|
-
async def run_interleaved(
|
|
84
|
-
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
|
85
|
-
) -> Dict[str, Any]:
|
|
86
|
-
"""Run interleaved chat completion.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
messages: List of message dicts
|
|
90
|
-
system: System prompt
|
|
91
|
-
max_tokens: Optional max tokens override
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
Response dict
|
|
95
|
-
"""
|
|
96
|
-
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
|
97
|
-
|
|
98
|
-
final_messages = [
|
|
99
|
-
{
|
|
100
|
-
"role": "system",
|
|
101
|
-
"content": [
|
|
102
|
-
{ "type": "text", "text": system }
|
|
103
|
-
]
|
|
104
|
-
}
|
|
105
|
-
]
|
|
106
|
-
|
|
107
|
-
# Process messages
|
|
108
|
-
for item in messages:
|
|
109
|
-
if isinstance(item, dict):
|
|
110
|
-
if isinstance(item["content"], list):
|
|
111
|
-
# Content is already in the correct format
|
|
112
|
-
final_messages.append(item)
|
|
113
|
-
else:
|
|
114
|
-
# Single string content, check for image
|
|
115
|
-
base64_img = self._extract_base64_image(item["content"])
|
|
116
|
-
if base64_img:
|
|
117
|
-
message = {
|
|
118
|
-
"role": item["role"],
|
|
119
|
-
"content": [
|
|
120
|
-
{
|
|
121
|
-
"type": "image_url",
|
|
122
|
-
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
123
|
-
}
|
|
124
|
-
],
|
|
125
|
-
}
|
|
126
|
-
else:
|
|
127
|
-
message = {
|
|
128
|
-
"role": item["role"],
|
|
129
|
-
"content": [{"type": "text", "text": item["content"]}],
|
|
130
|
-
}
|
|
131
|
-
final_messages.append(message)
|
|
132
|
-
else:
|
|
133
|
-
# String content, check for image
|
|
134
|
-
base64_img = self._extract_base64_image(item)
|
|
135
|
-
if base64_img:
|
|
136
|
-
message = {
|
|
137
|
-
"role": "user",
|
|
138
|
-
"content": [
|
|
139
|
-
{
|
|
140
|
-
"type": "image_url",
|
|
141
|
-
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
|
142
|
-
}
|
|
143
|
-
],
|
|
144
|
-
}
|
|
145
|
-
else:
|
|
146
|
-
message = {"role": "user", "content": [{"type": "text", "text": item}]}
|
|
147
|
-
final_messages.append(message)
|
|
148
|
-
|
|
149
|
-
payload = {
|
|
150
|
-
"model": self.model,
|
|
151
|
-
"messages": final_messages,
|
|
152
|
-
"max_tokens": max_tokens or self.max_tokens,
|
|
153
|
-
"temperature": self.temperature,
|
|
154
|
-
"top_p": 0.7,
|
|
155
|
-
}
|
|
156
|
-
|
|
157
|
-
try:
|
|
158
|
-
async with aiohttp.ClientSession() as session:
|
|
159
|
-
# Use default base URL if none provided
|
|
160
|
-
base_url = self.provider_base_url or "http://localhost:8000/v1"
|
|
161
|
-
|
|
162
|
-
# Check if the base URL already includes the chat/completions endpoint
|
|
163
|
-
|
|
164
|
-
endpoint_url = base_url
|
|
165
|
-
if not endpoint_url.endswith("/chat/completions"):
|
|
166
|
-
# If URL is RunPod format, make it OpenAI compatible
|
|
167
|
-
if endpoint_url.startswith("https://api.runpod.ai/v2/"):
|
|
168
|
-
# Extract RunPod endpoint ID
|
|
169
|
-
parts = endpoint_url.split("/")
|
|
170
|
-
if len(parts) >= 5:
|
|
171
|
-
runpod_id = parts[4]
|
|
172
|
-
endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
|
|
173
|
-
# If the URL ends with /v1, append /chat/completions
|
|
174
|
-
elif endpoint_url.endswith("/v1"):
|
|
175
|
-
endpoint_url = f"{endpoint_url}/chat/completions"
|
|
176
|
-
# If the URL doesn't end with /v1, make sure it has a proper structure
|
|
177
|
-
elif not endpoint_url.endswith("/"):
|
|
178
|
-
endpoint_url = f"{endpoint_url}/chat/completions"
|
|
179
|
-
else:
|
|
180
|
-
endpoint_url = f"{endpoint_url}chat/completions"
|
|
181
|
-
|
|
182
|
-
# Log the endpoint URL for debugging
|
|
183
|
-
logger.debug(f"Using endpoint URL: {endpoint_url}")
|
|
184
|
-
|
|
185
|
-
async with session.post(endpoint_url, headers=headers, json=payload) as response:
|
|
186
|
-
# Log the status and content type
|
|
187
|
-
logger.debug(f"Status: {response.status}")
|
|
188
|
-
logger.debug(f"Content-Type: {response.headers.get('Content-Type')}")
|
|
189
|
-
|
|
190
|
-
# Get the raw text of the response
|
|
191
|
-
response_text = await response.text()
|
|
192
|
-
logger.debug(f"Response content: {response_text}")
|
|
193
|
-
|
|
194
|
-
# if 503, then the endpoint is still warming up
|
|
195
|
-
if response.status == 503:
|
|
196
|
-
logger.error(f"Endpoint is still warming up, trying again in 30 seconds...")
|
|
197
|
-
await asyncio.sleep(30)
|
|
198
|
-
raise Exception(f"Endpoint is still warming up: {response_text}")
|
|
199
|
-
|
|
200
|
-
# Try to parse as JSON if the content type is appropriate
|
|
201
|
-
if "application/json" in response.headers.get('Content-Type', ''):
|
|
202
|
-
response_json = await response.json()
|
|
203
|
-
else:
|
|
204
|
-
raise Exception(f"Response is not JSON format")
|
|
205
|
-
|
|
206
|
-
if response.status != 200:
|
|
207
|
-
logger.error(f"Error in API call: {response_text}")
|
|
208
|
-
raise Exception(f"API error: {response_text}")
|
|
209
|
-
|
|
210
|
-
return response_json
|
|
211
|
-
|
|
212
|
-
except Exception as e:
|
|
213
|
-
logger.error(f"Error in API call: {str(e)}")
|
|
214
|
-
raise
|