bisheng-langchain 0.3.0rc0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. bisheng_langchain/chat_models/host_llm.py +1 -1
  2. bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +5 -3
  3. bisheng_langchain/gpts/agent_types/llm_functions_agent.py +7 -1
  4. bisheng_langchain/gpts/assistant.py +8 -5
  5. bisheng_langchain/gpts/auto_optimization.py +28 -27
  6. bisheng_langchain/gpts/auto_tool_selected.py +14 -15
  7. bisheng_langchain/gpts/load_tools.py +53 -1
  8. bisheng_langchain/gpts/prompts/__init__.py +4 -2
  9. bisheng_langchain/gpts/prompts/assistant_prompt_base.py +1 -0
  10. bisheng_langchain/gpts/prompts/assistant_prompt_cohere.py +19 -0
  11. bisheng_langchain/gpts/prompts/opening_dialog_prompt.py +1 -1
  12. bisheng_langchain/gpts/tools/api_tools/__init__.py +1 -1
  13. bisheng_langchain/gpts/tools/api_tools/base.py +3 -3
  14. bisheng_langchain/gpts/tools/api_tools/flow.py +19 -7
  15. bisheng_langchain/gpts/tools/api_tools/macro_data.py +175 -4
  16. bisheng_langchain/gpts/tools/api_tools/openapi.py +101 -0
  17. bisheng_langchain/gpts/tools/api_tools/sina.py +2 -2
  18. bisheng_langchain/gpts/tools/code_interpreter/tool.py +118 -39
  19. bisheng_langchain/rag/__init__.py +5 -0
  20. bisheng_langchain/rag/bisheng_rag_pipeline.py +320 -0
  21. bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +359 -0
  22. bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +376 -0
  23. bisheng_langchain/rag/bisheng_rag_tool.py +288 -0
  24. bisheng_langchain/rag/config/baseline.yaml +86 -0
  25. bisheng_langchain/rag/config/baseline_caibao.yaml +82 -0
  26. bisheng_langchain/rag/config/baseline_caibao_knowledge_v2.yaml +110 -0
  27. bisheng_langchain/rag/config/baseline_caibao_v2.yaml +112 -0
  28. bisheng_langchain/rag/config/baseline_demo_v2.yaml +92 -0
  29. bisheng_langchain/rag/config/baseline_s2b_mix.yaml +88 -0
  30. bisheng_langchain/rag/config/baseline_v2.yaml +90 -0
  31. bisheng_langchain/rag/extract_info.py +38 -0
  32. bisheng_langchain/rag/init_retrievers/__init__.py +4 -0
  33. bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +61 -0
  34. bisheng_langchain/rag/init_retrievers/keyword_retriever.py +65 -0
  35. bisheng_langchain/rag/init_retrievers/mix_retriever.py +103 -0
  36. bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +92 -0
  37. bisheng_langchain/rag/prompts/__init__.py +9 -0
  38. bisheng_langchain/rag/prompts/extract_key_prompt.py +34 -0
  39. bisheng_langchain/rag/prompts/prompt.py +47 -0
  40. bisheng_langchain/rag/prompts/prompt_cohere.py +111 -0
  41. bisheng_langchain/rag/qa_corpus/__init__.py +0 -0
  42. bisheng_langchain/rag/qa_corpus/qa_generator.py +143 -0
  43. bisheng_langchain/rag/rerank/__init__.py +5 -0
  44. bisheng_langchain/rag/rerank/rerank.py +48 -0
  45. bisheng_langchain/rag/rerank/rerank_benchmark.py +139 -0
  46. bisheng_langchain/rag/run_qa_gen_web.py +47 -0
  47. bisheng_langchain/rag/run_rag_evaluate_web.py +55 -0
  48. bisheng_langchain/rag/scoring/__init__.py +0 -0
  49. bisheng_langchain/rag/scoring/llama_index_score.py +91 -0
  50. bisheng_langchain/rag/scoring/ragas_score.py +183 -0
  51. bisheng_langchain/rag/utils.py +181 -0
  52. bisheng_langchain/retrievers/ensemble.py +2 -1
  53. bisheng_langchain/vectorstores/elastic_keywords_search.py +2 -1
  54. {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/METADATA +1 -1
  55. {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/RECORD +57 -22
  56. bisheng_langchain/gpts/prompts/base_prompt.py +0 -1
  57. {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/WHEEL +0 -0
  58. {bisheng_langchain-0.3.0rc0.dist-info → bisheng_langchain-0.3.1.dist-info}/top_level.txt +0 -0
@@ -163,7 +163,7 @@ class BaseHostChatLLM(BaseChatModel):
163
163
  values[
164
164
  'host_base_url'] = f"{values['host_base_url']}/{values['model_name']}/infer"
165
165
  except Exception:
166
- raise Exception(f'Update Decoupled status faild for model {model}')
166
+ raise Exception(f'Update Decoupled status failed for model {model}')
167
167
 
168
168
  try:
169
169
  if values['headers']:
@@ -85,16 +85,18 @@ class ElemUnstructuredLoader(BasePDFLoader):
85
85
  mode='partition',
86
86
  parameters=parameters)
87
87
 
88
- resp = requests.post(self.unstructured_api_url, headers=self.headers, json=payload).json()
88
+ resp = requests.post(self.unstructured_api_url, headers=self.headers, json=payload)
89
+ if resp.status_code != 200:
90
+ raise Exception(f'file partition {os.path.basename(self.file_name)} failed resp={resp.text}')
89
91
 
92
+ resp = resp.json()
90
93
  if 200 != resp.get('status_code'):
91
- logger.info(f'not return resp={resp}')
94
+ logger.info(f'file partition {os.path.basename(self.file_name)} error resp={resp}')
92
95
  partitions = resp['partitions']
93
96
  if not partitions:
94
97
  logger.info(f'partition_error resp={resp}')
95
98
  logger.info(f'unstruct_return code={resp.get("status_code")}')
96
99
 
97
- partitions = resp['partitions']
98
100
  content, metadata = merge_partitions(partitions)
99
101
  metadata['source'] = self.file_name
100
102
 
@@ -1,5 +1,5 @@
1
1
  import json
2
-
2
+ import re
3
3
  from bisheng_langchain.gpts.message_types import LiberalFunctionMessage, LiberalToolMessage
4
4
  from langchain.tools import BaseTool
5
5
  from langchain.tools.render import format_tool_to_openai_tool
@@ -39,6 +39,12 @@ def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageMode
39
39
  last_message = messages[-1]
40
40
  # If there is no function call, then we finish
41
41
  if 'tool_calls' not in last_message.additional_kwargs:
42
+ if '|<instruct>|' in system_message:
43
+ # cohere model
44
+ pattern = r"Answer:(.+)\nGrounded answer"
45
+ match = re.search(pattern, last_message.content)
46
+ if match:
47
+ last_message.content = match.group(1)
42
48
  return 'end'
43
49
  # Otherwise if there is, we continue
44
50
  else:
@@ -126,11 +126,14 @@ class BishengAssistant:
126
126
  if __name__ == "__main__":
127
127
  from langchain.globals import set_debug
128
128
 
129
- set_debug(True)
130
- # chat_history = []
131
- chat_history = ['你好', '你好,有什么可以帮助你吗?', '福蓉科技股价多少?', '福蓉科技(股票代码:300049)的当前股价为48.67元。']
132
- query = "去年这个时候的股价是多少?"
133
- bisheng_assistant = BishengAssistant("config/base_scene.yaml")
129
+ # set_debug(True)
130
+ chat_history = []
131
+ query = "请简要分析中科创达软件股份有限公司2019年聘任、解聘会计师事务的情况。"
132
+ # chat_history = ['你好', '你好,有什么可以帮助你吗?', '福蓉科技股价多少?', '福蓉科技(股票代码:300049)的当前股价为48.67元。']
133
+ # query = '去年这个时候的股价是多少?'
134
+ # bisheng_assistant = BishengAssistant("config/base_scene.yaml")
135
+ # bisheng_assistant = BishengAssistant("config/knowledge_scene.yaml")
136
+ bisheng_assistant = BishengAssistant("config/rag_scene.yaml")
134
137
  result = bisheng_assistant.run(query, chat_history=chat_history)
135
138
  for r in result:
136
139
  print(f'------------------')
@@ -3,7 +3,11 @@ import os
3
3
  import re
4
4
 
5
5
  import httpx
6
- from bisheng_langchain.gpts.prompts import ASSISTANT_PROMPT_OPT, BREIF_DES_PROMPT, OPENDIALOG_PROMPT
6
+ from bisheng_langchain.gpts.prompts import (
7
+ ASSISTANT_PROMPT_OPT,
8
+ BREIF_DES_PROMPT,
9
+ OPENDIALOG_PROMPT,
10
+ )
7
11
  from langchain_core.language_models.base import LanguageModelLike
8
12
  from langchain_openai.chat_models import ChatOpenAI
9
13
  from loguru import logger
@@ -48,16 +52,13 @@ def optimize_assistant_prompt(
48
52
  Returns:
49
53
  assistant_prompt(str):
50
54
  """
51
- chain = ({
52
- 'assistant_name': lambda x: x['assistant_name'],
53
- 'assistant_description': lambda x: x['assistant_description'],
54
- }
55
- | ASSISTANT_PROMPT_OPT
56
- | llm)
57
- chain_output = chain.invoke({
58
- 'assistant_name': assistant_name,
59
- 'assistant_description': assistant_description,
60
- })
55
+ chain = ASSISTANT_PROMPT_OPT | llm
56
+ chain_output = chain.invoke(
57
+ {
58
+ 'assistant_name': assistant_name,
59
+ 'assistant_description': assistant_description,
60
+ }
61
+ )
61
62
  response = chain_output.content
62
63
  assistant_prompt = parse_markdown(response)
63
64
  return assistant_prompt
@@ -67,17 +68,15 @@ def generate_opening_dialog(
67
68
  llm: LanguageModelLike,
68
69
  description: str,
69
70
  ) -> str:
70
- chain = ({
71
- 'description': lambda x: x['description'],
72
- }
73
- | OPENDIALOG_PROMPT
74
- | llm)
71
+ chain = OPENDIALOG_PROMPT | llm
75
72
  time = 0
76
73
  while time <= 3:
77
74
  try:
78
- chain_output = chain.invoke({
79
- 'description': description,
80
- })
75
+ chain_output = chain.invoke(
76
+ {
77
+ 'description': description,
78
+ }
79
+ )
81
80
  output = parse_json(chain_output.content)
82
81
  output = json.loads(output)
83
82
  opening_lines = output[0]['开场白']
@@ -101,20 +100,22 @@ def generate_breif_description(
101
100
  llm: LanguageModelLike,
102
101
  description: str,
103
102
  ) -> str:
104
- chain = ({
105
- 'description': lambda x: x['description'],
106
- }
107
- | BREIF_DES_PROMPT
108
- | llm)
109
- chain_output = chain.invoke({
110
- 'description': description,
111
- })
103
+ chain = BREIF_DES_PROMPT | llm
104
+ chain_output = chain.invoke(
105
+ {
106
+ 'description': description,
107
+ }
108
+ )
112
109
  breif_description = chain_output.content
113
110
  breif_description = breif_description.strip()
114
111
  return breif_description
115
112
 
116
113
 
117
114
  if __name__ == '__main__':
115
+ from dotenv import load_dotenv
116
+
117
+ load_dotenv('/app/.env', override=True)
118
+
118
119
  httpx_client = httpx.Client(proxies=os.getenv('OPENAI_PROXY'))
119
120
  llm = ChatOpenAI(model='gpt-4-0125-preview', temperature=0.01, http_client=httpx_client)
120
121
  # llm = ChatQWen(model="qwen1.5-72b-chat", temperature=0.01, api_key=os.getenv('QWEN_API_KEY'))
@@ -1,6 +1,9 @@
1
1
  from bisheng_langchain.gpts.prompts.select_tools_prompt import HUMAN_MSG, SYS_MSG
2
- from langchain.prompts import (ChatPromptTemplate, HumanMessagePromptTemplate,
3
- SystemMessagePromptTemplate)
2
+ from langchain.prompts import (
3
+ ChatPromptTemplate,
4
+ HumanMessagePromptTemplate,
5
+ SystemMessagePromptTemplate,
6
+ )
4
7
  from langchain_core.language_models.base import LanguageModelLike
5
8
  from pydantic import BaseModel
6
9
 
@@ -31,19 +34,15 @@ class ToolSelector:
31
34
  HumanMessagePromptTemplate.from_template(self.human_message),
32
35
  ]
33
36
 
34
- chain = ({
35
- 'tool_pool': lambda x: x['tool_pool'],
36
- 'task_name': lambda x: x['task_name'],
37
- 'task_description': lambda x: x['task_description'],
38
- }
39
- | ChatPromptTemplate.from_messages(messages)
40
- | self.llm)
41
-
42
- chain_output = chain.invoke({
43
- 'tool_pool': tool_pool,
44
- 'task_name': task_name,
45
- 'task_description': task_description,
46
- })
37
+ chain = ChatPromptTemplate.from_messages(messages) | self.llm
38
+
39
+ chain_output = chain.invoke(
40
+ {
41
+ 'tool_pool': tool_pool,
42
+ 'task_name': task_name,
43
+ 'task_description': task_description,
44
+ }
45
+ )
47
46
 
48
47
  try:
49
48
  all_tool_name = set([tool.tool_name for tool in self.tools])
@@ -1,7 +1,11 @@
1
+ import json
2
+ import os
1
3
  import warnings
2
4
  from typing import Any, Callable, Dict, List, Optional, Tuple
3
5
 
4
6
  import httpx
7
+ import pandas as pd
8
+ import pymysql
5
9
  from bisheng_langchain.gpts.tools.api_tools import ALL_API_TOOLS
6
10
  from bisheng_langchain.gpts.tools.bing_search.tool import BingSearchRun
7
11
  from bisheng_langchain.gpts.tools.calculator.tool import calculator
@@ -13,6 +17,7 @@ from bisheng_langchain.gpts.tools.dalle_image_generator.tool import (
13
17
  DallEImageGenerator,
14
18
  )
15
19
  from bisheng_langchain.gpts.tools.get_current_time.tool import get_current_time
20
+ from dotenv import load_dotenv
16
21
  from langchain_community.tools.arxiv.tool import ArxivQueryRun
17
22
  from langchain_community.tools.bearly.tool import BearlyInterpreterTool
18
23
  from langchain_community.utilities.arxiv import ArxivAPIWrapper
@@ -21,6 +26,7 @@ from langchain_core.callbacks import BaseCallbackManager, Callbacks
21
26
  from langchain_core.language_models import BaseLanguageModel
22
27
  from langchain_core.tools import BaseTool, Tool
23
28
  from mypy_extensions import Arg, KwArg
29
+ from bisheng_langchain.rag import BishengRAGTool
24
30
 
25
31
 
26
32
  def _get_current_time() -> BaseTool:
@@ -54,12 +60,14 @@ def _get_bing_search(**kwargs: Any) -> BaseTool:
54
60
 
55
61
  def _get_dalle_image_generator(**kwargs: Any) -> Tool:
56
62
  openai_api_key = kwargs.get('openai_api_key')
63
+ openai_api_base = kwargs.get('openai_api_base')
57
64
  http_async_client = httpx.AsyncClient(proxies=kwargs.get('openai_proxy'))
58
65
  httpc_client = httpx.Client(proxies=kwargs.get('openai_proxy'))
59
66
  return DallEImageGenerator(
60
67
  api_wrapper=DallEAPIWrapper(
61
68
  model='dall-e-3',
62
69
  api_key=openai_api_key,
70
+ base_url=openai_api_base,
63
71
  http_client=httpc_client,
64
72
  http_async_client=http_async_client,
65
73
  )
@@ -78,7 +86,8 @@ def _get_native_code_interpreter(**kwargs: Any) -> Tool:
78
86
  _EXTRA_PARAM_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[Optional[str]], List[Optional[str]]]] = { # type: ignore
79
87
  'dalle_image_generator': (_get_dalle_image_generator, ['openai_api_key', 'openai_proxy'], []),
80
88
  'bing_search': (_get_bing_search, ['bing_subscription_key', 'bing_search_url'], []),
81
- 'code_interpreter': (_get_native_code_interpreter, ["minio"], ['files']),
89
+ 'bisheng_code_interpreter': (_get_native_code_interpreter, ["minio"], ['files']),
90
+ 'bisheng_rag': (BishengRAGTool.get_rag_tool, ['name', 'description'], ['vector_store', 'keyword_store', 'llm', 'collection_name', 'max_content', 'sort_by_source_and_index']),
82
91
  }
83
92
 
84
93
  _API_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {**ALL_API_TOOLS} # type: ignore
@@ -159,3 +168,46 @@ def load_tools(
159
168
  def get_all_tool_names() -> List[str]:
160
169
  """Get a list of all possible tool names."""
161
170
  return list(_ALL_TOOLS.keys())
171
+
172
+
173
+ def get_tool_table():
174
+
175
+ load_dotenv('.sql_env', override=True)
176
+ db = pymysql.connect(
177
+ host=os.getenv('MYSQL_HOST'),
178
+ user=os.getenv('MYSQL_USER'),
179
+ password=os.getenv('MYSQL_PASSWORD'),
180
+ database=os.getenv('MYSQL_DATABASE'),
181
+ port=int(os.getenv('MYSQL_PORT')),
182
+ )
183
+ cursor = db.cursor()
184
+ cursor.execute("SELECT name, t.desc, tool_key, extra FROM t_gpts_tools as t;")
185
+ results = cursor.fetchall()
186
+ db.close()
187
+
188
+ df = pd.DataFrame(
189
+ columns=[
190
+ '前端工具名',
191
+ '前端工具描述',
192
+ 'tool_key',
193
+ 'tool参数配置',
194
+ 'function_name',
195
+ 'function_description',
196
+ 'function_args',
197
+ ]
198
+ )
199
+ for i, result in enumerate(results):
200
+ name, desc, tool_key, extra = result
201
+ if not extra:
202
+ extra = '{}'
203
+ tool_func = load_tools({tool_key: json.loads(extra)})[0]
204
+
205
+ df.loc[i, '前端工具名'] = name
206
+ df.loc[i, '前端工具描述'] = desc
207
+ df.loc[i, 'tool_key'] = tool_key
208
+ df.loc[i, 'tool参数配置'] = extra
209
+ df.loc[i, 'function_name'] = tool_func.name
210
+ df.loc[i, 'function_description'] = tool_func.description
211
+ df.loc[i, 'function_args'] = f"{tool_func.args_schema.schema()['properties']}"
212
+
213
+ return df
@@ -1,12 +1,14 @@
1
1
  from bisheng_langchain.gpts.prompts.assistant_prompt_opt import ASSISTANT_PROMPT_OPT
2
- from bisheng_langchain.gpts.prompts.base_prompt import DEFAULT_SYSTEM_MESSAGE
2
+ from bisheng_langchain.gpts.prompts.assistant_prompt_base import ASSISTANT_PROMPT_DEFAULT
3
+ from bisheng_langchain.gpts.prompts.assistant_prompt_cohere import ASSISTANT_PROMPT_COHERE
3
4
  from bisheng_langchain.gpts.prompts.breif_description_prompt import BREIF_DES_PROMPT
4
5
  from bisheng_langchain.gpts.prompts.opening_dialog_prompt import OPENDIALOG_PROMPT
5
6
  from bisheng_langchain.gpts.prompts.select_tools_prompt import HUMAN_MSG, SYS_MSG
6
7
 
7
8
 
8
9
  __all__ = [
9
- "DEFAULT_SYSTEM_MESSAGE",
10
+ "ASSISTANT_PROMPT_DEFAULT",
11
+ "ASSISTANT_PROMPT_COHERE",
10
12
  "ASSISTANT_PROMPT_OPT",
11
13
  "OPENDIALOG_PROMPT",
12
14
  "BREIF_DES_PROMPT",
@@ -0,0 +1 @@
1
+ ASSISTANT_PROMPT_DEFAULT = "You are a helpful assistant."
@@ -0,0 +1,19 @@
1
+ preamble="""You are a helpful assistant.
2
+ """
3
+
4
+ ASSISTANT_PROMPT_COHERE="""{preamble}|<instruct>|Carefully perform the following instructions, in order, starting each with a new line.
5
+ Firstly, You may need to use complex and advanced reasoning to complete your task and answer the question. Think about how you can use the provided tools to answer the question and come up with a high level plan you will execute.
6
+ Write 'Plan:' followed by an initial high level plan of how you will solve the problem including the tools and steps required.
7
+ Secondly, Carry out your plan by repeatedly using actions, reasoning over the results, and re-evaluating your plan. Perform Action, Observation, Reflection steps with the following format. Write 'Action:' followed by a json formatted action containing the "tool_name" and "parameters"
8
+ Next you will analyze the 'Observation:', this is the result of the action.
9
+ After that you should always think about what to do next. Write 'Reflection:' followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next including if you know the answer to the question.
10
+ ... (this Action/Observation/Reflection can repeat N times)
11
+ Thirdly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'.
12
+ Fourthly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'.
13
+ Fifthly, Write 'Answer:' followed by a response to the user's last input. Use the retrieved documents to help you. Do not insert any citations or grounding markup.
14
+ Finally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols <co: doc> and </co: doc> to indicate when a fact comes from a document in the search result, e.g <co: 4>my fact</co: 4> for a fact from document 4.
15
+
16
+ Additional instructions to note:
17
+ - If the user's question is in Chinese, please answer it in Chinese.
18
+ - 当问题中有涉及到时间信息时,比如最近6个月、昨天、去年等,你需要用时间工具查询时间信息。
19
+ """.format(preamble=preamble)
@@ -6,7 +6,7 @@ from langchain_core.prompts.chat import (
6
6
  )
7
7
 
8
8
  system_template = """
9
- 你是一个生成开场白和预置问题的助手。接下来,你会收到一段关于任务助手的描述,你需要带入描述中的角色,以描述中的角色身份生成一段开场白,同时你还需要以描述中的角色身份生成几个预置问题。输出格式如下:
9
+ 你是一个生成开场白和预置问题的助手。接下来,你会收到一段关于任务助手的描述,你需要带入描述中的角色,以描述中的角色身份生成一段开场白,同时你还需要站在用户的角度生成几个用户可能的提问。输出格式如下:
10
10
  [
11
11
  {{
12
12
  "开场白": "开场白内容",
@@ -40,7 +40,7 @@ _MACRO_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {
40
40
 
41
41
  _tmp_flow = ['knowledge_retrieve']
42
42
  _TMP_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {
43
- f'flow_{name}': (FlowTools.get_api_tool, ['collection_id'])
43
+ f'flow_{name}': (FlowTools.get_api_tool, ['collection_id', 'description'])
44
44
  for name in _tmp_flow
45
45
  }
46
46
  ALL_API_TOOLS = {}
@@ -64,7 +64,7 @@ class APIToolBase(BaseModel):
64
64
  resp = self.client.get(url)
65
65
  if resp.status_code != 200:
66
66
  logger.info('api_call_fail res={}', resp.text)
67
- return resp.text
67
+ return resp.text[:10000]
68
68
 
69
69
  async def arun(self, query: str, **kwargs) -> str:
70
70
  """Run query through api and parse result."""
@@ -79,8 +79,8 @@ class APIToolBase(BaseModel):
79
79
  url = self.url
80
80
  logger.info('api_call url={}', url)
81
81
  resp = await self.async_client.aget(url)
82
- logger.info(resp)
83
- return resp
82
+ logger.info(resp[:10000])
83
+ return resp[:10000]
84
84
 
85
85
  @classmethod
86
86
  def get_api_tool(cls, name, **kwargs: Any) -> BaseTool:
@@ -1,7 +1,9 @@
1
1
  from loguru import logger
2
- from pydantic import BaseModel, Field
3
-
2
+ from langchain_core.pydantic_v1 import BaseModel, Field
3
+ from typing import Any
4
4
  from .base import APIToolBase
5
+ from .base import MultArgsSchemaTool
6
+ from langchain_core.tools import BaseTool
5
7
 
6
8
 
7
9
  class FlowTools(APIToolBase):
@@ -34,10 +36,8 @@ class FlowTools(APIToolBase):
34
36
  return resp
35
37
 
36
38
  @classmethod
37
- def knowledge_retrieve(cls, collection_id: int = None) -> str:
38
- """
39
- 知识库检索工具,从内部知识库进行检索总结
40
- """
39
+ def knowledge_retrieve(cls, collection_id: int = None) -> str:
40
+
41
41
  flow_id = 'c7985115-a9d2-446a-9c55-40b5728ffb52'
42
42
  url = 'http://192.168.106.120:3002/api/v1/process/{}'.format(flow_id)
43
43
  input_key = 'inputs'
@@ -53,7 +53,19 @@ class FlowTools(APIToolBase):
53
53
  }
54
54
 
55
55
  class InputArgs(BaseModel):
56
- """args_schema"""
57
56
  query: str = Field(description='questions to ask')
58
57
 
59
58
  return cls(url=url, params=params, input_key=input_key, args_schema=InputArgs)
59
+
60
+ @classmethod
61
+ def get_api_tool(cls, name, **kwargs: Any) -> BaseTool:
62
+ attr_name = name.split('_', 1)[-1]
63
+ class_method = getattr(cls, attr_name)
64
+ function_description = kwargs.get('description','')
65
+ kwargs.pop('description')
66
+ return MultArgsSchemaTool(name=name + '_' +str(kwargs.get('collection_id')),
67
+ description=function_description,
68
+ func=class_method(**kwargs).run,
69
+ coroutine=class_method(**kwargs).arun,
70
+ args_schema=class_method(**kwargs).args_schema)
71
+
@@ -81,7 +81,7 @@ class MacroData(BaseModel):
81
81
  JS_CHINA_GDP_YEARLY_URL = 'https://cdn.jin10.com/dc/reports/dc_chinese_gdp_yoy_all.js?v={}&_={}'
82
82
  t = time.time()
83
83
  r = requests.get(JS_CHINA_GDP_YEARLY_URL.format(str(int(round(t * 1000))), str(int(round(t * 1000)) + 90)))
84
- json_data = json.loads(r.text[r.text.find('{') : r.text.rfind('}') + 1])
84
+ json_data = json.loads(r.text[r.text.find('{'): r.text.rfind('}') + 1])
85
85
  date_list = [item['date'] for item in json_data['list']]
86
86
  value_list = [item['datas']['中国GDP年率报告'] for item in json_data['list']]
87
87
  value_df = pd.DataFrame(value_list)
@@ -249,6 +249,60 @@ class MacroData(BaseModel):
249
249
  temp_df = temp_df[(temp_df['月份'] >= start) & (temp_df['月份'] <= end)]
250
250
  return temp_df.to_markdown()
251
251
 
252
+ @classmethod
253
+ def china_pmi(cls, start_date: str = '', end_date: str = '') -> str:
254
+ """中国 PMI (采购经理人指数)月度统计数据。
255
+ 返回数据包括:月份制造业 PMI,制造业 PMI 同比增长,非制造业 PMI,非制造业 PMI 同比增长。
256
+ """
257
+ url = "https://datacenter-web.eastmoney.com/api/data/v1/get"
258
+ params = {
259
+ "columns": "REPORT_DATE,TIME,MAKE_INDEX,MAKE_SAME,NMAKE_INDEX,NMAKE_SAME",
260
+ "pageNumber": "1",
261
+ "pageSize": "2000",
262
+ "sortColumns": "REPORT_DATE",
263
+ "sortTypes": "-1",
264
+ "source": "WEB",
265
+ "client": "WEB",
266
+ "reportName": "RPT_ECONOMY_PMI",
267
+ "p": "1",
268
+ "pageNo": "1",
269
+ "pageNum": "1",
270
+ "_": "1669047266881",
271
+ }
272
+ r = requests.get(url, params=params)
273
+ data_json = r.json()
274
+ temp_df = pd.DataFrame(data_json["result"]["data"])
275
+ temp_df.columns = [
276
+ "-",
277
+ "月份",
278
+ "制造业-指数",
279
+ "制造业-同比增长",
280
+ "非制造业-指数",
281
+ "非制造业-同比增长",
282
+ ]
283
+ temp_df = temp_df[
284
+ [
285
+ "月份",
286
+ "制造业-指数",
287
+ "制造业-同比增长",
288
+ "非制造业-指数",
289
+ "非制造业-同比增长",
290
+ ]
291
+ ]
292
+ temp_df["制造业-指数"] = pd.to_numeric(temp_df["制造业-指数"], errors="coerce")
293
+ temp_df["制造业-同比增长"] = pd.to_numeric(
294
+ temp_df["制造业-同比增长"], errors="coerce"
295
+ )
296
+ temp_df["非制造业-指数"] = pd.to_numeric(temp_df["非制造业-指数"], errors="coerce")
297
+ temp_df["非制造业-同比增长"] = pd.to_numeric(
298
+ temp_df["非制造业-同比增长"], errors="coerce"
299
+ )
300
+ if start_date and end_date:
301
+ start = start_date.split('-')[0] + '年' + start_date.split('-')[1] + '月份'
302
+ end = end_date.split('-')[0] + '年' + end_date.split('-')[1] + '月份'
303
+ temp_df = temp_df[(temp_df['月份'] >= start) & (temp_df['月份'] <= end)]
304
+ return temp_df.to_markdown()
305
+
252
306
  @classmethod
253
307
  def china_money_supply(cls, start_date: str = '', end_date: str = '') -> pd.DataFrame:
254
308
  """中国货币供应量(M2,M1,M0)月度统计数据。\
@@ -376,6 +430,121 @@ M0数量(单位:亿元),M0 同比(单位:%),M0 环比(单位
376
430
 
377
431
  return temp_df.to_markdown()
378
432
 
433
+ @classmethod
434
+ def bond_zh_us_rate(cls, start_date: str = "", end_date: str = "") -> str:
435
+ """
436
+ 本接口返回指定时间段[start_date,end_date]内交易日的中美两国的 2 年、5 年、10 年、30 年、10 年-2 年收益率数据。
437
+ start_date表示起始日期,end_date表示结束日期,日期格式例如 2024-04-07
438
+ """
439
+ url = "https://datacenter.eastmoney.com/api/data/get"
440
+ params = {
441
+ "type": "RPTA_WEB_TREASURYYIELD",
442
+ "sty": "ALL",
443
+ "st": "SOLAR_DATE",
444
+ "sr": "-1",
445
+ "token": "894050c76af8597a853f5b408b759f5d",
446
+ "p": "1",
447
+ "ps": "500",
448
+ "pageNo": "1",
449
+ "pageNum": "1",
450
+ "_": "1615791534490",
451
+ }
452
+ r = requests.get(url, params=params)
453
+ data_json = r.json()
454
+ total_page = data_json["result"]["pages"]
455
+ big_df = pd.DataFrame()
456
+ for page in range(1, total_page + 1):
457
+ params = {
458
+ "type": "RPTA_WEB_TREASURYYIELD",
459
+ "sty": "ALL",
460
+ "st": "SOLAR_DATE",
461
+ "sr": "-1",
462
+ "token": "894050c76af8597a853f5b408b759f5d",
463
+ "p": page,
464
+ "ps": "500",
465
+ "pageNo": page,
466
+ "pageNum": page,
467
+ "_": "1615791534490",
468
+ }
469
+ r = requests.get(url, params=params)
470
+ data_json = r.json()
471
+ # 时间过滤
472
+ if start_date and end_date:
473
+ temp_data = []
474
+ for item in data_json["result"]["data"]:
475
+ if start_date <= item["SOLAR_DATE"].split(" ")[0] <= end_date:
476
+ temp_data.append(item)
477
+ elif start_date > item["SOLAR_DATE"].split(" ")[0]:
478
+ break
479
+ else:
480
+ continue
481
+ else:
482
+ temp_data = data_json["result"]["data"]
483
+ temp_df = pd.DataFrame(temp_data)
484
+ for col in temp_df.columns:
485
+ if temp_df[col].isnull().all(): # 检查列是否包含 None 或 NaN
486
+ temp_df[col] = pd.to_numeric(temp_df[col], errors='coerce')
487
+ if big_df.empty:
488
+ big_df = temp_df
489
+ else:
490
+ big_df = pd.concat(objs=[big_df, temp_df], ignore_index=True)
491
+
492
+ big_df.rename(
493
+ columns={
494
+ "SOLAR_DATE": "日期",
495
+ "EMM00166462": "中国国债收益率5年",
496
+ "EMM00166466": "中国国债收益率10年",
497
+ "EMM00166469": "中国国债收益率30年",
498
+ "EMM00588704": "中国国债收益率2年",
499
+ "EMM01276014": "中国国债收益率10年-2年",
500
+ "EMG00001306": "美国国债收益率2年",
501
+ "EMG00001308": "美国国债收益率5年",
502
+ "EMG00001310": "美国国债收益率10年",
503
+ "EMG00001312": "美国国债收益率30年",
504
+ "EMG01339436": "美国国债收益率10年-2年",
505
+ "EMM00000024": "中国GDP年增率",
506
+ "EMG00159635": "美国GDP年增率",
507
+ },
508
+ inplace=True,
509
+ )
510
+ big_df = big_df[
511
+ [
512
+ "日期",
513
+ "中国国债收益率2年",
514
+ "中国国债收益率5年",
515
+ "中国国债收益率10年",
516
+ "中国国债收益率30年",
517
+ "中国国债收益率10年-2年",
518
+ "中国GDP年增率",
519
+ "美国国债收益率2年",
520
+ "美国国债收益率5年",
521
+ "美国国债收益率10年",
522
+ "美国国债收益率30年",
523
+ "美国国债收益率10年-2年",
524
+ "美国GDP年增率",
525
+ ]
526
+ ]
527
+ big_df = big_df.drop(["中国GDP年增率", "美国GDP年增率"], axis=1)
528
+ big_df["日期"] = pd.to_datetime(big_df["日期"], errors="coerce")
529
+ big_df["中国国债收益率2年"] = pd.to_numeric(big_df["中国国债收益率2年"], errors="coerce")
530
+ big_df["中国国债收益率5年"] = pd.to_numeric(big_df["中国国债收益率5年"], errors="coerce")
531
+ big_df["中国国债收益率10年"] = pd.to_numeric(big_df["中国国债收益率10年"], errors="coerce")
532
+ big_df["中国国债收益率30年"] = pd.to_numeric(big_df["中国国债收益率30年"], errors="coerce")
533
+ big_df["中国国债收益率10年-2年"] = pd.to_numeric(big_df["中国国债收益率10年-2年"], errors="coerce")
534
+ # big_df["中国GDP年增率"] = pd.to_numeric(big_df["中国GDP年增率"], errors="coerce")
535
+ big_df["美国国债收益率2年"] = pd.to_numeric(big_df["美国国债收益率2年"], errors="coerce")
536
+ big_df["美国国债收益率5年"] = pd.to_numeric(big_df["美国国债收益率5年"], errors="coerce")
537
+ big_df["美国国债收益率10年"] = pd.to_numeric(big_df["美国国债收益率10年"], errors="coerce")
538
+ big_df["美国国债收益率30年"] = pd.to_numeric(big_df["美国国债收益率30年"], errors="coerce")
539
+ big_df["美国国债收益率10年-2年"] = pd.to_numeric(big_df["美国国债收益率10年-2年"], errors="coerce")
540
+ # big_df["美国GDP年增率"] = pd.to_numeric(big_df["美国GDP年增率"], errors="coerce")
541
+ big_df.sort_values("日期", inplace=True)
542
+ big_df.set_index(["日期"], inplace=True)
543
+ big_df = big_df[pd.to_datetime(start_date):]
544
+ big_df.reset_index(inplace=True)
545
+ big_df["日期"] = pd.to_datetime(big_df["日期"]).dt.date
546
+ return big_df.to_markdown()
547
+
379
548
  @classmethod
380
549
  def get_api_tool(cls, name: str, **kwargs: Any) -> BaseTool:
381
550
  attr_name = name.split('_', 1)[-1]
@@ -385,13 +554,15 @@ M0数量(单位:亿元),M0 同比(单位:%),M0 环比(单位
385
554
 
386
555
 
387
556
  if __name__ == '__main__':
388
- start_date = '2023-01-01'
389
- end_date = '2023-05-01'
557
+ tmp_start_date = '2024-01-01'
558
+ tmp_end_date = '2024-01-03'
390
559
  # start_date = ''
391
560
  # end_date = ''
392
561
  # print(MacroData.china_ppi(start_date=start_date, end_date=end_date))
393
562
  # print(MacroData.china_shrzgm(start_date=start_date, end_date=end_date))
394
563
  # print(MacroData.china_consumer_goods_retail(start_date=start_date, end_date=end_date))
395
564
  # print(MacroData.china_cpi(start_date=start_date, end_date=end_date))
565
+ # print(MacroData.china_pmi(start_date=start_date, end_date=end_date))
396
566
  # print(MacroData.china_money_supply(start_date=start_date, end_date=end_date))
397
- print(MacroData.china_gdp_yearly(start_date=start_date, end_date=end_date))
567
+ # print(MacroData.china_gdp_yearly(start_date=start_date, end_date=end_date))
568
+ print(MacroData.bond_zh_us_rate(start_date=tmp_start_date, end_date=tmp_end_date))