code-memory 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
queries.py ADDED
@@ -0,0 +1,446 @@
1
+ """
2
+ Query layer for code-memory.
3
+
4
+ Provides hybrid retrieval (BM25 + dense vector) with Reciprocal Rank Fusion,
5
+ plus specialised query functions for definitions, references, and file
6
+ structure.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import struct
12
+ from typing import Any
13
+
14
+ import db as db_mod
15
+
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Hybrid search (BM25 + vector → RRF)
19
+ # ---------------------------------------------------------------------------
20
+
21
+ _RRF_K = 60 # standard RRF constant
22
+
23
+
24
+ def _bm25_search(query: str, db, top_k: int = 50) -> list[dict]:
25
+ """Run FTS5 BM25 search against ``symbols_fts``.
26
+
27
+ Returns a ranked list of dicts with ``symbol_id`` and ``bm25_score``.
28
+ """
29
+ # FTS5 MATCH query — escape double-quotes in user input
30
+ safe_query = query.replace('"', '""')
31
+ try:
32
+ rows = db.execute(
33
+ """
34
+ SELECT s.id, s.name, s.kind, f.path, s.line_start, s.line_end,
35
+ s.source_text, bm25(symbols_fts) AS score
36
+ FROM symbols_fts
37
+ JOIN symbols s ON s.id = symbols_fts.rowid
38
+ JOIN files f ON f.id = s.file_id
39
+ WHERE symbols_fts MATCH ?
40
+ ORDER BY score -- bm25() returns negative; lower = better
41
+ LIMIT ?
42
+ """,
43
+ (safe_query, top_k),
44
+ ).fetchall()
45
+ except Exception:
46
+ # FTS MATCH can fail on certain queries (e.g. operators only)
47
+ return []
48
+
49
+ return [
50
+ {
51
+ "symbol_id": r[0],
52
+ "name": r[1],
53
+ "kind": r[2],
54
+ "file_path": r[3],
55
+ "line_start": r[4],
56
+ "line_end": r[5],
57
+ "source_text": r[6],
58
+ "bm25_score": r[7],
59
+ }
60
+ for r in rows
61
+ ]
62
+
63
+
64
+ def _vector_search(query: str, db, top_k: int = 50) -> list[dict]:
65
+ """Run dense vector nearest-neighbour search via ``sqlite-vec``.
66
+
67
+ Returns a ranked list of dicts with ``symbol_id`` and ``vec_distance``.
68
+ """
69
+ query_vec = db_mod.embed_text(query)
70
+ query_blob = struct.pack(f"{len(query_vec)}f", *query_vec)
71
+
72
+ rows = db.execute(
73
+ """
74
+ SELECT se.symbol_id, se.distance,
75
+ s.name, s.kind, f.path, s.line_start, s.line_end, s.source_text
76
+ FROM symbol_embeddings se
77
+ JOIN symbols s ON s.id = se.symbol_id
78
+ JOIN files f ON f.id = s.file_id
79
+ WHERE se.embedding MATCH ?
80
+ AND se.k = ?
81
+ ORDER BY se.distance
82
+ """,
83
+ (query_blob, top_k),
84
+ ).fetchall()
85
+
86
+ return [
87
+ {
88
+ "symbol_id": r[0],
89
+ "vec_distance": r[1],
90
+ "name": r[2],
91
+ "kind": r[3],
92
+ "file_path": r[4],
93
+ "line_start": r[5],
94
+ "line_end": r[6],
95
+ "source_text": r[7],
96
+ }
97
+ for r in rows
98
+ ]
99
+
100
+
101
+ def hybrid_search(query: str, db, top_k: int = 10) -> list[dict]:
102
+ """Hybrid BM25 + vector search with Reciprocal Rank Fusion.
103
+
104
+ Runs both retrieval legs independently, then merges their ranked lists
105
+ using RRF: ``rrf_score(d) = Σ 1 / (k + rank(d))`` where ``k = 60``.
106
+
107
+ Args:
108
+ query: Free-text search query.
109
+ db: An open ``sqlite3.Connection`` from ``db.get_db()``.
110
+ top_k: Number of results to return.
111
+
112
+ Returns:
113
+ A list of result dicts sorted by descending RRF score.
114
+ """
115
+ bm25_results = _bm25_search(query, db, top_k=50)
116
+ vec_results = _vector_search(query, db, top_k=50)
117
+
118
+ # Build RRF score map keyed by symbol_id
119
+ scores: dict[int, float] = {}
120
+ details: dict[int, dict] = {}
121
+
122
+ for rank, r in enumerate(bm25_results, start=1):
123
+ sid = r["symbol_id"]
124
+ scores[sid] = scores.get(sid, 0.0) + 1.0 / (_RRF_K + rank)
125
+ details[sid] = {
126
+ "name": r["name"],
127
+ "kind": r["kind"],
128
+ "file_path": r["file_path"],
129
+ "line_start": r["line_start"],
130
+ "line_end": r["line_end"],
131
+ "source_text": r["source_text"],
132
+ }
133
+
134
+ for rank, r in enumerate(vec_results, start=1):
135
+ sid = r["symbol_id"]
136
+ scores[sid] = scores.get(sid, 0.0) + 1.0 / (_RRF_K + rank)
137
+ if sid not in details:
138
+ details[sid] = {
139
+ "name": r["name"],
140
+ "kind": r["kind"],
141
+ "file_path": r["file_path"],
142
+ "line_start": r["line_start"],
143
+ "line_end": r["line_end"],
144
+ "source_text": r["source_text"],
145
+ }
146
+
147
+ # Sort by descending RRF score
148
+ ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[:top_k]
149
+
150
+ return [
151
+ {**details[sid], "score": round(score, 6)}
152
+ for sid, score in ranked
153
+ ]
154
+
155
+
156
+ # ---------------------------------------------------------------------------
157
+ # Tool-facing query functions
158
+ # ---------------------------------------------------------------------------
159
+
160
+
161
+ def find_definition(symbol_name: str, db) -> list[dict]:
162
+ """Find where *symbol_name* is defined using hybrid search.
163
+
164
+ Post-filters for exact name matches first; falls back to top hybrid
165
+ results as "best guesses" if no exact match is found.
166
+
167
+ Args:
168
+ symbol_name: The name of the symbol to find.
169
+ db: An open ``sqlite3.Connection``.
170
+
171
+ Returns:
172
+ A list of result dicts.
173
+ """
174
+ results = hybrid_search(symbol_name, db, top_k=20)
175
+
176
+ # Exact-match filter (case-sensitive)
177
+ exact = [r for r in results if r["name"] == symbol_name]
178
+ if exact:
179
+ return exact
180
+
181
+ # Fallback: return top results as best guesses
182
+ return results[:5]
183
+
184
+
185
+ def find_references(symbol_name: str, db) -> list[dict]:
186
+ """Find all cross-references to *symbol_name*.
187
+
188
+ Queries the ``references_`` table for exact matches.
189
+
190
+ Args:
191
+ symbol_name: The name of the symbol to find references for.
192
+ db: An open ``sqlite3.Connection``.
193
+
194
+ Returns:
195
+ A list of dicts with ``symbol_name``, ``file_path``, ``line_number``.
196
+ """
197
+ rows = db.execute(
198
+ """
199
+ SELECT r.symbol_name, f.path, r.line_number
200
+ FROM references_ r
201
+ JOIN files f ON f.id = r.file_id
202
+ WHERE r.symbol_name = ?
203
+ ORDER BY f.path, r.line_number
204
+ """,
205
+ (symbol_name,),
206
+ ).fetchall()
207
+
208
+ return [
209
+ {"symbol_name": r[0], "file_path": r[1], "line_number": r[2]}
210
+ for r in rows
211
+ ]
212
+
213
+
214
+ def get_file_structure(file_path: str, db) -> list[dict]:
215
+ """List all symbols in a given file, ordered by line number.
216
+
217
+ Args:
218
+ file_path: Absolute (or matching) path to the file.
219
+ db: An open ``sqlite3.Connection``.
220
+
221
+ Returns:
222
+ A list of dicts with ``name``, ``kind``, ``line_start``, ``line_end``,
223
+ ``parent``.
224
+ """
225
+ import os
226
+
227
+ abs_path = os.path.abspath(file_path)
228
+
229
+ rows = db.execute(
230
+ """
231
+ SELECT s.name, s.kind, s.line_start, s.line_end,
232
+ p.name AS parent_name
233
+ FROM symbols s
234
+ JOIN files f ON f.id = s.file_id
235
+ LEFT JOIN symbols p ON p.id = s.parent_symbol_id
236
+ WHERE f.path = ?
237
+ ORDER BY s.line_start
238
+ """,
239
+ (abs_path,),
240
+ ).fetchall()
241
+
242
+ return [
243
+ {
244
+ "name": r[0],
245
+ "kind": r[1],
246
+ "line_start": r[2],
247
+ "line_end": r[3],
248
+ "parent": r[4],
249
+ }
250
+ for r in rows
251
+ ]
252
+
253
+
254
+ # ---------------------------------------------------------------------------
255
+ # Documentation search (Milestone 4)
256
+ # ---------------------------------------------------------------------------
257
+
258
+
259
+ def _doc_bm25_search(query: str, db, top_k: int = 50) -> list[dict]:
260
+ """Run FTS5 BM25 search against ``doc_chunks_fts``.
261
+
262
+ Returns a ranked list of dicts with chunk metadata and bm25_score.
263
+ """
264
+ safe_query = query.replace('"', '""')
265
+ try:
266
+ rows = db.execute(
267
+ """
268
+ SELECT dc.id, dc.section_title, dc.content, df.path, df.doc_type,
269
+ dc.line_start, dc.line_end, bm25(doc_chunks_fts) AS score
270
+ FROM doc_chunks_fts
271
+ JOIN doc_chunks dc ON dc.id = doc_chunks_fts.rowid
272
+ JOIN doc_files df ON df.id = dc.doc_file_id
273
+ WHERE doc_chunks_fts MATCH ?
274
+ ORDER BY score
275
+ LIMIT ?
276
+ """,
277
+ (safe_query, top_k),
278
+ ).fetchall()
279
+ except Exception:
280
+ return []
281
+
282
+ return [
283
+ {
284
+ "chunk_id": r[0],
285
+ "section_title": r[1],
286
+ "content": r[2],
287
+ "source_file": r[3],
288
+ "doc_type": r[4],
289
+ "line_start": r[5],
290
+ "line_end": r[6],
291
+ "bm25_score": r[7],
292
+ }
293
+ for r in rows
294
+ ]
295
+
296
+
297
+ def _doc_vector_search(query: str, db, top_k: int = 50) -> list[dict]:
298
+ """Run dense vector nearest-neighbour search on doc_embeddings."""
299
+ query_vec = db_mod.embed_text(query)
300
+ query_blob = struct.pack(f"{len(query_vec)}f", *query_vec)
301
+
302
+ rows = db.execute(
303
+ """
304
+ SELECT de.chunk_id, de.distance,
305
+ dc.section_title, dc.content, df.path, df.doc_type,
306
+ dc.line_start, dc.line_end
307
+ FROM doc_embeddings de
308
+ JOIN doc_chunks dc ON dc.id = de.chunk_id
309
+ JOIN doc_files df ON df.id = dc.doc_file_id
310
+ WHERE de.embedding MATCH ?
311
+ AND de.k = ?
312
+ ORDER BY de.distance
313
+ """,
314
+ (query_blob, top_k),
315
+ ).fetchall()
316
+
317
+ return [
318
+ {
319
+ "chunk_id": r[0],
320
+ "vec_distance": r[1],
321
+ "section_title": r[2],
322
+ "content": r[3],
323
+ "source_file": r[4],
324
+ "doc_type": r[5],
325
+ "line_start": r[6],
326
+ "line_end": r[7],
327
+ }
328
+ for r in rows
329
+ ]
330
+
331
+
332
+ def search_documentation(query: str, db, top_k: int = 10,
333
+ include_context: bool = False) -> list[dict]:
334
+ """Perform hybrid search over documentation chunks.
335
+
336
+ Uses BM25 + vector search with Reciprocal Rank Fusion.
337
+
338
+ Args:
339
+ query: Natural language query.
340
+ db: Database connection.
341
+ top_k: Maximum results to return.
342
+ include_context: If True, include adjacent chunks for context.
343
+
344
+ Returns:
345
+ List of matching chunks with source attribution and RRF scores.
346
+ """
347
+ bm25_results = _doc_bm25_search(query, db, top_k=50)
348
+ vec_results = _doc_vector_search(query, db, top_k=50)
349
+
350
+ # Build RRF score map keyed by chunk_id
351
+ scores: dict[int, float] = {}
352
+ details: dict[int, dict] = {}
353
+
354
+ for rank, r in enumerate(bm25_results, start=1):
355
+ cid = r["chunk_id"]
356
+ scores[cid] = scores.get(cid, 0.0) + 1.0 / (_RRF_K + rank)
357
+ details[cid] = {
358
+ "content": r["content"],
359
+ "source_file": r["source_file"],
360
+ "section_title": r["section_title"],
361
+ "line_start": r["line_start"],
362
+ "line_end": r["line_end"],
363
+ "doc_type": r["doc_type"],
364
+ }
365
+
366
+ for rank, r in enumerate(vec_results, start=1):
367
+ cid = r["chunk_id"]
368
+ scores[cid] = scores.get(cid, 0.0) + 1.0 / (_RRF_K + rank)
369
+ if cid not in details:
370
+ details[cid] = {
371
+ "content": r["content"],
372
+ "source_file": r["source_file"],
373
+ "section_title": r["section_title"],
374
+ "line_start": r["line_start"],
375
+ "line_end": r["line_end"],
376
+ "doc_type": r["doc_type"],
377
+ }
378
+
379
+ # Sort by descending RRF score
380
+ ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[:top_k]
381
+
382
+ results = [
383
+ {**details[cid], "score": round(score, 6)}
384
+ for cid, score in ranked
385
+ ]
386
+
387
+ # Optionally include adjacent chunks for context
388
+ if include_context and results:
389
+ results = _add_context_chunks(results, db)
390
+
391
+ return results
392
+
393
+
394
+ def _add_context_chunks(results: list[dict], db) -> list[dict]:
395
+ """Add adjacent chunks to results for additional context."""
396
+ enriched = []
397
+
398
+ for result in results:
399
+ # Get the chunk's file and index
400
+ row = db.execute(
401
+ """
402
+ SELECT dc.chunk_index, dc.doc_file_id
403
+ FROM doc_chunks dc
404
+ JOIN doc_files df ON df.id = dc.doc_file_id
405
+ WHERE df.path = ? AND dc.line_start = ? AND dc.line_end = ?
406
+ """,
407
+ (result["source_file"], result["line_start"], result["line_end"]),
408
+ ).fetchone()
409
+
410
+ if not row:
411
+ enriched.append(result)
412
+ continue
413
+
414
+ chunk_index, doc_file_id = row
415
+
416
+ # Get previous and next chunks
417
+ context_parts = []
418
+
419
+ prev = db.execute(
420
+ """
421
+ SELECT content FROM doc_chunks
422
+ WHERE doc_file_id = ? AND chunk_index = ?
423
+ """,
424
+ (doc_file_id, chunk_index - 1),
425
+ ).fetchone()
426
+ if prev:
427
+ context_parts.append({"type": "previous", "content": prev[0][:200]})
428
+
429
+ context_parts.append({"type": "current", "content": result["content"]})
430
+
431
+ next_chunk = db.execute(
432
+ """
433
+ SELECT content FROM doc_chunks
434
+ WHERE doc_file_id = ? AND chunk_index = ?
435
+ """,
436
+ (doc_file_id, chunk_index + 1),
437
+ ).fetchone()
438
+ if next_chunk:
439
+ context_parts.append({"type": "next", "content": next_chunk[0][:200]})
440
+
441
+ enriched.append({
442
+ **result,
443
+ "context": context_parts,
444
+ })
445
+
446
+ return enriched