bisheng-langchain 0.3.7.1__py3-none-any.whl → 0.4.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/document_loaders/custom_kv.py +2 -2
- bisheng_langchain/document_loaders/elem_pdf.py +3 -3
- bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +2 -2
- bisheng_langchain/document_loaders/parsers/image.py +1 -1
- bisheng_langchain/document_loaders/universal_kv.py +2 -2
- bisheng_langchain/gpts/agent_types/llm_functions_agent.py +35 -5
- bisheng_langchain/gpts/agent_types/llm_react_agent.py +27 -39
- bisheng_langchain/gpts/assistant.py +24 -24
- bisheng_langchain/gpts/tools/api_tools/base.py +1 -1
- bisheng_langchain/gpts/tools/api_tools/openapi.py +59 -32
- bisheng_langchain/rag/bisheng_rag_chain.py +26 -32
- bisheng_langchain/rag/bisheng_rag_tool.py +98 -98
- bisheng_langchain/rag/extract_info.py +0 -2
- bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +8 -12
- bisheng_langchain/rag/init_retrievers/keyword_retriever.py +8 -16
- bisheng_langchain/rag/init_retrievers/mix_retriever.py +16 -17
- bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +8 -8
- bisheng_langchain/sql/base.py +1 -1
- bisheng_langchain/vectorstores/elastic_keywords_search.py +17 -2
- bisheng_langchain/vectorstores/milvus.py +76 -69
- {bisheng_langchain-0.3.7.1.dist-info → bisheng_langchain-0.4.0.dev1.dist-info}/METADATA +6 -6
- {bisheng_langchain-0.3.7.1.dist-info → bisheng_langchain-0.4.0.dev1.dist-info}/RECORD +24 -24
- {bisheng_langchain-0.3.7.1.dist-info → bisheng_langchain-0.4.0.dev1.dist-info}/WHEEL +0 -0
- {bisheng_langchain-0.3.7.1.dist-info → bisheng_langchain-0.4.0.dev1.dist-info}/top_level.txt +0 -0
@@ -1,25 +1,20 @@
|
|
1
1
|
"""Chain for question-answering against a vector database."""
|
2
2
|
from __future__ import annotations
|
3
3
|
|
4
|
-
import inspect
|
5
|
-
from abc import abstractmethod
|
6
4
|
from typing import Any, Dict, List, Optional
|
7
5
|
|
8
|
-
from
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
)
|
13
|
-
from langchain_core.prompts import PromptTemplate, BasePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
|
6
|
+
from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
|
7
|
+
from langchain.chains.base import Chain
|
8
|
+
from langchain_core.callbacks import (AsyncCallbackManagerForChainRun, CallbackManagerForChainRun,
|
9
|
+
Callbacks)
|
14
10
|
from langchain_core.language_models import BaseLanguageModel
|
11
|
+
from langchain_core.prompts import (ChatPromptTemplate, HumanMessagePromptTemplate,
|
12
|
+
SystemMessagePromptTemplate)
|
15
13
|
from langchain_core.pydantic_v1 import Extra, Field
|
16
|
-
from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
|
17
14
|
|
18
|
-
from langchain.chains.base import Chain
|
19
15
|
from .bisheng_rag_tool import BishengRAGTool
|
20
16
|
|
21
|
-
|
22
|
-
# system_template = """Use the following pieces of context to answer the user's question.
|
17
|
+
# system_template = """Use the following pieces of context to answer the user's question.
|
23
18
|
# If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
24
19
|
# ----------------
|
25
20
|
# {context}"""
|
@@ -29,7 +24,6 @@ from .bisheng_rag_tool import BishengRAGTool
|
|
29
24
|
# ]
|
30
25
|
# DEFAULT_QA_PROMPT = ChatPromptTemplate.from_messages(messages)
|
31
26
|
|
32
|
-
|
33
27
|
system_template_general = """你是一个准确且可靠的知识库问答助手,能够借助上下文知识回答问题。你需要根据以下的规则来回答问题:
|
34
28
|
1. 如果上下文中包含了正确答案,你需要根据上下文进行准确的回答。但是在回答前,你需要注意,上下文中的信息可能存在事实性错误,如果文档中存在和事实不一致的错误,请根据事实回答。
|
35
29
|
2. 如果上下文中不包含答案,就说你不知道,不要试图编造答案。
|
@@ -51,15 +45,13 @@ DEFAULT_QA_PROMPT = ChatPromptTemplate.from_messages(messages_general)
|
|
51
45
|
|
52
46
|
class BishengRetrievalQA(Chain):
|
53
47
|
"""Base class for question-answering chains."""
|
54
|
-
|
55
48
|
"""Chain to use to combine the documents."""
|
56
|
-
input_key: str =
|
57
|
-
output_key: str =
|
49
|
+
input_key: str = 'query' #: :meta private:
|
50
|
+
output_key: str = 'result' #: :meta private:
|
58
51
|
return_source_documents: bool = False
|
59
52
|
"""Return the source documents or not."""
|
60
|
-
bisheng_rag_tool: BishengRAGTool = Field(
|
61
|
-
|
62
|
-
)
|
53
|
+
bisheng_rag_tool: BishengRAGTool = Field(default_factory=BishengRAGTool,
|
54
|
+
description='RAG tool')
|
63
55
|
|
64
56
|
class Config:
|
65
57
|
"""Configuration for this pydantic object."""
|
@@ -84,7 +76,7 @@ class BishengRetrievalQA(Chain):
|
|
84
76
|
"""
|
85
77
|
_output_keys = [self.output_key]
|
86
78
|
if self.return_source_documents:
|
87
|
-
_output_keys = _output_keys + [
|
79
|
+
_output_keys = _output_keys + ['source_documents']
|
88
80
|
return _output_keys
|
89
81
|
|
90
82
|
@classmethod
|
@@ -100,15 +92,13 @@ class BishengRetrievalQA(Chain):
|
|
100
92
|
return_source_documents: bool = False,
|
101
93
|
**kwargs: Any,
|
102
94
|
) -> BishengRetrievalQA:
|
103
|
-
bisheng_rag_tool = BishengRAGTool(
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
**kwargs
|
111
|
-
)
|
95
|
+
bisheng_rag_tool = BishengRAGTool(vector_store=vector_store,
|
96
|
+
keyword_store=keyword_store,
|
97
|
+
llm=llm,
|
98
|
+
QA_PROMPT=QA_PROMPT,
|
99
|
+
max_content=max_content,
|
100
|
+
sort_by_source_and_index=sort_by_source_and_index,
|
101
|
+
**kwargs)
|
112
102
|
return cls(
|
113
103
|
bisheng_rag_tool=bisheng_rag_tool,
|
114
104
|
callbacks=callbacks,
|
@@ -134,8 +124,12 @@ class BishengRetrievalQA(Chain):
|
|
134
124
|
"""
|
135
125
|
question = inputs[self.input_key]
|
136
126
|
if self.return_source_documents:
|
137
|
-
answer, docs = self.bisheng_rag_tool.run(
|
138
|
-
|
127
|
+
answer, docs = self.bisheng_rag_tool.run(
|
128
|
+
question,
|
129
|
+
return_only_outputs=False,
|
130
|
+
run_manager=run_manager,
|
131
|
+
)
|
132
|
+
return {self.output_key: answer, 'source_documents': docs}
|
139
133
|
else:
|
140
134
|
answer = self.bisheng_rag_tool.run(question, return_only_outputs=True)
|
141
135
|
return {self.output_key: answer}
|
@@ -160,7 +154,7 @@ class BishengRetrievalQA(Chain):
|
|
160
154
|
|
161
155
|
if self.return_source_documents:
|
162
156
|
answer, docs = await self.bisheng_rag_tool.arun(question, return_only_outputs=False)
|
163
|
-
return {self.output_key: answer,
|
157
|
+
return {self.output_key: answer, 'source_documents': docs}
|
164
158
|
else:
|
165
159
|
answer = await self.bisheng_rag_tool.arun(question, return_only_outputs=True)
|
166
160
|
return {self.output_key: answer}
|
@@ -1,27 +1,25 @@
|
|
1
|
-
import time
|
2
1
|
import os
|
3
|
-
import
|
4
|
-
import httpx
|
5
|
-
from typing import Any, Dict, Tuple, Type, Union, Optional
|
2
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
6
3
|
|
7
|
-
|
8
|
-
|
9
|
-
from
|
10
|
-
from
|
11
|
-
|
12
|
-
from
|
13
|
-
from langchain.chains.llm import LLMChain
|
14
|
-
from langchain.chains.question_answering import load_qa_chain
|
4
|
+
import httpx
|
5
|
+
import yaml
|
6
|
+
from bisheng_langchain.rag.extract_info import extract_title
|
7
|
+
from bisheng_langchain.rag.init_retrievers import (BaselineVectorRetriever, KeywordRetriever,
|
8
|
+
MixRetriever, SmallerChunksVectorRetriever)
|
9
|
+
from bisheng_langchain.rag.utils import import_by_type, import_class
|
15
10
|
from bisheng_langchain.retrievers import EnsembleRetriever
|
16
11
|
from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
|
17
|
-
from
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
from
|
24
|
-
from
|
12
|
+
from langchain.chains.llm import LLMChain
|
13
|
+
from langchain.chains.question_answering import load_qa_chain
|
14
|
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
15
|
+
from langchain_core.callbacks import CallbackManagerForChainRun
|
16
|
+
from langchain_core.language_models.base import LanguageModelLike
|
17
|
+
from langchain_core.prompts import ChatPromptTemplate
|
18
|
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
19
|
+
from langchain_core.runnables import RunnableConfig
|
20
|
+
from langchain_core.tools import BaseTool, Tool
|
21
|
+
from langchain_core.vectorstores import VectorStoreRetriever
|
22
|
+
from loguru import logger
|
25
23
|
|
26
24
|
|
27
25
|
class MultArgsSchemaTool(Tool):
|
@@ -37,26 +35,27 @@ class MultArgsSchemaTool(Tool):
|
|
37
35
|
|
38
36
|
class BishengRAGTool:
|
39
37
|
|
40
|
-
def __init__(
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
**kwargs
|
48
|
-
) -> None:
|
38
|
+
def __init__(self,
|
39
|
+
vector_store: Optional[Milvus] = None,
|
40
|
+
keyword_store: Optional[ElasticKeywordsSearch] = None,
|
41
|
+
llm: Optional[LanguageModelLike] = None,
|
42
|
+
collection_name: Optional[str] = None,
|
43
|
+
QA_PROMPT: Optional[ChatPromptTemplate] = None,
|
44
|
+
**kwargs) -> None:
|
49
45
|
if collection_name is None and (keyword_store is None or vector_store is None):
|
50
|
-
raise ValueError(
|
46
|
+
raise ValueError(
|
47
|
+
'collection_name must be provided if keyword_store or vector_store is not provided'
|
48
|
+
)
|
51
49
|
self.collection_name = collection_name
|
52
|
-
|
53
|
-
yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
50
|
+
|
51
|
+
yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
52
|
+
'config/baseline_v2.yaml')
|
54
53
|
with open(yaml_path, 'r', encoding='utf-8') as f:
|
55
54
|
self.params = yaml.safe_load(f)
|
56
|
-
|
55
|
+
|
57
56
|
# update params
|
58
|
-
max_content = kwargs.get(
|
59
|
-
sort_by_source_and_index = kwargs.get(
|
57
|
+
max_content = kwargs.get('max_content', 15000)
|
58
|
+
sort_by_source_and_index = kwargs.get('sort_by_source_and_index', True)
|
60
59
|
self.params['generate']['max_content'] = max_content
|
61
60
|
self.params['post_retrieval']['sort_by_source_and_index'] = sort_by_source_and_index
|
62
61
|
|
@@ -68,11 +67,12 @@ class BishengRAGTool:
|
|
68
67
|
llm_object = import_by_type(_type='llms', name=llm_params['type'])
|
69
68
|
if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
|
70
69
|
llm_params.pop('type')
|
71
|
-
self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']),
|
70
|
+
self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']),
|
71
|
+
**llm_params)
|
72
72
|
else:
|
73
73
|
llm_params.pop('type')
|
74
74
|
self.llm = llm_object(**llm_params)
|
75
|
-
|
75
|
+
|
76
76
|
# init milvus
|
77
77
|
if vector_store:
|
78
78
|
# if vector_store is retriever, get vector_store instance
|
@@ -87,26 +87,27 @@ class BishengRAGTool:
|
|
87
87
|
if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
|
88
88
|
embedding_params.pop('type')
|
89
89
|
self.embeddings = embedding_object(
|
90
|
-
http_client=httpx.Client(proxies=embedding_params['openai_proxy']),
|
91
|
-
|
90
|
+
http_client=httpx.Client(proxies=embedding_params['openai_proxy']),
|
91
|
+
**embedding_params)
|
92
92
|
else:
|
93
93
|
embedding_params.pop('type')
|
94
94
|
self.embeddings = embedding_object(**embedding_params)
|
95
|
-
|
95
|
+
|
96
96
|
self.vector_store = Milvus(
|
97
97
|
embedding_function=self.embeddings,
|
98
98
|
connection_args={
|
99
|
-
|
100
|
-
|
99
|
+
'host': self.params['milvus']['host'],
|
100
|
+
'port': self.params['milvus']['port'],
|
101
101
|
},
|
102
102
|
)
|
103
|
-
|
103
|
+
|
104
104
|
# init keyword store
|
105
105
|
if keyword_store:
|
106
106
|
self.keyword_store = keyword_store
|
107
107
|
else:
|
108
108
|
if self.params['elasticsearch'].get('extract_key_by_llm', False):
|
109
|
-
extract_key_prompt = import_class(
|
109
|
+
extract_key_prompt = import_class(
|
110
|
+
'bisheng_langchain.rag.prompts.EXTRACT_KEY_PROMPT')
|
110
111
|
llm_chain = LLMChain(llm=self.llm, prompt=extract_key_prompt)
|
111
112
|
else:
|
112
113
|
llm_chain = None
|
@@ -128,10 +129,11 @@ class BishengRAGTool:
|
|
128
129
|
'splitter_kwargs': retriever['splitter'],
|
129
130
|
'retrieval_kwargs': retriever['retrieval'],
|
130
131
|
}
|
131
|
-
retriever_list.append(
|
132
|
+
retriever_list.append(
|
133
|
+
self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
|
132
134
|
self.retriever = EnsembleRetriever(retrievers=retriever_list)
|
133
135
|
|
134
|
-
# init qa chain
|
136
|
+
# init qa chain
|
135
137
|
if QA_PROMPT:
|
136
138
|
prompt = QA_PROMPT
|
137
139
|
else:
|
@@ -140,13 +142,8 @@ class BishengRAGTool:
|
|
140
142
|
prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
|
141
143
|
else:
|
142
144
|
prompt = None
|
143
|
-
self.qa_chain =
|
144
|
-
|
145
|
-
chain_type=self.params['generate']['chain_type'],
|
146
|
-
prompt=prompt,
|
147
|
-
verbose=False
|
148
|
-
)
|
149
|
-
|
145
|
+
self.qa_chain = create_stuff_documents_chain(llm=self.llm, prompt=prompt)
|
146
|
+
|
150
147
|
def _post_init_retriever(self, retriever_type, **kwargs):
|
151
148
|
retriever_classes = {
|
152
149
|
'KeywordRetriever': KeywordRetriever,
|
@@ -181,9 +178,9 @@ class BishengRAGTool:
|
|
181
178
|
loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
|
182
179
|
|
183
180
|
logger.info(f'file_path: {file_path}')
|
184
|
-
loader = loader_object(
|
185
|
-
|
186
|
-
|
181
|
+
loader = loader_object(file_name=os.path.basename(file_path),
|
182
|
+
file_path=file_path,
|
183
|
+
**loader_params)
|
187
184
|
documents = loader.load()
|
188
185
|
logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
|
189
186
|
if len(documents[0].page_content) == 0:
|
@@ -197,30 +194,26 @@ class BishengRAGTool:
|
|
197
194
|
title = extract_title(llm=self.llm, text=doc.page_content)
|
198
195
|
logger.info(f'extract title: {title}')
|
199
196
|
except Exception as e:
|
200
|
-
logger.error(f
|
197
|
+
logger.error(f'Failed to extract title: {e}')
|
201
198
|
title = ''
|
202
199
|
doc.metadata['title'] = title
|
203
200
|
|
204
201
|
for idx, retriever in enumerate(self.retriever.retrievers):
|
205
|
-
retriever.add_documents(
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
)
|
211
|
-
|
202
|
+
retriever.add_documents(documents,
|
203
|
+
self.collection_name,
|
204
|
+
drop_old=drop_old,
|
205
|
+
add_aux_info=add_aux_info)
|
206
|
+
|
212
207
|
def retrieval_and_rerank(self, query):
|
213
208
|
"""
|
214
209
|
retrieval and rerank
|
215
210
|
"""
|
216
211
|
# EnsembleRetriever直接检索召回会默认去重
|
217
|
-
docs = self.retriever.get_relevant_documents(
|
218
|
-
|
219
|
-
collection_name=self.collection_name
|
220
|
-
)
|
212
|
+
docs = self.retriever.get_relevant_documents(query=query,
|
213
|
+
collection_name=self.collection_name)
|
221
214
|
logger.info(f'retrieval docs origin: {len(docs)}')
|
222
215
|
|
223
|
-
# delete redundancy according to max_content
|
216
|
+
# delete redundancy according to max_content
|
224
217
|
doc_num, doc_content_sum = 0, 0
|
225
218
|
for doc in docs:
|
226
219
|
doc_content_sum += len(doc.page_content)
|
@@ -235,28 +228,37 @@ class BishengRAGTool:
|
|
235
228
|
logger.info('sort chunks by source and chunk_index')
|
236
229
|
docs = sorted(docs, key=lambda x: (x.metadata['source'], x.metadata['chunk_index']))
|
237
230
|
return docs
|
238
|
-
|
239
|
-
def run(self,
|
231
|
+
|
232
|
+
def run(self,
|
233
|
+
query,
|
234
|
+
return_only_outputs=True,
|
235
|
+
run_manager: Optional[CallbackManagerForChainRun] = None) -> Any:
|
240
236
|
docs = self.retrieval_and_rerank(query)
|
241
237
|
try:
|
242
|
-
|
238
|
+
kwargs = {}
|
239
|
+
if run_manager:
|
240
|
+
kwargs['config'] = RunnableConfig(callbacks=[run_manager])
|
241
|
+
ans = self.qa_chain.invoke(
|
242
|
+
{
|
243
|
+
'context': docs,
|
244
|
+
'question': query
|
245
|
+
}, **kwargs
|
246
|
+
)
|
243
247
|
except Exception as e:
|
244
|
-
logger.
|
245
|
-
ans =
|
248
|
+
logger.exception(f'question: {query}\nerror: {e}')
|
249
|
+
ans = str(e)
|
246
250
|
if return_only_outputs:
|
247
|
-
|
248
|
-
return rag_answer
|
251
|
+
return ans
|
249
252
|
else:
|
250
|
-
|
251
|
-
|
252
|
-
return rag_answer, input_documents
|
253
|
-
|
253
|
+
return ans, docs
|
254
|
+
|
254
255
|
async def arun(self, query: str, return_only_outputs=True) -> str:
|
255
256
|
rag_answer = self.run(query, return_only_outputs)
|
256
257
|
return rag_answer
|
257
|
-
|
258
|
+
|
258
259
|
@classmethod
|
259
260
|
def get_rag_tool(cls, name, description, **kwargs: Any) -> BaseTool:
|
261
|
+
|
260
262
|
class InputArgs(BaseModel):
|
261
263
|
query: str = Field(description='question asked by the user.')
|
262
264
|
|
@@ -265,7 +267,7 @@ class BishengRAGTool:
|
|
265
267
|
func=cls(**kwargs).run,
|
266
268
|
coroutine=cls(**kwargs).arun,
|
267
269
|
args_schema=InputArgs)
|
268
|
-
|
270
|
+
|
269
271
|
|
270
272
|
if __name__ == '__main__':
|
271
273
|
# rag_tool = BishengRAGTool(collection_name='rag_finance_report_0_test')
|
@@ -280,32 +282,30 @@ if __name__ == '__main__':
|
|
280
282
|
collection_name = 'rag_finance_report_0_benchmark_caibao_1000_source_title'
|
281
283
|
# milvus
|
282
284
|
vector_store = Milvus(
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
285
|
+
collection_name=collection_name,
|
286
|
+
embedding_function=embeddings,
|
287
|
+
connection_args={
|
288
|
+
'host': '110.16.193.170',
|
289
|
+
'port': '50062',
|
290
|
+
},
|
289
291
|
)
|
290
292
|
# es
|
291
293
|
keyword_store = ElasticKeywordsSearch(
|
292
294
|
index_name=collection_name,
|
293
295
|
elasticsearch_url='http://110.16.193.170:50062/es',
|
294
|
-
ssl_verify={'basic_auth': [
|
296
|
+
ssl_verify={'basic_auth': ['elastic', 'oSGL-zVvZ5P3Tm7qkDLC']},
|
295
297
|
)
|
296
298
|
|
297
|
-
tool = BishengRAGTool.get_rag_tool(
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
llm=llm
|
303
|
-
)
|
299
|
+
tool = BishengRAGTool.get_rag_tool(name='rag_knowledge_retrieve',
|
300
|
+
description='金融年报财报知识库问答',
|
301
|
+
vector_store=vector_store,
|
302
|
+
keyword_store=keyword_store,
|
303
|
+
llm=llm)
|
304
304
|
print(tool.run('能否根据2020年金宇生物技术股份有限公司的年报,给我简要介绍一下报告期内公司的社会责任工作情况?'))
|
305
305
|
|
306
306
|
# tool = BishengRAGTool.get_rag_tool(
|
307
|
-
# name='rag_knowledge_retrieve',
|
307
|
+
# name='rag_knowledge_retrieve',
|
308
308
|
# description='金融年报财报知识库问答',
|
309
309
|
# collection_name='rag_finance_report_0_benchmark_caibao_1000_source_title'
|
310
310
|
# )
|
311
|
-
# print(tool.run('能否根据2020年金宇生物技术股份有限公司的年报,给我简要介绍一下报告期内公司的社会责任工作情况?'))
|
311
|
+
# print(tool.run('能否根据2020年金宇生物技术股份有限公司的年报,给我简要介绍一下报告期内公司的社会责任工作情况?'))
|
@@ -1,20 +1,15 @@
|
|
1
|
-
import
|
2
|
-
import uuid
|
3
|
-
from loguru import logger
|
4
|
-
from typing import Any, Dict, Iterable, List, Optional
|
1
|
+
from typing import Any, List, Optional
|
5
2
|
|
6
|
-
from
|
3
|
+
from langchain.text_splitter import TextSplitter
|
7
4
|
from langchain_core.documents import Document
|
8
5
|
from langchain_core.pydantic_v1 import Field
|
9
6
|
from langchain_core.retrievers import BaseRetriever
|
10
|
-
from
|
11
|
-
|
12
|
-
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
13
|
-
from langchain.text_splitter import TextSplitter
|
7
|
+
from loguru import logger
|
14
8
|
|
15
9
|
|
16
10
|
class BaselineVectorRetriever(BaseRetriever):
|
17
|
-
|
11
|
+
|
12
|
+
vector_store: Any
|
18
13
|
text_splitter: TextSplitter
|
19
14
|
search_type: str = 'similarity'
|
20
15
|
search_kwargs: dict = Field(default_factory=dict)
|
@@ -27,13 +22,14 @@ class BaselineVectorRetriever(BaseRetriever):
|
|
27
22
|
**kwargs,
|
28
23
|
) -> None:
|
29
24
|
split_docs = self.text_splitter.split_documents(documents)
|
30
|
-
logger.info(f
|
25
|
+
logger.info(f'BaselineVectorRetriever: split document into {len(split_docs)} chunks')
|
31
26
|
for chunk_index, split_doc in enumerate(split_docs):
|
32
27
|
if 'chunk_bboxes' in split_doc.metadata:
|
33
28
|
split_doc.metadata.pop('chunk_bboxes')
|
34
29
|
split_doc.metadata['chunk_index'] = chunk_index
|
35
30
|
if kwargs.get('add_aux_info', False):
|
36
|
-
split_doc.page_content = split_doc.metadata[
|
31
|
+
split_doc.page_content = split_doc.metadata['source'] + '\n' + split_doc.metadata[
|
32
|
+
'title'] + '\n' + split_doc.page_content
|
37
33
|
|
38
34
|
connection_args = self.vector_store.connection_args
|
39
35
|
embedding_function = self.vector_store.embedding_func
|
@@ -1,22 +1,14 @@
|
|
1
|
-
import
|
2
|
-
import uuid
|
3
|
-
from loguru import logger
|
4
|
-
from dataclasses import dataclass
|
5
|
-
from typing import Any, Dict, Iterable, List, Optional
|
1
|
+
from typing import Any, List, Optional
|
6
2
|
|
7
|
-
from
|
8
|
-
from bisheng_langchain.vectorstores.milvus import Milvus
|
3
|
+
from langchain.text_splitter import TextSplitter
|
9
4
|
from langchain_core.documents import Document
|
10
5
|
from langchain_core.pydantic_v1 import Field
|
11
6
|
from langchain_core.retrievers import BaseRetriever
|
12
|
-
from
|
13
|
-
|
14
|
-
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
15
|
-
from langchain.text_splitter import TextSplitter
|
7
|
+
from loguru import logger
|
16
8
|
|
17
9
|
|
18
10
|
class KeywordRetriever(BaseRetriever):
|
19
|
-
keyword_store:
|
11
|
+
keyword_store: Any
|
20
12
|
text_splitter: TextSplitter
|
21
13
|
search_type: str = 'similarity'
|
22
14
|
search_kwargs: dict = Field(default_factory=dict)
|
@@ -29,13 +21,14 @@ class KeywordRetriever(BaseRetriever):
|
|
29
21
|
**kwargs,
|
30
22
|
) -> None:
|
31
23
|
split_docs = self.text_splitter.split_documents(documents)
|
32
|
-
logger.info(f
|
24
|
+
logger.info(f'KeywordRetriever: split document into {len(split_docs)} chunks')
|
33
25
|
for chunk_index, split_doc in enumerate(split_docs):
|
34
26
|
if 'chunk_bboxes' in split_doc.metadata:
|
35
27
|
split_doc.metadata.pop('chunk_bboxes')
|
36
28
|
split_doc.metadata['chunk_index'] = chunk_index
|
37
29
|
if kwargs.get('add_aux_info', False):
|
38
|
-
split_doc.page_content = split_doc.metadata[
|
30
|
+
split_doc.page_content = split_doc.metadata['source'] + '\n' + split_doc.metadata[
|
31
|
+
'title'] + '\n' + split_doc.page_content
|
39
32
|
|
40
33
|
elasticsearch_url = self.keyword_store.elasticsearch_url
|
41
34
|
ssl_verify = self.keyword_store.ssl_verify
|
@@ -58,8 +51,7 @@ class KeywordRetriever(BaseRetriever):
|
|
58
51
|
index_name=collection_name,
|
59
52
|
elasticsearch_url=self.keyword_store.elasticsearch_url,
|
60
53
|
ssl_verify=self.keyword_store.ssl_verify,
|
61
|
-
llm_chain=self.keyword_store.llm_chain
|
62
|
-
)
|
54
|
+
llm_chain=self.keyword_store.llm_chain)
|
63
55
|
if self.search_type == 'similarity':
|
64
56
|
result = self.keyword_store.similarity_search(query, **self.search_kwargs)
|
65
57
|
return result
|
@@ -1,17 +1,14 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, List, Optional
|
2
2
|
|
3
3
|
from bisheng_langchain.vectorstores import ElasticKeywordsSearch
|
4
|
-
from
|
4
|
+
from langchain.text_splitter import TextSplitter
|
5
5
|
from langchain_core.documents import Document
|
6
6
|
from langchain_core.pydantic_v1 import Field
|
7
7
|
from langchain_core.retrievers import BaseRetriever
|
8
8
|
|
9
|
-
from langchain.schema import BaseRetriever, Document
|
10
|
-
from langchain.text_splitter import TextSplitter
|
11
|
-
|
12
9
|
|
13
10
|
class MixRetriever(BaseRetriever):
|
14
|
-
vector_store:
|
11
|
+
vector_store: Any
|
15
12
|
keyword_store: ElasticKeywordsSearch
|
16
13
|
vector_text_splitter: TextSplitter
|
17
14
|
keyword_text_splitter: TextSplitter
|
@@ -34,14 +31,16 @@ class MixRetriever(BaseRetriever):
|
|
34
31
|
split_doc.metadata.pop('chunk_bboxes')
|
35
32
|
split_doc.metadata['chunk_index'] = chunk_index
|
36
33
|
if kwargs.get('add_aux_info', False):
|
37
|
-
split_doc.page_content = split_doc.metadata[
|
34
|
+
split_doc.page_content = split_doc.metadata['source'] + '\n' + split_doc.metadata[
|
35
|
+
'title'] + '\n' + split_doc.page_content
|
38
36
|
keyword_split_docs = self.keyword_text_splitter.split_documents(documents)
|
39
37
|
for chunk_index, split_doc in enumerate(keyword_split_docs):
|
40
38
|
if 'chunk_bboxes' in split_doc.metadata:
|
41
39
|
split_doc.metadata.pop('chunk_bboxes')
|
42
40
|
split_doc.metadata['chunk_index'] = chunk_index
|
43
41
|
if kwargs.get('add_aux_info', False):
|
44
|
-
split_doc.page_content = split_doc.metadata[
|
42
|
+
split_doc.page_content = split_doc.metadata['source'] + '\n' + split_doc.metadata[
|
43
|
+
'title'] + '\n' + split_doc.page_content
|
45
44
|
|
46
45
|
self.keyword_store.from_documents(
|
47
46
|
keyword_split_docs,
|
@@ -70,15 +69,15 @@ class MixRetriever(BaseRetriever):
|
|
70
69
|
index_name=collection_name,
|
71
70
|
elasticsearch_url=self.keyword_store.elasticsearch_url,
|
72
71
|
ssl_verify=self.keyword_store.ssl_verify,
|
73
|
-
llm_chain=self.keyword_store.llm_chain
|
74
|
-
)
|
72
|
+
llm_chain=self.keyword_store.llm_chain)
|
75
73
|
self.vector_store = self.vector_store.__class__(
|
76
74
|
collection_name=collection_name,
|
77
75
|
embedding_function=self.vector_store.embedding_func,
|
78
76
|
connection_args=self.vector_store.connection_args,
|
79
77
|
)
|
80
78
|
if self.search_type == 'similarity':
|
81
|
-
keyword_docs = self.keyword_store.similarity_search(query,
|
79
|
+
keyword_docs = self.keyword_store.similarity_search(query,
|
80
|
+
**self.keyword_search_kwargs)
|
82
81
|
vector_docs = self.vector_store.similarity_search(query, **self.vector_search_kwargs)
|
83
82
|
if self.combine_strategy == 'keyword_front':
|
84
83
|
return keyword_docs + vector_docs
|
@@ -94,10 +93,10 @@ class MixRetriever(BaseRetriever):
|
|
94
93
|
combine_docs.extend(vector_docs[min_len:])
|
95
94
|
return combine_docs
|
96
95
|
else:
|
97
|
-
raise ValueError(
|
98
|
-
|
99
|
-
|
100
|
-
f'instead found {self.combine_strategy}'
|
101
|
-
)
|
96
|
+
raise ValueError(f'Expected combine_strategy to be one of '
|
97
|
+
f'(keyword_front, vector_front, mix),'
|
98
|
+
f'instead found {self.combine_strategy}')
|
102
99
|
else:
|
103
|
-
raise ValueError(
|
100
|
+
raise ValueError(
|
101
|
+
f'Expected search_type to be one of (similarity), instead found {self.search_type}'
|
102
|
+
)
|