codemesh 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codemesh/__init__.py +5 -0
- codemesh/__main__.py +8 -0
- codemesh/cli/__init__.py +3 -0
- codemesh/cli/init.py +208 -0
- codemesh/cli/install_cmd.py +208 -0
- codemesh/cli/main.py +469 -0
- codemesh/context/__init__.py +3 -0
- codemesh/context/builder.py +388 -0
- codemesh/db/__init__.py +3 -0
- codemesh/db/connection.py +66 -0
- codemesh/db/queries.py +696 -0
- codemesh/db/schema.py +125 -0
- codemesh/embedding/__init__.py +3 -0
- codemesh/extraction/__init__.py +7 -0
- codemesh/extraction/languages/__init__.py +95 -0
- codemesh/extraction/languages/c_family.py +614 -0
- codemesh/extraction/languages/go.py +397 -0
- codemesh/extraction/languages/java.py +603 -0
- codemesh/extraction/languages/python.py +718 -0
- codemesh/extraction/languages/rust.py +435 -0
- codemesh/extraction/languages/swift.py +464 -0
- codemesh/extraction/languages/typescript.py +1222 -0
- codemesh/extraction/orchestrator.py +218 -0
- codemesh/graph/__init__.py +8 -0
- codemesh/graph/query_manager.py +117 -0
- codemesh/graph/traverser.py +107 -0
- codemesh/indexer.py +240 -0
- codemesh/mcp/__init__.py +3 -0
- codemesh/mcp/server.py +60 -0
- codemesh/mcp/tools.py +605 -0
- codemesh/querier.py +269 -0
- codemesh/resolution/__init__.py +7 -0
- codemesh/resolution/frameworks/__init__.py +15 -0
- codemesh/resolution/frameworks/django.py +30 -0
- codemesh/resolution/frameworks/fastapi.py +23 -0
- codemesh/resolution/import_resolver.py +69 -0
- codemesh/resolution/name_matcher.py +30 -0
- codemesh/resolution/resolver.py +268 -0
- codemesh/retrieval/__init__.py +7 -0
- codemesh/search/__init__.py +3 -0
- codemesh/sync/__init__.py +3 -0
- codemesh/sync/watcher.py +135 -0
- codemesh/types.py +148 -0
- codemesh/viz/__init__.py +0 -0
- codemesh/viz/graph_builder.py +162 -0
- codemesh/viz/server.py +122 -0
- codemesh/viz/templates/index.html +359 -0
- codemesh-0.1.1.dist-info/METADATA +337 -0
- codemesh-0.1.1.dist-info/RECORD +52 -0
- codemesh-0.1.1.dist-info/WHEEL +4 -0
- codemesh-0.1.1.dist-info/entry_points.txt +2 -0
- codemesh-0.1.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""Token-budget-aware context builder with deduplication."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import sqlite3
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from xml.sax.saxutils import escape as xml_escape
|
|
10
|
+
|
|
11
|
+
from codemesh.types import Node
|
|
12
|
+
|
|
13
|
+
# Node kinds that provide high information value
|
|
14
|
+
_HIGH_VALUE_KINDS = {
|
|
15
|
+
"function",
|
|
16
|
+
"method",
|
|
17
|
+
"class",
|
|
18
|
+
"interface",
|
|
19
|
+
"type_alias",
|
|
20
|
+
"struct",
|
|
21
|
+
"trait",
|
|
22
|
+
"component",
|
|
23
|
+
"route",
|
|
24
|
+
"variable",
|
|
25
|
+
"constant",
|
|
26
|
+
"enum",
|
|
27
|
+
"module",
|
|
28
|
+
"namespace",
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ContextFormat(Enum):
|
|
33
|
+
XML = "xml"
|
|
34
|
+
MARKDOWN = "markdown"
|
|
35
|
+
STRUCTURED = "structured" # Entry Points + Related Symbols + Code
|
|
36
|
+
GRAPH = "graph" # Graph-linearized: entry points ranked + call chains
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ContextOptions:
|
|
41
|
+
max_tokens: int = 1200
|
|
42
|
+
max_snippets: int = 3 # Max 3 snippets
|
|
43
|
+
max_lines_per_snippet: int = 10 # Short snippets
|
|
44
|
+
context_margin: int = 0
|
|
45
|
+
max_per_file: int = 1
|
|
46
|
+
max_snippet_chars: int = 600 # Per-snippet cap
|
|
47
|
+
include_graph_summary: bool = False
|
|
48
|
+
format: ContextFormat = ContextFormat.XML
|
|
49
|
+
filter_low_value: bool = True
|
|
50
|
+
max_snippet_chars: int = 800 # Per-snippet cap
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class Snippet:
|
|
55
|
+
file_path: Path
|
|
56
|
+
start_line: int
|
|
57
|
+
end_line: int
|
|
58
|
+
code: str
|
|
59
|
+
relevance_score: float
|
|
60
|
+
node_name: str = ""
|
|
61
|
+
source: str = "bm25" # "bm25" or "graph:calls" or "graph:contains" or "graph:references"
|
|
62
|
+
edge_kind: str = "" # the edge kind that connected this node (for graph nodes)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def estimate_tokens(text: str) -> int:
|
|
66
|
+
return max(1, len(text) // 4)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _line_overlap(a_start: int, a_end: int, b_start: int, b_end: int) -> float:
|
|
70
|
+
"""Compute fractional overlap between two line ranges."""
|
|
71
|
+
overlap_start = max(a_start, b_start)
|
|
72
|
+
overlap_end = min(a_end, b_end)
|
|
73
|
+
if overlap_start >= overlap_end:
|
|
74
|
+
return 0.0
|
|
75
|
+
overlap_len = overlap_end - overlap_start
|
|
76
|
+
min_len = min(a_end - a_start, b_end - b_start)
|
|
77
|
+
return overlap_len / min_len if min_len > 0 else 0.0
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class ContextBuilder:
|
|
81
|
+
"""Builds token-budget-aware context for LLM agents.
|
|
82
|
+
|
|
83
|
+
Deduplicates overlapping snippets. Filters low-value node kinds (imports/exports).
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
OVERLAP_THRESHOLD = 0.6
|
|
87
|
+
|
|
88
|
+
def __init__(self, conn: sqlite3.Connection, root: Path) -> None:
|
|
89
|
+
self.conn = conn
|
|
90
|
+
self.root = root
|
|
91
|
+
|
|
92
|
+
def build(
|
|
93
|
+
self,
|
|
94
|
+
nodes: list[tuple[Node, float]],
|
|
95
|
+
query: str,
|
|
96
|
+
options: ContextOptions | None = None,
|
|
97
|
+
entry_points: list[tuple[Node, float]] | None = None,
|
|
98
|
+
related: list[tuple[Node, float]] | None = None,
|
|
99
|
+
) -> str:
|
|
100
|
+
if options is None:
|
|
101
|
+
options = ContextOptions()
|
|
102
|
+
|
|
103
|
+
snippets = self._extract_snippets(nodes, options)
|
|
104
|
+
deduped = self._deduplicate(snippets)
|
|
105
|
+
total_tokens = 0
|
|
106
|
+
selected: list[Snippet] = []
|
|
107
|
+
file_counts: dict[Path, int] = {}
|
|
108
|
+
|
|
109
|
+
for snippet in deduped:
|
|
110
|
+
tokens = estimate_tokens(snippet.code)
|
|
111
|
+
if total_tokens + tokens > options.max_tokens or len(selected) >= options.max_snippets:
|
|
112
|
+
break
|
|
113
|
+
fc = file_counts.get(snippet.file_path, 0)
|
|
114
|
+
if fc >= options.max_per_file:
|
|
115
|
+
continue
|
|
116
|
+
selected.append(snippet)
|
|
117
|
+
total_tokens += tokens
|
|
118
|
+
file_counts[snippet.file_path] = fc + 1
|
|
119
|
+
|
|
120
|
+
if options.format == ContextFormat.XML:
|
|
121
|
+
return self._format_xml(selected, query)
|
|
122
|
+
if options.format == ContextFormat.STRUCTURED:
|
|
123
|
+
return self._format_structured(selected, query, entry_points, related)
|
|
124
|
+
if options.format == ContextFormat.GRAPH:
|
|
125
|
+
return self._format_graph(selected, query, entry_points, related)
|
|
126
|
+
return self._format_markdown(selected, query)
|
|
127
|
+
|
|
128
|
+
def _extract_snippets(
|
|
129
|
+
self, nodes: list[tuple[Node, float]], options: ContextOptions
|
|
130
|
+
) -> list[Snippet]:
|
|
131
|
+
snippets: list[Snippet] = []
|
|
132
|
+
for node, score in nodes:
|
|
133
|
+
# Filter low-value node kinds
|
|
134
|
+
if options.filter_low_value and node.kind.value not in _HIGH_VALUE_KINDS:
|
|
135
|
+
continue
|
|
136
|
+
try:
|
|
137
|
+
file_path = (
|
|
138
|
+
self.root / node.file_path
|
|
139
|
+
if not node.file_path.is_absolute()
|
|
140
|
+
else node.file_path
|
|
141
|
+
)
|
|
142
|
+
if not file_path.exists():
|
|
143
|
+
continue
|
|
144
|
+
lines = file_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
|
145
|
+
start = max(0, node.start_line - 1 - options.context_margin)
|
|
146
|
+
end = min(len(lines), node.end_line + options.context_margin)
|
|
147
|
+
# Cap snippet size tightly
|
|
148
|
+
if end - start > options.max_lines_per_snippet:
|
|
149
|
+
node_len = node.end_line - node.start_line + 1
|
|
150
|
+
extra = options.max_lines_per_snippet - node_len
|
|
151
|
+
start = max(0, node.start_line - 1 - extra // 2)
|
|
152
|
+
end = min(len(lines), start + options.max_lines_per_snippet)
|
|
153
|
+
code_lines = lines[start:end]
|
|
154
|
+
code = "\n".join(code_lines)
|
|
155
|
+
# Cap individual snippet char size
|
|
156
|
+
if len(code) > options.max_snippet_chars:
|
|
157
|
+
truncated_lines = []
|
|
158
|
+
char_count = 0
|
|
159
|
+
for line in code_lines:
|
|
160
|
+
if char_count + len(line) + 1 > options.max_snippet_chars:
|
|
161
|
+
break
|
|
162
|
+
truncated_lines.append(line)
|
|
163
|
+
char_count += len(line) + 1
|
|
164
|
+
code = "\n".join(truncated_lines)
|
|
165
|
+
snippets.append(
|
|
166
|
+
Snippet(
|
|
167
|
+
file_path=node.file_path,
|
|
168
|
+
start_line=start + 1,
|
|
169
|
+
end_line=start + len(code_lines),
|
|
170
|
+
code=code,
|
|
171
|
+
relevance_score=score,
|
|
172
|
+
node_name=node.name,
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
except Exception:
|
|
176
|
+
continue
|
|
177
|
+
return snippets
|
|
178
|
+
|
|
179
|
+
def _deduplicate(self, snippets: list[Snippet]) -> list[Snippet]:
|
|
180
|
+
"""Remove overlapping snippets, keeping the higher-scored one."""
|
|
181
|
+
kept: list[Snippet] = []
|
|
182
|
+
for snippet in snippets:
|
|
183
|
+
is_dup = False
|
|
184
|
+
for existing in kept:
|
|
185
|
+
if existing.file_path != snippet.file_path:
|
|
186
|
+
continue
|
|
187
|
+
overlap = _line_overlap(
|
|
188
|
+
existing.start_line,
|
|
189
|
+
existing.end_line,
|
|
190
|
+
snippet.start_line,
|
|
191
|
+
snippet.end_line,
|
|
192
|
+
)
|
|
193
|
+
if overlap >= self.OVERLAP_THRESHOLD:
|
|
194
|
+
is_dup = True
|
|
195
|
+
break
|
|
196
|
+
if not is_dup:
|
|
197
|
+
kept.append(snippet)
|
|
198
|
+
return kept
|
|
199
|
+
|
|
200
|
+
def _format_xml(self, snippets: list[Snippet], query: str) -> str:
|
|
201
|
+
lines = [f'<code_context query="{xml_escape(query)}">']
|
|
202
|
+
for s in snippets:
|
|
203
|
+
rel = xml_escape(f"{s.relevance_score:.2f}")
|
|
204
|
+
source_attr = xml_escape(s.source)
|
|
205
|
+
lines.append(
|
|
206
|
+
f' <snippet file="{xml_escape(str(s.file_path))}" '
|
|
207
|
+
f'lines="{s.start_line}-{s.end_line}" relevance="{rel}" source="{source_attr}">'
|
|
208
|
+
)
|
|
209
|
+
if s.node_name:
|
|
210
|
+
lines.append(f" <!-- {xml_escape(s.node_name)} -->")
|
|
211
|
+
lines.append(f" {xml_escape(s.code)}")
|
|
212
|
+
lines.append(" </snippet>")
|
|
213
|
+
lines.append("</code_context>")
|
|
214
|
+
return "\n".join(lines)
|
|
215
|
+
|
|
216
|
+
def _format_markdown(self, snippets: list[Snippet], query: str) -> str:
|
|
217
|
+
lines = [f"## Code Context: {query}", ""]
|
|
218
|
+
for s in snippets:
|
|
219
|
+
header = f"### {s.file_path}:{s.start_line}-{s.end_line}"
|
|
220
|
+
if s.node_name:
|
|
221
|
+
header += f" ({s.node_name})"
|
|
222
|
+
lines.append(header)
|
|
223
|
+
lines.append("```")
|
|
224
|
+
lines.append(s.code)
|
|
225
|
+
lines.append("```")
|
|
226
|
+
lines.append("")
|
|
227
|
+
return "\n".join(lines)
|
|
228
|
+
|
|
229
|
+
def _format_structured(
|
|
230
|
+
self,
|
|
231
|
+
snippets: list[Snippet],
|
|
232
|
+
query: str,
|
|
233
|
+
entry_points: list[tuple[Node, float]] | None = None,
|
|
234
|
+
related: list[tuple[Node, float]] | None = None,
|
|
235
|
+
) -> str:
|
|
236
|
+
"""Structured output with entry points, related symbols, and code.
|
|
237
|
+
|
|
238
|
+
Three sections:
|
|
239
|
+
- Entry Points: BM25-matched symbols with signatures
|
|
240
|
+
- Related Symbols: Graph-walk-discovered symbols
|
|
241
|
+
- Code: Deduplicated code snippets with file:line references
|
|
242
|
+
"""
|
|
243
|
+
lines = ["## Code Context", "", f"**Query:** {query}", ""]
|
|
244
|
+
|
|
245
|
+
# Entry Points section
|
|
246
|
+
if entry_points:
|
|
247
|
+
lines.append("### Entry Points")
|
|
248
|
+
lines.append("")
|
|
249
|
+
for node, _score in entry_points[:5]:
|
|
250
|
+
sig = node.signature or ""
|
|
251
|
+
vis = f" ({node.visibility})" if node.visibility != "public" else ""
|
|
252
|
+
async_tag = " (async)" if node.is_async else ""
|
|
253
|
+
exported_tag = " (exported)" if node.is_exported else ""
|
|
254
|
+
lines.append(
|
|
255
|
+
f"- **{node.name}** ({node.kind.value}){vis}{async_tag}{exported_tag} - {node.file_path}:{node.start_line}"
|
|
256
|
+
)
|
|
257
|
+
if sig:
|
|
258
|
+
lines.append(f" `{sig[:120]}`")
|
|
259
|
+
if node.docstring:
|
|
260
|
+
lines.append(f" {node.docstring[:100]}")
|
|
261
|
+
lines.append("")
|
|
262
|
+
|
|
263
|
+
# Related Symbols section
|
|
264
|
+
if related:
|
|
265
|
+
lines.append("### Related Symbols")
|
|
266
|
+
lines.append("")
|
|
267
|
+
seen_files: set[str] = set()
|
|
268
|
+
for node, _score in related[:15]:
|
|
269
|
+
file_key = f"{node.file_path}:{node.name}"
|
|
270
|
+
if file_key in seen_files:
|
|
271
|
+
continue
|
|
272
|
+
seen_files.add(file_key)
|
|
273
|
+
lines.append(f"- {node.file_path}: {node.name} ({node.kind.value})")
|
|
274
|
+
lines.append("")
|
|
275
|
+
|
|
276
|
+
# Code section
|
|
277
|
+
if snippets:
|
|
278
|
+
lines.append("### Code")
|
|
279
|
+
lines.append("")
|
|
280
|
+
for s in snippets:
|
|
281
|
+
header = f"#### {s.file_path}:{s.start_line}-{s.end_line}"
|
|
282
|
+
if s.node_name:
|
|
283
|
+
header += f" ({s.node_name})"
|
|
284
|
+
if s.source and s.source != "bm25":
|
|
285
|
+
header += f" [{s.source}]"
|
|
286
|
+
lines.append(header)
|
|
287
|
+
lines.append("```")
|
|
288
|
+
lines.append(s.code)
|
|
289
|
+
lines.append("```")
|
|
290
|
+
lines.append("")
|
|
291
|
+
|
|
292
|
+
return "\n".join(lines)
|
|
293
|
+
|
|
294
|
+
def _format_graph(
|
|
295
|
+
self,
|
|
296
|
+
snippets: list[Snippet],
|
|
297
|
+
query: str,
|
|
298
|
+
entry_points: list[tuple[Node, float]] | None = None,
|
|
299
|
+
related: list[tuple[Node, float]] | None = None,
|
|
300
|
+
) -> str:
|
|
301
|
+
"""Graph-linearized context: rank entry points by importance, show call chains."""
|
|
302
|
+
lines = ["## Graph Context", "", f"**Query:** {query}", ""]
|
|
303
|
+
|
|
304
|
+
ranked: list[tuple[Node, float]] = []
|
|
305
|
+
if entry_points:
|
|
306
|
+
for node, score in entry_points[:10]:
|
|
307
|
+
bonus = 0.0
|
|
308
|
+
if node.is_exported:
|
|
309
|
+
bonus += 5.0
|
|
310
|
+
if node.kind.value in ("function", "method"):
|
|
311
|
+
bonus += 3.0
|
|
312
|
+
elif node.kind.value in ("class", "interface"):
|
|
313
|
+
bonus += 4.0
|
|
314
|
+
try:
|
|
315
|
+
rows = self.conn.execute(
|
|
316
|
+
"SELECT COUNT(*) FROM edges WHERE source_id = ? AND kind = 'calls'",
|
|
317
|
+
(node.id,),
|
|
318
|
+
).fetchone()
|
|
319
|
+
if rows:
|
|
320
|
+
bonus += min(float(rows[0]) * 0.5, 10.0)
|
|
321
|
+
except Exception:
|
|
322
|
+
pass
|
|
323
|
+
ranked.append((node, score + bonus))
|
|
324
|
+
|
|
325
|
+
ranked.sort(key=lambda x: x[1], reverse=True)
|
|
326
|
+
|
|
327
|
+
lines.append("### Entry Points (ranked)")
|
|
328
|
+
lines.append("")
|
|
329
|
+
for node, _rscore in ranked[:5]:
|
|
330
|
+
sig = node.signature or ""
|
|
331
|
+
tags: list[str] = []
|
|
332
|
+
if node.is_exported:
|
|
333
|
+
tags.append("exported")
|
|
334
|
+
if node.is_async:
|
|
335
|
+
tags.append("async")
|
|
336
|
+
tag_str = f" [{', '.join(tags)}]" if tags else ""
|
|
337
|
+
lines.append(
|
|
338
|
+
f"- **{node.name}** ({node.kind.value}){tag_str}"
|
|
339
|
+
f" - {node.file_path}:{node.start_line}"
|
|
340
|
+
)
|
|
341
|
+
if sig:
|
|
342
|
+
lines.append(f" `{sig[:120]}`")
|
|
343
|
+
lines.append("")
|
|
344
|
+
|
|
345
|
+
if ranked:
|
|
346
|
+
lines.append("### Call Chains")
|
|
347
|
+
lines.append("")
|
|
348
|
+
for node, _ in ranked[:3]:
|
|
349
|
+
callees = self._get_callees(node.id)
|
|
350
|
+
if callees:
|
|
351
|
+
lines.append(f"**{node.name}** ->")
|
|
352
|
+
for callee_name, callee_file in callees[:5]:
|
|
353
|
+
lines.append(f" - {callee_name} ({callee_file})")
|
|
354
|
+
lines.append("")
|
|
355
|
+
|
|
356
|
+
if snippets:
|
|
357
|
+
lines.append("### Code")
|
|
358
|
+
lines.append("")
|
|
359
|
+
for s in snippets:
|
|
360
|
+
header = f"#### {s.file_path}:{s.start_line}-{s.end_line}"
|
|
361
|
+
if s.node_name:
|
|
362
|
+
header += f" ({s.node_name})"
|
|
363
|
+
if s.source and s.source != "bm25":
|
|
364
|
+
header += f" [{s.source}]"
|
|
365
|
+
lines.append(header)
|
|
366
|
+
lines.append("```")
|
|
367
|
+
lines.append(s.code)
|
|
368
|
+
lines.append("```")
|
|
369
|
+
lines.append("")
|
|
370
|
+
|
|
371
|
+
return "\n".join(lines)
|
|
372
|
+
|
|
373
|
+
def _get_callees(self, node_id: str) -> list[tuple[str, str]]:
|
|
374
|
+
"""Get the names and files of nodes called by the given node."""
|
|
375
|
+
try:
|
|
376
|
+
rows = self.conn.execute(
|
|
377
|
+
"""
|
|
378
|
+
SELECT n.name, n.file_path
|
|
379
|
+
FROM edges e
|
|
380
|
+
JOIN nodes n ON e.target_id = n.id
|
|
381
|
+
WHERE e.source_id = ? AND e.kind = 'calls'
|
|
382
|
+
LIMIT 10
|
|
383
|
+
""",
|
|
384
|
+
(node_id,),
|
|
385
|
+
).fetchall()
|
|
386
|
+
return [(r[0], r[1]) for r in rows]
|
|
387
|
+
except Exception:
|
|
388
|
+
return []
|
codemesh/db/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""SQLite connection manager with WAL mode and optimized settings."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import sqlite3
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from collections.abc import Generator
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
DB_PATH = Path(".codemesh/index.db")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_db_path(root: Path | None = None) -> Path:
|
|
18
|
+
"""Get the database path for a given project root."""
|
|
19
|
+
if root is None:
|
|
20
|
+
root = Path.cwd()
|
|
21
|
+
return root / DB_PATH
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def create_connection(db_path: Path) -> sqlite3.Connection:
|
|
25
|
+
"""Create an optimized SQLite connection.
|
|
26
|
+
|
|
27
|
+
Enables WAL mode for concurrent reads, foreign keys,
|
|
28
|
+
and other performance optimizations. Also loads the sqlite-vec
|
|
29
|
+
extension for ANN vector search.
|
|
30
|
+
"""
|
|
31
|
+
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
32
|
+
conn = sqlite3.connect(str(db_path), timeout=30)
|
|
33
|
+
conn.row_factory = sqlite3.Row
|
|
34
|
+
conn.execute("PRAGMA journal_mode=WAL")
|
|
35
|
+
conn.execute("PRAGMA synchronous=NORMAL")
|
|
36
|
+
conn.execute("PRAGMA cache_size=-64000") # 64MB cache
|
|
37
|
+
conn.execute("PRAGMA temp_store=MEMORY")
|
|
38
|
+
_load_sqlite_vec(conn)
|
|
39
|
+
return conn
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _load_sqlite_vec(conn: sqlite3.Connection) -> None:
|
|
43
|
+
"""Load the sqlite-vec extension if available."""
|
|
44
|
+
try:
|
|
45
|
+
conn.enable_load_extension(True)
|
|
46
|
+
import sqlite_vec
|
|
47
|
+
|
|
48
|
+
conn.load_extension(sqlite_vec.loadable_path())
|
|
49
|
+
except Exception:
|
|
50
|
+
pass # sqlite-vec not available; brute-force fallback will be used
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@contextmanager
|
|
54
|
+
def get_connection(db_path: Path | None = None) -> Generator[sqlite3.Connection, None, None]:
|
|
55
|
+
"""Context manager for database connections."""
|
|
56
|
+
if db_path is None:
|
|
57
|
+
db_path = get_db_path()
|
|
58
|
+
conn = create_connection(db_path)
|
|
59
|
+
try:
|
|
60
|
+
yield conn
|
|
61
|
+
conn.commit()
|
|
62
|
+
except Exception:
|
|
63
|
+
conn.rollback()
|
|
64
|
+
raise
|
|
65
|
+
finally:
|
|
66
|
+
conn.close()
|