bisheng-langchain 0.3.0rc0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. bisheng_langchain/chat_models/host_llm.py +1 -1
  2. bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +5 -3
  3. bisheng_langchain/gpts/agent_types/llm_functions_agent.py +7 -1
  4. bisheng_langchain/gpts/assistant.py +8 -5
  5. bisheng_langchain/gpts/auto_optimization.py +28 -27
  6. bisheng_langchain/gpts/auto_tool_selected.py +14 -15
  7. bisheng_langchain/gpts/load_tools.py +53 -1
  8. bisheng_langchain/gpts/prompts/__init__.py +4 -2
  9. bisheng_langchain/gpts/prompts/assistant_prompt_base.py +1 -0
  10. bisheng_langchain/gpts/prompts/assistant_prompt_cohere.py +19 -0
  11. bisheng_langchain/gpts/prompts/opening_dialog_prompt.py +1 -1
  12. bisheng_langchain/gpts/tools/api_tools/__init__.py +1 -1
  13. bisheng_langchain/gpts/tools/api_tools/base.py +3 -3
  14. bisheng_langchain/gpts/tools/api_tools/flow.py +19 -7
  15. bisheng_langchain/gpts/tools/api_tools/macro_data.py +175 -4
  16. bisheng_langchain/gpts/tools/api_tools/openapi.py +101 -0
  17. bisheng_langchain/gpts/tools/api_tools/sina.py +2 -2
  18. bisheng_langchain/gpts/tools/code_interpreter/tool.py +118 -39
  19. bisheng_langchain/rag/__init__.py +5 -0
  20. bisheng_langchain/rag/bisheng_rag_pipeline.py +320 -0
  21. bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +359 -0
  22. bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +376 -0
  23. bisheng_langchain/rag/bisheng_rag_tool.py +288 -0
  24. bisheng_langchain/rag/config/baseline.yaml +86 -0
  25. bisheng_langchain/rag/config/baseline_caibao.yaml +82 -0
  26. bisheng_langchain/rag/config/baseline_caibao_knowledge_v2.yaml +110 -0
  27. bisheng_langchain/rag/config/baseline_caibao_v2.yaml +112 -0
  28. bisheng_langchain/rag/config/baseline_demo_v2.yaml +92 -0
  29. bisheng_langchain/rag/config/baseline_s2b_mix.yaml +88 -0
  30. bisheng_langchain/rag/config/baseline_v2.yaml +90 -0
  31. bisheng_langchain/rag/extract_info.py +38 -0
  32. bisheng_langchain/rag/init_retrievers/__init__.py +4 -0
  33. bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +61 -0
  34. bisheng_langchain/rag/init_retrievers/keyword_retriever.py +65 -0
  35. bisheng_langchain/rag/init_retrievers/mix_retriever.py +103 -0
  36. bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +92 -0
  37. bisheng_langchain/rag/prompts/__init__.py +9 -0
  38. bisheng_langchain/rag/prompts/extract_key_prompt.py +34 -0
  39. bisheng_langchain/rag/prompts/prompt.py +47 -0
  40. bisheng_langchain/rag/prompts/prompt_cohere.py +111 -0
  41. bisheng_langchain/rag/qa_corpus/__init__.py +0 -0
  42. bisheng_langchain/rag/qa_corpus/qa_generator.py +143 -0
  43. bisheng_langchain/rag/rerank/__init__.py +5 -0
  44. bisheng_langchain/rag/rerank/rerank.py +48 -0
  45. bisheng_langchain/rag/rerank/rerank_benchmark.py +139 -0
  46. bisheng_langchain/rag/run_qa_gen_web.py +47 -0
  47. bisheng_langchain/rag/run_rag_evaluate_web.py +55 -0
  48. bisheng_langchain/rag/scoring/__init__.py +0 -0
  49. bisheng_langchain/rag/scoring/llama_index_score.py +91 -0
  50. bisheng_langchain/rag/scoring/ragas_score.py +183 -0
  51. bisheng_langchain/rag/utils.py +181 -0
  52. bisheng_langchain/retrievers/ensemble.py +2 -1
  53. bisheng_langchain/vectorstores/elastic_keywords_search.py +2 -1
  54. {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/METADATA +1 -1
  55. {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/RECORD +57 -22
  56. bisheng_langchain/gpts/prompts/base_prompt.py +0 -1
  57. {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/WHEEL +0 -0
  58. {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,376 @@
1
+ import argparse
2
+ import copy
3
+ import time
4
+ import inspect
5
+ import os
6
+ import httpx
7
+ import pandas as pd
8
+ import yaml
9
+ import math
10
+ from tqdm import tqdm
11
+ from loguru import logger
12
+ from collections import defaultdict
13
+ from bisheng_langchain.retrievers import EnsembleRetriever
14
+ from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
15
+ from langchain.docstore.document import Document
16
+ from langchain.chains.question_answering import load_qa_chain
17
+ from langchain_community.chat_models.cohere import ChatCohere
18
+ from bisheng_langchain.rag.init_retrievers import (
19
+ BaselineVectorRetriever,
20
+ KeywordRetriever,
21
+ MixRetriever,
22
+ SmallerChunksVectorRetriever,
23
+ )
24
+ from bisheng_langchain.rag.scoring.ragas_score import RagScore
25
+ from bisheng_langchain.rag.extract_info import extract_title
26
+ from bisheng_langchain.rag.utils import import_by_type, import_class
27
+
28
+
29
+ class BishengRagPipeline:
30
+
31
+ def __init__(self, yaml_path) -> None:
32
+ self.yaml_path = yaml_path
33
+ with open(self.yaml_path, 'r') as f:
34
+ self.params = yaml.safe_load(f)
35
+
36
+ # init data
37
+ self.origin_file_path = self.params['data']['origin_file_path']
38
+ self.question_path = self.params['data']['question']
39
+ self.save_answer_path = self.params['data']['save_answer']
40
+
41
+ # init embeddings
42
+ embedding_params = self.params['embedding']
43
+ embedding_object = import_by_type(_type='embeddings', name=embedding_params['type'])
44
+ if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
45
+ embedding_params.pop('type')
46
+ self.embeddings = embedding_object(
47
+ http_client=httpx.Client(proxies=embedding_params['openai_proxy']), **embedding_params
48
+ )
49
+ else:
50
+ embedding_params.pop('type')
51
+ self.embeddings = embedding_object(**embedding_params)
52
+
53
+ # init llm
54
+ llm_params = self.params['chat_llm']
55
+ llm_object = import_by_type(_type='llms', name=llm_params['type'])
56
+ if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
57
+ llm_params.pop('type')
58
+ self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
59
+ else:
60
+ llm_params.pop('type')
61
+ self.llm = llm_object(**llm_params)
62
+
63
+ # milvus
64
+ self.vector_store = Milvus(
65
+ embedding_function=self.embeddings,
66
+ connection_args={
67
+ "host": self.params['milvus']['host'],
68
+ "port": self.params['milvus']['port'],
69
+ },
70
+ )
71
+
72
+ # es
73
+ self.keyword_store = ElasticKeywordsSearch(
74
+ index_name='default_es',
75
+ elasticsearch_url=self.params['elasticsearch']['url'],
76
+ ssl_verify=self.params['elasticsearch']['ssl_verify'],
77
+ )
78
+
79
+ # init retriever
80
+ retriever_list = []
81
+ retrievers = self.params['retriever']['retrievers']
82
+ for retriever in retrievers:
83
+ retriever_type = retriever.pop('type')
84
+ retriever_params = {
85
+ 'vector_store': self.vector_store,
86
+ 'keyword_store': self.keyword_store,
87
+ 'splitter_kwargs': retriever['splitter'],
88
+ 'retrieval_kwargs': retriever['retrieval'],
89
+ }
90
+ retriever_list.append(self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
91
+ self.retriever = EnsembleRetriever(retrievers=retriever_list)
92
+
93
+ def _post_init_retriever(self, retriever_type, **kwargs):
94
+ retriever_classes = {
95
+ 'KeywordRetriever': KeywordRetriever,
96
+ 'BaselineVectorRetriever': BaselineVectorRetriever,
97
+ 'MixRetriever': MixRetriever,
98
+ 'SmallerChunksVectorRetriever': SmallerChunksVectorRetriever,
99
+ }
100
+ if retriever_type not in retriever_classes:
101
+ raise ValueError(f'Unknown retriever type: {retriever_type}')
102
+
103
+ input_kwargs = {}
104
+ splitter_params = kwargs.pop('splitter_kwargs')
105
+ for key, value in splitter_params.items():
106
+ splitter_obj = import_by_type(_type='textsplitters', name=value.pop('type'))
107
+ input_kwargs[key] = splitter_obj(**value)
108
+
109
+ retrieval_params = kwargs.pop('retrieval_kwargs')
110
+ for key, value in retrieval_params.items():
111
+ input_kwargs[key] = value
112
+
113
+ input_kwargs['vector_store'] = kwargs.pop('vector_store')
114
+ input_kwargs['keyword_store'] = kwargs.pop('keyword_store')
115
+
116
+ retriever_class = retriever_classes[retriever_type]
117
+ return retriever_class(**input_kwargs)
118
+
119
+ def file2knowledge(self):
120
+ """
121
+ file to knowledge
122
+ """
123
+ df = pd.read_excel(self.question_path)
124
+ if ('文件名' not in df.columns) or ('知识库名' not in df.columns):
125
+ raise Exception(f'文件名 or 知识库名 not in {self.question_path}.')
126
+
127
+ loader_params = self.params['loader']
128
+ loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
129
+
130
+ all_questions_info = df.to_dict('records')
131
+ collectionname2filename = defaultdict(set)
132
+ for info in all_questions_info:
133
+ # 存入set,去掉重复的文件名
134
+ collectionname2filename[info['知识库名']].add(info['文件名'])
135
+
136
+ for collection_name in tqdm(collectionname2filename):
137
+ all_file_paths = []
138
+ for file_name in collectionname2filename[collection_name]:
139
+ file_path = os.path.join(self.origin_file_path, file_name)
140
+ if not os.path.exists(file_path):
141
+ raise Exception(f'{file_path} not exists.')
142
+ # file path可以是文件夹或者单个文件
143
+ if os.path.isdir(file_path):
144
+ # 文件夹包含多个文件
145
+ all_file_paths.extend(
146
+ [os.path.join(file_path, name) for name in os.listdir(file_path) if not name.startswith('.')]
147
+ )
148
+ else:
149
+ # 单个文件
150
+ all_file_paths.append(file_path)
151
+
152
+ # 当前知识库需要存储的所有文件
153
+ collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
154
+ for index, each_file_path in enumerate(all_file_paths):
155
+ logger.info(f'each_file_path: {each_file_path}')
156
+ loader = loader_object(
157
+ file_name=os.path.basename(each_file_path), file_path=each_file_path, **loader_params
158
+ )
159
+ documents = loader.load()
160
+
161
+ # # load from text
162
+ # if each_file_path.endswith('.pdf'):
163
+ # with open(each_file_path.replace('.pdf', '.txt'), 'r') as f:
164
+ # content = f.read()
165
+ # documents = [Document(page_content=content, metadata={'source': os.path.basename(each_file_path)})]
166
+
167
+ logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
168
+ if len(documents[0].page_content) == 0:
169
+ logger.error(f'{each_file_path} page_content is empty.')
170
+
171
+ # add aux infoerror
172
+ add_aux_info = self.params['retriever'].get('add_aux_info', False)
173
+ if add_aux_info:
174
+ for doc in documents:
175
+ try:
176
+ title = extract_title(llm=self.llm, text=doc.page_content)
177
+ logger.info(f'extract title: {title}')
178
+ except Exception as e:
179
+ logger.error(f"Failed to extract title: {e}")
180
+ title = ''
181
+ doc.metadata['title'] = title
182
+
183
+ vector_drop_old = self.params['milvus']['drop_old'] if index == 0 else False
184
+ keyword_drop_old = self.params['elasticsearch']['drop_old'] if index == 0 else False
185
+ for idx, retriever in enumerate(self.retriever.retrievers):
186
+ retriever.add_documents(documents, collection_name, vector_drop_old, add_aux_info=add_aux_info)
187
+
188
+ def retrieval_and_rerank(self, question, collection_name, max_content=100000):
189
+ """
190
+ retrieval and rerank
191
+ """
192
+ collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
193
+ # EnsembleRetriever直接检索召回会默认去重
194
+ docs = self.retriever.get_relevant_documents(query=question, collection_name=collection_name)
195
+ logger.info(f'retrieval docs origin: {len(docs)}')
196
+
197
+ # delete duplicate
198
+ if self.params['post_retrieval']['delete_duplicate']:
199
+ logger.info(f'origin docs: {len(docs)}')
200
+ all_contents = []
201
+ docs_no_dup = []
202
+ for index, doc in enumerate(docs):
203
+ doc_content = doc.page_content
204
+ if doc_content in all_contents:
205
+ continue
206
+ all_contents.append(doc_content)
207
+ docs_no_dup.append(doc)
208
+ docs = docs_no_dup
209
+ logger.info(f'delete duplicate docs: {len(docs)}')
210
+
211
+ # rerank
212
+ if self.params['post_retrieval']['with_rank'] and len(docs):
213
+ if not hasattr(self, 'ranker'):
214
+ rerank_params = self.params['post_retrieval']['rerank']
215
+ rerank_type = rerank_params.pop('type')
216
+ rerank_object = import_class(f'bisheng_langchain.rag.rerank.{rerank_type}')
217
+ self.ranker = rerank_object(**rerank_params)
218
+ docs = getattr(self, 'ranker').sort_and_filter(question, docs)
219
+
220
+ # delete redundancy according to max_content
221
+ doc_num, doc_content_sum = 0, 0
222
+ for doc in docs:
223
+ doc_content_sum += len(doc.page_content)
224
+ if doc_content_sum > max_content:
225
+ break
226
+ doc_num += 1
227
+ docs = docs[:doc_num]
228
+ logger.info(f'retrieval docs after delete redundancy: {len(docs)}')
229
+
230
+ # 按照文档的source和chunk_index排序,保证上下文的连贯性和一致性
231
+ if self.params['post_retrieval'].get('sort_by_source_and_index', False):
232
+ logger.info('sort chunks by source and chunk_index')
233
+ docs = sorted(docs, key=lambda x: (x.metadata['source'], x.metadata['chunk_index']))
234
+ return docs
235
+
236
+ def load_documents(self, file_name, max_content=100000):
237
+ """
238
+ 直接加载文档,如果文档过长,直接截断处理;
239
+ max_content: max content len of llm
240
+ """
241
+ file_path = os.path.join(self.origin_file_path, file_name)
242
+ if not os.path.exists(file_path):
243
+ raise Exception(f'{file_path} not exists.')
244
+ if os.path.isdir(file_path):
245
+ raise Exception(f'{file_path} is a directory.')
246
+
247
+ loader_params = copy.deepcopy(self.params['loader'])
248
+ loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
249
+ loader = loader_object(file_name=file_name, file_path=file_path, **loader_params)
250
+
251
+ documents = loader.load()
252
+ logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
253
+ for doc in documents:
254
+ doc.page_content = doc.page_content[:max_content]
255
+ return documents
256
+
257
+ def question_answering(self):
258
+ """
259
+ question answer over knowledge
260
+ """
261
+ df = pd.read_excel(self.question_path)
262
+ all_questions_info = df.to_dict('records')
263
+ if 'prompt_type' in self.params['generate']:
264
+ prompt_type = self.params['generate']['prompt_type']
265
+ prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
266
+ else:
267
+ prompt = None
268
+ if not isinstance(self.llm, ChatCohere):
269
+ qa_chain = load_qa_chain(
270
+ llm=self.llm, chain_type=self.params['generate']['chain_type'], prompt=prompt, verbose=False
271
+ )
272
+ file2docs = dict()
273
+ for questions_info in tqdm(all_questions_info):
274
+ question = questions_info['问题']
275
+ file_name = questions_info['文件名']
276
+ collection_name = questions_info['知识库名']
277
+
278
+ # if question != '请分析江苏中设集团股份有限公司2021年重大关联交易的情况。':
279
+ # continue
280
+
281
+ if self.params['generate']['with_retrieval']:
282
+ # retrieval and rerank
283
+ docs = self.retrieval_and_rerank(question, collection_name, max_content=self.params['generate']['max_content'])
284
+ else:
285
+ # load document
286
+ if file_name not in file2docs:
287
+ docs = self.load_documents(file_name, max_content=self.params['generate']['max_content'])
288
+ file2docs[file_name] = docs
289
+ else:
290
+ docs = file2docs[file_name]
291
+
292
+ if isinstance(self.llm, ChatCohere):
293
+ try:
294
+ # cohere rag
295
+ # messages = prompt.format_prompt(question=question).to_messages()
296
+ # ans = self.llm.invoke(messages, source_documents=docs)
297
+ # rag_answer = ans.content
298
+
299
+ # cohere rag by raw prompt
300
+ documents = ''
301
+ for i, doc in enumerate(docs):
302
+ # documents += f'Document: {i}\n'
303
+ # documents += 'text:' + doc.page_content + '\n\n'
304
+
305
+ documents += doc.page_content + '\n\n'
306
+ messages = prompt.format_prompt(question=question, documents=documents).to_messages()
307
+ ans = self.llm.invoke(messages, raw_prompting=True)
308
+ rag_answer = ans.content.replace('Answer: ', '')
309
+ except Exception as e:
310
+ logger.error(f'question: {question}\nerror: {e}')
311
+ ans = {'output_text': str(e)}
312
+ rag_answer = ans['output_text']
313
+ else:
314
+ # question answer
315
+ try:
316
+ ans = qa_chain({"input_documents": docs, "question": question}, return_only_outputs=True)
317
+ except Exception as e:
318
+ logger.error(f'question: {question}\nerror: {e}')
319
+ ans = {'output_text': str(e)}
320
+ rag_answer = ans['output_text']
321
+
322
+ # context = '\n\n'.join([doc.page_content for doc in docs])
323
+ # content = prompt.format(context=context, question=question)
324
+
325
+ # for rate_limit
326
+ # time.sleep(3)
327
+ logger.info(f'question: {question}\nans: {rag_answer}\n')
328
+ questions_info['rag_answer'] = rag_answer
329
+ # questions_info['rag_context'] = '\n----------------\n'.join([doc.page_content for doc in docs])
330
+ # questions_info['rag_context'] = content
331
+
332
+ df = pd.DataFrame(all_questions_info)
333
+ df.to_excel(self.save_answer_path, index=False)
334
+
335
+ def score(self):
336
+ """
337
+ score
338
+ """
339
+ metric_params = self.params['metric']
340
+ if metric_params['type'] == 'bisheng-ragas':
341
+ score_params = {
342
+ 'excel_path': self.save_answer_path,
343
+ 'save_path': os.path.dirname(self.save_answer_path),
344
+ 'question_column': metric_params['question_column'],
345
+ 'gt_column': metric_params['gt_column'],
346
+ 'answer_column': metric_params['answer_column'],
347
+ 'query_type_column': metric_params.get('query_type_column', None),
348
+ 'contexts_column': metric_params.get('contexts_column', None),
349
+ 'metrics': metric_params['metrics'],
350
+ 'batch_size': metric_params['batch_size'],
351
+ 'gt_split_column': metric_params.get('gt_split_column', None),
352
+ 'whether_gtsplit': metric_params.get('whether_gtsplit', False), # 是否需要模型对gt进行要点拆分
353
+ }
354
+ rag_score = RagScore(**score_params)
355
+ rag_score.score()
356
+ else:
357
+ # todo: 其他评分方法
358
+ pass
359
+
360
+
361
+ if __name__ == '__main__':
362
+ parser = argparse.ArgumentParser(description='Process some integers.')
363
+ # 添加参数
364
+ parser.add_argument('--mode', type=str, default='qa', help='upload or qa or score')
365
+ parser.add_argument('--params', type=str, default='config/test/baseline_s2b.yaml', help='bisheng rag params')
366
+ # 解析参数
367
+ args = parser.parse_args()
368
+
369
+ rag = BishengRagPipeline(args.params)
370
+
371
+ if args.mode == 'upload':
372
+ rag.file2knowledge()
373
+ elif args.mode == 'qa':
374
+ rag.question_answering()
375
+ elif args.mode == 'score':
376
+ rag.score()
@@ -0,0 +1,288 @@
1
+ import time
2
+ import os
3
+ import yaml
4
+ import httpx
5
+ from typing import Any, Dict, Tuple, Type, Union, Optional
6
+ from loguru import logger
7
+ from langchain_core.tools import BaseTool, Tool
8
+ from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
9
+ from langchain_core.language_models.base import LanguageModelLike
10
+ from langchain.chains.question_answering import load_qa_chain
11
+ from bisheng_langchain.retrievers import EnsembleRetriever
12
+ from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
13
+ from bisheng_langchain.rag.init_retrievers import (
14
+ BaselineVectorRetriever,
15
+ KeywordRetriever,
16
+ MixRetriever,
17
+ SmallerChunksVectorRetriever,
18
+ )
19
+ from bisheng_langchain.rag.utils import import_by_type, import_class
20
+ from bisheng_langchain.rag.extract_info import extract_title
21
+
22
+
23
+ class MultArgsSchemaTool(Tool):
24
+
25
+ def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
26
+ # For backwards compatibility, if run_input is a string,
27
+ # pass as a positional argument.
28
+ if isinstance(tool_input, str):
29
+ return (tool_input, ), {}
30
+ else:
31
+ return (), tool_input
32
+
33
+
34
+ class BishengRAGTool:
35
+
36
+ def __init__(
37
+ self,
38
+ vector_store: Optional[Milvus] = None,
39
+ keyword_store: Optional[ElasticKeywordsSearch] = None,
40
+ llm: Optional[LanguageModelLike] = None,
41
+ collection_name: Optional[str] = None,
42
+ **kwargs
43
+ ) -> None:
44
+ if collection_name is None and (keyword_store is None or vector_store is None):
45
+ raise ValueError('collection_name must be provided if keyword_store or vector_store is not provided')
46
+ self.collection_name = collection_name
47
+
48
+ yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config/baseline_v2.yaml')
49
+ with open(yaml_path, 'r', encoding='utf-8') as f:
50
+ self.params = yaml.safe_load(f)
51
+
52
+ # update params
53
+ max_content = kwargs.get("max_content", 15000)
54
+ sort_by_source_and_index = kwargs.get("sort_by_source_and_index", True)
55
+ self.params['generate']['max_content'] = max_content
56
+ self.params['post_retrieval']['sort_by_source_and_index'] = sort_by_source_and_index
57
+
58
+ # init milvus
59
+ if vector_store:
60
+ self.vector_store = vector_store
61
+ else:
62
+ # init embeddings
63
+ embedding_params = self.params['embedding']
64
+ embedding_object = import_by_type(_type='embeddings', name=embedding_params['type'])
65
+ if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
66
+ embedding_params.pop('type')
67
+ self.embeddings = embedding_object(
68
+ http_client=httpx.Client(proxies=embedding_params['openai_proxy']), **embedding_params
69
+ )
70
+ else:
71
+ embedding_params.pop('type')
72
+ self.embeddings = embedding_object(**embedding_params)
73
+
74
+ self.vector_store = Milvus(
75
+ embedding_function=self.embeddings,
76
+ connection_args={
77
+ "host": self.params['milvus']['host'],
78
+ "port": self.params['milvus']['port'],
79
+ },
80
+ )
81
+
82
+ # init keyword store
83
+ if keyword_store:
84
+ self.keyword_store = keyword_store
85
+ else:
86
+ self.keyword_store = ElasticKeywordsSearch(
87
+ index_name='default_es',
88
+ elasticsearch_url=self.params['elasticsearch']['url'],
89
+ ssl_verify=self.params['elasticsearch']['ssl_verify'],
90
+ )
91
+
92
+ # init llm
93
+ if llm:
94
+ self.llm = llm
95
+ else:
96
+ llm_params = self.params['chat_llm']
97
+ llm_object = import_by_type(_type='llms', name=llm_params['type'])
98
+ if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
99
+ llm_params.pop('type')
100
+ self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
101
+ else:
102
+ llm_params.pop('type')
103
+ self.llm = llm_object(**llm_params)
104
+
105
+ # init retriever
106
+ retriever_list = []
107
+ retrievers = self.params['retriever']['retrievers']
108
+ for retriever in retrievers:
109
+ retriever_type = retriever.pop('type')
110
+ retriever_params = {
111
+ 'vector_store': self.vector_store,
112
+ 'keyword_store': self.keyword_store,
113
+ 'splitter_kwargs': retriever['splitter'],
114
+ 'retrieval_kwargs': retriever['retrieval'],
115
+ }
116
+ retriever_list.append(self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
117
+ self.retriever = EnsembleRetriever(retrievers=retriever_list)
118
+
119
+ # init qa chain
120
+ if 'prompt_type' in self.params['generate']:
121
+ prompt_type = self.params['generate']['prompt_type']
122
+ prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
123
+ else:
124
+ prompt = None
125
+ self.qa_chain = load_qa_chain(
126
+ llm=self.llm,
127
+ chain_type=self.params['generate']['chain_type'],
128
+ prompt=prompt,
129
+ verbose=False
130
+ )
131
+
132
+ def _post_init_retriever(self, retriever_type, **kwargs):
133
+ retriever_classes = {
134
+ 'KeywordRetriever': KeywordRetriever,
135
+ 'BaselineVectorRetriever': BaselineVectorRetriever,
136
+ 'MixRetriever': MixRetriever,
137
+ 'SmallerChunksVectorRetriever': SmallerChunksVectorRetriever,
138
+ }
139
+ if retriever_type not in retriever_classes:
140
+ raise ValueError(f'Unknown retriever type: {retriever_type}')
141
+
142
+ input_kwargs = {}
143
+ splitter_params = kwargs.pop('splitter_kwargs')
144
+ for key, value in splitter_params.items():
145
+ splitter_obj = import_by_type(_type='textsplitters', name=value.pop('type'))
146
+ input_kwargs[key] = splitter_obj(**value)
147
+
148
+ retrieval_params = kwargs.pop('retrieval_kwargs')
149
+ for key, value in retrieval_params.items():
150
+ input_kwargs[key] = value
151
+
152
+ input_kwargs['vector_store'] = kwargs.pop('vector_store')
153
+ input_kwargs['keyword_store'] = kwargs.pop('keyword_store')
154
+
155
+ retriever_class = retriever_classes[retriever_type]
156
+ return retriever_class(**input_kwargs)
157
+
158
+ def file2knowledge(self, file_path, drop_old=True):
159
+ """
160
+ file to knowledge
161
+ """
162
+ loader_params = self.params['loader']
163
+ loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
164
+
165
+ logger.info(f'file_path: {file_path}')
166
+ loader = loader_object(
167
+ file_name=os.path.basename(file_path), file_path=file_path, **loader_params
168
+ )
169
+ documents = loader.load()
170
+ logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
171
+ if len(documents[0].page_content) == 0:
172
+ logger.error(f'{file_path} page_content is empty.')
173
+
174
+ # add aux info
175
+ add_aux_info = self.params['retriever'].get('add_aux_info', False)
176
+ if add_aux_info:
177
+ for doc in documents:
178
+ try:
179
+ title = extract_title(llm=self.llm, text=doc.page_content)
180
+ logger.info(f'extract title: {title}')
181
+ except Exception as e:
182
+ logger.error(f"Failed to extract title: {e}")
183
+ title = ''
184
+ doc.metadata['title'] = title
185
+
186
+ for idx, retriever in enumerate(self.retriever.retrievers):
187
+ retriever.add_documents(
188
+ documents,
189
+ self.collection_name,
190
+ drop_old=drop_old,
191
+ add_aux_info=add_aux_info
192
+ )
193
+
194
+ def retrieval_and_rerank(self, query):
195
+ """
196
+ retrieval and rerank
197
+ """
198
+ # EnsembleRetriever直接检索召回会默认去重
199
+ docs = self.retriever.get_relevant_documents(
200
+ query=query,
201
+ collection_name=self.collection_name
202
+ )
203
+ logger.info(f'retrieval docs origin: {len(docs)}')
204
+
205
+ # delete redundancy according to max_content
206
+ doc_num, doc_content_sum = 0, 0
207
+ for doc in docs:
208
+ doc_content_sum += len(doc.page_content)
209
+ if doc_content_sum > self.params['generate']['max_content']:
210
+ break
211
+ doc_num += 1
212
+ docs = docs[:doc_num]
213
+ logger.info(f'retrieval docs after delete redundancy: {len(docs)}')
214
+
215
+ # 按照文档的source和chunk_index排序,保证上下文的连贯性和一致性
216
+ if self.params['post_retrieval'].get('sort_by_source_and_index', False):
217
+ logger.info('sort chunks by source and chunk_index')
218
+ docs = sorted(docs, key=lambda x: (x.metadata['source'], x.metadata['chunk_index']))
219
+ return docs
220
+
221
+ def run(self, query) -> str:
222
+ docs = self.retrieval_and_rerank(query)
223
+ try:
224
+ ans = self.qa_chain({"input_documents": docs, "question": query}, return_only_outputs=True)
225
+ except Exception as e:
226
+ logger.error(f'question: {query}\nerror: {e}')
227
+ ans = {'output_text': str(e)}
228
+ rag_answer = ans['output_text']
229
+ return rag_answer
230
+
231
+ async def arun(self, query: str) -> str:
232
+ rag_answer = self.run(query)
233
+ return rag_answer
234
+
235
+ @classmethod
236
+ def get_rag_tool(cls, name, description, **kwargs: Any) -> BaseTool:
237
+ class InputArgs(BaseModel):
238
+ query: str = Field(description='question asked by the user.')
239
+
240
+ return MultArgsSchemaTool(name=name,
241
+ description=description,
242
+ func=cls(**kwargs).run,
243
+ coroutine=cls(**kwargs).arun,
244
+ args_schema=InputArgs)
245
+
246
+
247
+ if __name__ == '__main__':
248
+ # rag_tool = BishengRAGTool(collection_name='rag_finance_report_0_test')
249
+ # rag_tool.file2knowledge(file_path='/home/public/rag_benchmark_finance_report/金融年报财报的来源文件/2021-04-23__金宇生物技术股份有限公司__600201__生物股份__2020年__年度报告.pdf')
250
+
251
+ from langchain.chat_models import ChatOpenAI
252
+ from langchain.embeddings import OpenAIEmbeddings
253
+ # embedding
254
+ embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')
255
+ # llm
256
+ llm = ChatOpenAI(model='gpt-4-1106-preview', temperature=0.01)
257
+ collection_name = 'rag_finance_report_0_benchmark_caibao_1000_source_title'
258
+ # milvus
259
+ vector_store = Milvus(
260
+ collection_name=collection_name,
261
+ embedding_function=embeddings,
262
+ connection_args={
263
+ "host": '110.16.193.170',
264
+ "port": '50062',
265
+ },
266
+ )
267
+ # es
268
+ keyword_store = ElasticKeywordsSearch(
269
+ index_name=collection_name,
270
+ elasticsearch_url='http://110.16.193.170:50062/es',
271
+ ssl_verify={'basic_auth': ["elastic", "oSGL-zVvZ5P3Tm7qkDLC"]},
272
+ )
273
+
274
+ tool = BishengRAGTool.get_rag_tool(
275
+ name='rag_knowledge_retrieve',
276
+ description='金融年报财报知识库问答',
277
+ vector_store=vector_store,
278
+ keyword_store=keyword_store,
279
+ llm=llm
280
+ )
281
+ print(tool.run('能否根据2020年金宇生物技术股份有限公司的年报,给我简要介绍一下报告期内公司的社会责任工作情况?'))
282
+
283
+ # tool = BishengRAGTool.get_rag_tool(
284
+ # name='rag_knowledge_retrieve',
285
+ # description='金融年报财报知识库问答',
286
+ # collection_name='rag_finance_report_0_benchmark_caibao_1000_source_title'
287
+ # )
288
+ # print(tool.run('能否根据2020年金宇生物技术股份有限公司的年报,给我简要介绍一下报告期内公司的社会责任工作情况?'))