aient 1.1.12__py3-none-any.whl → 1.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/core/request.py +81 -27
- aient/core/response.py +13 -3
- aient/core/utils.py +1 -0
- {aient-1.1.12.dist-info → aient-1.1.13.dist-info}/METADATA +1 -1
- {aient-1.1.12.dist-info → aient-1.1.13.dist-info}/RECORD +8 -8
- {aient-1.1.12.dist-info → aient-1.1.13.dist-info}/WHEEL +0 -0
- {aient-1.1.12.dist-info → aient-1.1.13.dist-info}/licenses/LICENSE +0 -0
- {aient-1.1.12.dist-info → aient-1.1.13.dist-info}/top_level.txt +0 -0
aient/core/request.py
CHANGED
@@ -48,6 +48,7 @@ async def get_gemini_payload(request, engine, provider, api_key=None):
|
|
48
48
|
|
49
49
|
messages = []
|
50
50
|
systemInstruction = None
|
51
|
+
system_prompt = ""
|
51
52
|
function_arguments = None
|
52
53
|
for msg in request.messages:
|
53
54
|
if msg.role == "assistant":
|
@@ -102,7 +103,8 @@ async def get_gemini_payload(request, engine, provider, api_key=None):
|
|
102
103
|
messages.append({"role": msg.role, "parts": content})
|
103
104
|
elif msg.role == "system":
|
104
105
|
content[0]["text"] = re.sub(r"_+", "_", content[0]["text"])
|
105
|
-
|
106
|
+
system_prompt = system_prompt + "\n\n" + content[0]["text"]
|
107
|
+
systemInstruction = {"parts": [{"text": system_prompt}]}
|
106
108
|
|
107
109
|
if any(off_model in original_model for off_model in gemini_max_token_65k_models):
|
108
110
|
safety_settings = "OFF"
|
@@ -212,23 +214,35 @@ async def get_gemini_payload(request, engine, provider, api_key=None):
|
|
212
214
|
else:
|
213
215
|
payload["generationConfig"]["maxOutputTokens"] = 8192
|
214
216
|
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
val =
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
217
|
+
if "gemini-2.5" in original_model:
|
218
|
+
payload["generationConfig"]["thinkingConfig"] = {
|
219
|
+
"includeThoughts": True,
|
220
|
+
}
|
221
|
+
# 从请求模型名中检测思考预算设置
|
222
|
+
m = re.match(r".*-think-(-?\d+)", request.model)
|
223
|
+
if m:
|
224
|
+
try:
|
225
|
+
val = int(m.group(1))
|
226
|
+
if val < 0:
|
227
|
+
val = 0
|
228
|
+
elif val > 24576:
|
229
|
+
val = 24576
|
230
|
+
payload["generationConfig"]["thinkingConfig"]["thinkingBudget"] = val
|
231
|
+
except ValueError:
|
232
|
+
# 如果转换为整数失败,忽略思考预算设置
|
233
|
+
pass
|
234
|
+
|
235
|
+
# # 检测search标签
|
236
|
+
# if request.model.endswith("-search"):
|
237
|
+
# payload["tools"] = [{"googleSearch": {}}]
|
238
|
+
|
239
|
+
if safe_get(provider, "preferences", "post_body_parameter_overrides", default=None):
|
240
|
+
for key, value in safe_get(provider, "preferences", "post_body_parameter_overrides", default={}).items():
|
241
|
+
if key == request.model:
|
242
|
+
for k, v in value.items():
|
243
|
+
payload[k] = v
|
244
|
+
elif all(_model not in request.model for _model in ["gemini", "gpt", "claude"]):
|
245
|
+
payload[key] = value
|
232
246
|
|
233
247
|
return url, headers, payload
|
234
248
|
|
@@ -303,16 +317,16 @@ async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
|
|
303
317
|
gemini_stream = "generateContent"
|
304
318
|
model_dict = get_model_dict(provider)
|
305
319
|
original_model = model_dict[request.model]
|
306
|
-
search_tool = None
|
320
|
+
# search_tool = None
|
307
321
|
|
308
322
|
# https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini/2-0-flash?hl=zh-cn
|
309
323
|
pro_models = ["gemini-2.5", "gemini-2.0"]
|
310
324
|
if any(pro_model in original_model for pro_model in pro_models):
|
311
325
|
location = gemini2
|
312
|
-
search_tool = {"googleSearch": {}}
|
326
|
+
# search_tool = {"googleSearch": {}}
|
313
327
|
else:
|
314
328
|
location = gemini1
|
315
|
-
search_tool = {"googleSearchRetrieval": {}}
|
329
|
+
# search_tool = {"googleSearchRetrieval": {}}
|
316
330
|
|
317
331
|
if "google-vertex-ai" in provider.get("base_url", ""):
|
318
332
|
url = provider.get("base_url").rstrip('/') + "/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(
|
@@ -334,6 +348,7 @@ async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
|
|
334
348
|
|
335
349
|
messages = []
|
336
350
|
systemInstruction = None
|
351
|
+
system_prompt = ""
|
337
352
|
function_arguments = None
|
338
353
|
for msg in request.messages:
|
339
354
|
if msg.role == "assistant":
|
@@ -387,7 +402,8 @@ async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
|
|
387
402
|
elif msg.role != "system":
|
388
403
|
messages.append({"role": msg.role, "parts": content})
|
389
404
|
elif msg.role == "system":
|
390
|
-
|
405
|
+
system_prompt = system_prompt + "\n\n" + content[0]["text"]
|
406
|
+
systemInstruction = {"parts": [{"text": system_prompt}]}
|
391
407
|
|
392
408
|
if any(off_model in original_model for off_model in gemini_max_token_65k_models):
|
393
409
|
safety_settings = "OFF"
|
@@ -469,8 +485,34 @@ async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
|
|
469
485
|
else:
|
470
486
|
payload["generationConfig"]["max_output_tokens"] = 8192
|
471
487
|
|
472
|
-
if
|
473
|
-
payload["
|
488
|
+
if "gemini-2.5" in original_model:
|
489
|
+
payload["generationConfig"]["thinkingConfig"] = {
|
490
|
+
"includeThoughts": True,
|
491
|
+
}
|
492
|
+
# 从请求模型名中检测思考预算设置
|
493
|
+
m = re.match(r".*-think-(-?\d+)", request.model)
|
494
|
+
if m:
|
495
|
+
try:
|
496
|
+
val = int(m.group(1))
|
497
|
+
if val < 0:
|
498
|
+
val = 0
|
499
|
+
elif val > 24576:
|
500
|
+
val = 24576
|
501
|
+
payload["generationConfig"]["thinkingConfig"]["thinkingBudget"] = val
|
502
|
+
except ValueError:
|
503
|
+
# 如果转换为整数失败,忽略思考预算设置
|
504
|
+
pass
|
505
|
+
|
506
|
+
# if request.model.endswith("-search"):
|
507
|
+
# payload["tools"] = [search_tool]
|
508
|
+
|
509
|
+
if safe_get(provider, "preferences", "post_body_parameter_overrides", default=None):
|
510
|
+
for key, value in safe_get(provider, "preferences", "post_body_parameter_overrides", default={}).items():
|
511
|
+
if key == request.model:
|
512
|
+
for k, v in value.items():
|
513
|
+
payload[k] = v
|
514
|
+
elif all(_model not in request.model for _model in ["gemini", "gpt", "claude"]):
|
515
|
+
payload[key] = value
|
474
516
|
|
475
517
|
return url, headers, payload
|
476
518
|
|
@@ -1010,7 +1052,11 @@ async def get_gpt_payload(request, engine, provider, api_key=None):
|
|
1010
1052
|
|
1011
1053
|
if safe_get(provider, "preferences", "post_body_parameter_overrides", default=None):
|
1012
1054
|
for key, value in safe_get(provider, "preferences", "post_body_parameter_overrides", default={}).items():
|
1013
|
-
|
1055
|
+
if key == request.model:
|
1056
|
+
for k, v in value.items():
|
1057
|
+
payload[k] = v
|
1058
|
+
elif all(_model not in request.model for _model in ["gemini", "gpt", "claude"]):
|
1059
|
+
payload[key] = value
|
1014
1060
|
|
1015
1061
|
return url, headers, payload
|
1016
1062
|
|
@@ -1104,7 +1150,11 @@ async def get_azure_payload(request, engine, provider, api_key=None):
|
|
1104
1150
|
|
1105
1151
|
if safe_get(provider, "preferences", "post_body_parameter_overrides", default=None):
|
1106
1152
|
for key, value in safe_get(provider, "preferences", "post_body_parameter_overrides", default={}).items():
|
1107
|
-
|
1153
|
+
if key == request.model:
|
1154
|
+
for k, v in value.items():
|
1155
|
+
payload[k] = v
|
1156
|
+
elif all(_model not in request.model for _model in ["gemini", "gpt", "claude"]):
|
1157
|
+
payload[key] = value
|
1108
1158
|
|
1109
1159
|
return url, headers, payload
|
1110
1160
|
|
@@ -1433,9 +1483,13 @@ async def get_claude_payload(request, engine, provider, api_key=None):
|
|
1433
1483
|
message_index = message_index + 1
|
1434
1484
|
|
1435
1485
|
if "claude-3-7-sonnet" in original_model:
|
1436
|
-
max_tokens =
|
1486
|
+
max_tokens = 128000
|
1437
1487
|
elif "claude-3-5-sonnet" in original_model:
|
1438
1488
|
max_tokens = 8192
|
1489
|
+
elif "claude-sonnet-4" in original_model:
|
1490
|
+
max_tokens = 64000
|
1491
|
+
elif "claude-opus-4" in original_model:
|
1492
|
+
max_tokens = 32000
|
1439
1493
|
else:
|
1440
1494
|
max_tokens = 4096
|
1441
1495
|
|
aient/core/response.py
CHANGED
@@ -535,15 +535,25 @@ async def fetch_response(client, url, headers, payload, engine, model):
|
|
535
535
|
# print("parsed_data", json.dumps(parsed_data, indent=4, ensure_ascii=False))
|
536
536
|
content = ""
|
537
537
|
reasoning_content = ""
|
538
|
-
|
539
|
-
|
540
|
-
|
538
|
+
parts_list = safe_get(parsed_data, 0, "candidates", 0, "content", "parts", default=[])
|
539
|
+
for item in parts_list:
|
540
|
+
chunk = safe_get(item, "text")
|
541
|
+
is_think = safe_get(item, "thought", default=False)
|
541
542
|
# logger.info(f"chunk: {repr(chunk)}")
|
542
543
|
if chunk:
|
543
544
|
if is_think:
|
544
545
|
reasoning_content += chunk
|
545
546
|
else:
|
546
547
|
content += chunk
|
548
|
+
# for item in parsed_data:
|
549
|
+
# chunk = safe_get(item, "candidates", 0, "content", "parts", 0, "text")
|
550
|
+
# is_think = safe_get(item, "candidates", 0, "content", "parts", 0, "thought", default=False)
|
551
|
+
# # logger.info(f"chunk: {repr(chunk)}")
|
552
|
+
# if chunk:
|
553
|
+
# if is_think:
|
554
|
+
# reasoning_content += chunk
|
555
|
+
# else:
|
556
|
+
# content += chunk
|
547
557
|
|
548
558
|
usage_metadata = safe_get(parsed_data, -1, "usageMetadata")
|
549
559
|
prompt_tokens = safe_get(usage_metadata, "promptTokenCount", default=0)
|
aient/core/utils.py
CHANGED
@@ -96,6 +96,7 @@ def get_engine(provider, endpoint=None, original_model=""):
|
|
96
96
|
and "o3" not in original_model \
|
97
97
|
and "o4" not in original_model \
|
98
98
|
and "gemini" not in original_model \
|
99
|
+
and "gemma" not in original_model \
|
99
100
|
and "learnlm" not in original_model \
|
100
101
|
and "grok" not in original_model \
|
101
102
|
and parsed_url.netloc != 'api.cloudflare.com' \
|
@@ -4,9 +4,9 @@ aient/core/.gitignore,sha256=5JRRlYYsqt_yt6iFvvzhbqh2FTUQMqwo6WwIuFzlGR8,13
|
|
4
4
|
aient/core/__init__.py,sha256=NxjebTlku35S4Dzr16rdSqSTWUvvwEeACe8KvHJnjPg,34
|
5
5
|
aient/core/log_config.py,sha256=kz2_yJv1p-o3lUQOwA3qh-LSc3wMHv13iCQclw44W9c,274
|
6
6
|
aient/core/models.py,sha256=kF-HLi1I2k_G5r153ZHuiGH8_NmpTlFMfK0_myB28YQ,7366
|
7
|
-
aient/core/request.py,sha256=
|
8
|
-
aient/core/response.py,sha256=
|
9
|
-
aient/core/utils.py,sha256
|
7
|
+
aient/core/request.py,sha256=AmTnQ_Ri_ACRxDsWmPhhD6e79hNfwLxbsyBnpbAnmNA,64490
|
8
|
+
aient/core/response.py,sha256=Z0Bjl_QvpUguyky1LIcsVks4BKKqT0eYEpDmKa_cwpQ,31978
|
9
|
+
aient/core/utils.py,sha256=-naFCv8V-qhnqvDUd8BNbW1HR9CVAPxISrXoAz464Qg,26580
|
10
10
|
aient/core/test/test_base_api.py,sha256=pWnycRJbuPSXKKU9AQjWrMAX1wiLC_014Qc9hh5C2Pw,524
|
11
11
|
aient/core/test/test_geminimask.py,sha256=HFX8jDbNg_FjjgPNxfYaR-0-roUrOO-ND-FVsuxSoiw,13254
|
12
12
|
aient/core/test/test_image.py,sha256=_T4peNGdXKBHHxyQNx12u-NTyFE8TlYI6NvvagsG2LE,319
|
@@ -38,8 +38,8 @@ aient/prompt/agent.py,sha256=y2GETN6ScC5yQVs75VFfzm4YUWzblbqLYz0Sy6JnPRw,24950
|
|
38
38
|
aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
39
39
|
aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
|
40
40
|
aient/utils/scripts.py,sha256=wutPtgbs-WXo5AACLpnCJaRQBOSKXWNnsf2grbYDzyQ,29098
|
41
|
-
aient-1.1.
|
42
|
-
aient-1.1.
|
43
|
-
aient-1.1.
|
44
|
-
aient-1.1.
|
45
|
-
aient-1.1.
|
41
|
+
aient-1.1.13.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
|
42
|
+
aient-1.1.13.dist-info/METADATA,sha256=uguXOGwPo-gVDnS50EHIjb9nm6KtITwdbNXH9a3IQzA,4968
|
43
|
+
aient-1.1.13.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
44
|
+
aient-1.1.13.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
|
45
|
+
aient-1.1.13.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|