aient 1.1.74__py3-none-any.whl → 1.1.76__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/core/response.py +28 -22
- aient/models/chatgpt.py +53 -11
- {aient-1.1.74.dist-info → aient-1.1.76.dist-info}/METADATA +1 -1
- {aient-1.1.74.dist-info → aient-1.1.76.dist-info}/RECORD +7 -7
- {aient-1.1.74.dist-info → aient-1.1.76.dist-info}/WHEEL +0 -0
- {aient-1.1.74.dist-info → aient-1.1.76.dist-info}/licenses/LICENSE +0 -0
- {aient-1.1.74.dist-info → aient-1.1.76.dist-info}/top_level.txt +0 -0
aient/core/response.py
CHANGED
@@ -3,6 +3,7 @@ import json
|
|
3
3
|
import random
|
4
4
|
import string
|
5
5
|
import base64
|
6
|
+
import asyncio
|
6
7
|
from datetime import datetime
|
7
8
|
|
8
9
|
from .log_config import logger
|
@@ -14,19 +15,19 @@ async def check_response(response, error_log):
|
|
14
15
|
error_message = await response.aread()
|
15
16
|
error_str = error_message.decode('utf-8', errors='replace')
|
16
17
|
try:
|
17
|
-
error_json = json.loads
|
18
|
+
error_json = await asyncio.to_thread(json.loads, error_str)
|
18
19
|
except json.JSONDecodeError:
|
19
20
|
error_json = error_str
|
20
21
|
return {"error": f"{error_log} HTTP Error", "status_code": response.status_code, "details": error_json}
|
21
22
|
return None
|
22
23
|
|
23
|
-
def gemini_json_poccess(response_str):
|
24
|
+
async def gemini_json_poccess(response_str):
|
24
25
|
promptTokenCount = 0
|
25
26
|
candidatesTokenCount = 0
|
26
27
|
totalTokenCount = 0
|
27
28
|
image_base64 = None
|
28
29
|
|
29
|
-
response_json = json.loads
|
30
|
+
response_json = await asyncio.to_thread(json.loads, response_str)
|
30
31
|
json_data = safe_get(response_json, "candidates", 0, "content", default=None)
|
31
32
|
finishReason = safe_get(response_json, "candidates", 0 , "finishReason", default=None)
|
32
33
|
if finishReason:
|
@@ -75,7 +76,7 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model, tim
|
|
75
76
|
if line.startswith("data: "):
|
76
77
|
parts_json = line.lstrip("data: ").strip()
|
77
78
|
try:
|
78
|
-
json.loads
|
79
|
+
await asyncio.to_thread(json.loads, parts_json)
|
79
80
|
except json.JSONDecodeError:
|
80
81
|
logger.error(f"JSON decode error: {parts_json}")
|
81
82
|
continue
|
@@ -83,12 +84,12 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model, tim
|
|
83
84
|
parts_json += line
|
84
85
|
parts_json = parts_json.lstrip("[,")
|
85
86
|
try:
|
86
|
-
json.loads
|
87
|
+
await asyncio.to_thread(json.loads, parts_json)
|
87
88
|
except json.JSONDecodeError:
|
88
89
|
continue
|
89
90
|
|
90
91
|
# https://ai.google.dev/api/generate-content?hl=zh-cn#FinishReason
|
91
|
-
is_thinking, reasoning_content, content, image_base64, function_call_name, function_full_response, finishReason, blockReason, promptTokenCount, candidatesTokenCount, totalTokenCount = gemini_json_poccess(parts_json)
|
92
|
+
is_thinking, reasoning_content, content, image_base64, function_call_name, function_full_response, finishReason, blockReason, promptTokenCount, candidatesTokenCount, totalTokenCount = await gemini_json_poccess(parts_json)
|
92
93
|
|
93
94
|
if is_thinking:
|
94
95
|
sse_string = await generate_sse_response(timestamp, model, reasoning_content=reasoning_content)
|
@@ -159,7 +160,7 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
|
|
159
160
|
|
160
161
|
if line and '\"text\": \"' in line and is_finish == False:
|
161
162
|
try:
|
162
|
-
json_data = json.loads
|
163
|
+
json_data = await asyncio.to_thread(json.loads, "{" + line.strip().rstrip(",") + "}")
|
163
164
|
content = json_data.get('text', '')
|
164
165
|
sse_string = await generate_sse_response(timestamp, model, content=content)
|
165
166
|
yield sse_string
|
@@ -176,7 +177,7 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
|
|
176
177
|
function_full_response += line
|
177
178
|
|
178
179
|
if need_function_call:
|
179
|
-
function_call = json.loads
|
180
|
+
function_call = await asyncio.to_thread(json.loads, function_full_response)
|
180
181
|
function_call_name = function_call["name"]
|
181
182
|
function_call_id = function_call["id"]
|
182
183
|
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=function_call_name)
|
@@ -213,7 +214,7 @@ async def fetch_gpt_response_stream(client, url, headers, payload, timeout):
|
|
213
214
|
if line and not line.startswith(":") and (result:=line.lstrip("data: ").strip()):
|
214
215
|
if result.strip() == "[DONE]":
|
215
216
|
break
|
216
|
-
line = json.loads
|
217
|
+
line = await asyncio.to_thread(json.loads, result)
|
217
218
|
line['id'] = f"chatcmpl-{random_str}"
|
218
219
|
|
219
220
|
# 处理 <think> 标签
|
@@ -327,7 +328,7 @@ async def fetch_azure_response_stream(client, url, headers, payload, timeout):
|
|
327
328
|
if line and not line.startswith(":") and (result:=line.lstrip("data: ").strip()):
|
328
329
|
if result.strip() == "[DONE]":
|
329
330
|
break
|
330
|
-
line = json.loads
|
331
|
+
line = await asyncio.to_thread(json.loads, result)
|
331
332
|
no_stream_content = safe_get(line, "choices", 0, "message", "content", default="")
|
332
333
|
content = safe_get(line, "choices", 0, "delta", "content", default="")
|
333
334
|
|
@@ -380,7 +381,7 @@ async def fetch_cloudflare_response_stream(client, url, headers, payload, model,
|
|
380
381
|
line = line.lstrip("data: ")
|
381
382
|
if line == "[DONE]":
|
382
383
|
break
|
383
|
-
resp: dict = json.loads
|
384
|
+
resp: dict = await asyncio.to_thread(json.loads, line)
|
384
385
|
message = resp.get("response")
|
385
386
|
if message:
|
386
387
|
sse_string = await generate_sse_response(timestamp, model, content=message)
|
@@ -401,7 +402,7 @@ async def fetch_cohere_response_stream(client, url, headers, payload, model, tim
|
|
401
402
|
while "\n" in buffer:
|
402
403
|
line, buffer = buffer.split("\n", 1)
|
403
404
|
# logger.info("line: %s", repr(line))
|
404
|
-
resp: dict = json.loads
|
405
|
+
resp: dict = await asyncio.to_thread(json.loads, line)
|
405
406
|
if resp.get("is_finished") == True:
|
406
407
|
break
|
407
408
|
if resp.get("event_type") == "text-generation":
|
@@ -427,7 +428,7 @@ async def fetch_claude_response_stream(client, url, headers, payload, model, tim
|
|
427
428
|
# logger.info(line)
|
428
429
|
|
429
430
|
if line.startswith("data:") and (line := line.lstrip("data: ")):
|
430
|
-
resp: dict = json.loads
|
431
|
+
resp: dict = await asyncio.to_thread(json.loads, line)
|
431
432
|
|
432
433
|
input_tokens = input_tokens or safe_get(resp, "message", "usage", "input_tokens", default=0)
|
433
434
|
# cache_creation_input_tokens = safe_get(resp, "message", "usage", "cache_creation_input_tokens", default=0)
|
@@ -486,7 +487,7 @@ async def fetch_aws_response_stream(client, url, headers, payload, model, timeou
|
|
486
487
|
if not json_match:
|
487
488
|
continue
|
488
489
|
try:
|
489
|
-
chunk_data = json.loads
|
490
|
+
chunk_data = await asyncio.to_thread(json.loads, json_match.group(0).lstrip('event'))
|
490
491
|
except json.JSONDecodeError:
|
491
492
|
logger.error(f"DEBUG json.JSONDecodeError: {json_match.group(0).lstrip('event')!r}")
|
492
493
|
continue
|
@@ -496,7 +497,7 @@ async def fetch_aws_response_stream(client, url, headers, payload, model, timeou
|
|
496
497
|
# 解码 Base64 编码的字节
|
497
498
|
decoded_bytes = base64.b64decode(chunk_data["bytes"])
|
498
499
|
# 将解码后的字节再次解析为 JSON
|
499
|
-
payload_chunk = json.loads
|
500
|
+
payload_chunk = await asyncio.to_thread(json.loads, decoded_bytes.decode('utf-8'))
|
500
501
|
# print(f"DEBUG payload_chunk: {payload_chunk!r}")
|
501
502
|
|
502
503
|
text = safe_get(payload_chunk, "delta", "text", default="")
|
@@ -514,7 +515,7 @@ async def fetch_aws_response_stream(client, url, headers, payload, model, timeou
|
|
514
515
|
|
515
516
|
yield "data: [DONE]" + end_of_line
|
516
517
|
|
517
|
-
async def fetch_response(client, url, headers, payload, engine, model, timeout):
|
518
|
+
async def fetch_response(client, url, headers, payload, engine, model, timeout=200):
|
518
519
|
response = None
|
519
520
|
if payload.get("file"):
|
520
521
|
file = payload.pop("file")
|
@@ -530,7 +531,8 @@ async def fetch_response(client, url, headers, payload, engine, model, timeout):
|
|
530
531
|
yield response.read()
|
531
532
|
|
532
533
|
elif engine == "gemini" or engine == "vertex-gemini" or engine == "aws":
|
533
|
-
|
534
|
+
response_bytes = await response.aread()
|
535
|
+
response_json = await asyncio.to_thread(json.loads, response_bytes)
|
534
536
|
# print("response_json", json.dumps(response_json, indent=4, ensure_ascii=False))
|
535
537
|
|
536
538
|
if isinstance(response_json, str):
|
@@ -585,7 +587,8 @@ async def fetch_response(client, url, headers, payload, engine, model, timeout):
|
|
585
587
|
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=None, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=candidates_tokens, reasoning_content=reasoning_content, image_base64=image_base64)
|
586
588
|
|
587
589
|
elif engine == "claude":
|
588
|
-
|
590
|
+
response_bytes = await response.aread()
|
591
|
+
response_json = await asyncio.to_thread(json.loads, response_bytes)
|
589
592
|
# print("response_json", json.dumps(response_json, indent=4, ensure_ascii=False))
|
590
593
|
|
591
594
|
content = safe_get(response_json, "content", 0, "text")
|
@@ -604,7 +607,8 @@ async def fetch_response(client, url, headers, payload, engine, model, timeout):
|
|
604
607
|
yield await generate_no_stream_response(timestamp, model, content=content, tools_id=tools_id, function_call_name=function_call_name, function_call_content=function_call_content, role=role, total_tokens=total_tokens, prompt_tokens=prompt_tokens, completion_tokens=output_tokens)
|
605
608
|
|
606
609
|
elif engine == "azure":
|
607
|
-
|
610
|
+
response_bytes = await response.aread()
|
611
|
+
response_json = await asyncio.to_thread(json.loads, response_bytes)
|
608
612
|
# 删除 content_filter_results
|
609
613
|
if "choices" in response_json:
|
610
614
|
for choice in response_json["choices"]:
|
@@ -618,14 +622,16 @@ async def fetch_response(client, url, headers, payload, engine, model, timeout):
|
|
618
622
|
yield response_json
|
619
623
|
|
620
624
|
elif "dashscope.aliyuncs.com" in url and "multimodal-generation" in url:
|
621
|
-
|
625
|
+
response_bytes = await response.aread()
|
626
|
+
response_json = await asyncio.to_thread(json.loads, response_bytes)
|
622
627
|
content = safe_get(response_json, "output", "choices", 0, "message", "content", 0, default=None)
|
623
628
|
yield content
|
624
629
|
else:
|
625
|
-
|
630
|
+
response_bytes = await response.aread()
|
631
|
+
response_json = await asyncio.to_thread(json.loads, response_bytes)
|
626
632
|
yield response_json
|
627
633
|
|
628
|
-
async def fetch_response_stream(client, url, headers, payload, engine, model, timeout):
|
634
|
+
async def fetch_response_stream(client, url, headers, payload, engine, model, timeout=200):
|
629
635
|
if engine == "gemini" or engine == "vertex-gemini":
|
630
636
|
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model, timeout):
|
631
637
|
yield chunk
|
aient/models/chatgpt.py
CHANGED
@@ -15,6 +15,31 @@ from ..utils.scripts import safe_get, async_generator_to_sync, parse_function_xm
|
|
15
15
|
from ..core.request import prepare_request_payload
|
16
16
|
from ..core.response import fetch_response_stream, fetch_response
|
17
17
|
|
18
|
+
class APITimeoutError(Exception):
|
19
|
+
"""Custom exception for API timeout errors."""
|
20
|
+
pass
|
21
|
+
|
22
|
+
class ValidationError(Exception):
|
23
|
+
"""Custom exception for response validation errors."""
|
24
|
+
def __init__(self, message, response_text):
|
25
|
+
super().__init__(message)
|
26
|
+
self.response_text = response_text
|
27
|
+
|
28
|
+
class EmptyResponseError(Exception):
|
29
|
+
"""Custom exception for empty API responses."""
|
30
|
+
pass
|
31
|
+
|
32
|
+
class ModelNotFoundError(Exception):
|
33
|
+
"""Custom exception for model not found (404) errors."""
|
34
|
+
pass
|
35
|
+
|
36
|
+
class TaskComplete(Exception):
|
37
|
+
"""Exception-like signal to indicate the task is complete."""
|
38
|
+
def __init__(self, message):
|
39
|
+
self.completion_message = message
|
40
|
+
super().__init__(f"Task completed with message: {message}")
|
41
|
+
|
42
|
+
|
18
43
|
class chatgpt(BaseLLM):
|
19
44
|
"""
|
20
45
|
Official ChatGPT API
|
@@ -436,7 +461,7 @@ class chatgpt(BaseLLM):
|
|
436
461
|
yield chunk
|
437
462
|
|
438
463
|
if not full_response.strip():
|
439
|
-
raise
|
464
|
+
raise EmptyResponseError("Response is empty")
|
440
465
|
|
441
466
|
if self.print_log:
|
442
467
|
self.logger.info(f"total_tokens: {total_tokens}")
|
@@ -450,7 +475,7 @@ class chatgpt(BaseLLM):
|
|
450
475
|
if self.check_done:
|
451
476
|
# self.logger.info(f"worker Response: {full_response}")
|
452
477
|
if not full_response.strip().endswith('[done]'):
|
453
|
-
raise
|
478
|
+
raise ValidationError("Response is not ended with [done]", response_text=full_response)
|
454
479
|
else:
|
455
480
|
full_response = full_response.strip().rstrip('[done]')
|
456
481
|
full_response = full_response.replace("<tool_code>", "").replace("</tool_code>", "")
|
@@ -494,6 +519,8 @@ class chatgpt(BaseLLM):
|
|
494
519
|
# 删除 task_complete 跟其他工具一起调用的情况,因为 task_complete 必须单独调用
|
495
520
|
if len(function_parameter) > 1:
|
496
521
|
function_parameter = [tool_dict for tool_dict in function_parameter if tool_dict.get("function_name", "") != "task_complete"]
|
522
|
+
if len(function_parameter) == 1 and function_parameter[0].get("function_name", "") == "task_complete":
|
523
|
+
raise TaskComplete(safe_get(function_parameter, 0, "parameter", "message", default="The task has been completed."))
|
497
524
|
|
498
525
|
if self.print_log and invalid_tools:
|
499
526
|
self.logger.error(f"invalid_tools: {invalid_tools}")
|
@@ -739,13 +766,20 @@ class chatgpt(BaseLLM):
|
|
739
766
|
)
|
740
767
|
|
741
768
|
# 处理正常响应
|
769
|
+
index = 0
|
742
770
|
async for processed_chunk in self._process_stream_response(
|
743
771
|
generator, convo_id=convo_id, function_name=function_name,
|
744
772
|
total_tokens=total_tokens, function_arguments=function_arguments,
|
745
773
|
function_call_id=function_call_id, model=model, language=language,
|
746
774
|
system_prompt=system_prompt, pass_history=pass_history, is_async=True, stream=stream, **kwargs
|
747
775
|
):
|
776
|
+
if index == 0:
|
777
|
+
if "HTTP Error', 'status_code': 524" in processed_chunk:
|
778
|
+
raise APITimeoutError("Response timeout")
|
779
|
+
if "HTTP Error', 'status_code': 404" in processed_chunk:
|
780
|
+
raise ModelNotFoundError(f"Model: {model or self.engine} not found!")
|
748
781
|
yield processed_chunk
|
782
|
+
index += 1
|
749
783
|
|
750
784
|
# 成功处理,跳出重试循环
|
751
785
|
break
|
@@ -754,17 +788,25 @@ class chatgpt(BaseLLM):
|
|
754
788
|
return # Stop iteration
|
755
789
|
except httpx.RemoteProtocolError:
|
756
790
|
continue
|
791
|
+
except APITimeoutError:
|
792
|
+
self.logger.warning("API response timeout (524), retrying...")
|
793
|
+
continue
|
794
|
+
except ValidationError as e:
|
795
|
+
self.logger.warning(f"Validation failed: {e}. Retrying with corrective prompt.")
|
796
|
+
need_done_prompt = [
|
797
|
+
{"role": "assistant", "content": e.response_text},
|
798
|
+
{"role": "user", "content": "你的消息没有以[done]结尾,请重新输出"}
|
799
|
+
]
|
800
|
+
continue
|
801
|
+
except EmptyResponseError as e:
|
802
|
+
self.logger.warning(f"{e}, retrying...")
|
803
|
+
continue
|
804
|
+
except TaskComplete as e:
|
805
|
+
raise
|
806
|
+
except ModelNotFoundError as e:
|
807
|
+
raise
|
757
808
|
except Exception as e:
|
758
809
|
self.logger.error(f"{e}")
|
759
|
-
if "validation_error" in str(e):
|
760
|
-
bad_assistant_message = json.loads(str(e))["response"]
|
761
|
-
need_done_prompt = [
|
762
|
-
{"role": "assistant", "content": bad_assistant_message},
|
763
|
-
{"role": "user", "content": "你的消息没有以[done]结尾,请重新输出"}
|
764
|
-
]
|
765
|
-
continue
|
766
|
-
if "response_empty_error" in str(e):
|
767
|
-
continue
|
768
810
|
import traceback
|
769
811
|
self.logger.error(traceback.format_exc())
|
770
812
|
if "Invalid URL" in str(e):
|
@@ -3,7 +3,7 @@ aient/core/__init__.py,sha256=NxjebTlku35S4Dzr16rdSqSTWUvvwEeACe8KvHJnjPg,34
|
|
3
3
|
aient/core/log_config.py,sha256=kz2_yJv1p-o3lUQOwA3qh-LSc3wMHv13iCQclw44W9c,274
|
4
4
|
aient/core/models.py,sha256=KMlCRLjtq1wQHZTJGqnbWhPS2cHq6eLdnk7peKDrzR8,7490
|
5
5
|
aient/core/request.py,sha256=vfwi3ZGYp2hQzSJ6mPXJVgcV_uu5AJ_NAL84mLfF8WA,76674
|
6
|
-
aient/core/response.py,sha256=
|
6
|
+
aient/core/response.py,sha256=vQFuc3amHiD1hv_OiINRJnh33n79PnbdzMSBSRlqR5E,34309
|
7
7
|
aient/core/utils.py,sha256=D98d5Cy1h4ejKtuxS0EEDtL4YqpaZLB5tuXoVP0IBWQ,28462
|
8
8
|
aient/core/test/test_base_api.py,sha256=pWnycRJbuPSXKKU9AQjWrMAX1wiLC_014Qc9hh5C2Pw,524
|
9
9
|
aient/core/test/test_geminimask.py,sha256=HFX8jDbNg_FjjgPNxfYaR-0-roUrOO-ND-FVsuxSoiw,13254
|
@@ -12,7 +12,7 @@ aient/core/test/test_payload.py,sha256=8jBiJY1uidm1jzL-EiK0s6UGmW9XkdsuuKFGrwFhF
|
|
12
12
|
aient/models/__init__.py,sha256=ZTiZgbfBPTjIPSKURE7t6hlFBVLRS9lluGbmqc1WjxQ,43
|
13
13
|
aient/models/audio.py,sha256=kRd-8-WXzv4vwvsTGwnstK-WR8--vr9CdfCZzu8y9LA,1934
|
14
14
|
aient/models/base.py,sha256=-nnihYnx-vHZMqeVO9ljjt3k4FcD3n-iMk4tT-10nRQ,7232
|
15
|
-
aient/models/chatgpt.py,sha256=
|
15
|
+
aient/models/chatgpt.py,sha256=q62B6cbtHqKrqsQjM24k_1wi_5-UiuxkXa7e2yG_Clg,44661
|
16
16
|
aient/plugins/__init__.py,sha256=p3KO6Aa3Lupos4i2SjzLQw1hzQTigOAfEHngsldrsyk,986
|
17
17
|
aient/plugins/arXiv.py,sha256=yHjb6PS3GUWazpOYRMKMzghKJlxnZ5TX8z9F6UtUVow,1461
|
18
18
|
aient/plugins/config.py,sha256=TGgZ5SnNKZ8MmdznrZ-TEq7s2ulhAAwTSKH89bci3dA,7079
|
@@ -30,8 +30,8 @@ aient/plugins/write_file.py,sha256=Jt8fOEwqhYiSWpCbwfAr1xoi_BmFnx3076GMhuL06uI,3
|
|
30
30
|
aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
|
32
32
|
aient/utils/scripts.py,sha256=VqtK4RFEx7KxkmcqG3lFDS1DxoNlFFGErEjopVcc8IE,40974
|
33
|
-
aient-1.1.
|
34
|
-
aient-1.1.
|
35
|
-
aient-1.1.
|
36
|
-
aient-1.1.
|
37
|
-
aient-1.1.
|
33
|
+
aient-1.1.76.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
|
34
|
+
aient-1.1.76.dist-info/METADATA,sha256=nOBPFlGsNRfFqblnwjC4Z36Dq8TkUMcsdTDrI9Gcm8E,4842
|
35
|
+
aient-1.1.76.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
36
|
+
aient-1.1.76.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
|
37
|
+
aient-1.1.76.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|