sql-assistant 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. sql_assistant/__init__.py +3 -0
  2. sql_assistant/api/__init__.py +1 -0
  3. sql_assistant/api/backup.py +116 -0
  4. sql_assistant/api/config.py +183 -0
  5. sql_assistant/api/conversation.py +71 -0
  6. sql_assistant/api/dependencies.py +22 -0
  7. sql_assistant/api/history.py +61 -0
  8. sql_assistant/api/models.py +221 -0
  9. sql_assistant/api/query.py +275 -0
  10. sql_assistant/api/routes.py +19 -0
  11. sql_assistant/api/schema.py +21 -0
  12. sql_assistant/config.py +144 -0
  13. sql_assistant/database/__init__.py +1 -0
  14. sql_assistant/database/backup.py +568 -0
  15. sql_assistant/database/connectors/__init__.py +1 -0
  16. sql_assistant/database/connectors/base.py +185 -0
  17. sql_assistant/database/connectors/exceptions.py +88 -0
  18. sql_assistant/database/connectors/mongodb.py +194 -0
  19. sql_assistant/database/connectors/mysql.py +110 -0
  20. sql_assistant/database/connectors/postgresql.py +133 -0
  21. sql_assistant/database/connectors/redis.py +132 -0
  22. sql_assistant/database/connectors/sqlserver.py +140 -0
  23. sql_assistant/database/history.py +290 -0
  24. sql_assistant/database/manager.py +178 -0
  25. sql_assistant/database/security.py +230 -0
  26. sql_assistant/llm/__init__.py +1 -0
  27. sql_assistant/llm/base.py +28 -0
  28. sql_assistant/llm/exceptions.py +96 -0
  29. sql_assistant/llm/manager.py +82 -0
  30. sql_assistant/llm/prompts.py +29 -0
  31. sql_assistant/llm/providers/__init__.py +1 -0
  32. sql_assistant/llm/providers/claude.py +132 -0
  33. sql_assistant/llm/providers/gemini.py +127 -0
  34. sql_assistant/llm/providers/openai_compatible.py +103 -0
  35. sql_assistant/llm/retry.py +88 -0
  36. sql_assistant/main.py +94 -0
  37. sql_assistant/settings.py +219 -0
  38. sql_assistant/web/__init__.py +1 -0
  39. sql_assistant/web/static/css/base.css +25 -0
  40. sql_assistant/web/static/css/components/backup.css +146 -0
  41. sql_assistant/web/static/css/components/chat.css +465 -0
  42. sql_assistant/web/static/css/components/modal.css +143 -0
  43. sql_assistant/web/static/css/components/settings.css +358 -0
  44. sql_assistant/web/static/css/components/sidebar.css +235 -0
  45. sql_assistant/web/static/css/components/toast.css +30 -0
  46. sql_assistant/web/static/css/style.css +10 -0
  47. sql_assistant/web/static/css/theme.css +200 -0
  48. sql_assistant/web/static/js/api.js +38 -0
  49. sql_assistant/web/static/js/app.js +161 -0
  50. sql_assistant/web/static/js/backup.js +216 -0
  51. sql_assistant/web/static/js/chat.js +238 -0
  52. sql_assistant/web/static/js/color-theme-manager.js +121 -0
  53. sql_assistant/web/static/js/confirm.js +95 -0
  54. sql_assistant/web/static/js/conversations.js +182 -0
  55. sql_assistant/web/static/js/settings.js +425 -0
  56. sql_assistant/web/static/js/state.js +43 -0
  57. sql_assistant/web/static/js/theme-manager.js +64 -0
  58. sql_assistant/web/static/js/ui.js +53 -0
  59. sql_assistant/web/templates/index.html +373 -0
  60. sql_assistant-1.0.0.dist-info/METADATA +24 -0
  61. sql_assistant-1.0.0.dist-info/RECORD +64 -0
  62. sql_assistant-1.0.0.dist-info/WHEEL +4 -0
  63. sql_assistant-1.0.0.dist-info/entry_points.txt +2 -0
  64. sql_assistant-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,185 @@
1
+ """数据库连接器抽象基类"""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass, field
5
+ from typing import Any
6
+
7
+ from .exceptions import format_connector_result
8
+
9
+
10
+ @dataclass
11
+ class QueryResult:
12
+ """统一查询结果"""
13
+ columns: list[str] = field(default_factory=list)
14
+ rows: list[list[Any]] = field(default_factory=list)
15
+ row_count: int = 0
16
+ affected_rows: int = 0
17
+ sql_type: str = "" # SELECT, INSERT, UPDATE, DELETE, OTHER
18
+
19
+
20
+ class BaseConnector(ABC):
21
+ """数据库连接器抽象基类"""
22
+
23
+ db_type: str = "unknown"
24
+
25
+ def __init__(self, host: str, port: int, user: str, password: str, database: str):
26
+ self.host = host
27
+ self.port = port
28
+ self.user = user
29
+ self.password = password
30
+ self.database = database
31
+
32
+ @abstractmethod
33
+ async def connect(self) -> None:
34
+ """建立数据库连接"""
35
+ ...
36
+
37
+ @abstractmethod
38
+ async def disconnect(self) -> None:
39
+ """断开数据库连接"""
40
+ ...
41
+
42
+ @abstractmethod
43
+ async def execute(self, sql: str) -> QueryResult:
44
+ """执行 SQL 语句,返回统一结果"""
45
+ ...
46
+
47
+ @abstractmethod
48
+ async def test_connection(self) -> bool:
49
+ """测试连接是否可用"""
50
+ ...
51
+
52
+ @abstractmethod
53
+ async def get_schema(self) -> dict:
54
+ """获取数据库 schema(表名 + 列信息)。
55
+
56
+ 返回格式:
57
+ {
58
+ "db_type": "mysql",
59
+ "tables": [
60
+ {"name": "users", "columns": [
61
+ {"name": "id", "type": "INT", "nullable": False, "key": "PRI"},
62
+ {"name": "name", "type": "VARCHAR(100)", "nullable": True, "key": ""},
63
+ ]},
64
+ ]
65
+ }
66
+ """
67
+ ...
68
+
69
+ @staticmethod
70
+ def classify_sql(sql: str) -> str:
71
+ """判断 SQL 语句类型(忽略前导注释和空行)
72
+
73
+ 支持多语句 SQL,检查所有子语句的类型。如果任何子语句是增删改或 DDL,
74
+ 则返回该类型,确保危险操作能触发确认对话框。
75
+ """
76
+ s = BaseConnector._strip_sql_comments(sql).strip()
77
+
78
+ # 按分号分割 SQL 语句(处理字符串内的分号)
79
+ statements = BaseConnector._split_sql_statements(s)
80
+
81
+ for stmt in statements:
82
+ stmt_upper = stmt.strip().upper()
83
+ if not stmt_upper:
84
+ continue
85
+
86
+ # 按优先级判断:DDL > DELETE > UPDATE > INSERT > SELECT
87
+ if stmt_upper.startswith("CREATE") or stmt_upper.startswith("ALTER") or stmt_upper.startswith("DROP"):
88
+ return "DDL"
89
+ elif stmt_upper.startswith("DELETE"):
90
+ return "DELETE"
91
+ elif stmt_upper.startswith("UPDATE"):
92
+ return "UPDATE"
93
+ elif stmt_upper.startswith("INSERT"):
94
+ return "INSERT"
95
+
96
+ # 如果没有找到增删改或 DDL,检查是否是查询语句
97
+ for stmt in statements:
98
+ stmt_upper = stmt.strip().upper()
99
+ if stmt_upper.startswith("SELECT") or stmt_upper.startswith("SHOW") or stmt_upper.startswith("DESCRIBE") or stmt_upper.startswith("EXPLAIN"):
100
+ return "SELECT"
101
+
102
+ return "OTHER"
103
+
104
+ @staticmethod
105
+ def _split_sql_statements(sql: str) -> list:
106
+ """安全地按分号分割 SQL 语句,处理字符串内的分号"""
107
+ statements = []
108
+ current_stmt = []
109
+ in_single_quote = False
110
+ in_double_quote = False
111
+ in_comment = False
112
+
113
+ for i, char in enumerate(sql):
114
+ # 处理行注释
115
+ if i > 0 and sql[i-1] == '-' and char == '-':
116
+ # 跳过到行尾
117
+ while i < len(sql) and sql[i] != '\n':
118
+ i += 1
119
+ continue
120
+
121
+ # 处理单引号字符串
122
+ if char == "'" and not in_double_quote and not in_comment:
123
+ in_single_quote = not in_single_quote
124
+
125
+ # 处理双引号字符串
126
+ if char == '"' and not in_single_quote and not in_comment:
127
+ in_double_quote = not in_double_quote
128
+
129
+ # 处理分号(仅在字符串外)
130
+ if char == ';' and not in_single_quote and not in_double_quote:
131
+ statements.append(''.join(current_stmt))
132
+ current_stmt = []
133
+ continue
134
+
135
+ current_stmt.append(char)
136
+
137
+ # 添加最后一个语句(如果有)
138
+ if current_stmt:
139
+ statements.append(''.join(current_stmt))
140
+
141
+ return statements
142
+
143
+ @staticmethod
144
+ def _strip_sql_comments(sql: str) -> str:
145
+ """去除 SQL 前导注释行和空行,保留实际语句部分。
146
+
147
+ 处理两种注释风格:
148
+ - 行注释:-- ... 或 # ...
149
+ - 块注释:/* ... */
150
+ """
151
+ lines = sql.split("\n")
152
+ result_lines = []
153
+ in_block_comment = False
154
+
155
+ for line in lines:
156
+ stripped = line.strip()
157
+
158
+ # 处理块注释(可能跨行)
159
+ if in_block_comment:
160
+ end_idx = stripped.find("*/")
161
+ if end_idx != -1:
162
+ in_block_comment = False
163
+ remaining = stripped[end_idx + 2:].strip()
164
+ if remaining:
165
+ result_lines.append(remaining)
166
+ continue
167
+
168
+ # 检查行首是否是块注释开始
169
+ if stripped.startswith("/*"):
170
+ end_idx = stripped.find("*/", 2)
171
+ if end_idx != -1:
172
+ remaining = stripped[end_idx + 2:].strip()
173
+ if remaining:
174
+ result_lines.append(remaining)
175
+ else:
176
+ in_block_comment = True
177
+ continue
178
+
179
+ # 跳过行注释和空行
180
+ if not stripped or stripped.startswith("--") or stripped.startswith("#"):
181
+ continue
182
+
183
+ result_lines.append(line)
184
+
185
+ return "\n".join(result_lines)
@@ -0,0 +1,88 @@
1
+ """数据库连接器统一异常"""
2
+
3
+ from typing import Optional
4
+
5
+
6
+ class ConnectorError(Exception):
7
+ """数据库连接器基础异常"""
8
+ def __init__(self, message: str, db_type: str = "", code: Optional[str] = None):
9
+ super().__init__(message)
10
+ self.db_type = db_type
11
+ self.code = code
12
+
13
+ def to_dict(self) -> dict:
14
+ return {
15
+ "success": False,
16
+ "error": str(self),
17
+ "db_type": self.db_type,
18
+ "code": self.code,
19
+ }
20
+
21
+
22
+ class ConnectionError(ConnectorError):
23
+ """连接失败异常"""
24
+ def __init__(self, message: str, db_type: str = ""):
25
+ super().__init__(message, db_type, code="CONNECTION_FAILED")
26
+
27
+
28
+ class QueryError(ConnectorError):
29
+ """查询执行异常"""
30
+ def __init__(self, message: str, db_type: str = "", sql: str = ""):
31
+ super().__init__(message, db_type, code="QUERY_FAILED")
32
+ self.sql = sql
33
+
34
+ def to_dict(self) -> dict:
35
+ result = super().to_dict()
36
+ result["sql"] = self.sql
37
+ return result
38
+
39
+
40
+ class SchemaError(ConnectorError):
41
+ """Schema 获取异常"""
42
+ def __init__(self, message: str, db_type: str = ""):
43
+ super().__init__(message, db_type, code="SCHEMA_ERROR")
44
+
45
+
46
+ class AuthenticationError(ConnectorError):
47
+ """认证失败异常"""
48
+ def __init__(self, message: str, db_type: str = ""):
49
+ super().__init__(message, db_type, code="AUTH_FAILED")
50
+
51
+
52
+ class TimeoutError(ConnectorError):
53
+ """连接超时异常"""
54
+ def __init__(self, message: str, db_type: str = ""):
55
+ super().__init__(message, db_type, code="TIMEOUT")
56
+
57
+
58
+ def format_connector_result(
59
+ success: bool,
60
+ data: any = None,
61
+ error: Optional[str] = None,
62
+ db_type: str = "",
63
+ code: Optional[str] = None,
64
+ ) -> dict:
65
+ """统一格式化连接器返回结果
66
+
67
+ Args:
68
+ success: 是否成功
69
+ data: 成功时返回的数据
70
+ error: 失败时的错误信息
71
+ db_type: 数据库类型
72
+ code: 错误码
73
+
74
+ Returns:
75
+ 统一格式的字典
76
+ """
77
+ result = {"success": success}
78
+
79
+ if success:
80
+ result["data"] = data
81
+ else:
82
+ result["error"] = error or "未知错误"
83
+ if code:
84
+ result["code"] = code
85
+ if db_type:
86
+ result["db_type"] = db_type
87
+
88
+ return result
@@ -0,0 +1,194 @@
1
+ """MongoDB 连接器"""
2
+
3
+ import json
4
+ import asyncio
5
+ from typing import Any
6
+
7
+ from pymongo import MongoClient
8
+
9
+ from .base import BaseConnector, QueryResult
10
+
11
+
12
+ class MongoDBConnector(BaseConnector):
13
+ """MongoDB 数据库连接器"""
14
+
15
+ db_type = "mongodb"
16
+
17
+ def __init__(self, host: str, port: int, user: str, password: str, database: str):
18
+ super().__init__(host, port, user, password, database)
19
+ self._client: MongoClient | None = None
20
+ self._db = None
21
+
22
+ async def connect(self) -> None:
23
+ loop = asyncio.get_event_loop()
24
+ connection_string = f"mongodb://{self.user}:{self.password}@{self.host}:{self.port}/" if self.user else f"mongodb://{self.host}:{self.port}/"
25
+ self._client = await loop.run_in_executor(
26
+ None,
27
+ lambda: MongoClient(connection_string, serverSelectionTimeoutMS=10000),
28
+ )
29
+ self._db = self._client[self.database]
30
+
31
+ async def disconnect(self) -> None:
32
+ if self._client:
33
+ self._client.close()
34
+ self._client = None
35
+ self._db = None
36
+
37
+ async def get_schema(self) -> dict:
38
+ """MongoDB 无固定 schema,采样分析每个 collection 的字段"""
39
+ if not self._client or self._db is None:
40
+ raise RuntimeError("MongoDB 未连接")
41
+
42
+ loop = asyncio.get_event_loop()
43
+
44
+ def _run():
45
+ collections = self._db.list_collection_names()
46
+ tables = []
47
+ for coll_name in collections:
48
+ # 采样最多 5 条文档推断字段
49
+ docs = list(self._db[coll_name].find().limit(5))
50
+ columns_set: dict[str, set] = {}
51
+ for doc in docs:
52
+ for key in doc.keys():
53
+ if key not in columns_set:
54
+ columns_set[key] = set()
55
+ val = doc[key]
56
+ columns_set[key].add(type(val).__name__ if val is not None else "null")
57
+
58
+ columns = []
59
+ for col_name, types in columns_set.items():
60
+ columns.append({
61
+ "name": col_name,
62
+ "type": ", ".join(sorted(types)),
63
+ "nullable": True,
64
+ "key": "_id" if col_name == "_id" else "",
65
+ "default": None,
66
+ "comment": "",
67
+ })
68
+ tables.append({"name": coll_name, "columns": columns})
69
+
70
+ return {"db_type": "mongodb", "tables": tables}
71
+
72
+ return await loop.run_in_executor(None, _run)
73
+
74
+ async def execute(self, sql: str) -> QueryResult:
75
+ """执行 MongoDB 查询(JSON 格式)"""
76
+ if not self._client or self._db is None:
77
+ raise RuntimeError("MongoDB 未连接")
78
+
79
+ loop = asyncio.get_event_loop()
80
+ result = QueryResult()
81
+
82
+ def _run():
83
+ try:
84
+ # 解析 JSON 查询
85
+ query = json.loads(sql)
86
+ collection_name = query.get("collection", "")
87
+ operation = query.get("operation", "find").lower()
88
+ collection = self._db[collection_name]
89
+
90
+ if operation == "find":
91
+ filter_obj = query.get("filter", {})
92
+ projection = query.get("projection", None)
93
+ sort_list = query.get("sort", None)
94
+ limit_val = query.get("limit", 100)
95
+
96
+ cursor = collection.find(filter_obj, projection)
97
+ if sort_list:
98
+ cursor = cursor.sort(sort_list)
99
+ if limit_val:
100
+ cursor = cursor.limit(limit_val)
101
+
102
+ docs = list(cursor)
103
+ # 提取所有字段名
104
+ columns_set = set()
105
+ for doc in docs:
106
+ columns_set.update(doc.keys())
107
+ result.columns = sorted(columns_set)
108
+ result.rows = []
109
+ for doc in docs:
110
+ result.rows.append([doc.get(col, None) for col in result.columns])
111
+ result.row_count = len(docs)
112
+ result.sql_type = "SELECT"
113
+
114
+ elif operation == "insert":
115
+ documents = query.get("documents", [])
116
+ if isinstance(documents, dict):
117
+ documents = [documents]
118
+ result_obj = collection.insert_many(documents)
119
+ result.affected_rows = len(result_obj.inserted_ids)
120
+ result.sql_type = "INSERT"
121
+
122
+ elif operation == "update":
123
+ filter_obj = query.get("filter", {})
124
+ update_obj = query.get("update", {})
125
+ many = query.get("many", False)
126
+ if many:
127
+ update_result = collection.update_many(filter_obj, update_obj)
128
+ else:
129
+ update_result = collection.update_one(filter_obj, update_obj)
130
+ result.affected_rows = update_result.modified_count
131
+ result.sql_type = "UPDATE"
132
+
133
+ elif operation == "delete":
134
+ filter_obj = query.get("filter", {})
135
+ many = query.get("many", False)
136
+ if many:
137
+ delete_result = collection.delete_many(filter_obj)
138
+ else:
139
+ delete_result = collection.delete_one(filter_obj)
140
+ result.affected_rows = delete_result.deleted_count
141
+ result.sql_type = "DELETE"
142
+
143
+ elif operation == "aggregate":
144
+ pipeline = query.get("pipeline", [])
145
+ docs = list(collection.aggregate(pipeline))
146
+ columns_set = set()
147
+ for doc in docs:
148
+ columns_set.update(doc.keys())
149
+ result.columns = sorted(columns_set)
150
+ result.rows = [[doc.get(col, None) for col in result.columns] for doc in docs]
151
+ result.row_count = len(docs)
152
+ result.sql_type = "SELECT"
153
+
154
+ elif operation == "count":
155
+ filter_obj = query.get("filter", {})
156
+ count = collection.count_documents(filter_obj)
157
+ result.columns = ["count"]
158
+ result.rows = [[count]]
159
+ result.row_count = 1
160
+ result.sql_type = "SELECT"
161
+
162
+ else:
163
+ result.columns = ["error"]
164
+ result.rows = [[f"不支持的操作: {operation}"]]
165
+ result.row_count = 1
166
+
167
+ except json.JSONDecodeError as e:
168
+ result.columns = ["error"]
169
+ result.rows = [[f"JSON 解析错误: {e}"]]
170
+ result.row_count = 1
171
+ except Exception as e:
172
+ result.columns = ["error"]
173
+ result.rows = [[str(e)]]
174
+ result.row_count = 1
175
+
176
+ return result
177
+
178
+ return await loop.run_in_executor(None, _run)
179
+
180
+ async def test_connection(self) -> dict:
181
+ from .exceptions import format_connector_result
182
+ try:
183
+ await self.connect()
184
+ if self._client:
185
+ loop = asyncio.get_event_loop()
186
+ await loop.run_in_executor(
187
+ None,
188
+ lambda: self._client.admin.command("ping"),
189
+ )
190
+ return format_connector_result(True, data={"message": "MongoDB 连接成功"}, db_type="mongodb")
191
+ except Exception as e:
192
+ return format_connector_result(False, error=str(e), db_type="mongodb", code="CONNECTION_FAILED")
193
+ finally:
194
+ await self.disconnect()
@@ -0,0 +1,110 @@
1
+ """MySQL 连接器"""
2
+
3
+ import asyncio
4
+ from typing import Any
5
+
6
+ import pymysql
7
+
8
+ from .base import BaseConnector, QueryResult
9
+
10
+
11
+ class MySQLConnector(BaseConnector):
12
+ """MySQL 数据库连接器"""
13
+
14
+ db_type = "mysql"
15
+
16
+ def __init__(self, host: str, port: int, user: str, password: str, database: str):
17
+ super().__init__(host, port, user, password, database)
18
+ self._conn: pymysql.Connection | None = None
19
+
20
+ async def connect(self) -> None:
21
+ loop = asyncio.get_event_loop()
22
+ self._conn = await loop.run_in_executor(
23
+ None,
24
+ lambda: pymysql.connect(
25
+ host=self.host,
26
+ port=self.port,
27
+ user=self.user,
28
+ password=self.password,
29
+ database=self.database,
30
+ charset="utf8mb4",
31
+ cursorclass=pymysql.cursors.Cursor,
32
+ connect_timeout=10,
33
+ ),
34
+ )
35
+
36
+ async def disconnect(self) -> None:
37
+ if self._conn:
38
+ self._conn.close()
39
+ self._conn = None
40
+
41
+ async def execute(self, sql: str) -> QueryResult:
42
+ if not self._conn:
43
+ raise RuntimeError("MySQL 未连接")
44
+
45
+ loop = asyncio.get_event_loop()
46
+ result = QueryResult()
47
+
48
+ def _run():
49
+ with self._conn.cursor() as cursor:
50
+ # 分割多条语句
51
+ statements = [s.strip() for s in sql.split(";") if s.strip()]
52
+ if not statements:
53
+ return result
54
+
55
+ for stmt in statements:
56
+ cursor.execute(stmt)
57
+ sql_type = self.classify_sql(stmt)
58
+ result.sql_type = sql_type
59
+
60
+ if sql_type == "SELECT":
61
+ result.columns = [col[0] for col in cursor.description] if cursor.description else []
62
+ result.rows = cursor.fetchall() if cursor.description else []
63
+ result.row_count = len(result.rows)
64
+ else:
65
+ result.affected_rows = cursor.rowcount
66
+
67
+ self._conn.commit()
68
+ return result
69
+
70
+ return await loop.run_in_executor(None, _run)
71
+
72
+ async def get_schema(self) -> dict:
73
+ if not self._conn:
74
+ raise RuntimeError("MySQL 未连接")
75
+
76
+ loop = asyncio.get_event_loop()
77
+
78
+ def _run():
79
+ tables = []
80
+ with self._conn.cursor() as cursor:
81
+ cursor.execute("SHOW TABLES")
82
+ table_rows = cursor.fetchall()
83
+ for (table_name,) in table_rows:
84
+ cursor.execute(f"SHOW FULL COLUMNS FROM `{table_name}`")
85
+ columns = []
86
+ for col in cursor.fetchall():
87
+ columns.append({
88
+ "name": col[0],
89
+ "type": col[1],
90
+ "nullable": col[3] == "YES",
91
+ "key": col[4] or "",
92
+ "default": str(col[5]) if col[5] is not None else None,
93
+ "comment": col[8] or "",
94
+ })
95
+ tables.append({"name": table_name, "columns": columns})
96
+ return {"db_type": "mysql", "tables": tables}
97
+
98
+ return await loop.run_in_executor(None, _run)
99
+
100
+ async def test_connection(self) -> dict:
101
+ from .exceptions import format_connector_result
102
+ try:
103
+ await self.connect()
104
+ if self._conn:
105
+ self._conn.ping(reconnect=False)
106
+ return format_connector_result(True, data={"message": "MySQL 连接成功"}, db_type="mysql")
107
+ except Exception as e:
108
+ return format_connector_result(False, error=str(e), db_type="mysql", code="CONNECTION_FAILED")
109
+ finally:
110
+ await self.disconnect()