bisheng-langchain 0.3.1.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/chains/__init__.py +4 -1
- bisheng_langchain/chains/qa_generation/__init__.py +0 -0
- bisheng_langchain/chains/qa_generation/base.py +128 -0
- bisheng_langchain/chains/qa_generation/base_v2.py +413 -0
- bisheng_langchain/chains/qa_generation/prompt.py +53 -0
- bisheng_langchain/chains/qa_generation/prompt_v2.py +155 -0
- bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +36 -9
- bisheng_langchain/document_loaders/parsers/ellm_client.py +7 -9
- bisheng_langchain/document_loaders/universal_kv.py +4 -3
- bisheng_langchain/gpts/tools/api_tools/openapi.py +7 -7
- bisheng_langchain/rag/__init__.py +2 -0
- bisheng_langchain/rag/bisheng_rag_chain.py +164 -0
- bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +8 -2
- bisheng_langchain/rag/bisheng_rag_tool.py +47 -24
- bisheng_langchain/rag/config/baseline_caibao_v2.yaml +1 -1
- bisheng_langchain/rag/config/baseline_v2.yaml +3 -2
- bisheng_langchain/rag/prompts/prompt.py +1 -1
- bisheng_langchain/rag/qa_corpus/qa_generator.py +1 -1
- bisheng_langchain/rag/scoring/ragas_score.py +2 -2
- bisheng_langchain/rag/utils.py +27 -4
- bisheng_langchain/sql/__init__.py +3 -0
- bisheng_langchain/sql/base.py +120 -0
- bisheng_langchain/text_splitter.py +1 -1
- {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/METADATA +3 -1
- {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/RECORD +27 -20
- bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +0 -376
- {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/WHEEL +0 -0
- {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/top_level.txt +0 -0
@@ -1,376 +0,0 @@
|
|
1
|
-
import argparse
|
2
|
-
import copy
|
3
|
-
import time
|
4
|
-
import inspect
|
5
|
-
import os
|
6
|
-
import httpx
|
7
|
-
import pandas as pd
|
8
|
-
import yaml
|
9
|
-
import math
|
10
|
-
from tqdm import tqdm
|
11
|
-
from loguru import logger
|
12
|
-
from collections import defaultdict
|
13
|
-
from bisheng_langchain.retrievers import EnsembleRetriever
|
14
|
-
from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
|
15
|
-
from langchain.docstore.document import Document
|
16
|
-
from langchain.chains.question_answering import load_qa_chain
|
17
|
-
from langchain_community.chat_models.cohere import ChatCohere
|
18
|
-
from bisheng_langchain.rag.init_retrievers import (
|
19
|
-
BaselineVectorRetriever,
|
20
|
-
KeywordRetriever,
|
21
|
-
MixRetriever,
|
22
|
-
SmallerChunksVectorRetriever,
|
23
|
-
)
|
24
|
-
from bisheng_langchain.rag.scoring.ragas_score import RagScore
|
25
|
-
from bisheng_langchain.rag.extract_info import extract_title
|
26
|
-
from bisheng_langchain.rag.utils import import_by_type, import_class
|
27
|
-
|
28
|
-
|
29
|
-
class BishengRagPipeline:
|
30
|
-
|
31
|
-
def __init__(self, yaml_path) -> None:
|
32
|
-
self.yaml_path = yaml_path
|
33
|
-
with open(self.yaml_path, 'r') as f:
|
34
|
-
self.params = yaml.safe_load(f)
|
35
|
-
|
36
|
-
# init data
|
37
|
-
self.origin_file_path = self.params['data']['origin_file_path']
|
38
|
-
self.question_path = self.params['data']['question']
|
39
|
-
self.save_answer_path = self.params['data']['save_answer']
|
40
|
-
|
41
|
-
# init embeddings
|
42
|
-
embedding_params = self.params['embedding']
|
43
|
-
embedding_object = import_by_type(_type='embeddings', name=embedding_params['type'])
|
44
|
-
if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
|
45
|
-
embedding_params.pop('type')
|
46
|
-
self.embeddings = embedding_object(
|
47
|
-
http_client=httpx.Client(proxies=embedding_params['openai_proxy']), **embedding_params
|
48
|
-
)
|
49
|
-
else:
|
50
|
-
embedding_params.pop('type')
|
51
|
-
self.embeddings = embedding_object(**embedding_params)
|
52
|
-
|
53
|
-
# init llm
|
54
|
-
llm_params = self.params['chat_llm']
|
55
|
-
llm_object = import_by_type(_type='llms', name=llm_params['type'])
|
56
|
-
if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
|
57
|
-
llm_params.pop('type')
|
58
|
-
self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
|
59
|
-
else:
|
60
|
-
llm_params.pop('type')
|
61
|
-
self.llm = llm_object(**llm_params)
|
62
|
-
|
63
|
-
# milvus
|
64
|
-
self.vector_store = Milvus(
|
65
|
-
embedding_function=self.embeddings,
|
66
|
-
connection_args={
|
67
|
-
"host": self.params['milvus']['host'],
|
68
|
-
"port": self.params['milvus']['port'],
|
69
|
-
},
|
70
|
-
)
|
71
|
-
|
72
|
-
# es
|
73
|
-
self.keyword_store = ElasticKeywordsSearch(
|
74
|
-
index_name='default_es',
|
75
|
-
elasticsearch_url=self.params['elasticsearch']['url'],
|
76
|
-
ssl_verify=self.params['elasticsearch']['ssl_verify'],
|
77
|
-
)
|
78
|
-
|
79
|
-
# init retriever
|
80
|
-
retriever_list = []
|
81
|
-
retrievers = self.params['retriever']['retrievers']
|
82
|
-
for retriever in retrievers:
|
83
|
-
retriever_type = retriever.pop('type')
|
84
|
-
retriever_params = {
|
85
|
-
'vector_store': self.vector_store,
|
86
|
-
'keyword_store': self.keyword_store,
|
87
|
-
'splitter_kwargs': retriever['splitter'],
|
88
|
-
'retrieval_kwargs': retriever['retrieval'],
|
89
|
-
}
|
90
|
-
retriever_list.append(self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
|
91
|
-
self.retriever = EnsembleRetriever(retrievers=retriever_list)
|
92
|
-
|
93
|
-
def _post_init_retriever(self, retriever_type, **kwargs):
|
94
|
-
retriever_classes = {
|
95
|
-
'KeywordRetriever': KeywordRetriever,
|
96
|
-
'BaselineVectorRetriever': BaselineVectorRetriever,
|
97
|
-
'MixRetriever': MixRetriever,
|
98
|
-
'SmallerChunksVectorRetriever': SmallerChunksVectorRetriever,
|
99
|
-
}
|
100
|
-
if retriever_type not in retriever_classes:
|
101
|
-
raise ValueError(f'Unknown retriever type: {retriever_type}')
|
102
|
-
|
103
|
-
input_kwargs = {}
|
104
|
-
splitter_params = kwargs.pop('splitter_kwargs')
|
105
|
-
for key, value in splitter_params.items():
|
106
|
-
splitter_obj = import_by_type(_type='textsplitters', name=value.pop('type'))
|
107
|
-
input_kwargs[key] = splitter_obj(**value)
|
108
|
-
|
109
|
-
retrieval_params = kwargs.pop('retrieval_kwargs')
|
110
|
-
for key, value in retrieval_params.items():
|
111
|
-
input_kwargs[key] = value
|
112
|
-
|
113
|
-
input_kwargs['vector_store'] = kwargs.pop('vector_store')
|
114
|
-
input_kwargs['keyword_store'] = kwargs.pop('keyword_store')
|
115
|
-
|
116
|
-
retriever_class = retriever_classes[retriever_type]
|
117
|
-
return retriever_class(**input_kwargs)
|
118
|
-
|
119
|
-
def file2knowledge(self):
|
120
|
-
"""
|
121
|
-
file to knowledge
|
122
|
-
"""
|
123
|
-
df = pd.read_excel(self.question_path)
|
124
|
-
if ('文件名' not in df.columns) or ('知识库名' not in df.columns):
|
125
|
-
raise Exception(f'文件名 or 知识库名 not in {self.question_path}.')
|
126
|
-
|
127
|
-
loader_params = self.params['loader']
|
128
|
-
loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
|
129
|
-
|
130
|
-
all_questions_info = df.to_dict('records')
|
131
|
-
collectionname2filename = defaultdict(set)
|
132
|
-
for info in all_questions_info:
|
133
|
-
# 存入set,去掉重复的文件名
|
134
|
-
collectionname2filename[info['知识库名']].add(info['文件名'])
|
135
|
-
|
136
|
-
for collection_name in tqdm(collectionname2filename):
|
137
|
-
all_file_paths = []
|
138
|
-
for file_name in collectionname2filename[collection_name]:
|
139
|
-
file_path = os.path.join(self.origin_file_path, file_name)
|
140
|
-
if not os.path.exists(file_path):
|
141
|
-
raise Exception(f'{file_path} not exists.')
|
142
|
-
# file path可以是文件夹或者单个文件
|
143
|
-
if os.path.isdir(file_path):
|
144
|
-
# 文件夹包含多个文件
|
145
|
-
all_file_paths.extend(
|
146
|
-
[os.path.join(file_path, name) for name in os.listdir(file_path) if not name.startswith('.')]
|
147
|
-
)
|
148
|
-
else:
|
149
|
-
# 单个文件
|
150
|
-
all_file_paths.append(file_path)
|
151
|
-
|
152
|
-
# 当前知识库需要存储的所有文件
|
153
|
-
collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
|
154
|
-
for index, each_file_path in enumerate(all_file_paths):
|
155
|
-
logger.info(f'each_file_path: {each_file_path}')
|
156
|
-
loader = loader_object(
|
157
|
-
file_name=os.path.basename(each_file_path), file_path=each_file_path, **loader_params
|
158
|
-
)
|
159
|
-
documents = loader.load()
|
160
|
-
|
161
|
-
# # load from text
|
162
|
-
# if each_file_path.endswith('.pdf'):
|
163
|
-
# with open(each_file_path.replace('.pdf', '.txt'), 'r') as f:
|
164
|
-
# content = f.read()
|
165
|
-
# documents = [Document(page_content=content, metadata={'source': os.path.basename(each_file_path)})]
|
166
|
-
|
167
|
-
logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
|
168
|
-
if len(documents[0].page_content) == 0:
|
169
|
-
logger.error(f'{each_file_path} page_content is empty.')
|
170
|
-
|
171
|
-
# add aux infoerror
|
172
|
-
add_aux_info = self.params['retriever'].get('add_aux_info', False)
|
173
|
-
if add_aux_info:
|
174
|
-
for doc in documents:
|
175
|
-
try:
|
176
|
-
title = extract_title(llm=self.llm, text=doc.page_content)
|
177
|
-
logger.info(f'extract title: {title}')
|
178
|
-
except Exception as e:
|
179
|
-
logger.error(f"Failed to extract title: {e}")
|
180
|
-
title = ''
|
181
|
-
doc.metadata['title'] = title
|
182
|
-
|
183
|
-
vector_drop_old = self.params['milvus']['drop_old'] if index == 0 else False
|
184
|
-
keyword_drop_old = self.params['elasticsearch']['drop_old'] if index == 0 else False
|
185
|
-
for idx, retriever in enumerate(self.retriever.retrievers):
|
186
|
-
retriever.add_documents(documents, collection_name, vector_drop_old, add_aux_info=add_aux_info)
|
187
|
-
|
188
|
-
def retrieval_and_rerank(self, question, collection_name, max_content=100000):
|
189
|
-
"""
|
190
|
-
retrieval and rerank
|
191
|
-
"""
|
192
|
-
collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
|
193
|
-
# EnsembleRetriever直接检索召回会默认去重
|
194
|
-
docs = self.retriever.get_relevant_documents(query=question, collection_name=collection_name)
|
195
|
-
logger.info(f'retrieval docs origin: {len(docs)}')
|
196
|
-
|
197
|
-
# delete duplicate
|
198
|
-
if self.params['post_retrieval']['delete_duplicate']:
|
199
|
-
logger.info(f'origin docs: {len(docs)}')
|
200
|
-
all_contents = []
|
201
|
-
docs_no_dup = []
|
202
|
-
for index, doc in enumerate(docs):
|
203
|
-
doc_content = doc.page_content
|
204
|
-
if doc_content in all_contents:
|
205
|
-
continue
|
206
|
-
all_contents.append(doc_content)
|
207
|
-
docs_no_dup.append(doc)
|
208
|
-
docs = docs_no_dup
|
209
|
-
logger.info(f'delete duplicate docs: {len(docs)}')
|
210
|
-
|
211
|
-
# rerank
|
212
|
-
if self.params['post_retrieval']['with_rank'] and len(docs):
|
213
|
-
if not hasattr(self, 'ranker'):
|
214
|
-
rerank_params = self.params['post_retrieval']['rerank']
|
215
|
-
rerank_type = rerank_params.pop('type')
|
216
|
-
rerank_object = import_class(f'bisheng_langchain.rag.rerank.{rerank_type}')
|
217
|
-
self.ranker = rerank_object(**rerank_params)
|
218
|
-
docs = getattr(self, 'ranker').sort_and_filter(question, docs)
|
219
|
-
|
220
|
-
# delete redundancy according to max_content
|
221
|
-
doc_num, doc_content_sum = 0, 0
|
222
|
-
for doc in docs:
|
223
|
-
doc_content_sum += len(doc.page_content)
|
224
|
-
if doc_content_sum > max_content:
|
225
|
-
break
|
226
|
-
doc_num += 1
|
227
|
-
docs = docs[:doc_num]
|
228
|
-
logger.info(f'retrieval docs after delete redundancy: {len(docs)}')
|
229
|
-
|
230
|
-
# 按照文档的source和chunk_index排序,保证上下文的连贯性和一致性
|
231
|
-
if self.params['post_retrieval'].get('sort_by_source_and_index', False):
|
232
|
-
logger.info('sort chunks by source and chunk_index')
|
233
|
-
docs = sorted(docs, key=lambda x: (x.metadata['source'], x.metadata['chunk_index']))
|
234
|
-
return docs
|
235
|
-
|
236
|
-
def load_documents(self, file_name, max_content=100000):
|
237
|
-
"""
|
238
|
-
直接加载文档,如果文档过长,直接截断处理;
|
239
|
-
max_content: max content len of llm
|
240
|
-
"""
|
241
|
-
file_path = os.path.join(self.origin_file_path, file_name)
|
242
|
-
if not os.path.exists(file_path):
|
243
|
-
raise Exception(f'{file_path} not exists.')
|
244
|
-
if os.path.isdir(file_path):
|
245
|
-
raise Exception(f'{file_path} is a directory.')
|
246
|
-
|
247
|
-
loader_params = copy.deepcopy(self.params['loader'])
|
248
|
-
loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
|
249
|
-
loader = loader_object(file_name=file_name, file_path=file_path, **loader_params)
|
250
|
-
|
251
|
-
documents = loader.load()
|
252
|
-
logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
|
253
|
-
for doc in documents:
|
254
|
-
doc.page_content = doc.page_content[:max_content]
|
255
|
-
return documents
|
256
|
-
|
257
|
-
def question_answering(self):
|
258
|
-
"""
|
259
|
-
question answer over knowledge
|
260
|
-
"""
|
261
|
-
df = pd.read_excel(self.question_path)
|
262
|
-
all_questions_info = df.to_dict('records')
|
263
|
-
if 'prompt_type' in self.params['generate']:
|
264
|
-
prompt_type = self.params['generate']['prompt_type']
|
265
|
-
prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
|
266
|
-
else:
|
267
|
-
prompt = None
|
268
|
-
if not isinstance(self.llm, ChatCohere):
|
269
|
-
qa_chain = load_qa_chain(
|
270
|
-
llm=self.llm, chain_type=self.params['generate']['chain_type'], prompt=prompt, verbose=False
|
271
|
-
)
|
272
|
-
file2docs = dict()
|
273
|
-
for questions_info in tqdm(all_questions_info):
|
274
|
-
question = questions_info['问题']
|
275
|
-
file_name = questions_info['文件名']
|
276
|
-
collection_name = questions_info['知识库名']
|
277
|
-
|
278
|
-
# if question != '请分析江苏中设集团股份有限公司2021年重大关联交易的情况。':
|
279
|
-
# continue
|
280
|
-
|
281
|
-
if self.params['generate']['with_retrieval']:
|
282
|
-
# retrieval and rerank
|
283
|
-
docs = self.retrieval_and_rerank(question, collection_name, max_content=self.params['generate']['max_content'])
|
284
|
-
else:
|
285
|
-
# load document
|
286
|
-
if file_name not in file2docs:
|
287
|
-
docs = self.load_documents(file_name, max_content=self.params['generate']['max_content'])
|
288
|
-
file2docs[file_name] = docs
|
289
|
-
else:
|
290
|
-
docs = file2docs[file_name]
|
291
|
-
|
292
|
-
if isinstance(self.llm, ChatCohere):
|
293
|
-
try:
|
294
|
-
# cohere rag
|
295
|
-
# messages = prompt.format_prompt(question=question).to_messages()
|
296
|
-
# ans = self.llm.invoke(messages, source_documents=docs)
|
297
|
-
# rag_answer = ans.content
|
298
|
-
|
299
|
-
# cohere rag by raw prompt
|
300
|
-
documents = ''
|
301
|
-
for i, doc in enumerate(docs):
|
302
|
-
# documents += f'Document: {i}\n'
|
303
|
-
# documents += 'text:' + doc.page_content + '\n\n'
|
304
|
-
|
305
|
-
documents += doc.page_content + '\n\n'
|
306
|
-
messages = prompt.format_prompt(question=question, documents=documents).to_messages()
|
307
|
-
ans = self.llm.invoke(messages, raw_prompting=True)
|
308
|
-
rag_answer = ans.content.replace('Answer: ', '')
|
309
|
-
except Exception as e:
|
310
|
-
logger.error(f'question: {question}\nerror: {e}')
|
311
|
-
ans = {'output_text': str(e)}
|
312
|
-
rag_answer = ans['output_text']
|
313
|
-
else:
|
314
|
-
# question answer
|
315
|
-
try:
|
316
|
-
ans = qa_chain({"input_documents": docs, "question": question}, return_only_outputs=True)
|
317
|
-
except Exception as e:
|
318
|
-
logger.error(f'question: {question}\nerror: {e}')
|
319
|
-
ans = {'output_text': str(e)}
|
320
|
-
rag_answer = ans['output_text']
|
321
|
-
|
322
|
-
# context = '\n\n'.join([doc.page_content for doc in docs])
|
323
|
-
# content = prompt.format(context=context, question=question)
|
324
|
-
|
325
|
-
# for rate_limit
|
326
|
-
# time.sleep(3)
|
327
|
-
logger.info(f'question: {question}\nans: {rag_answer}\n')
|
328
|
-
questions_info['rag_answer'] = rag_answer
|
329
|
-
# questions_info['rag_context'] = '\n----------------\n'.join([doc.page_content for doc in docs])
|
330
|
-
# questions_info['rag_context'] = content
|
331
|
-
|
332
|
-
df = pd.DataFrame(all_questions_info)
|
333
|
-
df.to_excel(self.save_answer_path, index=False)
|
334
|
-
|
335
|
-
def score(self):
|
336
|
-
"""
|
337
|
-
score
|
338
|
-
"""
|
339
|
-
metric_params = self.params['metric']
|
340
|
-
if metric_params['type'] == 'bisheng-ragas':
|
341
|
-
score_params = {
|
342
|
-
'excel_path': self.save_answer_path,
|
343
|
-
'save_path': os.path.dirname(self.save_answer_path),
|
344
|
-
'question_column': metric_params['question_column'],
|
345
|
-
'gt_column': metric_params['gt_column'],
|
346
|
-
'answer_column': metric_params['answer_column'],
|
347
|
-
'query_type_column': metric_params.get('query_type_column', None),
|
348
|
-
'contexts_column': metric_params.get('contexts_column', None),
|
349
|
-
'metrics': metric_params['metrics'],
|
350
|
-
'batch_size': metric_params['batch_size'],
|
351
|
-
'gt_split_column': metric_params.get('gt_split_column', None),
|
352
|
-
'whether_gtsplit': metric_params.get('whether_gtsplit', False), # 是否需要模型对gt进行要点拆分
|
353
|
-
}
|
354
|
-
rag_score = RagScore(**score_params)
|
355
|
-
rag_score.score()
|
356
|
-
else:
|
357
|
-
# todo: 其他评分方法
|
358
|
-
pass
|
359
|
-
|
360
|
-
|
361
|
-
if __name__ == '__main__':
|
362
|
-
parser = argparse.ArgumentParser(description='Process some integers.')
|
363
|
-
# 添加参数
|
364
|
-
parser.add_argument('--mode', type=str, default='qa', help='upload or qa or score')
|
365
|
-
parser.add_argument('--params', type=str, default='config/test/baseline_s2b.yaml', help='bisheng rag params')
|
366
|
-
# 解析参数
|
367
|
-
args = parser.parse_args()
|
368
|
-
|
369
|
-
rag = BishengRagPipeline(args.params)
|
370
|
-
|
371
|
-
if args.mode == 'upload':
|
372
|
-
rag.file2knowledge()
|
373
|
-
elif args.mode == 'qa':
|
374
|
-
rag.question_answering()
|
375
|
-
elif args.mode == 'score':
|
376
|
-
rag.score()
|
File without changes
|
File without changes
|