bisheng-langchain 0.3.7.dev2__py3-none-any.whl → 0.4.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/document_loaders/custom_kv.py +2 -2
- bisheng_langchain/document_loaders/elem_pdf.py +3 -3
- bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +2 -2
- bisheng_langchain/document_loaders/parsers/image.py +1 -1
- bisheng_langchain/document_loaders/universal_kv.py +2 -2
- bisheng_langchain/gpts/agent_types/llm_functions_agent.py +35 -5
- bisheng_langchain/gpts/agent_types/llm_react_agent.py +27 -39
- bisheng_langchain/gpts/assistant.py +24 -24
- bisheng_langchain/gpts/tools/api_tools/base.py +1 -1
- bisheng_langchain/gpts/tools/api_tools/openapi.py +59 -32
- bisheng_langchain/rag/bisheng_rag_chain.py +26 -32
- bisheng_langchain/rag/bisheng_rag_tool.py +98 -98
- bisheng_langchain/rag/extract_info.py +0 -2
- bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +8 -12
- bisheng_langchain/rag/init_retrievers/keyword_retriever.py +8 -16
- bisheng_langchain/rag/init_retrievers/mix_retriever.py +16 -17
- bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +8 -8
- bisheng_langchain/sql/base.py +1 -1
- bisheng_langchain/vectorstores/elastic_keywords_search.py +17 -2
- bisheng_langchain/vectorstores/milvus.py +76 -69
- {bisheng_langchain-0.3.7.dev2.dist-info → bisheng_langchain-0.4.0.dev1.dist-info}/METADATA +6 -6
- {bisheng_langchain-0.3.7.dev2.dist-info → bisheng_langchain-0.4.0.dev1.dist-info}/RECORD +24 -24
- {bisheng_langchain-0.3.7.dev2.dist-info → bisheng_langchain-0.4.0.dev1.dist-info}/WHEEL +1 -1
- {bisheng_langchain-0.3.7.dev2.dist-info → bisheng_langchain-0.4.0.dev1.dist-info}/top_level.txt +0 -0
@@ -15,8 +15,8 @@ import cv2
|
|
15
15
|
import fitz
|
16
16
|
import numpy as np
|
17
17
|
from bisheng_langchain.utils.requests import Requests
|
18
|
-
from
|
19
|
-
from
|
18
|
+
from langchain_community.docstore.document import Document
|
19
|
+
from langchain_community.document_loaders.base import BaseLoader
|
20
20
|
from PIL import Image
|
21
21
|
|
22
22
|
logger = logging.getLogger(__name__)
|
@@ -12,9 +12,9 @@ from typing import List, Optional, Union
|
|
12
12
|
import fitz
|
13
13
|
import numpy as np
|
14
14
|
from bisheng_langchain.document_loaders.parsers import LayoutParser
|
15
|
-
from
|
16
|
-
from
|
17
|
-
from
|
15
|
+
from langchain_community.docstore.document import Document
|
16
|
+
from langchain_community.document_loaders.blob_loaders import Blob
|
17
|
+
from langchain_community.document_loaders.pdf import BasePDFLoader
|
18
18
|
from shapely import Polygon
|
19
19
|
from shapely import box as Rect
|
20
20
|
|
@@ -6,8 +6,8 @@ import os
|
|
6
6
|
from typing import List
|
7
7
|
|
8
8
|
import requests
|
9
|
-
from
|
10
|
-
from
|
9
|
+
from langchain_community.docstore.document import Document
|
10
|
+
from langchain_community.document_loaders.pdf import BasePDFLoader
|
11
11
|
|
12
12
|
logger = logging.getLogger(__name__)
|
13
13
|
|
@@ -11,8 +11,8 @@ import filetype
|
|
11
11
|
import fitz
|
12
12
|
import numpy as np
|
13
13
|
from bisheng_langchain.document_loaders.parsers import ELLMClient
|
14
|
-
from
|
15
|
-
from
|
14
|
+
from langchain_community.docstore.document import Document
|
15
|
+
from langchain_community.document_loaders.base import BaseLoader
|
16
16
|
from PIL import Image
|
17
17
|
|
18
18
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import json
|
2
2
|
import re
|
3
|
+
|
3
4
|
from bisheng_langchain.gpts.message_types import LiberalFunctionMessage, LiberalToolMessage
|
4
5
|
from langchain.tools import BaseTool
|
5
6
|
from langchain.tools.render import format_tool_to_openai_tool
|
@@ -8,13 +9,14 @@ from langchain_core.messages import FunctionMessage, SystemMessage, ToolMessage
|
|
8
9
|
from langgraph.graph import END
|
9
10
|
from langgraph.graph.message import MessageGraph
|
10
11
|
from langgraph.prebuilt import ToolExecutor, ToolInvocation
|
12
|
+
from langgraph.utils.runnable import RunnableCallable
|
11
13
|
|
12
14
|
|
13
15
|
def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageModelLike,
|
14
16
|
system_message: str, interrupt_before_action: bool,
|
15
17
|
**kwargs):
|
16
18
|
|
17
|
-
|
19
|
+
def _get_messages(messages):
|
18
20
|
msgs = []
|
19
21
|
for m in messages:
|
20
22
|
if isinstance(m, LiberalToolMessage):
|
@@ -31,6 +33,7 @@ def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageMode
|
|
31
33
|
llm_with_tools = llm.bind(tools=[format_tool_to_openai_tool(t) for t in tools])
|
32
34
|
else:
|
33
35
|
llm_with_tools = llm
|
36
|
+
|
34
37
|
agent = _get_messages | llm_with_tools
|
35
38
|
tool_executor = ToolExecutor(tools)
|
36
39
|
|
@@ -41,7 +44,7 @@ def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageMode
|
|
41
44
|
if 'tool_calls' not in last_message.additional_kwargs:
|
42
45
|
if '|<instruct>|' in system_message:
|
43
46
|
# cohere model
|
44
|
-
pattern = r
|
47
|
+
pattern = r'Answer:(.+)\nGrounded answer'
|
45
48
|
match = re.search(pattern, last_message.content)
|
46
49
|
if match:
|
47
50
|
last_message.content = match.group(1)
|
@@ -51,7 +54,7 @@ def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageMode
|
|
51
54
|
return 'continue'
|
52
55
|
|
53
56
|
# Define the function to execute tools
|
54
|
-
async def
|
57
|
+
async def acall_tool(messages):
|
55
58
|
actions: list[ToolInvocation] = []
|
56
59
|
# Based on the continue condition
|
57
60
|
# we know the last message involves a function call
|
@@ -78,11 +81,38 @@ def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageMode
|
|
78
81
|
]
|
79
82
|
return tool_messages
|
80
83
|
|
84
|
+
def call_tool(messages):
|
85
|
+
actions: list[ToolInvocation] = []
|
86
|
+
# Based on the continue condition
|
87
|
+
# we know the last message involves a function call
|
88
|
+
last_message = messages[-1]
|
89
|
+
for tool_call in last_message.additional_kwargs['tool_calls']:
|
90
|
+
function = tool_call['function']
|
91
|
+
function_name = function['name']
|
92
|
+
_tool_input = json.loads(function['arguments'] or '{}')
|
93
|
+
# We construct an ToolInvocation from the function_call
|
94
|
+
actions.append(ToolInvocation(
|
95
|
+
tool=function_name,
|
96
|
+
tool_input=_tool_input,
|
97
|
+
))
|
98
|
+
# We call the tool_executor and get back a response
|
99
|
+
responses = tool_executor.batch(actions, **kwargs)
|
100
|
+
# We use the response to create a ToolMessage
|
101
|
+
tool_messages = [
|
102
|
+
LiberalToolMessage(
|
103
|
+
tool_call_id=tool_call['id'],
|
104
|
+
content=response,
|
105
|
+
additional_kwargs={'name': tool_call['function']['name']},
|
106
|
+
)
|
107
|
+
for tool_call, response in zip(last_message.additional_kwargs['tool_calls'], responses)
|
108
|
+
]
|
109
|
+
return tool_messages
|
110
|
+
|
81
111
|
workflow = MessageGraph()
|
82
112
|
|
83
113
|
# Define the two nodes we will cycle between
|
84
114
|
workflow.add_node('agent', agent)
|
85
|
-
workflow.add_node('action', call_tool)
|
115
|
+
workflow.add_node('action', RunnableCallable(call_tool, acall_tool))
|
86
116
|
|
87
117
|
# Set the entrypoint as `agent`
|
88
118
|
# This means that this node is the first one called
|
@@ -116,7 +146,7 @@ def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageMode
|
|
116
146
|
# Finally, we compile it!
|
117
147
|
# This compiles it into a LangChain Runnable,
|
118
148
|
# meaning you can use it as you would any other runnable
|
119
|
-
app = workflow.compile()
|
149
|
+
app = workflow.compile(checkpointer=False)
|
120
150
|
if interrupt_before_action:
|
121
151
|
app.interrupt = ['action:inbox']
|
122
152
|
return app
|
@@ -1,24 +1,20 @@
|
|
1
1
|
import operator
|
2
2
|
from typing import Annotated, Sequence, TypedDict, Union
|
3
|
+
|
4
|
+
from bisheng_langchain.gpts.prompts.react_agent_prompt import react_agent_prompt
|
5
|
+
from langchain.agents import create_structured_chat_agent
|
3
6
|
from langchain.tools import BaseTool
|
4
7
|
from langchain_core.agents import AgentAction, AgentFinish
|
5
|
-
from langchain_core.messages import BaseMessage
|
6
8
|
from langchain_core.language_models import LanguageModelLike
|
9
|
+
from langchain_core.messages import BaseMessage
|
7
10
|
from langgraph.graph import END, StateGraph
|
8
11
|
from langgraph.graph.state import CompiledStateGraph
|
9
12
|
from langgraph.prebuilt.tool_executor import ToolExecutor
|
10
|
-
from langgraph.utils import RunnableCallable
|
11
|
-
from langchain.agents import create_structured_chat_agent
|
12
|
-
from bisheng_langchain.gpts.prompts.react_agent_prompt import react_agent_prompt
|
13
|
+
from langgraph.utils.runnable import RunnableCallable
|
13
14
|
|
14
15
|
|
15
|
-
def get_react_agent_executor(
|
16
|
-
|
17
|
-
llm: LanguageModelLike,
|
18
|
-
system_message: str,
|
19
|
-
interrupt_before_action: bool,
|
20
|
-
**kwargs
|
21
|
-
):
|
16
|
+
def get_react_agent_executor(tools: list[BaseTool], llm: LanguageModelLike, system_message: str,
|
17
|
+
interrupt_before_action: bool, **kwargs):
|
22
18
|
prompt = react_agent_prompt
|
23
19
|
prompt = prompt.partial(assistant_message=system_message)
|
24
20
|
agent = create_structured_chat_agent(llm, tools, prompt)
|
@@ -56,9 +52,7 @@ def _get_agent_state(input_schema=None):
|
|
56
52
|
return AgentState
|
57
53
|
|
58
54
|
|
59
|
-
def create_agent_executor(
|
60
|
-
agent_runnable, tools, input_schema=None
|
61
|
-
) -> CompiledStateGraph:
|
55
|
+
def create_agent_executor(agent_runnable, tools, input_schema=None) -> CompiledStateGraph:
|
62
56
|
"""This is a helper function for creating a graph that works with LangChain Agents.
|
63
57
|
|
64
58
|
Args:
|
@@ -82,65 +76,59 @@ def create_agent_executor(
|
|
82
76
|
def should_continue(data):
|
83
77
|
# If the agent outcome is an AgentFinish, then we return `exit` string
|
84
78
|
# This will be used when setting up the graph to define the flow
|
85
|
-
if isinstance(data[
|
86
|
-
return
|
79
|
+
if isinstance(data['agent_outcome'], AgentFinish):
|
80
|
+
return 'end'
|
87
81
|
# Otherwise, an AgentAction is returned
|
88
82
|
# Here we return `continue` string
|
89
83
|
# This will be used when setting up the graph to define the flow
|
90
84
|
else:
|
91
|
-
return
|
85
|
+
return 'continue'
|
92
86
|
|
93
87
|
def run_agent(data, config):
|
94
88
|
agent_outcome = agent_runnable.invoke(data, config)
|
95
|
-
return {
|
89
|
+
return {'agent_outcome': agent_outcome}
|
96
90
|
|
97
91
|
async def arun_agent(data, config):
|
98
92
|
agent_outcome = await agent_runnable.ainvoke(data, config)
|
99
|
-
return {
|
93
|
+
return {'agent_outcome': agent_outcome}
|
100
94
|
|
101
95
|
# Define the function to execute tools
|
102
96
|
def execute_tools(data, config):
|
103
97
|
# Get the most recent agent_outcome - this is the key added in the `agent` above
|
104
|
-
agent_action = data[
|
98
|
+
agent_action = data['agent_outcome']
|
105
99
|
if not isinstance(agent_action, list):
|
106
100
|
agent_action = [agent_action]
|
107
101
|
output = tool_executor.batch(agent_action, config, return_exceptions=True)
|
108
102
|
return {
|
109
|
-
|
110
|
-
(action, str(out)) for action, out in zip(agent_action, output)
|
111
|
-
]
|
103
|
+
'intermediate_steps': [(action, str(out)) for action, out in zip(agent_action, output)]
|
112
104
|
}
|
113
105
|
|
114
106
|
async def aexecute_tools(data, config):
|
115
107
|
# Get the most recent agent_outcome - this is the key added in the `agent` above
|
116
|
-
agent_action = data[
|
108
|
+
agent_action = data['agent_outcome']
|
117
109
|
if not isinstance(agent_action, list):
|
118
110
|
agent_action = [agent_action]
|
119
|
-
output = await tool_executor.abatch(
|
120
|
-
agent_action, config, return_exceptions=True
|
121
|
-
)
|
111
|
+
output = await tool_executor.abatch(agent_action, config, return_exceptions=True)
|
122
112
|
return {
|
123
|
-
|
124
|
-
(action, str(out)) for action, out in zip(agent_action, output)
|
125
|
-
]
|
113
|
+
'intermediate_steps': [(action, str(out)) for action, out in zip(agent_action, output)]
|
126
114
|
}
|
127
115
|
|
128
116
|
# Define a new graph
|
129
117
|
workflow = StateGraph(state)
|
130
118
|
|
131
119
|
# Define the two nodes we will cycle between
|
132
|
-
workflow.add_node(
|
133
|
-
workflow.add_node(
|
120
|
+
workflow.add_node('agent', RunnableCallable(run_agent, arun_agent))
|
121
|
+
workflow.add_node('tools', RunnableCallable(execute_tools, aexecute_tools))
|
134
122
|
|
135
123
|
# Set the entrypoint as `agent`
|
136
124
|
# This means that this node is the first one called
|
137
|
-
workflow.set_entry_point(
|
125
|
+
workflow.set_entry_point('agent')
|
138
126
|
|
139
127
|
# We now add a conditional edge
|
140
128
|
workflow.add_conditional_edges(
|
141
129
|
# First, we define the start node. We use `agent`.
|
142
130
|
# This means these are the edges taken after the `agent` node is called.
|
143
|
-
|
131
|
+
'agent',
|
144
132
|
# Next, we pass in the function that will determine which node is called next.
|
145
133
|
should_continue,
|
146
134
|
# Finally we pass in a mapping.
|
@@ -151,21 +139,21 @@ def create_agent_executor(
|
|
151
139
|
# Based on which one it matches, that node will then be called.
|
152
140
|
{
|
153
141
|
# If `tools`, then we call the tool node.
|
154
|
-
|
142
|
+
'continue': 'tools',
|
155
143
|
# Otherwise we finish.
|
156
|
-
|
144
|
+
'end': END,
|
157
145
|
},
|
158
146
|
)
|
159
147
|
|
160
148
|
# We now add a normal edge from `tools` to `agent`.
|
161
149
|
# This means that after `tools` is called, `agent` node is called next.
|
162
|
-
workflow.add_edge(
|
150
|
+
workflow.add_edge('tools', 'agent')
|
163
151
|
|
164
152
|
# Finally, we compile it!
|
165
153
|
# This compiles it into a LangChain Runnable,
|
166
154
|
# meaning you can use it as you would any other runnable
|
167
|
-
return workflow.compile()
|
155
|
+
return workflow.compile(checkpointer=False)
|
168
156
|
|
169
157
|
|
170
|
-
if __name__ ==
|
158
|
+
if __name__ == '__main__':
|
171
159
|
pass
|
@@ -1,9 +1,6 @@
|
|
1
1
|
import asyncio
|
2
2
|
import logging
|
3
|
-
from enum import Enum
|
4
|
-
from functools import lru_cache
|
5
3
|
from typing import Any, Mapping, Optional, Sequence
|
6
|
-
from urllib.parse import urlparse
|
7
4
|
|
8
5
|
import httpx
|
9
6
|
import yaml
|
@@ -11,7 +8,7 @@ from bisheng_langchain.gpts.load_tools import get_all_tool_names, load_tools
|
|
11
8
|
from bisheng_langchain.gpts.utils import import_by_type, import_class
|
12
9
|
from langchain.tools import BaseTool
|
13
10
|
from langchain_core.language_models.base import LanguageModelLike
|
14
|
-
from langchain_core.messages import
|
11
|
+
from langchain_core.messages import AIMessage, HumanMessage
|
15
12
|
from langchain_core.runnables import RunnableBinding
|
16
13
|
|
17
14
|
logger = logging.getLogger(__name__)
|
@@ -38,11 +35,13 @@ class ConfigurableAssistant(RunnableBinding):
|
|
38
35
|
config: Optional[Mapping[str, Any]] = None,
|
39
36
|
**others: Any,
|
40
37
|
) -> None:
|
41
|
-
others.pop(
|
42
|
-
agent_executor_object = import_class(
|
38
|
+
others.pop('bound', None)
|
39
|
+
agent_executor_object = import_class(
|
40
|
+
f'bisheng_langchain.gpts.agent_types.{agent_executor_type}')
|
43
41
|
|
44
|
-
_agent_executor = agent_executor_object(tools, llm, assistant_message,
|
45
|
-
|
42
|
+
_agent_executor = agent_executor_object(tools, llm, assistant_message,
|
43
|
+
interrupt_before_action)
|
44
|
+
agent_executor = _agent_executor.with_config({'recursion_limit': recursion_limit})
|
46
45
|
super().__init__(
|
47
46
|
agent_executor_type=agent_executor_type,
|
48
47
|
tools=tools,
|
@@ -88,7 +87,7 @@ class BishengAssistant:
|
|
88
87
|
tool_type = tool.pop('type')
|
89
88
|
tool_config = tool if tool else {}
|
90
89
|
if tool_type not in available_tools:
|
91
|
-
raise ValueError(f
|
90
|
+
raise ValueError(f'Tool type {tool_type} not found in TOOLS')
|
92
91
|
_returned_tools = load_tools({tool_type: tool_config})
|
93
92
|
if isinstance(_returned_tools, list):
|
94
93
|
tools.extend(_returned_tools)
|
@@ -98,43 +97,44 @@ class BishengAssistant:
|
|
98
97
|
# init agent executor
|
99
98
|
agent_executor_params = self.assistant_params['agent_executor']
|
100
99
|
self.agent_executor_type = agent_executor_params.pop('type')
|
101
|
-
self.assistant = ConfigurableAssistant(
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
**agent_executor_params
|
107
|
-
)
|
100
|
+
self.assistant = ConfigurableAssistant(agent_executor_type=self.agent_executor_type,
|
101
|
+
tools=tools,
|
102
|
+
llm=llm,
|
103
|
+
assistant_message=assistant_message,
|
104
|
+
**agent_executor_params)
|
108
105
|
|
109
106
|
def run(self, query, chat_history=[], chat_round=5):
|
110
107
|
if len(chat_history) % 2 != 0:
|
111
|
-
raise ValueError(
|
112
|
-
|
108
|
+
raise ValueError('chat history should be even')
|
109
|
+
|
113
110
|
# 限制chat_history轮数
|
114
111
|
if len(chat_history) > chat_round * 2:
|
115
|
-
chat_history = chat_history[-chat_round*2:]
|
112
|
+
chat_history = chat_history[-chat_round * 2:]
|
116
113
|
|
117
114
|
inputs = []
|
118
115
|
for i in range(0, len(chat_history), 2):
|
119
116
|
inputs.append(HumanMessage(content=chat_history[i]))
|
120
|
-
inputs.append(AIMessage(content=chat_history[i+1]))
|
117
|
+
inputs.append(AIMessage(content=chat_history[i + 1]))
|
121
118
|
inputs.append(HumanMessage(content=query))
|
122
119
|
if self.agent_executor_type == 'get_react_agent_executor':
|
123
|
-
result = asyncio.run(
|
120
|
+
result = asyncio.run(
|
121
|
+
self.assistant.ainvoke({
|
122
|
+
'input': inputs[-1].content,
|
123
|
+
'chat_history': inputs[:-1]
|
124
|
+
}))
|
124
125
|
else:
|
125
126
|
result = asyncio.run(self.assistant.ainvoke(inputs))
|
126
127
|
return result
|
127
128
|
|
128
129
|
|
129
|
-
if __name__ ==
|
130
|
-
from langchain.globals import set_debug
|
130
|
+
if __name__ == '__main__':
|
131
131
|
|
132
132
|
# set_debug(True)
|
133
133
|
# chat_history = []
|
134
134
|
# query = "分析当日市场行情"
|
135
135
|
chat_history = ['你好', '你好,有什么可以帮助你吗?', '福蓉科技股价多少?', '福蓉科技(股票代码:300049)的当前股价为48.67元。']
|
136
136
|
query = '今天是什么时候?去年这个时候的股价是多少?'
|
137
|
-
bisheng_assistant = BishengAssistant(
|
137
|
+
bisheng_assistant = BishengAssistant('config/base_scene.yaml')
|
138
138
|
# bisheng_assistant = BishengAssistant("config/knowledge_scene.yaml")
|
139
139
|
# bisheng_assistant = BishengAssistant("config/rag_scene.yaml")
|
140
140
|
result = bisheng_assistant.run(query, chat_history=chat_history)
|
@@ -29,7 +29,7 @@ class APIToolBase(BaseModel):
|
|
29
29
|
headers: Dict[str, Any] = {}
|
30
30
|
request_timeout: int = 30
|
31
31
|
url: str = None
|
32
|
-
params: Dict[str, Any] =
|
32
|
+
params: Dict[str, Any] = {}
|
33
33
|
input_key: str = 'keyword'
|
34
34
|
args_schema: Type[BaseModel] = ApiArg
|
35
35
|
|
@@ -4,27 +4,27 @@ from langchain_core.tools import BaseTool
|
|
4
4
|
from loguru import logger
|
5
5
|
from pydantic import BaseModel, create_model
|
6
6
|
|
7
|
-
from .base import APIToolBase,
|
7
|
+
from .base import APIToolBase, Field, MultArgsSchemaTool
|
8
8
|
|
9
9
|
|
10
10
|
class OpenApiTools(APIToolBase):
|
11
11
|
|
12
12
|
def get_real_path(self):
|
13
|
-
return self.url + self.params[
|
13
|
+
return self.url + self.params['path']
|
14
14
|
|
15
15
|
def get_request_method(self):
|
16
|
-
return self.params[
|
16
|
+
return self.params['method'].lower()
|
17
17
|
|
18
18
|
def get_params_json(self, **kwargs):
|
19
19
|
params_define = {}
|
20
|
-
for one in self.params[
|
21
|
-
params_define[one[
|
20
|
+
for one in self.params['parameters']:
|
21
|
+
params_define[one['name']] = one
|
22
22
|
|
23
23
|
params = {}
|
24
24
|
json_data = {}
|
25
25
|
for k, v in kwargs.items():
|
26
26
|
if params_define.get(k):
|
27
|
-
if params_define[k][
|
27
|
+
if params_define[k]['in'] == 'query':
|
28
28
|
params[k] = v
|
29
29
|
else:
|
30
30
|
json_data[k] = v
|
@@ -33,66 +33,93 @@ class OpenApiTools(APIToolBase):
|
|
33
33
|
return params, json_data
|
34
34
|
|
35
35
|
def parse_args_schema(self):
|
36
|
-
params = self.params[
|
36
|
+
params = self.params['parameters']
|
37
37
|
model_params = {}
|
38
38
|
for one in params:
|
39
|
-
field_type = one[
|
40
|
-
if field_type ==
|
41
|
-
field_type =
|
42
|
-
elif field_type ==
|
43
|
-
field_type =
|
44
|
-
elif field_type ==
|
45
|
-
field_type =
|
39
|
+
field_type = one['schema']['type']
|
40
|
+
if field_type == 'number':
|
41
|
+
field_type = float
|
42
|
+
elif field_type == 'integer':
|
43
|
+
field_type = int
|
44
|
+
elif field_type == 'string':
|
45
|
+
field_type = str
|
46
|
+
elif field_type in {'object', 'dict'}:
|
47
|
+
param_object_param = {}
|
48
|
+
for param in one['schema']['properties'].keys():
|
49
|
+
field_type = one['schema']['properties'][param]['type']
|
50
|
+
if field_type == 'number':
|
51
|
+
field_type = float
|
52
|
+
elif field_type == 'integer':
|
53
|
+
field_type = int
|
54
|
+
elif field_type == 'string':
|
55
|
+
field_type = str
|
56
|
+
param_object_param[param] = (
|
57
|
+
field_type,
|
58
|
+
Field(description=one['schema']['properties'][param]['description']))
|
59
|
+
param_model = create_model(
|
60
|
+
param,
|
61
|
+
__module__='bisheng_langchain.gpts.tools.api_tools.openapi',
|
62
|
+
**param_object_param)
|
63
|
+
field_type = param_model
|
46
64
|
else:
|
47
|
-
raise Exception(f
|
48
|
-
model_params[one[
|
49
|
-
return create_model(
|
50
|
-
|
65
|
+
raise Exception(f'schema type is not support: {field_type}')
|
66
|
+
model_params[one['name']] = (field_type, Field(description=one['description']))
|
67
|
+
return create_model('InputArgs',
|
68
|
+
__module__='bisheng_langchain.gpts.tools.api_tools.openapi',
|
69
|
+
__base__=BaseModel,
|
70
|
+
**model_params)
|
51
71
|
|
52
72
|
def run(self, **kwargs) -> str:
|
53
73
|
"""Run query through api and parse result."""
|
74
|
+
extra = {}
|
75
|
+
if 'proxy' in kwargs:
|
76
|
+
extra['proxy'] = kwargs.pop('proxy')
|
54
77
|
path = self.get_real_path()
|
55
78
|
logger.info('api_call url={}', path)
|
56
79
|
method = self.get_request_method()
|
57
80
|
params, json_data = self.get_params_json(**kwargs)
|
58
81
|
|
59
|
-
if method ==
|
60
|
-
resp = self.client.get(path, params=params)
|
82
|
+
if method == 'get':
|
83
|
+
resp = self.client.get(path, params=params, **extra)
|
61
84
|
elif method == 'post':
|
62
|
-
resp = self.client.post(path, params=params, json=json_data)
|
85
|
+
resp = self.client.post(path, params=params, json=json_data, **extra)
|
63
86
|
elif method == 'put':
|
64
|
-
resp = self.client.put(path, params=params, json=json_data)
|
87
|
+
resp = self.client.put(path, params=params, json=json_data, **extra)
|
65
88
|
elif method == 'delete':
|
66
|
-
resp = self.client.delete(path, params=params, json=json_data)
|
89
|
+
resp = self.client.delete(path, params=params, json=json_data, **extra)
|
67
90
|
else:
|
68
|
-
raise Exception(f
|
91
|
+
raise Exception(f'http method is not support: {method}')
|
69
92
|
if resp.status_code != 200:
|
70
93
|
logger.info(f'api_call_fail code={resp.status_code} res={resp.text}')
|
71
|
-
raise Exception(f
|
94
|
+
raise Exception(f'api_call_fail: {resp.status_code} {resp.text}')
|
72
95
|
return resp.text
|
73
96
|
|
74
97
|
async def arun(self, **kwargs) -> str:
|
75
98
|
"""Run query through api and parse result."""
|
99
|
+
extra = {}
|
100
|
+
if 'proxy' in kwargs:
|
101
|
+
extra['proxy'] = kwargs.pop('proxy')
|
102
|
+
|
76
103
|
path = self.get_real_path()
|
77
104
|
logger.info('api_call url={}', path)
|
78
105
|
method = self.get_request_method()
|
79
106
|
params, json_data = self.get_params_json(**kwargs)
|
80
107
|
|
81
|
-
if method ==
|
82
|
-
resp = await self.async_client.aget(path, params=params)
|
108
|
+
if method == 'get':
|
109
|
+
resp = await self.async_client.aget(path, params=params, **extra)
|
83
110
|
elif method == 'post':
|
84
|
-
resp = await self.async_client.apost(path, params=params, json=json_data)
|
111
|
+
resp = await self.async_client.apost(path, params=params, json=json_data, **extra)
|
85
112
|
elif method == 'put':
|
86
|
-
resp = await self.async_client.aput(path, params=params, json=json_data)
|
113
|
+
resp = await self.async_client.aput(path, params=params, json=json_data, **extra)
|
87
114
|
elif method == 'delete':
|
88
|
-
resp = await self.async_client.adelete(path, params=params, json=json_data)
|
115
|
+
resp = await self.async_client.adelete(path, params=params, json=json_data, **extra)
|
89
116
|
else:
|
90
|
-
raise Exception(f
|
117
|
+
raise Exception(f'http method is not support: {method}')
|
91
118
|
return resp
|
92
119
|
|
93
120
|
@classmethod
|
94
121
|
def get_api_tool(cls, name, **kwargs: Any) -> BaseTool:
|
95
|
-
description = kwargs.pop(
|
122
|
+
description = kwargs.pop('description', '')
|
96
123
|
obj = cls(**kwargs)
|
97
124
|
return MultArgsSchemaTool(name=name,
|
98
125
|
description=description,
|