bisheng-langchain 0.3.0rc0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/chat_models/host_llm.py +1 -1
- bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +5 -3
- bisheng_langchain/gpts/agent_types/llm_functions_agent.py +7 -1
- bisheng_langchain/gpts/assistant.py +8 -5
- bisheng_langchain/gpts/auto_optimization.py +28 -27
- bisheng_langchain/gpts/auto_tool_selected.py +14 -15
- bisheng_langchain/gpts/load_tools.py +53 -1
- bisheng_langchain/gpts/prompts/__init__.py +4 -2
- bisheng_langchain/gpts/prompts/assistant_prompt_base.py +1 -0
- bisheng_langchain/gpts/prompts/assistant_prompt_cohere.py +19 -0
- bisheng_langchain/gpts/prompts/opening_dialog_prompt.py +1 -1
- bisheng_langchain/gpts/tools/api_tools/__init__.py +1 -1
- bisheng_langchain/gpts/tools/api_tools/base.py +3 -3
- bisheng_langchain/gpts/tools/api_tools/flow.py +19 -7
- bisheng_langchain/gpts/tools/api_tools/macro_data.py +175 -4
- bisheng_langchain/gpts/tools/api_tools/openapi.py +101 -0
- bisheng_langchain/gpts/tools/api_tools/sina.py +2 -2
- bisheng_langchain/gpts/tools/code_interpreter/tool.py +118 -39
- bisheng_langchain/rag/__init__.py +5 -0
- bisheng_langchain/rag/bisheng_rag_pipeline.py +320 -0
- bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +359 -0
- bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +376 -0
- bisheng_langchain/rag/bisheng_rag_tool.py +288 -0
- bisheng_langchain/rag/config/baseline.yaml +86 -0
- bisheng_langchain/rag/config/baseline_caibao.yaml +82 -0
- bisheng_langchain/rag/config/baseline_caibao_knowledge_v2.yaml +110 -0
- bisheng_langchain/rag/config/baseline_caibao_v2.yaml +112 -0
- bisheng_langchain/rag/config/baseline_demo_v2.yaml +92 -0
- bisheng_langchain/rag/config/baseline_s2b_mix.yaml +88 -0
- bisheng_langchain/rag/config/baseline_v2.yaml +90 -0
- bisheng_langchain/rag/extract_info.py +38 -0
- bisheng_langchain/rag/init_retrievers/__init__.py +4 -0
- bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +61 -0
- bisheng_langchain/rag/init_retrievers/keyword_retriever.py +65 -0
- bisheng_langchain/rag/init_retrievers/mix_retriever.py +103 -0
- bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +92 -0
- bisheng_langchain/rag/prompts/__init__.py +9 -0
- bisheng_langchain/rag/prompts/extract_key_prompt.py +34 -0
- bisheng_langchain/rag/prompts/prompt.py +47 -0
- bisheng_langchain/rag/prompts/prompt_cohere.py +111 -0
- bisheng_langchain/rag/qa_corpus/__init__.py +0 -0
- bisheng_langchain/rag/qa_corpus/qa_generator.py +143 -0
- bisheng_langchain/rag/rerank/__init__.py +5 -0
- bisheng_langchain/rag/rerank/rerank.py +48 -0
- bisheng_langchain/rag/rerank/rerank_benchmark.py +139 -0
- bisheng_langchain/rag/run_qa_gen_web.py +47 -0
- bisheng_langchain/rag/run_rag_evaluate_web.py +55 -0
- bisheng_langchain/rag/scoring/__init__.py +0 -0
- bisheng_langchain/rag/scoring/llama_index_score.py +91 -0
- bisheng_langchain/rag/scoring/ragas_score.py +183 -0
- bisheng_langchain/rag/utils.py +181 -0
- bisheng_langchain/retrievers/ensemble.py +2 -1
- bisheng_langchain/vectorstores/elastic_keywords_search.py +2 -1
- {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/METADATA +1 -1
- {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/RECORD +57 -22
- bisheng_langchain/gpts/prompts/base_prompt.py +0 -1
- {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/WHEEL +0 -0
- {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,101 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
from langchain_core.tools import BaseTool
|
4
|
+
from loguru import logger
|
5
|
+
from pydantic import BaseModel, create_model
|
6
|
+
|
7
|
+
from .base import APIToolBase, MultArgsSchemaTool, Field
|
8
|
+
|
9
|
+
|
10
|
+
class OpenApiTools(APIToolBase):
|
11
|
+
|
12
|
+
def get_real_path(self):
|
13
|
+
return self.url + self.params["path"]
|
14
|
+
|
15
|
+
def get_request_method(self):
|
16
|
+
return self.params["method"]
|
17
|
+
|
18
|
+
def get_params_json(self, **kwargs):
|
19
|
+
params_define = {}
|
20
|
+
for one in self.params["parameters"]:
|
21
|
+
params_define[one["name"]] = one
|
22
|
+
|
23
|
+
params = {}
|
24
|
+
json_data = {}
|
25
|
+
for k, v in kwargs.items():
|
26
|
+
if params_define.get(k):
|
27
|
+
if params_define[k]["in"] == "query":
|
28
|
+
params[k] = v
|
29
|
+
else:
|
30
|
+
json_data[k] = v
|
31
|
+
else:
|
32
|
+
params[k] = v
|
33
|
+
return params, json_data
|
34
|
+
|
35
|
+
def parse_args_schema(self):
|
36
|
+
params = self.params["parameters"]
|
37
|
+
model_params = {}
|
38
|
+
for one in params:
|
39
|
+
field_type = one["schema"]["type"]
|
40
|
+
if field_type == "number":
|
41
|
+
field_type = "float"
|
42
|
+
elif field_type == "integer":
|
43
|
+
field_type = "int"
|
44
|
+
elif field_type == "string":
|
45
|
+
field_type = "str"
|
46
|
+
else:
|
47
|
+
raise Exception(f"schema type is not support: {field_type}")
|
48
|
+
model_params[one["name"]] = (eval(field_type), Field(description=one["description"]))
|
49
|
+
return create_model("InputArgs", __module__='bisheng_langchain.gpts.tools.api_tools.openapi',
|
50
|
+
__base__=BaseModel, **model_params)
|
51
|
+
|
52
|
+
def run(self, **kwargs) -> str:
|
53
|
+
"""Run query through api and parse result."""
|
54
|
+
path = self.get_real_path()
|
55
|
+
logger.info('api_call url={}', path)
|
56
|
+
method = self.get_request_method()
|
57
|
+
params, json_data = self.get_params_json(**kwargs)
|
58
|
+
|
59
|
+
if method == "get":
|
60
|
+
resp = self.client.get(path, params=params)
|
61
|
+
elif method == 'post':
|
62
|
+
resp = self.client.post(path, params=params, json=self.params)
|
63
|
+
elif method == 'put':
|
64
|
+
resp = self.client.put(path, params=params, json=self.params)
|
65
|
+
elif method == 'delete':
|
66
|
+
resp = self.client.delete(path, params=params, json=self.params)
|
67
|
+
else:
|
68
|
+
raise Exception(f"http method is not support: {method}")
|
69
|
+
if resp.status_code != 200:
|
70
|
+
logger.info(f'api_call_fail code={resp.status_code} res={resp.text}')
|
71
|
+
raise Exception(f"api_call_fail: {resp.status_code} {resp.text}")
|
72
|
+
return resp.text
|
73
|
+
|
74
|
+
async def arun(self, **kwargs) -> str:
|
75
|
+
"""Run query through api and parse result."""
|
76
|
+
path = self.get_real_path()
|
77
|
+
logger.info('api_call url={}', path)
|
78
|
+
method = self.get_request_method()
|
79
|
+
params, json_data = self.get_params_json(**kwargs)
|
80
|
+
|
81
|
+
if method == "get":
|
82
|
+
resp = await self.async_client.aget(path, params=params)
|
83
|
+
elif method == 'post':
|
84
|
+
resp = await self.async_client.apost(path, params=params, json=self.params)
|
85
|
+
elif method == 'put':
|
86
|
+
resp = await self.async_client.aput(path, params=params, json=self.params)
|
87
|
+
elif method == 'delete':
|
88
|
+
resp = await self.async_client.adelete(path, params=params, json=self.params)
|
89
|
+
else:
|
90
|
+
raise Exception(f"http method is not support: {method}")
|
91
|
+
return resp
|
92
|
+
|
93
|
+
@classmethod
|
94
|
+
def get_api_tool(cls, name, **kwargs: Any) -> BaseTool:
|
95
|
+
description = kwargs.pop("description", "")
|
96
|
+
obj = cls(**kwargs)
|
97
|
+
return MultArgsSchemaTool(name=name,
|
98
|
+
description=description,
|
99
|
+
func=obj.run,
|
100
|
+
coroutine=obj.arun,
|
101
|
+
args_schema=obj.parse_args_schema())
|
@@ -154,7 +154,7 @@ class StockInfo(APIToolBase):
|
|
154
154
|
resp = super().run(query=stock_number)
|
155
155
|
stock = self.devideStock(resp)[0]
|
156
156
|
if isinstance(stock, Stock):
|
157
|
-
return json.dumps(stock.__dict__)
|
157
|
+
return json.dumps(stock.__dict__, ensure_ascii=False)
|
158
158
|
else:
|
159
159
|
return stock
|
160
160
|
|
@@ -183,7 +183,7 @@ class StockInfo(APIToolBase):
|
|
183
183
|
resp = await super().arun(query=stock_number)
|
184
184
|
stock = self.devideStock(resp)[0]
|
185
185
|
if isinstance(stock, Stock):
|
186
|
-
return json.dumps(stock.__dict__)
|
186
|
+
return json.dumps(stock.__dict__, ensure_ascii=False)
|
187
187
|
else:
|
188
188
|
return stock
|
189
189
|
|
@@ -1,6 +1,8 @@
|
|
1
|
+
import glob
|
1
2
|
import itertools
|
2
3
|
import os
|
3
4
|
import pathlib
|
5
|
+
import re
|
4
6
|
import subprocess
|
5
7
|
import sys
|
6
8
|
import tempfile
|
@@ -11,24 +13,18 @@ from pathlib import Path
|
|
11
13
|
from typing import Dict, List, Optional, Tuple, Type
|
12
14
|
from uuid import uuid4
|
13
15
|
|
14
|
-
|
16
|
+
import matplotlib
|
15
17
|
from langchain_community.tools import Tool
|
16
18
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
17
19
|
from loguru import logger
|
18
20
|
|
19
|
-
|
20
|
-
from termcolor import colored
|
21
|
-
except ImportError:
|
22
|
-
|
23
|
-
def colored(x, *args, **kwargs):
|
24
|
-
return x
|
25
|
-
|
26
|
-
|
21
|
+
CODE_BLOCK_PATTERN = r"```(\w*)\n(.*?)\n```"
|
27
22
|
DEFAULT_TIMEOUT = 600
|
28
23
|
WIN32 = sys.platform == 'win32'
|
29
24
|
PATH_SEPARATOR = WIN32 and '\\' or '/'
|
30
25
|
WORKING_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'extensions')
|
31
26
|
TIMEOUT_MSG = 'Timeout'
|
27
|
+
UNKNOWN = "unknown"
|
32
28
|
|
33
29
|
|
34
30
|
def _cmd(lang):
|
@@ -41,6 +37,61 @@ def _cmd(lang):
|
|
41
37
|
raise NotImplementedError(f'{lang} not recognized in code execution')
|
42
38
|
|
43
39
|
|
40
|
+
def infer_lang(code):
|
41
|
+
"""infer the language for the code.
|
42
|
+
TODO: make it robust.
|
43
|
+
"""
|
44
|
+
if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "):
|
45
|
+
return "sh"
|
46
|
+
|
47
|
+
# check if code is a valid python code
|
48
|
+
try:
|
49
|
+
compile(code, "test", "exec")
|
50
|
+
return "python"
|
51
|
+
except SyntaxError:
|
52
|
+
# not a valid python code
|
53
|
+
return UNKNOWN
|
54
|
+
|
55
|
+
|
56
|
+
def extract_code(
|
57
|
+
text: str, pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False
|
58
|
+
) -> List[Tuple[str, str]]:
|
59
|
+
"""Extract code from a text.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
text (str): The text to extract code from.
|
63
|
+
pattern (str, optional): The regular expression pattern for finding the
|
64
|
+
code block. Defaults to CODE_BLOCK_PATTERN.
|
65
|
+
detect_single_line_code (bool, optional): Enable the new feature for
|
66
|
+
extracting single line code. Defaults to False.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
list: A list of tuples, each containing the language and the code.
|
70
|
+
If there is no code block in the input text, the language would be "unknown".
|
71
|
+
If there is code block but the language is not specified, the language would be "".
|
72
|
+
"""
|
73
|
+
if not detect_single_line_code:
|
74
|
+
match = re.findall(pattern, text, flags=re.DOTALL)
|
75
|
+
return match if match else [(UNKNOWN, text)]
|
76
|
+
|
77
|
+
# Extract both multi-line and single-line code block, separated by the | operator
|
78
|
+
# `{3}(\w+)?\s*([\s\S]*?)`{3}: Matches multi-line code blocks.
|
79
|
+
# The (\w+)? matches the language, where the ? indicates it is optional.
|
80
|
+
# `([^`]+)`: Matches inline code.
|
81
|
+
code_pattern = re.compile(r"`{3}(\w+)?\s*([\s\S]*?)`{3}|`([^`]+)`")
|
82
|
+
code_blocks = code_pattern.findall(text)
|
83
|
+
|
84
|
+
# Extract the individual code blocks and languages from the matched groups
|
85
|
+
extracted = []
|
86
|
+
for lang, group1, group2 in code_blocks:
|
87
|
+
if group1:
|
88
|
+
extracted.append((lang.strip(), group1.strip()))
|
89
|
+
elif group2:
|
90
|
+
extracted.append(("", group2.strip()))
|
91
|
+
|
92
|
+
return extracted
|
93
|
+
|
94
|
+
|
44
95
|
def execute_code(
|
45
96
|
code: Optional[str] = None,
|
46
97
|
timeout: Optional[int] = None,
|
@@ -121,16 +172,66 @@ def head_file(path: str, n: int) -> List[str]:
|
|
121
172
|
return []
|
122
173
|
|
123
174
|
|
124
|
-
def upload_minio(
|
175
|
+
def upload_minio(
|
176
|
+
param: dict,
|
177
|
+
bucket: str,
|
178
|
+
object_name: str,
|
179
|
+
file_path,
|
180
|
+
content_type='application/text',
|
181
|
+
):
|
125
182
|
# 初始化minio
|
126
183
|
import minio
|
127
184
|
|
128
|
-
minio_client = minio.Minio(
|
129
|
-
|
185
|
+
minio_client = minio.Minio(
|
186
|
+
endpoint=param.get('MINIO_ENDPOINT'),
|
187
|
+
access_key=param.get('MINIO_ACCESS_KEY'),
|
188
|
+
secret_key=param.get('MINIO_SECRET_KEY'),
|
189
|
+
secure=param.get('SCHEMA'),
|
190
|
+
cert_check=param.get('CERT_CHECK'),
|
191
|
+
)
|
192
|
+
minio_share = minio.Minio(
|
193
|
+
endpoint=param.get('MINIO_SHAREPOIN'),
|
194
|
+
access_key=param.get('MINIO_ACCESS_KEY'),
|
195
|
+
secret_key=param.get('MINIO_SECRET_KEY'),
|
196
|
+
secure=param.get('SCHEMA'),
|
197
|
+
cert_check=param.get('CERT_CHECK'),
|
198
|
+
)
|
199
|
+
logger.debug(
|
200
|
+
'upload_file obj={} bucket={} file_paht={}',
|
201
|
+
object_name,
|
202
|
+
bucket,
|
203
|
+
file_path,
|
204
|
+
)
|
130
205
|
minio_client.fput_object(
|
131
|
-
bucket_name=bucket,
|
206
|
+
bucket_name=bucket,
|
207
|
+
object_name=object_name,
|
208
|
+
file_path=file_path,
|
209
|
+
content_type=content_type,
|
210
|
+
)
|
211
|
+
return minio_share.presigned_get_object(
|
212
|
+
bucket_name=bucket,
|
213
|
+
object_name=object_name,
|
214
|
+
expires=timedelta(days=7),
|
132
215
|
)
|
133
|
-
|
216
|
+
|
217
|
+
|
218
|
+
def insert_set_font_code(code: str) -> str:
|
219
|
+
"""判断python代码中是否导入了matplotlib库,如果有则插入设置字体的代码"""
|
220
|
+
|
221
|
+
split_code = code.split('\n')
|
222
|
+
cache_file = matplotlib.get_cachedir()
|
223
|
+
font_cache = glob.glob(f'{cache_file}/fontlist*')
|
224
|
+
|
225
|
+
for cache in font_cache:
|
226
|
+
os.remove(cache)
|
227
|
+
|
228
|
+
# todo: 如果生成的代码中已经有了设置字体的代码,可能会导致该段代码失效
|
229
|
+
if 'matplotlib' in code:
|
230
|
+
pattern = re.compile(r'(import matplotlib|from matplotlib)')
|
231
|
+
index = max(i for i, line in enumerate(split_code) if pattern.search(line))
|
232
|
+
split_code.insert(index + 1, 'import matplotlib\nmatplotlib.rc("font", family="WenQuanYi Zen Hei")')
|
233
|
+
|
234
|
+
return '\n'.join(split_code)
|
134
235
|
|
135
236
|
|
136
237
|
class CodeInterpreterToolArguments(BaseModel):
|
@@ -169,7 +270,7 @@ class FileInfo(BaseModel):
|
|
169
270
|
class CodeInterpreterTool:
|
170
271
|
"""Tool for evaluating python code in native environment."""
|
171
272
|
|
172
|
-
name = '
|
273
|
+
name = 'bisheng_code_interpreter'
|
173
274
|
args_schema: Type[BaseModel] = CodeInterpreterToolArguments
|
174
275
|
|
175
276
|
def __init__(
|
@@ -204,6 +305,7 @@ class CodeInterpreterTool:
|
|
204
305
|
for i, code_block in enumerate(code_blocks):
|
205
306
|
lang, code = code_block
|
206
307
|
lang = infer_lang(code)
|
308
|
+
code = insert_set_font_code(code)
|
207
309
|
temp_dir = tempfile.TemporaryDirectory()
|
208
310
|
exitcode, logs, _ = execute_code(
|
209
311
|
code,
|
@@ -215,7 +317,7 @@ class CodeInterpreterTool:
|
|
215
317
|
return {'exitcode': exitcode, 'log': logs_all}
|
216
318
|
|
217
319
|
# 获取文件
|
218
|
-
temp_output_dir = Path(temp_dir.name)
|
320
|
+
temp_output_dir = Path(temp_dir.name)
|
219
321
|
for root, dirs, files in os.walk(temp_output_dir):
|
220
322
|
for name in files:
|
221
323
|
file_name = os.path.join(root, name)
|
@@ -236,26 +338,3 @@ class CodeInterpreterTool:
|
|
236
338
|
description=self.description,
|
237
339
|
args_schema=self.args_schema,
|
238
340
|
)
|
239
|
-
|
240
|
-
|
241
|
-
if __name__ == '__main__':
|
242
|
-
code_string = """print('hha')"""
|
243
|
-
code_blocks = extract_code(code_string)
|
244
|
-
logger.info(code_blocks)
|
245
|
-
logs_all = ''
|
246
|
-
for i, code_block in enumerate(code_blocks):
|
247
|
-
lang, code = code_block
|
248
|
-
lang = infer_lang(code)
|
249
|
-
print(
|
250
|
-
colored(
|
251
|
-
f'\n>>>>>>>> EXECUTING CODE BLOCK {i} (inferred language is {lang})...',
|
252
|
-
'red',
|
253
|
-
),
|
254
|
-
flush=True,
|
255
|
-
)
|
256
|
-
exitcode, logs, image = execute_code(code, lang=lang)
|
257
|
-
logs_all += '\n' + logs
|
258
|
-
if exitcode != 0:
|
259
|
-
logger.error(f'{exitcode}, {logs_all}')
|
260
|
-
|
261
|
-
logger.info(logs_all)
|
@@ -0,0 +1,320 @@
|
|
1
|
+
import argparse
|
2
|
+
import copy
|
3
|
+
import inspect
|
4
|
+
import time
|
5
|
+
import os
|
6
|
+
from collections import defaultdict
|
7
|
+
|
8
|
+
import httpx
|
9
|
+
import pandas as pd
|
10
|
+
import yaml
|
11
|
+
from loguru import logger
|
12
|
+
from tqdm import tqdm
|
13
|
+
from bisheng_langchain.retrievers import EnsembleRetriever
|
14
|
+
from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
|
15
|
+
from langchain.chains.question_answering import load_qa_chain
|
16
|
+
from bisheng_langchain.rag.init_retrievers import (
|
17
|
+
BaselineVectorRetriever,
|
18
|
+
KeywordRetriever,
|
19
|
+
MixRetriever,
|
20
|
+
SmallerChunksVectorRetriever,
|
21
|
+
)
|
22
|
+
from bisheng_langchain.rag.scoring.ragas_score import RagScore
|
23
|
+
from bisheng_langchain.rag.utils import import_by_type, import_class
|
24
|
+
|
25
|
+
|
26
|
+
class BishengRagPipeline:
|
27
|
+
|
28
|
+
def __init__(self, yaml_path) -> None:
|
29
|
+
self.yaml_path = yaml_path
|
30
|
+
with open(self.yaml_path, 'r') as f:
|
31
|
+
self.params = yaml.safe_load(f)
|
32
|
+
|
33
|
+
# init data
|
34
|
+
self.origin_file_path = self.params['data']['origin_file_path']
|
35
|
+
self.question_path = self.params['data']['question']
|
36
|
+
self.save_answer_path = self.params['data']['save_answer']
|
37
|
+
|
38
|
+
# init embeddings
|
39
|
+
embedding_params = self.params['embedding']
|
40
|
+
embedding_object = import_by_type(_type='embeddings', name=embedding_params['type'])
|
41
|
+
if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
|
42
|
+
embedding_params.pop('type')
|
43
|
+
self.embeddings = embedding_object(
|
44
|
+
http_client=httpx.Client(proxies=embedding_params['openai_proxy']), **embedding_params
|
45
|
+
)
|
46
|
+
else:
|
47
|
+
embedding_params.pop('type')
|
48
|
+
self.embeddings = embedding_object(**embedding_params)
|
49
|
+
|
50
|
+
# init llm
|
51
|
+
llm_params = self.params['chat_llm']
|
52
|
+
llm_object = import_by_type(_type='llms', name=llm_params['type'])
|
53
|
+
if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
|
54
|
+
llm_params.pop('type')
|
55
|
+
self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
|
56
|
+
else:
|
57
|
+
llm_params.pop('type')
|
58
|
+
self.llm = llm_object(**llm_params)
|
59
|
+
|
60
|
+
# milvus
|
61
|
+
self.vector_store = Milvus(
|
62
|
+
embedding_function=self.embeddings,
|
63
|
+
connection_args={
|
64
|
+
"host": self.params['milvus']['host'],
|
65
|
+
"port": self.params['milvus']['port'],
|
66
|
+
},
|
67
|
+
)
|
68
|
+
|
69
|
+
# es
|
70
|
+
self.keyword_store = ElasticKeywordsSearch(
|
71
|
+
index_name='default_es',
|
72
|
+
elasticsearch_url=self.params['elasticsearch']['url'],
|
73
|
+
ssl_verify=self.params['elasticsearch']['ssl_verify'],
|
74
|
+
)
|
75
|
+
|
76
|
+
# init retriever
|
77
|
+
retriever_list = []
|
78
|
+
retrievers = self.params['retriever']['retrievers']
|
79
|
+
for retriever in retrievers:
|
80
|
+
retriever_type = retriever.pop('type')
|
81
|
+
retriever_params = {
|
82
|
+
'vector_store': self.vector_store,
|
83
|
+
'keyword_store': self.keyword_store,
|
84
|
+
'splitter_kwargs': retriever['splitter'],
|
85
|
+
'retrieval_kwargs': retriever['retrieval'],
|
86
|
+
}
|
87
|
+
retriever_list.append(self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
|
88
|
+
self.retriever = EnsembleRetriever(retrievers=retriever_list)
|
89
|
+
|
90
|
+
def _post_init_retriever(self, retriever_type, **kwargs):
|
91
|
+
retriever_classes = {
|
92
|
+
'KeywordRetriever': KeywordRetriever,
|
93
|
+
'BaselineVectorRetriever': BaselineVectorRetriever,
|
94
|
+
'MixRetriever': MixRetriever,
|
95
|
+
'SmallerChunksVectorRetriever': SmallerChunksVectorRetriever,
|
96
|
+
}
|
97
|
+
if retriever_type not in retriever_classes:
|
98
|
+
raise ValueError(f'Unknown retriever type: {retriever_type}')
|
99
|
+
|
100
|
+
input_kwargs = {}
|
101
|
+
splitter_params = kwargs.pop('splitter_kwargs')
|
102
|
+
for key, value in splitter_params.items():
|
103
|
+
splitter_obj = import_by_type(_type='textsplitters', name=value.pop('type'))
|
104
|
+
input_kwargs[key] = splitter_obj(**value)
|
105
|
+
|
106
|
+
retrieval_params = kwargs.pop('retrieval_kwargs')
|
107
|
+
for key, value in retrieval_params.items():
|
108
|
+
input_kwargs[key] = value
|
109
|
+
|
110
|
+
input_kwargs['vector_store'] = kwargs.pop('vector_store')
|
111
|
+
input_kwargs['keyword_store'] = kwargs.pop('keyword_store')
|
112
|
+
|
113
|
+
retriever_class = retriever_classes[retriever_type]
|
114
|
+
return retriever_class(**input_kwargs)
|
115
|
+
|
116
|
+
def file2knowledge(self):
|
117
|
+
"""
|
118
|
+
file to knowledge
|
119
|
+
"""
|
120
|
+
df = pd.read_excel(self.question_path)
|
121
|
+
if ('文件名' not in df.columns) or ('知识库名' not in df.columns):
|
122
|
+
raise Exception(f'文件名 or 知识库名 not in {self.question_path}.')
|
123
|
+
|
124
|
+
loader_params = self.params['loader']
|
125
|
+
loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
|
126
|
+
|
127
|
+
all_questions_info = df.to_dict('records')
|
128
|
+
collectionname2filename = defaultdict(set)
|
129
|
+
for info in all_questions_info:
|
130
|
+
# 存入set,去掉重复的文件名
|
131
|
+
collectionname2filename[info['知识库名']].add(info['文件名'])
|
132
|
+
|
133
|
+
for collection_name in tqdm(collectionname2filename):
|
134
|
+
all_file_paths = []
|
135
|
+
for file_name in collectionname2filename[collection_name]:
|
136
|
+
file_path = os.path.join(self.origin_file_path, file_name)
|
137
|
+
if not os.path.exists(file_path):
|
138
|
+
raise Exception(f'{file_path} not exists.')
|
139
|
+
# file path可以是文件夹或者单个文件
|
140
|
+
if os.path.isdir(file_path):
|
141
|
+
# 文件夹包含多个文件
|
142
|
+
all_file_paths.extend(
|
143
|
+
[os.path.join(file_path, name) for name in os.listdir(file_path) if not name.startswith('.')]
|
144
|
+
)
|
145
|
+
else:
|
146
|
+
# 单个文件
|
147
|
+
all_file_paths.append(file_path)
|
148
|
+
|
149
|
+
# 当前知识库需要存储的所有文件
|
150
|
+
collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
|
151
|
+
for index, each_file_path in enumerate(all_file_paths):
|
152
|
+
logger.info(f'each_file_path: {each_file_path}')
|
153
|
+
loader = loader_object(
|
154
|
+
file_name=os.path.basename(each_file_path), file_path=each_file_path, **loader_params
|
155
|
+
)
|
156
|
+
documents = loader.load()
|
157
|
+
logger.info(f'documents: {len(documents)}')
|
158
|
+
if len(documents[0].page_content) == 0:
|
159
|
+
logger.error(f'{each_file_path} page_content is empty.')
|
160
|
+
|
161
|
+
vector_drop_old = self.params['milvus']['drop_old'] if index == 0 else False
|
162
|
+
keyword_drop_old = self.params['elasticsearch']['drop_old'] if index == 0 else False
|
163
|
+
for idx, retriever in enumerate(self.retriever.retrievers):
|
164
|
+
retriever.add_documents(documents, f"{collection_name}_{idx}", vector_drop_old)
|
165
|
+
# retriever.add_documents(documents, collection_name, vector_drop_old)
|
166
|
+
|
167
|
+
def retrieval_and_rerank(self, question, collection_name):
|
168
|
+
"""
|
169
|
+
retrieval and rerank
|
170
|
+
"""
|
171
|
+
collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
|
172
|
+
|
173
|
+
# EnsembleRetriever直接检索召回会默认去重
|
174
|
+
# docs = self.retriever.get_relevant_documents(query=question, collection_name=collection_name)
|
175
|
+
docs = []
|
176
|
+
for idx, retriever in enumerate(self.retriever.retrievers):
|
177
|
+
docs.extend(retriever.get_relevant_documents(query=question, collection_name=f"{collection_name}_{idx}"))
|
178
|
+
# docs.extend(retriever.get_relevant_documents(query=question, collection_name=collection_name))
|
179
|
+
logger.info(f'retrieval docs: {len(docs)}')
|
180
|
+
|
181
|
+
# delete duplicate
|
182
|
+
if self.params['post_retrieval']['delete_duplicate']:
|
183
|
+
logger.info(f'origin docs: {len(docs)}')
|
184
|
+
all_contents = []
|
185
|
+
docs_no_dup = []
|
186
|
+
for index, doc in enumerate(docs):
|
187
|
+
doc_content = doc.page_content
|
188
|
+
if doc_content in all_contents:
|
189
|
+
continue
|
190
|
+
all_contents.append(doc_content)
|
191
|
+
docs_no_dup.append(doc)
|
192
|
+
docs = docs_no_dup
|
193
|
+
logger.info(f'delete duplicate docs: {len(docs)}')
|
194
|
+
|
195
|
+
# rerank
|
196
|
+
if self.params['post_retrieval']['with_rank'] and len(docs):
|
197
|
+
if not hasattr(self, 'ranker'):
|
198
|
+
rerank_params = self.params['post_retrieval']['rerank']
|
199
|
+
rerank_type = rerank_params.pop('type')
|
200
|
+
rerank_object = import_class(f'bisheng_langchain.rag.rerank.{rerank_type}')
|
201
|
+
self.ranker = rerank_object(**rerank_params)
|
202
|
+
docs = getattr(self, 'ranker').sort_and_filter(question, docs)
|
203
|
+
|
204
|
+
return docs
|
205
|
+
|
206
|
+
def load_documents(self, file_name, max_content=100000):
|
207
|
+
"""
|
208
|
+
max_content: max content len of llm
|
209
|
+
"""
|
210
|
+
file_path = os.path.join(self.origin_file_path, file_name)
|
211
|
+
if not os.path.exists(file_path):
|
212
|
+
raise Exception(f'{file_path} not exists.')
|
213
|
+
if os.path.isdir(file_path):
|
214
|
+
raise Exception(f'{file_path} is a directory.')
|
215
|
+
|
216
|
+
loader_params = copy.deepcopy(self.params['loader'])
|
217
|
+
loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
|
218
|
+
loader = loader_object(file_name=file_name, file_path=file_path, **loader_params)
|
219
|
+
|
220
|
+
documents = loader.load()
|
221
|
+
logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
|
222
|
+
for doc in documents:
|
223
|
+
doc.page_content = doc.page_content[:max_content]
|
224
|
+
return documents
|
225
|
+
|
226
|
+
def question_answering(self):
|
227
|
+
"""
|
228
|
+
question answer over knowledge
|
229
|
+
"""
|
230
|
+
df = pd.read_excel(self.question_path)
|
231
|
+
all_questions_info = df.to_dict('records')
|
232
|
+
if 'prompt_type' in self.params['generate']:
|
233
|
+
prompt_type = self.params['generate']['prompt_type']
|
234
|
+
prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
|
235
|
+
else:
|
236
|
+
prompt = None
|
237
|
+
qa_chain = load_qa_chain(
|
238
|
+
llm=self.llm, chain_type=self.params['generate']['chain_type'], prompt=prompt, verbose=False
|
239
|
+
)
|
240
|
+
file2docs = dict()
|
241
|
+
for questions_info in tqdm(all_questions_info):
|
242
|
+
question = questions_info['问题']
|
243
|
+
file_name = questions_info['文件名']
|
244
|
+
collection_name = questions_info['知识库名']
|
245
|
+
|
246
|
+
if self.params['generate']['with_retrieval']:
|
247
|
+
# retrieval and rerank
|
248
|
+
docs = self.retrieval_and_rerank(question, collection_name)
|
249
|
+
else:
|
250
|
+
# load document
|
251
|
+
if file_name not in file2docs:
|
252
|
+
docs = self.load_documents(file_name)
|
253
|
+
file2docs[file_name] = docs
|
254
|
+
else:
|
255
|
+
docs = file2docs[file_name]
|
256
|
+
|
257
|
+
# question answer
|
258
|
+
try:
|
259
|
+
ans = qa_chain({"input_documents": docs, "question": question}, return_only_outputs=False)
|
260
|
+
except Exception as e:
|
261
|
+
logger.error(f'question: {question}\nerror: {e}')
|
262
|
+
ans = {'output_text': str(e)}
|
263
|
+
|
264
|
+
# context = '\n\n'.join([doc.page_content for doc in docs])
|
265
|
+
# content = prompt.format(context=context, question=question)
|
266
|
+
|
267
|
+
# # for rate_limit
|
268
|
+
# time.sleep(15)
|
269
|
+
|
270
|
+
rag_answer = ans['output_text']
|
271
|
+
logger.info(f'question: {question}\nans: {rag_answer}\n')
|
272
|
+
questions_info['rag_answer'] = rag_answer
|
273
|
+
# questions_info['rag_context'] = '\n----------------\n'.join([doc.page_content for doc in docs])
|
274
|
+
# questions_info['rag_context'] = content
|
275
|
+
|
276
|
+
df = pd.DataFrame(all_questions_info)
|
277
|
+
df.to_excel(self.save_answer_path, index=False)
|
278
|
+
|
279
|
+
def score(self):
|
280
|
+
"""
|
281
|
+
score
|
282
|
+
"""
|
283
|
+
metric_params = self.params['metric']
|
284
|
+
if metric_params['type'] == 'bisheng-ragas':
|
285
|
+
score_params = {
|
286
|
+
'excel_path': self.save_answer_path,
|
287
|
+
'save_path': os.path.dirname(self.save_answer_path),
|
288
|
+
'question_column': metric_params['question_column'],
|
289
|
+
'gt_column': metric_params['gt_column'],
|
290
|
+
'answer_column': metric_params['answer_column'],
|
291
|
+
'query_type_column': metric_params.get('query_type_column', None),
|
292
|
+
'contexts_column': metric_params.get('contexts_column', None),
|
293
|
+
'metrics': metric_params['metrics'],
|
294
|
+
'batch_size': metric_params['batch_size'],
|
295
|
+
'gt_split_column': metric_params.get('gt_split_column', None),
|
296
|
+
'whether_gtsplit': metric_params.get('whether_gtsplit', False), # 是否需要模型对gt进行要点拆分
|
297
|
+
}
|
298
|
+
rag_score = RagScore(**score_params)
|
299
|
+
rag_score.score()
|
300
|
+
else:
|
301
|
+
# todo: 其他评分方法
|
302
|
+
pass
|
303
|
+
|
304
|
+
|
305
|
+
if __name__ == '__main__':
|
306
|
+
parser = argparse.ArgumentParser(description='Process some integers.')
|
307
|
+
# 添加参数
|
308
|
+
parser.add_argument('--mode', type=str, default='qa', help='upload or qa or score')
|
309
|
+
parser.add_argument('--params', type=str, default='config/test/baseline_s2b.yaml', help='bisheng rag params')
|
310
|
+
# 解析参数
|
311
|
+
args = parser.parse_args()
|
312
|
+
|
313
|
+
rag = BishengRagPipeline(args.params)
|
314
|
+
|
315
|
+
if args.mode == 'upload':
|
316
|
+
rag.file2knowledge()
|
317
|
+
elif args.mode == 'qa':
|
318
|
+
rag.question_answering()
|
319
|
+
elif args.mode == 'score':
|
320
|
+
rag.score()
|