aient 1.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aient/core/utils.py ADDED
@@ -0,0 +1,655 @@
1
+ import re
2
+ import io
3
+ import os
4
+ import json
5
+ import httpx
6
+ import base64
7
+ import asyncio
8
+ import urllib.parse
9
+ from time import time
10
+ from PIL import Image
11
+ from fastapi import HTTPException
12
+ from urllib.parse import urlparse
13
+ from collections import defaultdict
14
+
15
+ from .log_config import logger
16
+
17
+ def get_model_dict(provider):
18
+ model_dict = {}
19
+ for model in provider['model']:
20
+ if type(model) == str:
21
+ model_dict[model] = model
22
+ if isinstance(model, dict):
23
+ model_dict.update({new: old for old, new in model.items()})
24
+ return model_dict
25
+
26
+ class BaseAPI:
27
+ def __init__(
28
+ self,
29
+ api_url: str = "https://api.openai.com/v1/chat/completions",
30
+ ):
31
+ if api_url == "":
32
+ api_url = "https://api.openai.com/v1/chat/completions"
33
+ self.source_api_url: str = api_url
34
+ from urllib.parse import urlparse, urlunparse
35
+ parsed_url = urlparse(self.source_api_url)
36
+ # print("parsed_url", parsed_url)
37
+ if parsed_url.scheme == "":
38
+ raise Exception("Error: API_URL is not set")
39
+ if parsed_url.path != '/':
40
+ before_v1 = parsed_url.path.split("chat/completions")[0]
41
+ if not before_v1.endswith("/"):
42
+ before_v1 = before_v1 + "/"
43
+ else:
44
+ before_v1 = ""
45
+ self.base_url: str = urlunparse(parsed_url[:2] + ("",) + ("",) * 3)
46
+ self.v1_url: str = urlunparse(parsed_url[:2]+ (before_v1,) + ("",) * 3)
47
+ self.v1_models: str = urlunparse(parsed_url[:2] + (before_v1 + "models",) + ("",) * 3)
48
+ self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "chat/completions",) + ("",) * 3)
49
+ self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "images/generations",) + ("",) * 3)
50
+ self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "audio/transcriptions",) + ("",) * 3)
51
+ self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "moderations",) + ("",) * 3)
52
+ self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "embeddings",) + ("",) * 3)
53
+ self.audio_speech: str = urlunparse(parsed_url[:2] + (before_v1 + "audio/speech",) + ("",) * 3)
54
+
55
+ if parsed_url.hostname == "generativelanguage.googleapis.com":
56
+ self.base_url = api_url
57
+ self.v1_url = api_url
58
+ self.chat_url = api_url
59
+ self.embeddings = api_url
60
+
61
+ def get_engine(provider, endpoint=None, original_model=""):
62
+ parsed_url = urlparse(provider['base_url'])
63
+ # print("parsed_url", parsed_url)
64
+ engine = None
65
+ stream = None
66
+ if parsed_url.path.endswith("/v1beta") or parsed_url.path.endswith("/v1") or parsed_url.netloc == 'generativelanguage.googleapis.com':
67
+ engine = "gemini"
68
+ elif parsed_url.netloc.rstrip('/').endswith('aiplatform.googleapis.com'):
69
+ engine = "vertex"
70
+ elif parsed_url.netloc.rstrip('/').endswith('openai.azure.com') or parsed_url.netloc.rstrip('/').endswith('services.ai.azure.com'):
71
+ engine = "azure"
72
+ elif parsed_url.netloc == 'api.cloudflare.com':
73
+ engine = "cloudflare"
74
+ elif parsed_url.netloc == 'api.anthropic.com' or parsed_url.path.endswith("v1/messages"):
75
+ engine = "claude"
76
+ elif parsed_url.netloc == 'api.cohere.com':
77
+ engine = "cohere"
78
+ stream = True
79
+ else:
80
+ engine = "gpt"
81
+
82
+ original_model = original_model.lower()
83
+ if original_model \
84
+ and "claude" not in original_model \
85
+ and "gpt" not in original_model \
86
+ and "deepseek" not in original_model \
87
+ and "o1" not in original_model \
88
+ and "o3" not in original_model \
89
+ and "gemini" not in original_model \
90
+ and "learnlm" not in original_model \
91
+ and "grok" not in original_model \
92
+ and parsed_url.netloc != 'api.cloudflare.com' \
93
+ and parsed_url.netloc != 'api.cohere.com':
94
+ engine = "openrouter"
95
+
96
+ if "claude" in original_model and engine == "vertex":
97
+ engine = "vertex-claude"
98
+
99
+ if "gemini" in original_model and engine == "vertex":
100
+ engine = "vertex-gemini"
101
+
102
+ if provider.get("engine"):
103
+ engine = provider["engine"]
104
+
105
+ if endpoint == "/v1/images/generations" or "stable-diffusion" in original_model:
106
+ engine = "dalle"
107
+ stream = False
108
+
109
+ if endpoint == "/v1/audio/transcriptions":
110
+ engine = "whisper"
111
+ stream = False
112
+
113
+ if endpoint == "/v1/moderations":
114
+ engine = "moderation"
115
+ stream = False
116
+
117
+ if endpoint == "/v1/embeddings":
118
+ engine = "embedding"
119
+
120
+ if endpoint == "/v1/audio/speech":
121
+ engine = "tts"
122
+ stream = False
123
+
124
+ return engine, stream
125
+
126
+ def get_proxy(proxy, client_config = {}):
127
+ if proxy:
128
+ # 解析代理URL
129
+ parsed = urlparse(proxy)
130
+ scheme = parsed.scheme.rstrip('h')
131
+
132
+ if scheme == 'socks5':
133
+ try:
134
+ from httpx_socks import AsyncProxyTransport
135
+ proxy = proxy.replace('socks5h://', 'socks5://')
136
+ transport = AsyncProxyTransport.from_url(proxy)
137
+ client_config["transport"] = transport
138
+ # print("proxy", proxy)
139
+ except ImportError:
140
+ logger.error("httpx-socks package is required for SOCKS proxy support")
141
+ raise ImportError("Please install httpx-socks package for SOCKS proxy support: pip install httpx-socks")
142
+ else:
143
+ client_config["proxies"] = {
144
+ "http://": proxy,
145
+ "https://": proxy
146
+ }
147
+ return client_config
148
+
149
+ def update_initial_model(provider):
150
+ try:
151
+ engine, stream_mode = get_engine(provider, endpoint=None, original_model="")
152
+ # print("engine", engine, provider)
153
+ api_url = provider['base_url']
154
+ api = provider['api']
155
+ proxy = safe_get(provider, "preferences", "proxy", default=None)
156
+ client_config = get_proxy(proxy)
157
+ if engine == "gemini":
158
+ url = "https://generativelanguage.googleapis.com/v1beta/models"
159
+ params = {"key": api}
160
+ with httpx.Client(**client_config) as client:
161
+ response = client.get(url, params=params)
162
+
163
+ original_models = response.json()
164
+ if original_models.get("error"):
165
+ raise Exception({"error": original_models.get("error"), "endpoint": url, "api": api})
166
+
167
+ models = {"data": []}
168
+ for model in original_models["models"]:
169
+ models["data"].append({
170
+ "id": model["name"].split("models/")[-1],
171
+ })
172
+ else:
173
+ endpoint = BaseAPI(api_url=api_url)
174
+ endpoint_models_url = endpoint.v1_models
175
+ if isinstance(api, list):
176
+ api = api[0]
177
+ headers = {"Authorization": f"Bearer {api}"}
178
+ response = httpx.get(
179
+ endpoint_models_url,
180
+ headers=headers,
181
+ **client_config
182
+ )
183
+ models = response.json()
184
+ if models.get("error"):
185
+ logger.error({"error": models.get("error"), "endpoint": endpoint_models_url, "api": api})
186
+ return []
187
+
188
+ # print(models)
189
+ models_list = models["data"]
190
+ models_id = [model["id"] for model in models_list]
191
+ set_models = set()
192
+ for model_item in models_id:
193
+ set_models.add(model_item)
194
+ models_id = list(set_models)
195
+ # print(models_id)
196
+ return models_id
197
+ except Exception as e:
198
+ # print("error:", e)
199
+ import traceback
200
+ traceback.print_exc()
201
+ return []
202
+
203
+ def safe_get(data, *keys, default=None):
204
+ for key in keys:
205
+ try:
206
+ data = data[key] if isinstance(data, (dict, list)) else data.get(key)
207
+ except (KeyError, IndexError, AttributeError, TypeError):
208
+ return default
209
+ if not data:
210
+ return default
211
+ return data
212
+
213
+ def parse_rate_limit(limit_string):
214
+ # 定义时间单位到秒的映射
215
+ time_units = {
216
+ 's': 1, 'sec': 1, 'second': 1,
217
+ 'm': 60, 'min': 60, 'minute': 60,
218
+ 'h': 3600, 'hr': 3600, 'hour': 3600,
219
+ 'd': 86400, 'day': 86400,
220
+ 'mo': 2592000, 'month': 2592000,
221
+ 'y': 31536000, 'year': 31536000
222
+ }
223
+
224
+ # 处理多个限制条件
225
+ limits = []
226
+ for limit in limit_string.split(','):
227
+ limit = limit.strip()
228
+ # 使用正则表达式匹配数字和单位
229
+ match = re.match(r'^(\d+)/(\w+)$', limit)
230
+ if not match:
231
+ raise ValueError(f"Invalid rate limit format: {limit}")
232
+
233
+ count, unit = match.groups()
234
+ count = int(count)
235
+
236
+ # 转换单位到秒
237
+ if unit not in time_units:
238
+ raise ValueError(f"Unknown time unit: {unit}")
239
+
240
+ seconds = time_units[unit]
241
+ limits.append((count, seconds))
242
+
243
+ return limits
244
+
245
+ class ThreadSafeCircularList:
246
+ def __init__(self, items = [], rate_limit={"default": "999999/min"}, schedule_algorithm="round_robin"):
247
+ if schedule_algorithm == "random":
248
+ import random
249
+ self.items = random.sample(items, len(items))
250
+ self.schedule_algorithm = "random"
251
+ elif schedule_algorithm == "round_robin":
252
+ self.items = items
253
+ self.schedule_algorithm = "round_robin"
254
+ elif schedule_algorithm == "fixed_priority":
255
+ self.items = items
256
+ self.schedule_algorithm = "fixed_priority"
257
+ else:
258
+ self.items = items
259
+ logger.warning(f"Unknown schedule algorithm: {schedule_algorithm}, use (round_robin, random, fixed_priority) instead")
260
+ self.schedule_algorithm = "round_robin"
261
+ self.index = 0
262
+ self.lock = asyncio.Lock()
263
+ # 修改为二级字典,第一级是item,第二级是model
264
+ self.requests = defaultdict(lambda: defaultdict(list))
265
+ self.cooling_until = defaultdict(float)
266
+ self.rate_limits = {}
267
+ if isinstance(rate_limit, dict):
268
+ for rate_limit_model, rate_limit_value in rate_limit.items():
269
+ self.rate_limits[rate_limit_model] = parse_rate_limit(rate_limit_value)
270
+ elif isinstance(rate_limit, str):
271
+ self.rate_limits["default"] = parse_rate_limit(rate_limit)
272
+ else:
273
+ logger.error(f"Error ThreadSafeCircularList: Unknown rate_limit type: {type(rate_limit)}, rate_limit: {rate_limit}")
274
+
275
+ async def set_cooling(self, item: str, cooling_time: int = 60):
276
+ """设置某个 item 进入冷却状态
277
+
278
+ Args:
279
+ item: 需要冷却的 item
280
+ cooling_time: 冷却时间(秒),默认60秒
281
+ """
282
+ if item == None:
283
+ return
284
+ now = time()
285
+ async with self.lock:
286
+ self.cooling_until[item] = now + cooling_time
287
+ # 清空该 item 的请求记录
288
+ # self.requests[item] = []
289
+ logger.warning(f"API key {item} 已进入冷却状态,冷却时间 {cooling_time} 秒")
290
+
291
+ async def is_rate_limited(self, item, model: str = None) -> bool:
292
+ now = time()
293
+ # 检查是否在冷却中
294
+ if now < self.cooling_until[item]:
295
+ return True
296
+
297
+ # 获取适用的速率限制
298
+
299
+ if model:
300
+ model_key = model
301
+ else:
302
+ model_key = "default"
303
+
304
+ rate_limit = None
305
+ # 先尝试精确匹配
306
+ if model and model in self.rate_limits:
307
+ rate_limit = self.rate_limits[model]
308
+ else:
309
+ # 如果没有精确匹配,尝试模糊匹配
310
+ for limit_model in self.rate_limits:
311
+ if limit_model != "default" and model and limit_model in model:
312
+ rate_limit = self.rate_limits[limit_model]
313
+ break
314
+
315
+ # 如果都没匹配到,使用默认值
316
+ if rate_limit is None:
317
+ rate_limit = self.rate_limits.get("default", [(999999, 60)]) # 默认限制
318
+
319
+ # 检查所有速率限制条件
320
+ for limit_count, limit_period in rate_limit:
321
+ # 使用特定模型的请求记录进行计算
322
+ recent_requests = sum(1 for req in self.requests[item][model_key] if req > now - limit_period)
323
+ if recent_requests >= limit_count:
324
+ logger.warning(f"API key {item} 对模型 {model_key} 已达到速率限制 ({limit_count}/{limit_period}秒)")
325
+ return True
326
+
327
+ # 清理太旧的请求记录
328
+ max_period = max(period for _, period in rate_limit)
329
+ self.requests[item][model_key] = [req for req in self.requests[item][model_key] if req > now - max_period]
330
+
331
+ # 记录新的请求
332
+ self.requests[item][model_key].append(now)
333
+ return False
334
+
335
+ async def next(self, model: str = None):
336
+ async with self.lock:
337
+ if self.schedule_algorithm == "fixed_priority":
338
+ self.index = 0
339
+ start_index = self.index
340
+ while True:
341
+ item = self.items[self.index]
342
+ self.index = (self.index + 1) % len(self.items)
343
+
344
+ if not await self.is_rate_limited(item, model):
345
+ return item
346
+
347
+ # 如果已经检查了所有的 API key 都被限制
348
+ if self.index == start_index:
349
+ logger.warning(f"All API keys are rate limited!")
350
+ raise HTTPException(status_code=429, detail="Too many requests")
351
+
352
+ async def after_next_current(self):
353
+ # 返回当前取出的 API,因为已经调用了 next,所以当前API应该是上一个
354
+ if len(self.items) == 0:
355
+ return None
356
+ async with self.lock:
357
+ item = self.items[(self.index - 1) % len(self.items)]
358
+ return item
359
+
360
+ def get_items_count(self) -> int:
361
+ """返回列表中的项目数量
362
+
363
+ Returns:
364
+ int: items列表的长度
365
+ """
366
+ return len(self.items)
367
+
368
+ def circular_list_encoder(obj):
369
+ if isinstance(obj, ThreadSafeCircularList):
370
+ return obj.to_dict()
371
+ raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
372
+
373
+ provider_api_circular_list = defaultdict(ThreadSafeCircularList)
374
+
375
+ # 【GCP-Vertex AI 目前有這些區域可用】 https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude?hl=zh_cn
376
+ # https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations?hl=zh-cn#available-regions
377
+
378
+ # c3.5s
379
+ # us-east5
380
+ # europe-west1
381
+
382
+ # c3s
383
+ # us-east5
384
+ # us-central1
385
+ # asia-southeast1
386
+
387
+ # c3o
388
+ # us-east5
389
+
390
+ # c3h
391
+ # us-east5
392
+ # us-central1
393
+ # europe-west1
394
+ # europe-west4
395
+
396
+
397
+ c35s = ThreadSafeCircularList(["us-east5", "europe-west1"])
398
+ c3s = ThreadSafeCircularList(["us-east5", "us-central1", "asia-southeast1"])
399
+ c3o = ThreadSafeCircularList(["us-east5"])
400
+ c3h = ThreadSafeCircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])
401
+ gemini1 = ThreadSafeCircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
402
+ gemini2 = ThreadSafeCircularList(["us-central1"])
403
+
404
+
405
+
406
+ # end_of_line = "\n\r\n"
407
+ # end_of_line = "\r\n"
408
+ # end_of_line = "\n\r"
409
+ end_of_line = "\n\n"
410
+ # end_of_line = "\r"
411
+ # end_of_line = "\n"
412
+
413
+ import random
414
+ import string
415
+ async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0, reasoning_content=None):
416
+ random.seed(timestamp)
417
+ random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
418
+
419
+ delta_content = {"role": "assistant", "content": content} if content else {}
420
+ if reasoning_content:
421
+ delta_content = {"role": "assistant", "content": "", "reasoning_content": reasoning_content}
422
+
423
+ sample_data = {
424
+ "id": f"chatcmpl-{random_str}",
425
+ "object": "chat.completion.chunk",
426
+ "created": timestamp,
427
+ "model": model,
428
+ "choices": [
429
+ {
430
+ "index": 0,
431
+ "delta": delta_content,
432
+ "logprobs": None,
433
+ "finish_reason": None if content else "stop"
434
+ }
435
+ ],
436
+ "usage": None,
437
+ "system_fingerprint": "fp_d576307f90",
438
+ }
439
+ if function_call_content:
440
+ sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]}
441
+ if tools_id and function_call_name:
442
+ sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id": tools_id,"type":"function","function":{"name": function_call_name, "arguments":""}}]}
443
+ # sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
444
+ if role:
445
+ sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
446
+ if total_tokens:
447
+ total_tokens = prompt_tokens + completion_tokens
448
+ sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
449
+ sample_data["choices"] = []
450
+ json_data = json.dumps(sample_data, ensure_ascii=False)
451
+
452
+ # 构建SSE响应
453
+ sse_response = f"data: {json_data}" + end_of_line
454
+
455
+ return sse_response
456
+
457
+ async def generate_no_stream_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0):
458
+ random.seed(timestamp)
459
+ random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
460
+ sample_data = {
461
+ "id": f"chatcmpl-{random_str}",
462
+ "object": "chat.completion",
463
+ "created": timestamp,
464
+ "model": model,
465
+ "choices": [
466
+ {
467
+ "index": 0,
468
+ "message": {
469
+ "role": role,
470
+ "content": content,
471
+ "refusal": None
472
+ },
473
+ "logprobs": None,
474
+ "finish_reason": "stop"
475
+ }
476
+ ],
477
+ "usage": None,
478
+ "system_fingerprint": "fp_a7d06e42a7"
479
+ }
480
+
481
+ if function_call_name:
482
+ if not tools_id:
483
+ tools_id = f"call_{random_str}"
484
+ sample_data = {
485
+ "id": f"chatcmpl-{random_str}",
486
+ "object": "chat.completion",
487
+ "created": timestamp,
488
+ "model": model,
489
+ "choices": [
490
+ {
491
+ "index": 0,
492
+ "message": {
493
+ "role": "assistant",
494
+ "content": None,
495
+ "tool_calls": [
496
+ {
497
+ "id": tools_id,
498
+ "type": "function",
499
+ "function": {
500
+ "name": function_call_name,
501
+ "arguments": json.dumps(function_call_content, ensure_ascii=False)
502
+ }
503
+ }
504
+ ],
505
+ "refusal": None
506
+ },
507
+ "logprobs": None,
508
+ "finish_reason": "tool_calls"
509
+ }
510
+ ],
511
+ "usage": None,
512
+ "service_tier": "default",
513
+ "system_fingerprint": "fp_4691090a87"
514
+ }
515
+
516
+ if total_tokens:
517
+ total_tokens = prompt_tokens + completion_tokens
518
+ sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
519
+
520
+ json_data = json.dumps(sample_data, ensure_ascii=False)
521
+
522
+ return json_data
523
+
524
+ def get_image_format(file_content):
525
+ try:
526
+ img = Image.open(io.BytesIO(file_content))
527
+ return img.format.lower()
528
+ except:
529
+ return None
530
+
531
+ def encode_image(image_path):
532
+ with open(image_path, "rb") as image_file:
533
+ file_content = image_file.read()
534
+ img_format = get_image_format(file_content)
535
+ if not img_format:
536
+ raise ValueError("无法识别的图片格式")
537
+ base64_encoded = base64.b64encode(file_content).decode('utf-8')
538
+
539
+ if img_format == 'png':
540
+ return f"data:image/png;base64,{base64_encoded}"
541
+ elif img_format in ['jpg', 'jpeg']:
542
+ return f"data:image/jpeg;base64,{base64_encoded}"
543
+ else:
544
+ raise ValueError(f"不支持的图片格式: {img_format}")
545
+
546
+ async def get_doc_from_url(url):
547
+ filename = urllib.parse.unquote(url.split("/")[-1])
548
+ transport = httpx.AsyncHTTPTransport(
549
+ http2=True,
550
+ verify=False,
551
+ retries=1
552
+ )
553
+ async with httpx.AsyncClient(transport=transport) as client:
554
+ try:
555
+ response = await client.get(
556
+ url,
557
+ timeout=30.0
558
+ )
559
+ with open(filename, 'wb') as f:
560
+ f.write(response.content)
561
+
562
+ except httpx.RequestError as e:
563
+ print(f"An error occurred while requesting {e.request.url!r}.")
564
+
565
+ return filename
566
+
567
+ async def get_encode_image(image_url):
568
+ filename = await get_doc_from_url(image_url)
569
+ image_path = os.getcwd() + "/" + filename
570
+ base64_image = encode_image(image_path)
571
+ os.remove(image_path)
572
+ return base64_image
573
+
574
+ # from PIL import Image
575
+ # import io
576
+ # def validate_image(image_data, image_type):
577
+ # try:
578
+ # decoded_image = base64.b64decode(image_data)
579
+ # image = Image.open(io.BytesIO(decoded_image))
580
+
581
+ # # 检查图片格式是否与声明的类型匹配
582
+ # # print("image.format", image.format)
583
+ # if image_type == "image/png" and image.format != "PNG":
584
+ # raise ValueError("Image is not a valid PNG")
585
+ # elif image_type == "image/jpeg" and image.format not in ["JPEG", "JPG"]:
586
+ # raise ValueError("Image is not a valid JPEG")
587
+
588
+ # # 如果没有异常,则图片有效
589
+ # return True
590
+ # except Exception as e:
591
+ # print(f"Image validation failed: {str(e)}")
592
+ # return False
593
+
594
+ async def get_image_message(base64_image, engine = None):
595
+ if base64_image.startswith("http"):
596
+ base64_image = await get_encode_image(base64_image)
597
+ colon_index = base64_image.index(":")
598
+ semicolon_index = base64_image.index(";")
599
+ image_type = base64_image[colon_index + 1:semicolon_index]
600
+
601
+ if image_type == "image/webp":
602
+ # 将webp转换为png
603
+
604
+ # 解码base64获取图片数据
605
+ image_data = base64.b64decode(base64_image.split(",")[1])
606
+
607
+ # 使用PIL打开webp图片
608
+ image = Image.open(io.BytesIO(image_data))
609
+
610
+ # 转换为PNG格式
611
+ png_buffer = io.BytesIO()
612
+ image.save(png_buffer, format="PNG")
613
+ png_base64 = base64.b64encode(png_buffer.getvalue()).decode('utf-8')
614
+
615
+ # 返回PNG格式的base64
616
+ base64_image = f"data:image/png;base64,{png_base64}"
617
+ image_type = "image/png"
618
+
619
+ if "gpt" == engine or "openrouter" == engine or "azure" == engine:
620
+ return {
621
+ "type": "image_url",
622
+ "image_url": {
623
+ "url": base64_image,
624
+ }
625
+ }
626
+ if "claude" == engine or "vertex-claude" == engine:
627
+ # if not validate_image(base64_image.split(",")[1], image_type):
628
+ # raise ValueError(f"Invalid image format. Expected {image_type}")
629
+ return {
630
+ "type": "image",
631
+ "source": {
632
+ "type": "base64",
633
+ "media_type": image_type,
634
+ "data": base64_image.split(",")[1],
635
+ }
636
+ }
637
+ if "gemini" == engine or "vertex-gemini" == engine:
638
+ return {
639
+ "inlineData": {
640
+ "mimeType": image_type,
641
+ "data": base64_image.split(",")[1],
642
+ }
643
+ }
644
+ raise ValueError("Unknown engine")
645
+
646
+ async def get_text_message(message, engine = None):
647
+ if "gpt" == engine or "claude" == engine or "openrouter" == engine or "vertex-claude" == engine or "azure" == engine:
648
+ return {"type": "text", "text": message}
649
+ if "gemini" == engine or "vertex-gemini" == engine:
650
+ return {"text": message}
651
+ if engine == "cloudflare":
652
+ return message
653
+ if engine == "cohere":
654
+ return message
655
+ raise ValueError("Unknown engine")
@@ -0,0 +1,9 @@
1
+ from .chatgpt import *
2
+ from .claude import *
3
+ from .gemini import *
4
+ from .vertex import *
5
+ from .groq import *
6
+ from .audio import *
7
+ from .duckduckgo import *
8
+
9
+ # __all__ = ["chatgpt", "claude", "claude3", "gemini", "groq"]