aient 1.0.49__py3-none-any.whl → 1.0.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/core/request.py +15 -7
- aient/core/response.py +41 -5
- aient/core/utils.py +33 -2
- aient/models/base.py +0 -1
- aient/models/chatgpt.py +2 -2
- aient/plugins/config.py +1 -1
- {aient-1.0.49.dist-info → aient-1.0.51.dist-info}/METADATA +1 -1
- {aient-1.0.49.dist-info → aient-1.0.51.dist-info}/RECORD +11 -11
- {aient-1.0.49.dist-info → aient-1.0.51.dist-info}/WHEEL +0 -0
- {aient-1.0.49.dist-info → aient-1.0.51.dist-info}/licenses/LICENSE +0 -0
- {aient-1.0.49.dist-info → aient-1.0.51.dist-info}/top_level.txt +0 -0
aient/core/request.py
CHANGED
@@ -96,7 +96,7 @@ async def get_gemini_payload(request, engine, provider, api_key=None):
|
|
96
96
|
content[0]["text"] = re.sub(r"_+", "_", content[0]["text"])
|
97
97
|
systemInstruction = {"parts": content}
|
98
98
|
|
99
|
-
off_models = ["gemini-2.0-flash
|
99
|
+
off_models = ["gemini-2.0-flash", "gemini-1.5", "gemini-2.5-pro"]
|
100
100
|
if any(off_model in original_model for off_model in off_models):
|
101
101
|
safety_settings = "OFF"
|
102
102
|
else:
|
@@ -294,12 +294,20 @@ async def get_vertex_gemini_payload(request, engine, provider, api_key=None):
|
|
294
294
|
location = gemini1
|
295
295
|
search_tool = {"googleSearchRetrieval": {}}
|
296
296
|
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
297
|
+
if "google-vertex-ai" in provider.get("base_url", ""):
|
298
|
+
url = provider.get("base_url").rstrip('/') + "/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(
|
299
|
+
LOCATION=await location.next(),
|
300
|
+
PROJECT_ID=project_id,
|
301
|
+
MODEL_ID=original_model,
|
302
|
+
stream=gemini_stream
|
303
|
+
)
|
304
|
+
else:
|
305
|
+
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(
|
306
|
+
LOCATION=await location.next(),
|
307
|
+
PROJECT_ID=project_id,
|
308
|
+
MODEL_ID=original_model,
|
309
|
+
stream=gemini_stream
|
310
|
+
)
|
303
311
|
|
304
312
|
messages = []
|
305
313
|
systemInstruction = None
|
aient/core/response.py
CHANGED
@@ -5,7 +5,7 @@ from datetime import datetime
|
|
5
5
|
|
6
6
|
from .log_config import logger
|
7
7
|
|
8
|
-
from .utils import safe_get, generate_sse_response, generate_no_stream_response, end_of_line
|
8
|
+
from .utils import safe_get, generate_sse_response, generate_no_stream_response, end_of_line, parse_json_safely
|
9
9
|
|
10
10
|
async def check_response(response, error_log):
|
11
11
|
if response and not (200 <= response.status_code < 300):
|
@@ -30,6 +30,9 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
30
30
|
function_full_response = "{"
|
31
31
|
need_function_call = False
|
32
32
|
is_finish = False
|
33
|
+
promptTokenCount = 0
|
34
|
+
candidatesTokenCount = 0
|
35
|
+
totalTokenCount = 0
|
33
36
|
# line_index = 0
|
34
37
|
# last_text_line = 0
|
35
38
|
# if "thinking" in model:
|
@@ -42,9 +45,19 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
42
45
|
while "\n" in buffer:
|
43
46
|
line, buffer = buffer.split("\n", 1)
|
44
47
|
# line_index += 1
|
48
|
+
|
45
49
|
if line and '\"finishReason\": \"' in line:
|
46
50
|
is_finish = True
|
47
|
-
|
51
|
+
if is_finish and '\"promptTokenCount\": ' in line:
|
52
|
+
json_data = parse_json_safely( "{" + line + "}")
|
53
|
+
promptTokenCount = json_data.get('promptTokenCount', 0)
|
54
|
+
if is_finish and '\"candidatesTokenCount\": ' in line:
|
55
|
+
json_data = parse_json_safely( "{" + line + "}")
|
56
|
+
candidatesTokenCount = json_data.get('candidatesTokenCount', 0)
|
57
|
+
if is_finish and '\"totalTokenCount\": ' in line:
|
58
|
+
json_data = parse_json_safely( "{" + line + "}")
|
59
|
+
totalTokenCount = json_data.get('totalTokenCount', 0)
|
60
|
+
|
48
61
|
# print(line)
|
49
62
|
if line and '\"text\": \"' in line:
|
50
63
|
try:
|
@@ -73,9 +86,6 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
73
86
|
|
74
87
|
function_full_response += line
|
75
88
|
|
76
|
-
if is_finish:
|
77
|
-
break
|
78
|
-
|
79
89
|
if need_function_call:
|
80
90
|
function_call = json.loads(function_full_response)
|
81
91
|
function_call_name = function_call["functionCall"]["name"]
|
@@ -84,6 +94,10 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
84
94
|
function_full_response = json.dumps(function_call["functionCall"]["args"])
|
85
95
|
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
|
86
96
|
yield sse_string
|
97
|
+
|
98
|
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, totalTokenCount, promptTokenCount, candidatesTokenCount)
|
99
|
+
yield sse_string
|
100
|
+
|
87
101
|
yield "data: [DONE]" + end_of_line
|
88
102
|
|
89
103
|
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
|
@@ -98,11 +112,29 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
|
|
98
112
|
revicing_function_call = False
|
99
113
|
function_full_response = "{"
|
100
114
|
need_function_call = False
|
115
|
+
is_finish = False
|
116
|
+
promptTokenCount = 0
|
117
|
+
candidatesTokenCount = 0
|
118
|
+
totalTokenCount = 0
|
119
|
+
|
101
120
|
async for chunk in response.aiter_text():
|
102
121
|
buffer += chunk
|
103
122
|
while "\n" in buffer:
|
104
123
|
line, buffer = buffer.split("\n", 1)
|
105
124
|
# logger.info(f"{line}")
|
125
|
+
|
126
|
+
if line and '\"finishReason\": \"' in line:
|
127
|
+
is_finish = True
|
128
|
+
if is_finish and '\"promptTokenCount\": ' in line:
|
129
|
+
json_data = parse_json_safely( "{" + line + "}")
|
130
|
+
promptTokenCount = json_data.get('promptTokenCount', 0)
|
131
|
+
if is_finish and '\"candidatesTokenCount\": ' in line:
|
132
|
+
json_data = parse_json_safely( "{" + line + "}")
|
133
|
+
candidatesTokenCount = json_data.get('candidatesTokenCount', 0)
|
134
|
+
if is_finish and '\"totalTokenCount\": ' in line:
|
135
|
+
json_data = parse_json_safely( "{" + line + "}")
|
136
|
+
totalTokenCount = json_data.get('totalTokenCount', 0)
|
137
|
+
|
106
138
|
if line and '\"text\": \"' in line:
|
107
139
|
try:
|
108
140
|
json_data = json.loads( "{" + line + "}")
|
@@ -130,6 +162,10 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
|
|
130
162
|
function_full_response = json.dumps(function_call["input"])
|
131
163
|
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=None, function_call_content=function_full_response)
|
132
164
|
yield sse_string
|
165
|
+
|
166
|
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, totalTokenCount, promptTokenCount, candidatesTokenCount)
|
167
|
+
yield sse_string
|
168
|
+
|
133
169
|
yield "data: [DONE]" + end_of_line
|
134
170
|
|
135
171
|
async def fetch_gpt_response_stream(client, url, headers, payload):
|
aient/core/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import re
|
2
2
|
import io
|
3
3
|
import os
|
4
|
+
import ast
|
4
5
|
import json
|
5
6
|
import httpx
|
6
7
|
import base64
|
@@ -65,7 +66,7 @@ def get_engine(provider, endpoint=None, original_model=""):
|
|
65
66
|
stream = None
|
66
67
|
if parsed_url.path.endswith("/v1beta") or parsed_url.path.endswith("/v1") or parsed_url.netloc == 'generativelanguage.googleapis.com':
|
67
68
|
engine = "gemini"
|
68
|
-
elif parsed_url.netloc.rstrip('/').endswith('aiplatform.googleapis.com'):
|
69
|
+
elif parsed_url.netloc.rstrip('/').endswith('aiplatform.googleapis.com') or (parsed_url.netloc.rstrip('/').endswith('gateway.ai.cloudflare.com') and "google-vertex-ai" in parsed_url.path):
|
69
70
|
engine = "vertex"
|
70
71
|
elif parsed_url.netloc.rstrip('/').endswith('openai.azure.com') or parsed_url.netloc.rstrip('/').endswith('services.ai.azure.com'):
|
71
72
|
engine = "azure"
|
@@ -680,4 +681,34 @@ async def get_text_message(message, engine = None):
|
|
680
681
|
return message
|
681
682
|
if engine == "cohere":
|
682
683
|
return message
|
683
|
-
raise ValueError("Unknown engine")
|
684
|
+
raise ValueError("Unknown engine")
|
685
|
+
|
686
|
+
def parse_json_safely(json_str):
|
687
|
+
"""
|
688
|
+
尝试解析JSON字符串,先使用ast.literal_eval,失败则使用json.loads
|
689
|
+
|
690
|
+
Args:
|
691
|
+
json_str: 要解析的JSON字符串
|
692
|
+
|
693
|
+
Returns:
|
694
|
+
解析后的Python对象
|
695
|
+
|
696
|
+
Raises:
|
697
|
+
Exception: 当两种方法都失败时抛出异常
|
698
|
+
"""
|
699
|
+
try:
|
700
|
+
# 首先尝试使用ast.literal_eval解析
|
701
|
+
return ast.literal_eval(json_str)
|
702
|
+
except (SyntaxError, ValueError):
|
703
|
+
try:
|
704
|
+
# 如果失败,尝试使用json.loads解析
|
705
|
+
return json.loads(json_str, strict=False)
|
706
|
+
except json.JSONDecodeError as e:
|
707
|
+
# 两种方法都失败,抛出异常
|
708
|
+
raise Exception(f"无法解析JSON字符串: {e}")
|
709
|
+
|
710
|
+
if __name__ == "__main__":
|
711
|
+
provider = {
|
712
|
+
"base_url": "https://gateway.ai.cloudflare.com/v1/%7Baccount_id%7D/%7Bgateway_id%7D/google-vertex-ai",
|
713
|
+
}
|
714
|
+
print(get_engine(provider))
|
aient/models/base.py
CHANGED
aient/models/chatgpt.py
CHANGED
@@ -445,7 +445,7 @@ class chatgpt(BaseLLM):
|
|
445
445
|
async for chunk in get_tools_result_async(
|
446
446
|
tool_name, tool_args, function_call_max_tokens,
|
447
447
|
model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
|
448
|
-
self.api_url, use_plugins=False, model=model or self.engine,
|
448
|
+
kwargs.get('api_url', self.api_url.chat_url), use_plugins=False, model=model or self.engine,
|
449
449
|
add_message=self.add_to_conversation, convo_id=convo_id, language=language
|
450
450
|
):
|
451
451
|
yield chunk
|
@@ -454,7 +454,7 @@ class chatgpt(BaseLLM):
|
|
454
454
|
async for chunk in get_tools_result_async(
|
455
455
|
tool_name, tool_args, function_call_max_tokens,
|
456
456
|
model or self.engine, chatgpt, kwargs.get('api_key', self.api_key),
|
457
|
-
self.api_url, use_plugins=False, model=model or self.engine,
|
457
|
+
kwargs.get('api_url', self.api_url.chat_url), use_plugins=False, model=model or self.engine,
|
458
458
|
add_message=self.add_to_conversation, convo_id=convo_id, language=language
|
459
459
|
):
|
460
460
|
yield chunk
|
aient/plugins/config.py
CHANGED
@@ -13,7 +13,7 @@ async def get_tools_result_async(function_call_name, function_full_response, fun
|
|
13
13
|
if function_call_name == "get_search_results":
|
14
14
|
prompt = json.loads(function_full_response)["query"]
|
15
15
|
yield "message_search_stage_1"
|
16
|
-
llm = robot(api_key=api_key, api_url=api_url
|
16
|
+
llm = robot(api_key=api_key, api_url=api_url, engine=engine, use_plugins=use_plugins)
|
17
17
|
keywords = (await llm.ask_async(search_key_word_prompt.format(source=prompt), model=model)).split("\n")
|
18
18
|
print("keywords", keywords)
|
19
19
|
keywords = [item.replace("三行关键词是:", "") for item in keywords if "\\x" not in item if item != ""]
|
@@ -3,16 +3,16 @@ aient/core/.git,sha256=lrAcW1SxzRBUcUiuKL5tS9ykDmmTXxyLP3YYU-Y-Q-I,45
|
|
3
3
|
aient/core/__init__.py,sha256=NxjebTlku35S4Dzr16rdSqSTWUvvwEeACe8KvHJnjPg,34
|
4
4
|
aient/core/log_config.py,sha256=kz2_yJv1p-o3lUQOwA3qh-LSc3wMHv13iCQclw44W9c,274
|
5
5
|
aient/core/models.py,sha256=H3_XuWA7aS25MWZPK1c-5RBiiuxWJbTfE3RAk0Pkc9A,7504
|
6
|
-
aient/core/request.py,sha256=
|
7
|
-
aient/core/response.py,sha256=
|
8
|
-
aient/core/utils.py,sha256=
|
6
|
+
aient/core/request.py,sha256=c963OuUEBe7j1jxiiwipUyzGrbsCwXQIw_XGF_KdL-4,49491
|
7
|
+
aient/core/response.py,sha256=uK6a--vHDk61yB4a-0og36S-d4FSO5X6cLeYSxY9G-A,27726
|
8
|
+
aient/core/utils.py,sha256=DFpFU8Y-8lzgQlhaDUnao8HmviGoh3-oN8jZR3Dha7E,26150
|
9
9
|
aient/core/test/test_base_api.py,sha256=CjfFzMG26r8C4xCPoVkKb3Ac6pp9gy5NUCbZJHoSSsM,393
|
10
10
|
aient/core/test/test_image.py,sha256=_T4peNGdXKBHHxyQNx12u-NTyFE8TlYI6NvvagsG2LE,319
|
11
11
|
aient/core/test/test_payload.py,sha256=8jBiJY1uidm1jzL-EiK0s6UGmW9XkdsuuKFGrwFhFkw,2755
|
12
12
|
aient/models/__init__.py,sha256=ouNDNvoBBpIFrLsk09Q_sq23HR0GbLAKfGLIFmfEuXE,219
|
13
13
|
aient/models/audio.py,sha256=kRd-8-WXzv4vwvsTGwnstK-WR8--vr9CdfCZzu8y9LA,1934
|
14
|
-
aient/models/base.py,sha256=
|
15
|
-
aient/models/chatgpt.py,sha256=
|
14
|
+
aient/models/base.py,sha256=PBfQTQI73OE2OuHvw-XN6vD-_y6k_vLjmT_8J4Gelvo,6813
|
15
|
+
aient/models/chatgpt.py,sha256=TRvUpgBsQW0oJpxOpT0__fYooy2hbkMSYVdM7CqusnY,42296
|
16
16
|
aient/models/claude.py,sha256=thK9P8qkaaoUN3OOJ9Shw4KDs-pAGKPoX4FOPGFXva8,28597
|
17
17
|
aient/models/duckduckgo.py,sha256=1l7vYCs9SG5SWPCbcl7q6pCcB5AUF_r-a4l9frz3Ogo,8115
|
18
18
|
aient/models/gemini.py,sha256=chGLc-8G_DAOxr10HPoOhvVFW1RvMgHd6mt--VyAW98,14730
|
@@ -20,7 +20,7 @@ aient/models/groq.py,sha256=2JCB0QE1htOprJHI5fZ11R2RtOhsHlsTjbmFyzc8oSM,10084
|
|
20
20
|
aient/models/vertex.py,sha256=qVD5l1Q538xXUPulxG4nmDjXE1VoV4yuAkTCpIeJVw0,16795
|
21
21
|
aient/plugins/__init__.py,sha256=KrCM6kFD1NB96hfhwUZIG8vJcdZVnfpACMew5YOWxSo,956
|
22
22
|
aient/plugins/arXiv.py,sha256=yHjb6PS3GUWazpOYRMKMzghKJlxnZ5TX8z9F6UtUVow,1461
|
23
|
-
aient/plugins/config.py,sha256=
|
23
|
+
aient/plugins/config.py,sha256=1eOFH0KVDV5qa-n_r-nyTx4sCstlcmoiGoaHqPjaSpI,7528
|
24
24
|
aient/plugins/excute_command.py,sha256=eAoBR6OmEbP7nzUScfRHHK3UwypuE5lxamUro8HmBMk,911
|
25
25
|
aient/plugins/get_time.py,sha256=Ih5XIW5SDAIhrZ9W4Qe5Hs1k4ieKPUc_LAd6ySNyqZk,654
|
26
26
|
aient/plugins/image.py,sha256=ZElCIaZznE06TN9xW3DrSukS7U3A5_cjk1Jge4NzPxw,2072
|
@@ -34,8 +34,8 @@ aient/prompt/agent.py,sha256=fYaLRcMZHgM35IIJkeYGeOpJvIhzfXLhpPR4Q3CYIRU,22258
|
|
34
34
|
aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
35
35
|
aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
|
36
36
|
aient/utils/scripts.py,sha256=XCXMRdpWRJb34Znk4t9JkFnvzDzGHVA5Vv5WpUgP2_0,27152
|
37
|
-
aient-1.0.
|
38
|
-
aient-1.0.
|
39
|
-
aient-1.0.
|
40
|
-
aient-1.0.
|
41
|
-
aient-1.0.
|
37
|
+
aient-1.0.51.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
|
38
|
+
aient-1.0.51.dist-info/METADATA,sha256=t_zvFpUBihcSS_ke8pZk0DltHeTBj_r8af3f7sO5gAc,4973
|
39
|
+
aient-1.0.51.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
40
|
+
aient-1.0.51.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
|
41
|
+
aient-1.0.51.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|