aient 1.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aient/core/request.py ADDED
@@ -0,0 +1,1361 @@
1
+ import re
2
+ import json
3
+ import httpx
4
+ import base64
5
+ import urllib.parse
6
+
7
+ from .models import RequestModel
8
+ from .utils import (
9
+ c3s,
10
+ c3o,
11
+ c3h,
12
+ c35s,
13
+ gemini1,
14
+ gemini2,
15
+ BaseAPI,
16
+ safe_get,
17
+ get_engine,
18
+ get_model_dict,
19
+ get_text_message,
20
+ get_image_message,
21
+ )
22
+
23
+ async def get_gemini_payload(request, engine, provider, api_key=None):
24
+ headers = {
25
+ 'Content-Type': 'application/json'
26
+ }
27
+ model_dict = get_model_dict(provider)
28
+ original_model = model_dict[request.model]
29
+ gemini_stream = "streamGenerateContent"
30
+ url = provider['base_url']
31
+ parsed_url = urllib.parse.urlparse(url)
32
+ # print("parsed_url", parsed_url)
33
+ if "/v1beta" in parsed_url.path:
34
+ api_version = "v1beta"
35
+ else:
36
+ api_version = "v1"
37
+
38
+ # https://generativelanguage.googleapis.com/v1beta/models/
39
+ url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/models')[0].rstrip('/')}/models/{original_model}:{gemini_stream}?key={api_key}"
40
+
41
+ messages = []
42
+ systemInstruction = None
43
+ function_arguments = None
44
+ for msg in request.messages:
45
+ if msg.role == "assistant":
46
+ msg.role = "model"
47
+ tool_calls = None
48
+ if isinstance(msg.content, list):
49
+ content = []
50
+ for item in msg.content:
51
+ if item.type == "text":
52
+ text_message = await get_text_message(item.text, engine)
53
+ content.append(text_message)
54
+ elif item.type == "image_url" and provider.get("image", True):
55
+ image_message = await get_image_message(item.image_url.url, engine)
56
+ content.append(image_message)
57
+ else:
58
+ content = [{"text": msg.content}]
59
+ tool_calls = msg.tool_calls
60
+
61
+ if tool_calls:
62
+ tool_call = tool_calls[0]
63
+ function_arguments = {
64
+ "functionCall": {
65
+ "name": tool_call.function.name,
66
+ "args": json.loads(tool_call.function.arguments)
67
+ }
68
+ }
69
+ messages.append(
70
+ {
71
+ "role": "model",
72
+ "parts": [function_arguments]
73
+ }
74
+ )
75
+ elif msg.role == "tool":
76
+ function_call_name = function_arguments["functionCall"]["name"]
77
+ messages.append(
78
+ {
79
+ "role": "function",
80
+ "parts": [{
81
+ "functionResponse": {
82
+ "name": function_call_name,
83
+ "response": {
84
+ "name": function_call_name,
85
+ "content": {
86
+ "result": msg.content,
87
+ }
88
+ }
89
+ }
90
+ }]
91
+ }
92
+ )
93
+ elif msg.role != "system":
94
+ messages.append({"role": msg.role, "parts": content})
95
+ elif msg.role == "system":
96
+ content[0]["text"] = re.sub(r"_+", "_", content[0]["text"])
97
+ systemInstruction = {"parts": content}
98
+
99
+ if "gemini-2.0-flash-exp" in original_model or "gemini-1.5" in original_model:
100
+ safety_settings = "OFF"
101
+ else:
102
+ safety_settings = "BLOCK_NONE"
103
+
104
+ payload = {
105
+ "contents": messages or [{"role": "user", "parts": [{"text": "No messages"}]}],
106
+ "safetySettings": [
107
+ {
108
+ "category": "HARM_CATEGORY_HARASSMENT",
109
+ "threshold": safety_settings
110
+ },
111
+ {
112
+ "category": "HARM_CATEGORY_HATE_SPEECH",
113
+ "threshold": safety_settings
114
+ },
115
+ {
116
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
117
+ "threshold": safety_settings
118
+ },
119
+ {
120
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
121
+ "threshold": safety_settings
122
+ }
123
+ ]
124
+ }
125
+
126
+ if systemInstruction:
127
+ if api_version == "v1beta":
128
+ payload["systemInstruction"] = systemInstruction
129
+ if api_version == "v1":
130
+ first_message = safe_get(payload, "contents", 0, "parts", 0, "text", default=None)
131
+ system_instruction = safe_get(systemInstruction, "parts", 0, "text", default=None)
132
+ if first_message and system_instruction:
133
+ payload["contents"][0]["parts"][0]["text"] = system_instruction + "\n" + first_message
134
+
135
+ miss_fields = [
136
+ 'model',
137
+ 'messages',
138
+ 'stream',
139
+ 'tool_choice',
140
+ 'presence_penalty',
141
+ 'frequency_penalty',
142
+ 'n',
143
+ 'user',
144
+ 'include_usage',
145
+ 'logprobs',
146
+ 'top_logprobs',
147
+ 'response_format'
148
+ ]
149
+ generation_config = {}
150
+
151
+ for field, value in request.model_dump(exclude_unset=True).items():
152
+ if field not in miss_fields and value is not None:
153
+ if field == "tools" and "gemini-2.0-flash-thinking" in original_model:
154
+ continue
155
+ if field == "tools":
156
+ # 处理每个工具的 function 定义
157
+ processed_tools = []
158
+ for tool in value:
159
+ function_def = tool["function"]
160
+ # 处理 parameters.properties 中的 default 字段
161
+ if safe_get(function_def, "parameters", "properties", default=None):
162
+ for prop_value in function_def["parameters"]["properties"].values():
163
+ if "default" in prop_value:
164
+ # 将 default 值添加到 description 中
165
+ default_value = prop_value["default"]
166
+ description = prop_value.get("description", "")
167
+ prop_value["description"] = f"{description}\nDefault: {default_value}"
168
+ # 删除 default 字段
169
+ del prop_value["default"]
170
+ if function_def["name"] != "googleSearch" and function_def["name"] != "googleSearch":
171
+ processed_tools.append({"function": function_def})
172
+
173
+ if processed_tools:
174
+ payload.update({
175
+ "tools": [{
176
+ "function_declarations": [tool["function"] for tool in processed_tools]
177
+ }],
178
+ "tool_config": {
179
+ "function_calling_config": {
180
+ "mode": "AUTO"
181
+ }
182
+ }
183
+ })
184
+ elif field == "temperature":
185
+ generation_config["temperature"] = value
186
+ elif field == "max_tokens":
187
+ generation_config["maxOutputTokens"] = value
188
+ elif field == "top_p":
189
+ generation_config["topP"] = value
190
+ else:
191
+ payload[field] = value
192
+
193
+ if generation_config:
194
+ payload["generationConfig"] = generation_config
195
+ if "maxOutputTokens" not in generation_config:
196
+ payload["generationConfig"]["maxOutputTokens"] = 8192
197
+
198
+ if request.model.endswith("-search"):
199
+ if "tools" not in payload:
200
+ payload["tools"] = [{
201
+ "googleSearch": {}
202
+ }]
203
+ else:
204
+ payload["tools"].append({
205
+ "googleSearch": {}
206
+ })
207
+
208
+ return url, headers, payload
209
+
210
+ import time
211
+ from cryptography.hazmat.primitives import hashes
212
+ from cryptography.hazmat.primitives.asymmetric import padding
213
+ from cryptography.hazmat.primitives.serialization import load_pem_private_key
214
+
215
+ def create_jwt(client_email, private_key):
216
+ # JWT Header
217
+ header = json.dumps({
218
+ "alg": "RS256",
219
+ "typ": "JWT"
220
+ }).encode()
221
+
222
+ # JWT Payload
223
+ now = int(time.time())
224
+ payload = json.dumps({
225
+ "iss": client_email,
226
+ "scope": "https://www.googleapis.com/auth/cloud-platform",
227
+ "aud": "https://oauth2.googleapis.com/token",
228
+ "exp": now + 3600,
229
+ "iat": now
230
+ }).encode()
231
+
232
+ # Encode header and payload
233
+ segments = [
234
+ base64.urlsafe_b64encode(header).rstrip(b'='),
235
+ base64.urlsafe_b64encode(payload).rstrip(b'=')
236
+ ]
237
+
238
+ # Create signature
239
+ signing_input = b'.'.join(segments)
240
+ private_key = load_pem_private_key(private_key.encode(), password=None)
241
+ signature = private_key.sign(
242
+ signing_input,
243
+ padding.PKCS1v15(),
244
+ hashes.SHA256()
245
+ )
246
+
247
+ segments.append(base64.urlsafe_b64encode(signature).rstrip(b'='))
248
+ return b'.'.join(segments).decode()
249
+
250
+ def get_access_token(client_email, private_key):
251
+ jwt = create_jwt(client_email, private_key)
252
+
253
+ with httpx.Client() as client:
254
+ response = client.post(
255
+ "https://oauth2.googleapis.com/token",
256
+ data={
257
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
258
+ "assertion": jwt
259
+ },
260
+ headers={'Content-Type': "application/x-www-form-urlencoded"}
261
+ )
262
+ response.raise_for_status()
263
+ return response.json()["access_token"]
264
+
265
+ async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
266
+ headers = {
267
+ 'Content-Type': 'application/json'
268
+ }
269
+ if provider.get("client_email") and provider.get("private_key"):
270
+ access_token = get_access_token(provider['client_email'], provider['private_key'])
271
+ headers['Authorization'] = f"Bearer {access_token}"
272
+ if provider.get("project_id"):
273
+ project_id = provider.get("project_id")
274
+
275
+ gemini_stream = "streamGenerateContent"
276
+ model_dict = get_model_dict(provider)
277
+ original_model = model_dict[request.model]
278
+ search_tool = None
279
+
280
+ if "gemini-2.0" in original_model or "gemini-exp" in original_model:
281
+ location = gemini2
282
+ search_tool = {"googleSearch": {}}
283
+ else:
284
+ location = gemini1
285
+ search_tool = {"googleSearchRetrieval": {}}
286
+
287
+ url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(
288
+ LOCATION=await location.next(),
289
+ PROJECT_ID=project_id,
290
+ MODEL_ID=original_model,
291
+ stream=gemini_stream
292
+ )
293
+
294
+ messages = []
295
+ systemInstruction = None
296
+ function_arguments = None
297
+ for msg in request.messages:
298
+ if msg.role == "assistant":
299
+ msg.role = "model"
300
+ tool_calls = None
301
+ if isinstance(msg.content, list):
302
+ content = []
303
+ for item in msg.content:
304
+ if item.type == "text":
305
+ text_message = await get_text_message(item.text, engine)
306
+ content.append(text_message)
307
+ elif item.type == "image_url" and provider.get("image", True):
308
+ image_message = await get_image_message(item.image_url.url, engine)
309
+ content.append(image_message)
310
+ else:
311
+ content = [{"text": msg.content}]
312
+ tool_calls = msg.tool_calls
313
+
314
+ if tool_calls:
315
+ tool_call = tool_calls[0]
316
+ function_arguments = {
317
+ "functionCall": {
318
+ "name": tool_call.function.name,
319
+ "args": json.loads(tool_call.function.arguments)
320
+ }
321
+ }
322
+ messages.append(
323
+ {
324
+ "role": "model",
325
+ "parts": [function_arguments]
326
+ }
327
+ )
328
+ elif msg.role == "tool":
329
+ function_call_name = function_arguments["functionCall"]["name"]
330
+ messages.append(
331
+ {
332
+ "role": "function",
333
+ "parts": [{
334
+ "functionResponse": {
335
+ "name": function_call_name,
336
+ "response": {
337
+ "name": function_call_name,
338
+ "content": {
339
+ "result": msg.content,
340
+ }
341
+ }
342
+ }
343
+ }]
344
+ }
345
+ )
346
+ elif msg.role != "system":
347
+ messages.append({"role": msg.role, "parts": content})
348
+ elif msg.role == "system":
349
+ systemInstruction = {"parts": content}
350
+
351
+
352
+ payload = {
353
+ "contents": messages,
354
+ # "safetySettings": [
355
+ # {
356
+ # "category": "HARM_CATEGORY_HARASSMENT",
357
+ # "threshold": "BLOCK_NONE"
358
+ # },
359
+ # {
360
+ # "category": "HARM_CATEGORY_HATE_SPEECH",
361
+ # "threshold": "BLOCK_NONE"
362
+ # },
363
+ # {
364
+ # "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
365
+ # "threshold": "BLOCK_NONE"
366
+ # },
367
+ # {
368
+ # "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
369
+ # "threshold": "BLOCK_NONE"
370
+ # }
371
+ # ]
372
+ }
373
+ if systemInstruction:
374
+ payload["system_instruction"] = systemInstruction
375
+
376
+ miss_fields = [
377
+ 'model',
378
+ 'messages',
379
+ 'stream',
380
+ 'tool_choice',
381
+ 'presence_penalty',
382
+ 'frequency_penalty',
383
+ 'n',
384
+ 'user',
385
+ 'include_usage',
386
+ 'logprobs',
387
+ 'top_logprobs'
388
+ ]
389
+ generation_config = {}
390
+
391
+ for field, value in request.model_dump(exclude_unset=True).items():
392
+ if field not in miss_fields and value is not None:
393
+ if field == "tools":
394
+ payload.update({
395
+ "tools": [{
396
+ "function_declarations": [tool["function"] for tool in value]
397
+ }],
398
+ "tool_config": {
399
+ "function_calling_config": {
400
+ "mode": "AUTO"
401
+ }
402
+ }
403
+ })
404
+ elif field == "temperature":
405
+ generation_config["temperature"] = value
406
+ elif field == "max_tokens":
407
+ generation_config["max_output_tokens"] = value
408
+ elif field == "top_p":
409
+ generation_config["top_p"] = value
410
+ else:
411
+ payload[field] = value
412
+
413
+ if generation_config:
414
+ payload["generationConfig"] = generation_config
415
+ if "max_output_tokens" not in generation_config:
416
+ payload["generationConfig"]["max_output_tokens"] = 8192
417
+
418
+ if request.model.endswith("-search"):
419
+ if "tools" not in payload:
420
+ payload["tools"] = [search_tool]
421
+ else:
422
+ payload["tools"].append(search_tool)
423
+
424
+ return url, headers, payload
425
+
426
+ async def get_vertex_claude_payload(request, engine, provider, api_key=None):
427
+ headers = {
428
+ 'Content-Type': 'application/json',
429
+ }
430
+ if provider.get("client_email") and provider.get("private_key"):
431
+ access_token = get_access_token(provider['client_email'], provider['private_key'])
432
+ headers['Authorization'] = f"Bearer {access_token}"
433
+ if provider.get("project_id"):
434
+ project_id = provider.get("project_id")
435
+
436
+ model_dict = get_model_dict(provider)
437
+ original_model = model_dict[request.model]
438
+ if "claude-3-5-sonnet" in original_model or "claude-3-7-sonnet" in original_model:
439
+ location = c35s
440
+ elif "claude-3-opus" in original_model:
441
+ location = c3o
442
+ elif "claude-3-sonnet" in original_model:
443
+ location = c3s
444
+ elif "claude-3-haiku" in original_model:
445
+ location = c3h
446
+
447
+ claude_stream = "streamRawPredict"
448
+ url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(
449
+ LOCATION=await location.next(),
450
+ PROJECT_ID=project_id,
451
+ MODEL=original_model,
452
+ stream=claude_stream
453
+ )
454
+
455
+ messages = []
456
+ system_prompt = None
457
+ tool_id = None
458
+ for msg in request.messages:
459
+ tool_call_id = None
460
+ tool_calls = None
461
+ if isinstance(msg.content, list):
462
+ content = []
463
+ for item in msg.content:
464
+ if item.type == "text":
465
+ text_message = await get_text_message(item.text, engine)
466
+ content.append(text_message)
467
+ elif item.type == "image_url" and provider.get("image", True):
468
+ image_message = await get_image_message(item.image_url.url, engine)
469
+ content.append(image_message)
470
+ else:
471
+ content = msg.content
472
+ tool_calls = msg.tool_calls
473
+ tool_id = tool_calls[0].id if tool_calls else None or tool_id
474
+ tool_call_id = msg.tool_call_id
475
+
476
+ if tool_calls:
477
+ tool_calls_list = []
478
+ tool_call = tool_calls[0]
479
+ tool_calls_list.append({
480
+ "type": "tool_use",
481
+ "id": tool_call.id,
482
+ "name": tool_call.function.name,
483
+ "input": json.loads(tool_call.function.arguments),
484
+ })
485
+ messages.append({"role": msg.role, "content": tool_calls_list})
486
+ elif tool_call_id:
487
+ messages.append({"role": "user", "content": [{
488
+ "type": "tool_result",
489
+ "tool_use_id": tool_id,
490
+ "content": content
491
+ }]})
492
+ elif msg.role == "function":
493
+ messages.append({"role": "assistant", "content": [{
494
+ "type": "tool_use",
495
+ "id": "toolu_017r5miPMV6PGSNKmhvHPic4",
496
+ "name": msg.name,
497
+ "input": {"prompt": "..."}
498
+ }]})
499
+ messages.append({"role": "user", "content": [{
500
+ "type": "tool_result",
501
+ "tool_use_id": "toolu_017r5miPMV6PGSNKmhvHPic4",
502
+ "content": msg.content
503
+ }]})
504
+ elif msg.role != "system":
505
+ messages.append({"role": msg.role, "content": content})
506
+ elif msg.role == "system":
507
+ system_prompt = content
508
+
509
+ conversation_len = len(messages) - 1
510
+ message_index = 0
511
+ while message_index < conversation_len:
512
+ if messages[message_index]["role"] == messages[message_index + 1]["role"]:
513
+ if messages[message_index].get("content"):
514
+ if isinstance(messages[message_index]["content"], list):
515
+ messages[message_index]["content"].extend(messages[message_index + 1]["content"])
516
+ elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
517
+ content_list = [{"type": "text", "text": messages[message_index]["content"]}]
518
+ content_list.extend(messages[message_index + 1]["content"])
519
+ messages[message_index]["content"] = content_list
520
+ else:
521
+ messages[message_index]["content"] += messages[message_index + 1]["content"]
522
+ messages.pop(message_index + 1)
523
+ conversation_len = conversation_len - 1
524
+ else:
525
+ message_index = message_index + 1
526
+
527
+ if "claude-3-7-sonnet" in original_model:
528
+ max_tokens = 20000
529
+ elif "claude-3-5-sonnet" in original_model:
530
+ max_tokens = 8192
531
+ else:
532
+ max_tokens = 4096
533
+
534
+ payload = {
535
+ "anthropic_version": "vertex-2023-10-16",
536
+ "messages": messages,
537
+ "system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
538
+ "max_tokens": max_tokens,
539
+ }
540
+
541
+ if request.max_tokens:
542
+ payload["max_tokens"] = int(request.max_tokens)
543
+
544
+ miss_fields = [
545
+ 'model',
546
+ 'messages',
547
+ 'presence_penalty',
548
+ 'frequency_penalty',
549
+ 'n',
550
+ 'user',
551
+ 'include_usage',
552
+ ]
553
+
554
+ for field, value in request.model_dump(exclude_unset=True).items():
555
+ if field not in miss_fields and value is not None:
556
+ payload[field] = value
557
+
558
+ if request.tools and provider.get("tools"):
559
+ tools = []
560
+ for tool in request.tools:
561
+ json_tool = await gpt2claude_tools_json(tool.dict()["function"])
562
+ tools.append(json_tool)
563
+ payload["tools"] = tools
564
+ if "tool_choice" in payload:
565
+ if isinstance(payload["tool_choice"], dict):
566
+ if payload["tool_choice"]["type"] == "function":
567
+ payload["tool_choice"] = {
568
+ "type": "tool",
569
+ "name": payload["tool_choice"]["function"]["name"]
570
+ }
571
+ if isinstance(payload["tool_choice"], str):
572
+ if payload["tool_choice"] == "auto":
573
+ payload["tool_choice"] = {
574
+ "type": "auto"
575
+ }
576
+ if payload["tool_choice"] == "none":
577
+ payload["tool_choice"] = {
578
+ "type": "any"
579
+ }
580
+
581
+ if provider.get("tools") == False:
582
+ payload.pop("tools", None)
583
+ payload.pop("tool_choice", None)
584
+
585
+ return url, headers, payload
586
+
587
+ async def get_gpt_payload(request, engine, provider, api_key=None):
588
+ headers = {
589
+ 'Content-Type': 'application/json',
590
+ }
591
+ model_dict = get_model_dict(provider)
592
+ original_model = model_dict[request.model]
593
+ if api_key:
594
+ headers['Authorization'] = f"Bearer {api_key}"
595
+
596
+ url = provider['base_url']
597
+
598
+ messages = []
599
+ for msg in request.messages:
600
+ tool_calls = None
601
+ tool_call_id = None
602
+ if isinstance(msg.content, list):
603
+ content = []
604
+ for item in msg.content:
605
+ if item.type == "text":
606
+ text_message = await get_text_message(item.text, engine)
607
+ content.append(text_message)
608
+ elif item.type == "image_url" and provider.get("image", True) and "o1-mini" not in original_model:
609
+ image_message = await get_image_message(item.image_url.url, engine)
610
+ content.append(image_message)
611
+ else:
612
+ content = msg.content
613
+ if msg.role == "system" and "o3-mini" in original_model and not content.startswith("Formatting re-enabled"):
614
+ content = "Formatting re-enabled. " + content
615
+ tool_calls = msg.tool_calls
616
+ tool_call_id = msg.tool_call_id
617
+
618
+ if tool_calls:
619
+ tool_calls_list = []
620
+ for tool_call in tool_calls:
621
+ tool_calls_list.append({
622
+ "id": tool_call.id,
623
+ "type": tool_call.type,
624
+ "function": {
625
+ "name": tool_call.function.name,
626
+ "arguments": tool_call.function.arguments
627
+ }
628
+ })
629
+ if provider.get("tools"):
630
+ messages.append({"role": msg.role, "tool_calls": tool_calls_list})
631
+ elif tool_call_id:
632
+ if provider.get("tools"):
633
+ messages.append({"role": msg.role, "tool_call_id": tool_call_id, "content": content})
634
+ else:
635
+ messages.append({"role": msg.role, "content": content})
636
+
637
+ if ("o1-mini" in original_model or "o1-preview" in original_model) and len(messages) > 1 and messages[0]["role"] == "system":
638
+ system_msg = messages.pop(0)
639
+ messages[0]["content"] = system_msg["content"] + messages[0]["content"]
640
+
641
+ payload = {
642
+ "model": original_model,
643
+ "messages": messages,
644
+ }
645
+
646
+ miss_fields = [
647
+ 'model',
648
+ 'messages',
649
+ ]
650
+
651
+ for field, value in request.model_dump(exclude_unset=True).items():
652
+ if field not in miss_fields and value is not None:
653
+ if field == "max_tokens" and ("o1" in original_model or "o3" in original_model):
654
+ payload["max_completion_tokens"] = value
655
+ else:
656
+ payload[field] = value
657
+
658
+ if provider.get("tools") == False or "o1-mini" in original_model or "chatgpt-4o-latest" in original_model or "grok" in original_model:
659
+ payload.pop("tools", None)
660
+ payload.pop("tool_choice", None)
661
+ if "models.inference.ai.azure.com" in url:
662
+ payload["stream"] = False
663
+ # request.stream = False
664
+ payload.pop("stream_options", None)
665
+
666
+ if "o3-mini" in original_model:
667
+ if request.model.endswith("high"):
668
+ payload["reasoning_effort"] = "high"
669
+ elif request.model.endswith("low"):
670
+ payload["reasoning_effort"] = "low"
671
+ else:
672
+ payload["reasoning_effort"] = "medium"
673
+
674
+ if "o3-mini" in original_model or "o1" in original_model:
675
+ if "temperature" in payload:
676
+ payload.pop("temperature")
677
+
678
+ if "deepseek-r" in original_model.lower():
679
+ if "temperature" not in payload:
680
+ payload["temperature"] = 0.6
681
+
682
+ if request.model.endswith("-search") and "gemini" in original_model:
683
+ if "tools" not in payload:
684
+ payload["tools"] = [{
685
+ "type": "function",
686
+ "function": {
687
+ "name": "googleSearch",
688
+ "description": "googleSearch"
689
+ }
690
+ }]
691
+ else:
692
+ if not any(tool["function"]["name"] == "googleSearch" for tool in payload["tools"]):
693
+ payload["tools"].append({
694
+ "type": "function",
695
+ "function": {
696
+ "name": "googleSearch",
697
+ "description": "googleSearch"
698
+ }
699
+ })
700
+
701
+ return url, headers, payload
702
+
703
+ def build_azure_endpoint(base_url, deployment_id, api_version="2024-10-21"):
704
+ # 移除base_url末尾的斜杠(如果有)
705
+ base_url = base_url.rstrip('/')
706
+ final_url = base_url
707
+
708
+ if "models/chat/completions" not in final_url:
709
+ # 构建路径
710
+ path = f"/openai/deployments/{deployment_id}/chat/completions"
711
+ # 使用urljoin拼接base_url和path
712
+ final_url = urllib.parse.urljoin(base_url, path)
713
+
714
+ if "?api-version=" not in final_url:
715
+ # 添加api-version查询参数
716
+ final_url = f"{final_url}?api-version={api_version}"
717
+
718
+ return final_url
719
+
720
+ async def get_azure_payload(request, engine, provider, api_key=None):
721
+ headers = {
722
+ 'Content-Type': 'application/json',
723
+ }
724
+ model_dict = get_model_dict(provider)
725
+ original_model = model_dict[request.model]
726
+ headers['api-key'] = f"{api_key}"
727
+
728
+ url = build_azure_endpoint(
729
+ base_url=provider['base_url'],
730
+ deployment_id=original_model,
731
+ )
732
+
733
+ messages = []
734
+ for msg in request.messages:
735
+ tool_calls = None
736
+ tool_call_id = None
737
+ if isinstance(msg.content, list):
738
+ content = []
739
+ for item in msg.content:
740
+ if item.type == "text":
741
+ text_message = await get_text_message(item.text, engine)
742
+ content.append(text_message)
743
+ elif item.type == "image_url" and provider.get("image", True) and "o1-mini" not in original_model:
744
+ image_message = await get_image_message(item.image_url.url, engine)
745
+ content.append(image_message)
746
+ else:
747
+ content = msg.content
748
+ tool_calls = msg.tool_calls
749
+ tool_call_id = msg.tool_call_id
750
+
751
+ if tool_calls:
752
+ tool_calls_list = []
753
+ for tool_call in tool_calls:
754
+ tool_calls_list.append({
755
+ "id": tool_call.id,
756
+ "type": tool_call.type,
757
+ "function": {
758
+ "name": tool_call.function.name,
759
+ "arguments": tool_call.function.arguments
760
+ }
761
+ })
762
+ if provider.get("tools"):
763
+ messages.append({"role": msg.role, "tool_calls": tool_calls_list})
764
+ elif tool_call_id:
765
+ if provider.get("tools"):
766
+ messages.append({"role": msg.role, "tool_call_id": tool_call_id, "content": content})
767
+ else:
768
+ messages.append({"role": msg.role, "content": content})
769
+
770
+ payload = {
771
+ "model": original_model,
772
+ "messages": messages,
773
+ }
774
+
775
+ miss_fields = [
776
+ 'model',
777
+ 'messages',
778
+ ]
779
+
780
+ for field, value in request.model_dump(exclude_unset=True).items():
781
+ if field not in miss_fields and value is not None:
782
+ if field == "max_tokens" and "o1" in original_model:
783
+ payload["max_completion_tokens"] = value
784
+ else:
785
+ payload[field] = value
786
+
787
+ if provider.get("tools") == False or "o1" in original_model or "chatgpt-4o-latest" in original_model or "grok" in original_model:
788
+ payload.pop("tools", None)
789
+ payload.pop("tool_choice", None)
790
+
791
+ return url, headers, payload
792
+
793
+ async def get_openrouter_payload(request, engine, provider, api_key=None):
794
+ headers = {
795
+ 'Content-Type': 'application/json'
796
+ }
797
+ model_dict = get_model_dict(provider)
798
+ original_model = model_dict[request.model]
799
+ if api_key:
800
+ headers['Authorization'] = f"Bearer {api_key}"
801
+
802
+ url = provider['base_url']
803
+
804
+ messages = []
805
+ for msg in request.messages:
806
+ name = None
807
+ if isinstance(msg.content, list):
808
+ content = []
809
+ for item in msg.content:
810
+ if item.type == "text":
811
+ text_message = await get_text_message(item.text, engine)
812
+ content.append(text_message)
813
+ elif item.type == "image_url" and provider.get("image", True):
814
+ image_message = await get_image_message(item.image_url.url, engine)
815
+ content.append(image_message)
816
+ else:
817
+ content = msg.content
818
+ name = msg.name
819
+ if name:
820
+ messages.append({"role": msg.role, "name": name, "content": content})
821
+ else:
822
+ # print("content", content)
823
+ if isinstance(content, list):
824
+ for item in content:
825
+ if item["type"] == "text":
826
+ messages.append({"role": msg.role, "content": item["text"]})
827
+ elif item["type"] == "image_url":
828
+ messages.append({"role": msg.role, "content": [await get_image_message(item["image_url"]["url"], engine)]})
829
+ else:
830
+ messages.append({"role": msg.role, "content": content})
831
+
832
+ payload = {
833
+ "model": original_model,
834
+ "messages": messages,
835
+ }
836
+
837
+ miss_fields = [
838
+ 'model',
839
+ 'messages',
840
+ 'n',
841
+ 'user',
842
+ 'include_usage',
843
+ ]
844
+
845
+ for field, value in request.model_dump(exclude_unset=True).items():
846
+ if field not in miss_fields and value is not None:
847
+ payload[field] = value
848
+
849
+ return url, headers, payload
850
+
851
+ async def get_cohere_payload(request, engine, provider, api_key=None):
852
+ headers = {
853
+ 'Content-Type': 'application/json'
854
+ }
855
+ model_dict = get_model_dict(provider)
856
+ original_model = model_dict[request.model]
857
+ if api_key:
858
+ headers['Authorization'] = f"Bearer {api_key}"
859
+
860
+ url = provider['base_url']
861
+
862
+ role_map = {
863
+ "user": "USER",
864
+ "assistant" : "CHATBOT",
865
+ "system": "SYSTEM"
866
+ }
867
+
868
+ messages = []
869
+ for msg in request.messages:
870
+ if isinstance(msg.content, list):
871
+ content = []
872
+ for item in msg.content:
873
+ if item.type == "text":
874
+ text_message = await get_text_message(item.text, engine)
875
+ content.append(text_message)
876
+ else:
877
+ content = msg.content
878
+
879
+ if isinstance(content, list):
880
+ for item in content:
881
+ if item["type"] == "text":
882
+ messages.append({"role": role_map[msg.role], "message": item["text"]})
883
+ else:
884
+ messages.append({"role": role_map[msg.role], "message": content})
885
+
886
+ chat_history = messages[:-1]
887
+ query = messages[-1].get("message")
888
+ payload = {
889
+ "model": original_model,
890
+ "message": query,
891
+ }
892
+
893
+ if chat_history:
894
+ payload["chat_history"] = chat_history
895
+
896
+ miss_fields = [
897
+ 'model',
898
+ 'messages',
899
+ 'tools',
900
+ 'tool_choice',
901
+ 'temperature',
902
+ 'top_p',
903
+ 'max_tokens',
904
+ 'presence_penalty',
905
+ 'frequency_penalty',
906
+ 'n',
907
+ 'user',
908
+ 'include_usage',
909
+ 'logprobs',
910
+ 'top_logprobs'
911
+ ]
912
+
913
+ for field, value in request.model_dump(exclude_unset=True).items():
914
+ if field not in miss_fields and value is not None:
915
+ payload[field] = value
916
+
917
+ return url, headers, payload
918
+
919
+ async def get_cloudflare_payload(request, engine, provider, api_key=None):
920
+ headers = {
921
+ 'Content-Type': 'application/json'
922
+ }
923
+ model_dict = get_model_dict(provider)
924
+ original_model = model_dict[request.model]
925
+ if api_key:
926
+ headers['Authorization'] = f"Bearer {api_key}"
927
+
928
+ url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=original_model)
929
+
930
+ msg = request.messages[-1]
931
+ content = None
932
+ if isinstance(msg.content, list):
933
+ for item in msg.content:
934
+ if item.type == "text":
935
+ content = await get_text_message(item.text, engine)
936
+ else:
937
+ content = msg.content
938
+
939
+ payload = {
940
+ "prompt": content,
941
+ }
942
+
943
+ miss_fields = [
944
+ 'model',
945
+ 'messages',
946
+ 'tools',
947
+ 'tool_choice',
948
+ 'temperature',
949
+ 'top_p',
950
+ 'max_tokens',
951
+ 'presence_penalty',
952
+ 'frequency_penalty',
953
+ 'n',
954
+ 'user',
955
+ 'include_usage',
956
+ 'logprobs',
957
+ 'top_logprobs'
958
+ ]
959
+
960
+ for field, value in request.model_dump(exclude_unset=True).items():
961
+ if field not in miss_fields and value is not None:
962
+ payload[field] = value
963
+
964
+ return url, headers, payload
965
+
966
+ async def gpt2claude_tools_json(json_dict):
967
+ import copy
968
+ json_dict = copy.deepcopy(json_dict)
969
+
970
+ # 处理 $ref 引用
971
+ def resolve_refs(obj, defs):
972
+ if isinstance(obj, dict):
973
+ # 如果有 $ref 引用,替换为实际定义
974
+ if "$ref" in obj and obj["$ref"].startswith("#/$defs/"):
975
+ ref_name = obj["$ref"].split("/")[-1]
976
+ if ref_name in defs:
977
+ # 完全替换为引用的对象
978
+ ref_obj = copy.deepcopy(defs[ref_name])
979
+ # 保留原始对象中的其他属性
980
+ for k, v in obj.items():
981
+ if k != "$ref":
982
+ ref_obj[k] = v
983
+ return ref_obj
984
+
985
+ # 递归处理所有属性
986
+ for key, value in list(obj.items()):
987
+ obj[key] = resolve_refs(value, defs)
988
+
989
+ elif isinstance(obj, list):
990
+ # 递归处理列表中的每个元素
991
+ for i, item in enumerate(obj):
992
+ obj[i] = resolve_refs(item, defs)
993
+
994
+ return obj
995
+
996
+ # 提取 $defs 定义
997
+ defs = {}
998
+ if "parameters" in json_dict and "defs" in json_dict["parameters"]:
999
+ defs = json_dict["parameters"]["defs"]
1000
+ # 从参数中删除 $defs,因为 Claude 不需要它
1001
+ del json_dict["parameters"]["defs"]
1002
+
1003
+ # 解析所有引用
1004
+ json_dict = resolve_refs(json_dict, defs)
1005
+
1006
+ # 继续原有的键名转换逻辑
1007
+ keys_to_change = {
1008
+ "parameters": "input_schema",
1009
+ }
1010
+ for old_key, new_key in keys_to_change.items():
1011
+ if old_key in json_dict:
1012
+ if new_key:
1013
+ if json_dict[old_key] == None:
1014
+ json_dict[old_key] = {
1015
+ "type": "object",
1016
+ "properties": {}
1017
+ }
1018
+ json_dict[new_key] = json_dict.pop(old_key)
1019
+ else:
1020
+ json_dict.pop(old_key)
1021
+ return json_dict
1022
+
1023
+ async def get_claude_payload(request, engine, provider, api_key=None):
1024
+ model_dict = get_model_dict(provider)
1025
+ original_model = model_dict[request.model]
1026
+
1027
+ if "claude-3-7-sonnet" in original_model:
1028
+ anthropic_beta = "output-128k-2025-02-19"
1029
+ elif "claude-3-5-sonnet" in original_model:
1030
+ anthropic_beta = "max-tokens-3-5-sonnet-2024-07-15"
1031
+ else:
1032
+ anthropic_beta = "tools-2024-05-16"
1033
+
1034
+ headers = {
1035
+ "content-type": "application/json",
1036
+ "x-api-key": f"{api_key}",
1037
+ "anthropic-version": "2023-06-01",
1038
+ "anthropic-beta": anthropic_beta,
1039
+ }
1040
+ url = provider['base_url']
1041
+
1042
+ messages = []
1043
+ system_prompt = None
1044
+ tool_id = None
1045
+ for msg in request.messages:
1046
+ tool_call_id = None
1047
+ tool_calls = None
1048
+ if isinstance(msg.content, list):
1049
+ content = []
1050
+ for item in msg.content:
1051
+ if item.type == "text":
1052
+ text_message = await get_text_message(item.text, engine)
1053
+ content.append(text_message)
1054
+ elif item.type == "image_url" and provider.get("image", True):
1055
+ image_message = await get_image_message(item.image_url.url, engine)
1056
+ content.append(image_message)
1057
+ else:
1058
+ content = msg.content
1059
+ tool_calls = msg.tool_calls
1060
+ tool_id = tool_calls[0].id if tool_calls else None or tool_id
1061
+ tool_call_id = msg.tool_call_id
1062
+
1063
+ if tool_calls:
1064
+ tool_calls_list = []
1065
+ tool_call = tool_calls[0]
1066
+ tool_calls_list.append({
1067
+ "type": "tool_use",
1068
+ "id": tool_call.id,
1069
+ "name": tool_call.function.name,
1070
+ "input": json.loads(tool_call.function.arguments),
1071
+ })
1072
+ messages.append({"role": msg.role, "content": tool_calls_list})
1073
+ elif tool_call_id:
1074
+ messages.append({"role": "user", "content": [{
1075
+ "type": "tool_result",
1076
+ "tool_use_id": tool_id,
1077
+ "content": content
1078
+ }]})
1079
+ elif msg.role == "function":
1080
+ messages.append({"role": "assistant", "content": [{
1081
+ "type": "tool_use",
1082
+ "id": "toolu_017r5miPMV6PGSNKmhvHPic4",
1083
+ "name": msg.name,
1084
+ "input": {"prompt": "..."}
1085
+ }]})
1086
+ messages.append({"role": "user", "content": [{
1087
+ "type": "tool_result",
1088
+ "tool_use_id": "toolu_017r5miPMV6PGSNKmhvHPic4",
1089
+ "content": msg.content
1090
+ }]})
1091
+ elif msg.role != "system":
1092
+ messages.append({"role": msg.role, "content": content})
1093
+ elif msg.role == "system":
1094
+ system_prompt = content
1095
+
1096
+ conversation_len = len(messages) - 1
1097
+ message_index = 0
1098
+ while message_index < conversation_len:
1099
+ if messages[message_index]["role"] == messages[message_index + 1]["role"]:
1100
+ if messages[message_index].get("content"):
1101
+ if isinstance(messages[message_index]["content"], list):
1102
+ messages[message_index]["content"].extend(messages[message_index + 1]["content"])
1103
+ elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
1104
+ content_list = [{"type": "text", "text": messages[message_index]["content"]}]
1105
+ content_list.extend(messages[message_index + 1]["content"])
1106
+ messages[message_index]["content"] = content_list
1107
+ else:
1108
+ messages[message_index]["content"] += messages[message_index + 1]["content"]
1109
+ messages.pop(message_index + 1)
1110
+ conversation_len = conversation_len - 1
1111
+ else:
1112
+ message_index = message_index + 1
1113
+
1114
+ if "claude-3-7-sonnet" in original_model:
1115
+ max_tokens = 20000
1116
+ elif "claude-3-5-sonnet" in original_model:
1117
+ max_tokens = 8192
1118
+ else:
1119
+ max_tokens = 4096
1120
+
1121
+ payload = {
1122
+ "model": original_model,
1123
+ "messages": messages,
1124
+ "system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
1125
+ "max_tokens": max_tokens,
1126
+ }
1127
+
1128
+ if request.max_tokens:
1129
+ payload["max_tokens"] = int(request.max_tokens)
1130
+
1131
+ miss_fields = [
1132
+ 'model',
1133
+ 'messages',
1134
+ 'presence_penalty',
1135
+ 'frequency_penalty',
1136
+ 'n',
1137
+ 'user',
1138
+ 'include_usage',
1139
+ ]
1140
+
1141
+ for field, value in request.model_dump(exclude_unset=True).items():
1142
+ if field not in miss_fields and value is not None:
1143
+ payload[field] = value
1144
+
1145
+ if request.tools and provider.get("tools"):
1146
+ tools = []
1147
+ for tool in request.tools:
1148
+ # print("tool", type(tool), tool)
1149
+ json_tool = await gpt2claude_tools_json(tool.dict()["function"])
1150
+ tools.append(json_tool)
1151
+ payload["tools"] = tools
1152
+ if "tool_choice" in payload:
1153
+ if isinstance(payload["tool_choice"], dict):
1154
+ if payload["tool_choice"]["type"] == "function":
1155
+ payload["tool_choice"] = {
1156
+ "type": "tool",
1157
+ "name": payload["tool_choice"]["function"]["name"]
1158
+ }
1159
+ if isinstance(payload["tool_choice"], str):
1160
+ if payload["tool_choice"] == "auto":
1161
+ payload["tool_choice"] = {
1162
+ "type": "auto"
1163
+ }
1164
+ if payload["tool_choice"] == "none":
1165
+ payload["tool_choice"] = {
1166
+ "type": "any"
1167
+ }
1168
+
1169
+ if provider.get("tools") == False:
1170
+ payload.pop("tools", None)
1171
+ payload.pop("tool_choice", None)
1172
+
1173
+ if "think" in request.model:
1174
+ payload["thinking"] = {
1175
+ "budget_tokens": 4096,
1176
+ "type": "enabled"
1177
+ }
1178
+ payload["temperature"] = 1
1179
+ payload.pop("top_p", None)
1180
+ payload.pop("top_k", None)
1181
+ if request.model.split("-")[-1].isdigit():
1182
+ think_tokens = int(request.model.split("-")[-1])
1183
+ if think_tokens < max_tokens:
1184
+ payload["thinking"] = {
1185
+ "budget_tokens": think_tokens,
1186
+ "type": "enabled"
1187
+ }
1188
+
1189
+ if request.thinking:
1190
+ payload["thinking"] = {
1191
+ "budget_tokens": request.thinking.budget_tokens,
1192
+ "type": request.thinking.type
1193
+ }
1194
+ payload["temperature"] = 1
1195
+ payload.pop("top_p", None)
1196
+ payload.pop("top_k", None)
1197
+ # print("payload", json.dumps(payload, indent=2, ensure_ascii=False))
1198
+
1199
+ return url, headers, payload
1200
+
1201
+ async def get_dalle_payload(request, engine, provider, api_key=None):
1202
+ model_dict = get_model_dict(provider)
1203
+ original_model = model_dict[request.model]
1204
+ headers = {
1205
+ "Content-Type": "application/json",
1206
+ }
1207
+ if api_key:
1208
+ headers['Authorization'] = f"Bearer {api_key}"
1209
+ url = provider['base_url']
1210
+ url = BaseAPI(url).image_url
1211
+
1212
+ payload = {
1213
+ "model": original_model,
1214
+ "prompt": request.prompt,
1215
+ "n": request.n,
1216
+ "response_format": request.response_format,
1217
+ "size": request.size
1218
+ }
1219
+
1220
+ return url, headers, payload
1221
+
1222
+ async def get_whisper_payload(request, engine, provider, api_key=None):
1223
+ model_dict = get_model_dict(provider)
1224
+ original_model = model_dict[request.model]
1225
+ headers = {
1226
+ # "Content-Type": "multipart/form-data",
1227
+ }
1228
+ if api_key:
1229
+ headers['Authorization'] = f"Bearer {api_key}"
1230
+ url = provider['base_url']
1231
+ url = BaseAPI(url).audio_transcriptions
1232
+
1233
+ payload = {
1234
+ "model": original_model,
1235
+ "file": request.file,
1236
+ }
1237
+
1238
+ if request.prompt:
1239
+ payload["prompt"] = request.prompt
1240
+ if request.response_format:
1241
+ payload["response_format"] = request.response_format
1242
+ if request.temperature:
1243
+ payload["temperature"] = request.temperature
1244
+ if request.language:
1245
+ payload["language"] = request.language
1246
+
1247
+ return url, headers, payload
1248
+
1249
+ async def get_moderation_payload(request, engine, provider, api_key=None):
1250
+ model_dict = get_model_dict(provider)
1251
+ original_model = model_dict[request.model]
1252
+ headers = {
1253
+ "Content-Type": "application/json",
1254
+ }
1255
+ if api_key:
1256
+ headers['Authorization'] = f"Bearer {api_key}"
1257
+ url = provider['base_url']
1258
+ url = BaseAPI(url).moderations
1259
+
1260
+ payload = {
1261
+ "model": original_model,
1262
+ "input": request.input,
1263
+ }
1264
+
1265
+ return url, headers, payload
1266
+
1267
+ async def get_embedding_payload(request, engine, provider, api_key=None):
1268
+ model_dict = get_model_dict(provider)
1269
+ original_model = model_dict[request.model]
1270
+ headers = {
1271
+ "Content-Type": "application/json",
1272
+ }
1273
+ if api_key:
1274
+ headers['Authorization'] = f"Bearer {api_key}"
1275
+ url = provider['base_url']
1276
+ url = BaseAPI(url).embeddings
1277
+
1278
+ payload = {
1279
+ "input": request.input,
1280
+ "model": original_model,
1281
+ }
1282
+
1283
+ if request.encoding_format:
1284
+ if url.startswith("https://api.jina.ai"):
1285
+ payload["embedding_type"] = request.encoding_format
1286
+ else:
1287
+ payload["encoding_format"] = request.encoding_format
1288
+
1289
+ return url, headers, payload
1290
+
1291
+ async def get_tts_payload(request, engine, provider, api_key=None):
1292
+ model_dict = get_model_dict(provider)
1293
+ original_model = model_dict[request.model]
1294
+ headers = {
1295
+ "Content-Type": "application/json",
1296
+ }
1297
+ if api_key:
1298
+ headers['Authorization'] = f"Bearer {api_key}"
1299
+ url = provider['base_url']
1300
+ url = BaseAPI(url).audio_speech
1301
+
1302
+ payload = {
1303
+ "model": original_model,
1304
+ "input": request.input,
1305
+ "voice": request.voice,
1306
+ }
1307
+
1308
+ if request.response_format:
1309
+ payload["response_format"] = request.response_format
1310
+ if request.speed:
1311
+ payload["speed"] = request.speed
1312
+ if request.stream is not None:
1313
+ payload["stream"] = request.stream
1314
+
1315
+ return url, headers, payload
1316
+
1317
+
1318
+ async def get_payload(request: RequestModel, engine, provider, api_key=None):
1319
+ if engine == "gemini":
1320
+ return await get_gemini_payload(request, engine, provider, api_key)
1321
+ elif engine == "vertex-gemini":
1322
+ return await get_vertex_gemini_payload(request, engine, provider, api_key)
1323
+ elif engine == "vertex-claude":
1324
+ return await get_vertex_claude_payload(request, engine, provider, api_key)
1325
+ elif engine == "azure":
1326
+ return await get_azure_payload(request, engine, provider, api_key)
1327
+ elif engine == "claude":
1328
+ return await get_claude_payload(request, engine, provider, api_key)
1329
+ elif engine == "gpt":
1330
+ provider['base_url'] = BaseAPI(provider['base_url']).chat_url
1331
+ return await get_gpt_payload(request, engine, provider, api_key)
1332
+ elif engine == "openrouter":
1333
+ return await get_openrouter_payload(request, engine, provider, api_key)
1334
+ elif engine == "cloudflare":
1335
+ return await get_cloudflare_payload(request, engine, provider, api_key)
1336
+ elif engine == "cohere":
1337
+ return await get_cohere_payload(request, engine, provider, api_key)
1338
+ elif engine == "dalle":
1339
+ return await get_dalle_payload(request, engine, provider, api_key)
1340
+ elif engine == "whisper":
1341
+ return await get_whisper_payload(request, engine, provider, api_key)
1342
+ elif engine == "tts":
1343
+ return await get_tts_payload(request, engine, provider, api_key)
1344
+ elif engine == "moderation":
1345
+ return await get_moderation_payload(request, engine, provider, api_key)
1346
+ elif engine == "embedding":
1347
+ return await get_embedding_payload(request, engine, provider, api_key)
1348
+ else:
1349
+ raise ValueError("Unknown payload")
1350
+
1351
+ async def prepare_request_payload(provider, request_data):
1352
+
1353
+ model_dict = get_model_dict(provider)
1354
+ request = RequestModel(**request_data)
1355
+
1356
+ original_model = model_dict[request.model]
1357
+ engine, _ = get_engine(provider, endpoint=None, original_model=original_model)
1358
+
1359
+ url, headers, payload = await get_payload(request, engine, provider, api_key=provider['api'])
1360
+
1361
+ return url, headers, payload, engine