beswarm 0.1.37__py3-none-any.whl → 0.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- beswarm/aient/setup.py +1 -1
- beswarm/aient/src/aient/models/claude.py +0 -67
- beswarm/aient/src/aient/models/groq.py +0 -34
- beswarm/aient/src/aient/plugins/config.py +3 -12
- beswarm/aient/src/aient/plugins/websearch.py +1 -1
- beswarm/aient/src/aient/utils/scripts.py +0 -23
- {beswarm-0.1.37.dist-info → beswarm-0.1.39.dist-info}/METADATA +3 -3
- {beswarm-0.1.37.dist-info → beswarm-0.1.39.dist-info}/RECORD +10 -11
- beswarm/aient/test/test_langchain_search_old.py +0 -235
- {beswarm-0.1.37.dist-info → beswarm-0.1.39.dist-info}/WHEEL +0 -0
- {beswarm-0.1.37.dist-info → beswarm-0.1.39.dist-info}/top_level.txt +0 -0
beswarm/aient/setup.py
CHANGED
@@ -4,7 +4,7 @@ from setuptools import setup, find_packages
|
|
4
4
|
|
5
5
|
setup(
|
6
6
|
name="aient",
|
7
|
-
version="1.0.
|
7
|
+
version="1.0.95",
|
8
8
|
description="Aient: The Awakening of Agent.",
|
9
9
|
long_description=Path.open(Path("README.md"), encoding="utf-8").read(),
|
10
10
|
long_description_content_type="text/markdown",
|
@@ -2,7 +2,6 @@ import os
|
|
2
2
|
import re
|
3
3
|
import json
|
4
4
|
import copy
|
5
|
-
import tiktoken
|
6
5
|
import requests
|
7
6
|
|
8
7
|
from .base import BaseLLM
|
@@ -65,39 +64,6 @@ class claude(BaseLLM):
|
|
65
64
|
self.conversation[convo_id] = claudeConversation()
|
66
65
|
self.system_prompt = system_prompt or self.system_prompt
|
67
66
|
|
68
|
-
def __truncate_conversation(self, convo_id: str = "default") -> None:
|
69
|
-
"""
|
70
|
-
Truncate the conversation
|
71
|
-
"""
|
72
|
-
while True:
|
73
|
-
if (
|
74
|
-
self.get_token_count(convo_id) > self.truncate_limit
|
75
|
-
and len(self.conversation[convo_id]) > 1
|
76
|
-
):
|
77
|
-
# Don't remove the first message
|
78
|
-
self.conversation[convo_id].pop(1)
|
79
|
-
else:
|
80
|
-
break
|
81
|
-
|
82
|
-
def get_token_count(self, convo_id: str = "default") -> int:
|
83
|
-
"""
|
84
|
-
Get token count
|
85
|
-
"""
|
86
|
-
tiktoken.model.MODEL_TO_ENCODING["claude-2.1"] = "cl100k_base"
|
87
|
-
encoding = tiktoken.encoding_for_model(self.engine)
|
88
|
-
|
89
|
-
num_tokens = 0
|
90
|
-
for message in self.conversation[convo_id]:
|
91
|
-
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
92
|
-
num_tokens += 5
|
93
|
-
for key, value in message.items():
|
94
|
-
if value:
|
95
|
-
num_tokens += len(encoding.encode(value))
|
96
|
-
if key == "name": # if there's a name, the role is omitted
|
97
|
-
num_tokens += 5 # role is always required and always 1 token
|
98
|
-
num_tokens += 5 # every reply is primed with <im_start>assistant
|
99
|
-
return num_tokens
|
100
|
-
|
101
67
|
def ask_stream(
|
102
68
|
self,
|
103
69
|
prompt: str,
|
@@ -267,39 +233,6 @@ class claude3(BaseLLM):
|
|
267
233
|
self.conversation[convo_id] = list()
|
268
234
|
self.system_prompt = system_prompt or self.system_prompt
|
269
235
|
|
270
|
-
def __truncate_conversation(self, convo_id: str = "default") -> None:
|
271
|
-
"""
|
272
|
-
Truncate the conversation
|
273
|
-
"""
|
274
|
-
while True:
|
275
|
-
if (
|
276
|
-
self.get_token_count(convo_id) > self.truncate_limit
|
277
|
-
and len(self.conversation[convo_id]) > 1
|
278
|
-
):
|
279
|
-
# Don't remove the first message
|
280
|
-
self.conversation[convo_id].pop(1)
|
281
|
-
else:
|
282
|
-
break
|
283
|
-
|
284
|
-
def get_token_count(self, convo_id: str = "default") -> int:
|
285
|
-
"""
|
286
|
-
Get token count
|
287
|
-
"""
|
288
|
-
tiktoken.model.MODEL_TO_ENCODING["claude-2.1"] = "cl100k_base"
|
289
|
-
encoding = tiktoken.encoding_for_model(self.engine)
|
290
|
-
|
291
|
-
num_tokens = 0
|
292
|
-
for message in self.conversation[convo_id]:
|
293
|
-
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
294
|
-
num_tokens += 5
|
295
|
-
for key, value in message.items():
|
296
|
-
if value:
|
297
|
-
num_tokens += len(encoding.encode(value))
|
298
|
-
if key == "name": # if there's a name, the role is omitted
|
299
|
-
num_tokens += 5 # role is always required and always 1 token
|
300
|
-
num_tokens += 5 # every reply is primed with <im_start>assistant
|
301
|
-
return num_tokens
|
302
|
-
|
303
236
|
def ask_stream(
|
304
237
|
self,
|
305
238
|
prompt: str,
|
@@ -1,7 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
import json
|
3
3
|
import requests
|
4
|
-
import tiktoken
|
5
4
|
|
6
5
|
from .base import BaseLLM
|
7
6
|
|
@@ -52,39 +51,6 @@ class groq(BaseLLM):
|
|
52
51
|
self.conversation[convo_id] = list()
|
53
52
|
self.system_prompt = system_prompt or self.system_prompt
|
54
53
|
|
55
|
-
def __truncate_conversation(self, convo_id: str = "default") -> None:
|
56
|
-
"""
|
57
|
-
Truncate the conversation
|
58
|
-
"""
|
59
|
-
while True:
|
60
|
-
if (
|
61
|
-
self.get_token_count(convo_id) > self.truncate_limit
|
62
|
-
and len(self.conversation[convo_id]) > 1
|
63
|
-
):
|
64
|
-
# Don't remove the first message
|
65
|
-
self.conversation[convo_id].pop(1)
|
66
|
-
else:
|
67
|
-
break
|
68
|
-
|
69
|
-
def get_token_count(self, convo_id: str = "default") -> int:
|
70
|
-
"""
|
71
|
-
Get token count
|
72
|
-
"""
|
73
|
-
# tiktoken.model.MODEL_TO_ENCODING["mixtral-8x7b-32768"] = "cl100k_base"
|
74
|
-
encoding = tiktoken.get_encoding("cl100k_base")
|
75
|
-
|
76
|
-
num_tokens = 0
|
77
|
-
for message in self.conversation[convo_id]:
|
78
|
-
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
79
|
-
num_tokens += 5
|
80
|
-
for key, value in message.items():
|
81
|
-
if value:
|
82
|
-
num_tokens += len(encoding.encode(value))
|
83
|
-
if key == "name": # if there's a name, the role is omitted
|
84
|
-
num_tokens += 5 # role is always required and always 1 token
|
85
|
-
num_tokens += 5 # every reply is primed with <im_start>assistant
|
86
|
-
return num_tokens
|
87
|
-
|
88
54
|
def ask_stream(
|
89
55
|
self,
|
90
56
|
prompt: str,
|
@@ -3,8 +3,7 @@ import json
|
|
3
3
|
import inspect
|
4
4
|
|
5
5
|
from .registry import registry
|
6
|
-
from ..utils.
|
7
|
-
from ..utils.prompt import search_key_word_prompt, arxiv_doc_user_prompt
|
6
|
+
from ..utils.prompt import search_key_word_prompt
|
8
7
|
|
9
8
|
async def get_tools_result_async(function_call_name, function_full_response, function_call_max_tokens, engine, robot, api_key, api_url, use_plugins, model, add_message, convo_id, language):
|
10
9
|
function_response = ""
|
@@ -26,10 +25,7 @@ async def get_tools_result_async(function_call_name, function_full_response, fun
|
|
26
25
|
yield chunk
|
27
26
|
else:
|
28
27
|
function_response = "\n\n".join(chunk)
|
29
|
-
|
30
|
-
# function_response = yield from eval(function_call_name)(prompt, keywords)
|
31
|
-
function_call_max_tokens = 32000
|
32
|
-
function_response, text_len = cut_message(function_response, function_call_max_tokens, engine)
|
28
|
+
|
33
29
|
if function_response:
|
34
30
|
function_response = (
|
35
31
|
f"You need to response the following question: {prompt}. Search results is provided inside <Search_results></Search_results> XML tags. Your task is to think about the question step by step and then answer the above question in {language} based on the Search results provided. Please response in {language} and adopt a style that is logical, in-depth, and detailed. Note: In order to make the answer appear highly professional, you should be an expert in textual analysis, aiming to make the answer precise and comprehensive. Directly response markdown format, without using markdown code blocks. For each sentence quoting search results, a markdown ordered superscript number url link must be used to indicate the source, e.g., [¹](https://www.example.com)"
|
@@ -40,18 +36,13 @@ async def get_tools_result_async(function_call_name, function_full_response, fun
|
|
40
36
|
).format(function_response)
|
41
37
|
else:
|
42
38
|
function_response = "无法找到相关信息,停止使用 tools"
|
43
|
-
|
44
|
-
# self.add_to_conversation(user_prompt, "user", convo_id=convo_id)
|
39
|
+
|
45
40
|
elif function_to_call:
|
46
41
|
prompt = json.loads(function_full_response)
|
47
42
|
if inspect.iscoroutinefunction(function_to_call):
|
48
43
|
function_response = await function_to_call(**prompt)
|
49
44
|
else:
|
50
45
|
function_response = function_to_call(**prompt)
|
51
|
-
function_response, text_len = cut_message(function_response, function_call_max_tokens, engine)
|
52
|
-
|
53
|
-
# if function_call_name == "download_read_arxiv_pdf":
|
54
|
-
# add_message(arxiv_doc_user_prompt, "user", convo_id=convo_id)
|
55
46
|
|
56
47
|
function_response = (
|
57
48
|
f"function_response:{function_response}"
|
@@ -6,7 +6,6 @@ import threading
|
|
6
6
|
import time as record_time
|
7
7
|
from itertools import islice
|
8
8
|
from bs4 import BeautifulSoup
|
9
|
-
from duckduckgo_search import DDGS
|
10
9
|
from .registry import register_tool
|
11
10
|
|
12
11
|
class ThreadWithReturnValue(threading.Thread):
|
@@ -178,6 +177,7 @@ def get_url_content(url: str) -> str:
|
|
178
177
|
|
179
178
|
def getddgsearchurl(query, max_results=4):
|
180
179
|
try:
|
180
|
+
from duckduckgo_search import DDGS
|
181
181
|
results = []
|
182
182
|
with DDGS() as ddgs:
|
183
183
|
ddgs_gen = ddgs.text(query, safesearch='Off', timelimit='y', backend="lite")
|
@@ -1,33 +1,10 @@
|
|
1
1
|
import os
|
2
2
|
import json
|
3
|
-
import base64
|
4
|
-
import tiktoken
|
5
3
|
import requests
|
6
4
|
import urllib.parse
|
7
5
|
|
8
6
|
from ..core.utils import get_image_message
|
9
7
|
|
10
|
-
def get_encode_text(text, model_name):
|
11
|
-
tiktoken.get_encoding("cl100k_base")
|
12
|
-
model_name = "gpt-3.5-turbo"
|
13
|
-
encoding = tiktoken.encoding_for_model(model_name)
|
14
|
-
encode_text = encoding.encode(text, disallowed_special=())
|
15
|
-
return encoding, encode_text
|
16
|
-
|
17
|
-
def get_text_token_len(text, model_name):
|
18
|
-
encoding, encode_text = get_encode_text(text, model_name)
|
19
|
-
return len(encode_text)
|
20
|
-
|
21
|
-
def cut_message(message: str, max_tokens: int, model_name: str):
|
22
|
-
if type(message) != str:
|
23
|
-
message = str(message)
|
24
|
-
encoding, encode_text = get_encode_text(message, model_name)
|
25
|
-
if len(encode_text) > max_tokens:
|
26
|
-
encode_text = encode_text[:max_tokens]
|
27
|
-
message = encoding.decode(encode_text)
|
28
|
-
encode_text = encoding.encode(message, disallowed_special=())
|
29
|
-
return message, len(encode_text)
|
30
|
-
|
31
8
|
def get_doc_from_url(url):
|
32
9
|
filename = urllib.parse.unquote(url.split("/")[-1])
|
33
10
|
response = requests.get(url, stream=True)
|
@@ -1,12 +1,11 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: beswarm
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.39
|
4
4
|
Summary: MAS
|
5
5
|
Requires-Python: >=3.11
|
6
6
|
Description-Content-Type: text/markdown
|
7
7
|
Requires-Dist: beautifulsoup4>=4.13.4
|
8
8
|
Requires-Dist: diskcache>=5.6.3
|
9
|
-
Requires-Dist: duckduckgo-search==5.3.1
|
10
9
|
Requires-Dist: fake-useragent>=2.2.0
|
11
10
|
Requires-Dist: fastapi>=0.115.12
|
12
11
|
Requires-Dist: grep-ast>=0.8.1
|
@@ -27,8 +26,9 @@ Requires-Dist: pyperclip>=1.9.0
|
|
27
26
|
Requires-Dist: pytz>=2025.2
|
28
27
|
Requires-Dist: requests>=2.32.3
|
29
28
|
Requires-Dist: scipy>=1.15.2
|
30
|
-
Requires-Dist: tiktoken==0.6.0
|
31
29
|
Requires-Dist: tqdm>=4.67.1
|
30
|
+
Provides-Extra: search
|
31
|
+
Requires-Dist: duckduckgo-search==5.3.1; extra == "search"
|
32
32
|
|
33
33
|
# beswarm
|
34
34
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
beswarm/__init__.py,sha256=HZjUOJtZR5QhMuDbq-wukQQn1VrBusNWai_ysGo-VVI,20
|
2
2
|
beswarm/utils.py,sha256=AdDCcqAIIKQEMl7PfryVgeT9G5sHe7QNsZnrvmTGA8E,283
|
3
3
|
beswarm/aient/main.py,sha256=SiYAIgQlLJqYusnTVEJOx1WNkSJKMImhgn5aWjfroxg,3814
|
4
|
-
beswarm/aient/setup.py,sha256=
|
4
|
+
beswarm/aient/setup.py,sha256=gX6fAYtVyLi9NHeEKPMGPF4IxtQ9MvEnJy1besWMS5U,487
|
5
5
|
beswarm/aient/src/aient/__init__.py,sha256=SRfF7oDVlOOAi6nGKiJIUK6B_arqYLO9iSMp-2IZZps,21
|
6
6
|
beswarm/aient/src/aient/core/__init__.py,sha256=NxjebTlku35S4Dzr16rdSqSTWUvvwEeACe8KvHJnjPg,34
|
7
7
|
beswarm/aient/src/aient/core/log_config.py,sha256=kz2_yJv1p-o3lUQOwA3qh-LSc3wMHv13iCQclw44W9c,274
|
@@ -17,14 +17,14 @@ beswarm/aient/src/aient/models/__init__.py,sha256=ouNDNvoBBpIFrLsk09Q_sq23HR0GbL
|
|
17
17
|
beswarm/aient/src/aient/models/audio.py,sha256=kRd-8-WXzv4vwvsTGwnstK-WR8--vr9CdfCZzu8y9LA,1934
|
18
18
|
beswarm/aient/src/aient/models/base.py,sha256=z-Z0pJfTN2x0cuwfvu0BdMRY9O-RmLwHEnBIJN1x4Fg,6719
|
19
19
|
beswarm/aient/src/aient/models/chatgpt.py,sha256=-NWkkKxTCyraPYT0YN37NA2rUfOaDNXtvFSQmIE5tS8,45066
|
20
|
-
beswarm/aient/src/aient/models/claude.py,sha256=
|
20
|
+
beswarm/aient/src/aient/models/claude.py,sha256=JezghW7y0brl4Y5qiSHvnYR5prQCFywX4RViHt39pGI,26037
|
21
21
|
beswarm/aient/src/aient/models/duckduckgo.py,sha256=1l7vYCs9SG5SWPCbcl7q6pCcB5AUF_r-a4l9frz3Ogo,8115
|
22
22
|
beswarm/aient/src/aient/models/gemini.py,sha256=chGLc-8G_DAOxr10HPoOhvVFW1RvMgHd6mt--VyAW98,14730
|
23
|
-
beswarm/aient/src/aient/models/groq.py,sha256=
|
23
|
+
beswarm/aient/src/aient/models/groq.py,sha256=eXfSOaPxgQEtk4U8qseArN8rFYOFBfMsPwRcDW1nERo,8790
|
24
24
|
beswarm/aient/src/aient/models/vertex.py,sha256=qVD5l1Q538xXUPulxG4nmDjXE1VoV4yuAkTCpIeJVw0,16795
|
25
25
|
beswarm/aient/src/aient/plugins/__init__.py,sha256=p3KO6Aa3Lupos4i2SjzLQw1hzQTigOAfEHngsldrsyk,986
|
26
26
|
beswarm/aient/src/aient/plugins/arXiv.py,sha256=yHjb6PS3GUWazpOYRMKMzghKJlxnZ5TX8z9F6UtUVow,1461
|
27
|
-
beswarm/aient/src/aient/plugins/config.py,sha256=
|
27
|
+
beswarm/aient/src/aient/plugins/config.py,sha256=Vp6CG9ocdC_FAlCMEGtKj45xamir76DFxdJVvURNtog,6539
|
28
28
|
beswarm/aient/src/aient/plugins/excute_command.py,sha256=u-JOZ21dDcDx1j3O0KVIHAsa6MNuOxHFBdV3iCnTih0,5413
|
29
29
|
beswarm/aient/src/aient/plugins/get_time.py,sha256=Ih5XIW5SDAIhrZ9W4Qe5Hs1k4ieKPUc_LAd6ySNyqZk,654
|
30
30
|
beswarm/aient/src/aient/plugins/image.py,sha256=ZElCIaZznE06TN9xW3DrSukS7U3A5_cjk1Jge4NzPxw,2072
|
@@ -32,13 +32,13 @@ beswarm/aient/src/aient/plugins/list_directory.py,sha256=5ubm-mfrj-tanGSDp4M_Tmb
|
|
32
32
|
beswarm/aient/src/aient/plugins/read_file.py,sha256=cJxGnhcz1_gjkgeemVyixLUiCvf-dWm-UtDfrbFdlLE,4857
|
33
33
|
beswarm/aient/src/aient/plugins/registry.py,sha256=YknzhieU_8nQ3oKlUSSWDB4X7t2Jx0JnqT2Jd9Xsvfk,3574
|
34
34
|
beswarm/aient/src/aient/plugins/run_python.py,sha256=dgcUwBunMuDkaSKR5bToudVzSdrXVewktDDFUz_iIOQ,4589
|
35
|
-
beswarm/aient/src/aient/plugins/websearch.py,sha256=
|
35
|
+
beswarm/aient/src/aient/plugins/websearch.py,sha256=a-JJZjEZ5MEQ9WBMkD7okBHYehLSzApoLMiyqBnxDqs,15193
|
36
36
|
beswarm/aient/src/aient/plugins/write_file.py,sha256=qmT6iQ3mDyVAa9Sld1jfJq0KPZj0w2kRIHq0JyjpGeA,1853
|
37
37
|
beswarm/aient/src/aient/prompt/__init__.py,sha256=GBtn6-JDT8KHFCcuPpfSNE_aGddg5p4FEyMCy4BfwGs,20
|
38
38
|
beswarm/aient/src/aient/prompt/agent.py,sha256=3VycHGnUq9OdR5pd_RM0AeLESlpAgBcmzrsesfq82X0,23856
|
39
39
|
beswarm/aient/src/aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
40
|
beswarm/aient/src/aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
|
41
|
-
beswarm/aient/src/aient/utils/scripts.py,sha256=
|
41
|
+
beswarm/aient/src/aient/utils/scripts.py,sha256=JbYHsU3LLtxBcuO_2MWbSgpHpCgtVQe9FGEFJpUyejc,25926
|
42
42
|
beswarm/aient/test/chatgpt.py,sha256=Hvl7FuDt1c74N5TVBmhErOPvJbJJzA7FNp5VoZM4u30,4957
|
43
43
|
beswarm/aient/test/claude.py,sha256=IyB4qI1eJLwlSfDNSnt2FhbQWYyBighHUjJxEXc3osQ,1095
|
44
44
|
beswarm/aient/test/test.py,sha256=rldnoLQdtRR8IKFSIzTti7eIK2MpPMoi9gL5qD8_K44,29
|
@@ -58,7 +58,6 @@ beswarm/aient/test/test_get_token_dict.py,sha256=QuR67aUbS9hTwBjndsr1u7-juopEe1K
|
|
58
58
|
beswarm/aient/test/test_google_search.py,sha256=rPaKqD_N3ogHYE5DrMfRmKumcVAHKC7LcYw5euR_zGM,1035
|
59
59
|
beswarm/aient/test/test_jieba.py,sha256=ydqIrPtJ71cgbQSXpkS_g1kSiBzEpk0mjv6N-6ETw4g,1139
|
60
60
|
beswarm/aient/test/test_json.py,sha256=cbKSwwSwt1A9sdn5vO_5cGca0x2rR4skejAgb8uDDu0,2284
|
61
|
-
beswarm/aient/test/test_langchain_search_old.py,sha256=QGZSYi-aBB5qrKPI64qfgENbGozfrSGQBpNZpHt0d7k,9066
|
62
61
|
beswarm/aient/test/test_logging.py,sha256=DFZ2KqrTVH6FQ5BKJIQudZxWRUdkzWka2QjmtVYPXvw,995
|
63
62
|
beswarm/aient/test/test_ollama.py,sha256=ywy9l06S1g1AnWQvlBbhpac7i-hBB9bpwi-pk0Afivc,1325
|
64
63
|
beswarm/aient/test/test_plugin.py,sha256=0sBwpf1YdKba-IVPZwBMKbLR7buHfudLS9NOETm7BTc,779
|
@@ -128,7 +127,7 @@ beswarm/tools/repomap.py,sha256=CwvwoN5Swr42EzrORTTeV8MMb7mPviy4a4b0fxBu50k,4082
|
|
128
127
|
beswarm/tools/search_arxiv.py,sha256=9slwBemXjEqrd7-YgVmyMijPXlkhZCybEDRVhWVQ9B0,7937
|
129
128
|
beswarm/tools/think.py,sha256=WLw-7jNIsnS6n8MMSYUin_f-BGLENFmnKM2LISEp0co,1760
|
130
129
|
beswarm/tools/worker.py,sha256=FfKCx7KFNbMRoAXtjU1_nJQjx9WHny7KBq8OXSYICJs,5334
|
131
|
-
beswarm-0.1.
|
132
|
-
beswarm-0.1.
|
133
|
-
beswarm-0.1.
|
134
|
-
beswarm-0.1.
|
130
|
+
beswarm-0.1.39.dist-info/METADATA,sha256=pOFh4a12JYNDKCeFjK6k9LoB6JGlqvc7EFVNVSWsqg8,3208
|
131
|
+
beswarm-0.1.39.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
132
|
+
beswarm-0.1.39.dist-info/top_level.txt,sha256=pJw4O87wvt5882smuSO6DfByJz7FJ8SxxT8h9fHCmpo,8
|
133
|
+
beswarm-0.1.39.dist-info/RECORD,,
|
@@ -1,235 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import re
|
3
|
-
|
4
|
-
import sys
|
5
|
-
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
6
|
-
import config
|
7
|
-
|
8
|
-
from langchain.chat_models import ChatOpenAI
|
9
|
-
|
10
|
-
|
11
|
-
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
|
12
|
-
|
13
|
-
from langchain.prompts.chat import (
|
14
|
-
ChatPromptTemplate,
|
15
|
-
SystemMessagePromptTemplate,
|
16
|
-
HumanMessagePromptTemplate,
|
17
|
-
)
|
18
|
-
from langchain.embeddings.openai import OpenAIEmbeddings
|
19
|
-
from langchain.vectorstores import Chroma
|
20
|
-
from langchain.text_splitter import CharacterTextSplitter
|
21
|
-
|
22
|
-
from langchain.document_loaders import UnstructuredPDFLoader
|
23
|
-
|
24
|
-
def getmd5(string):
|
25
|
-
import hashlib
|
26
|
-
md5_hash = hashlib.md5()
|
27
|
-
md5_hash.update(string.encode('utf-8'))
|
28
|
-
md5_hex = md5_hash.hexdigest()
|
29
|
-
return md5_hex
|
30
|
-
|
31
|
-
from utils.sitemap import SitemapLoader
|
32
|
-
async def get_doc_from_sitemap(url):
|
33
|
-
# https://www.langchain.asia/modules/indexes/document_loaders/examples/sitemap#%E8%BF%87%E6%BB%A4%E7%AB%99%E7%82%B9%E5%9C%B0%E5%9B%BE-url-
|
34
|
-
sitemap_loader = SitemapLoader(web_path=url)
|
35
|
-
docs = await sitemap_loader.load()
|
36
|
-
return docs
|
37
|
-
|
38
|
-
async def get_doc_from_local(docpath, doctype="md"):
|
39
|
-
from langchain.document_loaders import DirectoryLoader
|
40
|
-
# 加载文件夹中的所有txt类型的文件
|
41
|
-
loader = DirectoryLoader(docpath, glob='**/*.' + doctype)
|
42
|
-
# 将数据转成 document 对象,每个文件会作为一个 document
|
43
|
-
documents = loader.load()
|
44
|
-
return documents
|
45
|
-
|
46
|
-
system_template="""Use the following pieces of context to answer the users question.
|
47
|
-
If you don't know the answer, just say "Hmm..., I'm not sure.", don't try to make up an answer.
|
48
|
-
ALWAYS return a "Sources" part in your answer.
|
49
|
-
The "Sources" part should be a reference to the source of the document from which you got your answer.
|
50
|
-
|
51
|
-
Example of your response should be:
|
52
|
-
|
53
|
-
```
|
54
|
-
The answer is foo
|
55
|
-
|
56
|
-
Sources:
|
57
|
-
1. abc
|
58
|
-
2. xyz
|
59
|
-
```
|
60
|
-
Begin!
|
61
|
-
----------------
|
62
|
-
{summaries}
|
63
|
-
"""
|
64
|
-
messages = [
|
65
|
-
SystemMessagePromptTemplate.from_template(system_template),
|
66
|
-
HumanMessagePromptTemplate.from_template("{question}")
|
67
|
-
]
|
68
|
-
prompt = ChatPromptTemplate.from_messages(messages)
|
69
|
-
|
70
|
-
def get_chain(store, llm):
|
71
|
-
chain_type_kwargs = {"prompt": prompt}
|
72
|
-
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
73
|
-
llm,
|
74
|
-
chain_type="stuff",
|
75
|
-
retriever=store.as_retriever(),
|
76
|
-
chain_type_kwargs=chain_type_kwargs,
|
77
|
-
reduce_k_below_max_tokens=True
|
78
|
-
)
|
79
|
-
return chain
|
80
|
-
|
81
|
-
async def docQA(docpath, query_message, persist_db_path="db", model = "gpt-3.5-turbo"):
|
82
|
-
chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=config.API)
|
83
|
-
embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=config.API)
|
84
|
-
|
85
|
-
sitemap = "sitemap.xml"
|
86
|
-
match = re.match(r'^(https?|ftp)://[^\s/$.?#].[^\s]*$', docpath)
|
87
|
-
if match:
|
88
|
-
doc_method = get_doc_from_sitemap
|
89
|
-
docpath = os.path.join(docpath, sitemap)
|
90
|
-
else:
|
91
|
-
doc_method = get_doc_from_local
|
92
|
-
|
93
|
-
persist_db_path = getmd5(docpath)
|
94
|
-
if not os.path.exists(persist_db_path):
|
95
|
-
documents = await doc_method(docpath)
|
96
|
-
# 初始化加载器
|
97
|
-
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=50)
|
98
|
-
# 持久化数据
|
99
|
-
split_docs = text_splitter.split_documents(documents)
|
100
|
-
vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path)
|
101
|
-
vector_store.persist()
|
102
|
-
else:
|
103
|
-
# 加载数据
|
104
|
-
vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings)
|
105
|
-
|
106
|
-
# 创建问答对象
|
107
|
-
qa = get_chain(vector_store, chatllm)
|
108
|
-
# qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True)
|
109
|
-
# 进行问答
|
110
|
-
result = qa({"question": query_message})
|
111
|
-
return result
|
112
|
-
|
113
|
-
|
114
|
-
def persist_emdedding_pdf(docurl, persist_db_path):
|
115
|
-
embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None))
|
116
|
-
filename = get_doc_from_url(docurl)
|
117
|
-
docpath = os.getcwd() + "/" + filename
|
118
|
-
loader = UnstructuredPDFLoader(docpath)
|
119
|
-
documents = loader.load()
|
120
|
-
# 初始化加载器
|
121
|
-
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25)
|
122
|
-
# 切割加载的 document
|
123
|
-
split_docs = text_splitter.split_documents(documents)
|
124
|
-
vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path)
|
125
|
-
vector_store.persist()
|
126
|
-
os.remove(docpath)
|
127
|
-
return vector_store
|
128
|
-
|
129
|
-
async def pdfQA(docurl, docpath, query_message, model="gpt-3.5-turbo"):
|
130
|
-
chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=os.environ.get('API', None))
|
131
|
-
embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None))
|
132
|
-
persist_db_path = getmd5(docpath)
|
133
|
-
if not os.path.exists(persist_db_path):
|
134
|
-
vector_store = persist_emdedding_pdf(docurl, persist_db_path)
|
135
|
-
else:
|
136
|
-
vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings)
|
137
|
-
qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True)
|
138
|
-
result = qa({"query": query_message})
|
139
|
-
return result['result']
|
140
|
-
|
141
|
-
|
142
|
-
def pdf_search(docurl, query_message, model="gpt-3.5-turbo"):
|
143
|
-
chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=os.environ.get('API', None))
|
144
|
-
embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None))
|
145
|
-
filename = get_doc_from_url(docurl)
|
146
|
-
docpath = os.getcwd() + "/" + filename
|
147
|
-
loader = UnstructuredPDFLoader(docpath)
|
148
|
-
try:
|
149
|
-
documents = loader.load()
|
150
|
-
except:
|
151
|
-
print("pdf load error! docpath:", docpath)
|
152
|
-
return ""
|
153
|
-
os.remove(docpath)
|
154
|
-
# 初始化加载器
|
155
|
-
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25)
|
156
|
-
# 切割加载的 document
|
157
|
-
split_docs = text_splitter.split_documents(documents)
|
158
|
-
vector_store = Chroma.from_documents(split_docs, embeddings)
|
159
|
-
# 创建问答对象
|
160
|
-
qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True)
|
161
|
-
# 进行问答
|
162
|
-
result = qa({"query": query_message})
|
163
|
-
return result['result']
|
164
|
-
|
165
|
-
def summary_each_url(threads, chainllm, prompt):
|
166
|
-
summary_prompt = PromptTemplate(
|
167
|
-
input_variables=["web_summary", "question", "language"],
|
168
|
-
template=(
|
169
|
-
"You need to response the following question: {question}."
|
170
|
-
"Your task is answer the above question in {language} based on the Search results provided. Provide a detailed and in-depth response"
|
171
|
-
"If there is no relevant content in the search results, just answer None, do not make any explanations."
|
172
|
-
"Search results: {web_summary}."
|
173
|
-
),
|
174
|
-
)
|
175
|
-
summary_threads = []
|
176
|
-
|
177
|
-
for t in threads:
|
178
|
-
tmp = t.join()
|
179
|
-
print(tmp)
|
180
|
-
chain = LLMChain(llm=chainllm, prompt=summary_prompt)
|
181
|
-
chain_thread = ThreadWithReturnValue(target=chain.run, args=({"web_summary": tmp, "question": prompt, "language": config.LANGUAGE},))
|
182
|
-
chain_thread.start()
|
183
|
-
summary_threads.append(chain_thread)
|
184
|
-
|
185
|
-
url_result = ""
|
186
|
-
for t in summary_threads:
|
187
|
-
tmp = t.join()
|
188
|
-
print("summary", tmp)
|
189
|
-
if tmp != "None":
|
190
|
-
url_result += "\n\n" + tmp
|
191
|
-
return url_result
|
192
|
-
|
193
|
-
def get_search_results(prompt: str, context_max_tokens: int):
|
194
|
-
|
195
|
-
url_text_list = get_url_text_list(prompt)
|
196
|
-
useful_source_text = "\n\n".join(url_text_list)
|
197
|
-
# useful_source_text = summary_each_url(threads, chainllm, prompt)
|
198
|
-
|
199
|
-
useful_source_text, search_tokens_len = cut_message(useful_source_text, context_max_tokens)
|
200
|
-
print("search tokens len", search_tokens_len, "\n\n")
|
201
|
-
|
202
|
-
return useful_source_text
|
203
|
-
|
204
|
-
from typing import Any
|
205
|
-
from langchain.schema.output import LLMResult
|
206
|
-
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
207
|
-
class ChainStreamHandler(StreamingStdOutCallbackHandler):
|
208
|
-
def __init__(self):
|
209
|
-
self.tokens = []
|
210
|
-
# 记得结束后这里置true
|
211
|
-
self.finish = False
|
212
|
-
self.answer = ""
|
213
|
-
|
214
|
-
def on_llm_new_token(self, token: str, **kwargs):
|
215
|
-
# print(token)
|
216
|
-
self.tokens.append(token)
|
217
|
-
# yield ''.join(self.tokens)
|
218
|
-
# print(''.join(self.tokens))
|
219
|
-
|
220
|
-
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
221
|
-
self.finish = 1
|
222
|
-
|
223
|
-
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
|
224
|
-
print(str(error))
|
225
|
-
self.tokens.append(str(error))
|
226
|
-
|
227
|
-
def generate_tokens(self):
|
228
|
-
while not self.finish or self.tokens:
|
229
|
-
if self.tokens:
|
230
|
-
data = self.tokens.pop(0)
|
231
|
-
self.answer += data
|
232
|
-
yield data
|
233
|
-
else:
|
234
|
-
pass
|
235
|
-
return self.answer
|
File without changes
|
File without changes
|