bisheng-langchain 0.3.1.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. bisheng_langchain/chains/__init__.py +4 -1
  2. bisheng_langchain/chains/qa_generation/__init__.py +0 -0
  3. bisheng_langchain/chains/qa_generation/base.py +128 -0
  4. bisheng_langchain/chains/qa_generation/base_v2.py +413 -0
  5. bisheng_langchain/chains/qa_generation/prompt.py +53 -0
  6. bisheng_langchain/chains/qa_generation/prompt_v2.py +155 -0
  7. bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +36 -9
  8. bisheng_langchain/document_loaders/parsers/ellm_client.py +7 -9
  9. bisheng_langchain/document_loaders/universal_kv.py +4 -3
  10. bisheng_langchain/gpts/tools/api_tools/openapi.py +7 -7
  11. bisheng_langchain/rag/__init__.py +2 -0
  12. bisheng_langchain/rag/bisheng_rag_chain.py +164 -0
  13. bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +8 -2
  14. bisheng_langchain/rag/bisheng_rag_tool.py +47 -24
  15. bisheng_langchain/rag/config/baseline_caibao_v2.yaml +1 -1
  16. bisheng_langchain/rag/config/baseline_v2.yaml +3 -2
  17. bisheng_langchain/rag/prompts/prompt.py +1 -1
  18. bisheng_langchain/rag/qa_corpus/qa_generator.py +1 -1
  19. bisheng_langchain/rag/scoring/ragas_score.py +2 -2
  20. bisheng_langchain/rag/utils.py +27 -4
  21. bisheng_langchain/sql/__init__.py +3 -0
  22. bisheng_langchain/sql/base.py +120 -0
  23. bisheng_langchain/text_splitter.py +1 -1
  24. {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/METADATA +3 -1
  25. {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/RECORD +27 -20
  26. bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +0 -376
  27. {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/WHEEL +0 -0
  28. {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/top_level.txt +0 -0
@@ -1,376 +0,0 @@
1
- import argparse
2
- import copy
3
- import time
4
- import inspect
5
- import os
6
- import httpx
7
- import pandas as pd
8
- import yaml
9
- import math
10
- from tqdm import tqdm
11
- from loguru import logger
12
- from collections import defaultdict
13
- from bisheng_langchain.retrievers import EnsembleRetriever
14
- from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
15
- from langchain.docstore.document import Document
16
- from langchain.chains.question_answering import load_qa_chain
17
- from langchain_community.chat_models.cohere import ChatCohere
18
- from bisheng_langchain.rag.init_retrievers import (
19
- BaselineVectorRetriever,
20
- KeywordRetriever,
21
- MixRetriever,
22
- SmallerChunksVectorRetriever,
23
- )
24
- from bisheng_langchain.rag.scoring.ragas_score import RagScore
25
- from bisheng_langchain.rag.extract_info import extract_title
26
- from bisheng_langchain.rag.utils import import_by_type, import_class
27
-
28
-
29
- class BishengRagPipeline:
30
-
31
- def __init__(self, yaml_path) -> None:
32
- self.yaml_path = yaml_path
33
- with open(self.yaml_path, 'r') as f:
34
- self.params = yaml.safe_load(f)
35
-
36
- # init data
37
- self.origin_file_path = self.params['data']['origin_file_path']
38
- self.question_path = self.params['data']['question']
39
- self.save_answer_path = self.params['data']['save_answer']
40
-
41
- # init embeddings
42
- embedding_params = self.params['embedding']
43
- embedding_object = import_by_type(_type='embeddings', name=embedding_params['type'])
44
- if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
45
- embedding_params.pop('type')
46
- self.embeddings = embedding_object(
47
- http_client=httpx.Client(proxies=embedding_params['openai_proxy']), **embedding_params
48
- )
49
- else:
50
- embedding_params.pop('type')
51
- self.embeddings = embedding_object(**embedding_params)
52
-
53
- # init llm
54
- llm_params = self.params['chat_llm']
55
- llm_object = import_by_type(_type='llms', name=llm_params['type'])
56
- if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
57
- llm_params.pop('type')
58
- self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
59
- else:
60
- llm_params.pop('type')
61
- self.llm = llm_object(**llm_params)
62
-
63
- # milvus
64
- self.vector_store = Milvus(
65
- embedding_function=self.embeddings,
66
- connection_args={
67
- "host": self.params['milvus']['host'],
68
- "port": self.params['milvus']['port'],
69
- },
70
- )
71
-
72
- # es
73
- self.keyword_store = ElasticKeywordsSearch(
74
- index_name='default_es',
75
- elasticsearch_url=self.params['elasticsearch']['url'],
76
- ssl_verify=self.params['elasticsearch']['ssl_verify'],
77
- )
78
-
79
- # init retriever
80
- retriever_list = []
81
- retrievers = self.params['retriever']['retrievers']
82
- for retriever in retrievers:
83
- retriever_type = retriever.pop('type')
84
- retriever_params = {
85
- 'vector_store': self.vector_store,
86
- 'keyword_store': self.keyword_store,
87
- 'splitter_kwargs': retriever['splitter'],
88
- 'retrieval_kwargs': retriever['retrieval'],
89
- }
90
- retriever_list.append(self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
91
- self.retriever = EnsembleRetriever(retrievers=retriever_list)
92
-
93
- def _post_init_retriever(self, retriever_type, **kwargs):
94
- retriever_classes = {
95
- 'KeywordRetriever': KeywordRetriever,
96
- 'BaselineVectorRetriever': BaselineVectorRetriever,
97
- 'MixRetriever': MixRetriever,
98
- 'SmallerChunksVectorRetriever': SmallerChunksVectorRetriever,
99
- }
100
- if retriever_type not in retriever_classes:
101
- raise ValueError(f'Unknown retriever type: {retriever_type}')
102
-
103
- input_kwargs = {}
104
- splitter_params = kwargs.pop('splitter_kwargs')
105
- for key, value in splitter_params.items():
106
- splitter_obj = import_by_type(_type='textsplitters', name=value.pop('type'))
107
- input_kwargs[key] = splitter_obj(**value)
108
-
109
- retrieval_params = kwargs.pop('retrieval_kwargs')
110
- for key, value in retrieval_params.items():
111
- input_kwargs[key] = value
112
-
113
- input_kwargs['vector_store'] = kwargs.pop('vector_store')
114
- input_kwargs['keyword_store'] = kwargs.pop('keyword_store')
115
-
116
- retriever_class = retriever_classes[retriever_type]
117
- return retriever_class(**input_kwargs)
118
-
119
- def file2knowledge(self):
120
- """
121
- file to knowledge
122
- """
123
- df = pd.read_excel(self.question_path)
124
- if ('文件名' not in df.columns) or ('知识库名' not in df.columns):
125
- raise Exception(f'文件名 or 知识库名 not in {self.question_path}.')
126
-
127
- loader_params = self.params['loader']
128
- loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
129
-
130
- all_questions_info = df.to_dict('records')
131
- collectionname2filename = defaultdict(set)
132
- for info in all_questions_info:
133
- # 存入set,去掉重复的文件名
134
- collectionname2filename[info['知识库名']].add(info['文件名'])
135
-
136
- for collection_name in tqdm(collectionname2filename):
137
- all_file_paths = []
138
- for file_name in collectionname2filename[collection_name]:
139
- file_path = os.path.join(self.origin_file_path, file_name)
140
- if not os.path.exists(file_path):
141
- raise Exception(f'{file_path} not exists.')
142
- # file path可以是文件夹或者单个文件
143
- if os.path.isdir(file_path):
144
- # 文件夹包含多个文件
145
- all_file_paths.extend(
146
- [os.path.join(file_path, name) for name in os.listdir(file_path) if not name.startswith('.')]
147
- )
148
- else:
149
- # 单个文件
150
- all_file_paths.append(file_path)
151
-
152
- # 当前知识库需要存储的所有文件
153
- collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
154
- for index, each_file_path in enumerate(all_file_paths):
155
- logger.info(f'each_file_path: {each_file_path}')
156
- loader = loader_object(
157
- file_name=os.path.basename(each_file_path), file_path=each_file_path, **loader_params
158
- )
159
- documents = loader.load()
160
-
161
- # # load from text
162
- # if each_file_path.endswith('.pdf'):
163
- # with open(each_file_path.replace('.pdf', '.txt'), 'r') as f:
164
- # content = f.read()
165
- # documents = [Document(page_content=content, metadata={'source': os.path.basename(each_file_path)})]
166
-
167
- logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
168
- if len(documents[0].page_content) == 0:
169
- logger.error(f'{each_file_path} page_content is empty.')
170
-
171
- # add aux infoerror
172
- add_aux_info = self.params['retriever'].get('add_aux_info', False)
173
- if add_aux_info:
174
- for doc in documents:
175
- try:
176
- title = extract_title(llm=self.llm, text=doc.page_content)
177
- logger.info(f'extract title: {title}')
178
- except Exception as e:
179
- logger.error(f"Failed to extract title: {e}")
180
- title = ''
181
- doc.metadata['title'] = title
182
-
183
- vector_drop_old = self.params['milvus']['drop_old'] if index == 0 else False
184
- keyword_drop_old = self.params['elasticsearch']['drop_old'] if index == 0 else False
185
- for idx, retriever in enumerate(self.retriever.retrievers):
186
- retriever.add_documents(documents, collection_name, vector_drop_old, add_aux_info=add_aux_info)
187
-
188
- def retrieval_and_rerank(self, question, collection_name, max_content=100000):
189
- """
190
- retrieval and rerank
191
- """
192
- collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
193
- # EnsembleRetriever直接检索召回会默认去重
194
- docs = self.retriever.get_relevant_documents(query=question, collection_name=collection_name)
195
- logger.info(f'retrieval docs origin: {len(docs)}')
196
-
197
- # delete duplicate
198
- if self.params['post_retrieval']['delete_duplicate']:
199
- logger.info(f'origin docs: {len(docs)}')
200
- all_contents = []
201
- docs_no_dup = []
202
- for index, doc in enumerate(docs):
203
- doc_content = doc.page_content
204
- if doc_content in all_contents:
205
- continue
206
- all_contents.append(doc_content)
207
- docs_no_dup.append(doc)
208
- docs = docs_no_dup
209
- logger.info(f'delete duplicate docs: {len(docs)}')
210
-
211
- # rerank
212
- if self.params['post_retrieval']['with_rank'] and len(docs):
213
- if not hasattr(self, 'ranker'):
214
- rerank_params = self.params['post_retrieval']['rerank']
215
- rerank_type = rerank_params.pop('type')
216
- rerank_object = import_class(f'bisheng_langchain.rag.rerank.{rerank_type}')
217
- self.ranker = rerank_object(**rerank_params)
218
- docs = getattr(self, 'ranker').sort_and_filter(question, docs)
219
-
220
- # delete redundancy according to max_content
221
- doc_num, doc_content_sum = 0, 0
222
- for doc in docs:
223
- doc_content_sum += len(doc.page_content)
224
- if doc_content_sum > max_content:
225
- break
226
- doc_num += 1
227
- docs = docs[:doc_num]
228
- logger.info(f'retrieval docs after delete redundancy: {len(docs)}')
229
-
230
- # 按照文档的source和chunk_index排序,保证上下文的连贯性和一致性
231
- if self.params['post_retrieval'].get('sort_by_source_and_index', False):
232
- logger.info('sort chunks by source and chunk_index')
233
- docs = sorted(docs, key=lambda x: (x.metadata['source'], x.metadata['chunk_index']))
234
- return docs
235
-
236
- def load_documents(self, file_name, max_content=100000):
237
- """
238
- 直接加载文档,如果文档过长,直接截断处理;
239
- max_content: max content len of llm
240
- """
241
- file_path = os.path.join(self.origin_file_path, file_name)
242
- if not os.path.exists(file_path):
243
- raise Exception(f'{file_path} not exists.')
244
- if os.path.isdir(file_path):
245
- raise Exception(f'{file_path} is a directory.')
246
-
247
- loader_params = copy.deepcopy(self.params['loader'])
248
- loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
249
- loader = loader_object(file_name=file_name, file_path=file_path, **loader_params)
250
-
251
- documents = loader.load()
252
- logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
253
- for doc in documents:
254
- doc.page_content = doc.page_content[:max_content]
255
- return documents
256
-
257
- def question_answering(self):
258
- """
259
- question answer over knowledge
260
- """
261
- df = pd.read_excel(self.question_path)
262
- all_questions_info = df.to_dict('records')
263
- if 'prompt_type' in self.params['generate']:
264
- prompt_type = self.params['generate']['prompt_type']
265
- prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
266
- else:
267
- prompt = None
268
- if not isinstance(self.llm, ChatCohere):
269
- qa_chain = load_qa_chain(
270
- llm=self.llm, chain_type=self.params['generate']['chain_type'], prompt=prompt, verbose=False
271
- )
272
- file2docs = dict()
273
- for questions_info in tqdm(all_questions_info):
274
- question = questions_info['问题']
275
- file_name = questions_info['文件名']
276
- collection_name = questions_info['知识库名']
277
-
278
- # if question != '请分析江苏中设集团股份有限公司2021年重大关联交易的情况。':
279
- # continue
280
-
281
- if self.params['generate']['with_retrieval']:
282
- # retrieval and rerank
283
- docs = self.retrieval_and_rerank(question, collection_name, max_content=self.params['generate']['max_content'])
284
- else:
285
- # load document
286
- if file_name not in file2docs:
287
- docs = self.load_documents(file_name, max_content=self.params['generate']['max_content'])
288
- file2docs[file_name] = docs
289
- else:
290
- docs = file2docs[file_name]
291
-
292
- if isinstance(self.llm, ChatCohere):
293
- try:
294
- # cohere rag
295
- # messages = prompt.format_prompt(question=question).to_messages()
296
- # ans = self.llm.invoke(messages, source_documents=docs)
297
- # rag_answer = ans.content
298
-
299
- # cohere rag by raw prompt
300
- documents = ''
301
- for i, doc in enumerate(docs):
302
- # documents += f'Document: {i}\n'
303
- # documents += 'text:' + doc.page_content + '\n\n'
304
-
305
- documents += doc.page_content + '\n\n'
306
- messages = prompt.format_prompt(question=question, documents=documents).to_messages()
307
- ans = self.llm.invoke(messages, raw_prompting=True)
308
- rag_answer = ans.content.replace('Answer: ', '')
309
- except Exception as e:
310
- logger.error(f'question: {question}\nerror: {e}')
311
- ans = {'output_text': str(e)}
312
- rag_answer = ans['output_text']
313
- else:
314
- # question answer
315
- try:
316
- ans = qa_chain({"input_documents": docs, "question": question}, return_only_outputs=True)
317
- except Exception as e:
318
- logger.error(f'question: {question}\nerror: {e}')
319
- ans = {'output_text': str(e)}
320
- rag_answer = ans['output_text']
321
-
322
- # context = '\n\n'.join([doc.page_content for doc in docs])
323
- # content = prompt.format(context=context, question=question)
324
-
325
- # for rate_limit
326
- # time.sleep(3)
327
- logger.info(f'question: {question}\nans: {rag_answer}\n')
328
- questions_info['rag_answer'] = rag_answer
329
- # questions_info['rag_context'] = '\n----------------\n'.join([doc.page_content for doc in docs])
330
- # questions_info['rag_context'] = content
331
-
332
- df = pd.DataFrame(all_questions_info)
333
- df.to_excel(self.save_answer_path, index=False)
334
-
335
- def score(self):
336
- """
337
- score
338
- """
339
- metric_params = self.params['metric']
340
- if metric_params['type'] == 'bisheng-ragas':
341
- score_params = {
342
- 'excel_path': self.save_answer_path,
343
- 'save_path': os.path.dirname(self.save_answer_path),
344
- 'question_column': metric_params['question_column'],
345
- 'gt_column': metric_params['gt_column'],
346
- 'answer_column': metric_params['answer_column'],
347
- 'query_type_column': metric_params.get('query_type_column', None),
348
- 'contexts_column': metric_params.get('contexts_column', None),
349
- 'metrics': metric_params['metrics'],
350
- 'batch_size': metric_params['batch_size'],
351
- 'gt_split_column': metric_params.get('gt_split_column', None),
352
- 'whether_gtsplit': metric_params.get('whether_gtsplit', False), # 是否需要模型对gt进行要点拆分
353
- }
354
- rag_score = RagScore(**score_params)
355
- rag_score.score()
356
- else:
357
- # todo: 其他评分方法
358
- pass
359
-
360
-
361
- if __name__ == '__main__':
362
- parser = argparse.ArgumentParser(description='Process some integers.')
363
- # 添加参数
364
- parser.add_argument('--mode', type=str, default='qa', help='upload or qa or score')
365
- parser.add_argument('--params', type=str, default='config/test/baseline_s2b.yaml', help='bisheng rag params')
366
- # 解析参数
367
- args = parser.parse_args()
368
-
369
- rag = BishengRagPipeline(args.params)
370
-
371
- if args.mode == 'upload':
372
- rag.file2knowledge()
373
- elif args.mode == 'qa':
374
- rag.question_answering()
375
- elif args.mode == 'score':
376
- rag.score()