threadlens 1.0.0 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,652 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import re
5
- import sqlite3
6
- import time
7
- from pathlib import Path
8
- from typing import Any
9
-
10
- from .models import ThreadMessage
11
- from .paths import ensure_private_storage_path
12
-
13
-
14
- SCHEMA = """
15
- create table if not exists messages (
16
- id integer primary key,
17
- doc_key text not null unique,
18
- source text not null,
19
- thread_id text not null,
20
- message_id text not null,
21
- path text not null,
22
- line integer not null,
23
- timestamp text not null,
24
- role text not null,
25
- cwd text not null,
26
- title text not null,
27
- text text not null,
28
- metadata_json text not null
29
- );
30
-
31
- create table if not exists indexed_files (
32
- source text not null,
33
- path text not null,
34
- mtime_ns integer not null,
35
- size integer not null,
36
- message_count integer not null,
37
- indexed_at text not null default current_timestamp,
38
- primary key (source, path)
39
- );
40
-
41
- create virtual table if not exists messages_fts using fts5(
42
- text,
43
- title,
44
- cwd,
45
- source,
46
- role,
47
- content='messages',
48
- content_rowid='id'
49
- );
50
- """
51
-
52
-
53
- class ThreadStore:
54
- def __init__(self, db_path: Path):
55
- self.db_path = db_path
56
- ensure_private_storage_path(self.db_path)
57
- self.conn = sqlite3.connect(str(db_path))
58
- self.conn.row_factory = sqlite3.Row
59
- self.conn.executescript(SCHEMA)
60
- ensure_private_storage_path(self.db_path)
61
-
62
- def close(self) -> None:
63
- self.conn.close()
64
-
65
- def reset(self) -> None:
66
- self.conn.execute("delete from messages")
67
- self.conn.execute("delete from indexed_files")
68
- self.rebuild_fts()
69
- self.conn.commit()
70
-
71
- def delete_sources(self, sources: list[str]) -> None:
72
- if not sources:
73
- return
74
- placeholders = ",".join("?" for _ in sources)
75
- self.conn.execute(f"delete from messages where source in ({placeholders})", sources)
76
- self.conn.execute(f"delete from indexed_files where source in ({placeholders})", sources)
77
- self.rebuild_fts()
78
- self.conn.commit()
79
-
80
- def delete_file(self, source: str, path: Path) -> None:
81
- self.conn.execute(
82
- "delete from messages where source = ? and path = ?",
83
- (source, str(path)),
84
- )
85
- self.conn.execute(
86
- "delete from indexed_files where source = ? and path = ?",
87
- (source, str(path)),
88
- )
89
-
90
- def file_is_current(self, source: str, path: Path, *, mtime_ns: int, size: int) -> bool:
91
- row = self.conn.execute(
92
- """
93
- select 1
94
- from indexed_files
95
- where source = ? and path = ? and mtime_ns = ? and size = ?
96
- """,
97
- (source, str(path), mtime_ns, size),
98
- ).fetchone()
99
- return row is not None
100
-
101
- def mark_file_indexed(self, source: str, path: Path, *, mtime_ns: int, size: int, message_count: int) -> None:
102
- self.conn.execute(
103
- """
104
- insert or replace into indexed_files (
105
- source, path, mtime_ns, size, message_count, indexed_at
106
- )
107
- values (?, ?, ?, ?, ?, current_timestamp)
108
- """,
109
- (source, str(path), mtime_ns, size, message_count),
110
- )
111
-
112
- def add_messages(self, messages: list[ThreadMessage], *, rebuild: bool = True, commit: bool = True) -> int:
113
- if not messages:
114
- return 0
115
- rows = [
116
- (
117
- message.doc_key,
118
- message.source,
119
- message.thread_id,
120
- message.message_id,
121
- str(message.path),
122
- message.line,
123
- message.timestamp,
124
- message.role,
125
- message.cwd,
126
- message.title,
127
- message.text,
128
- json.dumps(message.metadata, sort_keys=True),
129
- )
130
- for message in messages
131
- ]
132
- self.conn.executemany(
133
- """
134
- insert or replace into messages (
135
- doc_key, source, thread_id, message_id, path, line, timestamp,
136
- role, cwd, title, text, metadata_json
137
- )
138
- values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
139
- """,
140
- rows,
141
- )
142
- if rebuild:
143
- self.rebuild_fts()
144
- if commit:
145
- self.conn.commit()
146
- return len(rows)
147
-
148
- def rebuild_fts(self) -> None:
149
- self.conn.execute("insert into messages_fts(messages_fts) values('rebuild')")
150
-
151
- def stats(self) -> list[sqlite3.Row]:
152
- return list(
153
- self.conn.execute(
154
- """
155
- select source, count(*) as messages, count(distinct thread_id) as threads
156
- from messages
157
- group by source
158
- order by source
159
- """
160
- )
161
- )
162
-
163
- def message_count(self, source: str | None = None) -> int:
164
- if source:
165
- row = self.conn.execute(
166
- "select count(*) as messages from messages where source = ?",
167
- (source,),
168
- ).fetchone()
169
- else:
170
- row = self.conn.execute("select count(*) as messages from messages").fetchone()
171
- return int(row["messages"] if row else 0)
172
-
173
- def search(
174
- self,
175
- query: str,
176
- *,
177
- limit: int = 10,
178
- source: str | None = None,
179
- raw_fts: bool = False,
180
- ) -> list[sqlite3.Row]:
181
- fts_query = query if raw_fts else make_fts_query(query)
182
- if not fts_query:
183
- return []
184
-
185
- params: list[Any] = [fts_query]
186
- where = "messages_fts match ?"
187
- if source:
188
- where += " and messages.source = ?"
189
- params.append(source)
190
- params.append(limit)
191
-
192
- sql = f"""
193
- select
194
- messages.*,
195
- snippet(messages_fts, 0, '[', ']', '...', 28) as snippet,
196
- bm25(messages_fts) as rank
197
- from messages_fts
198
- join messages on messages_fts.rowid = messages.id
199
- where {where}
200
- order by rank, timestamp desc
201
- limit ?
202
- """
203
- return list(self.conn.execute(sql, params))
204
-
205
- def search_sessions(
206
- self,
207
- query: str,
208
- *,
209
- limit: int = 10,
210
- source: str | None = None,
211
- cwd_prefix: str | None = None,
212
- ) -> list[dict[str, Any]]:
213
- tokens = tokenize_query(query)
214
- if not tokens or limit <= 0:
215
- return []
216
-
217
- candidates: list[dict[str, Any]] = []
218
- row_limit = max(100, limit * 30)
219
-
220
- stages = [
221
- (make_fts_query_from_tokens(tokens, operator="AND", prefix=False), "exact", 100.0),
222
- (make_fts_query_from_tokens(tokens, operator="AND", prefix=True), "prefix", 80.0),
223
- (make_fts_query_from_tokens(tokens, operator="OR", prefix=False), "any", 55.0),
224
- ]
225
-
226
- sessions: list[dict[str, Any]] = []
227
- for fts_query, stage, base_score in stages:
228
- if not fts_query:
229
- continue
230
- candidates.extend(
231
- self._search_message_candidates(
232
- fts_query,
233
- tokens,
234
- source=source,
235
- cwd_prefix=cwd_prefix,
236
- stage=stage,
237
- base_score=base_score,
238
- limit=row_limit,
239
- )
240
- )
241
- sessions = sorted_sessions(candidates, tokens=tokens)
242
- if stage == "any" and sessions:
243
- return sessions[:limit]
244
- if not sessions_need_fallback(sessions, tokens, limit):
245
- return sessions[:limit]
246
-
247
- if not sessions:
248
- candidates.extend(
249
- self._fuzzy_message_candidates(
250
- tokens,
251
- source=source,
252
- cwd_prefix=cwd_prefix,
253
- limit=row_limit,
254
- )
255
- )
256
- return sorted_sessions(candidates, tokens=tokens)[:limit]
257
-
258
- def _search_message_candidates(
259
- self,
260
- fts_query: str,
261
- tokens: list[str],
262
- *,
263
- source: str | None,
264
- cwd_prefix: str | None,
265
- stage: str,
266
- base_score: float,
267
- limit: int,
268
- ) -> list[dict[str, Any]]:
269
- params: list[Any] = [fts_query]
270
- where = "messages_fts match ?"
271
- if source:
272
- where += " and messages.source = ?"
273
- params.append(source)
274
- if cwd_prefix:
275
- where += " and (messages.cwd = ? or messages.cwd like ?)"
276
- params.extend([cwd_prefix, f"{cwd_prefix.rstrip('/')}/%"])
277
- params.append(limit)
278
-
279
- sql = f"""
280
- select
281
- messages.*,
282
- snippet(messages_fts, 0, '[', ']', '...', 32) as snippet,
283
- bm25(messages_fts) as rank
284
- from messages_fts
285
- join messages on messages_fts.rowid = messages.id
286
- where {where}
287
- order by rank, timestamp desc
288
- limit ?
289
- """
290
- rows = self.conn.execute(sql, params)
291
-
292
- candidates = []
293
- for row in rows:
294
- row_dict = dict(row)
295
- rank = float(row_dict.get("rank") or 0)
296
- text_for_matching = " ".join(
297
- [
298
- row_dict.get("title") or "",
299
- row_dict.get("cwd") or "",
300
- row_dict.get("text") or "",
301
- ]
302
- )
303
- matched_terms = match_terms(tokens, text_for_matching, fuzzy=False)
304
- score = (
305
- base_score
306
- + min(25.0, max(0.0, -rank * 5.0))
307
- + recency_boost(row_dict.get("timestamp") or "")
308
- + project_boost(tokens, row_dict)
309
- + ordered_span_boost(tokens, text_for_matching)
310
- )
311
- candidates.append(
312
- {
313
- "row": row_dict,
314
- "snippet": row_dict.get("snippet") or make_plain_snippet(row_dict.get("text") or "", tokens),
315
- "score": score,
316
- "stage": stage,
317
- "matched_terms": matched_terms,
318
- }
319
- )
320
- return candidates
321
-
322
- def _fuzzy_message_candidates(
323
- self,
324
- tokens: list[str],
325
- *,
326
- source: str | None,
327
- cwd_prefix: str | None,
328
- limit: int,
329
- ) -> list[dict[str, Any]]:
330
- if limit <= 0:
331
- return []
332
-
333
- params: list[Any] = []
334
- where = ""
335
- if source:
336
- where = "where source = ?"
337
- params.append(source)
338
- if cwd_prefix:
339
- clause = "(cwd = ? or cwd like ?)"
340
- if where:
341
- where += f" and {clause}"
342
- else:
343
- where = f"where {clause}"
344
- params.extend([cwd_prefix, f"{cwd_prefix.rstrip('/')}/%"])
345
- sql = f"""
346
- select *
347
- from messages
348
- {where}
349
- order by timestamp desc
350
- """
351
- candidates: list[dict[str, Any]] = []
352
- rows = self.conn.execute(sql, params)
353
-
354
- for row in rows:
355
- row_dict = dict(row)
356
- text_for_matching = " ".join(
357
- [
358
- row_dict.get("title") or "",
359
- row_dict.get("cwd") or "",
360
- row_dict.get("text") or "",
361
- ]
362
- )
363
- matched_terms = match_terms(tokens, text_for_matching, fuzzy=True)
364
- if not matched_terms:
365
- continue
366
- ratio = len(matched_terms) / max(1, len(tokens))
367
- if ratio < 0.5:
368
- continue
369
- score = (
370
- 35.0
371
- + (ratio * 20.0)
372
- + recency_boost(row_dict.get("timestamp") or "")
373
- + project_boost(tokens, row_dict)
374
- + ordered_span_boost(tokens, text_for_matching)
375
- )
376
- candidates.append(
377
- {
378
- "row": row_dict,
379
- "snippet": make_plain_snippet(row_dict.get("text") or "", tokens),
380
- "score": score,
381
- "stage": "fuzzy",
382
- "matched_terms": matched_terms,
383
- }
384
- )
385
- if len(candidates) >= limit:
386
- break
387
- return candidates
388
-
389
- def get_session(self, source: str, thread_id: str) -> list[sqlite3.Row]:
390
- return list(
391
- self.conn.execute(
392
- """
393
- select *
394
- from messages
395
- where source = ? and thread_id = ?
396
- order by timestamp, id
397
- """,
398
- (source, thread_id),
399
- )
400
- )
401
-
402
- def find_sessions(self, thread_id: str, source: str | None = None) -> list[sqlite3.Row]:
403
- params: list[Any] = [thread_id]
404
- where = "thread_id = ?"
405
- if source:
406
- where += " and source = ?"
407
- params.append(source)
408
- return list(
409
- self.conn.execute(
410
- f"""
411
- select source, thread_id, max(timestamp) as last_timestamp, max(cwd) as cwd, max(title) as title, count(*) as messages
412
- from messages
413
- where {where}
414
- group by source, thread_id
415
- order by last_timestamp desc
416
- """,
417
- params,
418
- )
419
- )
420
-
421
-
422
- def make_fts_query(query: str) -> str:
423
- tokens = tokenize_query(query)
424
- return make_fts_query_from_tokens(tokens, operator="AND", prefix=False)
425
-
426
-
427
- def tokenize_query(query: str) -> list[str]:
428
- tokens = re.findall(r"[A-Za-z0-9_]+", query.lower(), flags=re.UNICODE)
429
- safe_tokens = [token for token in tokens if token.strip() and (len(token) > 1 or not token.isdigit())]
430
- return safe_tokens
431
-
432
-
433
- def make_fts_query_from_tokens(tokens: list[str], *, operator: str, prefix: bool) -> str:
434
- safe_tokens = []
435
- for token in tokens:
436
- if not re.fullmatch(r"[A-Za-z0-9_]+", token):
437
- continue
438
- if prefix and len(token) >= 2:
439
- safe_tokens.append(f"{token}*")
440
- else:
441
- safe_tokens.append(token)
442
- return f" {operator} ".join(safe_tokens)
443
-
444
-
445
- def group_candidate_sessions(candidates: list[dict[str, Any]]) -> list[dict[str, Any]]:
446
- grouped: dict[tuple[str, str], dict[str, Any]] = {}
447
- seen_messages: set[tuple[str, str, str]] = set()
448
- for candidate in candidates:
449
- row = candidate["row"]
450
- key = (row["source"], row["thread_id"])
451
- message_key = (row["source"], row["thread_id"], row["message_id"])
452
- if message_key in seen_messages:
453
- continue
454
- seen_messages.add(message_key)
455
-
456
- result = grouped.setdefault(
457
- key,
458
- {
459
- "result_id": f"{row['source']}:{row['thread_id']}",
460
- "source": row["source"],
461
- "session_id": row["thread_id"],
462
- "thread_id": row["thread_id"],
463
- "cwd": row["cwd"],
464
- "title": row["title"],
465
- "last_timestamp": row["timestamp"],
466
- "score": 0.0,
467
- "matched_terms": set(),
468
- "best_snippets": [],
469
- "source_path": row["path"],
470
- "source_line": row["line"],
471
- "actions": {},
472
- "_message_count": 0,
473
- "_best_score": 0.0,
474
- },
475
- )
476
- result["_message_count"] += 1
477
- result["_best_score"] = max(result["_best_score"], candidate["score"])
478
- result["matched_terms"].update(candidate.get("matched_terms") or [])
479
- if row["timestamp"] > result["last_timestamp"]:
480
- result["last_timestamp"] = row["timestamp"]
481
- if not result["cwd"] and row["cwd"]:
482
- result["cwd"] = row["cwd"]
483
- if not result["title"] and row["title"]:
484
- result["title"] = row["title"]
485
- result["best_snippets"].append(
486
- {
487
- "message_id": row["message_id"],
488
- "timestamp": row["timestamp"],
489
- "role": row["role"],
490
- "snippet": candidate["snippet"],
491
- "source_path": row["path"],
492
- "source_line": row["line"],
493
- "match_type": candidate["stage"],
494
- "score": round(candidate["score"], 4),
495
- }
496
- )
497
-
498
- results = []
499
- for result in grouped.values():
500
- result["matched_terms"] = sorted(result["matched_terms"])
501
- result["best_snippets"] = sorted(result["best_snippets"], key=lambda item: item["score"], reverse=True)[:3]
502
- message_boost = min(5.0, result["_message_count"] * 0.35)
503
- term_boost = len(result["matched_terms"]) * 12.0
504
- result["score"] = round(float(result["_best_score"] + message_boost + term_boost), 4)
505
- result.pop("_message_count", None)
506
- result.pop("_best_score", None)
507
- results.append(result)
508
- return results
509
-
510
-
511
- def sorted_sessions(candidates: list[dict[str, Any]], *, tokens: list[str] | None = None) -> list[dict[str, Any]]:
512
- sessions = group_candidate_sessions(candidates)
513
- if tokens:
514
- sessions.sort(key=lambda result: (token_coverage(result), result["score"]), reverse=True)
515
- else:
516
- sessions.sort(key=lambda result: result["score"], reverse=True)
517
- return sessions
518
-
519
-
520
- def sessions_need_fallback(sessions: list[dict[str, Any]], tokens: list[str], limit: int) -> bool:
521
- if not sessions:
522
- return True
523
- if len(sessions) < limit:
524
- return True
525
- required_coverage = len(set(tokens))
526
- if required_coverage == 0:
527
- return False
528
- return all(token_coverage(result) < required_coverage for result in sessions[:limit])
529
-
530
-
531
- def token_coverage(result: dict[str, Any]) -> int:
532
- return len(set(result.get("matched_terms") or []))
533
-
534
-
535
- def match_terms(tokens: list[str], text: str, *, fuzzy: bool) -> list[str]:
536
- text_lower = text.lower()
537
- words = set(re.findall(r"[A-Za-z0-9_]+", text_lower))
538
- matched = []
539
- for token in tokens:
540
- if token in text_lower or any(word.startswith(token) for word in words):
541
- matched.append(token)
542
- continue
543
- if fuzzy and fuzzy_contains(token, words):
544
- matched.append(token)
545
- return matched
546
-
547
-
548
- def fuzzy_contains(token: str, words: set[str]) -> bool:
549
- if len(token) < 3:
550
- return False
551
- max_distance = 1 if len(token) <= 4 else 2
552
- for word in words:
553
- if not word or word[0] != token[0]:
554
- continue
555
- if abs(len(word) - len(token)) > max_distance:
556
- continue
557
- if bounded_levenshtein(token, word, max_distance) <= max_distance:
558
- return True
559
- return False
560
-
561
-
562
- def bounded_levenshtein(left: str, right: str, max_distance: int) -> int:
563
- previous_previous: list[int] | None = None
564
- previous = list(range(len(right) + 1))
565
- for i, left_char in enumerate(left, 1):
566
- current = [i]
567
- row_min = i
568
- for j, right_char in enumerate(right, 1):
569
- insert = current[j - 1] + 1
570
- delete = previous[j] + 1
571
- replace = previous[j - 1] + (left_char != right_char)
572
- value = min(insert, delete, replace)
573
- if (
574
- previous_previous is not None
575
- and i > 1
576
- and j > 1
577
- and left[i - 1] == right[j - 2]
578
- and left[i - 2] == right[j - 1]
579
- ):
580
- value = min(value, previous_previous[j - 2] + 1)
581
- current.append(value)
582
- row_min = min(row_min, value)
583
- if row_min > max_distance:
584
- return max_distance + 1
585
- previous_previous = previous
586
- previous = current
587
- return previous[-1]
588
-
589
-
590
- def recency_boost(timestamp: str) -> float:
591
- if not timestamp:
592
- return 0.0
593
- normalized = timestamp.replace("Z", "+00:00")
594
- try:
595
- parsed = time.mktime(time.strptime(normalized[:19], "%Y-%m-%dT%H:%M:%S"))
596
- except ValueError:
597
- return 0.0
598
- age_days = max(0.0, (time.time() - parsed) / 86400)
599
- return max(0.0, 8.0 - (age_days / 14.0))
600
-
601
-
602
- def project_boost(tokens: list[str], row: dict[str, Any]) -> float:
603
- haystack = f"{row.get('title') or ''} {row.get('cwd') or ''}".lower()
604
- if not haystack:
605
- return 0.0
606
- matched = sum(1 for token in tokens if token in haystack)
607
- return min(12.0, matched * 4.0)
608
-
609
-
610
- def ordered_span_boost(tokens: list[str], text: str) -> float:
611
- if len(tokens) < 2:
612
- return 0.0
613
-
614
- words = re.findall(r"[A-Za-z0-9_]+", text.lower())
615
- if not words:
616
- return 0.0
617
-
618
- positions: list[int] = []
619
- start = 0
620
- for token in tokens:
621
- for index in range(start, len(words)):
622
- if words[index] == token or words[index].startswith(token):
623
- positions.append(index)
624
- start = index + 1
625
- break
626
- else:
627
- return 0.0
628
-
629
- span = positions[-1] - positions[0] + 1
630
- slack = span - len(tokens)
631
- if slack <= 2:
632
- return 18.0
633
- if slack <= 6:
634
- return 12.0
635
- if slack <= 12:
636
- return 6.0
637
- return 0.0
638
-
639
-
640
- def make_plain_snippet(text: str, tokens: list[str], *, radius: int = 90) -> str:
641
- compact = re.sub(r"\s+", " ", text).strip()
642
- if not compact:
643
- return ""
644
- lowered = compact.lower()
645
- positions = [lowered.find(token) for token in tokens if lowered.find(token) >= 0]
646
- if not positions:
647
- return compact[: radius * 2]
648
- start = max(0, min(positions) - radius)
649
- end = min(len(compact), min(positions) + radius)
650
- prefix = "..." if start > 0 else ""
651
- suffix = "..." if end < len(compact) else ""
652
- return prefix + compact[start:end] + suffix