beswarm 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. beswarm/aient/main.py +50 -0
  2. beswarm/aient/setup.py +15 -0
  3. beswarm/aient/src/aient/__init__.py +1 -0
  4. beswarm/aient/src/aient/core/__init__.py +1 -0
  5. beswarm/aient/src/aient/core/log_config.py +6 -0
  6. beswarm/aient/src/aient/core/models.py +232 -0
  7. beswarm/aient/src/aient/core/request.py +1665 -0
  8. beswarm/aient/src/aient/core/response.py +617 -0
  9. beswarm/aient/src/aient/core/test/test_base_api.py +18 -0
  10. beswarm/aient/src/aient/core/test/test_image.py +15 -0
  11. beswarm/aient/src/aient/core/test/test_payload.py +92 -0
  12. beswarm/aient/src/aient/core/utils.py +715 -0
  13. beswarm/aient/src/aient/models/__init__.py +9 -0
  14. beswarm/aient/src/aient/models/audio.py +63 -0
  15. beswarm/aient/src/aient/models/base.py +251 -0
  16. beswarm/aient/src/aient/models/chatgpt.py +938 -0
  17. beswarm/aient/src/aient/models/claude.py +640 -0
  18. beswarm/aient/src/aient/models/duckduckgo.py +241 -0
  19. beswarm/aient/src/aient/models/gemini.py +357 -0
  20. beswarm/aient/src/aient/models/groq.py +268 -0
  21. beswarm/aient/src/aient/models/vertex.py +420 -0
  22. beswarm/aient/src/aient/plugins/__init__.py +33 -0
  23. beswarm/aient/src/aient/plugins/arXiv.py +48 -0
  24. beswarm/aient/src/aient/plugins/config.py +172 -0
  25. beswarm/aient/src/aient/plugins/excute_command.py +35 -0
  26. beswarm/aient/src/aient/plugins/get_time.py +19 -0
  27. beswarm/aient/src/aient/plugins/image.py +72 -0
  28. beswarm/aient/src/aient/plugins/list_directory.py +50 -0
  29. beswarm/aient/src/aient/plugins/read_file.py +79 -0
  30. beswarm/aient/src/aient/plugins/registry.py +116 -0
  31. beswarm/aient/src/aient/plugins/run_python.py +156 -0
  32. beswarm/aient/src/aient/plugins/websearch.py +394 -0
  33. beswarm/aient/src/aient/plugins/write_file.py +51 -0
  34. beswarm/aient/src/aient/prompt/__init__.py +1 -0
  35. beswarm/aient/src/aient/prompt/agent.py +280 -0
  36. beswarm/aient/src/aient/utils/__init__.py +0 -0
  37. beswarm/aient/src/aient/utils/prompt.py +143 -0
  38. beswarm/aient/src/aient/utils/scripts.py +721 -0
  39. beswarm/aient/test/chatgpt.py +161 -0
  40. beswarm/aient/test/claude.py +32 -0
  41. beswarm/aient/test/test.py +2 -0
  42. beswarm/aient/test/test_API.py +6 -0
  43. beswarm/aient/test/test_Deepbricks.py +20 -0
  44. beswarm/aient/test/test_Web_crawler.py +262 -0
  45. beswarm/aient/test/test_aiwaves.py +25 -0
  46. beswarm/aient/test/test_aiwaves_arxiv.py +19 -0
  47. beswarm/aient/test/test_ask_gemini.py +8 -0
  48. beswarm/aient/test/test_class.py +17 -0
  49. beswarm/aient/test/test_claude.py +23 -0
  50. beswarm/aient/test/test_claude_zh_char.py +26 -0
  51. beswarm/aient/test/test_ddg_search.py +50 -0
  52. beswarm/aient/test/test_download_pdf.py +56 -0
  53. beswarm/aient/test/test_gemini.py +97 -0
  54. beswarm/aient/test/test_get_token_dict.py +21 -0
  55. beswarm/aient/test/test_google_search.py +35 -0
  56. beswarm/aient/test/test_jieba.py +32 -0
  57. beswarm/aient/test/test_json.py +65 -0
  58. beswarm/aient/test/test_langchain_search_old.py +235 -0
  59. beswarm/aient/test/test_logging.py +32 -0
  60. beswarm/aient/test/test_ollama.py +55 -0
  61. beswarm/aient/test/test_plugin.py +16 -0
  62. beswarm/aient/test/test_py_run.py +26 -0
  63. beswarm/aient/test/test_requests.py +162 -0
  64. beswarm/aient/test/test_search.py +18 -0
  65. beswarm/aient/test/test_tikitoken.py +19 -0
  66. beswarm/aient/test/test_token.py +94 -0
  67. beswarm/aient/test/test_url.py +33 -0
  68. beswarm/aient/test/test_whisper.py +14 -0
  69. beswarm/aient/test/test_wildcard.py +20 -0
  70. beswarm/aient/test/test_yjh.py +21 -0
  71. {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/METADATA +1 -1
  72. beswarm-0.1.13.dist-info/RECORD +131 -0
  73. beswarm-0.1.12.dist-info/RECORD +0 -61
  74. {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/WHEEL +0 -0
  75. {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,715 @@
1
+ import re
2
+ import io
3
+ import os
4
+ import ast
5
+ import json
6
+ import httpx
7
+ import base64
8
+ import asyncio
9
+ import urllib.parse
10
+ from time import time
11
+ from PIL import Image
12
+ from fastapi import HTTPException
13
+ from urllib.parse import urlparse
14
+ from collections import defaultdict
15
+
16
+ from .log_config import logger
17
+
18
+ def get_model_dict(provider):
19
+ model_dict = {}
20
+ for model in provider['model']:
21
+ if type(model) == str:
22
+ model_dict[model] = model
23
+ if isinstance(model, dict):
24
+ model_dict.update({new: old for old, new in model.items()})
25
+ return model_dict
26
+
27
+ class BaseAPI:
28
+ def __init__(
29
+ self,
30
+ api_url: str = "https://api.openai.com/v1/chat/completions",
31
+ ):
32
+ if api_url == "":
33
+ api_url = "https://api.openai.com/v1/chat/completions"
34
+ self.source_api_url: str = api_url
35
+ from urllib.parse import urlparse, urlunparse
36
+ parsed_url = urlparse(self.source_api_url)
37
+ # print("parsed_url", parsed_url)
38
+ if parsed_url.scheme == "":
39
+ raise Exception("Error: API_URL is not set")
40
+ if parsed_url.path != '/':
41
+ before_v1 = parsed_url.path.split("chat/completions")[0]
42
+ if not before_v1.endswith("/"):
43
+ before_v1 = before_v1 + "/"
44
+ else:
45
+ before_v1 = ""
46
+ self.base_url: str = urlunparse(parsed_url[:2] + ("",) + ("",) * 3)
47
+ self.v1_url: str = urlunparse(parsed_url[:2]+ (before_v1,) + ("",) * 3)
48
+ self.v1_models: str = urlunparse(parsed_url[:2] + (before_v1 + "models",) + ("",) * 3)
49
+ self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "chat/completions",) + ("",) * 3)
50
+ self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "images/generations",) + ("",) * 3)
51
+ self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "audio/transcriptions",) + ("",) * 3)
52
+ self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "moderations",) + ("",) * 3)
53
+ self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "embeddings",) + ("",) * 3)
54
+ self.audio_speech: str = urlunparse(parsed_url[:2] + (before_v1 + "audio/speech",) + ("",) * 3)
55
+
56
+ if parsed_url.hostname == "generativelanguage.googleapis.com":
57
+ self.base_url = api_url
58
+ self.v1_url = api_url
59
+ self.chat_url = api_url
60
+ self.embeddings = api_url
61
+
62
+ def get_engine(provider, endpoint=None, original_model=""):
63
+ parsed_url = urlparse(provider['base_url'])
64
+ # print("parsed_url", parsed_url)
65
+ engine = None
66
+ stream = None
67
+ if parsed_url.path.endswith("/v1beta") or \
68
+ parsed_url.path.endswith("/v1") or \
69
+ (parsed_url.netloc == 'generativelanguage.googleapis.com' and "openai/chat/completions" not in parsed_url.path):
70
+ engine = "gemini"
71
+ elif parsed_url.netloc.rstrip('/').endswith('aiplatform.googleapis.com') or (parsed_url.netloc.rstrip('/').endswith('gateway.ai.cloudflare.com') and "google-vertex-ai" in parsed_url.path):
72
+ engine = "vertex"
73
+ elif parsed_url.netloc.rstrip('/').endswith('openai.azure.com') or parsed_url.netloc.rstrip('/').endswith('services.ai.azure.com'):
74
+ engine = "azure"
75
+ elif parsed_url.netloc == 'api.cloudflare.com':
76
+ engine = "cloudflare"
77
+ elif parsed_url.netloc == 'api.anthropic.com' or parsed_url.path.endswith("v1/messages"):
78
+ engine = "claude"
79
+ elif 'amazonaws.com' in parsed_url.netloc:
80
+ engine = "aws"
81
+ elif parsed_url.netloc == 'api.cohere.com':
82
+ engine = "cohere"
83
+ stream = True
84
+ else:
85
+ engine = "gpt"
86
+
87
+ original_model = original_model.lower()
88
+ if original_model \
89
+ and "claude" not in original_model \
90
+ and "gpt" not in original_model \
91
+ and "deepseek" not in original_model \
92
+ and "o1" not in original_model \
93
+ and "o3" not in original_model \
94
+ and "o4" not in original_model \
95
+ and "gemini" not in original_model \
96
+ and "learnlm" not in original_model \
97
+ and "grok" not in original_model \
98
+ and parsed_url.netloc != 'api.cloudflare.com' \
99
+ and parsed_url.netloc != 'api.cohere.com':
100
+ engine = "openrouter"
101
+
102
+ if "claude" in original_model and engine == "vertex":
103
+ engine = "vertex-claude"
104
+
105
+ if "gemini" in original_model and engine == "vertex":
106
+ engine = "vertex-gemini"
107
+
108
+ if provider.get("engine"):
109
+ engine = provider["engine"]
110
+
111
+ if endpoint == "/v1/images/generations" or "stable-diffusion" in original_model:
112
+ engine = "dalle"
113
+ stream = False
114
+
115
+ if endpoint == "/v1/audio/transcriptions":
116
+ engine = "whisper"
117
+ stream = False
118
+
119
+ if endpoint == "/v1/moderations":
120
+ engine = "moderation"
121
+ stream = False
122
+
123
+ if endpoint == "/v1/embeddings":
124
+ engine = "embedding"
125
+
126
+ if endpoint == "/v1/audio/speech":
127
+ engine = "tts"
128
+ stream = False
129
+
130
+ return engine, stream
131
+
132
+ from httpx_socks import AsyncProxyTransport
133
+ def get_proxy(proxy, client_config = {}):
134
+ if proxy:
135
+ # 解析代理URL
136
+ parsed = urlparse(proxy)
137
+ scheme = parsed.scheme.rstrip('h')
138
+
139
+ if scheme == 'socks5':
140
+ proxy = proxy.replace('socks5h://', 'socks5://')
141
+ transport = AsyncProxyTransport.from_url(proxy)
142
+ client_config["transport"] = transport
143
+ # print("proxy", proxy)
144
+ else:
145
+ client_config["proxies"] = {
146
+ "http://": proxy,
147
+ "https://": proxy
148
+ }
149
+ return client_config
150
+
151
+ def update_initial_model(provider):
152
+ try:
153
+ engine, stream_mode = get_engine(provider, endpoint=None, original_model="")
154
+ # print("engine", engine, provider)
155
+ api_url = provider['base_url']
156
+ api = provider['api']
157
+ proxy = safe_get(provider, "preferences", "proxy", default=None)
158
+ client_config = get_proxy(proxy)
159
+ if engine == "gemini":
160
+ before_v1 = api_url.split("/v1beta")[0]
161
+ url = before_v1 + "/v1beta/models"
162
+ params = {"key": api}
163
+ with httpx.Client(**client_config) as client:
164
+ response = client.get(url, params=params)
165
+
166
+ original_models = response.json()
167
+ if original_models.get("error"):
168
+ raise Exception({"error": original_models.get("error"), "endpoint": url, "api": api})
169
+
170
+ models = {"data": []}
171
+ for model in original_models["models"]:
172
+ models["data"].append({
173
+ "id": model["name"].split("models/")[-1],
174
+ })
175
+ else:
176
+ endpoint = BaseAPI(api_url=api_url)
177
+ endpoint_models_url = endpoint.v1_models
178
+ if isinstance(api, list):
179
+ api = api[0]
180
+ headers = {"Authorization": f"Bearer {api}"}
181
+ response = httpx.get(
182
+ endpoint_models_url,
183
+ headers=headers,
184
+ **client_config
185
+ )
186
+ models = response.json()
187
+ if models.get("error"):
188
+ logger.error({"error": models.get("error"), "endpoint": endpoint_models_url, "api": api})
189
+ return []
190
+
191
+ # print(models)
192
+ models_list = models["data"]
193
+ models_id = [model["id"] for model in models_list]
194
+ set_models = set()
195
+ for model_item in models_id:
196
+ set_models.add(model_item)
197
+ models_id = list(set_models)
198
+ # print(models_id)
199
+ return models_id
200
+ except Exception as e:
201
+ # print("error:", e)
202
+ import traceback
203
+ traceback.print_exc()
204
+ return []
205
+
206
+ def safe_get(data, *keys, default=None):
207
+ for key in keys:
208
+ try:
209
+ data = data[key] if isinstance(data, (dict, list)) else data.get(key)
210
+ except (KeyError, IndexError, AttributeError, TypeError):
211
+ return default
212
+ if not data:
213
+ return default
214
+ return data
215
+
216
+ def parse_rate_limit(limit_string):
217
+ # 定义时间单位到秒的映射
218
+ time_units = {
219
+ 's': 1, 'sec': 1, 'second': 1,
220
+ 'm': 60, 'min': 60, 'minute': 60,
221
+ 'h': 3600, 'hr': 3600, 'hour': 3600,
222
+ 'd': 86400, 'day': 86400,
223
+ 'mo': 2592000, 'month': 2592000,
224
+ 'y': 31536000, 'year': 31536000
225
+ }
226
+
227
+ # 处理多个限制条件
228
+ limits = []
229
+ for limit in limit_string.split(','):
230
+ limit = limit.strip()
231
+ # 使用正则表达式匹配数字和单位
232
+ match = re.match(r'^(\d+)/(\w+)$', limit)
233
+ if not match:
234
+ raise ValueError(f"Invalid rate limit format: {limit}")
235
+
236
+ count, unit = match.groups()
237
+ count = int(count)
238
+
239
+ # 转换单位到秒
240
+ if unit not in time_units:
241
+ raise ValueError(f"Unknown time unit: {unit}")
242
+
243
+ seconds = time_units[unit]
244
+ limits.append((count, seconds))
245
+
246
+ return limits
247
+
248
+ class ThreadSafeCircularList:
249
+ def __init__(self, items = [], rate_limit={"default": "999999/min"}, schedule_algorithm="round_robin"):
250
+ if schedule_algorithm == "random":
251
+ import random
252
+ self.items = random.sample(items, len(items))
253
+ self.schedule_algorithm = "random"
254
+ elif schedule_algorithm == "round_robin":
255
+ self.items = items
256
+ self.schedule_algorithm = "round_robin"
257
+ elif schedule_algorithm == "fixed_priority":
258
+ self.items = items
259
+ self.schedule_algorithm = "fixed_priority"
260
+ else:
261
+ self.items = items
262
+ logger.warning(f"Unknown schedule algorithm: {schedule_algorithm}, use (round_robin, random, fixed_priority) instead")
263
+ self.schedule_algorithm = "round_robin"
264
+ self.index = 0
265
+ self.lock = asyncio.Lock()
266
+ # 修改为二级字典,第一级是item,第二级是model
267
+ self.requests = defaultdict(lambda: defaultdict(list))
268
+ self.cooling_until = defaultdict(float)
269
+ self.rate_limits = {}
270
+ if isinstance(rate_limit, dict):
271
+ for rate_limit_model, rate_limit_value in rate_limit.items():
272
+ self.rate_limits[rate_limit_model] = parse_rate_limit(rate_limit_value)
273
+ elif isinstance(rate_limit, str):
274
+ self.rate_limits["default"] = parse_rate_limit(rate_limit)
275
+ else:
276
+ logger.error(f"Error ThreadSafeCircularList: Unknown rate_limit type: {type(rate_limit)}, rate_limit: {rate_limit}")
277
+
278
+ async def set_cooling(self, item: str, cooling_time: int = 60):
279
+ """设置某个 item 进入冷却状态
280
+
281
+ Args:
282
+ item: 需要冷却的 item
283
+ cooling_time: 冷却时间(秒),默认60秒
284
+ """
285
+ if item == None:
286
+ return
287
+ now = time()
288
+ async with self.lock:
289
+ self.cooling_until[item] = now + cooling_time
290
+ # 清空该 item 的请求记录
291
+ # self.requests[item] = []
292
+ logger.warning(f"API key {item} 已进入冷却状态,冷却时间 {cooling_time} 秒")
293
+
294
+ async def is_rate_limited(self, item, model: str = None, is_check: bool = False) -> bool:
295
+ now = time()
296
+ # 检查是否在冷却中
297
+ if now < self.cooling_until[item]:
298
+ return True
299
+
300
+ # 获取适用的速率限制
301
+
302
+ if model:
303
+ model_key = model
304
+ else:
305
+ model_key = "default"
306
+
307
+ rate_limit = None
308
+ # 先尝试精确匹配
309
+ if model and model in self.rate_limits:
310
+ rate_limit = self.rate_limits[model]
311
+ else:
312
+ # 如果没有精确匹配,尝试模糊匹配
313
+ for limit_model in self.rate_limits:
314
+ if limit_model != "default" and model and limit_model in model:
315
+ rate_limit = self.rate_limits[limit_model]
316
+ break
317
+
318
+ # 如果都没匹配到,使用默认值
319
+ if rate_limit is None:
320
+ rate_limit = self.rate_limits.get("default", [(999999, 60)]) # 默认限制
321
+
322
+ # 检查所有速率限制条件
323
+ for limit_count, limit_period in rate_limit:
324
+ # 使用特定模型的请求记录进行计算
325
+ recent_requests = sum(1 for req in self.requests[item][model_key] if req > now - limit_period)
326
+ if recent_requests >= limit_count:
327
+ if not is_check:
328
+ logger.warning(f"API key {item}: model: {model_key} has been rate limited ({limit_count}/{limit_period} seconds)")
329
+ return True
330
+
331
+ # 清理太旧的请求记录
332
+ max_period = max(period for _, period in rate_limit)
333
+ self.requests[item][model_key] = [req for req in self.requests[item][model_key] if req > now - max_period]
334
+
335
+ # 记录新的请求
336
+ if not is_check:
337
+ self.requests[item][model_key].append(now)
338
+
339
+ return False
340
+
341
+ async def next(self, model: str = None):
342
+ async with self.lock:
343
+ if self.schedule_algorithm == "fixed_priority":
344
+ self.index = 0
345
+ start_index = self.index
346
+ while True:
347
+ item = self.items[self.index]
348
+ self.index = (self.index + 1) % len(self.items)
349
+
350
+ if not await self.is_rate_limited(item, model):
351
+ return item
352
+
353
+ # 如果已经检查了所有的 API key 都被限制
354
+ if self.index == start_index:
355
+ logger.warning(f"All API keys are rate limited!")
356
+ raise HTTPException(status_code=429, detail="Too many requests")
357
+
358
+ async def is_all_rate_limited(self, model: str = None) -> bool:
359
+ """检查是否所有的items都被速率限制
360
+
361
+ 与next方法不同,此方法不会改变任何内部状态(如self.index),
362
+ 仅返回一个布尔值表示是否所有的key都被限制。
363
+
364
+ Args:
365
+ model: 要检查的模型名称,默认为None
366
+
367
+ Returns:
368
+ bool: 如果所有items都被速率限制返回True,否则返回False
369
+ """
370
+ if len(self.items) == 0:
371
+ return False
372
+
373
+ async with self.lock:
374
+ for item in self.items:
375
+ if not await self.is_rate_limited(item, model, is_check=True):
376
+ return False
377
+
378
+ # 如果遍历完所有items都被限制,返回True
379
+ # logger.debug(f"Check result: all items are rate limited!")
380
+ return True
381
+
382
+ async def after_next_current(self):
383
+ # 返回当前取出的 API,因为已经调用了 next,所以当前API应该是上一个
384
+ if len(self.items) == 0:
385
+ return None
386
+ async with self.lock:
387
+ item = self.items[(self.index - 1) % len(self.items)]
388
+ return item
389
+
390
+ def get_items_count(self) -> int:
391
+ """返回列表中的项目数量
392
+
393
+ Returns:
394
+ int: items列表的长度
395
+ """
396
+ return len(self.items)
397
+
398
+ def circular_list_encoder(obj):
399
+ if isinstance(obj, ThreadSafeCircularList):
400
+ return obj.to_dict()
401
+ raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
402
+
403
+ provider_api_circular_list = defaultdict(ThreadSafeCircularList)
404
+
405
+ # 【GCP-Vertex AI 目前有這些區域可用】 https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude?hl=zh_cn
406
+ # https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations?hl=zh-cn#available-regions
407
+
408
+ # c3.5s
409
+ # us-east5
410
+ # europe-west1
411
+
412
+ # c3s
413
+ # us-east5
414
+ # us-central1
415
+ # asia-southeast1
416
+
417
+ # c3o
418
+ # us-east5
419
+
420
+ # c3h
421
+ # us-east5
422
+ # us-central1
423
+ # europe-west1
424
+ # europe-west4
425
+
426
+
427
+ c35s = ThreadSafeCircularList(["us-east5", "europe-west1"])
428
+ c3s = ThreadSafeCircularList(["us-east5", "us-central1", "asia-southeast1"])
429
+ c3o = ThreadSafeCircularList(["us-east5"])
430
+ c3h = ThreadSafeCircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])
431
+ gemini1 = ThreadSafeCircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
432
+ gemini2 = ThreadSafeCircularList(["us-central1"])
433
+
434
+
435
+
436
+ # end_of_line = "\n\r\n"
437
+ # end_of_line = "\r\n"
438
+ # end_of_line = "\n\r"
439
+ end_of_line = "\n\n"
440
+ # end_of_line = "\r"
441
+ # end_of_line = "\n"
442
+
443
+ import random
444
+ import string
445
+ async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0, reasoning_content=None):
446
+ random.seed(timestamp)
447
+ random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
448
+
449
+ delta_content = {"role": "assistant", "content": content} if content else {}
450
+ if reasoning_content:
451
+ delta_content = {"role": "assistant", "content": "", "reasoning_content": reasoning_content}
452
+
453
+ sample_data = {
454
+ "id": f"chatcmpl-{random_str}",
455
+ "object": "chat.completion.chunk",
456
+ "created": timestamp,
457
+ "model": model,
458
+ "choices": [
459
+ {
460
+ "index": 0,
461
+ "delta": delta_content,
462
+ "logprobs": None,
463
+ "finish_reason": None if content else "stop"
464
+ }
465
+ ],
466
+ "usage": None,
467
+ "system_fingerprint": "fp_d576307f90",
468
+ }
469
+ if function_call_content:
470
+ sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"arguments": function_call_content}}]}
471
+ if tools_id and function_call_name:
472
+ sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id": tools_id,"type":"function","function":{"name": function_call_name, "arguments":""}}]}
473
+ # sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
474
+ if role:
475
+ sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
476
+ if total_tokens:
477
+ total_tokens = prompt_tokens + completion_tokens
478
+ sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
479
+ sample_data["choices"] = []
480
+ json_data = json.dumps(sample_data, ensure_ascii=False)
481
+
482
+ # 构建SSE响应
483
+ sse_response = f"data: {json_data}" + end_of_line
484
+
485
+ return sse_response
486
+
487
+ async def generate_no_stream_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0):
488
+ random.seed(timestamp)
489
+ random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=29))
490
+ sample_data = {
491
+ "id": f"chatcmpl-{random_str}",
492
+ "object": "chat.completion",
493
+ "created": timestamp,
494
+ "model": model,
495
+ "choices": [
496
+ {
497
+ "index": 0,
498
+ "message": {
499
+ "role": role,
500
+ "content": content,
501
+ "refusal": None
502
+ },
503
+ "logprobs": None,
504
+ "finish_reason": "stop"
505
+ }
506
+ ],
507
+ "usage": None,
508
+ "system_fingerprint": "fp_a7d06e42a7"
509
+ }
510
+
511
+ if function_call_name:
512
+ if not tools_id:
513
+ tools_id = f"call_{random_str}"
514
+ sample_data = {
515
+ "id": f"chatcmpl-{random_str}",
516
+ "object": "chat.completion",
517
+ "created": timestamp,
518
+ "model": model,
519
+ "choices": [
520
+ {
521
+ "index": 0,
522
+ "message": {
523
+ "role": "assistant",
524
+ "content": None,
525
+ "tool_calls": [
526
+ {
527
+ "id": tools_id,
528
+ "type": "function",
529
+ "function": {
530
+ "name": function_call_name,
531
+ "arguments": json.dumps(function_call_content, ensure_ascii=False)
532
+ }
533
+ }
534
+ ],
535
+ "refusal": None
536
+ },
537
+ "logprobs": None,
538
+ "finish_reason": "tool_calls"
539
+ }
540
+ ],
541
+ "usage": None,
542
+ "service_tier": "default",
543
+ "system_fingerprint": "fp_4691090a87"
544
+ }
545
+
546
+ if total_tokens:
547
+ total_tokens = prompt_tokens + completion_tokens
548
+ sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
549
+
550
+ json_data = json.dumps(sample_data, ensure_ascii=False)
551
+
552
+ return json_data
553
+
554
+ def get_image_format(file_content):
555
+ try:
556
+ img = Image.open(io.BytesIO(file_content))
557
+ return img.format.lower()
558
+ except:
559
+ return None
560
+
561
+ def encode_image(image_path):
562
+ with open(image_path, "rb") as image_file:
563
+ file_content = image_file.read()
564
+ img_format = get_image_format(file_content)
565
+ if not img_format:
566
+ raise ValueError("无法识别的图片格式")
567
+ base64_encoded = base64.b64encode(file_content).decode('utf-8')
568
+
569
+ if img_format == 'png':
570
+ return f"data:image/png;base64,{base64_encoded}"
571
+ elif img_format in ['jpg', 'jpeg']:
572
+ return f"data:image/jpeg;base64,{base64_encoded}"
573
+ else:
574
+ raise ValueError(f"不支持的图片格式: {img_format}")
575
+
576
+ async def get_doc_from_url(url):
577
+ filename = urllib.parse.unquote(url.split("/")[-1])
578
+ transport = httpx.AsyncHTTPTransport(
579
+ http2=True,
580
+ verify=False,
581
+ retries=1
582
+ )
583
+ async with httpx.AsyncClient(transport=transport) as client:
584
+ try:
585
+ response = await client.get(
586
+ url,
587
+ timeout=30.0
588
+ )
589
+ with open(filename, 'wb') as f:
590
+ f.write(response.content)
591
+
592
+ except httpx.RequestError as e:
593
+ print(f"An error occurred while requesting {e.request.url!r}.")
594
+
595
+ return filename
596
+
597
+ async def get_encode_image(image_url):
598
+ filename = await get_doc_from_url(image_url)
599
+ image_path = os.getcwd() + "/" + filename
600
+ base64_image = encode_image(image_path)
601
+ os.remove(image_path)
602
+ return base64_image
603
+
604
+ # from PIL import Image
605
+ # import io
606
+ # def validate_image(image_data, image_type):
607
+ # try:
608
+ # decoded_image = base64.b64decode(image_data)
609
+ # image = Image.open(io.BytesIO(decoded_image))
610
+
611
+ # # 检查图片格式是否与声明的类型匹配
612
+ # # print("image.format", image.format)
613
+ # if image_type == "image/png" and image.format != "PNG":
614
+ # raise ValueError("Image is not a valid PNG")
615
+ # elif image_type == "image/jpeg" and image.format not in ["JPEG", "JPG"]:
616
+ # raise ValueError("Image is not a valid JPEG")
617
+
618
+ # # 如果没有异常,则图片有效
619
+ # return True
620
+ # except Exception as e:
621
+ # print(f"Image validation failed: {str(e)}")
622
+ # return False
623
+
624
+ async def get_image_message(base64_image, engine = None):
625
+ if base64_image.startswith("http"):
626
+ base64_image = await get_encode_image(base64_image)
627
+ colon_index = base64_image.index(":")
628
+ semicolon_index = base64_image.index(";")
629
+ image_type = base64_image[colon_index + 1:semicolon_index]
630
+
631
+ if image_type == "image/webp":
632
+ # 将webp转换为png
633
+
634
+ # 解码base64获取图片数据
635
+ image_data = base64.b64decode(base64_image.split(",")[1])
636
+
637
+ # 使用PIL打开webp图片
638
+ image = Image.open(io.BytesIO(image_data))
639
+
640
+ # 转换为PNG格式
641
+ png_buffer = io.BytesIO()
642
+ image.save(png_buffer, format="PNG")
643
+ png_base64 = base64.b64encode(png_buffer.getvalue()).decode('utf-8')
644
+
645
+ # 返回PNG格式的base64
646
+ base64_image = f"data:image/png;base64,{png_base64}"
647
+ image_type = "image/png"
648
+
649
+ if "gpt" == engine or "openrouter" == engine or "azure" == engine:
650
+ return {
651
+ "type": "image_url",
652
+ "image_url": {
653
+ "url": base64_image,
654
+ }
655
+ }
656
+ if "claude" == engine or "vertex-claude" == engine or "aws" == engine:
657
+ # if not validate_image(base64_image.split(",")[1], image_type):
658
+ # raise ValueError(f"Invalid image format. Expected {image_type}")
659
+ return {
660
+ "type": "image",
661
+ "source": {
662
+ "type": "base64",
663
+ "media_type": image_type,
664
+ "data": base64_image.split(",")[1],
665
+ }
666
+ }
667
+ if "gemini" == engine or "vertex-gemini" == engine:
668
+ return {
669
+ "inlineData": {
670
+ "mimeType": image_type,
671
+ "data": base64_image.split(",")[1],
672
+ }
673
+ }
674
+ raise ValueError("Unknown engine")
675
+
676
+ async def get_text_message(message, engine = None):
677
+ if "gpt" == engine or "claude" == engine or "openrouter" == engine or "vertex-claude" == engine or "azure" == engine or "aws" == engine:
678
+ return {"type": "text", "text": message}
679
+ if "gemini" == engine or "vertex-gemini" == engine:
680
+ return {"text": message}
681
+ if engine == "cloudflare":
682
+ return message
683
+ if engine == "cohere":
684
+ return message
685
+ raise ValueError("Unknown engine")
686
+
687
+ def parse_json_safely(json_str):
688
+ """
689
+ 尝试解析JSON字符串,先使用ast.literal_eval,失败则使用json.loads
690
+
691
+ Args:
692
+ json_str: 要解析的JSON字符串
693
+
694
+ Returns:
695
+ 解析后的Python对象
696
+
697
+ Raises:
698
+ Exception: 当两种方法都失败时抛出异常
699
+ """
700
+ try:
701
+ # 首先尝试使用ast.literal_eval解析
702
+ return ast.literal_eval(json_str)
703
+ except (SyntaxError, ValueError):
704
+ try:
705
+ # 如果失败,尝试使用json.loads解析
706
+ return json.loads(json_str, strict=False)
707
+ except json.JSONDecodeError as e:
708
+ # 两种方法都失败,抛出异常
709
+ raise Exception(f"无法解析JSON字符串: {e}")
710
+
711
+ if __name__ == "__main__":
712
+ provider = {
713
+ "base_url": "https://gateway.ai.cloudflare.com/v1/%7Baccount_id%7D/%7Bgateway_id%7D/google-vertex-ai",
714
+ }
715
+ print(get_engine(provider))