bisheng-langchain 0.3.0rc0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/chat_models/host_llm.py +1 -1
- bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +5 -3
- bisheng_langchain/gpts/agent_types/llm_functions_agent.py +7 -1
- bisheng_langchain/gpts/assistant.py +8 -5
- bisheng_langchain/gpts/auto_optimization.py +28 -27
- bisheng_langchain/gpts/auto_tool_selected.py +14 -15
- bisheng_langchain/gpts/load_tools.py +53 -1
- bisheng_langchain/gpts/prompts/__init__.py +4 -2
- bisheng_langchain/gpts/prompts/assistant_prompt_base.py +1 -0
- bisheng_langchain/gpts/prompts/assistant_prompt_cohere.py +19 -0
- bisheng_langchain/gpts/prompts/opening_dialog_prompt.py +1 -1
- bisheng_langchain/gpts/tools/api_tools/__init__.py +1 -1
- bisheng_langchain/gpts/tools/api_tools/base.py +3 -3
- bisheng_langchain/gpts/tools/api_tools/flow.py +19 -7
- bisheng_langchain/gpts/tools/api_tools/macro_data.py +175 -4
- bisheng_langchain/gpts/tools/api_tools/openapi.py +101 -0
- bisheng_langchain/gpts/tools/api_tools/sina.py +2 -2
- bisheng_langchain/gpts/tools/code_interpreter/tool.py +118 -39
- bisheng_langchain/rag/__init__.py +5 -0
- bisheng_langchain/rag/bisheng_rag_pipeline.py +320 -0
- bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +359 -0
- bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +376 -0
- bisheng_langchain/rag/bisheng_rag_tool.py +288 -0
- bisheng_langchain/rag/config/baseline.yaml +86 -0
- bisheng_langchain/rag/config/baseline_caibao.yaml +82 -0
- bisheng_langchain/rag/config/baseline_caibao_knowledge_v2.yaml +110 -0
- bisheng_langchain/rag/config/baseline_caibao_v2.yaml +112 -0
- bisheng_langchain/rag/config/baseline_demo_v2.yaml +92 -0
- bisheng_langchain/rag/config/baseline_s2b_mix.yaml +88 -0
- bisheng_langchain/rag/config/baseline_v2.yaml +90 -0
- bisheng_langchain/rag/extract_info.py +38 -0
- bisheng_langchain/rag/init_retrievers/__init__.py +4 -0
- bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +61 -0
- bisheng_langchain/rag/init_retrievers/keyword_retriever.py +65 -0
- bisheng_langchain/rag/init_retrievers/mix_retriever.py +103 -0
- bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +92 -0
- bisheng_langchain/rag/prompts/__init__.py +9 -0
- bisheng_langchain/rag/prompts/extract_key_prompt.py +34 -0
- bisheng_langchain/rag/prompts/prompt.py +47 -0
- bisheng_langchain/rag/prompts/prompt_cohere.py +111 -0
- bisheng_langchain/rag/qa_corpus/__init__.py +0 -0
- bisheng_langchain/rag/qa_corpus/qa_generator.py +143 -0
- bisheng_langchain/rag/rerank/__init__.py +5 -0
- bisheng_langchain/rag/rerank/rerank.py +48 -0
- bisheng_langchain/rag/rerank/rerank_benchmark.py +139 -0
- bisheng_langchain/rag/run_qa_gen_web.py +47 -0
- bisheng_langchain/rag/run_rag_evaluate_web.py +55 -0
- bisheng_langchain/rag/scoring/__init__.py +0 -0
- bisheng_langchain/rag/scoring/llama_index_score.py +91 -0
- bisheng_langchain/rag/scoring/ragas_score.py +183 -0
- bisheng_langchain/rag/utils.py +181 -0
- bisheng_langchain/retrievers/ensemble.py +2 -1
- bisheng_langchain/vectorstores/elastic_keywords_search.py +2 -1
- {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/METADATA +1 -1
- {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/RECORD +57 -22
- bisheng_langchain/gpts/prompts/base_prompt.py +0 -1
- {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/WHEEL +0 -0
- {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,376 @@
|
|
1
|
+
import argparse
|
2
|
+
import copy
|
3
|
+
import time
|
4
|
+
import inspect
|
5
|
+
import os
|
6
|
+
import httpx
|
7
|
+
import pandas as pd
|
8
|
+
import yaml
|
9
|
+
import math
|
10
|
+
from tqdm import tqdm
|
11
|
+
from loguru import logger
|
12
|
+
from collections import defaultdict
|
13
|
+
from bisheng_langchain.retrievers import EnsembleRetriever
|
14
|
+
from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
|
15
|
+
from langchain.docstore.document import Document
|
16
|
+
from langchain.chains.question_answering import load_qa_chain
|
17
|
+
from langchain_community.chat_models.cohere import ChatCohere
|
18
|
+
from bisheng_langchain.rag.init_retrievers import (
|
19
|
+
BaselineVectorRetriever,
|
20
|
+
KeywordRetriever,
|
21
|
+
MixRetriever,
|
22
|
+
SmallerChunksVectorRetriever,
|
23
|
+
)
|
24
|
+
from bisheng_langchain.rag.scoring.ragas_score import RagScore
|
25
|
+
from bisheng_langchain.rag.extract_info import extract_title
|
26
|
+
from bisheng_langchain.rag.utils import import_by_type, import_class
|
27
|
+
|
28
|
+
|
29
|
+
class BishengRagPipeline:
|
30
|
+
|
31
|
+
def __init__(self, yaml_path) -> None:
|
32
|
+
self.yaml_path = yaml_path
|
33
|
+
with open(self.yaml_path, 'r') as f:
|
34
|
+
self.params = yaml.safe_load(f)
|
35
|
+
|
36
|
+
# init data
|
37
|
+
self.origin_file_path = self.params['data']['origin_file_path']
|
38
|
+
self.question_path = self.params['data']['question']
|
39
|
+
self.save_answer_path = self.params['data']['save_answer']
|
40
|
+
|
41
|
+
# init embeddings
|
42
|
+
embedding_params = self.params['embedding']
|
43
|
+
embedding_object = import_by_type(_type='embeddings', name=embedding_params['type'])
|
44
|
+
if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
|
45
|
+
embedding_params.pop('type')
|
46
|
+
self.embeddings = embedding_object(
|
47
|
+
http_client=httpx.Client(proxies=embedding_params['openai_proxy']), **embedding_params
|
48
|
+
)
|
49
|
+
else:
|
50
|
+
embedding_params.pop('type')
|
51
|
+
self.embeddings = embedding_object(**embedding_params)
|
52
|
+
|
53
|
+
# init llm
|
54
|
+
llm_params = self.params['chat_llm']
|
55
|
+
llm_object = import_by_type(_type='llms', name=llm_params['type'])
|
56
|
+
if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
|
57
|
+
llm_params.pop('type')
|
58
|
+
self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
|
59
|
+
else:
|
60
|
+
llm_params.pop('type')
|
61
|
+
self.llm = llm_object(**llm_params)
|
62
|
+
|
63
|
+
# milvus
|
64
|
+
self.vector_store = Milvus(
|
65
|
+
embedding_function=self.embeddings,
|
66
|
+
connection_args={
|
67
|
+
"host": self.params['milvus']['host'],
|
68
|
+
"port": self.params['milvus']['port'],
|
69
|
+
},
|
70
|
+
)
|
71
|
+
|
72
|
+
# es
|
73
|
+
self.keyword_store = ElasticKeywordsSearch(
|
74
|
+
index_name='default_es',
|
75
|
+
elasticsearch_url=self.params['elasticsearch']['url'],
|
76
|
+
ssl_verify=self.params['elasticsearch']['ssl_verify'],
|
77
|
+
)
|
78
|
+
|
79
|
+
# init retriever
|
80
|
+
retriever_list = []
|
81
|
+
retrievers = self.params['retriever']['retrievers']
|
82
|
+
for retriever in retrievers:
|
83
|
+
retriever_type = retriever.pop('type')
|
84
|
+
retriever_params = {
|
85
|
+
'vector_store': self.vector_store,
|
86
|
+
'keyword_store': self.keyword_store,
|
87
|
+
'splitter_kwargs': retriever['splitter'],
|
88
|
+
'retrieval_kwargs': retriever['retrieval'],
|
89
|
+
}
|
90
|
+
retriever_list.append(self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
|
91
|
+
self.retriever = EnsembleRetriever(retrievers=retriever_list)
|
92
|
+
|
93
|
+
def _post_init_retriever(self, retriever_type, **kwargs):
|
94
|
+
retriever_classes = {
|
95
|
+
'KeywordRetriever': KeywordRetriever,
|
96
|
+
'BaselineVectorRetriever': BaselineVectorRetriever,
|
97
|
+
'MixRetriever': MixRetriever,
|
98
|
+
'SmallerChunksVectorRetriever': SmallerChunksVectorRetriever,
|
99
|
+
}
|
100
|
+
if retriever_type not in retriever_classes:
|
101
|
+
raise ValueError(f'Unknown retriever type: {retriever_type}')
|
102
|
+
|
103
|
+
input_kwargs = {}
|
104
|
+
splitter_params = kwargs.pop('splitter_kwargs')
|
105
|
+
for key, value in splitter_params.items():
|
106
|
+
splitter_obj = import_by_type(_type='textsplitters', name=value.pop('type'))
|
107
|
+
input_kwargs[key] = splitter_obj(**value)
|
108
|
+
|
109
|
+
retrieval_params = kwargs.pop('retrieval_kwargs')
|
110
|
+
for key, value in retrieval_params.items():
|
111
|
+
input_kwargs[key] = value
|
112
|
+
|
113
|
+
input_kwargs['vector_store'] = kwargs.pop('vector_store')
|
114
|
+
input_kwargs['keyword_store'] = kwargs.pop('keyword_store')
|
115
|
+
|
116
|
+
retriever_class = retriever_classes[retriever_type]
|
117
|
+
return retriever_class(**input_kwargs)
|
118
|
+
|
119
|
+
def file2knowledge(self):
|
120
|
+
"""
|
121
|
+
file to knowledge
|
122
|
+
"""
|
123
|
+
df = pd.read_excel(self.question_path)
|
124
|
+
if ('文件名' not in df.columns) or ('知识库名' not in df.columns):
|
125
|
+
raise Exception(f'文件名 or 知识库名 not in {self.question_path}.')
|
126
|
+
|
127
|
+
loader_params = self.params['loader']
|
128
|
+
loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
|
129
|
+
|
130
|
+
all_questions_info = df.to_dict('records')
|
131
|
+
collectionname2filename = defaultdict(set)
|
132
|
+
for info in all_questions_info:
|
133
|
+
# 存入set,去掉重复的文件名
|
134
|
+
collectionname2filename[info['知识库名']].add(info['文件名'])
|
135
|
+
|
136
|
+
for collection_name in tqdm(collectionname2filename):
|
137
|
+
all_file_paths = []
|
138
|
+
for file_name in collectionname2filename[collection_name]:
|
139
|
+
file_path = os.path.join(self.origin_file_path, file_name)
|
140
|
+
if not os.path.exists(file_path):
|
141
|
+
raise Exception(f'{file_path} not exists.')
|
142
|
+
# file path可以是文件夹或者单个文件
|
143
|
+
if os.path.isdir(file_path):
|
144
|
+
# 文件夹包含多个文件
|
145
|
+
all_file_paths.extend(
|
146
|
+
[os.path.join(file_path, name) for name in os.listdir(file_path) if not name.startswith('.')]
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
# 单个文件
|
150
|
+
all_file_paths.append(file_path)
|
151
|
+
|
152
|
+
# 当前知识库需要存储的所有文件
|
153
|
+
collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
|
154
|
+
for index, each_file_path in enumerate(all_file_paths):
|
155
|
+
logger.info(f'each_file_path: {each_file_path}')
|
156
|
+
loader = loader_object(
|
157
|
+
file_name=os.path.basename(each_file_path), file_path=each_file_path, **loader_params
|
158
|
+
)
|
159
|
+
documents = loader.load()
|
160
|
+
|
161
|
+
# # load from text
|
162
|
+
# if each_file_path.endswith('.pdf'):
|
163
|
+
# with open(each_file_path.replace('.pdf', '.txt'), 'r') as f:
|
164
|
+
# content = f.read()
|
165
|
+
# documents = [Document(page_content=content, metadata={'source': os.path.basename(each_file_path)})]
|
166
|
+
|
167
|
+
logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
|
168
|
+
if len(documents[0].page_content) == 0:
|
169
|
+
logger.error(f'{each_file_path} page_content is empty.')
|
170
|
+
|
171
|
+
# add aux infoerror
|
172
|
+
add_aux_info = self.params['retriever'].get('add_aux_info', False)
|
173
|
+
if add_aux_info:
|
174
|
+
for doc in documents:
|
175
|
+
try:
|
176
|
+
title = extract_title(llm=self.llm, text=doc.page_content)
|
177
|
+
logger.info(f'extract title: {title}')
|
178
|
+
except Exception as e:
|
179
|
+
logger.error(f"Failed to extract title: {e}")
|
180
|
+
title = ''
|
181
|
+
doc.metadata['title'] = title
|
182
|
+
|
183
|
+
vector_drop_old = self.params['milvus']['drop_old'] if index == 0 else False
|
184
|
+
keyword_drop_old = self.params['elasticsearch']['drop_old'] if index == 0 else False
|
185
|
+
for idx, retriever in enumerate(self.retriever.retrievers):
|
186
|
+
retriever.add_documents(documents, collection_name, vector_drop_old, add_aux_info=add_aux_info)
|
187
|
+
|
188
|
+
def retrieval_and_rerank(self, question, collection_name, max_content=100000):
|
189
|
+
"""
|
190
|
+
retrieval and rerank
|
191
|
+
"""
|
192
|
+
collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
|
193
|
+
# EnsembleRetriever直接检索召回会默认去重
|
194
|
+
docs = self.retriever.get_relevant_documents(query=question, collection_name=collection_name)
|
195
|
+
logger.info(f'retrieval docs origin: {len(docs)}')
|
196
|
+
|
197
|
+
# delete duplicate
|
198
|
+
if self.params['post_retrieval']['delete_duplicate']:
|
199
|
+
logger.info(f'origin docs: {len(docs)}')
|
200
|
+
all_contents = []
|
201
|
+
docs_no_dup = []
|
202
|
+
for index, doc in enumerate(docs):
|
203
|
+
doc_content = doc.page_content
|
204
|
+
if doc_content in all_contents:
|
205
|
+
continue
|
206
|
+
all_contents.append(doc_content)
|
207
|
+
docs_no_dup.append(doc)
|
208
|
+
docs = docs_no_dup
|
209
|
+
logger.info(f'delete duplicate docs: {len(docs)}')
|
210
|
+
|
211
|
+
# rerank
|
212
|
+
if self.params['post_retrieval']['with_rank'] and len(docs):
|
213
|
+
if not hasattr(self, 'ranker'):
|
214
|
+
rerank_params = self.params['post_retrieval']['rerank']
|
215
|
+
rerank_type = rerank_params.pop('type')
|
216
|
+
rerank_object = import_class(f'bisheng_langchain.rag.rerank.{rerank_type}')
|
217
|
+
self.ranker = rerank_object(**rerank_params)
|
218
|
+
docs = getattr(self, 'ranker').sort_and_filter(question, docs)
|
219
|
+
|
220
|
+
# delete redundancy according to max_content
|
221
|
+
doc_num, doc_content_sum = 0, 0
|
222
|
+
for doc in docs:
|
223
|
+
doc_content_sum += len(doc.page_content)
|
224
|
+
if doc_content_sum > max_content:
|
225
|
+
break
|
226
|
+
doc_num += 1
|
227
|
+
docs = docs[:doc_num]
|
228
|
+
logger.info(f'retrieval docs after delete redundancy: {len(docs)}')
|
229
|
+
|
230
|
+
# 按照文档的source和chunk_index排序,保证上下文的连贯性和一致性
|
231
|
+
if self.params['post_retrieval'].get('sort_by_source_and_index', False):
|
232
|
+
logger.info('sort chunks by source and chunk_index')
|
233
|
+
docs = sorted(docs, key=lambda x: (x.metadata['source'], x.metadata['chunk_index']))
|
234
|
+
return docs
|
235
|
+
|
236
|
+
def load_documents(self, file_name, max_content=100000):
|
237
|
+
"""
|
238
|
+
直接加载文档,如果文档过长,直接截断处理;
|
239
|
+
max_content: max content len of llm
|
240
|
+
"""
|
241
|
+
file_path = os.path.join(self.origin_file_path, file_name)
|
242
|
+
if not os.path.exists(file_path):
|
243
|
+
raise Exception(f'{file_path} not exists.')
|
244
|
+
if os.path.isdir(file_path):
|
245
|
+
raise Exception(f'{file_path} is a directory.')
|
246
|
+
|
247
|
+
loader_params = copy.deepcopy(self.params['loader'])
|
248
|
+
loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
|
249
|
+
loader = loader_object(file_name=file_name, file_path=file_path, **loader_params)
|
250
|
+
|
251
|
+
documents = loader.load()
|
252
|
+
logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
|
253
|
+
for doc in documents:
|
254
|
+
doc.page_content = doc.page_content[:max_content]
|
255
|
+
return documents
|
256
|
+
|
257
|
+
def question_answering(self):
|
258
|
+
"""
|
259
|
+
question answer over knowledge
|
260
|
+
"""
|
261
|
+
df = pd.read_excel(self.question_path)
|
262
|
+
all_questions_info = df.to_dict('records')
|
263
|
+
if 'prompt_type' in self.params['generate']:
|
264
|
+
prompt_type = self.params['generate']['prompt_type']
|
265
|
+
prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
|
266
|
+
else:
|
267
|
+
prompt = None
|
268
|
+
if not isinstance(self.llm, ChatCohere):
|
269
|
+
qa_chain = load_qa_chain(
|
270
|
+
llm=self.llm, chain_type=self.params['generate']['chain_type'], prompt=prompt, verbose=False
|
271
|
+
)
|
272
|
+
file2docs = dict()
|
273
|
+
for questions_info in tqdm(all_questions_info):
|
274
|
+
question = questions_info['问题']
|
275
|
+
file_name = questions_info['文件名']
|
276
|
+
collection_name = questions_info['知识库名']
|
277
|
+
|
278
|
+
# if question != '请分析江苏中设集团股份有限公司2021年重大关联交易的情况。':
|
279
|
+
# continue
|
280
|
+
|
281
|
+
if self.params['generate']['with_retrieval']:
|
282
|
+
# retrieval and rerank
|
283
|
+
docs = self.retrieval_and_rerank(question, collection_name, max_content=self.params['generate']['max_content'])
|
284
|
+
else:
|
285
|
+
# load document
|
286
|
+
if file_name not in file2docs:
|
287
|
+
docs = self.load_documents(file_name, max_content=self.params['generate']['max_content'])
|
288
|
+
file2docs[file_name] = docs
|
289
|
+
else:
|
290
|
+
docs = file2docs[file_name]
|
291
|
+
|
292
|
+
if isinstance(self.llm, ChatCohere):
|
293
|
+
try:
|
294
|
+
# cohere rag
|
295
|
+
# messages = prompt.format_prompt(question=question).to_messages()
|
296
|
+
# ans = self.llm.invoke(messages, source_documents=docs)
|
297
|
+
# rag_answer = ans.content
|
298
|
+
|
299
|
+
# cohere rag by raw prompt
|
300
|
+
documents = ''
|
301
|
+
for i, doc in enumerate(docs):
|
302
|
+
# documents += f'Document: {i}\n'
|
303
|
+
# documents += 'text:' + doc.page_content + '\n\n'
|
304
|
+
|
305
|
+
documents += doc.page_content + '\n\n'
|
306
|
+
messages = prompt.format_prompt(question=question, documents=documents).to_messages()
|
307
|
+
ans = self.llm.invoke(messages, raw_prompting=True)
|
308
|
+
rag_answer = ans.content.replace('Answer: ', '')
|
309
|
+
except Exception as e:
|
310
|
+
logger.error(f'question: {question}\nerror: {e}')
|
311
|
+
ans = {'output_text': str(e)}
|
312
|
+
rag_answer = ans['output_text']
|
313
|
+
else:
|
314
|
+
# question answer
|
315
|
+
try:
|
316
|
+
ans = qa_chain({"input_documents": docs, "question": question}, return_only_outputs=True)
|
317
|
+
except Exception as e:
|
318
|
+
logger.error(f'question: {question}\nerror: {e}')
|
319
|
+
ans = {'output_text': str(e)}
|
320
|
+
rag_answer = ans['output_text']
|
321
|
+
|
322
|
+
# context = '\n\n'.join([doc.page_content for doc in docs])
|
323
|
+
# content = prompt.format(context=context, question=question)
|
324
|
+
|
325
|
+
# for rate_limit
|
326
|
+
# time.sleep(3)
|
327
|
+
logger.info(f'question: {question}\nans: {rag_answer}\n')
|
328
|
+
questions_info['rag_answer'] = rag_answer
|
329
|
+
# questions_info['rag_context'] = '\n----------------\n'.join([doc.page_content for doc in docs])
|
330
|
+
# questions_info['rag_context'] = content
|
331
|
+
|
332
|
+
df = pd.DataFrame(all_questions_info)
|
333
|
+
df.to_excel(self.save_answer_path, index=False)
|
334
|
+
|
335
|
+
def score(self):
|
336
|
+
"""
|
337
|
+
score
|
338
|
+
"""
|
339
|
+
metric_params = self.params['metric']
|
340
|
+
if metric_params['type'] == 'bisheng-ragas':
|
341
|
+
score_params = {
|
342
|
+
'excel_path': self.save_answer_path,
|
343
|
+
'save_path': os.path.dirname(self.save_answer_path),
|
344
|
+
'question_column': metric_params['question_column'],
|
345
|
+
'gt_column': metric_params['gt_column'],
|
346
|
+
'answer_column': metric_params['answer_column'],
|
347
|
+
'query_type_column': metric_params.get('query_type_column', None),
|
348
|
+
'contexts_column': metric_params.get('contexts_column', None),
|
349
|
+
'metrics': metric_params['metrics'],
|
350
|
+
'batch_size': metric_params['batch_size'],
|
351
|
+
'gt_split_column': metric_params.get('gt_split_column', None),
|
352
|
+
'whether_gtsplit': metric_params.get('whether_gtsplit', False), # 是否需要模型对gt进行要点拆分
|
353
|
+
}
|
354
|
+
rag_score = RagScore(**score_params)
|
355
|
+
rag_score.score()
|
356
|
+
else:
|
357
|
+
# todo: 其他评分方法
|
358
|
+
pass
|
359
|
+
|
360
|
+
|
361
|
+
if __name__ == '__main__':
|
362
|
+
parser = argparse.ArgumentParser(description='Process some integers.')
|
363
|
+
# 添加参数
|
364
|
+
parser.add_argument('--mode', type=str, default='qa', help='upload or qa or score')
|
365
|
+
parser.add_argument('--params', type=str, default='config/test/baseline_s2b.yaml', help='bisheng rag params')
|
366
|
+
# 解析参数
|
367
|
+
args = parser.parse_args()
|
368
|
+
|
369
|
+
rag = BishengRagPipeline(args.params)
|
370
|
+
|
371
|
+
if args.mode == 'upload':
|
372
|
+
rag.file2knowledge()
|
373
|
+
elif args.mode == 'qa':
|
374
|
+
rag.question_answering()
|
375
|
+
elif args.mode == 'score':
|
376
|
+
rag.score()
|
@@ -0,0 +1,288 @@
|
|
1
|
+
import time
|
2
|
+
import os
|
3
|
+
import yaml
|
4
|
+
import httpx
|
5
|
+
from typing import Any, Dict, Tuple, Type, Union, Optional
|
6
|
+
from loguru import logger
|
7
|
+
from langchain_core.tools import BaseTool, Tool
|
8
|
+
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
9
|
+
from langchain_core.language_models.base import LanguageModelLike
|
10
|
+
from langchain.chains.question_answering import load_qa_chain
|
11
|
+
from bisheng_langchain.retrievers import EnsembleRetriever
|
12
|
+
from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
|
13
|
+
from bisheng_langchain.rag.init_retrievers import (
|
14
|
+
BaselineVectorRetriever,
|
15
|
+
KeywordRetriever,
|
16
|
+
MixRetriever,
|
17
|
+
SmallerChunksVectorRetriever,
|
18
|
+
)
|
19
|
+
from bisheng_langchain.rag.utils import import_by_type, import_class
|
20
|
+
from bisheng_langchain.rag.extract_info import extract_title
|
21
|
+
|
22
|
+
|
23
|
+
class MultArgsSchemaTool(Tool):
|
24
|
+
|
25
|
+
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
26
|
+
# For backwards compatibility, if run_input is a string,
|
27
|
+
# pass as a positional argument.
|
28
|
+
if isinstance(tool_input, str):
|
29
|
+
return (tool_input, ), {}
|
30
|
+
else:
|
31
|
+
return (), tool_input
|
32
|
+
|
33
|
+
|
34
|
+
class BishengRAGTool:
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
vector_store: Optional[Milvus] = None,
|
39
|
+
keyword_store: Optional[ElasticKeywordsSearch] = None,
|
40
|
+
llm: Optional[LanguageModelLike] = None,
|
41
|
+
collection_name: Optional[str] = None,
|
42
|
+
**kwargs
|
43
|
+
) -> None:
|
44
|
+
if collection_name is None and (keyword_store is None or vector_store is None):
|
45
|
+
raise ValueError('collection_name must be provided if keyword_store or vector_store is not provided')
|
46
|
+
self.collection_name = collection_name
|
47
|
+
|
48
|
+
yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config/baseline_v2.yaml')
|
49
|
+
with open(yaml_path, 'r', encoding='utf-8') as f:
|
50
|
+
self.params = yaml.safe_load(f)
|
51
|
+
|
52
|
+
# update params
|
53
|
+
max_content = kwargs.get("max_content", 15000)
|
54
|
+
sort_by_source_and_index = kwargs.get("sort_by_source_and_index", True)
|
55
|
+
self.params['generate']['max_content'] = max_content
|
56
|
+
self.params['post_retrieval']['sort_by_source_and_index'] = sort_by_source_and_index
|
57
|
+
|
58
|
+
# init milvus
|
59
|
+
if vector_store:
|
60
|
+
self.vector_store = vector_store
|
61
|
+
else:
|
62
|
+
# init embeddings
|
63
|
+
embedding_params = self.params['embedding']
|
64
|
+
embedding_object = import_by_type(_type='embeddings', name=embedding_params['type'])
|
65
|
+
if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
|
66
|
+
embedding_params.pop('type')
|
67
|
+
self.embeddings = embedding_object(
|
68
|
+
http_client=httpx.Client(proxies=embedding_params['openai_proxy']), **embedding_params
|
69
|
+
)
|
70
|
+
else:
|
71
|
+
embedding_params.pop('type')
|
72
|
+
self.embeddings = embedding_object(**embedding_params)
|
73
|
+
|
74
|
+
self.vector_store = Milvus(
|
75
|
+
embedding_function=self.embeddings,
|
76
|
+
connection_args={
|
77
|
+
"host": self.params['milvus']['host'],
|
78
|
+
"port": self.params['milvus']['port'],
|
79
|
+
},
|
80
|
+
)
|
81
|
+
|
82
|
+
# init keyword store
|
83
|
+
if keyword_store:
|
84
|
+
self.keyword_store = keyword_store
|
85
|
+
else:
|
86
|
+
self.keyword_store = ElasticKeywordsSearch(
|
87
|
+
index_name='default_es',
|
88
|
+
elasticsearch_url=self.params['elasticsearch']['url'],
|
89
|
+
ssl_verify=self.params['elasticsearch']['ssl_verify'],
|
90
|
+
)
|
91
|
+
|
92
|
+
# init llm
|
93
|
+
if llm:
|
94
|
+
self.llm = llm
|
95
|
+
else:
|
96
|
+
llm_params = self.params['chat_llm']
|
97
|
+
llm_object = import_by_type(_type='llms', name=llm_params['type'])
|
98
|
+
if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
|
99
|
+
llm_params.pop('type')
|
100
|
+
self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
|
101
|
+
else:
|
102
|
+
llm_params.pop('type')
|
103
|
+
self.llm = llm_object(**llm_params)
|
104
|
+
|
105
|
+
# init retriever
|
106
|
+
retriever_list = []
|
107
|
+
retrievers = self.params['retriever']['retrievers']
|
108
|
+
for retriever in retrievers:
|
109
|
+
retriever_type = retriever.pop('type')
|
110
|
+
retriever_params = {
|
111
|
+
'vector_store': self.vector_store,
|
112
|
+
'keyword_store': self.keyword_store,
|
113
|
+
'splitter_kwargs': retriever['splitter'],
|
114
|
+
'retrieval_kwargs': retriever['retrieval'],
|
115
|
+
}
|
116
|
+
retriever_list.append(self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
|
117
|
+
self.retriever = EnsembleRetriever(retrievers=retriever_list)
|
118
|
+
|
119
|
+
# init qa chain
|
120
|
+
if 'prompt_type' in self.params['generate']:
|
121
|
+
prompt_type = self.params['generate']['prompt_type']
|
122
|
+
prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
|
123
|
+
else:
|
124
|
+
prompt = None
|
125
|
+
self.qa_chain = load_qa_chain(
|
126
|
+
llm=self.llm,
|
127
|
+
chain_type=self.params['generate']['chain_type'],
|
128
|
+
prompt=prompt,
|
129
|
+
verbose=False
|
130
|
+
)
|
131
|
+
|
132
|
+
def _post_init_retriever(self, retriever_type, **kwargs):
|
133
|
+
retriever_classes = {
|
134
|
+
'KeywordRetriever': KeywordRetriever,
|
135
|
+
'BaselineVectorRetriever': BaselineVectorRetriever,
|
136
|
+
'MixRetriever': MixRetriever,
|
137
|
+
'SmallerChunksVectorRetriever': SmallerChunksVectorRetriever,
|
138
|
+
}
|
139
|
+
if retriever_type not in retriever_classes:
|
140
|
+
raise ValueError(f'Unknown retriever type: {retriever_type}')
|
141
|
+
|
142
|
+
input_kwargs = {}
|
143
|
+
splitter_params = kwargs.pop('splitter_kwargs')
|
144
|
+
for key, value in splitter_params.items():
|
145
|
+
splitter_obj = import_by_type(_type='textsplitters', name=value.pop('type'))
|
146
|
+
input_kwargs[key] = splitter_obj(**value)
|
147
|
+
|
148
|
+
retrieval_params = kwargs.pop('retrieval_kwargs')
|
149
|
+
for key, value in retrieval_params.items():
|
150
|
+
input_kwargs[key] = value
|
151
|
+
|
152
|
+
input_kwargs['vector_store'] = kwargs.pop('vector_store')
|
153
|
+
input_kwargs['keyword_store'] = kwargs.pop('keyword_store')
|
154
|
+
|
155
|
+
retriever_class = retriever_classes[retriever_type]
|
156
|
+
return retriever_class(**input_kwargs)
|
157
|
+
|
158
|
+
def file2knowledge(self, file_path, drop_old=True):
|
159
|
+
"""
|
160
|
+
file to knowledge
|
161
|
+
"""
|
162
|
+
loader_params = self.params['loader']
|
163
|
+
loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
|
164
|
+
|
165
|
+
logger.info(f'file_path: {file_path}')
|
166
|
+
loader = loader_object(
|
167
|
+
file_name=os.path.basename(file_path), file_path=file_path, **loader_params
|
168
|
+
)
|
169
|
+
documents = loader.load()
|
170
|
+
logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
|
171
|
+
if len(documents[0].page_content) == 0:
|
172
|
+
logger.error(f'{file_path} page_content is empty.')
|
173
|
+
|
174
|
+
# add aux info
|
175
|
+
add_aux_info = self.params['retriever'].get('add_aux_info', False)
|
176
|
+
if add_aux_info:
|
177
|
+
for doc in documents:
|
178
|
+
try:
|
179
|
+
title = extract_title(llm=self.llm, text=doc.page_content)
|
180
|
+
logger.info(f'extract title: {title}')
|
181
|
+
except Exception as e:
|
182
|
+
logger.error(f"Failed to extract title: {e}")
|
183
|
+
title = ''
|
184
|
+
doc.metadata['title'] = title
|
185
|
+
|
186
|
+
for idx, retriever in enumerate(self.retriever.retrievers):
|
187
|
+
retriever.add_documents(
|
188
|
+
documents,
|
189
|
+
self.collection_name,
|
190
|
+
drop_old=drop_old,
|
191
|
+
add_aux_info=add_aux_info
|
192
|
+
)
|
193
|
+
|
194
|
+
def retrieval_and_rerank(self, query):
|
195
|
+
"""
|
196
|
+
retrieval and rerank
|
197
|
+
"""
|
198
|
+
# EnsembleRetriever直接检索召回会默认去重
|
199
|
+
docs = self.retriever.get_relevant_documents(
|
200
|
+
query=query,
|
201
|
+
collection_name=self.collection_name
|
202
|
+
)
|
203
|
+
logger.info(f'retrieval docs origin: {len(docs)}')
|
204
|
+
|
205
|
+
# delete redundancy according to max_content
|
206
|
+
doc_num, doc_content_sum = 0, 0
|
207
|
+
for doc in docs:
|
208
|
+
doc_content_sum += len(doc.page_content)
|
209
|
+
if doc_content_sum > self.params['generate']['max_content']:
|
210
|
+
break
|
211
|
+
doc_num += 1
|
212
|
+
docs = docs[:doc_num]
|
213
|
+
logger.info(f'retrieval docs after delete redundancy: {len(docs)}')
|
214
|
+
|
215
|
+
# 按照文档的source和chunk_index排序,保证上下文的连贯性和一致性
|
216
|
+
if self.params['post_retrieval'].get('sort_by_source_and_index', False):
|
217
|
+
logger.info('sort chunks by source and chunk_index')
|
218
|
+
docs = sorted(docs, key=lambda x: (x.metadata['source'], x.metadata['chunk_index']))
|
219
|
+
return docs
|
220
|
+
|
221
|
+
def run(self, query) -> str:
|
222
|
+
docs = self.retrieval_and_rerank(query)
|
223
|
+
try:
|
224
|
+
ans = self.qa_chain({"input_documents": docs, "question": query}, return_only_outputs=True)
|
225
|
+
except Exception as e:
|
226
|
+
logger.error(f'question: {query}\nerror: {e}')
|
227
|
+
ans = {'output_text': str(e)}
|
228
|
+
rag_answer = ans['output_text']
|
229
|
+
return rag_answer
|
230
|
+
|
231
|
+
async def arun(self, query: str) -> str:
|
232
|
+
rag_answer = self.run(query)
|
233
|
+
return rag_answer
|
234
|
+
|
235
|
+
@classmethod
|
236
|
+
def get_rag_tool(cls, name, description, **kwargs: Any) -> BaseTool:
|
237
|
+
class InputArgs(BaseModel):
|
238
|
+
query: str = Field(description='question asked by the user.')
|
239
|
+
|
240
|
+
return MultArgsSchemaTool(name=name,
|
241
|
+
description=description,
|
242
|
+
func=cls(**kwargs).run,
|
243
|
+
coroutine=cls(**kwargs).arun,
|
244
|
+
args_schema=InputArgs)
|
245
|
+
|
246
|
+
|
247
|
+
if __name__ == '__main__':
|
248
|
+
# rag_tool = BishengRAGTool(collection_name='rag_finance_report_0_test')
|
249
|
+
# rag_tool.file2knowledge(file_path='/home/public/rag_benchmark_finance_report/金融年报财报的来源文件/2021-04-23__金宇生物技术股份有限公司__600201__生物股份__2020年__年度报告.pdf')
|
250
|
+
|
251
|
+
from langchain.chat_models import ChatOpenAI
|
252
|
+
from langchain.embeddings import OpenAIEmbeddings
|
253
|
+
# embedding
|
254
|
+
embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')
|
255
|
+
# llm
|
256
|
+
llm = ChatOpenAI(model='gpt-4-1106-preview', temperature=0.01)
|
257
|
+
collection_name = 'rag_finance_report_0_benchmark_caibao_1000_source_title'
|
258
|
+
# milvus
|
259
|
+
vector_store = Milvus(
|
260
|
+
collection_name=collection_name,
|
261
|
+
embedding_function=embeddings,
|
262
|
+
connection_args={
|
263
|
+
"host": '110.16.193.170',
|
264
|
+
"port": '50062',
|
265
|
+
},
|
266
|
+
)
|
267
|
+
# es
|
268
|
+
keyword_store = ElasticKeywordsSearch(
|
269
|
+
index_name=collection_name,
|
270
|
+
elasticsearch_url='http://110.16.193.170:50062/es',
|
271
|
+
ssl_verify={'basic_auth': ["elastic", "oSGL-zVvZ5P3Tm7qkDLC"]},
|
272
|
+
)
|
273
|
+
|
274
|
+
tool = BishengRAGTool.get_rag_tool(
|
275
|
+
name='rag_knowledge_retrieve',
|
276
|
+
description='金融年报财报知识库问答',
|
277
|
+
vector_store=vector_store,
|
278
|
+
keyword_store=keyword_store,
|
279
|
+
llm=llm
|
280
|
+
)
|
281
|
+
print(tool.run('能否根据2020年金宇生物技术股份有限公司的年报,给我简要介绍一下报告期内公司的社会责任工作情况?'))
|
282
|
+
|
283
|
+
# tool = BishengRAGTool.get_rag_tool(
|
284
|
+
# name='rag_knowledge_retrieve',
|
285
|
+
# description='金融年报财报知识库问答',
|
286
|
+
# collection_name='rag_finance_report_0_benchmark_caibao_1000_source_title'
|
287
|
+
# )
|
288
|
+
# print(tool.run('能否根据2020年金宇生物技术股份有限公司的年报,给我简要介绍一下报告期内公司的社会责任工作情况?'))
|