aient 1.1.35__py3-none-any.whl → 1.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/core/models.py +1 -0
- aient/core/request.py +126 -2
- aient/core/response.py +33 -5
- aient/core/utils.py +7 -5
- {aient-1.1.35.dist-info → aient-1.1.36.dist-info}/METADATA +1 -1
- {aient-1.1.35.dist-info → aient-1.1.36.dist-info}/RECORD +9 -9
- {aient-1.1.35.dist-info → aient-1.1.36.dist-info}/WHEEL +0 -0
- {aient-1.1.35.dist-info → aient-1.1.36.dist-info}/licenses/LICENSE +0 -0
- {aient-1.1.35.dist-info → aient-1.1.36.dist-info}/top_level.txt +0 -0
aient/core/models.py
CHANGED
@@ -105,6 +105,7 @@ class RequestModel(BaseRequest):
|
|
105
105
|
response_format: Optional[ResponseFormat] = None
|
106
106
|
thinking: Optional[Thinking] = None
|
107
107
|
stream_options: Optional[StreamOptions] = None
|
108
|
+
chat_template_kwargs: Optional[Dict[str, Any]] = None
|
108
109
|
|
109
110
|
def get_last_text_message(self) -> Optional[str]:
|
110
111
|
for message in reversed(self.messages):
|
aient/core/request.py
CHANGED
@@ -341,7 +341,8 @@ async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
|
|
341
341
|
else:
|
342
342
|
location = gemini1
|
343
343
|
|
344
|
-
if "gemini-2.5-
|
344
|
+
if "gemini-2.5-flash-lite-preview-06-17" == original_model or \
|
345
|
+
"gemini-2.5-pro-preview-06-05" == original_model:
|
345
346
|
location = gemini2_5_pro_exp
|
346
347
|
|
347
348
|
if "google-vertex-ai" in provider.get("base_url", ""):
|
@@ -362,7 +363,8 @@ async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
|
|
362
363
|
else:
|
363
364
|
url = f"https://aiplatform.googleapis.com/v1/publishers/google/models/{original_model}:{gemini_stream}?key={api_key}"
|
364
365
|
headers.pop("Authorization", None)
|
365
|
-
elif "gemini-2.5-
|
366
|
+
elif "gemini-2.5-flash-lite-preview-06-17" == original_model or \
|
367
|
+
"gemini-2.5-pro-preview-06-05" == original_model:
|
366
368
|
url = "https://aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(
|
367
369
|
LOCATION=await location.next(),
|
368
370
|
PROJECT_ID=project_id,
|
@@ -1195,6 +1197,126 @@ async def get_azure_payload(request, engine, provider, api_key=None):
|
|
1195
1197
|
|
1196
1198
|
return url, headers, payload
|
1197
1199
|
|
1200
|
+
async def get_azure_databricks_payload(request, engine, provider, api_key=None):
|
1201
|
+
api_key = base64.b64encode(f"token:{api_key}".encode()).decode()
|
1202
|
+
headers = {
|
1203
|
+
'Content-Type': 'application/json',
|
1204
|
+
'Authorization': f"Basic {api_key}",
|
1205
|
+
}
|
1206
|
+
model_dict = get_model_dict(provider)
|
1207
|
+
original_model = model_dict[request.model]
|
1208
|
+
|
1209
|
+
base_url=provider['base_url']
|
1210
|
+
url = urllib.parse.urljoin(base_url, f"/serving-endpoints/{original_model}/invocations")
|
1211
|
+
|
1212
|
+
messages = []
|
1213
|
+
for msg in request.messages:
|
1214
|
+
tool_calls = None
|
1215
|
+
tool_call_id = None
|
1216
|
+
if isinstance(msg.content, list):
|
1217
|
+
content = []
|
1218
|
+
for item in msg.content:
|
1219
|
+
if item.type == "text":
|
1220
|
+
text_message = await get_text_message(item.text, engine)
|
1221
|
+
content.append(text_message)
|
1222
|
+
elif item.type == "image_url" and provider.get("image", True) and "o1-mini" not in original_model:
|
1223
|
+
image_message = await get_image_message(item.image_url.url, engine)
|
1224
|
+
content.append(image_message)
|
1225
|
+
else:
|
1226
|
+
content = msg.content
|
1227
|
+
tool_calls = msg.tool_calls
|
1228
|
+
tool_call_id = msg.tool_call_id
|
1229
|
+
|
1230
|
+
if tool_calls:
|
1231
|
+
tool_calls_list = []
|
1232
|
+
for tool_call in tool_calls:
|
1233
|
+
tool_calls_list.append({
|
1234
|
+
"id": tool_call.id,
|
1235
|
+
"type": tool_call.type,
|
1236
|
+
"function": {
|
1237
|
+
"name": tool_call.function.name,
|
1238
|
+
"arguments": tool_call.function.arguments
|
1239
|
+
}
|
1240
|
+
})
|
1241
|
+
if provider.get("tools"):
|
1242
|
+
messages.append({"role": msg.role, "tool_calls": tool_calls_list})
|
1243
|
+
elif tool_call_id:
|
1244
|
+
if provider.get("tools"):
|
1245
|
+
messages.append({"role": msg.role, "tool_call_id": tool_call_id, "content": content})
|
1246
|
+
else:
|
1247
|
+
messages.append({"role": msg.role, "content": content})
|
1248
|
+
|
1249
|
+
if "claude-3-7-sonnet" in original_model:
|
1250
|
+
max_tokens = 128000
|
1251
|
+
elif "claude-3-5-sonnet" in original_model:
|
1252
|
+
max_tokens = 8192
|
1253
|
+
elif "claude-sonnet-4" in original_model:
|
1254
|
+
max_tokens = 64000
|
1255
|
+
elif "claude-opus-4" in original_model:
|
1256
|
+
max_tokens = 32000
|
1257
|
+
else:
|
1258
|
+
max_tokens = 4096
|
1259
|
+
|
1260
|
+
payload = {
|
1261
|
+
"model": original_model,
|
1262
|
+
"messages": messages,
|
1263
|
+
"max_tokens": max_tokens,
|
1264
|
+
}
|
1265
|
+
|
1266
|
+
if request.max_tokens:
|
1267
|
+
payload["max_tokens"] = int(request.max_tokens)
|
1268
|
+
|
1269
|
+
miss_fields = [
|
1270
|
+
'model',
|
1271
|
+
'messages',
|
1272
|
+
]
|
1273
|
+
|
1274
|
+
for field, value in request.model_dump(exclude_unset=True).items():
|
1275
|
+
if field not in miss_fields and value is not None:
|
1276
|
+
if field == "max_tokens" and "o1" in original_model:
|
1277
|
+
payload["max_completion_tokens"] = value
|
1278
|
+
else:
|
1279
|
+
payload[field] = value
|
1280
|
+
|
1281
|
+
if provider.get("tools") == False or "o1" in original_model or "chatgpt-4o-latest" in original_model or "grok" in original_model:
|
1282
|
+
payload.pop("tools", None)
|
1283
|
+
payload.pop("tool_choice", None)
|
1284
|
+
|
1285
|
+
if "think" in request.model.lower():
|
1286
|
+
payload["thinking"] = {
|
1287
|
+
"budget_tokens": 4096,
|
1288
|
+
"type": "enabled"
|
1289
|
+
}
|
1290
|
+
payload["temperature"] = 1
|
1291
|
+
payload.pop("top_p", None)
|
1292
|
+
payload.pop("top_k", None)
|
1293
|
+
if request.model.split("-")[-1].isdigit():
|
1294
|
+
think_tokens = int(request.model.split("-")[-1])
|
1295
|
+
if think_tokens < max_tokens:
|
1296
|
+
payload["thinking"] = {
|
1297
|
+
"budget_tokens": think_tokens,
|
1298
|
+
"type": "enabled"
|
1299
|
+
}
|
1300
|
+
|
1301
|
+
if request.thinking:
|
1302
|
+
payload["thinking"] = {
|
1303
|
+
"budget_tokens": request.thinking.budget_tokens,
|
1304
|
+
"type": request.thinking.type
|
1305
|
+
}
|
1306
|
+
payload["temperature"] = 1
|
1307
|
+
payload.pop("top_p", None)
|
1308
|
+
payload.pop("top_k", None)
|
1309
|
+
|
1310
|
+
if safe_get(provider, "preferences", "post_body_parameter_overrides", default=None):
|
1311
|
+
for key, value in safe_get(provider, "preferences", "post_body_parameter_overrides", default={}).items():
|
1312
|
+
if key == request.model:
|
1313
|
+
for k, v in value.items():
|
1314
|
+
payload[k] = v
|
1315
|
+
elif all(_model not in request.model.lower() for _model in ["gemini", "gpt", "claude"]):
|
1316
|
+
payload[key] = value
|
1317
|
+
|
1318
|
+
return url, headers, payload
|
1319
|
+
|
1198
1320
|
async def get_openrouter_payload(request, engine, provider, api_key=None):
|
1199
1321
|
headers = {
|
1200
1322
|
'Content-Type': 'application/json'
|
@@ -1763,6 +1885,8 @@ async def get_payload(request: RequestModel, engine, provider, api_key=None):
|
|
1763
1885
|
return await get_vertex_claude_payload(request, engine, provider, api_key)
|
1764
1886
|
elif engine == "azure":
|
1765
1887
|
return await get_azure_payload(request, engine, provider, api_key)
|
1888
|
+
elif engine == "azure-databricks":
|
1889
|
+
return await get_azure_databricks_payload(request, engine, provider, api_key)
|
1766
1890
|
elif engine == "claude":
|
1767
1891
|
return await get_claude_payload(request, engine, provider, api_key)
|
1768
1892
|
elif engine == "gpt":
|
aient/core/response.py
CHANGED
@@ -49,6 +49,30 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
49
49
|
while "\n" in buffer:
|
50
50
|
line, buffer = buffer.split("\n", 1)
|
51
51
|
# line_index += 1
|
52
|
+
if line.startswith("data: "):
|
53
|
+
json_line = line.lstrip("data: ").strip()
|
54
|
+
response_json = json.loads(json_line)
|
55
|
+
json_data = safe_get(response_json, "candidates", 0, "content", default=None)
|
56
|
+
finishReason = safe_get(response_json, "candidates", 0 , "finishReason", default=None)
|
57
|
+
if finishReason:
|
58
|
+
promptTokenCount = safe_get(response_json, "usageMetadata", "promptTokenCount", default=0)
|
59
|
+
candidatesTokenCount = safe_get(response_json, "usageMetadata", "candidatesTokenCount", default=0)
|
60
|
+
totalTokenCount = safe_get(response_json, "usageMetadata", "totalTokenCount", default=0)
|
61
|
+
|
62
|
+
content = safe_get(json_data, "parts", 0, "text", default="")
|
63
|
+
b64_json = safe_get(json_data, "parts", 0, "inlineData", "data", default="")
|
64
|
+
if b64_json:
|
65
|
+
image_base64 = b64_json
|
66
|
+
|
67
|
+
is_thinking = safe_get(json_data, "parts", 0, "thought", default=False)
|
68
|
+
if is_thinking:
|
69
|
+
sse_string = await generate_sse_response(timestamp, model, reasoning_content=content)
|
70
|
+
yield sse_string
|
71
|
+
elif not image_base64 and content:
|
72
|
+
sse_string = await generate_sse_response(timestamp, model, content=content)
|
73
|
+
yield sse_string
|
74
|
+
|
75
|
+
continue
|
52
76
|
|
53
77
|
# https://ai.google.dev/api/generate-content?hl=zh-cn#FinishReason
|
54
78
|
if line and '\"finishReason\": \"' in line:
|
@@ -270,8 +294,15 @@ async def fetch_gpt_response_stream(client, url, headers, payload):
|
|
270
294
|
|
271
295
|
no_stream_content = safe_get(line, "choices", 0, "message", "content", default=None)
|
272
296
|
openrouter_reasoning = safe_get(line, "choices", 0, "delta", "reasoning", default="")
|
297
|
+
azure_databricks_claude_summary_content = safe_get(line, "choices", 0, "delta", "content", 0, "summary", 0, "text", default="")
|
298
|
+
azure_databricks_claude_signature_content = safe_get(line, "choices", 0, "delta", "content", 0, "summary", 0, "signature", default="")
|
273
299
|
# print("openrouter_reasoning", repr(openrouter_reasoning), openrouter_reasoning.endswith("\\\\"), openrouter_reasoning.endswith("\\"))
|
274
|
-
if
|
300
|
+
if azure_databricks_claude_signature_content:
|
301
|
+
pass
|
302
|
+
elif azure_databricks_claude_summary_content:
|
303
|
+
sse_string = await generate_sse_response(timestamp, payload["model"], reasoning_content=azure_databricks_claude_summary_content)
|
304
|
+
yield sse_string
|
305
|
+
elif openrouter_reasoning:
|
275
306
|
if openrouter_reasoning.endswith("\\"):
|
276
307
|
enter_buffer += openrouter_reasoning
|
277
308
|
continue
|
@@ -640,15 +671,12 @@ async def fetch_response_stream(client, url, headers, payload, engine, model):
|
|
640
671
|
elif engine == "aws":
|
641
672
|
async for chunk in fetch_aws_response_stream(client, url, headers, payload, model):
|
642
673
|
yield chunk
|
643
|
-
elif engine == "gpt":
|
674
|
+
elif engine == "gpt" or engine == "openrouter" or engine == "azure-databricks":
|
644
675
|
async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
|
645
676
|
yield chunk
|
646
677
|
elif engine == "azure":
|
647
678
|
async for chunk in fetch_azure_response_stream(client, url, headers, payload):
|
648
679
|
yield chunk
|
649
|
-
elif engine == "openrouter":
|
650
|
-
async for chunk in fetch_gpt_response_stream(client, url, headers, payload):
|
651
|
-
yield chunk
|
652
680
|
elif engine == "cloudflare":
|
653
681
|
async for chunk in fetch_cloudflare_response_stream(client, url, headers, payload, model):
|
654
682
|
yield chunk
|
aient/core/utils.py
CHANGED
@@ -75,6 +75,8 @@ def get_engine(provider, endpoint=None, original_model=""):
|
|
75
75
|
engine = "vertex"
|
76
76
|
elif parsed_url.netloc.rstrip('/').endswith('azure.com'):
|
77
77
|
engine = "azure"
|
78
|
+
elif parsed_url.netloc.rstrip('/').endswith('azuredatabricks.net'):
|
79
|
+
engine = "azure-databricks"
|
78
80
|
elif parsed_url.netloc == 'api.cloudflare.com':
|
79
81
|
engine = "cloudflare"
|
80
82
|
elif parsed_url.netloc == 'api.anthropic.com' or parsed_url.path.endswith("v1/messages"):
|
@@ -482,7 +484,6 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
482
484
|
if role:
|
483
485
|
sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
|
484
486
|
if total_tokens:
|
485
|
-
total_tokens = prompt_tokens + completion_tokens
|
486
487
|
sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
|
487
488
|
sample_data["choices"] = []
|
488
489
|
if stop:
|
@@ -574,7 +575,6 @@ async def generate_no_stream_response(timestamp, model, content=None, tools_id=N
|
|
574
575
|
}
|
575
576
|
|
576
577
|
if total_tokens:
|
577
|
-
total_tokens = prompt_tokens + completion_tokens
|
578
578
|
sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens}
|
579
579
|
|
580
580
|
json_data = json.dumps(sample_data, ensure_ascii=False)
|
@@ -674,7 +674,7 @@ async def get_image_message(base64_image, engine = None):
|
|
674
674
|
base64_image = f"data:image/png;base64,{png_base64}"
|
675
675
|
image_type = "image/png"
|
676
676
|
|
677
|
-
if "gpt" == engine or "openrouter" == engine or "azure" == engine:
|
677
|
+
if "gpt" == engine or "openrouter" == engine or "azure" == engine or "azure-databricks" == engine:
|
678
678
|
return {
|
679
679
|
"type": "image_url",
|
680
680
|
"image_url": {
|
@@ -702,7 +702,9 @@ async def get_image_message(base64_image, engine = None):
|
|
702
702
|
raise ValueError("Unknown engine")
|
703
703
|
|
704
704
|
async def get_text_message(message, engine = None):
|
705
|
-
if "gpt" == engine or "claude" == engine or "openrouter" == engine or
|
705
|
+
if "gpt" == engine or "claude" == engine or "openrouter" == engine or \
|
706
|
+
"vertex-claude" == engine or "azure" == engine or "aws" == engine or \
|
707
|
+
"azure-databricks" == engine:
|
706
708
|
return {"type": "text", "text": message}
|
707
709
|
if "gemini" == engine or "vertex-gemini" == engine:
|
708
710
|
return {"text": message}
|
@@ -734,7 +736,7 @@ def parse_json_safely(json_str):
|
|
734
736
|
return json.loads(json_str, strict=False)
|
735
737
|
except json.JSONDecodeError as e:
|
736
738
|
# 两种方法都失败,抛出异常
|
737
|
-
raise Exception(f"无法解析JSON字符串: {e}")
|
739
|
+
raise Exception(f"无法解析JSON字符串: {e}, {json_str}")
|
738
740
|
|
739
741
|
if __name__ == "__main__":
|
740
742
|
provider = {
|
@@ -3,10 +3,10 @@ aient/core/.git,sha256=lrAcW1SxzRBUcUiuKL5tS9ykDmmTXxyLP3YYU-Y-Q-I,45
|
|
3
3
|
aient/core/.gitignore,sha256=5JRRlYYsqt_yt6iFvvzhbqh2FTUQMqwo6WwIuFzlGR8,13
|
4
4
|
aient/core/__init__.py,sha256=NxjebTlku35S4Dzr16rdSqSTWUvvwEeACe8KvHJnjPg,34
|
5
5
|
aient/core/log_config.py,sha256=kz2_yJv1p-o3lUQOwA3qh-LSc3wMHv13iCQclw44W9c,274
|
6
|
-
aient/core/models.py,sha256=
|
7
|
-
aient/core/request.py,sha256=
|
8
|
-
aient/core/response.py,sha256=
|
9
|
-
aient/core/utils.py,sha256=
|
6
|
+
aient/core/models.py,sha256=d4MISNezTSe0ls0-fjuToI2SoT-sk5fWqAJuKVinIlo,7502
|
7
|
+
aient/core/request.py,sha256=6Nwduj7kFuubFaZ0ZLkT_zd03XpT-bFhgrKVOZiGBOQ,71918
|
8
|
+
aient/core/response.py,sha256=RYy70Ld_txixHHd61Dqtlo0tKHMU_OIXqxGWd6EfATI,35315
|
9
|
+
aient/core/utils.py,sha256=fhI5wBxr01lVEp8nMfjG9dQ859AE-VdrWyb9suLzzqM,27400
|
10
10
|
aient/core/test/test_base_api.py,sha256=pWnycRJbuPSXKKU9AQjWrMAX1wiLC_014Qc9hh5C2Pw,524
|
11
11
|
aient/core/test/test_geminimask.py,sha256=HFX8jDbNg_FjjgPNxfYaR-0-roUrOO-ND-FVsuxSoiw,13254
|
12
12
|
aient/core/test/test_image.py,sha256=_T4peNGdXKBHHxyQNx12u-NTyFE8TlYI6NvvagsG2LE,319
|
@@ -37,8 +37,8 @@ aient/plugins/write_file.py,sha256=7spYxloI_aUbeANEQK-oXrGPoBqSfsD7sdfMAWlNxhU,3
|
|
37
37
|
aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
38
38
|
aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
|
39
39
|
aient/utils/scripts.py,sha256=_43DEeoaiNVSA7ew1UUmp-gIV6XXe6rQPc2HTRuTzkw,40944
|
40
|
-
aient-1.1.
|
41
|
-
aient-1.1.
|
42
|
-
aient-1.1.
|
43
|
-
aient-1.1.
|
44
|
-
aient-1.1.
|
40
|
+
aient-1.1.36.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
|
41
|
+
aient-1.1.36.dist-info/METADATA,sha256=u-UNrKVDoYOocUU5VF-hi72Ej0bahPyP8SUKkj24LPU,4968
|
42
|
+
aient-1.1.36.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
43
|
+
aient-1.1.36.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
|
44
|
+
aient-1.1.36.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|