code-memory 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- .github/workflows/ci.yml +71 -0
- .github/workflows/publish.yml +33 -0
- .gitignore +40 -0
- .python-version +1 -0
- CHANGELOG.md +43 -0
- CONTRIBUTING.md +133 -0
- LICENSE +21 -0
- Makefile +33 -0
- PKG-INFO +275 -0
- README.md +233 -0
- code_memory-0.1.0.dist-info/METADATA +275 -0
- code_memory-0.1.0.dist-info/RECORD +37 -0
- code_memory-0.1.0.dist-info/WHEEL +4 -0
- code_memory-0.1.0.dist-info/entry_points.txt +2 -0
- code_memory-0.1.0.dist-info/licenses/LICENSE +21 -0
- db.py +403 -0
- doc_parser.py +494 -0
- errors.py +115 -0
- git_search.py +313 -0
- logging_config.py +191 -0
- parser.py +392 -0
- prompts/milestone_1.xml +62 -0
- prompts/milestone_2.xml +246 -0
- prompts/milestone_3.xml +214 -0
- prompts/milestone_4.xml +453 -0
- prompts/milestone_5.xml +599 -0
- pyproject.toml +92 -0
- queries.py +446 -0
- server.py +299 -0
- tests/__init__.py +1 -0
- tests/conftest.py +192 -0
- tests/test_errors.py +112 -0
- tests/test_logging.py +169 -0
- tests/test_tools.py +114 -0
- tests/test_validation.py +216 -0
- uv.lock +1921 -0
- validation.py +316 -0
queries.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Query layer for code-memory.
|
|
3
|
+
|
|
4
|
+
Provides hybrid retrieval (BM25 + dense vector) with Reciprocal Rank Fusion,
|
|
5
|
+
plus specialised query functions for definitions, references, and file
|
|
6
|
+
structure.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import struct
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import db as db_mod
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# ---------------------------------------------------------------------------
|
|
18
|
+
# Hybrid search (BM25 + vector → RRF)
|
|
19
|
+
# ---------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
_RRF_K = 60 # standard RRF constant
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _bm25_search(query: str, db, top_k: int = 50) -> list[dict]:
|
|
25
|
+
"""Run FTS5 BM25 search against ``symbols_fts``.
|
|
26
|
+
|
|
27
|
+
Returns a ranked list of dicts with ``symbol_id`` and ``bm25_score``.
|
|
28
|
+
"""
|
|
29
|
+
# FTS5 MATCH query — escape double-quotes in user input
|
|
30
|
+
safe_query = query.replace('"', '""')
|
|
31
|
+
try:
|
|
32
|
+
rows = db.execute(
|
|
33
|
+
"""
|
|
34
|
+
SELECT s.id, s.name, s.kind, f.path, s.line_start, s.line_end,
|
|
35
|
+
s.source_text, bm25(symbols_fts) AS score
|
|
36
|
+
FROM symbols_fts
|
|
37
|
+
JOIN symbols s ON s.id = symbols_fts.rowid
|
|
38
|
+
JOIN files f ON f.id = s.file_id
|
|
39
|
+
WHERE symbols_fts MATCH ?
|
|
40
|
+
ORDER BY score -- bm25() returns negative; lower = better
|
|
41
|
+
LIMIT ?
|
|
42
|
+
""",
|
|
43
|
+
(safe_query, top_k),
|
|
44
|
+
).fetchall()
|
|
45
|
+
except Exception:
|
|
46
|
+
# FTS MATCH can fail on certain queries (e.g. operators only)
|
|
47
|
+
return []
|
|
48
|
+
|
|
49
|
+
return [
|
|
50
|
+
{
|
|
51
|
+
"symbol_id": r[0],
|
|
52
|
+
"name": r[1],
|
|
53
|
+
"kind": r[2],
|
|
54
|
+
"file_path": r[3],
|
|
55
|
+
"line_start": r[4],
|
|
56
|
+
"line_end": r[5],
|
|
57
|
+
"source_text": r[6],
|
|
58
|
+
"bm25_score": r[7],
|
|
59
|
+
}
|
|
60
|
+
for r in rows
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _vector_search(query: str, db, top_k: int = 50) -> list[dict]:
|
|
65
|
+
"""Run dense vector nearest-neighbour search via ``sqlite-vec``.
|
|
66
|
+
|
|
67
|
+
Returns a ranked list of dicts with ``symbol_id`` and ``vec_distance``.
|
|
68
|
+
"""
|
|
69
|
+
query_vec = db_mod.embed_text(query)
|
|
70
|
+
query_blob = struct.pack(f"{len(query_vec)}f", *query_vec)
|
|
71
|
+
|
|
72
|
+
rows = db.execute(
|
|
73
|
+
"""
|
|
74
|
+
SELECT se.symbol_id, se.distance,
|
|
75
|
+
s.name, s.kind, f.path, s.line_start, s.line_end, s.source_text
|
|
76
|
+
FROM symbol_embeddings se
|
|
77
|
+
JOIN symbols s ON s.id = se.symbol_id
|
|
78
|
+
JOIN files f ON f.id = s.file_id
|
|
79
|
+
WHERE se.embedding MATCH ?
|
|
80
|
+
AND se.k = ?
|
|
81
|
+
ORDER BY se.distance
|
|
82
|
+
""",
|
|
83
|
+
(query_blob, top_k),
|
|
84
|
+
).fetchall()
|
|
85
|
+
|
|
86
|
+
return [
|
|
87
|
+
{
|
|
88
|
+
"symbol_id": r[0],
|
|
89
|
+
"vec_distance": r[1],
|
|
90
|
+
"name": r[2],
|
|
91
|
+
"kind": r[3],
|
|
92
|
+
"file_path": r[4],
|
|
93
|
+
"line_start": r[5],
|
|
94
|
+
"line_end": r[6],
|
|
95
|
+
"source_text": r[7],
|
|
96
|
+
}
|
|
97
|
+
for r in rows
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def hybrid_search(query: str, db, top_k: int = 10) -> list[dict]:
|
|
102
|
+
"""Hybrid BM25 + vector search with Reciprocal Rank Fusion.
|
|
103
|
+
|
|
104
|
+
Runs both retrieval legs independently, then merges their ranked lists
|
|
105
|
+
using RRF: ``rrf_score(d) = Σ 1 / (k + rank(d))`` where ``k = 60``.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
query: Free-text search query.
|
|
109
|
+
db: An open ``sqlite3.Connection`` from ``db.get_db()``.
|
|
110
|
+
top_k: Number of results to return.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
A list of result dicts sorted by descending RRF score.
|
|
114
|
+
"""
|
|
115
|
+
bm25_results = _bm25_search(query, db, top_k=50)
|
|
116
|
+
vec_results = _vector_search(query, db, top_k=50)
|
|
117
|
+
|
|
118
|
+
# Build RRF score map keyed by symbol_id
|
|
119
|
+
scores: dict[int, float] = {}
|
|
120
|
+
details: dict[int, dict] = {}
|
|
121
|
+
|
|
122
|
+
for rank, r in enumerate(bm25_results, start=1):
|
|
123
|
+
sid = r["symbol_id"]
|
|
124
|
+
scores[sid] = scores.get(sid, 0.0) + 1.0 / (_RRF_K + rank)
|
|
125
|
+
details[sid] = {
|
|
126
|
+
"name": r["name"],
|
|
127
|
+
"kind": r["kind"],
|
|
128
|
+
"file_path": r["file_path"],
|
|
129
|
+
"line_start": r["line_start"],
|
|
130
|
+
"line_end": r["line_end"],
|
|
131
|
+
"source_text": r["source_text"],
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
for rank, r in enumerate(vec_results, start=1):
|
|
135
|
+
sid = r["symbol_id"]
|
|
136
|
+
scores[sid] = scores.get(sid, 0.0) + 1.0 / (_RRF_K + rank)
|
|
137
|
+
if sid not in details:
|
|
138
|
+
details[sid] = {
|
|
139
|
+
"name": r["name"],
|
|
140
|
+
"kind": r["kind"],
|
|
141
|
+
"file_path": r["file_path"],
|
|
142
|
+
"line_start": r["line_start"],
|
|
143
|
+
"line_end": r["line_end"],
|
|
144
|
+
"source_text": r["source_text"],
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
# Sort by descending RRF score
|
|
148
|
+
ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[:top_k]
|
|
149
|
+
|
|
150
|
+
return [
|
|
151
|
+
{**details[sid], "score": round(score, 6)}
|
|
152
|
+
for sid, score in ranked
|
|
153
|
+
]
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# ---------------------------------------------------------------------------
|
|
157
|
+
# Tool-facing query functions
|
|
158
|
+
# ---------------------------------------------------------------------------
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def find_definition(symbol_name: str, db) -> list[dict]:
|
|
162
|
+
"""Find where *symbol_name* is defined using hybrid search.
|
|
163
|
+
|
|
164
|
+
Post-filters for exact name matches first; falls back to top hybrid
|
|
165
|
+
results as "best guesses" if no exact match is found.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
symbol_name: The name of the symbol to find.
|
|
169
|
+
db: An open ``sqlite3.Connection``.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
A list of result dicts.
|
|
173
|
+
"""
|
|
174
|
+
results = hybrid_search(symbol_name, db, top_k=20)
|
|
175
|
+
|
|
176
|
+
# Exact-match filter (case-sensitive)
|
|
177
|
+
exact = [r for r in results if r["name"] == symbol_name]
|
|
178
|
+
if exact:
|
|
179
|
+
return exact
|
|
180
|
+
|
|
181
|
+
# Fallback: return top results as best guesses
|
|
182
|
+
return results[:5]
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def find_references(symbol_name: str, db) -> list[dict]:
|
|
186
|
+
"""Find all cross-references to *symbol_name*.
|
|
187
|
+
|
|
188
|
+
Queries the ``references_`` table for exact matches.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
symbol_name: The name of the symbol to find references for.
|
|
192
|
+
db: An open ``sqlite3.Connection``.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
A list of dicts with ``symbol_name``, ``file_path``, ``line_number``.
|
|
196
|
+
"""
|
|
197
|
+
rows = db.execute(
|
|
198
|
+
"""
|
|
199
|
+
SELECT r.symbol_name, f.path, r.line_number
|
|
200
|
+
FROM references_ r
|
|
201
|
+
JOIN files f ON f.id = r.file_id
|
|
202
|
+
WHERE r.symbol_name = ?
|
|
203
|
+
ORDER BY f.path, r.line_number
|
|
204
|
+
""",
|
|
205
|
+
(symbol_name,),
|
|
206
|
+
).fetchall()
|
|
207
|
+
|
|
208
|
+
return [
|
|
209
|
+
{"symbol_name": r[0], "file_path": r[1], "line_number": r[2]}
|
|
210
|
+
for r in rows
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def get_file_structure(file_path: str, db) -> list[dict]:
|
|
215
|
+
"""List all symbols in a given file, ordered by line number.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
file_path: Absolute (or matching) path to the file.
|
|
219
|
+
db: An open ``sqlite3.Connection``.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
A list of dicts with ``name``, ``kind``, ``line_start``, ``line_end``,
|
|
223
|
+
``parent``.
|
|
224
|
+
"""
|
|
225
|
+
import os
|
|
226
|
+
|
|
227
|
+
abs_path = os.path.abspath(file_path)
|
|
228
|
+
|
|
229
|
+
rows = db.execute(
|
|
230
|
+
"""
|
|
231
|
+
SELECT s.name, s.kind, s.line_start, s.line_end,
|
|
232
|
+
p.name AS parent_name
|
|
233
|
+
FROM symbols s
|
|
234
|
+
JOIN files f ON f.id = s.file_id
|
|
235
|
+
LEFT JOIN symbols p ON p.id = s.parent_symbol_id
|
|
236
|
+
WHERE f.path = ?
|
|
237
|
+
ORDER BY s.line_start
|
|
238
|
+
""",
|
|
239
|
+
(abs_path,),
|
|
240
|
+
).fetchall()
|
|
241
|
+
|
|
242
|
+
return [
|
|
243
|
+
{
|
|
244
|
+
"name": r[0],
|
|
245
|
+
"kind": r[1],
|
|
246
|
+
"line_start": r[2],
|
|
247
|
+
"line_end": r[3],
|
|
248
|
+
"parent": r[4],
|
|
249
|
+
}
|
|
250
|
+
for r in rows
|
|
251
|
+
]
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# ---------------------------------------------------------------------------
|
|
255
|
+
# Documentation search (Milestone 4)
|
|
256
|
+
# ---------------------------------------------------------------------------
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _doc_bm25_search(query: str, db, top_k: int = 50) -> list[dict]:
|
|
260
|
+
"""Run FTS5 BM25 search against ``doc_chunks_fts``.
|
|
261
|
+
|
|
262
|
+
Returns a ranked list of dicts with chunk metadata and bm25_score.
|
|
263
|
+
"""
|
|
264
|
+
safe_query = query.replace('"', '""')
|
|
265
|
+
try:
|
|
266
|
+
rows = db.execute(
|
|
267
|
+
"""
|
|
268
|
+
SELECT dc.id, dc.section_title, dc.content, df.path, df.doc_type,
|
|
269
|
+
dc.line_start, dc.line_end, bm25(doc_chunks_fts) AS score
|
|
270
|
+
FROM doc_chunks_fts
|
|
271
|
+
JOIN doc_chunks dc ON dc.id = doc_chunks_fts.rowid
|
|
272
|
+
JOIN doc_files df ON df.id = dc.doc_file_id
|
|
273
|
+
WHERE doc_chunks_fts MATCH ?
|
|
274
|
+
ORDER BY score
|
|
275
|
+
LIMIT ?
|
|
276
|
+
""",
|
|
277
|
+
(safe_query, top_k),
|
|
278
|
+
).fetchall()
|
|
279
|
+
except Exception:
|
|
280
|
+
return []
|
|
281
|
+
|
|
282
|
+
return [
|
|
283
|
+
{
|
|
284
|
+
"chunk_id": r[0],
|
|
285
|
+
"section_title": r[1],
|
|
286
|
+
"content": r[2],
|
|
287
|
+
"source_file": r[3],
|
|
288
|
+
"doc_type": r[4],
|
|
289
|
+
"line_start": r[5],
|
|
290
|
+
"line_end": r[6],
|
|
291
|
+
"bm25_score": r[7],
|
|
292
|
+
}
|
|
293
|
+
for r in rows
|
|
294
|
+
]
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _doc_vector_search(query: str, db, top_k: int = 50) -> list[dict]:
|
|
298
|
+
"""Run dense vector nearest-neighbour search on doc_embeddings."""
|
|
299
|
+
query_vec = db_mod.embed_text(query)
|
|
300
|
+
query_blob = struct.pack(f"{len(query_vec)}f", *query_vec)
|
|
301
|
+
|
|
302
|
+
rows = db.execute(
|
|
303
|
+
"""
|
|
304
|
+
SELECT de.chunk_id, de.distance,
|
|
305
|
+
dc.section_title, dc.content, df.path, df.doc_type,
|
|
306
|
+
dc.line_start, dc.line_end
|
|
307
|
+
FROM doc_embeddings de
|
|
308
|
+
JOIN doc_chunks dc ON dc.id = de.chunk_id
|
|
309
|
+
JOIN doc_files df ON df.id = dc.doc_file_id
|
|
310
|
+
WHERE de.embedding MATCH ?
|
|
311
|
+
AND de.k = ?
|
|
312
|
+
ORDER BY de.distance
|
|
313
|
+
""",
|
|
314
|
+
(query_blob, top_k),
|
|
315
|
+
).fetchall()
|
|
316
|
+
|
|
317
|
+
return [
|
|
318
|
+
{
|
|
319
|
+
"chunk_id": r[0],
|
|
320
|
+
"vec_distance": r[1],
|
|
321
|
+
"section_title": r[2],
|
|
322
|
+
"content": r[3],
|
|
323
|
+
"source_file": r[4],
|
|
324
|
+
"doc_type": r[5],
|
|
325
|
+
"line_start": r[6],
|
|
326
|
+
"line_end": r[7],
|
|
327
|
+
}
|
|
328
|
+
for r in rows
|
|
329
|
+
]
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def search_documentation(query: str, db, top_k: int = 10,
|
|
333
|
+
include_context: bool = False) -> list[dict]:
|
|
334
|
+
"""Perform hybrid search over documentation chunks.
|
|
335
|
+
|
|
336
|
+
Uses BM25 + vector search with Reciprocal Rank Fusion.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
query: Natural language query.
|
|
340
|
+
db: Database connection.
|
|
341
|
+
top_k: Maximum results to return.
|
|
342
|
+
include_context: If True, include adjacent chunks for context.
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
List of matching chunks with source attribution and RRF scores.
|
|
346
|
+
"""
|
|
347
|
+
bm25_results = _doc_bm25_search(query, db, top_k=50)
|
|
348
|
+
vec_results = _doc_vector_search(query, db, top_k=50)
|
|
349
|
+
|
|
350
|
+
# Build RRF score map keyed by chunk_id
|
|
351
|
+
scores: dict[int, float] = {}
|
|
352
|
+
details: dict[int, dict] = {}
|
|
353
|
+
|
|
354
|
+
for rank, r in enumerate(bm25_results, start=1):
|
|
355
|
+
cid = r["chunk_id"]
|
|
356
|
+
scores[cid] = scores.get(cid, 0.0) + 1.0 / (_RRF_K + rank)
|
|
357
|
+
details[cid] = {
|
|
358
|
+
"content": r["content"],
|
|
359
|
+
"source_file": r["source_file"],
|
|
360
|
+
"section_title": r["section_title"],
|
|
361
|
+
"line_start": r["line_start"],
|
|
362
|
+
"line_end": r["line_end"],
|
|
363
|
+
"doc_type": r["doc_type"],
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
for rank, r in enumerate(vec_results, start=1):
|
|
367
|
+
cid = r["chunk_id"]
|
|
368
|
+
scores[cid] = scores.get(cid, 0.0) + 1.0 / (_RRF_K + rank)
|
|
369
|
+
if cid not in details:
|
|
370
|
+
details[cid] = {
|
|
371
|
+
"content": r["content"],
|
|
372
|
+
"source_file": r["source_file"],
|
|
373
|
+
"section_title": r["section_title"],
|
|
374
|
+
"line_start": r["line_start"],
|
|
375
|
+
"line_end": r["line_end"],
|
|
376
|
+
"doc_type": r["doc_type"],
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
# Sort by descending RRF score
|
|
380
|
+
ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[:top_k]
|
|
381
|
+
|
|
382
|
+
results = [
|
|
383
|
+
{**details[cid], "score": round(score, 6)}
|
|
384
|
+
for cid, score in ranked
|
|
385
|
+
]
|
|
386
|
+
|
|
387
|
+
# Optionally include adjacent chunks for context
|
|
388
|
+
if include_context and results:
|
|
389
|
+
results = _add_context_chunks(results, db)
|
|
390
|
+
|
|
391
|
+
return results
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def _add_context_chunks(results: list[dict], db) -> list[dict]:
|
|
395
|
+
"""Add adjacent chunks to results for additional context."""
|
|
396
|
+
enriched = []
|
|
397
|
+
|
|
398
|
+
for result in results:
|
|
399
|
+
# Get the chunk's file and index
|
|
400
|
+
row = db.execute(
|
|
401
|
+
"""
|
|
402
|
+
SELECT dc.chunk_index, dc.doc_file_id
|
|
403
|
+
FROM doc_chunks dc
|
|
404
|
+
JOIN doc_files df ON df.id = dc.doc_file_id
|
|
405
|
+
WHERE df.path = ? AND dc.line_start = ? AND dc.line_end = ?
|
|
406
|
+
""",
|
|
407
|
+
(result["source_file"], result["line_start"], result["line_end"]),
|
|
408
|
+
).fetchone()
|
|
409
|
+
|
|
410
|
+
if not row:
|
|
411
|
+
enriched.append(result)
|
|
412
|
+
continue
|
|
413
|
+
|
|
414
|
+
chunk_index, doc_file_id = row
|
|
415
|
+
|
|
416
|
+
# Get previous and next chunks
|
|
417
|
+
context_parts = []
|
|
418
|
+
|
|
419
|
+
prev = db.execute(
|
|
420
|
+
"""
|
|
421
|
+
SELECT content FROM doc_chunks
|
|
422
|
+
WHERE doc_file_id = ? AND chunk_index = ?
|
|
423
|
+
""",
|
|
424
|
+
(doc_file_id, chunk_index - 1),
|
|
425
|
+
).fetchone()
|
|
426
|
+
if prev:
|
|
427
|
+
context_parts.append({"type": "previous", "content": prev[0][:200]})
|
|
428
|
+
|
|
429
|
+
context_parts.append({"type": "current", "content": result["content"]})
|
|
430
|
+
|
|
431
|
+
next_chunk = db.execute(
|
|
432
|
+
"""
|
|
433
|
+
SELECT content FROM doc_chunks
|
|
434
|
+
WHERE doc_file_id = ? AND chunk_index = ?
|
|
435
|
+
""",
|
|
436
|
+
(doc_file_id, chunk_index + 1),
|
|
437
|
+
).fetchone()
|
|
438
|
+
if next_chunk:
|
|
439
|
+
context_parts.append({"type": "next", "content": next_chunk[0][:200]})
|
|
440
|
+
|
|
441
|
+
enriched.append({
|
|
442
|
+
**result,
|
|
443
|
+
"context": context_parts,
|
|
444
|
+
})
|
|
445
|
+
|
|
446
|
+
return enriched
|