sql-assistant 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_assistant/__init__.py +3 -0
- sql_assistant/api/__init__.py +1 -0
- sql_assistant/api/backup.py +116 -0
- sql_assistant/api/config.py +183 -0
- sql_assistant/api/conversation.py +71 -0
- sql_assistant/api/dependencies.py +22 -0
- sql_assistant/api/history.py +61 -0
- sql_assistant/api/models.py +221 -0
- sql_assistant/api/query.py +275 -0
- sql_assistant/api/routes.py +19 -0
- sql_assistant/api/schema.py +21 -0
- sql_assistant/config.py +144 -0
- sql_assistant/database/__init__.py +1 -0
- sql_assistant/database/backup.py +568 -0
- sql_assistant/database/connectors/__init__.py +1 -0
- sql_assistant/database/connectors/base.py +185 -0
- sql_assistant/database/connectors/exceptions.py +88 -0
- sql_assistant/database/connectors/mongodb.py +194 -0
- sql_assistant/database/connectors/mysql.py +110 -0
- sql_assistant/database/connectors/postgresql.py +133 -0
- sql_assistant/database/connectors/redis.py +132 -0
- sql_assistant/database/connectors/sqlserver.py +140 -0
- sql_assistant/database/history.py +290 -0
- sql_assistant/database/manager.py +178 -0
- sql_assistant/database/security.py +230 -0
- sql_assistant/llm/__init__.py +1 -0
- sql_assistant/llm/base.py +28 -0
- sql_assistant/llm/exceptions.py +96 -0
- sql_assistant/llm/manager.py +82 -0
- sql_assistant/llm/prompts.py +29 -0
- sql_assistant/llm/providers/__init__.py +1 -0
- sql_assistant/llm/providers/claude.py +132 -0
- sql_assistant/llm/providers/gemini.py +127 -0
- sql_assistant/llm/providers/openai_compatible.py +103 -0
- sql_assistant/llm/retry.py +88 -0
- sql_assistant/main.py +94 -0
- sql_assistant/settings.py +219 -0
- sql_assistant/web/__init__.py +1 -0
- sql_assistant/web/static/css/base.css +25 -0
- sql_assistant/web/static/css/components/backup.css +146 -0
- sql_assistant/web/static/css/components/chat.css +465 -0
- sql_assistant/web/static/css/components/modal.css +143 -0
- sql_assistant/web/static/css/components/settings.css +358 -0
- sql_assistant/web/static/css/components/sidebar.css +235 -0
- sql_assistant/web/static/css/components/toast.css +30 -0
- sql_assistant/web/static/css/style.css +10 -0
- sql_assistant/web/static/css/theme.css +200 -0
- sql_assistant/web/static/js/api.js +38 -0
- sql_assistant/web/static/js/app.js +161 -0
- sql_assistant/web/static/js/backup.js +216 -0
- sql_assistant/web/static/js/chat.js +238 -0
- sql_assistant/web/static/js/color-theme-manager.js +121 -0
- sql_assistant/web/static/js/confirm.js +95 -0
- sql_assistant/web/static/js/conversations.js +182 -0
- sql_assistant/web/static/js/settings.js +425 -0
- sql_assistant/web/static/js/state.js +43 -0
- sql_assistant/web/static/js/theme-manager.js +64 -0
- sql_assistant/web/static/js/ui.js +53 -0
- sql_assistant/web/templates/index.html +373 -0
- sql_assistant-1.0.0.dist-info/METADATA +24 -0
- sql_assistant-1.0.0.dist-info/RECORD +64 -0
- sql_assistant-1.0.0.dist-info/WHEEL +4 -0
- sql_assistant-1.0.0.dist-info/entry_points.txt +2 -0
- sql_assistant-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""PostgreSQL 连接器"""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import psycopg2
|
|
7
|
+
import psycopg2.extras
|
|
8
|
+
|
|
9
|
+
from .base import BaseConnector, QueryResult
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PostgreSQLConnector(BaseConnector):
|
|
13
|
+
"""PostgreSQL 数据库连接器"""
|
|
14
|
+
|
|
15
|
+
db_type = "postgresql"
|
|
16
|
+
|
|
17
|
+
def __init__(self, host: str, port: int, user: str, password: str, database: str):
|
|
18
|
+
super().__init__(host, port, user, password, database)
|
|
19
|
+
self._conn: psycopg2.extensions.connection | None = None
|
|
20
|
+
|
|
21
|
+
async def connect(self) -> None:
|
|
22
|
+
loop = asyncio.get_event_loop()
|
|
23
|
+
self._conn = await loop.run_in_executor(
|
|
24
|
+
None,
|
|
25
|
+
lambda: psycopg2.connect(
|
|
26
|
+
host=self.host,
|
|
27
|
+
port=self.port,
|
|
28
|
+
user=self.user,
|
|
29
|
+
password=self.password,
|
|
30
|
+
dbname=self.database,
|
|
31
|
+
connect_timeout=10,
|
|
32
|
+
),
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
async def disconnect(self) -> None:
|
|
36
|
+
if self._conn:
|
|
37
|
+
self._conn.close()
|
|
38
|
+
self._conn = None
|
|
39
|
+
|
|
40
|
+
async def execute(self, sql: str) -> QueryResult:
|
|
41
|
+
if not self._conn:
|
|
42
|
+
raise RuntimeError("PostgreSQL 未连接")
|
|
43
|
+
|
|
44
|
+
loop = asyncio.get_event_loop()
|
|
45
|
+
result = QueryResult()
|
|
46
|
+
|
|
47
|
+
def _run():
|
|
48
|
+
with self._conn.cursor() as cursor:
|
|
49
|
+
statements = [s.strip() for s in sql.split(";") if s.strip()]
|
|
50
|
+
if not statements:
|
|
51
|
+
return result
|
|
52
|
+
|
|
53
|
+
for stmt in statements:
|
|
54
|
+
cursor.execute(stmt)
|
|
55
|
+
sql_type = self.classify_sql(stmt)
|
|
56
|
+
result.sql_type = sql_type
|
|
57
|
+
|
|
58
|
+
if sql_type == "SELECT":
|
|
59
|
+
result.columns = [col[0] for col in cursor.description] if cursor.description else []
|
|
60
|
+
result.rows = cursor.fetchall() if cursor.description else []
|
|
61
|
+
result.row_count = len(result.rows)
|
|
62
|
+
else:
|
|
63
|
+
result.affected_rows = cursor.rowcount
|
|
64
|
+
|
|
65
|
+
self._conn.commit()
|
|
66
|
+
return result
|
|
67
|
+
|
|
68
|
+
return await loop.run_in_executor(None, _run)
|
|
69
|
+
|
|
70
|
+
async def get_schema(self) -> dict:
|
|
71
|
+
if not self._conn:
|
|
72
|
+
raise RuntimeError("PostgreSQL 未连接")
|
|
73
|
+
|
|
74
|
+
loop = asyncio.get_event_loop()
|
|
75
|
+
|
|
76
|
+
def _run():
|
|
77
|
+
schema_sql = """
|
|
78
|
+
SELECT
|
|
79
|
+
c.table_name,
|
|
80
|
+
c.column_name,
|
|
81
|
+
c.data_type,
|
|
82
|
+
c.character_maximum_length,
|
|
83
|
+
c.is_nullable,
|
|
84
|
+
c.column_default,
|
|
85
|
+
tc.constraint_type
|
|
86
|
+
FROM information_schema.columns c
|
|
87
|
+
LEFT JOIN information_schema.key_column_usage kcu
|
|
88
|
+
ON c.table_schema = kcu.table_schema
|
|
89
|
+
AND c.table_name = kcu.table_name
|
|
90
|
+
AND c.column_name = kcu.column_name
|
|
91
|
+
LEFT JOIN information_schema.table_constraints tc
|
|
92
|
+
ON kcu.constraint_name = tc.constraint_name
|
|
93
|
+
AND tc.constraint_type = 'PRIMARY KEY'
|
|
94
|
+
WHERE c.table_schema NOT IN ('information_schema', 'pg_catalog')
|
|
95
|
+
ORDER BY c.table_name, c.ordinal_position
|
|
96
|
+
"""
|
|
97
|
+
tables_dict: dict[str, list] = {}
|
|
98
|
+
with self._conn.cursor() as cursor:
|
|
99
|
+
cursor.execute(schema_sql)
|
|
100
|
+
for row in cursor.fetchall():
|
|
101
|
+
tname = row[0]
|
|
102
|
+
col_type = row[2]
|
|
103
|
+
if row[3]:
|
|
104
|
+
col_type = f"{col_type}({row[3]})"
|
|
105
|
+
if tname not in tables_dict:
|
|
106
|
+
tables_dict[tname] = []
|
|
107
|
+
tables_dict[tname].append({
|
|
108
|
+
"name": row[1],
|
|
109
|
+
"type": col_type,
|
|
110
|
+
"nullable": row[4] == "YES",
|
|
111
|
+
"key": "PRI" if row[6] == "PRIMARY KEY" else "",
|
|
112
|
+
"default": str(row[5]) if row[5] is not None else None,
|
|
113
|
+
"comment": "",
|
|
114
|
+
})
|
|
115
|
+
|
|
116
|
+
tables = [{"name": k, "columns": v} for k, v in tables_dict.items()]
|
|
117
|
+
return {"db_type": "postgresql", "tables": tables}
|
|
118
|
+
|
|
119
|
+
return await loop.run_in_executor(None, _run)
|
|
120
|
+
|
|
121
|
+
async def test_connection(self) -> dict:
|
|
122
|
+
from .exceptions import format_connector_result
|
|
123
|
+
try:
|
|
124
|
+
await self.connect()
|
|
125
|
+
if self._conn:
|
|
126
|
+
cur = self._conn.cursor()
|
|
127
|
+
cur.execute("SELECT 1")
|
|
128
|
+
cur.close()
|
|
129
|
+
return format_connector_result(True, data={"message": "PostgreSQL 连接成功"}, db_type="postgresql")
|
|
130
|
+
except Exception as e:
|
|
131
|
+
return format_connector_result(False, error=str(e), db_type="postgresql", code="CONNECTION_FAILED")
|
|
132
|
+
finally:
|
|
133
|
+
await self.disconnect()
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""Redis 连接器"""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import redis.asyncio as aioredis
|
|
7
|
+
|
|
8
|
+
from .base import BaseConnector, QueryResult
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RedisConnector(BaseConnector):
|
|
12
|
+
"""Redis 数据库连接器"""
|
|
13
|
+
|
|
14
|
+
db_type = "redis"
|
|
15
|
+
|
|
16
|
+
def __init__(self, host: str, port: int, user: str, password: str, database: str):
|
|
17
|
+
# Redis database 是数字索引
|
|
18
|
+
super().__init__(host, port, user, password, database)
|
|
19
|
+
self._conn: aioredis.Redis | None = None
|
|
20
|
+
|
|
21
|
+
async def connect(self) -> None:
|
|
22
|
+
db_num = int(self.database) if self.database and self.database.isdigit() else 0
|
|
23
|
+
self._conn = aioredis.Redis(
|
|
24
|
+
host=self.host,
|
|
25
|
+
port=self.port,
|
|
26
|
+
username=self.user or None,
|
|
27
|
+
password=self.password or None,
|
|
28
|
+
db=db_num,
|
|
29
|
+
decode_responses=True,
|
|
30
|
+
socket_connect_timeout=10,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
async def disconnect(self) -> None:
|
|
34
|
+
if self._conn:
|
|
35
|
+
await self._conn.aclose()
|
|
36
|
+
self._conn = None
|
|
37
|
+
|
|
38
|
+
async def get_schema(self) -> dict:
|
|
39
|
+
"""Redis 无传统 schema,使用 SCAN 返回 key 列表(避免 KEYS 命令阻塞)"""
|
|
40
|
+
if not self._conn:
|
|
41
|
+
raise RuntimeError("Redis 未连接")
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
all_keys = []
|
|
45
|
+
async for k in self._conn.scan_iter(match="*", count=50):
|
|
46
|
+
all_keys.append(k)
|
|
47
|
+
if len(all_keys) >= 50:
|
|
48
|
+
break
|
|
49
|
+
|
|
50
|
+
key_count = len(all_keys)
|
|
51
|
+
keys = sorted(all_keys)
|
|
52
|
+
types = {}
|
|
53
|
+
for k in keys:
|
|
54
|
+
try:
|
|
55
|
+
t = await self._conn.type(k)
|
|
56
|
+
types[k] = t
|
|
57
|
+
except Exception:
|
|
58
|
+
types[k] = "unknown"
|
|
59
|
+
|
|
60
|
+
return {
|
|
61
|
+
"db_type": "redis",
|
|
62
|
+
"keys": [
|
|
63
|
+
{"name": k, "type": types.get(k, "unknown")}
|
|
64
|
+
for k in keys
|
|
65
|
+
],
|
|
66
|
+
"key_count": key_count,
|
|
67
|
+
}
|
|
68
|
+
except Exception as e:
|
|
69
|
+
return {"db_type": "redis", "keys": [], "error": str(e)}
|
|
70
|
+
|
|
71
|
+
async def execute(self, sql: str) -> QueryResult:
|
|
72
|
+
"""执行 Redis 命令"""
|
|
73
|
+
if not self._conn:
|
|
74
|
+
raise RuntimeError("Redis 未连接")
|
|
75
|
+
|
|
76
|
+
result = QueryResult(sql_type="OTHER")
|
|
77
|
+
|
|
78
|
+
# Redis 使用原生命令格式:命令 参数1 参数2 ...
|
|
79
|
+
parts = sql.strip().split()
|
|
80
|
+
if not parts:
|
|
81
|
+
return result
|
|
82
|
+
|
|
83
|
+
command = parts[0].upper()
|
|
84
|
+
args = parts[1:] if len(parts) > 1 else []
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
# 调用 Redis 命令
|
|
88
|
+
cmd_func = getattr(self._conn, command.lower(), None)
|
|
89
|
+
if cmd_func is None:
|
|
90
|
+
raise ValueError(f"不支持的 Redis 命令: {command}")
|
|
91
|
+
|
|
92
|
+
value = await cmd_func(*args)
|
|
93
|
+
|
|
94
|
+
result.sql_type = "SELECT" if command in (
|
|
95
|
+
"GET", "HGET", "HGETALL", "LRANGE", "SMEMBERS", "ZRANGE",
|
|
96
|
+
"KEYS", "MGET", "TYPE", "TTL", "EXISTS", "STRLEN",
|
|
97
|
+
) else "OTHER"
|
|
98
|
+
|
|
99
|
+
# 格式化结果
|
|
100
|
+
if isinstance(value, list):
|
|
101
|
+
result.columns = ["result"]
|
|
102
|
+
result.rows = [[v] for v in value]
|
|
103
|
+
result.row_count = len(value)
|
|
104
|
+
elif isinstance(value, dict):
|
|
105
|
+
result.columns = ["key", "value"]
|
|
106
|
+
result.rows = [[k, v] for k, v in value.items()]
|
|
107
|
+
result.row_count = len(value)
|
|
108
|
+
elif isinstance(value, (int, float)):
|
|
109
|
+
result.affected_rows = int(value)
|
|
110
|
+
else:
|
|
111
|
+
result.columns = ["result"]
|
|
112
|
+
result.rows = [[str(value)]]
|
|
113
|
+
result.row_count = 1
|
|
114
|
+
|
|
115
|
+
except Exception as e:
|
|
116
|
+
result.columns = ["error"]
|
|
117
|
+
result.rows = [[str(e)]]
|
|
118
|
+
result.row_count = 1
|
|
119
|
+
|
|
120
|
+
return result
|
|
121
|
+
|
|
122
|
+
async def test_connection(self) -> dict:
|
|
123
|
+
from .exceptions import format_connector_result
|
|
124
|
+
try:
|
|
125
|
+
await self.connect()
|
|
126
|
+
if self._conn:
|
|
127
|
+
await self._conn.ping()
|
|
128
|
+
return format_connector_result(True, data={"message": "Redis 连接成功"}, db_type="redis")
|
|
129
|
+
except Exception as e:
|
|
130
|
+
return format_connector_result(False, error=str(e), db_type="redis", code="CONNECTION_FAILED")
|
|
131
|
+
finally:
|
|
132
|
+
await self.disconnect()
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""SQL Server 连接器"""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from .base import BaseConnector, QueryResult
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SQLServerConnector(BaseConnector):
|
|
10
|
+
"""SQL Server 数据库连接器 (使用 pymssql)"""
|
|
11
|
+
|
|
12
|
+
db_type = "sqlserver"
|
|
13
|
+
|
|
14
|
+
def __init__(self, host: str, port: int, user: str, password: str, database: str):
|
|
15
|
+
super().__init__(host, port, user, password, database)
|
|
16
|
+
self._conn = None
|
|
17
|
+
|
|
18
|
+
async def connect(self) -> None:
|
|
19
|
+
try:
|
|
20
|
+
import pymssql
|
|
21
|
+
except ImportError:
|
|
22
|
+
raise ImportError(
|
|
23
|
+
"请安装 pymssql: uv pip install pymssql 或 pip install pymssql"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
loop = asyncio.get_event_loop()
|
|
27
|
+
self._conn = await loop.run_in_executor(
|
|
28
|
+
None,
|
|
29
|
+
lambda: pymssql.connect(
|
|
30
|
+
server=self.host,
|
|
31
|
+
port=self.port,
|
|
32
|
+
user=self.user,
|
|
33
|
+
password=self.password,
|
|
34
|
+
database=self.database,
|
|
35
|
+
timeout=10,
|
|
36
|
+
login_timeout=10,
|
|
37
|
+
),
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
async def disconnect(self) -> None:
|
|
41
|
+
if self._conn:
|
|
42
|
+
self._conn.close()
|
|
43
|
+
self._conn = None
|
|
44
|
+
|
|
45
|
+
async def execute(self, sql: str) -> QueryResult:
|
|
46
|
+
if not self._conn:
|
|
47
|
+
raise RuntimeError("SQL Server 未连接")
|
|
48
|
+
|
|
49
|
+
loop = asyncio.get_event_loop()
|
|
50
|
+
result = QueryResult()
|
|
51
|
+
|
|
52
|
+
def _run():
|
|
53
|
+
cursor = self._conn.cursor()
|
|
54
|
+
try:
|
|
55
|
+
statements = [s.strip() for s in sql.split(";") if s.strip()]
|
|
56
|
+
if not statements:
|
|
57
|
+
return result
|
|
58
|
+
|
|
59
|
+
for stmt in statements:
|
|
60
|
+
cursor.execute(stmt)
|
|
61
|
+
sql_type = self.classify_sql(stmt)
|
|
62
|
+
result.sql_type = sql_type
|
|
63
|
+
|
|
64
|
+
if sql_type == "SELECT":
|
|
65
|
+
result.columns = [col[0] for col in cursor.description] if cursor.description else []
|
|
66
|
+
result.rows = cursor.fetchall() if cursor.description else []
|
|
67
|
+
result.row_count = len(result.rows)
|
|
68
|
+
else:
|
|
69
|
+
result.affected_rows = cursor.rowcount
|
|
70
|
+
|
|
71
|
+
self._conn.commit()
|
|
72
|
+
return result
|
|
73
|
+
finally:
|
|
74
|
+
cursor.close()
|
|
75
|
+
|
|
76
|
+
return await loop.run_in_executor(None, _run)
|
|
77
|
+
|
|
78
|
+
async def get_schema(self) -> dict:
|
|
79
|
+
if not self._conn:
|
|
80
|
+
raise RuntimeError("SQL Server 未连接")
|
|
81
|
+
|
|
82
|
+
loop = asyncio.get_event_loop()
|
|
83
|
+
|
|
84
|
+
def _run():
|
|
85
|
+
schema_sql = """
|
|
86
|
+
SELECT
|
|
87
|
+
c.TABLE_NAME,
|
|
88
|
+
c.COLUMN_NAME,
|
|
89
|
+
c.DATA_TYPE,
|
|
90
|
+
c.CHARACTER_MAXIMUM_LENGTH,
|
|
91
|
+
CASE WHEN c.IS_NULLABLE = 'YES' THEN 1 ELSE 0 END,
|
|
92
|
+
c.COLUMN_DEFAULT,
|
|
93
|
+
CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 'PRI' ELSE '' END
|
|
94
|
+
FROM INFORMATION_SCHEMA.COLUMNS c
|
|
95
|
+
LEFT JOIN (
|
|
96
|
+
SELECT ku.TABLE_NAME, ku.COLUMN_NAME
|
|
97
|
+
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc
|
|
98
|
+
JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE ku
|
|
99
|
+
ON tc.CONSTRAINT_NAME = ku.CONSTRAINT_NAME
|
|
100
|
+
WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY'
|
|
101
|
+
) pk ON c.TABLE_NAME = pk.TABLE_NAME AND c.COLUMN_NAME = pk.COLUMN_NAME
|
|
102
|
+
WHERE c.TABLE_SCHEMA = 'dbo'
|
|
103
|
+
ORDER BY c.TABLE_NAME, c.ORDINAL_POSITION
|
|
104
|
+
"""
|
|
105
|
+
tables_dict: dict[str, list] = {}
|
|
106
|
+
cursor = self._conn.cursor()
|
|
107
|
+
try:
|
|
108
|
+
cursor.execute(schema_sql)
|
|
109
|
+
for row in cursor.fetchall():
|
|
110
|
+
tname = row[0]
|
|
111
|
+
col_type = row[2]
|
|
112
|
+
if row[3]:
|
|
113
|
+
col_type = f"{col_type}({row[3]})"
|
|
114
|
+
if tname not in tables_dict:
|
|
115
|
+
tables_dict[tname] = []
|
|
116
|
+
tables_dict[tname].append({
|
|
117
|
+
"name": row[1],
|
|
118
|
+
"type": col_type,
|
|
119
|
+
"nullable": bool(row[4]),
|
|
120
|
+
"key": row[6],
|
|
121
|
+
"default": str(row[5]) if row[5] is not None else None,
|
|
122
|
+
"comment": "",
|
|
123
|
+
})
|
|
124
|
+
finally:
|
|
125
|
+
cursor.close()
|
|
126
|
+
|
|
127
|
+
tables = [{"name": k, "columns": v} for k, v in tables_dict.items()]
|
|
128
|
+
return {"db_type": "sqlserver", "tables": tables}
|
|
129
|
+
|
|
130
|
+
return await loop.run_in_executor(None, _run)
|
|
131
|
+
|
|
132
|
+
async def test_connection(self) -> dict:
|
|
133
|
+
from .exceptions import format_connector_result
|
|
134
|
+
try:
|
|
135
|
+
await self.connect()
|
|
136
|
+
return format_connector_result(True, data={"message": "SQL Server 连接成功"}, db_type="sqlserver")
|
|
137
|
+
except Exception as e:
|
|
138
|
+
return format_connector_result(False, error=str(e), db_type="sqlserver", code="CONNECTION_FAILED")
|
|
139
|
+
finally:
|
|
140
|
+
await self.disconnect()
|
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
"""查询历史管理 - 使用 SQLite 存储"""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import asyncio
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import aiosqlite
|
|
10
|
+
|
|
11
|
+
from ..config import DEFAULT_CONFIG_DIR
|
|
12
|
+
|
|
13
|
+
HISTORY_DB_PATH = DEFAULT_CONFIG_DIR / "history.db"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HistoryManager:
|
|
17
|
+
"""查询历史管理器 - 支持对话管理"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, db_path: Optional[Path] = None):
|
|
20
|
+
self.db_path = db_path or HISTORY_DB_PATH
|
|
21
|
+
self._db: Optional[aiosqlite.Connection] = None
|
|
22
|
+
|
|
23
|
+
async def _get_db(self) -> aiosqlite.Connection:
|
|
24
|
+
if self._db is None:
|
|
25
|
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
26
|
+
self._db = await aiosqlite.connect(str(self.db_path))
|
|
27
|
+
self._db.row_factory = aiosqlite.Row
|
|
28
|
+
await self._init_tables()
|
|
29
|
+
return self._db
|
|
30
|
+
|
|
31
|
+
async def _init_tables(self):
|
|
32
|
+
db = self._db
|
|
33
|
+
# 对话表 - 存储对话会话
|
|
34
|
+
await db.execute("""
|
|
35
|
+
CREATE TABLE IF NOT EXISTS conversations (
|
|
36
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
37
|
+
title TEXT NOT NULL,
|
|
38
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
39
|
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
40
|
+
)
|
|
41
|
+
""")
|
|
42
|
+
await db.execute("""
|
|
43
|
+
CREATE INDEX IF NOT EXISTS idx_conversations_updated_at
|
|
44
|
+
ON conversations(updated_at DESC)
|
|
45
|
+
""")
|
|
46
|
+
|
|
47
|
+
# 查询历史表 - 添加 conversation_id 字段关联对话
|
|
48
|
+
await db.execute("""
|
|
49
|
+
CREATE TABLE IF NOT EXISTS query_history (
|
|
50
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
51
|
+
conversation_id INTEGER,
|
|
52
|
+
question TEXT NOT NULL,
|
|
53
|
+
sql TEXT NOT NULL,
|
|
54
|
+
result_json TEXT,
|
|
55
|
+
db_type TEXT DEFAULT '',
|
|
56
|
+
llm_provider TEXT DEFAULT '',
|
|
57
|
+
success INTEGER DEFAULT 1,
|
|
58
|
+
error_message TEXT DEFAULT '',
|
|
59
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
60
|
+
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
|
61
|
+
)
|
|
62
|
+
""")
|
|
63
|
+
await db.execute("""
|
|
64
|
+
CREATE INDEX IF NOT EXISTS idx_history_conversation_id
|
|
65
|
+
ON query_history(conversation_id)
|
|
66
|
+
""")
|
|
67
|
+
await db.execute("""
|
|
68
|
+
CREATE INDEX IF NOT EXISTS idx_history_created_at
|
|
69
|
+
ON query_history(created_at DESC)
|
|
70
|
+
""")
|
|
71
|
+
await db.commit()
|
|
72
|
+
|
|
73
|
+
# ============ 对话管理方法 ============
|
|
74
|
+
|
|
75
|
+
async def create_conversation(self, title: str = "") -> int:
|
|
76
|
+
"""创建新对话,返回对话 ID"""
|
|
77
|
+
db = await self._get_db()
|
|
78
|
+
if not title:
|
|
79
|
+
title = "未命名对话"
|
|
80
|
+
|
|
81
|
+
cursor = await db.execute(
|
|
82
|
+
"""INSERT INTO conversations (title, created_at, updated_at)
|
|
83
|
+
VALUES (?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)""",
|
|
84
|
+
(title,),
|
|
85
|
+
)
|
|
86
|
+
await db.commit()
|
|
87
|
+
return cursor.lastrowid
|
|
88
|
+
|
|
89
|
+
async def get_conversations(self, limit: int = 50, offset: int = 0) -> list[dict]:
|
|
90
|
+
"""获取对话列表"""
|
|
91
|
+
db = await self._get_db()
|
|
92
|
+
cursor = await db.execute(
|
|
93
|
+
"""SELECT c.id, c.title, c.created_at, c.updated_at,
|
|
94
|
+
COUNT(q.id) as message_count
|
|
95
|
+
FROM conversations c
|
|
96
|
+
LEFT JOIN query_history q ON c.id = q.conversation_id
|
|
97
|
+
GROUP BY c.id
|
|
98
|
+
ORDER BY c.updated_at DESC
|
|
99
|
+
LIMIT ? OFFSET ?""",
|
|
100
|
+
(limit, offset),
|
|
101
|
+
)
|
|
102
|
+
rows = await cursor.fetchall()
|
|
103
|
+
return [dict(row) for row in rows]
|
|
104
|
+
|
|
105
|
+
async def get_conversation(self, conversation_id: int) -> Optional[dict]:
|
|
106
|
+
"""获取单个对话详情"""
|
|
107
|
+
db = await self._get_db()
|
|
108
|
+
cursor = await db.execute(
|
|
109
|
+
"""SELECT c.id, c.title, c.created_at, c.updated_at,
|
|
110
|
+
COUNT(q.id) as message_count
|
|
111
|
+
FROM conversations c
|
|
112
|
+
LEFT JOIN query_history q ON c.id = q.conversation_id
|
|
113
|
+
WHERE c.id = ?
|
|
114
|
+
GROUP BY c.id""",
|
|
115
|
+
(conversation_id,),
|
|
116
|
+
)
|
|
117
|
+
row = await cursor.fetchone()
|
|
118
|
+
return dict(row) if row else None
|
|
119
|
+
|
|
120
|
+
async def update_conversation(self, conversation_id: int, title: str) -> bool:
|
|
121
|
+
"""更新对话标题"""
|
|
122
|
+
db = await self._get_db()
|
|
123
|
+
cursor = await db.execute(
|
|
124
|
+
"""UPDATE conversations SET title = ?, updated_at = CURRENT_TIMESTAMP
|
|
125
|
+
WHERE id = ?""",
|
|
126
|
+
(title, conversation_id),
|
|
127
|
+
)
|
|
128
|
+
await db.commit()
|
|
129
|
+
return cursor.rowcount > 0
|
|
130
|
+
|
|
131
|
+
async def delete_conversation(self, conversation_id: int) -> bool:
|
|
132
|
+
"""删除对话(会级联删除相关的查询记录)"""
|
|
133
|
+
db = await self._get_db()
|
|
134
|
+
cursor = await db.execute(
|
|
135
|
+
"DELETE FROM conversations WHERE id = ?",
|
|
136
|
+
(conversation_id,),
|
|
137
|
+
)
|
|
138
|
+
await db.commit()
|
|
139
|
+
return cursor.rowcount > 0
|
|
140
|
+
|
|
141
|
+
async def get_conversation_messages(self, conversation_id: int) -> list[dict]:
|
|
142
|
+
"""获取对话中的所有消息"""
|
|
143
|
+
db = await self._get_db()
|
|
144
|
+
cursor = await db.execute(
|
|
145
|
+
"""SELECT q.id, q.question, q.sql, q.result_json, q.db_type,
|
|
146
|
+
q.llm_provider, q.success, q.error_message, q.created_at
|
|
147
|
+
FROM query_history q
|
|
148
|
+
WHERE q.conversation_id = ?
|
|
149
|
+
ORDER BY q.created_at ASC""",
|
|
150
|
+
(conversation_id,),
|
|
151
|
+
)
|
|
152
|
+
rows = await cursor.fetchall()
|
|
153
|
+
return [dict(row) for row in rows]
|
|
154
|
+
|
|
155
|
+
# ============ 查询历史方法 ============
|
|
156
|
+
|
|
157
|
+
async def add_record(
|
|
158
|
+
self,
|
|
159
|
+
question: str,
|
|
160
|
+
sql: str,
|
|
161
|
+
result: Optional[dict] = None,
|
|
162
|
+
db_type: str = "",
|
|
163
|
+
llm_provider: str = "",
|
|
164
|
+
success: bool = True,
|
|
165
|
+
error_message: str = "",
|
|
166
|
+
conversation_id: Optional[int] = None,
|
|
167
|
+
) -> int:
|
|
168
|
+
"""添加一条查询记录,返回记录 ID"""
|
|
169
|
+
db = await self._get_db()
|
|
170
|
+
result_json = json.dumps(result, ensure_ascii=False, default=str) if result else None
|
|
171
|
+
|
|
172
|
+
cursor = await db.execute(
|
|
173
|
+
"""INSERT INTO query_history
|
|
174
|
+
(conversation_id, question, sql, result_json, db_type, llm_provider, success, error_message)
|
|
175
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
176
|
+
(conversation_id, question, sql, result_json, db_type, llm_provider, int(success), error_message),
|
|
177
|
+
)
|
|
178
|
+
await db.commit()
|
|
179
|
+
|
|
180
|
+
# 如果关联了对话,更新对话的 updated_at
|
|
181
|
+
if conversation_id:
|
|
182
|
+
await db.execute(
|
|
183
|
+
"UPDATE conversations SET updated_at = CURRENT_TIMESTAMP WHERE id = ?",
|
|
184
|
+
(conversation_id,),
|
|
185
|
+
)
|
|
186
|
+
await db.commit()
|
|
187
|
+
|
|
188
|
+
return cursor.lastrowid
|
|
189
|
+
|
|
190
|
+
async def get_records(
|
|
191
|
+
self,
|
|
192
|
+
limit: int = 50,
|
|
193
|
+
offset: int = 0,
|
|
194
|
+
conversation_id: Optional[int] = None,
|
|
195
|
+
) -> list[dict]:
|
|
196
|
+
"""获取查询历史记录"""
|
|
197
|
+
db = await self._get_db()
|
|
198
|
+
if conversation_id:
|
|
199
|
+
cursor = await db.execute(
|
|
200
|
+
"""SELECT id, conversation_id, question, sql, result_json, db_type, llm_provider,
|
|
201
|
+
success, error_message, created_at
|
|
202
|
+
FROM query_history
|
|
203
|
+
WHERE conversation_id = ?
|
|
204
|
+
ORDER BY created_at DESC
|
|
205
|
+
LIMIT ? OFFSET ?""",
|
|
206
|
+
(conversation_id, limit, offset),
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
cursor = await db.execute(
|
|
210
|
+
"""SELECT id, conversation_id, question, sql, result_json, db_type, llm_provider,
|
|
211
|
+
success, error_message, created_at
|
|
212
|
+
FROM query_history
|
|
213
|
+
ORDER BY created_at DESC
|
|
214
|
+
LIMIT ? OFFSET ?""",
|
|
215
|
+
(limit, offset),
|
|
216
|
+
)
|
|
217
|
+
rows = await cursor.fetchall()
|
|
218
|
+
return [dict(row) for row in rows]
|
|
219
|
+
|
|
220
|
+
async def get_record(self, record_id: int) -> Optional[dict]:
|
|
221
|
+
"""获取单条记录"""
|
|
222
|
+
db = await self._get_db()
|
|
223
|
+
cursor = await db.execute(
|
|
224
|
+
"SELECT * FROM query_history WHERE id = ?",
|
|
225
|
+
(record_id,),
|
|
226
|
+
)
|
|
227
|
+
row = await cursor.fetchone()
|
|
228
|
+
return dict(row) if row else None
|
|
229
|
+
|
|
230
|
+
async def delete_record(self, record_id: int) -> bool:
|
|
231
|
+
"""删除单条记录"""
|
|
232
|
+
db = await self._get_db()
|
|
233
|
+
cursor = await db.execute(
|
|
234
|
+
"DELETE FROM query_history WHERE id = ?",
|
|
235
|
+
(record_id,),
|
|
236
|
+
)
|
|
237
|
+
await db.commit()
|
|
238
|
+
return cursor.rowcount > 0
|
|
239
|
+
|
|
240
|
+
async def clear_history(self) -> int:
|
|
241
|
+
"""清空所有历史记录(包括所有对话)"""
|
|
242
|
+
db = await self._get_db()
|
|
243
|
+
cursor = await db.execute("DELETE FROM query_history")
|
|
244
|
+
await db.commit()
|
|
245
|
+
await db.execute("DELETE FROM conversations")
|
|
246
|
+
await db.commit()
|
|
247
|
+
return cursor.rowcount
|
|
248
|
+
|
|
249
|
+
async def get_count(self, conversation_id: Optional[int] = None) -> int:
|
|
250
|
+
"""获取记录总数"""
|
|
251
|
+
db = await self._get_db()
|
|
252
|
+
if conversation_id:
|
|
253
|
+
cursor = await db.execute(
|
|
254
|
+
"SELECT COUNT(*) FROM query_history WHERE conversation_id = ?",
|
|
255
|
+
(conversation_id,),
|
|
256
|
+
)
|
|
257
|
+
else:
|
|
258
|
+
cursor = await db.execute("SELECT COUNT(*) FROM query_history")
|
|
259
|
+
row = await cursor.fetchone()
|
|
260
|
+
return row[0] if row else 0
|
|
261
|
+
|
|
262
|
+
async def get_conversation_count(self) -> int:
|
|
263
|
+
"""获取对话总数"""
|
|
264
|
+
db = await self._get_db()
|
|
265
|
+
cursor = await db.execute("SELECT COUNT(*) FROM conversations")
|
|
266
|
+
row = await cursor.fetchone()
|
|
267
|
+
return row[0] if row else 0
|
|
268
|
+
|
|
269
|
+
async def close(self):
|
|
270
|
+
if self._db:
|
|
271
|
+
await self._db.close()
|
|
272
|
+
self._db = None
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# 全局单例
|
|
276
|
+
_history_manager: Optional[HistoryManager] = None
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def get_history_manager() -> HistoryManager:
|
|
280
|
+
global _history_manager
|
|
281
|
+
if _history_manager is None:
|
|
282
|
+
_history_manager = HistoryManager()
|
|
283
|
+
return _history_manager
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
async def close_history_manager():
|
|
287
|
+
global _history_manager
|
|
288
|
+
if _history_manager:
|
|
289
|
+
await _history_manager.close()
|
|
290
|
+
_history_manager = None
|