bisheng-langchain 0.3.7.1__py3-none-any.whl → 0.4.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. bisheng_langchain/document_loaders/custom_kv.py +2 -2
  2. bisheng_langchain/document_loaders/elem_pdf.py +3 -3
  3. bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +2 -2
  4. bisheng_langchain/document_loaders/parsers/image.py +1 -1
  5. bisheng_langchain/document_loaders/universal_kv.py +2 -2
  6. bisheng_langchain/gpts/agent_types/llm_functions_agent.py +35 -5
  7. bisheng_langchain/gpts/agent_types/llm_react_agent.py +27 -39
  8. bisheng_langchain/gpts/assistant.py +24 -24
  9. bisheng_langchain/gpts/tools/api_tools/base.py +1 -1
  10. bisheng_langchain/gpts/tools/api_tools/openapi.py +59 -32
  11. bisheng_langchain/rag/bisheng_rag_chain.py +26 -32
  12. bisheng_langchain/rag/bisheng_rag_tool.py +98 -98
  13. bisheng_langchain/rag/extract_info.py +0 -2
  14. bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +8 -12
  15. bisheng_langchain/rag/init_retrievers/keyword_retriever.py +8 -16
  16. bisheng_langchain/rag/init_retrievers/mix_retriever.py +16 -17
  17. bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +8 -8
  18. bisheng_langchain/sql/base.py +1 -1
  19. bisheng_langchain/vectorstores/elastic_keywords_search.py +17 -2
  20. bisheng_langchain/vectorstores/milvus.py +76 -69
  21. {bisheng_langchain-0.3.7.1.dist-info → bisheng_langchain-0.4.0.dev1.dist-info}/METADATA +6 -6
  22. {bisheng_langchain-0.3.7.1.dist-info → bisheng_langchain-0.4.0.dev1.dist-info}/RECORD +24 -24
  23. {bisheng_langchain-0.3.7.1.dist-info → bisheng_langchain-0.4.0.dev1.dist-info}/WHEEL +0 -0
  24. {bisheng_langchain-0.3.7.1.dist-info → bisheng_langchain-0.4.0.dev1.dist-info}/top_level.txt +0 -0
@@ -1,25 +1,20 @@
1
1
  """Chain for question-answering against a vector database."""
2
2
  from __future__ import annotations
3
3
 
4
- import inspect
5
- from abc import abstractmethod
6
4
  from typing import Any, Dict, List, Optional
7
5
 
8
- from langchain_core.callbacks import (
9
- AsyncCallbackManagerForChainRun,
10
- CallbackManagerForChainRun,
11
- Callbacks
12
- )
13
- from langchain_core.prompts import PromptTemplate, BasePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
6
+ from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
7
+ from langchain.chains.base import Chain
8
+ from langchain_core.callbacks import (AsyncCallbackManagerForChainRun, CallbackManagerForChainRun,
9
+ Callbacks)
14
10
  from langchain_core.language_models import BaseLanguageModel
11
+ from langchain_core.prompts import (ChatPromptTemplate, HumanMessagePromptTemplate,
12
+ SystemMessagePromptTemplate)
15
13
  from langchain_core.pydantic_v1 import Extra, Field
16
- from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
17
14
 
18
- from langchain.chains.base import Chain
19
15
  from .bisheng_rag_tool import BishengRAGTool
20
16
 
21
-
22
- # system_template = """Use the following pieces of context to answer the user's question.
17
+ # system_template = """Use the following pieces of context to answer the user's question.
23
18
  # If you don't know the answer, just say that you don't know, don't try to make up an answer.
24
19
  # ----------------
25
20
  # {context}"""
@@ -29,7 +24,6 @@ from .bisheng_rag_tool import BishengRAGTool
29
24
  # ]
30
25
  # DEFAULT_QA_PROMPT = ChatPromptTemplate.from_messages(messages)
31
26
 
32
-
33
27
  system_template_general = """你是一个准确且可靠的知识库问答助手,能够借助上下文知识回答问题。你需要根据以下的规则来回答问题:
34
28
  1. 如果上下文中包含了正确答案,你需要根据上下文进行准确的回答。但是在回答前,你需要注意,上下文中的信息可能存在事实性错误,如果文档中存在和事实不一致的错误,请根据事实回答。
35
29
  2. 如果上下文中不包含答案,就说你不知道,不要试图编造答案。
@@ -51,15 +45,13 @@ DEFAULT_QA_PROMPT = ChatPromptTemplate.from_messages(messages_general)
51
45
 
52
46
  class BishengRetrievalQA(Chain):
53
47
  """Base class for question-answering chains."""
54
-
55
48
  """Chain to use to combine the documents."""
56
- input_key: str = "query" #: :meta private:
57
- output_key: str = "result" #: :meta private:
49
+ input_key: str = 'query' #: :meta private:
50
+ output_key: str = 'result' #: :meta private:
58
51
  return_source_documents: bool = False
59
52
  """Return the source documents or not."""
60
- bisheng_rag_tool: BishengRAGTool = Field(
61
- default_factory=BishengRAGTool, description="RAG tool"
62
- )
53
+ bisheng_rag_tool: BishengRAGTool = Field(default_factory=BishengRAGTool,
54
+ description='RAG tool')
63
55
 
64
56
  class Config:
65
57
  """Configuration for this pydantic object."""
@@ -84,7 +76,7 @@ class BishengRetrievalQA(Chain):
84
76
  """
85
77
  _output_keys = [self.output_key]
86
78
  if self.return_source_documents:
87
- _output_keys = _output_keys + ["source_documents"]
79
+ _output_keys = _output_keys + ['source_documents']
88
80
  return _output_keys
89
81
 
90
82
  @classmethod
@@ -100,15 +92,13 @@ class BishengRetrievalQA(Chain):
100
92
  return_source_documents: bool = False,
101
93
  **kwargs: Any,
102
94
  ) -> BishengRetrievalQA:
103
- bisheng_rag_tool = BishengRAGTool(
104
- vector_store=vector_store,
105
- keyword_store=keyword_store,
106
- llm=llm,
107
- QA_PROMPT=QA_PROMPT,
108
- max_content=max_content,
109
- sort_by_source_and_index=sort_by_source_and_index,
110
- **kwargs
111
- )
95
+ bisheng_rag_tool = BishengRAGTool(vector_store=vector_store,
96
+ keyword_store=keyword_store,
97
+ llm=llm,
98
+ QA_PROMPT=QA_PROMPT,
99
+ max_content=max_content,
100
+ sort_by_source_and_index=sort_by_source_and_index,
101
+ **kwargs)
112
102
  return cls(
113
103
  bisheng_rag_tool=bisheng_rag_tool,
114
104
  callbacks=callbacks,
@@ -134,8 +124,12 @@ class BishengRetrievalQA(Chain):
134
124
  """
135
125
  question = inputs[self.input_key]
136
126
  if self.return_source_documents:
137
- answer, docs = self.bisheng_rag_tool.run(question, return_only_outputs=False)
138
- return {self.output_key: answer, "source_documents": docs}
127
+ answer, docs = self.bisheng_rag_tool.run(
128
+ question,
129
+ return_only_outputs=False,
130
+ run_manager=run_manager,
131
+ )
132
+ return {self.output_key: answer, 'source_documents': docs}
139
133
  else:
140
134
  answer = self.bisheng_rag_tool.run(question, return_only_outputs=True)
141
135
  return {self.output_key: answer}
@@ -160,7 +154,7 @@ class BishengRetrievalQA(Chain):
160
154
 
161
155
  if self.return_source_documents:
162
156
  answer, docs = await self.bisheng_rag_tool.arun(question, return_only_outputs=False)
163
- return {self.output_key: answer, "source_documents": docs}
157
+ return {self.output_key: answer, 'source_documents': docs}
164
158
  else:
165
159
  answer = await self.bisheng_rag_tool.arun(question, return_only_outputs=True)
166
160
  return {self.output_key: answer}
@@ -1,27 +1,25 @@
1
- import time
2
1
  import os
3
- import yaml
4
- import httpx
5
- from typing import Any, Dict, Tuple, Type, Union, Optional
2
+ from typing import Any, Dict, Optional, Tuple, Union
6
3
 
7
- from langchain_core.vectorstores import VectorStoreRetriever
8
- from loguru import logger
9
- from langchain_core.tools import BaseTool, Tool
10
- from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
11
- from langchain_core.language_models.base import LanguageModelLike
12
- from langchain_core.prompts import ChatPromptTemplate
13
- from langchain.chains.llm import LLMChain
14
- from langchain.chains.question_answering import load_qa_chain
4
+ import httpx
5
+ import yaml
6
+ from bisheng_langchain.rag.extract_info import extract_title
7
+ from bisheng_langchain.rag.init_retrievers import (BaselineVectorRetriever, KeywordRetriever,
8
+ MixRetriever, SmallerChunksVectorRetriever)
9
+ from bisheng_langchain.rag.utils import import_by_type, import_class
15
10
  from bisheng_langchain.retrievers import EnsembleRetriever
16
11
  from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
17
- from bisheng_langchain.rag.init_retrievers import (
18
- BaselineVectorRetriever,
19
- KeywordRetriever,
20
- MixRetriever,
21
- SmallerChunksVectorRetriever,
22
- )
23
- from bisheng_langchain.rag.utils import import_by_type, import_class
24
- from bisheng_langchain.rag.extract_info import extract_title
12
+ from langchain.chains.llm import LLMChain
13
+ from langchain.chains.question_answering import load_qa_chain
14
+ from langchain.chains.combine_documents import create_stuff_documents_chain
15
+ from langchain_core.callbacks import CallbackManagerForChainRun
16
+ from langchain_core.language_models.base import LanguageModelLike
17
+ from langchain_core.prompts import ChatPromptTemplate
18
+ from langchain_core.pydantic_v1 import BaseModel, Field
19
+ from langchain_core.runnables import RunnableConfig
20
+ from langchain_core.tools import BaseTool, Tool
21
+ from langchain_core.vectorstores import VectorStoreRetriever
22
+ from loguru import logger
25
23
 
26
24
 
27
25
  class MultArgsSchemaTool(Tool):
@@ -37,26 +35,27 @@ class MultArgsSchemaTool(Tool):
37
35
 
38
36
  class BishengRAGTool:
39
37
 
40
- def __init__(
41
- self,
42
- vector_store: Optional[Milvus] = None,
43
- keyword_store: Optional[ElasticKeywordsSearch] = None,
44
- llm: Optional[LanguageModelLike] = None,
45
- collection_name: Optional[str] = None,
46
- QA_PROMPT: Optional[ChatPromptTemplate] = None,
47
- **kwargs
48
- ) -> None:
38
+ def __init__(self,
39
+ vector_store: Optional[Milvus] = None,
40
+ keyword_store: Optional[ElasticKeywordsSearch] = None,
41
+ llm: Optional[LanguageModelLike] = None,
42
+ collection_name: Optional[str] = None,
43
+ QA_PROMPT: Optional[ChatPromptTemplate] = None,
44
+ **kwargs) -> None:
49
45
  if collection_name is None and (keyword_store is None or vector_store is None):
50
- raise ValueError('collection_name must be provided if keyword_store or vector_store is not provided')
46
+ raise ValueError(
47
+ 'collection_name must be provided if keyword_store or vector_store is not provided'
48
+ )
51
49
  self.collection_name = collection_name
52
-
53
- yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config/baseline_v2.yaml')
50
+
51
+ yaml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
52
+ 'config/baseline_v2.yaml')
54
53
  with open(yaml_path, 'r', encoding='utf-8') as f:
55
54
  self.params = yaml.safe_load(f)
56
-
55
+
57
56
  # update params
58
- max_content = kwargs.get("max_content", 15000)
59
- sort_by_source_and_index = kwargs.get("sort_by_source_and_index", True)
57
+ max_content = kwargs.get('max_content', 15000)
58
+ sort_by_source_and_index = kwargs.get('sort_by_source_and_index', True)
60
59
  self.params['generate']['max_content'] = max_content
61
60
  self.params['post_retrieval']['sort_by_source_and_index'] = sort_by_source_and_index
62
61
 
@@ -68,11 +67,12 @@ class BishengRAGTool:
68
67
  llm_object = import_by_type(_type='llms', name=llm_params['type'])
69
68
  if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
70
69
  llm_params.pop('type')
71
- self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
70
+ self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']),
71
+ **llm_params)
72
72
  else:
73
73
  llm_params.pop('type')
74
74
  self.llm = llm_object(**llm_params)
75
-
75
+
76
76
  # init milvus
77
77
  if vector_store:
78
78
  # if vector_store is retriever, get vector_store instance
@@ -87,26 +87,27 @@ class BishengRAGTool:
87
87
  if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
88
88
  embedding_params.pop('type')
89
89
  self.embeddings = embedding_object(
90
- http_client=httpx.Client(proxies=embedding_params['openai_proxy']), **embedding_params
91
- )
90
+ http_client=httpx.Client(proxies=embedding_params['openai_proxy']),
91
+ **embedding_params)
92
92
  else:
93
93
  embedding_params.pop('type')
94
94
  self.embeddings = embedding_object(**embedding_params)
95
-
95
+
96
96
  self.vector_store = Milvus(
97
97
  embedding_function=self.embeddings,
98
98
  connection_args={
99
- "host": self.params['milvus']['host'],
100
- "port": self.params['milvus']['port'],
99
+ 'host': self.params['milvus']['host'],
100
+ 'port': self.params['milvus']['port'],
101
101
  },
102
102
  )
103
-
103
+
104
104
  # init keyword store
105
105
  if keyword_store:
106
106
  self.keyword_store = keyword_store
107
107
  else:
108
108
  if self.params['elasticsearch'].get('extract_key_by_llm', False):
109
- extract_key_prompt = import_class(f'bisheng_langchain.rag.prompts.EXTRACT_KEY_PROMPT')
109
+ extract_key_prompt = import_class(
110
+ 'bisheng_langchain.rag.prompts.EXTRACT_KEY_PROMPT')
110
111
  llm_chain = LLMChain(llm=self.llm, prompt=extract_key_prompt)
111
112
  else:
112
113
  llm_chain = None
@@ -128,10 +129,11 @@ class BishengRAGTool:
128
129
  'splitter_kwargs': retriever['splitter'],
129
130
  'retrieval_kwargs': retriever['retrieval'],
130
131
  }
131
- retriever_list.append(self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
132
+ retriever_list.append(
133
+ self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
132
134
  self.retriever = EnsembleRetriever(retrievers=retriever_list)
133
135
 
134
- # init qa chain
136
+ # init qa chain
135
137
  if QA_PROMPT:
136
138
  prompt = QA_PROMPT
137
139
  else:
@@ -140,13 +142,8 @@ class BishengRAGTool:
140
142
  prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
141
143
  else:
142
144
  prompt = None
143
- self.qa_chain = load_qa_chain(
144
- llm=self.llm,
145
- chain_type=self.params['generate']['chain_type'],
146
- prompt=prompt,
147
- verbose=False
148
- )
149
-
145
+ self.qa_chain = create_stuff_documents_chain(llm=self.llm, prompt=prompt)
146
+
150
147
  def _post_init_retriever(self, retriever_type, **kwargs):
151
148
  retriever_classes = {
152
149
  'KeywordRetriever': KeywordRetriever,
@@ -181,9 +178,9 @@ class BishengRAGTool:
181
178
  loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
182
179
 
183
180
  logger.info(f'file_path: {file_path}')
184
- loader = loader_object(
185
- file_name=os.path.basename(file_path), file_path=file_path, **loader_params
186
- )
181
+ loader = loader_object(file_name=os.path.basename(file_path),
182
+ file_path=file_path,
183
+ **loader_params)
187
184
  documents = loader.load()
188
185
  logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
189
186
  if len(documents[0].page_content) == 0:
@@ -197,30 +194,26 @@ class BishengRAGTool:
197
194
  title = extract_title(llm=self.llm, text=doc.page_content)
198
195
  logger.info(f'extract title: {title}')
199
196
  except Exception as e:
200
- logger.error(f"Failed to extract title: {e}")
197
+ logger.error(f'Failed to extract title: {e}')
201
198
  title = ''
202
199
  doc.metadata['title'] = title
203
200
 
204
201
  for idx, retriever in enumerate(self.retriever.retrievers):
205
- retriever.add_documents(
206
- documents,
207
- self.collection_name,
208
- drop_old=drop_old,
209
- add_aux_info=add_aux_info
210
- )
211
-
202
+ retriever.add_documents(documents,
203
+ self.collection_name,
204
+ drop_old=drop_old,
205
+ add_aux_info=add_aux_info)
206
+
212
207
  def retrieval_and_rerank(self, query):
213
208
  """
214
209
  retrieval and rerank
215
210
  """
216
211
  # EnsembleRetriever直接检索召回会默认去重
217
- docs = self.retriever.get_relevant_documents(
218
- query=query,
219
- collection_name=self.collection_name
220
- )
212
+ docs = self.retriever.get_relevant_documents(query=query,
213
+ collection_name=self.collection_name)
221
214
  logger.info(f'retrieval docs origin: {len(docs)}')
222
215
 
223
- # delete redundancy according to max_content
216
+ # delete redundancy according to max_content
224
217
  doc_num, doc_content_sum = 0, 0
225
218
  for doc in docs:
226
219
  doc_content_sum += len(doc.page_content)
@@ -235,28 +228,37 @@ class BishengRAGTool:
235
228
  logger.info('sort chunks by source and chunk_index')
236
229
  docs = sorted(docs, key=lambda x: (x.metadata['source'], x.metadata['chunk_index']))
237
230
  return docs
238
-
239
- def run(self, query, return_only_outputs=True) -> Any:
231
+
232
+ def run(self,
233
+ query,
234
+ return_only_outputs=True,
235
+ run_manager: Optional[CallbackManagerForChainRun] = None) -> Any:
240
236
  docs = self.retrieval_and_rerank(query)
241
237
  try:
242
- ans = self.qa_chain({"input_documents": docs, "question": query}, return_only_outputs=return_only_outputs)
238
+ kwargs = {}
239
+ if run_manager:
240
+ kwargs['config'] = RunnableConfig(callbacks=[run_manager])
241
+ ans = self.qa_chain.invoke(
242
+ {
243
+ 'context': docs,
244
+ 'question': query
245
+ }, **kwargs
246
+ )
243
247
  except Exception as e:
244
- logger.error(f'question: {query}\nerror: {e}')
245
- ans = {'output_text': str(e)}
248
+ logger.exception(f'question: {query}\nerror: {e}')
249
+ ans = str(e)
246
250
  if return_only_outputs:
247
- rag_answer = ans['output_text']
248
- return rag_answer
251
+ return ans
249
252
  else:
250
- rag_answer = ans['output_text']
251
- input_documents = ans['input_documents']
252
- return rag_answer, input_documents
253
-
253
+ return ans, docs
254
+
254
255
  async def arun(self, query: str, return_only_outputs=True) -> str:
255
256
  rag_answer = self.run(query, return_only_outputs)
256
257
  return rag_answer
257
-
258
+
258
259
  @classmethod
259
260
  def get_rag_tool(cls, name, description, **kwargs: Any) -> BaseTool:
261
+
260
262
  class InputArgs(BaseModel):
261
263
  query: str = Field(description='question asked by the user.')
262
264
 
@@ -265,7 +267,7 @@ class BishengRAGTool:
265
267
  func=cls(**kwargs).run,
266
268
  coroutine=cls(**kwargs).arun,
267
269
  args_schema=InputArgs)
268
-
270
+
269
271
 
270
272
  if __name__ == '__main__':
271
273
  # rag_tool = BishengRAGTool(collection_name='rag_finance_report_0_test')
@@ -280,32 +282,30 @@ if __name__ == '__main__':
280
282
  collection_name = 'rag_finance_report_0_benchmark_caibao_1000_source_title'
281
283
  # milvus
282
284
  vector_store = Milvus(
283
- collection_name=collection_name,
284
- embedding_function=embeddings,
285
- connection_args={
286
- "host": '110.16.193.170',
287
- "port": '50062',
288
- },
285
+ collection_name=collection_name,
286
+ embedding_function=embeddings,
287
+ connection_args={
288
+ 'host': '110.16.193.170',
289
+ 'port': '50062',
290
+ },
289
291
  )
290
292
  # es
291
293
  keyword_store = ElasticKeywordsSearch(
292
294
  index_name=collection_name,
293
295
  elasticsearch_url='http://110.16.193.170:50062/es',
294
- ssl_verify={'basic_auth': ["elastic", "oSGL-zVvZ5P3Tm7qkDLC"]},
296
+ ssl_verify={'basic_auth': ['elastic', 'oSGL-zVvZ5P3Tm7qkDLC']},
295
297
  )
296
298
 
297
- tool = BishengRAGTool.get_rag_tool(
298
- name='rag_knowledge_retrieve',
299
- description='金融年报财报知识库问答',
300
- vector_store=vector_store,
301
- keyword_store=keyword_store,
302
- llm=llm
303
- )
299
+ tool = BishengRAGTool.get_rag_tool(name='rag_knowledge_retrieve',
300
+ description='金融年报财报知识库问答',
301
+ vector_store=vector_store,
302
+ keyword_store=keyword_store,
303
+ llm=llm)
304
304
  print(tool.run('能否根据2020年金宇生物技术股份有限公司的年报,给我简要介绍一下报告期内公司的社会责任工作情况?'))
305
305
 
306
306
  # tool = BishengRAGTool.get_rag_tool(
307
- # name='rag_knowledge_retrieve',
307
+ # name='rag_knowledge_retrieve',
308
308
  # description='金融年报财报知识库问答',
309
309
  # collection_name='rag_finance_report_0_benchmark_caibao_1000_source_title'
310
310
  # )
311
- # print(tool.run('能否根据2020年金宇生物技术股份有限公司的年报,给我简要介绍一下报告期内公司的社会责任工作情况?'))
311
+ # print(tool.run('能否根据2020年金宇生物技术股份有限公司的年报,给我简要介绍一下报告期内公司的社会责任工作情况?'))
@@ -1,5 +1,3 @@
1
- import httpx
2
- from langchain.chat_models import ChatOpenAI
3
1
  from bisheng_langchain.chat_models import ChatQWen
4
2
  from langchain.chains import LLMChain
5
3
  from langchain.prompts.chat import (
@@ -1,20 +1,15 @@
1
- import os
2
- import uuid
3
- from loguru import logger
4
- from typing import Any, Dict, Iterable, List, Optional
1
+ from typing import Any, List, Optional
5
2
 
6
- from bisheng_langchain.vectorstores.milvus import Milvus
3
+ from langchain.text_splitter import TextSplitter
7
4
  from langchain_core.documents import Document
8
5
  from langchain_core.pydantic_v1 import Field
9
6
  from langchain_core.retrievers import BaseRetriever
10
- from langchain_core.vectorstores import VectorStore
11
-
12
- from langchain.callbacks.manager import CallbackManagerForRetrieverRun
13
- from langchain.text_splitter import TextSplitter
7
+ from loguru import logger
14
8
 
15
9
 
16
10
  class BaselineVectorRetriever(BaseRetriever):
17
- vector_store: Milvus
11
+
12
+ vector_store: Any
18
13
  text_splitter: TextSplitter
19
14
  search_type: str = 'similarity'
20
15
  search_kwargs: dict = Field(default_factory=dict)
@@ -27,13 +22,14 @@ class BaselineVectorRetriever(BaseRetriever):
27
22
  **kwargs,
28
23
  ) -> None:
29
24
  split_docs = self.text_splitter.split_documents(documents)
30
- logger.info(f"BaselineVectorRetriever: split document into {len(split_docs)} chunks")
25
+ logger.info(f'BaselineVectorRetriever: split document into {len(split_docs)} chunks')
31
26
  for chunk_index, split_doc in enumerate(split_docs):
32
27
  if 'chunk_bboxes' in split_doc.metadata:
33
28
  split_doc.metadata.pop('chunk_bboxes')
34
29
  split_doc.metadata['chunk_index'] = chunk_index
35
30
  if kwargs.get('add_aux_info', False):
36
- split_doc.page_content = split_doc.metadata["source"] + '\n' + split_doc.metadata["title"] + '\n' + split_doc.page_content
31
+ split_doc.page_content = split_doc.metadata['source'] + '\n' + split_doc.metadata[
32
+ 'title'] + '\n' + split_doc.page_content
37
33
 
38
34
  connection_args = self.vector_store.connection_args
39
35
  embedding_function = self.vector_store.embedding_func
@@ -1,22 +1,14 @@
1
- import os
2
- import uuid
3
- from loguru import logger
4
- from dataclasses import dataclass
5
- from typing import Any, Dict, Iterable, List, Optional
1
+ from typing import Any, List, Optional
6
2
 
7
- from bisheng_langchain.vectorstores import ElasticKeywordsSearch
8
- from bisheng_langchain.vectorstores.milvus import Milvus
3
+ from langchain.text_splitter import TextSplitter
9
4
  from langchain_core.documents import Document
10
5
  from langchain_core.pydantic_v1 import Field
11
6
  from langchain_core.retrievers import BaseRetriever
12
- from langchain_core.vectorstores import VectorStore
13
-
14
- from langchain.callbacks.manager import CallbackManagerForRetrieverRun
15
- from langchain.text_splitter import TextSplitter
7
+ from loguru import logger
16
8
 
17
9
 
18
10
  class KeywordRetriever(BaseRetriever):
19
- keyword_store: ElasticKeywordsSearch
11
+ keyword_store: Any
20
12
  text_splitter: TextSplitter
21
13
  search_type: str = 'similarity'
22
14
  search_kwargs: dict = Field(default_factory=dict)
@@ -29,13 +21,14 @@ class KeywordRetriever(BaseRetriever):
29
21
  **kwargs,
30
22
  ) -> None:
31
23
  split_docs = self.text_splitter.split_documents(documents)
32
- logger.info(f"KeywordRetriever: split document into {len(split_docs)} chunks")
24
+ logger.info(f'KeywordRetriever: split document into {len(split_docs)} chunks')
33
25
  for chunk_index, split_doc in enumerate(split_docs):
34
26
  if 'chunk_bboxes' in split_doc.metadata:
35
27
  split_doc.metadata.pop('chunk_bboxes')
36
28
  split_doc.metadata['chunk_index'] = chunk_index
37
29
  if kwargs.get('add_aux_info', False):
38
- split_doc.page_content = split_doc.metadata["source"] + '\n' + split_doc.metadata["title"] + '\n' + split_doc.page_content
30
+ split_doc.page_content = split_doc.metadata['source'] + '\n' + split_doc.metadata[
31
+ 'title'] + '\n' + split_doc.page_content
39
32
 
40
33
  elasticsearch_url = self.keyword_store.elasticsearch_url
41
34
  ssl_verify = self.keyword_store.ssl_verify
@@ -58,8 +51,7 @@ class KeywordRetriever(BaseRetriever):
58
51
  index_name=collection_name,
59
52
  elasticsearch_url=self.keyword_store.elasticsearch_url,
60
53
  ssl_verify=self.keyword_store.ssl_verify,
61
- llm_chain=self.keyword_store.llm_chain
62
- )
54
+ llm_chain=self.keyword_store.llm_chain)
63
55
  if self.search_type == 'similarity':
64
56
  result = self.keyword_store.similarity_search(query, **self.search_kwargs)
65
57
  return result
@@ -1,17 +1,14 @@
1
- from typing import Any, Dict, Iterable, List, Optional
1
+ from typing import Any, List, Optional
2
2
 
3
3
  from bisheng_langchain.vectorstores import ElasticKeywordsSearch
4
- from bisheng_langchain.vectorstores.milvus import Milvus
4
+ from langchain.text_splitter import TextSplitter
5
5
  from langchain_core.documents import Document
6
6
  from langchain_core.pydantic_v1 import Field
7
7
  from langchain_core.retrievers import BaseRetriever
8
8
 
9
- from langchain.schema import BaseRetriever, Document
10
- from langchain.text_splitter import TextSplitter
11
-
12
9
 
13
10
  class MixRetriever(BaseRetriever):
14
- vector_store: Milvus
11
+ vector_store: Any
15
12
  keyword_store: ElasticKeywordsSearch
16
13
  vector_text_splitter: TextSplitter
17
14
  keyword_text_splitter: TextSplitter
@@ -34,14 +31,16 @@ class MixRetriever(BaseRetriever):
34
31
  split_doc.metadata.pop('chunk_bboxes')
35
32
  split_doc.metadata['chunk_index'] = chunk_index
36
33
  if kwargs.get('add_aux_info', False):
37
- split_doc.page_content = split_doc.metadata["source"] + '\n' + split_doc.metadata["title"] + '\n' + split_doc.page_content
34
+ split_doc.page_content = split_doc.metadata['source'] + '\n' + split_doc.metadata[
35
+ 'title'] + '\n' + split_doc.page_content
38
36
  keyword_split_docs = self.keyword_text_splitter.split_documents(documents)
39
37
  for chunk_index, split_doc in enumerate(keyword_split_docs):
40
38
  if 'chunk_bboxes' in split_doc.metadata:
41
39
  split_doc.metadata.pop('chunk_bboxes')
42
40
  split_doc.metadata['chunk_index'] = chunk_index
43
41
  if kwargs.get('add_aux_info', False):
44
- split_doc.page_content = split_doc.metadata["source"] + '\n' + split_doc.metadata["title"] + '\n' + split_doc.page_content
42
+ split_doc.page_content = split_doc.metadata['source'] + '\n' + split_doc.metadata[
43
+ 'title'] + '\n' + split_doc.page_content
45
44
 
46
45
  self.keyword_store.from_documents(
47
46
  keyword_split_docs,
@@ -70,15 +69,15 @@ class MixRetriever(BaseRetriever):
70
69
  index_name=collection_name,
71
70
  elasticsearch_url=self.keyword_store.elasticsearch_url,
72
71
  ssl_verify=self.keyword_store.ssl_verify,
73
- llm_chain=self.keyword_store.llm_chain
74
- )
72
+ llm_chain=self.keyword_store.llm_chain)
75
73
  self.vector_store = self.vector_store.__class__(
76
74
  collection_name=collection_name,
77
75
  embedding_function=self.vector_store.embedding_func,
78
76
  connection_args=self.vector_store.connection_args,
79
77
  )
80
78
  if self.search_type == 'similarity':
81
- keyword_docs = self.keyword_store.similarity_search(query, **self.keyword_search_kwargs)
79
+ keyword_docs = self.keyword_store.similarity_search(query,
80
+ **self.keyword_search_kwargs)
82
81
  vector_docs = self.vector_store.similarity_search(query, **self.vector_search_kwargs)
83
82
  if self.combine_strategy == 'keyword_front':
84
83
  return keyword_docs + vector_docs
@@ -94,10 +93,10 @@ class MixRetriever(BaseRetriever):
94
93
  combine_docs.extend(vector_docs[min_len:])
95
94
  return combine_docs
96
95
  else:
97
- raise ValueError(
98
- f'Expected combine_strategy to be one of '
99
- f'(keyword_front, vector_front, mix),'
100
- f'instead found {self.combine_strategy}'
101
- )
96
+ raise ValueError(f'Expected combine_strategy to be one of '
97
+ f'(keyword_front, vector_front, mix),'
98
+ f'instead found {self.combine_strategy}')
102
99
  else:
103
- raise ValueError(f'Expected search_type to be one of (similarity), instead found {self.search_type}')
100
+ raise ValueError(
101
+ f'Expected search_type to be one of (similarity), instead found {self.search_type}'
102
+ )