beswarm 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- beswarm/aient/main.py +50 -0
- beswarm/aient/setup.py +15 -0
- beswarm/aient/src/aient/__init__.py +1 -0
- beswarm/aient/src/aient/core/__init__.py +1 -0
- beswarm/aient/src/aient/core/log_config.py +6 -0
- beswarm/aient/src/aient/core/models.py +232 -0
- beswarm/aient/src/aient/core/request.py +1665 -0
- beswarm/aient/src/aient/core/response.py +617 -0
- beswarm/aient/src/aient/core/test/test_base_api.py +18 -0
- beswarm/aient/src/aient/core/test/test_image.py +15 -0
- beswarm/aient/src/aient/core/test/test_payload.py +92 -0
- beswarm/aient/src/aient/core/utils.py +715 -0
- beswarm/aient/src/aient/models/__init__.py +9 -0
- beswarm/aient/src/aient/models/audio.py +63 -0
- beswarm/aient/src/aient/models/base.py +251 -0
- beswarm/aient/src/aient/models/chatgpt.py +938 -0
- beswarm/aient/src/aient/models/claude.py +640 -0
- beswarm/aient/src/aient/models/duckduckgo.py +241 -0
- beswarm/aient/src/aient/models/gemini.py +357 -0
- beswarm/aient/src/aient/models/groq.py +268 -0
- beswarm/aient/src/aient/models/vertex.py +420 -0
- beswarm/aient/src/aient/plugins/__init__.py +33 -0
- beswarm/aient/src/aient/plugins/arXiv.py +48 -0
- beswarm/aient/src/aient/plugins/config.py +172 -0
- beswarm/aient/src/aient/plugins/excute_command.py +35 -0
- beswarm/aient/src/aient/plugins/get_time.py +19 -0
- beswarm/aient/src/aient/plugins/image.py +72 -0
- beswarm/aient/src/aient/plugins/list_directory.py +50 -0
- beswarm/aient/src/aient/plugins/read_file.py +79 -0
- beswarm/aient/src/aient/plugins/registry.py +116 -0
- beswarm/aient/src/aient/plugins/run_python.py +156 -0
- beswarm/aient/src/aient/plugins/websearch.py +394 -0
- beswarm/aient/src/aient/plugins/write_file.py +51 -0
- beswarm/aient/src/aient/prompt/__init__.py +1 -0
- beswarm/aient/src/aient/prompt/agent.py +280 -0
- beswarm/aient/src/aient/utils/__init__.py +0 -0
- beswarm/aient/src/aient/utils/prompt.py +143 -0
- beswarm/aient/src/aient/utils/scripts.py +721 -0
- beswarm/aient/test/chatgpt.py +161 -0
- beswarm/aient/test/claude.py +32 -0
- beswarm/aient/test/test.py +2 -0
- beswarm/aient/test/test_API.py +6 -0
- beswarm/aient/test/test_Deepbricks.py +20 -0
- beswarm/aient/test/test_Web_crawler.py +262 -0
- beswarm/aient/test/test_aiwaves.py +25 -0
- beswarm/aient/test/test_aiwaves_arxiv.py +19 -0
- beswarm/aient/test/test_ask_gemini.py +8 -0
- beswarm/aient/test/test_class.py +17 -0
- beswarm/aient/test/test_claude.py +23 -0
- beswarm/aient/test/test_claude_zh_char.py +26 -0
- beswarm/aient/test/test_ddg_search.py +50 -0
- beswarm/aient/test/test_download_pdf.py +56 -0
- beswarm/aient/test/test_gemini.py +97 -0
- beswarm/aient/test/test_get_token_dict.py +21 -0
- beswarm/aient/test/test_google_search.py +35 -0
- beswarm/aient/test/test_jieba.py +32 -0
- beswarm/aient/test/test_json.py +65 -0
- beswarm/aient/test/test_langchain_search_old.py +235 -0
- beswarm/aient/test/test_logging.py +32 -0
- beswarm/aient/test/test_ollama.py +55 -0
- beswarm/aient/test/test_plugin.py +16 -0
- beswarm/aient/test/test_py_run.py +26 -0
- beswarm/aient/test/test_requests.py +162 -0
- beswarm/aient/test/test_search.py +18 -0
- beswarm/aient/test/test_tikitoken.py +19 -0
- beswarm/aient/test/test_token.py +94 -0
- beswarm/aient/test/test_url.py +33 -0
- beswarm/aient/test/test_whisper.py +14 -0
- beswarm/aient/test/test_wildcard.py +20 -0
- beswarm/aient/test/test_yjh.py +21 -0
- {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/METADATA +1 -1
- beswarm-0.1.13.dist-info/RECORD +131 -0
- beswarm-0.1.12.dist-info/RECORD +0 -61
- {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/WHEEL +0 -0
- {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,56 @@
|
|
1
|
+
# import requests
|
2
|
+
# import urllib.parse
|
3
|
+
# import os
|
4
|
+
# import sys
|
5
|
+
# sys.path.append(os.getcwd())
|
6
|
+
# import config
|
7
|
+
|
8
|
+
# from langchain.chat_models import ChatOpenAI
|
9
|
+
# from langchain.embeddings.openai import OpenAIEmbeddings
|
10
|
+
# from langchain.vectorstores import Chroma
|
11
|
+
# from langchain.text_splitter import CharacterTextSplitter
|
12
|
+
# from langchain.document_loaders import UnstructuredPDFLoader
|
13
|
+
# from langchain.chains import RetrievalQA
|
14
|
+
|
15
|
+
|
16
|
+
# def get_doc_from_url(url):
|
17
|
+
# filename = urllib.parse.unquote(url.split("/")[-1])
|
18
|
+
# response = requests.get(url, stream=True)
|
19
|
+
# with open(filename, 'wb') as f:
|
20
|
+
# for chunk in response.iter_content(chunk_size=1024):
|
21
|
+
# f.write(chunk)
|
22
|
+
# return filename
|
23
|
+
|
24
|
+
# def pdf_search(docurl, query_message, model="gpt-3.5-turbo"):
|
25
|
+
# chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.API_URL.split("chat")[0], model_name=model, openai_api_key=os.environ.get('API', None))
|
26
|
+
# embeddings = OpenAIEmbeddings(openai_api_base=config.API_URL.split("chat")[0], openai_api_key=os.environ.get('API', None))
|
27
|
+
# filename = get_doc_from_url(docurl)
|
28
|
+
# docpath = os.getcwd() + "/" + filename
|
29
|
+
# loader = UnstructuredPDFLoader(docpath)
|
30
|
+
# print(docpath)
|
31
|
+
# documents = loader.load()
|
32
|
+
# os.remove(docpath)
|
33
|
+
# # 初始化加载器
|
34
|
+
# text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25)
|
35
|
+
# # 切割加载的 document
|
36
|
+
# split_docs = text_splitter.split_documents(documents)
|
37
|
+
# vector_store = Chroma.from_documents(split_docs, embeddings)
|
38
|
+
# # 创建问答对象
|
39
|
+
# qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True)
|
40
|
+
# # 进行问答
|
41
|
+
# result = qa({"query": query_message})
|
42
|
+
# return result['result']
|
43
|
+
|
44
|
+
# pdf_search("https://www.nsfc.gov.cn/csc/20345/22468/pdf/2001/%E5%86%BB%E7%BB%93%E8%A3%82%E9%9A%99%E7%A0%82%E5%B2%A9%E4%BD%8E%E5%91%A8%E5%BE%AA%E7%8E%AF%E5%8A%A8%E5%8A%9B%E7%89%B9%E6%80%A7%E8%AF%95%E9%AA%8C%E7%A0%94%E7%A9%B6.pdf", "端水实验的目的是什么?")
|
45
|
+
|
46
|
+
from PyPDF2 import PdfReader
|
47
|
+
|
48
|
+
def has_text(pdf_path):
|
49
|
+
with open(pdf_path, 'rb') as file:
|
50
|
+
pdf = PdfReader(file)
|
51
|
+
page = pdf.pages[0]
|
52
|
+
text = page.extract_text()
|
53
|
+
return text
|
54
|
+
|
55
|
+
pdf_path = '/Users/yanyuming/Downloads/GitHub/ChatGPT-Telegram-Bot/冻结裂隙砂岩低周循环动力特性试验研究.pdf'
|
56
|
+
print(has_text(pdf_path))
|
@@ -0,0 +1,97 @@
|
|
1
|
+
import json
|
2
|
+
|
3
|
+
class JSONExtractor:
|
4
|
+
def __init__(self):
|
5
|
+
self.buffer = ""
|
6
|
+
self.bracket_count = 0
|
7
|
+
self.in_target = False
|
8
|
+
self.target_json = ""
|
9
|
+
|
10
|
+
def process_line(self, line):
|
11
|
+
self.buffer += line.strip()
|
12
|
+
|
13
|
+
for char in line:
|
14
|
+
if char == '{':
|
15
|
+
self.bracket_count += 1
|
16
|
+
if self.bracket_count == 4 and '"functionCall"' in self.buffer[-20:]:
|
17
|
+
self.in_target = True
|
18
|
+
self.target_json = '{'
|
19
|
+
elif char == '}':
|
20
|
+
if self.in_target:
|
21
|
+
self.target_json += '}'
|
22
|
+
self.bracket_count -= 1
|
23
|
+
if self.bracket_count == 3 and self.in_target:
|
24
|
+
self.in_target = False
|
25
|
+
return self.parse_target_json()
|
26
|
+
|
27
|
+
if self.in_target:
|
28
|
+
self.target_json += char
|
29
|
+
|
30
|
+
return None
|
31
|
+
|
32
|
+
def parse_target_json(self):
|
33
|
+
try:
|
34
|
+
parsed = json.loads(self.target_json)
|
35
|
+
if 'functionCall' in parsed:
|
36
|
+
return parsed['functionCall']
|
37
|
+
except json.JSONDecodeError:
|
38
|
+
pass
|
39
|
+
return None
|
40
|
+
|
41
|
+
# 使用示例
|
42
|
+
extractor = JSONExtractor()
|
43
|
+
|
44
|
+
# 模拟流式接收数据
|
45
|
+
sample_lines = [
|
46
|
+
'{\n',
|
47
|
+
' "candidates": [\n',
|
48
|
+
' {\n',
|
49
|
+
' "content": {\n',
|
50
|
+
' "parts": [\n',
|
51
|
+
' {\n',
|
52
|
+
' "functionCall": {\n',
|
53
|
+
' "name": "get_search_results",\n',
|
54
|
+
' "args": {\n',
|
55
|
+
' "prompt": "Claude Opus 3.5 release date"\n',
|
56
|
+
' }\n',
|
57
|
+
' }\n',
|
58
|
+
' }\n',
|
59
|
+
' ],\n',
|
60
|
+
' "role": "model"\n',
|
61
|
+
' },\n',
|
62
|
+
' "finishReason": "STOP",\n',
|
63
|
+
' "index": 0,\n',
|
64
|
+
' "safetyRatings": [\n',
|
65
|
+
' {\n',
|
66
|
+
' "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",\n',
|
67
|
+
' "probability": "NEGLIGIBLE"\n',
|
68
|
+
' },\n',
|
69
|
+
' {\n',
|
70
|
+
' "category": "HARM_CATEGORY_HATE_SPEECH",\n',
|
71
|
+
' "probability": "NEGLIGIBLE"\n',
|
72
|
+
' },\n',
|
73
|
+
' {\n',
|
74
|
+
' "category": "HARM_CATEGORY_HARASSMENT",\n',
|
75
|
+
' "probability": "NEGLIGIBLE"\n',
|
76
|
+
' },\n',
|
77
|
+
' {\n',
|
78
|
+
' "category": "HARM_CATEGORY_DANGEROUS_CONTENT",\n',
|
79
|
+
' "probability": "NEGLIGIBLE"\n',
|
80
|
+
' }\n',
|
81
|
+
' ]\n',
|
82
|
+
' }\n',
|
83
|
+
' ],\n',
|
84
|
+
' "usageMetadata": {\n',
|
85
|
+
' "promptTokenCount": 113,\n',
|
86
|
+
' "candidatesTokenCount": 55,\n',
|
87
|
+
' "totalTokenCount": 168\n',
|
88
|
+
' }\n',
|
89
|
+
'}\n'
|
90
|
+
]
|
91
|
+
|
92
|
+
for line in sample_lines:
|
93
|
+
result = extractor.process_line(line)
|
94
|
+
if result:
|
95
|
+
print("提取的functionCall:")
|
96
|
+
print(json.dumps(result, indent=2, ensure_ascii=False))
|
97
|
+
break
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
|
3
|
+
# 定义一个默认值工厂函数,这里使用int来初始化为0
|
4
|
+
default_dict = defaultdict(int)
|
5
|
+
|
6
|
+
# 示例用法
|
7
|
+
print(default_dict['a']) # 输出: 0,因为'a'不存在,自动初始化为0
|
8
|
+
default_dict['a'] += 1
|
9
|
+
print(default_dict['a']) # 输出: 1
|
10
|
+
|
11
|
+
# 你也可以使用其他类型的工厂函数,例如list
|
12
|
+
list_default_dict = defaultdict(list)
|
13
|
+
print(list_default_dict['b']) # 输出: [],因为'b'不存在,自动初始化为空列表
|
14
|
+
list_default_dict['b'].append(2)
|
15
|
+
print(list_default_dict['b']) # 输出: [2]
|
16
|
+
|
17
|
+
# 如果你有一个现有的字典,也可以将其转换为defaultdict
|
18
|
+
existing_dict = {'c': 3, 'd': 4}
|
19
|
+
default_dict = defaultdict(int, existing_dict)
|
20
|
+
print(default_dict['c']) # 输出: 3
|
21
|
+
print(default_dict['e']) # 输出: 0,因为'e'不存在,自动初始化为0
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import os
|
2
|
+
import requests
|
3
|
+
from googleapiclient.discovery import build
|
4
|
+
from dotenv import load_dotenv
|
5
|
+
load_dotenv()
|
6
|
+
|
7
|
+
search_engine_id = os.environ.get('GOOGLE_CSE_ID', None)
|
8
|
+
api_key = os.environ.get('GOOGLE_API_KEY', None)
|
9
|
+
query = "Python 编程"
|
10
|
+
|
11
|
+
def google_search1(query, api_key, search_engine_id):
|
12
|
+
service = build("customsearch", "v1", developerKey=api_key)
|
13
|
+
res = service.cse().list(q=query, cx=search_engine_id).execute()
|
14
|
+
link_list = [item['link'] for item in res['items']]
|
15
|
+
return link_list
|
16
|
+
|
17
|
+
def google_search2(query, api_key, cx):
|
18
|
+
url = "https://www.googleapis.com/customsearch/v1"
|
19
|
+
params = {
|
20
|
+
'q': query,
|
21
|
+
'key': api_key,
|
22
|
+
'cx': cx
|
23
|
+
}
|
24
|
+
response = requests.get(url, params=params)
|
25
|
+
print(response.text)
|
26
|
+
results = response.json()
|
27
|
+
link_list = [item['link'] for item in results.get('items', [])]
|
28
|
+
|
29
|
+
return link_list
|
30
|
+
|
31
|
+
# results = google_search1(query, api_key, search_engine_id)
|
32
|
+
# print(results)
|
33
|
+
|
34
|
+
results = google_search2(query, api_key, search_engine_id)
|
35
|
+
print(results)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
import jieba
|
2
|
+
import jieba.analyse
|
3
|
+
|
4
|
+
# 加载文本
|
5
|
+
# text = "话说葬送的芙莉莲动漫是半年番还是季番?完结没?"
|
6
|
+
# text = "民进党当初为什么支持柯文哲选台北市长?"
|
7
|
+
text = "今天的微博热搜有哪些?"
|
8
|
+
# text = "How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?"
|
9
|
+
|
10
|
+
# 使用TF-IDF算法提取关键词
|
11
|
+
keywords_tfidf = jieba.analyse.extract_tags(text, topK=10, withWeight=False, allowPOS=())
|
12
|
+
|
13
|
+
# 使用TextRank算法提取关键词
|
14
|
+
keywords_textrank = jieba.analyse.textrank(text, topK=10, withWeight=False, allowPOS=('ns', 'n', 'vn', 'v'))
|
15
|
+
|
16
|
+
print("TF-IDF算法提取的关键词:", keywords_tfidf)
|
17
|
+
print("TextRank算法提取的关键词:", keywords_textrank)
|
18
|
+
|
19
|
+
|
20
|
+
seg_list = jieba.cut(text, cut_all=True)
|
21
|
+
print("Full Mode: " + " ".join(seg_list)) # 全模式
|
22
|
+
|
23
|
+
seg_list = jieba.cut(text, cut_all=False)
|
24
|
+
print("Default Mode: " + " ".join(seg_list)) # 精确模式
|
25
|
+
|
26
|
+
seg_list = jieba.cut(text) # 默认是精确模式
|
27
|
+
print(" ".join(seg_list))
|
28
|
+
|
29
|
+
seg_list = jieba.cut_for_search(text) # 搜索引擎模式
|
30
|
+
result = " ".join(seg_list)
|
31
|
+
|
32
|
+
print([result] * 3)
|
@@ -0,0 +1,65 @@
|
|
1
|
+
import json
|
2
|
+
|
3
|
+
# json_data = '爱'
|
4
|
+
# # json_data = '爱的主人,我会尽快为您规划一个走线到美国的安全路线。请您稍等片刻。\n\n首先,我会检查免签国家并为您提供相应的信息。接下来,我会 搜索有关旅行到美国的安全建议和路线规划。{}'
|
5
|
+
|
6
|
+
def split_json_strings(input_string):
|
7
|
+
# 初始化结果列表和当前 JSON 字符串
|
8
|
+
json_strings = []
|
9
|
+
current_json = ""
|
10
|
+
brace_count = 0
|
11
|
+
|
12
|
+
# 遍历输入字符串的每个字符
|
13
|
+
for char in input_string:
|
14
|
+
current_json += char
|
15
|
+
if char == '{':
|
16
|
+
brace_count += 1
|
17
|
+
elif char == '}':
|
18
|
+
brace_count -= 1
|
19
|
+
|
20
|
+
# 如果花括号配对完成,我们找到了一个完整的 JSON 字符串
|
21
|
+
if brace_count == 0:
|
22
|
+
# 尝试解析当前 JSON 字符串
|
23
|
+
try:
|
24
|
+
json.loads(current_json)
|
25
|
+
json_strings.append(current_json)
|
26
|
+
current_json = ""
|
27
|
+
except json.JSONDecodeError:
|
28
|
+
# 如果解析失败,继续添加字符
|
29
|
+
pass
|
30
|
+
if json_strings == []:
|
31
|
+
json_strings.append(input_string)
|
32
|
+
return json_strings
|
33
|
+
|
34
|
+
# 测试函数
|
35
|
+
input_string = '{"url": "https://github.com/fastai/fasthtml"'
|
36
|
+
result = split_json_strings(input_string)
|
37
|
+
|
38
|
+
for i, json_str in enumerate(result, 1):
|
39
|
+
print(f"JSON {i}:", json_str)
|
40
|
+
print("Parsed:", json.loads(json_str))
|
41
|
+
print()
|
42
|
+
|
43
|
+
# def check_json(json_data):
|
44
|
+
# while True:
|
45
|
+
# try:
|
46
|
+
# json.loads(json_data)
|
47
|
+
# break
|
48
|
+
# except json.decoder.JSONDecodeError as e:
|
49
|
+
# print("JSON error:", e)
|
50
|
+
# print("JSON body", repr(json_data))
|
51
|
+
# if "Invalid control character" in str(e):
|
52
|
+
# json_data = json_data.replace("\n", "\\n")
|
53
|
+
# if "Unterminated string starting" in str(e):
|
54
|
+
# json_data += '"}'
|
55
|
+
# if "Expecting ',' delimiter" in str(e):
|
56
|
+
# json_data += '}'
|
57
|
+
# if "Expecting value: line 1 column 1" in str(e):
|
58
|
+
# json_data = '{"prompt": ' + json.dumps(json_data) + '}'
|
59
|
+
# return json_data
|
60
|
+
# print(json.loads(check_json(json_data)))
|
61
|
+
|
62
|
+
# a = '''
|
63
|
+
# '''
|
64
|
+
|
65
|
+
# print(json.loads(a))
|
@@ -0,0 +1,235 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
|
4
|
+
import sys
|
5
|
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
6
|
+
import config
|
7
|
+
|
8
|
+
from langchain.chat_models import ChatOpenAI
|
9
|
+
|
10
|
+
|
11
|
+
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
|
12
|
+
|
13
|
+
from langchain.prompts.chat import (
|
14
|
+
ChatPromptTemplate,
|
15
|
+
SystemMessagePromptTemplate,
|
16
|
+
HumanMessagePromptTemplate,
|
17
|
+
)
|
18
|
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
19
|
+
from langchain.vectorstores import Chroma
|
20
|
+
from langchain.text_splitter import CharacterTextSplitter
|
21
|
+
|
22
|
+
from langchain.document_loaders import UnstructuredPDFLoader
|
23
|
+
|
24
|
+
def getmd5(string):
|
25
|
+
import hashlib
|
26
|
+
md5_hash = hashlib.md5()
|
27
|
+
md5_hash.update(string.encode('utf-8'))
|
28
|
+
md5_hex = md5_hash.hexdigest()
|
29
|
+
return md5_hex
|
30
|
+
|
31
|
+
from utils.sitemap import SitemapLoader
|
32
|
+
async def get_doc_from_sitemap(url):
|
33
|
+
# https://www.langchain.asia/modules/indexes/document_loaders/examples/sitemap#%E8%BF%87%E6%BB%A4%E7%AB%99%E7%82%B9%E5%9C%B0%E5%9B%BE-url-
|
34
|
+
sitemap_loader = SitemapLoader(web_path=url)
|
35
|
+
docs = await sitemap_loader.load()
|
36
|
+
return docs
|
37
|
+
|
38
|
+
async def get_doc_from_local(docpath, doctype="md"):
|
39
|
+
from langchain.document_loaders import DirectoryLoader
|
40
|
+
# 加载文件夹中的所有txt类型的文件
|
41
|
+
loader = DirectoryLoader(docpath, glob='**/*.' + doctype)
|
42
|
+
# 将数据转成 document 对象,每个文件会作为一个 document
|
43
|
+
documents = loader.load()
|
44
|
+
return documents
|
45
|
+
|
46
|
+
system_template="""Use the following pieces of context to answer the users question.
|
47
|
+
If you don't know the answer, just say "Hmm..., I'm not sure.", don't try to make up an answer.
|
48
|
+
ALWAYS return a "Sources" part in your answer.
|
49
|
+
The "Sources" part should be a reference to the source of the document from which you got your answer.
|
50
|
+
|
51
|
+
Example of your response should be:
|
52
|
+
|
53
|
+
```
|
54
|
+
The answer is foo
|
55
|
+
|
56
|
+
Sources:
|
57
|
+
1. abc
|
58
|
+
2. xyz
|
59
|
+
```
|
60
|
+
Begin!
|
61
|
+
----------------
|
62
|
+
{summaries}
|
63
|
+
"""
|
64
|
+
messages = [
|
65
|
+
SystemMessagePromptTemplate.from_template(system_template),
|
66
|
+
HumanMessagePromptTemplate.from_template("{question}")
|
67
|
+
]
|
68
|
+
prompt = ChatPromptTemplate.from_messages(messages)
|
69
|
+
|
70
|
+
def get_chain(store, llm):
|
71
|
+
chain_type_kwargs = {"prompt": prompt}
|
72
|
+
chain = RetrievalQAWithSourcesChain.from_chain_type(
|
73
|
+
llm,
|
74
|
+
chain_type="stuff",
|
75
|
+
retriever=store.as_retriever(),
|
76
|
+
chain_type_kwargs=chain_type_kwargs,
|
77
|
+
reduce_k_below_max_tokens=True
|
78
|
+
)
|
79
|
+
return chain
|
80
|
+
|
81
|
+
async def docQA(docpath, query_message, persist_db_path="db", model = "gpt-3.5-turbo"):
|
82
|
+
chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=config.API)
|
83
|
+
embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=config.API)
|
84
|
+
|
85
|
+
sitemap = "sitemap.xml"
|
86
|
+
match = re.match(r'^(https?|ftp)://[^\s/$.?#].[^\s]*$', docpath)
|
87
|
+
if match:
|
88
|
+
doc_method = get_doc_from_sitemap
|
89
|
+
docpath = os.path.join(docpath, sitemap)
|
90
|
+
else:
|
91
|
+
doc_method = get_doc_from_local
|
92
|
+
|
93
|
+
persist_db_path = getmd5(docpath)
|
94
|
+
if not os.path.exists(persist_db_path):
|
95
|
+
documents = await doc_method(docpath)
|
96
|
+
# 初始化加载器
|
97
|
+
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=50)
|
98
|
+
# 持久化数据
|
99
|
+
split_docs = text_splitter.split_documents(documents)
|
100
|
+
vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path)
|
101
|
+
vector_store.persist()
|
102
|
+
else:
|
103
|
+
# 加载数据
|
104
|
+
vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings)
|
105
|
+
|
106
|
+
# 创建问答对象
|
107
|
+
qa = get_chain(vector_store, chatllm)
|
108
|
+
# qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True)
|
109
|
+
# 进行问答
|
110
|
+
result = qa({"question": query_message})
|
111
|
+
return result
|
112
|
+
|
113
|
+
|
114
|
+
def persist_emdedding_pdf(docurl, persist_db_path):
|
115
|
+
embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None))
|
116
|
+
filename = get_doc_from_url(docurl)
|
117
|
+
docpath = os.getcwd() + "/" + filename
|
118
|
+
loader = UnstructuredPDFLoader(docpath)
|
119
|
+
documents = loader.load()
|
120
|
+
# 初始化加载器
|
121
|
+
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25)
|
122
|
+
# 切割加载的 document
|
123
|
+
split_docs = text_splitter.split_documents(documents)
|
124
|
+
vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path)
|
125
|
+
vector_store.persist()
|
126
|
+
os.remove(docpath)
|
127
|
+
return vector_store
|
128
|
+
|
129
|
+
async def pdfQA(docurl, docpath, query_message, model="gpt-3.5-turbo"):
|
130
|
+
chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=os.environ.get('API', None))
|
131
|
+
embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None))
|
132
|
+
persist_db_path = getmd5(docpath)
|
133
|
+
if not os.path.exists(persist_db_path):
|
134
|
+
vector_store = persist_emdedding_pdf(docurl, persist_db_path)
|
135
|
+
else:
|
136
|
+
vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings)
|
137
|
+
qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True)
|
138
|
+
result = qa({"query": query_message})
|
139
|
+
return result['result']
|
140
|
+
|
141
|
+
|
142
|
+
def pdf_search(docurl, query_message, model="gpt-3.5-turbo"):
|
143
|
+
chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=os.environ.get('API', None))
|
144
|
+
embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None))
|
145
|
+
filename = get_doc_from_url(docurl)
|
146
|
+
docpath = os.getcwd() + "/" + filename
|
147
|
+
loader = UnstructuredPDFLoader(docpath)
|
148
|
+
try:
|
149
|
+
documents = loader.load()
|
150
|
+
except:
|
151
|
+
print("pdf load error! docpath:", docpath)
|
152
|
+
return ""
|
153
|
+
os.remove(docpath)
|
154
|
+
# 初始化加载器
|
155
|
+
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25)
|
156
|
+
# 切割加载的 document
|
157
|
+
split_docs = text_splitter.split_documents(documents)
|
158
|
+
vector_store = Chroma.from_documents(split_docs, embeddings)
|
159
|
+
# 创建问答对象
|
160
|
+
qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True)
|
161
|
+
# 进行问答
|
162
|
+
result = qa({"query": query_message})
|
163
|
+
return result['result']
|
164
|
+
|
165
|
+
def summary_each_url(threads, chainllm, prompt):
|
166
|
+
summary_prompt = PromptTemplate(
|
167
|
+
input_variables=["web_summary", "question", "language"],
|
168
|
+
template=(
|
169
|
+
"You need to response the following question: {question}."
|
170
|
+
"Your task is answer the above question in {language} based on the Search results provided. Provide a detailed and in-depth response"
|
171
|
+
"If there is no relevant content in the search results, just answer None, do not make any explanations."
|
172
|
+
"Search results: {web_summary}."
|
173
|
+
),
|
174
|
+
)
|
175
|
+
summary_threads = []
|
176
|
+
|
177
|
+
for t in threads:
|
178
|
+
tmp = t.join()
|
179
|
+
print(tmp)
|
180
|
+
chain = LLMChain(llm=chainllm, prompt=summary_prompt)
|
181
|
+
chain_thread = ThreadWithReturnValue(target=chain.run, args=({"web_summary": tmp, "question": prompt, "language": config.LANGUAGE},))
|
182
|
+
chain_thread.start()
|
183
|
+
summary_threads.append(chain_thread)
|
184
|
+
|
185
|
+
url_result = ""
|
186
|
+
for t in summary_threads:
|
187
|
+
tmp = t.join()
|
188
|
+
print("summary", tmp)
|
189
|
+
if tmp != "None":
|
190
|
+
url_result += "\n\n" + tmp
|
191
|
+
return url_result
|
192
|
+
|
193
|
+
def get_search_results(prompt: str, context_max_tokens: int):
|
194
|
+
|
195
|
+
url_text_list = get_url_text_list(prompt)
|
196
|
+
useful_source_text = "\n\n".join(url_text_list)
|
197
|
+
# useful_source_text = summary_each_url(threads, chainllm, prompt)
|
198
|
+
|
199
|
+
useful_source_text, search_tokens_len = cut_message(useful_source_text, context_max_tokens)
|
200
|
+
print("search tokens len", search_tokens_len, "\n\n")
|
201
|
+
|
202
|
+
return useful_source_text
|
203
|
+
|
204
|
+
from typing import Any
|
205
|
+
from langchain.schema.output import LLMResult
|
206
|
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
207
|
+
class ChainStreamHandler(StreamingStdOutCallbackHandler):
|
208
|
+
def __init__(self):
|
209
|
+
self.tokens = []
|
210
|
+
# 记得结束后这里置true
|
211
|
+
self.finish = False
|
212
|
+
self.answer = ""
|
213
|
+
|
214
|
+
def on_llm_new_token(self, token: str, **kwargs):
|
215
|
+
# print(token)
|
216
|
+
self.tokens.append(token)
|
217
|
+
# yield ''.join(self.tokens)
|
218
|
+
# print(''.join(self.tokens))
|
219
|
+
|
220
|
+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
221
|
+
self.finish = 1
|
222
|
+
|
223
|
+
def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
|
224
|
+
print(str(error))
|
225
|
+
self.tokens.append(str(error))
|
226
|
+
|
227
|
+
def generate_tokens(self):
|
228
|
+
while not self.finish or self.tokens:
|
229
|
+
if self.tokens:
|
230
|
+
data = self.tokens.pop(0)
|
231
|
+
self.answer += data
|
232
|
+
yield data
|
233
|
+
else:
|
234
|
+
pass
|
235
|
+
return self.answer
|
@@ -0,0 +1,32 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
class SpecificStringFilter(logging.Filter):
|
4
|
+
def __init__(self, specific_string):
|
5
|
+
super().__init__()
|
6
|
+
self.specific_string = specific_string
|
7
|
+
|
8
|
+
def filter(self, record):
|
9
|
+
return self.specific_string not in record.getMessage()
|
10
|
+
|
11
|
+
# 创建一个 logger
|
12
|
+
logger = logging.getLogger('my_logger')
|
13
|
+
logger.setLevel(logging.DEBUG)
|
14
|
+
|
15
|
+
# 创建一个 console handler,并设置级别为 debug
|
16
|
+
ch = logging.StreamHandler()
|
17
|
+
# ch.setLevel(logging.DEBUG)
|
18
|
+
|
19
|
+
# 创建一个 filter 实例
|
20
|
+
specific_string = "httpx.RemoteProtocolError: Server disconnected without sending a response."
|
21
|
+
my_filter = SpecificStringFilter(specific_string)
|
22
|
+
|
23
|
+
# 将 filter 添加到 handler
|
24
|
+
ch.addFilter(my_filter)
|
25
|
+
|
26
|
+
# 将 handler 添加到 logger
|
27
|
+
logger.addHandler(ch)
|
28
|
+
|
29
|
+
# 测试日志消息
|
30
|
+
logger.debug("This is a debug message.")
|
31
|
+
logger.error("This message will be ignored: ignore me.httpx.RemoteProtocolError: Server disconnected without sending a response.")
|
32
|
+
logger.info("Another info message.")
|
@@ -0,0 +1,55 @@
|
|
1
|
+
import os
|
2
|
+
from rich.console import Console
|
3
|
+
from rich.markdown import Markdown
|
4
|
+
import json
|
5
|
+
import requests
|
6
|
+
|
7
|
+
def query_ollama(prompt, model):
|
8
|
+
# 设置请求的URL和数据
|
9
|
+
url = 'http://localhost:11434/api/generate'
|
10
|
+
data = {
|
11
|
+
"model": model,
|
12
|
+
"prompt": prompt,
|
13
|
+
"stream": True,
|
14
|
+
}
|
15
|
+
|
16
|
+
response = requests.Session().post(
|
17
|
+
url,
|
18
|
+
json=data,
|
19
|
+
stream=True,
|
20
|
+
)
|
21
|
+
full_response: str = ""
|
22
|
+
for line in response.iter_lines():
|
23
|
+
if not line or line.decode("utf-8")[:6] == "event:" or line.decode("utf-8") == "data: {}":
|
24
|
+
continue
|
25
|
+
line = line.decode("utf-8")
|
26
|
+
# print(line)
|
27
|
+
resp: dict = json.loads(line)
|
28
|
+
content = resp.get("response")
|
29
|
+
if not content:
|
30
|
+
continue
|
31
|
+
full_response += content
|
32
|
+
yield content
|
33
|
+
|
34
|
+
if __name__ == "__main__":
|
35
|
+
console = Console()
|
36
|
+
# model = 'llama2'
|
37
|
+
# model = 'mistral'
|
38
|
+
# model = 'llama3:8b'
|
39
|
+
model = 'phi3:medium'
|
40
|
+
# model = 'qwen:14b'
|
41
|
+
# model = 'wizardlm2:7b'
|
42
|
+
# model = 'codeqwen:7b-chat'
|
43
|
+
# model = 'phi'
|
44
|
+
|
45
|
+
# 查询答案
|
46
|
+
prompt = r'''
|
47
|
+
|
48
|
+
|
49
|
+
'''
|
50
|
+
answer = ""
|
51
|
+
for result in query_ollama(prompt, model):
|
52
|
+
os.system("clear")
|
53
|
+
answer += result
|
54
|
+
md = Markdown(answer)
|
55
|
+
console.print(md, no_wrap=False)
|
@@ -0,0 +1,16 @@
|
|
1
|
+
import json
|
2
|
+
|
3
|
+
from ..src.aient.plugins.websearch import get_search_results
|
4
|
+
from ..src.aient.plugins.arXiv import download_read_arxiv_pdf
|
5
|
+
from ..src.aient.plugins.image import generate_image
|
6
|
+
from ..src.aient.plugins.get_time import get_time
|
7
|
+
from ..src.aient.plugins.run_python import run_python_script
|
8
|
+
|
9
|
+
from ..src.aient.plugins.config import function_to_json
|
10
|
+
|
11
|
+
|
12
|
+
print(json.dumps(function_to_json(get_search_results), indent=4, ensure_ascii=False))
|
13
|
+
print(json.dumps(function_to_json(download_read_arxiv_pdf), indent=4, ensure_ascii=False))
|
14
|
+
print(json.dumps(function_to_json(generate_image), indent=4, ensure_ascii=False))
|
15
|
+
print(json.dumps(function_to_json(get_time), indent=4, ensure_ascii=False))
|
16
|
+
print(json.dumps(function_to_json(run_python_script), indent=4, ensure_ascii=False))
|
@@ -0,0 +1,26 @@
|
|
1
|
+
def run_python_script(script):
|
2
|
+
# 创建一个字典来存储脚本执行的本地变量
|
3
|
+
local_vars = {}
|
4
|
+
|
5
|
+
try:
|
6
|
+
# 执行脚本字符串
|
7
|
+
exec(script, {}, local_vars)
|
8
|
+
return local_vars
|
9
|
+
except Exception as e:
|
10
|
+
return str(e)
|
11
|
+
|
12
|
+
# 示例用法
|
13
|
+
script = "# \u8ba1\u7b97\u524d100\u4e2a\u6590\u6ce2\u7eb3\u5207\u6570\u5217\u7684\u548c\n\ndef fibonacci_sum(n):\n a, b = 0, 1\n sum = 0\n for _ in range(n):\n sum += a\n a, b = b, a + b\n return sum\n\nfibonacci_sum(100)"
|
14
|
+
print(script)
|
15
|
+
output = run_python_script(script)
|
16
|
+
print(output)
|
17
|
+
# 下面是要运行的程序,怎么修改上面的代码,可以捕获fibonacci_sum的输出
|
18
|
+
def fibonacci_sum(n):
|
19
|
+
a, b = 0, 1
|
20
|
+
sum = 0
|
21
|
+
for _ in range(n):
|
22
|
+
sum += a
|
23
|
+
a, b = b, a + b
|
24
|
+
return sum
|
25
|
+
|
26
|
+
print(fibonacci_sum(100))
|