beswarm 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- beswarm/aient/main.py +50 -0
- beswarm/aient/setup.py +15 -0
- beswarm/aient/src/aient/__init__.py +1 -0
- beswarm/aient/src/aient/core/__init__.py +1 -0
- beswarm/aient/src/aient/core/log_config.py +6 -0
- beswarm/aient/src/aient/core/models.py +232 -0
- beswarm/aient/src/aient/core/request.py +1665 -0
- beswarm/aient/src/aient/core/response.py +617 -0
- beswarm/aient/src/aient/core/test/test_base_api.py +18 -0
- beswarm/aient/src/aient/core/test/test_image.py +15 -0
- beswarm/aient/src/aient/core/test/test_payload.py +92 -0
- beswarm/aient/src/aient/core/utils.py +715 -0
- beswarm/aient/src/aient/models/__init__.py +9 -0
- beswarm/aient/src/aient/models/audio.py +63 -0
- beswarm/aient/src/aient/models/base.py +251 -0
- beswarm/aient/src/aient/models/chatgpt.py +941 -0
- beswarm/aient/src/aient/models/claude.py +640 -0
- beswarm/aient/src/aient/models/duckduckgo.py +241 -0
- beswarm/aient/src/aient/models/gemini.py +357 -0
- beswarm/aient/src/aient/models/groq.py +268 -0
- beswarm/aient/src/aient/models/vertex.py +420 -0
- beswarm/aient/src/aient/plugins/__init__.py +33 -0
- beswarm/aient/src/aient/plugins/arXiv.py +48 -0
- beswarm/aient/src/aient/plugins/config.py +172 -0
- beswarm/aient/src/aient/plugins/excute_command.py +35 -0
- beswarm/aient/src/aient/plugins/get_time.py +19 -0
- beswarm/aient/src/aient/plugins/image.py +72 -0
- beswarm/aient/src/aient/plugins/list_directory.py +50 -0
- beswarm/aient/src/aient/plugins/read_file.py +79 -0
- beswarm/aient/src/aient/plugins/registry.py +116 -0
- beswarm/aient/src/aient/plugins/run_python.py +156 -0
- beswarm/aient/src/aient/plugins/websearch.py +394 -0
- beswarm/aient/src/aient/plugins/write_file.py +51 -0
- beswarm/aient/src/aient/prompt/__init__.py +1 -0
- beswarm/aient/src/aient/prompt/agent.py +280 -0
- beswarm/aient/src/aient/utils/__init__.py +0 -0
- beswarm/aient/src/aient/utils/prompt.py +143 -0
- beswarm/aient/src/aient/utils/scripts.py +721 -0
- beswarm/aient/test/chatgpt.py +161 -0
- beswarm/aient/test/claude.py +32 -0
- beswarm/aient/test/test.py +2 -0
- beswarm/aient/test/test_API.py +6 -0
- beswarm/aient/test/test_Deepbricks.py +20 -0
- beswarm/aient/test/test_Web_crawler.py +262 -0
- beswarm/aient/test/test_aiwaves.py +25 -0
- beswarm/aient/test/test_aiwaves_arxiv.py +19 -0
- beswarm/aient/test/test_ask_gemini.py +8 -0
- beswarm/aient/test/test_class.py +17 -0
- beswarm/aient/test/test_claude.py +23 -0
- beswarm/aient/test/test_claude_zh_char.py +26 -0
- beswarm/aient/test/test_ddg_search.py +50 -0
- beswarm/aient/test/test_download_pdf.py +56 -0
- beswarm/aient/test/test_gemini.py +97 -0
- beswarm/aient/test/test_get_token_dict.py +21 -0
- beswarm/aient/test/test_google_search.py +35 -0
- beswarm/aient/test/test_jieba.py +32 -0
- beswarm/aient/test/test_json.py +65 -0
- beswarm/aient/test/test_langchain_search_old.py +235 -0
- beswarm/aient/test/test_logging.py +32 -0
- beswarm/aient/test/test_ollama.py +55 -0
- beswarm/aient/test/test_plugin.py +16 -0
- beswarm/aient/test/test_py_run.py +26 -0
- beswarm/aient/test/test_requests.py +162 -0
- beswarm/aient/test/test_search.py +18 -0
- beswarm/aient/test/test_tikitoken.py +19 -0
- beswarm/aient/test/test_token.py +94 -0
- beswarm/aient/test/test_url.py +33 -0
- beswarm/aient/test/test_whisper.py +14 -0
- beswarm/aient/test/test_wildcard.py +20 -0
- beswarm/aient/test/test_yjh.py +21 -0
- beswarm/tools/worker.py +3 -1
- {beswarm-0.1.12.dist-info → beswarm-0.1.14.dist-info}/METADATA +1 -1
- beswarm-0.1.14.dist-info/RECORD +131 -0
- beswarm-0.1.12.dist-info/RECORD +0 -61
- {beswarm-0.1.12.dist-info → beswarm-0.1.14.dist-info}/WHEEL +0 -0
- {beswarm-0.1.12.dist-info → beswarm-0.1.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1665 @@
|
|
1
|
+
import re
|
2
|
+
import json
|
3
|
+
import httpx
|
4
|
+
import base64
|
5
|
+
import urllib.parse
|
6
|
+
|
7
|
+
from .models import RequestModel
|
8
|
+
from .utils import (
|
9
|
+
c3s,
|
10
|
+
c3o,
|
11
|
+
c3h,
|
12
|
+
c35s,
|
13
|
+
gemini1,
|
14
|
+
gemini2,
|
15
|
+
BaseAPI,
|
16
|
+
safe_get,
|
17
|
+
get_engine,
|
18
|
+
get_model_dict,
|
19
|
+
get_text_message,
|
20
|
+
get_image_message,
|
21
|
+
)
|
22
|
+
|
23
|
+
async def get_gemini_payload(request, engine, provider, api_key=None):
|
24
|
+
import re
|
25
|
+
|
26
|
+
headers = {
|
27
|
+
'Content-Type': 'application/json'
|
28
|
+
}
|
29
|
+
|
30
|
+
# 获取映射后的实际模型ID
|
31
|
+
model_dict = get_model_dict(provider)
|
32
|
+
original_model = model_dict[request.model]
|
33
|
+
|
34
|
+
gemini_stream = "streamGenerateContent"
|
35
|
+
url = provider['base_url']
|
36
|
+
parsed_url = urllib.parse.urlparse(url)
|
37
|
+
if "/v1beta" in parsed_url.path:
|
38
|
+
api_version = "v1beta"
|
39
|
+
else:
|
40
|
+
api_version = "v1"
|
41
|
+
|
42
|
+
url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/models')[0].rstrip('/')}/models/{original_model}:{gemini_stream}?key={api_key}"
|
43
|
+
|
44
|
+
messages = []
|
45
|
+
systemInstruction = None
|
46
|
+
function_arguments = None
|
47
|
+
for msg in request.messages:
|
48
|
+
if msg.role == "assistant":
|
49
|
+
msg.role = "model"
|
50
|
+
tool_calls = None
|
51
|
+
if isinstance(msg.content, list):
|
52
|
+
content = []
|
53
|
+
for item in msg.content:
|
54
|
+
if item.type == "text":
|
55
|
+
text_message = await get_text_message(item.text, engine)
|
56
|
+
content.append(text_message)
|
57
|
+
elif item.type == "image_url" and provider.get("image", True):
|
58
|
+
image_message = await get_image_message(item.image_url.url, engine)
|
59
|
+
content.append(image_message)
|
60
|
+
else:
|
61
|
+
content = [{"text": msg.content}]
|
62
|
+
tool_calls = msg.tool_calls
|
63
|
+
|
64
|
+
if tool_calls:
|
65
|
+
tool_call = tool_calls[0]
|
66
|
+
function_arguments = {
|
67
|
+
"functionCall": {
|
68
|
+
"name": tool_call.function.name,
|
69
|
+
"args": json.loads(tool_call.function.arguments)
|
70
|
+
}
|
71
|
+
}
|
72
|
+
messages.append(
|
73
|
+
{
|
74
|
+
"role": "model",
|
75
|
+
"parts": [function_arguments]
|
76
|
+
}
|
77
|
+
)
|
78
|
+
elif msg.role == "tool":
|
79
|
+
function_call_name = function_arguments["functionCall"]["name"]
|
80
|
+
messages.append(
|
81
|
+
{
|
82
|
+
"role": "function",
|
83
|
+
"parts": [{
|
84
|
+
"functionResponse": {
|
85
|
+
"name": function_call_name,
|
86
|
+
"response": {
|
87
|
+
"name": function_call_name,
|
88
|
+
"content": {
|
89
|
+
"result": msg.content,
|
90
|
+
}
|
91
|
+
}
|
92
|
+
}
|
93
|
+
}]
|
94
|
+
}
|
95
|
+
)
|
96
|
+
elif msg.role != "system":
|
97
|
+
messages.append({"role": msg.role, "parts": content})
|
98
|
+
elif msg.role == "system":
|
99
|
+
content[0]["text"] = re.sub(r"_+", "_", content[0]["text"])
|
100
|
+
systemInstruction = {"parts": content}
|
101
|
+
|
102
|
+
off_models = ["gemini-2.0-flash", "gemini-2.5-flash", "gemini-1.5", "gemini-2.5-pro"]
|
103
|
+
if any(off_model in original_model for off_model in off_models):
|
104
|
+
safety_settings = "OFF"
|
105
|
+
else:
|
106
|
+
safety_settings = "BLOCK_NONE"
|
107
|
+
|
108
|
+
payload = {
|
109
|
+
"contents": messages or [{"role": "user", "parts": [{"text": "No messages"}]}],
|
110
|
+
"safetySettings": [
|
111
|
+
{
|
112
|
+
"category": "HARM_CATEGORY_HARASSMENT",
|
113
|
+
"threshold": safety_settings
|
114
|
+
},
|
115
|
+
{
|
116
|
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
117
|
+
"threshold": safety_settings
|
118
|
+
},
|
119
|
+
{
|
120
|
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
121
|
+
"threshold": safety_settings
|
122
|
+
},
|
123
|
+
{
|
124
|
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
125
|
+
"threshold": safety_settings
|
126
|
+
},
|
127
|
+
{
|
128
|
+
"category": "HARM_CATEGORY_CIVIC_INTEGRITY",
|
129
|
+
"threshold": "BLOCK_NONE"
|
130
|
+
},
|
131
|
+
]
|
132
|
+
}
|
133
|
+
|
134
|
+
if systemInstruction:
|
135
|
+
if api_version == "v1beta":
|
136
|
+
payload["systemInstruction"] = systemInstruction
|
137
|
+
if api_version == "v1":
|
138
|
+
first_message = safe_get(payload, "contents", 0, "parts", 0, "text", default=None)
|
139
|
+
system_instruction = safe_get(systemInstruction, "parts", 0, "text", default=None)
|
140
|
+
if first_message and system_instruction:
|
141
|
+
payload["contents"][0]["parts"][0]["text"] = system_instruction + "\n" + first_message
|
142
|
+
|
143
|
+
miss_fields = [
|
144
|
+
'model',
|
145
|
+
'messages',
|
146
|
+
'stream',
|
147
|
+
'tool_choice',
|
148
|
+
'presence_penalty',
|
149
|
+
'frequency_penalty',
|
150
|
+
'n',
|
151
|
+
'user',
|
152
|
+
'include_usage',
|
153
|
+
'logprobs',
|
154
|
+
'top_logprobs',
|
155
|
+
'response_format',
|
156
|
+
'stream_options',
|
157
|
+
]
|
158
|
+
generation_config = {}
|
159
|
+
|
160
|
+
for field, value in request.model_dump(exclude_unset=True).items():
|
161
|
+
if field not in miss_fields and value is not None:
|
162
|
+
if field == "tools" and "gemini-2.0-flash-thinking" in original_model:
|
163
|
+
continue
|
164
|
+
if field == "tools":
|
165
|
+
# 处理每个工具的 function 定义
|
166
|
+
processed_tools = []
|
167
|
+
for tool in value:
|
168
|
+
function_def = tool["function"]
|
169
|
+
# 处理 parameters.properties 中的 default 字段
|
170
|
+
if safe_get(function_def, "parameters", "properties", default=None):
|
171
|
+
for prop_value in function_def["parameters"]["properties"].values():
|
172
|
+
if "default" in prop_value:
|
173
|
+
# 将 default 值添加到 description 中
|
174
|
+
default_value = prop_value["default"]
|
175
|
+
description = prop_value.get("description", "")
|
176
|
+
prop_value["description"] = f"{description}\nDefault: {default_value}"
|
177
|
+
# 删除 default 字段
|
178
|
+
del prop_value["default"]
|
179
|
+
if function_def["name"] != "googleSearch" and function_def["name"] != "googleSearch":
|
180
|
+
processed_tools.append({"function": function_def})
|
181
|
+
|
182
|
+
if processed_tools:
|
183
|
+
payload.update({
|
184
|
+
"tools": [{
|
185
|
+
"function_declarations": [tool["function"] for tool in processed_tools]
|
186
|
+
}],
|
187
|
+
"tool_config": {
|
188
|
+
"function_calling_config": {
|
189
|
+
"mode": "AUTO"
|
190
|
+
}
|
191
|
+
}
|
192
|
+
})
|
193
|
+
elif field == "temperature":
|
194
|
+
generation_config["temperature"] = value
|
195
|
+
elif field == "max_tokens":
|
196
|
+
generation_config["maxOutputTokens"] = value
|
197
|
+
elif field == "top_p":
|
198
|
+
generation_config["topP"] = value
|
199
|
+
else:
|
200
|
+
payload[field] = value
|
201
|
+
|
202
|
+
max_token_65k_models = ["gemini-2.5-pro", "gemini-2.0-pro", "gemini-2.0-flash-thinking", "gemini-2.5-flash"]
|
203
|
+
payload["generationConfig"] = generation_config
|
204
|
+
if "maxOutputTokens" not in generation_config:
|
205
|
+
if any(pro_model in original_model for pro_model in max_token_65k_models):
|
206
|
+
payload["generationConfig"]["maxOutputTokens"] = 65536
|
207
|
+
else:
|
208
|
+
payload["generationConfig"]["maxOutputTokens"] = 8192
|
209
|
+
|
210
|
+
# 从请求模型名中检测思考预算设置
|
211
|
+
m = re.match(r".*-think-(-?\d+)", request.model)
|
212
|
+
if m:
|
213
|
+
try:
|
214
|
+
val = int(m.group(1))
|
215
|
+
if val < 0:
|
216
|
+
val = 0
|
217
|
+
elif val > 24576:
|
218
|
+
val = 24576
|
219
|
+
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": val}
|
220
|
+
except ValueError:
|
221
|
+
# 如果转换为整数失败,忽略思考预算设置
|
222
|
+
pass
|
223
|
+
|
224
|
+
# 检测search标签
|
225
|
+
if request.model.endswith("-search"):
|
226
|
+
if "tools" not in payload:
|
227
|
+
payload["tools"] = [{"googleSearch": {}}]
|
228
|
+
else:
|
229
|
+
payload["tools"].append({"googleSearch": {}})
|
230
|
+
|
231
|
+
return url, headers, payload
|
232
|
+
|
233
|
+
import time
|
234
|
+
from cryptography.hazmat.primitives import hashes
|
235
|
+
from cryptography.hazmat.primitives.asymmetric import padding
|
236
|
+
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
237
|
+
|
238
|
+
def create_jwt(client_email, private_key):
|
239
|
+
# JWT Header
|
240
|
+
header = json.dumps({
|
241
|
+
"alg": "RS256",
|
242
|
+
"typ": "JWT"
|
243
|
+
}).encode()
|
244
|
+
|
245
|
+
# JWT Payload
|
246
|
+
now = int(time.time())
|
247
|
+
payload = json.dumps({
|
248
|
+
"iss": client_email,
|
249
|
+
"scope": "https://www.googleapis.com/auth/cloud-platform",
|
250
|
+
"aud": "https://oauth2.googleapis.com/token",
|
251
|
+
"exp": now + 3600,
|
252
|
+
"iat": now
|
253
|
+
}).encode()
|
254
|
+
|
255
|
+
# Encode header and payload
|
256
|
+
segments = [
|
257
|
+
base64.urlsafe_b64encode(header).rstrip(b'='),
|
258
|
+
base64.urlsafe_b64encode(payload).rstrip(b'=')
|
259
|
+
]
|
260
|
+
|
261
|
+
# Create signature
|
262
|
+
signing_input = b'.'.join(segments)
|
263
|
+
private_key = load_pem_private_key(private_key.encode(), password=None)
|
264
|
+
signature = private_key.sign(
|
265
|
+
signing_input,
|
266
|
+
padding.PKCS1v15(),
|
267
|
+
hashes.SHA256()
|
268
|
+
)
|
269
|
+
|
270
|
+
segments.append(base64.urlsafe_b64encode(signature).rstrip(b'='))
|
271
|
+
return b'.'.join(segments).decode()
|
272
|
+
|
273
|
+
def get_access_token(client_email, private_key):
|
274
|
+
jwt = create_jwt(client_email, private_key)
|
275
|
+
|
276
|
+
with httpx.Client() as client:
|
277
|
+
response = client.post(
|
278
|
+
"https://oauth2.googleapis.com/token",
|
279
|
+
data={
|
280
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
281
|
+
"assertion": jwt
|
282
|
+
},
|
283
|
+
headers={'Content-Type': "application/x-www-form-urlencoded"}
|
284
|
+
)
|
285
|
+
response.raise_for_status()
|
286
|
+
return response.json()["access_token"]
|
287
|
+
|
288
|
+
async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
|
289
|
+
headers = {
|
290
|
+
'Content-Type': 'application/json'
|
291
|
+
}
|
292
|
+
if provider.get("client_email") and provider.get("private_key"):
|
293
|
+
access_token = get_access_token(provider['client_email'], provider['private_key'])
|
294
|
+
headers['Authorization'] = f"Bearer {access_token}"
|
295
|
+
if provider.get("project_id"):
|
296
|
+
project_id = provider.get("project_id")
|
297
|
+
|
298
|
+
gemini_stream = "streamGenerateContent"
|
299
|
+
model_dict = get_model_dict(provider)
|
300
|
+
original_model = model_dict[request.model]
|
301
|
+
search_tool = None
|
302
|
+
|
303
|
+
pro_models = ["gemini-2.5-pro", "gemini-2.0-pro", "gemini-exp"]
|
304
|
+
if any(pro_model in original_model for pro_model in pro_models):
|
305
|
+
location = gemini2
|
306
|
+
search_tool = {"googleSearch": {}}
|
307
|
+
else:
|
308
|
+
location = gemini1
|
309
|
+
search_tool = {"googleSearchRetrieval": {}}
|
310
|
+
|
311
|
+
if "google-vertex-ai" in provider.get("base_url", ""):
|
312
|
+
url = provider.get("base_url").rstrip('/') + "/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(
|
313
|
+
LOCATION=await location.next(),
|
314
|
+
PROJECT_ID=project_id,
|
315
|
+
MODEL_ID=original_model,
|
316
|
+
stream=gemini_stream
|
317
|
+
)
|
318
|
+
else:
|
319
|
+
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(
|
320
|
+
LOCATION=await location.next(),
|
321
|
+
PROJECT_ID=project_id,
|
322
|
+
MODEL_ID=original_model,
|
323
|
+
stream=gemini_stream
|
324
|
+
)
|
325
|
+
|
326
|
+
messages = []
|
327
|
+
systemInstruction = None
|
328
|
+
function_arguments = None
|
329
|
+
for msg in request.messages:
|
330
|
+
if msg.role == "assistant":
|
331
|
+
msg.role = "model"
|
332
|
+
tool_calls = None
|
333
|
+
if isinstance(msg.content, list):
|
334
|
+
content = []
|
335
|
+
for item in msg.content:
|
336
|
+
if item.type == "text":
|
337
|
+
text_message = await get_text_message(item.text, engine)
|
338
|
+
content.append(text_message)
|
339
|
+
elif item.type == "image_url" and provider.get("image", True):
|
340
|
+
image_message = await get_image_message(item.image_url.url, engine)
|
341
|
+
content.append(image_message)
|
342
|
+
else:
|
343
|
+
content = [{"text": msg.content}]
|
344
|
+
tool_calls = msg.tool_calls
|
345
|
+
|
346
|
+
if tool_calls:
|
347
|
+
tool_call = tool_calls[0]
|
348
|
+
function_arguments = {
|
349
|
+
"functionCall": {
|
350
|
+
"name": tool_call.function.name,
|
351
|
+
"args": json.loads(tool_call.function.arguments)
|
352
|
+
}
|
353
|
+
}
|
354
|
+
messages.append(
|
355
|
+
{
|
356
|
+
"role": "model",
|
357
|
+
"parts": [function_arguments]
|
358
|
+
}
|
359
|
+
)
|
360
|
+
elif msg.role == "tool":
|
361
|
+
function_call_name = function_arguments["functionCall"]["name"]
|
362
|
+
messages.append(
|
363
|
+
{
|
364
|
+
"role": "function",
|
365
|
+
"parts": [{
|
366
|
+
"functionResponse": {
|
367
|
+
"name": function_call_name,
|
368
|
+
"response": {
|
369
|
+
"name": function_call_name,
|
370
|
+
"content": {
|
371
|
+
"result": msg.content,
|
372
|
+
}
|
373
|
+
}
|
374
|
+
}
|
375
|
+
}]
|
376
|
+
}
|
377
|
+
)
|
378
|
+
elif msg.role != "system":
|
379
|
+
messages.append({"role": msg.role, "parts": content})
|
380
|
+
elif msg.role == "system":
|
381
|
+
systemInstruction = {"parts": content}
|
382
|
+
|
383
|
+
|
384
|
+
payload = {
|
385
|
+
"contents": messages,
|
386
|
+
# "safetySettings": [
|
387
|
+
# {
|
388
|
+
# "category": "HARM_CATEGORY_HARASSMENT",
|
389
|
+
# "threshold": "BLOCK_NONE"
|
390
|
+
# },
|
391
|
+
# {
|
392
|
+
# "category": "HARM_CATEGORY_HATE_SPEECH",
|
393
|
+
# "threshold": "BLOCK_NONE"
|
394
|
+
# },
|
395
|
+
# {
|
396
|
+
# "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
397
|
+
# "threshold": "BLOCK_NONE"
|
398
|
+
# },
|
399
|
+
# {
|
400
|
+
# "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
401
|
+
# "threshold": "BLOCK_NONE"
|
402
|
+
# }
|
403
|
+
# ]
|
404
|
+
}
|
405
|
+
if systemInstruction:
|
406
|
+
payload["system_instruction"] = systemInstruction
|
407
|
+
|
408
|
+
miss_fields = [
|
409
|
+
'model',
|
410
|
+
'messages',
|
411
|
+
'stream',
|
412
|
+
'tool_choice',
|
413
|
+
'presence_penalty',
|
414
|
+
'frequency_penalty',
|
415
|
+
'n',
|
416
|
+
'user',
|
417
|
+
'include_usage',
|
418
|
+
'logprobs',
|
419
|
+
'top_logprobs',
|
420
|
+
'stream_options',
|
421
|
+
]
|
422
|
+
generation_config = {}
|
423
|
+
|
424
|
+
for field, value in request.model_dump(exclude_unset=True).items():
|
425
|
+
if field not in miss_fields and value is not None:
|
426
|
+
if field == "tools":
|
427
|
+
payload.update({
|
428
|
+
"tools": [{
|
429
|
+
"function_declarations": [tool["function"] for tool in value]
|
430
|
+
}],
|
431
|
+
"tool_config": {
|
432
|
+
"function_calling_config": {
|
433
|
+
"mode": "AUTO"
|
434
|
+
}
|
435
|
+
}
|
436
|
+
})
|
437
|
+
elif field == "temperature":
|
438
|
+
generation_config["temperature"] = value
|
439
|
+
elif field == "max_tokens":
|
440
|
+
generation_config["max_output_tokens"] = value
|
441
|
+
elif field == "top_p":
|
442
|
+
generation_config["top_p"] = value
|
443
|
+
else:
|
444
|
+
payload[field] = value
|
445
|
+
|
446
|
+
if generation_config:
|
447
|
+
payload["generationConfig"] = generation_config
|
448
|
+
if "max_output_tokens" not in generation_config:
|
449
|
+
payload["generationConfig"]["max_output_tokens"] = 8192
|
450
|
+
|
451
|
+
if request.model.endswith("-search"):
|
452
|
+
if "tools" not in payload:
|
453
|
+
payload["tools"] = [search_tool]
|
454
|
+
else:
|
455
|
+
payload["tools"].append(search_tool)
|
456
|
+
|
457
|
+
return url, headers, payload
|
458
|
+
|
459
|
+
async def get_vertex_claude_payload(request, engine, provider, api_key=None):
|
460
|
+
headers = {
|
461
|
+
'Content-Type': 'application/json',
|
462
|
+
}
|
463
|
+
if provider.get("client_email") and provider.get("private_key"):
|
464
|
+
access_token = get_access_token(provider['client_email'], provider['private_key'])
|
465
|
+
headers['Authorization'] = f"Bearer {access_token}"
|
466
|
+
if provider.get("project_id"):
|
467
|
+
project_id = provider.get("project_id")
|
468
|
+
|
469
|
+
model_dict = get_model_dict(provider)
|
470
|
+
original_model = model_dict[request.model]
|
471
|
+
if "claude-3-5-sonnet" in original_model or "claude-3-7-sonnet" in original_model:
|
472
|
+
location = c35s
|
473
|
+
elif "claude-3-opus" in original_model:
|
474
|
+
location = c3o
|
475
|
+
elif "claude-3-sonnet" in original_model:
|
476
|
+
location = c3s
|
477
|
+
elif "claude-3-haiku" in original_model:
|
478
|
+
location = c3h
|
479
|
+
|
480
|
+
claude_stream = "streamRawPredict"
|
481
|
+
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(
|
482
|
+
LOCATION=await location.next(),
|
483
|
+
PROJECT_ID=project_id,
|
484
|
+
MODEL=original_model,
|
485
|
+
stream=claude_stream
|
486
|
+
)
|
487
|
+
|
488
|
+
messages = []
|
489
|
+
system_prompt = None
|
490
|
+
tool_id = None
|
491
|
+
for msg in request.messages:
|
492
|
+
tool_call_id = None
|
493
|
+
tool_calls = None
|
494
|
+
if isinstance(msg.content, list):
|
495
|
+
content = []
|
496
|
+
for item in msg.content:
|
497
|
+
if item.type == "text":
|
498
|
+
text_message = await get_text_message(item.text, engine)
|
499
|
+
content.append(text_message)
|
500
|
+
elif item.type == "image_url" and provider.get("image", True):
|
501
|
+
image_message = await get_image_message(item.image_url.url, engine)
|
502
|
+
content.append(image_message)
|
503
|
+
else:
|
504
|
+
content = msg.content
|
505
|
+
tool_calls = msg.tool_calls
|
506
|
+
tool_id = tool_calls[0].id if tool_calls else None or tool_id
|
507
|
+
tool_call_id = msg.tool_call_id
|
508
|
+
|
509
|
+
if tool_calls:
|
510
|
+
tool_calls_list = []
|
511
|
+
tool_call = tool_calls[0]
|
512
|
+
tool_calls_list.append({
|
513
|
+
"type": "tool_use",
|
514
|
+
"id": tool_call.id,
|
515
|
+
"name": tool_call.function.name,
|
516
|
+
"input": json.loads(tool_call.function.arguments),
|
517
|
+
})
|
518
|
+
messages.append({"role": msg.role, "content": tool_calls_list})
|
519
|
+
elif tool_call_id:
|
520
|
+
messages.append({"role": "user", "content": [{
|
521
|
+
"type": "tool_result",
|
522
|
+
"tool_use_id": tool_id,
|
523
|
+
"content": content
|
524
|
+
}]})
|
525
|
+
elif msg.role == "function":
|
526
|
+
messages.append({"role": "assistant", "content": [{
|
527
|
+
"type": "tool_use",
|
528
|
+
"id": "toolu_017r5miPMV6PGSNKmhvHPic4",
|
529
|
+
"name": msg.name,
|
530
|
+
"input": {"prompt": "..."}
|
531
|
+
}]})
|
532
|
+
messages.append({"role": "user", "content": [{
|
533
|
+
"type": "tool_result",
|
534
|
+
"tool_use_id": "toolu_017r5miPMV6PGSNKmhvHPic4",
|
535
|
+
"content": msg.content
|
536
|
+
}]})
|
537
|
+
elif msg.role != "system":
|
538
|
+
messages.append({"role": msg.role, "content": content})
|
539
|
+
elif msg.role == "system":
|
540
|
+
system_prompt = content
|
541
|
+
|
542
|
+
conversation_len = len(messages) - 1
|
543
|
+
message_index = 0
|
544
|
+
while message_index < conversation_len:
|
545
|
+
if messages[message_index]["role"] == messages[message_index + 1]["role"]:
|
546
|
+
if messages[message_index].get("content"):
|
547
|
+
if isinstance(messages[message_index]["content"], list):
|
548
|
+
messages[message_index]["content"].extend(messages[message_index + 1]["content"])
|
549
|
+
elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
|
550
|
+
content_list = [{"type": "text", "text": messages[message_index]["content"]}]
|
551
|
+
content_list.extend(messages[message_index + 1]["content"])
|
552
|
+
messages[message_index]["content"] = content_list
|
553
|
+
else:
|
554
|
+
messages[message_index]["content"] += messages[message_index + 1]["content"]
|
555
|
+
messages.pop(message_index + 1)
|
556
|
+
conversation_len = conversation_len - 1
|
557
|
+
else:
|
558
|
+
message_index = message_index + 1
|
559
|
+
|
560
|
+
if "claude-3-7-sonnet" in original_model:
|
561
|
+
max_tokens = 20000
|
562
|
+
elif "claude-3-5-sonnet" in original_model:
|
563
|
+
max_tokens = 8192
|
564
|
+
else:
|
565
|
+
max_tokens = 4096
|
566
|
+
|
567
|
+
payload = {
|
568
|
+
"anthropic_version": "vertex-2023-10-16",
|
569
|
+
"messages": messages,
|
570
|
+
"system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
|
571
|
+
"max_tokens": max_tokens,
|
572
|
+
}
|
573
|
+
|
574
|
+
if request.max_tokens:
|
575
|
+
payload["max_tokens"] = int(request.max_tokens)
|
576
|
+
|
577
|
+
miss_fields = [
|
578
|
+
'model',
|
579
|
+
'messages',
|
580
|
+
'presence_penalty',
|
581
|
+
'frequency_penalty',
|
582
|
+
'n',
|
583
|
+
'user',
|
584
|
+
'include_usage',
|
585
|
+
'stream_options',
|
586
|
+
]
|
587
|
+
|
588
|
+
for field, value in request.model_dump(exclude_unset=True).items():
|
589
|
+
if field not in miss_fields and value is not None:
|
590
|
+
payload[field] = value
|
591
|
+
|
592
|
+
if request.tools and provider.get("tools"):
|
593
|
+
tools = []
|
594
|
+
for tool in request.tools:
|
595
|
+
json_tool = await gpt2claude_tools_json(tool.dict()["function"])
|
596
|
+
tools.append(json_tool)
|
597
|
+
payload["tools"] = tools
|
598
|
+
if "tool_choice" in payload:
|
599
|
+
if isinstance(payload["tool_choice"], dict):
|
600
|
+
if payload["tool_choice"]["type"] == "function":
|
601
|
+
payload["tool_choice"] = {
|
602
|
+
"type": "tool",
|
603
|
+
"name": payload["tool_choice"]["function"]["name"]
|
604
|
+
}
|
605
|
+
if isinstance(payload["tool_choice"], str):
|
606
|
+
if payload["tool_choice"] == "auto":
|
607
|
+
payload["tool_choice"] = {
|
608
|
+
"type": "auto"
|
609
|
+
}
|
610
|
+
if payload["tool_choice"] == "none":
|
611
|
+
payload["tool_choice"] = {
|
612
|
+
"type": "any"
|
613
|
+
}
|
614
|
+
|
615
|
+
if provider.get("tools") == False:
|
616
|
+
payload.pop("tools", None)
|
617
|
+
payload.pop("tool_choice", None)
|
618
|
+
|
619
|
+
return url, headers, payload
|
620
|
+
|
621
|
+
import hashlib
|
622
|
+
import hmac
|
623
|
+
import datetime
|
624
|
+
import urllib.parse
|
625
|
+
from datetime import timezone
|
626
|
+
|
627
|
+
def sign(key, msg):
|
628
|
+
return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
|
629
|
+
|
630
|
+
def get_signature_key(key, date_stamp, region_name, service_name):
|
631
|
+
k_date = sign(('AWS4' + key).encode('utf-8'), date_stamp)
|
632
|
+
k_region = sign(k_date, region_name)
|
633
|
+
k_service = sign(k_region, service_name)
|
634
|
+
k_signing = sign(k_service, 'aws4_request')
|
635
|
+
return k_signing
|
636
|
+
|
637
|
+
def get_signature(request_body, model_id, aws_access_key, aws_secret_key, aws_region, host, content_type, accept_header):
|
638
|
+
request_body = json.dumps(request_body)
|
639
|
+
SERVICE = "bedrock"
|
640
|
+
canonical_querystring = ''
|
641
|
+
method = 'POST'
|
642
|
+
raw_path = f'/model/{model_id}/invoke-with-response-stream'
|
643
|
+
canonical_uri = urllib.parse.quote(raw_path, safe='/-_.~')
|
644
|
+
# Create a date for headers and the credential string
|
645
|
+
t = datetime.datetime.now(timezone.utc)
|
646
|
+
amz_date = t.strftime('%Y%m%dT%H%M%SZ')
|
647
|
+
date_stamp = t.strftime('%Y%m%d') # Date YYYYMMDD
|
648
|
+
|
649
|
+
# --- Task 1: Create a Canonical Request ---
|
650
|
+
payload_hash = hashlib.sha256(request_body.encode('utf-8')).hexdigest()
|
651
|
+
|
652
|
+
canonical_headers = f'accept:{accept_header}\n' \
|
653
|
+
f'content-type:{content_type}\n' \
|
654
|
+
f'host:{host}\n' \
|
655
|
+
f'x-amz-bedrock-accept:{accept_header}\n' \
|
656
|
+
f'x-amz-content-sha256:{payload_hash}\n' \
|
657
|
+
f'x-amz-date:{amz_date}\n'
|
658
|
+
# 注意:头名称需要按字母顺序排序
|
659
|
+
|
660
|
+
signed_headers = 'accept;content-type;host;x-amz-bedrock-accept;x-amz-content-sha256;x-amz-date' # 按字母顺序排序
|
661
|
+
|
662
|
+
canonical_request = f'{method}\n' \
|
663
|
+
f'{canonical_uri}\n' \
|
664
|
+
f'{canonical_querystring}\n' \
|
665
|
+
f'{canonical_headers}\n' \
|
666
|
+
f'{signed_headers}\n' \
|
667
|
+
f'{payload_hash}'
|
668
|
+
|
669
|
+
# --- Task 2: Create the String to Sign ---
|
670
|
+
algorithm = 'AWS4-HMAC-SHA256'
|
671
|
+
credential_scope = f'{date_stamp}/{aws_region}/{SERVICE}/aws4_request'
|
672
|
+
string_to_sign = f'{algorithm}\n' \
|
673
|
+
f'{amz_date}\n' \
|
674
|
+
f'{credential_scope}\n' \
|
675
|
+
f'{hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()}'
|
676
|
+
|
677
|
+
# --- Task 3: Calculate the Signature ---
|
678
|
+
signing_key = get_signature_key(aws_secret_key, date_stamp, aws_region, SERVICE)
|
679
|
+
signature = hmac.new(signing_key, string_to_sign.encode('utf-8'), hashlib.sha256).hexdigest()
|
680
|
+
|
681
|
+
# --- Task 4: Add Signing Information to the Request ---
|
682
|
+
authorization_header = f'{algorithm} Credential={aws_access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}'
|
683
|
+
return amz_date, payload_hash, authorization_header
|
684
|
+
|
685
|
+
async def get_aws_payload(request, engine, provider, api_key=None):
|
686
|
+
CONTENT_TYPE = "application/json"
|
687
|
+
# AWS_REGION = "us-east-1"
|
688
|
+
model_dict = get_model_dict(provider)
|
689
|
+
original_model = model_dict[request.model]
|
690
|
+
# MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
691
|
+
base_url = provider.get('base_url')
|
692
|
+
AWS_REGION = base_url.split('.')[1]
|
693
|
+
HOST = f"bedrock-runtime.{AWS_REGION}.amazonaws.com"
|
694
|
+
# url = f"{base_url}/model/{original_model}/invoke"
|
695
|
+
url = f"{base_url}/model/{original_model}/invoke-with-response-stream"
|
696
|
+
|
697
|
+
# if "claude-3-5-sonnet" in original_model or "claude-3-7-sonnet" in original_model:
|
698
|
+
# location = c35s
|
699
|
+
# elif "claude-3-opus" in original_model:
|
700
|
+
# location = c3o
|
701
|
+
# elif "claude-3-sonnet" in original_model:
|
702
|
+
# location = c3s
|
703
|
+
# elif "claude-3-haiku" in original_model:
|
704
|
+
# location = c3h
|
705
|
+
|
706
|
+
# claude_stream = "streamRawPredict"
|
707
|
+
# url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(
|
708
|
+
# LOCATION=await location.next(),
|
709
|
+
# PROJECT_ID=project_id,
|
710
|
+
# MODEL=original_model,
|
711
|
+
# stream=claude_stream
|
712
|
+
# )
|
713
|
+
|
714
|
+
messages = []
|
715
|
+
system_prompt = None
|
716
|
+
tool_id = None
|
717
|
+
for msg in request.messages:
|
718
|
+
tool_call_id = None
|
719
|
+
tool_calls = None
|
720
|
+
if isinstance(msg.content, list):
|
721
|
+
content = []
|
722
|
+
for item in msg.content:
|
723
|
+
if item.type == "text":
|
724
|
+
text_message = await get_text_message(item.text, engine)
|
725
|
+
content.append(text_message)
|
726
|
+
elif item.type == "image_url" and provider.get("image", True):
|
727
|
+
image_message = await get_image_message(item.image_url.url, engine)
|
728
|
+
content.append(image_message)
|
729
|
+
else:
|
730
|
+
content = msg.content
|
731
|
+
tool_calls = msg.tool_calls
|
732
|
+
tool_id = tool_calls[0].id if tool_calls else None or tool_id
|
733
|
+
tool_call_id = msg.tool_call_id
|
734
|
+
|
735
|
+
if tool_calls:
|
736
|
+
tool_calls_list = []
|
737
|
+
tool_call = tool_calls[0]
|
738
|
+
tool_calls_list.append({
|
739
|
+
"type": "tool_use",
|
740
|
+
"id": tool_call.id,
|
741
|
+
"name": tool_call.function.name,
|
742
|
+
"input": json.loads(tool_call.function.arguments),
|
743
|
+
})
|
744
|
+
messages.append({"role": msg.role, "content": tool_calls_list})
|
745
|
+
elif tool_call_id:
|
746
|
+
messages.append({"role": "user", "content": [{
|
747
|
+
"type": "tool_result",
|
748
|
+
"tool_use_id": tool_id,
|
749
|
+
"content": content
|
750
|
+
}]})
|
751
|
+
elif msg.role == "function":
|
752
|
+
messages.append({"role": "assistant", "content": [{
|
753
|
+
"type": "tool_use",
|
754
|
+
"id": "toolu_017r5miPMV6PGSNKmhvHPic4",
|
755
|
+
"name": msg.name,
|
756
|
+
"input": {"prompt": "..."}
|
757
|
+
}]})
|
758
|
+
messages.append({"role": "user", "content": [{
|
759
|
+
"type": "tool_result",
|
760
|
+
"tool_use_id": "toolu_017r5miPMV6PGSNKmhvHPic4",
|
761
|
+
"content": msg.content
|
762
|
+
}]})
|
763
|
+
elif msg.role != "system":
|
764
|
+
messages.append({"role": msg.role, "content": content})
|
765
|
+
elif msg.role == "system":
|
766
|
+
system_prompt = content
|
767
|
+
|
768
|
+
conversation_len = len(messages) - 1
|
769
|
+
message_index = 0
|
770
|
+
while message_index < conversation_len:
|
771
|
+
if messages[message_index]["role"] == messages[message_index + 1]["role"]:
|
772
|
+
if messages[message_index].get("content"):
|
773
|
+
if isinstance(messages[message_index]["content"], list):
|
774
|
+
messages[message_index]["content"].extend(messages[message_index + 1]["content"])
|
775
|
+
elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
|
776
|
+
content_list = [{"type": "text", "text": messages[message_index]["content"]}]
|
777
|
+
content_list.extend(messages[message_index + 1]["content"])
|
778
|
+
messages[message_index]["content"] = content_list
|
779
|
+
else:
|
780
|
+
messages[message_index]["content"] += messages[message_index + 1]["content"]
|
781
|
+
messages.pop(message_index + 1)
|
782
|
+
conversation_len = conversation_len - 1
|
783
|
+
else:
|
784
|
+
message_index = message_index + 1
|
785
|
+
|
786
|
+
# if "claude-3-7-sonnet" in original_model:
|
787
|
+
# max_tokens = 20000
|
788
|
+
# elif "claude-3-5-sonnet" in original_model:
|
789
|
+
# max_tokens = 8192
|
790
|
+
# else:
|
791
|
+
# max_tokens = 4096
|
792
|
+
max_tokens = 4096
|
793
|
+
|
794
|
+
payload = {
|
795
|
+
"messages": messages,
|
796
|
+
"anthropic_version": "bedrock-2023-05-31",
|
797
|
+
"max_tokens": max_tokens,
|
798
|
+
}
|
799
|
+
|
800
|
+
# payload = {
|
801
|
+
# "anthropic_version": "vertex-2023-10-16",
|
802
|
+
# "messages": messages,
|
803
|
+
# "system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
|
804
|
+
# "max_tokens": max_tokens,
|
805
|
+
# }
|
806
|
+
|
807
|
+
if request.max_tokens:
|
808
|
+
payload["max_tokens"] = int(request.max_tokens)
|
809
|
+
|
810
|
+
miss_fields = [
|
811
|
+
'model',
|
812
|
+
'messages',
|
813
|
+
'presence_penalty',
|
814
|
+
'frequency_penalty',
|
815
|
+
'n',
|
816
|
+
'user',
|
817
|
+
'include_usage',
|
818
|
+
'stream_options',
|
819
|
+
'stream',
|
820
|
+
]
|
821
|
+
|
822
|
+
for field, value in request.model_dump(exclude_unset=True).items():
|
823
|
+
if field not in miss_fields and value is not None:
|
824
|
+
payload[field] = value
|
825
|
+
|
826
|
+
if request.tools and provider.get("tools"):
|
827
|
+
tools = []
|
828
|
+
for tool in request.tools:
|
829
|
+
json_tool = await gpt2claude_tools_json(tool.dict()["function"])
|
830
|
+
tools.append(json_tool)
|
831
|
+
payload["tools"] = tools
|
832
|
+
if "tool_choice" in payload:
|
833
|
+
if isinstance(payload["tool_choice"], dict):
|
834
|
+
if payload["tool_choice"]["type"] == "function":
|
835
|
+
payload["tool_choice"] = {
|
836
|
+
"type": "tool",
|
837
|
+
"name": payload["tool_choice"]["function"]["name"]
|
838
|
+
}
|
839
|
+
if isinstance(payload["tool_choice"], str):
|
840
|
+
if payload["tool_choice"] == "auto":
|
841
|
+
payload["tool_choice"] = {
|
842
|
+
"type": "auto"
|
843
|
+
}
|
844
|
+
if payload["tool_choice"] == "none":
|
845
|
+
payload["tool_choice"] = {
|
846
|
+
"type": "any"
|
847
|
+
}
|
848
|
+
|
849
|
+
if provider.get("tools") == False:
|
850
|
+
payload.pop("tools", None)
|
851
|
+
payload.pop("tool_choice", None)
|
852
|
+
|
853
|
+
if provider.get("aws_access_key") and provider.get("aws_secret_key"):
|
854
|
+
ACCEPT_HEADER = "application/vnd.amazon.bedrock.payload+json" # 指定接受 Bedrock 流格式
|
855
|
+
amz_date, payload_hash, authorization_header = get_signature(payload, original_model, provider.get("aws_access_key"), provider.get("aws_secret_key"), AWS_REGION, HOST, CONTENT_TYPE, ACCEPT_HEADER)
|
856
|
+
headers = {
|
857
|
+
'Accept': ACCEPT_HEADER,
|
858
|
+
'Content-Type': CONTENT_TYPE,
|
859
|
+
'X-Amz-Date': amz_date,
|
860
|
+
'X-Amz-Bedrock-Accept': ACCEPT_HEADER, # Bedrock 特定头
|
861
|
+
'X-Amz-Content-Sha256': payload_hash,
|
862
|
+
'Authorization': authorization_header,
|
863
|
+
# Add 'X-Amz-Security-Token': SESSION_TOKEN if using temporary credentials
|
864
|
+
}
|
865
|
+
|
866
|
+
return url, headers, payload
|
867
|
+
|
868
|
+
async def get_gpt_payload(request, engine, provider, api_key=None):
|
869
|
+
headers = {
|
870
|
+
'Content-Type': 'application/json',
|
871
|
+
}
|
872
|
+
model_dict = get_model_dict(provider)
|
873
|
+
original_model = model_dict[request.model]
|
874
|
+
if api_key:
|
875
|
+
headers['Authorization'] = f"Bearer {api_key}"
|
876
|
+
|
877
|
+
url = provider['base_url']
|
878
|
+
|
879
|
+
messages = []
|
880
|
+
for msg in request.messages:
|
881
|
+
tool_calls = None
|
882
|
+
tool_call_id = None
|
883
|
+
if isinstance(msg.content, list):
|
884
|
+
content = []
|
885
|
+
for item in msg.content:
|
886
|
+
if item.type == "text":
|
887
|
+
text_message = await get_text_message(item.text, engine)
|
888
|
+
content.append(text_message)
|
889
|
+
elif item.type == "image_url" and provider.get("image", True) and "o1-mini" not in original_model:
|
890
|
+
image_message = await get_image_message(item.image_url.url, engine)
|
891
|
+
content.append(image_message)
|
892
|
+
else:
|
893
|
+
content = msg.content
|
894
|
+
if msg.role == "system" and "o3-mini" in original_model and not content.startswith("Formatting re-enabled"):
|
895
|
+
content = "Formatting re-enabled. " + content
|
896
|
+
tool_calls = msg.tool_calls
|
897
|
+
tool_call_id = msg.tool_call_id
|
898
|
+
|
899
|
+
if tool_calls:
|
900
|
+
tool_calls_list = []
|
901
|
+
for tool_call in tool_calls:
|
902
|
+
tool_calls_list.append({
|
903
|
+
"id": tool_call.id,
|
904
|
+
"type": tool_call.type,
|
905
|
+
"function": {
|
906
|
+
"name": tool_call.function.name,
|
907
|
+
"arguments": tool_call.function.arguments
|
908
|
+
}
|
909
|
+
})
|
910
|
+
if provider.get("tools"):
|
911
|
+
messages.append({"role": msg.role, "tool_calls": tool_calls_list})
|
912
|
+
elif tool_call_id:
|
913
|
+
if provider.get("tools"):
|
914
|
+
messages.append({"role": msg.role, "tool_call_id": tool_call_id, "content": content})
|
915
|
+
else:
|
916
|
+
messages.append({"role": msg.role, "content": content})
|
917
|
+
|
918
|
+
if ("o1-mini" in original_model or "o1-preview" in original_model) and len(messages) > 1 and messages[0]["role"] == "system":
|
919
|
+
system_msg = messages.pop(0)
|
920
|
+
messages[0]["content"] = system_msg["content"] + messages[0]["content"]
|
921
|
+
|
922
|
+
payload = {
|
923
|
+
"model": original_model,
|
924
|
+
"messages": messages,
|
925
|
+
}
|
926
|
+
|
927
|
+
miss_fields = [
|
928
|
+
'model',
|
929
|
+
'messages',
|
930
|
+
]
|
931
|
+
|
932
|
+
for field, value in request.model_dump(exclude_unset=True).items():
|
933
|
+
if field not in miss_fields and value is not None:
|
934
|
+
if field == "max_tokens" and ("o1" in original_model or "o3" in original_model or "o4" in original_model):
|
935
|
+
payload["max_completion_tokens"] = value
|
936
|
+
else:
|
937
|
+
payload[field] = value
|
938
|
+
|
939
|
+
if provider.get("tools") == False or "o1-mini" in original_model or "chatgpt-4o-latest" in original_model or "grok" in original_model:
|
940
|
+
payload.pop("tools", None)
|
941
|
+
payload.pop("tool_choice", None)
|
942
|
+
|
943
|
+
# if "models.inference.ai.azure.com" in url:
|
944
|
+
# payload["stream"] = False
|
945
|
+
# payload.pop("stream_options", None)
|
946
|
+
|
947
|
+
if "api.x.ai" in url:
|
948
|
+
payload.pop("stream_options", None)
|
949
|
+
|
950
|
+
if "grok-3-mini" in original_model:
|
951
|
+
if request.model.endswith("high"):
|
952
|
+
payload["reasoning_effort"] = "high"
|
953
|
+
elif request.model.endswith("low"):
|
954
|
+
payload["reasoning_effort"] = "low"
|
955
|
+
|
956
|
+
if "o1" in original_model or "o3" in original_model or "o4" in original_model:
|
957
|
+
if request.model.endswith("high"):
|
958
|
+
payload["reasoning_effort"] = "high"
|
959
|
+
elif request.model.endswith("low"):
|
960
|
+
payload["reasoning_effort"] = "low"
|
961
|
+
else:
|
962
|
+
payload["reasoning_effort"] = "medium"
|
963
|
+
|
964
|
+
if "temperature" in payload:
|
965
|
+
payload.pop("temperature")
|
966
|
+
|
967
|
+
# 代码生成/数学解题 0.0
|
968
|
+
# 数据抽取/分析 1.0
|
969
|
+
# 通用对话 1.3
|
970
|
+
# 翻译 1.3
|
971
|
+
# 创意类写作/诗歌创作 1.5
|
972
|
+
if "deepseek-r" in original_model.lower():
|
973
|
+
if "temperature" not in payload:
|
974
|
+
payload["temperature"] = 0.6
|
975
|
+
|
976
|
+
if request.model.endswith("-search") and "gemini" in original_model:
|
977
|
+
if "tools" not in payload:
|
978
|
+
payload["tools"] = [{
|
979
|
+
"type": "function",
|
980
|
+
"function": {
|
981
|
+
"name": "googleSearch",
|
982
|
+
"description": "googleSearch"
|
983
|
+
}
|
984
|
+
}]
|
985
|
+
else:
|
986
|
+
if not any(tool["function"]["name"] == "googleSearch" for tool in payload["tools"]):
|
987
|
+
payload["tools"].append({
|
988
|
+
"type": "function",
|
989
|
+
"function": {
|
990
|
+
"name": "googleSearch",
|
991
|
+
"description": "googleSearch"
|
992
|
+
}
|
993
|
+
})
|
994
|
+
|
995
|
+
return url, headers, payload
|
996
|
+
|
997
|
+
def build_azure_endpoint(base_url, deployment_id, api_version="2024-10-21"):
|
998
|
+
# 移除base_url末尾的斜杠(如果有)
|
999
|
+
base_url = base_url.rstrip('/')
|
1000
|
+
final_url = base_url
|
1001
|
+
|
1002
|
+
if "models/chat/completions" not in final_url:
|
1003
|
+
# 构建路径
|
1004
|
+
path = f"/openai/deployments/{deployment_id}/chat/completions"
|
1005
|
+
# 使用urljoin拼接base_url和path
|
1006
|
+
final_url = urllib.parse.urljoin(base_url, path)
|
1007
|
+
|
1008
|
+
if "?api-version=" not in final_url:
|
1009
|
+
# 添加api-version查询参数
|
1010
|
+
final_url = f"{final_url}?api-version={api_version}"
|
1011
|
+
|
1012
|
+
return final_url
|
1013
|
+
|
1014
|
+
async def get_azure_payload(request, engine, provider, api_key=None):
|
1015
|
+
headers = {
|
1016
|
+
'Content-Type': 'application/json',
|
1017
|
+
}
|
1018
|
+
model_dict = get_model_dict(provider)
|
1019
|
+
original_model = model_dict[request.model]
|
1020
|
+
headers['api-key'] = f"{api_key}"
|
1021
|
+
|
1022
|
+
url = build_azure_endpoint(
|
1023
|
+
base_url=provider['base_url'],
|
1024
|
+
deployment_id=original_model,
|
1025
|
+
)
|
1026
|
+
|
1027
|
+
messages = []
|
1028
|
+
for msg in request.messages:
|
1029
|
+
tool_calls = None
|
1030
|
+
tool_call_id = None
|
1031
|
+
if isinstance(msg.content, list):
|
1032
|
+
content = []
|
1033
|
+
for item in msg.content:
|
1034
|
+
if item.type == "text":
|
1035
|
+
text_message = await get_text_message(item.text, engine)
|
1036
|
+
content.append(text_message)
|
1037
|
+
elif item.type == "image_url" and provider.get("image", True) and "o1-mini" not in original_model:
|
1038
|
+
image_message = await get_image_message(item.image_url.url, engine)
|
1039
|
+
content.append(image_message)
|
1040
|
+
else:
|
1041
|
+
content = msg.content
|
1042
|
+
tool_calls = msg.tool_calls
|
1043
|
+
tool_call_id = msg.tool_call_id
|
1044
|
+
|
1045
|
+
if tool_calls:
|
1046
|
+
tool_calls_list = []
|
1047
|
+
for tool_call in tool_calls:
|
1048
|
+
tool_calls_list.append({
|
1049
|
+
"id": tool_call.id,
|
1050
|
+
"type": tool_call.type,
|
1051
|
+
"function": {
|
1052
|
+
"name": tool_call.function.name,
|
1053
|
+
"arguments": tool_call.function.arguments
|
1054
|
+
}
|
1055
|
+
})
|
1056
|
+
if provider.get("tools"):
|
1057
|
+
messages.append({"role": msg.role, "tool_calls": tool_calls_list})
|
1058
|
+
elif tool_call_id:
|
1059
|
+
if provider.get("tools"):
|
1060
|
+
messages.append({"role": msg.role, "tool_call_id": tool_call_id, "content": content})
|
1061
|
+
else:
|
1062
|
+
messages.append({"role": msg.role, "content": content})
|
1063
|
+
|
1064
|
+
payload = {
|
1065
|
+
"model": original_model,
|
1066
|
+
"messages": messages,
|
1067
|
+
}
|
1068
|
+
|
1069
|
+
miss_fields = [
|
1070
|
+
'model',
|
1071
|
+
'messages',
|
1072
|
+
]
|
1073
|
+
|
1074
|
+
for field, value in request.model_dump(exclude_unset=True).items():
|
1075
|
+
if field not in miss_fields and value is not None:
|
1076
|
+
if field == "max_tokens" and "o1" in original_model:
|
1077
|
+
payload["max_completion_tokens"] = value
|
1078
|
+
else:
|
1079
|
+
payload[field] = value
|
1080
|
+
|
1081
|
+
if provider.get("tools") == False or "o1" in original_model or "chatgpt-4o-latest" in original_model or "grok" in original_model:
|
1082
|
+
payload.pop("tools", None)
|
1083
|
+
payload.pop("tool_choice", None)
|
1084
|
+
|
1085
|
+
return url, headers, payload
|
1086
|
+
|
1087
|
+
async def get_openrouter_payload(request, engine, provider, api_key=None):
|
1088
|
+
headers = {
|
1089
|
+
'Content-Type': 'application/json'
|
1090
|
+
}
|
1091
|
+
model_dict = get_model_dict(provider)
|
1092
|
+
original_model = model_dict[request.model]
|
1093
|
+
if api_key:
|
1094
|
+
headers['Authorization'] = f"Bearer {api_key}"
|
1095
|
+
|
1096
|
+
url = provider['base_url']
|
1097
|
+
|
1098
|
+
messages = []
|
1099
|
+
for msg in request.messages:
|
1100
|
+
name = None
|
1101
|
+
if isinstance(msg.content, list):
|
1102
|
+
content = []
|
1103
|
+
for item in msg.content:
|
1104
|
+
if item.type == "text":
|
1105
|
+
text_message = await get_text_message(item.text, engine)
|
1106
|
+
content.append(text_message)
|
1107
|
+
elif item.type == "image_url" and provider.get("image", True):
|
1108
|
+
image_message = await get_image_message(item.image_url.url, engine)
|
1109
|
+
content.append(image_message)
|
1110
|
+
else:
|
1111
|
+
content = msg.content
|
1112
|
+
name = msg.name
|
1113
|
+
if name:
|
1114
|
+
messages.append({"role": msg.role, "name": name, "content": content})
|
1115
|
+
else:
|
1116
|
+
# print("content", content)
|
1117
|
+
if isinstance(content, list):
|
1118
|
+
for item in content:
|
1119
|
+
if item["type"] == "text":
|
1120
|
+
messages.append({"role": msg.role, "content": item["text"]})
|
1121
|
+
elif item["type"] == "image_url":
|
1122
|
+
messages.append({"role": msg.role, "content": [await get_image_message(item["image_url"]["url"], engine)]})
|
1123
|
+
else:
|
1124
|
+
messages.append({"role": msg.role, "content": content})
|
1125
|
+
|
1126
|
+
payload = {
|
1127
|
+
"model": original_model,
|
1128
|
+
"messages": messages,
|
1129
|
+
}
|
1130
|
+
|
1131
|
+
miss_fields = [
|
1132
|
+
'model',
|
1133
|
+
'messages',
|
1134
|
+
'n',
|
1135
|
+
'user',
|
1136
|
+
'include_usage',
|
1137
|
+
'stream_options',
|
1138
|
+
]
|
1139
|
+
|
1140
|
+
for field, value in request.model_dump(exclude_unset=True).items():
|
1141
|
+
if field not in miss_fields and value is not None:
|
1142
|
+
payload[field] = value
|
1143
|
+
|
1144
|
+
return url, headers, payload
|
1145
|
+
|
1146
|
+
async def get_cohere_payload(request, engine, provider, api_key=None):
|
1147
|
+
headers = {
|
1148
|
+
'Content-Type': 'application/json'
|
1149
|
+
}
|
1150
|
+
model_dict = get_model_dict(provider)
|
1151
|
+
original_model = model_dict[request.model]
|
1152
|
+
if api_key:
|
1153
|
+
headers['Authorization'] = f"Bearer {api_key}"
|
1154
|
+
|
1155
|
+
url = provider['base_url']
|
1156
|
+
|
1157
|
+
role_map = {
|
1158
|
+
"user": "USER",
|
1159
|
+
"assistant" : "CHATBOT",
|
1160
|
+
"system": "SYSTEM"
|
1161
|
+
}
|
1162
|
+
|
1163
|
+
messages = []
|
1164
|
+
for msg in request.messages:
|
1165
|
+
if isinstance(msg.content, list):
|
1166
|
+
content = []
|
1167
|
+
for item in msg.content:
|
1168
|
+
if item.type == "text":
|
1169
|
+
text_message = await get_text_message(item.text, engine)
|
1170
|
+
content.append(text_message)
|
1171
|
+
else:
|
1172
|
+
content = msg.content
|
1173
|
+
|
1174
|
+
if isinstance(content, list):
|
1175
|
+
for item in content:
|
1176
|
+
if item["type"] == "text":
|
1177
|
+
messages.append({"role": role_map[msg.role], "message": item["text"]})
|
1178
|
+
else:
|
1179
|
+
messages.append({"role": role_map[msg.role], "message": content})
|
1180
|
+
|
1181
|
+
chat_history = messages[:-1]
|
1182
|
+
query = messages[-1].get("message")
|
1183
|
+
payload = {
|
1184
|
+
"model": original_model,
|
1185
|
+
"message": query,
|
1186
|
+
}
|
1187
|
+
|
1188
|
+
if chat_history:
|
1189
|
+
payload["chat_history"] = chat_history
|
1190
|
+
|
1191
|
+
miss_fields = [
|
1192
|
+
'model',
|
1193
|
+
'messages',
|
1194
|
+
'tools',
|
1195
|
+
'tool_choice',
|
1196
|
+
'temperature',
|
1197
|
+
'top_p',
|
1198
|
+
'max_tokens',
|
1199
|
+
'presence_penalty',
|
1200
|
+
'frequency_penalty',
|
1201
|
+
'n',
|
1202
|
+
'user',
|
1203
|
+
'include_usage',
|
1204
|
+
'logprobs',
|
1205
|
+
'top_logprobs',
|
1206
|
+
'stream_options',
|
1207
|
+
]
|
1208
|
+
|
1209
|
+
for field, value in request.model_dump(exclude_unset=True).items():
|
1210
|
+
if field not in miss_fields and value is not None:
|
1211
|
+
payload[field] = value
|
1212
|
+
|
1213
|
+
return url, headers, payload
|
1214
|
+
|
1215
|
+
async def get_cloudflare_payload(request, engine, provider, api_key=None):
|
1216
|
+
headers = {
|
1217
|
+
'Content-Type': 'application/json'
|
1218
|
+
}
|
1219
|
+
model_dict = get_model_dict(provider)
|
1220
|
+
original_model = model_dict[request.model]
|
1221
|
+
if api_key:
|
1222
|
+
headers['Authorization'] = f"Bearer {api_key}"
|
1223
|
+
|
1224
|
+
url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=original_model)
|
1225
|
+
|
1226
|
+
msg = request.messages[-1]
|
1227
|
+
content = None
|
1228
|
+
if isinstance(msg.content, list):
|
1229
|
+
for item in msg.content:
|
1230
|
+
if item.type == "text":
|
1231
|
+
content = await get_text_message(item.text, engine)
|
1232
|
+
else:
|
1233
|
+
content = msg.content
|
1234
|
+
|
1235
|
+
payload = {
|
1236
|
+
"prompt": content,
|
1237
|
+
}
|
1238
|
+
|
1239
|
+
miss_fields = [
|
1240
|
+
'model',
|
1241
|
+
'messages',
|
1242
|
+
'tools',
|
1243
|
+
'tool_choice',
|
1244
|
+
'temperature',
|
1245
|
+
'top_p',
|
1246
|
+
'max_tokens',
|
1247
|
+
'presence_penalty',
|
1248
|
+
'frequency_penalty',
|
1249
|
+
'n',
|
1250
|
+
'user',
|
1251
|
+
'include_usage',
|
1252
|
+
'logprobs',
|
1253
|
+
'top_logprobs',
|
1254
|
+
'stream_options',
|
1255
|
+
]
|
1256
|
+
|
1257
|
+
for field, value in request.model_dump(exclude_unset=True).items():
|
1258
|
+
if field not in miss_fields and value is not None:
|
1259
|
+
payload[field] = value
|
1260
|
+
|
1261
|
+
return url, headers, payload
|
1262
|
+
|
1263
|
+
async def gpt2claude_tools_json(json_dict):
|
1264
|
+
import copy
|
1265
|
+
json_dict = copy.deepcopy(json_dict)
|
1266
|
+
|
1267
|
+
# 处理 $ref 引用
|
1268
|
+
def resolve_refs(obj, defs):
|
1269
|
+
if isinstance(obj, dict):
|
1270
|
+
# 如果有 $ref 引用,替换为实际定义
|
1271
|
+
if "$ref" in obj and obj["$ref"].startswith("#/$defs/"):
|
1272
|
+
ref_name = obj["$ref"].split("/")[-1]
|
1273
|
+
if ref_name in defs:
|
1274
|
+
# 完全替换为引用的对象
|
1275
|
+
ref_obj = copy.deepcopy(defs[ref_name])
|
1276
|
+
# 保留原始对象中的其他属性
|
1277
|
+
for k, v in obj.items():
|
1278
|
+
if k != "$ref":
|
1279
|
+
ref_obj[k] = v
|
1280
|
+
return ref_obj
|
1281
|
+
|
1282
|
+
# 递归处理所有属性
|
1283
|
+
for key, value in list(obj.items()):
|
1284
|
+
obj[key] = resolve_refs(value, defs)
|
1285
|
+
|
1286
|
+
elif isinstance(obj, list):
|
1287
|
+
# 递归处理列表中的每个元素
|
1288
|
+
for i, item in enumerate(obj):
|
1289
|
+
obj[i] = resolve_refs(item, defs)
|
1290
|
+
|
1291
|
+
return obj
|
1292
|
+
|
1293
|
+
# 提取 $defs 定义
|
1294
|
+
defs = {}
|
1295
|
+
if "parameters" in json_dict and "defs" in json_dict["parameters"]:
|
1296
|
+
defs = json_dict["parameters"]["defs"]
|
1297
|
+
# 从参数中删除 $defs,因为 Claude 不需要它
|
1298
|
+
del json_dict["parameters"]["defs"]
|
1299
|
+
|
1300
|
+
# 解析所有引用
|
1301
|
+
json_dict = resolve_refs(json_dict, defs)
|
1302
|
+
|
1303
|
+
# 继续原有的键名转换逻辑
|
1304
|
+
keys_to_change = {
|
1305
|
+
"parameters": "input_schema",
|
1306
|
+
}
|
1307
|
+
for old_key, new_key in keys_to_change.items():
|
1308
|
+
if old_key in json_dict:
|
1309
|
+
if new_key:
|
1310
|
+
if json_dict[old_key] == None:
|
1311
|
+
json_dict[old_key] = {
|
1312
|
+
"type": "object",
|
1313
|
+
"properties": {}
|
1314
|
+
}
|
1315
|
+
json_dict[new_key] = json_dict.pop(old_key)
|
1316
|
+
else:
|
1317
|
+
json_dict.pop(old_key)
|
1318
|
+
return json_dict
|
1319
|
+
|
1320
|
+
async def get_claude_payload(request, engine, provider, api_key=None):
|
1321
|
+
model_dict = get_model_dict(provider)
|
1322
|
+
original_model = model_dict[request.model]
|
1323
|
+
|
1324
|
+
if "claude-3-7-sonnet" in original_model:
|
1325
|
+
anthropic_beta = "output-128k-2025-02-19"
|
1326
|
+
elif "claude-3-5-sonnet" in original_model:
|
1327
|
+
anthropic_beta = "max-tokens-3-5-sonnet-2024-07-15"
|
1328
|
+
else:
|
1329
|
+
anthropic_beta = "tools-2024-05-16"
|
1330
|
+
|
1331
|
+
headers = {
|
1332
|
+
"content-type": "application/json",
|
1333
|
+
"x-api-key": f"{api_key}",
|
1334
|
+
"anthropic-version": "2023-06-01",
|
1335
|
+
"anthropic-beta": anthropic_beta,
|
1336
|
+
}
|
1337
|
+
url = provider['base_url']
|
1338
|
+
|
1339
|
+
messages = []
|
1340
|
+
system_prompt = None
|
1341
|
+
tool_id = None
|
1342
|
+
for msg in request.messages:
|
1343
|
+
tool_call_id = None
|
1344
|
+
tool_calls = None
|
1345
|
+
if isinstance(msg.content, list):
|
1346
|
+
content = []
|
1347
|
+
for item in msg.content:
|
1348
|
+
if item.type == "text":
|
1349
|
+
text_message = await get_text_message(item.text, engine)
|
1350
|
+
content.append(text_message)
|
1351
|
+
elif item.type == "image_url" and provider.get("image", True):
|
1352
|
+
image_message = await get_image_message(item.image_url.url, engine)
|
1353
|
+
content.append(image_message)
|
1354
|
+
else:
|
1355
|
+
content = msg.content
|
1356
|
+
tool_calls = msg.tool_calls
|
1357
|
+
tool_id = tool_calls[0].id if tool_calls else None or tool_id
|
1358
|
+
tool_call_id = msg.tool_call_id
|
1359
|
+
|
1360
|
+
if tool_calls:
|
1361
|
+
tool_calls_list = []
|
1362
|
+
tool_call = tool_calls[0]
|
1363
|
+
tool_calls_list.append({
|
1364
|
+
"type": "tool_use",
|
1365
|
+
"id": tool_call.id,
|
1366
|
+
"name": tool_call.function.name,
|
1367
|
+
"input": json.loads(tool_call.function.arguments),
|
1368
|
+
})
|
1369
|
+
messages.append({"role": msg.role, "content": tool_calls_list})
|
1370
|
+
elif tool_call_id:
|
1371
|
+
messages.append({"role": "user", "content": [{
|
1372
|
+
"type": "tool_result",
|
1373
|
+
"tool_use_id": tool_id,
|
1374
|
+
"content": content
|
1375
|
+
}]})
|
1376
|
+
elif msg.role == "function":
|
1377
|
+
messages.append({"role": "assistant", "content": [{
|
1378
|
+
"type": "tool_use",
|
1379
|
+
"id": "toolu_017r5miPMV6PGSNKmhvHPic4",
|
1380
|
+
"name": msg.name,
|
1381
|
+
"input": {"prompt": "..."}
|
1382
|
+
}]})
|
1383
|
+
messages.append({"role": "user", "content": [{
|
1384
|
+
"type": "tool_result",
|
1385
|
+
"tool_use_id": "toolu_017r5miPMV6PGSNKmhvHPic4",
|
1386
|
+
"content": msg.content
|
1387
|
+
}]})
|
1388
|
+
elif msg.role != "system":
|
1389
|
+
messages.append({"role": msg.role, "content": content})
|
1390
|
+
elif msg.role == "system":
|
1391
|
+
system_prompt = content
|
1392
|
+
|
1393
|
+
conversation_len = len(messages) - 1
|
1394
|
+
message_index = 0
|
1395
|
+
while message_index < conversation_len:
|
1396
|
+
if messages[message_index]["role"] == messages[message_index + 1]["role"]:
|
1397
|
+
if messages[message_index].get("content"):
|
1398
|
+
if isinstance(messages[message_index]["content"], list):
|
1399
|
+
messages[message_index]["content"].extend(messages[message_index + 1]["content"])
|
1400
|
+
elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
|
1401
|
+
content_list = [{"type": "text", "text": messages[message_index]["content"]}]
|
1402
|
+
content_list.extend(messages[message_index + 1]["content"])
|
1403
|
+
messages[message_index]["content"] = content_list
|
1404
|
+
else:
|
1405
|
+
messages[message_index]["content"] += messages[message_index + 1]["content"]
|
1406
|
+
messages.pop(message_index + 1)
|
1407
|
+
conversation_len = conversation_len - 1
|
1408
|
+
else:
|
1409
|
+
message_index = message_index + 1
|
1410
|
+
|
1411
|
+
if "claude-3-7-sonnet" in original_model:
|
1412
|
+
max_tokens = 20000
|
1413
|
+
elif "claude-3-5-sonnet" in original_model:
|
1414
|
+
max_tokens = 8192
|
1415
|
+
else:
|
1416
|
+
max_tokens = 4096
|
1417
|
+
|
1418
|
+
payload = {
|
1419
|
+
"model": original_model,
|
1420
|
+
"messages": messages,
|
1421
|
+
"system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
|
1422
|
+
"max_tokens": max_tokens,
|
1423
|
+
}
|
1424
|
+
|
1425
|
+
if request.max_tokens:
|
1426
|
+
payload["max_tokens"] = int(request.max_tokens)
|
1427
|
+
|
1428
|
+
miss_fields = [
|
1429
|
+
'model',
|
1430
|
+
'messages',
|
1431
|
+
'presence_penalty',
|
1432
|
+
'frequency_penalty',
|
1433
|
+
'n',
|
1434
|
+
'user',
|
1435
|
+
'include_usage',
|
1436
|
+
'stream_options',
|
1437
|
+
]
|
1438
|
+
|
1439
|
+
for field, value in request.model_dump(exclude_unset=True).items():
|
1440
|
+
if field not in miss_fields and value is not None:
|
1441
|
+
payload[field] = value
|
1442
|
+
|
1443
|
+
if request.tools and provider.get("tools"):
|
1444
|
+
tools = []
|
1445
|
+
for tool in request.tools:
|
1446
|
+
# print("tool", type(tool), tool)
|
1447
|
+
json_tool = await gpt2claude_tools_json(tool.dict()["function"])
|
1448
|
+
tools.append(json_tool)
|
1449
|
+
payload["tools"] = tools
|
1450
|
+
if "tool_choice" in payload:
|
1451
|
+
if isinstance(payload["tool_choice"], dict):
|
1452
|
+
if payload["tool_choice"]["type"] == "function":
|
1453
|
+
payload["tool_choice"] = {
|
1454
|
+
"type": "tool",
|
1455
|
+
"name": payload["tool_choice"]["function"]["name"]
|
1456
|
+
}
|
1457
|
+
if isinstance(payload["tool_choice"], str):
|
1458
|
+
if payload["tool_choice"] == "auto":
|
1459
|
+
payload["tool_choice"] = {
|
1460
|
+
"type": "auto"
|
1461
|
+
}
|
1462
|
+
if payload["tool_choice"] == "none":
|
1463
|
+
payload["tool_choice"] = {
|
1464
|
+
"type": "any"
|
1465
|
+
}
|
1466
|
+
|
1467
|
+
if provider.get("tools") == False:
|
1468
|
+
payload.pop("tools", None)
|
1469
|
+
payload.pop("tool_choice", None)
|
1470
|
+
|
1471
|
+
if "think" in request.model:
|
1472
|
+
payload["thinking"] = {
|
1473
|
+
"budget_tokens": 4096,
|
1474
|
+
"type": "enabled"
|
1475
|
+
}
|
1476
|
+
payload["temperature"] = 1
|
1477
|
+
payload.pop("top_p", None)
|
1478
|
+
payload.pop("top_k", None)
|
1479
|
+
if request.model.split("-")[-1].isdigit():
|
1480
|
+
think_tokens = int(request.model.split("-")[-1])
|
1481
|
+
if think_tokens < max_tokens:
|
1482
|
+
payload["thinking"] = {
|
1483
|
+
"budget_tokens": think_tokens,
|
1484
|
+
"type": "enabled"
|
1485
|
+
}
|
1486
|
+
|
1487
|
+
if request.thinking:
|
1488
|
+
payload["thinking"] = {
|
1489
|
+
"budget_tokens": request.thinking.budget_tokens,
|
1490
|
+
"type": request.thinking.type
|
1491
|
+
}
|
1492
|
+
payload["temperature"] = 1
|
1493
|
+
payload.pop("top_p", None)
|
1494
|
+
payload.pop("top_k", None)
|
1495
|
+
# print("payload", json.dumps(payload, indent=2, ensure_ascii=False))
|
1496
|
+
|
1497
|
+
return url, headers, payload
|
1498
|
+
|
1499
|
+
async def get_dalle_payload(request, engine, provider, api_key=None):
|
1500
|
+
model_dict = get_model_dict(provider)
|
1501
|
+
original_model = model_dict[request.model]
|
1502
|
+
headers = {
|
1503
|
+
"Content-Type": "application/json",
|
1504
|
+
}
|
1505
|
+
if api_key:
|
1506
|
+
headers['Authorization'] = f"Bearer {api_key}"
|
1507
|
+
url = provider['base_url']
|
1508
|
+
url = BaseAPI(url).image_url
|
1509
|
+
|
1510
|
+
payload = {
|
1511
|
+
"model": original_model,
|
1512
|
+
"prompt": request.prompt,
|
1513
|
+
"n": request.n,
|
1514
|
+
"response_format": request.response_format,
|
1515
|
+
"size": request.size
|
1516
|
+
}
|
1517
|
+
|
1518
|
+
return url, headers, payload
|
1519
|
+
|
1520
|
+
async def get_whisper_payload(request, engine, provider, api_key=None):
|
1521
|
+
model_dict = get_model_dict(provider)
|
1522
|
+
original_model = model_dict[request.model]
|
1523
|
+
headers = {
|
1524
|
+
# "Content-Type": "multipart/form-data",
|
1525
|
+
}
|
1526
|
+
if api_key:
|
1527
|
+
headers['Authorization'] = f"Bearer {api_key}"
|
1528
|
+
url = provider['base_url']
|
1529
|
+
url = BaseAPI(url).audio_transcriptions
|
1530
|
+
|
1531
|
+
payload = {
|
1532
|
+
"model": original_model,
|
1533
|
+
"file": request.file,
|
1534
|
+
}
|
1535
|
+
|
1536
|
+
if request.prompt:
|
1537
|
+
payload["prompt"] = request.prompt
|
1538
|
+
if request.response_format:
|
1539
|
+
payload["response_format"] = request.response_format
|
1540
|
+
if request.temperature:
|
1541
|
+
payload["temperature"] = request.temperature
|
1542
|
+
if request.language:
|
1543
|
+
payload["language"] = request.language
|
1544
|
+
|
1545
|
+
# https://platform.openai.com/docs/api-reference/audio/createTranscription
|
1546
|
+
if request.timestamp_granularities:
|
1547
|
+
payload["timestamp_granularities[]"] = request.timestamp_granularities
|
1548
|
+
|
1549
|
+
return url, headers, payload
|
1550
|
+
|
1551
|
+
async def get_moderation_payload(request, engine, provider, api_key=None):
|
1552
|
+
model_dict = get_model_dict(provider)
|
1553
|
+
original_model = model_dict[request.model]
|
1554
|
+
headers = {
|
1555
|
+
"Content-Type": "application/json",
|
1556
|
+
}
|
1557
|
+
if api_key:
|
1558
|
+
headers['Authorization'] = f"Bearer {api_key}"
|
1559
|
+
url = provider['base_url']
|
1560
|
+
url = BaseAPI(url).moderations
|
1561
|
+
|
1562
|
+
payload = {
|
1563
|
+
"model": original_model,
|
1564
|
+
"input": request.input,
|
1565
|
+
}
|
1566
|
+
|
1567
|
+
return url, headers, payload
|
1568
|
+
|
1569
|
+
async def get_embedding_payload(request, engine, provider, api_key=None):
|
1570
|
+
model_dict = get_model_dict(provider)
|
1571
|
+
original_model = model_dict[request.model]
|
1572
|
+
headers = {
|
1573
|
+
"Content-Type": "application/json",
|
1574
|
+
}
|
1575
|
+
if api_key:
|
1576
|
+
headers['Authorization'] = f"Bearer {api_key}"
|
1577
|
+
url = provider['base_url']
|
1578
|
+
url = BaseAPI(url).embeddings
|
1579
|
+
|
1580
|
+
payload = {
|
1581
|
+
"input": request.input,
|
1582
|
+
"model": original_model,
|
1583
|
+
}
|
1584
|
+
|
1585
|
+
if request.encoding_format:
|
1586
|
+
if url.startswith("https://api.jina.ai"):
|
1587
|
+
payload["embedding_type"] = request.encoding_format
|
1588
|
+
else:
|
1589
|
+
payload["encoding_format"] = request.encoding_format
|
1590
|
+
|
1591
|
+
return url, headers, payload
|
1592
|
+
|
1593
|
+
async def get_tts_payload(request, engine, provider, api_key=None):
|
1594
|
+
model_dict = get_model_dict(provider)
|
1595
|
+
original_model = model_dict[request.model]
|
1596
|
+
headers = {
|
1597
|
+
"Content-Type": "application/json",
|
1598
|
+
}
|
1599
|
+
if api_key:
|
1600
|
+
headers['Authorization'] = f"Bearer {api_key}"
|
1601
|
+
url = provider['base_url']
|
1602
|
+
url = BaseAPI(url).audio_speech
|
1603
|
+
|
1604
|
+
payload = {
|
1605
|
+
"model": original_model,
|
1606
|
+
"input": request.input,
|
1607
|
+
"voice": request.voice,
|
1608
|
+
}
|
1609
|
+
|
1610
|
+
if request.response_format:
|
1611
|
+
payload["response_format"] = request.response_format
|
1612
|
+
if request.speed:
|
1613
|
+
payload["speed"] = request.speed
|
1614
|
+
if request.stream is not None:
|
1615
|
+
payload["stream"] = request.stream
|
1616
|
+
|
1617
|
+
return url, headers, payload
|
1618
|
+
|
1619
|
+
|
1620
|
+
async def get_payload(request: RequestModel, engine, provider, api_key=None):
|
1621
|
+
if engine == "gemini":
|
1622
|
+
return await get_gemini_payload(request, engine, provider, api_key)
|
1623
|
+
elif engine == "vertex-gemini":
|
1624
|
+
return await get_vertex_gemini_payload(request, engine, provider, api_key)
|
1625
|
+
elif engine == "aws":
|
1626
|
+
return await get_aws_payload(request, engine, provider, api_key)
|
1627
|
+
elif engine == "vertex-claude":
|
1628
|
+
return await get_vertex_claude_payload(request, engine, provider, api_key)
|
1629
|
+
elif engine == "azure":
|
1630
|
+
return await get_azure_payload(request, engine, provider, api_key)
|
1631
|
+
elif engine == "claude":
|
1632
|
+
return await get_claude_payload(request, engine, provider, api_key)
|
1633
|
+
elif engine == "gpt":
|
1634
|
+
provider['base_url'] = BaseAPI(provider['base_url']).chat_url
|
1635
|
+
return await get_gpt_payload(request, engine, provider, api_key)
|
1636
|
+
elif engine == "openrouter":
|
1637
|
+
return await get_openrouter_payload(request, engine, provider, api_key)
|
1638
|
+
elif engine == "cloudflare":
|
1639
|
+
return await get_cloudflare_payload(request, engine, provider, api_key)
|
1640
|
+
elif engine == "cohere":
|
1641
|
+
return await get_cohere_payload(request, engine, provider, api_key)
|
1642
|
+
elif engine == "dalle":
|
1643
|
+
return await get_dalle_payload(request, engine, provider, api_key)
|
1644
|
+
elif engine == "whisper":
|
1645
|
+
return await get_whisper_payload(request, engine, provider, api_key)
|
1646
|
+
elif engine == "tts":
|
1647
|
+
return await get_tts_payload(request, engine, provider, api_key)
|
1648
|
+
elif engine == "moderation":
|
1649
|
+
return await get_moderation_payload(request, engine, provider, api_key)
|
1650
|
+
elif engine == "embedding":
|
1651
|
+
return await get_embedding_payload(request, engine, provider, api_key)
|
1652
|
+
else:
|
1653
|
+
raise ValueError("Unknown payload")
|
1654
|
+
|
1655
|
+
async def prepare_request_payload(provider, request_data):
|
1656
|
+
|
1657
|
+
model_dict = get_model_dict(provider)
|
1658
|
+
request = RequestModel(**request_data)
|
1659
|
+
|
1660
|
+
original_model = model_dict[request.model]
|
1661
|
+
engine, _ = get_engine(provider, endpoint=None, original_model=original_model)
|
1662
|
+
|
1663
|
+
url, headers, payload = await get_payload(request, engine, provider, api_key=provider['api'])
|
1664
|
+
|
1665
|
+
return url, headers, payload, engine
|