sqlsaber 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlsaber might be problematic. Click here for more details.
- sqlsaber/agents/base.py +0 -108
- sqlsaber/cli/commands.py +34 -2
- sqlsaber/cli/display.py +0 -30
- sqlsaber/cli/interactive.py +76 -24
- sqlsaber/cli/streaming.py +0 -4
- sqlsaber/cli/threads.py +301 -0
- sqlsaber/database/schema.py +30 -2
- sqlsaber/threads/__init__.py +5 -0
- sqlsaber/threads/storage.py +303 -0
- sqlsaber/tools/__init__.py +0 -2
- sqlsaber/tools/base.py +0 -12
- sqlsaber/tools/enums.py +0 -2
- sqlsaber/tools/instructions.py +3 -23
- sqlsaber/tools/registry.py +0 -12
- {sqlsaber-0.16.0.dist-info → sqlsaber-0.17.0.dist-info}/METADATA +12 -3
- {sqlsaber-0.16.0.dist-info → sqlsaber-0.17.0.dist-info}/RECORD +19 -23
- sqlsaber/conversation/__init__.py +0 -12
- sqlsaber/conversation/manager.py +0 -224
- sqlsaber/conversation/models.py +0 -120
- sqlsaber/conversation/storage.py +0 -362
- sqlsaber/models/__init__.py +0 -10
- sqlsaber/models/types.py +0 -40
- sqlsaber/tools/visualization_tools.py +0 -144
- {sqlsaber-0.16.0.dist-info → sqlsaber-0.17.0.dist-info}/WHEEL +0 -0
- {sqlsaber-0.16.0.dist-info → sqlsaber-0.17.0.dist-info}/entry_points.txt +0 -0
- {sqlsaber-0.16.0.dist-info → sqlsaber-0.17.0.dist-info}/licenses/LICENSE +0 -0
sqlsaber/cli/threads.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
"""Threads CLI: list, show, and resume threads (pydantic-ai message snapshots)."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
from typing import Annotated
|
|
7
|
+
|
|
8
|
+
import cyclopts
|
|
9
|
+
from pydantic_ai.messages import ModelMessage
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
from rich.markdown import Markdown
|
|
12
|
+
from rich.panel import Panel
|
|
13
|
+
from rich.table import Table
|
|
14
|
+
|
|
15
|
+
from sqlsaber.agents import build_sqlsaber_agent
|
|
16
|
+
from sqlsaber.cli.display import DisplayManager
|
|
17
|
+
from sqlsaber.cli.interactive import InteractiveSession
|
|
18
|
+
from sqlsaber.config.database import DatabaseConfigManager
|
|
19
|
+
from sqlsaber.database.connection import DatabaseConnection
|
|
20
|
+
from sqlsaber.database.resolver import DatabaseResolutionError, resolve_database
|
|
21
|
+
from sqlsaber.threads import ThreadStorage
|
|
22
|
+
|
|
23
|
+
# Globals consistent with other CLI modules
|
|
24
|
+
console = Console()
|
|
25
|
+
config_manager = DatabaseConfigManager()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
threads_app = cyclopts.App(
|
|
29
|
+
name="threads",
|
|
30
|
+
help="Manage SQLsaber threads",
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _human_readable(timestamp: float | None) -> str:
|
|
35
|
+
if not timestamp:
|
|
36
|
+
return "-"
|
|
37
|
+
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _render_transcript(
|
|
41
|
+
console: Console, all_msgs: list[ModelMessage], last_n: int | None = None
|
|
42
|
+
) -> None:
|
|
43
|
+
"""Render conversation turns from ModelMessage[] using DisplayManager."""
|
|
44
|
+
dm = DisplayManager(console)
|
|
45
|
+
|
|
46
|
+
# Locate indices of user prompts
|
|
47
|
+
user_indices: list[int] = []
|
|
48
|
+
for idx, message in enumerate(all_msgs):
|
|
49
|
+
for part in getattr(message, "parts", []):
|
|
50
|
+
if getattr(part, "part_kind", "") == "user-prompt":
|
|
51
|
+
user_indices.append(idx)
|
|
52
|
+
break
|
|
53
|
+
|
|
54
|
+
# Build turn slices as (start_idx, end_idx)
|
|
55
|
+
slices: list[tuple[int, int]] = []
|
|
56
|
+
if user_indices:
|
|
57
|
+
for i, start_idx in enumerate(user_indices):
|
|
58
|
+
end_idx = (
|
|
59
|
+
user_indices[i + 1] if i + 1 < len(user_indices) else len(all_msgs)
|
|
60
|
+
)
|
|
61
|
+
slices.append((start_idx, end_idx))
|
|
62
|
+
|
|
63
|
+
if last_n is not None and last_n > 0 and slices:
|
|
64
|
+
slices = slices[-last_n:]
|
|
65
|
+
|
|
66
|
+
def _render_user(message: ModelMessage) -> None:
|
|
67
|
+
for part in getattr(message, "parts", []):
|
|
68
|
+
if getattr(part, "part_kind", "") == "user-prompt":
|
|
69
|
+
content = getattr(part, "content", None)
|
|
70
|
+
text: str | None = None
|
|
71
|
+
if isinstance(content, str):
|
|
72
|
+
text = content
|
|
73
|
+
elif isinstance(content, list): # multimodal
|
|
74
|
+
parts: list[str] = []
|
|
75
|
+
for seg in content:
|
|
76
|
+
if isinstance(seg, str):
|
|
77
|
+
parts.append(seg)
|
|
78
|
+
else:
|
|
79
|
+
try:
|
|
80
|
+
parts.append(json.dumps(seg, ensure_ascii=False))
|
|
81
|
+
except Exception:
|
|
82
|
+
parts.append(str(seg))
|
|
83
|
+
text = "\n".join([s for s in parts if s]) or None
|
|
84
|
+
if text:
|
|
85
|
+
console.print(
|
|
86
|
+
Panel.fit(Markdown(text), title="User", border_style="cyan")
|
|
87
|
+
)
|
|
88
|
+
return
|
|
89
|
+
console.print(Panel.fit("(no content)", title="User", border_style="cyan"))
|
|
90
|
+
|
|
91
|
+
def _render_response(message: ModelMessage) -> None:
|
|
92
|
+
for part in getattr(message, "parts", []):
|
|
93
|
+
kind = getattr(part, "part_kind", "")
|
|
94
|
+
if kind == "text":
|
|
95
|
+
text = getattr(part, "content", "")
|
|
96
|
+
if isinstance(text, str) and text.strip():
|
|
97
|
+
console.print(
|
|
98
|
+
Panel.fit(
|
|
99
|
+
Markdown(text), title="Assistant", border_style="green"
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
elif kind in ("tool-call", "builtin-tool-call"):
|
|
103
|
+
name = getattr(part, "tool_name", "tool")
|
|
104
|
+
args = getattr(part, "args", None)
|
|
105
|
+
args_dict: dict = {}
|
|
106
|
+
if isinstance(args, dict):
|
|
107
|
+
args_dict = args
|
|
108
|
+
elif isinstance(args, str):
|
|
109
|
+
try:
|
|
110
|
+
parsed = json.loads(args)
|
|
111
|
+
if isinstance(parsed, dict):
|
|
112
|
+
args_dict = parsed
|
|
113
|
+
except Exception:
|
|
114
|
+
args_dict = {}
|
|
115
|
+
dm.show_tool_executing(name, args_dict)
|
|
116
|
+
elif kind in ("tool-return", "builtin-tool-return"):
|
|
117
|
+
name = getattr(part, "tool_name", "tool")
|
|
118
|
+
content = getattr(part, "content", None)
|
|
119
|
+
if isinstance(content, (dict, list)):
|
|
120
|
+
content_str = json.dumps(content, ensure_ascii=False)
|
|
121
|
+
elif isinstance(content, str):
|
|
122
|
+
content_str = content
|
|
123
|
+
else:
|
|
124
|
+
content_str = json.dumps({"return_value": str(content)})
|
|
125
|
+
if name == "list_tables":
|
|
126
|
+
dm.show_table_list(content_str)
|
|
127
|
+
elif name == "introspect_schema":
|
|
128
|
+
dm.show_schema_info(content_str)
|
|
129
|
+
elif name == "execute_sql":
|
|
130
|
+
try:
|
|
131
|
+
data = json.loads(content_str)
|
|
132
|
+
if (
|
|
133
|
+
isinstance(data, dict)
|
|
134
|
+
and data.get("success")
|
|
135
|
+
and data.get("results")
|
|
136
|
+
):
|
|
137
|
+
dm.show_query_results(data["results"]) # type: ignore[arg-type]
|
|
138
|
+
else:
|
|
139
|
+
console.print(
|
|
140
|
+
Panel.fit(
|
|
141
|
+
content_str,
|
|
142
|
+
title=f"Tool result: {name}",
|
|
143
|
+
border_style="yellow",
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
except Exception:
|
|
147
|
+
console.print(
|
|
148
|
+
Panel.fit(
|
|
149
|
+
content_str,
|
|
150
|
+
title=f"Tool result: {name}",
|
|
151
|
+
border_style="yellow",
|
|
152
|
+
)
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
console.print(
|
|
156
|
+
Panel.fit(
|
|
157
|
+
content_str,
|
|
158
|
+
title=f"Tool result: {name}",
|
|
159
|
+
border_style="yellow",
|
|
160
|
+
)
|
|
161
|
+
)
|
|
162
|
+
# Thinking parts omitted
|
|
163
|
+
|
|
164
|
+
for start_idx, end_idx in slices or [(0, len(all_msgs))]:
|
|
165
|
+
if start_idx < len(all_msgs):
|
|
166
|
+
_render_user(all_msgs[start_idx])
|
|
167
|
+
for i in range(start_idx + 1, end_idx):
|
|
168
|
+
_render_response(all_msgs[i])
|
|
169
|
+
console.print("")
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@threads_app.command(name="list")
|
|
173
|
+
def list_threads(
|
|
174
|
+
database: Annotated[
|
|
175
|
+
str | None,
|
|
176
|
+
cyclopts.Parameter(["--database", "-d"], help="Filter by database name"),
|
|
177
|
+
] = None,
|
|
178
|
+
limit: Annotated[
|
|
179
|
+
int,
|
|
180
|
+
cyclopts.Parameter(["--limit", "-n"], help="Max threads to return"),
|
|
181
|
+
] = 50,
|
|
182
|
+
):
|
|
183
|
+
"""List threads (optionally filtered by database)."""
|
|
184
|
+
store = ThreadStorage()
|
|
185
|
+
threads = asyncio.run(store.list_threads(database_name=database, limit=limit))
|
|
186
|
+
if not threads:
|
|
187
|
+
console.print("No threads found.")
|
|
188
|
+
return
|
|
189
|
+
table = Table(title="Threads")
|
|
190
|
+
table.add_column("ID", style="cyan")
|
|
191
|
+
table.add_column("Database", style="magenta")
|
|
192
|
+
table.add_column("Title", style="green")
|
|
193
|
+
table.add_column("Last Activity", style="dim")
|
|
194
|
+
table.add_column("Model", style="yellow")
|
|
195
|
+
for t in threads:
|
|
196
|
+
table.add_row(
|
|
197
|
+
t.id,
|
|
198
|
+
t.database_name or "-",
|
|
199
|
+
(t.title or "-")[:60],
|
|
200
|
+
_human_readable(getattr(t, "last_activity_at", None)),
|
|
201
|
+
t.model_name or "-",
|
|
202
|
+
)
|
|
203
|
+
console.print(table)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@threads_app.command
|
|
207
|
+
def show(
|
|
208
|
+
thread_id: Annotated[str, cyclopts.Parameter(help="Thread ID")],
|
|
209
|
+
):
|
|
210
|
+
"""Show thread metadata and render the full transcript."""
|
|
211
|
+
store = ThreadStorage()
|
|
212
|
+
thread = asyncio.run(store.get_thread(thread_id))
|
|
213
|
+
if not thread:
|
|
214
|
+
console.print(f"[red]Thread not found:[/red] {thread_id}")
|
|
215
|
+
return
|
|
216
|
+
msgs = asyncio.run(store.get_thread_messages(thread_id))
|
|
217
|
+
console.print(f"[bold]Thread: {thread.id}[/bold]")
|
|
218
|
+
console.print("")
|
|
219
|
+
console.print(f"Database: {thread.database_name}")
|
|
220
|
+
console.print(f"Title: {thread.title}")
|
|
221
|
+
console.print(f"Last activity: {_human_readable(thread.last_activity_at)}")
|
|
222
|
+
console.print(f"Model: {thread.model_name}")
|
|
223
|
+
console.print("")
|
|
224
|
+
|
|
225
|
+
_render_transcript(console, msgs, None)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@threads_app.command
|
|
229
|
+
def resume(
|
|
230
|
+
thread_id: Annotated[str, cyclopts.Parameter(help="Thread ID to resume")],
|
|
231
|
+
database: Annotated[
|
|
232
|
+
str | None,
|
|
233
|
+
cyclopts.Parameter(["--database", "-d"], help="Database name or DSN override"),
|
|
234
|
+
] = None,
|
|
235
|
+
):
|
|
236
|
+
"""Render transcript, then resume thread in interactive mode."""
|
|
237
|
+
store = ThreadStorage()
|
|
238
|
+
|
|
239
|
+
async def _run() -> None:
|
|
240
|
+
thread = await store.get_thread(thread_id)
|
|
241
|
+
if not thread:
|
|
242
|
+
console.print(f"[red]Thread not found:[/red] {thread_id}")
|
|
243
|
+
return
|
|
244
|
+
db_selector = database or thread.database_name
|
|
245
|
+
if not db_selector:
|
|
246
|
+
console.print(
|
|
247
|
+
"[red]No database specified or stored with this thread.[/red]"
|
|
248
|
+
)
|
|
249
|
+
return
|
|
250
|
+
try:
|
|
251
|
+
resolved = resolve_database(db_selector, config_manager)
|
|
252
|
+
connection_string = resolved.connection_string
|
|
253
|
+
db_name = resolved.name
|
|
254
|
+
except DatabaseResolutionError as e:
|
|
255
|
+
console.print(f"[red]Database resolution error:[/red] {e}")
|
|
256
|
+
return
|
|
257
|
+
|
|
258
|
+
db_conn = DatabaseConnection(connection_string)
|
|
259
|
+
try:
|
|
260
|
+
agent = build_sqlsaber_agent(db_conn, db_name)
|
|
261
|
+
history = await store.get_thread_messages(thread_id)
|
|
262
|
+
console.print(Panel.fit(f"Thread: {thread.id}", border_style="blue"))
|
|
263
|
+
_render_transcript(console, history, None)
|
|
264
|
+
session = InteractiveSession(
|
|
265
|
+
console=console,
|
|
266
|
+
agent=agent,
|
|
267
|
+
db_conn=db_conn,
|
|
268
|
+
database_name=db_name,
|
|
269
|
+
initial_thread_id=thread_id,
|
|
270
|
+
initial_history=history,
|
|
271
|
+
)
|
|
272
|
+
await session.run()
|
|
273
|
+
finally:
|
|
274
|
+
await db_conn.close()
|
|
275
|
+
console.print("\n[green]Goodbye![/green]")
|
|
276
|
+
|
|
277
|
+
asyncio.run(_run())
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
@threads_app.command
|
|
281
|
+
def prune(
|
|
282
|
+
days: Annotated[
|
|
283
|
+
int,
|
|
284
|
+
cyclopts.Parameter(
|
|
285
|
+
["--days", "-n"], help="Delete threads older than this many days"
|
|
286
|
+
),
|
|
287
|
+
] = 30,
|
|
288
|
+
):
|
|
289
|
+
"""Prune old threads by last activity timestamp."""
|
|
290
|
+
store = ThreadStorage()
|
|
291
|
+
|
|
292
|
+
async def _run() -> None:
|
|
293
|
+
deleted = await store.prune_threads(older_than_days=days)
|
|
294
|
+
console.print(f"[green]✓ Pruned {deleted} thread(s).[/green]")
|
|
295
|
+
|
|
296
|
+
asyncio.run(_run())
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def create_threads_app() -> cyclopts.App:
|
|
300
|
+
"""Return the threads sub-app (for registration)."""
|
|
301
|
+
return threads_app
|
sqlsaber/database/schema.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Database schema introspection utilities."""
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from typing import Any
|
|
4
|
+
from typing import Any, TypedDict
|
|
5
5
|
|
|
6
6
|
import aiosqlite
|
|
7
7
|
|
|
@@ -12,7 +12,35 @@ from sqlsaber.database.connection import (
|
|
|
12
12
|
PostgreSQLConnection,
|
|
13
13
|
SQLiteConnection,
|
|
14
14
|
)
|
|
15
|
-
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ColumnInfo(TypedDict):
|
|
18
|
+
"""Type definition for column information."""
|
|
19
|
+
|
|
20
|
+
data_type: str
|
|
21
|
+
nullable: bool
|
|
22
|
+
default: str | None
|
|
23
|
+
max_length: int | None
|
|
24
|
+
precision: int | None
|
|
25
|
+
scale: int | None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ForeignKeyInfo(TypedDict):
|
|
29
|
+
"""Type definition for foreign key information."""
|
|
30
|
+
|
|
31
|
+
column: str
|
|
32
|
+
references: dict[str, str] # {"table": "schema.table", "column": "column_name"}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SchemaInfo(TypedDict):
|
|
36
|
+
"""Type definition for schema information."""
|
|
37
|
+
|
|
38
|
+
schema: str
|
|
39
|
+
name: str
|
|
40
|
+
type: str
|
|
41
|
+
columns: dict[str, ColumnInfo]
|
|
42
|
+
primary_keys: list[str]
|
|
43
|
+
foreign_keys: list[ForeignKeyInfo]
|
|
16
44
|
|
|
17
45
|
|
|
18
46
|
class BaseSchemaIntrospector(ABC):
|
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
"""SQLite storage for pydantic-ai thread snapshots.
|
|
2
|
+
|
|
3
|
+
Each thread represents a session (interactive or non-interactive) and stores the
|
|
4
|
+
complete pydantic-ai message history as a snapshot. On every completed run in the
|
|
5
|
+
same session, we overwrite the snapshot with the new full history.
|
|
6
|
+
|
|
7
|
+
This design intentionally avoids per-run append logs and mirrors pydantic-ai's
|
|
8
|
+
recommended approach of serializing ModelMessage[] with ModelMessagesTypeAdapter.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import logging
|
|
13
|
+
import time
|
|
14
|
+
import uuid
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import aiosqlite
|
|
19
|
+
import platformdirs
|
|
20
|
+
from pydantic_ai.messages import ModelMessage, ModelMessagesTypeAdapter
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
SCHEMA_SQL = """
|
|
26
|
+
CREATE TABLE IF NOT EXISTS threads (
|
|
27
|
+
id TEXT PRIMARY KEY,
|
|
28
|
+
database_name TEXT,
|
|
29
|
+
title TEXT,
|
|
30
|
+
created_at REAL NOT NULL,
|
|
31
|
+
ended_at REAL,
|
|
32
|
+
last_activity_at REAL NOT NULL,
|
|
33
|
+
model_name TEXT,
|
|
34
|
+
messages_json BLOB NOT NULL,
|
|
35
|
+
extra_metadata TEXT
|
|
36
|
+
);
|
|
37
|
+
|
|
38
|
+
CREATE INDEX IF NOT EXISTS idx_threads_dbname ON threads(database_name);
|
|
39
|
+
CREATE INDEX IF NOT EXISTS idx_threads_activity ON threads(last_activity_at);
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Thread:
|
|
44
|
+
"""Thread metadata."""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
id: str,
|
|
49
|
+
database_name: str | None,
|
|
50
|
+
title: str | None,
|
|
51
|
+
created_at: float,
|
|
52
|
+
ended_at: float | None,
|
|
53
|
+
last_activity_at: float,
|
|
54
|
+
model_name: str | None,
|
|
55
|
+
) -> None:
|
|
56
|
+
self.id = id
|
|
57
|
+
self.database_name = database_name
|
|
58
|
+
self.title = title
|
|
59
|
+
self.created_at = created_at
|
|
60
|
+
self.ended_at = ended_at
|
|
61
|
+
self.last_activity_at = last_activity_at
|
|
62
|
+
self.model_name = model_name
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ThreadStorage:
|
|
66
|
+
"""Handles SQLite storage for pydantic-ai thread snapshots."""
|
|
67
|
+
|
|
68
|
+
def __init__(self) -> None:
|
|
69
|
+
self.db_path = Path(platformdirs.user_config_dir("sqlsaber")) / "threads.db"
|
|
70
|
+
self._lock = asyncio.Lock()
|
|
71
|
+
self._initialized = False
|
|
72
|
+
|
|
73
|
+
async def _init_db(self) -> None:
|
|
74
|
+
if self._initialized:
|
|
75
|
+
return
|
|
76
|
+
try:
|
|
77
|
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
78
|
+
async with aiosqlite.connect(self.db_path) as db:
|
|
79
|
+
await db.executescript(SCHEMA_SQL)
|
|
80
|
+
await db.commit()
|
|
81
|
+
self._initialized = True
|
|
82
|
+
logger.debug("Initialized threads database at %s", self.db_path)
|
|
83
|
+
except Exception as e: # pragma: no cover - best-effort persistence
|
|
84
|
+
logger.warning("Failed to initialize threads DB: %s", e)
|
|
85
|
+
|
|
86
|
+
async def save_snapshot(
|
|
87
|
+
self,
|
|
88
|
+
*,
|
|
89
|
+
messages_json: bytes,
|
|
90
|
+
database_name: str | None,
|
|
91
|
+
thread_id: str | None = None,
|
|
92
|
+
extra_metadata: str | None = None,
|
|
93
|
+
) -> str:
|
|
94
|
+
"""Create or update a thread snapshot."""
|
|
95
|
+
await self._init_db()
|
|
96
|
+
now = time.time()
|
|
97
|
+
|
|
98
|
+
if thread_id is None:
|
|
99
|
+
thread_id = str(uuid.uuid4())
|
|
100
|
+
try:
|
|
101
|
+
async with self._lock, aiosqlite.connect(self.db_path) as db:
|
|
102
|
+
await db.execute(
|
|
103
|
+
"""
|
|
104
|
+
INSERT INTO threads (
|
|
105
|
+
id, database_name, created_at,
|
|
106
|
+
last_activity_at, messages_json, extra_metadata
|
|
107
|
+
) VALUES (?, ?, ?, ?, ?, ?)
|
|
108
|
+
""",
|
|
109
|
+
(
|
|
110
|
+
thread_id,
|
|
111
|
+
database_name,
|
|
112
|
+
now,
|
|
113
|
+
now,
|
|
114
|
+
messages_json,
|
|
115
|
+
extra_metadata,
|
|
116
|
+
),
|
|
117
|
+
)
|
|
118
|
+
await db.commit()
|
|
119
|
+
logger.debug("Created thread %s", thread_id)
|
|
120
|
+
return thread_id
|
|
121
|
+
except Exception as e: # pragma: no cover
|
|
122
|
+
logger.warning("Failed to create thread: %s", e)
|
|
123
|
+
return thread_id
|
|
124
|
+
else:
|
|
125
|
+
try:
|
|
126
|
+
async with self._lock, aiosqlite.connect(self.db_path) as db:
|
|
127
|
+
await db.execute(
|
|
128
|
+
"""
|
|
129
|
+
UPDATE threads
|
|
130
|
+
SET last_activity_at = ?,
|
|
131
|
+
messages_json = ?,
|
|
132
|
+
extra_metadata = COALESCE(?, extra_metadata)
|
|
133
|
+
WHERE id = ?
|
|
134
|
+
""",
|
|
135
|
+
(
|
|
136
|
+
now,
|
|
137
|
+
messages_json,
|
|
138
|
+
extra_metadata,
|
|
139
|
+
thread_id,
|
|
140
|
+
),
|
|
141
|
+
)
|
|
142
|
+
await db.commit()
|
|
143
|
+
logger.debug("Updated thread %s snapshot", thread_id)
|
|
144
|
+
return thread_id
|
|
145
|
+
except Exception as e: # pragma: no cover
|
|
146
|
+
logger.warning("Failed to update thread %s: %s", thread_id, e)
|
|
147
|
+
return thread_id
|
|
148
|
+
|
|
149
|
+
async def save_metadata(
|
|
150
|
+
self,
|
|
151
|
+
*,
|
|
152
|
+
thread_id: str,
|
|
153
|
+
title: str | None = None,
|
|
154
|
+
model_name: str | None = None,
|
|
155
|
+
) -> bool:
|
|
156
|
+
"""Update thread metadata (title/model/extra). Only provided fields are updated."""
|
|
157
|
+
await self._init_db()
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
async with self._lock, aiosqlite.connect(self.db_path) as db:
|
|
161
|
+
await db.execute(
|
|
162
|
+
"""
|
|
163
|
+
UPDATE threads
|
|
164
|
+
SET title = ?, model_name = ?
|
|
165
|
+
WHERE id = ?
|
|
166
|
+
""",
|
|
167
|
+
(title, model_name, thread_id),
|
|
168
|
+
)
|
|
169
|
+
await db.commit()
|
|
170
|
+
return True
|
|
171
|
+
except Exception as e: # pragma: no cover
|
|
172
|
+
logger.warning("Failed to update metadata for thread %s: %s", thread_id, e)
|
|
173
|
+
return False
|
|
174
|
+
|
|
175
|
+
async def end_thread(self, thread_id: str) -> bool:
|
|
176
|
+
await self._init_db()
|
|
177
|
+
try:
|
|
178
|
+
async with self._lock, aiosqlite.connect(self.db_path) as db:
|
|
179
|
+
await db.execute(
|
|
180
|
+
"UPDATE threads SET ended_at = ?, last_activity_at = ? WHERE id = ?",
|
|
181
|
+
(time.time(), time.time(), thread_id),
|
|
182
|
+
)
|
|
183
|
+
await db.commit()
|
|
184
|
+
return True
|
|
185
|
+
except Exception as e: # pragma: no cover
|
|
186
|
+
logger.warning("Failed to end thread %s: %s", thread_id, e)
|
|
187
|
+
return False
|
|
188
|
+
|
|
189
|
+
async def get_thread(self, thread_id: str) -> Thread | None:
|
|
190
|
+
await self._init_db()
|
|
191
|
+
try:
|
|
192
|
+
async with aiosqlite.connect(self.db_path) as db:
|
|
193
|
+
async with db.execute(
|
|
194
|
+
"""
|
|
195
|
+
SELECT id, database_name, title, created_at, ended_at,
|
|
196
|
+
last_activity_at, model_name
|
|
197
|
+
FROM threads WHERE id = ?
|
|
198
|
+
""",
|
|
199
|
+
(thread_id,),
|
|
200
|
+
) as cur:
|
|
201
|
+
row = await cur.fetchone()
|
|
202
|
+
if not row:
|
|
203
|
+
return None
|
|
204
|
+
return Thread(
|
|
205
|
+
id=row[0],
|
|
206
|
+
database_name=row[1],
|
|
207
|
+
title=row[2],
|
|
208
|
+
created_at=row[3],
|
|
209
|
+
ended_at=row[4],
|
|
210
|
+
last_activity_at=row[5],
|
|
211
|
+
model_name=row[6],
|
|
212
|
+
)
|
|
213
|
+
except Exception as e: # pragma: no cover
|
|
214
|
+
logger.warning("Failed to get thread %s: %s", thread_id, e)
|
|
215
|
+
return None
|
|
216
|
+
|
|
217
|
+
async def get_thread_messages(self, thread_id: str) -> list[ModelMessage]:
|
|
218
|
+
"""Load the full message history for a thread as ModelMessage[]."""
|
|
219
|
+
await self._init_db()
|
|
220
|
+
try:
|
|
221
|
+
async with aiosqlite.connect(self.db_path) as db:
|
|
222
|
+
async with db.execute(
|
|
223
|
+
"SELECT messages_json FROM threads WHERE id = ?",
|
|
224
|
+
(thread_id,),
|
|
225
|
+
) as cur:
|
|
226
|
+
row = await cur.fetchone()
|
|
227
|
+
if not row:
|
|
228
|
+
return []
|
|
229
|
+
messages_blob: bytes = row[0]
|
|
230
|
+
return ModelMessagesTypeAdapter.validate_json(messages_blob)
|
|
231
|
+
except Exception as e: # pragma: no cover
|
|
232
|
+
logger.warning("Failed to load thread %s messages: %s", thread_id, e)
|
|
233
|
+
return []
|
|
234
|
+
|
|
235
|
+
async def list_threads(
|
|
236
|
+
self, *, database_name: str | None = None, limit: int = 50
|
|
237
|
+
) -> list[Thread]:
|
|
238
|
+
await self._init_db()
|
|
239
|
+
try:
|
|
240
|
+
query = (
|
|
241
|
+
"SELECT id, database_name, title, created_at, ended_at, last_activity_at, model_name"
|
|
242
|
+
" FROM threads"
|
|
243
|
+
)
|
|
244
|
+
params: list[Any] = []
|
|
245
|
+
if database_name:
|
|
246
|
+
query += " WHERE database_name = ?"
|
|
247
|
+
params.append(database_name)
|
|
248
|
+
query += " ORDER BY last_activity_at DESC LIMIT ?"
|
|
249
|
+
params.append(limit)
|
|
250
|
+
|
|
251
|
+
async with aiosqlite.connect(self.db_path) as db:
|
|
252
|
+
async with db.execute(query, params) as cur:
|
|
253
|
+
threads: list[Thread] = []
|
|
254
|
+
async for row in cur:
|
|
255
|
+
threads.append(
|
|
256
|
+
Thread(
|
|
257
|
+
id=row[0],
|
|
258
|
+
database_name=row[1],
|
|
259
|
+
title=row[2],
|
|
260
|
+
created_at=row[3],
|
|
261
|
+
ended_at=row[4],
|
|
262
|
+
last_activity_at=row[5],
|
|
263
|
+
model_name=row[6],
|
|
264
|
+
)
|
|
265
|
+
)
|
|
266
|
+
return threads
|
|
267
|
+
except Exception as e: # pragma: no cover
|
|
268
|
+
logger.warning("Failed to list threads: %s", e)
|
|
269
|
+
return []
|
|
270
|
+
|
|
271
|
+
async def delete_thread(self, thread_id: str) -> bool:
|
|
272
|
+
await self._init_db()
|
|
273
|
+
try:
|
|
274
|
+
async with self._lock, aiosqlite.connect(self.db_path) as db:
|
|
275
|
+
cur = await db.execute("DELETE FROM threads WHERE id = ?", (thread_id,))
|
|
276
|
+
await db.commit()
|
|
277
|
+
return cur.rowcount > 0
|
|
278
|
+
except Exception as e: # pragma: no cover
|
|
279
|
+
logger.warning("Failed to delete thread %s: %s", thread_id, e)
|
|
280
|
+
return False
|
|
281
|
+
|
|
282
|
+
async def prune_threads(self, older_than_days: int = 30) -> int:
|
|
283
|
+
"""Delete threads whose last_activity_at is older than the cutoff.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
older_than_days: Threads with last_activity_at older than this many days ago will be deleted.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
The number of rows deleted (best-effort; 0 on failure).
|
|
290
|
+
"""
|
|
291
|
+
await self._init_db()
|
|
292
|
+
cutoff = time.time() - older_than_days * 24 * 3600
|
|
293
|
+
try:
|
|
294
|
+
async with self._lock, aiosqlite.connect(self.db_path) as db:
|
|
295
|
+
cur = await db.execute(
|
|
296
|
+
"DELETE FROM threads WHERE last_activity_at < ?",
|
|
297
|
+
(cutoff,),
|
|
298
|
+
)
|
|
299
|
+
await db.commit()
|
|
300
|
+
return cur.rowcount or 0
|
|
301
|
+
except Exception as e: # pragma: no cover
|
|
302
|
+
logger.warning("Failed to prune threads: %s", e)
|
|
303
|
+
return 0
|
sqlsaber/tools/__init__.py
CHANGED
|
@@ -7,7 +7,6 @@ from .registry import ToolRegistry, register_tool, tool_registry
|
|
|
7
7
|
|
|
8
8
|
# Import concrete tools to register them
|
|
9
9
|
from .sql_tools import ExecuteSQLTool, IntrospectSchemaTool, ListTablesTool, SQLTool
|
|
10
|
-
from .visualization_tools import PlotDataTool
|
|
11
10
|
|
|
12
11
|
__all__ = [
|
|
13
12
|
"Tool",
|
|
@@ -21,5 +20,4 @@ __all__ = [
|
|
|
21
20
|
"ListTablesTool",
|
|
22
21
|
"IntrospectSchemaTool",
|
|
23
22
|
"ExecuteSQLTool",
|
|
24
|
-
"PlotDataTool",
|
|
25
23
|
]
|
sqlsaber/tools/base.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Base class for SQLSaber tools."""
|
|
2
2
|
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from types import SimpleNamespace
|
|
5
4
|
from typing import Any
|
|
6
5
|
|
|
7
6
|
from .enums import ToolCategory, WorkflowPosition
|
|
@@ -44,17 +43,6 @@ class Tool(ABC):
|
|
|
44
43
|
"""
|
|
45
44
|
pass
|
|
46
45
|
|
|
47
|
-
def to_definition(self):
|
|
48
|
-
"""Convert this tool to a ToolDefinition-like object with attributes.
|
|
49
|
-
|
|
50
|
-
Tests expect attribute access (definition.name), so return a SimpleNamespace.
|
|
51
|
-
"""
|
|
52
|
-
return SimpleNamespace(
|
|
53
|
-
name=self.name,
|
|
54
|
-
description=self.description,
|
|
55
|
-
input_schema=self.input_schema,
|
|
56
|
-
)
|
|
57
|
-
|
|
58
46
|
@property
|
|
59
47
|
def category(self) -> ToolCategory:
|
|
60
48
|
"""Return the tool category. Override to customize."""
|