sql-assistant 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. sql_assistant/__init__.py +3 -0
  2. sql_assistant/api/__init__.py +1 -0
  3. sql_assistant/api/backup.py +116 -0
  4. sql_assistant/api/config.py +183 -0
  5. sql_assistant/api/conversation.py +71 -0
  6. sql_assistant/api/dependencies.py +22 -0
  7. sql_assistant/api/history.py +61 -0
  8. sql_assistant/api/models.py +221 -0
  9. sql_assistant/api/query.py +275 -0
  10. sql_assistant/api/routes.py +19 -0
  11. sql_assistant/api/schema.py +21 -0
  12. sql_assistant/config.py +144 -0
  13. sql_assistant/database/__init__.py +1 -0
  14. sql_assistant/database/backup.py +568 -0
  15. sql_assistant/database/connectors/__init__.py +1 -0
  16. sql_assistant/database/connectors/base.py +185 -0
  17. sql_assistant/database/connectors/exceptions.py +88 -0
  18. sql_assistant/database/connectors/mongodb.py +194 -0
  19. sql_assistant/database/connectors/mysql.py +110 -0
  20. sql_assistant/database/connectors/postgresql.py +133 -0
  21. sql_assistant/database/connectors/redis.py +132 -0
  22. sql_assistant/database/connectors/sqlserver.py +140 -0
  23. sql_assistant/database/history.py +290 -0
  24. sql_assistant/database/manager.py +178 -0
  25. sql_assistant/database/security.py +230 -0
  26. sql_assistant/llm/__init__.py +1 -0
  27. sql_assistant/llm/base.py +28 -0
  28. sql_assistant/llm/exceptions.py +96 -0
  29. sql_assistant/llm/manager.py +82 -0
  30. sql_assistant/llm/prompts.py +29 -0
  31. sql_assistant/llm/providers/__init__.py +1 -0
  32. sql_assistant/llm/providers/claude.py +132 -0
  33. sql_assistant/llm/providers/gemini.py +127 -0
  34. sql_assistant/llm/providers/openai_compatible.py +103 -0
  35. sql_assistant/llm/retry.py +88 -0
  36. sql_assistant/main.py +94 -0
  37. sql_assistant/settings.py +219 -0
  38. sql_assistant/web/__init__.py +1 -0
  39. sql_assistant/web/static/css/base.css +25 -0
  40. sql_assistant/web/static/css/components/backup.css +146 -0
  41. sql_assistant/web/static/css/components/chat.css +465 -0
  42. sql_assistant/web/static/css/components/modal.css +143 -0
  43. sql_assistant/web/static/css/components/settings.css +358 -0
  44. sql_assistant/web/static/css/components/sidebar.css +235 -0
  45. sql_assistant/web/static/css/components/toast.css +30 -0
  46. sql_assistant/web/static/css/style.css +10 -0
  47. sql_assistant/web/static/css/theme.css +200 -0
  48. sql_assistant/web/static/js/api.js +38 -0
  49. sql_assistant/web/static/js/app.js +161 -0
  50. sql_assistant/web/static/js/backup.js +216 -0
  51. sql_assistant/web/static/js/chat.js +238 -0
  52. sql_assistant/web/static/js/color-theme-manager.js +121 -0
  53. sql_assistant/web/static/js/confirm.js +95 -0
  54. sql_assistant/web/static/js/conversations.js +182 -0
  55. sql_assistant/web/static/js/settings.js +425 -0
  56. sql_assistant/web/static/js/state.js +43 -0
  57. sql_assistant/web/static/js/theme-manager.js +64 -0
  58. sql_assistant/web/static/js/ui.js +53 -0
  59. sql_assistant/web/templates/index.html +373 -0
  60. sql_assistant-1.0.0.dist-info/METADATA +24 -0
  61. sql_assistant-1.0.0.dist-info/RECORD +64 -0
  62. sql_assistant-1.0.0.dist-info/WHEEL +4 -0
  63. sql_assistant-1.0.0.dist-info/entry_points.txt +2 -0
  64. sql_assistant-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,133 @@
1
+ """PostgreSQL 连接器"""
2
+
3
+ import asyncio
4
+ from typing import Any
5
+
6
+ import psycopg2
7
+ import psycopg2.extras
8
+
9
+ from .base import BaseConnector, QueryResult
10
+
11
+
12
+ class PostgreSQLConnector(BaseConnector):
13
+ """PostgreSQL 数据库连接器"""
14
+
15
+ db_type = "postgresql"
16
+
17
+ def __init__(self, host: str, port: int, user: str, password: str, database: str):
18
+ super().__init__(host, port, user, password, database)
19
+ self._conn: psycopg2.extensions.connection | None = None
20
+
21
+ async def connect(self) -> None:
22
+ loop = asyncio.get_event_loop()
23
+ self._conn = await loop.run_in_executor(
24
+ None,
25
+ lambda: psycopg2.connect(
26
+ host=self.host,
27
+ port=self.port,
28
+ user=self.user,
29
+ password=self.password,
30
+ dbname=self.database,
31
+ connect_timeout=10,
32
+ ),
33
+ )
34
+
35
+ async def disconnect(self) -> None:
36
+ if self._conn:
37
+ self._conn.close()
38
+ self._conn = None
39
+
40
+ async def execute(self, sql: str) -> QueryResult:
41
+ if not self._conn:
42
+ raise RuntimeError("PostgreSQL 未连接")
43
+
44
+ loop = asyncio.get_event_loop()
45
+ result = QueryResult()
46
+
47
+ def _run():
48
+ with self._conn.cursor() as cursor:
49
+ statements = [s.strip() for s in sql.split(";") if s.strip()]
50
+ if not statements:
51
+ return result
52
+
53
+ for stmt in statements:
54
+ cursor.execute(stmt)
55
+ sql_type = self.classify_sql(stmt)
56
+ result.sql_type = sql_type
57
+
58
+ if sql_type == "SELECT":
59
+ result.columns = [col[0] for col in cursor.description] if cursor.description else []
60
+ result.rows = cursor.fetchall() if cursor.description else []
61
+ result.row_count = len(result.rows)
62
+ else:
63
+ result.affected_rows = cursor.rowcount
64
+
65
+ self._conn.commit()
66
+ return result
67
+
68
+ return await loop.run_in_executor(None, _run)
69
+
70
+ async def get_schema(self) -> dict:
71
+ if not self._conn:
72
+ raise RuntimeError("PostgreSQL 未连接")
73
+
74
+ loop = asyncio.get_event_loop()
75
+
76
+ def _run():
77
+ schema_sql = """
78
+ SELECT
79
+ c.table_name,
80
+ c.column_name,
81
+ c.data_type,
82
+ c.character_maximum_length,
83
+ c.is_nullable,
84
+ c.column_default,
85
+ tc.constraint_type
86
+ FROM information_schema.columns c
87
+ LEFT JOIN information_schema.key_column_usage kcu
88
+ ON c.table_schema = kcu.table_schema
89
+ AND c.table_name = kcu.table_name
90
+ AND c.column_name = kcu.column_name
91
+ LEFT JOIN information_schema.table_constraints tc
92
+ ON kcu.constraint_name = tc.constraint_name
93
+ AND tc.constraint_type = 'PRIMARY KEY'
94
+ WHERE c.table_schema NOT IN ('information_schema', 'pg_catalog')
95
+ ORDER BY c.table_name, c.ordinal_position
96
+ """
97
+ tables_dict: dict[str, list] = {}
98
+ with self._conn.cursor() as cursor:
99
+ cursor.execute(schema_sql)
100
+ for row in cursor.fetchall():
101
+ tname = row[0]
102
+ col_type = row[2]
103
+ if row[3]:
104
+ col_type = f"{col_type}({row[3]})"
105
+ if tname not in tables_dict:
106
+ tables_dict[tname] = []
107
+ tables_dict[tname].append({
108
+ "name": row[1],
109
+ "type": col_type,
110
+ "nullable": row[4] == "YES",
111
+ "key": "PRI" if row[6] == "PRIMARY KEY" else "",
112
+ "default": str(row[5]) if row[5] is not None else None,
113
+ "comment": "",
114
+ })
115
+
116
+ tables = [{"name": k, "columns": v} for k, v in tables_dict.items()]
117
+ return {"db_type": "postgresql", "tables": tables}
118
+
119
+ return await loop.run_in_executor(None, _run)
120
+
121
+ async def test_connection(self) -> dict:
122
+ from .exceptions import format_connector_result
123
+ try:
124
+ await self.connect()
125
+ if self._conn:
126
+ cur = self._conn.cursor()
127
+ cur.execute("SELECT 1")
128
+ cur.close()
129
+ return format_connector_result(True, data={"message": "PostgreSQL 连接成功"}, db_type="postgresql")
130
+ except Exception as e:
131
+ return format_connector_result(False, error=str(e), db_type="postgresql", code="CONNECTION_FAILED")
132
+ finally:
133
+ await self.disconnect()
@@ -0,0 +1,132 @@
1
+ """Redis 连接器"""
2
+
3
+ import asyncio
4
+ from typing import Any
5
+
6
+ import redis.asyncio as aioredis
7
+
8
+ from .base import BaseConnector, QueryResult
9
+
10
+
11
+ class RedisConnector(BaseConnector):
12
+ """Redis 数据库连接器"""
13
+
14
+ db_type = "redis"
15
+
16
+ def __init__(self, host: str, port: int, user: str, password: str, database: str):
17
+ # Redis database 是数字索引
18
+ super().__init__(host, port, user, password, database)
19
+ self._conn: aioredis.Redis | None = None
20
+
21
+ async def connect(self) -> None:
22
+ db_num = int(self.database) if self.database and self.database.isdigit() else 0
23
+ self._conn = aioredis.Redis(
24
+ host=self.host,
25
+ port=self.port,
26
+ username=self.user or None,
27
+ password=self.password or None,
28
+ db=db_num,
29
+ decode_responses=True,
30
+ socket_connect_timeout=10,
31
+ )
32
+
33
+ async def disconnect(self) -> None:
34
+ if self._conn:
35
+ await self._conn.aclose()
36
+ self._conn = None
37
+
38
+ async def get_schema(self) -> dict:
39
+ """Redis 无传统 schema,使用 SCAN 返回 key 列表(避免 KEYS 命令阻塞)"""
40
+ if not self._conn:
41
+ raise RuntimeError("Redis 未连接")
42
+
43
+ try:
44
+ all_keys = []
45
+ async for k in self._conn.scan_iter(match="*", count=50):
46
+ all_keys.append(k)
47
+ if len(all_keys) >= 50:
48
+ break
49
+
50
+ key_count = len(all_keys)
51
+ keys = sorted(all_keys)
52
+ types = {}
53
+ for k in keys:
54
+ try:
55
+ t = await self._conn.type(k)
56
+ types[k] = t
57
+ except Exception:
58
+ types[k] = "unknown"
59
+
60
+ return {
61
+ "db_type": "redis",
62
+ "keys": [
63
+ {"name": k, "type": types.get(k, "unknown")}
64
+ for k in keys
65
+ ],
66
+ "key_count": key_count,
67
+ }
68
+ except Exception as e:
69
+ return {"db_type": "redis", "keys": [], "error": str(e)}
70
+
71
+ async def execute(self, sql: str) -> QueryResult:
72
+ """执行 Redis 命令"""
73
+ if not self._conn:
74
+ raise RuntimeError("Redis 未连接")
75
+
76
+ result = QueryResult(sql_type="OTHER")
77
+
78
+ # Redis 使用原生命令格式:命令 参数1 参数2 ...
79
+ parts = sql.strip().split()
80
+ if not parts:
81
+ return result
82
+
83
+ command = parts[0].upper()
84
+ args = parts[1:] if len(parts) > 1 else []
85
+
86
+ try:
87
+ # 调用 Redis 命令
88
+ cmd_func = getattr(self._conn, command.lower(), None)
89
+ if cmd_func is None:
90
+ raise ValueError(f"不支持的 Redis 命令: {command}")
91
+
92
+ value = await cmd_func(*args)
93
+
94
+ result.sql_type = "SELECT" if command in (
95
+ "GET", "HGET", "HGETALL", "LRANGE", "SMEMBERS", "ZRANGE",
96
+ "KEYS", "MGET", "TYPE", "TTL", "EXISTS", "STRLEN",
97
+ ) else "OTHER"
98
+
99
+ # 格式化结果
100
+ if isinstance(value, list):
101
+ result.columns = ["result"]
102
+ result.rows = [[v] for v in value]
103
+ result.row_count = len(value)
104
+ elif isinstance(value, dict):
105
+ result.columns = ["key", "value"]
106
+ result.rows = [[k, v] for k, v in value.items()]
107
+ result.row_count = len(value)
108
+ elif isinstance(value, (int, float)):
109
+ result.affected_rows = int(value)
110
+ else:
111
+ result.columns = ["result"]
112
+ result.rows = [[str(value)]]
113
+ result.row_count = 1
114
+
115
+ except Exception as e:
116
+ result.columns = ["error"]
117
+ result.rows = [[str(e)]]
118
+ result.row_count = 1
119
+
120
+ return result
121
+
122
+ async def test_connection(self) -> dict:
123
+ from .exceptions import format_connector_result
124
+ try:
125
+ await self.connect()
126
+ if self._conn:
127
+ await self._conn.ping()
128
+ return format_connector_result(True, data={"message": "Redis 连接成功"}, db_type="redis")
129
+ except Exception as e:
130
+ return format_connector_result(False, error=str(e), db_type="redis", code="CONNECTION_FAILED")
131
+ finally:
132
+ await self.disconnect()
@@ -0,0 +1,140 @@
1
+ """SQL Server 连接器"""
2
+
3
+ import asyncio
4
+ from typing import Any
5
+
6
+ from .base import BaseConnector, QueryResult
7
+
8
+
9
+ class SQLServerConnector(BaseConnector):
10
+ """SQL Server 数据库连接器 (使用 pymssql)"""
11
+
12
+ db_type = "sqlserver"
13
+
14
+ def __init__(self, host: str, port: int, user: str, password: str, database: str):
15
+ super().__init__(host, port, user, password, database)
16
+ self._conn = None
17
+
18
+ async def connect(self) -> None:
19
+ try:
20
+ import pymssql
21
+ except ImportError:
22
+ raise ImportError(
23
+ "请安装 pymssql: uv pip install pymssql 或 pip install pymssql"
24
+ )
25
+
26
+ loop = asyncio.get_event_loop()
27
+ self._conn = await loop.run_in_executor(
28
+ None,
29
+ lambda: pymssql.connect(
30
+ server=self.host,
31
+ port=self.port,
32
+ user=self.user,
33
+ password=self.password,
34
+ database=self.database,
35
+ timeout=10,
36
+ login_timeout=10,
37
+ ),
38
+ )
39
+
40
+ async def disconnect(self) -> None:
41
+ if self._conn:
42
+ self._conn.close()
43
+ self._conn = None
44
+
45
+ async def execute(self, sql: str) -> QueryResult:
46
+ if not self._conn:
47
+ raise RuntimeError("SQL Server 未连接")
48
+
49
+ loop = asyncio.get_event_loop()
50
+ result = QueryResult()
51
+
52
+ def _run():
53
+ cursor = self._conn.cursor()
54
+ try:
55
+ statements = [s.strip() for s in sql.split(";") if s.strip()]
56
+ if not statements:
57
+ return result
58
+
59
+ for stmt in statements:
60
+ cursor.execute(stmt)
61
+ sql_type = self.classify_sql(stmt)
62
+ result.sql_type = sql_type
63
+
64
+ if sql_type == "SELECT":
65
+ result.columns = [col[0] for col in cursor.description] if cursor.description else []
66
+ result.rows = cursor.fetchall() if cursor.description else []
67
+ result.row_count = len(result.rows)
68
+ else:
69
+ result.affected_rows = cursor.rowcount
70
+
71
+ self._conn.commit()
72
+ return result
73
+ finally:
74
+ cursor.close()
75
+
76
+ return await loop.run_in_executor(None, _run)
77
+
78
+ async def get_schema(self) -> dict:
79
+ if not self._conn:
80
+ raise RuntimeError("SQL Server 未连接")
81
+
82
+ loop = asyncio.get_event_loop()
83
+
84
+ def _run():
85
+ schema_sql = """
86
+ SELECT
87
+ c.TABLE_NAME,
88
+ c.COLUMN_NAME,
89
+ c.DATA_TYPE,
90
+ c.CHARACTER_MAXIMUM_LENGTH,
91
+ CASE WHEN c.IS_NULLABLE = 'YES' THEN 1 ELSE 0 END,
92
+ c.COLUMN_DEFAULT,
93
+ CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 'PRI' ELSE '' END
94
+ FROM INFORMATION_SCHEMA.COLUMNS c
95
+ LEFT JOIN (
96
+ SELECT ku.TABLE_NAME, ku.COLUMN_NAME
97
+ FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc
98
+ JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE ku
99
+ ON tc.CONSTRAINT_NAME = ku.CONSTRAINT_NAME
100
+ WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY'
101
+ ) pk ON c.TABLE_NAME = pk.TABLE_NAME AND c.COLUMN_NAME = pk.COLUMN_NAME
102
+ WHERE c.TABLE_SCHEMA = 'dbo'
103
+ ORDER BY c.TABLE_NAME, c.ORDINAL_POSITION
104
+ """
105
+ tables_dict: dict[str, list] = {}
106
+ cursor = self._conn.cursor()
107
+ try:
108
+ cursor.execute(schema_sql)
109
+ for row in cursor.fetchall():
110
+ tname = row[0]
111
+ col_type = row[2]
112
+ if row[3]:
113
+ col_type = f"{col_type}({row[3]})"
114
+ if tname not in tables_dict:
115
+ tables_dict[tname] = []
116
+ tables_dict[tname].append({
117
+ "name": row[1],
118
+ "type": col_type,
119
+ "nullable": bool(row[4]),
120
+ "key": row[6],
121
+ "default": str(row[5]) if row[5] is not None else None,
122
+ "comment": "",
123
+ })
124
+ finally:
125
+ cursor.close()
126
+
127
+ tables = [{"name": k, "columns": v} for k, v in tables_dict.items()]
128
+ return {"db_type": "sqlserver", "tables": tables}
129
+
130
+ return await loop.run_in_executor(None, _run)
131
+
132
+ async def test_connection(self) -> dict:
133
+ from .exceptions import format_connector_result
134
+ try:
135
+ await self.connect()
136
+ return format_connector_result(True, data={"message": "SQL Server 连接成功"}, db_type="sqlserver")
137
+ except Exception as e:
138
+ return format_connector_result(False, error=str(e), db_type="sqlserver", code="CONNECTION_FAILED")
139
+ finally:
140
+ await self.disconnect()
@@ -0,0 +1,290 @@
1
+ """查询历史管理 - 使用 SQLite 存储"""
2
+
3
+ import json
4
+ import asyncio
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import aiosqlite
10
+
11
+ from ..config import DEFAULT_CONFIG_DIR
12
+
13
+ HISTORY_DB_PATH = DEFAULT_CONFIG_DIR / "history.db"
14
+
15
+
16
+ class HistoryManager:
17
+ """查询历史管理器 - 支持对话管理"""
18
+
19
+ def __init__(self, db_path: Optional[Path] = None):
20
+ self.db_path = db_path or HISTORY_DB_PATH
21
+ self._db: Optional[aiosqlite.Connection] = None
22
+
23
+ async def _get_db(self) -> aiosqlite.Connection:
24
+ if self._db is None:
25
+ self.db_path.parent.mkdir(parents=True, exist_ok=True)
26
+ self._db = await aiosqlite.connect(str(self.db_path))
27
+ self._db.row_factory = aiosqlite.Row
28
+ await self._init_tables()
29
+ return self._db
30
+
31
+ async def _init_tables(self):
32
+ db = self._db
33
+ # 对话表 - 存储对话会话
34
+ await db.execute("""
35
+ CREATE TABLE IF NOT EXISTS conversations (
36
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
37
+ title TEXT NOT NULL,
38
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
39
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
40
+ )
41
+ """)
42
+ await db.execute("""
43
+ CREATE INDEX IF NOT EXISTS idx_conversations_updated_at
44
+ ON conversations(updated_at DESC)
45
+ """)
46
+
47
+ # 查询历史表 - 添加 conversation_id 字段关联对话
48
+ await db.execute("""
49
+ CREATE TABLE IF NOT EXISTS query_history (
50
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
51
+ conversation_id INTEGER,
52
+ question TEXT NOT NULL,
53
+ sql TEXT NOT NULL,
54
+ result_json TEXT,
55
+ db_type TEXT DEFAULT '',
56
+ llm_provider TEXT DEFAULT '',
57
+ success INTEGER DEFAULT 1,
58
+ error_message TEXT DEFAULT '',
59
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
60
+ FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
61
+ )
62
+ """)
63
+ await db.execute("""
64
+ CREATE INDEX IF NOT EXISTS idx_history_conversation_id
65
+ ON query_history(conversation_id)
66
+ """)
67
+ await db.execute("""
68
+ CREATE INDEX IF NOT EXISTS idx_history_created_at
69
+ ON query_history(created_at DESC)
70
+ """)
71
+ await db.commit()
72
+
73
+ # ============ 对话管理方法 ============
74
+
75
+ async def create_conversation(self, title: str = "") -> int:
76
+ """创建新对话,返回对话 ID"""
77
+ db = await self._get_db()
78
+ if not title:
79
+ title = "未命名对话"
80
+
81
+ cursor = await db.execute(
82
+ """INSERT INTO conversations (title, created_at, updated_at)
83
+ VALUES (?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)""",
84
+ (title,),
85
+ )
86
+ await db.commit()
87
+ return cursor.lastrowid
88
+
89
+ async def get_conversations(self, limit: int = 50, offset: int = 0) -> list[dict]:
90
+ """获取对话列表"""
91
+ db = await self._get_db()
92
+ cursor = await db.execute(
93
+ """SELECT c.id, c.title, c.created_at, c.updated_at,
94
+ COUNT(q.id) as message_count
95
+ FROM conversations c
96
+ LEFT JOIN query_history q ON c.id = q.conversation_id
97
+ GROUP BY c.id
98
+ ORDER BY c.updated_at DESC
99
+ LIMIT ? OFFSET ?""",
100
+ (limit, offset),
101
+ )
102
+ rows = await cursor.fetchall()
103
+ return [dict(row) for row in rows]
104
+
105
+ async def get_conversation(self, conversation_id: int) -> Optional[dict]:
106
+ """获取单个对话详情"""
107
+ db = await self._get_db()
108
+ cursor = await db.execute(
109
+ """SELECT c.id, c.title, c.created_at, c.updated_at,
110
+ COUNT(q.id) as message_count
111
+ FROM conversations c
112
+ LEFT JOIN query_history q ON c.id = q.conversation_id
113
+ WHERE c.id = ?
114
+ GROUP BY c.id""",
115
+ (conversation_id,),
116
+ )
117
+ row = await cursor.fetchone()
118
+ return dict(row) if row else None
119
+
120
+ async def update_conversation(self, conversation_id: int, title: str) -> bool:
121
+ """更新对话标题"""
122
+ db = await self._get_db()
123
+ cursor = await db.execute(
124
+ """UPDATE conversations SET title = ?, updated_at = CURRENT_TIMESTAMP
125
+ WHERE id = ?""",
126
+ (title, conversation_id),
127
+ )
128
+ await db.commit()
129
+ return cursor.rowcount > 0
130
+
131
+ async def delete_conversation(self, conversation_id: int) -> bool:
132
+ """删除对话(会级联删除相关的查询记录)"""
133
+ db = await self._get_db()
134
+ cursor = await db.execute(
135
+ "DELETE FROM conversations WHERE id = ?",
136
+ (conversation_id,),
137
+ )
138
+ await db.commit()
139
+ return cursor.rowcount > 0
140
+
141
+ async def get_conversation_messages(self, conversation_id: int) -> list[dict]:
142
+ """获取对话中的所有消息"""
143
+ db = await self._get_db()
144
+ cursor = await db.execute(
145
+ """SELECT q.id, q.question, q.sql, q.result_json, q.db_type,
146
+ q.llm_provider, q.success, q.error_message, q.created_at
147
+ FROM query_history q
148
+ WHERE q.conversation_id = ?
149
+ ORDER BY q.created_at ASC""",
150
+ (conversation_id,),
151
+ )
152
+ rows = await cursor.fetchall()
153
+ return [dict(row) for row in rows]
154
+
155
+ # ============ 查询历史方法 ============
156
+
157
+ async def add_record(
158
+ self,
159
+ question: str,
160
+ sql: str,
161
+ result: Optional[dict] = None,
162
+ db_type: str = "",
163
+ llm_provider: str = "",
164
+ success: bool = True,
165
+ error_message: str = "",
166
+ conversation_id: Optional[int] = None,
167
+ ) -> int:
168
+ """添加一条查询记录,返回记录 ID"""
169
+ db = await self._get_db()
170
+ result_json = json.dumps(result, ensure_ascii=False, default=str) if result else None
171
+
172
+ cursor = await db.execute(
173
+ """INSERT INTO query_history
174
+ (conversation_id, question, sql, result_json, db_type, llm_provider, success, error_message)
175
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
176
+ (conversation_id, question, sql, result_json, db_type, llm_provider, int(success), error_message),
177
+ )
178
+ await db.commit()
179
+
180
+ # 如果关联了对话,更新对话的 updated_at
181
+ if conversation_id:
182
+ await db.execute(
183
+ "UPDATE conversations SET updated_at = CURRENT_TIMESTAMP WHERE id = ?",
184
+ (conversation_id,),
185
+ )
186
+ await db.commit()
187
+
188
+ return cursor.lastrowid
189
+
190
+ async def get_records(
191
+ self,
192
+ limit: int = 50,
193
+ offset: int = 0,
194
+ conversation_id: Optional[int] = None,
195
+ ) -> list[dict]:
196
+ """获取查询历史记录"""
197
+ db = await self._get_db()
198
+ if conversation_id:
199
+ cursor = await db.execute(
200
+ """SELECT id, conversation_id, question, sql, result_json, db_type, llm_provider,
201
+ success, error_message, created_at
202
+ FROM query_history
203
+ WHERE conversation_id = ?
204
+ ORDER BY created_at DESC
205
+ LIMIT ? OFFSET ?""",
206
+ (conversation_id, limit, offset),
207
+ )
208
+ else:
209
+ cursor = await db.execute(
210
+ """SELECT id, conversation_id, question, sql, result_json, db_type, llm_provider,
211
+ success, error_message, created_at
212
+ FROM query_history
213
+ ORDER BY created_at DESC
214
+ LIMIT ? OFFSET ?""",
215
+ (limit, offset),
216
+ )
217
+ rows = await cursor.fetchall()
218
+ return [dict(row) for row in rows]
219
+
220
+ async def get_record(self, record_id: int) -> Optional[dict]:
221
+ """获取单条记录"""
222
+ db = await self._get_db()
223
+ cursor = await db.execute(
224
+ "SELECT * FROM query_history WHERE id = ?",
225
+ (record_id,),
226
+ )
227
+ row = await cursor.fetchone()
228
+ return dict(row) if row else None
229
+
230
+ async def delete_record(self, record_id: int) -> bool:
231
+ """删除单条记录"""
232
+ db = await self._get_db()
233
+ cursor = await db.execute(
234
+ "DELETE FROM query_history WHERE id = ?",
235
+ (record_id,),
236
+ )
237
+ await db.commit()
238
+ return cursor.rowcount > 0
239
+
240
+ async def clear_history(self) -> int:
241
+ """清空所有历史记录(包括所有对话)"""
242
+ db = await self._get_db()
243
+ cursor = await db.execute("DELETE FROM query_history")
244
+ await db.commit()
245
+ await db.execute("DELETE FROM conversations")
246
+ await db.commit()
247
+ return cursor.rowcount
248
+
249
+ async def get_count(self, conversation_id: Optional[int] = None) -> int:
250
+ """获取记录总数"""
251
+ db = await self._get_db()
252
+ if conversation_id:
253
+ cursor = await db.execute(
254
+ "SELECT COUNT(*) FROM query_history WHERE conversation_id = ?",
255
+ (conversation_id,),
256
+ )
257
+ else:
258
+ cursor = await db.execute("SELECT COUNT(*) FROM query_history")
259
+ row = await cursor.fetchone()
260
+ return row[0] if row else 0
261
+
262
+ async def get_conversation_count(self) -> int:
263
+ """获取对话总数"""
264
+ db = await self._get_db()
265
+ cursor = await db.execute("SELECT COUNT(*) FROM conversations")
266
+ row = await cursor.fetchone()
267
+ return row[0] if row else 0
268
+
269
+ async def close(self):
270
+ if self._db:
271
+ await self._db.close()
272
+ self._db = None
273
+
274
+
275
+ # 全局单例
276
+ _history_manager: Optional[HistoryManager] = None
277
+
278
+
279
+ def get_history_manager() -> HistoryManager:
280
+ global _history_manager
281
+ if _history_manager is None:
282
+ _history_manager = HistoryManager()
283
+ return _history_manager
284
+
285
+
286
+ async def close_history_manager():
287
+ global _history_manager
288
+ if _history_manager:
289
+ await _history_manager.close()
290
+ _history_manager = None