bisheng-langchain 0.3.0rc1__py3-none-any.whl → 0.3.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +5 -3
  2. bisheng_langchain/gpts/agent_types/llm_functions_agent.py +7 -1
  3. bisheng_langchain/gpts/assistant.py +8 -5
  4. bisheng_langchain/gpts/load_tools.py +2 -0
  5. bisheng_langchain/gpts/prompts/__init__.py +4 -2
  6. bisheng_langchain/gpts/prompts/assistant_prompt_base.py +1 -0
  7. bisheng_langchain/gpts/prompts/assistant_prompt_cohere.py +19 -0
  8. bisheng_langchain/gpts/tools/api_tools/flow.py +3 -3
  9. bisheng_langchain/gpts/tools/api_tools/openapi.py +101 -0
  10. bisheng_langchain/rag/__init__.py +5 -0
  11. bisheng_langchain/rag/bisheng_rag_pipeline.py +320 -0
  12. bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +359 -0
  13. bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +376 -0
  14. bisheng_langchain/rag/bisheng_rag_tool.py +288 -0
  15. bisheng_langchain/rag/config/baseline.yaml +86 -0
  16. bisheng_langchain/rag/config/baseline_caibao.yaml +82 -0
  17. bisheng_langchain/rag/config/baseline_caibao_knowledge_v2.yaml +110 -0
  18. bisheng_langchain/rag/config/baseline_caibao_v2.yaml +112 -0
  19. bisheng_langchain/rag/config/baseline_demo_v2.yaml +92 -0
  20. bisheng_langchain/rag/config/baseline_s2b_mix.yaml +88 -0
  21. bisheng_langchain/rag/config/baseline_v2.yaml +90 -0
  22. bisheng_langchain/rag/extract_info.py +38 -0
  23. bisheng_langchain/rag/init_retrievers/__init__.py +4 -0
  24. bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +61 -0
  25. bisheng_langchain/rag/init_retrievers/keyword_retriever.py +65 -0
  26. bisheng_langchain/rag/init_retrievers/mix_retriever.py +103 -0
  27. bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +92 -0
  28. bisheng_langchain/rag/prompts/__init__.py +9 -0
  29. bisheng_langchain/rag/prompts/extract_key_prompt.py +34 -0
  30. bisheng_langchain/rag/prompts/prompt.py +47 -0
  31. bisheng_langchain/rag/prompts/prompt_cohere.py +111 -0
  32. bisheng_langchain/rag/qa_corpus/__init__.py +0 -0
  33. bisheng_langchain/rag/qa_corpus/qa_generator.py +143 -0
  34. bisheng_langchain/rag/rerank/__init__.py +5 -0
  35. bisheng_langchain/rag/rerank/rerank.py +48 -0
  36. bisheng_langchain/rag/rerank/rerank_benchmark.py +139 -0
  37. bisheng_langchain/rag/run_qa_gen_web.py +47 -0
  38. bisheng_langchain/rag/run_rag_evaluate_web.py +55 -0
  39. bisheng_langchain/rag/scoring/__init__.py +0 -0
  40. bisheng_langchain/rag/scoring/llama_index_score.py +91 -0
  41. bisheng_langchain/rag/scoring/ragas_score.py +183 -0
  42. bisheng_langchain/rag/utils.py +181 -0
  43. bisheng_langchain/retrievers/ensemble.py +2 -1
  44. bisheng_langchain/vectorstores/elastic_keywords_search.py +2 -1
  45. {bisheng_langchain-0.3.0rc1.dist-info → bisheng_langchain-0.3.1.1.dist-info}/METADATA +1 -1
  46. {bisheng_langchain-0.3.0rc1.dist-info → bisheng_langchain-0.3.1.1.dist-info}/RECORD +48 -13
  47. bisheng_langchain/gpts/prompts/base_prompt.py +0 -1
  48. {bisheng_langchain-0.3.0rc1.dist-info → bisheng_langchain-0.3.1.1.dist-info}/WHEEL +0 -0
  49. {bisheng_langchain-0.3.0rc1.dist-info → bisheng_langchain-0.3.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,359 @@
1
+ import argparse
2
+ import copy
3
+ import time
4
+ import inspect
5
+ import os
6
+ import httpx
7
+ import pandas as pd
8
+ import yaml
9
+ import math
10
+ from collections import defaultdict
11
+ from loguru import logger
12
+ from tqdm import tqdm
13
+ from langchain.docstore.document import Document
14
+ from langchain.chains.question_answering import load_qa_chain
15
+ from langchain.chains.llm import LLMChain
16
+ from bisheng_langchain.retrievers import EnsembleRetriever
17
+ from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
18
+ from bisheng_langchain.rag.init_retrievers import (
19
+ BaselineVectorRetriever,
20
+ KeywordRetriever,
21
+ MixRetriever,
22
+ SmallerChunksVectorRetriever,
23
+ )
24
+ from bisheng_langchain.rag.scoring.ragas_score import RagScore
25
+ from bisheng_langchain.rag.extract_info import extract_title
26
+ from bisheng_langchain.rag.utils import import_by_type, import_class
27
+
28
+
29
+ class BishengRagPipeline:
30
+
31
+ def __init__(self, yaml_path) -> None:
32
+ self.yaml_path = yaml_path
33
+ with open(self.yaml_path, 'r') as f:
34
+ self.params = yaml.safe_load(f)
35
+
36
+ # init data
37
+ self.origin_file_path = self.params['data']['origin_file_path']
38
+ self.question_path = self.params['data']['question']
39
+ self.save_answer_path = self.params['data']['save_answer']
40
+
41
+ # init embeddings
42
+ embedding_params = self.params['embedding']
43
+ embedding_object = import_by_type(_type='embeddings', name=embedding_params['type'])
44
+ if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
45
+ embedding_params.pop('type')
46
+ self.embeddings = embedding_object(
47
+ http_client=httpx.Client(proxies=embedding_params['openai_proxy']), **embedding_params
48
+ )
49
+ else:
50
+ embedding_params.pop('type')
51
+ self.embeddings = embedding_object(**embedding_params)
52
+
53
+ # init llm
54
+ llm_params = self.params['chat_llm']
55
+ llm_object = import_by_type(_type='llms', name=llm_params['type'])
56
+ if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
57
+ llm_params.pop('type')
58
+ self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
59
+ else:
60
+ llm_params.pop('type')
61
+ self.llm = llm_object(**llm_params)
62
+
63
+ # milvus
64
+ self.vector_store = Milvus(
65
+ embedding_function=self.embeddings,
66
+ connection_args={
67
+ "host": self.params['milvus']['host'],
68
+ "port": self.params['milvus']['port'],
69
+ },
70
+ )
71
+
72
+ # es
73
+ if self.params['elasticsearch'].get('extract_key_by_llm', False):
74
+ extract_key_prompt = import_class(f'bisheng_langchain.rag.prompts.EXTRACT_KEY_PROMPT')
75
+ llm_chain = LLMChain(llm=self.llm, prompt=extract_key_prompt)
76
+ else:
77
+ llm_chain = None
78
+ self.keyword_store = ElasticKeywordsSearch(
79
+ index_name='default_es',
80
+ elasticsearch_url=self.params['elasticsearch']['url'],
81
+ ssl_verify=self.params['elasticsearch']['ssl_verify'],
82
+ llm_chain=llm_chain
83
+ )
84
+
85
+ # init retriever
86
+ retriever_list = []
87
+ retrievers = self.params['retriever']['retrievers']
88
+ for retriever in retrievers:
89
+ retriever_type = retriever.pop('type')
90
+ retriever_params = {
91
+ 'vector_store': self.vector_store,
92
+ 'keyword_store': self.keyword_store,
93
+ 'splitter_kwargs': retriever['splitter'],
94
+ 'retrieval_kwargs': retriever['retrieval'],
95
+ }
96
+ retriever_list.append(self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
97
+ self.retriever = EnsembleRetriever(retrievers=retriever_list)
98
+
99
+ def _post_init_retriever(self, retriever_type, **kwargs):
100
+ retriever_classes = {
101
+ 'KeywordRetriever': KeywordRetriever,
102
+ 'BaselineVectorRetriever': BaselineVectorRetriever,
103
+ 'MixRetriever': MixRetriever,
104
+ 'SmallerChunksVectorRetriever': SmallerChunksVectorRetriever,
105
+ }
106
+ if retriever_type not in retriever_classes:
107
+ raise ValueError(f'Unknown retriever type: {retriever_type}')
108
+
109
+ input_kwargs = {}
110
+ splitter_params = kwargs.pop('splitter_kwargs')
111
+ for key, value in splitter_params.items():
112
+ splitter_obj = import_by_type(_type='textsplitters', name=value.pop('type'))
113
+ input_kwargs[key] = splitter_obj(**value)
114
+
115
+ retrieval_params = kwargs.pop('retrieval_kwargs')
116
+ for key, value in retrieval_params.items():
117
+ input_kwargs[key] = value
118
+
119
+ input_kwargs['vector_store'] = kwargs.pop('vector_store')
120
+ input_kwargs['keyword_store'] = kwargs.pop('keyword_store')
121
+
122
+ retriever_class = retriever_classes[retriever_type]
123
+ return retriever_class(**input_kwargs)
124
+
125
+ def file2knowledge(self):
126
+ """
127
+ file to knowledge
128
+ """
129
+ df = pd.read_excel(self.question_path)
130
+ if ('文件名' not in df.columns) or ('知识库名' not in df.columns):
131
+ raise Exception(f'文件名 or 知识库名 not in {self.question_path}.')
132
+
133
+ loader_params = self.params['loader']
134
+ loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
135
+
136
+ all_questions_info = df.to_dict('records')
137
+ collectionname2filename = defaultdict(set)
138
+ for info in all_questions_info:
139
+ # 存入set,去掉重复的文件名
140
+ collectionname2filename[info['知识库名']].add(info['文件名'])
141
+
142
+ for collection_name in tqdm(collectionname2filename):
143
+ all_file_paths = []
144
+ for file_name in collectionname2filename[collection_name]:
145
+ file_path = os.path.join(self.origin_file_path, file_name)
146
+ if not os.path.exists(file_path):
147
+ raise Exception(f'{file_path} not exists.')
148
+ # file path可以是文件夹或者单个文件
149
+ if os.path.isdir(file_path):
150
+ # 文件夹包含多个文件
151
+ all_file_paths.extend(
152
+ [os.path.join(file_path, name) for name in os.listdir(file_path) if not name.startswith('.')]
153
+ )
154
+ else:
155
+ # 单个文件
156
+ all_file_paths.append(file_path)
157
+
158
+ # 当前知识库需要存储的所有文件
159
+ collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
160
+ for index, each_file_path in enumerate(all_file_paths):
161
+ logger.info(f'each_file_path: {each_file_path}')
162
+ loader = loader_object(
163
+ file_name=os.path.basename(each_file_path), file_path=each_file_path, **loader_params
164
+ )
165
+ documents = loader.load()
166
+
167
+ # # load from text
168
+ # if each_file_path.endswith('.pdf'):
169
+ # with open(each_file_path.replace('.pdf', '.txt'), 'r') as f:
170
+ # content = f.read()
171
+ # documents = [Document(page_content=content, metadata={'source': os.path.basename(each_file_path)})]
172
+
173
+ logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
174
+ if len(documents[0].page_content) == 0:
175
+ logger.error(f'{each_file_path} page_content is empty.')
176
+
177
+ # add aux info
178
+ add_aux_info = self.params['retriever'].get('add_aux_info', False)
179
+ if add_aux_info:
180
+ for doc in documents:
181
+ try:
182
+ title = extract_title(llm=self.llm, text=doc.page_content)
183
+ logger.info(f'extract title: {title}')
184
+ except Exception as e:
185
+ logger.error(f"Failed to extract title: {e}")
186
+ title = ''
187
+ doc.metadata['title'] = title
188
+
189
+ vector_drop_old = self.params['milvus']['drop_old'] if index == 0 else False
190
+ keyword_drop_old = self.params['elasticsearch']['drop_old'] if index == 0 else False
191
+ for idx, retriever in enumerate(self.retriever.retrievers):
192
+ retriever.add_documents(documents, collection_name, vector_drop_old, add_aux_info=add_aux_info)
193
+
194
+ def retrieval_and_rerank(self, question, collection_name, max_content=100000):
195
+ """
196
+ retrieval and rerank
197
+ """
198
+ collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
199
+ # EnsembleRetriever直接检索召回会默认去重
200
+ docs = self.retriever.get_relevant_documents(query=question, collection_name=collection_name)
201
+ logger.info(f'retrieval docs origin: {len(docs)}')
202
+
203
+ # delete duplicate
204
+ if self.params['post_retrieval']['delete_duplicate']:
205
+ logger.info(f'origin docs: {len(docs)}')
206
+ all_contents = []
207
+ docs_no_dup = []
208
+ for index, doc in enumerate(docs):
209
+ doc_content = doc.page_content
210
+ if doc_content in all_contents:
211
+ continue
212
+ all_contents.append(doc_content)
213
+ docs_no_dup.append(doc)
214
+ docs = docs_no_dup
215
+ logger.info(f'delete duplicate docs: {len(docs)}')
216
+
217
+ # rerank
218
+ if self.params['post_retrieval']['with_rank'] and len(docs):
219
+ if not hasattr(self, 'ranker'):
220
+ rerank_params = self.params['post_retrieval']['rerank']
221
+ rerank_type = rerank_params.pop('type')
222
+ rerank_object = import_class(f'bisheng_langchain.rag.rerank.{rerank_type}')
223
+ self.ranker = rerank_object(**rerank_params)
224
+ docs = getattr(self, 'ranker').sort_and_filter(question, docs)
225
+
226
+ # delete redundancy according to max_content
227
+ doc_num, doc_content_sum = 0, 0
228
+ for doc in docs:
229
+ doc_content_sum += len(doc.page_content)
230
+ if doc_content_sum > max_content:
231
+ break
232
+ doc_num += 1
233
+ docs = docs[:doc_num]
234
+ logger.info(f'retrieval docs after delete redundancy: {len(docs)}')
235
+
236
+ # 按照文档的source和chunk_index排序,保证上下文的连贯性和一致性
237
+ if self.params['post_retrieval'].get('sort_by_source_and_index', False):
238
+ logger.info('sort chunks by source and chunk_index')
239
+ docs = sorted(docs, key=lambda x: (x.metadata['source'], x.metadata['chunk_index']))
240
+ return docs
241
+
242
+ def load_documents(self, file_name, max_content=100000):
243
+ """
244
+ 直接加载文档,如果文档过长,直接截断处理;
245
+ max_content: max content len of llm
246
+ """
247
+ file_path = os.path.join(self.origin_file_path, file_name)
248
+ if not os.path.exists(file_path):
249
+ raise Exception(f'{file_path} not exists.')
250
+ if os.path.isdir(file_path):
251
+ raise Exception(f'{file_path} is a directory.')
252
+
253
+ loader_params = copy.deepcopy(self.params['loader'])
254
+ loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
255
+ loader = loader_object(file_name=file_name, file_path=file_path, **loader_params)
256
+
257
+ documents = loader.load()
258
+ logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
259
+ for doc in documents:
260
+ doc.page_content = doc.page_content[:max_content]
261
+ return documents
262
+
263
+ def question_answering(self):
264
+ """
265
+ question answer over knowledge
266
+ """
267
+ df = pd.read_excel(self.question_path)
268
+ all_questions_info = df.to_dict('records')
269
+ if 'prompt_type' in self.params['generate']:
270
+ prompt_type = self.params['generate']['prompt_type']
271
+ prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
272
+ else:
273
+ prompt = None
274
+ qa_chain = load_qa_chain(
275
+ llm=self.llm, chain_type=self.params['generate']['chain_type'], prompt=prompt, verbose=False
276
+ )
277
+ file2docs = dict()
278
+ for questions_info in tqdm(all_questions_info):
279
+ question = questions_info['问题']
280
+ file_name = questions_info['文件名']
281
+ collection_name = questions_info['知识库名']
282
+
283
+ # if question != '请分析江苏中设集团股份有限公司2021年重大关联交易的情况。':
284
+ # continue
285
+
286
+ if self.params['generate']['with_retrieval']:
287
+ # retrieval and rerank
288
+ docs = self.retrieval_and_rerank(question, collection_name, max_content=self.params['generate']['max_content'])
289
+ else:
290
+ # load document
291
+ if file_name not in file2docs:
292
+ docs = self.load_documents(file_name, max_content=self.params['generate']['max_content'])
293
+ file2docs[file_name] = docs
294
+ else:
295
+ docs = file2docs[file_name]
296
+
297
+ # question answer
298
+ try:
299
+ ans = qa_chain({"input_documents": docs, "question": question}, return_only_outputs=True)
300
+ except Exception as e:
301
+ logger.error(f'question: {question}\nerror: {e}')
302
+ ans = {'output_text': str(e)}
303
+ rag_answer = ans['output_text']
304
+
305
+ # context = '\n\n'.join([doc.page_content for doc in docs])
306
+ # content = prompt.format(context=context, question=question)
307
+
308
+ # for rate_limit
309
+ # time.sleep(30)
310
+ logger.info(f'question: {question}\nans: {rag_answer}\n')
311
+ questions_info['rag_answer'] = rag_answer
312
+ # questions_info['rag_context'] = '\n----------------\n'.join([doc.page_content for doc in docs])
313
+ # questions_info['rag_context'] = content
314
+
315
+ df = pd.DataFrame(all_questions_info)
316
+ df.to_excel(self.save_answer_path, index=False)
317
+
318
+ def score(self):
319
+ """
320
+ score
321
+ """
322
+ metric_params = self.params['metric']
323
+ if metric_params['type'] == 'bisheng-ragas':
324
+ score_params = {
325
+ 'excel_path': self.save_answer_path,
326
+ 'save_path': os.path.dirname(self.save_answer_path),
327
+ 'question_column': metric_params['question_column'],
328
+ 'gt_column': metric_params['gt_column'],
329
+ 'answer_column': metric_params['answer_column'],
330
+ 'query_type_column': metric_params.get('query_type_column', None),
331
+ 'contexts_column': metric_params.get('contexts_column', None),
332
+ 'metrics': metric_params['metrics'],
333
+ 'batch_size': metric_params['batch_size'],
334
+ 'gt_split_column': metric_params.get('gt_split_column', None),
335
+ 'whether_gtsplit': metric_params.get('whether_gtsplit', False), # 是否需要模型对gt进行要点拆分
336
+ }
337
+ rag_score = RagScore(**score_params)
338
+ rag_score.score()
339
+ else:
340
+ # todo: 其他评分方法
341
+ pass
342
+
343
+
344
+ if __name__ == '__main__':
345
+ parser = argparse.ArgumentParser(description='Process some integers.')
346
+ # 添加参数
347
+ parser.add_argument('--mode', type=str, default='qa', help='upload or qa or score')
348
+ parser.add_argument('--params', type=str, default='config/test/baseline_s2b.yaml', help='bisheng rag params')
349
+ # 解析参数
350
+ args = parser.parse_args()
351
+
352
+ rag = BishengRagPipeline(args.params)
353
+
354
+ if args.mode == 'upload':
355
+ rag.file2knowledge()
356
+ elif args.mode == 'qa':
357
+ rag.question_answering()
358
+ elif args.mode == 'score':
359
+ rag.score()