bisheng-langchain 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/agents/chatglm_functions_agent/base.py +6 -3
- bisheng_langchain/agents/llm_functions_agent/base.py +6 -3
- bisheng_langchain/chains/qa_generation/base.py +1 -1
- bisheng_langchain/chains/transform.py +1 -1
- bisheng_langchain/chat_models/__init__.py +2 -3
- bisheng_langchain/chat_models/host_llm.py +5 -7
- bisheng_langchain/chat_models/minimax.py +4 -7
- bisheng_langchain/chat_models/proxy_llm.py +5 -7
- bisheng_langchain/chat_models/qwen.py +5 -7
- bisheng_langchain/chat_models/sensetime.py +5 -7
- bisheng_langchain/chat_models/wenxin.py +4 -7
- bisheng_langchain/chat_models/xunfeiai.py +4 -7
- bisheng_langchain/chat_models/zhipuai.py +4 -7
- bisheng_langchain/embeddings/host_embedding.py +6 -4
- bisheng_langchain/embeddings/huggingfacegte.py +2 -2
- bisheng_langchain/embeddings/huggingfacemultilingual.py +2 -2
- bisheng_langchain/embeddings/wenxin.py +5 -8
- bisheng_langchain/gpts/agent_types/llm_functions_agent.py +6 -78
- bisheng_langchain/gpts/agent_types/llm_react_agent.py +2 -5
- bisheng_langchain/gpts/tools/api_tools/base.py +7 -9
- bisheng_langchain/gpts/tools/api_tools/firecrawl.py +1 -1
- bisheng_langchain/gpts/tools/api_tools/flow.py +1 -1
- bisheng_langchain/gpts/tools/api_tools/macro_data.py +3 -3
- bisheng_langchain/gpts/tools/api_tools/openapi.py +3 -3
- bisheng_langchain/gpts/tools/api_tools/sina.py +1 -1
- bisheng_langchain/gpts/tools/api_tools/tianyancha.py +6 -3
- bisheng_langchain/gpts/tools/bing_search/tool.py +2 -2
- bisheng_langchain/gpts/tools/calculator/tool.py +2 -2
- bisheng_langchain/gpts/tools/code_interpreter/tool.py +2 -2
- bisheng_langchain/gpts/tools/dalle_image_generator/tool.py +7 -11
- bisheng_langchain/gpts/tools/get_current_time/tool.py +1 -1
- bisheng_langchain/gpts/tools/message/dingding.py +1 -2
- bisheng_langchain/gpts/tools/message/email.py +6 -8
- bisheng_langchain/gpts/tools/message/feishu.py +10 -11
- bisheng_langchain/gpts/tools/message/wechat.py +2 -3
- bisheng_langchain/gpts/tools/sql_agent/tool.py +44 -203
- bisheng_langchain/input_output/input.py +7 -11
- bisheng_langchain/input_output/output.py +2 -6
- bisheng_langchain/memory/redis.py +3 -3
- bisheng_langchain/rag/bisheng_rag_chain.py +2 -8
- bisheng_langchain/rag/bisheng_rag_tool.py +2 -2
- bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +1 -1
- bisheng_langchain/rag/init_retrievers/keyword_retriever.py +1 -1
- bisheng_langchain/rag/init_retrievers/mix_retriever.py +1 -1
- bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +2 -2
- bisheng_langchain/retrievers/ensemble.py +3 -2
- bisheng_langchain/utils/azure_dalle_image_generator.py +3 -2
- bisheng_langchain/utils/requests.py +3 -13
- bisheng_langchain/vectorstores/retriever.py +4 -7
- {bisheng_langchain-1.1.1.dist-info → bisheng_langchain-1.2.0.dist-info}/METADATA +5 -5
- {bisheng_langchain-1.1.1.dist-info → bisheng_langchain-1.2.0.dist-info}/RECORD +53 -53
- {bisheng_langchain-1.1.1.dist-info → bisheng_langchain-1.2.0.dist-info}/WHEEL +0 -0
- {bisheng_langchain-1.1.1.dist-info → bisheng_langchain-1.2.0.dist-info}/top_level.txt +0 -0
@@ -4,15 +4,15 @@ from typing import Any
|
|
4
4
|
|
5
5
|
import pandas as pd
|
6
6
|
import requests
|
7
|
-
from
|
7
|
+
from pydantic import BaseModel, Field
|
8
8
|
from langchain_core.tools import BaseTool
|
9
9
|
|
10
10
|
from .base import MultArgsSchemaTool
|
11
11
|
|
12
12
|
|
13
13
|
class QueryArg(BaseModel):
|
14
|
-
start_date: str = Field(default='', description='开始月份, 使用YYYY-MM-DD 方式表示',
|
15
|
-
end_date: str = Field(default='', description='结束月份,使用YYYY-MM-DD 方式表示',
|
14
|
+
start_date: str = Field(default='', description='开始月份, 使用YYYY-MM-DD 方式表示', examples=['2023-01-01'])
|
15
|
+
end_date: str = Field(default='', description='结束月份,使用YYYY-MM-DD 方式表示', examples=['2023-05-01'])
|
16
16
|
|
17
17
|
|
18
18
|
class MacroData(BaseModel):
|
@@ -9,9 +9,9 @@ from .base import APIToolBase, Field, MultArgsSchemaTool
|
|
9
9
|
|
10
10
|
class OpenApiTools(APIToolBase):
|
11
11
|
|
12
|
-
api_key: Optional[str]
|
13
|
-
api_location: Optional[str]
|
14
|
-
parameter_name: Optional[str]
|
12
|
+
api_key: Optional[str] = None
|
13
|
+
api_location: Optional[str] = None
|
14
|
+
parameter_name: Optional[str] = None
|
15
15
|
|
16
16
|
def get_real_path(self, path_params: dict | None):
|
17
17
|
path = self.params['path']
|
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|
3
3
|
|
4
4
|
from typing import Any, Dict, Type
|
5
5
|
|
6
|
+
from pydantic import model_validator, BaseModel, Field
|
7
|
+
|
6
8
|
from bisheng_langchain.utils.requests import Requests, RequestsWrapper
|
7
|
-
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
8
9
|
|
9
10
|
from .base import APIToolBase
|
10
11
|
|
@@ -19,7 +20,8 @@ class CompanyInfo(APIToolBase):
|
|
19
20
|
api_key: str = None
|
20
21
|
args_schema: Type[BaseModel] = InputArgs
|
21
22
|
|
22
|
-
@
|
23
|
+
@model_validator(mode='before')
|
24
|
+
@classmethod
|
23
25
|
def build_header(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
24
26
|
"""Build headers that were passed in."""
|
25
27
|
if not values.get('api_key'):
|
@@ -30,7 +32,8 @@ class CompanyInfo(APIToolBase):
|
|
30
32
|
values['headers'] = headers
|
31
33
|
return values
|
32
34
|
|
33
|
-
@
|
35
|
+
@model_validator(mode='before')
|
36
|
+
@classmethod
|
34
37
|
def validate_environment(cls, values: Dict) -> Dict:
|
35
38
|
"""Validate that api key and python package exists in environment."""
|
36
39
|
timeout = values.get('request_timeout', 30)
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
from typing import Optional, Type
|
4
4
|
|
5
|
-
from
|
5
|
+
from pydantic import BaseModel, Field
|
6
6
|
from langchain_community.utilities.bing_search import BingSearchAPIWrapper
|
7
7
|
from langchain_core.callbacks import CallbackManagerForToolRun
|
8
8
|
from langchain_core.tools import BaseTool
|
@@ -43,7 +43,7 @@ class BingSearchResults(BaseTool):
|
|
43
43
|
"Input should be a search query. Output is a JSON array of the query results"
|
44
44
|
)
|
45
45
|
num_results: int = 5
|
46
|
-
args_schema = BingSearchInput
|
46
|
+
args_schema: Type[BaseModel] = BingSearchInput
|
47
47
|
api_wrapper: BingSearchAPIWrapper
|
48
48
|
|
49
49
|
def _run(
|
@@ -2,7 +2,7 @@ import math
|
|
2
2
|
from math import *
|
3
3
|
|
4
4
|
import sympy
|
5
|
-
from
|
5
|
+
from pydantic import BaseModel, Field
|
6
6
|
from langchain.tools import tool
|
7
7
|
from sympy import *
|
8
8
|
|
@@ -10,7 +10,7 @@ from sympy import *
|
|
10
10
|
class CalculatorInput(BaseModel):
|
11
11
|
expression: str = Field(
|
12
12
|
description="The input to this tool should be a mathematical expression using only Python's built-in mathematical operators.",
|
13
|
-
|
13
|
+
examples=['200*7'],
|
14
14
|
)
|
15
15
|
|
16
16
|
|
@@ -15,7 +15,7 @@ from uuid import uuid4
|
|
15
15
|
|
16
16
|
import matplotlib
|
17
17
|
from langchain_community.tools import Tool
|
18
|
-
from
|
18
|
+
from pydantic import BaseModel, Field
|
19
19
|
from loguru import logger
|
20
20
|
|
21
21
|
CODE_BLOCK_PATTERN = r"```(\w*)\n(.*?)\n```"
|
@@ -239,7 +239,7 @@ class CodeInterpreterToolArguments(BaseModel):
|
|
239
239
|
|
240
240
|
python_code: str = Field(
|
241
241
|
...,
|
242
|
-
|
242
|
+
examples=["print('Hello World')"],
|
243
243
|
description=(
|
244
244
|
'The pure python script to be evaluated. '
|
245
245
|
'The contents will be in main.py. '
|
@@ -2,13 +2,11 @@ import logging
|
|
2
2
|
import os
|
3
3
|
from typing import Any, Dict, Mapping, Optional, Tuple, Type, Union
|
4
4
|
|
5
|
-
from langchain.pydantic_v1 import BaseModel, Field
|
6
|
-
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
|
7
5
|
from langchain_community.utils.openai import is_openai_v1
|
8
6
|
from langchain_core.callbacks import CallbackManagerForToolRun
|
9
|
-
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
|
10
7
|
from langchain_core.tools import BaseTool
|
11
8
|
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
|
9
|
+
from pydantic import ConfigDict, model_validator, BaseModel, Field
|
12
10
|
|
13
11
|
from bisheng_langchain.utils.azure_dalle_image_generator import AzureDallEWrapper
|
14
12
|
|
@@ -26,7 +24,7 @@ class DallEAPIWrapper(BaseModel):
|
|
26
24
|
2. save your OPENAI_API_KEY in an environment variable
|
27
25
|
"""
|
28
26
|
|
29
|
-
client: Any #: :meta private:
|
27
|
+
client: Any = None #: :meta private:
|
30
28
|
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
31
29
|
model_name: str = Field(default="dall-e-2", alias="model")
|
32
30
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
@@ -59,13 +57,10 @@ class DallEAPIWrapper(BaseModel):
|
|
59
57
|
http_async_client: Union[Any, None] = None
|
60
58
|
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
61
59
|
http_client as well if you'd like a custom client for sync invocations."""
|
60
|
+
model_config = ConfigDict(extra='forbid')
|
62
61
|
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
extra = Extra.forbid
|
67
|
-
|
68
|
-
@root_validator(pre=True)
|
62
|
+
@model_validator(mode='before')
|
63
|
+
@classmethod
|
69
64
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
70
65
|
"""Build extra kwargs from additional params that were passed in."""
|
71
66
|
all_required_field_names = get_pydantic_field_names(cls)
|
@@ -91,7 +86,8 @@ class DallEAPIWrapper(BaseModel):
|
|
91
86
|
values["model_kwargs"] = extra
|
92
87
|
return values
|
93
88
|
|
94
|
-
@
|
89
|
+
@model_validator(mode='before')
|
90
|
+
@classmethod
|
95
91
|
def validate_environment(cls, values: Dict) -> Dict:
|
96
92
|
"""Validate that api key and python package exists in environment."""
|
97
93
|
values["openai_api_key"] = get_from_dict_or_env(values, "openai_api_key", "OPENAI_API_KEY")
|
@@ -1,8 +1,7 @@
|
|
1
1
|
from typing import Any, Optional, Type
|
2
2
|
|
3
3
|
import requests
|
4
|
-
from
|
5
|
-
from loguru import logger
|
4
|
+
from pydantic import BaseModel, Field
|
6
5
|
|
7
6
|
from bisheng_langchain.gpts.tools.api_tools.base import (APIToolBase,
|
8
7
|
MultArgsSchemaTool)
|
@@ -1,11 +1,9 @@
|
|
1
|
-
import os
|
2
1
|
import smtplib
|
3
|
-
from email.mime.application import MIMEApplication
|
4
2
|
from email.mime.multipart import MIMEMultipart
|
5
3
|
from email.mime.text import MIMEText
|
6
|
-
from typing import Any
|
4
|
+
from typing import Any
|
7
5
|
|
8
|
-
from
|
6
|
+
from pydantic import BaseModel, Field
|
9
7
|
|
10
8
|
from bisheng_langchain.gpts.tools.api_tools.base import (APIToolBase,
|
11
9
|
MultArgsSchemaTool)
|
@@ -17,7 +15,7 @@ class InputArgs(BaseModel):
|
|
17
15
|
content: str = Field(description="邮件正文内容")
|
18
16
|
|
19
17
|
|
20
|
-
class EmailMessageTool(
|
18
|
+
class EmailMessageTool(BaseModel):
|
21
19
|
|
22
20
|
email_account: str = Field(description="发件人邮箱")
|
23
21
|
email_password: str = Field(description="邮箱授权码/密码")
|
@@ -27,9 +25,9 @@ class EmailMessageTool(APIToolBase):
|
|
27
25
|
|
28
26
|
def send_email(
|
29
27
|
self,
|
30
|
-
receiver,
|
31
|
-
subject,
|
32
|
-
content,
|
28
|
+
receiver: str = None,
|
29
|
+
subject: str = None,
|
30
|
+
content: str = None,
|
33
31
|
):
|
34
32
|
"""
|
35
33
|
发送电子邮件函数
|
@@ -1,29 +1,28 @@
|
|
1
1
|
from typing import Any, Optional, Type
|
2
2
|
|
3
3
|
import requests
|
4
|
-
from
|
5
|
-
from loguru import logger
|
4
|
+
from pydantic import BaseModel, Field
|
6
5
|
|
7
6
|
from bisheng_langchain.gpts.tools.api_tools.base import (APIToolBase,
|
8
7
|
MultArgsSchemaTool)
|
9
8
|
|
10
9
|
|
11
10
|
class InputArgs(BaseModel):
|
12
|
-
message: Optional[str] = Field(description="需要发送的钉钉消息")
|
13
|
-
receive_id: Optional[str] = Field(description="接收的ID")
|
14
|
-
receive_id_type: Optional[str] = Field(description="接收的ID类型")
|
15
|
-
container_id: Optional[str] = Field(description="container_id")
|
16
|
-
start_time: Optional[str] = Field(description="start_time")
|
17
|
-
end_time: Optional[str] = Field(description="end_time")
|
11
|
+
message: Optional[str] = Field(None, description="需要发送的钉钉消息")
|
12
|
+
receive_id: Optional[str] = Field(None, description="接收的ID")
|
13
|
+
receive_id_type: Optional[str] = Field(None, description="接收的ID类型")
|
14
|
+
container_id: Optional[str] = Field(None, description="container_id")
|
15
|
+
start_time: Optional[str] = Field(None, description="start_time")
|
16
|
+
end_time: Optional[str] = Field(None, description="end_time")
|
18
17
|
# page_token: Optional[str] = Field(description="page_token")
|
19
|
-
container_id_type: Optional[str] = Field(description="container_id_type")
|
18
|
+
container_id_type: Optional[str] = Field(None, description="container_id_type")
|
20
19
|
page_size: Optional[int] = Field(default=20,description="page_size")
|
21
|
-
page_token: Optional[str] = Field(description="page_token")
|
20
|
+
page_token: Optional[str] = Field(None, description="page_token")
|
22
21
|
sort_type: Optional[str] = Field(description="sort_type",default="ByCreateTimeAsc")
|
23
22
|
|
24
23
|
|
25
24
|
class FeishuMessageTool(BaseModel):
|
26
|
-
API_BASE_URL = "https://open.feishu.cn/open-apis"
|
25
|
+
API_BASE_URL: str = "https://open.feishu.cn/open-apis"
|
27
26
|
app_id: str = Field(description="app_id")
|
28
27
|
app_secret: str = Field(description="app_secret")
|
29
28
|
|
@@ -1,8 +1,7 @@
|
|
1
|
-
from typing import Any
|
1
|
+
from typing import Any
|
2
2
|
|
3
3
|
import requests
|
4
|
-
from
|
5
|
-
from loguru import logger
|
4
|
+
from pydantic import BaseModel, Field
|
6
5
|
|
7
6
|
from bisheng_langchain.gpts.tools.api_tools.base import (APIToolBase,
|
8
7
|
MultArgsSchemaTool)
|
@@ -1,236 +1,77 @@
|
|
1
|
-
from typing import Type, Optional
|
1
|
+
from typing import Type, Optional
|
2
2
|
|
3
3
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
4
4
|
from langchain_community.utilities import SQLDatabase
|
5
5
|
from langchain_core.callbacks import CallbackManagerForToolRun
|
6
6
|
from langchain_core.language_models import BaseLanguageModel
|
7
|
-
from langchain_core.messages import
|
8
|
-
from langchain_core.
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from langgraph.graph import add_messages, StateGraph
|
13
|
-
from langgraph.prebuilt import ToolNode
|
14
|
-
from pydantic import BaseModel, Field
|
7
|
+
from langchain_core.messages import HumanMessage
|
8
|
+
from langchain_core.tools import BaseTool
|
9
|
+
from langgraph.graph.graph import CompiledGraph
|
10
|
+
from langgraph.prebuilt import create_react_agent
|
11
|
+
from pydantic import BaseModel, Field, ConfigDict
|
15
12
|
|
13
|
+
_agent_system_prompt = """You are an autonomous agent that answers user questions by querying an SQL database through the provided tools.
|
16
14
|
|
17
|
-
|
18
|
-
messages: Annotated[list[AnyMessage], add_messages]
|
15
|
+
When a new question arrives, follow the steps *in order*:
|
19
16
|
|
17
|
+
1. ALWAYS call `sql_db_list_tables` first.
|
18
|
+
Purpose: discover what tables are available. Never skip this step.
|
20
19
|
|
21
|
-
|
22
|
-
|
23
|
-
tool_calls = state["messages"][-1].tool_calls
|
24
|
-
return {
|
25
|
-
"messages": [
|
26
|
-
ToolMessage(
|
27
|
-
content=f"Error: {repr(error)}\n please fix your mistakes.",
|
28
|
-
tool_call_id=tc["id"],
|
29
|
-
)
|
30
|
-
for tc in tool_calls
|
31
|
-
]
|
32
|
-
}
|
20
|
+
2. Choose the table(s) that are probably relevant, then call `sql_db_schema`
|
21
|
+
once for each of those tables to obtain their schemas.
|
33
22
|
|
23
|
+
3. Write one syntactically-correct {dialect} SELECT statement.
|
24
|
+
Guidelines for this query:
|
25
|
+
- Return no more than 50 rows **unless** the user explicitly requests another limit.
|
26
|
+
- Select only the columns needed to answer the question; avoid `SELECT *`.
|
27
|
+
- If helpful, add `ORDER BY` on a meaningful column so the most interesting rows appear first.
|
28
|
+
- ABSOLUTELY NO data-modification statements (INSERT, UPDATE, DELETE, DROP, …).
|
29
|
+
- Double-check the SQL before executing.
|
34
30
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
return ToolNode(tools).with_fallbacks(
|
40
|
-
[RunnableLambda(handle_tool_error)], exception_key="error"
|
41
|
-
)
|
42
|
-
|
31
|
+
4. Execute the query with the execution tool `sql_db_query`.
|
32
|
+
If execution fails, inspect the error, revise the SQL, and try again.
|
33
|
+
Repeat until the query runs successfully or you are certain the request
|
34
|
+
cannot be satisfied.
|
43
35
|
|
44
|
-
|
45
|
-
|
36
|
+
5. Read the resulting rows and craft a concise, direct answer for the user.
|
37
|
+
If the result set is empty, explain that no matching data was found.
|
46
38
|
|
47
|
-
|
39
|
+
6. Include the final SQL query in your answer unless the user asks you not to.
|
48
40
|
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
41
|
+
Remember:
|
42
|
+
- List tables → fetch schemas → write & verify SELECT → execute → answer.
|
43
|
+
- Never skip steps 1 or 2.
|
44
|
+
- Never perform DML.
|
45
|
+
- Keep answers focused on the user's question."""
|
54
46
|
|
55
|
-
db: SQLDatabase
|
56
|
-
|
57
|
-
def _run(self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None):
|
58
|
-
result = self.db.run_no_throw(query)
|
59
|
-
if not result:
|
60
|
-
return "Error: Query failed. Please rewrite your query and try again."
|
61
|
-
return result
|
62
47
|
|
63
48
|
class SqlAgentAPIWrapper(BaseModel):
|
49
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
50
|
+
|
64
51
|
llm: BaseLanguageModel = Field(description="llm to use for sql agent")
|
65
52
|
sql_address: str = Field(description="sql database address for SQLDatabase uri")
|
66
53
|
|
67
|
-
db: Optional[SQLDatabase]
|
68
|
-
|
69
|
-
get_schema_tool: Optional[BaseTool]
|
70
|
-
db_query_tool: Optional[BaseTool]
|
71
|
-
query_check: Optional[Any]
|
72
|
-
query_gen: Optional[Any]
|
73
|
-
workflow: Optional[StateGraph]
|
74
|
-
app: Optional[Any]
|
75
|
-
schema_llm: Optional[Any]
|
76
|
-
query_check_llm: Optional[Any]
|
77
|
-
query_gen_llm: Optional[Any]
|
78
|
-
|
79
|
-
class Config:
|
80
|
-
arbitrary_types_allowed = True
|
54
|
+
db: Optional[SQLDatabase] = None
|
55
|
+
agent: Optional[CompiledGraph] = None
|
81
56
|
|
82
57
|
def __init__(self, **kwargs):
|
83
58
|
super().__init__(**kwargs)
|
84
59
|
self.llm = kwargs.get('llm')
|
85
|
-
|
86
|
-
# todo 修改sql agent实现逻辑。此处逻辑只支持bishengLLM组件。原因是因为目前的实现必须实例化多个llm对象,每个llm对象绑定不同的tool
|
87
|
-
self.schema_llm = self.llm.__class__(model_id=self.llm.model_id, model_name=self.llm.model_name)
|
88
|
-
self.query_check_llm = self.llm.__class__(model_id=self.llm.model_id, model_name=self.llm.model_name)
|
89
|
-
self.query_gen_llm = self.llm.__class__(model_id=self.llm.model_id, model_name=self.llm.model_name)
|
90
60
|
self.sql_address = kwargs.get('sql_address')
|
91
61
|
|
92
62
|
self.db = SQLDatabase.from_uri(self.sql_address)
|
93
63
|
toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm)
|
94
64
|
tools = toolkit.get_tools()
|
95
|
-
self.
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
self.query_gen = self.init_query_gen()
|
101
|
-
|
102
|
-
# Define a new graph
|
103
|
-
self.workflow = StateGraph(State)
|
104
|
-
self.init_workflow()
|
105
|
-
self.app = self.workflow.compile(checkpointer=False, debug=True)
|
106
|
-
|
107
|
-
def init_workflow(self):
|
108
|
-
self.workflow.add_node("first_tool_call", self.first_tool_call)
|
109
|
-
self.workflow.add_node(
|
110
|
-
"list_tables_tool", create_tool_node_with_fallback([self.list_tables_tool])
|
111
|
-
)
|
112
|
-
|
113
|
-
self.workflow.add_node("get_schema_tool", create_tool_node_with_fallback([self.get_schema_tool]))
|
114
|
-
|
115
|
-
model_get_schema = self.schema_llm.bind_tools(
|
116
|
-
[self.get_schema_tool]
|
65
|
+
self.agent = create_react_agent(
|
66
|
+
self.llm,
|
67
|
+
tools,
|
68
|
+
prompt=_agent_system_prompt.format(dialect=self.db.dialect),
|
69
|
+
checkpointer=False,
|
117
70
|
)
|
118
|
-
self.workflow.add_node(
|
119
|
-
"model_get_schema",
|
120
|
-
lambda state: {
|
121
|
-
"messages": [model_get_schema.invoke(state["messages"])],
|
122
|
-
},
|
123
|
-
)
|
124
|
-
|
125
|
-
self.workflow.add_node("query_gen", self.query_gen_node)
|
126
|
-
self.workflow.add_node("correct_query", self.model_check_query)
|
127
|
-
|
128
|
-
self.workflow.add_node("execute_query", create_tool_node_with_fallback([self.db_query_tool]))
|
129
|
-
|
130
|
-
self.workflow.add_edge(START, "first_tool_call")
|
131
|
-
self.workflow.add_edge("first_tool_call", "list_tables_tool")
|
132
|
-
self.workflow.add_edge("list_tables_tool", "model_get_schema")
|
133
|
-
self.workflow.add_edge("model_get_schema", "get_schema_tool")
|
134
|
-
self.workflow.add_edge("get_schema_tool", "query_gen")
|
135
|
-
self.workflow.add_conditional_edges(
|
136
|
-
"query_gen",
|
137
|
-
self.should_continue,
|
138
|
-
)
|
139
|
-
self.workflow.add_edge("correct_query", "execute_query")
|
140
|
-
self.workflow.add_edge("execute_query", "query_gen")
|
141
|
-
|
142
|
-
@staticmethod
|
143
|
-
def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
|
144
|
-
messages = state["messages"]
|
145
|
-
last_message = messages[-1]
|
146
|
-
# If there is a tool call, then we finish
|
147
|
-
if getattr(last_message, "tool_calls", None):
|
148
|
-
return END
|
149
|
-
if last_message.content.startswith("Error:"):
|
150
|
-
return "query_gen"
|
151
|
-
else:
|
152
|
-
return "correct_query"
|
153
|
-
|
154
|
-
def init_query_check(self):
|
155
|
-
query_check_system = """You are a SQL expert with a strong attention to detail.
|
156
|
-
Double check the SQLite query for common mistakes, including:
|
157
|
-
- Using NOT IN with NULL values
|
158
|
-
- Using UNION when UNION ALL should have been used
|
159
|
-
- Using BETWEEN for exclusive ranges
|
160
|
-
- Data type mismatch in predicates
|
161
|
-
- Properly quoting identifiers
|
162
|
-
- Using the correct number of arguments for functions
|
163
|
-
- Casting to the correct data type
|
164
|
-
- Using the proper columns for joins
|
165
|
-
|
166
|
-
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
|
167
|
-
|
168
|
-
You will call the appropriate tool to execute the query after running this check."""
|
169
|
-
|
170
|
-
query_check_prompt = ChatPromptTemplate.from_messages(
|
171
|
-
[("system", query_check_system), ("placeholder", "{messages}")]
|
172
|
-
)
|
173
|
-
query_check = query_check_prompt | self.query_check_llm.bind_tools(
|
174
|
-
[self.db_query_tool]
|
175
|
-
)
|
176
|
-
return query_check
|
177
|
-
|
178
|
-
def first_tool_call(self, state: State) -> dict[str, list[AIMessage]]:
|
179
|
-
return {
|
180
|
-
"messages": [
|
181
|
-
AIMessage(
|
182
|
-
content="",
|
183
|
-
tool_calls=[
|
184
|
-
{
|
185
|
-
"name": "sql_db_list_tables",
|
186
|
-
"args": {},
|
187
|
-
"id": "tool_abcd123",
|
188
|
-
}
|
189
|
-
],
|
190
|
-
)
|
191
|
-
]
|
192
|
-
}
|
193
|
-
|
194
|
-
def model_check_query(self, state: State) -> dict[str, list[AIMessage]]:
|
195
|
-
"""
|
196
|
-
Use this tool to double-check if your query is correct before executing it.
|
197
|
-
"""
|
198
|
-
return {"messages": [self.query_check.invoke({"messages": [state["messages"][-1]]})]}
|
199
|
-
|
200
|
-
def init_query_gen(self):
|
201
|
-
# Add a node for a model to generate a query based on the question and schema
|
202
|
-
query_gen_system = """You are a SQL expert with a strong attention to detail.Given an input question, output a syntactically correct SQL query to run, then look at the results of the query and return the answer.DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.When generating the query:Output the SQL query that answers the input question without a tool call.Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 10 results.You can order the results by a relevant column to return the most interesting examples in the database.Never query for all the columns from a specific table, only ask for the relevant columns given the question.If you get an error while executing a query, rewrite the query and try again.If you get an empty result set, you should try to rewrite the query to get a non-empty result set. NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."""
|
203
|
-
query_gen_prompt = ChatPromptTemplate.from_messages(
|
204
|
-
[("system", query_gen_system), ("placeholder", "{messages}")]
|
205
|
-
)
|
206
|
-
query_gen = query_gen_prompt | self.query_gen_llm.bind_tools(
|
207
|
-
[SubmitFinalAnswer]
|
208
|
-
)
|
209
|
-
return query_gen
|
210
|
-
|
211
|
-
def query_gen_node(self, state: State) -> Any:
|
212
|
-
message = self.query_gen.invoke(state)
|
213
|
-
|
214
|
-
# Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
|
215
|
-
tool_messages = []
|
216
|
-
if message.tool_calls:
|
217
|
-
for tc in message.tool_calls:
|
218
|
-
if tc["name"] != "SubmitFinalAnswer":
|
219
|
-
tool_messages.append(
|
220
|
-
ToolMessage(
|
221
|
-
content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
|
222
|
-
tool_call_id=tc["id"],
|
223
|
-
)
|
224
|
-
)
|
225
|
-
else:
|
226
|
-
tool_messages = []
|
227
|
-
return {"messages": [message] + tool_messages}
|
228
71
|
|
229
72
|
def run(self, query: str) -> str:
|
230
|
-
messages = self.
|
231
|
-
|
232
|
-
})
|
233
|
-
return messages["messages"][-1].tool_calls[0]["args"]["final_answer"]
|
73
|
+
messages = self.agent.invoke({"messages": [HumanMessage(content=query)]})
|
74
|
+
return messages["messages"][-1].content
|
234
75
|
|
235
76
|
def arun(self, query: str) -> str:
|
236
77
|
return self.run(query)
|
@@ -241,8 +82,8 @@ class SqlAgentInput(BaseModel):
|
|
241
82
|
|
242
83
|
|
243
84
|
class SqlAgentTool(BaseTool):
|
244
|
-
name = "sql_agent"
|
245
|
-
description = "回答与 SQL 数据库有关的问题。给定用户问题,将从数据库中获取可用的表以及对应 DDL,生成 SQL 查询语句并进行执行,最终得到执行结果。"
|
85
|
+
name: str = "sql_agent"
|
86
|
+
description: str = "回答与 SQL 数据库有关的问题。给定用户问题,将从数据库中获取可用的表以及对应 DDL,生成 SQL 查询语句并进行执行,最终得到执行结果。"
|
246
87
|
args_schema: Type[BaseModel] = SqlAgentInput
|
247
88
|
api_wrapper: SqlAgentAPIWrapper
|
248
89
|
|
@@ -1,12 +1,12 @@
|
|
1
1
|
|
2
2
|
from typing import List, Optional
|
3
3
|
|
4
|
-
from pydantic import
|
4
|
+
from pydantic import ConfigDict, BaseModel
|
5
5
|
|
6
6
|
|
7
7
|
class InputNode(BaseModel):
|
8
8
|
"""Input组件,用来控制输入"""
|
9
|
-
input: Optional[List[str]]
|
9
|
+
input: Optional[List[str]] = None
|
10
10
|
|
11
11
|
def text(self):
|
12
12
|
return self.input
|
@@ -15,14 +15,10 @@ class InputNode(BaseModel):
|
|
15
15
|
class VariableNode(BaseModel):
|
16
16
|
"""用来设置变量"""
|
17
17
|
# key
|
18
|
-
variables: Optional[List[str]]
|
18
|
+
variables: Optional[List[str]] = None
|
19
19
|
# vaulues
|
20
20
|
variable_value: Optional[List[str]] = []
|
21
|
-
|
22
|
-
class Config:
|
23
|
-
"""Configuration for this pydantic object."""
|
24
|
-
|
25
|
-
extra = Extra.forbid
|
21
|
+
model_config = ConfigDict(extra="forbid")
|
26
22
|
|
27
23
|
def text(self):
|
28
24
|
if self.variable_value:
|
@@ -36,9 +32,9 @@ class VariableNode(BaseModel):
|
|
36
32
|
|
37
33
|
|
38
34
|
class InputFileNode(BaseModel):
|
39
|
-
file_path: Optional[str]
|
40
|
-
file_name: Optional[str]
|
41
|
-
file_type: Optional[str] # tips for file
|
35
|
+
file_path: Optional[str] = None
|
36
|
+
file_name: Optional[str] = None
|
37
|
+
file_type: Optional[str] = None # tips for file
|
42
38
|
"""Output组件,用来控制输出"""
|
43
39
|
|
44
40
|
def text(self):
|
@@ -5,7 +5,7 @@ from venv import logger
|
|
5
5
|
from bisheng_langchain.chains import LoaderOutputChain
|
6
6
|
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
|
7
7
|
from langchain.chains.base import Chain
|
8
|
-
from pydantic import
|
8
|
+
from pydantic import ConfigDict, BaseModel
|
9
9
|
|
10
10
|
_TEXT_COLOR_MAPPING = {
|
11
11
|
'blue': '36;1',
|
@@ -52,11 +52,7 @@ class Report(Chain):
|
|
52
52
|
|
53
53
|
input_key: str = 'report_name' #: :meta private:
|
54
54
|
output_key: str = 'text' #: :meta private:
|
55
|
-
|
56
|
-
class Config:
|
57
|
-
"""Configuration for this pydantic object."""
|
58
|
-
extra = Extra.forbid
|
59
|
-
arbitrary_types_allowed = True
|
55
|
+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
60
56
|
|
61
57
|
@property
|
62
58
|
def input_keys(self) -> List[str]:
|