sql-assistant 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. sql_assistant/__init__.py +3 -0
  2. sql_assistant/api/__init__.py +1 -0
  3. sql_assistant/api/backup.py +116 -0
  4. sql_assistant/api/config.py +183 -0
  5. sql_assistant/api/conversation.py +71 -0
  6. sql_assistant/api/dependencies.py +22 -0
  7. sql_assistant/api/history.py +61 -0
  8. sql_assistant/api/models.py +221 -0
  9. sql_assistant/api/query.py +275 -0
  10. sql_assistant/api/routes.py +19 -0
  11. sql_assistant/api/schema.py +21 -0
  12. sql_assistant/config.py +144 -0
  13. sql_assistant/database/__init__.py +1 -0
  14. sql_assistant/database/backup.py +568 -0
  15. sql_assistant/database/connectors/__init__.py +1 -0
  16. sql_assistant/database/connectors/base.py +185 -0
  17. sql_assistant/database/connectors/exceptions.py +88 -0
  18. sql_assistant/database/connectors/mongodb.py +194 -0
  19. sql_assistant/database/connectors/mysql.py +110 -0
  20. sql_assistant/database/connectors/postgresql.py +133 -0
  21. sql_assistant/database/connectors/redis.py +132 -0
  22. sql_assistant/database/connectors/sqlserver.py +140 -0
  23. sql_assistant/database/history.py +290 -0
  24. sql_assistant/database/manager.py +178 -0
  25. sql_assistant/database/security.py +230 -0
  26. sql_assistant/llm/__init__.py +1 -0
  27. sql_assistant/llm/base.py +28 -0
  28. sql_assistant/llm/exceptions.py +96 -0
  29. sql_assistant/llm/manager.py +82 -0
  30. sql_assistant/llm/prompts.py +29 -0
  31. sql_assistant/llm/providers/__init__.py +1 -0
  32. sql_assistant/llm/providers/claude.py +132 -0
  33. sql_assistant/llm/providers/gemini.py +127 -0
  34. sql_assistant/llm/providers/openai_compatible.py +103 -0
  35. sql_assistant/llm/retry.py +88 -0
  36. sql_assistant/main.py +94 -0
  37. sql_assistant/settings.py +219 -0
  38. sql_assistant/web/__init__.py +1 -0
  39. sql_assistant/web/static/css/base.css +25 -0
  40. sql_assistant/web/static/css/components/backup.css +146 -0
  41. sql_assistant/web/static/css/components/chat.css +465 -0
  42. sql_assistant/web/static/css/components/modal.css +143 -0
  43. sql_assistant/web/static/css/components/settings.css +358 -0
  44. sql_assistant/web/static/css/components/sidebar.css +235 -0
  45. sql_assistant/web/static/css/components/toast.css +30 -0
  46. sql_assistant/web/static/css/style.css +10 -0
  47. sql_assistant/web/static/css/theme.css +200 -0
  48. sql_assistant/web/static/js/api.js +38 -0
  49. sql_assistant/web/static/js/app.js +161 -0
  50. sql_assistant/web/static/js/backup.js +216 -0
  51. sql_assistant/web/static/js/chat.js +238 -0
  52. sql_assistant/web/static/js/color-theme-manager.js +121 -0
  53. sql_assistant/web/static/js/confirm.js +95 -0
  54. sql_assistant/web/static/js/conversations.js +182 -0
  55. sql_assistant/web/static/js/settings.js +425 -0
  56. sql_assistant/web/static/js/state.js +43 -0
  57. sql_assistant/web/static/js/theme-manager.js +64 -0
  58. sql_assistant/web/static/js/ui.js +53 -0
  59. sql_assistant/web/templates/index.html +373 -0
  60. sql_assistant-1.0.0.dist-info/METADATA +24 -0
  61. sql_assistant-1.0.0.dist-info/RECORD +64 -0
  62. sql_assistant-1.0.0.dist-info/WHEEL +4 -0
  63. sql_assistant-1.0.0.dist-info/entry_points.txt +2 -0
  64. sql_assistant-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,178 @@
1
+ """数据库连接管理器 - 统一管理所有数据库连接"""
2
+
3
+ from typing import Optional
4
+
5
+ from ..config import get_config_manager
6
+ from ..settings import DatabaseConfig
7
+ from .connectors.base import BaseConnector, QueryResult
8
+ from .connectors.mysql import MySQLConnector
9
+ from .connectors.sqlserver import SQLServerConnector
10
+ from .connectors.postgresql import PostgreSQLConnector
11
+ from .connectors.redis import RedisConnector
12
+ from .connectors.mongodb import MongoDBConnector
13
+
14
+
15
+ class DatabaseManager:
16
+ """数据库连接管理器"""
17
+
18
+ def __init__(self):
19
+ self._connectors: dict[str, BaseConnector] = {}
20
+ self._schema_cache: dict[str, dict] = {}
21
+
22
+ def _create_connector(self, config: DatabaseConfig) -> BaseConnector:
23
+ """根据配置创建对应的连接器"""
24
+ connector_classes = {
25
+ "mysql": MySQLConnector,
26
+ "sqlserver": SQLServerConnector,
27
+ "postgresql": PostgreSQLConnector,
28
+ "redis": RedisConnector,
29
+ "mongodb": MongoDBConnector,
30
+ }
31
+ cls = connector_classes.get(config.db_type)
32
+ if cls is None:
33
+ raise ValueError(f"不支持的数据库类型: {config.db_type}")
34
+
35
+ return cls(
36
+ host=config.host,
37
+ port=config.get_port(),
38
+ user=config.user,
39
+ password=config.password,
40
+ database=config.database,
41
+ )
42
+
43
+ async def get_connector(self) -> BaseConnector:
44
+ """获取当前激活的数据库连接器(自动连接)"""
45
+ config = get_config_manager().get_active_database()
46
+ if not config:
47
+ raise ValueError("未配置数据库,请先在设置中添加数据库连接")
48
+
49
+ key = config.name
50
+ if key not in self._connectors:
51
+ connector = self._create_connector(config)
52
+ await connector.connect()
53
+ self._connectors[key] = connector
54
+
55
+ return self._connectors[key]
56
+
57
+ async def execute(self, sql: str) -> QueryResult:
58
+ """执行 SQL 并返回结果"""
59
+ connector = await self.get_connector()
60
+ return await connector.execute(sql)
61
+
62
+ async def get_schema(self, force_refresh: bool = False) -> dict:
63
+ """获取当前激活数据库的 schema(带缓存)"""
64
+ config = get_config_manager().get_active_database()
65
+ if not config:
66
+ return {"db_type": "", "tables": [], "error": "未配置数据库连接"}
67
+
68
+ cache_key = config.name
69
+ if not force_refresh and cache_key in self._schema_cache:
70
+ return self._schema_cache[cache_key]
71
+
72
+ try:
73
+ connector = await self.get_connector()
74
+ schema = await connector.get_schema()
75
+ self._schema_cache[cache_key] = schema
76
+ return schema
77
+ except Exception as e:
78
+ return {"db_type": config.db_type, "tables": [], "error": str(e)}
79
+
80
+ async def refresh_schema(self) -> dict:
81
+ """强制刷新当前数据库 schema"""
82
+ config = get_config_manager().get_active_database()
83
+ if config:
84
+ cache_key = config.name
85
+ self._schema_cache.pop(cache_key, None)
86
+ if cache_key in self._connectors:
87
+ await self._connectors[cache_key].disconnect()
88
+ del self._connectors[cache_key]
89
+ return await self.get_schema(force_refresh=True)
90
+
91
+ async def get_schema_text(self) -> str:
92
+ """将 schema 转为给 LLM 的文本描述"""
93
+ schema = await self.get_schema()
94
+
95
+ db_type = schema.get("db_type", "")
96
+
97
+ if db_type == "redis":
98
+ keys = schema.get("keys", [])
99
+ if not keys:
100
+ return "(Redis: 当前数据库无 key)"
101
+ lines = ["## 当前 Redis Key 列表(前50个):"]
102
+ for k in keys:
103
+ lines.append(f"- {k['name']} ({k['type']})")
104
+ key_count = schema.get("key_count", 0)
105
+ if key_count > 50:
106
+ lines.append(f"... 共 {key_count} 个 key")
107
+ return "\n".join(lines)
108
+
109
+ tables = schema.get("tables", [])
110
+ if not tables:
111
+ error = schema.get("error", "")
112
+ if error:
113
+ return f"(Schema 获取失败: {error})"
114
+ return "(数据库为空,暂无表)"
115
+
116
+ lines = ["## 当前数据库 Schema:", ""]
117
+ for t in tables:
118
+ cols = t.get("columns", [])
119
+ if not cols:
120
+ lines.append(f"- {t['name']} (无列信息)")
121
+ continue
122
+ col_strs = []
123
+ for c in cols:
124
+ extras = []
125
+ if c.get("key") == "PRI":
126
+ extras.append("PRIMARY KEY")
127
+ if not c.get("nullable", True):
128
+ extras.append("NOT NULL")
129
+ suffix = f" -- {', '.join(extras)}" if extras else ""
130
+ col_strs.append(f" {c['name']} {c['type']}{suffix}")
131
+ lines.append(f"- {t['name']}:")
132
+ lines.extend(col_strs)
133
+ lines.append("")
134
+
135
+ return "\n".join(lines)
136
+
137
+ async def test_active_connection(self) -> dict:
138
+ """测试当前激活的连接"""
139
+ config = get_config_manager().get_active_database()
140
+ if not config:
141
+ return {"success": False, "error": "未配置数据库连接"}
142
+
143
+ connector = self._create_connector(config)
144
+ try:
145
+ result = await connector.test_connection()
146
+ return result
147
+ except Exception as e:
148
+ return {"success": False, "error": str(e), "code": "CONNECTION_FAILED"}
149
+ finally:
150
+ await connector.disconnect()
151
+
152
+ async def test_connection(self, config: DatabaseConfig) -> dict:
153
+ """测试指定配置的连接"""
154
+ connector = self._create_connector(config)
155
+ try:
156
+ result = await connector.test_connection()
157
+ return result
158
+ except Exception as e:
159
+ return {"success": False, "error": str(e), "code": "CONNECTION_FAILED"}
160
+ finally:
161
+ await connector.disconnect()
162
+
163
+ async def close_all(self):
164
+ """关闭所有连接"""
165
+ for connector in self._connectors.values():
166
+ await connector.disconnect()
167
+ self._connectors.clear()
168
+
169
+
170
+ # 全局单例
171
+ _db_manager: Optional[DatabaseManager] = None
172
+
173
+
174
+ def get_db_manager() -> DatabaseManager:
175
+ global _db_manager
176
+ if _db_manager is None:
177
+ _db_manager = DatabaseManager()
178
+ return _db_manager
@@ -0,0 +1,230 @@
1
+ """
2
+ SQL 安全保护模块 - 防止 SQL 注入攻击
3
+ """
4
+ import re
5
+ from typing import Optional, Tuple, List
6
+ from dataclasses import dataclass
7
+
8
+
9
+ @dataclass
10
+ class SQLCheckResult:
11
+ """SQL 检查结果"""
12
+ is_safe: bool
13
+ warning: str = ""
14
+ risk_level: str = "none" # none, low, medium, high
15
+ blocked_reason: str = ""
16
+
17
+
18
+ class SQLSecurityGuard:
19
+ """SQL 安全守卫类"""
20
+
21
+ # 危险 SQL 关键字和模式
22
+ DANGEROUS_KEYWORDS = [
23
+ # 注释相关
24
+ r"/\*", r"\*/", r"--", r"#",
25
+ # 联合查询
26
+ r"\bUNION\b", r"\bUNION\s+ALL\b",
27
+ # 子查询
28
+ r"\bSELECT\b.*\bFROM\b.*\bSELECT\b",
29
+ # 批处理
30
+ r";",
31
+ # 信息_schema
32
+ r"\binformation_schema\b",
33
+ r"\bmysql\.",
34
+ r"\bsys\.",
35
+ # 系统函数
36
+ r"\bLOAD_FILE\b",
37
+ r"\bINTO\s+OUTFILE\b",
38
+ r"\bINTO\s+DUMPFILE\b",
39
+ r"\bEXEC\b",
40
+ r"\bEXECUTE\b",
41
+ r"\bSYSTEM\b",
42
+ r"\bSHELL\b",
43
+ r"\bCMD\b",
44
+ # 数据操作
45
+ r"\bDROP\b",
46
+ r"\bTRUNCATE\b",
47
+ r"\bALTER\b.*\bDROP\b",
48
+ # 用户权限
49
+ r"\bGRANT\b",
50
+ r"\bREVOKE\b",
51
+ r"\bCREATE\s+USER\b",
52
+ r"\bALTER\s+USER\b",
53
+ # 十六进制
54
+ r"0x[0-9a-fA-F]+",
55
+ # 字符串拼接
56
+ r"\|\|",
57
+ r"CONCAT\s*\(",
58
+ r"GROUP_CONCAT\s*\(",
59
+ # 延迟攻击
60
+ r"\bSLEEP\b",
61
+ r"\bBENCHMARK\b",
62
+ # 布尔注入
63
+ r"\bOR\s+1\s*=\s*1\b",
64
+ r"\bAND\s+1\s*=\s*1\b",
65
+ r"\bOR\s+TRUE\b",
66
+ r"\bAND\s+FALSE\b",
67
+ # 盲注
68
+ r"\bIF\s*\(",
69
+ r"\bCASE\s+WHEN\b",
70
+ r"\bEXISTS\s*\(",
71
+ # 其他危险
72
+ r"\bDELETE\b.*\bWHERE\s*1\s*=\s*1\b",
73
+ r"\bUPDATE\b.*\bWHERE\s*1\s*=\s*1\b",
74
+ ]
75
+
76
+ # 允许的 SQL 类型
77
+ ALLOWED_SQL_TYPES = {
78
+ "SELECT", "INSERT", "UPDATE", "DELETE",
79
+ "SHOW", "DESCRIBE", "EXPLAIN",
80
+ "CREATE", "ALTER", "DROP", "TRUNCATE",
81
+ "USE", "SET"
82
+ }
83
+
84
+ def __init__(self):
85
+ # 编译正则表达式
86
+ self.dangerous_patterns = [
87
+ re.compile(pattern, re.IGNORECASE)
88
+ for pattern in self.DANGEROUS_KEYWORDS
89
+ ]
90
+
91
+ def check_sql_safety(self, sql: str) -> SQLCheckResult:
92
+ """
93
+ 检查 SQL 语句安全性
94
+
95
+ Args:
96
+ sql: SQL 语句
97
+
98
+ Returns:
99
+ SQLCheckResult 检查结果
100
+ """
101
+ sql = sql.strip()
102
+
103
+ if not sql:
104
+ return SQLCheckResult(
105
+ is_safe=False,
106
+ blocked_reason="SQL 语句为空"
107
+ )
108
+
109
+ # 检查危险模式
110
+ warnings = []
111
+ risk_level = "none"
112
+
113
+ for pattern in self.dangerous_patterns:
114
+ matches = pattern.findall(sql)
115
+ if matches:
116
+ for match in matches:
117
+ warning_msg = self._get_warning_message(match)
118
+ if warning_msg:
119
+ warnings.append(warning_msg)
120
+ # 升级风险等级
121
+ if "high" in warning_msg.lower():
122
+ risk_level = "high"
123
+ elif risk_level != "high":
124
+ risk_level = "medium"
125
+
126
+ if risk_level == "high":
127
+ return SQLCheckResult(
128
+ is_safe=False,
129
+ warning="; ".join(warnings),
130
+ risk_level=risk_level,
131
+ blocked_reason="检测到高风险 SQL 注入模式"
132
+ )
133
+
134
+ if warnings:
135
+ return SQLCheckResult(
136
+ is_safe=True, # 中等风险仍允许执行,但给出警告
137
+ warning="; ".join(warnings),
138
+ risk_level=risk_level
139
+ )
140
+
141
+ return SQLCheckResult(is_safe=True)
142
+
143
+ def _get_warning_message(self, match: str) -> str:
144
+ """根据匹配内容获取警告信息"""
145
+ match_lower = match.lower()
146
+
147
+ if any(keyword in match_lower for keyword in ["union", "select.*from.*select"]):
148
+ return "⚠️ 检测到联合查询模式,可能存在注入风险"
149
+ elif any(keyword in match_lower for keyword in ["--", "/*", "*/", "#"]):
150
+ return "⚠️ 检测到注释符,请确认操作意图"
151
+ elif ";" in match:
152
+ return "⚠️ 检测到分号分隔符,可能存在多语句执行风险"
153
+ elif any(keyword in match_lower for keyword in ["information_schema", "mysql.", "sys."]):
154
+ return "⚠️ 检测到系统表访问,请注意数据安全"
155
+ elif any(keyword in match_lower for keyword in ["load_file", "into outfile", "into dumpfile", "exec", "execute", "system", "shell"]):
156
+ return "🔴 检测到高风险系统函数调用"
157
+ elif any(keyword in match_lower for keyword in ["sleep", "benchmark"]):
158
+ return "🔴 检测到延迟注入函数"
159
+ elif any(keyword in match_lower for keyword in ["or 1=1", "and 1=1", "or true", "and false"]):
160
+ return "🔴 检测到布尔注入模式"
161
+ elif any(keyword in match_lower for keyword in ["drop", "truncate"]):
162
+ return "⚠️ 检测到数据删除操作,请谨慎执行"
163
+ elif any(keyword in match_lower for keyword in ["0x"]):
164
+ return "⚠️ 检测到十六进制编码"
165
+ elif any(keyword in match_lower for keyword in ["concat", "group_concat", "||"]):
166
+ return "⚠️ 检测到字符串拼接,请注意注入风险"
167
+
168
+ return ""
169
+
170
+ def sanitize_sql(self, sql: str) -> Tuple[str, List[str]]:
171
+ """
172
+ 清理 SQL 语句,移除一些明显的危险内容
173
+
174
+ Args:
175
+ sql: 原始 SQL
176
+
177
+ Returns:
178
+ (清理后的 SQL, 警告列表)
179
+ """
180
+ warnings = []
181
+ cleaned = sql
182
+
183
+ # 移除首尾空白
184
+ cleaned = cleaned.strip()
185
+
186
+ # 检查并警告危险模式(不自动修改)
187
+ check_result = self.check_sql_safety(cleaned)
188
+ if check_result.warning:
189
+ warnings.append(check_result.warning)
190
+
191
+ return cleaned, warnings
192
+
193
+ def requires_confirmation(self, sql: str) -> Tuple[bool, str]:
194
+ """
195
+ 判断 SQL 是否需要用户确认
196
+
197
+ Args:
198
+ sql: SQL 语句
199
+
200
+ Returns:
201
+ (是否需要确认, 确认提示信息)
202
+ """
203
+ from .connectors.base import BaseConnector
204
+
205
+ sql_type = BaseConnector.classify_sql(sql)
206
+
207
+ # 增删改操作需要确认
208
+ if sql_type in ["INSERT", "UPDATE", "DELETE", "DDL"]:
209
+ reason_map = {
210
+ "INSERT": "即将执行 INSERT 操作,会添加新数据",
211
+ "UPDATE": "即将执行 UPDATE 操作,会修改现有数据",
212
+ "DELETE": "即将执行 DELETE 操作,会删除数据",
213
+ "DDL": "即将执行 DDL 操作,会修改数据库结构"
214
+ }
215
+ return True, reason_map.get(sql_type, "即将执行数据修改操作")
216
+
217
+ # SELECT 不需要确认
218
+ return False, ""
219
+
220
+
221
+ # 全局单例
222
+ _security_guard: Optional[SQLSecurityGuard] = None
223
+
224
+
225
+ def get_security_guard() -> SQLSecurityGuard:
226
+ """获取 SQL 安全守卫单例"""
227
+ global _security_guard
228
+ if _security_guard is None:
229
+ _security_guard = SQLSecurityGuard()
230
+ return _security_guard
@@ -0,0 +1 @@
1
+ """LLM 集成模块"""
@@ -0,0 +1,28 @@
1
+ """LLM Provider 抽象基类"""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import AsyncGenerator
5
+
6
+
7
+ class BaseLLMProvider(ABC):
8
+ """LLM Provider 抽象基类"""
9
+
10
+ def __init__(self, api_key: str, base_url: str, model: str):
11
+ self.api_key = api_key
12
+ self.base_url = base_url.rstrip("/")
13
+ self.model = model
14
+
15
+ @abstractmethod
16
+ async def chat(self, messages: list[dict], temperature: float = 0.1) -> str:
17
+ """发送对话请求,返回完整响应"""
18
+ ...
19
+
20
+ @abstractmethod
21
+ async def chat_stream(self, messages: list[dict], temperature: float = 0.1) -> AsyncGenerator[str, None]:
22
+ """发送流式对话请求"""
23
+ ...
24
+
25
+ @abstractmethod
26
+ async def test_connection(self) -> dict:
27
+ """测试连接是否正常"""
28
+ ...
@@ -0,0 +1,96 @@
1
+ """LLM Provider 统一异常"""
2
+
3
+ from typing import Optional, Dict, Any
4
+
5
+
6
+ class LLMError(Exception):
7
+ """LLM Provider 基础异常"""
8
+ def __init__(self, message: str, provider: str = "", code: Optional[str] = None):
9
+ super().__init__(message)
10
+ self.provider = provider
11
+ self.code = code
12
+
13
+ def to_dict(self) -> dict:
14
+ return {
15
+ "success": False,
16
+ "error": str(self),
17
+ "provider": self.provider,
18
+ "code": self.code,
19
+ }
20
+
21
+
22
+ class LLMConnectionError(LLMError):
23
+ """LLM 连接失败异常"""
24
+ def __init__(self, message: str, provider: str = ""):
25
+ super().__init__(message, provider, code="LLM_CONNECTION_FAILED")
26
+
27
+
28
+ class LLMTimeoutError(LLMError):
29
+ """LLM 请求超时异常"""
30
+ def __init__(self, message: str, provider: str = ""):
31
+ super().__init__(message, provider, code="LLM_TIMEOUT")
32
+
33
+
34
+ class LLMAuthError(LLMError):
35
+ """LLM 认证失败异常"""
36
+ def __init__(self, message: str, provider: str = ""):
37
+ super().__init__(message, provider, code="LLM_AUTH_FAILED")
38
+
39
+
40
+ class LLMRateLimitError(LLMError):
41
+ """LLM 限流异常"""
42
+ def __init__(self, message: str, provider: str = "", retry_after: Optional[int] = None):
43
+ super().__init__(message, provider, code="LLM_RATE_LIMITED")
44
+ self.retry_after = retry_after
45
+
46
+ def to_dict(self) -> dict:
47
+ result = super().to_dict()
48
+ if self.retry_after:
49
+ result["retry_after"] = self.retry_after
50
+ return result
51
+
52
+
53
+ class LLMResponseError(LLMError):
54
+ """LLM 响应解析异常"""
55
+ def __init__(self, message: str, provider: str = "", raw_response: Any = None):
56
+ super().__init__(message, provider, code="LLM_RESPONSE_ERROR")
57
+ self.raw_response = raw_response
58
+
59
+ def to_dict(self) -> dict:
60
+ result = super().to_dict()
61
+ if self.raw_response:
62
+ result["raw_response"] = str(self.raw_response)[:500]
63
+ return result
64
+
65
+
66
+ def format_llm_result(
67
+ success: bool,
68
+ data: Any = None,
69
+ error: Optional[str] = None,
70
+ provider: str = "",
71
+ code: Optional[str] = None,
72
+ ) -> Dict[str, Any]:
73
+ """统一格式化 LLM 返回结果
74
+
75
+ Args:
76
+ success: 是否成功
77
+ data: 成功时返回的数据
78
+ error: 失败时的错误信息
79
+ provider: 提供商名称
80
+ code: 错误码
81
+
82
+ Returns:
83
+ 统一格式的字典
84
+ """
85
+ result: Dict[str, Any] = {"success": success}
86
+
87
+ if success:
88
+ result["data"] = data
89
+ else:
90
+ result["error"] = error or "未知错误"
91
+ if code:
92
+ result["code"] = code
93
+ if provider:
94
+ result["provider"] = provider
95
+
96
+ return result
@@ -0,0 +1,82 @@
1
+ """LLM Manager - 统一管理所有 Provider"""
2
+
3
+ from typing import AsyncGenerator, Optional
4
+
5
+ from ..config import get_config_manager
6
+ from ..settings import LLMProviderConfig
7
+ from .base import BaseLLMProvider
8
+ from .providers.openai_compatible import OpenAICompatibleProvider
9
+ from .providers.gemini import GeminiProvider
10
+ from .providers.claude import ClaudeProvider
11
+
12
+
13
+ class LLMManager:
14
+ """LLM Provider 管理器"""
15
+
16
+ def __init__(self):
17
+ self._providers: dict[str, BaseLLMProvider] = {}
18
+
19
+ def _create_provider(self, config: LLMProviderConfig) -> BaseLLMProvider:
20
+ """根据配置创建对应的 Provider 实例"""
21
+ if config.provider == "gemini":
22
+ return GeminiProvider(
23
+ api_key=config.api_key,
24
+ base_url=config.get_base_url(),
25
+ model=config.get_model(),
26
+ )
27
+ elif config.provider == "claude" or config.provider == "minimax":
28
+ return ClaudeProvider(
29
+ api_key=config.api_key,
30
+ base_url=config.get_base_url(),
31
+ model=config.get_model(),
32
+ )
33
+ else:
34
+ # OpenAI 兼容接口
35
+ return OpenAICompatibleProvider(
36
+ api_key=config.api_key,
37
+ base_url=config.get_base_url(),
38
+ model=config.get_model(),
39
+ provider_name=config.provider,
40
+ )
41
+
42
+ def get_provider(self, config: LLMProviderConfig) -> BaseLLMProvider:
43
+ """获取或创建 Provider (带缓存)"""
44
+ key = config.name
45
+ if key not in self._providers:
46
+ self._providers[key] = self._create_provider(config)
47
+ return self._providers[key]
48
+
49
+ async def chat(self, messages: list[dict], temperature: float = 0.1) -> str:
50
+ """使用当前激活的 LLM 发送对话请求"""
51
+ config = get_config_manager().get_active_llm()
52
+ if not config:
53
+ raise ValueError("未配置 LLM,请先在设置中添加 LLM 提供商")
54
+ provider = self.get_provider(config)
55
+ return await provider.chat(messages, temperature)
56
+
57
+ async def chat_stream(self, messages: list[dict], temperature: float = 0.1) -> AsyncGenerator[str, None]:
58
+ """使用当前激活的 LLM 发送流式对话"""
59
+ config = get_config_manager().get_active_llm()
60
+ if not config:
61
+ raise ValueError("未配置 LLM,请先在设置中添加 LLM 提供商")
62
+ provider = self.get_provider(config)
63
+ async for chunk in provider.chat_stream(messages, temperature):
64
+ yield chunk
65
+
66
+ async def close_all(self):
67
+ """关闭所有 Provider 连接"""
68
+ for provider in self._providers.values():
69
+ if hasattr(provider, "close"):
70
+ await provider.close()
71
+ self._providers.clear()
72
+
73
+
74
+ # 全局单例
75
+ _llm_manager: Optional[LLMManager] = None
76
+
77
+
78
+ def get_llm_manager() -> LLMManager:
79
+ global _llm_manager
80
+ if _llm_manager is None:
81
+ _llm_manager = LLMManager()
82
+ return _llm_manager
@@ -0,0 +1,29 @@
1
+ """LLM Prompt 模板 - 自然语言转 SQL"""
2
+
3
+ SYSTEM_PROMPT = """你是一个专业的 SQL 助手。你的任务是将用户的自然语言问题转换为 {db_type} 数据库的 SQL 语句。
4
+
5
+ {schema_context}
6
+ ## 规则
7
+ 1. 只返回 SQL 语句,不要返回其他解释性文字
8
+ 2. SQL 语句必须符合 {db_type} 的语法规范
9
+ 3. **严格使用上面 Schema 中列出的表名和字段名**,不要凭空编造表或字段
10
+ 4. 如果用户问的内容在 Schema 中找不到对应的表/字段,返回:-- 数据库中不存在相关表或字段,请确认查询内容
11
+ 5. 对于 SELECT 查询,使用合理的 LIMIT 限制(默认 100 条)
12
+ 6. 对于 DELETE/UPDATE,必须包含 WHERE 条件,并在语句前加注释提醒危险操作
13
+ 7. 如果用户的问题无法转换为 SQL,返回:-- 无法理解的问题
14
+ 8. 对于 Redis,使用 Redis 命令格式
15
+ 9. 对于 MongoDB,使用 MongoDB 查询语法(JSON 格式)
16
+
17
+ ## 数据库方言注意事项
18
+ - MySQL: 使用反引号包裹标识符,支持 LIMIT
19
+ - PostgreSQL: 使用双引号包裹标识符,支持 LIMIT
20
+ - SQL Server: 使用方括号包裹标识符,使用 TOP 或 OFFSET-FETCH
21
+ - Redis: 返回 Redis 原生命令
22
+ - MongoDB: 返回 MongoDB find/aggregate 查询
23
+
24
+ 现在开始响应。"""
25
+
26
+ NL_TO_SQL_PROMPT = """将以下自然语言转换为 {db_type} SQL 语句:
27
+ {question}
28
+
29
+ 只返回 SQL 语句:"""
@@ -0,0 +1 @@
1
+ """LLM Provider 实现"""