sql-assistant 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_assistant/__init__.py +3 -0
- sql_assistant/api/__init__.py +1 -0
- sql_assistant/api/backup.py +116 -0
- sql_assistant/api/config.py +183 -0
- sql_assistant/api/conversation.py +71 -0
- sql_assistant/api/dependencies.py +22 -0
- sql_assistant/api/history.py +61 -0
- sql_assistant/api/models.py +221 -0
- sql_assistant/api/query.py +275 -0
- sql_assistant/api/routes.py +19 -0
- sql_assistant/api/schema.py +21 -0
- sql_assistant/config.py +144 -0
- sql_assistant/database/__init__.py +1 -0
- sql_assistant/database/backup.py +568 -0
- sql_assistant/database/connectors/__init__.py +1 -0
- sql_assistant/database/connectors/base.py +185 -0
- sql_assistant/database/connectors/exceptions.py +88 -0
- sql_assistant/database/connectors/mongodb.py +194 -0
- sql_assistant/database/connectors/mysql.py +110 -0
- sql_assistant/database/connectors/postgresql.py +133 -0
- sql_assistant/database/connectors/redis.py +132 -0
- sql_assistant/database/connectors/sqlserver.py +140 -0
- sql_assistant/database/history.py +290 -0
- sql_assistant/database/manager.py +178 -0
- sql_assistant/database/security.py +230 -0
- sql_assistant/llm/__init__.py +1 -0
- sql_assistant/llm/base.py +28 -0
- sql_assistant/llm/exceptions.py +96 -0
- sql_assistant/llm/manager.py +82 -0
- sql_assistant/llm/prompts.py +29 -0
- sql_assistant/llm/providers/__init__.py +1 -0
- sql_assistant/llm/providers/claude.py +132 -0
- sql_assistant/llm/providers/gemini.py +127 -0
- sql_assistant/llm/providers/openai_compatible.py +103 -0
- sql_assistant/llm/retry.py +88 -0
- sql_assistant/main.py +94 -0
- sql_assistant/settings.py +219 -0
- sql_assistant/web/__init__.py +1 -0
- sql_assistant/web/static/css/base.css +25 -0
- sql_assistant/web/static/css/components/backup.css +146 -0
- sql_assistant/web/static/css/components/chat.css +465 -0
- sql_assistant/web/static/css/components/modal.css +143 -0
- sql_assistant/web/static/css/components/settings.css +358 -0
- sql_assistant/web/static/css/components/sidebar.css +235 -0
- sql_assistant/web/static/css/components/toast.css +30 -0
- sql_assistant/web/static/css/style.css +10 -0
- sql_assistant/web/static/css/theme.css +200 -0
- sql_assistant/web/static/js/api.js +38 -0
- sql_assistant/web/static/js/app.js +161 -0
- sql_assistant/web/static/js/backup.js +216 -0
- sql_assistant/web/static/js/chat.js +238 -0
- sql_assistant/web/static/js/color-theme-manager.js +121 -0
- sql_assistant/web/static/js/confirm.js +95 -0
- sql_assistant/web/static/js/conversations.js +182 -0
- sql_assistant/web/static/js/settings.js +425 -0
- sql_assistant/web/static/js/state.js +43 -0
- sql_assistant/web/static/js/theme-manager.js +64 -0
- sql_assistant/web/static/js/ui.js +53 -0
- sql_assistant/web/templates/index.html +373 -0
- sql_assistant-1.0.0.dist-info/METADATA +24 -0
- sql_assistant-1.0.0.dist-info/RECORD +64 -0
- sql_assistant-1.0.0.dist-info/WHEEL +4 -0
- sql_assistant-1.0.0.dist-info/entry_points.txt +2 -0
- sql_assistant-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""数据库连接器抽象基类"""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from .exceptions import format_connector_result
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class QueryResult:
|
|
12
|
+
"""统一查询结果"""
|
|
13
|
+
columns: list[str] = field(default_factory=list)
|
|
14
|
+
rows: list[list[Any]] = field(default_factory=list)
|
|
15
|
+
row_count: int = 0
|
|
16
|
+
affected_rows: int = 0
|
|
17
|
+
sql_type: str = "" # SELECT, INSERT, UPDATE, DELETE, OTHER
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BaseConnector(ABC):
|
|
21
|
+
"""数据库连接器抽象基类"""
|
|
22
|
+
|
|
23
|
+
db_type: str = "unknown"
|
|
24
|
+
|
|
25
|
+
def __init__(self, host: str, port: int, user: str, password: str, database: str):
|
|
26
|
+
self.host = host
|
|
27
|
+
self.port = port
|
|
28
|
+
self.user = user
|
|
29
|
+
self.password = password
|
|
30
|
+
self.database = database
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
async def connect(self) -> None:
|
|
34
|
+
"""建立数据库连接"""
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
async def disconnect(self) -> None:
|
|
39
|
+
"""断开数据库连接"""
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
async def execute(self, sql: str) -> QueryResult:
|
|
44
|
+
"""执行 SQL 语句,返回统一结果"""
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
async def test_connection(self) -> bool:
|
|
49
|
+
"""测试连接是否可用"""
|
|
50
|
+
...
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
async def get_schema(self) -> dict:
|
|
54
|
+
"""获取数据库 schema(表名 + 列信息)。
|
|
55
|
+
|
|
56
|
+
返回格式:
|
|
57
|
+
{
|
|
58
|
+
"db_type": "mysql",
|
|
59
|
+
"tables": [
|
|
60
|
+
{"name": "users", "columns": [
|
|
61
|
+
{"name": "id", "type": "INT", "nullable": False, "key": "PRI"},
|
|
62
|
+
{"name": "name", "type": "VARCHAR(100)", "nullable": True, "key": ""},
|
|
63
|
+
]},
|
|
64
|
+
]
|
|
65
|
+
}
|
|
66
|
+
"""
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def classify_sql(sql: str) -> str:
|
|
71
|
+
"""判断 SQL 语句类型(忽略前导注释和空行)
|
|
72
|
+
|
|
73
|
+
支持多语句 SQL,检查所有子语句的类型。如果任何子语句是增删改或 DDL,
|
|
74
|
+
则返回该类型,确保危险操作能触发确认对话框。
|
|
75
|
+
"""
|
|
76
|
+
s = BaseConnector._strip_sql_comments(sql).strip()
|
|
77
|
+
|
|
78
|
+
# 按分号分割 SQL 语句(处理字符串内的分号)
|
|
79
|
+
statements = BaseConnector._split_sql_statements(s)
|
|
80
|
+
|
|
81
|
+
for stmt in statements:
|
|
82
|
+
stmt_upper = stmt.strip().upper()
|
|
83
|
+
if not stmt_upper:
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
# 按优先级判断:DDL > DELETE > UPDATE > INSERT > SELECT
|
|
87
|
+
if stmt_upper.startswith("CREATE") or stmt_upper.startswith("ALTER") or stmt_upper.startswith("DROP"):
|
|
88
|
+
return "DDL"
|
|
89
|
+
elif stmt_upper.startswith("DELETE"):
|
|
90
|
+
return "DELETE"
|
|
91
|
+
elif stmt_upper.startswith("UPDATE"):
|
|
92
|
+
return "UPDATE"
|
|
93
|
+
elif stmt_upper.startswith("INSERT"):
|
|
94
|
+
return "INSERT"
|
|
95
|
+
|
|
96
|
+
# 如果没有找到增删改或 DDL,检查是否是查询语句
|
|
97
|
+
for stmt in statements:
|
|
98
|
+
stmt_upper = stmt.strip().upper()
|
|
99
|
+
if stmt_upper.startswith("SELECT") or stmt_upper.startswith("SHOW") or stmt_upper.startswith("DESCRIBE") or stmt_upper.startswith("EXPLAIN"):
|
|
100
|
+
return "SELECT"
|
|
101
|
+
|
|
102
|
+
return "OTHER"
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def _split_sql_statements(sql: str) -> list:
|
|
106
|
+
"""安全地按分号分割 SQL 语句,处理字符串内的分号"""
|
|
107
|
+
statements = []
|
|
108
|
+
current_stmt = []
|
|
109
|
+
in_single_quote = False
|
|
110
|
+
in_double_quote = False
|
|
111
|
+
in_comment = False
|
|
112
|
+
|
|
113
|
+
for i, char in enumerate(sql):
|
|
114
|
+
# 处理行注释
|
|
115
|
+
if i > 0 and sql[i-1] == '-' and char == '-':
|
|
116
|
+
# 跳过到行尾
|
|
117
|
+
while i < len(sql) and sql[i] != '\n':
|
|
118
|
+
i += 1
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
# 处理单引号字符串
|
|
122
|
+
if char == "'" and not in_double_quote and not in_comment:
|
|
123
|
+
in_single_quote = not in_single_quote
|
|
124
|
+
|
|
125
|
+
# 处理双引号字符串
|
|
126
|
+
if char == '"' and not in_single_quote and not in_comment:
|
|
127
|
+
in_double_quote = not in_double_quote
|
|
128
|
+
|
|
129
|
+
# 处理分号(仅在字符串外)
|
|
130
|
+
if char == ';' and not in_single_quote and not in_double_quote:
|
|
131
|
+
statements.append(''.join(current_stmt))
|
|
132
|
+
current_stmt = []
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
current_stmt.append(char)
|
|
136
|
+
|
|
137
|
+
# 添加最后一个语句(如果有)
|
|
138
|
+
if current_stmt:
|
|
139
|
+
statements.append(''.join(current_stmt))
|
|
140
|
+
|
|
141
|
+
return statements
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def _strip_sql_comments(sql: str) -> str:
|
|
145
|
+
"""去除 SQL 前导注释行和空行,保留实际语句部分。
|
|
146
|
+
|
|
147
|
+
处理两种注释风格:
|
|
148
|
+
- 行注释:-- ... 或 # ...
|
|
149
|
+
- 块注释:/* ... */
|
|
150
|
+
"""
|
|
151
|
+
lines = sql.split("\n")
|
|
152
|
+
result_lines = []
|
|
153
|
+
in_block_comment = False
|
|
154
|
+
|
|
155
|
+
for line in lines:
|
|
156
|
+
stripped = line.strip()
|
|
157
|
+
|
|
158
|
+
# 处理块注释(可能跨行)
|
|
159
|
+
if in_block_comment:
|
|
160
|
+
end_idx = stripped.find("*/")
|
|
161
|
+
if end_idx != -1:
|
|
162
|
+
in_block_comment = False
|
|
163
|
+
remaining = stripped[end_idx + 2:].strip()
|
|
164
|
+
if remaining:
|
|
165
|
+
result_lines.append(remaining)
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
# 检查行首是否是块注释开始
|
|
169
|
+
if stripped.startswith("/*"):
|
|
170
|
+
end_idx = stripped.find("*/", 2)
|
|
171
|
+
if end_idx != -1:
|
|
172
|
+
remaining = stripped[end_idx + 2:].strip()
|
|
173
|
+
if remaining:
|
|
174
|
+
result_lines.append(remaining)
|
|
175
|
+
else:
|
|
176
|
+
in_block_comment = True
|
|
177
|
+
continue
|
|
178
|
+
|
|
179
|
+
# 跳过行注释和空行
|
|
180
|
+
if not stripped or stripped.startswith("--") or stripped.startswith("#"):
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
result_lines.append(line)
|
|
184
|
+
|
|
185
|
+
return "\n".join(result_lines)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""数据库连接器统一异常"""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ConnectorError(Exception):
|
|
7
|
+
"""数据库连接器基础异常"""
|
|
8
|
+
def __init__(self, message: str, db_type: str = "", code: Optional[str] = None):
|
|
9
|
+
super().__init__(message)
|
|
10
|
+
self.db_type = db_type
|
|
11
|
+
self.code = code
|
|
12
|
+
|
|
13
|
+
def to_dict(self) -> dict:
|
|
14
|
+
return {
|
|
15
|
+
"success": False,
|
|
16
|
+
"error": str(self),
|
|
17
|
+
"db_type": self.db_type,
|
|
18
|
+
"code": self.code,
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ConnectionError(ConnectorError):
|
|
23
|
+
"""连接失败异常"""
|
|
24
|
+
def __init__(self, message: str, db_type: str = ""):
|
|
25
|
+
super().__init__(message, db_type, code="CONNECTION_FAILED")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class QueryError(ConnectorError):
|
|
29
|
+
"""查询执行异常"""
|
|
30
|
+
def __init__(self, message: str, db_type: str = "", sql: str = ""):
|
|
31
|
+
super().__init__(message, db_type, code="QUERY_FAILED")
|
|
32
|
+
self.sql = sql
|
|
33
|
+
|
|
34
|
+
def to_dict(self) -> dict:
|
|
35
|
+
result = super().to_dict()
|
|
36
|
+
result["sql"] = self.sql
|
|
37
|
+
return result
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SchemaError(ConnectorError):
|
|
41
|
+
"""Schema 获取异常"""
|
|
42
|
+
def __init__(self, message: str, db_type: str = ""):
|
|
43
|
+
super().__init__(message, db_type, code="SCHEMA_ERROR")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class AuthenticationError(ConnectorError):
|
|
47
|
+
"""认证失败异常"""
|
|
48
|
+
def __init__(self, message: str, db_type: str = ""):
|
|
49
|
+
super().__init__(message, db_type, code="AUTH_FAILED")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class TimeoutError(ConnectorError):
|
|
53
|
+
"""连接超时异常"""
|
|
54
|
+
def __init__(self, message: str, db_type: str = ""):
|
|
55
|
+
super().__init__(message, db_type, code="TIMEOUT")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def format_connector_result(
|
|
59
|
+
success: bool,
|
|
60
|
+
data: any = None,
|
|
61
|
+
error: Optional[str] = None,
|
|
62
|
+
db_type: str = "",
|
|
63
|
+
code: Optional[str] = None,
|
|
64
|
+
) -> dict:
|
|
65
|
+
"""统一格式化连接器返回结果
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
success: 是否成功
|
|
69
|
+
data: 成功时返回的数据
|
|
70
|
+
error: 失败时的错误信息
|
|
71
|
+
db_type: 数据库类型
|
|
72
|
+
code: 错误码
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
统一格式的字典
|
|
76
|
+
"""
|
|
77
|
+
result = {"success": success}
|
|
78
|
+
|
|
79
|
+
if success:
|
|
80
|
+
result["data"] = data
|
|
81
|
+
else:
|
|
82
|
+
result["error"] = error or "未知错误"
|
|
83
|
+
if code:
|
|
84
|
+
result["code"] = code
|
|
85
|
+
if db_type:
|
|
86
|
+
result["db_type"] = db_type
|
|
87
|
+
|
|
88
|
+
return result
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""MongoDB 连接器"""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import asyncio
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pymongo import MongoClient
|
|
8
|
+
|
|
9
|
+
from .base import BaseConnector, QueryResult
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MongoDBConnector(BaseConnector):
|
|
13
|
+
"""MongoDB 数据库连接器"""
|
|
14
|
+
|
|
15
|
+
db_type = "mongodb"
|
|
16
|
+
|
|
17
|
+
def __init__(self, host: str, port: int, user: str, password: str, database: str):
|
|
18
|
+
super().__init__(host, port, user, password, database)
|
|
19
|
+
self._client: MongoClient | None = None
|
|
20
|
+
self._db = None
|
|
21
|
+
|
|
22
|
+
async def connect(self) -> None:
|
|
23
|
+
loop = asyncio.get_event_loop()
|
|
24
|
+
connection_string = f"mongodb://{self.user}:{self.password}@{self.host}:{self.port}/" if self.user else f"mongodb://{self.host}:{self.port}/"
|
|
25
|
+
self._client = await loop.run_in_executor(
|
|
26
|
+
None,
|
|
27
|
+
lambda: MongoClient(connection_string, serverSelectionTimeoutMS=10000),
|
|
28
|
+
)
|
|
29
|
+
self._db = self._client[self.database]
|
|
30
|
+
|
|
31
|
+
async def disconnect(self) -> None:
|
|
32
|
+
if self._client:
|
|
33
|
+
self._client.close()
|
|
34
|
+
self._client = None
|
|
35
|
+
self._db = None
|
|
36
|
+
|
|
37
|
+
async def get_schema(self) -> dict:
|
|
38
|
+
"""MongoDB 无固定 schema,采样分析每个 collection 的字段"""
|
|
39
|
+
if not self._client or self._db is None:
|
|
40
|
+
raise RuntimeError("MongoDB 未连接")
|
|
41
|
+
|
|
42
|
+
loop = asyncio.get_event_loop()
|
|
43
|
+
|
|
44
|
+
def _run():
|
|
45
|
+
collections = self._db.list_collection_names()
|
|
46
|
+
tables = []
|
|
47
|
+
for coll_name in collections:
|
|
48
|
+
# 采样最多 5 条文档推断字段
|
|
49
|
+
docs = list(self._db[coll_name].find().limit(5))
|
|
50
|
+
columns_set: dict[str, set] = {}
|
|
51
|
+
for doc in docs:
|
|
52
|
+
for key in doc.keys():
|
|
53
|
+
if key not in columns_set:
|
|
54
|
+
columns_set[key] = set()
|
|
55
|
+
val = doc[key]
|
|
56
|
+
columns_set[key].add(type(val).__name__ if val is not None else "null")
|
|
57
|
+
|
|
58
|
+
columns = []
|
|
59
|
+
for col_name, types in columns_set.items():
|
|
60
|
+
columns.append({
|
|
61
|
+
"name": col_name,
|
|
62
|
+
"type": ", ".join(sorted(types)),
|
|
63
|
+
"nullable": True,
|
|
64
|
+
"key": "_id" if col_name == "_id" else "",
|
|
65
|
+
"default": None,
|
|
66
|
+
"comment": "",
|
|
67
|
+
})
|
|
68
|
+
tables.append({"name": coll_name, "columns": columns})
|
|
69
|
+
|
|
70
|
+
return {"db_type": "mongodb", "tables": tables}
|
|
71
|
+
|
|
72
|
+
return await loop.run_in_executor(None, _run)
|
|
73
|
+
|
|
74
|
+
async def execute(self, sql: str) -> QueryResult:
|
|
75
|
+
"""执行 MongoDB 查询(JSON 格式)"""
|
|
76
|
+
if not self._client or self._db is None:
|
|
77
|
+
raise RuntimeError("MongoDB 未连接")
|
|
78
|
+
|
|
79
|
+
loop = asyncio.get_event_loop()
|
|
80
|
+
result = QueryResult()
|
|
81
|
+
|
|
82
|
+
def _run():
|
|
83
|
+
try:
|
|
84
|
+
# 解析 JSON 查询
|
|
85
|
+
query = json.loads(sql)
|
|
86
|
+
collection_name = query.get("collection", "")
|
|
87
|
+
operation = query.get("operation", "find").lower()
|
|
88
|
+
collection = self._db[collection_name]
|
|
89
|
+
|
|
90
|
+
if operation == "find":
|
|
91
|
+
filter_obj = query.get("filter", {})
|
|
92
|
+
projection = query.get("projection", None)
|
|
93
|
+
sort_list = query.get("sort", None)
|
|
94
|
+
limit_val = query.get("limit", 100)
|
|
95
|
+
|
|
96
|
+
cursor = collection.find(filter_obj, projection)
|
|
97
|
+
if sort_list:
|
|
98
|
+
cursor = cursor.sort(sort_list)
|
|
99
|
+
if limit_val:
|
|
100
|
+
cursor = cursor.limit(limit_val)
|
|
101
|
+
|
|
102
|
+
docs = list(cursor)
|
|
103
|
+
# 提取所有字段名
|
|
104
|
+
columns_set = set()
|
|
105
|
+
for doc in docs:
|
|
106
|
+
columns_set.update(doc.keys())
|
|
107
|
+
result.columns = sorted(columns_set)
|
|
108
|
+
result.rows = []
|
|
109
|
+
for doc in docs:
|
|
110
|
+
result.rows.append([doc.get(col, None) for col in result.columns])
|
|
111
|
+
result.row_count = len(docs)
|
|
112
|
+
result.sql_type = "SELECT"
|
|
113
|
+
|
|
114
|
+
elif operation == "insert":
|
|
115
|
+
documents = query.get("documents", [])
|
|
116
|
+
if isinstance(documents, dict):
|
|
117
|
+
documents = [documents]
|
|
118
|
+
result_obj = collection.insert_many(documents)
|
|
119
|
+
result.affected_rows = len(result_obj.inserted_ids)
|
|
120
|
+
result.sql_type = "INSERT"
|
|
121
|
+
|
|
122
|
+
elif operation == "update":
|
|
123
|
+
filter_obj = query.get("filter", {})
|
|
124
|
+
update_obj = query.get("update", {})
|
|
125
|
+
many = query.get("many", False)
|
|
126
|
+
if many:
|
|
127
|
+
update_result = collection.update_many(filter_obj, update_obj)
|
|
128
|
+
else:
|
|
129
|
+
update_result = collection.update_one(filter_obj, update_obj)
|
|
130
|
+
result.affected_rows = update_result.modified_count
|
|
131
|
+
result.sql_type = "UPDATE"
|
|
132
|
+
|
|
133
|
+
elif operation == "delete":
|
|
134
|
+
filter_obj = query.get("filter", {})
|
|
135
|
+
many = query.get("many", False)
|
|
136
|
+
if many:
|
|
137
|
+
delete_result = collection.delete_many(filter_obj)
|
|
138
|
+
else:
|
|
139
|
+
delete_result = collection.delete_one(filter_obj)
|
|
140
|
+
result.affected_rows = delete_result.deleted_count
|
|
141
|
+
result.sql_type = "DELETE"
|
|
142
|
+
|
|
143
|
+
elif operation == "aggregate":
|
|
144
|
+
pipeline = query.get("pipeline", [])
|
|
145
|
+
docs = list(collection.aggregate(pipeline))
|
|
146
|
+
columns_set = set()
|
|
147
|
+
for doc in docs:
|
|
148
|
+
columns_set.update(doc.keys())
|
|
149
|
+
result.columns = sorted(columns_set)
|
|
150
|
+
result.rows = [[doc.get(col, None) for col in result.columns] for doc in docs]
|
|
151
|
+
result.row_count = len(docs)
|
|
152
|
+
result.sql_type = "SELECT"
|
|
153
|
+
|
|
154
|
+
elif operation == "count":
|
|
155
|
+
filter_obj = query.get("filter", {})
|
|
156
|
+
count = collection.count_documents(filter_obj)
|
|
157
|
+
result.columns = ["count"]
|
|
158
|
+
result.rows = [[count]]
|
|
159
|
+
result.row_count = 1
|
|
160
|
+
result.sql_type = "SELECT"
|
|
161
|
+
|
|
162
|
+
else:
|
|
163
|
+
result.columns = ["error"]
|
|
164
|
+
result.rows = [[f"不支持的操作: {operation}"]]
|
|
165
|
+
result.row_count = 1
|
|
166
|
+
|
|
167
|
+
except json.JSONDecodeError as e:
|
|
168
|
+
result.columns = ["error"]
|
|
169
|
+
result.rows = [[f"JSON 解析错误: {e}"]]
|
|
170
|
+
result.row_count = 1
|
|
171
|
+
except Exception as e:
|
|
172
|
+
result.columns = ["error"]
|
|
173
|
+
result.rows = [[str(e)]]
|
|
174
|
+
result.row_count = 1
|
|
175
|
+
|
|
176
|
+
return result
|
|
177
|
+
|
|
178
|
+
return await loop.run_in_executor(None, _run)
|
|
179
|
+
|
|
180
|
+
async def test_connection(self) -> dict:
|
|
181
|
+
from .exceptions import format_connector_result
|
|
182
|
+
try:
|
|
183
|
+
await self.connect()
|
|
184
|
+
if self._client:
|
|
185
|
+
loop = asyncio.get_event_loop()
|
|
186
|
+
await loop.run_in_executor(
|
|
187
|
+
None,
|
|
188
|
+
lambda: self._client.admin.command("ping"),
|
|
189
|
+
)
|
|
190
|
+
return format_connector_result(True, data={"message": "MongoDB 连接成功"}, db_type="mongodb")
|
|
191
|
+
except Exception as e:
|
|
192
|
+
return format_connector_result(False, error=str(e), db_type="mongodb", code="CONNECTION_FAILED")
|
|
193
|
+
finally:
|
|
194
|
+
await self.disconnect()
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""MySQL 连接器"""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import pymysql
|
|
7
|
+
|
|
8
|
+
from .base import BaseConnector, QueryResult
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MySQLConnector(BaseConnector):
|
|
12
|
+
"""MySQL 数据库连接器"""
|
|
13
|
+
|
|
14
|
+
db_type = "mysql"
|
|
15
|
+
|
|
16
|
+
def __init__(self, host: str, port: int, user: str, password: str, database: str):
|
|
17
|
+
super().__init__(host, port, user, password, database)
|
|
18
|
+
self._conn: pymysql.Connection | None = None
|
|
19
|
+
|
|
20
|
+
async def connect(self) -> None:
|
|
21
|
+
loop = asyncio.get_event_loop()
|
|
22
|
+
self._conn = await loop.run_in_executor(
|
|
23
|
+
None,
|
|
24
|
+
lambda: pymysql.connect(
|
|
25
|
+
host=self.host,
|
|
26
|
+
port=self.port,
|
|
27
|
+
user=self.user,
|
|
28
|
+
password=self.password,
|
|
29
|
+
database=self.database,
|
|
30
|
+
charset="utf8mb4",
|
|
31
|
+
cursorclass=pymysql.cursors.Cursor,
|
|
32
|
+
connect_timeout=10,
|
|
33
|
+
),
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
async def disconnect(self) -> None:
|
|
37
|
+
if self._conn:
|
|
38
|
+
self._conn.close()
|
|
39
|
+
self._conn = None
|
|
40
|
+
|
|
41
|
+
async def execute(self, sql: str) -> QueryResult:
|
|
42
|
+
if not self._conn:
|
|
43
|
+
raise RuntimeError("MySQL 未连接")
|
|
44
|
+
|
|
45
|
+
loop = asyncio.get_event_loop()
|
|
46
|
+
result = QueryResult()
|
|
47
|
+
|
|
48
|
+
def _run():
|
|
49
|
+
with self._conn.cursor() as cursor:
|
|
50
|
+
# 分割多条语句
|
|
51
|
+
statements = [s.strip() for s in sql.split(";") if s.strip()]
|
|
52
|
+
if not statements:
|
|
53
|
+
return result
|
|
54
|
+
|
|
55
|
+
for stmt in statements:
|
|
56
|
+
cursor.execute(stmt)
|
|
57
|
+
sql_type = self.classify_sql(stmt)
|
|
58
|
+
result.sql_type = sql_type
|
|
59
|
+
|
|
60
|
+
if sql_type == "SELECT":
|
|
61
|
+
result.columns = [col[0] for col in cursor.description] if cursor.description else []
|
|
62
|
+
result.rows = cursor.fetchall() if cursor.description else []
|
|
63
|
+
result.row_count = len(result.rows)
|
|
64
|
+
else:
|
|
65
|
+
result.affected_rows = cursor.rowcount
|
|
66
|
+
|
|
67
|
+
self._conn.commit()
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
return await loop.run_in_executor(None, _run)
|
|
71
|
+
|
|
72
|
+
async def get_schema(self) -> dict:
|
|
73
|
+
if not self._conn:
|
|
74
|
+
raise RuntimeError("MySQL 未连接")
|
|
75
|
+
|
|
76
|
+
loop = asyncio.get_event_loop()
|
|
77
|
+
|
|
78
|
+
def _run():
|
|
79
|
+
tables = []
|
|
80
|
+
with self._conn.cursor() as cursor:
|
|
81
|
+
cursor.execute("SHOW TABLES")
|
|
82
|
+
table_rows = cursor.fetchall()
|
|
83
|
+
for (table_name,) in table_rows:
|
|
84
|
+
cursor.execute(f"SHOW FULL COLUMNS FROM `{table_name}`")
|
|
85
|
+
columns = []
|
|
86
|
+
for col in cursor.fetchall():
|
|
87
|
+
columns.append({
|
|
88
|
+
"name": col[0],
|
|
89
|
+
"type": col[1],
|
|
90
|
+
"nullable": col[3] == "YES",
|
|
91
|
+
"key": col[4] or "",
|
|
92
|
+
"default": str(col[5]) if col[5] is not None else None,
|
|
93
|
+
"comment": col[8] or "",
|
|
94
|
+
})
|
|
95
|
+
tables.append({"name": table_name, "columns": columns})
|
|
96
|
+
return {"db_type": "mysql", "tables": tables}
|
|
97
|
+
|
|
98
|
+
return await loop.run_in_executor(None, _run)
|
|
99
|
+
|
|
100
|
+
async def test_connection(self) -> dict:
|
|
101
|
+
from .exceptions import format_connector_result
|
|
102
|
+
try:
|
|
103
|
+
await self.connect()
|
|
104
|
+
if self._conn:
|
|
105
|
+
self._conn.ping(reconnect=False)
|
|
106
|
+
return format_connector_result(True, data={"message": "MySQL 连接成功"}, db_type="mysql")
|
|
107
|
+
except Exception as e:
|
|
108
|
+
return format_connector_result(False, error=str(e), db_type="mysql", code="CONNECTION_FAILED")
|
|
109
|
+
finally:
|
|
110
|
+
await self.disconnect()
|