beswarm 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. beswarm/aient/main.py +50 -0
  2. beswarm/aient/setup.py +15 -0
  3. beswarm/aient/src/aient/__init__.py +1 -0
  4. beswarm/aient/src/aient/core/__init__.py +1 -0
  5. beswarm/aient/src/aient/core/log_config.py +6 -0
  6. beswarm/aient/src/aient/core/models.py +232 -0
  7. beswarm/aient/src/aient/core/request.py +1665 -0
  8. beswarm/aient/src/aient/core/response.py +617 -0
  9. beswarm/aient/src/aient/core/test/test_base_api.py +18 -0
  10. beswarm/aient/src/aient/core/test/test_image.py +15 -0
  11. beswarm/aient/src/aient/core/test/test_payload.py +92 -0
  12. beswarm/aient/src/aient/core/utils.py +715 -0
  13. beswarm/aient/src/aient/models/__init__.py +9 -0
  14. beswarm/aient/src/aient/models/audio.py +63 -0
  15. beswarm/aient/src/aient/models/base.py +251 -0
  16. beswarm/aient/src/aient/models/chatgpt.py +941 -0
  17. beswarm/aient/src/aient/models/claude.py +640 -0
  18. beswarm/aient/src/aient/models/duckduckgo.py +241 -0
  19. beswarm/aient/src/aient/models/gemini.py +357 -0
  20. beswarm/aient/src/aient/models/groq.py +268 -0
  21. beswarm/aient/src/aient/models/vertex.py +420 -0
  22. beswarm/aient/src/aient/plugins/__init__.py +33 -0
  23. beswarm/aient/src/aient/plugins/arXiv.py +48 -0
  24. beswarm/aient/src/aient/plugins/config.py +172 -0
  25. beswarm/aient/src/aient/plugins/excute_command.py +35 -0
  26. beswarm/aient/src/aient/plugins/get_time.py +19 -0
  27. beswarm/aient/src/aient/plugins/image.py +72 -0
  28. beswarm/aient/src/aient/plugins/list_directory.py +50 -0
  29. beswarm/aient/src/aient/plugins/read_file.py +79 -0
  30. beswarm/aient/src/aient/plugins/registry.py +116 -0
  31. beswarm/aient/src/aient/plugins/run_python.py +156 -0
  32. beswarm/aient/src/aient/plugins/websearch.py +394 -0
  33. beswarm/aient/src/aient/plugins/write_file.py +51 -0
  34. beswarm/aient/src/aient/prompt/__init__.py +1 -0
  35. beswarm/aient/src/aient/prompt/agent.py +280 -0
  36. beswarm/aient/src/aient/utils/__init__.py +0 -0
  37. beswarm/aient/src/aient/utils/prompt.py +143 -0
  38. beswarm/aient/src/aient/utils/scripts.py +721 -0
  39. beswarm/aient/test/chatgpt.py +161 -0
  40. beswarm/aient/test/claude.py +32 -0
  41. beswarm/aient/test/test.py +2 -0
  42. beswarm/aient/test/test_API.py +6 -0
  43. beswarm/aient/test/test_Deepbricks.py +20 -0
  44. beswarm/aient/test/test_Web_crawler.py +262 -0
  45. beswarm/aient/test/test_aiwaves.py +25 -0
  46. beswarm/aient/test/test_aiwaves_arxiv.py +19 -0
  47. beswarm/aient/test/test_ask_gemini.py +8 -0
  48. beswarm/aient/test/test_class.py +17 -0
  49. beswarm/aient/test/test_claude.py +23 -0
  50. beswarm/aient/test/test_claude_zh_char.py +26 -0
  51. beswarm/aient/test/test_ddg_search.py +50 -0
  52. beswarm/aient/test/test_download_pdf.py +56 -0
  53. beswarm/aient/test/test_gemini.py +97 -0
  54. beswarm/aient/test/test_get_token_dict.py +21 -0
  55. beswarm/aient/test/test_google_search.py +35 -0
  56. beswarm/aient/test/test_jieba.py +32 -0
  57. beswarm/aient/test/test_json.py +65 -0
  58. beswarm/aient/test/test_langchain_search_old.py +235 -0
  59. beswarm/aient/test/test_logging.py +32 -0
  60. beswarm/aient/test/test_ollama.py +55 -0
  61. beswarm/aient/test/test_plugin.py +16 -0
  62. beswarm/aient/test/test_py_run.py +26 -0
  63. beswarm/aient/test/test_requests.py +162 -0
  64. beswarm/aient/test/test_search.py +18 -0
  65. beswarm/aient/test/test_tikitoken.py +19 -0
  66. beswarm/aient/test/test_token.py +94 -0
  67. beswarm/aient/test/test_url.py +33 -0
  68. beswarm/aient/test/test_whisper.py +14 -0
  69. beswarm/aient/test/test_wildcard.py +20 -0
  70. beswarm/aient/test/test_yjh.py +21 -0
  71. beswarm/tools/worker.py +3 -1
  72. {beswarm-0.1.12.dist-info → beswarm-0.1.14.dist-info}/METADATA +1 -1
  73. beswarm-0.1.14.dist-info/RECORD +131 -0
  74. beswarm-0.1.12.dist-info/RECORD +0 -61
  75. {beswarm-0.1.12.dist-info → beswarm-0.1.14.dist-info}/WHEEL +0 -0
  76. {beswarm-0.1.12.dist-info → beswarm-0.1.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1665 @@
1
+ import re
2
+ import json
3
+ import httpx
4
+ import base64
5
+ import urllib.parse
6
+
7
+ from .models import RequestModel
8
+ from .utils import (
9
+ c3s,
10
+ c3o,
11
+ c3h,
12
+ c35s,
13
+ gemini1,
14
+ gemini2,
15
+ BaseAPI,
16
+ safe_get,
17
+ get_engine,
18
+ get_model_dict,
19
+ get_text_message,
20
+ get_image_message,
21
+ )
22
+
23
+ async def get_gemini_payload(request, engine, provider, api_key=None):
24
+ import re
25
+
26
+ headers = {
27
+ 'Content-Type': 'application/json'
28
+ }
29
+
30
+ # 获取映射后的实际模型ID
31
+ model_dict = get_model_dict(provider)
32
+ original_model = model_dict[request.model]
33
+
34
+ gemini_stream = "streamGenerateContent"
35
+ url = provider['base_url']
36
+ parsed_url = urllib.parse.urlparse(url)
37
+ if "/v1beta" in parsed_url.path:
38
+ api_version = "v1beta"
39
+ else:
40
+ api_version = "v1"
41
+
42
+ url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/models')[0].rstrip('/')}/models/{original_model}:{gemini_stream}?key={api_key}"
43
+
44
+ messages = []
45
+ systemInstruction = None
46
+ function_arguments = None
47
+ for msg in request.messages:
48
+ if msg.role == "assistant":
49
+ msg.role = "model"
50
+ tool_calls = None
51
+ if isinstance(msg.content, list):
52
+ content = []
53
+ for item in msg.content:
54
+ if item.type == "text":
55
+ text_message = await get_text_message(item.text, engine)
56
+ content.append(text_message)
57
+ elif item.type == "image_url" and provider.get("image", True):
58
+ image_message = await get_image_message(item.image_url.url, engine)
59
+ content.append(image_message)
60
+ else:
61
+ content = [{"text": msg.content}]
62
+ tool_calls = msg.tool_calls
63
+
64
+ if tool_calls:
65
+ tool_call = tool_calls[0]
66
+ function_arguments = {
67
+ "functionCall": {
68
+ "name": tool_call.function.name,
69
+ "args": json.loads(tool_call.function.arguments)
70
+ }
71
+ }
72
+ messages.append(
73
+ {
74
+ "role": "model",
75
+ "parts": [function_arguments]
76
+ }
77
+ )
78
+ elif msg.role == "tool":
79
+ function_call_name = function_arguments["functionCall"]["name"]
80
+ messages.append(
81
+ {
82
+ "role": "function",
83
+ "parts": [{
84
+ "functionResponse": {
85
+ "name": function_call_name,
86
+ "response": {
87
+ "name": function_call_name,
88
+ "content": {
89
+ "result": msg.content,
90
+ }
91
+ }
92
+ }
93
+ }]
94
+ }
95
+ )
96
+ elif msg.role != "system":
97
+ messages.append({"role": msg.role, "parts": content})
98
+ elif msg.role == "system":
99
+ content[0]["text"] = re.sub(r"_+", "_", content[0]["text"])
100
+ systemInstruction = {"parts": content}
101
+
102
+ off_models = ["gemini-2.0-flash", "gemini-2.5-flash", "gemini-1.5", "gemini-2.5-pro"]
103
+ if any(off_model in original_model for off_model in off_models):
104
+ safety_settings = "OFF"
105
+ else:
106
+ safety_settings = "BLOCK_NONE"
107
+
108
+ payload = {
109
+ "contents": messages or [{"role": "user", "parts": [{"text": "No messages"}]}],
110
+ "safetySettings": [
111
+ {
112
+ "category": "HARM_CATEGORY_HARASSMENT",
113
+ "threshold": safety_settings
114
+ },
115
+ {
116
+ "category": "HARM_CATEGORY_HATE_SPEECH",
117
+ "threshold": safety_settings
118
+ },
119
+ {
120
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
121
+ "threshold": safety_settings
122
+ },
123
+ {
124
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
125
+ "threshold": safety_settings
126
+ },
127
+ {
128
+ "category": "HARM_CATEGORY_CIVIC_INTEGRITY",
129
+ "threshold": "BLOCK_NONE"
130
+ },
131
+ ]
132
+ }
133
+
134
+ if systemInstruction:
135
+ if api_version == "v1beta":
136
+ payload["systemInstruction"] = systemInstruction
137
+ if api_version == "v1":
138
+ first_message = safe_get(payload, "contents", 0, "parts", 0, "text", default=None)
139
+ system_instruction = safe_get(systemInstruction, "parts", 0, "text", default=None)
140
+ if first_message and system_instruction:
141
+ payload["contents"][0]["parts"][0]["text"] = system_instruction + "\n" + first_message
142
+
143
+ miss_fields = [
144
+ 'model',
145
+ 'messages',
146
+ 'stream',
147
+ 'tool_choice',
148
+ 'presence_penalty',
149
+ 'frequency_penalty',
150
+ 'n',
151
+ 'user',
152
+ 'include_usage',
153
+ 'logprobs',
154
+ 'top_logprobs',
155
+ 'response_format',
156
+ 'stream_options',
157
+ ]
158
+ generation_config = {}
159
+
160
+ for field, value in request.model_dump(exclude_unset=True).items():
161
+ if field not in miss_fields and value is not None:
162
+ if field == "tools" and "gemini-2.0-flash-thinking" in original_model:
163
+ continue
164
+ if field == "tools":
165
+ # 处理每个工具的 function 定义
166
+ processed_tools = []
167
+ for tool in value:
168
+ function_def = tool["function"]
169
+ # 处理 parameters.properties 中的 default 字段
170
+ if safe_get(function_def, "parameters", "properties", default=None):
171
+ for prop_value in function_def["parameters"]["properties"].values():
172
+ if "default" in prop_value:
173
+ # 将 default 值添加到 description 中
174
+ default_value = prop_value["default"]
175
+ description = prop_value.get("description", "")
176
+ prop_value["description"] = f"{description}\nDefault: {default_value}"
177
+ # 删除 default 字段
178
+ del prop_value["default"]
179
+ if function_def["name"] != "googleSearch" and function_def["name"] != "googleSearch":
180
+ processed_tools.append({"function": function_def})
181
+
182
+ if processed_tools:
183
+ payload.update({
184
+ "tools": [{
185
+ "function_declarations": [tool["function"] for tool in processed_tools]
186
+ }],
187
+ "tool_config": {
188
+ "function_calling_config": {
189
+ "mode": "AUTO"
190
+ }
191
+ }
192
+ })
193
+ elif field == "temperature":
194
+ generation_config["temperature"] = value
195
+ elif field == "max_tokens":
196
+ generation_config["maxOutputTokens"] = value
197
+ elif field == "top_p":
198
+ generation_config["topP"] = value
199
+ else:
200
+ payload[field] = value
201
+
202
+ max_token_65k_models = ["gemini-2.5-pro", "gemini-2.0-pro", "gemini-2.0-flash-thinking", "gemini-2.5-flash"]
203
+ payload["generationConfig"] = generation_config
204
+ if "maxOutputTokens" not in generation_config:
205
+ if any(pro_model in original_model for pro_model in max_token_65k_models):
206
+ payload["generationConfig"]["maxOutputTokens"] = 65536
207
+ else:
208
+ payload["generationConfig"]["maxOutputTokens"] = 8192
209
+
210
+ # 从请求模型名中检测思考预算设置
211
+ m = re.match(r".*-think-(-?\d+)", request.model)
212
+ if m:
213
+ try:
214
+ val = int(m.group(1))
215
+ if val < 0:
216
+ val = 0
217
+ elif val > 24576:
218
+ val = 24576
219
+ payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": val}
220
+ except ValueError:
221
+ # 如果转换为整数失败,忽略思考预算设置
222
+ pass
223
+
224
+ # 检测search标签
225
+ if request.model.endswith("-search"):
226
+ if "tools" not in payload:
227
+ payload["tools"] = [{"googleSearch": {}}]
228
+ else:
229
+ payload["tools"].append({"googleSearch": {}})
230
+
231
+ return url, headers, payload
232
+
233
+ import time
234
+ from cryptography.hazmat.primitives import hashes
235
+ from cryptography.hazmat.primitives.asymmetric import padding
236
+ from cryptography.hazmat.primitives.serialization import load_pem_private_key
237
+
238
+ def create_jwt(client_email, private_key):
239
+ # JWT Header
240
+ header = json.dumps({
241
+ "alg": "RS256",
242
+ "typ": "JWT"
243
+ }).encode()
244
+
245
+ # JWT Payload
246
+ now = int(time.time())
247
+ payload = json.dumps({
248
+ "iss": client_email,
249
+ "scope": "https://www.googleapis.com/auth/cloud-platform",
250
+ "aud": "https://oauth2.googleapis.com/token",
251
+ "exp": now + 3600,
252
+ "iat": now
253
+ }).encode()
254
+
255
+ # Encode header and payload
256
+ segments = [
257
+ base64.urlsafe_b64encode(header).rstrip(b'='),
258
+ base64.urlsafe_b64encode(payload).rstrip(b'=')
259
+ ]
260
+
261
+ # Create signature
262
+ signing_input = b'.'.join(segments)
263
+ private_key = load_pem_private_key(private_key.encode(), password=None)
264
+ signature = private_key.sign(
265
+ signing_input,
266
+ padding.PKCS1v15(),
267
+ hashes.SHA256()
268
+ )
269
+
270
+ segments.append(base64.urlsafe_b64encode(signature).rstrip(b'='))
271
+ return b'.'.join(segments).decode()
272
+
273
+ def get_access_token(client_email, private_key):
274
+ jwt = create_jwt(client_email, private_key)
275
+
276
+ with httpx.Client() as client:
277
+ response = client.post(
278
+ "https://oauth2.googleapis.com/token",
279
+ data={
280
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
281
+ "assertion": jwt
282
+ },
283
+ headers={'Content-Type': "application/x-www-form-urlencoded"}
284
+ )
285
+ response.raise_for_status()
286
+ return response.json()["access_token"]
287
+
288
+ async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
289
+ headers = {
290
+ 'Content-Type': 'application/json'
291
+ }
292
+ if provider.get("client_email") and provider.get("private_key"):
293
+ access_token = get_access_token(provider['client_email'], provider['private_key'])
294
+ headers['Authorization'] = f"Bearer {access_token}"
295
+ if provider.get("project_id"):
296
+ project_id = provider.get("project_id")
297
+
298
+ gemini_stream = "streamGenerateContent"
299
+ model_dict = get_model_dict(provider)
300
+ original_model = model_dict[request.model]
301
+ search_tool = None
302
+
303
+ pro_models = ["gemini-2.5-pro", "gemini-2.0-pro", "gemini-exp"]
304
+ if any(pro_model in original_model for pro_model in pro_models):
305
+ location = gemini2
306
+ search_tool = {"googleSearch": {}}
307
+ else:
308
+ location = gemini1
309
+ search_tool = {"googleSearchRetrieval": {}}
310
+
311
+ if "google-vertex-ai" in provider.get("base_url", ""):
312
+ url = provider.get("base_url").rstrip('/') + "/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(
313
+ LOCATION=await location.next(),
314
+ PROJECT_ID=project_id,
315
+ MODEL_ID=original_model,
316
+ stream=gemini_stream
317
+ )
318
+ else:
319
+ url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(
320
+ LOCATION=await location.next(),
321
+ PROJECT_ID=project_id,
322
+ MODEL_ID=original_model,
323
+ stream=gemini_stream
324
+ )
325
+
326
+ messages = []
327
+ systemInstruction = None
328
+ function_arguments = None
329
+ for msg in request.messages:
330
+ if msg.role == "assistant":
331
+ msg.role = "model"
332
+ tool_calls = None
333
+ if isinstance(msg.content, list):
334
+ content = []
335
+ for item in msg.content:
336
+ if item.type == "text":
337
+ text_message = await get_text_message(item.text, engine)
338
+ content.append(text_message)
339
+ elif item.type == "image_url" and provider.get("image", True):
340
+ image_message = await get_image_message(item.image_url.url, engine)
341
+ content.append(image_message)
342
+ else:
343
+ content = [{"text": msg.content}]
344
+ tool_calls = msg.tool_calls
345
+
346
+ if tool_calls:
347
+ tool_call = tool_calls[0]
348
+ function_arguments = {
349
+ "functionCall": {
350
+ "name": tool_call.function.name,
351
+ "args": json.loads(tool_call.function.arguments)
352
+ }
353
+ }
354
+ messages.append(
355
+ {
356
+ "role": "model",
357
+ "parts": [function_arguments]
358
+ }
359
+ )
360
+ elif msg.role == "tool":
361
+ function_call_name = function_arguments["functionCall"]["name"]
362
+ messages.append(
363
+ {
364
+ "role": "function",
365
+ "parts": [{
366
+ "functionResponse": {
367
+ "name": function_call_name,
368
+ "response": {
369
+ "name": function_call_name,
370
+ "content": {
371
+ "result": msg.content,
372
+ }
373
+ }
374
+ }
375
+ }]
376
+ }
377
+ )
378
+ elif msg.role != "system":
379
+ messages.append({"role": msg.role, "parts": content})
380
+ elif msg.role == "system":
381
+ systemInstruction = {"parts": content}
382
+
383
+
384
+ payload = {
385
+ "contents": messages,
386
+ # "safetySettings": [
387
+ # {
388
+ # "category": "HARM_CATEGORY_HARASSMENT",
389
+ # "threshold": "BLOCK_NONE"
390
+ # },
391
+ # {
392
+ # "category": "HARM_CATEGORY_HATE_SPEECH",
393
+ # "threshold": "BLOCK_NONE"
394
+ # },
395
+ # {
396
+ # "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
397
+ # "threshold": "BLOCK_NONE"
398
+ # },
399
+ # {
400
+ # "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
401
+ # "threshold": "BLOCK_NONE"
402
+ # }
403
+ # ]
404
+ }
405
+ if systemInstruction:
406
+ payload["system_instruction"] = systemInstruction
407
+
408
+ miss_fields = [
409
+ 'model',
410
+ 'messages',
411
+ 'stream',
412
+ 'tool_choice',
413
+ 'presence_penalty',
414
+ 'frequency_penalty',
415
+ 'n',
416
+ 'user',
417
+ 'include_usage',
418
+ 'logprobs',
419
+ 'top_logprobs',
420
+ 'stream_options',
421
+ ]
422
+ generation_config = {}
423
+
424
+ for field, value in request.model_dump(exclude_unset=True).items():
425
+ if field not in miss_fields and value is not None:
426
+ if field == "tools":
427
+ payload.update({
428
+ "tools": [{
429
+ "function_declarations": [tool["function"] for tool in value]
430
+ }],
431
+ "tool_config": {
432
+ "function_calling_config": {
433
+ "mode": "AUTO"
434
+ }
435
+ }
436
+ })
437
+ elif field == "temperature":
438
+ generation_config["temperature"] = value
439
+ elif field == "max_tokens":
440
+ generation_config["max_output_tokens"] = value
441
+ elif field == "top_p":
442
+ generation_config["top_p"] = value
443
+ else:
444
+ payload[field] = value
445
+
446
+ if generation_config:
447
+ payload["generationConfig"] = generation_config
448
+ if "max_output_tokens" not in generation_config:
449
+ payload["generationConfig"]["max_output_tokens"] = 8192
450
+
451
+ if request.model.endswith("-search"):
452
+ if "tools" not in payload:
453
+ payload["tools"] = [search_tool]
454
+ else:
455
+ payload["tools"].append(search_tool)
456
+
457
+ return url, headers, payload
458
+
459
+ async def get_vertex_claude_payload(request, engine, provider, api_key=None):
460
+ headers = {
461
+ 'Content-Type': 'application/json',
462
+ }
463
+ if provider.get("client_email") and provider.get("private_key"):
464
+ access_token = get_access_token(provider['client_email'], provider['private_key'])
465
+ headers['Authorization'] = f"Bearer {access_token}"
466
+ if provider.get("project_id"):
467
+ project_id = provider.get("project_id")
468
+
469
+ model_dict = get_model_dict(provider)
470
+ original_model = model_dict[request.model]
471
+ if "claude-3-5-sonnet" in original_model or "claude-3-7-sonnet" in original_model:
472
+ location = c35s
473
+ elif "claude-3-opus" in original_model:
474
+ location = c3o
475
+ elif "claude-3-sonnet" in original_model:
476
+ location = c3s
477
+ elif "claude-3-haiku" in original_model:
478
+ location = c3h
479
+
480
+ claude_stream = "streamRawPredict"
481
+ url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(
482
+ LOCATION=await location.next(),
483
+ PROJECT_ID=project_id,
484
+ MODEL=original_model,
485
+ stream=claude_stream
486
+ )
487
+
488
+ messages = []
489
+ system_prompt = None
490
+ tool_id = None
491
+ for msg in request.messages:
492
+ tool_call_id = None
493
+ tool_calls = None
494
+ if isinstance(msg.content, list):
495
+ content = []
496
+ for item in msg.content:
497
+ if item.type == "text":
498
+ text_message = await get_text_message(item.text, engine)
499
+ content.append(text_message)
500
+ elif item.type == "image_url" and provider.get("image", True):
501
+ image_message = await get_image_message(item.image_url.url, engine)
502
+ content.append(image_message)
503
+ else:
504
+ content = msg.content
505
+ tool_calls = msg.tool_calls
506
+ tool_id = tool_calls[0].id if tool_calls else None or tool_id
507
+ tool_call_id = msg.tool_call_id
508
+
509
+ if tool_calls:
510
+ tool_calls_list = []
511
+ tool_call = tool_calls[0]
512
+ tool_calls_list.append({
513
+ "type": "tool_use",
514
+ "id": tool_call.id,
515
+ "name": tool_call.function.name,
516
+ "input": json.loads(tool_call.function.arguments),
517
+ })
518
+ messages.append({"role": msg.role, "content": tool_calls_list})
519
+ elif tool_call_id:
520
+ messages.append({"role": "user", "content": [{
521
+ "type": "tool_result",
522
+ "tool_use_id": tool_id,
523
+ "content": content
524
+ }]})
525
+ elif msg.role == "function":
526
+ messages.append({"role": "assistant", "content": [{
527
+ "type": "tool_use",
528
+ "id": "toolu_017r5miPMV6PGSNKmhvHPic4",
529
+ "name": msg.name,
530
+ "input": {"prompt": "..."}
531
+ }]})
532
+ messages.append({"role": "user", "content": [{
533
+ "type": "tool_result",
534
+ "tool_use_id": "toolu_017r5miPMV6PGSNKmhvHPic4",
535
+ "content": msg.content
536
+ }]})
537
+ elif msg.role != "system":
538
+ messages.append({"role": msg.role, "content": content})
539
+ elif msg.role == "system":
540
+ system_prompt = content
541
+
542
+ conversation_len = len(messages) - 1
543
+ message_index = 0
544
+ while message_index < conversation_len:
545
+ if messages[message_index]["role"] == messages[message_index + 1]["role"]:
546
+ if messages[message_index].get("content"):
547
+ if isinstance(messages[message_index]["content"], list):
548
+ messages[message_index]["content"].extend(messages[message_index + 1]["content"])
549
+ elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
550
+ content_list = [{"type": "text", "text": messages[message_index]["content"]}]
551
+ content_list.extend(messages[message_index + 1]["content"])
552
+ messages[message_index]["content"] = content_list
553
+ else:
554
+ messages[message_index]["content"] += messages[message_index + 1]["content"]
555
+ messages.pop(message_index + 1)
556
+ conversation_len = conversation_len - 1
557
+ else:
558
+ message_index = message_index + 1
559
+
560
+ if "claude-3-7-sonnet" in original_model:
561
+ max_tokens = 20000
562
+ elif "claude-3-5-sonnet" in original_model:
563
+ max_tokens = 8192
564
+ else:
565
+ max_tokens = 4096
566
+
567
+ payload = {
568
+ "anthropic_version": "vertex-2023-10-16",
569
+ "messages": messages,
570
+ "system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
571
+ "max_tokens": max_tokens,
572
+ }
573
+
574
+ if request.max_tokens:
575
+ payload["max_tokens"] = int(request.max_tokens)
576
+
577
+ miss_fields = [
578
+ 'model',
579
+ 'messages',
580
+ 'presence_penalty',
581
+ 'frequency_penalty',
582
+ 'n',
583
+ 'user',
584
+ 'include_usage',
585
+ 'stream_options',
586
+ ]
587
+
588
+ for field, value in request.model_dump(exclude_unset=True).items():
589
+ if field not in miss_fields and value is not None:
590
+ payload[field] = value
591
+
592
+ if request.tools and provider.get("tools"):
593
+ tools = []
594
+ for tool in request.tools:
595
+ json_tool = await gpt2claude_tools_json(tool.dict()["function"])
596
+ tools.append(json_tool)
597
+ payload["tools"] = tools
598
+ if "tool_choice" in payload:
599
+ if isinstance(payload["tool_choice"], dict):
600
+ if payload["tool_choice"]["type"] == "function":
601
+ payload["tool_choice"] = {
602
+ "type": "tool",
603
+ "name": payload["tool_choice"]["function"]["name"]
604
+ }
605
+ if isinstance(payload["tool_choice"], str):
606
+ if payload["tool_choice"] == "auto":
607
+ payload["tool_choice"] = {
608
+ "type": "auto"
609
+ }
610
+ if payload["tool_choice"] == "none":
611
+ payload["tool_choice"] = {
612
+ "type": "any"
613
+ }
614
+
615
+ if provider.get("tools") == False:
616
+ payload.pop("tools", None)
617
+ payload.pop("tool_choice", None)
618
+
619
+ return url, headers, payload
620
+
621
+ import hashlib
622
+ import hmac
623
+ import datetime
624
+ import urllib.parse
625
+ from datetime import timezone
626
+
627
+ def sign(key, msg):
628
+ return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
629
+
630
+ def get_signature_key(key, date_stamp, region_name, service_name):
631
+ k_date = sign(('AWS4' + key).encode('utf-8'), date_stamp)
632
+ k_region = sign(k_date, region_name)
633
+ k_service = sign(k_region, service_name)
634
+ k_signing = sign(k_service, 'aws4_request')
635
+ return k_signing
636
+
637
+ def get_signature(request_body, model_id, aws_access_key, aws_secret_key, aws_region, host, content_type, accept_header):
638
+ request_body = json.dumps(request_body)
639
+ SERVICE = "bedrock"
640
+ canonical_querystring = ''
641
+ method = 'POST'
642
+ raw_path = f'/model/{model_id}/invoke-with-response-stream'
643
+ canonical_uri = urllib.parse.quote(raw_path, safe='/-_.~')
644
+ # Create a date for headers and the credential string
645
+ t = datetime.datetime.now(timezone.utc)
646
+ amz_date = t.strftime('%Y%m%dT%H%M%SZ')
647
+ date_stamp = t.strftime('%Y%m%d') # Date YYYYMMDD
648
+
649
+ # --- Task 1: Create a Canonical Request ---
650
+ payload_hash = hashlib.sha256(request_body.encode('utf-8')).hexdigest()
651
+
652
+ canonical_headers = f'accept:{accept_header}\n' \
653
+ f'content-type:{content_type}\n' \
654
+ f'host:{host}\n' \
655
+ f'x-amz-bedrock-accept:{accept_header}\n' \
656
+ f'x-amz-content-sha256:{payload_hash}\n' \
657
+ f'x-amz-date:{amz_date}\n'
658
+ # 注意:头名称需要按字母顺序排序
659
+
660
+ signed_headers = 'accept;content-type;host;x-amz-bedrock-accept;x-amz-content-sha256;x-amz-date' # 按字母顺序排序
661
+
662
+ canonical_request = f'{method}\n' \
663
+ f'{canonical_uri}\n' \
664
+ f'{canonical_querystring}\n' \
665
+ f'{canonical_headers}\n' \
666
+ f'{signed_headers}\n' \
667
+ f'{payload_hash}'
668
+
669
+ # --- Task 2: Create the String to Sign ---
670
+ algorithm = 'AWS4-HMAC-SHA256'
671
+ credential_scope = f'{date_stamp}/{aws_region}/{SERVICE}/aws4_request'
672
+ string_to_sign = f'{algorithm}\n' \
673
+ f'{amz_date}\n' \
674
+ f'{credential_scope}\n' \
675
+ f'{hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()}'
676
+
677
+ # --- Task 3: Calculate the Signature ---
678
+ signing_key = get_signature_key(aws_secret_key, date_stamp, aws_region, SERVICE)
679
+ signature = hmac.new(signing_key, string_to_sign.encode('utf-8'), hashlib.sha256).hexdigest()
680
+
681
+ # --- Task 4: Add Signing Information to the Request ---
682
+ authorization_header = f'{algorithm} Credential={aws_access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}'
683
+ return amz_date, payload_hash, authorization_header
684
+
685
+ async def get_aws_payload(request, engine, provider, api_key=None):
686
+ CONTENT_TYPE = "application/json"
687
+ # AWS_REGION = "us-east-1"
688
+ model_dict = get_model_dict(provider)
689
+ original_model = model_dict[request.model]
690
+ # MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0"
691
+ base_url = provider.get('base_url')
692
+ AWS_REGION = base_url.split('.')[1]
693
+ HOST = f"bedrock-runtime.{AWS_REGION}.amazonaws.com"
694
+ # url = f"{base_url}/model/{original_model}/invoke"
695
+ url = f"{base_url}/model/{original_model}/invoke-with-response-stream"
696
+
697
+ # if "claude-3-5-sonnet" in original_model or "claude-3-7-sonnet" in original_model:
698
+ # location = c35s
699
+ # elif "claude-3-opus" in original_model:
700
+ # location = c3o
701
+ # elif "claude-3-sonnet" in original_model:
702
+ # location = c3s
703
+ # elif "claude-3-haiku" in original_model:
704
+ # location = c3h
705
+
706
+ # claude_stream = "streamRawPredict"
707
+ # url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(
708
+ # LOCATION=await location.next(),
709
+ # PROJECT_ID=project_id,
710
+ # MODEL=original_model,
711
+ # stream=claude_stream
712
+ # )
713
+
714
+ messages = []
715
+ system_prompt = None
716
+ tool_id = None
717
+ for msg in request.messages:
718
+ tool_call_id = None
719
+ tool_calls = None
720
+ if isinstance(msg.content, list):
721
+ content = []
722
+ for item in msg.content:
723
+ if item.type == "text":
724
+ text_message = await get_text_message(item.text, engine)
725
+ content.append(text_message)
726
+ elif item.type == "image_url" and provider.get("image", True):
727
+ image_message = await get_image_message(item.image_url.url, engine)
728
+ content.append(image_message)
729
+ else:
730
+ content = msg.content
731
+ tool_calls = msg.tool_calls
732
+ tool_id = tool_calls[0].id if tool_calls else None or tool_id
733
+ tool_call_id = msg.tool_call_id
734
+
735
+ if tool_calls:
736
+ tool_calls_list = []
737
+ tool_call = tool_calls[0]
738
+ tool_calls_list.append({
739
+ "type": "tool_use",
740
+ "id": tool_call.id,
741
+ "name": tool_call.function.name,
742
+ "input": json.loads(tool_call.function.arguments),
743
+ })
744
+ messages.append({"role": msg.role, "content": tool_calls_list})
745
+ elif tool_call_id:
746
+ messages.append({"role": "user", "content": [{
747
+ "type": "tool_result",
748
+ "tool_use_id": tool_id,
749
+ "content": content
750
+ }]})
751
+ elif msg.role == "function":
752
+ messages.append({"role": "assistant", "content": [{
753
+ "type": "tool_use",
754
+ "id": "toolu_017r5miPMV6PGSNKmhvHPic4",
755
+ "name": msg.name,
756
+ "input": {"prompt": "..."}
757
+ }]})
758
+ messages.append({"role": "user", "content": [{
759
+ "type": "tool_result",
760
+ "tool_use_id": "toolu_017r5miPMV6PGSNKmhvHPic4",
761
+ "content": msg.content
762
+ }]})
763
+ elif msg.role != "system":
764
+ messages.append({"role": msg.role, "content": content})
765
+ elif msg.role == "system":
766
+ system_prompt = content
767
+
768
+ conversation_len = len(messages) - 1
769
+ message_index = 0
770
+ while message_index < conversation_len:
771
+ if messages[message_index]["role"] == messages[message_index + 1]["role"]:
772
+ if messages[message_index].get("content"):
773
+ if isinstance(messages[message_index]["content"], list):
774
+ messages[message_index]["content"].extend(messages[message_index + 1]["content"])
775
+ elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
776
+ content_list = [{"type": "text", "text": messages[message_index]["content"]}]
777
+ content_list.extend(messages[message_index + 1]["content"])
778
+ messages[message_index]["content"] = content_list
779
+ else:
780
+ messages[message_index]["content"] += messages[message_index + 1]["content"]
781
+ messages.pop(message_index + 1)
782
+ conversation_len = conversation_len - 1
783
+ else:
784
+ message_index = message_index + 1
785
+
786
+ # if "claude-3-7-sonnet" in original_model:
787
+ # max_tokens = 20000
788
+ # elif "claude-3-5-sonnet" in original_model:
789
+ # max_tokens = 8192
790
+ # else:
791
+ # max_tokens = 4096
792
+ max_tokens = 4096
793
+
794
+ payload = {
795
+ "messages": messages,
796
+ "anthropic_version": "bedrock-2023-05-31",
797
+ "max_tokens": max_tokens,
798
+ }
799
+
800
+ # payload = {
801
+ # "anthropic_version": "vertex-2023-10-16",
802
+ # "messages": messages,
803
+ # "system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
804
+ # "max_tokens": max_tokens,
805
+ # }
806
+
807
+ if request.max_tokens:
808
+ payload["max_tokens"] = int(request.max_tokens)
809
+
810
+ miss_fields = [
811
+ 'model',
812
+ 'messages',
813
+ 'presence_penalty',
814
+ 'frequency_penalty',
815
+ 'n',
816
+ 'user',
817
+ 'include_usage',
818
+ 'stream_options',
819
+ 'stream',
820
+ ]
821
+
822
+ for field, value in request.model_dump(exclude_unset=True).items():
823
+ if field not in miss_fields and value is not None:
824
+ payload[field] = value
825
+
826
+ if request.tools and provider.get("tools"):
827
+ tools = []
828
+ for tool in request.tools:
829
+ json_tool = await gpt2claude_tools_json(tool.dict()["function"])
830
+ tools.append(json_tool)
831
+ payload["tools"] = tools
832
+ if "tool_choice" in payload:
833
+ if isinstance(payload["tool_choice"], dict):
834
+ if payload["tool_choice"]["type"] == "function":
835
+ payload["tool_choice"] = {
836
+ "type": "tool",
837
+ "name": payload["tool_choice"]["function"]["name"]
838
+ }
839
+ if isinstance(payload["tool_choice"], str):
840
+ if payload["tool_choice"] == "auto":
841
+ payload["tool_choice"] = {
842
+ "type": "auto"
843
+ }
844
+ if payload["tool_choice"] == "none":
845
+ payload["tool_choice"] = {
846
+ "type": "any"
847
+ }
848
+
849
+ if provider.get("tools") == False:
850
+ payload.pop("tools", None)
851
+ payload.pop("tool_choice", None)
852
+
853
+ if provider.get("aws_access_key") and provider.get("aws_secret_key"):
854
+ ACCEPT_HEADER = "application/vnd.amazon.bedrock.payload+json" # 指定接受 Bedrock 流格式
855
+ amz_date, payload_hash, authorization_header = get_signature(payload, original_model, provider.get("aws_access_key"), provider.get("aws_secret_key"), AWS_REGION, HOST, CONTENT_TYPE, ACCEPT_HEADER)
856
+ headers = {
857
+ 'Accept': ACCEPT_HEADER,
858
+ 'Content-Type': CONTENT_TYPE,
859
+ 'X-Amz-Date': amz_date,
860
+ 'X-Amz-Bedrock-Accept': ACCEPT_HEADER, # Bedrock 特定头
861
+ 'X-Amz-Content-Sha256': payload_hash,
862
+ 'Authorization': authorization_header,
863
+ # Add 'X-Amz-Security-Token': SESSION_TOKEN if using temporary credentials
864
+ }
865
+
866
+ return url, headers, payload
867
+
868
+ async def get_gpt_payload(request, engine, provider, api_key=None):
869
+ headers = {
870
+ 'Content-Type': 'application/json',
871
+ }
872
+ model_dict = get_model_dict(provider)
873
+ original_model = model_dict[request.model]
874
+ if api_key:
875
+ headers['Authorization'] = f"Bearer {api_key}"
876
+
877
+ url = provider['base_url']
878
+
879
+ messages = []
880
+ for msg in request.messages:
881
+ tool_calls = None
882
+ tool_call_id = None
883
+ if isinstance(msg.content, list):
884
+ content = []
885
+ for item in msg.content:
886
+ if item.type == "text":
887
+ text_message = await get_text_message(item.text, engine)
888
+ content.append(text_message)
889
+ elif item.type == "image_url" and provider.get("image", True) and "o1-mini" not in original_model:
890
+ image_message = await get_image_message(item.image_url.url, engine)
891
+ content.append(image_message)
892
+ else:
893
+ content = msg.content
894
+ if msg.role == "system" and "o3-mini" in original_model and not content.startswith("Formatting re-enabled"):
895
+ content = "Formatting re-enabled. " + content
896
+ tool_calls = msg.tool_calls
897
+ tool_call_id = msg.tool_call_id
898
+
899
+ if tool_calls:
900
+ tool_calls_list = []
901
+ for tool_call in tool_calls:
902
+ tool_calls_list.append({
903
+ "id": tool_call.id,
904
+ "type": tool_call.type,
905
+ "function": {
906
+ "name": tool_call.function.name,
907
+ "arguments": tool_call.function.arguments
908
+ }
909
+ })
910
+ if provider.get("tools"):
911
+ messages.append({"role": msg.role, "tool_calls": tool_calls_list})
912
+ elif tool_call_id:
913
+ if provider.get("tools"):
914
+ messages.append({"role": msg.role, "tool_call_id": tool_call_id, "content": content})
915
+ else:
916
+ messages.append({"role": msg.role, "content": content})
917
+
918
+ if ("o1-mini" in original_model or "o1-preview" in original_model) and len(messages) > 1 and messages[0]["role"] == "system":
919
+ system_msg = messages.pop(0)
920
+ messages[0]["content"] = system_msg["content"] + messages[0]["content"]
921
+
922
+ payload = {
923
+ "model": original_model,
924
+ "messages": messages,
925
+ }
926
+
927
+ miss_fields = [
928
+ 'model',
929
+ 'messages',
930
+ ]
931
+
932
+ for field, value in request.model_dump(exclude_unset=True).items():
933
+ if field not in miss_fields and value is not None:
934
+ if field == "max_tokens" and ("o1" in original_model or "o3" in original_model or "o4" in original_model):
935
+ payload["max_completion_tokens"] = value
936
+ else:
937
+ payload[field] = value
938
+
939
+ if provider.get("tools") == False or "o1-mini" in original_model or "chatgpt-4o-latest" in original_model or "grok" in original_model:
940
+ payload.pop("tools", None)
941
+ payload.pop("tool_choice", None)
942
+
943
+ # if "models.inference.ai.azure.com" in url:
944
+ # payload["stream"] = False
945
+ # payload.pop("stream_options", None)
946
+
947
+ if "api.x.ai" in url:
948
+ payload.pop("stream_options", None)
949
+
950
+ if "grok-3-mini" in original_model:
951
+ if request.model.endswith("high"):
952
+ payload["reasoning_effort"] = "high"
953
+ elif request.model.endswith("low"):
954
+ payload["reasoning_effort"] = "low"
955
+
956
+ if "o1" in original_model or "o3" in original_model or "o4" in original_model:
957
+ if request.model.endswith("high"):
958
+ payload["reasoning_effort"] = "high"
959
+ elif request.model.endswith("low"):
960
+ payload["reasoning_effort"] = "low"
961
+ else:
962
+ payload["reasoning_effort"] = "medium"
963
+
964
+ if "temperature" in payload:
965
+ payload.pop("temperature")
966
+
967
+ # 代码生成/数学解题  0.0
968
+ # 数据抽取/分析 1.0
969
+ # 通用对话 1.3
970
+ # 翻译 1.3
971
+ # 创意类写作/诗歌创作 1.5
972
+ if "deepseek-r" in original_model.lower():
973
+ if "temperature" not in payload:
974
+ payload["temperature"] = 0.6
975
+
976
+ if request.model.endswith("-search") and "gemini" in original_model:
977
+ if "tools" not in payload:
978
+ payload["tools"] = [{
979
+ "type": "function",
980
+ "function": {
981
+ "name": "googleSearch",
982
+ "description": "googleSearch"
983
+ }
984
+ }]
985
+ else:
986
+ if not any(tool["function"]["name"] == "googleSearch" for tool in payload["tools"]):
987
+ payload["tools"].append({
988
+ "type": "function",
989
+ "function": {
990
+ "name": "googleSearch",
991
+ "description": "googleSearch"
992
+ }
993
+ })
994
+
995
+ return url, headers, payload
996
+
997
+ def build_azure_endpoint(base_url, deployment_id, api_version="2024-10-21"):
998
+ # 移除base_url末尾的斜杠(如果有)
999
+ base_url = base_url.rstrip('/')
1000
+ final_url = base_url
1001
+
1002
+ if "models/chat/completions" not in final_url:
1003
+ # 构建路径
1004
+ path = f"/openai/deployments/{deployment_id}/chat/completions"
1005
+ # 使用urljoin拼接base_url和path
1006
+ final_url = urllib.parse.urljoin(base_url, path)
1007
+
1008
+ if "?api-version=" not in final_url:
1009
+ # 添加api-version查询参数
1010
+ final_url = f"{final_url}?api-version={api_version}"
1011
+
1012
+ return final_url
1013
+
1014
+ async def get_azure_payload(request, engine, provider, api_key=None):
1015
+ headers = {
1016
+ 'Content-Type': 'application/json',
1017
+ }
1018
+ model_dict = get_model_dict(provider)
1019
+ original_model = model_dict[request.model]
1020
+ headers['api-key'] = f"{api_key}"
1021
+
1022
+ url = build_azure_endpoint(
1023
+ base_url=provider['base_url'],
1024
+ deployment_id=original_model,
1025
+ )
1026
+
1027
+ messages = []
1028
+ for msg in request.messages:
1029
+ tool_calls = None
1030
+ tool_call_id = None
1031
+ if isinstance(msg.content, list):
1032
+ content = []
1033
+ for item in msg.content:
1034
+ if item.type == "text":
1035
+ text_message = await get_text_message(item.text, engine)
1036
+ content.append(text_message)
1037
+ elif item.type == "image_url" and provider.get("image", True) and "o1-mini" not in original_model:
1038
+ image_message = await get_image_message(item.image_url.url, engine)
1039
+ content.append(image_message)
1040
+ else:
1041
+ content = msg.content
1042
+ tool_calls = msg.tool_calls
1043
+ tool_call_id = msg.tool_call_id
1044
+
1045
+ if tool_calls:
1046
+ tool_calls_list = []
1047
+ for tool_call in tool_calls:
1048
+ tool_calls_list.append({
1049
+ "id": tool_call.id,
1050
+ "type": tool_call.type,
1051
+ "function": {
1052
+ "name": tool_call.function.name,
1053
+ "arguments": tool_call.function.arguments
1054
+ }
1055
+ })
1056
+ if provider.get("tools"):
1057
+ messages.append({"role": msg.role, "tool_calls": tool_calls_list})
1058
+ elif tool_call_id:
1059
+ if provider.get("tools"):
1060
+ messages.append({"role": msg.role, "tool_call_id": tool_call_id, "content": content})
1061
+ else:
1062
+ messages.append({"role": msg.role, "content": content})
1063
+
1064
+ payload = {
1065
+ "model": original_model,
1066
+ "messages": messages,
1067
+ }
1068
+
1069
+ miss_fields = [
1070
+ 'model',
1071
+ 'messages',
1072
+ ]
1073
+
1074
+ for field, value in request.model_dump(exclude_unset=True).items():
1075
+ if field not in miss_fields and value is not None:
1076
+ if field == "max_tokens" and "o1" in original_model:
1077
+ payload["max_completion_tokens"] = value
1078
+ else:
1079
+ payload[field] = value
1080
+
1081
+ if provider.get("tools") == False or "o1" in original_model or "chatgpt-4o-latest" in original_model or "grok" in original_model:
1082
+ payload.pop("tools", None)
1083
+ payload.pop("tool_choice", None)
1084
+
1085
+ return url, headers, payload
1086
+
1087
+ async def get_openrouter_payload(request, engine, provider, api_key=None):
1088
+ headers = {
1089
+ 'Content-Type': 'application/json'
1090
+ }
1091
+ model_dict = get_model_dict(provider)
1092
+ original_model = model_dict[request.model]
1093
+ if api_key:
1094
+ headers['Authorization'] = f"Bearer {api_key}"
1095
+
1096
+ url = provider['base_url']
1097
+
1098
+ messages = []
1099
+ for msg in request.messages:
1100
+ name = None
1101
+ if isinstance(msg.content, list):
1102
+ content = []
1103
+ for item in msg.content:
1104
+ if item.type == "text":
1105
+ text_message = await get_text_message(item.text, engine)
1106
+ content.append(text_message)
1107
+ elif item.type == "image_url" and provider.get("image", True):
1108
+ image_message = await get_image_message(item.image_url.url, engine)
1109
+ content.append(image_message)
1110
+ else:
1111
+ content = msg.content
1112
+ name = msg.name
1113
+ if name:
1114
+ messages.append({"role": msg.role, "name": name, "content": content})
1115
+ else:
1116
+ # print("content", content)
1117
+ if isinstance(content, list):
1118
+ for item in content:
1119
+ if item["type"] == "text":
1120
+ messages.append({"role": msg.role, "content": item["text"]})
1121
+ elif item["type"] == "image_url":
1122
+ messages.append({"role": msg.role, "content": [await get_image_message(item["image_url"]["url"], engine)]})
1123
+ else:
1124
+ messages.append({"role": msg.role, "content": content})
1125
+
1126
+ payload = {
1127
+ "model": original_model,
1128
+ "messages": messages,
1129
+ }
1130
+
1131
+ miss_fields = [
1132
+ 'model',
1133
+ 'messages',
1134
+ 'n',
1135
+ 'user',
1136
+ 'include_usage',
1137
+ 'stream_options',
1138
+ ]
1139
+
1140
+ for field, value in request.model_dump(exclude_unset=True).items():
1141
+ if field not in miss_fields and value is not None:
1142
+ payload[field] = value
1143
+
1144
+ return url, headers, payload
1145
+
1146
+ async def get_cohere_payload(request, engine, provider, api_key=None):
1147
+ headers = {
1148
+ 'Content-Type': 'application/json'
1149
+ }
1150
+ model_dict = get_model_dict(provider)
1151
+ original_model = model_dict[request.model]
1152
+ if api_key:
1153
+ headers['Authorization'] = f"Bearer {api_key}"
1154
+
1155
+ url = provider['base_url']
1156
+
1157
+ role_map = {
1158
+ "user": "USER",
1159
+ "assistant" : "CHATBOT",
1160
+ "system": "SYSTEM"
1161
+ }
1162
+
1163
+ messages = []
1164
+ for msg in request.messages:
1165
+ if isinstance(msg.content, list):
1166
+ content = []
1167
+ for item in msg.content:
1168
+ if item.type == "text":
1169
+ text_message = await get_text_message(item.text, engine)
1170
+ content.append(text_message)
1171
+ else:
1172
+ content = msg.content
1173
+
1174
+ if isinstance(content, list):
1175
+ for item in content:
1176
+ if item["type"] == "text":
1177
+ messages.append({"role": role_map[msg.role], "message": item["text"]})
1178
+ else:
1179
+ messages.append({"role": role_map[msg.role], "message": content})
1180
+
1181
+ chat_history = messages[:-1]
1182
+ query = messages[-1].get("message")
1183
+ payload = {
1184
+ "model": original_model,
1185
+ "message": query,
1186
+ }
1187
+
1188
+ if chat_history:
1189
+ payload["chat_history"] = chat_history
1190
+
1191
+ miss_fields = [
1192
+ 'model',
1193
+ 'messages',
1194
+ 'tools',
1195
+ 'tool_choice',
1196
+ 'temperature',
1197
+ 'top_p',
1198
+ 'max_tokens',
1199
+ 'presence_penalty',
1200
+ 'frequency_penalty',
1201
+ 'n',
1202
+ 'user',
1203
+ 'include_usage',
1204
+ 'logprobs',
1205
+ 'top_logprobs',
1206
+ 'stream_options',
1207
+ ]
1208
+
1209
+ for field, value in request.model_dump(exclude_unset=True).items():
1210
+ if field not in miss_fields and value is not None:
1211
+ payload[field] = value
1212
+
1213
+ return url, headers, payload
1214
+
1215
+ async def get_cloudflare_payload(request, engine, provider, api_key=None):
1216
+ headers = {
1217
+ 'Content-Type': 'application/json'
1218
+ }
1219
+ model_dict = get_model_dict(provider)
1220
+ original_model = model_dict[request.model]
1221
+ if api_key:
1222
+ headers['Authorization'] = f"Bearer {api_key}"
1223
+
1224
+ url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=original_model)
1225
+
1226
+ msg = request.messages[-1]
1227
+ content = None
1228
+ if isinstance(msg.content, list):
1229
+ for item in msg.content:
1230
+ if item.type == "text":
1231
+ content = await get_text_message(item.text, engine)
1232
+ else:
1233
+ content = msg.content
1234
+
1235
+ payload = {
1236
+ "prompt": content,
1237
+ }
1238
+
1239
+ miss_fields = [
1240
+ 'model',
1241
+ 'messages',
1242
+ 'tools',
1243
+ 'tool_choice',
1244
+ 'temperature',
1245
+ 'top_p',
1246
+ 'max_tokens',
1247
+ 'presence_penalty',
1248
+ 'frequency_penalty',
1249
+ 'n',
1250
+ 'user',
1251
+ 'include_usage',
1252
+ 'logprobs',
1253
+ 'top_logprobs',
1254
+ 'stream_options',
1255
+ ]
1256
+
1257
+ for field, value in request.model_dump(exclude_unset=True).items():
1258
+ if field not in miss_fields and value is not None:
1259
+ payload[field] = value
1260
+
1261
+ return url, headers, payload
1262
+
1263
+ async def gpt2claude_tools_json(json_dict):
1264
+ import copy
1265
+ json_dict = copy.deepcopy(json_dict)
1266
+
1267
+ # 处理 $ref 引用
1268
+ def resolve_refs(obj, defs):
1269
+ if isinstance(obj, dict):
1270
+ # 如果有 $ref 引用,替换为实际定义
1271
+ if "$ref" in obj and obj["$ref"].startswith("#/$defs/"):
1272
+ ref_name = obj["$ref"].split("/")[-1]
1273
+ if ref_name in defs:
1274
+ # 完全替换为引用的对象
1275
+ ref_obj = copy.deepcopy(defs[ref_name])
1276
+ # 保留原始对象中的其他属性
1277
+ for k, v in obj.items():
1278
+ if k != "$ref":
1279
+ ref_obj[k] = v
1280
+ return ref_obj
1281
+
1282
+ # 递归处理所有属性
1283
+ for key, value in list(obj.items()):
1284
+ obj[key] = resolve_refs(value, defs)
1285
+
1286
+ elif isinstance(obj, list):
1287
+ # 递归处理列表中的每个元素
1288
+ for i, item in enumerate(obj):
1289
+ obj[i] = resolve_refs(item, defs)
1290
+
1291
+ return obj
1292
+
1293
+ # 提取 $defs 定义
1294
+ defs = {}
1295
+ if "parameters" in json_dict and "defs" in json_dict["parameters"]:
1296
+ defs = json_dict["parameters"]["defs"]
1297
+ # 从参数中删除 $defs,因为 Claude 不需要它
1298
+ del json_dict["parameters"]["defs"]
1299
+
1300
+ # 解析所有引用
1301
+ json_dict = resolve_refs(json_dict, defs)
1302
+
1303
+ # 继续原有的键名转换逻辑
1304
+ keys_to_change = {
1305
+ "parameters": "input_schema",
1306
+ }
1307
+ for old_key, new_key in keys_to_change.items():
1308
+ if old_key in json_dict:
1309
+ if new_key:
1310
+ if json_dict[old_key] == None:
1311
+ json_dict[old_key] = {
1312
+ "type": "object",
1313
+ "properties": {}
1314
+ }
1315
+ json_dict[new_key] = json_dict.pop(old_key)
1316
+ else:
1317
+ json_dict.pop(old_key)
1318
+ return json_dict
1319
+
1320
+ async def get_claude_payload(request, engine, provider, api_key=None):
1321
+ model_dict = get_model_dict(provider)
1322
+ original_model = model_dict[request.model]
1323
+
1324
+ if "claude-3-7-sonnet" in original_model:
1325
+ anthropic_beta = "output-128k-2025-02-19"
1326
+ elif "claude-3-5-sonnet" in original_model:
1327
+ anthropic_beta = "max-tokens-3-5-sonnet-2024-07-15"
1328
+ else:
1329
+ anthropic_beta = "tools-2024-05-16"
1330
+
1331
+ headers = {
1332
+ "content-type": "application/json",
1333
+ "x-api-key": f"{api_key}",
1334
+ "anthropic-version": "2023-06-01",
1335
+ "anthropic-beta": anthropic_beta,
1336
+ }
1337
+ url = provider['base_url']
1338
+
1339
+ messages = []
1340
+ system_prompt = None
1341
+ tool_id = None
1342
+ for msg in request.messages:
1343
+ tool_call_id = None
1344
+ tool_calls = None
1345
+ if isinstance(msg.content, list):
1346
+ content = []
1347
+ for item in msg.content:
1348
+ if item.type == "text":
1349
+ text_message = await get_text_message(item.text, engine)
1350
+ content.append(text_message)
1351
+ elif item.type == "image_url" and provider.get("image", True):
1352
+ image_message = await get_image_message(item.image_url.url, engine)
1353
+ content.append(image_message)
1354
+ else:
1355
+ content = msg.content
1356
+ tool_calls = msg.tool_calls
1357
+ tool_id = tool_calls[0].id if tool_calls else None or tool_id
1358
+ tool_call_id = msg.tool_call_id
1359
+
1360
+ if tool_calls:
1361
+ tool_calls_list = []
1362
+ tool_call = tool_calls[0]
1363
+ tool_calls_list.append({
1364
+ "type": "tool_use",
1365
+ "id": tool_call.id,
1366
+ "name": tool_call.function.name,
1367
+ "input": json.loads(tool_call.function.arguments),
1368
+ })
1369
+ messages.append({"role": msg.role, "content": tool_calls_list})
1370
+ elif tool_call_id:
1371
+ messages.append({"role": "user", "content": [{
1372
+ "type": "tool_result",
1373
+ "tool_use_id": tool_id,
1374
+ "content": content
1375
+ }]})
1376
+ elif msg.role == "function":
1377
+ messages.append({"role": "assistant", "content": [{
1378
+ "type": "tool_use",
1379
+ "id": "toolu_017r5miPMV6PGSNKmhvHPic4",
1380
+ "name": msg.name,
1381
+ "input": {"prompt": "..."}
1382
+ }]})
1383
+ messages.append({"role": "user", "content": [{
1384
+ "type": "tool_result",
1385
+ "tool_use_id": "toolu_017r5miPMV6PGSNKmhvHPic4",
1386
+ "content": msg.content
1387
+ }]})
1388
+ elif msg.role != "system":
1389
+ messages.append({"role": msg.role, "content": content})
1390
+ elif msg.role == "system":
1391
+ system_prompt = content
1392
+
1393
+ conversation_len = len(messages) - 1
1394
+ message_index = 0
1395
+ while message_index < conversation_len:
1396
+ if messages[message_index]["role"] == messages[message_index + 1]["role"]:
1397
+ if messages[message_index].get("content"):
1398
+ if isinstance(messages[message_index]["content"], list):
1399
+ messages[message_index]["content"].extend(messages[message_index + 1]["content"])
1400
+ elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
1401
+ content_list = [{"type": "text", "text": messages[message_index]["content"]}]
1402
+ content_list.extend(messages[message_index + 1]["content"])
1403
+ messages[message_index]["content"] = content_list
1404
+ else:
1405
+ messages[message_index]["content"] += messages[message_index + 1]["content"]
1406
+ messages.pop(message_index + 1)
1407
+ conversation_len = conversation_len - 1
1408
+ else:
1409
+ message_index = message_index + 1
1410
+
1411
+ if "claude-3-7-sonnet" in original_model:
1412
+ max_tokens = 20000
1413
+ elif "claude-3-5-sonnet" in original_model:
1414
+ max_tokens = 8192
1415
+ else:
1416
+ max_tokens = 4096
1417
+
1418
+ payload = {
1419
+ "model": original_model,
1420
+ "messages": messages,
1421
+ "system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
1422
+ "max_tokens": max_tokens,
1423
+ }
1424
+
1425
+ if request.max_tokens:
1426
+ payload["max_tokens"] = int(request.max_tokens)
1427
+
1428
+ miss_fields = [
1429
+ 'model',
1430
+ 'messages',
1431
+ 'presence_penalty',
1432
+ 'frequency_penalty',
1433
+ 'n',
1434
+ 'user',
1435
+ 'include_usage',
1436
+ 'stream_options',
1437
+ ]
1438
+
1439
+ for field, value in request.model_dump(exclude_unset=True).items():
1440
+ if field not in miss_fields and value is not None:
1441
+ payload[field] = value
1442
+
1443
+ if request.tools and provider.get("tools"):
1444
+ tools = []
1445
+ for tool in request.tools:
1446
+ # print("tool", type(tool), tool)
1447
+ json_tool = await gpt2claude_tools_json(tool.dict()["function"])
1448
+ tools.append(json_tool)
1449
+ payload["tools"] = tools
1450
+ if "tool_choice" in payload:
1451
+ if isinstance(payload["tool_choice"], dict):
1452
+ if payload["tool_choice"]["type"] == "function":
1453
+ payload["tool_choice"] = {
1454
+ "type": "tool",
1455
+ "name": payload["tool_choice"]["function"]["name"]
1456
+ }
1457
+ if isinstance(payload["tool_choice"], str):
1458
+ if payload["tool_choice"] == "auto":
1459
+ payload["tool_choice"] = {
1460
+ "type": "auto"
1461
+ }
1462
+ if payload["tool_choice"] == "none":
1463
+ payload["tool_choice"] = {
1464
+ "type": "any"
1465
+ }
1466
+
1467
+ if provider.get("tools") == False:
1468
+ payload.pop("tools", None)
1469
+ payload.pop("tool_choice", None)
1470
+
1471
+ if "think" in request.model:
1472
+ payload["thinking"] = {
1473
+ "budget_tokens": 4096,
1474
+ "type": "enabled"
1475
+ }
1476
+ payload["temperature"] = 1
1477
+ payload.pop("top_p", None)
1478
+ payload.pop("top_k", None)
1479
+ if request.model.split("-")[-1].isdigit():
1480
+ think_tokens = int(request.model.split("-")[-1])
1481
+ if think_tokens < max_tokens:
1482
+ payload["thinking"] = {
1483
+ "budget_tokens": think_tokens,
1484
+ "type": "enabled"
1485
+ }
1486
+
1487
+ if request.thinking:
1488
+ payload["thinking"] = {
1489
+ "budget_tokens": request.thinking.budget_tokens,
1490
+ "type": request.thinking.type
1491
+ }
1492
+ payload["temperature"] = 1
1493
+ payload.pop("top_p", None)
1494
+ payload.pop("top_k", None)
1495
+ # print("payload", json.dumps(payload, indent=2, ensure_ascii=False))
1496
+
1497
+ return url, headers, payload
1498
+
1499
+ async def get_dalle_payload(request, engine, provider, api_key=None):
1500
+ model_dict = get_model_dict(provider)
1501
+ original_model = model_dict[request.model]
1502
+ headers = {
1503
+ "Content-Type": "application/json",
1504
+ }
1505
+ if api_key:
1506
+ headers['Authorization'] = f"Bearer {api_key}"
1507
+ url = provider['base_url']
1508
+ url = BaseAPI(url).image_url
1509
+
1510
+ payload = {
1511
+ "model": original_model,
1512
+ "prompt": request.prompt,
1513
+ "n": request.n,
1514
+ "response_format": request.response_format,
1515
+ "size": request.size
1516
+ }
1517
+
1518
+ return url, headers, payload
1519
+
1520
+ async def get_whisper_payload(request, engine, provider, api_key=None):
1521
+ model_dict = get_model_dict(provider)
1522
+ original_model = model_dict[request.model]
1523
+ headers = {
1524
+ # "Content-Type": "multipart/form-data",
1525
+ }
1526
+ if api_key:
1527
+ headers['Authorization'] = f"Bearer {api_key}"
1528
+ url = provider['base_url']
1529
+ url = BaseAPI(url).audio_transcriptions
1530
+
1531
+ payload = {
1532
+ "model": original_model,
1533
+ "file": request.file,
1534
+ }
1535
+
1536
+ if request.prompt:
1537
+ payload["prompt"] = request.prompt
1538
+ if request.response_format:
1539
+ payload["response_format"] = request.response_format
1540
+ if request.temperature:
1541
+ payload["temperature"] = request.temperature
1542
+ if request.language:
1543
+ payload["language"] = request.language
1544
+
1545
+ # https://platform.openai.com/docs/api-reference/audio/createTranscription
1546
+ if request.timestamp_granularities:
1547
+ payload["timestamp_granularities[]"] = request.timestamp_granularities
1548
+
1549
+ return url, headers, payload
1550
+
1551
+ async def get_moderation_payload(request, engine, provider, api_key=None):
1552
+ model_dict = get_model_dict(provider)
1553
+ original_model = model_dict[request.model]
1554
+ headers = {
1555
+ "Content-Type": "application/json",
1556
+ }
1557
+ if api_key:
1558
+ headers['Authorization'] = f"Bearer {api_key}"
1559
+ url = provider['base_url']
1560
+ url = BaseAPI(url).moderations
1561
+
1562
+ payload = {
1563
+ "model": original_model,
1564
+ "input": request.input,
1565
+ }
1566
+
1567
+ return url, headers, payload
1568
+
1569
+ async def get_embedding_payload(request, engine, provider, api_key=None):
1570
+ model_dict = get_model_dict(provider)
1571
+ original_model = model_dict[request.model]
1572
+ headers = {
1573
+ "Content-Type": "application/json",
1574
+ }
1575
+ if api_key:
1576
+ headers['Authorization'] = f"Bearer {api_key}"
1577
+ url = provider['base_url']
1578
+ url = BaseAPI(url).embeddings
1579
+
1580
+ payload = {
1581
+ "input": request.input,
1582
+ "model": original_model,
1583
+ }
1584
+
1585
+ if request.encoding_format:
1586
+ if url.startswith("https://api.jina.ai"):
1587
+ payload["embedding_type"] = request.encoding_format
1588
+ else:
1589
+ payload["encoding_format"] = request.encoding_format
1590
+
1591
+ return url, headers, payload
1592
+
1593
+ async def get_tts_payload(request, engine, provider, api_key=None):
1594
+ model_dict = get_model_dict(provider)
1595
+ original_model = model_dict[request.model]
1596
+ headers = {
1597
+ "Content-Type": "application/json",
1598
+ }
1599
+ if api_key:
1600
+ headers['Authorization'] = f"Bearer {api_key}"
1601
+ url = provider['base_url']
1602
+ url = BaseAPI(url).audio_speech
1603
+
1604
+ payload = {
1605
+ "model": original_model,
1606
+ "input": request.input,
1607
+ "voice": request.voice,
1608
+ }
1609
+
1610
+ if request.response_format:
1611
+ payload["response_format"] = request.response_format
1612
+ if request.speed:
1613
+ payload["speed"] = request.speed
1614
+ if request.stream is not None:
1615
+ payload["stream"] = request.stream
1616
+
1617
+ return url, headers, payload
1618
+
1619
+
1620
+ async def get_payload(request: RequestModel, engine, provider, api_key=None):
1621
+ if engine == "gemini":
1622
+ return await get_gemini_payload(request, engine, provider, api_key)
1623
+ elif engine == "vertex-gemini":
1624
+ return await get_vertex_gemini_payload(request, engine, provider, api_key)
1625
+ elif engine == "aws":
1626
+ return await get_aws_payload(request, engine, provider, api_key)
1627
+ elif engine == "vertex-claude":
1628
+ return await get_vertex_claude_payload(request, engine, provider, api_key)
1629
+ elif engine == "azure":
1630
+ return await get_azure_payload(request, engine, provider, api_key)
1631
+ elif engine == "claude":
1632
+ return await get_claude_payload(request, engine, provider, api_key)
1633
+ elif engine == "gpt":
1634
+ provider['base_url'] = BaseAPI(provider['base_url']).chat_url
1635
+ return await get_gpt_payload(request, engine, provider, api_key)
1636
+ elif engine == "openrouter":
1637
+ return await get_openrouter_payload(request, engine, provider, api_key)
1638
+ elif engine == "cloudflare":
1639
+ return await get_cloudflare_payload(request, engine, provider, api_key)
1640
+ elif engine == "cohere":
1641
+ return await get_cohere_payload(request, engine, provider, api_key)
1642
+ elif engine == "dalle":
1643
+ return await get_dalle_payload(request, engine, provider, api_key)
1644
+ elif engine == "whisper":
1645
+ return await get_whisper_payload(request, engine, provider, api_key)
1646
+ elif engine == "tts":
1647
+ return await get_tts_payload(request, engine, provider, api_key)
1648
+ elif engine == "moderation":
1649
+ return await get_moderation_payload(request, engine, provider, api_key)
1650
+ elif engine == "embedding":
1651
+ return await get_embedding_payload(request, engine, provider, api_key)
1652
+ else:
1653
+ raise ValueError("Unknown payload")
1654
+
1655
+ async def prepare_request_payload(provider, request_data):
1656
+
1657
+ model_dict = get_model_dict(provider)
1658
+ request = RequestModel(**request_data)
1659
+
1660
+ original_model = model_dict[request.model]
1661
+ engine, _ = get_engine(provider, endpoint=None, original_model=original_model)
1662
+
1663
+ url, headers, payload = await get_payload(request, engine, provider, api_key=provider['api'])
1664
+
1665
+ return url, headers, payload, engine