bisheng-langchain 0.4.0.dev1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/gpts/agent_types/llm_functions_agent.py +13 -3
- bisheng_langchain/gpts/load_tools.py +16 -10
- bisheng_langchain/gpts/tools/api_tools/openapi.py +21 -7
- bisheng_langchain/gpts/tools/sql_agent/__init__.py +0 -0
- bisheng_langchain/gpts/tools/sql_agent/tool.py +262 -0
- bisheng_langchain/rag/bisheng_rag_tool.py +7 -6
- {bisheng_langchain-0.4.0.dev1.dist-info → bisheng_langchain-0.4.1.dist-info}/METADATA +2 -2
- {bisheng_langchain-0.4.0.dev1.dist-info → bisheng_langchain-0.4.1.dist-info}/RECORD +10 -8
- {bisheng_langchain-0.4.0.dev1.dist-info → bisheng_langchain-0.4.1.dist-info}/WHEEL +0 -0
- {bisheng_langchain-0.4.0.dev1.dist-info → bisheng_langchain-0.4.1.dist-info}/top_level.txt +0 -0
@@ -62,7 +62,10 @@ def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageMode
|
|
62
62
|
for tool_call in last_message.additional_kwargs['tool_calls']:
|
63
63
|
function = tool_call['function']
|
64
64
|
function_name = function['name']
|
65
|
-
|
65
|
+
try:
|
66
|
+
_tool_input = json.loads(function['arguments'] or '{}')
|
67
|
+
except Exception as e:
|
68
|
+
raise Exception(f"Error parsing arguments for function: {function_name}. arguments: {function['arguments']}. error: {str(e)}")
|
66
69
|
# We construct an ToolInvocation from the function_call
|
67
70
|
actions.append(ToolInvocation(
|
68
71
|
tool=function_name,
|
@@ -89,7 +92,10 @@ def get_openai_functions_agent_executor(tools: list[BaseTool], llm: LanguageMode
|
|
89
92
|
for tool_call in last_message.additional_kwargs['tool_calls']:
|
90
93
|
function = tool_call['function']
|
91
94
|
function_name = function['name']
|
92
|
-
|
95
|
+
try:
|
96
|
+
_tool_input = json.loads(function['arguments'] or '{}')
|
97
|
+
except Exception as e:
|
98
|
+
raise Exception(f"Error parsing arguments for function: {function_name}. arguments: {function['arguments']}. error: {str(e)}")
|
93
99
|
# We construct an ToolInvocation from the function_call
|
94
100
|
actions.append(ToolInvocation(
|
95
101
|
tool=function_name,
|
@@ -200,7 +206,11 @@ def get_qwen_local_functions_agent_executor(
|
|
200
206
|
# only one function
|
201
207
|
function = last_message.additional_kwargs['function_call']
|
202
208
|
function_name = function['name']
|
203
|
-
|
209
|
+
try:
|
210
|
+
_tool_input = json.loads(function['arguments'] or '{}')
|
211
|
+
except Exception as e:
|
212
|
+
raise Exception(
|
213
|
+
f"Error parsing arguments for function: {function_name}. arguments: {function['arguments']}. error: {str(e)}")
|
204
214
|
# We construct an ToolInvocation from the function_call
|
205
215
|
actions.append(ToolInvocation(
|
206
216
|
tool=function_name,
|
@@ -6,26 +6,27 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
6
6
|
import httpx
|
7
7
|
import pandas as pd
|
8
8
|
import pymysql
|
9
|
+
from dotenv import load_dotenv
|
10
|
+
from langchain_community.tools.arxiv.tool import ArxivQueryRun
|
11
|
+
from langchain_community.tools.bearly.tool import BearlyInterpreterTool
|
12
|
+
from langchain_community.utilities.arxiv import ArxivAPIWrapper
|
13
|
+
from langchain_community.utilities.bing_search import BingSearchAPIWrapper
|
14
|
+
from langchain_core.callbacks import BaseCallbackManager, Callbacks
|
15
|
+
from langchain_core.language_models import BaseLanguageModel
|
16
|
+
from langchain_core.tools import BaseTool, Tool
|
17
|
+
from mypy_extensions import Arg, KwArg
|
18
|
+
|
9
19
|
from bisheng_langchain.gpts.tools.api_tools import ALL_API_TOOLS
|
10
20
|
from bisheng_langchain.gpts.tools.bing_search.tool import BingSearchRun
|
11
21
|
from bisheng_langchain.gpts.tools.calculator.tool import calculator
|
12
22
|
from bisheng_langchain.gpts.tools.code_interpreter.tool import CodeInterpreterTool
|
13
|
-
|
14
23
|
# from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
|
15
24
|
from bisheng_langchain.gpts.tools.dalle_image_generator.tool import (
|
16
25
|
DallEAPIWrapper,
|
17
26
|
DallEImageGenerator,
|
18
27
|
)
|
19
28
|
from bisheng_langchain.gpts.tools.get_current_time.tool import get_current_time
|
20
|
-
from
|
21
|
-
from langchain_community.tools.arxiv.tool import ArxivQueryRun
|
22
|
-
from langchain_community.tools.bearly.tool import BearlyInterpreterTool
|
23
|
-
from langchain_community.utilities.arxiv import ArxivAPIWrapper
|
24
|
-
from langchain_community.utilities.bing_search import BingSearchAPIWrapper
|
25
|
-
from langchain_core.callbacks import BaseCallbackManager, Callbacks
|
26
|
-
from langchain_core.language_models import BaseLanguageModel
|
27
|
-
from langchain_core.tools import BaseTool, Tool
|
28
|
-
from mypy_extensions import Arg, KwArg
|
29
|
+
from bisheng_langchain.gpts.tools.sql_agent.tool import SqlAgentTool, SqlAgentAPIWrapper
|
29
30
|
from bisheng_langchain.rag import BishengRAGTool
|
30
31
|
from bisheng_langchain.utils.azure_dalle_image_generator import AzureDallEWrapper
|
31
32
|
|
@@ -80,6 +81,10 @@ def _get_dalle_image_generator(**kwargs: Any) -> Tool:
|
|
80
81
|
)
|
81
82
|
|
82
83
|
|
84
|
+
def _get_sql_agent(**kwargs: Any) -> BaseTool:
|
85
|
+
return SqlAgentTool(api_wrapper=SqlAgentAPIWrapper(**kwargs))
|
86
|
+
|
87
|
+
|
83
88
|
def _get_bearly_code_interpreter(**kwargs: Any) -> Tool:
|
84
89
|
return BearlyInterpreterTool(**kwargs).as_tool()
|
85
90
|
|
@@ -99,6 +104,7 @@ _EXTRA_PARAM_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[Optio
|
|
99
104
|
'bisheng_rag': (BishengRAGTool.get_rag_tool, ['name', 'description'],
|
100
105
|
['vector_store', 'keyword_store', 'llm', 'collection_name', 'max_content',
|
101
106
|
'sort_by_source_and_index']),
|
107
|
+
'sql_agent': (_get_sql_agent, ['llm', 'sql_address'], []),
|
102
108
|
}
|
103
109
|
|
104
110
|
_API_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[str]]] = {**ALL_API_TOOLS} # type: ignore
|
@@ -9,8 +9,11 @@ from .base import APIToolBase, Field, MultArgsSchemaTool
|
|
9
9
|
|
10
10
|
class OpenApiTools(APIToolBase):
|
11
11
|
|
12
|
-
def get_real_path(self):
|
13
|
-
|
12
|
+
def get_real_path(self, path_params: dict|None):
|
13
|
+
path = self.params['path']
|
14
|
+
if path_params:
|
15
|
+
path = path.format(**path_params)
|
16
|
+
return self.url + path
|
14
17
|
|
15
18
|
def get_request_method(self):
|
16
19
|
return self.params['method'].lower()
|
@@ -20,17 +23,20 @@ class OpenApiTools(APIToolBase):
|
|
20
23
|
for one in self.params['parameters']:
|
21
24
|
params_define[one['name']] = one
|
22
25
|
|
26
|
+
path_params = {}
|
23
27
|
params = {}
|
24
28
|
json_data = {}
|
25
29
|
for k, v in kwargs.items():
|
26
30
|
if params_define.get(k):
|
27
31
|
if params_define[k]['in'] == 'query':
|
28
32
|
params[k] = v
|
33
|
+
elif params_define[k]['in'] == 'path':
|
34
|
+
path_params[k] = v
|
29
35
|
else:
|
30
36
|
json_data[k] = v
|
31
37
|
else:
|
32
38
|
params[k] = v
|
33
|
-
return params, json_data
|
39
|
+
return params, json_data, path_params
|
34
40
|
|
35
41
|
def parse_args_schema(self):
|
36
42
|
params = self.params['parameters']
|
@@ -43,6 +49,10 @@ class OpenApiTools(APIToolBase):
|
|
43
49
|
field_type = int
|
44
50
|
elif field_type == 'string':
|
45
51
|
field_type = str
|
52
|
+
elif field_type == 'boolean':
|
53
|
+
field_type = bool
|
54
|
+
elif field_type == 'array':
|
55
|
+
field_type = list
|
46
56
|
elif field_type in {'object', 'dict'}:
|
47
57
|
param_object_param = {}
|
48
58
|
for param in one['schema']['properties'].keys():
|
@@ -53,6 +63,10 @@ class OpenApiTools(APIToolBase):
|
|
53
63
|
field_type = int
|
54
64
|
elif field_type == 'string':
|
55
65
|
field_type = str
|
66
|
+
elif field_type == 'boolean':
|
67
|
+
field_type = bool
|
68
|
+
elif field_type == 'array':
|
69
|
+
field_type = list
|
56
70
|
param_object_param[param] = (
|
57
71
|
field_type,
|
58
72
|
Field(description=one['schema']['properties'][param]['description']))
|
@@ -74,10 +88,10 @@ class OpenApiTools(APIToolBase):
|
|
74
88
|
extra = {}
|
75
89
|
if 'proxy' in kwargs:
|
76
90
|
extra['proxy'] = kwargs.pop('proxy')
|
77
|
-
|
91
|
+
params, json_data, path_params = self.get_params_json(**kwargs)
|
92
|
+
path = self.get_real_path(path_params)
|
78
93
|
logger.info('api_call url={}', path)
|
79
94
|
method = self.get_request_method()
|
80
|
-
params, json_data = self.get_params_json(**kwargs)
|
81
95
|
|
82
96
|
if method == 'get':
|
83
97
|
resp = self.client.get(path, params=params, **extra)
|
@@ -100,10 +114,10 @@ class OpenApiTools(APIToolBase):
|
|
100
114
|
if 'proxy' in kwargs:
|
101
115
|
extra['proxy'] = kwargs.pop('proxy')
|
102
116
|
|
103
|
-
|
117
|
+
params, json_data, path_params = self.get_params_json(**kwargs)
|
118
|
+
path = self.get_real_path(path_params)
|
104
119
|
logger.info('api_call url={}', path)
|
105
120
|
method = self.get_request_method()
|
106
|
-
params, json_data = self.get_params_json(**kwargs)
|
107
121
|
|
108
122
|
if method == 'get':
|
109
123
|
resp = await self.async_client.aget(path, params=params, **extra)
|
File without changes
|
@@ -0,0 +1,262 @@
|
|
1
|
+
from typing import Type, Optional, TypedDict, Annotated, Any, Literal
|
2
|
+
|
3
|
+
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
4
|
+
from langchain_community.utilities import SQLDatabase
|
5
|
+
from langchain_core.callbacks import CallbackManagerForToolRun
|
6
|
+
from langchain_core.language_models import BaseLanguageModel
|
7
|
+
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage
|
8
|
+
from langchain_core.prompts import ChatPromptTemplate
|
9
|
+
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
|
10
|
+
from langchain_core.tools import BaseTool, tool
|
11
|
+
from langgraph.constants import END, START
|
12
|
+
from langgraph.graph import add_messages, StateGraph
|
13
|
+
from langgraph.prebuilt import ToolNode
|
14
|
+
from pydantic import BaseModel, Field
|
15
|
+
|
16
|
+
|
17
|
+
class State(TypedDict):
|
18
|
+
messages: Annotated[list[AnyMessage], add_messages]
|
19
|
+
|
20
|
+
|
21
|
+
def handle_tool_error(state) -> dict:
|
22
|
+
error = state.get("error")
|
23
|
+
tool_calls = state["messages"][-1].tool_calls
|
24
|
+
return {
|
25
|
+
"messages": [
|
26
|
+
ToolMessage(
|
27
|
+
content=f"Error: {repr(error)}\n please fix your mistakes.",
|
28
|
+
tool_call_id=tc["id"],
|
29
|
+
)
|
30
|
+
for tc in tool_calls
|
31
|
+
]
|
32
|
+
}
|
33
|
+
|
34
|
+
|
35
|
+
def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
|
36
|
+
"""
|
37
|
+
Create a ToolNode with a fallback to handle errors and surface them to the agent.
|
38
|
+
"""
|
39
|
+
return ToolNode(tools).with_fallbacks(
|
40
|
+
[RunnableLambda(handle_tool_error)], exception_key="error"
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class SubmitFinalAnswer(BaseModel):
|
45
|
+
"""Submit the final answer to the user based on the query results."""
|
46
|
+
|
47
|
+
final_answer: str = Field(..., description="The final answer to the user")
|
48
|
+
|
49
|
+
class QueryDBTool(BaseTool):
|
50
|
+
name = "db_query_tool"
|
51
|
+
description = """Execute a SQL query against the database and get back the result.
|
52
|
+
If the query is not correct, an error message will be returned.
|
53
|
+
If an error is returned, rewrite the query, check the query, and try again."""
|
54
|
+
|
55
|
+
db: SQLDatabase
|
56
|
+
|
57
|
+
def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None):
|
58
|
+
result = self.db.run_no_throw(query)
|
59
|
+
if not result:
|
60
|
+
return "Error: Query failed. Please rewrite your query and try again."
|
61
|
+
return result
|
62
|
+
|
63
|
+
class SqlAgentAPIWrapper(BaseModel):
|
64
|
+
llm: BaseLanguageModel = Field(description="llm to use for sql agent")
|
65
|
+
sql_address: str = Field(description="sql database address for SQLDatabase uri")
|
66
|
+
|
67
|
+
db: Optional[SQLDatabase]
|
68
|
+
list_tables_tool: Optional[BaseTool]
|
69
|
+
get_schema_tool: Optional[BaseTool]
|
70
|
+
db_query_tool: Optional[BaseTool]
|
71
|
+
query_check: Optional[Any]
|
72
|
+
query_gen: Optional[Any]
|
73
|
+
workflow: Optional[StateGraph]
|
74
|
+
app: Optional[Any]
|
75
|
+
|
76
|
+
class Config:
|
77
|
+
arbitrary_types_allowed = True
|
78
|
+
|
79
|
+
def __init__(self, **kwargs):
|
80
|
+
super().__init__(**kwargs)
|
81
|
+
self.llm = kwargs.get('llm')
|
82
|
+
self.sql_address = kwargs.get('sql_address')
|
83
|
+
|
84
|
+
self.db = SQLDatabase.from_uri(self.sql_address)
|
85
|
+
toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
|
86
|
+
tools = toolkit.get_tools()
|
87
|
+
self.list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
|
88
|
+
self.get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
|
89
|
+
self.db_query_tool = QueryDBTool(db=self.db)
|
90
|
+
|
91
|
+
self.query_check = self.init_query_check()
|
92
|
+
self.query_gen = self.init_query_gen()
|
93
|
+
|
94
|
+
# Define a new graph
|
95
|
+
self.workflow = StateGraph(State)
|
96
|
+
self.init_workflow()
|
97
|
+
self.app = self.workflow.compile(checkpointer=False)
|
98
|
+
|
99
|
+
def init_workflow(self):
|
100
|
+
self.workflow.add_node("first_tool_call", self.first_tool_call)
|
101
|
+
self.workflow.add_node(
|
102
|
+
"list_tables_tool", create_tool_node_with_fallback([self.list_tables_tool])
|
103
|
+
)
|
104
|
+
|
105
|
+
self.workflow.add_node("get_schema_tool", create_tool_node_with_fallback([self.get_schema_tool]))
|
106
|
+
|
107
|
+
model_get_schema = self.llm.bind_tools(
|
108
|
+
[self.get_schema_tool]
|
109
|
+
)
|
110
|
+
self.workflow.add_node(
|
111
|
+
"model_get_schema",
|
112
|
+
lambda state: {
|
113
|
+
"messages": [model_get_schema.invoke(state["messages"])],
|
114
|
+
},
|
115
|
+
)
|
116
|
+
|
117
|
+
self.workflow.add_node("query_gen", self.query_gen_node)
|
118
|
+
self.workflow.add_node("correct_query", self.model_check_query)
|
119
|
+
|
120
|
+
self.workflow.add_node("execute_query", create_tool_node_with_fallback([self.db_query_tool]))
|
121
|
+
|
122
|
+
self.workflow.add_edge(START, "first_tool_call")
|
123
|
+
self.workflow.add_edge("first_tool_call", "list_tables_tool")
|
124
|
+
self.workflow.add_edge("list_tables_tool", "model_get_schema")
|
125
|
+
self.workflow.add_edge("model_get_schema", "get_schema_tool")
|
126
|
+
self.workflow.add_edge("get_schema_tool", "query_gen")
|
127
|
+
self.workflow.add_conditional_edges(
|
128
|
+
"query_gen",
|
129
|
+
self.should_continue,
|
130
|
+
)
|
131
|
+
self.workflow.add_edge("correct_query", "execute_query")
|
132
|
+
self.workflow.add_edge("execute_query", "query_gen")
|
133
|
+
|
134
|
+
@staticmethod
|
135
|
+
def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
|
136
|
+
messages = state["messages"]
|
137
|
+
last_message = messages[-1]
|
138
|
+
# If there is a tool call, then we finish
|
139
|
+
if getattr(last_message, "tool_calls", None):
|
140
|
+
return END
|
141
|
+
if last_message.content.startswith("Error:"):
|
142
|
+
return "query_gen"
|
143
|
+
else:
|
144
|
+
return "correct_query"
|
145
|
+
|
146
|
+
def init_query_check(self):
|
147
|
+
query_check_system = """You are a SQL expert with a strong attention to detail.
|
148
|
+
Double check the SQLite query for common mistakes, including:
|
149
|
+
- Using NOT IN with NULL values
|
150
|
+
- Using UNION when UNION ALL should have been used
|
151
|
+
- Using BETWEEN for exclusive ranges
|
152
|
+
- Data type mismatch in predicates
|
153
|
+
- Properly quoting identifiers
|
154
|
+
- Using the correct number of arguments for functions
|
155
|
+
- Casting to the correct data type
|
156
|
+
- Using the proper columns for joins
|
157
|
+
|
158
|
+
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
|
159
|
+
|
160
|
+
You will call the appropriate tool to execute the query after running this check."""
|
161
|
+
|
162
|
+
query_check_prompt = ChatPromptTemplate.from_messages(
|
163
|
+
[("system", query_check_system), ("placeholder", "{messages}")]
|
164
|
+
)
|
165
|
+
query_check = query_check_prompt | self.llm.bind_tools(
|
166
|
+
[self.db_query_tool], tool_choice="required"
|
167
|
+
)
|
168
|
+
return query_check
|
169
|
+
|
170
|
+
def first_tool_call(self, state: State) -> dict[str, list[AIMessage]]:
|
171
|
+
return {
|
172
|
+
"messages": [
|
173
|
+
AIMessage(
|
174
|
+
content="",
|
175
|
+
tool_calls=[
|
176
|
+
{
|
177
|
+
"name": "sql_db_list_tables",
|
178
|
+
"args": {},
|
179
|
+
"id": "tool_abcd123",
|
180
|
+
}
|
181
|
+
],
|
182
|
+
)
|
183
|
+
]
|
184
|
+
}
|
185
|
+
|
186
|
+
def model_check_query(self, state: State) -> dict[str, list[AIMessage]]:
|
187
|
+
"""
|
188
|
+
Use this tool to double-check if your query is correct before executing it.
|
189
|
+
"""
|
190
|
+
return {"messages": [self.query_check.invoke({"messages": [state["messages"][-1]]})]}
|
191
|
+
|
192
|
+
def init_query_gen(self):
|
193
|
+
# Add a node for a model to generate a query based on the question and schema
|
194
|
+
query_gen_system = """You are a SQL expert with a strong attention to detail.Given an input question, output a syntactically correct SQL query to run, then look at the results of the query and return the answer.DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.When generating the query:Output the SQL query that answers the input question without a tool call.Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.You can order the results by a relevant column to return the most interesting examples in the database.Never query for all the columns from a specific table, only ask for the relevant columns given the question.If you get an error while executing a query, rewrite the query and try again.If you get an empty result set, you should try to rewrite the query to get a non-empty result set. NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."""
|
195
|
+
query_gen_prompt = ChatPromptTemplate.from_messages(
|
196
|
+
[("system", query_gen_system), ("placeholder", "{messages}")]
|
197
|
+
)
|
198
|
+
query_gen = query_gen_prompt | self.llm.bind_tools(
|
199
|
+
[SubmitFinalAnswer]
|
200
|
+
)
|
201
|
+
return query_gen
|
202
|
+
|
203
|
+
def query_gen_node(self, state: State) -> Any:
|
204
|
+
message = self.query_gen.invoke(state)
|
205
|
+
|
206
|
+
# Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
|
207
|
+
tool_messages = []
|
208
|
+
if message.tool_calls:
|
209
|
+
for tc in message.tool_calls:
|
210
|
+
if tc["name"] != "SubmitFinalAnswer":
|
211
|
+
tool_messages.append(
|
212
|
+
ToolMessage(
|
213
|
+
content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
|
214
|
+
tool_call_id=tc["id"],
|
215
|
+
)
|
216
|
+
)
|
217
|
+
else:
|
218
|
+
tool_messages = []
|
219
|
+
return {"messages": [message] + tool_messages}
|
220
|
+
|
221
|
+
def run(self, query: str) -> str:
|
222
|
+
messages = self.app.invoke({"messages": [("user", query)]}, config={
|
223
|
+
'recursion_limit': 50
|
224
|
+
})
|
225
|
+
return messages["messages"][-1].tool_calls[0]["args"]["final_answer"]
|
226
|
+
|
227
|
+
def arun(self, query: str) -> str:
|
228
|
+
return self.run(query)
|
229
|
+
|
230
|
+
|
231
|
+
class SqlAgentInput(BaseModel):
|
232
|
+
query: str = Field(description="用户数据查询需求(需要尽可能完整、准确)")
|
233
|
+
|
234
|
+
|
235
|
+
class SqlAgentTool(BaseTool):
|
236
|
+
name = "sql_agent"
|
237
|
+
description = "回答与 SQL 数据库有关的问题。给定用户问题,将从数据库中获取可用的表以及对应 DDL,生成 SQL 查询语句并进行执行,最终得到执行结果。"
|
238
|
+
args_schema: Type[BaseModel] = SqlAgentInput
|
239
|
+
api_wrapper: SqlAgentAPIWrapper
|
240
|
+
|
241
|
+
def _run(
|
242
|
+
self,
|
243
|
+
query: str,
|
244
|
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
245
|
+
) -> str:
|
246
|
+
"""Use the tool."""
|
247
|
+
return self.api_wrapper.run(query)
|
248
|
+
|
249
|
+
|
250
|
+
if __name__ == '__main__':
|
251
|
+
from langchain_openai import AzureChatOpenAI
|
252
|
+
|
253
|
+
llm = AzureChatOpenAI()
|
254
|
+
sql_agent_tool = SqlAgentTool(
|
255
|
+
api_wrapper=SqlAgentAPIWrapper(
|
256
|
+
llm=llm,
|
257
|
+
sql_address="sqlite:///Chinook.db",
|
258
|
+
)
|
259
|
+
)
|
260
|
+
|
261
|
+
result = sql_agent_tool.run("Which sales agent made the most in sales in 2009?")
|
262
|
+
print(result)
|
@@ -142,6 +142,7 @@ class BishengRAGTool:
|
|
142
142
|
prompt = import_class(f'bisheng_langchain.rag.prompts.{prompt_type}')
|
143
143
|
else:
|
144
144
|
prompt = None
|
145
|
+
self.prompt_inputs = prompt.input_variables
|
145
146
|
self.qa_chain = create_stuff_documents_chain(llm=self.llm, prompt=prompt)
|
146
147
|
|
147
148
|
def _post_init_retriever(self, retriever_type, **kwargs):
|
@@ -238,12 +239,12 @@ class BishengRAGTool:
|
|
238
239
|
kwargs = {}
|
239
240
|
if run_manager:
|
240
241
|
kwargs['config'] = RunnableConfig(callbacks=[run_manager])
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
)
|
242
|
+
tmp_input = {
|
243
|
+
'context': docs,
|
244
|
+
}
|
245
|
+
if 'question' in self.prompt_inputs:
|
246
|
+
tmp_input['question'] = query
|
247
|
+
ans = self.qa_chain.invoke(tmp_input, **kwargs)
|
247
248
|
except Exception as e:
|
248
249
|
logger.exception(f'question: {query}\nerror: {e}')
|
249
250
|
ans = str(e)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: bisheng-langchain
|
3
|
-
Version: 0.4.
|
3
|
+
Version: 0.4.1
|
4
4
|
Summary: bisheng langchain modules
|
5
5
|
Home-page: https://github.com/dataelement/bisheng
|
6
6
|
Author: DataElem
|
@@ -32,7 +32,7 @@ Requires-Dist: langgraph==0.2.*
|
|
32
32
|
Requires-Dist: openai==1.51.*
|
33
33
|
Requires-Dist: langchain-openai>=0.1.25
|
34
34
|
Requires-Dist: llama-index==0.9.48
|
35
|
-
Requires-Dist: bisheng-ragas==1
|
35
|
+
Requires-Dist: bisheng-ragas==1.*
|
36
36
|
|
37
37
|
## What is bisheng-langchain?
|
38
38
|
|
@@ -73,11 +73,11 @@ bisheng_langchain/gpts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3
|
|
73
73
|
bisheng_langchain/gpts/assistant.py,sha256=Xx5sNVJzB9chUHf9jYigNugWw-gyEIb6GpLOnrMWDXU,5552
|
74
74
|
bisheng_langchain/gpts/auto_optimization.py,sha256=WNsC19rgvuDYQlSIaYThq5RqCbuobDbzCwAJW4Ksw0c,3626
|
75
75
|
bisheng_langchain/gpts/auto_tool_selected.py,sha256=21WETf9o0YS-QEBwv3mmZRObKWszefQkXEqAA6KzoaM,1582
|
76
|
-
bisheng_langchain/gpts/load_tools.py,sha256=
|
76
|
+
bisheng_langchain/gpts/load_tools.py,sha256=CO6sdpkfUPa66Z-qZVWX34jrOlu8JBRC6j3jyYtnkjc,8755
|
77
77
|
bisheng_langchain/gpts/message_types.py,sha256=7EJOx62j9E1U67jxWgxE_I7a8IjAvvKANknXkD2gFm0,213
|
78
78
|
bisheng_langchain/gpts/utils.py,sha256=t3YDxaJ0OYd6EKsek7PJFRYnsezwzEFK5oVU-PRbu5g,6671
|
79
79
|
bisheng_langchain/gpts/agent_types/__init__.py,sha256=88tFt1GfrfIqa4hCg0cMJk7rTeUmCSSdiVhR41CW4rM,381
|
80
|
-
bisheng_langchain/gpts/agent_types/llm_functions_agent.py,sha256=
|
80
|
+
bisheng_langchain/gpts/agent_types/llm_functions_agent.py,sha256=LRdB6QPLz6Ztqtk0pzx_nvOczlqAisAkynMoTCp1mdE,10597
|
81
81
|
bisheng_langchain/gpts/agent_types/llm_react_agent.py,sha256=lo8Neo346aZP8tve56yiDfy6xQbtF3o_lJLIAPPgBM0,6623
|
82
82
|
bisheng_langchain/gpts/prompts/__init__.py,sha256=pOnXvk6_PjqAoLrh68sI9o3o6znKGxoLMVFP-0XTCJo,704
|
83
83
|
bisheng_langchain/gpts/prompts/assistant_prompt_base.py,sha256=Yp9M1XbZb5jHeBG_txcwWA84Euvl89t0g-GbJMa5Ur0,1133
|
@@ -92,7 +92,7 @@ bisheng_langchain/gpts/tools/api_tools/__init__.py,sha256=CkEjgIFM4GIv86V1B7SsFL
|
|
92
92
|
bisheng_langchain/gpts/tools/api_tools/base.py,sha256=zPUCM_mOM9ygsc8pwejZvngEfEvGtiWTKbavfza7Eqg,3593
|
93
93
|
bisheng_langchain/gpts/tools/api_tools/flow.py,sha256=ot2YAYgQGWgUpb2nCECAmpqHY6m0SgzwkupF9kDT3lU,2461
|
94
94
|
bisheng_langchain/gpts/tools/api_tools/macro_data.py,sha256=FyG-qtl2ECS1CDKt6olN0eDTDM91d-UvDkMDBiVLgYQ,27429
|
95
|
-
bisheng_langchain/gpts/tools/api_tools/openapi.py,sha256=
|
95
|
+
bisheng_langchain/gpts/tools/api_tools/openapi.py,sha256=7i2Dw05u7FMKWeCKBvkNuPt9z575GaYyy-VyodUc2JM,5760
|
96
96
|
bisheng_langchain/gpts/tools/api_tools/sina.py,sha256=4KpK7_HUUtjpdJ-K4LjPlb-occyAZcRtmmCWqJ2BotE,9708
|
97
97
|
bisheng_langchain/gpts/tools/api_tools/tianyancha.py,sha256=abDAz-yAH1-2rKiSmZ6TgnrNUnpgAZpDY8oDiWfWapc,6684
|
98
98
|
bisheng_langchain/gpts/tools/bing_search/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -105,6 +105,8 @@ bisheng_langchain/gpts/tools/dalle_image_generator/__init__.py,sha256=47DEQpj8HB
|
|
105
105
|
bisheng_langchain/gpts/tools/dalle_image_generator/tool.py,sha256=h_mSGn2fvw4wGufrqKYC3lI1LLo9Uu_rynDM88IonMA,7631
|
106
106
|
bisheng_langchain/gpts/tools/get_current_time/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
107
107
|
bisheng_langchain/gpts/tools/get_current_time/tool.py,sha256=3uvk7Yu07qhZy1sBrFMhGEwyxEGMB8vubizs9x-6DG8,801
|
108
|
+
bisheng_langchain/gpts/tools/sql_agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
109
|
+
bisheng_langchain/gpts/tools/sql_agent/tool.py,sha256=YhKfYOzy4XOMcQtW-AUe85sgWRT5k0--zW4VK6oO-Dg,11239
|
108
110
|
bisheng_langchain/input_output/__init__.py,sha256=sW_GB7MlrHYsqY1Meb_LeimQqNsMz1gH-00Tqb2BUyM,153
|
109
111
|
bisheng_langchain/input_output/input.py,sha256=I5YDmgbvvj1o2lO9wi8LE37wM0wP5jkhUREU32YrZMQ,1094
|
110
112
|
bisheng_langchain/input_output/output.py,sha256=6U-az6-Cwz665C2YmcH3SYctWVjPFjmW8s70CA_qphk,11585
|
@@ -114,7 +116,7 @@ bisheng_langchain/rag/__init__.py,sha256=Rm_cDxOJINt0H4bOeUo3JctPxaI6xKKXZcS-R_w
|
|
114
116
|
bisheng_langchain/rag/bisheng_rag_chain.py,sha256=yCgbRJ9hHOAF4lsAL2kjX-YX9J7nduIV7lsoYnuTXL4,6251
|
115
117
|
bisheng_langchain/rag/bisheng_rag_pipeline.py,sha256=neoBK3TtuQ07_WeuJCzYlvtsDQNepUa_68NT8VCgytw,13749
|
116
118
|
bisheng_langchain/rag/bisheng_rag_pipeline_v2.py,sha256=iOoF7mbLp9qDGPsV0fEmgph_Ba8VnECYvCPebXk8xmo,16144
|
117
|
-
bisheng_langchain/rag/bisheng_rag_tool.py,sha256=
|
119
|
+
bisheng_langchain/rag/bisheng_rag_tool.py,sha256=fhvHJEJuFQ2ph50-89nTny4chy0pFxvPe-a-fdfnStI,13665
|
118
120
|
bisheng_langchain/rag/extract_info.py,sha256=jtZ4Bchjv4tOaayC2MnkV-lLu3vDA0Hsk_S-ATni34g,1695
|
119
121
|
bisheng_langchain/rag/run_qa_gen_web.py,sha256=-fIvHNnD3lD0iNU5m0Me1GDwRjlcsB8tE5RnPtFRG2s,1840
|
120
122
|
bisheng_langchain/rag/run_rag_evaluate_web.py,sha256=a9vMhq-ZhEiHHr43uKUzKtjdk280uAP_UHQW_eOaQMw,2224
|
@@ -155,7 +157,7 @@ bisheng_langchain/vectorstores/__init__.py,sha256=zCZgDe7LyQ0iDkfcm5UJ5NxwKQSRHn
|
|
155
157
|
bisheng_langchain/vectorstores/elastic_keywords_search.py,sha256=sIMbud4UfbfLTLf_pIIfcVC2lHbaTWCTTABTFP5mXkE,15334
|
156
158
|
bisheng_langchain/vectorstores/milvus.py,sha256=jWq_lce-ihOz07D1kwj5ctPzElYexNCjJ-xSv-pK1CI,37172
|
157
159
|
bisheng_langchain/vectorstores/retriever.py,sha256=hj4nAAl352EV_ANnU2OHJn7omCH3nBK82ydo14KqMH4,4353
|
158
|
-
bisheng_langchain-0.4.
|
159
|
-
bisheng_langchain-0.4.
|
160
|
-
bisheng_langchain-0.4.
|
161
|
-
bisheng_langchain-0.4.
|
160
|
+
bisheng_langchain-0.4.1.dist-info/METADATA,sha256=BFNONpW_KmqEspBNvTO9cBWIHTFkFiTdVE9--OAcCDI,2459
|
161
|
+
bisheng_langchain-0.4.1.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
162
|
+
bisheng_langchain-0.4.1.dist-info/top_level.txt,sha256=Z6pPNyCo4ihyr9iqGQbH8sJiC4dAUwA_mAyGRQB5_Fs,18
|
163
|
+
bisheng_langchain-0.4.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|