beswarm 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. beswarm/aient/main.py +50 -0
  2. beswarm/aient/setup.py +15 -0
  3. beswarm/aient/src/aient/__init__.py +1 -0
  4. beswarm/aient/src/aient/core/__init__.py +1 -0
  5. beswarm/aient/src/aient/core/log_config.py +6 -0
  6. beswarm/aient/src/aient/core/models.py +232 -0
  7. beswarm/aient/src/aient/core/request.py +1665 -0
  8. beswarm/aient/src/aient/core/response.py +617 -0
  9. beswarm/aient/src/aient/core/test/test_base_api.py +18 -0
  10. beswarm/aient/src/aient/core/test/test_image.py +15 -0
  11. beswarm/aient/src/aient/core/test/test_payload.py +92 -0
  12. beswarm/aient/src/aient/core/utils.py +715 -0
  13. beswarm/aient/src/aient/models/__init__.py +9 -0
  14. beswarm/aient/src/aient/models/audio.py +63 -0
  15. beswarm/aient/src/aient/models/base.py +251 -0
  16. beswarm/aient/src/aient/models/chatgpt.py +938 -0
  17. beswarm/aient/src/aient/models/claude.py +640 -0
  18. beswarm/aient/src/aient/models/duckduckgo.py +241 -0
  19. beswarm/aient/src/aient/models/gemini.py +357 -0
  20. beswarm/aient/src/aient/models/groq.py +268 -0
  21. beswarm/aient/src/aient/models/vertex.py +420 -0
  22. beswarm/aient/src/aient/plugins/__init__.py +33 -0
  23. beswarm/aient/src/aient/plugins/arXiv.py +48 -0
  24. beswarm/aient/src/aient/plugins/config.py +172 -0
  25. beswarm/aient/src/aient/plugins/excute_command.py +35 -0
  26. beswarm/aient/src/aient/plugins/get_time.py +19 -0
  27. beswarm/aient/src/aient/plugins/image.py +72 -0
  28. beswarm/aient/src/aient/plugins/list_directory.py +50 -0
  29. beswarm/aient/src/aient/plugins/read_file.py +79 -0
  30. beswarm/aient/src/aient/plugins/registry.py +116 -0
  31. beswarm/aient/src/aient/plugins/run_python.py +156 -0
  32. beswarm/aient/src/aient/plugins/websearch.py +394 -0
  33. beswarm/aient/src/aient/plugins/write_file.py +51 -0
  34. beswarm/aient/src/aient/prompt/__init__.py +1 -0
  35. beswarm/aient/src/aient/prompt/agent.py +280 -0
  36. beswarm/aient/src/aient/utils/__init__.py +0 -0
  37. beswarm/aient/src/aient/utils/prompt.py +143 -0
  38. beswarm/aient/src/aient/utils/scripts.py +721 -0
  39. beswarm/aient/test/chatgpt.py +161 -0
  40. beswarm/aient/test/claude.py +32 -0
  41. beswarm/aient/test/test.py +2 -0
  42. beswarm/aient/test/test_API.py +6 -0
  43. beswarm/aient/test/test_Deepbricks.py +20 -0
  44. beswarm/aient/test/test_Web_crawler.py +262 -0
  45. beswarm/aient/test/test_aiwaves.py +25 -0
  46. beswarm/aient/test/test_aiwaves_arxiv.py +19 -0
  47. beswarm/aient/test/test_ask_gemini.py +8 -0
  48. beswarm/aient/test/test_class.py +17 -0
  49. beswarm/aient/test/test_claude.py +23 -0
  50. beswarm/aient/test/test_claude_zh_char.py +26 -0
  51. beswarm/aient/test/test_ddg_search.py +50 -0
  52. beswarm/aient/test/test_download_pdf.py +56 -0
  53. beswarm/aient/test/test_gemini.py +97 -0
  54. beswarm/aient/test/test_get_token_dict.py +21 -0
  55. beswarm/aient/test/test_google_search.py +35 -0
  56. beswarm/aient/test/test_jieba.py +32 -0
  57. beswarm/aient/test/test_json.py +65 -0
  58. beswarm/aient/test/test_langchain_search_old.py +235 -0
  59. beswarm/aient/test/test_logging.py +32 -0
  60. beswarm/aient/test/test_ollama.py +55 -0
  61. beswarm/aient/test/test_plugin.py +16 -0
  62. beswarm/aient/test/test_py_run.py +26 -0
  63. beswarm/aient/test/test_requests.py +162 -0
  64. beswarm/aient/test/test_search.py +18 -0
  65. beswarm/aient/test/test_tikitoken.py +19 -0
  66. beswarm/aient/test/test_token.py +94 -0
  67. beswarm/aient/test/test_url.py +33 -0
  68. beswarm/aient/test/test_whisper.py +14 -0
  69. beswarm/aient/test/test_wildcard.py +20 -0
  70. beswarm/aient/test/test_yjh.py +21 -0
  71. {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/METADATA +1 -1
  72. beswarm-0.1.13.dist-info/RECORD +131 -0
  73. beswarm-0.1.12.dist-info/RECORD +0 -61
  74. {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/WHEEL +0 -0
  75. {beswarm-0.1.12.dist-info → beswarm-0.1.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,56 @@
1
+ # import requests
2
+ # import urllib.parse
3
+ # import os
4
+ # import sys
5
+ # sys.path.append(os.getcwd())
6
+ # import config
7
+
8
+ # from langchain.chat_models import ChatOpenAI
9
+ # from langchain.embeddings.openai import OpenAIEmbeddings
10
+ # from langchain.vectorstores import Chroma
11
+ # from langchain.text_splitter import CharacterTextSplitter
12
+ # from langchain.document_loaders import UnstructuredPDFLoader
13
+ # from langchain.chains import RetrievalQA
14
+
15
+
16
+ # def get_doc_from_url(url):
17
+ # filename = urllib.parse.unquote(url.split("/")[-1])
18
+ # response = requests.get(url, stream=True)
19
+ # with open(filename, 'wb') as f:
20
+ # for chunk in response.iter_content(chunk_size=1024):
21
+ # f.write(chunk)
22
+ # return filename
23
+
24
+ # def pdf_search(docurl, query_message, model="gpt-3.5-turbo"):
25
+ # chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.API_URL.split("chat")[0], model_name=model, openai_api_key=os.environ.get('API', None))
26
+ # embeddings = OpenAIEmbeddings(openai_api_base=config.API_URL.split("chat")[0], openai_api_key=os.environ.get('API', None))
27
+ # filename = get_doc_from_url(docurl)
28
+ # docpath = os.getcwd() + "/" + filename
29
+ # loader = UnstructuredPDFLoader(docpath)
30
+ # print(docpath)
31
+ # documents = loader.load()
32
+ # os.remove(docpath)
33
+ # # 初始化加载器
34
+ # text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25)
35
+ # # 切割加载的 document
36
+ # split_docs = text_splitter.split_documents(documents)
37
+ # vector_store = Chroma.from_documents(split_docs, embeddings)
38
+ # # 创建问答对象
39
+ # qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True)
40
+ # # 进行问答
41
+ # result = qa({"query": query_message})
42
+ # return result['result']
43
+
44
+ # pdf_search("https://www.nsfc.gov.cn/csc/20345/22468/pdf/2001/%E5%86%BB%E7%BB%93%E8%A3%82%E9%9A%99%E7%A0%82%E5%B2%A9%E4%BD%8E%E5%91%A8%E5%BE%AA%E7%8E%AF%E5%8A%A8%E5%8A%9B%E7%89%B9%E6%80%A7%E8%AF%95%E9%AA%8C%E7%A0%94%E7%A9%B6.pdf", "端水实验的目的是什么?")
45
+
46
+ from PyPDF2 import PdfReader
47
+
48
+ def has_text(pdf_path):
49
+ with open(pdf_path, 'rb') as file:
50
+ pdf = PdfReader(file)
51
+ page = pdf.pages[0]
52
+ text = page.extract_text()
53
+ return text
54
+
55
+ pdf_path = '/Users/yanyuming/Downloads/GitHub/ChatGPT-Telegram-Bot/冻结裂隙砂岩低周循环动力特性试验研究.pdf'
56
+ print(has_text(pdf_path))
@@ -0,0 +1,97 @@
1
+ import json
2
+
3
+ class JSONExtractor:
4
+ def __init__(self):
5
+ self.buffer = ""
6
+ self.bracket_count = 0
7
+ self.in_target = False
8
+ self.target_json = ""
9
+
10
+ def process_line(self, line):
11
+ self.buffer += line.strip()
12
+
13
+ for char in line:
14
+ if char == '{':
15
+ self.bracket_count += 1
16
+ if self.bracket_count == 4 and '"functionCall"' in self.buffer[-20:]:
17
+ self.in_target = True
18
+ self.target_json = '{'
19
+ elif char == '}':
20
+ if self.in_target:
21
+ self.target_json += '}'
22
+ self.bracket_count -= 1
23
+ if self.bracket_count == 3 and self.in_target:
24
+ self.in_target = False
25
+ return self.parse_target_json()
26
+
27
+ if self.in_target:
28
+ self.target_json += char
29
+
30
+ return None
31
+
32
+ def parse_target_json(self):
33
+ try:
34
+ parsed = json.loads(self.target_json)
35
+ if 'functionCall' in parsed:
36
+ return parsed['functionCall']
37
+ except json.JSONDecodeError:
38
+ pass
39
+ return None
40
+
41
+ # 使用示例
42
+ extractor = JSONExtractor()
43
+
44
+ # 模拟流式接收数据
45
+ sample_lines = [
46
+ '{\n',
47
+ ' "candidates": [\n',
48
+ ' {\n',
49
+ ' "content": {\n',
50
+ ' "parts": [\n',
51
+ ' {\n',
52
+ ' "functionCall": {\n',
53
+ ' "name": "get_search_results",\n',
54
+ ' "args": {\n',
55
+ ' "prompt": "Claude Opus 3.5 release date"\n',
56
+ ' }\n',
57
+ ' }\n',
58
+ ' }\n',
59
+ ' ],\n',
60
+ ' "role": "model"\n',
61
+ ' },\n',
62
+ ' "finishReason": "STOP",\n',
63
+ ' "index": 0,\n',
64
+ ' "safetyRatings": [\n',
65
+ ' {\n',
66
+ ' "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",\n',
67
+ ' "probability": "NEGLIGIBLE"\n',
68
+ ' },\n',
69
+ ' {\n',
70
+ ' "category": "HARM_CATEGORY_HATE_SPEECH",\n',
71
+ ' "probability": "NEGLIGIBLE"\n',
72
+ ' },\n',
73
+ ' {\n',
74
+ ' "category": "HARM_CATEGORY_HARASSMENT",\n',
75
+ ' "probability": "NEGLIGIBLE"\n',
76
+ ' },\n',
77
+ ' {\n',
78
+ ' "category": "HARM_CATEGORY_DANGEROUS_CONTENT",\n',
79
+ ' "probability": "NEGLIGIBLE"\n',
80
+ ' }\n',
81
+ ' ]\n',
82
+ ' }\n',
83
+ ' ],\n',
84
+ ' "usageMetadata": {\n',
85
+ ' "promptTokenCount": 113,\n',
86
+ ' "candidatesTokenCount": 55,\n',
87
+ ' "totalTokenCount": 168\n',
88
+ ' }\n',
89
+ '}\n'
90
+ ]
91
+
92
+ for line in sample_lines:
93
+ result = extractor.process_line(line)
94
+ if result:
95
+ print("提取的functionCall:")
96
+ print(json.dumps(result, indent=2, ensure_ascii=False))
97
+ break
@@ -0,0 +1,21 @@
1
+ from collections import defaultdict
2
+
3
+ # 定义一个默认值工厂函数,这里使用int来初始化为0
4
+ default_dict = defaultdict(int)
5
+
6
+ # 示例用法
7
+ print(default_dict['a']) # 输出: 0,因为'a'不存在,自动初始化为0
8
+ default_dict['a'] += 1
9
+ print(default_dict['a']) # 输出: 1
10
+
11
+ # 你也可以使用其他类型的工厂函数,例如list
12
+ list_default_dict = defaultdict(list)
13
+ print(list_default_dict['b']) # 输出: [],因为'b'不存在,自动初始化为空列表
14
+ list_default_dict['b'].append(2)
15
+ print(list_default_dict['b']) # 输出: [2]
16
+
17
+ # 如果你有一个现有的字典,也可以将其转换为defaultdict
18
+ existing_dict = {'c': 3, 'd': 4}
19
+ default_dict = defaultdict(int, existing_dict)
20
+ print(default_dict['c']) # 输出: 3
21
+ print(default_dict['e']) # 输出: 0,因为'e'不存在,自动初始化为0
@@ -0,0 +1,35 @@
1
+ import os
2
+ import requests
3
+ from googleapiclient.discovery import build
4
+ from dotenv import load_dotenv
5
+ load_dotenv()
6
+
7
+ search_engine_id = os.environ.get('GOOGLE_CSE_ID', None)
8
+ api_key = os.environ.get('GOOGLE_API_KEY', None)
9
+ query = "Python 编程"
10
+
11
+ def google_search1(query, api_key, search_engine_id):
12
+ service = build("customsearch", "v1", developerKey=api_key)
13
+ res = service.cse().list(q=query, cx=search_engine_id).execute()
14
+ link_list = [item['link'] for item in res['items']]
15
+ return link_list
16
+
17
+ def google_search2(query, api_key, cx):
18
+ url = "https://www.googleapis.com/customsearch/v1"
19
+ params = {
20
+ 'q': query,
21
+ 'key': api_key,
22
+ 'cx': cx
23
+ }
24
+ response = requests.get(url, params=params)
25
+ print(response.text)
26
+ results = response.json()
27
+ link_list = [item['link'] for item in results.get('items', [])]
28
+
29
+ return link_list
30
+
31
+ # results = google_search1(query, api_key, search_engine_id)
32
+ # print(results)
33
+
34
+ results = google_search2(query, api_key, search_engine_id)
35
+ print(results)
@@ -0,0 +1,32 @@
1
+ import jieba
2
+ import jieba.analyse
3
+
4
+ # 加载文本
5
+ # text = "话说葬送的芙莉莲动漫是半年番还是季番?完结没?"
6
+ # text = "民进党当初为什么支持柯文哲选台北市长?"
7
+ text = "今天的微博热搜有哪些?"
8
+ # text = "How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?"
9
+
10
+ # 使用TF-IDF算法提取关键词
11
+ keywords_tfidf = jieba.analyse.extract_tags(text, topK=10, withWeight=False, allowPOS=())
12
+
13
+ # 使用TextRank算法提取关键词
14
+ keywords_textrank = jieba.analyse.textrank(text, topK=10, withWeight=False, allowPOS=('ns', 'n', 'vn', 'v'))
15
+
16
+ print("TF-IDF算法提取的关键词:", keywords_tfidf)
17
+ print("TextRank算法提取的关键词:", keywords_textrank)
18
+
19
+
20
+ seg_list = jieba.cut(text, cut_all=True)
21
+ print("Full Mode: " + " ".join(seg_list)) # 全模式
22
+
23
+ seg_list = jieba.cut(text, cut_all=False)
24
+ print("Default Mode: " + " ".join(seg_list)) # 精确模式
25
+
26
+ seg_list = jieba.cut(text) # 默认是精确模式
27
+ print(" ".join(seg_list))
28
+
29
+ seg_list = jieba.cut_for_search(text) # 搜索引擎模式
30
+ result = " ".join(seg_list)
31
+
32
+ print([result] * 3)
@@ -0,0 +1,65 @@
1
+ import json
2
+
3
+ # json_data = '爱'
4
+ # # json_data = '爱的主人,我会尽快为您规划一个走线到美国的安全路线。请您稍等片刻。\n\n首先,我会检查免签国家并为您提供相应的信息。接下来,我会 搜索有关旅行到美国的安全建议和路线规划。{}'
5
+
6
+ def split_json_strings(input_string):
7
+ # 初始化结果列表和当前 JSON 字符串
8
+ json_strings = []
9
+ current_json = ""
10
+ brace_count = 0
11
+
12
+ # 遍历输入字符串的每个字符
13
+ for char in input_string:
14
+ current_json += char
15
+ if char == '{':
16
+ brace_count += 1
17
+ elif char == '}':
18
+ brace_count -= 1
19
+
20
+ # 如果花括号配对完成,我们找到了一个完整的 JSON 字符串
21
+ if brace_count == 0:
22
+ # 尝试解析当前 JSON 字符串
23
+ try:
24
+ json.loads(current_json)
25
+ json_strings.append(current_json)
26
+ current_json = ""
27
+ except json.JSONDecodeError:
28
+ # 如果解析失败,继续添加字符
29
+ pass
30
+ if json_strings == []:
31
+ json_strings.append(input_string)
32
+ return json_strings
33
+
34
+ # 测试函数
35
+ input_string = '{"url": "https://github.com/fastai/fasthtml"'
36
+ result = split_json_strings(input_string)
37
+
38
+ for i, json_str in enumerate(result, 1):
39
+ print(f"JSON {i}:", json_str)
40
+ print("Parsed:", json.loads(json_str))
41
+ print()
42
+
43
+ # def check_json(json_data):
44
+ # while True:
45
+ # try:
46
+ # json.loads(json_data)
47
+ # break
48
+ # except json.decoder.JSONDecodeError as e:
49
+ # print("JSON error:", e)
50
+ # print("JSON body", repr(json_data))
51
+ # if "Invalid control character" in str(e):
52
+ # json_data = json_data.replace("\n", "\\n")
53
+ # if "Unterminated string starting" in str(e):
54
+ # json_data += '"}'
55
+ # if "Expecting ',' delimiter" in str(e):
56
+ # json_data += '}'
57
+ # if "Expecting value: line 1 column 1" in str(e):
58
+ # json_data = '{"prompt": ' + json.dumps(json_data) + '}'
59
+ # return json_data
60
+ # print(json.loads(check_json(json_data)))
61
+
62
+ # a = '''
63
+ # '''
64
+
65
+ # print(json.loads(a))
@@ -0,0 +1,235 @@
1
+ import os
2
+ import re
3
+
4
+ import sys
5
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
+ import config
7
+
8
+ from langchain.chat_models import ChatOpenAI
9
+
10
+
11
+ from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
12
+
13
+ from langchain.prompts.chat import (
14
+ ChatPromptTemplate,
15
+ SystemMessagePromptTemplate,
16
+ HumanMessagePromptTemplate,
17
+ )
18
+ from langchain.embeddings.openai import OpenAIEmbeddings
19
+ from langchain.vectorstores import Chroma
20
+ from langchain.text_splitter import CharacterTextSplitter
21
+
22
+ from langchain.document_loaders import UnstructuredPDFLoader
23
+
24
+ def getmd5(string):
25
+ import hashlib
26
+ md5_hash = hashlib.md5()
27
+ md5_hash.update(string.encode('utf-8'))
28
+ md5_hex = md5_hash.hexdigest()
29
+ return md5_hex
30
+
31
+ from utils.sitemap import SitemapLoader
32
+ async def get_doc_from_sitemap(url):
33
+ # https://www.langchain.asia/modules/indexes/document_loaders/examples/sitemap#%E8%BF%87%E6%BB%A4%E7%AB%99%E7%82%B9%E5%9C%B0%E5%9B%BE-url-
34
+ sitemap_loader = SitemapLoader(web_path=url)
35
+ docs = await sitemap_loader.load()
36
+ return docs
37
+
38
+ async def get_doc_from_local(docpath, doctype="md"):
39
+ from langchain.document_loaders import DirectoryLoader
40
+ # 加载文件夹中的所有txt类型的文件
41
+ loader = DirectoryLoader(docpath, glob='**/*.' + doctype)
42
+ # 将数据转成 document 对象,每个文件会作为一个 document
43
+ documents = loader.load()
44
+ return documents
45
+
46
+ system_template="""Use the following pieces of context to answer the users question.
47
+ If you don't know the answer, just say "Hmm..., I'm not sure.", don't try to make up an answer.
48
+ ALWAYS return a "Sources" part in your answer.
49
+ The "Sources" part should be a reference to the source of the document from which you got your answer.
50
+
51
+ Example of your response should be:
52
+
53
+ ```
54
+ The answer is foo
55
+
56
+ Sources:
57
+ 1. abc
58
+ 2. xyz
59
+ ```
60
+ Begin!
61
+ ----------------
62
+ {summaries}
63
+ """
64
+ messages = [
65
+ SystemMessagePromptTemplate.from_template(system_template),
66
+ HumanMessagePromptTemplate.from_template("{question}")
67
+ ]
68
+ prompt = ChatPromptTemplate.from_messages(messages)
69
+
70
+ def get_chain(store, llm):
71
+ chain_type_kwargs = {"prompt": prompt}
72
+ chain = RetrievalQAWithSourcesChain.from_chain_type(
73
+ llm,
74
+ chain_type="stuff",
75
+ retriever=store.as_retriever(),
76
+ chain_type_kwargs=chain_type_kwargs,
77
+ reduce_k_below_max_tokens=True
78
+ )
79
+ return chain
80
+
81
+ async def docQA(docpath, query_message, persist_db_path="db", model = "gpt-3.5-turbo"):
82
+ chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=config.API)
83
+ embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=config.API)
84
+
85
+ sitemap = "sitemap.xml"
86
+ match = re.match(r'^(https?|ftp)://[^\s/$.?#].[^\s]*$', docpath)
87
+ if match:
88
+ doc_method = get_doc_from_sitemap
89
+ docpath = os.path.join(docpath, sitemap)
90
+ else:
91
+ doc_method = get_doc_from_local
92
+
93
+ persist_db_path = getmd5(docpath)
94
+ if not os.path.exists(persist_db_path):
95
+ documents = await doc_method(docpath)
96
+ # 初始化加载器
97
+ text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=50)
98
+ # 持久化数据
99
+ split_docs = text_splitter.split_documents(documents)
100
+ vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path)
101
+ vector_store.persist()
102
+ else:
103
+ # 加载数据
104
+ vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings)
105
+
106
+ # 创建问答对象
107
+ qa = get_chain(vector_store, chatllm)
108
+ # qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True)
109
+ # 进行问答
110
+ result = qa({"question": query_message})
111
+ return result
112
+
113
+
114
+ def persist_emdedding_pdf(docurl, persist_db_path):
115
+ embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None))
116
+ filename = get_doc_from_url(docurl)
117
+ docpath = os.getcwd() + "/" + filename
118
+ loader = UnstructuredPDFLoader(docpath)
119
+ documents = loader.load()
120
+ # 初始化加载器
121
+ text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25)
122
+ # 切割加载的 document
123
+ split_docs = text_splitter.split_documents(documents)
124
+ vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path)
125
+ vector_store.persist()
126
+ os.remove(docpath)
127
+ return vector_store
128
+
129
+ async def pdfQA(docurl, docpath, query_message, model="gpt-3.5-turbo"):
130
+ chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=os.environ.get('API', None))
131
+ embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None))
132
+ persist_db_path = getmd5(docpath)
133
+ if not os.path.exists(persist_db_path):
134
+ vector_store = persist_emdedding_pdf(docurl, persist_db_path)
135
+ else:
136
+ vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings)
137
+ qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True)
138
+ result = qa({"query": query_message})
139
+ return result['result']
140
+
141
+
142
+ def pdf_search(docurl, query_message, model="gpt-3.5-turbo"):
143
+ chatllm = ChatOpenAI(temperature=0.5, openai_api_base=config.bot_api_url.v1_url, model_name=model, openai_api_key=os.environ.get('API', None))
144
+ embeddings = OpenAIEmbeddings(openai_api_base=config.bot_api_url.v1_url, openai_api_key=os.environ.get('API', None))
145
+ filename = get_doc_from_url(docurl)
146
+ docpath = os.getcwd() + "/" + filename
147
+ loader = UnstructuredPDFLoader(docpath)
148
+ try:
149
+ documents = loader.load()
150
+ except:
151
+ print("pdf load error! docpath:", docpath)
152
+ return ""
153
+ os.remove(docpath)
154
+ # 初始化加载器
155
+ text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25)
156
+ # 切割加载的 document
157
+ split_docs = text_splitter.split_documents(documents)
158
+ vector_store = Chroma.from_documents(split_docs, embeddings)
159
+ # 创建问答对象
160
+ qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True)
161
+ # 进行问答
162
+ result = qa({"query": query_message})
163
+ return result['result']
164
+
165
+ def summary_each_url(threads, chainllm, prompt):
166
+ summary_prompt = PromptTemplate(
167
+ input_variables=["web_summary", "question", "language"],
168
+ template=(
169
+ "You need to response the following question: {question}."
170
+ "Your task is answer the above question in {language} based on the Search results provided. Provide a detailed and in-depth response"
171
+ "If there is no relevant content in the search results, just answer None, do not make any explanations."
172
+ "Search results: {web_summary}."
173
+ ),
174
+ )
175
+ summary_threads = []
176
+
177
+ for t in threads:
178
+ tmp = t.join()
179
+ print(tmp)
180
+ chain = LLMChain(llm=chainllm, prompt=summary_prompt)
181
+ chain_thread = ThreadWithReturnValue(target=chain.run, args=({"web_summary": tmp, "question": prompt, "language": config.LANGUAGE},))
182
+ chain_thread.start()
183
+ summary_threads.append(chain_thread)
184
+
185
+ url_result = ""
186
+ for t in summary_threads:
187
+ tmp = t.join()
188
+ print("summary", tmp)
189
+ if tmp != "None":
190
+ url_result += "\n\n" + tmp
191
+ return url_result
192
+
193
+ def get_search_results(prompt: str, context_max_tokens: int):
194
+
195
+ url_text_list = get_url_text_list(prompt)
196
+ useful_source_text = "\n\n".join(url_text_list)
197
+ # useful_source_text = summary_each_url(threads, chainllm, prompt)
198
+
199
+ useful_source_text, search_tokens_len = cut_message(useful_source_text, context_max_tokens)
200
+ print("search tokens len", search_tokens_len, "\n\n")
201
+
202
+ return useful_source_text
203
+
204
+ from typing import Any
205
+ from langchain.schema.output import LLMResult
206
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
207
+ class ChainStreamHandler(StreamingStdOutCallbackHandler):
208
+ def __init__(self):
209
+ self.tokens = []
210
+ # 记得结束后这里置true
211
+ self.finish = False
212
+ self.answer = ""
213
+
214
+ def on_llm_new_token(self, token: str, **kwargs):
215
+ # print(token)
216
+ self.tokens.append(token)
217
+ # yield ''.join(self.tokens)
218
+ # print(''.join(self.tokens))
219
+
220
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
221
+ self.finish = 1
222
+
223
+ def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
224
+ print(str(error))
225
+ self.tokens.append(str(error))
226
+
227
+ def generate_tokens(self):
228
+ while not self.finish or self.tokens:
229
+ if self.tokens:
230
+ data = self.tokens.pop(0)
231
+ self.answer += data
232
+ yield data
233
+ else:
234
+ pass
235
+ return self.answer
@@ -0,0 +1,32 @@
1
+ import logging
2
+
3
+ class SpecificStringFilter(logging.Filter):
4
+ def __init__(self, specific_string):
5
+ super().__init__()
6
+ self.specific_string = specific_string
7
+
8
+ def filter(self, record):
9
+ return self.specific_string not in record.getMessage()
10
+
11
+ # 创建一个 logger
12
+ logger = logging.getLogger('my_logger')
13
+ logger.setLevel(logging.DEBUG)
14
+
15
+ # 创建一个 console handler,并设置级别为 debug
16
+ ch = logging.StreamHandler()
17
+ # ch.setLevel(logging.DEBUG)
18
+
19
+ # 创建一个 filter 实例
20
+ specific_string = "httpx.RemoteProtocolError: Server disconnected without sending a response."
21
+ my_filter = SpecificStringFilter(specific_string)
22
+
23
+ # 将 filter 添加到 handler
24
+ ch.addFilter(my_filter)
25
+
26
+ # 将 handler 添加到 logger
27
+ logger.addHandler(ch)
28
+
29
+ # 测试日志消息
30
+ logger.debug("This is a debug message.")
31
+ logger.error("This message will be ignored: ignore me.httpx.RemoteProtocolError: Server disconnected without sending a response.")
32
+ logger.info("Another info message.")
@@ -0,0 +1,55 @@
1
+ import os
2
+ from rich.console import Console
3
+ from rich.markdown import Markdown
4
+ import json
5
+ import requests
6
+
7
+ def query_ollama(prompt, model):
8
+ # 设置请求的URL和数据
9
+ url = 'http://localhost:11434/api/generate'
10
+ data = {
11
+ "model": model,
12
+ "prompt": prompt,
13
+ "stream": True,
14
+ }
15
+
16
+ response = requests.Session().post(
17
+ url,
18
+ json=data,
19
+ stream=True,
20
+ )
21
+ full_response: str = ""
22
+ for line in response.iter_lines():
23
+ if not line or line.decode("utf-8")[:6] == "event:" or line.decode("utf-8") == "data: {}":
24
+ continue
25
+ line = line.decode("utf-8")
26
+ # print(line)
27
+ resp: dict = json.loads(line)
28
+ content = resp.get("response")
29
+ if not content:
30
+ continue
31
+ full_response += content
32
+ yield content
33
+
34
+ if __name__ == "__main__":
35
+ console = Console()
36
+ # model = 'llama2'
37
+ # model = 'mistral'
38
+ # model = 'llama3:8b'
39
+ model = 'phi3:medium'
40
+ # model = 'qwen:14b'
41
+ # model = 'wizardlm2:7b'
42
+ # model = 'codeqwen:7b-chat'
43
+ # model = 'phi'
44
+
45
+ # 查询答案
46
+ prompt = r'''
47
+
48
+
49
+ '''
50
+ answer = ""
51
+ for result in query_ollama(prompt, model):
52
+ os.system("clear")
53
+ answer += result
54
+ md = Markdown(answer)
55
+ console.print(md, no_wrap=False)
@@ -0,0 +1,16 @@
1
+ import json
2
+
3
+ from ..src.aient.plugins.websearch import get_search_results
4
+ from ..src.aient.plugins.arXiv import download_read_arxiv_pdf
5
+ from ..src.aient.plugins.image import generate_image
6
+ from ..src.aient.plugins.get_time import get_time
7
+ from ..src.aient.plugins.run_python import run_python_script
8
+
9
+ from ..src.aient.plugins.config import function_to_json
10
+
11
+
12
+ print(json.dumps(function_to_json(get_search_results), indent=4, ensure_ascii=False))
13
+ print(json.dumps(function_to_json(download_read_arxiv_pdf), indent=4, ensure_ascii=False))
14
+ print(json.dumps(function_to_json(generate_image), indent=4, ensure_ascii=False))
15
+ print(json.dumps(function_to_json(get_time), indent=4, ensure_ascii=False))
16
+ print(json.dumps(function_to_json(run_python_script), indent=4, ensure_ascii=False))
@@ -0,0 +1,26 @@
1
+ def run_python_script(script):
2
+ # 创建一个字典来存储脚本执行的本地变量
3
+ local_vars = {}
4
+
5
+ try:
6
+ # 执行脚本字符串
7
+ exec(script, {}, local_vars)
8
+ return local_vars
9
+ except Exception as e:
10
+ return str(e)
11
+
12
+ # 示例用法
13
+ script = "# \u8ba1\u7b97\u524d100\u4e2a\u6590\u6ce2\u7eb3\u5207\u6570\u5217\u7684\u548c\n\ndef fibonacci_sum(n):\n a, b = 0, 1\n sum = 0\n for _ in range(n):\n sum += a\n a, b = b, a + b\n return sum\n\nfibonacci_sum(100)"
14
+ print(script)
15
+ output = run_python_script(script)
16
+ print(output)
17
+ # 下面是要运行的程序,怎么修改上面的代码,可以捕获fibonacci_sum的输出
18
+ def fibonacci_sum(n):
19
+ a, b = 0, 1
20
+ sum = 0
21
+ for _ in range(n):
22
+ sum += a
23
+ a, b = b, a + b
24
+ return sum
25
+
26
+ print(fibonacci_sum(100))