bisheng-langchain 0.3.0rc0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/chat_models/host_llm.py +1 -1
- bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +5 -3
- bisheng_langchain/gpts/agent_types/llm_functions_agent.py +7 -1
- bisheng_langchain/gpts/assistant.py +8 -5
- bisheng_langchain/gpts/auto_optimization.py +28 -27
- bisheng_langchain/gpts/auto_tool_selected.py +14 -15
- bisheng_langchain/gpts/load_tools.py +53 -1
- bisheng_langchain/gpts/prompts/__init__.py +4 -2
- bisheng_langchain/gpts/prompts/assistant_prompt_base.py +1 -0
- bisheng_langchain/gpts/prompts/assistant_prompt_cohere.py +19 -0
- bisheng_langchain/gpts/prompts/opening_dialog_prompt.py +1 -1
- bisheng_langchain/gpts/tools/api_tools/__init__.py +1 -1
- bisheng_langchain/gpts/tools/api_tools/base.py +3 -3
- bisheng_langchain/gpts/tools/api_tools/flow.py +19 -7
- bisheng_langchain/gpts/tools/api_tools/macro_data.py +175 -4
- bisheng_langchain/gpts/tools/api_tools/openapi.py +101 -0
- bisheng_langchain/gpts/tools/api_tools/sina.py +2 -2
- bisheng_langchain/gpts/tools/code_interpreter/tool.py +118 -39
- bisheng_langchain/rag/__init__.py +5 -0
- bisheng_langchain/rag/bisheng_rag_pipeline.py +320 -0
- bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +359 -0
- bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +376 -0
- bisheng_langchain/rag/bisheng_rag_tool.py +288 -0
- bisheng_langchain/rag/config/baseline.yaml +86 -0
- bisheng_langchain/rag/config/baseline_caibao.yaml +82 -0
- bisheng_langchain/rag/config/baseline_caibao_knowledge_v2.yaml +110 -0
- bisheng_langchain/rag/config/baseline_caibao_v2.yaml +112 -0
- bisheng_langchain/rag/config/baseline_demo_v2.yaml +92 -0
- bisheng_langchain/rag/config/baseline_s2b_mix.yaml +88 -0
- bisheng_langchain/rag/config/baseline_v2.yaml +90 -0
- bisheng_langchain/rag/extract_info.py +38 -0
- bisheng_langchain/rag/init_retrievers/__init__.py +4 -0
- bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +61 -0
- bisheng_langchain/rag/init_retrievers/keyword_retriever.py +65 -0
- bisheng_langchain/rag/init_retrievers/mix_retriever.py +103 -0
- bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +92 -0
- bisheng_langchain/rag/prompts/__init__.py +9 -0
- bisheng_langchain/rag/prompts/extract_key_prompt.py +34 -0
- bisheng_langchain/rag/prompts/prompt.py +47 -0
- bisheng_langchain/rag/prompts/prompt_cohere.py +111 -0
- bisheng_langchain/rag/qa_corpus/__init__.py +0 -0
- bisheng_langchain/rag/qa_corpus/qa_generator.py +143 -0
- bisheng_langchain/rag/rerank/__init__.py +5 -0
- bisheng_langchain/rag/rerank/rerank.py +48 -0
- bisheng_langchain/rag/rerank/rerank_benchmark.py +139 -0
- bisheng_langchain/rag/run_qa_gen_web.py +47 -0
- bisheng_langchain/rag/run_rag_evaluate_web.py +55 -0
- bisheng_langchain/rag/scoring/__init__.py +0 -0
- bisheng_langchain/rag/scoring/llama_index_score.py +91 -0
- bisheng_langchain/rag/scoring/ragas_score.py +183 -0
- bisheng_langchain/rag/utils.py +181 -0
- bisheng_langchain/retrievers/ensemble.py +2 -1
- bisheng_langchain/vectorstores/elastic_keywords_search.py +2 -1
- {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/METADATA +1 -1
- {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/RECORD +57 -22
- bisheng_langchain/gpts/prompts/base_prompt.py +0 -1
- {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/WHEEL +0 -0
- {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/top_level.txt +0 -0
@@ -163,7 +163,7 @@ class BaseHostChatLLM(BaseChatModel):
|
|
163
163
|
values[
|
164
164
|
'host_base_url'] = f"{values['host_base_url']}/{values['model_name']}/infer"
|
165
165
|
except Exception:
|
166
|
-
raise Exception(f'Update Decoupled status
|
166
|
+
raise Exception(f'Update Decoupled status failed for model {model}')
|
167
167
|
|
168
168
|
try:
|
169
169
|
if values['headers']:
|
@@ -85,16 +85,18 @@ class ElemUnstructuredLoader(BasePDFLoader):
|
|
85
85
|
mode='partition',
|
86
86
|
parameters=parameters)
|
87
87
|
|
88
|
-
resp = requests.post(self.unstructured_api_url, headers=self.headers, json=payload)
|
88
|
+
resp = requests.post(self.unstructured_api_url, headers=self.headers, json=payload)
|
89
|
+
if resp.status_code != 200:
|
90
|
+
raise Exception(f'file partition {os.path.basename(self.file_name)} failed resp={resp.text}')
|
89
91
|
|
92
|
+
resp = resp.json()
|
90
93
|
if 200 != resp.get('status_code'):
|
91
|
-
logger.info(f'
|
94
|
+
logger.info(f'file partition {os.path.basename(self.file_name)} error resp={resp}')
|
92
95
|
partitions = resp['partitions']
|
93
96
|
if not partitions:
|
94
97
|
logger.info(f'partition_error resp={resp}')
|
95
98
|
logger.info(f'unstruct_return code={resp.get("status_code")}')
|
96
99
|
|
97
|
-
partitions = resp['partitions']
|
98
100
|
content, metadata = merge_partitions(partitions)
|
99
101
|
metadata['source'] = self.file_name
|
100
102
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
|
2
|
+
import re
|
3
3
|
from bisheng_langchain.gpts.message_types import LiberalFunctionMessage, LiberalToolMessage
|
4
4
|
from langchain.tools import BaseTool
|
5
5
|
from langchain.tools.render import format_tool_to_openai_tool
|
@@ -39,6 +39,12 @@ def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageMode
|
|
39
39
|
last_message = messages[-1]
|
40
40
|
# If there is no function call, then we finish
|
41
41
|
if 'tool_calls' not in last_message.additional_kwargs:
|
42
|
+
if '|<instruct>|' in system_message:
|
43
|
+
# cohere model
|
44
|
+
pattern = r"Answer:(.+)\nGrounded answer"
|
45
|
+
match = re.search(pattern, last_message.content)
|
46
|
+
if match:
|
47
|
+
last_message.content = match.group(1)
|
42
48
|
return 'end'
|
43
49
|
# Otherwise if there is, we continue
|
44
50
|
else:
|
@@ -126,11 +126,14 @@ class BishengAssistant:
|
|
126
126
|
if __name__ == "__main__":
|
127
127
|
from langchain.globals import set_debug
|
128
128
|
|
129
|
-
set_debug(True)
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
129
|
+
# set_debug(True)
|
130
|
+
chat_history = []
|
131
|
+
query = "请简要分析中科创达软件股份有限公司2019年聘任、解聘会计师事务的情况。"
|
132
|
+
# chat_history = ['你好', '你好,有什么可以帮助你吗?', '福蓉科技股价多少?', '福蓉科技(股票代码:300049)的当前股价为48.67元。']
|
133
|
+
# query = '去年这个时候的股价是多少?'
|
134
|
+
# bisheng_assistant = BishengAssistant("config/base_scene.yaml")
|
135
|
+
# bisheng_assistant = BishengAssistant("config/knowledge_scene.yaml")
|
136
|
+
bisheng_assistant = BishengAssistant("config/rag_scene.yaml")
|
134
137
|
result = bisheng_assistant.run(query, chat_history=chat_history)
|
135
138
|
for r in result:
|
136
139
|
print(f'------------------')
|
@@ -3,7 +3,11 @@ import os
|
|
3
3
|
import re
|
4
4
|
|
5
5
|
import httpx
|
6
|
-
from bisheng_langchain.gpts.prompts import
|
6
|
+
from bisheng_langchain.gpts.prompts import (
|
7
|
+
ASSISTANT_PROMPT_OPT,
|
8
|
+
BREIF_DES_PROMPT,
|
9
|
+
OPENDIALOG_PROMPT,
|
10
|
+
)
|
7
11
|
from langchain_core.language_models.base import LanguageModelLike
|
8
12
|
from langchain_openai.chat_models import ChatOpenAI
|
9
13
|
from loguru import logger
|
@@ -48,16 +52,13 @@ def optimize_assistant_prompt(
|
|
48
52
|
Returns:
|
49
53
|
assistant_prompt(str):
|
50
54
|
"""
|
51
|
-
chain =
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
'assistant_name': assistant_name,
|
59
|
-
'assistant_description': assistant_description,
|
60
|
-
})
|
55
|
+
chain = ASSISTANT_PROMPT_OPT | llm
|
56
|
+
chain_output = chain.invoke(
|
57
|
+
{
|
58
|
+
'assistant_name': assistant_name,
|
59
|
+
'assistant_description': assistant_description,
|
60
|
+
}
|
61
|
+
)
|
61
62
|
response = chain_output.content
|
62
63
|
assistant_prompt = parse_markdown(response)
|
63
64
|
return assistant_prompt
|
@@ -67,17 +68,15 @@ def generate_opening_dialog(
|
|
67
68
|
llm: LanguageModelLike,
|
68
69
|
description: str,
|
69
70
|
) -> str:
|
70
|
-
chain =
|
71
|
-
'description': lambda x: x['description'],
|
72
|
-
}
|
73
|
-
| OPENDIALOG_PROMPT
|
74
|
-
| llm)
|
71
|
+
chain = OPENDIALOG_PROMPT | llm
|
75
72
|
time = 0
|
76
73
|
while time <= 3:
|
77
74
|
try:
|
78
|
-
chain_output = chain.invoke(
|
79
|
-
|
80
|
-
|
75
|
+
chain_output = chain.invoke(
|
76
|
+
{
|
77
|
+
'description': description,
|
78
|
+
}
|
79
|
+
)
|
81
80
|
output = parse_json(chain_output.content)
|
82
81
|
output = json.loads(output)
|
83
82
|
opening_lines = output[0]['开场白']
|
@@ -101,20 +100,22 @@ def generate_breif_description(
|
|
101
100
|
llm: LanguageModelLike,
|
102
101
|
description: str,
|
103
102
|
) -> str:
|
104
|
-
chain =
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
'description': description,
|
111
|
-
})
|
103
|
+
chain = BREIF_DES_PROMPT | llm
|
104
|
+
chain_output = chain.invoke(
|
105
|
+
{
|
106
|
+
'description': description,
|
107
|
+
}
|
108
|
+
)
|
112
109
|
breif_description = chain_output.content
|
113
110
|
breif_description = breif_description.strip()
|
114
111
|
return breif_description
|
115
112
|
|
116
113
|
|
117
114
|
if __name__ == '__main__':
|
115
|
+
from dotenv import load_dotenv
|
116
|
+
|
117
|
+
load_dotenv('/app/.env', override=True)
|
118
|
+
|
118
119
|
httpx_client = httpx.Client(proxies=os.getenv('OPENAI_PROXY'))
|
119
120
|
llm = ChatOpenAI(model='gpt-4-0125-preview', temperature=0.01, http_client=httpx_client)
|
120
121
|
# llm = ChatQWen(model="qwen1.5-72b-chat", temperature=0.01, api_key=os.getenv('QWEN_API_KEY'))
|
@@ -1,6 +1,9 @@
|
|
1
1
|
from bisheng_langchain.gpts.prompts.select_tools_prompt import HUMAN_MSG, SYS_MSG
|
2
|
-
from langchain.prompts import (
|
3
|
-
|
2
|
+
from langchain.prompts import (
|
3
|
+
ChatPromptTemplate,
|
4
|
+
HumanMessagePromptTemplate,
|
5
|
+
SystemMessagePromptTemplate,
|
6
|
+
)
|
4
7
|
from langchain_core.language_models.base import LanguageModelLike
|
5
8
|
from pydantic import BaseModel
|
6
9
|
|
@@ -31,19 +34,15 @@ class ToolSelector:
|
|
31
34
|
HumanMessagePromptTemplate.from_template(self.human_message),
|
32
35
|
]
|
33
36
|
|
34
|
-
chain = (
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
'tool_pool': tool_pool,
|
44
|
-
'task_name': task_name,
|
45
|
-
'task_description': task_description,
|
46
|
-
})
|
37
|
+
chain = ChatPromptTemplate.from_messages(messages) | self.llm
|
38
|
+
|
39
|
+
chain_output = chain.invoke(
|
40
|
+
{
|
41
|
+
'tool_pool': tool_pool,
|
42
|
+
'task_name': task_name,
|
43
|
+
'task_description': task_description,
|
44
|
+
}
|
45
|
+
)
|
47
46
|
|
48
47
|
try:
|
49
48
|
all_tool_name = set([tool.tool_name for tool in self.tools])
|
@@ -1,7 +1,11 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
1
3
|
import warnings
|
2
4
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
3
5
|
|
4
6
|
import httpx
|
7
|
+
import pandas as pd
|
8
|
+
import pymysql
|
5
9
|
from bisheng_langchain.gpts.tools.api_tools import ALL_API_TOOLS
|
6
10
|
from bisheng_langchain.gpts.tools.bing_search.tool import BingSearchRun
|
7
11
|
from bisheng_langchain.gpts.tools.calculator.tool import calculator
|
@@ -13,6 +17,7 @@ from bisheng_langchain.gpts.tools.dalle_image_generator.tool import (
|
|
13
17
|
DallEImageGenerator,
|
14
18
|
)
|
15
19
|
from bisheng_langchain.gpts.tools.get_current_time.tool import get_current_time
|
20
|
+
from dotenv import load_dotenv
|
16
21
|
from langchain_community.tools.arxiv.tool import ArxivQueryRun
|
17
22
|
from langchain_community.tools.bearly.tool import BearlyInterpreterTool
|
18
23
|
from langchain_community.utilities.arxiv import ArxivAPIWrapper
|
@@ -21,6 +26,7 @@ from langchain_core.callbacks import BaseCallbackManager, Callbacks
|
|
21
26
|
from langchain_core.language_models import BaseLanguageModel
|
22
27
|
from langchain_core.tools import BaseTool, Tool
|
23
28
|
from mypy_extensions import Arg, KwArg
|
29
|
+
from bisheng_langchain.rag import BishengRAGTool
|
24
30
|
|
25
31
|
|
26
32
|
def _get_current_time() -> BaseTool:
|
@@ -54,12 +60,14 @@ def _get_bing_search(**kwargs: Any) -> BaseTool:
|
|
54
60
|
|
55
61
|
def _get_dalle_image_generator(**kwargs: Any) -> Tool:
|
56
62
|
openai_api_key = kwargs.get('openai_api_key')
|
63
|
+
openai_api_base = kwargs.get('openai_api_base')
|
57
64
|
http_async_client = httpx.AsyncClient(proxies=kwargs.get('openai_proxy'))
|
58
65
|
httpc_client = httpx.Client(proxies=kwargs.get('openai_proxy'))
|
59
66
|
return DallEImageGenerator(
|
60
67
|
api_wrapper=DallEAPIWrapper(
|
61
68
|
model='dall-e-3',
|
62
69
|
api_key=openai_api_key,
|
70
|
+
base_url=openai_api_base,
|
63
71
|
http_client=httpc_client,
|
64
72
|
http_async_client=http_async_client,
|
65
73
|
)
|
@@ -78,7 +86,8 @@ def _get_native_code_interpreter(**kwargs: Any) -> Tool:
|
|
78
86
|
_EXTRA_PARAM_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[Optional[str]], List[Optional[str]]]] = { # type: ignore
|
79
87
|
'dalle_image_generator': (_get_dalle_image_generator, ['openai_api_key', 'openai_proxy'], []),
|
80
88
|
'bing_search': (_get_bing_search, ['bing_subscription_key', 'bing_search_url'], []),
|
81
|
-
'
|
89
|
+
'bisheng_code_interpreter': (_get_native_code_interpreter, ["minio"], ['files']),
|
90
|
+
'bisheng_rag': (BishengRAGTool.get_rag_tool, ['name', 'description'], ['vector_store', 'keyword_store', 'llm', 'collection_name', 'max_content', 'sort_by_source_and_index']),
|
82
91
|
}
|
83
92
|
|
84
93
|
_API_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {**ALL_API_TOOLS} # type: ignore
|
@@ -159,3 +168,46 @@ def load_tools(
|
|
159
168
|
def get_all_tool_names() -> List[str]:
|
160
169
|
"""Get a list of all possible tool names."""
|
161
170
|
return list(_ALL_TOOLS.keys())
|
171
|
+
|
172
|
+
|
173
|
+
def get_tool_table():
|
174
|
+
|
175
|
+
load_dotenv('.sql_env', override=True)
|
176
|
+
db = pymysql.connect(
|
177
|
+
host=os.getenv('MYSQL_HOST'),
|
178
|
+
user=os.getenv('MYSQL_USER'),
|
179
|
+
password=os.getenv('MYSQL_PASSWORD'),
|
180
|
+
database=os.getenv('MYSQL_DATABASE'),
|
181
|
+
port=int(os.getenv('MYSQL_PORT')),
|
182
|
+
)
|
183
|
+
cursor = db.cursor()
|
184
|
+
cursor.execute("SELECT name, t.desc, tool_key, extra FROM t_gpts_tools as t;")
|
185
|
+
results = cursor.fetchall()
|
186
|
+
db.close()
|
187
|
+
|
188
|
+
df = pd.DataFrame(
|
189
|
+
columns=[
|
190
|
+
'前端工具名',
|
191
|
+
'前端工具描述',
|
192
|
+
'tool_key',
|
193
|
+
'tool参数配置',
|
194
|
+
'function_name',
|
195
|
+
'function_description',
|
196
|
+
'function_args',
|
197
|
+
]
|
198
|
+
)
|
199
|
+
for i, result in enumerate(results):
|
200
|
+
name, desc, tool_key, extra = result
|
201
|
+
if not extra:
|
202
|
+
extra = '{}'
|
203
|
+
tool_func = load_tools({tool_key: json.loads(extra)})[0]
|
204
|
+
|
205
|
+
df.loc[i, '前端工具名'] = name
|
206
|
+
df.loc[i, '前端工具描述'] = desc
|
207
|
+
df.loc[i, 'tool_key'] = tool_key
|
208
|
+
df.loc[i, 'tool参数配置'] = extra
|
209
|
+
df.loc[i, 'function_name'] = tool_func.name
|
210
|
+
df.loc[i, 'function_description'] = tool_func.description
|
211
|
+
df.loc[i, 'function_args'] = f"{tool_func.args_schema.schema()['properties']}"
|
212
|
+
|
213
|
+
return df
|
@@ -1,12 +1,14 @@
|
|
1
1
|
from bisheng_langchain.gpts.prompts.assistant_prompt_opt import ASSISTANT_PROMPT_OPT
|
2
|
-
from bisheng_langchain.gpts.prompts.
|
2
|
+
from bisheng_langchain.gpts.prompts.assistant_prompt_base import ASSISTANT_PROMPT_DEFAULT
|
3
|
+
from bisheng_langchain.gpts.prompts.assistant_prompt_cohere import ASSISTANT_PROMPT_COHERE
|
3
4
|
from bisheng_langchain.gpts.prompts.breif_description_prompt import BREIF_DES_PROMPT
|
4
5
|
from bisheng_langchain.gpts.prompts.opening_dialog_prompt import OPENDIALOG_PROMPT
|
5
6
|
from bisheng_langchain.gpts.prompts.select_tools_prompt import HUMAN_MSG, SYS_MSG
|
6
7
|
|
7
8
|
|
8
9
|
__all__ = [
|
9
|
-
"
|
10
|
+
"ASSISTANT_PROMPT_DEFAULT",
|
11
|
+
"ASSISTANT_PROMPT_COHERE",
|
10
12
|
"ASSISTANT_PROMPT_OPT",
|
11
13
|
"OPENDIALOG_PROMPT",
|
12
14
|
"BREIF_DES_PROMPT",
|
@@ -0,0 +1 @@
|
|
1
|
+
ASSISTANT_PROMPT_DEFAULT = "You are a helpful assistant."
|
@@ -0,0 +1,19 @@
|
|
1
|
+
preamble="""You are a helpful assistant.
|
2
|
+
"""
|
3
|
+
|
4
|
+
ASSISTANT_PROMPT_COHERE="""{preamble}|<instruct>|Carefully perform the following instructions, in order, starting each with a new line.
|
5
|
+
Firstly, You may need to use complex and advanced reasoning to complete your task and answer the question. Think about how you can use the provided tools to answer the question and come up with a high level plan you will execute.
|
6
|
+
Write 'Plan:' followed by an initial high level plan of how you will solve the problem including the tools and steps required.
|
7
|
+
Secondly, Carry out your plan by repeatedly using actions, reasoning over the results, and re-evaluating your plan. Perform Action, Observation, Reflection steps with the following format. Write 'Action:' followed by a json formatted action containing the "tool_name" and "parameters"
|
8
|
+
Next you will analyze the 'Observation:', this is the result of the action.
|
9
|
+
After that you should always think about what to do next. Write 'Reflection:' followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next including if you know the answer to the question.
|
10
|
+
... (this Action/Observation/Reflection can repeat N times)
|
11
|
+
Thirdly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'.
|
12
|
+
Fourthly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'.
|
13
|
+
Fifthly, Write 'Answer:' followed by a response to the user's last input. Use the retrieved documents to help you. Do not insert any citations or grounding markup.
|
14
|
+
Finally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols <co: doc> and </co: doc> to indicate when a fact comes from a document in the search result, e.g <co: 4>my fact</co: 4> for a fact from document 4.
|
15
|
+
|
16
|
+
Additional instructions to note:
|
17
|
+
- If the user's question is in Chinese, please answer it in Chinese.
|
18
|
+
- 当问题中有涉及到时间信息时,比如最近6个月、昨天、去年等,你需要用时间工具查询时间信息。
|
19
|
+
""".format(preamble=preamble)
|
@@ -6,7 +6,7 @@ from langchain_core.prompts.chat import (
|
|
6
6
|
)
|
7
7
|
|
8
8
|
system_template = """
|
9
|
-
|
9
|
+
你是一个生成开场白和预置问题的助手。接下来,你会收到一段关于任务助手的描述,你需要带入描述中的角色,以描述中的角色身份生成一段开场白,同时你还需要站在用户的角度生成几个用户可能的提问。输出格式如下:
|
10
10
|
[
|
11
11
|
{{
|
12
12
|
"开场白": "开场白内容",
|
@@ -40,7 +40,7 @@ _MACRO_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {
|
|
40
40
|
|
41
41
|
_tmp_flow = ['knowledge_retrieve']
|
42
42
|
_TMP_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {
|
43
|
-
f'flow_{name}': (FlowTools.get_api_tool, ['collection_id'])
|
43
|
+
f'flow_{name}': (FlowTools.get_api_tool, ['collection_id', 'description'])
|
44
44
|
for name in _tmp_flow
|
45
45
|
}
|
46
46
|
ALL_API_TOOLS = {}
|
@@ -64,7 +64,7 @@ class APIToolBase(BaseModel):
|
|
64
64
|
resp = self.client.get(url)
|
65
65
|
if resp.status_code != 200:
|
66
66
|
logger.info('api_call_fail res={}', resp.text)
|
67
|
-
return resp.text
|
67
|
+
return resp.text[:10000]
|
68
68
|
|
69
69
|
async def arun(self, query: str, **kwargs) -> str:
|
70
70
|
"""Run query through api and parse result."""
|
@@ -79,8 +79,8 @@ class APIToolBase(BaseModel):
|
|
79
79
|
url = self.url
|
80
80
|
logger.info('api_call url={}', url)
|
81
81
|
resp = await self.async_client.aget(url)
|
82
|
-
logger.info(resp)
|
83
|
-
return resp
|
82
|
+
logger.info(resp[:10000])
|
83
|
+
return resp[:10000]
|
84
84
|
|
85
85
|
@classmethod
|
86
86
|
def get_api_tool(cls, name, **kwargs: Any) -> BaseTool:
|
@@ -1,7 +1,9 @@
|
|
1
1
|
from loguru import logger
|
2
|
-
from
|
3
|
-
|
2
|
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
3
|
+
from typing import Any
|
4
4
|
from .base import APIToolBase
|
5
|
+
from .base import MultArgsSchemaTool
|
6
|
+
from langchain_core.tools import BaseTool
|
5
7
|
|
6
8
|
|
7
9
|
class FlowTools(APIToolBase):
|
@@ -34,10 +36,8 @@ class FlowTools(APIToolBase):
|
|
34
36
|
return resp
|
35
37
|
|
36
38
|
@classmethod
|
37
|
-
def knowledge_retrieve(cls, collection_id: int = None) -> str:
|
38
|
-
|
39
|
-
知识库检索工具,从内部知识库进行检索总结
|
40
|
-
"""
|
39
|
+
def knowledge_retrieve(cls, collection_id: int = None) -> str:
|
40
|
+
|
41
41
|
flow_id = 'c7985115-a9d2-446a-9c55-40b5728ffb52'
|
42
42
|
url = 'http://192.168.106.120:3002/api/v1/process/{}'.format(flow_id)
|
43
43
|
input_key = 'inputs'
|
@@ -53,7 +53,19 @@ class FlowTools(APIToolBase):
|
|
53
53
|
}
|
54
54
|
|
55
55
|
class InputArgs(BaseModel):
|
56
|
-
"""args_schema"""
|
57
56
|
query: str = Field(description='questions to ask')
|
58
57
|
|
59
58
|
return cls(url=url, params=params, input_key=input_key, args_schema=InputArgs)
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def get_api_tool(cls, name, **kwargs: Any) -> BaseTool:
|
62
|
+
attr_name = name.split('_', 1)[-1]
|
63
|
+
class_method = getattr(cls, attr_name)
|
64
|
+
function_description = kwargs.get('description','')
|
65
|
+
kwargs.pop('description')
|
66
|
+
return MultArgsSchemaTool(name=name + '_' +str(kwargs.get('collection_id')),
|
67
|
+
description=function_description,
|
68
|
+
func=class_method(**kwargs).run,
|
69
|
+
coroutine=class_method(**kwargs).arun,
|
70
|
+
args_schema=class_method(**kwargs).args_schema)
|
71
|
+
|
@@ -81,7 +81,7 @@ class MacroData(BaseModel):
|
|
81
81
|
JS_CHINA_GDP_YEARLY_URL = 'https://cdn.jin10.com/dc/reports/dc_chinese_gdp_yoy_all.js?v={}&_={}'
|
82
82
|
t = time.time()
|
83
83
|
r = requests.get(JS_CHINA_GDP_YEARLY_URL.format(str(int(round(t * 1000))), str(int(round(t * 1000)) + 90)))
|
84
|
-
json_data = json.loads(r.text[r.text.find('{')
|
84
|
+
json_data = json.loads(r.text[r.text.find('{'): r.text.rfind('}') + 1])
|
85
85
|
date_list = [item['date'] for item in json_data['list']]
|
86
86
|
value_list = [item['datas']['中国GDP年率报告'] for item in json_data['list']]
|
87
87
|
value_df = pd.DataFrame(value_list)
|
@@ -249,6 +249,60 @@ class MacroData(BaseModel):
|
|
249
249
|
temp_df = temp_df[(temp_df['月份'] >= start) & (temp_df['月份'] <= end)]
|
250
250
|
return temp_df.to_markdown()
|
251
251
|
|
252
|
+
@classmethod
|
253
|
+
def china_pmi(cls, start_date: str = '', end_date: str = '') -> str:
|
254
|
+
"""中国 PMI (采购经理人指数)月度统计数据。
|
255
|
+
返回数据包括:月份制造业 PMI,制造业 PMI 同比增长,非制造业 PMI,非制造业 PMI 同比增长。
|
256
|
+
"""
|
257
|
+
url = "https://datacenter-web.eastmoney.com/api/data/v1/get"
|
258
|
+
params = {
|
259
|
+
"columns": "REPORT_DATE,TIME,MAKE_INDEX,MAKE_SAME,NMAKE_INDEX,NMAKE_SAME",
|
260
|
+
"pageNumber": "1",
|
261
|
+
"pageSize": "2000",
|
262
|
+
"sortColumns": "REPORT_DATE",
|
263
|
+
"sortTypes": "-1",
|
264
|
+
"source": "WEB",
|
265
|
+
"client": "WEB",
|
266
|
+
"reportName": "RPT_ECONOMY_PMI",
|
267
|
+
"p": "1",
|
268
|
+
"pageNo": "1",
|
269
|
+
"pageNum": "1",
|
270
|
+
"_": "1669047266881",
|
271
|
+
}
|
272
|
+
r = requests.get(url, params=params)
|
273
|
+
data_json = r.json()
|
274
|
+
temp_df = pd.DataFrame(data_json["result"]["data"])
|
275
|
+
temp_df.columns = [
|
276
|
+
"-",
|
277
|
+
"月份",
|
278
|
+
"制造业-指数",
|
279
|
+
"制造业-同比增长",
|
280
|
+
"非制造业-指数",
|
281
|
+
"非制造业-同比增长",
|
282
|
+
]
|
283
|
+
temp_df = temp_df[
|
284
|
+
[
|
285
|
+
"月份",
|
286
|
+
"制造业-指数",
|
287
|
+
"制造业-同比增长",
|
288
|
+
"非制造业-指数",
|
289
|
+
"非制造业-同比增长",
|
290
|
+
]
|
291
|
+
]
|
292
|
+
temp_df["制造业-指数"] = pd.to_numeric(temp_df["制造业-指数"], errors="coerce")
|
293
|
+
temp_df["制造业-同比增长"] = pd.to_numeric(
|
294
|
+
temp_df["制造业-同比增长"], errors="coerce"
|
295
|
+
)
|
296
|
+
temp_df["非制造业-指数"] = pd.to_numeric(temp_df["非制造业-指数"], errors="coerce")
|
297
|
+
temp_df["非制造业-同比增长"] = pd.to_numeric(
|
298
|
+
temp_df["非制造业-同比增长"], errors="coerce"
|
299
|
+
)
|
300
|
+
if start_date and end_date:
|
301
|
+
start = start_date.split('-')[0] + '年' + start_date.split('-')[1] + '月份'
|
302
|
+
end = end_date.split('-')[0] + '年' + end_date.split('-')[1] + '月份'
|
303
|
+
temp_df = temp_df[(temp_df['月份'] >= start) & (temp_df['月份'] <= end)]
|
304
|
+
return temp_df.to_markdown()
|
305
|
+
|
252
306
|
@classmethod
|
253
307
|
def china_money_supply(cls, start_date: str = '', end_date: str = '') -> pd.DataFrame:
|
254
308
|
"""中国货币供应量(M2,M1,M0)月度统计数据。\
|
@@ -376,6 +430,121 @@ M0数量(单位:亿元),M0 同比(单位:%),M0 环比(单位
|
|
376
430
|
|
377
431
|
return temp_df.to_markdown()
|
378
432
|
|
433
|
+
@classmethod
|
434
|
+
def bond_zh_us_rate(cls, start_date: str = "", end_date: str = "") -> str:
|
435
|
+
"""
|
436
|
+
本接口返回指定时间段[start_date,end_date]内交易日的中美两国的 2 年、5 年、10 年、30 年、10 年-2 年收益率数据。
|
437
|
+
start_date表示起始日期,end_date表示结束日期,日期格式例如 2024-04-07
|
438
|
+
"""
|
439
|
+
url = "https://datacenter.eastmoney.com/api/data/get"
|
440
|
+
params = {
|
441
|
+
"type": "RPTA_WEB_TREASURYYIELD",
|
442
|
+
"sty": "ALL",
|
443
|
+
"st": "SOLAR_DATE",
|
444
|
+
"sr": "-1",
|
445
|
+
"token": "894050c76af8597a853f5b408b759f5d",
|
446
|
+
"p": "1",
|
447
|
+
"ps": "500",
|
448
|
+
"pageNo": "1",
|
449
|
+
"pageNum": "1",
|
450
|
+
"_": "1615791534490",
|
451
|
+
}
|
452
|
+
r = requests.get(url, params=params)
|
453
|
+
data_json = r.json()
|
454
|
+
total_page = data_json["result"]["pages"]
|
455
|
+
big_df = pd.DataFrame()
|
456
|
+
for page in range(1, total_page + 1):
|
457
|
+
params = {
|
458
|
+
"type": "RPTA_WEB_TREASURYYIELD",
|
459
|
+
"sty": "ALL",
|
460
|
+
"st": "SOLAR_DATE",
|
461
|
+
"sr": "-1",
|
462
|
+
"token": "894050c76af8597a853f5b408b759f5d",
|
463
|
+
"p": page,
|
464
|
+
"ps": "500",
|
465
|
+
"pageNo": page,
|
466
|
+
"pageNum": page,
|
467
|
+
"_": "1615791534490",
|
468
|
+
}
|
469
|
+
r = requests.get(url, params=params)
|
470
|
+
data_json = r.json()
|
471
|
+
# 时间过滤
|
472
|
+
if start_date and end_date:
|
473
|
+
temp_data = []
|
474
|
+
for item in data_json["result"]["data"]:
|
475
|
+
if start_date <= item["SOLAR_DATE"].split(" ")[0] <= end_date:
|
476
|
+
temp_data.append(item)
|
477
|
+
elif start_date > item["SOLAR_DATE"].split(" ")[0]:
|
478
|
+
break
|
479
|
+
else:
|
480
|
+
continue
|
481
|
+
else:
|
482
|
+
temp_data = data_json["result"]["data"]
|
483
|
+
temp_df = pd.DataFrame(temp_data)
|
484
|
+
for col in temp_df.columns:
|
485
|
+
if temp_df[col].isnull().all(): # 检查列是否包含 None 或 NaN
|
486
|
+
temp_df[col] = pd.to_numeric(temp_df[col], errors='coerce')
|
487
|
+
if big_df.empty:
|
488
|
+
big_df = temp_df
|
489
|
+
else:
|
490
|
+
big_df = pd.concat(objs=[big_df, temp_df], ignore_index=True)
|
491
|
+
|
492
|
+
big_df.rename(
|
493
|
+
columns={
|
494
|
+
"SOLAR_DATE": "日期",
|
495
|
+
"EMM00166462": "中国国债收益率5年",
|
496
|
+
"EMM00166466": "中国国债收益率10年",
|
497
|
+
"EMM00166469": "中国国债收益率30年",
|
498
|
+
"EMM00588704": "中国国债收益率2年",
|
499
|
+
"EMM01276014": "中国国债收益率10年-2年",
|
500
|
+
"EMG00001306": "美国国债收益率2年",
|
501
|
+
"EMG00001308": "美国国债收益率5年",
|
502
|
+
"EMG00001310": "美国国债收益率10年",
|
503
|
+
"EMG00001312": "美国国债收益率30年",
|
504
|
+
"EMG01339436": "美国国债收益率10年-2年",
|
505
|
+
"EMM00000024": "中国GDP年增率",
|
506
|
+
"EMG00159635": "美国GDP年增率",
|
507
|
+
},
|
508
|
+
inplace=True,
|
509
|
+
)
|
510
|
+
big_df = big_df[
|
511
|
+
[
|
512
|
+
"日期",
|
513
|
+
"中国国债收益率2年",
|
514
|
+
"中国国债收益率5年",
|
515
|
+
"中国国债收益率10年",
|
516
|
+
"中国国债收益率30年",
|
517
|
+
"中国国债收益率10年-2年",
|
518
|
+
"中国GDP年增率",
|
519
|
+
"美国国债收益率2年",
|
520
|
+
"美国国债收益率5年",
|
521
|
+
"美国国债收益率10年",
|
522
|
+
"美国国债收益率30年",
|
523
|
+
"美国国债收益率10年-2年",
|
524
|
+
"美国GDP年增率",
|
525
|
+
]
|
526
|
+
]
|
527
|
+
big_df = big_df.drop(["中国GDP年增率", "美国GDP年增率"], axis=1)
|
528
|
+
big_df["日期"] = pd.to_datetime(big_df["日期"], errors="coerce")
|
529
|
+
big_df["中国国债收益率2年"] = pd.to_numeric(big_df["中国国债收益率2年"], errors="coerce")
|
530
|
+
big_df["中国国债收益率5年"] = pd.to_numeric(big_df["中国国债收益率5年"], errors="coerce")
|
531
|
+
big_df["中国国债收益率10年"] = pd.to_numeric(big_df["中国国债收益率10年"], errors="coerce")
|
532
|
+
big_df["中国国债收益率30年"] = pd.to_numeric(big_df["中国国债收益率30年"], errors="coerce")
|
533
|
+
big_df["中国国债收益率10年-2年"] = pd.to_numeric(big_df["中国国债收益率10年-2年"], errors="coerce")
|
534
|
+
# big_df["中国GDP年增率"] = pd.to_numeric(big_df["中国GDP年增率"], errors="coerce")
|
535
|
+
big_df["美国国债收益率2年"] = pd.to_numeric(big_df["美国国债收益率2年"], errors="coerce")
|
536
|
+
big_df["美国国债收益率5年"] = pd.to_numeric(big_df["美国国债收益率5年"], errors="coerce")
|
537
|
+
big_df["美国国债收益率10年"] = pd.to_numeric(big_df["美国国债收益率10年"], errors="coerce")
|
538
|
+
big_df["美国国债收益率30年"] = pd.to_numeric(big_df["美国国债收益率30年"], errors="coerce")
|
539
|
+
big_df["美国国债收益率10年-2年"] = pd.to_numeric(big_df["美国国债收益率10年-2年"], errors="coerce")
|
540
|
+
# big_df["美国GDP年增率"] = pd.to_numeric(big_df["美国GDP年增率"], errors="coerce")
|
541
|
+
big_df.sort_values("日期", inplace=True)
|
542
|
+
big_df.set_index(["日期"], inplace=True)
|
543
|
+
big_df = big_df[pd.to_datetime(start_date):]
|
544
|
+
big_df.reset_index(inplace=True)
|
545
|
+
big_df["日期"] = pd.to_datetime(big_df["日期"]).dt.date
|
546
|
+
return big_df.to_markdown()
|
547
|
+
|
379
548
|
@classmethod
|
380
549
|
def get_api_tool(cls, name: str, **kwargs: Any) -> BaseTool:
|
381
550
|
attr_name = name.split('_', 1)[-1]
|
@@ -385,13 +554,15 @@ M0数量(单位:亿元),M0 同比(单位:%),M0 环比(单位
|
|
385
554
|
|
386
555
|
|
387
556
|
if __name__ == '__main__':
|
388
|
-
|
389
|
-
|
557
|
+
tmp_start_date = '2024-01-01'
|
558
|
+
tmp_end_date = '2024-01-03'
|
390
559
|
# start_date = ''
|
391
560
|
# end_date = ''
|
392
561
|
# print(MacroData.china_ppi(start_date=start_date, end_date=end_date))
|
393
562
|
# print(MacroData.china_shrzgm(start_date=start_date, end_date=end_date))
|
394
563
|
# print(MacroData.china_consumer_goods_retail(start_date=start_date, end_date=end_date))
|
395
564
|
# print(MacroData.china_cpi(start_date=start_date, end_date=end_date))
|
565
|
+
# print(MacroData.china_pmi(start_date=start_date, end_date=end_date))
|
396
566
|
# print(MacroData.china_money_supply(start_date=start_date, end_date=end_date))
|
397
|
-
print(MacroData.china_gdp_yearly(start_date=start_date, end_date=end_date))
|
567
|
+
# print(MacroData.china_gdp_yearly(start_date=start_date, end_date=end_date))
|
568
|
+
print(MacroData.bond_zh_us_rate(start_date=tmp_start_date, end_date=tmp_end_date))
|