bisheng-langchain 0.3.0rc0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. bisheng_langchain/chat_models/host_llm.py +1 -1
  2. bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +5 -3
  3. bisheng_langchain/gpts/agent_types/llm_functions_agent.py +7 -1
  4. bisheng_langchain/gpts/assistant.py +8 -5
  5. bisheng_langchain/gpts/auto_optimization.py +28 -27
  6. bisheng_langchain/gpts/auto_tool_selected.py +14 -15
  7. bisheng_langchain/gpts/load_tools.py +53 -1
  8. bisheng_langchain/gpts/prompts/__init__.py +4 -2
  9. bisheng_langchain/gpts/prompts/assistant_prompt_base.py +1 -0
  10. bisheng_langchain/gpts/prompts/assistant_prompt_cohere.py +19 -0
  11. bisheng_langchain/gpts/prompts/opening_dialog_prompt.py +1 -1
  12. bisheng_langchain/gpts/tools/api_tools/__init__.py +1 -1
  13. bisheng_langchain/gpts/tools/api_tools/base.py +3 -3
  14. bisheng_langchain/gpts/tools/api_tools/flow.py +19 -7
  15. bisheng_langchain/gpts/tools/api_tools/macro_data.py +175 -4
  16. bisheng_langchain/gpts/tools/api_tools/openapi.py +101 -0
  17. bisheng_langchain/gpts/tools/api_tools/sina.py +2 -2
  18. bisheng_langchain/gpts/tools/code_interpreter/tool.py +118 -39
  19. bisheng_langchain/rag/__init__.py +5 -0
  20. bisheng_langchain/rag/bisheng_rag_pipeline.py +320 -0
  21. bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +359 -0
  22. bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +376 -0
  23. bisheng_langchain/rag/bisheng_rag_tool.py +288 -0
  24. bisheng_langchain/rag/config/baseline.yaml +86 -0
  25. bisheng_langchain/rag/config/baseline_caibao.yaml +82 -0
  26. bisheng_langchain/rag/config/baseline_caibao_knowledge_v2.yaml +110 -0
  27. bisheng_langchain/rag/config/baseline_caibao_v2.yaml +112 -0
  28. bisheng_langchain/rag/config/baseline_demo_v2.yaml +92 -0
  29. bisheng_langchain/rag/config/baseline_s2b_mix.yaml +88 -0
  30. bisheng_langchain/rag/config/baseline_v2.yaml +90 -0
  31. bisheng_langchain/rag/extract_info.py +38 -0
  32. bisheng_langchain/rag/init_retrievers/__init__.py +4 -0
  33. bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +61 -0
  34. bisheng_langchain/rag/init_retrievers/keyword_retriever.py +65 -0
  35. bisheng_langchain/rag/init_retrievers/mix_retriever.py +103 -0
  36. bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +92 -0
  37. bisheng_langchain/rag/prompts/__init__.py +9 -0
  38. bisheng_langchain/rag/prompts/extract_key_prompt.py +34 -0
  39. bisheng_langchain/rag/prompts/prompt.py +47 -0
  40. bisheng_langchain/rag/prompts/prompt_cohere.py +111 -0
  41. bisheng_langchain/rag/qa_corpus/__init__.py +0 -0
  42. bisheng_langchain/rag/qa_corpus/qa_generator.py +143 -0
  43. bisheng_langchain/rag/rerank/__init__.py +5 -0
  44. bisheng_langchain/rag/rerank/rerank.py +48 -0
  45. bisheng_langchain/rag/rerank/rerank_benchmark.py +139 -0
  46. bisheng_langchain/rag/run_qa_gen_web.py +47 -0
  47. bisheng_langchain/rag/run_rag_evaluate_web.py +55 -0
  48. bisheng_langchain/rag/scoring/__init__.py +0 -0
  49. bisheng_langchain/rag/scoring/llama_index_score.py +91 -0
  50. bisheng_langchain/rag/scoring/ragas_score.py +183 -0
  51. bisheng_langchain/rag/utils.py +181 -0
  52. bisheng_langchain/retrievers/ensemble.py +2 -1
  53. bisheng_langchain/vectorstores/elastic_keywords_search.py +2 -1
  54. {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/METADATA +1 -1
  55. {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/RECORD +57 -22
  56. bisheng_langchain/gpts/prompts/base_prompt.py +0 -1
  57. {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/WHEEL +0 -0
  58. {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,101 @@
1
+ from typing import Any
2
+
3
+ from langchain_core.tools import BaseTool
4
+ from loguru import logger
5
+ from pydantic import BaseModel, create_model
6
+
7
+ from .base import APIToolBase, MultArgsSchemaTool, Field
8
+
9
+
10
+ class OpenApiTools(APIToolBase):
11
+
12
+ def get_real_path(self):
13
+ return self.url + self.params["path"]
14
+
15
+ def get_request_method(self):
16
+ return self.params["method"]
17
+
18
+ def get_params_json(self, **kwargs):
19
+ params_define = {}
20
+ for one in self.params["parameters"]:
21
+ params_define[one["name"]] = one
22
+
23
+ params = {}
24
+ json_data = {}
25
+ for k, v in kwargs.items():
26
+ if params_define.get(k):
27
+ if params_define[k]["in"] == "query":
28
+ params[k] = v
29
+ else:
30
+ json_data[k] = v
31
+ else:
32
+ params[k] = v
33
+ return params, json_data
34
+
35
+ def parse_args_schema(self):
36
+ params = self.params["parameters"]
37
+ model_params = {}
38
+ for one in params:
39
+ field_type = one["schema"]["type"]
40
+ if field_type == "number":
41
+ field_type = "float"
42
+ elif field_type == "integer":
43
+ field_type = "int"
44
+ elif field_type == "string":
45
+ field_type = "str"
46
+ else:
47
+ raise Exception(f"schema type is not support: {field_type}")
48
+ model_params[one["name"]] = (eval(field_type), Field(description=one["description"]))
49
+ return create_model("InputArgs", __module__='bisheng_langchain.gpts.tools.api_tools.openapi',
50
+ __base__=BaseModel, **model_params)
51
+
52
+ def run(self, **kwargs) -> str:
53
+ """Run query through api and parse result."""
54
+ path = self.get_real_path()
55
+ logger.info('api_call url={}', path)
56
+ method = self.get_request_method()
57
+ params, json_data = self.get_params_json(**kwargs)
58
+
59
+ if method == "get":
60
+ resp = self.client.get(path, params=params)
61
+ elif method == 'post':
62
+ resp = self.client.post(path, params=params, json=self.params)
63
+ elif method == 'put':
64
+ resp = self.client.put(path, params=params, json=self.params)
65
+ elif method == 'delete':
66
+ resp = self.client.delete(path, params=params, json=self.params)
67
+ else:
68
+ raise Exception(f"http method is not support: {method}")
69
+ if resp.status_code != 200:
70
+ logger.info(f'api_call_fail code={resp.status_code} res={resp.text}')
71
+ raise Exception(f"api_call_fail: {resp.status_code} {resp.text}")
72
+ return resp.text
73
+
74
+ async def arun(self, **kwargs) -> str:
75
+ """Run query through api and parse result."""
76
+ path = self.get_real_path()
77
+ logger.info('api_call url={}', path)
78
+ method = self.get_request_method()
79
+ params, json_data = self.get_params_json(**kwargs)
80
+
81
+ if method == "get":
82
+ resp = await self.async_client.aget(path, params=params)
83
+ elif method == 'post':
84
+ resp = await self.async_client.apost(path, params=params, json=self.params)
85
+ elif method == 'put':
86
+ resp = await self.async_client.aput(path, params=params, json=self.params)
87
+ elif method == 'delete':
88
+ resp = await self.async_client.adelete(path, params=params, json=self.params)
89
+ else:
90
+ raise Exception(f"http method is not support: {method}")
91
+ return resp
92
+
93
+ @classmethod
94
+ def get_api_tool(cls, name, **kwargs: Any) -> BaseTool:
95
+ description = kwargs.pop("description", "")
96
+ obj = cls(**kwargs)
97
+ return MultArgsSchemaTool(name=name,
98
+ description=description,
99
+ func=obj.run,
100
+ coroutine=obj.arun,
101
+ args_schema=obj.parse_args_schema())
@@ -154,7 +154,7 @@ class StockInfo(APIToolBase):
154
154
  resp = super().run(query=stock_number)
155
155
  stock = self.devideStock(resp)[0]
156
156
  if isinstance(stock, Stock):
157
- return json.dumps(stock.__dict__)
157
+ return json.dumps(stock.__dict__, ensure_ascii=False)
158
158
  else:
159
159
  return stock
160
160
 
@@ -183,7 +183,7 @@ class StockInfo(APIToolBase):
183
183
  resp = await super().arun(query=stock_number)
184
184
  stock = self.devideStock(resp)[0]
185
185
  if isinstance(stock, Stock):
186
- return json.dumps(stock.__dict__)
186
+ return json.dumps(stock.__dict__, ensure_ascii=False)
187
187
  else:
188
188
  return stock
189
189
 
@@ -1,6 +1,8 @@
1
+ import glob
1
2
  import itertools
2
3
  import os
3
4
  import pathlib
5
+ import re
4
6
  import subprocess
5
7
  import sys
6
8
  import tempfile
@@ -11,24 +13,18 @@ from pathlib import Path
11
13
  from typing import Dict, List, Optional, Tuple, Type
12
14
  from uuid import uuid4
13
15
 
14
- from autogen.code_utils import extract_code, infer_lang
16
+ import matplotlib
15
17
  from langchain_community.tools import Tool
16
18
  from langchain_core.pydantic_v1 import BaseModel, Field
17
19
  from loguru import logger
18
20
 
19
- try:
20
- from termcolor import colored
21
- except ImportError:
22
-
23
- def colored(x, *args, **kwargs):
24
- return x
25
-
26
-
21
+ CODE_BLOCK_PATTERN = r"```(\w*)\n(.*?)\n```"
27
22
  DEFAULT_TIMEOUT = 600
28
23
  WIN32 = sys.platform == 'win32'
29
24
  PATH_SEPARATOR = WIN32 and '\\' or '/'
30
25
  WORKING_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'extensions')
31
26
  TIMEOUT_MSG = 'Timeout'
27
+ UNKNOWN = "unknown"
32
28
 
33
29
 
34
30
  def _cmd(lang):
@@ -41,6 +37,61 @@ def _cmd(lang):
41
37
  raise NotImplementedError(f'{lang} not recognized in code execution')
42
38
 
43
39
 
40
+ def infer_lang(code):
41
+ """infer the language for the code.
42
+ TODO: make it robust.
43
+ """
44
+ if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "):
45
+ return "sh"
46
+
47
+ # check if code is a valid python code
48
+ try:
49
+ compile(code, "test", "exec")
50
+ return "python"
51
+ except SyntaxError:
52
+ # not a valid python code
53
+ return UNKNOWN
54
+
55
+
56
+ def extract_code(
57
+ text: str, pattern: str = CODE_BLOCK_PATTERN, detect_single_line_code: bool = False
58
+ ) -> List[Tuple[str, str]]:
59
+ """Extract code from a text.
60
+
61
+ Args:
62
+ text (str): The text to extract code from.
63
+ pattern (str, optional): The regular expression pattern for finding the
64
+ code block. Defaults to CODE_BLOCK_PATTERN.
65
+ detect_single_line_code (bool, optional): Enable the new feature for
66
+ extracting single line code. Defaults to False.
67
+
68
+ Returns:
69
+ list: A list of tuples, each containing the language and the code.
70
+ If there is no code block in the input text, the language would be "unknown".
71
+ If there is code block but the language is not specified, the language would be "".
72
+ """
73
+ if not detect_single_line_code:
74
+ match = re.findall(pattern, text, flags=re.DOTALL)
75
+ return match if match else [(UNKNOWN, text)]
76
+
77
+ # Extract both multi-line and single-line code block, separated by the | operator
78
+ # `{3}(\w+)?\s*([\s\S]*?)`{3}: Matches multi-line code blocks.
79
+ # The (\w+)? matches the language, where the ? indicates it is optional.
80
+ # `([^`]+)`: Matches inline code.
81
+ code_pattern = re.compile(r"`{3}(\w+)?\s*([\s\S]*?)`{3}|`([^`]+)`")
82
+ code_blocks = code_pattern.findall(text)
83
+
84
+ # Extract the individual code blocks and languages from the matched groups
85
+ extracted = []
86
+ for lang, group1, group2 in code_blocks:
87
+ if group1:
88
+ extracted.append((lang.strip(), group1.strip()))
89
+ elif group2:
90
+ extracted.append(("", group2.strip()))
91
+
92
+ return extracted
93
+
94
+
44
95
  def execute_code(
45
96
  code: Optional[str] = None,
46
97
  timeout: Optional[int] = None,
@@ -121,16 +172,66 @@ def head_file(path: str, n: int) -> List[str]:
121
172
  return []
122
173
 
123
174
 
124
- def upload_minio(param: dict, bucket: str, object_name: str, file_path, content_type='application/text'):
175
+ def upload_minio(
176
+ param: dict,
177
+ bucket: str,
178
+ object_name: str,
179
+ file_path,
180
+ content_type='application/text',
181
+ ):
125
182
  # 初始化minio
126
183
  import minio
127
184
 
128
- minio_client = minio.Minio(**param)
129
- logger.debug('upload_file obj={} bucket={} file_paht={}', object_name, bucket, file_path)
185
+ minio_client = minio.Minio(
186
+ endpoint=param.get('MINIO_ENDPOINT'),
187
+ access_key=param.get('MINIO_ACCESS_KEY'),
188
+ secret_key=param.get('MINIO_SECRET_KEY'),
189
+ secure=param.get('SCHEMA'),
190
+ cert_check=param.get('CERT_CHECK'),
191
+ )
192
+ minio_share = minio.Minio(
193
+ endpoint=param.get('MINIO_SHAREPOIN'),
194
+ access_key=param.get('MINIO_ACCESS_KEY'),
195
+ secret_key=param.get('MINIO_SECRET_KEY'),
196
+ secure=param.get('SCHEMA'),
197
+ cert_check=param.get('CERT_CHECK'),
198
+ )
199
+ logger.debug(
200
+ 'upload_file obj={} bucket={} file_paht={}',
201
+ object_name,
202
+ bucket,
203
+ file_path,
204
+ )
130
205
  minio_client.fput_object(
131
- bucket_name=bucket, object_name=object_name, file_path=file_path, content_type=content_type
206
+ bucket_name=bucket,
207
+ object_name=object_name,
208
+ file_path=file_path,
209
+ content_type=content_type,
210
+ )
211
+ return minio_share.presigned_get_object(
212
+ bucket_name=bucket,
213
+ object_name=object_name,
214
+ expires=timedelta(days=7),
132
215
  )
133
- return minio_client.presigned_get_object(bucket_name=bucket, object_name=object_name, expires=timedelta(days=7))
216
+
217
+
218
+ def insert_set_font_code(code: str) -> str:
219
+ """判断python代码中是否导入了matplotlib库,如果有则插入设置字体的代码"""
220
+
221
+ split_code = code.split('\n')
222
+ cache_file = matplotlib.get_cachedir()
223
+ font_cache = glob.glob(f'{cache_file}/fontlist*')
224
+
225
+ for cache in font_cache:
226
+ os.remove(cache)
227
+
228
+ # todo: 如果生成的代码中已经有了设置字体的代码,可能会导致该段代码失效
229
+ if 'matplotlib' in code:
230
+ pattern = re.compile(r'(import matplotlib|from matplotlib)')
231
+ index = max(i for i, line in enumerate(split_code) if pattern.search(line))
232
+ split_code.insert(index + 1, 'import matplotlib\nmatplotlib.rc("font", family="WenQuanYi Zen Hei")')
233
+
234
+ return '\n'.join(split_code)
134
235
 
135
236
 
136
237
  class CodeInterpreterToolArguments(BaseModel):
@@ -169,7 +270,7 @@ class FileInfo(BaseModel):
169
270
  class CodeInterpreterTool:
170
271
  """Tool for evaluating python code in native environment."""
171
272
 
172
- name = 'code_interpreter'
273
+ name = 'bisheng_code_interpreter'
173
274
  args_schema: Type[BaseModel] = CodeInterpreterToolArguments
174
275
 
175
276
  def __init__(
@@ -204,6 +305,7 @@ class CodeInterpreterTool:
204
305
  for i, code_block in enumerate(code_blocks):
205
306
  lang, code = code_block
206
307
  lang = infer_lang(code)
308
+ code = insert_set_font_code(code)
207
309
  temp_dir = tempfile.TemporaryDirectory()
208
310
  exitcode, logs, _ = execute_code(
209
311
  code,
@@ -215,7 +317,7 @@ class CodeInterpreterTool:
215
317
  return {'exitcode': exitcode, 'log': logs_all}
216
318
 
217
319
  # 获取文件
218
- temp_output_dir = Path(temp_dir.name) / 'output'
320
+ temp_output_dir = Path(temp_dir.name)
219
321
  for root, dirs, files in os.walk(temp_output_dir):
220
322
  for name in files:
221
323
  file_name = os.path.join(root, name)
@@ -236,26 +338,3 @@ class CodeInterpreterTool:
236
338
  description=self.description,
237
339
  args_schema=self.args_schema,
238
340
  )
239
-
240
-
241
- if __name__ == '__main__':
242
- code_string = """print('hha')"""
243
- code_blocks = extract_code(code_string)
244
- logger.info(code_blocks)
245
- logs_all = ''
246
- for i, code_block in enumerate(code_blocks):
247
- lang, code = code_block
248
- lang = infer_lang(code)
249
- print(
250
- colored(
251
- f'\n>>>>>>>> EXECUTING CODE BLOCK {i} (inferred language is {lang})...',
252
- 'red',
253
- ),
254
- flush=True,
255
- )
256
- exitcode, logs, image = execute_code(code, lang=lang)
257
- logs_all += '\n' + logs
258
- if exitcode != 0:
259
- logger.error(f'{exitcode}, {logs_all}')
260
-
261
- logger.info(logs_all)
@@ -0,0 +1,5 @@
1
+ from bisheng_langchain.rag.bisheng_rag_tool import BishengRAGTool
2
+
3
+ __all__ = [
4
+ "BishengRAGTool",
5
+ ]
@@ -0,0 +1,320 @@
1
+ import argparse
2
+ import copy
3
+ import inspect
4
+ import time
5
+ import os
6
+ from collections import defaultdict
7
+
8
+ import httpx
9
+ import pandas as pd
10
+ import yaml
11
+ from loguru import logger
12
+ from tqdm import tqdm
13
+ from bisheng_langchain.retrievers import EnsembleRetriever
14
+ from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
15
+ from langchain.chains.question_answering import load_qa_chain
16
+ from bisheng_langchain.rag.init_retrievers import (
17
+ BaselineVectorRetriever,
18
+ KeywordRetriever,
19
+ MixRetriever,
20
+ SmallerChunksVectorRetriever,
21
+ )
22
+ from bisheng_langchain.rag.scoring.ragas_score import RagScore
23
+ from bisheng_langchain.rag.utils import import_by_type, import_class
24
+
25
+
26
+ class BishengRagPipeline:
27
+
28
+ def __init__(self, yaml_path) -> None:
29
+ self.yaml_path = yaml_path
30
+ with open(self.yaml_path, 'r') as f:
31
+ self.params = yaml.safe_load(f)
32
+
33
+ # init data
34
+ self.origin_file_path = self.params['data']['origin_file_path']
35
+ self.question_path = self.params['data']['question']
36
+ self.save_answer_path = self.params['data']['save_answer']
37
+
38
+ # init embeddings
39
+ embedding_params = self.params['embedding']
40
+ embedding_object = import_by_type(_type='embeddings', name=embedding_params['type'])
41
+ if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
42
+ embedding_params.pop('type')
43
+ self.embeddings = embedding_object(
44
+ http_client=httpx.Client(proxies=embedding_params['openai_proxy']), **embedding_params
45
+ )
46
+ else:
47
+ embedding_params.pop('type')
48
+ self.embeddings = embedding_object(**embedding_params)
49
+
50
+ # init llm
51
+ llm_params = self.params['chat_llm']
52
+ llm_object = import_by_type(_type='llms', name=llm_params['type'])
53
+ if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
54
+ llm_params.pop('type')
55
+ self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
56
+ else:
57
+ llm_params.pop('type')
58
+ self.llm = llm_object(**llm_params)
59
+
60
+ # milvus
61
+ self.vector_store = Milvus(
62
+ embedding_function=self.embeddings,
63
+ connection_args={
64
+ "host": self.params['milvus']['host'],
65
+ "port": self.params['milvus']['port'],
66
+ },
67
+ )
68
+
69
+ # es
70
+ self.keyword_store = ElasticKeywordsSearch(
71
+ index_name='default_es',
72
+ elasticsearch_url=self.params['elasticsearch']['url'],
73
+ ssl_verify=self.params['elasticsearch']['ssl_verify'],
74
+ )
75
+
76
+ # init retriever
77
+ retriever_list = []
78
+ retrievers = self.params['retriever']['retrievers']
79
+ for retriever in retrievers:
80
+ retriever_type = retriever.pop('type')
81
+ retriever_params = {
82
+ 'vector_store': self.vector_store,
83
+ 'keyword_store': self.keyword_store,
84
+ 'splitter_kwargs': retriever['splitter'],
85
+ 'retrieval_kwargs': retriever['retrieval'],
86
+ }
87
+ retriever_list.append(self._post_init_retriever(retriever_type=retriever_type, **retriever_params))
88
+ self.retriever = EnsembleRetriever(retrievers=retriever_list)
89
+
90
+ def _post_init_retriever(self, retriever_type, **kwargs):
91
+ retriever_classes = {
92
+ 'KeywordRetriever': KeywordRetriever,
93
+ 'BaselineVectorRetriever': BaselineVectorRetriever,
94
+ 'MixRetriever': MixRetriever,
95
+ 'SmallerChunksVectorRetriever': SmallerChunksVectorRetriever,
96
+ }
97
+ if retriever_type not in retriever_classes:
98
+ raise ValueError(f'Unknown retriever type: {retriever_type}')
99
+
100
+ input_kwargs = {}
101
+ splitter_params = kwargs.pop('splitter_kwargs')
102
+ for key, value in splitter_params.items():
103
+ splitter_obj = import_by_type(_type='textsplitters', name=value.pop('type'))
104
+ input_kwargs[key] = splitter_obj(**value)
105
+
106
+ retrieval_params = kwargs.pop('retrieval_kwargs')
107
+ for key, value in retrieval_params.items():
108
+ input_kwargs[key] = value
109
+
110
+ input_kwargs['vector_store'] = kwargs.pop('vector_store')
111
+ input_kwargs['keyword_store'] = kwargs.pop('keyword_store')
112
+
113
+ retriever_class = retriever_classes[retriever_type]
114
+ return retriever_class(**input_kwargs)
115
+
116
+ def file2knowledge(self):
117
+ """
118
+ file to knowledge
119
+ """
120
+ df = pd.read_excel(self.question_path)
121
+ if ('文件名' not in df.columns) or ('知识库名' not in df.columns):
122
+ raise Exception(f'文件名 or 知识库名 not in {self.question_path}.')
123
+
124
+ loader_params = self.params['loader']
125
+ loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
126
+
127
+ all_questions_info = df.to_dict('records')
128
+ collectionname2filename = defaultdict(set)
129
+ for info in all_questions_info:
130
+ # 存入set,去掉重复的文件名
131
+ collectionname2filename[info['知识库名']].add(info['文件名'])
132
+
133
+ for collection_name in tqdm(collectionname2filename):
134
+ all_file_paths = []
135
+ for file_name in collectionname2filename[collection_name]:
136
+ file_path = os.path.join(self.origin_file_path, file_name)
137
+ if not os.path.exists(file_path):
138
+ raise Exception(f'{file_path} not exists.')
139
+ # file path可以是文件夹或者单个文件
140
+ if os.path.isdir(file_path):
141
+ # 文件夹包含多个文件
142
+ all_file_paths.extend(
143
+ [os.path.join(file_path, name) for name in os.listdir(file_path) if not name.startswith('.')]
144
+ )
145
+ else:
146
+ # 单个文件
147
+ all_file_paths.append(file_path)
148
+
149
+ # 当前知识库需要存储的所有文件
150
+ collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
151
+ for index, each_file_path in enumerate(all_file_paths):
152
+ logger.info(f'each_file_path: {each_file_path}')
153
+ loader = loader_object(
154
+ file_name=os.path.basename(each_file_path), file_path=each_file_path, **loader_params
155
+ )
156
+ documents = loader.load()
157
+ logger.info(f'documents: {len(documents)}')
158
+ if len(documents[0].page_content) == 0:
159
+ logger.error(f'{each_file_path} page_content is empty.')
160
+
161
+ vector_drop_old = self.params['milvus']['drop_old'] if index == 0 else False
162
+ keyword_drop_old = self.params['elasticsearch']['drop_old'] if index == 0 else False
163
+ for idx, retriever in enumerate(self.retriever.retrievers):
164
+ retriever.add_documents(documents, f"{collection_name}_{idx}", vector_drop_old)
165
+ # retriever.add_documents(documents, collection_name, vector_drop_old)
166
+
167
+ def retrieval_and_rerank(self, question, collection_name):
168
+ """
169
+ retrieval and rerank
170
+ """
171
+ collection_name = f"{collection_name}_{self.params['retriever']['suffix']}"
172
+
173
+ # EnsembleRetriever直接检索召回会默认去重
174
+ # docs = self.retriever.get_relevant_documents(query=question, collection_name=collection_name)
175
+ docs = []
176
+ for idx, retriever in enumerate(self.retriever.retrievers):
177
+ docs.extend(retriever.get_relevant_documents(query=question, collection_name=f"{collection_name}_{idx}"))
178
+ # docs.extend(retriever.get_relevant_documents(query=question, collection_name=collection_name))
179
+ logger.info(f'retrieval docs: {len(docs)}')
180
+
181
+ # delete duplicate
182
+ if self.params['post_retrieval']['delete_duplicate']:
183
+ logger.info(f'origin docs: {len(docs)}')
184
+ all_contents = []
185
+ docs_no_dup = []
186
+ for index, doc in enumerate(docs):
187
+ doc_content = doc.page_content
188
+ if doc_content in all_contents:
189
+ continue
190
+ all_contents.append(doc_content)
191
+ docs_no_dup.append(doc)
192
+ docs = docs_no_dup
193
+ logger.info(f'delete duplicate docs: {len(docs)}')
194
+
195
+ # rerank
196
+ if self.params['post_retrieval']['with_rank'] and len(docs):
197
+ if not hasattr(self, 'ranker'):
198
+ rerank_params = self.params['post_retrieval']['rerank']
199
+ rerank_type = rerank_params.pop('type')
200
+ rerank_object = import_class(f'bisheng_langchain.rag.rerank.{rerank_type}')
201
+ self.ranker = rerank_object(**rerank_params)
202
+ docs = getattr(self, 'ranker').sort_and_filter(question, docs)
203
+
204
+ return docs
205
+
206
+ def load_documents(self, file_name, max_content=100000):
207
+ """
208
+ max_content: max content len of llm
209
+ """
210
+ file_path = os.path.join(self.origin_file_path, file_name)
211
+ if not os.path.exists(file_path):
212
+ raise Exception(f'{file_path} not exists.')
213
+ if os.path.isdir(file_path):
214
+ raise Exception(f'{file_path} is a directory.')
215
+
216
+ loader_params = copy.deepcopy(self.params['loader'])
217
+ loader_object = import_by_type(_type='documentloaders', name=loader_params.pop('type'))
218
+ loader = loader_object(file_name=file_name, file_path=file_path, **loader_params)
219
+
220
+ documents = loader.load()
221
+ logger.info(f'documents: {len(documents)}, page_content: {len(documents[0].page_content)}')
222
+ for doc in documents:
223
+ doc.page_content = doc.page_content[:max_content]
224
+ return documents
225
+
226
+ def question_answering(self):
227
+ """
228
+ question answer over knowledge
229
+ """
230
+ df = pd.read_excel(self.question_path)
231
+ all_questions_info = df.to_dict('records')
232
+ if 'prompt_type' in self.params['generate']:
233
+ prompt_type = self.params['generate']['prompt_type']
234
+ prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
235
+ else:
236
+ prompt = None
237
+ qa_chain = load_qa_chain(
238
+ llm=self.llm, chain_type=self.params['generate']['chain_type'], prompt=prompt, verbose=False
239
+ )
240
+ file2docs = dict()
241
+ for questions_info in tqdm(all_questions_info):
242
+ question = questions_info['问题']
243
+ file_name = questions_info['文件名']
244
+ collection_name = questions_info['知识库名']
245
+
246
+ if self.params['generate']['with_retrieval']:
247
+ # retrieval and rerank
248
+ docs = self.retrieval_and_rerank(question, collection_name)
249
+ else:
250
+ # load document
251
+ if file_name not in file2docs:
252
+ docs = self.load_documents(file_name)
253
+ file2docs[file_name] = docs
254
+ else:
255
+ docs = file2docs[file_name]
256
+
257
+ # question answer
258
+ try:
259
+ ans = qa_chain({"input_documents": docs, "question": question}, return_only_outputs=False)
260
+ except Exception as e:
261
+ logger.error(f'question: {question}\nerror: {e}')
262
+ ans = {'output_text': str(e)}
263
+
264
+ # context = '\n\n'.join([doc.page_content for doc in docs])
265
+ # content = prompt.format(context=context, question=question)
266
+
267
+ # # for rate_limit
268
+ # time.sleep(15)
269
+
270
+ rag_answer = ans['output_text']
271
+ logger.info(f'question: {question}\nans: {rag_answer}\n')
272
+ questions_info['rag_answer'] = rag_answer
273
+ # questions_info['rag_context'] = '\n----------------\n'.join([doc.page_content for doc in docs])
274
+ # questions_info['rag_context'] = content
275
+
276
+ df = pd.DataFrame(all_questions_info)
277
+ df.to_excel(self.save_answer_path, index=False)
278
+
279
+ def score(self):
280
+ """
281
+ score
282
+ """
283
+ metric_params = self.params['metric']
284
+ if metric_params['type'] == 'bisheng-ragas':
285
+ score_params = {
286
+ 'excel_path': self.save_answer_path,
287
+ 'save_path': os.path.dirname(self.save_answer_path),
288
+ 'question_column': metric_params['question_column'],
289
+ 'gt_column': metric_params['gt_column'],
290
+ 'answer_column': metric_params['answer_column'],
291
+ 'query_type_column': metric_params.get('query_type_column', None),
292
+ 'contexts_column': metric_params.get('contexts_column', None),
293
+ 'metrics': metric_params['metrics'],
294
+ 'batch_size': metric_params['batch_size'],
295
+ 'gt_split_column': metric_params.get('gt_split_column', None),
296
+ 'whether_gtsplit': metric_params.get('whether_gtsplit', False), # 是否需要模型对gt进行要点拆分
297
+ }
298
+ rag_score = RagScore(**score_params)
299
+ rag_score.score()
300
+ else:
301
+ # todo: 其他评分方法
302
+ pass
303
+
304
+
305
+ if __name__ == '__main__':
306
+ parser = argparse.ArgumentParser(description='Process some integers.')
307
+ # 添加参数
308
+ parser.add_argument('--mode', type=str, default='qa', help='upload or qa or score')
309
+ parser.add_argument('--params', type=str, default='config/test/baseline_s2b.yaml', help='bisheng rag params')
310
+ # 解析参数
311
+ args = parser.parse_args()
312
+
313
+ rag = BishengRagPipeline(args.params)
314
+
315
+ if args.mode == 'upload':
316
+ rag.file2knowledge()
317
+ elif args.mode == 'qa':
318
+ rag.question_answering()
319
+ elif args.mode == 'score':
320
+ rag.score()