bisheng-langchain 0.3.1.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/chains/__init__.py +4 -1
- bisheng_langchain/chains/qa_generation/__init__.py +0 -0
- bisheng_langchain/chains/qa_generation/base.py +128 -0
- bisheng_langchain/chains/qa_generation/base_v2.py +413 -0
- bisheng_langchain/chains/qa_generation/prompt.py +53 -0
- bisheng_langchain/chains/qa_generation/prompt_v2.py +155 -0
- bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +36 -9
- bisheng_langchain/document_loaders/parsers/ellm_client.py +7 -9
- bisheng_langchain/document_loaders/universal_kv.py +4 -3
- bisheng_langchain/gpts/tools/api_tools/openapi.py +7 -7
- bisheng_langchain/rag/__init__.py +2 -0
- bisheng_langchain/rag/bisheng_rag_chain.py +164 -0
- bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +8 -2
- bisheng_langchain/rag/bisheng_rag_tool.py +47 -24
- bisheng_langchain/rag/config/baseline_caibao_v2.yaml +1 -1
- bisheng_langchain/rag/config/baseline_v2.yaml +3 -2
- bisheng_langchain/rag/prompts/prompt.py +1 -1
- bisheng_langchain/rag/qa_corpus/qa_generator.py +1 -1
- bisheng_langchain/rag/scoring/ragas_score.py +2 -2
- bisheng_langchain/rag/utils.py +27 -4
- bisheng_langchain/sql/__init__.py +3 -0
- bisheng_langchain/sql/base.py +120 -0
- bisheng_langchain/text_splitter.py +1 -1
- {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/METADATA +3 -1
- {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/RECORD +27 -20
- bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +0 -376
- {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/WHEEL +0 -0
- {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,155 @@
|
|
1
|
+
from langchain_core.prompts.chat import (
|
2
|
+
ChatPromptTemplate,
|
3
|
+
HumanMessagePromptTemplate,
|
4
|
+
SystemMessagePromptTemplate,
|
5
|
+
)
|
6
|
+
|
7
|
+
|
8
|
+
SEED_QUESTION_SYSTEM = SystemMessagePromptTemplate.from_template(
|
9
|
+
"""\
|
10
|
+
您的任务是遵循以下规则从给定的上下文中提出一个问题,规则如下:
|
11
|
+
|
12
|
+
1. 即使在没有给定上下文的情况下,问题也应该对人类有意义。
|
13
|
+
2. 应该可以从给定上下文中完全回答问题。
|
14
|
+
3. 问题应该来自包含重要信息的上下文部分。它也可以来自表格、段落、或者代码等。
|
15
|
+
4. 回答问题时不应包含任何链接。
|
16
|
+
5. 问题的难度应该是中等的。
|
17
|
+
6. 问题必须是合理的,并且必须能被人理解和回答。
|
18
|
+
7. 不要在问题中使用“提供的上下文”等短语。
|
19
|
+
8. 避免使用可以分解成多个问题的“和”字样来构建问题。
|
20
|
+
9. 如果上下文是中文,那么问题也应该是中文的。
|
21
|
+
|
22
|
+
Examples:
|
23
|
+
context:武汉达梦数据库股份有限公司 招股说明书 (申报稿) 1-1-226 表中作出恰当列报。 2、研发费用 2021年度、 2020年度、 2019 年度,达梦数据 研发费用金额分别 为11,786.99 万元、 9,660.26 万元、 6,255.86万元, 各年度研发费用占营 业收入的比例分别为 15.86 % 、 21.46 %、20.74 %。 由于研发投入金额及其占当期 营业收入的比例是 达梦数据 的关键 指标之一,可能存在因为核算不准 确而导致的错报风险。因此, 中天 运会计师 将研发费用的归集和核算 确定为关键审计事项。 针对研发费用的真实性与准确性,会计师执行的 重要审计程序主要包括: (1)了解与研发费用相关的关键内部控制,评价 这些控制的设计,确定其是否得到执行,并对相关内 部控制的运行有效性进行测试; (2)获取研发项目立项、审批资料,抽查重要研 发项目的过程文档,判断研发项目的真实性; (3)获取研发费用按项目、性质分类明细表,分
|
24
|
+
question:达梦2021年的研发费用占营业收入的比例是多少?
|
25
|
+
|
26
|
+
context:武汉达梦数据库股份有限公司 招股说明书 (申报稿) 1-1-329 (2)存货周转率 公司与同行业可比公司存货周转率对比情况如下: 公司简称 2021年度 2020年度 2019年度 中望软件 6.93 5.62 10.66 星环科技 3.38 3.21 2.24 金山办公 212.60 175.46 162.91 平均值 74.30 61.43 58.60 本公司 1.13 0.57 0.87 数据来源:可比公司招股说明书、定期报告。 报告期各期, 公司存货周转率显著低于同行业可比公司存货周转率平均水平, 主要是因为公司将未验收的数据及行业解决方案项目所发生的累 计成本均作为 存货核算。报告期各期末,公司存在 “湖北省司法行政数据中心项目 ”、“政法云 大数据中心基础设施服务及大数据中心软件采购 项目”等金额较大且实施周期较 长的数据及行业解决方案项目,导致年末存货金额较大。
|
27
|
+
question:达梦2021年的存货周转率相较于前一年有何变化?
|
28
|
+
""" # noqa: E501
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
SEED_QUESTION_HUMAN = HumanMessagePromptTemplate.from_template(
|
33
|
+
"""
|
34
|
+
context:{context}
|
35
|
+
question:
|
36
|
+
"""
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
SEED_QUESTION_CHAT_PROMPT = ChatPromptTemplate.from_messages(
|
41
|
+
[
|
42
|
+
SEED_QUESTION_SYSTEM,
|
43
|
+
SEED_QUESTION_HUMAN
|
44
|
+
]
|
45
|
+
)
|
46
|
+
|
47
|
+
|
48
|
+
SCORE_CONTEXT_SYSTEM = SystemMessagePromptTemplate.from_template(
|
49
|
+
"""Evaluate the provided context and assign a numerical score between 0 and 10 based on the following criteria:
|
50
|
+
1. Award a high score to context that thoroughly delves into and explains concepts.
|
51
|
+
2. Assign a lower score to context that contains excessive references, acknowledgments, external links, personal information, or other non-essential elements.
|
52
|
+
|
53
|
+
And you should only output the score.
|
54
|
+
|
55
|
+
Examples:
|
56
|
+
Context:
|
57
|
+
01-2022.04.30 贷方发生额共 计 2535.43 万元,户名;X 贸易有限公司;\n③根据用款企业提供的增值税纳税申报表来看,2021 年度用款企业年累计开票额为\n7826.48 万元,年累计应纳税合计 95.32 万元,年累计已纳税额 86.23 万元;截止至 2022 年 3 月,用款企业累计开票额为 1986.54 万元,累计应纳税合计19.54 万元,累计已纳税额\n20.23 万元。\n根据核算用款企业的银行流水及企业会计记账系统,剔除借款人往来转账款,估算用款 企业年营业额约在 6000 万元左右(纳税申报营业额),全部营业收入约 20000 万元左右,借 款人所在 X 贸易有限公司综合毛
|
58
|
+
利润率约为 35%,净利润约 20%左右。\n\n| 资产种类 | 坐落 | 产权人 | 建筑面积 | 现价值 | 贷款余额 | 资产净值 |\n| --- | --- | --- | --- | --- | --- | --- |\n| 房产 | HN 省 YY 市 PP 小区 5#2-101 | A | 240.20 | 365.23 万 | 165.
|
59
|
+
Score: 4
|
60
|
+
|
61
|
+
Context:
|
62
|
+
认缴出资额 200 万元 实缴出资额 200 万元 持股比例 20% |\n| 企业所属商圈 | 无 | 是否为已准入商圈 | 是□ 否 ☑ |\n(1) 企业经营历史及现状说明\nX 贸易有限公司 (下称“用款企业”) 注册成立于 2015 年 11 月,统一社会信用代码1234567890ACBDEFGH,法定代表人 A,公司注册地址位于 M 市 N 区 JF 路 20 号 NJ 大厦 18 楼1807 室,实际办公地址位于 M 市 N 区 K 广场 C 座 19 楼 1901、1906、1908、1910、1912、1914,办公面积为 880.51 ㎡,经营场所为用款企业租赁房产,租赁期限,现阶段年租金 73 万余元。\n用款企业是著名品牌“XYZ”的运营公司,是
|
63
|
+
以经营短袜、连裤袜、 内衣、家居服、配饰为主要品类的亲体织物公司,致力于为年轻消费群体提供“一站式”多品类亲体织物购物 体验。 作为织物文化的传播者和输出者,用款企业秉承一贯的高品质与原创精神,依托中国 研发团队,创领多项核心技术,不断建立并升级健康织物行业标准,目前拥有实用新型专利 6 项,发明专利 1 项,注册商标 30 余个,为品牌的商标保护构建了全面的商标防御体系。\n“XYZ”品牌创立于 2006 年,于 2009 年正式进入中国市场,在成立 10 年的时间里,在全国共有 400 余家店面,运营主要有以下三种模式:\n①直营模式:目前用款企业
|
64
|
+
管控的直营店有 100 家左右,其中在 M 地区共有 9 家直营店,分别为 Y1 店、Y2 店、Y3 店、Y4 店、Y5 店、Y6 店、Y7 店、Y8 店、Y9 店。经查看用款企业相关财务系统并截屏 ,用款企业 2021 年度 、2022 年 1-4 月直营店营业收入合计分别为7623.45 万元、1987.23 万元,M 地区 9 家直营店收入合计分别为 1238.67 万元、302.54 万元。根据数据测算直营部分毛利润率65%。
|
65
|
+
Score: 7
|
66
|
+
""" # noqa: E501
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
SCORE_CONTEXT_HUMAN = HumanMessagePromptTemplate.from_template(
|
71
|
+
"""
|
72
|
+
Context:
|
73
|
+
{context}
|
74
|
+
Score:
|
75
|
+
""" # noqa: E501
|
76
|
+
)
|
77
|
+
|
78
|
+
|
79
|
+
SCORE_CONTEXT_CHAT_PROMPT = ChatPromptTemplate.from_messages(
|
80
|
+
[
|
81
|
+
SCORE_CONTEXT_SYSTEM,
|
82
|
+
SCORE_CONTEXT_HUMAN
|
83
|
+
]
|
84
|
+
)
|
85
|
+
|
86
|
+
|
87
|
+
FILTER_QUESTION_SYSTEM = SystemMessagePromptTemplate.from_template(
|
88
|
+
"""\
|
89
|
+
Determine if the given question can be clearly understood even when presented without any additional context. Specify reason and verdict is a valid json format.
|
90
|
+
|
91
|
+
Examples:
|
92
|
+
question: What is the discovery about space?
|
93
|
+
{{
|
94
|
+
"reason":"The question is too vague and does not specify which discovery about space it is referring to."
|
95
|
+
"verdit":"No"
|
96
|
+
}}
|
97
|
+
|
98
|
+
question: What caused the Great Depression?
|
99
|
+
{{
|
100
|
+
"reason":"The question is specific and refers to a well-known historical economic event, making it clear and answerable.",
|
101
|
+
"verdict":"Yes"
|
102
|
+
}}
|
103
|
+
|
104
|
+
question: What is the keyword that best describes the paper's focus in natural language understanding tasks?
|
105
|
+
{{
|
106
|
+
"reason": "The question mentions a 'paper' in it without referring it's name which makes it unclear without it",
|
107
|
+
"verdict": "No"
|
108
|
+
}}
|
109
|
+
|
110
|
+
question: Who wrote 'Romeo and Juliet'?
|
111
|
+
{{
|
112
|
+
"reason": "The question is clear and refers to a specific work by name therefore it is clear",
|
113
|
+
"verdict": "Yes"
|
114
|
+
}}
|
115
|
+
|
116
|
+
question: What did the study mention?
|
117
|
+
{{
|
118
|
+
"reason": "The question is vague and does not specify which study it is referring to",
|
119
|
+
"verdict": "No"
|
120
|
+
}}
|
121
|
+
|
122
|
+
question: What is the focus of the REPLUG paper?
|
123
|
+
{{
|
124
|
+
"reason": "The question refers to a specific work by it's name hence can be understood",
|
125
|
+
"verdict": "Yes"
|
126
|
+
}}
|
127
|
+
""" # noqa: E501
|
128
|
+
)
|
129
|
+
|
130
|
+
|
131
|
+
FILTER_QUESTION_HUMAN = HumanMessagePromptTemplate.from_template(
|
132
|
+
"""\
|
133
|
+
question:{question}
|
134
|
+
""" # noqa: E501
|
135
|
+
)
|
136
|
+
|
137
|
+
|
138
|
+
FILTER_QUESTION_CHAT_PROMPT = ChatPromptTemplate.from_messages(
|
139
|
+
[
|
140
|
+
FILTER_QUESTION_SYSTEM,
|
141
|
+
FILTER_QUESTION_HUMAN
|
142
|
+
]
|
143
|
+
)
|
144
|
+
|
145
|
+
|
146
|
+
ANSWER_FORMULATE = HumanMessagePromptTemplate.from_template(
|
147
|
+
"""\
|
148
|
+
Answer the question using the information from the given context.
|
149
|
+
|
150
|
+
context:{context}
|
151
|
+
|
152
|
+
question:{question}
|
153
|
+
answer:
|
154
|
+
""" # noqa: E501
|
155
|
+
)
|
@@ -128,16 +128,43 @@ class ElemUnstructuredLoaderV0(BasePDFLoader):
|
|
128
128
|
super().__init__(file_path)
|
129
129
|
|
130
130
|
def load(self) -> List[Document]:
|
131
|
+
page_content, metadata = self.get_text_metadata()
|
132
|
+
doc = Document(page_content=page_content, metadata=metadata)
|
133
|
+
return [doc]
|
134
|
+
|
135
|
+
def get_text_metadata(self):
|
131
136
|
b64_data = base64.b64encode(open(self.file_path, 'rb').read()).decode()
|
132
137
|
payload = dict(filename=os.path.basename(self.file_name), b64_data=[b64_data], mode='text')
|
133
138
|
payload.update({'start': self.start, 'n': self.n})
|
134
139
|
payload.update(self.extra_kwargs)
|
135
|
-
resp = requests.post(self.unstructured_api_url, headers=self.headers, json=payload)
|
136
|
-
|
137
|
-
if 200
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
140
|
+
resp = requests.post(self.unstructured_api_url, headers=self.headers, json=payload)
|
141
|
+
# 说明文件解析成功
|
142
|
+
if resp.status_code == 200 and resp.json().get('status_code') == 200:
|
143
|
+
res = resp.json()
|
144
|
+
return res['text'], {'source': self.file_name}
|
145
|
+
# 说明文件解析失败,pdf文件直接返回报错
|
146
|
+
if self.file_name.endswith('.pdf'):
|
147
|
+
raise Exception(f'file text {os.path.basename(self.file_name)} failed resp={resp.text}')
|
148
|
+
# 非pdf文件,先将文件转为pdf格式,让后再执行partition模式解析文档
|
149
|
+
# 把文件转为pdf
|
150
|
+
resp = requests.post(self.unstructured_api_url, headers=self.headers, json={
|
151
|
+
'filename': os.path.basename(self.file_name),
|
152
|
+
'b64_data': [b64_data],
|
153
|
+
'mode': 'topdf',
|
154
|
+
})
|
155
|
+
if resp.status_code != 200 or resp.json().get('status_code') != 200:
|
156
|
+
raise Exception(f'file topdf {os.path.basename(self.file_name)} failed resp={resp.text}')
|
157
|
+
# 解析pdf文件
|
158
|
+
payload['mode'] = 'partition'
|
159
|
+
payload['b64_data'] = [resp.json()['b64_pdf']]
|
160
|
+
payload['filename'] = os.path.basename(self.file_name) + '.pdf'
|
161
|
+
resp = requests.post(self.unstructured_api_url, headers=self.headers, json=payload)
|
162
|
+
if resp.status_code != 200 or resp.json().get('status_code') != 200:
|
163
|
+
raise Exception(f'file partition {os.path.basename(self.file_name)} failed resp={resp.text}')
|
164
|
+
res = resp.json()
|
165
|
+
partitions = res['partitions']
|
166
|
+
if not partitions:
|
167
|
+
raise Exception(f'file partition empty {os.path.basename(self.file_name)} resp={resp.text}')
|
168
|
+
# 拼接结果为文本
|
169
|
+
content, _ = merge_partitions(partitions)
|
170
|
+
return content, {'source': self.file_name}
|
@@ -1,13 +1,13 @@
|
|
1
1
|
# import base64
|
2
2
|
import copy
|
3
|
-
import
|
3
|
+
from typing import Optional
|
4
|
+
|
4
5
|
import requests
|
5
|
-
from typing import Any, Iterator, List, Mapping, Optional, Union
|
6
6
|
|
7
7
|
|
8
8
|
class ELLMClient(object):
|
9
|
-
|
10
|
-
|
9
|
+
|
10
|
+
def __init__(self, api_base_url: Optional[str] = None):
|
11
11
|
self.ep = api_base_url
|
12
12
|
self.client = requests.Session()
|
13
13
|
self.timeout = 10000
|
@@ -26,8 +26,8 @@ class ELLMClient(object):
|
|
26
26
|
'ellm': 'ELLM'
|
27
27
|
},
|
28
28
|
'form': {
|
29
|
-
'det': '
|
30
|
-
'recog': '
|
29
|
+
'det': 'general_text_det_v2.0',
|
30
|
+
'recog': 'general_text_reg_nb_v1.0_faster',
|
31
31
|
'ellm': 'ELLM'
|
32
32
|
},
|
33
33
|
'hand': {
|
@@ -48,9 +48,7 @@ class ELLMClient(object):
|
|
48
48
|
req_data = {'data': [b64_image], 'param': params}
|
49
49
|
|
50
50
|
try:
|
51
|
-
r = self.client.post(url=self.ep,
|
52
|
-
json=req_data,
|
53
|
-
timeout=self.timeout)
|
51
|
+
r = self.client.post(url=self.ep, json=req_data, timeout=self.timeout)
|
54
52
|
return r.json()
|
55
53
|
except Exception as e:
|
56
54
|
return {'status_code': 400, 'status_message': str(e)}
|
@@ -47,6 +47,7 @@ def transpdf2png(pdf_file):
|
|
47
47
|
class UniversalKVLoader(BaseLoader):
|
48
48
|
"""Extract key-value from pdf or image.
|
49
49
|
"""
|
50
|
+
|
50
51
|
def __init__(self,
|
51
52
|
file_path: str,
|
52
53
|
ellm_model_url: str = None,
|
@@ -83,7 +84,7 @@ class UniversalKVLoader(BaseLoader):
|
|
83
84
|
|
84
85
|
kv_results = defaultdict(list)
|
85
86
|
for key, value in key_values.items():
|
86
|
-
kv_results[key]
|
87
|
+
kv_results[key].extend([v['text'] for v in value])
|
87
88
|
|
88
89
|
content = json.dumps(kv_results, indent=2, ensure_ascii=False)
|
89
90
|
file_name = os.path.basename(self.file_path)
|
@@ -95,7 +96,7 @@ class UniversalKVLoader(BaseLoader):
|
|
95
96
|
pdf_images = transpdf2png(self.file_path)
|
96
97
|
|
97
98
|
kv_results = defaultdict(list)
|
98
|
-
for pdf_name in pdf_images:
|
99
|
+
for index, pdf_name in enumerate(pdf_images):
|
99
100
|
page = int(pdf_name.split('page_')[-1])
|
100
101
|
if page > self.max_pages:
|
101
102
|
continue
|
@@ -110,7 +111,7 @@ class UniversalKVLoader(BaseLoader):
|
|
110
111
|
raise ValueError(f'universal kv load failed: {resp}')
|
111
112
|
|
112
113
|
for key, value in key_values.items():
|
113
|
-
kv_results[key].extend(
|
114
|
+
kv_results[key].extend([v['text'] for v in value])
|
114
115
|
|
115
116
|
content = json.dumps(kv_results, indent=2, ensure_ascii=False)
|
116
117
|
file_name = os.path.basename(self.file_path)
|
@@ -13,7 +13,7 @@ class OpenApiTools(APIToolBase):
|
|
13
13
|
return self.url + self.params["path"]
|
14
14
|
|
15
15
|
def get_request_method(self):
|
16
|
-
return self.params["method"]
|
16
|
+
return self.params["method"].lower()
|
17
17
|
|
18
18
|
def get_params_json(self, **kwargs):
|
19
19
|
params_define = {}
|
@@ -59,11 +59,11 @@ class OpenApiTools(APIToolBase):
|
|
59
59
|
if method == "get":
|
60
60
|
resp = self.client.get(path, params=params)
|
61
61
|
elif method == 'post':
|
62
|
-
resp = self.client.post(path, params=params, json=
|
62
|
+
resp = self.client.post(path, params=params, json=json_data)
|
63
63
|
elif method == 'put':
|
64
|
-
resp = self.client.put(path, params=params, json=
|
64
|
+
resp = self.client.put(path, params=params, json=json_data)
|
65
65
|
elif method == 'delete':
|
66
|
-
resp = self.client.delete(path, params=params, json=
|
66
|
+
resp = self.client.delete(path, params=params, json=json_data)
|
67
67
|
else:
|
68
68
|
raise Exception(f"http method is not support: {method}")
|
69
69
|
if resp.status_code != 200:
|
@@ -81,11 +81,11 @@ class OpenApiTools(APIToolBase):
|
|
81
81
|
if method == "get":
|
82
82
|
resp = await self.async_client.aget(path, params=params)
|
83
83
|
elif method == 'post':
|
84
|
-
resp = await self.async_client.apost(path, params=params, json=
|
84
|
+
resp = await self.async_client.apost(path, params=params, json=json_data)
|
85
85
|
elif method == 'put':
|
86
|
-
resp = await self.async_client.aput(path, params=params, json=
|
86
|
+
resp = await self.async_client.aput(path, params=params, json=json_data)
|
87
87
|
elif method == 'delete':
|
88
|
-
resp = await self.async_client.adelete(path, params=params, json=
|
88
|
+
resp = await self.async_client.adelete(path, params=params, json=json_data)
|
89
89
|
else:
|
90
90
|
raise Exception(f"http method is not support: {method}")
|
91
91
|
return resp
|
@@ -0,0 +1,164 @@
|
|
1
|
+
"""Chain for question-answering against a vector database."""
|
2
|
+
from __future__ import annotations
|
3
|
+
|
4
|
+
import inspect
|
5
|
+
from abc import abstractmethod
|
6
|
+
from typing import Any, Dict, List, Optional
|
7
|
+
|
8
|
+
from langchain_core.callbacks import (
|
9
|
+
AsyncCallbackManagerForChainRun,
|
10
|
+
CallbackManagerForChainRun,
|
11
|
+
Callbacks
|
12
|
+
)
|
13
|
+
from langchain_core.prompts import PromptTemplate, BasePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
|
14
|
+
from langchain_core.language_models import BaseLanguageModel
|
15
|
+
from langchain_core.pydantic_v1 import Extra, Field
|
16
|
+
from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
|
17
|
+
|
18
|
+
from langchain.chains.base import Chain
|
19
|
+
from .bisheng_rag_tool import BishengRAGTool
|
20
|
+
|
21
|
+
|
22
|
+
# system_template = """Use the following pieces of context to answer the user's question.
|
23
|
+
# If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
24
|
+
# ----------------
|
25
|
+
# {context}"""
|
26
|
+
# messages = [
|
27
|
+
# SystemMessagePromptTemplate.from_template(system_template),
|
28
|
+
# HumanMessagePromptTemplate.from_template("{question}"),
|
29
|
+
# ]
|
30
|
+
# DEFAULT_QA_PROMPT = ChatPromptTemplate.from_messages(messages)
|
31
|
+
|
32
|
+
|
33
|
+
system_template_general = """你是一个准确且可靠的知识库问答助手,能够借助上下文知识回答问题。你需要根据以下的规则来回答问题:
|
34
|
+
1. 如果上下文中包含了正确答案,你需要根据上下文进行准确的回答。但是在回答前,你需要注意,上下文中的信息可能存在事实性错误,如果文档中存在和事实不一致的错误,请根据事实回答。
|
35
|
+
2. 如果上下文中不包含答案,就说你不知道,不要试图编造答案。
|
36
|
+
3. 你需要根据上下文给出详细的回答,不要试图偷懒,不要遗漏括号中的信息,你必须回答的尽可能详细。
|
37
|
+
"""
|
38
|
+
human_template_general = """
|
39
|
+
上下文:
|
40
|
+
{context}
|
41
|
+
|
42
|
+
问题:
|
43
|
+
{question}
|
44
|
+
"""
|
45
|
+
messages_general = [
|
46
|
+
SystemMessagePromptTemplate.from_template(system_template_general),
|
47
|
+
HumanMessagePromptTemplate.from_template(human_template_general),
|
48
|
+
]
|
49
|
+
DEFAULT_QA_PROMPT = ChatPromptTemplate.from_messages(messages_general)
|
50
|
+
|
51
|
+
|
52
|
+
class BishengRetrievalQA(Chain):
|
53
|
+
"""Base class for question-answering chains."""
|
54
|
+
|
55
|
+
"""Chain to use to combine the documents."""
|
56
|
+
input_key: str = "query" #: :meta private:
|
57
|
+
output_key: str = "result" #: :meta private:
|
58
|
+
return_source_documents: bool = False
|
59
|
+
"""Return the source documents or not."""
|
60
|
+
bisheng_rag_tool: BishengRAGTool = Field(
|
61
|
+
default_factory=BishengRAGTool, description="RAG tool"
|
62
|
+
)
|
63
|
+
|
64
|
+
class Config:
|
65
|
+
"""Configuration for this pydantic object."""
|
66
|
+
|
67
|
+
extra = Extra.forbid
|
68
|
+
arbitrary_types_allowed = True
|
69
|
+
allow_population_by_field_name = True
|
70
|
+
|
71
|
+
@property
|
72
|
+
def input_keys(self) -> List[str]:
|
73
|
+
"""Input keys.
|
74
|
+
|
75
|
+
:meta private:
|
76
|
+
"""
|
77
|
+
return [self.input_key]
|
78
|
+
|
79
|
+
@property
|
80
|
+
def output_keys(self) -> List[str]:
|
81
|
+
"""Output keys.
|
82
|
+
|
83
|
+
:meta private:
|
84
|
+
"""
|
85
|
+
_output_keys = [self.output_key]
|
86
|
+
if self.return_source_documents:
|
87
|
+
_output_keys = _output_keys + ["source_documents"]
|
88
|
+
return _output_keys
|
89
|
+
|
90
|
+
@classmethod
|
91
|
+
def from_llm(
|
92
|
+
cls,
|
93
|
+
llm: BaseLanguageModel,
|
94
|
+
vector_store: Milvus,
|
95
|
+
keyword_store: ElasticKeywordsSearch,
|
96
|
+
QA_PROMPT: ChatPromptTemplate = DEFAULT_QA_PROMPT,
|
97
|
+
max_content: int = 15000,
|
98
|
+
sort_by_source_and_index: bool = False,
|
99
|
+
callbacks: Callbacks = None,
|
100
|
+
**kwargs: Any,
|
101
|
+
) -> BishengRetrievalQA:
|
102
|
+
bisheng_rag_tool = BishengRAGTool(
|
103
|
+
vector_store=vector_store,
|
104
|
+
keyword_store=keyword_store,
|
105
|
+
llm=llm,
|
106
|
+
QA_PROMPT=QA_PROMPT,
|
107
|
+
max_content=max_content,
|
108
|
+
sort_by_source_and_index=sort_by_source_and_index,
|
109
|
+
**kwargs
|
110
|
+
)
|
111
|
+
return cls(
|
112
|
+
bisheng_rag_tool=bisheng_rag_tool,
|
113
|
+
callbacks=callbacks,
|
114
|
+
**kwargs,
|
115
|
+
)
|
116
|
+
|
117
|
+
def _call(
|
118
|
+
self,
|
119
|
+
inputs: Dict[str, Any],
|
120
|
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
121
|
+
) -> Dict[str, Any]:
|
122
|
+
"""Run get_relevant_text and llm on input query.
|
123
|
+
|
124
|
+
If chain has 'return_source_documents' as 'True', returns
|
125
|
+
the retrieved documents as well under the key 'source_documents'.
|
126
|
+
|
127
|
+
Example:
|
128
|
+
.. code-block:: python
|
129
|
+
|
130
|
+
res = indexqa({'query': 'This is my query'})
|
131
|
+
answer, docs = res['result'], res['source_documents']
|
132
|
+
"""
|
133
|
+
question = inputs[self.input_key]
|
134
|
+
if self.return_source_documents:
|
135
|
+
answer, docs = self.bisheng_rag_tool.run(question, return_only_outputs=False)
|
136
|
+
return {self.output_key: answer, "source_documents": docs}
|
137
|
+
else:
|
138
|
+
answer = self.bisheng_rag_tool.run(question, return_only_outputs=True)
|
139
|
+
return {self.output_key: answer}
|
140
|
+
|
141
|
+
async def _acall(
|
142
|
+
self,
|
143
|
+
inputs: Dict[str, Any],
|
144
|
+
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
145
|
+
) -> Dict[str, Any]:
|
146
|
+
"""Run get_relevant_text and llm on input query.
|
147
|
+
|
148
|
+
If chain has 'return_source_documents' as 'True', returns
|
149
|
+
the retrieved documents as well under the key 'source_documents'.
|
150
|
+
|
151
|
+
Example:
|
152
|
+
.. code-block:: python
|
153
|
+
|
154
|
+
res = indexqa({'query': 'This is my query'})
|
155
|
+
answer, docs = res['result'], res['source_documents']
|
156
|
+
"""
|
157
|
+
question = inputs[self.input_key]
|
158
|
+
|
159
|
+
if self.return_source_documents:
|
160
|
+
answer, docs = await self.bisheng_rag_tool.arun(question, return_only_outputs=False)
|
161
|
+
return {self.output_key: answer, "source_documents": docs}
|
162
|
+
else:
|
163
|
+
answer = await self.bisheng_rag_tool.arun(question, return_only_outputs=True)
|
164
|
+
return {self.output_key: answer}
|
@@ -44,7 +44,9 @@ class BishengRagPipeline:
|
|
44
44
|
if embedding_params['type'] == 'OpenAIEmbeddings' and embedding_params['openai_proxy']:
|
45
45
|
embedding_params.pop('type')
|
46
46
|
self.embeddings = embedding_object(
|
47
|
-
http_client=httpx.Client(proxies=embedding_params['openai_proxy']),
|
47
|
+
http_client=httpx.Client(proxies=embedding_params['openai_proxy']),
|
48
|
+
http_async_client=httpx.AsyncClient(proxies=embedding_params['openai_proxy']),
|
49
|
+
**embedding_params
|
48
50
|
)
|
49
51
|
else:
|
50
52
|
embedding_params.pop('type')
|
@@ -55,7 +57,11 @@ class BishengRagPipeline:
|
|
55
57
|
llm_object = import_by_type(_type='llms', name=llm_params['type'])
|
56
58
|
if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
|
57
59
|
llm_params.pop('type')
|
58
|
-
self.llm = llm_object(
|
60
|
+
self.llm = llm_object(
|
61
|
+
http_client=httpx.Client(proxies=llm_params['openai_proxy']),
|
62
|
+
http_async_client=httpx.AsyncClient(proxies=llm_params['openai_proxy']),
|
63
|
+
**llm_params
|
64
|
+
)
|
59
65
|
else:
|
60
66
|
llm_params.pop('type')
|
61
67
|
self.llm = llm_object(**llm_params)
|
@@ -3,10 +3,14 @@ import os
|
|
3
3
|
import yaml
|
4
4
|
import httpx
|
5
5
|
from typing import Any, Dict, Tuple, Type, Union, Optional
|
6
|
+
|
7
|
+
from langchain_core.vectorstores import VectorStoreRetriever
|
6
8
|
from loguru import logger
|
7
9
|
from langchain_core.tools import BaseTool, Tool
|
8
10
|
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
9
11
|
from langchain_core.language_models.base import LanguageModelLike
|
12
|
+
from langchain_core.prompts import ChatPromptTemplate
|
13
|
+
from langchain.chains.llm import LLMChain
|
10
14
|
from langchain.chains.question_answering import load_qa_chain
|
11
15
|
from bisheng_langchain.retrievers import EnsembleRetriever
|
12
16
|
from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
|
@@ -39,6 +43,7 @@ class BishengRAGTool:
|
|
39
43
|
keyword_store: Optional[ElasticKeywordsSearch] = None,
|
40
44
|
llm: Optional[LanguageModelLike] = None,
|
41
45
|
collection_name: Optional[str] = None,
|
46
|
+
QA_PROMPT: Optional[ChatPromptTemplate] = None,
|
42
47
|
**kwargs
|
43
48
|
) -> None:
|
44
49
|
if collection_name is None and (keyword_store is None or vector_store is None):
|
@@ -54,10 +59,27 @@ class BishengRAGTool:
|
|
54
59
|
sort_by_source_and_index = kwargs.get("sort_by_source_and_index", True)
|
55
60
|
self.params['generate']['max_content'] = max_content
|
56
61
|
self.params['post_retrieval']['sort_by_source_and_index'] = sort_by_source_and_index
|
62
|
+
|
63
|
+
# init llm
|
64
|
+
if llm:
|
65
|
+
self.llm = llm
|
66
|
+
else:
|
67
|
+
llm_params = self.params['chat_llm']
|
68
|
+
llm_object = import_by_type(_type='llms', name=llm_params['type'])
|
69
|
+
if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
|
70
|
+
llm_params.pop('type')
|
71
|
+
self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
|
72
|
+
else:
|
73
|
+
llm_params.pop('type')
|
74
|
+
self.llm = llm_object(**llm_params)
|
57
75
|
|
58
76
|
# init milvus
|
59
77
|
if vector_store:
|
60
|
-
|
78
|
+
# if vector_store is retriever, get vector_store instance
|
79
|
+
if isinstance(vector_store, VectorStoreRetriever):
|
80
|
+
self.vector_store = vector_store.vectorstore
|
81
|
+
else:
|
82
|
+
self.vector_store = vector_store
|
61
83
|
else:
|
62
84
|
# init embeddings
|
63
85
|
embedding_params = self.params['embedding']
|
@@ -83,24 +105,17 @@ class BishengRAGTool:
|
|
83
105
|
if keyword_store:
|
84
106
|
self.keyword_store = keyword_store
|
85
107
|
else:
|
108
|
+
if self.params['elasticsearch'].get('extract_key_by_llm', False):
|
109
|
+
extract_key_prompt = import_class(f'bisheng_langchain.rag.prompts.EXTRACT_KEY_PROMPT')
|
110
|
+
llm_chain = LLMChain(llm=self.llm, prompt=extract_key_prompt)
|
111
|
+
else:
|
112
|
+
llm_chain = None
|
86
113
|
self.keyword_store = ElasticKeywordsSearch(
|
87
114
|
index_name='default_es',
|
88
115
|
elasticsearch_url=self.params['elasticsearch']['url'],
|
89
116
|
ssl_verify=self.params['elasticsearch']['ssl_verify'],
|
117
|
+
llm_chain=llm_chain,
|
90
118
|
)
|
91
|
-
|
92
|
-
# init llm
|
93
|
-
if llm:
|
94
|
-
self.llm = llm
|
95
|
-
else:
|
96
|
-
llm_params = self.params['chat_llm']
|
97
|
-
llm_object = import_by_type(_type='llms', name=llm_params['type'])
|
98
|
-
if llm_params['type'] == 'ChatOpenAI' and llm_params['openai_proxy']:
|
99
|
-
llm_params.pop('type')
|
100
|
-
self.llm = llm_object(http_client=httpx.Client(proxies=llm_params['openai_proxy']), **llm_params)
|
101
|
-
else:
|
102
|
-
llm_params.pop('type')
|
103
|
-
self.llm = llm_object(**llm_params)
|
104
119
|
|
105
120
|
# init retriever
|
106
121
|
retriever_list = []
|
@@ -117,11 +132,14 @@ class BishengRAGTool:
|
|
117
132
|
self.retriever = EnsembleRetriever(retrievers=retriever_list)
|
118
133
|
|
119
134
|
# init qa chain
|
120
|
-
if
|
121
|
-
|
122
|
-
prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
|
135
|
+
if QA_PROMPT:
|
136
|
+
prompt = QA_PROMPT
|
123
137
|
else:
|
124
|
-
|
138
|
+
if 'prompt_type' in self.params['generate']:
|
139
|
+
prompt_type = self.params['generate']['prompt_type']
|
140
|
+
prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
|
141
|
+
else:
|
142
|
+
prompt = None
|
125
143
|
self.qa_chain = load_qa_chain(
|
126
144
|
llm=self.llm,
|
127
145
|
chain_type=self.params['generate']['chain_type'],
|
@@ -218,18 +236,23 @@ class BishengRAGTool:
|
|
218
236
|
docs = sorted(docs, key=lambda x: (x.metadata['source'], x.metadata['chunk_index']))
|
219
237
|
return docs
|
220
238
|
|
221
|
-
def run(self, query) ->
|
239
|
+
def run(self, query, return_only_outputs=True) -> Any:
|
222
240
|
docs = self.retrieval_and_rerank(query)
|
223
241
|
try:
|
224
|
-
ans = self.qa_chain({"input_documents": docs, "question": query}, return_only_outputs=
|
242
|
+
ans = self.qa_chain({"input_documents": docs, "question": query}, return_only_outputs=return_only_outputs)
|
225
243
|
except Exception as e:
|
226
244
|
logger.error(f'question: {query}\nerror: {e}')
|
227
245
|
ans = {'output_text': str(e)}
|
228
|
-
|
229
|
-
|
246
|
+
if return_only_outputs:
|
247
|
+
rag_answer = ans['output_text']
|
248
|
+
return rag_answer
|
249
|
+
else:
|
250
|
+
rag_answer = ans['output_text']
|
251
|
+
input_documents = ans['input_documents']
|
252
|
+
return rag_answer, input_documents
|
230
253
|
|
231
|
-
async def arun(self, query: str) -> str:
|
232
|
-
rag_answer = self.run(query)
|
254
|
+
async def arun(self, query: str, return_only_outputs=True) -> str:
|
255
|
+
rag_answer = self.run(query, return_only_outputs)
|
233
256
|
return rag_answer
|
234
257
|
|
235
258
|
@classmethod
|