aient 1.0.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/__init__.py +1 -0
- aient/core/.git +1 -0
- aient/core/__init__.py +1 -0
- aient/core/log_config.py +6 -0
- aient/core/models.py +227 -0
- aient/core/request.py +1361 -0
- aient/core/response.py +531 -0
- aient/core/test/test_base_api.py +17 -0
- aient/core/test/test_image.py +15 -0
- aient/core/test/test_payload.py +92 -0
- aient/core/utils.py +655 -0
- aient/models/__init__.py +9 -0
- aient/models/audio.py +63 -0
- aient/models/base.py +270 -0
- aient/models/chatgpt.py +856 -0
- aient/models/claude.py +640 -0
- aient/models/duckduckgo.py +241 -0
- aient/models/gemini.py +357 -0
- aient/models/groq.py +268 -0
- aient/models/vertex.py +420 -0
- aient/plugins/__init__.py +32 -0
- aient/plugins/arXiv.py +48 -0
- aient/plugins/config.py +178 -0
- aient/plugins/image.py +72 -0
- aient/plugins/registry.py +116 -0
- aient/plugins/run_python.py +156 -0
- aient/plugins/today.py +19 -0
- aient/plugins/websearch.py +393 -0
- aient/utils/__init__.py +0 -0
- aient/utils/prompt.py +143 -0
- aient/utils/scripts.py +235 -0
- aient-1.0.29.dist-info/METADATA +119 -0
- aient-1.0.29.dist-info/RECORD +36 -0
- aient-1.0.29.dist-info/WHEEL +5 -0
- aient-1.0.29.dist-info/licenses/LICENSE +7 -0
- aient-1.0.29.dist-info/top_level.txt +1 -0
aient/core/utils.py
ADDED
@@ -0,0 +1,655 @@
|
|
1
|
+
import re
|
2
|
+
import io
|
3
|
+
import os
|
4
|
+
import json
|
5
|
+
import httpx
|
6
|
+
import base64
|
7
|
+
import asyncio
|
8
|
+
import urllib.parse
|
9
|
+
from time import time
|
10
|
+
from PIL import Image
|
11
|
+
from fastapi import HTTPException
|
12
|
+
from urllib.parse import urlparse
|
13
|
+
from collections import defaultdict
|
14
|
+
|
15
|
+
from .log_config import logger
|
16
|
+
|
17
|
+
def get_model_dict(provider):
|
18
|
+
model_dict = {}
|
19
|
+
for model in provider['model']:
|
20
|
+
if type(model) == str:
|
21
|
+
model_dict[model] = model
|
22
|
+
if isinstance(model, dict):
|
23
|
+
model_dict.update({new: old for old, new in model.items()})
|
24
|
+
return model_dict
|
25
|
+
|
26
|
+
class BaseAPI:
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
api_url: str = "https://api.openai.com/v1/chat/completions",
|
30
|
+
):
|
31
|
+
if api_url == "":
|
32
|
+
api_url = "https://api.openai.com/v1/chat/completions"
|
33
|
+
self.source_api_url: str = api_url
|
34
|
+
from urllib.parse import urlparse, urlunparse
|
35
|
+
parsed_url = urlparse(self.source_api_url)
|
36
|
+
# print("parsed_url", parsed_url)
|
37
|
+
if parsed_url.scheme == "":
|
38
|
+
raise Exception("Error: API_URL is not set")
|
39
|
+
if parsed_url.path != '/':
|
40
|
+
before_v1 = parsed_url.path.split("chat/completions")[0]
|
41
|
+
if not before_v1.endswith("/"):
|
42
|
+
before_v1 = before_v1 + "/"
|
43
|
+
else:
|
44
|
+
before_v1 = ""
|
45
|
+
self.base_url: str = urlunparse(parsed_url[:2] + ("",) + ("",) * 3)
|
46
|
+
self.v1_url: str = urlunparse(parsed_url[:2]+ (before_v1,) + ("",) * 3)
|
47
|
+
self.v1_models: str = urlunparse(parsed_url[:2] + (before_v1 + "models",) + ("",) * 3)
|
48
|
+
self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "chat/completions",) + ("",) * 3)
|
49
|
+
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "images/generations",) + ("",) * 3)
|
50
|
+
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "audio/transcriptions",) + ("",) * 3)
|
51
|
+
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "moderations",) + ("",) * 3)
|
52
|
+
self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "embeddings",) + ("",) * 3)
|
53
|
+
self.audio_speech: str = urlunparse(parsed_url[:2] + (before_v1 + "audio/speech",) + ("",) * 3)
|
54
|
+
|
55
|
+
if parsed_url.hostname == "generativelanguage.googleapis.com":
|
56
|
+
self.base_url = api_url
|
57
|
+
self.v1_url = api_url
|
58
|
+
self.chat_url = api_url
|
59
|
+
self.embeddings = api_url
|
60
|
+
|
61
|
+
def get_engine(provider, endpoint=None, original_model=""):
|
62
|
+
parsed_url = urlparse(provider['base_url'])
|
63
|
+
# print("parsed_url", parsed_url)
|
64
|
+
engine = None
|
65
|
+
stream = None
|
66
|
+
if parsed_url.path.endswith("/v1beta") or parsed_url.path.endswith("/v1") or parsed_url.netloc == 'generativelanguage.googleapis.com':
|
67
|
+
engine = "gemini"
|
68
|
+
elif parsed_url.netloc.rstrip('/').endswith('aiplatform.googleapis.com'):
|
69
|
+
engine = "vertex"
|
70
|
+
elif parsed_url.netloc.rstrip('/').endswith('openai.azure.com') or parsed_url.netloc.rstrip('/').endswith('services.ai.azure.com'):
|
71
|
+
engine = "azure"
|
72
|
+
elif parsed_url.netloc == 'api.cloudflare.com':
|
73
|
+
engine = "cloudflare"
|
74
|
+
elif parsed_url.netloc == 'api.anthropic.com' or parsed_url.path.endswith("v1/messages"):
|
75
|
+
engine = "claude"
|
76
|
+
elif parsed_url.netloc == 'api.cohere.com':
|
77
|
+
engine = "cohere"
|
78
|
+
stream = True
|
79
|
+
else:
|
80
|
+
engine = "gpt"
|
81
|
+
|
82
|
+
original_model = original_model.lower()
|
83
|
+
if original_model \
|
84
|
+
and "claude" not in original_model \
|
85
|
+
and "gpt" not in original_model \
|
86
|
+
and "deepseek" not in original_model \
|
87
|
+
and "o1" not in original_model \
|
88
|
+
and "o3" not in original_model \
|
89
|
+
and "gemini" not in original_model \
|
90
|
+
and "learnlm" not in original_model \
|
91
|
+
and "grok" not in original_model \
|
92
|
+
and parsed_url.netloc != 'api.cloudflare.com' \
|
93
|
+
and parsed_url.netloc != 'api.cohere.com':
|
94
|
+
engine = "openrouter"
|
95
|
+
|
96
|
+
if "claude" in original_model and engine == "vertex":
|
97
|
+
engine = "vertex-claude"
|
98
|
+
|
99
|
+
if "gemini" in original_model and engine == "vertex":
|
100
|
+
engine = "vertex-gemini"
|
101
|
+
|
102
|
+
if provider.get("engine"):
|
103
|
+
engine = provider["engine"]
|
104
|
+
|
105
|
+
if endpoint == "/v1/images/generations" or "stable-diffusion" in original_model:
|
106
|
+
engine = "dalle"
|
107
|
+
stream = False
|
108
|
+
|
109
|
+
if endpoint == "/v1/audio/transcriptions":
|
110
|
+
engine = "whisper"
|
111
|
+
stream = False
|
112
|
+
|
113
|
+
if endpoint == "/v1/moderations":
|
114
|
+
engine = "moderation"
|
115
|
+
stream = False
|
116
|
+
|
117
|
+
if endpoint == "/v1/embeddings":
|
118
|
+
engine = "embedding"
|
119
|
+
|
120
|
+
if endpoint == "/v1/audio/speech":
|
121
|
+
engine = "tts"
|
122
|
+
stream = False
|
123
|
+
|
124
|
+
return engine, stream
|
125
|
+
|
126
|
+
def get_proxy(proxy, client_config = {}):
|
127
|
+
if proxy:
|
128
|
+
# 解析代理URL
|
129
|
+
parsed = urlparse(proxy)
|
130
|
+
scheme = parsed.scheme.rstrip('h')
|
131
|
+
|
132
|
+
if scheme == 'socks5':
|
133
|
+
try:
|
134
|
+
from httpx_socks import AsyncProxyTransport
|
135
|
+
proxy = proxy.replace('socks5h://', 'socks5://')
|
136
|
+
transport = AsyncProxyTransport.from_url(proxy)
|
137
|
+
client_config["transport"] = transport
|
138
|
+
# print("proxy", proxy)
|
139
|
+
except ImportError:
|
140
|
+
logger.error("httpx-socks package is required for SOCKS proxy support")
|
141
|
+
raise ImportError("Please install httpx-socks package for SOCKS proxy support: pip install httpx-socks")
|
142
|
+
else:
|
143
|
+
client_config["proxies"] = {
|
144
|
+
"http://": proxy,
|
145
|
+
"https://": proxy
|
146
|
+
}
|
147
|
+
return client_config
|
148
|
+
|
149
|
+
def update_initial_model(provider):
|
150
|
+
try:
|
151
|
+
engine, stream_mode = get_engine(provider, endpoint=None, original_model="")
|
152
|
+
# print("engine", engine, provider)
|
153
|
+
api_url = provider['base_url']
|
154
|
+
api = provider['api']
|
155
|
+
proxy = safe_get(provider, "preferences", "proxy", default=None)
|
156
|
+
client_config = get_proxy(proxy)
|
157
|
+
if engine == "gemini":
|
158
|
+
url = "https://generativelanguage.googleapis.com/v1beta/models"
|
159
|
+
params = {"key": api}
|
160
|
+
with httpx.Client(**client_config) as client:
|
161
|
+
response = client.get(url, params=params)
|
162
|
+
|
163
|
+
original_models = response.json()
|
164
|
+
if original_models.get("error"):
|
165
|
+
raise Exception({"error": original_models.get("error"), "endpoint": url, "api": api})
|
166
|
+
|
167
|
+
models = {"data": []}
|
168
|
+
for model in original_models["models"]:
|
169
|
+
models["data"].append({
|
170
|
+
"id": model["name"].split("models/")[-1],
|
171
|
+
})
|
172
|
+
else:
|
173
|
+
endpoint = BaseAPI(api_url=api_url)
|
174
|
+
endpoint_models_url = endpoint.v1_models
|
175
|
+
if isinstance(api, list):
|
176
|
+
api = api[0]
|
177
|
+
headers = {"Authorization": f"Bearer {api}"}
|
178
|
+
response = httpx.get(
|
179
|
+
endpoint_models_url,
|
180
|
+
headers=headers,
|
181
|
+
**client_config
|
182
|
+
)
|
183
|
+
models = response.json()
|
184
|
+
if models.get("error"):
|
185
|
+
logger.error({"error": models.get("error"), "endpoint": endpoint_models_url, "api": api})
|
186
|
+
return []
|
187
|
+
|
188
|
+
# print(models)
|
189
|
+
models_list = models["data"]
|
190
|
+
models_id = [model["id"] for model in models_list]
|
191
|
+
set_models = set()
|
192
|
+
for model_item in models_id:
|
193
|
+
set_models.add(model_item)
|
194
|
+
models_id = list(set_models)
|
195
|
+
# print(models_id)
|
196
|
+
return models_id
|
197
|
+
except Exception as e:
|
198
|
+
# print("error:", e)
|
199
|
+
import traceback
|
200
|
+
traceback.print_exc()
|
201
|
+
return []
|
202
|
+
|
203
|
+
def safe_get(data, *keys, default=None):
|
204
|
+
for key in keys:
|
205
|
+
try:
|
206
|
+
data = data[key] if isinstance(data, (dict, list)) else data.get(key)
|
207
|
+
except (KeyError, IndexError, AttributeError, TypeError):
|
208
|
+
return default
|
209
|
+
if not data:
|
210
|
+
return default
|
211
|
+
return data
|
212
|
+
|
213
|
+
def parse_rate_limit(limit_string):
|
214
|
+
# 定义时间单位到秒的映射
|
215
|
+
time_units = {
|
216
|
+
's': 1, 'sec': 1, 'second': 1,
|
217
|
+
'm': 60, 'min': 60, 'minute': 60,
|
218
|
+
'h': 3600, 'hr': 3600, 'hour': 3600,
|
219
|
+
'd': 86400, 'day': 86400,
|
220
|
+
'mo': 2592000, 'month': 2592000,
|
221
|
+
'y': 31536000, 'year': 31536000
|
222
|
+
}
|
223
|
+
|
224
|
+
# 处理多个限制条件
|
225
|
+
limits = []
|
226
|
+
for limit in limit_string.split(','):
|
227
|
+
limit = limit.strip()
|
228
|
+
# 使用正则表达式匹配数字和单位
|
229
|
+
match = re.match(r'^(\d+)/(\w+)$', limit)
|
230
|
+
if not match:
|
231
|
+
raise ValueError(f"Invalid rate limit format: {limit}")
|
232
|
+
|
233
|
+
count, unit = match.groups()
|
234
|
+
count = int(count)
|
235
|
+
|
236
|
+
# 转换单位到秒
|
237
|
+
if unit not in time_units:
|
238
|
+
raise ValueError(f"Unknown time unit: {unit}")
|
239
|
+
|
240
|
+
seconds = time_units[unit]
|
241
|
+
limits.append((count, seconds))
|
242
|
+
|
243
|
+
return limits
|
244
|
+
|
245
|
+
class ThreadSafeCircularList:
|
246
|
+
def __init__(self, items = [], rate_limit={"default": "999999/min"}, schedule_algorithm="round_robin"):
|
247
|
+
if schedule_algorithm == "random":
|
248
|
+
import random
|
249
|
+
self.items = random.sample(items, len(items))
|
250
|
+
self.schedule_algorithm = "random"
|
251
|
+
elif schedule_algorithm == "round_robin":
|
252
|
+
self.items = items
|
253
|
+
self.schedule_algorithm = "round_robin"
|
254
|
+
elif schedule_algorithm == "fixed_priority":
|
255
|
+
self.items = items
|
256
|
+
self.schedule_algorithm = "fixed_priority"
|
257
|
+
else:
|
258
|
+
self.items = items
|
259
|
+
logger.warning(f"Unknown schedule algorithm: {schedule_algorithm}, use (round_robin, random, fixed_priority) instead")
|
260
|
+
self.schedule_algorithm = "round_robin"
|
261
|
+
self.index = 0
|
262
|
+
self.lock = asyncio.Lock()
|
263
|
+
# 修改为二级字典,第一级是item,第二级是model
|
264
|
+
self.requests = defaultdict(lambda: defaultdict(list))
|
265
|
+
self.cooling_until = defaultdict(float)
|
266
|
+
self.rate_limits = {}
|
267
|
+
if isinstance(rate_limit, dict):
|
268
|
+
for rate_limit_model, rate_limit_value in rate_limit.items():
|
269
|
+
self.rate_limits[rate_limit_model] = parse_rate_limit(rate_limit_value)
|
270
|
+
elif isinstance(rate_limit, str):
|
271
|
+
self.rate_limits["default"] = parse_rate_limit(rate_limit)
|
272
|
+
else:
|
273
|
+
logger.error(f"Error ThreadSafeCircularList: Unknown rate_limit type: {type(rate_limit)}, rate_limit: {rate_limit}")
|
274
|
+
|
275
|
+
async def set_cooling(self, item: str, cooling_time: int = 60):
|
276
|
+
"""设置某个 item 进入冷却状态
|
277
|
+
|
278
|
+
Args:
|
279
|
+
item: 需要冷却的 item
|
280
|
+
cooling_time: 冷却时间(秒),默认60秒
|
281
|
+
"""
|
282
|
+
if item == None:
|
283
|
+
return
|
284
|
+
now = time()
|
285
|
+
async with self.lock:
|
286
|
+
self.cooling_until[item] = now + cooling_time
|
287
|
+
# 清空该 item 的请求记录
|
288
|
+
# self.requests[item] = []
|
289
|
+
logger.warning(f"API key {item} 已进入冷却状态,冷却时间 {cooling_time} 秒")
|
290
|
+
|
291
|
+
async def is_rate_limited(self, item, model: str = None) -> bool:
|
292
|
+
now = time()
|
293
|
+
# 检查是否在冷却中
|
294
|
+
if now < self.cooling_until[item]:
|
295
|
+
return True
|
296
|
+
|
297
|
+
# 获取适用的速率限制
|
298
|
+
|
299
|
+
if model:
|
300
|
+
model_key = model
|
301
|
+
else:
|
302
|
+
model_key = "default"
|
303
|
+
|
304
|
+
rate_limit = None
|
305
|
+
# 先尝试精确匹配
|
306
|
+
if model and model in self.rate_limits:
|
307
|
+
rate_limit = self.rate_limits[model]
|
308
|
+
else:
|
309
|
+
# 如果没有精确匹配,尝试模糊匹配
|
310
|
+
for limit_model in self.rate_limits:
|
311
|
+
if limit_model != "default" and model and limit_model in model:
|
312
|
+
rate_limit = self.rate_limits[limit_model]
|
313
|
+
break
|
314
|
+
|
315
|
+
# 如果都没匹配到,使用默认值
|
316
|
+
if rate_limit is None:
|
317
|
+
rate_limit = self.rate_limits.get("default", [(999999, 60)]) # 默认限制
|
318
|
+
|
319
|
+
# 检查所有速率限制条件
|
320
|
+
for limit_count, limit_period in rate_limit:
|
321
|
+
# 使用特定模型的请求记录进行计算
|
322
|
+
recent_requests = sum(1 for req in self.requests[item][model_key] if req > now - limit_period)
|
323
|
+
if recent_requests >= limit_count:
|
324
|
+
logger.warning(f"API key {item} 对模型 {model_key} 已达到速率限制 ({limit_count}/{limit_period}秒)")
|
325
|
+
return True
|
326
|
+
|
327
|
+
# 清理太旧的请求记录
|
328
|
+
max_period = max(period for _, period in rate_limit)
|
329
|
+
self.requests[item][model_key] = [req for req in self.requests[item][model_key] if req > now - max_period]
|
330
|
+
|
331
|
+
# 记录新的请求
|
332
|
+
self.requests[item][model_key].append(now)
|
333
|
+
return False
|
334
|
+
|
335
|
+
async def next(self, model: str = None):
|
336
|
+
async with self.lock:
|
337
|
+
if self.schedule_algorithm == "fixed_priority":
|
338
|
+
self.index = 0
|
339
|
+
start_index = self.index
|
340
|
+
while True:
|
341
|
+
item = self.items[self.index]
|
342
|
+
self.index = (self.index + 1) % len(self.items)
|
343
|
+
|
344
|
+
if not await self.is_rate_limited(item, model):
|
345
|
+
return item
|
346
|
+
|
347
|
+
# 如果已经检查了所有的 API key 都被限制
|
348
|
+
if self.index == start_index:
|
349
|
+
logger.warning(f"All API keys are rate limited!")
|
350
|
+
raise HTTPException(status_code=429, detail="Too many requests")
|
351
|
+
|
352
|
+
async def after_next_current(self):
|
353
|
+
# 返回当前取出的 API,因为已经调用了 next,所以当前API应该是上一个
|
354
|
+
if len(self.items) == 0:
|
355
|
+
return None
|
356
|
+
async with self.lock:
|
357
|
+
item = self.items[(self.index - 1) % len(self.items)]
|
358
|
+
return item
|
359
|
+
|
360
|
+
def get_items_count(self) -> int:
|
361
|
+
"""返回列表中的项目数量
|
362
|
+
|
363
|
+
Returns:
|
364
|
+
int: items列表的长度
|
365
|
+
"""
|
366
|
+
return len(self.items)
|
367
|
+
|
368
|
+
def circular_list_encoder(obj):
|
369
|
+
if isinstance(obj, ThreadSafeCircularList):
|
370
|
+
return obj.to_dict()
|
371
|
+
raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
|
372
|
+
|
373
|
+
provider_api_circular_list = defaultdict(ThreadSafeCircularList)
|
374
|
+
|
375
|
+
# 【GCP-Vertex AI 目前有這些區域可用】 https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude?hl=zh_cn
|
376
|
+
# https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations?hl=zh-cn#available-regions
|
377
|
+
|
378
|
+
# c3.5s
|
379
|
+
# us-east5
|
380
|
+
# europe-west1
|
381
|
+
|
382
|
+
# c3s
|
383
|
+
# us-east5
|
384
|
+
# us-central1
|
385
|
+
# asia-southeast1
|
386
|
+
|
387
|
+
# c3o
|
388
|
+
# us-east5
|
389
|
+
|
390
|
+
# c3h
|
391
|
+
# us-east5
|
392
|
+
# us-central1
|
393
|
+
# europe-west1
|
394
|
+
# europe-west4
|
395
|
+
|
396
|
+
|
397
|
+
c35s = ThreadSafeCircularList(["us-east5", "europe-west1"])
|
398
|
+
c3s = ThreadSafeCircularList(["us-east5", "us-central1", "asia-southeast1"])
|
399
|
+
c3o = ThreadSafeCircularList(["us-east5"])
|
400
|
+
c3h = ThreadSafeCircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])
|
401
|
+
gemini1 = ThreadSafeCircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
|
402
|
+
gemini2 = ThreadSafeCircularList(["us-central1"])
|
403
|
+
|
404
|
+
|
405
|
+
|
406
|
+
# end_of_line = "\n\r\n"
|
407
|
+
# end_of_line = "\r\n"
|
408
|
+
# end_of_line = "\n\r"
|
409
|
+
end_of_line = "\n\n"
|
410
|
+
# end_of_line = "\r"
|
411
|
+
# end_of_line = "\n"
|
412
|
+
|
413
|
+
import random
|
414
|
+
import string
|
415
|
+
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0, reasoning_content=None):
|
416
|
+
random.seed(timestamp)
|
417
|
+
random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
|
418
|
+
|
419
|
+
delta_content = {"role": "assistant", "content": content} if content else {}
|
420
|
+
if reasoning_content:
|
421
|
+
delta_content = {"role": "assistant", "content": "", "reasoning_content": reasoning_content}
|
422
|
+
|
423
|
+
sample_data = {
|
424
|
+
"id": f"chatcmpl-{random_str}",
|
425
|
+
"object": "chat.completion.chunk",
|
426
|
+
"created": timestamp,
|
427
|
+
"model": model,
|
428
|
+
"choices": [
|
429
|
+
{
|
430
|
+
"index": 0,
|
431
|
+
"delta": delta_content,
|
432
|
+
"logprobs": None,
|
433
|
+
"finish_reason": None if content else "stop"
|
434
|
+
}
|
435
|
+
],
|
436
|
+
"usage": None,
|
437
|
+
"system_fingerprint": "fp_d576307f90",
|
438
|
+
}
|
439
|
+
if function_call_content:
|
440
|
+
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]}
|
441
|
+
if tools_id and function_call_name:
|
442
|
+
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id": tools_id,"type":"function","function":{"name": function_call_name, "arguments":""}}]}
|
443
|
+
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
|
444
|
+
if role:
|
445
|
+
sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
|
446
|
+
if total_tokens:
|
447
|
+
total_tokens = prompt_tokens + completion_tokens
|
448
|
+
sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
|
449
|
+
sample_data["choices"] = []
|
450
|
+
json_data = json.dumps(sample_data, ensure_ascii=False)
|
451
|
+
|
452
|
+
# 构建SSE响应
|
453
|
+
sse_response = f"data: {json_data}" + end_of_line
|
454
|
+
|
455
|
+
return sse_response
|
456
|
+
|
457
|
+
async def generate_no_stream_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0):
|
458
|
+
random.seed(timestamp)
|
459
|
+
random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
|
460
|
+
sample_data = {
|
461
|
+
"id": f"chatcmpl-{random_str}",
|
462
|
+
"object": "chat.completion",
|
463
|
+
"created": timestamp,
|
464
|
+
"model": model,
|
465
|
+
"choices": [
|
466
|
+
{
|
467
|
+
"index": 0,
|
468
|
+
"message": {
|
469
|
+
"role": role,
|
470
|
+
"content": content,
|
471
|
+
"refusal": None
|
472
|
+
},
|
473
|
+
"logprobs": None,
|
474
|
+
"finish_reason": "stop"
|
475
|
+
}
|
476
|
+
],
|
477
|
+
"usage": None,
|
478
|
+
"system_fingerprint": "fp_a7d06e42a7"
|
479
|
+
}
|
480
|
+
|
481
|
+
if function_call_name:
|
482
|
+
if not tools_id:
|
483
|
+
tools_id = f"call_{random_str}"
|
484
|
+
sample_data = {
|
485
|
+
"id": f"chatcmpl-{random_str}",
|
486
|
+
"object": "chat.completion",
|
487
|
+
"created": timestamp,
|
488
|
+
"model": model,
|
489
|
+
"choices": [
|
490
|
+
{
|
491
|
+
"index": 0,
|
492
|
+
"message": {
|
493
|
+
"role": "assistant",
|
494
|
+
"content": None,
|
495
|
+
"tool_calls": [
|
496
|
+
{
|
497
|
+
"id": tools_id,
|
498
|
+
"type": "function",
|
499
|
+
"function": {
|
500
|
+
"name": function_call_name,
|
501
|
+
"arguments": json.dumps(function_call_content, ensure_ascii=False)
|
502
|
+
}
|
503
|
+
}
|
504
|
+
],
|
505
|
+
"refusal": None
|
506
|
+
},
|
507
|
+
"logprobs": None,
|
508
|
+
"finish_reason": "tool_calls"
|
509
|
+
}
|
510
|
+
],
|
511
|
+
"usage": None,
|
512
|
+
"service_tier": "default",
|
513
|
+
"system_fingerprint": "fp_4691090a87"
|
514
|
+
}
|
515
|
+
|
516
|
+
if total_tokens:
|
517
|
+
total_tokens = prompt_tokens + completion_tokens
|
518
|
+
sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
|
519
|
+
|
520
|
+
json_data = json.dumps(sample_data, ensure_ascii=False)
|
521
|
+
|
522
|
+
return json_data
|
523
|
+
|
524
|
+
def get_image_format(file_content):
|
525
|
+
try:
|
526
|
+
img = Image.open(io.BytesIO(file_content))
|
527
|
+
return img.format.lower()
|
528
|
+
except:
|
529
|
+
return None
|
530
|
+
|
531
|
+
def encode_image(image_path):
|
532
|
+
with open(image_path, "rb") as image_file:
|
533
|
+
file_content = image_file.read()
|
534
|
+
img_format = get_image_format(file_content)
|
535
|
+
if not img_format:
|
536
|
+
raise ValueError("无法识别的图片格式")
|
537
|
+
base64_encoded = base64.b64encode(file_content).decode('utf-8')
|
538
|
+
|
539
|
+
if img_format == 'png':
|
540
|
+
return f"data:image/png;base64,{base64_encoded}"
|
541
|
+
elif img_format in ['jpg', 'jpeg']:
|
542
|
+
return f"data:image/jpeg;base64,{base64_encoded}"
|
543
|
+
else:
|
544
|
+
raise ValueError(f"不支持的图片格式: {img_format}")
|
545
|
+
|
546
|
+
async def get_doc_from_url(url):
|
547
|
+
filename = urllib.parse.unquote(url.split("/")[-1])
|
548
|
+
transport = httpx.AsyncHTTPTransport(
|
549
|
+
http2=True,
|
550
|
+
verify=False,
|
551
|
+
retries=1
|
552
|
+
)
|
553
|
+
async with httpx.AsyncClient(transport=transport) as client:
|
554
|
+
try:
|
555
|
+
response = await client.get(
|
556
|
+
url,
|
557
|
+
timeout=30.0
|
558
|
+
)
|
559
|
+
with open(filename, 'wb') as f:
|
560
|
+
f.write(response.content)
|
561
|
+
|
562
|
+
except httpx.RequestError as e:
|
563
|
+
print(f"An error occurred while requesting {e.request.url!r}.")
|
564
|
+
|
565
|
+
return filename
|
566
|
+
|
567
|
+
async def get_encode_image(image_url):
|
568
|
+
filename = await get_doc_from_url(image_url)
|
569
|
+
image_path = os.getcwd() + "/" + filename
|
570
|
+
base64_image = encode_image(image_path)
|
571
|
+
os.remove(image_path)
|
572
|
+
return base64_image
|
573
|
+
|
574
|
+
# from PIL import Image
|
575
|
+
# import io
|
576
|
+
# def validate_image(image_data, image_type):
|
577
|
+
# try:
|
578
|
+
# decoded_image = base64.b64decode(image_data)
|
579
|
+
# image = Image.open(io.BytesIO(decoded_image))
|
580
|
+
|
581
|
+
# # 检查图片格式是否与声明的类型匹配
|
582
|
+
# # print("image.format", image.format)
|
583
|
+
# if image_type == "image/png" and image.format != "PNG":
|
584
|
+
# raise ValueError("Image is not a valid PNG")
|
585
|
+
# elif image_type == "image/jpeg" and image.format not in ["JPEG", "JPG"]:
|
586
|
+
# raise ValueError("Image is not a valid JPEG")
|
587
|
+
|
588
|
+
# # 如果没有异常,则图片有效
|
589
|
+
# return True
|
590
|
+
# except Exception as e:
|
591
|
+
# print(f"Image validation failed: {str(e)}")
|
592
|
+
# return False
|
593
|
+
|
594
|
+
async def get_image_message(base64_image, engine = None):
|
595
|
+
if base64_image.startswith("http"):
|
596
|
+
base64_image = await get_encode_image(base64_image)
|
597
|
+
colon_index = base64_image.index(":")
|
598
|
+
semicolon_index = base64_image.index(";")
|
599
|
+
image_type = base64_image[colon_index + 1:semicolon_index]
|
600
|
+
|
601
|
+
if image_type == "image/webp":
|
602
|
+
# 将webp转换为png
|
603
|
+
|
604
|
+
# 解码base64获取图片数据
|
605
|
+
image_data = base64.b64decode(base64_image.split(",")[1])
|
606
|
+
|
607
|
+
# 使用PIL打开webp图片
|
608
|
+
image = Image.open(io.BytesIO(image_data))
|
609
|
+
|
610
|
+
# 转换为PNG格式
|
611
|
+
png_buffer = io.BytesIO()
|
612
|
+
image.save(png_buffer, format="PNG")
|
613
|
+
png_base64 = base64.b64encode(png_buffer.getvalue()).decode('utf-8')
|
614
|
+
|
615
|
+
# 返回PNG格式的base64
|
616
|
+
base64_image = f"data:image/png;base64,{png_base64}"
|
617
|
+
image_type = "image/png"
|
618
|
+
|
619
|
+
if "gpt" == engine or "openrouter" == engine or "azure" == engine:
|
620
|
+
return {
|
621
|
+
"type": "image_url",
|
622
|
+
"image_url": {
|
623
|
+
"url": base64_image,
|
624
|
+
}
|
625
|
+
}
|
626
|
+
if "claude" == engine or "vertex-claude" == engine:
|
627
|
+
# if not validate_image(base64_image.split(",")[1], image_type):
|
628
|
+
# raise ValueError(f"Invalid image format. Expected {image_type}")
|
629
|
+
return {
|
630
|
+
"type": "image",
|
631
|
+
"source": {
|
632
|
+
"type": "base64",
|
633
|
+
"media_type": image_type,
|
634
|
+
"data": base64_image.split(",")[1],
|
635
|
+
}
|
636
|
+
}
|
637
|
+
if "gemini" == engine or "vertex-gemini" == engine:
|
638
|
+
return {
|
639
|
+
"inlineData": {
|
640
|
+
"mimeType": image_type,
|
641
|
+
"data": base64_image.split(",")[1],
|
642
|
+
}
|
643
|
+
}
|
644
|
+
raise ValueError("Unknown engine")
|
645
|
+
|
646
|
+
async def get_text_message(message, engine = None):
|
647
|
+
if "gpt" == engine or "claude" == engine or "openrouter" == engine or "vertex-claude" == engine or "azure" == engine:
|
648
|
+
return {"type": "text", "text": message}
|
649
|
+
if "gemini" == engine or "vertex-gemini" == engine:
|
650
|
+
return {"text": message}
|
651
|
+
if engine == "cloudflare":
|
652
|
+
return message
|
653
|
+
if engine == "cohere":
|
654
|
+
return message
|
655
|
+
raise ValueError("Unknown engine")
|
aient/models/__init__.py
ADDED