pr-context-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
src/config.py ADDED
@@ -0,0 +1,118 @@
1
+ """Provider factory — reads LLM_PROVIDER env var and returns the right LLMProvider.
2
+
3
+ Also exposes is_fixes_enabled() for the ENABLE_FIXES kill switch (Milestone 8).
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ import os
9
+ from typing import TYPE_CHECKING
10
+
11
+ from src.llm.base import LLMProvider
12
+
13
+ if TYPE_CHECKING:
14
+ from src.llm import FailoverProvider
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ _PROVIDER_ENV = "LLM_PROVIDER"
19
+ _DEFAULT = "groq"
20
+ _VALID = {"groq", "gemini", "ollama", "anthropic"}
21
+
22
+
23
+ def get_provider() -> LLMProvider:
24
+ """Instantiate and return the LLMProvider named by LLM_PROVIDER (default: groq).
25
+
26
+ Reads provider-specific keys from env vars — nothing provider-specific leaks
27
+ into callers. Raises RuntimeError for a missing required key, ValueError for an
28
+ unknown provider name.
29
+ """
30
+ name = (os.getenv(_PROVIDER_ENV) or _DEFAULT).lower()
31
+ if name not in _VALID:
32
+ raise ValueError(
33
+ f"Unknown LLM provider: {name!r}. Valid choices: {', '.join(sorted(_VALID))}"
34
+ )
35
+ logger.info("LLM provider: %s", name)
36
+ return _build_single_provider(name)
37
+
38
+
39
+ def get_failover_provider() -> FailoverProvider:
40
+ """Build a FailoverProvider with automatic Gemini fallback when key is present.
41
+
42
+ Primary provider is determined by LLM_PROVIDER (default: groq). If GEMINI_API_KEY
43
+ is set and the primary provider is not Gemini, Gemini is added as a fallback.
44
+ This is the runtime payoff for ADR-0.
45
+
46
+ The returned provider's .attribution() gives a footer-friendly string such as
47
+ "groq" or "gemini (groq rate-limited)".
48
+
49
+ Raises:
50
+ RuntimeError: If the primary provider's required API key is missing.
51
+ ValueError: If LLM_PROVIDER names an unrecognised provider.
52
+ """
53
+ from src.llm import FailoverProvider
54
+
55
+ name = (os.getenv(_PROVIDER_ENV) or _DEFAULT).lower()
56
+ if name not in _VALID:
57
+ raise ValueError(
58
+ f"Unknown LLM provider: {name!r}. Valid choices: {', '.join(sorted(_VALID))}"
59
+ )
60
+
61
+ providers: list[tuple[str, LLMProvider]] = [(name, _build_single_provider(name))]
62
+
63
+ # Auto-add Gemini fallback when the key is present and the primary is not Gemini.
64
+ gemini_key = os.getenv("GEMINI_API_KEY")
65
+ if name != "gemini" and gemini_key:
66
+ from src.llm.gemini_provider import GeminiProvider
67
+
68
+ providers.append(("gemini", GeminiProvider(api_key=gemini_key)))
69
+ logger.info("Gemini failover enabled (GEMINI_API_KEY present)")
70
+ else:
71
+ logger.info("LLM provider: %s (no failover configured)", name)
72
+
73
+ return FailoverProvider(providers=providers)
74
+
75
+
76
+ def _build_single_provider(name: str) -> LLMProvider:
77
+ """Instantiate the named provider. Caller is responsible for name validation."""
78
+ if name == "groq":
79
+ from src.llm.groq_provider import GroqProvider
80
+
81
+ api_key = os.getenv("GROQ_API_KEY")
82
+ if not api_key:
83
+ raise RuntimeError("GROQ_API_KEY is required when LLM_PROVIDER=groq")
84
+ return GroqProvider(api_key=api_key)
85
+
86
+ if name == "gemini":
87
+ from src.llm.gemini_provider import GeminiProvider
88
+
89
+ api_key = os.getenv("GEMINI_API_KEY")
90
+ if not api_key:
91
+ raise RuntimeError("GEMINI_API_KEY is required when LLM_PROVIDER=gemini")
92
+ return GeminiProvider(api_key=api_key)
93
+
94
+ if name == "ollama":
95
+ from src.llm.ollama_provider import OllamaProvider
96
+
97
+ base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
98
+ model = os.getenv("OLLAMA_MODEL", "qwen2.5-coder:7b")
99
+ return OllamaProvider(base_url=base_url, model=model)
100
+
101
+ if name == "anthropic":
102
+ from src.llm.anthropic_provider import AnthropicProvider
103
+
104
+ api_key = os.getenv("ANTHROPIC_API_KEY")
105
+ if not api_key:
106
+ raise RuntimeError("ANTHROPIC_API_KEY is required when LLM_PROVIDER=anthropic")
107
+ return AnthropicProvider(api_key=api_key)
108
+
109
+ raise AssertionError(f"unreachable: unhandled provider {name!r}")
110
+
111
+
112
+ def is_fixes_enabled() -> bool:
113
+ """Return True when ENABLE_FIXES env var is set to a truthy value.
114
+
115
+ Fix suggestions are opt-in per repo. Default is False so the M4 briefing
116
+ behaviour is unchanged unless the caller explicitly enables the feature.
117
+ """
118
+ return os.getenv("ENABLE_FIXES", "false").strip().lower() in ("true", "1", "yes")
@@ -0,0 +1 @@
1
+ """Context gathering: codebase index (RAG) and git history."""
@@ -0,0 +1,382 @@
1
+ """Codebase vector index for semantic similarity search across repo files.
2
+
3
+ Chunks repo source files by function/class (Python) or sliding window (others),
4
+ embeds them with a local fastembed model, and stores them in a sqlite-vec database.
5
+ Subsequent runs only re-embed files whose git blob hash changed.
6
+ """
7
+ import ast
8
+ import logging
9
+ import os
10
+ import sqlite3
11
+ import subprocess
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+
15
+ import sqlite_vec
16
+ from fastembed import TextEmbedding
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5"
21
+ _EMBEDDING_DIM = 384
22
+ _CHUNK_WINDOW = 60
23
+ _CHUNK_OVERLAP = 10
24
+ _MAX_FILE_LINES = 2000
25
+
26
+ _INDEXABLE_EXTS = {
27
+ ".py", ".js", ".jsx", ".ts", ".tsx", ".go", ".rb", ".java",
28
+ ".rs", ".c", ".cpp", ".cs", ".php", ".sh",
29
+ }
30
+
31
+ _SKIP_DIRS = {
32
+ ".git", "__pycache__", "node_modules", ".venv", "venv",
33
+ ".tox", "dist", "build", ".mypy_cache", ".ruff_cache",
34
+ }
35
+
36
+
37
+ @dataclass
38
+ class RelatedChunk:
39
+ """A code chunk retrieved as semantically similar to a query."""
40
+
41
+ file_path: str
42
+ label: str
43
+ chunk_text: str
44
+ distance: float
45
+
46
+
47
+ @dataclass
48
+ class _Chunk:
49
+ """Internal representation of a code chunk before indexing."""
50
+
51
+ file_path: str
52
+ git_hash: str
53
+ label: str
54
+ chunk_text: str
55
+ start_line: int
56
+
57
+
58
+ class CodebaseIndex:
59
+ """Manages a sqlite-vec index of repo code chunks for semantic search.
60
+
61
+ On first run, walks the repo, chunks files by function/class, embeds them
62
+ with a local fastembed model, and stores them in index.db. Subsequent runs
63
+ only re-embed files whose git blob hash changed, making incremental updates fast.
64
+ """
65
+
66
+ def __init__(self, db_path: str = "index.db", repo_root: str = ".") -> None:
67
+ self._db_path = db_path
68
+ self._repo_root = Path(repo_root).resolve()
69
+ self._model: TextEmbedding | None = None
70
+ self._db: sqlite3.Connection | None = None
71
+
72
+ def _get_model(self) -> TextEmbedding:
73
+ if self._model is None:
74
+ logger.info("Loading embedding model %s (first run downloads weights)", _EMBEDDING_MODEL)
75
+ self._model = TextEmbedding(model_name=_EMBEDDING_MODEL)
76
+ return self._model
77
+
78
+ def _get_db(self) -> sqlite3.Connection:
79
+ if self._db is None:
80
+ db = sqlite3.connect(self._db_path)
81
+ db.enable_load_extension(True)
82
+ sqlite_vec.load(db)
83
+ db.enable_load_extension(False)
84
+ self._db = db
85
+ self._init_schema()
86
+ return self._db
87
+
88
+ def _init_schema(self) -> None:
89
+ self._db.executescript(
90
+ f"""
91
+ CREATE TABLE IF NOT EXISTS chunks (
92
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
93
+ file_path TEXT NOT NULL,
94
+ git_hash TEXT NOT NULL,
95
+ label TEXT NOT NULL,
96
+ chunk_text TEXT NOT NULL,
97
+ start_line INTEGER NOT NULL DEFAULT 0
98
+ );
99
+ CREATE TABLE IF NOT EXISTS file_hashes (
100
+ file_path TEXT PRIMARY KEY,
101
+ git_hash TEXT NOT NULL
102
+ );
103
+ CREATE VIRTUAL TABLE IF NOT EXISTS vec_chunks
104
+ USING vec0(embedding float[{_EMBEDDING_DIM}]);
105
+ """
106
+ )
107
+ self._db.commit()
108
+
109
+ def build_or_update(self) -> None:
110
+ """Walk the repo and (re-)embed any files whose content changed since last run."""
111
+ db = self._get_db()
112
+
113
+ current_files = self._list_repo_files()
114
+ indexed: dict[str, str] = dict(
115
+ db.execute("SELECT file_path, git_hash FROM file_hashes").fetchall()
116
+ )
117
+
118
+ to_index: list[tuple[str, str]] = []
119
+ to_remove: list[str] = list(set(indexed) - set(current_files))
120
+
121
+ for path, git_hash in current_files.items():
122
+ if path not in indexed or indexed[path] != git_hash:
123
+ to_index.append((path, git_hash))
124
+
125
+ if to_remove:
126
+ self._remove_files(to_remove)
127
+ logger.info("Removed %d deleted/untracked files from index", len(to_remove))
128
+
129
+ if not to_index:
130
+ logger.info("Index up to date (%d files indexed)", len(current_files))
131
+ return
132
+
133
+ logger.info("Indexing %d new/changed files", len(to_index))
134
+
135
+ # Remove stale chunks for changed files before re-adding
136
+ changed_paths = [p for p, _ in to_index if p in indexed]
137
+ if changed_paths:
138
+ self._remove_files(changed_paths)
139
+
140
+ all_chunks: list[_Chunk] = []
141
+ for path, git_hash in to_index:
142
+ abs_path = self._repo_root / path
143
+ try:
144
+ text = abs_path.read_text(encoding="utf-8", errors="replace")
145
+ except OSError as exc:
146
+ logger.warning("Cannot read %s: %s", path, exc)
147
+ continue
148
+ all_chunks.extend(chunk_file(path, git_hash, text))
149
+
150
+ if not all_chunks:
151
+ logger.info("No chunks produced from %d files", len(to_index))
152
+ return
153
+
154
+ model = self._get_model()
155
+ texts = [c.chunk_text for c in all_chunks]
156
+ embeddings = list(model.embed(texts))
157
+
158
+ for chunk, embedding in zip(all_chunks, embeddings):
159
+ row = db.execute(
160
+ "INSERT INTO chunks (file_path, git_hash, label, chunk_text, start_line)"
161
+ " VALUES (?, ?, ?, ?, ?)",
162
+ (chunk.file_path, chunk.git_hash, chunk.label, chunk.chunk_text, chunk.start_line),
163
+ )
164
+ chunk_id = row.lastrowid
165
+ emb_bytes = sqlite_vec.serialize_float32(embedding.tolist())
166
+ db.execute(
167
+ "INSERT INTO vec_chunks(rowid, embedding) VALUES (?, ?)",
168
+ (chunk_id, emb_bytes),
169
+ )
170
+
171
+ for path, git_hash in to_index:
172
+ db.execute(
173
+ "INSERT OR REPLACE INTO file_hashes (file_path, git_hash) VALUES (?, ?)",
174
+ (path, git_hash),
175
+ )
176
+
177
+ db.commit()
178
+ logger.info("Indexed %d chunks from %d files", len(all_chunks), len(to_index))
179
+
180
+ def query(
181
+ self,
182
+ text: str,
183
+ exclude_paths: set[str] | None = None,
184
+ top_k: int = 5,
185
+ ) -> list[RelatedChunk]:
186
+ """Find top-k chunks semantically similar to *text*, skipping *exclude_paths*."""
187
+ db = self._get_db()
188
+
189
+ model = self._get_model()
190
+ embedding = next(iter(model.embed([text])))
191
+ emb_bytes = sqlite_vec.serialize_float32(embedding.tolist())
192
+
193
+ # Fetch extra rows to cover filtered-out excluded paths, capped to avoid over-fetching
194
+ fetch_k = min(top_k + (len(exclude_paths) * 3 if exclude_paths else 0) + 20, top_k + 100)
195
+
196
+ # sqlite-vec requires k = ? in the WHERE clause for knn queries
197
+ rows = db.execute(
198
+ """
199
+ SELECT c.file_path, c.label, c.chunk_text, v.distance
200
+ FROM vec_chunks v
201
+ JOIN chunks c ON c.id = v.rowid
202
+ WHERE v.embedding MATCH ?
203
+ AND k = ?
204
+ ORDER BY v.distance
205
+ """,
206
+ (emb_bytes, fetch_k),
207
+ ).fetchall()
208
+
209
+ results: list[RelatedChunk] = []
210
+ for file_path, label, chunk_text, distance in rows:
211
+ if exclude_paths and file_path in exclude_paths:
212
+ continue
213
+ results.append(
214
+ RelatedChunk(
215
+ file_path=file_path,
216
+ label=label,
217
+ chunk_text=chunk_text,
218
+ distance=distance,
219
+ )
220
+ )
221
+ if len(results) >= top_k:
222
+ break
223
+
224
+ return results
225
+
226
+ def _remove_files(self, paths: list[str]) -> None:
227
+ db = self._get_db()
228
+ for path in paths:
229
+ ids = [
230
+ r[0]
231
+ for r in db.execute(
232
+ "SELECT id FROM chunks WHERE file_path = ?", (path,)
233
+ ).fetchall()
234
+ ]
235
+ for chunk_id in ids:
236
+ db.execute("DELETE FROM vec_chunks WHERE rowid = ?", (chunk_id,))
237
+ db.execute("DELETE FROM chunks WHERE file_path = ?", (path,))
238
+ db.execute("DELETE FROM file_hashes WHERE file_path = ?", (path,))
239
+ db.commit()
240
+
241
+ def _list_repo_files(self) -> dict[str, str]:
242
+ """Return {relative_path: git_blob_hash} for all indexable tracked files."""
243
+ try:
244
+ result = subprocess.run(
245
+ ["git", "ls-files", "-s"],
246
+ cwd=self._repo_root,
247
+ capture_output=True,
248
+ text=True,
249
+ check=True,
250
+ )
251
+ except (subprocess.CalledProcessError, FileNotFoundError):
252
+ logger.warning("git ls-files failed; falling back to directory scan")
253
+ return _scan_directory(self._repo_root)
254
+
255
+ files: dict[str, str] = {}
256
+ for line in result.stdout.splitlines():
257
+ # format: <mode> <hash> <stage>\t<path>
258
+ parts = line.split("\t", 1)
259
+ if len(parts) != 2:
260
+ continue
261
+ try:
262
+ git_hash = parts[0].split()[1]
263
+ except IndexError:
264
+ continue
265
+ path = parts[1]
266
+ if is_indexable(path):
267
+ files[path] = git_hash
268
+
269
+ return files
270
+
271
+
272
+ # ---------------------------------------------------------------------------
273
+ # Module-level helpers (pure functions — unit-testable without the class)
274
+ # ---------------------------------------------------------------------------
275
+
276
+
277
+ def is_indexable(path: str) -> bool:
278
+ """Return True if this file path should be included in the codebase index."""
279
+ p = Path(path)
280
+ if any(part in _SKIP_DIRS for part in p.parts):
281
+ return False
282
+ return p.suffix.lower() in _INDEXABLE_EXTS
283
+
284
+
285
+ def chunk_file(file_path: str, git_hash: str, text: str) -> list[_Chunk]:
286
+ """Split a source file into indexable chunks.
287
+
288
+ Uses AST-based function/method chunking for Python files; falls back to a
289
+ sliding-window strategy for all other languages.
290
+ """
291
+ if file_path.endswith(".py"):
292
+ chunks = chunk_python(file_path, git_hash, text)
293
+ if chunks:
294
+ return chunks
295
+ return chunk_window(file_path, git_hash, text)
296
+
297
+
298
+ def chunk_python(file_path: str, git_hash: str, text: str) -> list[_Chunk]:
299
+ """Chunk a Python file by top-level functions and class methods."""
300
+ try:
301
+ tree = ast.parse(text)
302
+ except SyntaxError:
303
+ return []
304
+
305
+ lines = text.splitlines()
306
+ chunks: list[_Chunk] = []
307
+
308
+ def _extract(node: ast.AST, prefix: str = "") -> None:
309
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
310
+ start = node.lineno - 1
311
+ end = node.end_lineno or (start + 1)
312
+ body = "\n".join(lines[start:end])
313
+ label = f"{prefix}{node.name}" if prefix else node.name
314
+ chunks.append(
315
+ _Chunk(
316
+ file_path=file_path,
317
+ git_hash=git_hash,
318
+ label=label,
319
+ chunk_text=body,
320
+ start_line=node.lineno,
321
+ )
322
+ )
323
+ elif isinstance(node, ast.ClassDef):
324
+ class_prefix = f"{node.name}."
325
+ for child in node.body:
326
+ _extract(child, prefix=class_prefix)
327
+
328
+ for node in tree.body:
329
+ _extract(node)
330
+
331
+ return chunks
332
+
333
+
334
+ def chunk_window(file_path: str, git_hash: str, text: str) -> list[_Chunk]:
335
+ """Chunk a file with a sliding window of lines."""
336
+ lines = text.splitlines()
337
+ if not lines:
338
+ return []
339
+
340
+ if len(lines) > _MAX_FILE_LINES:
341
+ lines = lines[:_MAX_FILE_LINES]
342
+
343
+ step = max(1, _CHUNK_WINDOW - _CHUNK_OVERLAP)
344
+ chunks: list[_Chunk] = []
345
+
346
+ for start in range(0, len(lines), step):
347
+ end = min(start + _CHUNK_WINDOW, len(lines))
348
+ body = "\n".join(lines[start:end])
349
+ label = f"lines {start + 1}-{end}"
350
+ chunks.append(
351
+ _Chunk(
352
+ file_path=file_path,
353
+ git_hash=git_hash,
354
+ label=label,
355
+ chunk_text=body,
356
+ start_line=start + 1,
357
+ )
358
+ )
359
+
360
+ return chunks
361
+
362
+
363
+ def _scan_directory(repo_root: Path) -> dict[str, str]:
364
+ """Fallback: scan directory and hash file contents when git is unavailable."""
365
+ import hashlib
366
+
367
+ files: dict[str, str] = {}
368
+ for root, dirs, filenames in os.walk(repo_root):
369
+ dirs[:] = [d for d in dirs if d not in _SKIP_DIRS]
370
+ for filename in filenames:
371
+ abs_path = Path(root) / filename
372
+ rel_path = str(abs_path.relative_to(repo_root))
373
+ if not is_indexable(rel_path):
374
+ continue
375
+ try:
376
+ content = abs_path.read_bytes()
377
+ content_hash = hashlib.sha1(content).hexdigest()
378
+ files[rel_path] = content_hash
379
+ except OSError:
380
+ continue
381
+
382
+ return files