gora-cli 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gora/store.py ADDED
@@ -0,0 +1,935 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import os
5
+ from pathlib import Path
6
+ import re
7
+ import sqlite3
8
+ import sys
9
+ from typing import Iterable
10
+
11
+ from .parsers import ChatMessage, ChatSession, is_image_reference_text, is_title_noise_text
12
+
13
+
14
+ PARSER_VERSION = 2
15
+ DEFAULT_ROLES: tuple[str, ...] | None = None
16
+
17
+
18
+ def default_db_path() -> Path:
19
+ override = os.environ.get("GORA_DB")
20
+ if override:
21
+ return Path(override).expanduser()
22
+
23
+ xdg_data_home = os.environ.get("XDG_DATA_HOME")
24
+ if xdg_data_home:
25
+ return Path(xdg_data_home).expanduser() / "gora" / "history.sqlite"
26
+
27
+ if sys.platform == "darwin":
28
+ return Path.home() / "Library" / "Application Support" / "gora" / "history.sqlite"
29
+
30
+ return Path.home() / ".local" / "share" / "gora" / "history.sqlite"
31
+
32
+
33
+ def connect(db_path: Path | None = None) -> sqlite3.Connection:
34
+ path = db_path or default_db_path()
35
+ _prepare_db_path(path, secure_parent=db_path is None and not os.environ.get("GORA_DB"))
36
+ connection = sqlite3.connect(path)
37
+ connection.row_factory = sqlite3.Row
38
+ init_db(connection)
39
+ _secure_sqlite_files(path)
40
+ return connection
41
+
42
+
43
+ def _prepare_db_path(path: Path, *, secure_parent: bool) -> None:
44
+ if secure_parent:
45
+ path.parent.mkdir(parents=True, mode=0o700, exist_ok=True)
46
+ _chmod(path.parent, 0o700)
47
+ else:
48
+ path.parent.mkdir(parents=True, exist_ok=True)
49
+
50
+ if path.exists():
51
+ _chmod(path, 0o600)
52
+ return
53
+
54
+ descriptor = os.open(path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
55
+ os.close(descriptor)
56
+
57
+
58
+ def _secure_sqlite_files(path: Path) -> None:
59
+ for candidate in (path, path.with_name(f"{path.name}-journal"), path.with_name(f"{path.name}-wal"), path.with_name(f"{path.name}-shm")):
60
+ if candidate.exists():
61
+ _chmod(candidate, 0o600)
62
+
63
+
64
+ def _chmod(path: Path, mode: int) -> None:
65
+ try:
66
+ path.chmod(mode)
67
+ except OSError:
68
+ pass
69
+
70
+
71
+ def init_db(connection: sqlite3.Connection) -> None:
72
+ connection.execute("PRAGMA foreign_keys = ON")
73
+ connection.executescript(
74
+ """
75
+ CREATE TABLE IF NOT EXISTS sessions (
76
+ session_key TEXT PRIMARY KEY,
77
+ provider TEXT NOT NULL,
78
+ session_id TEXT NOT NULL,
79
+ source_path TEXT NOT NULL UNIQUE,
80
+ parent_session_key TEXT,
81
+ parent_session_id TEXT,
82
+ thread_source TEXT,
83
+ source_label TEXT,
84
+ parser_version INTEGER NOT NULL DEFAULT 2,
85
+ cwd TEXT,
86
+ title TEXT,
87
+ started_at TEXT,
88
+ updated_at TEXT,
89
+ message_count INTEGER NOT NULL DEFAULT 0,
90
+ source_mtime REAL NOT NULL,
91
+ source_size INTEGER NOT NULL,
92
+ imported_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
93
+ );
94
+
95
+ CREATE TABLE IF NOT EXISTS messages (
96
+ message_key TEXT PRIMARY KEY,
97
+ session_key TEXT NOT NULL REFERENCES sessions(session_key) ON DELETE CASCADE,
98
+ provider TEXT NOT NULL,
99
+ role TEXT NOT NULL,
100
+ timestamp TEXT,
101
+ ordinal INTEGER NOT NULL,
102
+ raw_type TEXT,
103
+ model TEXT,
104
+ model_provider TEXT,
105
+ text TEXT NOT NULL
106
+ );
107
+
108
+ CREATE TABLE IF NOT EXISTS session_models (
109
+ session_key TEXT NOT NULL REFERENCES sessions(session_key) ON DELETE CASCADE,
110
+ model TEXT NOT NULL,
111
+ model_provider TEXT NOT NULL DEFAULT '',
112
+ message_count INTEGER NOT NULL DEFAULT 0,
113
+ PRIMARY KEY (session_key, model, model_provider)
114
+ );
115
+
116
+ CREATE INDEX IF NOT EXISTS idx_sessions_provider_updated
117
+ ON sessions(provider, updated_at DESC);
118
+ CREATE INDEX IF NOT EXISTS idx_sessions_cwd
119
+ ON sessions(cwd);
120
+ CREATE INDEX IF NOT EXISTS idx_messages_session_ordinal
121
+ ON messages(session_key, ordinal);
122
+ CREATE INDEX IF NOT EXISTS idx_messages_role
123
+ ON messages(role);
124
+ CREATE INDEX IF NOT EXISTS idx_session_models_model
125
+ ON session_models(model, model_provider);
126
+ """
127
+ )
128
+ connection.execute(
129
+ """
130
+ CREATE VIRTUAL TABLE IF NOT EXISTS message_fts
131
+ USING fts5(
132
+ text,
133
+ message_key UNINDEXED,
134
+ session_key UNINDEXED,
135
+ provider UNINDEXED,
136
+ role UNINDEXED,
137
+ cwd UNINDEXED,
138
+ tokenize='unicode61'
139
+ )
140
+ """
141
+ )
142
+ _migrate_db(connection)
143
+ connection.commit()
144
+
145
+
146
+ def _migrate_db(connection: sqlite3.Connection) -> None:
147
+ session_columns = {
148
+ row["name"]
149
+ for row in connection.execute("PRAGMA table_info(sessions)")
150
+ }
151
+ for column in ("parent_session_key", "parent_session_id", "thread_source", "source_label"):
152
+ if column not in session_columns:
153
+ connection.execute(f"ALTER TABLE sessions ADD COLUMN {column} TEXT")
154
+ if "parser_version" not in session_columns:
155
+ connection.execute("ALTER TABLE sessions ADD COLUMN parser_version INTEGER NOT NULL DEFAULT 0")
156
+
157
+ message_columns = {
158
+ row["name"]
159
+ for row in connection.execute("PRAGMA table_info(messages)")
160
+ }
161
+ if "model" not in message_columns:
162
+ connection.execute("ALTER TABLE messages ADD COLUMN model TEXT")
163
+ if "model_provider" not in message_columns:
164
+ connection.execute("ALTER TABLE messages ADD COLUMN model_provider TEXT")
165
+
166
+ connection.execute(
167
+ """
168
+ CREATE TABLE IF NOT EXISTS session_models (
169
+ session_key TEXT NOT NULL REFERENCES sessions(session_key) ON DELETE CASCADE,
170
+ model TEXT NOT NULL,
171
+ model_provider TEXT NOT NULL DEFAULT '',
172
+ message_count INTEGER NOT NULL DEFAULT 0,
173
+ PRIMARY KEY (session_key, model, model_provider)
174
+ )
175
+ """
176
+ )
177
+ connection.execute(
178
+ "CREATE INDEX IF NOT EXISTS idx_messages_model ON messages(model)"
179
+ )
180
+ connection.execute(
181
+ "CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_key)"
182
+ )
183
+ connection.execute(
184
+ "CREATE INDEX IF NOT EXISTS idx_session_models_model ON session_models(model, model_provider)"
185
+ )
186
+
187
+
188
+ def upsert_session(
189
+ connection: sqlite3.Connection,
190
+ session: ChatSession,
191
+ *,
192
+ force: bool = False,
193
+ ) -> bool:
194
+ source_path = str(session.source_path)
195
+ existing = connection.execute(
196
+ """
197
+ SELECT *
198
+ FROM sessions
199
+ WHERE source_path = ?
200
+ """,
201
+ (source_path,),
202
+ ).fetchone()
203
+ if (
204
+ existing
205
+ and not force
206
+ and existing["source_mtime"] == session.source_mtime
207
+ and existing["source_size"] == session.source_size
208
+ and existing["parser_version"] == PARSER_VERSION
209
+ and _session_metadata_matches(existing, session)
210
+ and _merged_title(existing["title"], session.title) == existing["title"]
211
+ ):
212
+ return False
213
+
214
+ with connection:
215
+ if not existing:
216
+ session_key = choose_session_key(connection, session)
217
+ _insert_session(connection, session_key, session)
218
+ _insert_messages(connection, session_key, session.messages, cwd=session.cwd, start_ordinal=0)
219
+ _refresh_session_models(connection, session_key)
220
+ return True
221
+
222
+ changed = _merge_session(connection, existing, session)
223
+ return changed > 0
224
+
225
+
226
+ def source_is_current(connection: sqlite3.Connection, source_path: Path) -> bool:
227
+ try:
228
+ stat = source_path.stat()
229
+ except OSError:
230
+ return False
231
+ existing = connection.execute(
232
+ """
233
+ SELECT source_mtime, source_size, parser_version
234
+ FROM sessions
235
+ WHERE source_path = ?
236
+ """,
237
+ (str(source_path),),
238
+ ).fetchone()
239
+ return bool(
240
+ existing
241
+ and existing["source_mtime"] == stat.st_mtime
242
+ and existing["source_size"] == stat.st_size
243
+ and existing["parser_version"] == PARSER_VERSION
244
+ )
245
+
246
+
247
+ def _insert_session(connection: sqlite3.Connection, session_key: str, session: ChatSession) -> None:
248
+ connection.execute(
249
+ """
250
+ INSERT INTO sessions (
251
+ session_key, provider, session_id, source_path,
252
+ parent_session_key, parent_session_id, thread_source, source_label,
253
+ parser_version, cwd, title,
254
+ started_at, updated_at, message_count, source_mtime, source_size
255
+ )
256
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
257
+ """,
258
+ (
259
+ session_key,
260
+ session.provider,
261
+ session.session_id,
262
+ str(session.source_path),
263
+ parent_session_key(session),
264
+ session.parent_session_id,
265
+ session.thread_source,
266
+ session.source_label,
267
+ PARSER_VERSION,
268
+ session.cwd,
269
+ session.title,
270
+ session.started_at,
271
+ session.updated_at,
272
+ len(session.messages),
273
+ session.source_mtime,
274
+ session.source_size,
275
+ ),
276
+ )
277
+
278
+
279
+ def _merge_session(
280
+ connection: sqlite3.Connection,
281
+ existing: sqlite3.Row,
282
+ session: ChatSession,
283
+ ) -> int:
284
+ session_key = existing["session_key"]
285
+ existing_rows = list(
286
+ connection.execute(
287
+ """
288
+ SELECT message_key, role, timestamp, text, ordinal, model, model_provider
289
+ FROM messages
290
+ WHERE session_key = ?
291
+ ORDER BY ordinal ASC
292
+ """,
293
+ (session_key,),
294
+ )
295
+ )
296
+ existing_by_fingerprint: dict[str, list[sqlite3.Row]] = {}
297
+ for row in existing_rows:
298
+ fingerprint = _message_fingerprint(row["role"], row["timestamp"], row["text"])
299
+ existing_by_fingerprint.setdefault(fingerprint, []).append(row)
300
+
301
+ new_messages: list[ChatMessage] = []
302
+ updated = 0
303
+ for message in session.messages:
304
+ fingerprint = _message_fingerprint(message.role, message.timestamp, message.text)
305
+ if existing_by_fingerprint.get(fingerprint):
306
+ existing_row = existing_by_fingerprint[fingerprint].pop(0)
307
+ updated += _update_message_metadata(connection, existing_row, message)
308
+ continue
309
+ new_messages.append(message)
310
+
311
+ start_ordinal = max((row["ordinal"] for row in existing_rows), default=-1) + 1
312
+ _insert_messages(connection, session_key, new_messages, cwd=session.cwd, start_ordinal=start_ordinal)
313
+ metadata_updated = _update_session_metadata(connection, existing, session)
314
+ if new_messages or updated:
315
+ _refresh_session_models(connection, session_key)
316
+ return len(new_messages) + updated + metadata_updated
317
+
318
+
319
+ def _insert_messages(
320
+ connection: sqlite3.Connection,
321
+ session_key: str,
322
+ messages: Iterable[ChatMessage],
323
+ *,
324
+ cwd: str | None,
325
+ start_ordinal: int,
326
+ ) -> None:
327
+ message_rows: list[tuple[object, ...]] = []
328
+ fts_rows: list[tuple[object, ...]] = []
329
+ for offset, message in enumerate(messages):
330
+ ordinal = start_ordinal + offset
331
+ message_key = f"{session_key}:{ordinal}"
332
+ provider = session_key.split(":", 1)[0]
333
+ message_rows.append(
334
+ (
335
+ message_key,
336
+ session_key,
337
+ provider,
338
+ message.role,
339
+ message.timestamp,
340
+ ordinal,
341
+ message.raw_type,
342
+ message.model,
343
+ message.model_provider,
344
+ message.text,
345
+ )
346
+ )
347
+ fts_rows.append(
348
+ (
349
+ message.text,
350
+ message_key,
351
+ session_key,
352
+ provider,
353
+ message.role,
354
+ cwd,
355
+ )
356
+ )
357
+
358
+ if not message_rows:
359
+ return
360
+
361
+ connection.executemany(
362
+ """
363
+ INSERT INTO messages (
364
+ message_key, session_key, provider, role, timestamp,
365
+ ordinal, raw_type, model, model_provider, text
366
+ )
367
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
368
+ """,
369
+ message_rows,
370
+ )
371
+ connection.executemany(
372
+ """
373
+ INSERT INTO message_fts (
374
+ text, message_key, session_key, provider, role, cwd
375
+ )
376
+ VALUES (?, ?, ?, ?, ?, ?)
377
+ """,
378
+ fts_rows,
379
+ )
380
+
381
+
382
+ def _update_message_metadata(
383
+ connection: sqlite3.Connection,
384
+ row: sqlite3.Row,
385
+ message: ChatMessage,
386
+ ) -> int:
387
+ updates: list[str] = []
388
+ params: list[object] = []
389
+ if message.model and row["model"] != message.model:
390
+ updates.append("model = ?")
391
+ params.append(message.model)
392
+ if message.model_provider and row["model_provider"] != message.model_provider:
393
+ updates.append("model_provider = ?")
394
+ params.append(message.model_provider)
395
+ if not updates:
396
+ return 0
397
+
398
+ params.append(row["message_key"])
399
+ connection.execute(
400
+ f"""
401
+ UPDATE messages
402
+ SET {", ".join(updates)}
403
+ WHERE message_key = ?
404
+ """,
405
+ params,
406
+ )
407
+ return 1
408
+
409
+
410
+ def _refresh_session_models(connection: sqlite3.Connection, session_key: str) -> None:
411
+ connection.execute("DELETE FROM session_models WHERE session_key = ?", (session_key,))
412
+ connection.execute(
413
+ """
414
+ INSERT INTO session_models (session_key, model, model_provider, message_count)
415
+ SELECT session_key, model, COALESCE(model_provider, ''), COUNT(*) AS message_count
416
+ FROM messages
417
+ WHERE session_key = ?
418
+ AND model IS NOT NULL
419
+ AND model != ''
420
+ GROUP BY session_key, model, COALESCE(model_provider, '')
421
+ """,
422
+ (session_key,),
423
+ )
424
+
425
+
426
+ def _update_session_metadata(
427
+ connection: sqlite3.Connection,
428
+ existing: sqlite3.Row,
429
+ session: ChatSession,
430
+ ) -> int:
431
+ message_count = int(
432
+ connection.execute(
433
+ "SELECT COUNT(*) AS count FROM messages WHERE session_key = ?",
434
+ (existing["session_key"],),
435
+ ).fetchone()["count"]
436
+ )
437
+ title = _merged_title(existing["title"], session.title)
438
+ connection.execute(
439
+ """
440
+ UPDATE sessions
441
+ SET cwd = ?,
442
+ title = ?,
443
+ parent_session_key = ?,
444
+ parent_session_id = ?,
445
+ thread_source = ?,
446
+ source_label = ?,
447
+ parser_version = ?,
448
+ started_at = ?,
449
+ updated_at = ?,
450
+ message_count = ?,
451
+ source_mtime = ?,
452
+ source_size = ?,
453
+ imported_at = CURRENT_TIMESTAMP
454
+ WHERE session_key = ?
455
+ """,
456
+ (
457
+ session.cwd or existing["cwd"],
458
+ title,
459
+ parent_session_key(session),
460
+ session.parent_session_id,
461
+ session.thread_source,
462
+ session.source_label,
463
+ PARSER_VERSION,
464
+ _earlier_time(existing["started_at"], session.started_at),
465
+ _later_time(existing["updated_at"], session.updated_at),
466
+ message_count,
467
+ session.source_mtime,
468
+ session.source_size,
469
+ existing["session_key"],
470
+ ),
471
+ )
472
+ changed = (
473
+ title != existing["title"]
474
+ or existing["parent_session_key"] != parent_session_key(session)
475
+ or existing["parent_session_id"] != session.parent_session_id
476
+ or existing["thread_source"] != session.thread_source
477
+ or existing["source_label"] != session.source_label
478
+ or existing["parser_version"] != PARSER_VERSION
479
+ )
480
+ return 1 if changed else 0
481
+
482
+
483
+ def _session_metadata_matches(existing: sqlite3.Row, session: ChatSession) -> bool:
484
+ return (
485
+ existing["parent_session_key"] == parent_session_key(session)
486
+ and existing["parent_session_id"] == session.parent_session_id
487
+ and existing["thread_source"] == session.thread_source
488
+ and existing["source_label"] == session.source_label
489
+ )
490
+
491
+
492
+ def _merged_title(existing_title: str | None, parsed_title: str | None) -> str | None:
493
+ if parsed_title and not is_title_noise_text(parsed_title):
494
+ if not existing_title or is_title_noise_text(existing_title) or is_image_reference_text(existing_title):
495
+ return parsed_title
496
+ if existing_title and not is_title_noise_text(existing_title) and not is_image_reference_text(existing_title):
497
+ return existing_title
498
+ return parsed_title
499
+
500
+
501
+ def _message_fingerprint(role: str, timestamp: str | None, text: str) -> str:
502
+ digest = hashlib.sha1()
503
+ digest.update(role.encode("utf-8", errors="replace"))
504
+ digest.update(b"\0")
505
+ digest.update((timestamp or "").encode("utf-8", errors="replace"))
506
+ digest.update(b"\0")
507
+ digest.update(text.encode("utf-8", errors="replace"))
508
+ return digest.hexdigest()
509
+
510
+
511
+ def _earlier_time(left: str | None, right: str | None) -> str | None:
512
+ if not left:
513
+ return right
514
+ if not right:
515
+ return left
516
+ return min(left, right)
517
+
518
+
519
+ def _later_time(left: str | None, right: str | None) -> str | None:
520
+ if not left:
521
+ return right
522
+ if not right:
523
+ return left
524
+ return max(left, right)
525
+
526
+
527
+ def list_sessions(
528
+ connection: sqlite3.Connection,
529
+ *,
530
+ providers: Iterable[str] | None = None,
531
+ cwd: str | None = None,
532
+ models: Iterable[str] | None = None,
533
+ include_children: bool = False,
534
+ limit: int = 20,
535
+ ) -> list[sqlite3.Row]:
536
+ where, params = _session_filters(
537
+ providers=providers,
538
+ cwd=cwd,
539
+ models=models,
540
+ include_children=include_children,
541
+ )
542
+ params.append(limit)
543
+ return list(
544
+ connection.execute(
545
+ f"""
546
+ SELECT s.*,
547
+ (
548
+ SELECT GROUP_CONCAT(model)
549
+ FROM (
550
+ SELECT DISTINCT model
551
+ FROM session_models sm
552
+ WHERE sm.session_key = s.session_key
553
+ ORDER BY model
554
+ )
555
+ ) AS models
556
+ FROM sessions s
557
+ {where}
558
+ ORDER BY COALESCE(s.updated_at, s.started_at, s.imported_at) DESC
559
+ LIMIT ?
560
+ """,
561
+ params,
562
+ )
563
+ )
564
+
565
+
566
+ def search_messages(
567
+ connection: sqlite3.Connection,
568
+ query: str,
569
+ *,
570
+ providers: Iterable[str] | None = None,
571
+ cwd: str | None = None,
572
+ models: Iterable[str] | None = None,
573
+ roles: Iterable[str] | None = DEFAULT_ROLES,
574
+ include_children: bool = False,
575
+ limit: int = 20,
576
+ ) -> list[sqlite3.Row]:
577
+ provider_values = tuple(providers or ())
578
+ role_values = tuple(roles or ())
579
+ filters: list[str] = []
580
+ params: list[object] = [to_fts_query(query)]
581
+
582
+ if provider_values:
583
+ filters.append(_in_clause("m.provider", provider_values, params))
584
+ if role_values:
585
+ filters.append(_in_clause("m.role", role_values, params))
586
+ if cwd:
587
+ filters.append("s.cwd LIKE ?")
588
+ params.append(f"%{cwd}%")
589
+ if not include_children:
590
+ filters.append(_root_session_filter("s"))
591
+ _append_model_filter(filters, params, models, session_column="m.session_key")
592
+
593
+ where = ""
594
+ if filters:
595
+ where = "AND " + " AND ".join(filters)
596
+
597
+ params.append(limit)
598
+ try:
599
+ return list(
600
+ connection.execute(
601
+ f"""
602
+ SELECT
603
+ m.message_key, m.session_key, m.provider, m.role, m.timestamp,
604
+ m.ordinal, m.model, m.model_provider, m.text,
605
+ s.session_id, s.cwd, s.title, s.source_path,
606
+ s.parent_session_key, s.parent_session_id, s.thread_source, s.source_label,
607
+ (
608
+ SELECT GROUP_CONCAT(model)
609
+ FROM (
610
+ SELECT DISTINCT model
611
+ FROM session_models sm
612
+ WHERE sm.session_key = m.session_key
613
+ ORDER BY model
614
+ )
615
+ ) AS models,
616
+ snippet(message_fts, 0, '[', ']', '...', 18) AS snippet
617
+ FROM message_fts
618
+ JOIN messages m ON m.message_key = message_fts.message_key
619
+ JOIN sessions s ON s.session_key = m.session_key
620
+ WHERE message_fts MATCH ?
621
+ {where}
622
+ ORDER BY bm25(message_fts), COALESCE(m.timestamp, s.updated_at) DESC
623
+ LIMIT ?
624
+ """,
625
+ params,
626
+ )
627
+ )
628
+ except sqlite3.OperationalError:
629
+ return _like_search(
630
+ connection,
631
+ query,
632
+ providers=provider_values,
633
+ cwd=cwd,
634
+ models=tuple(models or ()),
635
+ roles=role_values,
636
+ include_children=include_children,
637
+ limit=limit,
638
+ )
639
+
640
+
641
+ def get_session_messages(
642
+ connection: sqlite3.Connection,
643
+ session_ref: str,
644
+ *,
645
+ roles: Iterable[str] | None = DEFAULT_ROLES,
646
+ ) -> tuple[sqlite3.Row, list[sqlite3.Row]]:
647
+ candidates = list(
648
+ connection.execute(
649
+ """
650
+ SELECT *
651
+ FROM sessions
652
+ WHERE session_key = ?
653
+ OR session_id = ?
654
+ OR session_key LIKE ?
655
+ OR session_id LIKE ?
656
+ ORDER BY updated_at DESC
657
+ LIMIT 5
658
+ """,
659
+ (session_ref, session_ref, f"{session_ref}%", f"{session_ref}%"),
660
+ )
661
+ )
662
+ if not candidates:
663
+ raise LookupError(f"no session matched {session_ref!r}")
664
+ if len(candidates) > 1:
665
+ names = ", ".join(row["session_key"] for row in candidates)
666
+ raise LookupError(f"ambiguous session reference {session_ref!r}: {names}")
667
+
668
+ session = candidates[0]
669
+ params: list[object] = [session["session_key"]]
670
+ role_values = tuple(roles or ())
671
+ role_sql = ""
672
+ if role_values:
673
+ role_sql = "AND " + _in_clause("role", role_values, params)
674
+
675
+ messages = list(
676
+ connection.execute(
677
+ f"""
678
+ SELECT *
679
+ FROM messages
680
+ WHERE session_key = ?
681
+ {role_sql}
682
+ ORDER BY
683
+ CASE WHEN timestamp IS NULL OR timestamp = '' THEN 1 ELSE 0 END,
684
+ timestamp ASC,
685
+ ordinal ASC
686
+ """,
687
+ params,
688
+ )
689
+ )
690
+ return session, messages
691
+
692
+
693
+ def provider_counts(connection: sqlite3.Connection) -> dict[str, int]:
694
+ return {
695
+ row["provider"]: row["count"]
696
+ for row in connection.execute(
697
+ "SELECT provider, COUNT(*) AS count FROM sessions GROUP BY provider ORDER BY provider"
698
+ )
699
+ }
700
+
701
+
702
+ def model_counts(connection: sqlite3.Connection) -> list[sqlite3.Row]:
703
+ return list(
704
+ connection.execute(
705
+ """
706
+ SELECT model,
707
+ model_provider,
708
+ COUNT(DISTINCT session_key) AS sessions,
709
+ SUM(message_count) AS messages
710
+ FROM session_models
711
+ GROUP BY model, model_provider
712
+ ORDER BY sessions DESC, messages DESC, model ASC, model_provider ASC
713
+ """
714
+ )
715
+ )
716
+
717
+
718
+ def repo_counts(
719
+ connection: sqlite3.Connection,
720
+ *,
721
+ providers: Iterable[str] | None = None,
722
+ models: Iterable[str] | None = None,
723
+ include_children: bool = False,
724
+ limit: int = 20,
725
+ ) -> list[sqlite3.Row]:
726
+ filters = ["s.cwd IS NOT NULL", "s.cwd != ''"]
727
+ params: list[object] = []
728
+
729
+ provider_values = tuple(providers or ())
730
+ if provider_values:
731
+ filters.append(_in_clause("s.provider", provider_values, params))
732
+ if not include_children:
733
+ filters.append(_root_session_filter("s"))
734
+ _append_model_filter(filters, params, models, session_column="s.session_key")
735
+
736
+ params.append(limit)
737
+ return list(
738
+ connection.execute(
739
+ f"""
740
+ WITH filtered_sessions AS (
741
+ SELECT s.*
742
+ FROM sessions s
743
+ WHERE {" AND ".join(filters)}
744
+ )
745
+ SELECT fs.cwd,
746
+ COUNT(DISTINCT fs.session_key) AS sessions,
747
+ (
748
+ SELECT GROUP_CONCAT(provider)
749
+ FROM (
750
+ SELECT DISTINCT fs_provider.provider
751
+ FROM filtered_sessions fs_provider
752
+ WHERE fs_provider.cwd = fs.cwd
753
+ ORDER BY fs_provider.provider
754
+ )
755
+ ) AS providers,
756
+ (
757
+ SELECT GROUP_CONCAT(model)
758
+ FROM (
759
+ SELECT DISTINCT sm.model
760
+ FROM session_models sm
761
+ JOIN filtered_sessions fs_model ON fs_model.session_key = sm.session_key
762
+ WHERE fs_model.cwd = fs.cwd
763
+ ORDER BY sm.model
764
+ )
765
+ ) AS models,
766
+ MAX(COALESCE(fs.updated_at, fs.started_at, fs.imported_at)) AS updated_at
767
+ FROM filtered_sessions fs
768
+ GROUP BY fs.cwd
769
+ ORDER BY sessions DESC, updated_at DESC, fs.cwd ASC
770
+ LIMIT ?
771
+ """,
772
+ params,
773
+ )
774
+ )
775
+
776
+
777
+ def make_session_key(provider: str, session_id: str) -> str:
778
+ return f"{provider}:{session_id}"
779
+
780
+
781
+ def parent_session_key(session: ChatSession) -> str | None:
782
+ if not session.parent_session_id:
783
+ return None
784
+ return make_session_key(session.provider, session.parent_session_id)
785
+
786
+
787
+ def choose_session_key(connection: sqlite3.Connection, session: ChatSession) -> str:
788
+ base_key = make_session_key(session.provider, session.session_id)
789
+ existing = connection.execute(
790
+ "SELECT source_path FROM sessions WHERE session_key = ?",
791
+ (base_key,),
792
+ ).fetchone()
793
+ if not existing or existing["source_path"] == str(session.source_path):
794
+ return base_key
795
+
796
+ digest = hashlib.sha1(str(session.source_path).encode("utf-8")).hexdigest()[:10]
797
+ return f"{base_key}:{digest}"
798
+
799
+
800
+ def to_fts_query(query: str) -> str:
801
+ terms = [term for term in re.findall(r'"[^"]+"|\S+', query) if term.strip()]
802
+ if not terms:
803
+ return '""'
804
+
805
+ escaped: list[str] = []
806
+ for term in terms:
807
+ term = term.strip()
808
+ if term.startswith('"') and term.endswith('"') and len(term) > 1:
809
+ term = term[1:-1]
810
+ term = term.replace('"', '""')
811
+ escaped.append(f'"{term}"')
812
+ return " AND ".join(escaped)
813
+
814
+
815
+ def _like_search(
816
+ connection: sqlite3.Connection,
817
+ query: str,
818
+ *,
819
+ providers: Iterable[str],
820
+ cwd: str | None,
821
+ models: Iterable[str],
822
+ roles: Iterable[str],
823
+ include_children: bool,
824
+ limit: int,
825
+ ) -> list[sqlite3.Row]:
826
+ filters = ["m.text LIKE ?"]
827
+ params: list[object] = [f"%{query}%"]
828
+
829
+ provider_values = tuple(providers)
830
+ if provider_values:
831
+ filters.append(_in_clause("m.provider", provider_values, params))
832
+
833
+ role_values = tuple(roles)
834
+ if role_values:
835
+ filters.append(_in_clause("m.role", role_values, params))
836
+
837
+ if cwd:
838
+ filters.append("s.cwd LIKE ?")
839
+ params.append(f"%{cwd}%")
840
+ if not include_children:
841
+ filters.append(_root_session_filter("s"))
842
+ _append_model_filter(filters, params, models, session_column="m.session_key")
843
+
844
+ params.append(limit)
845
+ return list(
846
+ connection.execute(
847
+ f"""
848
+ SELECT
849
+ m.message_key, m.session_key, m.provider, m.role, m.timestamp,
850
+ m.ordinal, m.model, m.model_provider, m.text,
851
+ s.session_id, s.cwd, s.title, s.source_path,
852
+ s.parent_session_key, s.parent_session_id, s.thread_source, s.source_label,
853
+ (
854
+ SELECT GROUP_CONCAT(model)
855
+ FROM (
856
+ SELECT DISTINCT model
857
+ FROM session_models sm
858
+ WHERE sm.session_key = m.session_key
859
+ ORDER BY model
860
+ )
861
+ ) AS models,
862
+ m.text AS snippet
863
+ FROM messages m
864
+ JOIN sessions s ON s.session_key = m.session_key
865
+ WHERE {" AND ".join(filters)}
866
+ ORDER BY COALESCE(m.timestamp, s.updated_at) DESC
867
+ LIMIT ?
868
+ """,
869
+ params,
870
+ )
871
+ )
872
+
873
+
874
+ def _session_filters(
875
+ *,
876
+ providers: Iterable[str] | None,
877
+ cwd: str | None,
878
+ models: Iterable[str] | None,
879
+ include_children: bool,
880
+ ) -> tuple[str, list[object]]:
881
+ filters: list[str] = []
882
+ params: list[object] = []
883
+
884
+ provider_values = tuple(providers or ())
885
+ if provider_values:
886
+ filters.append(_in_clause("s.provider", provider_values, params))
887
+ if cwd:
888
+ filters.append("s.cwd LIKE ?")
889
+ params.append(f"%{cwd}%")
890
+ if not include_children:
891
+ filters.append(_root_session_filter("s"))
892
+ _append_model_filter(filters, params, models, session_column="s.session_key")
893
+
894
+ if not filters:
895
+ return "", params
896
+ return "WHERE " + " AND ".join(filters), params
897
+
898
+
899
+ def _root_session_filter(alias: str) -> str:
900
+ return f"({alias}.parent_session_key IS NULL OR {alias}.parent_session_key = '')"
901
+
902
+
903
+ def _append_model_filter(
904
+ filters: list[str],
905
+ params: list[object],
906
+ models: Iterable[str] | None,
907
+ *,
908
+ session_column: str,
909
+ ) -> None:
910
+ model_values = tuple(model for model in (models or ()) if model)
911
+ if not model_values:
912
+ return
913
+
914
+ model_params: list[object] = []
915
+ clause = _in_clause("sm_filter.model", model_values, model_params)
916
+ params.extend(model_params)
917
+ filters.append(
918
+ f"""
919
+ EXISTS (
920
+ SELECT 1
921
+ FROM session_models sm_filter
922
+ WHERE sm_filter.session_key = {session_column}
923
+ AND {clause}
924
+ )
925
+ """
926
+ )
927
+
928
+
929
+ def _in_clause(column: str, values: Iterable[str], params: list[object]) -> str:
930
+ clean_values = tuple(value for value in values if value)
931
+ if not clean_values:
932
+ return "1 = 1"
933
+ params.extend(clean_values)
934
+ placeholders = ", ".join("?" for _ in clean_values)
935
+ return f"{column} IN ({placeholders})"