aient 1.1.60__tar.gz → 1.1.62__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aient-1.1.60 → aient-1.1.62}/PKG-INFO +1 -1
- {aient-1.1.60 → aient-1.1.62}/aient/models/base.py +31 -15
- {aient-1.1.60 → aient-1.1.62}/aient/models/chatgpt.py +79 -37
- {aient-1.1.60 → aient-1.1.62}/aient/utils/scripts.py +2 -0
- {aient-1.1.60 → aient-1.1.62}/aient.egg-info/PKG-INFO +1 -1
- {aient-1.1.60 → aient-1.1.62}/pyproject.toml +1 -1
- {aient-1.1.60 → aient-1.1.62}/LICENSE +0 -0
- {aient-1.1.60 → aient-1.1.62}/README.md +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/__init__.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/core/__init__.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/core/log_config.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/core/models.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/core/request.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/core/response.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/core/test/test_base_api.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/core/test/test_geminimask.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/core/test/test_image.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/core/test/test_payload.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/core/utils.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/models/__init__.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/models/audio.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/__init__.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/arXiv.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/config.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/excute_command.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/get_time.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/image.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/list_directory.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/read_file.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/read_image.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/readonly.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/registry.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/run_python.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/websearch.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/plugins/write_file.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/utils/__init__.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient/utils/prompt.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient.egg-info/SOURCES.txt +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient.egg-info/dependency_links.txt +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient.egg-info/requires.txt +0 -0
- {aient-1.1.60 → aient-1.1.62}/aient.egg-info/top_level.txt +0 -0
- {aient-1.1.60 → aient-1.1.62}/setup.cfg +0 -0
- {aient-1.1.60 → aient-1.1.62}/test/test_Web_crawler.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/test/test_ddg_search.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/test/test_google_search.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/test/test_ollama.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/test/test_plugin.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/test/test_search.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/test/test_url.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/test/test_whisper.py +0 -0
- {aient-1.1.60 → aient-1.1.62}/test/test_yjh.py +0 -0
@@ -53,20 +53,10 @@ class BaseLLM:
|
|
53
53
|
"https": proxy,
|
54
54
|
},
|
55
55
|
)
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
self.aclient = httpx.AsyncClient(
|
61
|
-
follow_redirects=True,
|
62
|
-
proxies=proxy,
|
63
|
-
timeout=timeout,
|
64
|
-
)
|
65
|
-
else:
|
66
|
-
self.aclient = httpx.AsyncClient(
|
67
|
-
follow_redirects=True,
|
68
|
-
timeout=timeout,
|
69
|
-
)
|
56
|
+
self._aclient = None
|
57
|
+
self._proxy = proxy
|
58
|
+
self._timeout = timeout
|
59
|
+
self._loop = None
|
70
60
|
|
71
61
|
self.conversation: dict[str, list[dict]] = {
|
72
62
|
"default": [
|
@@ -83,6 +73,33 @@ class BaseLLM:
|
|
83
73
|
self.use_plugins = use_plugins
|
84
74
|
self.print_log: bool = print_log
|
85
75
|
|
76
|
+
def _get_aclient(self):
|
77
|
+
"""
|
78
|
+
Lazily initialize and return the httpx.AsyncClient.
|
79
|
+
This method ensures the client is always bound to a running event loop.
|
80
|
+
"""
|
81
|
+
import asyncio
|
82
|
+
try:
|
83
|
+
loop = asyncio.get_running_loop()
|
84
|
+
except RuntimeError:
|
85
|
+
loop = asyncio.new_event_loop()
|
86
|
+
asyncio.set_event_loop(loop)
|
87
|
+
|
88
|
+
if self._aclient is None or self._aclient.is_closed or self._loop is not loop:
|
89
|
+
self._loop = loop
|
90
|
+
proxy = self._proxy or os.environ.get("all_proxy") or os.environ.get("ALL_PROXY") or None
|
91
|
+
proxies = proxy if proxy and "socks5h" not in proxy else None
|
92
|
+
self._aclient = httpx.AsyncClient(
|
93
|
+
follow_redirects=True,
|
94
|
+
proxy=proxies,
|
95
|
+
timeout=self._timeout,
|
96
|
+
)
|
97
|
+
return self._aclient
|
98
|
+
|
99
|
+
@property
|
100
|
+
def aclient(self):
|
101
|
+
return self._get_aclient()
|
102
|
+
|
86
103
|
def add_to_conversation(
|
87
104
|
self,
|
88
105
|
message: list,
|
@@ -196,7 +213,6 @@ class BaseLLM:
|
|
196
213
|
**kwargs,
|
197
214
|
):
|
198
215
|
response += chunk
|
199
|
-
# full_response: str = "".join([r async for r in response])
|
200
216
|
full_response: str = "".join(response)
|
201
217
|
return full_response
|
202
218
|
|
@@ -17,7 +17,7 @@ from ..plugins.registry import registry
|
|
17
17
|
from ..plugins import PLUGINS, get_tools_result_async, function_call_list, update_tools_config
|
18
18
|
from ..utils.scripts import safe_get, async_generator_to_sync, parse_function_xml, parse_continuous_json, convert_functions_to_xml, remove_xml_tags_and_content
|
19
19
|
from ..core.request import prepare_request_payload
|
20
|
-
from ..core.response import fetch_response_stream
|
20
|
+
from ..core.response import fetch_response_stream, fetch_response
|
21
21
|
|
22
22
|
def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]:
|
23
23
|
"""
|
@@ -288,6 +288,7 @@ class chatgpt(BaseLLM):
|
|
288
288
|
convo_id: str = "default",
|
289
289
|
model: str = "",
|
290
290
|
pass_history: int = 9999,
|
291
|
+
stream: bool = True,
|
291
292
|
**kwargs,
|
292
293
|
):
|
293
294
|
self.conversation[convo_id][0] = {"role": "system","content": self.system_prompt + "\n\n" + self.get_latest_file_content()}
|
@@ -309,12 +310,13 @@ class chatgpt(BaseLLM):
|
|
309
310
|
{"role": "system","content": self.system_prompt + "\n\n" + self.get_latest_file_content()},
|
310
311
|
{"role": role, "content": prompt}
|
311
312
|
],
|
312
|
-
"stream":
|
313
|
-
"stream_options": {
|
314
|
-
"include_usage": True
|
315
|
-
},
|
313
|
+
"stream": stream,
|
316
314
|
"temperature": kwargs.get("temperature", self.temperature)
|
317
315
|
}
|
316
|
+
if stream:
|
317
|
+
request_data["stream_options"] = {
|
318
|
+
"include_usage": True
|
319
|
+
}
|
318
320
|
|
319
321
|
if kwargs.get("max_tokens", self.max_tokens):
|
320
322
|
request_data["max_tokens"] = kwargs.get("max_tokens", self.max_tokens)
|
@@ -687,6 +689,7 @@ class chatgpt(BaseLLM):
|
|
687
689
|
function_call_id: str = "",
|
688
690
|
language: str = "English",
|
689
691
|
system_prompt: str = None,
|
692
|
+
stream: bool = True,
|
690
693
|
**kwargs,
|
691
694
|
):
|
692
695
|
"""
|
@@ -702,17 +705,20 @@ class chatgpt(BaseLLM):
|
|
702
705
|
json_post = None
|
703
706
|
async def get_post_body_async():
|
704
707
|
nonlocal json_post
|
705
|
-
url, headers, json_post, engine_type = await self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs)
|
708
|
+
url, headers, json_post, engine_type = await self.get_post_body(prompt, role, convo_id, model, pass_history, stream=stream, **kwargs)
|
706
709
|
return url, headers, json_post, engine_type
|
707
710
|
|
708
711
|
# 替换原来的获取请求体的代码
|
709
712
|
# json_post = next(async_generator_to_sync(get_post_body_async()))
|
710
713
|
try:
|
711
|
-
url, headers, json_post, engine_type = asyncio.run(get_post_body_async())
|
712
|
-
except RuntimeError:
|
713
|
-
# 如果已经在事件循环中,则使用不同的方法
|
714
714
|
loop = asyncio.get_event_loop()
|
715
|
-
|
715
|
+
if loop.is_closed():
|
716
|
+
loop = asyncio.new_event_loop()
|
717
|
+
asyncio.set_event_loop(loop)
|
718
|
+
except RuntimeError:
|
719
|
+
loop = asyncio.new_event_loop()
|
720
|
+
asyncio.set_event_loop(loop)
|
721
|
+
url, headers, json_post, engine_type = loop.run_until_complete(get_post_body_async())
|
716
722
|
|
717
723
|
self.truncate_conversation(convo_id=convo_id)
|
718
724
|
|
@@ -760,14 +766,24 @@ class chatgpt(BaseLLM):
|
|
760
766
|
yield f"data: {json.dumps(tmp_response)}\n\n"
|
761
767
|
async_generator = _mock_response_generator()
|
762
768
|
else:
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
769
|
+
if stream:
|
770
|
+
async_generator = fetch_response_stream(
|
771
|
+
self.aclient,
|
772
|
+
url,
|
773
|
+
headers,
|
774
|
+
json_post,
|
775
|
+
engine_type,
|
776
|
+
model or self.engine,
|
777
|
+
)
|
778
|
+
else:
|
779
|
+
async_generator = fetch_response(
|
780
|
+
self.aclient,
|
781
|
+
url,
|
782
|
+
headers,
|
783
|
+
json_post,
|
784
|
+
engine_type,
|
785
|
+
model or self.engine,
|
786
|
+
)
|
771
787
|
# 异步处理响应流
|
772
788
|
async for chunk in self._process_stream_response(
|
773
789
|
async_generator,
|
@@ -817,6 +833,7 @@ class chatgpt(BaseLLM):
|
|
817
833
|
function_call_id: str = "",
|
818
834
|
language: str = "English",
|
819
835
|
system_prompt: str = None,
|
836
|
+
stream: bool = True,
|
820
837
|
**kwargs,
|
821
838
|
):
|
822
839
|
"""
|
@@ -829,7 +846,7 @@ class chatgpt(BaseLLM):
|
|
829
846
|
self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name, total_tokens=total_tokens, function_arguments=function_arguments, pass_history=pass_history, function_call_id=function_call_id)
|
830
847
|
|
831
848
|
# 获取请求体
|
832
|
-
url, headers, json_post, engine_type = await self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs)
|
849
|
+
url, headers, json_post, engine_type = await self.get_post_body(prompt, role, convo_id, model, pass_history, stream=stream, **kwargs)
|
833
850
|
self.truncate_conversation(convo_id=convo_id)
|
834
851
|
|
835
852
|
# 打印日志
|
@@ -874,24 +891,24 @@ class chatgpt(BaseLLM):
|
|
874
891
|
yield f"data: {json.dumps(tmp_response)}\n\n"
|
875
892
|
generator = _mock_response_generator()
|
876
893
|
else:
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
894
|
+
if stream:
|
895
|
+
generator = fetch_response_stream(
|
896
|
+
self.aclient,
|
897
|
+
url,
|
898
|
+
headers,
|
899
|
+
json_post,
|
900
|
+
engine_type,
|
901
|
+
model or self.engine,
|
902
|
+
)
|
903
|
+
else:
|
904
|
+
generator = fetch_response(
|
905
|
+
self.aclient,
|
906
|
+
url,
|
907
|
+
headers,
|
908
|
+
json_post,
|
909
|
+
engine_type,
|
910
|
+
model or self.engine,
|
911
|
+
)
|
895
912
|
|
896
913
|
# 处理正常响应
|
897
914
|
async for processed_chunk in self._process_stream_response(
|
@@ -943,11 +960,36 @@ class chatgpt(BaseLLM):
|
|
943
960
|
convo_id=convo_id,
|
944
961
|
pass_history=pass_history,
|
945
962
|
model=model or self.engine,
|
963
|
+
stream=False,
|
946
964
|
**kwargs,
|
947
965
|
)
|
948
966
|
full_response: str = "".join([r async for r in response])
|
949
967
|
return full_response
|
950
968
|
|
969
|
+
def ask(
|
970
|
+
self,
|
971
|
+
prompt: str,
|
972
|
+
role: str = "user",
|
973
|
+
convo_id: str = "default",
|
974
|
+
model: str = "",
|
975
|
+
pass_history: int = 9999,
|
976
|
+
**kwargs,
|
977
|
+
) -> str:
|
978
|
+
"""
|
979
|
+
Non-streaming ask
|
980
|
+
"""
|
981
|
+
response = self.ask_stream(
|
982
|
+
prompt=prompt,
|
983
|
+
role=role,
|
984
|
+
convo_id=convo_id,
|
985
|
+
pass_history=pass_history,
|
986
|
+
model=model or self.engine,
|
987
|
+
stream=False,
|
988
|
+
**kwargs,
|
989
|
+
)
|
990
|
+
full_response: str = "".join([r for r in response])
|
991
|
+
return full_response
|
992
|
+
|
951
993
|
def rollback(self, n: int = 1, convo_id: str = "default") -> None:
|
952
994
|
"""
|
953
995
|
Rollback the conversation
|
@@ -212,6 +212,8 @@ def async_generator_to_sync(async_gen):
|
|
212
212
|
# 清理所有待处理的任务
|
213
213
|
tasks = [t for t in asyncio.all_tasks(loop) if not t.done()]
|
214
214
|
if tasks:
|
215
|
+
for task in tasks:
|
216
|
+
task.cancel()
|
215
217
|
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
216
218
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
217
219
|
loop.close()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|