bisheng-langchain 0.3.0rc1__py3-none-any.whl → 0.3.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +5 -3
- bisheng_langchain/gpts/agent_types/llm_functions_agent.py +7 -1
- bisheng_langchain/gpts/assistant.py +8 -5
- bisheng_langchain/gpts/load_tools.py +2 -0
- bisheng_langchain/gpts/prompts/__init__.py +4 -2
- bisheng_langchain/gpts/prompts/assistant_prompt_base.py +1 -0
- bisheng_langchain/gpts/prompts/assistant_prompt_cohere.py +19 -0
- bisheng_langchain/gpts/tools/api_tools/flow.py +3 -3
- bisheng_langchain/gpts/tools/api_tools/openapi.py +101 -0
- bisheng_langchain/rag/__init__.py +5 -0
- bisheng_langchain/rag/bisheng_rag_pipeline.py +320 -0
- bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +359 -0
- bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +376 -0
- bisheng_langchain/rag/bisheng_rag_tool.py +288 -0
- bisheng_langchain/rag/config/baseline.yaml +86 -0
- bisheng_langchain/rag/config/baseline_caibao.yaml +82 -0
- bisheng_langchain/rag/config/baseline_caibao_knowledge_v2.yaml +110 -0
- bisheng_langchain/rag/config/baseline_caibao_v2.yaml +112 -0
- bisheng_langchain/rag/config/baseline_demo_v2.yaml +92 -0
- bisheng_langchain/rag/config/baseline_s2b_mix.yaml +88 -0
- bisheng_langchain/rag/config/baseline_v2.yaml +90 -0
- bisheng_langchain/rag/extract_info.py +38 -0
- bisheng_langchain/rag/init_retrievers/__init__.py +4 -0
- bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +61 -0
- bisheng_langchain/rag/init_retrievers/keyword_retriever.py +65 -0
- bisheng_langchain/rag/init_retrievers/mix_retriever.py +103 -0
- bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +92 -0
- bisheng_langchain/rag/prompts/__init__.py +9 -0
- bisheng_langchain/rag/prompts/extract_key_prompt.py +34 -0
- bisheng_langchain/rag/prompts/prompt.py +47 -0
- bisheng_langchain/rag/prompts/prompt_cohere.py +111 -0
- bisheng_langchain/rag/qa_corpus/__init__.py +0 -0
- bisheng_langchain/rag/qa_corpus/qa_generator.py +143 -0
- bisheng_langchain/rag/rerank/__init__.py +5 -0
- bisheng_langchain/rag/rerank/rerank.py +48 -0
- bisheng_langchain/rag/rerank/rerank_benchmark.py +139 -0
- bisheng_langchain/rag/run_qa_gen_web.py +47 -0
- bisheng_langchain/rag/run_rag_evaluate_web.py +55 -0
- bisheng_langchain/rag/scoring/__init__.py +0 -0
- bisheng_langchain/rag/scoring/llama_index_score.py +91 -0
- bisheng_langchain/rag/scoring/ragas_score.py +183 -0
- bisheng_langchain/rag/utils.py +181 -0
- bisheng_langchain/retrievers/ensemble.py +2 -1
- bisheng_langchain/vectorstores/elastic_keywords_search.py +2 -1
- {bisheng_langchain-0.3.0rc1.dist-info → bisheng_langchain-0.3.1.1.dist-info}/METADATA +1 -1
- {bisheng_langchain-0.3.0rc1.dist-info → bisheng_langchain-0.3.1.1.dist-info}/RECORD +48 -13
- bisheng_langchain/gpts/prompts/base_prompt.py +0 -1
- {bisheng_langchain-0.3.0rc1.dist-info → bisheng_langchain-0.3.1.1.dist-info}/WHEEL +0 -0
- {bisheng_langchain-0.3.0rc1.dist-info → bisheng_langchain-0.3.1.1.dist-info}/top_level.txt +0 -0
@@ -85,16 +85,18 @@ class ElemUnstructuredLoader(BasePDFLoader):
|
|
85
85
|
mode='partition',
|
86
86
|
parameters=parameters)
|
87
87
|
|
88
|
-
resp = requests.post(self.unstructured_api_url, headers=self.headers, json=payload)
|
88
|
+
resp = requests.post(self.unstructured_api_url, headers=self.headers, json=payload)
|
89
|
+
if resp.status_code != 200:
|
90
|
+
raise Exception(f'file partition {os.path.basename(self.file_name)} failed resp={resp.text}')
|
89
91
|
|
92
|
+
resp = resp.json()
|
90
93
|
if 200 != resp.get('status_code'):
|
91
|
-
logger.info(f'
|
94
|
+
logger.info(f'file partition {os.path.basename(self.file_name)} error resp={resp}')
|
92
95
|
partitions = resp['partitions']
|
93
96
|
if not partitions:
|
94
97
|
logger.info(f'partition_error resp={resp}')
|
95
98
|
logger.info(f'unstruct_return code={resp.get("status_code")}')
|
96
99
|
|
97
|
-
partitions = resp['partitions']
|
98
100
|
content, metadata = merge_partitions(partitions)
|
99
101
|
metadata['source'] = self.file_name
|
100
102
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
|
2
|
+
import re
|
3
3
|
from bisheng_langchain.gpts.message_types import LiberalFunctionMessage, LiberalToolMessage
|
4
4
|
from langchain.tools import BaseTool
|
5
5
|
from langchain.tools.render import format_tool_to_openai_tool
|
@@ -39,6 +39,12 @@ def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageMode
|
|
39
39
|
last_message = messages[-1]
|
40
40
|
# If there is no function call, then we finish
|
41
41
|
if 'tool_calls' not in last_message.additional_kwargs:
|
42
|
+
if '|<instruct>|' in system_message:
|
43
|
+
# cohere model
|
44
|
+
pattern = r"Answer:(.+)\nGrounded answer"
|
45
|
+
match = re.search(pattern, last_message.content)
|
46
|
+
if match:
|
47
|
+
last_message.content = match.group(1)
|
42
48
|
return 'end'
|
43
49
|
# Otherwise if there is, we continue
|
44
50
|
else:
|
@@ -126,11 +126,14 @@ class BishengAssistant:
|
|
126
126
|
if __name__ == "__main__":
|
127
127
|
from langchain.globals import set_debug
|
128
128
|
|
129
|
-
set_debug(True)
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
129
|
+
# set_debug(True)
|
130
|
+
chat_history = []
|
131
|
+
query = "请简要分析中科创达软件股份有限公司2019年聘任、解聘会计师事务的情况。"
|
132
|
+
# chat_history = ['你好', '你好,有什么可以帮助你吗?', '福蓉科技股价多少?', '福蓉科技(股票代码:300049)的当前股价为48.67元。']
|
133
|
+
# query = '去年这个时候的股价是多少?'
|
134
|
+
# bisheng_assistant = BishengAssistant("config/base_scene.yaml")
|
135
|
+
# bisheng_assistant = BishengAssistant("config/knowledge_scene.yaml")
|
136
|
+
bisheng_assistant = BishengAssistant("config/rag_scene.yaml")
|
134
137
|
result = bisheng_assistant.run(query, chat_history=chat_history)
|
135
138
|
for r in result:
|
136
139
|
print(f'------------------')
|
@@ -26,6 +26,7 @@ from langchain_core.callbacks import BaseCallbackManager, Callbacks
|
|
26
26
|
from langchain_core.language_models import BaseLanguageModel
|
27
27
|
from langchain_core.tools import BaseTool, Tool
|
28
28
|
from mypy_extensions import Arg, KwArg
|
29
|
+
from bisheng_langchain.rag import BishengRAGTool
|
29
30
|
|
30
31
|
|
31
32
|
def _get_current_time() -> BaseTool:
|
@@ -86,6 +87,7 @@ _EXTRA_PARAM_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[Optio
|
|
86
87
|
'dalle_image_generator': (_get_dalle_image_generator, ['openai_api_key', 'openai_proxy'], []),
|
87
88
|
'bing_search': (_get_bing_search, ['bing_subscription_key', 'bing_search_url'], []),
|
88
89
|
'bisheng_code_interpreter': (_get_native_code_interpreter, ["minio"], ['files']),
|
90
|
+
'bisheng_rag': (BishengRAGTool.get_rag_tool, ['name', 'description'], ['vector_store', 'keyword_store', 'llm', 'collection_name', 'max_content', 'sort_by_source_and_index']),
|
89
91
|
}
|
90
92
|
|
91
93
|
_API_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {**ALL_API_TOOLS} # type: ignore
|
@@ -1,12 +1,14 @@
|
|
1
1
|
from bisheng_langchain.gpts.prompts.assistant_prompt_opt import ASSISTANT_PROMPT_OPT
|
2
|
-
from bisheng_langchain.gpts.prompts.
|
2
|
+
from bisheng_langchain.gpts.prompts.assistant_prompt_base import ASSISTANT_PROMPT_DEFAULT
|
3
|
+
from bisheng_langchain.gpts.prompts.assistant_prompt_cohere import ASSISTANT_PROMPT_COHERE
|
3
4
|
from bisheng_langchain.gpts.prompts.breif_description_prompt import BREIF_DES_PROMPT
|
4
5
|
from bisheng_langchain.gpts.prompts.opening_dialog_prompt import OPENDIALOG_PROMPT
|
5
6
|
from bisheng_langchain.gpts.prompts.select_tools_prompt import HUMAN_MSG, SYS_MSG
|
6
7
|
|
7
8
|
|
8
9
|
__all__ = [
|
9
|
-
"
|
10
|
+
"ASSISTANT_PROMPT_DEFAULT",
|
11
|
+
"ASSISTANT_PROMPT_COHERE",
|
10
12
|
"ASSISTANT_PROMPT_OPT",
|
11
13
|
"OPENDIALOG_PROMPT",
|
12
14
|
"BREIF_DES_PROMPT",
|
@@ -0,0 +1 @@
|
|
1
|
+
ASSISTANT_PROMPT_DEFAULT = "You are a helpful assistant."
|
@@ -0,0 +1,19 @@
|
|
1
|
+
preamble="""You are a helpful assistant.
|
2
|
+
"""
|
3
|
+
|
4
|
+
ASSISTANT_PROMPT_COHERE="""{preamble}|<instruct>|Carefully perform the following instructions, in order, starting each with a new line.
|
5
|
+
Firstly, You may need to use complex and advanced reasoning to complete your task and answer the question. Think about how you can use the provided tools to answer the question and come up with a high level plan you will execute.
|
6
|
+
Write 'Plan:' followed by an initial high level plan of how you will solve the problem including the tools and steps required.
|
7
|
+
Secondly, Carry out your plan by repeatedly using actions, reasoning over the results, and re-evaluating your plan. Perform Action, Observation, Reflection steps with the following format. Write 'Action:' followed by a json formatted action containing the "tool_name" and "parameters"
|
8
|
+
Next you will analyze the 'Observation:', this is the result of the action.
|
9
|
+
After that you should always think about what to do next. Write 'Reflection:' followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next including if you know the answer to the question.
|
10
|
+
... (this Action/Observation/Reflection can repeat N times)
|
11
|
+
Thirdly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'.
|
12
|
+
Fourthly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'.
|
13
|
+
Fifthly, Write 'Answer:' followed by a response to the user's last input. Use the retrieved documents to help you. Do not insert any citations or grounding markup.
|
14
|
+
Finally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols <co: doc> and </co: doc> to indicate when a fact comes from a document in the search result, e.g <co: 4>my fact</co: 4> for a fact from document 4.
|
15
|
+
|
16
|
+
Additional instructions to note:
|
17
|
+
- If the user's question is in Chinese, please answer it in Chinese.
|
18
|
+
- 当问题中有涉及到时间信息时,比如最近6个月、昨天、去年等,你需要用时间工具查询时间信息。
|
19
|
+
""".format(preamble=preamble)
|
@@ -1,10 +1,11 @@
|
|
1
1
|
from loguru import logger
|
2
|
-
from
|
2
|
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
3
3
|
from typing import Any
|
4
4
|
from .base import APIToolBase
|
5
5
|
from .base import MultArgsSchemaTool
|
6
6
|
from langchain_core.tools import BaseTool
|
7
7
|
|
8
|
+
|
8
9
|
class FlowTools(APIToolBase):
|
9
10
|
|
10
11
|
def run(self, query: str) -> str:
|
@@ -52,9 +53,8 @@ class FlowTools(APIToolBase):
|
|
52
53
|
}
|
53
54
|
|
54
55
|
class InputArgs(BaseModel):
|
55
|
-
"""args_schema"""
|
56
56
|
query: str = Field(description='questions to ask')
|
57
|
-
|
57
|
+
|
58
58
|
return cls(url=url, params=params, input_key=input_key, args_schema=InputArgs)
|
59
59
|
|
60
60
|
@classmethod
|
@@ -0,0 +1,101 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
from langchain_core.tools import BaseTool
|
4
|
+
from loguru import logger
|
5
|
+
from pydantic import BaseModel, create_model
|
6
|
+
|
7
|
+
from .base import APIToolBase, MultArgsSchemaTool, Field
|
8
|
+
|
9
|
+
|
10
|
+
class OpenApiTools(APIToolBase):
|
11
|
+
|
12
|
+
def get_real_path(self):
|
13
|
+
return self.url + self.params["path"]
|
14
|
+
|
15
|
+
def get_request_method(self):
|
16
|
+
return self.params["method"]
|
17
|
+
|
18
|
+
def get_params_json(self, **kwargs):
|
19
|
+
params_define = {}
|
20
|
+
for one in self.params["parameters"]:
|
21
|
+
params_define[one["name"]] = one
|
22
|
+
|
23
|
+
params = {}
|
24
|
+
json_data = {}
|
25
|
+
for k, v in kwargs.items():
|
26
|
+
if params_define.get(k):
|
27
|
+
if params_define[k]["in"] == "query":
|
28
|
+
params[k] = v
|
29
|
+
else:
|
30
|
+
json_data[k] = v
|
31
|
+
else:
|
32
|
+
params[k] = v
|
33
|
+
return params, json_data
|
34
|
+
|
35
|
+
def parse_args_schema(self):
|
36
|
+
params = self.params["parameters"]
|
37
|
+
model_params = {}
|
38
|
+
for one in params:
|
39
|
+
field_type = one["schema"]["type"]
|
40
|
+
if field_type == "number":
|
41
|
+
field_type = "float"
|
42
|
+
elif field_type == "integer":
|
43
|
+
field_type = "int"
|
44
|
+
elif field_type == "string":
|
45
|
+
field_type = "str"
|
46
|
+
else:
|
47
|
+
raise Exception(f"schema type is not support: {field_type}")
|
48
|
+
model_params[one["name"]] = (eval(field_type), Field(description=one["description"]))
|
49
|
+
return create_model("InputArgs", __module__='bisheng_langchain.gpts.tools.api_tools.openapi',
|
50
|
+
__base__=BaseModel, **model_params)
|
51
|
+
|
52
|
+
def run(self, **kwargs) -> str:
|
53
|
+
"""Run query through api and parse result."""
|
54
|
+
path = self.get_real_path()
|
55
|
+
logger.info('api_call url={}', path)
|
56
|
+
method = self.get_request_method()
|
57
|
+
params, json_data = self.get_params_json(**kwargs)
|
58
|
+
|
59
|
+
if method == "get":
|
60
|
+
resp = self.client.get(path, params=params)
|
61
|
+
elif method == 'post':
|
62
|
+
resp = self.client.post(path, params=params, json=self.params)
|
63
|
+
elif method == 'put':
|
64
|
+
resp = self.client.put(path, params=params, json=self.params)
|
65
|
+
elif method == 'delete':
|
66
|
+
resp = self.client.delete(path, params=params, json=self.params)
|
67
|
+
else:
|
68
|
+
raise Exception(f"http method is not support: {method}")
|
69
|
+
if resp.status_code != 200:
|
70
|
+
logger.info(f'api_call_fail code={resp.status_code} res={resp.text}')
|
71
|
+
raise Exception(f"api_call_fail: {resp.status_code} {resp.text}")
|
72
|
+
return resp.text
|
73
|
+
|
74
|
+
async def arun(self, **kwargs) -> str:
|
75
|
+
"""Run query through api and parse result."""
|
76
|
+
path = self.get_real_path()
|
77
|
+
logger.info('api_call url={}', path)
|
78
|
+
method = self.get_request_method()
|
79
|
+
params, json_data = self.get_params_json(**kwargs)
|
80
|
+
|
81
|
+
if method == "get":
|
82
|
+
resp = await self.async_client.aget(path, params=params)
|
83
|
+
elif method == 'post':
|
84
|
+
resp = await self.async_client.apost(path, params=params, json=self.params)
|
85
|
+
elif method == 'put':
|
86
|
+
resp = await self.async_client.aput(path, params=params, json=self.params)
|
87
|
+
elif method == 'delete':
|
88
|
+
resp = await self.async_client.adelete(path, params=params, json=self.params)
|
89
|
+
else:
|
90
|
+
raise Exception(f"http method is not support: {method}")
|
91
|
+
return resp
|
92
|
+
|
93
|
+
@classmethod
|
94
|
+
def get_api_tool(cls, name, **kwargs: Any) -> BaseTool:
|
95
|
+
description = kwargs.pop("description", "")
|
96
|
+
obj = cls(**kwargs)
|
97
|
+
return MultArgsSchemaTool(name=name,
|
98
|
+
description=description,
|
99
|
+
func=obj.run,
|
100
|
+
coroutine=obj.arun,
|
101
|
+
args_schema=obj.parse_args_schema())
|
@@ -0,0 +1,320 @@
|
|
1
|
+
import argparse
|
2
|
+
import copy
|
3
|
+
import inspect
|
4
|
+
import time
|
5
|
+
import os
|
6
|
+
from collections import defaultdict
|
7
|
+
|
8
|
+
import httpx
|
9
|
+
import pandas as pd
|
10
|
+
import yaml
|
11
|
+
from loguru import logger
|
12
|
+
from tqdm import tqdm
|
13
|
+
from bisheng_langchain.retrievers import EnsembleRetriever
|
14
|
+
from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
|
15
|
+
from langchain.chains.question_answering import load_qa_chain
|
16
|
+
from bisheng_langchain.rag.init_retrievers import (
|
17
|
+
BaselineVectorRetriever,
|
18
|
+
KeywordRetriever,
|
19
|
+
MixRetriever,
|
20
|
+
SmallerChunksVectorRetriever,
|
21
|
+
)
|
22
|
+
from bisheng_langchain.rag.scoring.ragas_score import RagScore
|
23
|
+
from bisheng_langchain.rag.utils import import_by_type, import_class
|
24
|
+
|
25
|
+
|
26
|
+
class BishengRagPipeline:
|
27
|
+
|
28
|
+
def __init__(self, yaml_path) -> None:
|
29
|
+
self.yaml_path = yaml_path
|
30
|
+
with open(self.yaml_path, 'r') as f:
|
31
|
+
self.params = yaml.safe_load(f)
|
32
|
+
|
33
|
+
# init data
|
34
|
+
self.origin_file_path = self.params['data']['origin_file_path']
|
35
|
+
self.question_path = self.params['data']['question']
|
36
|
+
self.save_answer_path = self.params['data']['save_answer']
|
37
|
+
|
38
|
+
# init embeddings
|
39
|
+
embedding_params = self.params['embedding']
|
40
|
+
embedding_object = import_by_type(_type='embeddings', name=embedding_params['type'])
|
41
|
+
if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
|
42
|
+
embedding_params.pop('type')
|
43
|
+
self.embeddings = embedding_object(
|
44
|
+
http_client=httpx.Client(proxies=embedding_params['openai_proxy']), **embedding_params
|
45
|
+
)
|
46
|
+
else:
|
47
|
+
embedding_params.pop('type')
|
48
|
+
self.embeddings = embedding_object(**embedding_params)
|
49
|
+
|
50
|
+
# init llm
|
51
|
+
llm_params = self.params['chat_llm']
|
52
|
+
llm_object = import_by_type(_type='llms', name=llm_params['type'])
|
53
|
+
if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
|
54
|
+
llm_params.pop('type')
|
55
|
+
self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
|
56
|
+
else:
|
57
|
+
llm_params.pop('type')
|
58
|
+
self.llm = llm_object(**llm_params)
|
59
|
+
|
60
|
+
# milvus
|
61
|
+
self.vector_store = Milvus(
|
62
|
+
embedding_function=self.embeddings,
|
63
|
+
connection_args={
|
64
|
+
"host": self.params['milvus']['host'],
|
65
|
+
"port": self.params['milvus']['port'],
|
66
|
+
},
|
67
|
+
)
|
68
|
+
|
69
|
+
# es
|
70
|
+
self.keyword_store = ElasticKeywordsSearch(
|
71
|
+
index_name='default_es',
|
72
|
+
elasticsearch_url=self.params['elasticsearch']['url'],
|
73
|
+
ssl_verify=self.params['elasticsearch']['ssl_verify'],
|
74
|
+
)
|
75
|
+
|
76
|
+
# init retriever
|
77
|
+
retriever_list = []
|
78
|
+
retrievers = self.params['retriever']['retrievers']
|
79
|
+
for retriever in retrievers:
|
80
|
+
retriever_type = retriever.pop('type')
|
81
|
+
retriever_params = {
|
82
|
+
'vector_store': self.vector_store,
|
83
|
+
'keyword_store': self.keyword_store,
|
84
|
+
'splitter_kwargs': retriever['splitter'],
|
85
|
+
'retrieval_kwargs': retriever['retrieval'],
|
86
|
+
}
|
87
|
+
retriever_list.append(self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
|
88
|
+
self.retriever = EnsembleRetriever(retrievers=retriever_list)
|
89
|
+
|
90
|
+
def _post_init_retriever(self, retriever_type, **kwargs):
|
91
|
+
retriever_classes = {
|
92
|
+
'KeywordRetriever': KeywordRetriever,
|
93
|
+
'BaselineVectorRetriever': BaselineVectorRetriever,
|
94
|
+
'MixRetriever': MixRetriever,
|
95
|
+
'SmallerChunksVectorRetriever': SmallerChunksVectorRetriever,
|
96
|
+
}
|
97
|
+
if retriever_type not in retriever_classes:
|
98
|
+
raise ValueError(f'Unknown retriever type: {retriever_type}')
|
99
|
+
|
100
|
+
input_kwargs = {}
|
101
|
+
splitter_params = kwargs.pop('splitter_kwargs')
|
102
|
+
for key, value in splitter_params.items():
|
103
|
+
splitter_obj = import_by_type(_type='textsplitters', name=value.pop('type'))
|
104
|
+
input_kwargs[key] = splitter_obj(**value)
|
105
|
+
|
106
|
+
retrieval_params = kwargs.pop('retrieval_kwargs')
|
107
|
+
for key, value in retrieval_params.items():
|
108
|
+
input_kwargs[key] = value
|
109
|
+
|
110
|
+
input_kwargs['vector_store'] = kwargs.pop('vector_store')
|
111
|
+
input_kwargs['keyword_store'] = kwargs.pop('keyword_store')
|
112
|
+
|
113
|
+
retriever_class = retriever_classes[retriever_type]
|
114
|
+
return retriever_class(**input_kwargs)
|
115
|
+
|
116
|
+
def file2knowledge(self):
|
117
|
+
"""
|
118
|
+
file to knowledge
|
119
|
+
"""
|
120
|
+
df = pd.read_excel(self.question_path)
|
121
|
+
if ('文件名' not in df.columns) or ('知识库名' not in df.columns):
|
122
|
+
raise Exception(f'文件名 or 知识库名 not in {self.question_path}.')
|
123
|
+
|
124
|
+
loader_params = self.params['loader']
|
125
|
+
loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
|
126
|
+
|
127
|
+
all_questions_info = df.to_dict('records')
|
128
|
+
collectionname2filename = defaultdict(set)
|
129
|
+
for info in all_questions_info:
|
130
|
+
# 存入set,去掉重复的文件名
|
131
|
+
collectionname2filename[info['知识库名']].add(info['文件名'])
|
132
|
+
|
133
|
+
for collection_name in tqdm(collectionname2filename):
|
134
|
+
all_file_paths = []
|
135
|
+
for file_name in collectionname2filename[collection_name]:
|
136
|
+
file_path = os.path.join(self.origin_file_path, file_name)
|
137
|
+
if not os.path.exists(file_path):
|
138
|
+
raise Exception(f'{file_path} not exists.')
|
139
|
+
# file path可以是文件夹或者单个文件
|
140
|
+
if os.path.isdir(file_path):
|
141
|
+
# 文件夹包含多个文件
|
142
|
+
all_file_paths.extend(
|
143
|
+
[os.path.join(file_path, name) for name in os.listdir(file_path) if not name.startswith('.')]
|
144
|
+
)
|
145
|
+
else:
|
146
|
+
# 单个文件
|
147
|
+
all_file_paths.append(file_path)
|
148
|
+
|
149
|
+
# 当前知识库需要存储的所有文件
|
150
|
+
collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
|
151
|
+
for index, each_file_path in enumerate(all_file_paths):
|
152
|
+
logger.info(f'each_file_path: {each_file_path}')
|
153
|
+
loader = loader_object(
|
154
|
+
file_name=os.path.basename(each_file_path), file_path=each_file_path, **loader_params
|
155
|
+
)
|
156
|
+
documents = loader.load()
|
157
|
+
logger.info(f'documents: {len(documents)}')
|
158
|
+
if len(documents[0].page_content) == 0:
|
159
|
+
logger.error(f'{each_file_path} page_content is empty.')
|
160
|
+
|
161
|
+
vector_drop_old = self.params['milvus']['drop_old'] if index == 0 else False
|
162
|
+
keyword_drop_old = self.params['elasticsearch']['drop_old'] if index == 0 else False
|
163
|
+
for idx, retriever in enumerate(self.retriever.retrievers):
|
164
|
+
retriever.add_documents(documents, f"{collection_name}_{idx}", vector_drop_old)
|
165
|
+
# retriever.add_documents(documents, collection_name, vector_drop_old)
|
166
|
+
|
167
|
+
def retrieval_and_rerank(self, question, collection_name):
|
168
|
+
"""
|
169
|
+
retrieval and rerank
|
170
|
+
"""
|
171
|
+
collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
|
172
|
+
|
173
|
+
# EnsembleRetriever直接检索召回会默认去重
|
174
|
+
# docs = self.retriever.get_relevant_documents(query=question, collection_name=collection_name)
|
175
|
+
docs = []
|
176
|
+
for idx, retriever in enumerate(self.retriever.retrievers):
|
177
|
+
docs.extend(retriever.get_relevant_documents(query=question, collection_name=f"{collection_name}_{idx}"))
|
178
|
+
# docs.extend(retriever.get_relevant_documents(query=question, collection_name=collection_name))
|
179
|
+
logger.info(f'retrieval docs: {len(docs)}')
|
180
|
+
|
181
|
+
# delete duplicate
|
182
|
+
if self.params['post_retrieval']['delete_duplicate']:
|
183
|
+
logger.info(f'origin docs: {len(docs)}')
|
184
|
+
all_contents = []
|
185
|
+
docs_no_dup = []
|
186
|
+
for index, doc in enumerate(docs):
|
187
|
+
doc_content = doc.page_content
|
188
|
+
if doc_content in all_contents:
|
189
|
+
continue
|
190
|
+
all_contents.append(doc_content)
|
191
|
+
docs_no_dup.append(doc)
|
192
|
+
docs = docs_no_dup
|
193
|
+
logger.info(f'delete duplicate docs: {len(docs)}')
|
194
|
+
|
195
|
+
# rerank
|
196
|
+
if self.params['post_retrieval']['with_rank'] and len(docs):
|
197
|
+
if not hasattr(self, 'ranker'):
|
198
|
+
rerank_params = self.params['post_retrieval']['rerank']
|
199
|
+
rerank_type = rerank_params.pop('type')
|
200
|
+
rerank_object = import_class(f'bisheng_langchain.rag.rerank.{rerank_type}')
|
201
|
+
self.ranker = rerank_object(**rerank_params)
|
202
|
+
docs = getattr(self, 'ranker').sort_and_filter(question, docs)
|
203
|
+
|
204
|
+
return docs
|
205
|
+
|
206
|
+
def load_documents(self, file_name, max_content=100000):
|
207
|
+
"""
|
208
|
+
max_content: max content len of llm
|
209
|
+
"""
|
210
|
+
file_path = os.path.join(self.origin_file_path, file_name)
|
211
|
+
if not os.path.exists(file_path):
|
212
|
+
raise Exception(f'{file_path} not exists.')
|
213
|
+
if os.path.isdir(file_path):
|
214
|
+
raise Exception(f'{file_path} is a directory.')
|
215
|
+
|
216
|
+
loader_params = copy.deepcopy(self.params['loader'])
|
217
|
+
loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
|
218
|
+
loader = loader_object(file_name=file_name, file_path=file_path, **loader_params)
|
219
|
+
|
220
|
+
documents = loader.load()
|
221
|
+
logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
|
222
|
+
for doc in documents:
|
223
|
+
doc.page_content = doc.page_content[:max_content]
|
224
|
+
return documents
|
225
|
+
|
226
|
+
def question_answering(self):
|
227
|
+
"""
|
228
|
+
question answer over knowledge
|
229
|
+
"""
|
230
|
+
df = pd.read_excel(self.question_path)
|
231
|
+
all_questions_info = df.to_dict('records')
|
232
|
+
if 'prompt_type' in self.params['generate']:
|
233
|
+
prompt_type = self.params['generate']['prompt_type']
|
234
|
+
prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
|
235
|
+
else:
|
236
|
+
prompt = None
|
237
|
+
qa_chain = load_qa_chain(
|
238
|
+
llm=self.llm, chain_type=self.params['generate']['chain_type'], prompt=prompt, verbose=False
|
239
|
+
)
|
240
|
+
file2docs = dict()
|
241
|
+
for questions_info in tqdm(all_questions_info):
|
242
|
+
question = questions_info['问题']
|
243
|
+
file_name = questions_info['文件名']
|
244
|
+
collection_name = questions_info['知识库名']
|
245
|
+
|
246
|
+
if self.params['generate']['with_retrieval']:
|
247
|
+
# retrieval and rerank
|
248
|
+
docs = self.retrieval_and_rerank(question, collection_name)
|
249
|
+
else:
|
250
|
+
# load document
|
251
|
+
if file_name not in file2docs:
|
252
|
+
docs = self.load_documents(file_name)
|
253
|
+
file2docs[file_name] = docs
|
254
|
+
else:
|
255
|
+
docs = file2docs[file_name]
|
256
|
+
|
257
|
+
# question answer
|
258
|
+
try:
|
259
|
+
ans = qa_chain({"input_documents": docs, "question": question}, return_only_outputs=False)
|
260
|
+
except Exception as e:
|
261
|
+
logger.error(f'question: {question}\nerror: {e}')
|
262
|
+
ans = {'output_text': str(e)}
|
263
|
+
|
264
|
+
# context = '\n\n'.join([doc.page_content for doc in docs])
|
265
|
+
# content = prompt.format(context=context, question=question)
|
266
|
+
|
267
|
+
# # for rate_limit
|
268
|
+
# time.sleep(15)
|
269
|
+
|
270
|
+
rag_answer = ans['output_text']
|
271
|
+
logger.info(f'question: {question}\nans: {rag_answer}\n')
|
272
|
+
questions_info['rag_answer'] = rag_answer
|
273
|
+
# questions_info['rag_context'] = '\n----------------\n'.join([doc.page_content for doc in docs])
|
274
|
+
# questions_info['rag_context'] = content
|
275
|
+
|
276
|
+
df = pd.DataFrame(all_questions_info)
|
277
|
+
df.to_excel(self.save_answer_path, index=False)
|
278
|
+
|
279
|
+
def score(self):
|
280
|
+
"""
|
281
|
+
score
|
282
|
+
"""
|
283
|
+
metric_params = self.params['metric']
|
284
|
+
if metric_params['type'] == 'bisheng-ragas':
|
285
|
+
score_params = {
|
286
|
+
'excel_path': self.save_answer_path,
|
287
|
+
'save_path': os.path.dirname(self.save_answer_path),
|
288
|
+
'question_column': metric_params['question_column'],
|
289
|
+
'gt_column': metric_params['gt_column'],
|
290
|
+
'answer_column': metric_params['answer_column'],
|
291
|
+
'query_type_column': metric_params.get('query_type_column', None),
|
292
|
+
'contexts_column': metric_params.get('contexts_column', None),
|
293
|
+
'metrics': metric_params['metrics'],
|
294
|
+
'batch_size': metric_params['batch_size'],
|
295
|
+
'gt_split_column': metric_params.get('gt_split_column', None),
|
296
|
+
'whether_gtsplit': metric_params.get('whether_gtsplit', False), # 是否需要模型对gt进行要点拆分
|
297
|
+
}
|
298
|
+
rag_score = RagScore(**score_params)
|
299
|
+
rag_score.score()
|
300
|
+
else:
|
301
|
+
# todo: 其他评分方法
|
302
|
+
pass
|
303
|
+
|
304
|
+
|
305
|
+
if __name__ == '__main__':
|
306
|
+
parser = argparse.ArgumentParser(description='Process some integers.')
|
307
|
+
# 添加参数
|
308
|
+
parser.add_argument('--mode', type=str, default='qa', help='upload or qa or score')
|
309
|
+
parser.add_argument('--params', type=str, default='config/test/baseline_s2b.yaml', help='bisheng rag params')
|
310
|
+
# 解析参数
|
311
|
+
args = parser.parse_args()
|
312
|
+
|
313
|
+
rag = BishengRagPipeline(args.params)
|
314
|
+
|
315
|
+
if args.mode == 'upload':
|
316
|
+
rag.file2knowledge()
|
317
|
+
elif args.mode == 'qa':
|
318
|
+
rag.question_answering()
|
319
|
+
elif args.mode == 'score':
|
320
|
+
rag.score()
|