sourcefire 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,967 @@
1
+ """LangChain RAG chain with mode-aware retrieval and Gemini API streaming."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import os
7
+ import subprocess
8
+ from pathlib import Path
9
+ from typing import Any, AsyncGenerator
10
+
11
+ from langchain_google_genai import ChatGoogleGenerativeAI
12
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
13
+ from langchain_core.tools import tool
14
+
15
+ from sourcefire.indexer.embeddings import embed_text
16
+ from sourcefire.indexer.language_profiles import LanguageProfile
17
+ from sourcefire.retriever.search import semantic_search, get_chunks_by_filenames, parse_file_references
18
+ from sourcefire.retriever.graph import ImportGraph
19
+ from sourcefire.chain.prompts import assemble_prompt
20
+ from sourcefire.db import query_similar
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Static context loader
25
+ # ---------------------------------------------------------------------------
26
+
27
+
28
+ def _load_static_context(project_dir: Path) -> tuple[str, str]:
29
+ """Load CLAUDE.md from project_dir.
30
+
31
+ Returns:
32
+ A 2-tuple of (claude_md_content, memory_content).
33
+ """
34
+ claude_md = ""
35
+ claude_md_path = project_dir / "CLAUDE.md"
36
+ if claude_md_path.is_file():
37
+ try:
38
+ claude_md = claude_md_path.read_text(encoding="utf-8", errors="replace")
39
+ except OSError:
40
+ pass
41
+
42
+ return claude_md, ""
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # Mode-specific retrievers
47
+ # ---------------------------------------------------------------------------
48
+
49
+
50
+ async def _retrieve_debug(
51
+ collection: Any,
52
+ graph: ImportGraph,
53
+ query: str,
54
+ query_vector: list[float],
55
+ top_k: int,
56
+ profile: LanguageProfile | None = None,
57
+ ) -> list[dict[str, Any]]:
58
+ """Debug mode: parse stack trace -> direct lookup -> graph expansion -> semantic."""
59
+ chunks: list[dict[str, Any]] = []
60
+ seen_filenames: set[str] = set()
61
+
62
+ file_ref_patterns = profile.file_ref_patterns if profile else None
63
+ file_refs = parse_file_references(query, file_ref_patterns)
64
+ direct_filenames = [ref["file"] for ref in file_refs]
65
+
66
+ if direct_filenames:
67
+ direct_chunks = await get_chunks_by_filenames(collection, direct_filenames)
68
+ for c in direct_chunks:
69
+ c["priority"] = "direct"
70
+ c.setdefault("relevance", 1.0)
71
+ chunks.append(c)
72
+ seen_filenames.add(c["filename"])
73
+
74
+ graph_filenames: list[str] = []
75
+ for fname in direct_filenames:
76
+ graph_filenames.extend(graph.get_neighbors(fname, hops=1))
77
+
78
+ graph_filenames = [f for f in graph_filenames if f not in seen_filenames]
79
+
80
+ if graph_filenames:
81
+ graph_chunks = await get_chunks_by_filenames(collection, graph_filenames)
82
+ for c in graph_chunks:
83
+ c["priority"] = "graph"
84
+ c.setdefault("relevance", 0.6)
85
+ chunks.append(c)
86
+ seen_filenames.add(c["filename"])
87
+
88
+ semantic_chunks = await semantic_search(collection, query_vector, top_k=top_k)
89
+ for c in semantic_chunks:
90
+ if c["filename"] not in seen_filenames:
91
+ c["priority"] = "semantic"
92
+ chunks.append(c)
93
+ seen_filenames.add(c["filename"])
94
+
95
+ return chunks
96
+
97
+
98
+ async def _retrieve_feature(
99
+ collection: Any,
100
+ graph: ImportGraph,
101
+ query: str,
102
+ query_vector: list[float],
103
+ top_k: int,
104
+ profile: LanguageProfile | None = None,
105
+ ) -> list[dict[str, Any]]:
106
+ """Feature mode: semantic search -> best feature -> retrieve feature chunks."""
107
+ _FEATURE_CAP = 15
108
+
109
+ seed_chunks = await semantic_search(collection, query_vector, top_k=top_k)
110
+
111
+ feature_scores: dict[str, list[float]] = {}
112
+ for c in seed_chunks:
113
+ feat = c.get("feature") or "core"
114
+ feature_scores.setdefault(feat, []).append(float(c.get("relevance", 0.0)))
115
+
116
+ if not feature_scores:
117
+ for c in seed_chunks:
118
+ c["priority"] = "semantic"
119
+ return seed_chunks
120
+
121
+ best_feature = max(feature_scores, key=lambda f: sum(feature_scores[f]) / len(feature_scores[f]))
122
+
123
+ feature_chunks = await semantic_search(
124
+ collection,
125
+ query_vector,
126
+ top_k=_FEATURE_CAP,
127
+ feature=best_feature,
128
+ )
129
+ for c in feature_chunks:
130
+ c["priority"] = "semantic"
131
+
132
+ return feature_chunks
133
+
134
+
135
+ async def _retrieve_explain(
136
+ collection: Any,
137
+ graph: ImportGraph,
138
+ query: str,
139
+ query_vector: list[float],
140
+ top_k: int,
141
+ profile: LanguageProfile | None = None,
142
+ ) -> list[dict[str, Any]]:
143
+ """Explain mode: semantic search -> import graph expansion in both directions."""
144
+ chunks: list[dict[str, Any]] = []
145
+ seen_filenames: set[str] = set()
146
+
147
+ seed_chunks = await semantic_search(collection, query_vector, top_k=top_k)
148
+ for c in seed_chunks:
149
+ c["priority"] = "semantic"
150
+ chunks.append(c)
151
+ seen_filenames.add(c["filename"])
152
+
153
+ neighbor_filenames: list[str] = []
154
+ for c in seed_chunks:
155
+ for neighbor in graph.get_neighbors(c["filename"], hops=1):
156
+ if neighbor not in seen_filenames:
157
+ neighbor_filenames.append(neighbor)
158
+ seen_filenames.add(neighbor)
159
+
160
+ if neighbor_filenames:
161
+ neighbor_chunks = await get_chunks_by_filenames(collection, neighbor_filenames)
162
+ for c in neighbor_chunks:
163
+ c["priority"] = "graph"
164
+ c.setdefault("relevance", 0.5)
165
+ chunks.append(c)
166
+
167
+ return chunks
168
+
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # Public retrieval entry point
172
+ # ---------------------------------------------------------------------------
173
+
174
+
175
+ async def retrieve_for_mode(
176
+ collection: Any,
177
+ graph: ImportGraph,
178
+ query: str,
179
+ mode: str,
180
+ top_k: int = 8,
181
+ profile: LanguageProfile | None = None,
182
+ ) -> list[dict[str, Any]]:
183
+ """Embed *query* and dispatch to the mode-specific retriever."""
184
+ loop = asyncio.get_event_loop()
185
+ query_vector: list[float] = await loop.run_in_executor(None, embed_text, query)
186
+
187
+ if mode == "debug":
188
+ return await _retrieve_debug(collection, graph, query, query_vector, top_k, profile)
189
+ elif mode == "feature":
190
+ return await _retrieve_feature(collection, graph, query, query_vector, top_k, profile)
191
+ else:
192
+ return await _retrieve_explain(collection, graph, query, query_vector, top_k, profile)
193
+
194
+
195
+ # ---------------------------------------------------------------------------
196
+ # LangChain Tools
197
+ # ---------------------------------------------------------------------------
198
+
199
+ def _get_tools(
200
+ graph: ImportGraph,
201
+ profile: LanguageProfile | None = None,
202
+ collection: Any = None,
203
+ project_dir: Path | None = None,
204
+ ) -> list[Any]:
205
+
206
+ _project_dir = project_dir or Path.cwd()
207
+ searchable_exts = tuple(profile.searchable_extensions) if profile else (".py", ".js", ".ts", ".go", ".rs", ".java", ".dart", ".yaml", ".json", ".md")
208
+
209
+ @tool
210
+ def read_local_file(filepath: str) -> str:
211
+ """Reads the complete content of a file from the repository.
212
+ Use when you need to see exact implementation details.
213
+ Provide the relative filepath (e.g. 'src/main.py').
214
+ """
215
+ full_path = (_project_dir / filepath).resolve()
216
+ if not str(full_path).startswith(str(_project_dir.resolve())):
217
+ return "Error: Path traversal not allowed."
218
+ if not full_path.is_file():
219
+ return f"Error: File '{filepath}' not found in the codebase."
220
+ try:
221
+ content = full_path.read_text(encoding="utf-8", errors="replace")
222
+ if len(content) > 30000:
223
+ return content[:30000] + "\n\n... [File truncated because it is too large] ..."
224
+ return content
225
+ except Exception as e:
226
+ return f"Error reading file: {e}"
227
+
228
+ @tool
229
+ def list_directory(dir_path: str) -> str:
230
+ """Lists files and folders within a specific directory in the repository.
231
+ Use to understand project structure or find files.
232
+ Provide the relative path (e.g. 'src/components'). Use '.' for root.
233
+ """
234
+ if dir_path == ".":
235
+ full_path = _project_dir.resolve()
236
+ else:
237
+ full_path = (_project_dir / dir_path).resolve()
238
+
239
+ if not str(full_path).startswith(str(_project_dir.resolve())):
240
+ return "Error: Path traversal not allowed."
241
+ if not full_path.is_dir():
242
+ return f"Error: Directory '{dir_path}' not found."
243
+
244
+ try:
245
+ items = []
246
+ for item in full_path.iterdir():
247
+ suffix = "/" if item.is_dir() else ""
248
+ items.append(item.name + suffix)
249
+ return f"Contents of '{dir_path}':\n" + "\n".join(sorted(items))
250
+ except Exception as e:
251
+ return f"Error listing directory: {e}"
252
+
253
+ @tool
254
+ def find_file_usages(filepath: str) -> str:
255
+ """Find other files that import or depend on the given file.
256
+ Provide the relative filepath (e.g. 'src/utils/auth.py').
257
+ Returns a list of files that directly import it.
258
+ """
259
+ importers = graph.get_importers(filepath)
260
+ if not importers:
261
+ return f"No direct local dependencies found importing {filepath}."
262
+ return f"Files importing '{filepath}':\n" + "\n".join(f"- {f}" for f in importers)
263
+
264
+ @tool
265
+ def search_codebase_keywords(query: str, dir_path: str = ".") -> str:
266
+ """Search for an exact string or keyword across files in a directory.
267
+ Returns files and line numbers where the keyword appears.
268
+ Use for variable names, classes, function names, or patterns.
269
+ """
270
+ full_path = _project_dir if dir_path == "." else (_project_dir / dir_path)
271
+ full_path = full_path.resolve()
272
+
273
+ if not str(full_path).startswith(str(_project_dir.resolve())):
274
+ return "Error: Path traversal not allowed."
275
+
276
+ if not full_path.is_dir() and not full_path.is_file():
277
+ return f"Error: Directory '{dir_path}' not found."
278
+
279
+ results = []
280
+ try:
281
+ for root, _, files in os.walk(full_path):
282
+ if any(x in root for x in [".git", ".claude", "node_modules", "__pycache__", "build", "dist", "target", ".dart_tool"]):
283
+ continue
284
+ for file in files:
285
+ if file.endswith(searchable_exts):
286
+ fpath = Path(root) / file
287
+ try:
288
+ lines = fpath.read_text("utf-8", "replace").splitlines()
289
+ for i, line in enumerate(lines):
290
+ if query.lower() in line.lower():
291
+ rel_path = fpath.relative_to(_project_dir)
292
+ results.append(f"{rel_path}:{i+1}: {line.strip()[:100]}")
293
+ if len(results) >= 50:
294
+ return f"Results truncated at 50 matches.\n" + "\n".join(results)
295
+ except Exception:
296
+ continue
297
+ if not results:
298
+ return f"No matches found for '{query}' in {dir_path}"
299
+ return "\n".join(results)
300
+ except Exception as e:
301
+ return f"Error searching: {e}"
302
+
303
+ @tool
304
+ def find_definition(symbol_name: str) -> str:
305
+ """Find where a class, function, or variable is defined in the codebase.
306
+ Searches for definition patterns like 'class Foo', 'def foo', 'function foo', etc.
307
+ Returns file paths and line numbers.
308
+ """
309
+ patterns = [
310
+ f"class {symbol_name}",
311
+ f"def {symbol_name}",
312
+ f"async def {symbol_name}",
313
+ f"function {symbol_name}",
314
+ f"const {symbol_name}",
315
+ f"let {symbol_name}",
316
+ f"var {symbol_name}",
317
+ f"type {symbol_name}",
318
+ f"interface {symbol_name}",
319
+ f"enum {symbol_name}",
320
+ f"struct {symbol_name}",
321
+ f"trait {symbol_name}",
322
+ f"impl {symbol_name}",
323
+ f"func {symbol_name}",
324
+ f"mixin {symbol_name}",
325
+ f"extension {symbol_name}",
326
+ ]
327
+ results = []
328
+ try:
329
+ for root, _, files in os.walk(_project_dir):
330
+ if any(x in root for x in [".git", "node_modules", "__pycache__", "build", "dist", "target", ".dart_tool"]):
331
+ continue
332
+ for file in files:
333
+ if file.endswith(searchable_exts):
334
+ fpath = Path(root) / file
335
+ try:
336
+ lines = fpath.read_text("utf-8", "replace").splitlines()
337
+ for i, line in enumerate(lines):
338
+ stripped = line.strip()
339
+ for pattern in patterns:
340
+ if stripped.startswith(pattern) or f" {pattern}" in stripped:
341
+ rel_path = fpath.relative_to(_project_dir)
342
+ results.append(f"{rel_path}:{i+1}: {stripped[:120]}")
343
+ break
344
+ if len(results) >= 20:
345
+ return "\n".join(results) + "\n... (truncated)"
346
+ except Exception:
347
+ continue
348
+ if not results:
349
+ return f"No definition found for '{symbol_name}'"
350
+ return "\n".join(results)
351
+ except Exception as e:
352
+ return f"Error searching: {e}"
353
+
354
+ @tool
355
+ def get_file_structure(dir_path: str = ".", max_depth: int = 3) -> str:
356
+ """Get a tree view of the project structure up to a given depth.
357
+ Use this to understand the overall project layout.
358
+ Provide relative path and max depth (default 3).
359
+ """
360
+ if dir_path == ".":
361
+ full_path = _project_dir.resolve()
362
+ else:
363
+ full_path = (_project_dir / dir_path).resolve()
364
+
365
+ if not str(full_path).startswith(str(_project_dir.resolve())):
366
+ return "Error: Path traversal not allowed."
367
+ if not full_path.is_dir():
368
+ return f"Error: Directory '{dir_path}' not found."
369
+
370
+ skip_dirs = {".git", "node_modules", "__pycache__", "build", "dist", "target",
371
+ ".dart_tool", ".next", "venv", ".venv", ".idea", ".vs"}
372
+ lines = []
373
+
374
+ def _walk(path: Path, prefix: str, depth: int):
375
+ if depth > max_depth:
376
+ return
377
+ try:
378
+ entries = sorted(path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower()))
379
+ except PermissionError:
380
+ return
381
+ dirs = [e for e in entries if e.is_dir() and e.name not in skip_dirs]
382
+ files = [e for e in entries if e.is_file()]
383
+
384
+ for f in files[:20]:
385
+ lines.append(f"{prefix}{f.name}")
386
+ if len(files) > 20:
387
+ lines.append(f"{prefix}... ({len(files) - 20} more files)")
388
+
389
+ for d in dirs:
390
+ lines.append(f"{prefix}{d.name}/")
391
+ _walk(d, prefix + " ", depth + 1)
392
+
393
+ lines.append(f"{dir_path}/")
394
+ _walk(full_path, " ", 1)
395
+ return "\n".join(lines[:200])
396
+
397
+ @tool
398
+ def git_file_history(filepath: str, max_commits: int = 10) -> str:
399
+ """Get recent git commit history for a specific file.
400
+ Shows who changed it, when, and why. Use to understand evolution of a file.
401
+ Provide relative filepath.
402
+ """
403
+ full_path = (_project_dir / filepath).resolve()
404
+ if not str(full_path).startswith(str(_project_dir.resolve())):
405
+ return "Error: Path traversal not allowed."
406
+ try:
407
+ result = subprocess.run(
408
+ ["git", "log", f"-{max_commits}", "--pretty=format:%h %ai %s", "--", filepath],
409
+ cwd=str(_project_dir),
410
+ capture_output=True, text=True, timeout=10,
411
+ )
412
+ if result.returncode != 0:
413
+ return f"Git error: {result.stderr.strip()}"
414
+ if not result.stdout.strip():
415
+ return f"No git history found for '{filepath}'"
416
+ return f"Recent commits for '{filepath}':\n{result.stdout}"
417
+ except Exception as e:
418
+ return f"Error: {e}"
419
+
420
+ @tool
421
+ def git_blame_lines(filepath: str, start_line: int, end_line: int) -> str:
422
+ """Show git blame for specific line range of a file.
423
+ Use when you need to know who last changed specific lines and why.
424
+ """
425
+ full_path = (_project_dir / filepath).resolve()
426
+ if not str(full_path).startswith(str(_project_dir.resolve())):
427
+ return "Error: Path traversal not allowed."
428
+ try:
429
+ result = subprocess.run(
430
+ ["git", "blame", f"-L{start_line},{end_line}", "--date=short", filepath],
431
+ cwd=str(_project_dir),
432
+ capture_output=True, text=True, timeout=10,
433
+ )
434
+ if result.returncode != 0:
435
+ return f"Git error: {result.stderr.strip()}"
436
+ return result.stdout or "No blame output"
437
+ except Exception as e:
438
+ return f"Error: {e}"
439
+
440
+ @tool
441
+ def read_lines(filepath: str, start_line: int, end_line: int) -> str:
442
+ """Read a specific range of lines from a file.
443
+ Use when you only need a portion of a large file — avoids context overflow.
444
+ Line numbers are 1-based. Returns lines with line numbers prefixed.
445
+ """
446
+ full_path = (_project_dir / filepath).resolve()
447
+ if not str(full_path).startswith(str(_project_dir.resolve())):
448
+ return "Error: Path traversal not allowed."
449
+ if not full_path.is_file():
450
+ return f"Error: File '{filepath}' not found."
451
+ try:
452
+ lines = full_path.read_text(encoding="utf-8", errors="replace").splitlines()
453
+ start = max(1, start_line) - 1
454
+ end = min(len(lines), end_line)
455
+ if start >= len(lines):
456
+ return f"Error: File has only {len(lines)} lines."
457
+ selected = lines[start:end]
458
+ return "\n".join(f"{start + i + 1:>4} | {line}" for i, line in enumerate(selected))
459
+ except Exception as e:
460
+ return f"Error: {e}"
461
+
462
+ @tool
463
+ def regex_search(pattern: str, dir_path: str = ".", max_results: int = 30) -> str:
464
+ """Search files using a regex pattern. More powerful than keyword search.
465
+ Use for complex patterns like 'def .*async', 'TODO|FIXME|HACK', etc.
466
+ Returns matching lines with file paths and line numbers.
467
+ """
468
+ import re as re_mod
469
+ full_path = _project_dir if dir_path == "." else (_project_dir / dir_path)
470
+ full_path = full_path.resolve()
471
+ if not str(full_path).startswith(str(_project_dir.resolve())):
472
+ return "Error: Path traversal not allowed."
473
+
474
+ try:
475
+ compiled = re_mod.compile(pattern, re_mod.IGNORECASE)
476
+ except re_mod.error as e:
477
+ return f"Invalid regex: {e}"
478
+
479
+ results = []
480
+ try:
481
+ for root, _, files in os.walk(full_path):
482
+ if any(x in root for x in [".git", "node_modules", "__pycache__", "build", "dist", "target", ".dart_tool"]):
483
+ continue
484
+ for file in files:
485
+ if file.endswith(searchable_exts):
486
+ fpath = Path(root) / file
487
+ try:
488
+ for i, line in enumerate(fpath.read_text("utf-8", "replace").splitlines()):
489
+ if compiled.search(line):
490
+ rel_path = fpath.relative_to(_project_dir)
491
+ results.append(f"{rel_path}:{i+1}: {line.strip()[:120]}")
492
+ if len(results) >= max_results:
493
+ return f"Found {len(results)}+ matches (truncated):\n" + "\n".join(results)
494
+ except Exception:
495
+ continue
496
+ if not results:
497
+ return f"No matches for pattern '{pattern}' in {dir_path}"
498
+ return f"Found {len(results)} matches:\n" + "\n".join(results)
499
+ except Exception as e:
500
+ return f"Error: {e}"
501
+
502
+ @tool
503
+ def find_references(symbol_name: str, dir_path: str = ".") -> str:
504
+ """Find all usages/references of a symbol (function, class, variable) across the codebase.
505
+ Unlike find_definition which finds where something is declared, this finds where it's used.
506
+ Returns file paths, line numbers, and the line content.
507
+ """
508
+ full_path = _project_dir if dir_path == "." else (_project_dir / dir_path)
509
+ full_path = full_path.resolve()
510
+ if not str(full_path).startswith(str(_project_dir.resolve())):
511
+ return "Error: Path traversal not allowed."
512
+
513
+ results = []
514
+ try:
515
+ for root, _, files in os.walk(full_path):
516
+ if any(x in root for x in [".git", "node_modules", "__pycache__", "build", "dist", "target", ".dart_tool"]):
517
+ continue
518
+ for file in files:
519
+ if file.endswith(searchable_exts):
520
+ fpath = Path(root) / file
521
+ try:
522
+ for i, line in enumerate(fpath.read_text("utf-8", "replace").splitlines()):
523
+ if symbol_name in line:
524
+ rel_path = fpath.relative_to(_project_dir)
525
+ results.append(f"{rel_path}:{i+1}: {line.strip()[:120]}")
526
+ if len(results) >= 40:
527
+ return f"Found {len(results)}+ references (truncated):\n" + "\n".join(results)
528
+ except Exception:
529
+ continue
530
+ if not results:
531
+ return f"No references found for '{symbol_name}'"
532
+ return f"Found {len(results)} references:\n" + "\n".join(results)
533
+ except Exception as e:
534
+ return f"Error: {e}"
535
+
536
+ @tool
537
+ def git_diff(ref: str = "HEAD~1", filepath: str = "") -> str:
538
+ """Show git diff for recent changes. Use to understand what changed recently.
539
+ ref: git reference like 'HEAD~1', 'HEAD~3', a branch name, or commit hash.
540
+ filepath: optional — limit diff to a specific file.
541
+ """
542
+ cmd = ["git", "diff", "--stat", "-p", ref]
543
+ if filepath:
544
+ full_path = (_project_dir / filepath).resolve()
545
+ if not str(full_path).startswith(str(_project_dir.resolve())):
546
+ return "Error: Path traversal not allowed."
547
+ cmd.extend(["--", filepath])
548
+ try:
549
+ result = subprocess.run(
550
+ cmd, cwd=str(_project_dir),
551
+ capture_output=True, text=True, timeout=15,
552
+ )
553
+ if result.returncode != 0:
554
+ return f"Git error: {result.stderr.strip()}"
555
+ output = result.stdout.strip()
556
+ if not output:
557
+ return "No changes found."
558
+ if len(output) > 15000:
559
+ return output[:15000] + "\n\n... [Diff truncated] ..."
560
+ return output
561
+ except Exception as e:
562
+ return f"Error: {e}"
563
+
564
+ @tool
565
+ def git_log_search(search_term: str, max_commits: int = 10) -> str:
566
+ """Search git commit messages and diffs for a term.
567
+ Use to find when a feature was added, a bug was introduced, or a file was changed.
568
+ Searches both commit messages (-grep) and code changes (-S).
569
+ """
570
+ results = []
571
+ try:
572
+ msg_result = subprocess.run(
573
+ ["git", "log", f"-{max_commits}", "--pretty=format:%h %ai %s", f"--grep={search_term}", "-i"],
574
+ cwd=str(_project_dir),
575
+ capture_output=True, text=True, timeout=10,
576
+ )
577
+ if msg_result.stdout.strip():
578
+ results.append("Commits mentioning '" + search_term + "':\n" + msg_result.stdout.strip())
579
+ except Exception:
580
+ pass
581
+
582
+ try:
583
+ code_result = subprocess.run(
584
+ ["git", "log", f"-{max_commits}", "--pretty=format:%h %ai %s", f"-S{search_term}"],
585
+ cwd=str(_project_dir),
586
+ capture_output=True, text=True, timeout=10,
587
+ )
588
+ if code_result.stdout.strip():
589
+ results.append("Commits changing code with '" + search_term + "':\n" + code_result.stdout.strip())
590
+ except Exception:
591
+ pass
592
+
593
+ if not results:
594
+ return f"No commits found related to '{search_term}'"
595
+ return "\n\n".join(results)
596
+
597
+ @tool
598
+ def file_stats(filepath: str) -> str:
599
+ """Get stats about a file: line count, size, last modified, language.
600
+ Use to quickly assess file complexity and recency.
601
+ """
602
+ full_path = (_project_dir / filepath).resolve()
603
+ if not str(full_path).startswith(str(_project_dir.resolve())):
604
+ return "Error: Path traversal not allowed."
605
+ if not full_path.is_file():
606
+ return f"Error: File '{filepath}' not found."
607
+ try:
608
+ content = full_path.read_text(encoding="utf-8", errors="replace")
609
+ lines = content.splitlines()
610
+ non_blank = sum(1 for l in lines if l.strip())
611
+ stat = full_path.stat()
612
+ from datetime import datetime
613
+ modified = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M")
614
+ size_kb = stat.st_size / 1024
615
+
616
+ parts = [
617
+ f"File: {filepath}",
618
+ f"Lines: {len(lines)} ({non_blank} non-blank)",
619
+ f"Size: {size_kb:.1f} KB",
620
+ f"Last modified: {modified}",
621
+ f"Extension: {full_path.suffix}",
622
+ ]
623
+
624
+ import_count = sum(1 for l in lines if l.strip().startswith(("import ", "from ", "#include", "use ", "require")))
625
+ class_count = sum(1 for l in lines if any(l.strip().startswith(k) for k in ("class ", "struct ", "enum ", "interface ", "trait ")))
626
+ func_count = sum(1 for l in lines if any(l.strip().startswith(k) for k in ("def ", "func ", "fn ", "function ", "async def ", "pub fn ")))
627
+
628
+ if import_count: parts.append(f"Imports: {import_count}")
629
+ if class_count: parts.append(f"Classes/structs: {class_count}")
630
+ if func_count: parts.append(f"Functions: {func_count}")
631
+
632
+ return "\n".join(parts)
633
+ except Exception as e:
634
+ return f"Error: {e}"
635
+
636
+ @tool
637
+ def find_files_by_name(filename_pattern: str) -> str:
638
+ """Find files whose name matches a pattern (case-insensitive substring match).
639
+ Use when you know part of a filename but not its full path.
640
+ Example: 'auth' finds auth_service.py, AuthController.java, etc.
641
+ """
642
+ pattern_lower = filename_pattern.lower()
643
+ results = []
644
+ try:
645
+ for root, dirs, files in os.walk(_project_dir):
646
+ dirs[:] = [d for d in dirs if d not in {
647
+ ".git", "node_modules", "__pycache__", "build", "dist",
648
+ "target", ".dart_tool", ".next", "venv", ".venv",
649
+ }]
650
+ for file in files:
651
+ if pattern_lower in file.lower():
652
+ fpath = Path(root) / file
653
+ rel_path = fpath.relative_to(_project_dir)
654
+ results.append(str(rel_path))
655
+ if len(results) >= 30:
656
+ return f"Found {len(results)}+ files (truncated):\n" + "\n".join(results)
657
+ if not results:
658
+ return f"No files matching '{filename_pattern}'"
659
+ return f"Found {len(results)} files:\n" + "\n".join(results)
660
+ except Exception as e:
661
+ return f"Error: {e}"
662
+
663
+ @tool
664
+ def get_call_chain(filepath: str, function_name: str) -> str:
665
+ """Trace who calls a function and what it calls.
666
+ Returns callers (files that reference this function) and callees
667
+ (functions/methods invoked inside the given function's body).
668
+ """
669
+ full_path = (_project_dir / filepath).resolve()
670
+ if not str(full_path).startswith(str(_project_dir.resolve())):
671
+ return "Error: Path traversal not allowed."
672
+ if not full_path.is_file():
673
+ return f"Error: File '{filepath}' not found."
674
+
675
+ parts = []
676
+
677
+ importers = graph.get_importers(filepath)
678
+ callers = []
679
+ for root, _, files in os.walk(_project_dir):
680
+ if any(x in root for x in [".git", "node_modules", "__pycache__", "build", "dist", "target"]):
681
+ continue
682
+ for file in files:
683
+ if file.endswith(searchable_exts):
684
+ fpath = Path(root) / file
685
+ if fpath.resolve() == full_path:
686
+ continue
687
+ try:
688
+ content = fpath.read_text("utf-8", "replace")
689
+ if function_name in content:
690
+ rel = fpath.relative_to(_project_dir)
691
+ for i, line in enumerate(content.splitlines()):
692
+ if function_name in line:
693
+ callers.append(f" {rel}:{i+1}: {line.strip()[:100]}")
694
+ break
695
+ except Exception:
696
+ continue
697
+ if len(callers) >= 15:
698
+ break
699
+
700
+ if callers:
701
+ parts.append(f"Callers of {function_name}:\n" + "\n".join(callers))
702
+ else:
703
+ parts.append(f"No callers found for {function_name}")
704
+
705
+ try:
706
+ import re as re_mod
707
+ content = full_path.read_text("utf-8", "replace")
708
+ lines = content.splitlines()
709
+
710
+ func_start = -1
711
+ for i, line in enumerate(lines):
712
+ if function_name in line and any(k in line for k in ("def ", "func ", "fn ", "function ", "void ", "class ")):
713
+ func_start = i
714
+ break
715
+
716
+ if func_start >= 0:
717
+ body = "\n".join(lines[func_start:func_start + 50])
718
+ callees = set(re_mod.findall(r'\b([a-zA-Z_]\w+)\s*\(', body))
719
+ keywords = {"if", "for", "while", "switch", "catch", "return", "print", "throw", function_name}
720
+ callees -= keywords
721
+ if callees:
722
+ parts.append(f"Functions called by {function_name}:\n " + ", ".join(sorted(callees)))
723
+ except Exception:
724
+ pass
725
+
726
+ return "\n\n".join(parts) if parts else f"Could not trace call chain for {function_name}"
727
+
728
+ @tool
729
+ def semantic_code_search(query: str, top_k: int = 6) -> str:
730
+ """Search the codebase using semantic similarity (vector embeddings).
731
+ Unlike keyword search, this finds code by MEANING — even if the exact
732
+ words aren't present. Use for conceptual queries like:
733
+ - 'error handling logic'
734
+ - 'user authentication flow'
735
+ - 'database connection setup'
736
+ Returns the most semantically relevant code chunks with file paths.
737
+ """
738
+ if not collection:
739
+ return "Error: Vector database not available."
740
+ try:
741
+ query_vector = embed_text(query)
742
+ results = query_similar(collection, query_vector, n_results=top_k)
743
+ if not results:
744
+ return f"No semantically relevant code found for: '{query}'"
745
+ parts = []
746
+ for r in results:
747
+ score = r.get('relevance', 0)
748
+ fname = r.get('filename', '?')
749
+ loc = r.get('location', '')
750
+ code = r.get('code', '')[:500]
751
+ feature = r.get('feature', '')
752
+ layer = r.get('layer', '')
753
+ meta = ""
754
+ if feature and feature != "unknown":
755
+ meta += f" [feature: {feature}]"
756
+ if layer and layer != "unknown":
757
+ meta += f" [layer: {layer}]"
758
+ parts.append(f"--- {fname} ({loc}) [score: {score:.2f}]{meta}\n{code}")
759
+ return f"Found {len(results)} semantically similar chunks:\n\n" + "\n\n".join(parts)
760
+ except Exception as e:
761
+ return f"Error in semantic search: {e}"
762
+
763
+ @tool
764
+ def find_similar_code(filepath: str, chunk_index: int = 0, top_k: int = 5) -> str:
765
+ """Find code chunks that are semantically similar to a specific file/chunk.
766
+ Use this to discover related implementations, duplicated logic, or code
767
+ that serves a similar purpose elsewhere in the codebase.
768
+ Provide the file path and optionally a chunk index (default 0 = first chunk).
769
+ """
770
+ if not collection:
771
+ return "Error: Vector database not available."
772
+ try:
773
+ full_path = (_project_dir / filepath).resolve()
774
+ if not str(full_path).startswith(str(_project_dir.resolve())):
775
+ return "Error: Path traversal not allowed."
776
+ if not full_path.is_file():
777
+ return f"Error: File '{filepath}' not found."
778
+
779
+ content = full_path.read_text(encoding="utf-8", errors="replace")
780
+ query_text = content[:800]
781
+ query_vector = embed_text(query_text)
782
+
783
+ results = query_similar(collection, query_vector, n_results=top_k + 2)
784
+ results = [r for r in results if r.get('filename') != filepath][:top_k]
785
+
786
+ if not results:
787
+ return f"No similar code found to '{filepath}'"
788
+ parts = []
789
+ for r in results:
790
+ score = r.get('relevance', 0)
791
+ fname = r.get('filename', '?')
792
+ code = r.get('code', '')[:400]
793
+ parts.append(f"--- {fname} [similarity: {score:.2f}]\n{code}")
794
+ return f"Files similar to '{filepath}':\n\n" + "\n\n".join(parts)
795
+ except Exception as e:
796
+ return f"Error: {e}"
797
+
798
+ return [
799
+ read_local_file,
800
+ read_lines,
801
+ list_directory,
802
+ find_file_usages,
803
+ search_codebase_keywords,
804
+ regex_search,
805
+ find_definition,
806
+ find_references,
807
+ find_files_by_name,
808
+ get_file_structure,
809
+ get_call_chain,
810
+ file_stats,
811
+ semantic_code_search,
812
+ find_similar_code,
813
+ git_file_history,
814
+ git_blame_lines,
815
+ git_diff,
816
+ git_log_search,
817
+ ]
818
+
819
+
820
+ # ---------------------------------------------------------------------------
821
+ # Streaming RAG response
822
+ # ---------------------------------------------------------------------------
823
+
824
+
825
+ async def stream_rag_response(
826
+ collection: Any,
827
+ graph: ImportGraph,
828
+ query: str,
829
+ mode: str,
830
+ model: str = "gemini-2.5-flash",
831
+ history: list[dict[str, str]] | None = None,
832
+ profile: LanguageProfile | None = None,
833
+ project_dir: Path | None = None,
834
+ gemini_api_key: str = "",
835
+ ) -> AsyncGenerator[dict[str, Any], None]:
836
+ """Async generator that retrieves context and streams a Gemini response."""
837
+ if history is None:
838
+ history = []
839
+
840
+ _project_dir = project_dir or Path.cwd()
841
+ highlight_language = profile.highlight_language if profile else "text"
842
+
843
+ yield {"type": "status", "stage": "retrieving"}
844
+
845
+ try:
846
+ chunks = await retrieve_for_mode(collection, graph, query, mode, profile=profile)
847
+ except Exception as exc:
848
+ yield {"type": "error", "content": f"Retrieval failed: {exc}"}
849
+ return
850
+
851
+ unique_files = len(set(c.get("filename", "") for c in chunks))
852
+ yield {"type": "status", "stage": "context_found", "chunks": len(chunks), "files": unique_files}
853
+
854
+ claude_md, memory_content = _load_static_context(_project_dir)
855
+
856
+ prompt = assemble_prompt(
857
+ mode=mode,
858
+ query=query,
859
+ chunks=chunks,
860
+ claude_md=claude_md,
861
+ memory_content=memory_content,
862
+ history=history,
863
+ model=model,
864
+ highlight_language=highlight_language,
865
+ )
866
+
867
+ messages: list[Any] = [SystemMessage(content=prompt["system"])]
868
+
869
+ for turn in prompt["history"]:
870
+ role = turn.get("role", "user")
871
+ content = turn.get("content", "")
872
+ if role == "assistant":
873
+ messages.append(AIMessage(content=content))
874
+ else:
875
+ messages.append(HumanMessage(content=content))
876
+
877
+ context_block = prompt["context"]
878
+ if context_block:
879
+ human_content = f"## Retrieved Code Context\n\n{context_block}\n\n---\n\n{query}"
880
+ else:
881
+ human_content = query
882
+
883
+ messages.append(HumanMessage(content=human_content))
884
+
885
+ sources = [
886
+ {"filename": c.get("filename", ""), "priority": c.get("priority", "semantic")}
887
+ for c in chunks
888
+ ]
889
+
890
+ try:
891
+ tools = _get_tools(graph, profile, collection, _project_dir)
892
+ llm = ChatGoogleGenerativeAI(
893
+ model=model,
894
+ google_api_key=gemini_api_key,
895
+ streaming=True,
896
+ )
897
+ llm_with_tools = llm.bind_tools(tools)
898
+
899
+ yield {"type": "status", "stage": "thinking"}
900
+
901
+ MAX_STEPS = 5
902
+ step = 0
903
+ first_token = True
904
+
905
+ while step < MAX_STEPS:
906
+ step += 1
907
+ has_tool_calls = False
908
+ full_chunk_msg = None
909
+
910
+ async for chunk in llm_with_tools.astream(messages):
911
+ if full_chunk_msg is None:
912
+ full_chunk_msg = chunk
913
+ else:
914
+ full_chunk_msg += chunk
915
+
916
+ if chunk.content:
917
+ content_val = chunk.content
918
+ if isinstance(content_val, list):
919
+ text_parts = [
920
+ b.get("text", "") for b in content_val
921
+ if isinstance(b, dict) and b.get("type") == "text"
922
+ ]
923
+ content_str = "".join(text_parts)
924
+ else:
925
+ content_str = str(content_val)
926
+
927
+ if content_str:
928
+ if first_token:
929
+ yield {"type": "status", "stage": "generating"}
930
+ first_token = False
931
+ yield {"type": "token", "content": content_str}
932
+
933
+ if full_chunk_msg:
934
+ messages.append(full_chunk_msg)
935
+
936
+ if full_chunk_msg.tool_calls:
937
+ has_tool_calls = True
938
+ for tool_call in full_chunk_msg.tool_calls:
939
+ tool_name = tool_call["name"]
940
+ args = tool_call["args"]
941
+ tool_id = tool_call["id"]
942
+
943
+ args_summary = ", ".join(f"{k}={repr(v)[:40]}" for k, v in args.items()) if args else ""
944
+ yield {"type": "status", "stage": "tool_call", "tool": tool_name, "args": args_summary}
945
+
946
+ selected_tool = next((t for t in tools if t.name == tool_name), None)
947
+ if selected_tool:
948
+ try:
949
+ result = await asyncio.to_thread(selected_tool.invoke, args)
950
+ result_str = str(result)
951
+ except Exception as e:
952
+ result_str = f"Error executing tool: {e}"
953
+ else:
954
+ result_str = f"Error: Tool {tool_name} not found."
955
+
956
+ yield {"type": "status", "stage": "tool_done", "tool": tool_name}
957
+ messages.append(ToolMessage(content=result_str, tool_call_id=tool_id))
958
+
959
+ first_token = True
960
+
961
+ if not has_tool_calls:
962
+ break
963
+
964
+ yield {"type": "done", "sources": sources, "stats": prompt["stats"]}
965
+
966
+ except Exception as exc:
967
+ yield {"type": "error", "content": f"Gemini API error: {exc}"}