sql-assistant 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. sql_assistant/__init__.py +3 -0
  2. sql_assistant/api/__init__.py +1 -0
  3. sql_assistant/api/backup.py +116 -0
  4. sql_assistant/api/config.py +183 -0
  5. sql_assistant/api/conversation.py +71 -0
  6. sql_assistant/api/dependencies.py +22 -0
  7. sql_assistant/api/history.py +61 -0
  8. sql_assistant/api/models.py +221 -0
  9. sql_assistant/api/query.py +275 -0
  10. sql_assistant/api/routes.py +19 -0
  11. sql_assistant/api/schema.py +21 -0
  12. sql_assistant/config.py +144 -0
  13. sql_assistant/database/__init__.py +1 -0
  14. sql_assistant/database/backup.py +568 -0
  15. sql_assistant/database/connectors/__init__.py +1 -0
  16. sql_assistant/database/connectors/base.py +185 -0
  17. sql_assistant/database/connectors/exceptions.py +88 -0
  18. sql_assistant/database/connectors/mongodb.py +194 -0
  19. sql_assistant/database/connectors/mysql.py +110 -0
  20. sql_assistant/database/connectors/postgresql.py +133 -0
  21. sql_assistant/database/connectors/redis.py +132 -0
  22. sql_assistant/database/connectors/sqlserver.py +140 -0
  23. sql_assistant/database/history.py +290 -0
  24. sql_assistant/database/manager.py +178 -0
  25. sql_assistant/database/security.py +230 -0
  26. sql_assistant/llm/__init__.py +1 -0
  27. sql_assistant/llm/base.py +28 -0
  28. sql_assistant/llm/exceptions.py +96 -0
  29. sql_assistant/llm/manager.py +82 -0
  30. sql_assistant/llm/prompts.py +29 -0
  31. sql_assistant/llm/providers/__init__.py +1 -0
  32. sql_assistant/llm/providers/claude.py +132 -0
  33. sql_assistant/llm/providers/gemini.py +127 -0
  34. sql_assistant/llm/providers/openai_compatible.py +103 -0
  35. sql_assistant/llm/retry.py +88 -0
  36. sql_assistant/main.py +94 -0
  37. sql_assistant/settings.py +219 -0
  38. sql_assistant/web/__init__.py +1 -0
  39. sql_assistant/web/static/css/base.css +25 -0
  40. sql_assistant/web/static/css/components/backup.css +146 -0
  41. sql_assistant/web/static/css/components/chat.css +465 -0
  42. sql_assistant/web/static/css/components/modal.css +143 -0
  43. sql_assistant/web/static/css/components/settings.css +358 -0
  44. sql_assistant/web/static/css/components/sidebar.css +235 -0
  45. sql_assistant/web/static/css/components/toast.css +30 -0
  46. sql_assistant/web/static/css/style.css +10 -0
  47. sql_assistant/web/static/css/theme.css +200 -0
  48. sql_assistant/web/static/js/api.js +38 -0
  49. sql_assistant/web/static/js/app.js +161 -0
  50. sql_assistant/web/static/js/backup.js +216 -0
  51. sql_assistant/web/static/js/chat.js +238 -0
  52. sql_assistant/web/static/js/color-theme-manager.js +121 -0
  53. sql_assistant/web/static/js/confirm.js +95 -0
  54. sql_assistant/web/static/js/conversations.js +182 -0
  55. sql_assistant/web/static/js/settings.js +425 -0
  56. sql_assistant/web/static/js/state.js +43 -0
  57. sql_assistant/web/static/js/theme-manager.js +64 -0
  58. sql_assistant/web/static/js/ui.js +53 -0
  59. sql_assistant/web/templates/index.html +373 -0
  60. sql_assistant-1.0.0.dist-info/METADATA +24 -0
  61. sql_assistant-1.0.0.dist-info/RECORD +64 -0
  62. sql_assistant-1.0.0.dist-info/WHEEL +4 -0
  63. sql_assistant-1.0.0.dist-info/entry_points.txt +2 -0
  64. sql_assistant-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,275 @@
1
+ """查询相关 API 路由"""
2
+
3
+ import hashlib
4
+ from typing import Any
5
+
6
+ from fastapi import APIRouter, HTTPException
7
+
8
+ from ..llm.prompts import SYSTEM_PROMPT, NL_TO_SQL_PROMPT
9
+ from ..database.connectors.base import QueryResult
10
+ from ..database.security import get_security_guard
11
+
12
+ from .dependencies import get_config, get_llm, get_db, get_history
13
+ from .models import QueryRequest, QueryResponse, SQLPreviewResponse
14
+
15
+ router = APIRouter()
16
+
17
+
18
+ def _extract_sql(text: str) -> str:
19
+ """从 LLM 输出中提取 SQL 语句。
20
+
21
+ 处理多种常见格式:
22
+ - Markdown 代码块包裹:```sql ... ```
23
+ - 带前导注释:-- 注释\\nDELETE FROM ...
24
+ - 带解释性前缀:以下是 SQL:\\nSELECT ...
25
+ """
26
+ text = text.strip()
27
+
28
+ # 1. 提取 markdown 代码块内容
29
+ if text.startswith("```"):
30
+ lines = text.split("\n")
31
+ if lines[0].strip().startswith("```"):
32
+ lines = lines[1:]
33
+ if lines and lines[-1].strip() == "```":
34
+ lines = lines[:-1]
35
+ text = "\n".join(lines).strip()
36
+
37
+ # 2. 如果提取后内容仍不以 SQL 关键字开头,尝试定位 SQL 起始行
38
+ sql_keywords = (
39
+ "SELECT", "INSERT", "UPDATE", "DELETE", "WITH", "CREATE", "ALTER",
40
+ "DROP", "TRUNCATE", "SHOW", "DESCRIBE", "EXPLAIN", "USE", "SET",
41
+ "GRANT", "REVOKE", "BEGIN", "COMMIT", "ROLLBACK",
42
+ )
43
+ lines = text.split("\n")
44
+ # 跳过前导的注释行和空行,找到第一条实际语句
45
+ start_idx = 0
46
+ for i, line in enumerate(lines):
47
+ stripped = line.strip()
48
+ if not stripped or stripped.startswith("--") or stripped.startswith("#") or stripped.startswith("/*"):
49
+ continue
50
+ # 检查是否以 SQL 关键字开头
51
+ upper = stripped.upper()
52
+ if any(upper.startswith(kw + " ") or upper == kw or upper.startswith(kw + "\t")
53
+ for kw in sql_keywords):
54
+ start_idx = i
55
+ break
56
+ else:
57
+ # 不是 SQL 关键字也不是注释,可能是解释性文字,继续往后找
58
+ continue
59
+
60
+ if start_idx > 0:
61
+ text = "\n".join(lines[start_idx:]).strip()
62
+
63
+ return text
64
+
65
+
66
+ def _result_to_dict(result: Any) -> dict:
67
+ if isinstance(result, QueryResult):
68
+ return {
69
+ "columns": result.columns,
70
+ "rows": result.rows,
71
+ "row_count": result.row_count,
72
+ "affected_rows": result.affected_rows,
73
+ "sql_type": result.sql_type,
74
+ }
75
+ return {"value": str(result)}
76
+
77
+
78
+ def _hash_sql(sql: str) -> str:
79
+ return hashlib.sha256(sql.strip().encode('utf-8')).hexdigest()
80
+
81
+
82
+ def _validate_request() -> tuple:
83
+ config = get_config()
84
+ llm = get_llm()
85
+ db = get_db()
86
+
87
+ active_llm = config.get_active_llm()
88
+ active_db = config.get_active_database()
89
+
90
+ if not active_llm:
91
+ raise HTTPException(status_code=400, detail="请先在设置中配置并选择 LLM 提供商")
92
+ if not active_db:
93
+ raise HTTPException(status_code=400, detail="请先在设置中配置并选择数据库")
94
+
95
+ return config, llm, db, active_llm, active_db
96
+
97
+
98
+ @router.post("/query/preview", response_model=SQLPreviewResponse)
99
+ async def preview_query(request: QueryRequest):
100
+ config, llm, db, active_llm, active_db = _validate_request()
101
+ db_type = request.db_type_override or active_db.db_type
102
+ schema_text = await db.get_schema_text()
103
+
104
+ try:
105
+ messages = [
106
+ {"role": "system", "content": SYSTEM_PROMPT.format(
107
+ db_type=db_type,
108
+ schema_context=schema_text,
109
+ )},
110
+ {"role": "user", "content": NL_TO_SQL_PROMPT.format(
111
+ db_type=db_type,
112
+ question=request.question,
113
+ )},
114
+ ]
115
+ sql_text = await llm.chat(messages, temperature=0.1)
116
+ sql_text = _extract_sql(sql_text)
117
+
118
+ security_guard = get_security_guard()
119
+ check_result = security_guard.check_sql_safety(sql_text)
120
+
121
+ if not check_result.is_safe:
122
+ return SQLPreviewResponse(
123
+ success=False,
124
+ error=f"SQL 安全检查失败: {check_result.blocked_reason}"
125
+ )
126
+
127
+ requires_confirmation, confirmation_reason = security_guard.requires_confirmation(sql_text)
128
+
129
+ return SQLPreviewResponse(
130
+ success=True,
131
+ sql=sql_text,
132
+ sql_hash=_hash_sql(sql_text),
133
+ requires_confirmation=requires_confirmation,
134
+ confirmation_reason=confirmation_reason,
135
+ warning=check_result.warning,
136
+ risk_level=check_result.risk_level
137
+ )
138
+
139
+ except Exception as e:
140
+ return SQLPreviewResponse(
141
+ success=False,
142
+ error=f"生成 SQL 失败: {e}"
143
+ )
144
+
145
+
146
+ @router.post("/query", response_model=QueryResponse)
147
+ async def execute_query(request: QueryRequest):
148
+ config, llm, db, active_llm, active_db = _validate_request()
149
+ history = get_history()
150
+ db_type = request.db_type_override or active_db.db_type
151
+ schema_text = await db.get_schema_text()
152
+
153
+ sql_text = ""
154
+
155
+ # 如果是确认执行(confirmed=True),直接使用前端传递的 SQL
156
+ # 避免 LLM 非确定性导致 hash 不匹配
157
+ if request.confirmed:
158
+ if not request.sql:
159
+ return QueryResponse(
160
+ success=False,
161
+ question=request.question,
162
+ error="确认执行时需要提供 SQL 语句",
163
+ )
164
+
165
+ sql_text = request.sql
166
+
167
+ # 如果提供了 sql_hash,验证 SQL 的完整性
168
+ if request.sql_hash:
169
+ expected_hash = _hash_sql(sql_text)
170
+ if request.sql_hash != expected_hash:
171
+ return QueryResponse(
172
+ success=False,
173
+ question=request.question,
174
+ sql=sql_text,
175
+ error="SQL 验证失败,请重新预览并确认",
176
+ )
177
+ else:
178
+ # 非确认执行,正常调用 LLM 生成 SQL
179
+ try:
180
+ messages = [
181
+ {"role": "system", "content": SYSTEM_PROMPT.format(
182
+ db_type=db_type,
183
+ schema_context=schema_text,
184
+ )},
185
+ {"role": "user", "content": NL_TO_SQL_PROMPT.format(
186
+ db_type=db_type,
187
+ question=request.question,
188
+ )},
189
+ ]
190
+ sql_text = await llm.chat(messages, temperature=0.1)
191
+ sql_text = _extract_sql(sql_text)
192
+ except Exception as e:
193
+ return QueryResponse(
194
+ success=False,
195
+ question=request.question,
196
+ error=f"LLM 调用失败: {e}",
197
+ )
198
+
199
+ security_guard = get_security_guard()
200
+ check_result = security_guard.check_sql_safety(sql_text)
201
+ if not check_result.is_safe:
202
+ return QueryResponse(
203
+ success=False,
204
+ question=request.question,
205
+ sql=sql_text,
206
+ error=f"SQL 安全检查失败: {check_result.blocked_reason}",
207
+ )
208
+
209
+ requires_confirmation, _ = security_guard.requires_confirmation(sql_text)
210
+
211
+ if requires_confirmation and not request.confirmed:
212
+ return QueryResponse(
213
+ success=False,
214
+ question=request.question,
215
+ sql=sql_text,
216
+ error="此操作需要用户确认,请预览后确认执行",
217
+ )
218
+
219
+ try:
220
+ result = await db.execute(sql_text)
221
+ result_dict = _result_to_dict(result)
222
+
223
+ total_rows = result_dict.get("row_count", 0)
224
+ total_pages = (total_rows + request.page_size - 1) // request.page_size if total_rows > 0 else 1
225
+
226
+ if total_rows > 0 and request.page > 1:
227
+ start_idx = (request.page - 1) * request.page_size
228
+ end_idx = start_idx + request.page_size
229
+ result_dict["rows"] = result_dict["rows"][start_idx:end_idx]
230
+ result_dict["row_count"] = len(result_dict["rows"])
231
+
232
+ pagination = {
233
+ "page": request.page,
234
+ "page_size": request.page_size,
235
+ "total_rows": total_rows,
236
+ "total_pages": total_pages,
237
+ }
238
+
239
+ history_id = await history.add_record(
240
+ question=request.question,
241
+ sql=sql_text,
242
+ result=result_dict,
243
+ db_type=db_type,
244
+ llm_provider=active_llm.provider,
245
+ success=True,
246
+ conversation_id=request.conversation_id,
247
+ )
248
+
249
+ return QueryResponse(
250
+ success=True,
251
+ question=request.question,
252
+ sql=sql_text,
253
+ result=result_dict,
254
+ history_id=history_id,
255
+ conversation_id=request.conversation_id,
256
+ pagination=pagination,
257
+ )
258
+ except Exception as e:
259
+ history_id = await history.add_record(
260
+ question=request.question,
261
+ sql=sql_text,
262
+ db_type=db_type,
263
+ llm_provider=active_llm.provider,
264
+ success=False,
265
+ error_message=str(e),
266
+ conversation_id=request.conversation_id,
267
+ )
268
+ return QueryResponse(
269
+ success=False,
270
+ question=request.question,
271
+ sql=sql_text,
272
+ error=f"SQL 执行失败: {e}",
273
+ history_id=history_id,
274
+ conversation_id=request.conversation_id,
275
+ )
@@ -0,0 +1,19 @@
1
+ """API 路由"""
2
+
3
+ from fastapi import APIRouter
4
+
5
+ from .query import router as query_router
6
+ from .config import router as config_router
7
+ from .conversation import router as conversation_router
8
+ from .backup import router as backup_router
9
+ from .schema import router as schema_router
10
+ from .history import router as history_router
11
+
12
+ router = APIRouter(prefix="/api")
13
+
14
+ router.include_router(query_router)
15
+ router.include_router(config_router)
16
+ router.include_router(conversation_router)
17
+ router.include_router(backup_router)
18
+ router.include_router(schema_router)
19
+ router.include_router(history_router)
@@ -0,0 +1,21 @@
1
+ """Schema 相关 API 路由"""
2
+
3
+ from fastapi import APIRouter
4
+
5
+ from .dependencies import get_db
6
+
7
+ router = APIRouter()
8
+
9
+
10
+ @router.post("/schema/refresh")
11
+ async def refresh_schema():
12
+ db = get_db()
13
+ schema = await db.refresh_schema()
14
+ return schema
15
+
16
+
17
+ @router.get("/schema")
18
+ async def get_schema():
19
+ db = get_db()
20
+ schema = await db.get_schema()
21
+ return schema
@@ -0,0 +1,144 @@
1
+ """配置管理 - YAML 文件读写"""
2
+
3
+ import os
4
+ import yaml
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ from .settings import AppSettings, LLMProviderConfig, DatabaseConfig
9
+
10
+ # 配置文件放在项目目录下的 .data 文件夹
11
+ PROJECT_DIR = Path(__file__).resolve().parent.parent.parent
12
+ DEFAULT_CONFIG_DIR = PROJECT_DIR / ".data"
13
+ DEFAULT_CONFIG_FILE = DEFAULT_CONFIG_DIR / "config.yaml"
14
+
15
+
16
+ class ConfigManager:
17
+ """统一配置管理器"""
18
+
19
+ def __init__(self, config_path: Optional[Path] = None):
20
+ self.config_path = Path(config_path) if config_path else DEFAULT_CONFIG_FILE
21
+ self._settings: AppSettings = AppSettings()
22
+ self._ensure_config_exists()
23
+
24
+ def _ensure_config_exists(self):
25
+ """确保配置目录存在"""
26
+ self.config_path.parent.mkdir(parents=True, exist_ok=True)
27
+
28
+ # ---- 加载/保存 ----
29
+
30
+ def load(self) -> AppSettings:
31
+ """从 YAML 文件加载配置"""
32
+ try:
33
+ with open(self.config_path, "r", encoding="utf-8") as f:
34
+ data = yaml.safe_load(f) or {}
35
+ self._settings = AppSettings(**data)
36
+ except Exception:
37
+ self._settings = AppSettings()
38
+ return self._settings
39
+
40
+ def save(self) -> None:
41
+ """保存配置到 YAML 文件"""
42
+ self._ensure_config_exists()
43
+ data = self._settings.model_dump(
44
+ exclude={"llm_providers": {"__all__": {"DEFAULT_BASE_URLS", "DEFAULT_MODELS"}},
45
+ "databases": {"__all__": {"DEFAULT_PORTS"}}}
46
+ )
47
+ with open(self.config_path, "w", encoding="utf-8") as f:
48
+ yaml.safe_dump(data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
49
+
50
+ def get_settings(self) -> AppSettings:
51
+ return self._settings
52
+
53
+ # ---- LLM Provider CRUD ----
54
+
55
+ def get_llm_providers(self) -> list[LLMProviderConfig]:
56
+ return self._settings.llm_providers
57
+
58
+ def get_llm_provider(self, name: str) -> Optional[LLMProviderConfig]:
59
+ for p in self._settings.llm_providers:
60
+ if p.name == name:
61
+ return p
62
+ return None
63
+
64
+ def add_llm_provider(self, config: LLMProviderConfig) -> None:
65
+ existing = self.get_llm_provider(config.name)
66
+ if existing:
67
+ idx = self._settings.llm_providers.index(existing)
68
+ self._settings.llm_providers[idx] = config
69
+ else:
70
+ self._settings.llm_providers.append(config)
71
+ self.save()
72
+
73
+ def remove_llm_provider(self, name: str) -> bool:
74
+ provider = self.get_llm_provider(name)
75
+ if provider:
76
+ self._settings.llm_providers.remove(provider)
77
+ if self._settings.active_llm == name:
78
+ self._settings.active_llm = ""
79
+ self.save()
80
+ return True
81
+ return False
82
+
83
+ def get_active_llm(self) -> Optional[LLMProviderConfig]:
84
+ return self.get_llm_provider(self._settings.active_llm)
85
+
86
+ def set_active_llm(self, name: str) -> bool:
87
+ if self.get_llm_provider(name):
88
+ self._settings.active_llm = name
89
+ self.save()
90
+ return True
91
+ return False
92
+
93
+ # ---- Database CRUD ----
94
+
95
+ def get_databases(self) -> list[DatabaseConfig]:
96
+ return self._settings.databases
97
+
98
+ def get_database(self, name: str) -> Optional[DatabaseConfig]:
99
+ for db in self._settings.databases:
100
+ if db.name == name:
101
+ return db
102
+ return None
103
+
104
+ def add_database(self, config: DatabaseConfig) -> None:
105
+ existing = self.get_database(config.name)
106
+ if existing:
107
+ idx = self._settings.databases.index(existing)
108
+ self._settings.databases[idx] = config
109
+ else:
110
+ self._settings.databases.append(config)
111
+ self.save()
112
+
113
+ def remove_database(self, name: str) -> bool:
114
+ db = self.get_database(name)
115
+ if db:
116
+ self._settings.databases.remove(db)
117
+ if self._settings.active_database == name:
118
+ self._settings.active_database = ""
119
+ self.save()
120
+ return True
121
+ return False
122
+
123
+ def get_active_database(self) -> Optional[DatabaseConfig]:
124
+ return self.get_database(self._settings.active_database)
125
+
126
+ def set_active_database(self, name: str) -> bool:
127
+ if self.get_database(name):
128
+ self._settings.active_database = name
129
+ self.save()
130
+ return True
131
+ return False
132
+
133
+
134
+ # 全局单例
135
+ _config_instance: Optional[ConfigManager] = None
136
+
137
+
138
+ def get_config_manager() -> ConfigManager:
139
+ """获取全局 ConfigManager 单例"""
140
+ global _config_instance
141
+ if _config_instance is None:
142
+ _config_instance = ConfigManager()
143
+ _config_instance.load()
144
+ return _config_instance
@@ -0,0 +1 @@
1
+ """数据库连接器模块"""