sql-assistant 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_assistant/__init__.py +3 -0
- sql_assistant/api/__init__.py +1 -0
- sql_assistant/api/backup.py +116 -0
- sql_assistant/api/config.py +183 -0
- sql_assistant/api/conversation.py +71 -0
- sql_assistant/api/dependencies.py +22 -0
- sql_assistant/api/history.py +61 -0
- sql_assistant/api/models.py +221 -0
- sql_assistant/api/query.py +275 -0
- sql_assistant/api/routes.py +19 -0
- sql_assistant/api/schema.py +21 -0
- sql_assistant/config.py +144 -0
- sql_assistant/database/__init__.py +1 -0
- sql_assistant/database/backup.py +568 -0
- sql_assistant/database/connectors/__init__.py +1 -0
- sql_assistant/database/connectors/base.py +185 -0
- sql_assistant/database/connectors/exceptions.py +88 -0
- sql_assistant/database/connectors/mongodb.py +194 -0
- sql_assistant/database/connectors/mysql.py +110 -0
- sql_assistant/database/connectors/postgresql.py +133 -0
- sql_assistant/database/connectors/redis.py +132 -0
- sql_assistant/database/connectors/sqlserver.py +140 -0
- sql_assistant/database/history.py +290 -0
- sql_assistant/database/manager.py +178 -0
- sql_assistant/database/security.py +230 -0
- sql_assistant/llm/__init__.py +1 -0
- sql_assistant/llm/base.py +28 -0
- sql_assistant/llm/exceptions.py +96 -0
- sql_assistant/llm/manager.py +82 -0
- sql_assistant/llm/prompts.py +29 -0
- sql_assistant/llm/providers/__init__.py +1 -0
- sql_assistant/llm/providers/claude.py +132 -0
- sql_assistant/llm/providers/gemini.py +127 -0
- sql_assistant/llm/providers/openai_compatible.py +103 -0
- sql_assistant/llm/retry.py +88 -0
- sql_assistant/main.py +94 -0
- sql_assistant/settings.py +219 -0
- sql_assistant/web/__init__.py +1 -0
- sql_assistant/web/static/css/base.css +25 -0
- sql_assistant/web/static/css/components/backup.css +146 -0
- sql_assistant/web/static/css/components/chat.css +465 -0
- sql_assistant/web/static/css/components/modal.css +143 -0
- sql_assistant/web/static/css/components/settings.css +358 -0
- sql_assistant/web/static/css/components/sidebar.css +235 -0
- sql_assistant/web/static/css/components/toast.css +30 -0
- sql_assistant/web/static/css/style.css +10 -0
- sql_assistant/web/static/css/theme.css +200 -0
- sql_assistant/web/static/js/api.js +38 -0
- sql_assistant/web/static/js/app.js +161 -0
- sql_assistant/web/static/js/backup.js +216 -0
- sql_assistant/web/static/js/chat.js +238 -0
- sql_assistant/web/static/js/color-theme-manager.js +121 -0
- sql_assistant/web/static/js/confirm.js +95 -0
- sql_assistant/web/static/js/conversations.js +182 -0
- sql_assistant/web/static/js/settings.js +425 -0
- sql_assistant/web/static/js/state.js +43 -0
- sql_assistant/web/static/js/theme-manager.js +64 -0
- sql_assistant/web/static/js/ui.js +53 -0
- sql_assistant/web/templates/index.html +373 -0
- sql_assistant-1.0.0.dist-info/METADATA +24 -0
- sql_assistant-1.0.0.dist-info/RECORD +64 -0
- sql_assistant-1.0.0.dist-info/WHEEL +4 -0
- sql_assistant-1.0.0.dist-info/entry_points.txt +2 -0
- sql_assistant-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""查询相关 API 路由"""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from fastapi import APIRouter, HTTPException
|
|
7
|
+
|
|
8
|
+
from ..llm.prompts import SYSTEM_PROMPT, NL_TO_SQL_PROMPT
|
|
9
|
+
from ..database.connectors.base import QueryResult
|
|
10
|
+
from ..database.security import get_security_guard
|
|
11
|
+
|
|
12
|
+
from .dependencies import get_config, get_llm, get_db, get_history
|
|
13
|
+
from .models import QueryRequest, QueryResponse, SQLPreviewResponse
|
|
14
|
+
|
|
15
|
+
router = APIRouter()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _extract_sql(text: str) -> str:
|
|
19
|
+
"""从 LLM 输出中提取 SQL 语句。
|
|
20
|
+
|
|
21
|
+
处理多种常见格式:
|
|
22
|
+
- Markdown 代码块包裹:```sql ... ```
|
|
23
|
+
- 带前导注释:-- 注释\\nDELETE FROM ...
|
|
24
|
+
- 带解释性前缀:以下是 SQL:\\nSELECT ...
|
|
25
|
+
"""
|
|
26
|
+
text = text.strip()
|
|
27
|
+
|
|
28
|
+
# 1. 提取 markdown 代码块内容
|
|
29
|
+
if text.startswith("```"):
|
|
30
|
+
lines = text.split("\n")
|
|
31
|
+
if lines[0].strip().startswith("```"):
|
|
32
|
+
lines = lines[1:]
|
|
33
|
+
if lines and lines[-1].strip() == "```":
|
|
34
|
+
lines = lines[:-1]
|
|
35
|
+
text = "\n".join(lines).strip()
|
|
36
|
+
|
|
37
|
+
# 2. 如果提取后内容仍不以 SQL 关键字开头,尝试定位 SQL 起始行
|
|
38
|
+
sql_keywords = (
|
|
39
|
+
"SELECT", "INSERT", "UPDATE", "DELETE", "WITH", "CREATE", "ALTER",
|
|
40
|
+
"DROP", "TRUNCATE", "SHOW", "DESCRIBE", "EXPLAIN", "USE", "SET",
|
|
41
|
+
"GRANT", "REVOKE", "BEGIN", "COMMIT", "ROLLBACK",
|
|
42
|
+
)
|
|
43
|
+
lines = text.split("\n")
|
|
44
|
+
# 跳过前导的注释行和空行,找到第一条实际语句
|
|
45
|
+
start_idx = 0
|
|
46
|
+
for i, line in enumerate(lines):
|
|
47
|
+
stripped = line.strip()
|
|
48
|
+
if not stripped or stripped.startswith("--") or stripped.startswith("#") or stripped.startswith("/*"):
|
|
49
|
+
continue
|
|
50
|
+
# 检查是否以 SQL 关键字开头
|
|
51
|
+
upper = stripped.upper()
|
|
52
|
+
if any(upper.startswith(kw + " ") or upper == kw or upper.startswith(kw + "\t")
|
|
53
|
+
for kw in sql_keywords):
|
|
54
|
+
start_idx = i
|
|
55
|
+
break
|
|
56
|
+
else:
|
|
57
|
+
# 不是 SQL 关键字也不是注释,可能是解释性文字,继续往后找
|
|
58
|
+
continue
|
|
59
|
+
|
|
60
|
+
if start_idx > 0:
|
|
61
|
+
text = "\n".join(lines[start_idx:]).strip()
|
|
62
|
+
|
|
63
|
+
return text
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _result_to_dict(result: Any) -> dict:
|
|
67
|
+
if isinstance(result, QueryResult):
|
|
68
|
+
return {
|
|
69
|
+
"columns": result.columns,
|
|
70
|
+
"rows": result.rows,
|
|
71
|
+
"row_count": result.row_count,
|
|
72
|
+
"affected_rows": result.affected_rows,
|
|
73
|
+
"sql_type": result.sql_type,
|
|
74
|
+
}
|
|
75
|
+
return {"value": str(result)}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _hash_sql(sql: str) -> str:
|
|
79
|
+
return hashlib.sha256(sql.strip().encode('utf-8')).hexdigest()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _validate_request() -> tuple:
|
|
83
|
+
config = get_config()
|
|
84
|
+
llm = get_llm()
|
|
85
|
+
db = get_db()
|
|
86
|
+
|
|
87
|
+
active_llm = config.get_active_llm()
|
|
88
|
+
active_db = config.get_active_database()
|
|
89
|
+
|
|
90
|
+
if not active_llm:
|
|
91
|
+
raise HTTPException(status_code=400, detail="请先在设置中配置并选择 LLM 提供商")
|
|
92
|
+
if not active_db:
|
|
93
|
+
raise HTTPException(status_code=400, detail="请先在设置中配置并选择数据库")
|
|
94
|
+
|
|
95
|
+
return config, llm, db, active_llm, active_db
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@router.post("/query/preview", response_model=SQLPreviewResponse)
|
|
99
|
+
async def preview_query(request: QueryRequest):
|
|
100
|
+
config, llm, db, active_llm, active_db = _validate_request()
|
|
101
|
+
db_type = request.db_type_override or active_db.db_type
|
|
102
|
+
schema_text = await db.get_schema_text()
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
messages = [
|
|
106
|
+
{"role": "system", "content": SYSTEM_PROMPT.format(
|
|
107
|
+
db_type=db_type,
|
|
108
|
+
schema_context=schema_text,
|
|
109
|
+
)},
|
|
110
|
+
{"role": "user", "content": NL_TO_SQL_PROMPT.format(
|
|
111
|
+
db_type=db_type,
|
|
112
|
+
question=request.question,
|
|
113
|
+
)},
|
|
114
|
+
]
|
|
115
|
+
sql_text = await llm.chat(messages, temperature=0.1)
|
|
116
|
+
sql_text = _extract_sql(sql_text)
|
|
117
|
+
|
|
118
|
+
security_guard = get_security_guard()
|
|
119
|
+
check_result = security_guard.check_sql_safety(sql_text)
|
|
120
|
+
|
|
121
|
+
if not check_result.is_safe:
|
|
122
|
+
return SQLPreviewResponse(
|
|
123
|
+
success=False,
|
|
124
|
+
error=f"SQL 安全检查失败: {check_result.blocked_reason}"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
requires_confirmation, confirmation_reason = security_guard.requires_confirmation(sql_text)
|
|
128
|
+
|
|
129
|
+
return SQLPreviewResponse(
|
|
130
|
+
success=True,
|
|
131
|
+
sql=sql_text,
|
|
132
|
+
sql_hash=_hash_sql(sql_text),
|
|
133
|
+
requires_confirmation=requires_confirmation,
|
|
134
|
+
confirmation_reason=confirmation_reason,
|
|
135
|
+
warning=check_result.warning,
|
|
136
|
+
risk_level=check_result.risk_level
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
except Exception as e:
|
|
140
|
+
return SQLPreviewResponse(
|
|
141
|
+
success=False,
|
|
142
|
+
error=f"生成 SQL 失败: {e}"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@router.post("/query", response_model=QueryResponse)
|
|
147
|
+
async def execute_query(request: QueryRequest):
|
|
148
|
+
config, llm, db, active_llm, active_db = _validate_request()
|
|
149
|
+
history = get_history()
|
|
150
|
+
db_type = request.db_type_override or active_db.db_type
|
|
151
|
+
schema_text = await db.get_schema_text()
|
|
152
|
+
|
|
153
|
+
sql_text = ""
|
|
154
|
+
|
|
155
|
+
# 如果是确认执行(confirmed=True),直接使用前端传递的 SQL
|
|
156
|
+
# 避免 LLM 非确定性导致 hash 不匹配
|
|
157
|
+
if request.confirmed:
|
|
158
|
+
if not request.sql:
|
|
159
|
+
return QueryResponse(
|
|
160
|
+
success=False,
|
|
161
|
+
question=request.question,
|
|
162
|
+
error="确认执行时需要提供 SQL 语句",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
sql_text = request.sql
|
|
166
|
+
|
|
167
|
+
# 如果提供了 sql_hash,验证 SQL 的完整性
|
|
168
|
+
if request.sql_hash:
|
|
169
|
+
expected_hash = _hash_sql(sql_text)
|
|
170
|
+
if request.sql_hash != expected_hash:
|
|
171
|
+
return QueryResponse(
|
|
172
|
+
success=False,
|
|
173
|
+
question=request.question,
|
|
174
|
+
sql=sql_text,
|
|
175
|
+
error="SQL 验证失败,请重新预览并确认",
|
|
176
|
+
)
|
|
177
|
+
else:
|
|
178
|
+
# 非确认执行,正常调用 LLM 生成 SQL
|
|
179
|
+
try:
|
|
180
|
+
messages = [
|
|
181
|
+
{"role": "system", "content": SYSTEM_PROMPT.format(
|
|
182
|
+
db_type=db_type,
|
|
183
|
+
schema_context=schema_text,
|
|
184
|
+
)},
|
|
185
|
+
{"role": "user", "content": NL_TO_SQL_PROMPT.format(
|
|
186
|
+
db_type=db_type,
|
|
187
|
+
question=request.question,
|
|
188
|
+
)},
|
|
189
|
+
]
|
|
190
|
+
sql_text = await llm.chat(messages, temperature=0.1)
|
|
191
|
+
sql_text = _extract_sql(sql_text)
|
|
192
|
+
except Exception as e:
|
|
193
|
+
return QueryResponse(
|
|
194
|
+
success=False,
|
|
195
|
+
question=request.question,
|
|
196
|
+
error=f"LLM 调用失败: {e}",
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
security_guard = get_security_guard()
|
|
200
|
+
check_result = security_guard.check_sql_safety(sql_text)
|
|
201
|
+
if not check_result.is_safe:
|
|
202
|
+
return QueryResponse(
|
|
203
|
+
success=False,
|
|
204
|
+
question=request.question,
|
|
205
|
+
sql=sql_text,
|
|
206
|
+
error=f"SQL 安全检查失败: {check_result.blocked_reason}",
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
requires_confirmation, _ = security_guard.requires_confirmation(sql_text)
|
|
210
|
+
|
|
211
|
+
if requires_confirmation and not request.confirmed:
|
|
212
|
+
return QueryResponse(
|
|
213
|
+
success=False,
|
|
214
|
+
question=request.question,
|
|
215
|
+
sql=sql_text,
|
|
216
|
+
error="此操作需要用户确认,请预览后确认执行",
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
result = await db.execute(sql_text)
|
|
221
|
+
result_dict = _result_to_dict(result)
|
|
222
|
+
|
|
223
|
+
total_rows = result_dict.get("row_count", 0)
|
|
224
|
+
total_pages = (total_rows + request.page_size - 1) // request.page_size if total_rows > 0 else 1
|
|
225
|
+
|
|
226
|
+
if total_rows > 0 and request.page > 1:
|
|
227
|
+
start_idx = (request.page - 1) * request.page_size
|
|
228
|
+
end_idx = start_idx + request.page_size
|
|
229
|
+
result_dict["rows"] = result_dict["rows"][start_idx:end_idx]
|
|
230
|
+
result_dict["row_count"] = len(result_dict["rows"])
|
|
231
|
+
|
|
232
|
+
pagination = {
|
|
233
|
+
"page": request.page,
|
|
234
|
+
"page_size": request.page_size,
|
|
235
|
+
"total_rows": total_rows,
|
|
236
|
+
"total_pages": total_pages,
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
history_id = await history.add_record(
|
|
240
|
+
question=request.question,
|
|
241
|
+
sql=sql_text,
|
|
242
|
+
result=result_dict,
|
|
243
|
+
db_type=db_type,
|
|
244
|
+
llm_provider=active_llm.provider,
|
|
245
|
+
success=True,
|
|
246
|
+
conversation_id=request.conversation_id,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
return QueryResponse(
|
|
250
|
+
success=True,
|
|
251
|
+
question=request.question,
|
|
252
|
+
sql=sql_text,
|
|
253
|
+
result=result_dict,
|
|
254
|
+
history_id=history_id,
|
|
255
|
+
conversation_id=request.conversation_id,
|
|
256
|
+
pagination=pagination,
|
|
257
|
+
)
|
|
258
|
+
except Exception as e:
|
|
259
|
+
history_id = await history.add_record(
|
|
260
|
+
question=request.question,
|
|
261
|
+
sql=sql_text,
|
|
262
|
+
db_type=db_type,
|
|
263
|
+
llm_provider=active_llm.provider,
|
|
264
|
+
success=False,
|
|
265
|
+
error_message=str(e),
|
|
266
|
+
conversation_id=request.conversation_id,
|
|
267
|
+
)
|
|
268
|
+
return QueryResponse(
|
|
269
|
+
success=False,
|
|
270
|
+
question=request.question,
|
|
271
|
+
sql=sql_text,
|
|
272
|
+
error=f"SQL 执行失败: {e}",
|
|
273
|
+
history_id=history_id,
|
|
274
|
+
conversation_id=request.conversation_id,
|
|
275
|
+
)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""API 路由"""
|
|
2
|
+
|
|
3
|
+
from fastapi import APIRouter
|
|
4
|
+
|
|
5
|
+
from .query import router as query_router
|
|
6
|
+
from .config import router as config_router
|
|
7
|
+
from .conversation import router as conversation_router
|
|
8
|
+
from .backup import router as backup_router
|
|
9
|
+
from .schema import router as schema_router
|
|
10
|
+
from .history import router as history_router
|
|
11
|
+
|
|
12
|
+
router = APIRouter(prefix="/api")
|
|
13
|
+
|
|
14
|
+
router.include_router(query_router)
|
|
15
|
+
router.include_router(config_router)
|
|
16
|
+
router.include_router(conversation_router)
|
|
17
|
+
router.include_router(backup_router)
|
|
18
|
+
router.include_router(schema_router)
|
|
19
|
+
router.include_router(history_router)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Schema 相关 API 路由"""
|
|
2
|
+
|
|
3
|
+
from fastapi import APIRouter
|
|
4
|
+
|
|
5
|
+
from .dependencies import get_db
|
|
6
|
+
|
|
7
|
+
router = APIRouter()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@router.post("/schema/refresh")
|
|
11
|
+
async def refresh_schema():
|
|
12
|
+
db = get_db()
|
|
13
|
+
schema = await db.refresh_schema()
|
|
14
|
+
return schema
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@router.get("/schema")
|
|
18
|
+
async def get_schema():
|
|
19
|
+
db = get_db()
|
|
20
|
+
schema = await db.get_schema()
|
|
21
|
+
return schema
|
sql_assistant/config.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""配置管理 - YAML 文件读写"""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import yaml
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from .settings import AppSettings, LLMProviderConfig, DatabaseConfig
|
|
9
|
+
|
|
10
|
+
# 配置文件放在项目目录下的 .data 文件夹
|
|
11
|
+
PROJECT_DIR = Path(__file__).resolve().parent.parent.parent
|
|
12
|
+
DEFAULT_CONFIG_DIR = PROJECT_DIR / ".data"
|
|
13
|
+
DEFAULT_CONFIG_FILE = DEFAULT_CONFIG_DIR / "config.yaml"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ConfigManager:
|
|
17
|
+
"""统一配置管理器"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, config_path: Optional[Path] = None):
|
|
20
|
+
self.config_path = Path(config_path) if config_path else DEFAULT_CONFIG_FILE
|
|
21
|
+
self._settings: AppSettings = AppSettings()
|
|
22
|
+
self._ensure_config_exists()
|
|
23
|
+
|
|
24
|
+
def _ensure_config_exists(self):
|
|
25
|
+
"""确保配置目录存在"""
|
|
26
|
+
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
27
|
+
|
|
28
|
+
# ---- 加载/保存 ----
|
|
29
|
+
|
|
30
|
+
def load(self) -> AppSettings:
|
|
31
|
+
"""从 YAML 文件加载配置"""
|
|
32
|
+
try:
|
|
33
|
+
with open(self.config_path, "r", encoding="utf-8") as f:
|
|
34
|
+
data = yaml.safe_load(f) or {}
|
|
35
|
+
self._settings = AppSettings(**data)
|
|
36
|
+
except Exception:
|
|
37
|
+
self._settings = AppSettings()
|
|
38
|
+
return self._settings
|
|
39
|
+
|
|
40
|
+
def save(self) -> None:
|
|
41
|
+
"""保存配置到 YAML 文件"""
|
|
42
|
+
self._ensure_config_exists()
|
|
43
|
+
data = self._settings.model_dump(
|
|
44
|
+
exclude={"llm_providers": {"__all__": {"DEFAULT_BASE_URLS", "DEFAULT_MODELS"}},
|
|
45
|
+
"databases": {"__all__": {"DEFAULT_PORTS"}}}
|
|
46
|
+
)
|
|
47
|
+
with open(self.config_path, "w", encoding="utf-8") as f:
|
|
48
|
+
yaml.safe_dump(data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
|
|
49
|
+
|
|
50
|
+
def get_settings(self) -> AppSettings:
|
|
51
|
+
return self._settings
|
|
52
|
+
|
|
53
|
+
# ---- LLM Provider CRUD ----
|
|
54
|
+
|
|
55
|
+
def get_llm_providers(self) -> list[LLMProviderConfig]:
|
|
56
|
+
return self._settings.llm_providers
|
|
57
|
+
|
|
58
|
+
def get_llm_provider(self, name: str) -> Optional[LLMProviderConfig]:
|
|
59
|
+
for p in self._settings.llm_providers:
|
|
60
|
+
if p.name == name:
|
|
61
|
+
return p
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
def add_llm_provider(self, config: LLMProviderConfig) -> None:
|
|
65
|
+
existing = self.get_llm_provider(config.name)
|
|
66
|
+
if existing:
|
|
67
|
+
idx = self._settings.llm_providers.index(existing)
|
|
68
|
+
self._settings.llm_providers[idx] = config
|
|
69
|
+
else:
|
|
70
|
+
self._settings.llm_providers.append(config)
|
|
71
|
+
self.save()
|
|
72
|
+
|
|
73
|
+
def remove_llm_provider(self, name: str) -> bool:
|
|
74
|
+
provider = self.get_llm_provider(name)
|
|
75
|
+
if provider:
|
|
76
|
+
self._settings.llm_providers.remove(provider)
|
|
77
|
+
if self._settings.active_llm == name:
|
|
78
|
+
self._settings.active_llm = ""
|
|
79
|
+
self.save()
|
|
80
|
+
return True
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
def get_active_llm(self) -> Optional[LLMProviderConfig]:
|
|
84
|
+
return self.get_llm_provider(self._settings.active_llm)
|
|
85
|
+
|
|
86
|
+
def set_active_llm(self, name: str) -> bool:
|
|
87
|
+
if self.get_llm_provider(name):
|
|
88
|
+
self._settings.active_llm = name
|
|
89
|
+
self.save()
|
|
90
|
+
return True
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
# ---- Database CRUD ----
|
|
94
|
+
|
|
95
|
+
def get_databases(self) -> list[DatabaseConfig]:
|
|
96
|
+
return self._settings.databases
|
|
97
|
+
|
|
98
|
+
def get_database(self, name: str) -> Optional[DatabaseConfig]:
|
|
99
|
+
for db in self._settings.databases:
|
|
100
|
+
if db.name == name:
|
|
101
|
+
return db
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
def add_database(self, config: DatabaseConfig) -> None:
|
|
105
|
+
existing = self.get_database(config.name)
|
|
106
|
+
if existing:
|
|
107
|
+
idx = self._settings.databases.index(existing)
|
|
108
|
+
self._settings.databases[idx] = config
|
|
109
|
+
else:
|
|
110
|
+
self._settings.databases.append(config)
|
|
111
|
+
self.save()
|
|
112
|
+
|
|
113
|
+
def remove_database(self, name: str) -> bool:
|
|
114
|
+
db = self.get_database(name)
|
|
115
|
+
if db:
|
|
116
|
+
self._settings.databases.remove(db)
|
|
117
|
+
if self._settings.active_database == name:
|
|
118
|
+
self._settings.active_database = ""
|
|
119
|
+
self.save()
|
|
120
|
+
return True
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
def get_active_database(self) -> Optional[DatabaseConfig]:
|
|
124
|
+
return self.get_database(self._settings.active_database)
|
|
125
|
+
|
|
126
|
+
def set_active_database(self, name: str) -> bool:
|
|
127
|
+
if self.get_database(name):
|
|
128
|
+
self._settings.active_database = name
|
|
129
|
+
self.save()
|
|
130
|
+
return True
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# 全局单例
|
|
135
|
+
_config_instance: Optional[ConfigManager] = None
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def get_config_manager() -> ConfigManager:
|
|
139
|
+
"""获取全局 ConfigManager 单例"""
|
|
140
|
+
global _config_instance
|
|
141
|
+
if _config_instance is None:
|
|
142
|
+
_config_instance = ConfigManager()
|
|
143
|
+
_config_instance.load()
|
|
144
|
+
return _config_instance
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""数据库连接器模块"""
|