bisheng-langchain 0.3.0rc1__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +5 -3
  2. bisheng_langchain/gpts/agent_types/llm_functions_agent.py +7 -1
  3. bisheng_langchain/gpts/assistant.py +8 -5
  4. bisheng_langchain/gpts/load_tools.py +2 -0
  5. bisheng_langchain/gpts/prompts/__init__.py +4 -2
  6. bisheng_langchain/gpts/prompts/assistant_prompt_base.py +1 -0
  7. bisheng_langchain/gpts/prompts/assistant_prompt_cohere.py +19 -0
  8. bisheng_langchain/gpts/tools/api_tools/flow.py +3 -3
  9. bisheng_langchain/gpts/tools/api_tools/openapi.py +101 -0
  10. bisheng_langchain/rag/__init__.py +5 -0
  11. bisheng_langchain/rag/bisheng_rag_pipeline.py +320 -0
  12. bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +359 -0
  13. bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +376 -0
  14. bisheng_langchain/rag/bisheng_rag_tool.py +288 -0
  15. bisheng_langchain/rag/config/baseline.yaml +86 -0
  16. bisheng_langchain/rag/config/baseline_caibao.yaml +82 -0
  17. bisheng_langchain/rag/config/baseline_caibao_knowledge_v2.yaml +110 -0
  18. bisheng_langchain/rag/config/baseline_caibao_v2.yaml +112 -0
  19. bisheng_langchain/rag/config/baseline_demo_v2.yaml +92 -0
  20. bisheng_langchain/rag/config/baseline_s2b_mix.yaml +88 -0
  21. bisheng_langchain/rag/config/baseline_v2.yaml +90 -0
  22. bisheng_langchain/rag/extract_info.py +38 -0
  23. bisheng_langchain/rag/init_retrievers/__init__.py +4 -0
  24. bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +61 -0
  25. bisheng_langchain/rag/init_retrievers/keyword_retriever.py +65 -0
  26. bisheng_langchain/rag/init_retrievers/mix_retriever.py +103 -0
  27. bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +92 -0
  28. bisheng_langchain/rag/prompts/__init__.py +9 -0
  29. bisheng_langchain/rag/prompts/extract_key_prompt.py +34 -0
  30. bisheng_langchain/rag/prompts/prompt.py +47 -0
  31. bisheng_langchain/rag/prompts/prompt_cohere.py +111 -0
  32. bisheng_langchain/rag/qa_corpus/__init__.py +0 -0
  33. bisheng_langchain/rag/qa_corpus/qa_generator.py +143 -0
  34. bisheng_langchain/rag/rerank/__init__.py +5 -0
  35. bisheng_langchain/rag/rerank/rerank.py +48 -0
  36. bisheng_langchain/rag/rerank/rerank_benchmark.py +139 -0
  37. bisheng_langchain/rag/run_qa_gen_web.py +47 -0
  38. bisheng_langchain/rag/run_rag_evaluate_web.py +55 -0
  39. bisheng_langchain/rag/scoring/__init__.py +0 -0
  40. bisheng_langchain/rag/scoring/llama_index_score.py +91 -0
  41. bisheng_langchain/rag/scoring/ragas_score.py +183 -0
  42. bisheng_langchain/rag/utils.py +181 -0
  43. bisheng_langchain/retrievers/ensemble.py +2 -1
  44. bisheng_langchain/vectorstores/elastic_keywords_search.py +2 -1
  45. {bisheng_langchain-0.3.0rc1.dist-info → bisheng_langchain-0.3.1.dist-info}/METADATA +1 -1
  46. {bisheng_langchain-0.3.0rc1.dist-info → bisheng_langchain-0.3.1.dist-info}/RECORD +48 -13
  47. bisheng_langchain/gpts/prompts/base_prompt.py +0 -1
  48. {bisheng_langchain-0.3.0rc1.dist-info → bisheng_langchain-0.3.1.dist-info}/WHEEL +0 -0
  49. {bisheng_langchain-0.3.0rc1.dist-info → bisheng_langchain-0.3.1.dist-info}/top_level.txt +0 -0
@@ -85,16 +85,18 @@ class ElemUnstructuredLoader(BasePDFLoader):
85
85
  mode='partition',
86
86
  parameters=parameters)
87
87
 
88
- resp = requests.post(self.unstructured_api_url, headers=self.headers, json=payload).json()
88
+ resp = requests.post(self.unstructured_api_url, headers=self.headers, json=payload)
89
+ if resp.status_code != 200:
90
+ raise Exception(f'file partition {os.path.basename(self.file_name)} failed resp={resp.text}')
89
91
 
92
+ resp = resp.json()
90
93
  if 200 != resp.get('status_code'):
91
- logger.info(f'not return resp={resp}')
94
+ logger.info(f'file partition {os.path.basename(self.file_name)} error resp={resp}')
92
95
  partitions = resp['partitions']
93
96
  if not partitions:
94
97
  logger.info(f'partition_error resp={resp}')
95
98
  logger.info(f'unstruct_return code={resp.get("status_code")}')
96
99
 
97
- partitions = resp['partitions']
98
100
  content, metadata = merge_partitions(partitions)
99
101
  metadata['source'] = self.file_name
100
102
 
@@ -1,5 +1,5 @@
1
1
  import json
2
-
2
+ import re
3
3
  from bisheng_langchain.gpts.message_types import LiberalFunctionMessage, LiberalToolMessage
4
4
  from langchain.tools import BaseTool
5
5
  from langchain.tools.render import format_tool_to_openai_tool
@@ -39,6 +39,12 @@ def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageMode
39
39
  last_message = messages[-1]
40
40
  # If there is no function call, then we finish
41
41
  if 'tool_calls' not in last_message.additional_kwargs:
42
+ if '|<instruct>|' in system_message:
43
+ # cohere model
44
+ pattern = r"Answer:(.+)\nGrounded answer"
45
+ match = re.search(pattern, last_message.content)
46
+ if match:
47
+ last_message.content = match.group(1)
42
48
  return 'end'
43
49
  # Otherwise if there is, we continue
44
50
  else:
@@ -126,11 +126,14 @@ class BishengAssistant:
126
126
  if __name__ == "__main__":
127
127
  from langchain.globals import set_debug
128
128
 
129
- set_debug(True)
130
- # chat_history = []
131
- chat_history = ['你好', '你好,有什么可以帮助你吗?', '福蓉科技股价多少?', '福蓉科技(股票代码:300049)的当前股价为48.67元。']
132
- query = "去年这个时候的股价是多少?"
133
- bisheng_assistant = BishengAssistant("config/base_scene.yaml")
129
+ # set_debug(True)
130
+ chat_history = []
131
+ query = "请简要分析中科创达软件股份有限公司2019年聘任、解聘会计师事务的情况。"
132
+ # chat_history = ['你好', '你好,有什么可以帮助你吗?', '福蓉科技股价多少?', '福蓉科技(股票代码:300049)的当前股价为48.67元。']
133
+ # query = '去年这个时候的股价是多少?'
134
+ # bisheng_assistant = BishengAssistant("config/base_scene.yaml")
135
+ # bisheng_assistant = BishengAssistant("config/knowledge_scene.yaml")
136
+ bisheng_assistant = BishengAssistant("config/rag_scene.yaml")
134
137
  result = bisheng_assistant.run(query, chat_history=chat_history)
135
138
  for r in result:
136
139
  print(f'------------------')
@@ -26,6 +26,7 @@ from langchain_core.callbacks import BaseCallbackManager, Callbacks
26
26
  from langchain_core.language_models import BaseLanguageModel
27
27
  from langchain_core.tools import BaseTool, Tool
28
28
  from mypy_extensions import Arg, KwArg
29
+ from bisheng_langchain.rag import BishengRAGTool
29
30
 
30
31
 
31
32
  def _get_current_time() -> BaseTool:
@@ -86,6 +87,7 @@ _EXTRA_PARAM_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[Optio
86
87
  'dalle_image_generator': (_get_dalle_image_generator, ['openai_api_key', 'openai_proxy'], []),
87
88
  'bing_search': (_get_bing_search, ['bing_subscription_key', 'bing_search_url'], []),
88
89
  'bisheng_code_interpreter': (_get_native_code_interpreter, ["minio"], ['files']),
90
+ 'bisheng_rag': (BishengRAGTool.get_rag_tool, ['name', 'description'], ['vector_store', 'keyword_store', 'llm', 'collection_name', 'max_content', 'sort_by_source_and_index']),
89
91
  }
90
92
 
91
93
  _API_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {**ALL_API_TOOLS} # type: ignore
@@ -1,12 +1,14 @@
1
1
  from bisheng_langchain.gpts.prompts.assistant_prompt_opt import ASSISTANT_PROMPT_OPT
2
- from bisheng_langchain.gpts.prompts.base_prompt import DEFAULT_SYSTEM_MESSAGE
2
+ from bisheng_langchain.gpts.prompts.assistant_prompt_base import ASSISTANT_PROMPT_DEFAULT
3
+ from bisheng_langchain.gpts.prompts.assistant_prompt_cohere import ASSISTANT_PROMPT_COHERE
3
4
  from bisheng_langchain.gpts.prompts.breif_description_prompt import BREIF_DES_PROMPT
4
5
  from bisheng_langchain.gpts.prompts.opening_dialog_prompt import OPENDIALOG_PROMPT
5
6
  from bisheng_langchain.gpts.prompts.select_tools_prompt import HUMAN_MSG, SYS_MSG
6
7
 
7
8
 
8
9
  __all__ = [
9
- "DEFAULT_SYSTEM_MESSAGE",
10
+ "ASSISTANT_PROMPT_DEFAULT",
11
+ "ASSISTANT_PROMPT_COHERE",
10
12
  "ASSISTANT_PROMPT_OPT",
11
13
  "OPENDIALOG_PROMPT",
12
14
  "BREIF_DES_PROMPT",
@@ -0,0 +1 @@
1
+ ASSISTANT_PROMPT_DEFAULT = "You are a helpful assistant."
@@ -0,0 +1,19 @@
1
+ preamble="""You are a helpful assistant.
2
+ """
3
+
4
+ ASSISTANT_PROMPT_COHERE="""{preamble}|<instruct>|Carefully perform the following instructions, in order, starting each with a new line.
5
+ Firstly, You may need to use complex and advanced reasoning to complete your task and answer the question. Think about how you can use the provided tools to answer the question and come up with a high level plan you will execute.
6
+ Write 'Plan:' followed by an initial high level plan of how you will solve the problem including the tools and steps required.
7
+ Secondly, Carry out your plan by repeatedly using actions, reasoning over the results, and re-evaluating your plan. Perform Action, Observation, Reflection steps with the following format. Write 'Action:' followed by a json formatted action containing the "tool_name" and "parameters"
8
+ Next you will analyze the 'Observation:', this is the result of the action.
9
+ After that you should always think about what to do next. Write 'Reflection:' followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next including if you know the answer to the question.
10
+ ... (this Action/Observation/Reflection can repeat N times)
11
+ Thirdly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'.
12
+ Fourthly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'.
13
+ Fifthly, Write 'Answer:' followed by a response to the user's last input. Use the retrieved documents to help you. Do not insert any citations or grounding markup.
14
+ Finally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols <co: doc> and </co: doc> to indicate when a fact comes from a document in the search result, e.g <co: 4>my fact</co: 4> for a fact from document 4.
15
+
16
+ Additional instructions to note:
17
+ - If the user's question is in Chinese, please answer it in Chinese.
18
+ - 当问题中有涉及到时间信息时,比如最近6个月、昨天、去年等,你需要用时间工具查询时间信息。
19
+ """.format(preamble=preamble)
@@ -1,10 +1,11 @@
1
1
  from loguru import logger
2
- from pydantic import BaseModel, Field
2
+ from langchain_core.pydantic_v1 import BaseModel, Field
3
3
  from typing import Any
4
4
  from .base import APIToolBase
5
5
  from .base import MultArgsSchemaTool
6
6
  from langchain_core.tools import BaseTool
7
7
 
8
+
8
9
  class FlowTools(APIToolBase):
9
10
 
10
11
  def run(self, query: str) -> str:
@@ -52,9 +53,8 @@ class FlowTools(APIToolBase):
52
53
  }
53
54
 
54
55
  class InputArgs(BaseModel):
55
- """args_schema"""
56
56
  query: str = Field(description='questions to ask')
57
-
57
+
58
58
  return cls(url=url, params=params, input_key=input_key, args_schema=InputArgs)
59
59
 
60
60
  @classmethod
@@ -0,0 +1,101 @@
1
+ from typing import Any
2
+
3
+ from langchain_core.tools import BaseTool
4
+ from loguru import logger
5
+ from pydantic import BaseModel, create_model
6
+
7
+ from .base import APIToolBase, MultArgsSchemaTool, Field
8
+
9
+
10
+ class OpenApiTools(APIToolBase):
11
+
12
+ def get_real_path(self):
13
+ return self.url + self.params["path"]
14
+
15
+ def get_request_method(self):
16
+ return self.params["method"]
17
+
18
+ def get_params_json(self, **kwargs):
19
+ params_define = {}
20
+ for one in self.params["parameters"]:
21
+ params_define[one["name"]] = one
22
+
23
+ params = {}
24
+ json_data = {}
25
+ for k, v in kwargs.items():
26
+ if params_define.get(k):
27
+ if params_define[k]["in"] == "query":
28
+ params[k] = v
29
+ else:
30
+ json_data[k] = v
31
+ else:
32
+ params[k] = v
33
+ return params, json_data
34
+
35
+ def parse_args_schema(self):
36
+ params = self.params["parameters"]
37
+ model_params = {}
38
+ for one in params:
39
+ field_type = one["schema"]["type"]
40
+ if field_type == "number":
41
+ field_type = "float"
42
+ elif field_type == "integer":
43
+ field_type = "int"
44
+ elif field_type == "string":
45
+ field_type = "str"
46
+ else:
47
+ raise Exception(f"schema type is not support: {field_type}")
48
+ model_params[one["name"]] = (eval(field_type), Field(description=one["description"]))
49
+ return create_model("InputArgs", __module__='bisheng_langchain.gpts.tools.api_tools.openapi',
50
+ __base__=BaseModel, **model_params)
51
+
52
+ def run(self, **kwargs) -> str:
53
+ """Run query through api and parse result."""
54
+ path = self.get_real_path()
55
+ logger.info('api_call url={}', path)
56
+ method = self.get_request_method()
57
+ params, json_data = self.get_params_json(**kwargs)
58
+
59
+ if method == "get":
60
+ resp = self.client.get(path, params=params)
61
+ elif method == 'post':
62
+ resp = self.client.post(path, params=params, json=self.params)
63
+ elif method == 'put':
64
+ resp = self.client.put(path, params=params, json=self.params)
65
+ elif method == 'delete':
66
+ resp = self.client.delete(path, params=params, json=self.params)
67
+ else:
68
+ raise Exception(f"http method is not support: {method}")
69
+ if resp.status_code != 200:
70
+ logger.info(f'api_call_fail code={resp.status_code} res={resp.text}')
71
+ raise Exception(f"api_call_fail: {resp.status_code} {resp.text}")
72
+ return resp.text
73
+
74
+ async def arun(self, **kwargs) -> str:
75
+ """Run query through api and parse result."""
76
+ path = self.get_real_path()
77
+ logger.info('api_call url={}', path)
78
+ method = self.get_request_method()
79
+ params, json_data = self.get_params_json(**kwargs)
80
+
81
+ if method == "get":
82
+ resp = await self.async_client.aget(path, params=params)
83
+ elif method == 'post':
84
+ resp = await self.async_client.apost(path, params=params, json=self.params)
85
+ elif method == 'put':
86
+ resp = await self.async_client.aput(path, params=params, json=self.params)
87
+ elif method == 'delete':
88
+ resp = await self.async_client.adelete(path, params=params, json=self.params)
89
+ else:
90
+ raise Exception(f"http method is not support: {method}")
91
+ return resp
92
+
93
+ @classmethod
94
+ def get_api_tool(cls, name, **kwargs: Any) -> BaseTool:
95
+ description = kwargs.pop("description", "")
96
+ obj = cls(**kwargs)
97
+ return MultArgsSchemaTool(name=name,
98
+ description=description,
99
+ func=obj.run,
100
+ coroutine=obj.arun,
101
+ args_schema=obj.parse_args_schema())
@@ -0,0 +1,5 @@
1
+ from bisheng_langchain.rag.bisheng_rag_tool import BishengRAGTool
2
+
3
+ __all__ = [
4
+ "BishengRAGTool",
5
+ ]
@@ -0,0 +1,320 @@
1
+ import argparse
2
+ import copy
3
+ import inspect
4
+ import time
5
+ import os
6
+ from collections import defaultdict
7
+
8
+ import httpx
9
+ import pandas as pd
10
+ import yaml
11
+ from loguru import logger
12
+ from tqdm import tqdm
13
+ from bisheng_langchain.retrievers import EnsembleRetriever
14
+ from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
15
+ from langchain.chains.question_answering import load_qa_chain
16
+ from bisheng_langchain.rag.init_retrievers import (
17
+ BaselineVectorRetriever,
18
+ KeywordRetriever,
19
+ MixRetriever,
20
+ SmallerChunksVectorRetriever,
21
+ )
22
+ from bisheng_langchain.rag.scoring.ragas_score import RagScore
23
+ from bisheng_langchain.rag.utils import import_by_type, import_class
24
+
25
+
26
+ class BishengRagPipeline:
27
+
28
+ def __init__(self, yaml_path) -> None:
29
+ self.yaml_path = yaml_path
30
+ with open(self.yaml_path, 'r') as f:
31
+ self.params = yaml.safe_load(f)
32
+
33
+ # init data
34
+ self.origin_file_path = self.params['data']['origin_file_path']
35
+ self.question_path = self.params['data']['question']
36
+ self.save_answer_path = self.params['data']['save_answer']
37
+
38
+ # init embeddings
39
+ embedding_params = self.params['embedding']
40
+ embedding_object = import_by_type(_type='embeddings', name=embedding_params['type'])
41
+ if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
42
+ embedding_params.pop('type')
43
+ self.embeddings = embedding_object(
44
+ http_client=httpx.Client(proxies=embedding_params['openai_proxy']), **embedding_params
45
+ )
46
+ else:
47
+ embedding_params.pop('type')
48
+ self.embeddings = embedding_object(**embedding_params)
49
+
50
+ # init llm
51
+ llm_params = self.params['chat_llm']
52
+ llm_object = import_by_type(_type='llms', name=llm_params['type'])
53
+ if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
54
+ llm_params.pop('type')
55
+ self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
56
+ else:
57
+ llm_params.pop('type')
58
+ self.llm = llm_object(**llm_params)
59
+
60
+ # milvus
61
+ self.vector_store = Milvus(
62
+ embedding_function=self.embeddings,
63
+ connection_args={
64
+ "host": self.params['milvus']['host'],
65
+ "port": self.params['milvus']['port'],
66
+ },
67
+ )
68
+
69
+ # es
70
+ self.keyword_store = ElasticKeywordsSearch(
71
+ index_name='default_es',
72
+ elasticsearch_url=self.params['elasticsearch']['url'],
73
+ ssl_verify=self.params['elasticsearch']['ssl_verify'],
74
+ )
75
+
76
+ # init retriever
77
+ retriever_list = []
78
+ retrievers = self.params['retriever']['retrievers']
79
+ for retriever in retrievers:
80
+ retriever_type = retriever.pop('type')
81
+ retriever_params = {
82
+ 'vector_store': self.vector_store,
83
+ 'keyword_store': self.keyword_store,
84
+ 'splitter_kwargs': retriever['splitter'],
85
+ 'retrieval_kwargs': retriever['retrieval'],
86
+ }
87
+ retriever_list.append(self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
88
+ self.retriever = EnsembleRetriever(retrievers=retriever_list)
89
+
90
+ def _post_init_retriever(self, retriever_type, **kwargs):
91
+ retriever_classes = {
92
+ 'KeywordRetriever': KeywordRetriever,
93
+ 'BaselineVectorRetriever': BaselineVectorRetriever,
94
+ 'MixRetriever': MixRetriever,
95
+ 'SmallerChunksVectorRetriever': SmallerChunksVectorRetriever,
96
+ }
97
+ if retriever_type not in retriever_classes:
98
+ raise ValueError(f'Unknown retriever type: {retriever_type}')
99
+
100
+ input_kwargs = {}
101
+ splitter_params = kwargs.pop('splitter_kwargs')
102
+ for key, value in splitter_params.items():
103
+ splitter_obj = import_by_type(_type='textsplitters', name=value.pop('type'))
104
+ input_kwargs[key] = splitter_obj(**value)
105
+
106
+ retrieval_params = kwargs.pop('retrieval_kwargs')
107
+ for key, value in retrieval_params.items():
108
+ input_kwargs[key] = value
109
+
110
+ input_kwargs['vector_store'] = kwargs.pop('vector_store')
111
+ input_kwargs['keyword_store'] = kwargs.pop('keyword_store')
112
+
113
+ retriever_class = retriever_classes[retriever_type]
114
+ return retriever_class(**input_kwargs)
115
+
116
+ def file2knowledge(self):
117
+ """
118
+ file to knowledge
119
+ """
120
+ df = pd.read_excel(self.question_path)
121
+ if ('文件名' not in df.columns) or ('知识库名' not in df.columns):
122
+ raise Exception(f'文件名 or 知识库名 not in {self.question_path}.')
123
+
124
+ loader_params = self.params['loader']
125
+ loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
126
+
127
+ all_questions_info = df.to_dict('records')
128
+ collectionname2filename = defaultdict(set)
129
+ for info in all_questions_info:
130
+ # 存入set,去掉重复的文件名
131
+ collectionname2filename[info['知识库名']].add(info['文件名'])
132
+
133
+ for collection_name in tqdm(collectionname2filename):
134
+ all_file_paths = []
135
+ for file_name in collectionname2filename[collection_name]:
136
+ file_path = os.path.join(self.origin_file_path, file_name)
137
+ if not os.path.exists(file_path):
138
+ raise Exception(f'{file_path} not exists.')
139
+ # file path可以是文件夹或者单个文件
140
+ if os.path.isdir(file_path):
141
+ # 文件夹包含多个文件
142
+ all_file_paths.extend(
143
+ [os.path.join(file_path, name) for name in os.listdir(file_path) if not name.startswith('.')]
144
+ )
145
+ else:
146
+ # 单个文件
147
+ all_file_paths.append(file_path)
148
+
149
+ # 当前知识库需要存储的所有文件
150
+ collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
151
+ for index, each_file_path in enumerate(all_file_paths):
152
+ logger.info(f'each_file_path: {each_file_path}')
153
+ loader = loader_object(
154
+ file_name=os.path.basename(each_file_path), file_path=each_file_path, **loader_params
155
+ )
156
+ documents = loader.load()
157
+ logger.info(f'documents: {len(documents)}')
158
+ if len(documents[0].page_content) == 0:
159
+ logger.error(f'{each_file_path} page_content is empty.')
160
+
161
+ vector_drop_old = self.params['milvus']['drop_old'] if index == 0 else False
162
+ keyword_drop_old = self.params['elasticsearch']['drop_old'] if index == 0 else False
163
+ for idx, retriever in enumerate(self.retriever.retrievers):
164
+ retriever.add_documents(documents, f"{collection_name}_{idx}", vector_drop_old)
165
+ # retriever.add_documents(documents, collection_name, vector_drop_old)
166
+
167
+ def retrieval_and_rerank(self, question, collection_name):
168
+ """
169
+ retrieval and rerank
170
+ """
171
+ collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
172
+
173
+ # EnsembleRetriever直接检索召回会默认去重
174
+ # docs = self.retriever.get_relevant_documents(query=question, collection_name=collection_name)
175
+ docs = []
176
+ for idx, retriever in enumerate(self.retriever.retrievers):
177
+ docs.extend(retriever.get_relevant_documents(query=question, collection_name=f"{collection_name}_{idx}"))
178
+ # docs.extend(retriever.get_relevant_documents(query=question, collection_name=collection_name))
179
+ logger.info(f'retrieval docs: {len(docs)}')
180
+
181
+ # delete duplicate
182
+ if self.params['post_retrieval']['delete_duplicate']:
183
+ logger.info(f'origin docs: {len(docs)}')
184
+ all_contents = []
185
+ docs_no_dup = []
186
+ for index, doc in enumerate(docs):
187
+ doc_content = doc.page_content
188
+ if doc_content in all_contents:
189
+ continue
190
+ all_contents.append(doc_content)
191
+ docs_no_dup.append(doc)
192
+ docs = docs_no_dup
193
+ logger.info(f'delete duplicate docs: {len(docs)}')
194
+
195
+ # rerank
196
+ if self.params['post_retrieval']['with_rank'] and len(docs):
197
+ if not hasattr(self, 'ranker'):
198
+ rerank_params = self.params['post_retrieval']['rerank']
199
+ rerank_type = rerank_params.pop('type')
200
+ rerank_object = import_class(f'bisheng_langchain.rag.rerank.{rerank_type}')
201
+ self.ranker = rerank_object(**rerank_params)
202
+ docs = getattr(self, 'ranker').sort_and_filter(question, docs)
203
+
204
+ return docs
205
+
206
+ def load_documents(self, file_name, max_content=100000):
207
+ """
208
+ max_content: max content len of llm
209
+ """
210
+ file_path = os.path.join(self.origin_file_path, file_name)
211
+ if not os.path.exists(file_path):
212
+ raise Exception(f'{file_path} not exists.')
213
+ if os.path.isdir(file_path):
214
+ raise Exception(f'{file_path} is a directory.')
215
+
216
+ loader_params = copy.deepcopy(self.params['loader'])
217
+ loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
218
+ loader = loader_object(file_name=file_name, file_path=file_path, **loader_params)
219
+
220
+ documents = loader.load()
221
+ logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
222
+ for doc in documents:
223
+ doc.page_content = doc.page_content[:max_content]
224
+ return documents
225
+
226
+ def question_answering(self):
227
+ """
228
+ question answer over knowledge
229
+ """
230
+ df = pd.read_excel(self.question_path)
231
+ all_questions_info = df.to_dict('records')
232
+ if 'prompt_type' in self.params['generate']:
233
+ prompt_type = self.params['generate']['prompt_type']
234
+ prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
235
+ else:
236
+ prompt = None
237
+ qa_chain = load_qa_chain(
238
+ llm=self.llm, chain_type=self.params['generate']['chain_type'], prompt=prompt, verbose=False
239
+ )
240
+ file2docs = dict()
241
+ for questions_info in tqdm(all_questions_info):
242
+ question = questions_info['问题']
243
+ file_name = questions_info['文件名']
244
+ collection_name = questions_info['知识库名']
245
+
246
+ if self.params['generate']['with_retrieval']:
247
+ # retrieval and rerank
248
+ docs = self.retrieval_and_rerank(question, collection_name)
249
+ else:
250
+ # load document
251
+ if file_name not in file2docs:
252
+ docs = self.load_documents(file_name)
253
+ file2docs[file_name] = docs
254
+ else:
255
+ docs = file2docs[file_name]
256
+
257
+ # question answer
258
+ try:
259
+ ans = qa_chain({"input_documents": docs, "question": question}, return_only_outputs=False)
260
+ except Exception as e:
261
+ logger.error(f'question: {question}\nerror: {e}')
262
+ ans = {'output_text': str(e)}
263
+
264
+ # context = '\n\n'.join([doc.page_content for doc in docs])
265
+ # content = prompt.format(context=context, question=question)
266
+
267
+ # # for rate_limit
268
+ # time.sleep(15)
269
+
270
+ rag_answer = ans['output_text']
271
+ logger.info(f'question: {question}\nans: {rag_answer}\n')
272
+ questions_info['rag_answer'] = rag_answer
273
+ # questions_info['rag_context'] = '\n----------------\n'.join([doc.page_content for doc in docs])
274
+ # questions_info['rag_context'] = content
275
+
276
+ df = pd.DataFrame(all_questions_info)
277
+ df.to_excel(self.save_answer_path, index=False)
278
+
279
+ def score(self):
280
+ """
281
+ score
282
+ """
283
+ metric_params = self.params['metric']
284
+ if metric_params['type'] == 'bisheng-ragas':
285
+ score_params = {
286
+ 'excel_path': self.save_answer_path,
287
+ 'save_path': os.path.dirname(self.save_answer_path),
288
+ 'question_column': metric_params['question_column'],
289
+ 'gt_column': metric_params['gt_column'],
290
+ 'answer_column': metric_params['answer_column'],
291
+ 'query_type_column': metric_params.get('query_type_column', None),
292
+ 'contexts_column': metric_params.get('contexts_column', None),
293
+ 'metrics': metric_params['metrics'],
294
+ 'batch_size': metric_params['batch_size'],
295
+ 'gt_split_column': metric_params.get('gt_split_column', None),
296
+ 'whether_gtsplit': metric_params.get('whether_gtsplit', False), # 是否需要模型对gt进行要点拆分
297
+ }
298
+ rag_score = RagScore(**score_params)
299
+ rag_score.score()
300
+ else:
301
+ # todo: 其他评分方法
302
+ pass
303
+
304
+
305
+ if __name__ == '__main__':
306
+ parser = argparse.ArgumentParser(description='Process some integers.')
307
+ # 添加参数
308
+ parser.add_argument('--mode', type=str, default='qa', help='upload or qa or score')
309
+ parser.add_argument('--params', type=str, default='config/test/baseline_s2b.yaml', help='bisheng rag params')
310
+ # 解析参数
311
+ args = parser.parse_args()
312
+
313
+ rag = BishengRagPipeline(args.params)
314
+
315
+ if args.mode == 'upload':
316
+ rag.file2knowledge()
317
+ elif args.mode == 'qa':
318
+ rag.question_answering()
319
+ elif args.mode == 'score':
320
+ rag.score()