gora-cli 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gora/__init__.py +5 -0
- gora/__main__.py +13 -0
- gora/cli.py +459 -0
- gora/go_tui/go.mod +42 -0
- gora/go_tui/go.sum +106 -0
- gora/go_tui/main.go +2634 -0
- gora/go_tui/main_test.go +626 -0
- gora/parsers.py +626 -0
- gora/store.py +935 -0
- gora/tui.py +115 -0
- gora_cli-0.1.2.dist-info/METADATA +282 -0
- gora_cli-0.1.2.dist-info/RECORD +14 -0
- gora_cli-0.1.2.dist-info/WHEEL +4 -0
- gora_cli-0.1.2.dist-info/entry_points.txt +2 -0
gora/store.py
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import re
|
|
7
|
+
import sqlite3
|
|
8
|
+
import sys
|
|
9
|
+
from typing import Iterable
|
|
10
|
+
|
|
11
|
+
from .parsers import ChatMessage, ChatSession, is_image_reference_text, is_title_noise_text
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
PARSER_VERSION = 2
|
|
15
|
+
DEFAULT_ROLES: tuple[str, ...] | None = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def default_db_path() -> Path:
|
|
19
|
+
override = os.environ.get("GORA_DB")
|
|
20
|
+
if override:
|
|
21
|
+
return Path(override).expanduser()
|
|
22
|
+
|
|
23
|
+
xdg_data_home = os.environ.get("XDG_DATA_HOME")
|
|
24
|
+
if xdg_data_home:
|
|
25
|
+
return Path(xdg_data_home).expanduser() / "gora" / "history.sqlite"
|
|
26
|
+
|
|
27
|
+
if sys.platform == "darwin":
|
|
28
|
+
return Path.home() / "Library" / "Application Support" / "gora" / "history.sqlite"
|
|
29
|
+
|
|
30
|
+
return Path.home() / ".local" / "share" / "gora" / "history.sqlite"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def connect(db_path: Path | None = None) -> sqlite3.Connection:
|
|
34
|
+
path = db_path or default_db_path()
|
|
35
|
+
_prepare_db_path(path, secure_parent=db_path is None and not os.environ.get("GORA_DB"))
|
|
36
|
+
connection = sqlite3.connect(path)
|
|
37
|
+
connection.row_factory = sqlite3.Row
|
|
38
|
+
init_db(connection)
|
|
39
|
+
_secure_sqlite_files(path)
|
|
40
|
+
return connection
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _prepare_db_path(path: Path, *, secure_parent: bool) -> None:
|
|
44
|
+
if secure_parent:
|
|
45
|
+
path.parent.mkdir(parents=True, mode=0o700, exist_ok=True)
|
|
46
|
+
_chmod(path.parent, 0o700)
|
|
47
|
+
else:
|
|
48
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
49
|
+
|
|
50
|
+
if path.exists():
|
|
51
|
+
_chmod(path, 0o600)
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
descriptor = os.open(path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
|
|
55
|
+
os.close(descriptor)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _secure_sqlite_files(path: Path) -> None:
|
|
59
|
+
for candidate in (path, path.with_name(f"{path.name}-journal"), path.with_name(f"{path.name}-wal"), path.with_name(f"{path.name}-shm")):
|
|
60
|
+
if candidate.exists():
|
|
61
|
+
_chmod(candidate, 0o600)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _chmod(path: Path, mode: int) -> None:
|
|
65
|
+
try:
|
|
66
|
+
path.chmod(mode)
|
|
67
|
+
except OSError:
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def init_db(connection: sqlite3.Connection) -> None:
|
|
72
|
+
connection.execute("PRAGMA foreign_keys = ON")
|
|
73
|
+
connection.executescript(
|
|
74
|
+
"""
|
|
75
|
+
CREATE TABLE IF NOT EXISTS sessions (
|
|
76
|
+
session_key TEXT PRIMARY KEY,
|
|
77
|
+
provider TEXT NOT NULL,
|
|
78
|
+
session_id TEXT NOT NULL,
|
|
79
|
+
source_path TEXT NOT NULL UNIQUE,
|
|
80
|
+
parent_session_key TEXT,
|
|
81
|
+
parent_session_id TEXT,
|
|
82
|
+
thread_source TEXT,
|
|
83
|
+
source_label TEXT,
|
|
84
|
+
parser_version INTEGER NOT NULL DEFAULT 2,
|
|
85
|
+
cwd TEXT,
|
|
86
|
+
title TEXT,
|
|
87
|
+
started_at TEXT,
|
|
88
|
+
updated_at TEXT,
|
|
89
|
+
message_count INTEGER NOT NULL DEFAULT 0,
|
|
90
|
+
source_mtime REAL NOT NULL,
|
|
91
|
+
source_size INTEGER NOT NULL,
|
|
92
|
+
imported_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
93
|
+
);
|
|
94
|
+
|
|
95
|
+
CREATE TABLE IF NOT EXISTS messages (
|
|
96
|
+
message_key TEXT PRIMARY KEY,
|
|
97
|
+
session_key TEXT NOT NULL REFERENCES sessions(session_key) ON DELETE CASCADE,
|
|
98
|
+
provider TEXT NOT NULL,
|
|
99
|
+
role TEXT NOT NULL,
|
|
100
|
+
timestamp TEXT,
|
|
101
|
+
ordinal INTEGER NOT NULL,
|
|
102
|
+
raw_type TEXT,
|
|
103
|
+
model TEXT,
|
|
104
|
+
model_provider TEXT,
|
|
105
|
+
text TEXT NOT NULL
|
|
106
|
+
);
|
|
107
|
+
|
|
108
|
+
CREATE TABLE IF NOT EXISTS session_models (
|
|
109
|
+
session_key TEXT NOT NULL REFERENCES sessions(session_key) ON DELETE CASCADE,
|
|
110
|
+
model TEXT NOT NULL,
|
|
111
|
+
model_provider TEXT NOT NULL DEFAULT '',
|
|
112
|
+
message_count INTEGER NOT NULL DEFAULT 0,
|
|
113
|
+
PRIMARY KEY (session_key, model, model_provider)
|
|
114
|
+
);
|
|
115
|
+
|
|
116
|
+
CREATE INDEX IF NOT EXISTS idx_sessions_provider_updated
|
|
117
|
+
ON sessions(provider, updated_at DESC);
|
|
118
|
+
CREATE INDEX IF NOT EXISTS idx_sessions_cwd
|
|
119
|
+
ON sessions(cwd);
|
|
120
|
+
CREATE INDEX IF NOT EXISTS idx_messages_session_ordinal
|
|
121
|
+
ON messages(session_key, ordinal);
|
|
122
|
+
CREATE INDEX IF NOT EXISTS idx_messages_role
|
|
123
|
+
ON messages(role);
|
|
124
|
+
CREATE INDEX IF NOT EXISTS idx_session_models_model
|
|
125
|
+
ON session_models(model, model_provider);
|
|
126
|
+
"""
|
|
127
|
+
)
|
|
128
|
+
connection.execute(
|
|
129
|
+
"""
|
|
130
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS message_fts
|
|
131
|
+
USING fts5(
|
|
132
|
+
text,
|
|
133
|
+
message_key UNINDEXED,
|
|
134
|
+
session_key UNINDEXED,
|
|
135
|
+
provider UNINDEXED,
|
|
136
|
+
role UNINDEXED,
|
|
137
|
+
cwd UNINDEXED,
|
|
138
|
+
tokenize='unicode61'
|
|
139
|
+
)
|
|
140
|
+
"""
|
|
141
|
+
)
|
|
142
|
+
_migrate_db(connection)
|
|
143
|
+
connection.commit()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _migrate_db(connection: sqlite3.Connection) -> None:
|
|
147
|
+
session_columns = {
|
|
148
|
+
row["name"]
|
|
149
|
+
for row in connection.execute("PRAGMA table_info(sessions)")
|
|
150
|
+
}
|
|
151
|
+
for column in ("parent_session_key", "parent_session_id", "thread_source", "source_label"):
|
|
152
|
+
if column not in session_columns:
|
|
153
|
+
connection.execute(f"ALTER TABLE sessions ADD COLUMN {column} TEXT")
|
|
154
|
+
if "parser_version" not in session_columns:
|
|
155
|
+
connection.execute("ALTER TABLE sessions ADD COLUMN parser_version INTEGER NOT NULL DEFAULT 0")
|
|
156
|
+
|
|
157
|
+
message_columns = {
|
|
158
|
+
row["name"]
|
|
159
|
+
for row in connection.execute("PRAGMA table_info(messages)")
|
|
160
|
+
}
|
|
161
|
+
if "model" not in message_columns:
|
|
162
|
+
connection.execute("ALTER TABLE messages ADD COLUMN model TEXT")
|
|
163
|
+
if "model_provider" not in message_columns:
|
|
164
|
+
connection.execute("ALTER TABLE messages ADD COLUMN model_provider TEXT")
|
|
165
|
+
|
|
166
|
+
connection.execute(
|
|
167
|
+
"""
|
|
168
|
+
CREATE TABLE IF NOT EXISTS session_models (
|
|
169
|
+
session_key TEXT NOT NULL REFERENCES sessions(session_key) ON DELETE CASCADE,
|
|
170
|
+
model TEXT NOT NULL,
|
|
171
|
+
model_provider TEXT NOT NULL DEFAULT '',
|
|
172
|
+
message_count INTEGER NOT NULL DEFAULT 0,
|
|
173
|
+
PRIMARY KEY (session_key, model, model_provider)
|
|
174
|
+
)
|
|
175
|
+
"""
|
|
176
|
+
)
|
|
177
|
+
connection.execute(
|
|
178
|
+
"CREATE INDEX IF NOT EXISTS idx_messages_model ON messages(model)"
|
|
179
|
+
)
|
|
180
|
+
connection.execute(
|
|
181
|
+
"CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_key)"
|
|
182
|
+
)
|
|
183
|
+
connection.execute(
|
|
184
|
+
"CREATE INDEX IF NOT EXISTS idx_session_models_model ON session_models(model, model_provider)"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def upsert_session(
|
|
189
|
+
connection: sqlite3.Connection,
|
|
190
|
+
session: ChatSession,
|
|
191
|
+
*,
|
|
192
|
+
force: bool = False,
|
|
193
|
+
) -> bool:
|
|
194
|
+
source_path = str(session.source_path)
|
|
195
|
+
existing = connection.execute(
|
|
196
|
+
"""
|
|
197
|
+
SELECT *
|
|
198
|
+
FROM sessions
|
|
199
|
+
WHERE source_path = ?
|
|
200
|
+
""",
|
|
201
|
+
(source_path,),
|
|
202
|
+
).fetchone()
|
|
203
|
+
if (
|
|
204
|
+
existing
|
|
205
|
+
and not force
|
|
206
|
+
and existing["source_mtime"] == session.source_mtime
|
|
207
|
+
and existing["source_size"] == session.source_size
|
|
208
|
+
and existing["parser_version"] == PARSER_VERSION
|
|
209
|
+
and _session_metadata_matches(existing, session)
|
|
210
|
+
and _merged_title(existing["title"], session.title) == existing["title"]
|
|
211
|
+
):
|
|
212
|
+
return False
|
|
213
|
+
|
|
214
|
+
with connection:
|
|
215
|
+
if not existing:
|
|
216
|
+
session_key = choose_session_key(connection, session)
|
|
217
|
+
_insert_session(connection, session_key, session)
|
|
218
|
+
_insert_messages(connection, session_key, session.messages, cwd=session.cwd, start_ordinal=0)
|
|
219
|
+
_refresh_session_models(connection, session_key)
|
|
220
|
+
return True
|
|
221
|
+
|
|
222
|
+
changed = _merge_session(connection, existing, session)
|
|
223
|
+
return changed > 0
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def source_is_current(connection: sqlite3.Connection, source_path: Path) -> bool:
|
|
227
|
+
try:
|
|
228
|
+
stat = source_path.stat()
|
|
229
|
+
except OSError:
|
|
230
|
+
return False
|
|
231
|
+
existing = connection.execute(
|
|
232
|
+
"""
|
|
233
|
+
SELECT source_mtime, source_size, parser_version
|
|
234
|
+
FROM sessions
|
|
235
|
+
WHERE source_path = ?
|
|
236
|
+
""",
|
|
237
|
+
(str(source_path),),
|
|
238
|
+
).fetchone()
|
|
239
|
+
return bool(
|
|
240
|
+
existing
|
|
241
|
+
and existing["source_mtime"] == stat.st_mtime
|
|
242
|
+
and existing["source_size"] == stat.st_size
|
|
243
|
+
and existing["parser_version"] == PARSER_VERSION
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _insert_session(connection: sqlite3.Connection, session_key: str, session: ChatSession) -> None:
|
|
248
|
+
connection.execute(
|
|
249
|
+
"""
|
|
250
|
+
INSERT INTO sessions (
|
|
251
|
+
session_key, provider, session_id, source_path,
|
|
252
|
+
parent_session_key, parent_session_id, thread_source, source_label,
|
|
253
|
+
parser_version, cwd, title,
|
|
254
|
+
started_at, updated_at, message_count, source_mtime, source_size
|
|
255
|
+
)
|
|
256
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
257
|
+
""",
|
|
258
|
+
(
|
|
259
|
+
session_key,
|
|
260
|
+
session.provider,
|
|
261
|
+
session.session_id,
|
|
262
|
+
str(session.source_path),
|
|
263
|
+
parent_session_key(session),
|
|
264
|
+
session.parent_session_id,
|
|
265
|
+
session.thread_source,
|
|
266
|
+
session.source_label,
|
|
267
|
+
PARSER_VERSION,
|
|
268
|
+
session.cwd,
|
|
269
|
+
session.title,
|
|
270
|
+
session.started_at,
|
|
271
|
+
session.updated_at,
|
|
272
|
+
len(session.messages),
|
|
273
|
+
session.source_mtime,
|
|
274
|
+
session.source_size,
|
|
275
|
+
),
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _merge_session(
|
|
280
|
+
connection: sqlite3.Connection,
|
|
281
|
+
existing: sqlite3.Row,
|
|
282
|
+
session: ChatSession,
|
|
283
|
+
) -> int:
|
|
284
|
+
session_key = existing["session_key"]
|
|
285
|
+
existing_rows = list(
|
|
286
|
+
connection.execute(
|
|
287
|
+
"""
|
|
288
|
+
SELECT message_key, role, timestamp, text, ordinal, model, model_provider
|
|
289
|
+
FROM messages
|
|
290
|
+
WHERE session_key = ?
|
|
291
|
+
ORDER BY ordinal ASC
|
|
292
|
+
""",
|
|
293
|
+
(session_key,),
|
|
294
|
+
)
|
|
295
|
+
)
|
|
296
|
+
existing_by_fingerprint: dict[str, list[sqlite3.Row]] = {}
|
|
297
|
+
for row in existing_rows:
|
|
298
|
+
fingerprint = _message_fingerprint(row["role"], row["timestamp"], row["text"])
|
|
299
|
+
existing_by_fingerprint.setdefault(fingerprint, []).append(row)
|
|
300
|
+
|
|
301
|
+
new_messages: list[ChatMessage] = []
|
|
302
|
+
updated = 0
|
|
303
|
+
for message in session.messages:
|
|
304
|
+
fingerprint = _message_fingerprint(message.role, message.timestamp, message.text)
|
|
305
|
+
if existing_by_fingerprint.get(fingerprint):
|
|
306
|
+
existing_row = existing_by_fingerprint[fingerprint].pop(0)
|
|
307
|
+
updated += _update_message_metadata(connection, existing_row, message)
|
|
308
|
+
continue
|
|
309
|
+
new_messages.append(message)
|
|
310
|
+
|
|
311
|
+
start_ordinal = max((row["ordinal"] for row in existing_rows), default=-1) + 1
|
|
312
|
+
_insert_messages(connection, session_key, new_messages, cwd=session.cwd, start_ordinal=start_ordinal)
|
|
313
|
+
metadata_updated = _update_session_metadata(connection, existing, session)
|
|
314
|
+
if new_messages or updated:
|
|
315
|
+
_refresh_session_models(connection, session_key)
|
|
316
|
+
return len(new_messages) + updated + metadata_updated
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def _insert_messages(
|
|
320
|
+
connection: sqlite3.Connection,
|
|
321
|
+
session_key: str,
|
|
322
|
+
messages: Iterable[ChatMessage],
|
|
323
|
+
*,
|
|
324
|
+
cwd: str | None,
|
|
325
|
+
start_ordinal: int,
|
|
326
|
+
) -> None:
|
|
327
|
+
message_rows: list[tuple[object, ...]] = []
|
|
328
|
+
fts_rows: list[tuple[object, ...]] = []
|
|
329
|
+
for offset, message in enumerate(messages):
|
|
330
|
+
ordinal = start_ordinal + offset
|
|
331
|
+
message_key = f"{session_key}:{ordinal}"
|
|
332
|
+
provider = session_key.split(":", 1)[0]
|
|
333
|
+
message_rows.append(
|
|
334
|
+
(
|
|
335
|
+
message_key,
|
|
336
|
+
session_key,
|
|
337
|
+
provider,
|
|
338
|
+
message.role,
|
|
339
|
+
message.timestamp,
|
|
340
|
+
ordinal,
|
|
341
|
+
message.raw_type,
|
|
342
|
+
message.model,
|
|
343
|
+
message.model_provider,
|
|
344
|
+
message.text,
|
|
345
|
+
)
|
|
346
|
+
)
|
|
347
|
+
fts_rows.append(
|
|
348
|
+
(
|
|
349
|
+
message.text,
|
|
350
|
+
message_key,
|
|
351
|
+
session_key,
|
|
352
|
+
provider,
|
|
353
|
+
message.role,
|
|
354
|
+
cwd,
|
|
355
|
+
)
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
if not message_rows:
|
|
359
|
+
return
|
|
360
|
+
|
|
361
|
+
connection.executemany(
|
|
362
|
+
"""
|
|
363
|
+
INSERT INTO messages (
|
|
364
|
+
message_key, session_key, provider, role, timestamp,
|
|
365
|
+
ordinal, raw_type, model, model_provider, text
|
|
366
|
+
)
|
|
367
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
368
|
+
""",
|
|
369
|
+
message_rows,
|
|
370
|
+
)
|
|
371
|
+
connection.executemany(
|
|
372
|
+
"""
|
|
373
|
+
INSERT INTO message_fts (
|
|
374
|
+
text, message_key, session_key, provider, role, cwd
|
|
375
|
+
)
|
|
376
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
377
|
+
""",
|
|
378
|
+
fts_rows,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def _update_message_metadata(
|
|
383
|
+
connection: sqlite3.Connection,
|
|
384
|
+
row: sqlite3.Row,
|
|
385
|
+
message: ChatMessage,
|
|
386
|
+
) -> int:
|
|
387
|
+
updates: list[str] = []
|
|
388
|
+
params: list[object] = []
|
|
389
|
+
if message.model and row["model"] != message.model:
|
|
390
|
+
updates.append("model = ?")
|
|
391
|
+
params.append(message.model)
|
|
392
|
+
if message.model_provider and row["model_provider"] != message.model_provider:
|
|
393
|
+
updates.append("model_provider = ?")
|
|
394
|
+
params.append(message.model_provider)
|
|
395
|
+
if not updates:
|
|
396
|
+
return 0
|
|
397
|
+
|
|
398
|
+
params.append(row["message_key"])
|
|
399
|
+
connection.execute(
|
|
400
|
+
f"""
|
|
401
|
+
UPDATE messages
|
|
402
|
+
SET {", ".join(updates)}
|
|
403
|
+
WHERE message_key = ?
|
|
404
|
+
""",
|
|
405
|
+
params,
|
|
406
|
+
)
|
|
407
|
+
return 1
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def _refresh_session_models(connection: sqlite3.Connection, session_key: str) -> None:
|
|
411
|
+
connection.execute("DELETE FROM session_models WHERE session_key = ?", (session_key,))
|
|
412
|
+
connection.execute(
|
|
413
|
+
"""
|
|
414
|
+
INSERT INTO session_models (session_key, model, model_provider, message_count)
|
|
415
|
+
SELECT session_key, model, COALESCE(model_provider, ''), COUNT(*) AS message_count
|
|
416
|
+
FROM messages
|
|
417
|
+
WHERE session_key = ?
|
|
418
|
+
AND model IS NOT NULL
|
|
419
|
+
AND model != ''
|
|
420
|
+
GROUP BY session_key, model, COALESCE(model_provider, '')
|
|
421
|
+
""",
|
|
422
|
+
(session_key,),
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def _update_session_metadata(
|
|
427
|
+
connection: sqlite3.Connection,
|
|
428
|
+
existing: sqlite3.Row,
|
|
429
|
+
session: ChatSession,
|
|
430
|
+
) -> int:
|
|
431
|
+
message_count = int(
|
|
432
|
+
connection.execute(
|
|
433
|
+
"SELECT COUNT(*) AS count FROM messages WHERE session_key = ?",
|
|
434
|
+
(existing["session_key"],),
|
|
435
|
+
).fetchone()["count"]
|
|
436
|
+
)
|
|
437
|
+
title = _merged_title(existing["title"], session.title)
|
|
438
|
+
connection.execute(
|
|
439
|
+
"""
|
|
440
|
+
UPDATE sessions
|
|
441
|
+
SET cwd = ?,
|
|
442
|
+
title = ?,
|
|
443
|
+
parent_session_key = ?,
|
|
444
|
+
parent_session_id = ?,
|
|
445
|
+
thread_source = ?,
|
|
446
|
+
source_label = ?,
|
|
447
|
+
parser_version = ?,
|
|
448
|
+
started_at = ?,
|
|
449
|
+
updated_at = ?,
|
|
450
|
+
message_count = ?,
|
|
451
|
+
source_mtime = ?,
|
|
452
|
+
source_size = ?,
|
|
453
|
+
imported_at = CURRENT_TIMESTAMP
|
|
454
|
+
WHERE session_key = ?
|
|
455
|
+
""",
|
|
456
|
+
(
|
|
457
|
+
session.cwd or existing["cwd"],
|
|
458
|
+
title,
|
|
459
|
+
parent_session_key(session),
|
|
460
|
+
session.parent_session_id,
|
|
461
|
+
session.thread_source,
|
|
462
|
+
session.source_label,
|
|
463
|
+
PARSER_VERSION,
|
|
464
|
+
_earlier_time(existing["started_at"], session.started_at),
|
|
465
|
+
_later_time(existing["updated_at"], session.updated_at),
|
|
466
|
+
message_count,
|
|
467
|
+
session.source_mtime,
|
|
468
|
+
session.source_size,
|
|
469
|
+
existing["session_key"],
|
|
470
|
+
),
|
|
471
|
+
)
|
|
472
|
+
changed = (
|
|
473
|
+
title != existing["title"]
|
|
474
|
+
or existing["parent_session_key"] != parent_session_key(session)
|
|
475
|
+
or existing["parent_session_id"] != session.parent_session_id
|
|
476
|
+
or existing["thread_source"] != session.thread_source
|
|
477
|
+
or existing["source_label"] != session.source_label
|
|
478
|
+
or existing["parser_version"] != PARSER_VERSION
|
|
479
|
+
)
|
|
480
|
+
return 1 if changed else 0
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _session_metadata_matches(existing: sqlite3.Row, session: ChatSession) -> bool:
|
|
484
|
+
return (
|
|
485
|
+
existing["parent_session_key"] == parent_session_key(session)
|
|
486
|
+
and existing["parent_session_id"] == session.parent_session_id
|
|
487
|
+
and existing["thread_source"] == session.thread_source
|
|
488
|
+
and existing["source_label"] == session.source_label
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def _merged_title(existing_title: str | None, parsed_title: str | None) -> str | None:
|
|
493
|
+
if parsed_title and not is_title_noise_text(parsed_title):
|
|
494
|
+
if not existing_title or is_title_noise_text(existing_title) or is_image_reference_text(existing_title):
|
|
495
|
+
return parsed_title
|
|
496
|
+
if existing_title and not is_title_noise_text(existing_title) and not is_image_reference_text(existing_title):
|
|
497
|
+
return existing_title
|
|
498
|
+
return parsed_title
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _message_fingerprint(role: str, timestamp: str | None, text: str) -> str:
|
|
502
|
+
digest = hashlib.sha1()
|
|
503
|
+
digest.update(role.encode("utf-8", errors="replace"))
|
|
504
|
+
digest.update(b"\0")
|
|
505
|
+
digest.update((timestamp or "").encode("utf-8", errors="replace"))
|
|
506
|
+
digest.update(b"\0")
|
|
507
|
+
digest.update(text.encode("utf-8", errors="replace"))
|
|
508
|
+
return digest.hexdigest()
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def _earlier_time(left: str | None, right: str | None) -> str | None:
|
|
512
|
+
if not left:
|
|
513
|
+
return right
|
|
514
|
+
if not right:
|
|
515
|
+
return left
|
|
516
|
+
return min(left, right)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def _later_time(left: str | None, right: str | None) -> str | None:
|
|
520
|
+
if not left:
|
|
521
|
+
return right
|
|
522
|
+
if not right:
|
|
523
|
+
return left
|
|
524
|
+
return max(left, right)
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def list_sessions(
|
|
528
|
+
connection: sqlite3.Connection,
|
|
529
|
+
*,
|
|
530
|
+
providers: Iterable[str] | None = None,
|
|
531
|
+
cwd: str | None = None,
|
|
532
|
+
models: Iterable[str] | None = None,
|
|
533
|
+
include_children: bool = False,
|
|
534
|
+
limit: int = 20,
|
|
535
|
+
) -> list[sqlite3.Row]:
|
|
536
|
+
where, params = _session_filters(
|
|
537
|
+
providers=providers,
|
|
538
|
+
cwd=cwd,
|
|
539
|
+
models=models,
|
|
540
|
+
include_children=include_children,
|
|
541
|
+
)
|
|
542
|
+
params.append(limit)
|
|
543
|
+
return list(
|
|
544
|
+
connection.execute(
|
|
545
|
+
f"""
|
|
546
|
+
SELECT s.*,
|
|
547
|
+
(
|
|
548
|
+
SELECT GROUP_CONCAT(model)
|
|
549
|
+
FROM (
|
|
550
|
+
SELECT DISTINCT model
|
|
551
|
+
FROM session_models sm
|
|
552
|
+
WHERE sm.session_key = s.session_key
|
|
553
|
+
ORDER BY model
|
|
554
|
+
)
|
|
555
|
+
) AS models
|
|
556
|
+
FROM sessions s
|
|
557
|
+
{where}
|
|
558
|
+
ORDER BY COALESCE(s.updated_at, s.started_at, s.imported_at) DESC
|
|
559
|
+
LIMIT ?
|
|
560
|
+
""",
|
|
561
|
+
params,
|
|
562
|
+
)
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def search_messages(
|
|
567
|
+
connection: sqlite3.Connection,
|
|
568
|
+
query: str,
|
|
569
|
+
*,
|
|
570
|
+
providers: Iterable[str] | None = None,
|
|
571
|
+
cwd: str | None = None,
|
|
572
|
+
models: Iterable[str] | None = None,
|
|
573
|
+
roles: Iterable[str] | None = DEFAULT_ROLES,
|
|
574
|
+
include_children: bool = False,
|
|
575
|
+
limit: int = 20,
|
|
576
|
+
) -> list[sqlite3.Row]:
|
|
577
|
+
provider_values = tuple(providers or ())
|
|
578
|
+
role_values = tuple(roles or ())
|
|
579
|
+
filters: list[str] = []
|
|
580
|
+
params: list[object] = [to_fts_query(query)]
|
|
581
|
+
|
|
582
|
+
if provider_values:
|
|
583
|
+
filters.append(_in_clause("m.provider", provider_values, params))
|
|
584
|
+
if role_values:
|
|
585
|
+
filters.append(_in_clause("m.role", role_values, params))
|
|
586
|
+
if cwd:
|
|
587
|
+
filters.append("s.cwd LIKE ?")
|
|
588
|
+
params.append(f"%{cwd}%")
|
|
589
|
+
if not include_children:
|
|
590
|
+
filters.append(_root_session_filter("s"))
|
|
591
|
+
_append_model_filter(filters, params, models, session_column="m.session_key")
|
|
592
|
+
|
|
593
|
+
where = ""
|
|
594
|
+
if filters:
|
|
595
|
+
where = "AND " + " AND ".join(filters)
|
|
596
|
+
|
|
597
|
+
params.append(limit)
|
|
598
|
+
try:
|
|
599
|
+
return list(
|
|
600
|
+
connection.execute(
|
|
601
|
+
f"""
|
|
602
|
+
SELECT
|
|
603
|
+
m.message_key, m.session_key, m.provider, m.role, m.timestamp,
|
|
604
|
+
m.ordinal, m.model, m.model_provider, m.text,
|
|
605
|
+
s.session_id, s.cwd, s.title, s.source_path,
|
|
606
|
+
s.parent_session_key, s.parent_session_id, s.thread_source, s.source_label,
|
|
607
|
+
(
|
|
608
|
+
SELECT GROUP_CONCAT(model)
|
|
609
|
+
FROM (
|
|
610
|
+
SELECT DISTINCT model
|
|
611
|
+
FROM session_models sm
|
|
612
|
+
WHERE sm.session_key = m.session_key
|
|
613
|
+
ORDER BY model
|
|
614
|
+
)
|
|
615
|
+
) AS models,
|
|
616
|
+
snippet(message_fts, 0, '[', ']', '...', 18) AS snippet
|
|
617
|
+
FROM message_fts
|
|
618
|
+
JOIN messages m ON m.message_key = message_fts.message_key
|
|
619
|
+
JOIN sessions s ON s.session_key = m.session_key
|
|
620
|
+
WHERE message_fts MATCH ?
|
|
621
|
+
{where}
|
|
622
|
+
ORDER BY bm25(message_fts), COALESCE(m.timestamp, s.updated_at) DESC
|
|
623
|
+
LIMIT ?
|
|
624
|
+
""",
|
|
625
|
+
params,
|
|
626
|
+
)
|
|
627
|
+
)
|
|
628
|
+
except sqlite3.OperationalError:
|
|
629
|
+
return _like_search(
|
|
630
|
+
connection,
|
|
631
|
+
query,
|
|
632
|
+
providers=provider_values,
|
|
633
|
+
cwd=cwd,
|
|
634
|
+
models=tuple(models or ()),
|
|
635
|
+
roles=role_values,
|
|
636
|
+
include_children=include_children,
|
|
637
|
+
limit=limit,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def get_session_messages(
|
|
642
|
+
connection: sqlite3.Connection,
|
|
643
|
+
session_ref: str,
|
|
644
|
+
*,
|
|
645
|
+
roles: Iterable[str] | None = DEFAULT_ROLES,
|
|
646
|
+
) -> tuple[sqlite3.Row, list[sqlite3.Row]]:
|
|
647
|
+
candidates = list(
|
|
648
|
+
connection.execute(
|
|
649
|
+
"""
|
|
650
|
+
SELECT *
|
|
651
|
+
FROM sessions
|
|
652
|
+
WHERE session_key = ?
|
|
653
|
+
OR session_id = ?
|
|
654
|
+
OR session_key LIKE ?
|
|
655
|
+
OR session_id LIKE ?
|
|
656
|
+
ORDER BY updated_at DESC
|
|
657
|
+
LIMIT 5
|
|
658
|
+
""",
|
|
659
|
+
(session_ref, session_ref, f"{session_ref}%", f"{session_ref}%"),
|
|
660
|
+
)
|
|
661
|
+
)
|
|
662
|
+
if not candidates:
|
|
663
|
+
raise LookupError(f"no session matched {session_ref!r}")
|
|
664
|
+
if len(candidates) > 1:
|
|
665
|
+
names = ", ".join(row["session_key"] for row in candidates)
|
|
666
|
+
raise LookupError(f"ambiguous session reference {session_ref!r}: {names}")
|
|
667
|
+
|
|
668
|
+
session = candidates[0]
|
|
669
|
+
params: list[object] = [session["session_key"]]
|
|
670
|
+
role_values = tuple(roles or ())
|
|
671
|
+
role_sql = ""
|
|
672
|
+
if role_values:
|
|
673
|
+
role_sql = "AND " + _in_clause("role", role_values, params)
|
|
674
|
+
|
|
675
|
+
messages = list(
|
|
676
|
+
connection.execute(
|
|
677
|
+
f"""
|
|
678
|
+
SELECT *
|
|
679
|
+
FROM messages
|
|
680
|
+
WHERE session_key = ?
|
|
681
|
+
{role_sql}
|
|
682
|
+
ORDER BY
|
|
683
|
+
CASE WHEN timestamp IS NULL OR timestamp = '' THEN 1 ELSE 0 END,
|
|
684
|
+
timestamp ASC,
|
|
685
|
+
ordinal ASC
|
|
686
|
+
""",
|
|
687
|
+
params,
|
|
688
|
+
)
|
|
689
|
+
)
|
|
690
|
+
return session, messages
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def provider_counts(connection: sqlite3.Connection) -> dict[str, int]:
|
|
694
|
+
return {
|
|
695
|
+
row["provider"]: row["count"]
|
|
696
|
+
for row in connection.execute(
|
|
697
|
+
"SELECT provider, COUNT(*) AS count FROM sessions GROUP BY provider ORDER BY provider"
|
|
698
|
+
)
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def model_counts(connection: sqlite3.Connection) -> list[sqlite3.Row]:
|
|
703
|
+
return list(
|
|
704
|
+
connection.execute(
|
|
705
|
+
"""
|
|
706
|
+
SELECT model,
|
|
707
|
+
model_provider,
|
|
708
|
+
COUNT(DISTINCT session_key) AS sessions,
|
|
709
|
+
SUM(message_count) AS messages
|
|
710
|
+
FROM session_models
|
|
711
|
+
GROUP BY model, model_provider
|
|
712
|
+
ORDER BY sessions DESC, messages DESC, model ASC, model_provider ASC
|
|
713
|
+
"""
|
|
714
|
+
)
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
def repo_counts(
|
|
719
|
+
connection: sqlite3.Connection,
|
|
720
|
+
*,
|
|
721
|
+
providers: Iterable[str] | None = None,
|
|
722
|
+
models: Iterable[str] | None = None,
|
|
723
|
+
include_children: bool = False,
|
|
724
|
+
limit: int = 20,
|
|
725
|
+
) -> list[sqlite3.Row]:
|
|
726
|
+
filters = ["s.cwd IS NOT NULL", "s.cwd != ''"]
|
|
727
|
+
params: list[object] = []
|
|
728
|
+
|
|
729
|
+
provider_values = tuple(providers or ())
|
|
730
|
+
if provider_values:
|
|
731
|
+
filters.append(_in_clause("s.provider", provider_values, params))
|
|
732
|
+
if not include_children:
|
|
733
|
+
filters.append(_root_session_filter("s"))
|
|
734
|
+
_append_model_filter(filters, params, models, session_column="s.session_key")
|
|
735
|
+
|
|
736
|
+
params.append(limit)
|
|
737
|
+
return list(
|
|
738
|
+
connection.execute(
|
|
739
|
+
f"""
|
|
740
|
+
WITH filtered_sessions AS (
|
|
741
|
+
SELECT s.*
|
|
742
|
+
FROM sessions s
|
|
743
|
+
WHERE {" AND ".join(filters)}
|
|
744
|
+
)
|
|
745
|
+
SELECT fs.cwd,
|
|
746
|
+
COUNT(DISTINCT fs.session_key) AS sessions,
|
|
747
|
+
(
|
|
748
|
+
SELECT GROUP_CONCAT(provider)
|
|
749
|
+
FROM (
|
|
750
|
+
SELECT DISTINCT fs_provider.provider
|
|
751
|
+
FROM filtered_sessions fs_provider
|
|
752
|
+
WHERE fs_provider.cwd = fs.cwd
|
|
753
|
+
ORDER BY fs_provider.provider
|
|
754
|
+
)
|
|
755
|
+
) AS providers,
|
|
756
|
+
(
|
|
757
|
+
SELECT GROUP_CONCAT(model)
|
|
758
|
+
FROM (
|
|
759
|
+
SELECT DISTINCT sm.model
|
|
760
|
+
FROM session_models sm
|
|
761
|
+
JOIN filtered_sessions fs_model ON fs_model.session_key = sm.session_key
|
|
762
|
+
WHERE fs_model.cwd = fs.cwd
|
|
763
|
+
ORDER BY sm.model
|
|
764
|
+
)
|
|
765
|
+
) AS models,
|
|
766
|
+
MAX(COALESCE(fs.updated_at, fs.started_at, fs.imported_at)) AS updated_at
|
|
767
|
+
FROM filtered_sessions fs
|
|
768
|
+
GROUP BY fs.cwd
|
|
769
|
+
ORDER BY sessions DESC, updated_at DESC, fs.cwd ASC
|
|
770
|
+
LIMIT ?
|
|
771
|
+
""",
|
|
772
|
+
params,
|
|
773
|
+
)
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
def make_session_key(provider: str, session_id: str) -> str:
|
|
778
|
+
return f"{provider}:{session_id}"
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def parent_session_key(session: ChatSession) -> str | None:
|
|
782
|
+
if not session.parent_session_id:
|
|
783
|
+
return None
|
|
784
|
+
return make_session_key(session.provider, session.parent_session_id)
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def choose_session_key(connection: sqlite3.Connection, session: ChatSession) -> str:
|
|
788
|
+
base_key = make_session_key(session.provider, session.session_id)
|
|
789
|
+
existing = connection.execute(
|
|
790
|
+
"SELECT source_path FROM sessions WHERE session_key = ?",
|
|
791
|
+
(base_key,),
|
|
792
|
+
).fetchone()
|
|
793
|
+
if not existing or existing["source_path"] == str(session.source_path):
|
|
794
|
+
return base_key
|
|
795
|
+
|
|
796
|
+
digest = hashlib.sha1(str(session.source_path).encode("utf-8")).hexdigest()[:10]
|
|
797
|
+
return f"{base_key}:{digest}"
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
def to_fts_query(query: str) -> str:
|
|
801
|
+
terms = [term for term in re.findall(r'"[^"]+"|\S+', query) if term.strip()]
|
|
802
|
+
if not terms:
|
|
803
|
+
return '""'
|
|
804
|
+
|
|
805
|
+
escaped: list[str] = []
|
|
806
|
+
for term in terms:
|
|
807
|
+
term = term.strip()
|
|
808
|
+
if term.startswith('"') and term.endswith('"') and len(term) > 1:
|
|
809
|
+
term = term[1:-1]
|
|
810
|
+
term = term.replace('"', '""')
|
|
811
|
+
escaped.append(f'"{term}"')
|
|
812
|
+
return " AND ".join(escaped)
|
|
813
|
+
|
|
814
|
+
|
|
815
|
+
def _like_search(
|
|
816
|
+
connection: sqlite3.Connection,
|
|
817
|
+
query: str,
|
|
818
|
+
*,
|
|
819
|
+
providers: Iterable[str],
|
|
820
|
+
cwd: str | None,
|
|
821
|
+
models: Iterable[str],
|
|
822
|
+
roles: Iterable[str],
|
|
823
|
+
include_children: bool,
|
|
824
|
+
limit: int,
|
|
825
|
+
) -> list[sqlite3.Row]:
|
|
826
|
+
filters = ["m.text LIKE ?"]
|
|
827
|
+
params: list[object] = [f"%{query}%"]
|
|
828
|
+
|
|
829
|
+
provider_values = tuple(providers)
|
|
830
|
+
if provider_values:
|
|
831
|
+
filters.append(_in_clause("m.provider", provider_values, params))
|
|
832
|
+
|
|
833
|
+
role_values = tuple(roles)
|
|
834
|
+
if role_values:
|
|
835
|
+
filters.append(_in_clause("m.role", role_values, params))
|
|
836
|
+
|
|
837
|
+
if cwd:
|
|
838
|
+
filters.append("s.cwd LIKE ?")
|
|
839
|
+
params.append(f"%{cwd}%")
|
|
840
|
+
if not include_children:
|
|
841
|
+
filters.append(_root_session_filter("s"))
|
|
842
|
+
_append_model_filter(filters, params, models, session_column="m.session_key")
|
|
843
|
+
|
|
844
|
+
params.append(limit)
|
|
845
|
+
return list(
|
|
846
|
+
connection.execute(
|
|
847
|
+
f"""
|
|
848
|
+
SELECT
|
|
849
|
+
m.message_key, m.session_key, m.provider, m.role, m.timestamp,
|
|
850
|
+
m.ordinal, m.model, m.model_provider, m.text,
|
|
851
|
+
s.session_id, s.cwd, s.title, s.source_path,
|
|
852
|
+
s.parent_session_key, s.parent_session_id, s.thread_source, s.source_label,
|
|
853
|
+
(
|
|
854
|
+
SELECT GROUP_CONCAT(model)
|
|
855
|
+
FROM (
|
|
856
|
+
SELECT DISTINCT model
|
|
857
|
+
FROM session_models sm
|
|
858
|
+
WHERE sm.session_key = m.session_key
|
|
859
|
+
ORDER BY model
|
|
860
|
+
)
|
|
861
|
+
) AS models,
|
|
862
|
+
m.text AS snippet
|
|
863
|
+
FROM messages m
|
|
864
|
+
JOIN sessions s ON s.session_key = m.session_key
|
|
865
|
+
WHERE {" AND ".join(filters)}
|
|
866
|
+
ORDER BY COALESCE(m.timestamp, s.updated_at) DESC
|
|
867
|
+
LIMIT ?
|
|
868
|
+
""",
|
|
869
|
+
params,
|
|
870
|
+
)
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
def _session_filters(
|
|
875
|
+
*,
|
|
876
|
+
providers: Iterable[str] | None,
|
|
877
|
+
cwd: str | None,
|
|
878
|
+
models: Iterable[str] | None,
|
|
879
|
+
include_children: bool,
|
|
880
|
+
) -> tuple[str, list[object]]:
|
|
881
|
+
filters: list[str] = []
|
|
882
|
+
params: list[object] = []
|
|
883
|
+
|
|
884
|
+
provider_values = tuple(providers or ())
|
|
885
|
+
if provider_values:
|
|
886
|
+
filters.append(_in_clause("s.provider", provider_values, params))
|
|
887
|
+
if cwd:
|
|
888
|
+
filters.append("s.cwd LIKE ?")
|
|
889
|
+
params.append(f"%{cwd}%")
|
|
890
|
+
if not include_children:
|
|
891
|
+
filters.append(_root_session_filter("s"))
|
|
892
|
+
_append_model_filter(filters, params, models, session_column="s.session_key")
|
|
893
|
+
|
|
894
|
+
if not filters:
|
|
895
|
+
return "", params
|
|
896
|
+
return "WHERE " + " AND ".join(filters), params
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def _root_session_filter(alias: str) -> str:
|
|
900
|
+
return f"({alias}.parent_session_key IS NULL OR {alias}.parent_session_key = '')"
|
|
901
|
+
|
|
902
|
+
|
|
903
|
+
def _append_model_filter(
|
|
904
|
+
filters: list[str],
|
|
905
|
+
params: list[object],
|
|
906
|
+
models: Iterable[str] | None,
|
|
907
|
+
*,
|
|
908
|
+
session_column: str,
|
|
909
|
+
) -> None:
|
|
910
|
+
model_values = tuple(model for model in (models or ()) if model)
|
|
911
|
+
if not model_values:
|
|
912
|
+
return
|
|
913
|
+
|
|
914
|
+
model_params: list[object] = []
|
|
915
|
+
clause = _in_clause("sm_filter.model", model_values, model_params)
|
|
916
|
+
params.extend(model_params)
|
|
917
|
+
filters.append(
|
|
918
|
+
f"""
|
|
919
|
+
EXISTS (
|
|
920
|
+
SELECT 1
|
|
921
|
+
FROM session_models sm_filter
|
|
922
|
+
WHERE sm_filter.session_key = {session_column}
|
|
923
|
+
AND {clause}
|
|
924
|
+
)
|
|
925
|
+
"""
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
def _in_clause(column: str, values: Iterable[str], params: list[object]) -> str:
|
|
930
|
+
clean_values = tuple(value for value in values if value)
|
|
931
|
+
if not clean_values:
|
|
932
|
+
return "1 = 1"
|
|
933
|
+
params.extend(clean_values)
|
|
934
|
+
placeholders = ", ".join("?" for _ in clean_values)
|
|
935
|
+
return f"{column} IN ({placeholders})"
|