sourcefire 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sourcefire/__init__.py +0 -0
- sourcefire/api/__init__.py +0 -0
- sourcefire/api/models.py +24 -0
- sourcefire/api/routes.py +166 -0
- sourcefire/chain/__init__.py +0 -0
- sourcefire/chain/prompts.py +195 -0
- sourcefire/chain/rag_chain.py +967 -0
- sourcefire/cli.py +293 -0
- sourcefire/config.py +148 -0
- sourcefire/db.py +196 -0
- sourcefire/indexer/__init__.py +0 -0
- sourcefire/indexer/embeddings.py +27 -0
- sourcefire/indexer/language_profiles.py +448 -0
- sourcefire/indexer/metadata.py +289 -0
- sourcefire/indexer/pipeline.py +406 -0
- sourcefire/init.py +189 -0
- sourcefire/prompts/system.md +28 -0
- sourcefire/retriever/__init__.py +0 -0
- sourcefire/retriever/graph.py +162 -0
- sourcefire/retriever/search.py +86 -0
- sourcefire/static/.DS_Store +0 -0
- sourcefire/static/app.js +414 -0
- sourcefire/static/index.html +102 -0
- sourcefire/static/styles.css +607 -0
- sourcefire/watcher.py +105 -0
- sourcefire-0.2.0.dist-info/METADATA +145 -0
- sourcefire-0.2.0.dist-info/RECORD +31 -0
- sourcefire-0.2.0.dist-info/WHEEL +5 -0
- sourcefire-0.2.0.dist-info/entry_points.txt +2 -0
- sourcefire-0.2.0.dist-info/licenses/LICENSE +21 -0
- sourcefire-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,967 @@
|
|
|
1
|
+
"""LangChain RAG chain with mode-aware retrieval and Gemini API streaming."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import os
|
|
7
|
+
import subprocess
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, AsyncGenerator
|
|
10
|
+
|
|
11
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
12
|
+
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
|
13
|
+
from langchain_core.tools import tool
|
|
14
|
+
|
|
15
|
+
from sourcefire.indexer.embeddings import embed_text
|
|
16
|
+
from sourcefire.indexer.language_profiles import LanguageProfile
|
|
17
|
+
from sourcefire.retriever.search import semantic_search, get_chunks_by_filenames, parse_file_references
|
|
18
|
+
from sourcefire.retriever.graph import ImportGraph
|
|
19
|
+
from sourcefire.chain.prompts import assemble_prompt
|
|
20
|
+
from sourcefire.db import query_similar
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# ---------------------------------------------------------------------------
|
|
24
|
+
# Static context loader
|
|
25
|
+
# ---------------------------------------------------------------------------
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _load_static_context(project_dir: Path) -> tuple[str, str]:
|
|
29
|
+
"""Load CLAUDE.md from project_dir.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
A 2-tuple of (claude_md_content, memory_content).
|
|
33
|
+
"""
|
|
34
|
+
claude_md = ""
|
|
35
|
+
claude_md_path = project_dir / "CLAUDE.md"
|
|
36
|
+
if claude_md_path.is_file():
|
|
37
|
+
try:
|
|
38
|
+
claude_md = claude_md_path.read_text(encoding="utf-8", errors="replace")
|
|
39
|
+
except OSError:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
return claude_md, ""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ---------------------------------------------------------------------------
|
|
46
|
+
# Mode-specific retrievers
|
|
47
|
+
# ---------------------------------------------------------------------------
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
async def _retrieve_debug(
|
|
51
|
+
collection: Any,
|
|
52
|
+
graph: ImportGraph,
|
|
53
|
+
query: str,
|
|
54
|
+
query_vector: list[float],
|
|
55
|
+
top_k: int,
|
|
56
|
+
profile: LanguageProfile | None = None,
|
|
57
|
+
) -> list[dict[str, Any]]:
|
|
58
|
+
"""Debug mode: parse stack trace -> direct lookup -> graph expansion -> semantic."""
|
|
59
|
+
chunks: list[dict[str, Any]] = []
|
|
60
|
+
seen_filenames: set[str] = set()
|
|
61
|
+
|
|
62
|
+
file_ref_patterns = profile.file_ref_patterns if profile else None
|
|
63
|
+
file_refs = parse_file_references(query, file_ref_patterns)
|
|
64
|
+
direct_filenames = [ref["file"] for ref in file_refs]
|
|
65
|
+
|
|
66
|
+
if direct_filenames:
|
|
67
|
+
direct_chunks = await get_chunks_by_filenames(collection, direct_filenames)
|
|
68
|
+
for c in direct_chunks:
|
|
69
|
+
c["priority"] = "direct"
|
|
70
|
+
c.setdefault("relevance", 1.0)
|
|
71
|
+
chunks.append(c)
|
|
72
|
+
seen_filenames.add(c["filename"])
|
|
73
|
+
|
|
74
|
+
graph_filenames: list[str] = []
|
|
75
|
+
for fname in direct_filenames:
|
|
76
|
+
graph_filenames.extend(graph.get_neighbors(fname, hops=1))
|
|
77
|
+
|
|
78
|
+
graph_filenames = [f for f in graph_filenames if f not in seen_filenames]
|
|
79
|
+
|
|
80
|
+
if graph_filenames:
|
|
81
|
+
graph_chunks = await get_chunks_by_filenames(collection, graph_filenames)
|
|
82
|
+
for c in graph_chunks:
|
|
83
|
+
c["priority"] = "graph"
|
|
84
|
+
c.setdefault("relevance", 0.6)
|
|
85
|
+
chunks.append(c)
|
|
86
|
+
seen_filenames.add(c["filename"])
|
|
87
|
+
|
|
88
|
+
semantic_chunks = await semantic_search(collection, query_vector, top_k=top_k)
|
|
89
|
+
for c in semantic_chunks:
|
|
90
|
+
if c["filename"] not in seen_filenames:
|
|
91
|
+
c["priority"] = "semantic"
|
|
92
|
+
chunks.append(c)
|
|
93
|
+
seen_filenames.add(c["filename"])
|
|
94
|
+
|
|
95
|
+
return chunks
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
async def _retrieve_feature(
|
|
99
|
+
collection: Any,
|
|
100
|
+
graph: ImportGraph,
|
|
101
|
+
query: str,
|
|
102
|
+
query_vector: list[float],
|
|
103
|
+
top_k: int,
|
|
104
|
+
profile: LanguageProfile | None = None,
|
|
105
|
+
) -> list[dict[str, Any]]:
|
|
106
|
+
"""Feature mode: semantic search -> best feature -> retrieve feature chunks."""
|
|
107
|
+
_FEATURE_CAP = 15
|
|
108
|
+
|
|
109
|
+
seed_chunks = await semantic_search(collection, query_vector, top_k=top_k)
|
|
110
|
+
|
|
111
|
+
feature_scores: dict[str, list[float]] = {}
|
|
112
|
+
for c in seed_chunks:
|
|
113
|
+
feat = c.get("feature") or "core"
|
|
114
|
+
feature_scores.setdefault(feat, []).append(float(c.get("relevance", 0.0)))
|
|
115
|
+
|
|
116
|
+
if not feature_scores:
|
|
117
|
+
for c in seed_chunks:
|
|
118
|
+
c["priority"] = "semantic"
|
|
119
|
+
return seed_chunks
|
|
120
|
+
|
|
121
|
+
best_feature = max(feature_scores, key=lambda f: sum(feature_scores[f]) / len(feature_scores[f]))
|
|
122
|
+
|
|
123
|
+
feature_chunks = await semantic_search(
|
|
124
|
+
collection,
|
|
125
|
+
query_vector,
|
|
126
|
+
top_k=_FEATURE_CAP,
|
|
127
|
+
feature=best_feature,
|
|
128
|
+
)
|
|
129
|
+
for c in feature_chunks:
|
|
130
|
+
c["priority"] = "semantic"
|
|
131
|
+
|
|
132
|
+
return feature_chunks
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
async def _retrieve_explain(
|
|
136
|
+
collection: Any,
|
|
137
|
+
graph: ImportGraph,
|
|
138
|
+
query: str,
|
|
139
|
+
query_vector: list[float],
|
|
140
|
+
top_k: int,
|
|
141
|
+
profile: LanguageProfile | None = None,
|
|
142
|
+
) -> list[dict[str, Any]]:
|
|
143
|
+
"""Explain mode: semantic search -> import graph expansion in both directions."""
|
|
144
|
+
chunks: list[dict[str, Any]] = []
|
|
145
|
+
seen_filenames: set[str] = set()
|
|
146
|
+
|
|
147
|
+
seed_chunks = await semantic_search(collection, query_vector, top_k=top_k)
|
|
148
|
+
for c in seed_chunks:
|
|
149
|
+
c["priority"] = "semantic"
|
|
150
|
+
chunks.append(c)
|
|
151
|
+
seen_filenames.add(c["filename"])
|
|
152
|
+
|
|
153
|
+
neighbor_filenames: list[str] = []
|
|
154
|
+
for c in seed_chunks:
|
|
155
|
+
for neighbor in graph.get_neighbors(c["filename"], hops=1):
|
|
156
|
+
if neighbor not in seen_filenames:
|
|
157
|
+
neighbor_filenames.append(neighbor)
|
|
158
|
+
seen_filenames.add(neighbor)
|
|
159
|
+
|
|
160
|
+
if neighbor_filenames:
|
|
161
|
+
neighbor_chunks = await get_chunks_by_filenames(collection, neighbor_filenames)
|
|
162
|
+
for c in neighbor_chunks:
|
|
163
|
+
c["priority"] = "graph"
|
|
164
|
+
c.setdefault("relevance", 0.5)
|
|
165
|
+
chunks.append(c)
|
|
166
|
+
|
|
167
|
+
return chunks
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
# ---------------------------------------------------------------------------
|
|
171
|
+
# Public retrieval entry point
|
|
172
|
+
# ---------------------------------------------------------------------------
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
async def retrieve_for_mode(
|
|
176
|
+
collection: Any,
|
|
177
|
+
graph: ImportGraph,
|
|
178
|
+
query: str,
|
|
179
|
+
mode: str,
|
|
180
|
+
top_k: int = 8,
|
|
181
|
+
profile: LanguageProfile | None = None,
|
|
182
|
+
) -> list[dict[str, Any]]:
|
|
183
|
+
"""Embed *query* and dispatch to the mode-specific retriever."""
|
|
184
|
+
loop = asyncio.get_event_loop()
|
|
185
|
+
query_vector: list[float] = await loop.run_in_executor(None, embed_text, query)
|
|
186
|
+
|
|
187
|
+
if mode == "debug":
|
|
188
|
+
return await _retrieve_debug(collection, graph, query, query_vector, top_k, profile)
|
|
189
|
+
elif mode == "feature":
|
|
190
|
+
return await _retrieve_feature(collection, graph, query, query_vector, top_k, profile)
|
|
191
|
+
else:
|
|
192
|
+
return await _retrieve_explain(collection, graph, query, query_vector, top_k, profile)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# ---------------------------------------------------------------------------
|
|
196
|
+
# LangChain Tools
|
|
197
|
+
# ---------------------------------------------------------------------------
|
|
198
|
+
|
|
199
|
+
def _get_tools(
|
|
200
|
+
graph: ImportGraph,
|
|
201
|
+
profile: LanguageProfile | None = None,
|
|
202
|
+
collection: Any = None,
|
|
203
|
+
project_dir: Path | None = None,
|
|
204
|
+
) -> list[Any]:
|
|
205
|
+
|
|
206
|
+
_project_dir = project_dir or Path.cwd()
|
|
207
|
+
searchable_exts = tuple(profile.searchable_extensions) if profile else (".py", ".js", ".ts", ".go", ".rs", ".java", ".dart", ".yaml", ".json", ".md")
|
|
208
|
+
|
|
209
|
+
@tool
|
|
210
|
+
def read_local_file(filepath: str) -> str:
|
|
211
|
+
"""Reads the complete content of a file from the repository.
|
|
212
|
+
Use when you need to see exact implementation details.
|
|
213
|
+
Provide the relative filepath (e.g. 'src/main.py').
|
|
214
|
+
"""
|
|
215
|
+
full_path = (_project_dir / filepath).resolve()
|
|
216
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
217
|
+
return "Error: Path traversal not allowed."
|
|
218
|
+
if not full_path.is_file():
|
|
219
|
+
return f"Error: File '{filepath}' not found in the codebase."
|
|
220
|
+
try:
|
|
221
|
+
content = full_path.read_text(encoding="utf-8", errors="replace")
|
|
222
|
+
if len(content) > 30000:
|
|
223
|
+
return content[:30000] + "\n\n... [File truncated because it is too large] ..."
|
|
224
|
+
return content
|
|
225
|
+
except Exception as e:
|
|
226
|
+
return f"Error reading file: {e}"
|
|
227
|
+
|
|
228
|
+
@tool
|
|
229
|
+
def list_directory(dir_path: str) -> str:
|
|
230
|
+
"""Lists files and folders within a specific directory in the repository.
|
|
231
|
+
Use to understand project structure or find files.
|
|
232
|
+
Provide the relative path (e.g. 'src/components'). Use '.' for root.
|
|
233
|
+
"""
|
|
234
|
+
if dir_path == ".":
|
|
235
|
+
full_path = _project_dir.resolve()
|
|
236
|
+
else:
|
|
237
|
+
full_path = (_project_dir / dir_path).resolve()
|
|
238
|
+
|
|
239
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
240
|
+
return "Error: Path traversal not allowed."
|
|
241
|
+
if not full_path.is_dir():
|
|
242
|
+
return f"Error: Directory '{dir_path}' not found."
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
items = []
|
|
246
|
+
for item in full_path.iterdir():
|
|
247
|
+
suffix = "/" if item.is_dir() else ""
|
|
248
|
+
items.append(item.name + suffix)
|
|
249
|
+
return f"Contents of '{dir_path}':\n" + "\n".join(sorted(items))
|
|
250
|
+
except Exception as e:
|
|
251
|
+
return f"Error listing directory: {e}"
|
|
252
|
+
|
|
253
|
+
@tool
|
|
254
|
+
def find_file_usages(filepath: str) -> str:
|
|
255
|
+
"""Find other files that import or depend on the given file.
|
|
256
|
+
Provide the relative filepath (e.g. 'src/utils/auth.py').
|
|
257
|
+
Returns a list of files that directly import it.
|
|
258
|
+
"""
|
|
259
|
+
importers = graph.get_importers(filepath)
|
|
260
|
+
if not importers:
|
|
261
|
+
return f"No direct local dependencies found importing {filepath}."
|
|
262
|
+
return f"Files importing '{filepath}':\n" + "\n".join(f"- {f}" for f in importers)
|
|
263
|
+
|
|
264
|
+
@tool
|
|
265
|
+
def search_codebase_keywords(query: str, dir_path: str = ".") -> str:
|
|
266
|
+
"""Search for an exact string or keyword across files in a directory.
|
|
267
|
+
Returns files and line numbers where the keyword appears.
|
|
268
|
+
Use for variable names, classes, function names, or patterns.
|
|
269
|
+
"""
|
|
270
|
+
full_path = _project_dir if dir_path == "." else (_project_dir / dir_path)
|
|
271
|
+
full_path = full_path.resolve()
|
|
272
|
+
|
|
273
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
274
|
+
return "Error: Path traversal not allowed."
|
|
275
|
+
|
|
276
|
+
if not full_path.is_dir() and not full_path.is_file():
|
|
277
|
+
return f"Error: Directory '{dir_path}' not found."
|
|
278
|
+
|
|
279
|
+
results = []
|
|
280
|
+
try:
|
|
281
|
+
for root, _, files in os.walk(full_path):
|
|
282
|
+
if any(x in root for x in [".git", ".claude", "node_modules", "__pycache__", "build", "dist", "target", ".dart_tool"]):
|
|
283
|
+
continue
|
|
284
|
+
for file in files:
|
|
285
|
+
if file.endswith(searchable_exts):
|
|
286
|
+
fpath = Path(root) / file
|
|
287
|
+
try:
|
|
288
|
+
lines = fpath.read_text("utf-8", "replace").splitlines()
|
|
289
|
+
for i, line in enumerate(lines):
|
|
290
|
+
if query.lower() in line.lower():
|
|
291
|
+
rel_path = fpath.relative_to(_project_dir)
|
|
292
|
+
results.append(f"{rel_path}:{i+1}: {line.strip()[:100]}")
|
|
293
|
+
if len(results) >= 50:
|
|
294
|
+
return f"Results truncated at 50 matches.\n" + "\n".join(results)
|
|
295
|
+
except Exception:
|
|
296
|
+
continue
|
|
297
|
+
if not results:
|
|
298
|
+
return f"No matches found for '{query}' in {dir_path}"
|
|
299
|
+
return "\n".join(results)
|
|
300
|
+
except Exception as e:
|
|
301
|
+
return f"Error searching: {e}"
|
|
302
|
+
|
|
303
|
+
@tool
|
|
304
|
+
def find_definition(symbol_name: str) -> str:
|
|
305
|
+
"""Find where a class, function, or variable is defined in the codebase.
|
|
306
|
+
Searches for definition patterns like 'class Foo', 'def foo', 'function foo', etc.
|
|
307
|
+
Returns file paths and line numbers.
|
|
308
|
+
"""
|
|
309
|
+
patterns = [
|
|
310
|
+
f"class {symbol_name}",
|
|
311
|
+
f"def {symbol_name}",
|
|
312
|
+
f"async def {symbol_name}",
|
|
313
|
+
f"function {symbol_name}",
|
|
314
|
+
f"const {symbol_name}",
|
|
315
|
+
f"let {symbol_name}",
|
|
316
|
+
f"var {symbol_name}",
|
|
317
|
+
f"type {symbol_name}",
|
|
318
|
+
f"interface {symbol_name}",
|
|
319
|
+
f"enum {symbol_name}",
|
|
320
|
+
f"struct {symbol_name}",
|
|
321
|
+
f"trait {symbol_name}",
|
|
322
|
+
f"impl {symbol_name}",
|
|
323
|
+
f"func {symbol_name}",
|
|
324
|
+
f"mixin {symbol_name}",
|
|
325
|
+
f"extension {symbol_name}",
|
|
326
|
+
]
|
|
327
|
+
results = []
|
|
328
|
+
try:
|
|
329
|
+
for root, _, files in os.walk(_project_dir):
|
|
330
|
+
if any(x in root for x in [".git", "node_modules", "__pycache__", "build", "dist", "target", ".dart_tool"]):
|
|
331
|
+
continue
|
|
332
|
+
for file in files:
|
|
333
|
+
if file.endswith(searchable_exts):
|
|
334
|
+
fpath = Path(root) / file
|
|
335
|
+
try:
|
|
336
|
+
lines = fpath.read_text("utf-8", "replace").splitlines()
|
|
337
|
+
for i, line in enumerate(lines):
|
|
338
|
+
stripped = line.strip()
|
|
339
|
+
for pattern in patterns:
|
|
340
|
+
if stripped.startswith(pattern) or f" {pattern}" in stripped:
|
|
341
|
+
rel_path = fpath.relative_to(_project_dir)
|
|
342
|
+
results.append(f"{rel_path}:{i+1}: {stripped[:120]}")
|
|
343
|
+
break
|
|
344
|
+
if len(results) >= 20:
|
|
345
|
+
return "\n".join(results) + "\n... (truncated)"
|
|
346
|
+
except Exception:
|
|
347
|
+
continue
|
|
348
|
+
if not results:
|
|
349
|
+
return f"No definition found for '{symbol_name}'"
|
|
350
|
+
return "\n".join(results)
|
|
351
|
+
except Exception as e:
|
|
352
|
+
return f"Error searching: {e}"
|
|
353
|
+
|
|
354
|
+
@tool
|
|
355
|
+
def get_file_structure(dir_path: str = ".", max_depth: int = 3) -> str:
|
|
356
|
+
"""Get a tree view of the project structure up to a given depth.
|
|
357
|
+
Use this to understand the overall project layout.
|
|
358
|
+
Provide relative path and max depth (default 3).
|
|
359
|
+
"""
|
|
360
|
+
if dir_path == ".":
|
|
361
|
+
full_path = _project_dir.resolve()
|
|
362
|
+
else:
|
|
363
|
+
full_path = (_project_dir / dir_path).resolve()
|
|
364
|
+
|
|
365
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
366
|
+
return "Error: Path traversal not allowed."
|
|
367
|
+
if not full_path.is_dir():
|
|
368
|
+
return f"Error: Directory '{dir_path}' not found."
|
|
369
|
+
|
|
370
|
+
skip_dirs = {".git", "node_modules", "__pycache__", "build", "dist", "target",
|
|
371
|
+
".dart_tool", ".next", "venv", ".venv", ".idea", ".vs"}
|
|
372
|
+
lines = []
|
|
373
|
+
|
|
374
|
+
def _walk(path: Path, prefix: str, depth: int):
|
|
375
|
+
if depth > max_depth:
|
|
376
|
+
return
|
|
377
|
+
try:
|
|
378
|
+
entries = sorted(path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower()))
|
|
379
|
+
except PermissionError:
|
|
380
|
+
return
|
|
381
|
+
dirs = [e for e in entries if e.is_dir() and e.name not in skip_dirs]
|
|
382
|
+
files = [e for e in entries if e.is_file()]
|
|
383
|
+
|
|
384
|
+
for f in files[:20]:
|
|
385
|
+
lines.append(f"{prefix}{f.name}")
|
|
386
|
+
if len(files) > 20:
|
|
387
|
+
lines.append(f"{prefix}... ({len(files) - 20} more files)")
|
|
388
|
+
|
|
389
|
+
for d in dirs:
|
|
390
|
+
lines.append(f"{prefix}{d.name}/")
|
|
391
|
+
_walk(d, prefix + " ", depth + 1)
|
|
392
|
+
|
|
393
|
+
lines.append(f"{dir_path}/")
|
|
394
|
+
_walk(full_path, " ", 1)
|
|
395
|
+
return "\n".join(lines[:200])
|
|
396
|
+
|
|
397
|
+
@tool
|
|
398
|
+
def git_file_history(filepath: str, max_commits: int = 10) -> str:
|
|
399
|
+
"""Get recent git commit history for a specific file.
|
|
400
|
+
Shows who changed it, when, and why. Use to understand evolution of a file.
|
|
401
|
+
Provide relative filepath.
|
|
402
|
+
"""
|
|
403
|
+
full_path = (_project_dir / filepath).resolve()
|
|
404
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
405
|
+
return "Error: Path traversal not allowed."
|
|
406
|
+
try:
|
|
407
|
+
result = subprocess.run(
|
|
408
|
+
["git", "log", f"-{max_commits}", "--pretty=format:%h %ai %s", "--", filepath],
|
|
409
|
+
cwd=str(_project_dir),
|
|
410
|
+
capture_output=True, text=True, timeout=10,
|
|
411
|
+
)
|
|
412
|
+
if result.returncode != 0:
|
|
413
|
+
return f"Git error: {result.stderr.strip()}"
|
|
414
|
+
if not result.stdout.strip():
|
|
415
|
+
return f"No git history found for '{filepath}'"
|
|
416
|
+
return f"Recent commits for '{filepath}':\n{result.stdout}"
|
|
417
|
+
except Exception as e:
|
|
418
|
+
return f"Error: {e}"
|
|
419
|
+
|
|
420
|
+
@tool
|
|
421
|
+
def git_blame_lines(filepath: str, start_line: int, end_line: int) -> str:
|
|
422
|
+
"""Show git blame for specific line range of a file.
|
|
423
|
+
Use when you need to know who last changed specific lines and why.
|
|
424
|
+
"""
|
|
425
|
+
full_path = (_project_dir / filepath).resolve()
|
|
426
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
427
|
+
return "Error: Path traversal not allowed."
|
|
428
|
+
try:
|
|
429
|
+
result = subprocess.run(
|
|
430
|
+
["git", "blame", f"-L{start_line},{end_line}", "--date=short", filepath],
|
|
431
|
+
cwd=str(_project_dir),
|
|
432
|
+
capture_output=True, text=True, timeout=10,
|
|
433
|
+
)
|
|
434
|
+
if result.returncode != 0:
|
|
435
|
+
return f"Git error: {result.stderr.strip()}"
|
|
436
|
+
return result.stdout or "No blame output"
|
|
437
|
+
except Exception as e:
|
|
438
|
+
return f"Error: {e}"
|
|
439
|
+
|
|
440
|
+
@tool
|
|
441
|
+
def read_lines(filepath: str, start_line: int, end_line: int) -> str:
|
|
442
|
+
"""Read a specific range of lines from a file.
|
|
443
|
+
Use when you only need a portion of a large file — avoids context overflow.
|
|
444
|
+
Line numbers are 1-based. Returns lines with line numbers prefixed.
|
|
445
|
+
"""
|
|
446
|
+
full_path = (_project_dir / filepath).resolve()
|
|
447
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
448
|
+
return "Error: Path traversal not allowed."
|
|
449
|
+
if not full_path.is_file():
|
|
450
|
+
return f"Error: File '{filepath}' not found."
|
|
451
|
+
try:
|
|
452
|
+
lines = full_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
|
453
|
+
start = max(1, start_line) - 1
|
|
454
|
+
end = min(len(lines), end_line)
|
|
455
|
+
if start >= len(lines):
|
|
456
|
+
return f"Error: File has only {len(lines)} lines."
|
|
457
|
+
selected = lines[start:end]
|
|
458
|
+
return "\n".join(f"{start + i + 1:>4} | {line}" for i, line in enumerate(selected))
|
|
459
|
+
except Exception as e:
|
|
460
|
+
return f"Error: {e}"
|
|
461
|
+
|
|
462
|
+
@tool
|
|
463
|
+
def regex_search(pattern: str, dir_path: str = ".", max_results: int = 30) -> str:
|
|
464
|
+
"""Search files using a regex pattern. More powerful than keyword search.
|
|
465
|
+
Use for complex patterns like 'def .*async', 'TODO|FIXME|HACK', etc.
|
|
466
|
+
Returns matching lines with file paths and line numbers.
|
|
467
|
+
"""
|
|
468
|
+
import re as re_mod
|
|
469
|
+
full_path = _project_dir if dir_path == "." else (_project_dir / dir_path)
|
|
470
|
+
full_path = full_path.resolve()
|
|
471
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
472
|
+
return "Error: Path traversal not allowed."
|
|
473
|
+
|
|
474
|
+
try:
|
|
475
|
+
compiled = re_mod.compile(pattern, re_mod.IGNORECASE)
|
|
476
|
+
except re_mod.error as e:
|
|
477
|
+
return f"Invalid regex: {e}"
|
|
478
|
+
|
|
479
|
+
results = []
|
|
480
|
+
try:
|
|
481
|
+
for root, _, files in os.walk(full_path):
|
|
482
|
+
if any(x in root for x in [".git", "node_modules", "__pycache__", "build", "dist", "target", ".dart_tool"]):
|
|
483
|
+
continue
|
|
484
|
+
for file in files:
|
|
485
|
+
if file.endswith(searchable_exts):
|
|
486
|
+
fpath = Path(root) / file
|
|
487
|
+
try:
|
|
488
|
+
for i, line in enumerate(fpath.read_text("utf-8", "replace").splitlines()):
|
|
489
|
+
if compiled.search(line):
|
|
490
|
+
rel_path = fpath.relative_to(_project_dir)
|
|
491
|
+
results.append(f"{rel_path}:{i+1}: {line.strip()[:120]}")
|
|
492
|
+
if len(results) >= max_results:
|
|
493
|
+
return f"Found {len(results)}+ matches (truncated):\n" + "\n".join(results)
|
|
494
|
+
except Exception:
|
|
495
|
+
continue
|
|
496
|
+
if not results:
|
|
497
|
+
return f"No matches for pattern '{pattern}' in {dir_path}"
|
|
498
|
+
return f"Found {len(results)} matches:\n" + "\n".join(results)
|
|
499
|
+
except Exception as e:
|
|
500
|
+
return f"Error: {e}"
|
|
501
|
+
|
|
502
|
+
@tool
|
|
503
|
+
def find_references(symbol_name: str, dir_path: str = ".") -> str:
|
|
504
|
+
"""Find all usages/references of a symbol (function, class, variable) across the codebase.
|
|
505
|
+
Unlike find_definition which finds where something is declared, this finds where it's used.
|
|
506
|
+
Returns file paths, line numbers, and the line content.
|
|
507
|
+
"""
|
|
508
|
+
full_path = _project_dir if dir_path == "." else (_project_dir / dir_path)
|
|
509
|
+
full_path = full_path.resolve()
|
|
510
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
511
|
+
return "Error: Path traversal not allowed."
|
|
512
|
+
|
|
513
|
+
results = []
|
|
514
|
+
try:
|
|
515
|
+
for root, _, files in os.walk(full_path):
|
|
516
|
+
if any(x in root for x in [".git", "node_modules", "__pycache__", "build", "dist", "target", ".dart_tool"]):
|
|
517
|
+
continue
|
|
518
|
+
for file in files:
|
|
519
|
+
if file.endswith(searchable_exts):
|
|
520
|
+
fpath = Path(root) / file
|
|
521
|
+
try:
|
|
522
|
+
for i, line in enumerate(fpath.read_text("utf-8", "replace").splitlines()):
|
|
523
|
+
if symbol_name in line:
|
|
524
|
+
rel_path = fpath.relative_to(_project_dir)
|
|
525
|
+
results.append(f"{rel_path}:{i+1}: {line.strip()[:120]}")
|
|
526
|
+
if len(results) >= 40:
|
|
527
|
+
return f"Found {len(results)}+ references (truncated):\n" + "\n".join(results)
|
|
528
|
+
except Exception:
|
|
529
|
+
continue
|
|
530
|
+
if not results:
|
|
531
|
+
return f"No references found for '{symbol_name}'"
|
|
532
|
+
return f"Found {len(results)} references:\n" + "\n".join(results)
|
|
533
|
+
except Exception as e:
|
|
534
|
+
return f"Error: {e}"
|
|
535
|
+
|
|
536
|
+
@tool
|
|
537
|
+
def git_diff(ref: str = "HEAD~1", filepath: str = "") -> str:
|
|
538
|
+
"""Show git diff for recent changes. Use to understand what changed recently.
|
|
539
|
+
ref: git reference like 'HEAD~1', 'HEAD~3', a branch name, or commit hash.
|
|
540
|
+
filepath: optional — limit diff to a specific file.
|
|
541
|
+
"""
|
|
542
|
+
cmd = ["git", "diff", "--stat", "-p", ref]
|
|
543
|
+
if filepath:
|
|
544
|
+
full_path = (_project_dir / filepath).resolve()
|
|
545
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
546
|
+
return "Error: Path traversal not allowed."
|
|
547
|
+
cmd.extend(["--", filepath])
|
|
548
|
+
try:
|
|
549
|
+
result = subprocess.run(
|
|
550
|
+
cmd, cwd=str(_project_dir),
|
|
551
|
+
capture_output=True, text=True, timeout=15,
|
|
552
|
+
)
|
|
553
|
+
if result.returncode != 0:
|
|
554
|
+
return f"Git error: {result.stderr.strip()}"
|
|
555
|
+
output = result.stdout.strip()
|
|
556
|
+
if not output:
|
|
557
|
+
return "No changes found."
|
|
558
|
+
if len(output) > 15000:
|
|
559
|
+
return output[:15000] + "\n\n... [Diff truncated] ..."
|
|
560
|
+
return output
|
|
561
|
+
except Exception as e:
|
|
562
|
+
return f"Error: {e}"
|
|
563
|
+
|
|
564
|
+
@tool
|
|
565
|
+
def git_log_search(search_term: str, max_commits: int = 10) -> str:
|
|
566
|
+
"""Search git commit messages and diffs for a term.
|
|
567
|
+
Use to find when a feature was added, a bug was introduced, or a file was changed.
|
|
568
|
+
Searches both commit messages (-grep) and code changes (-S).
|
|
569
|
+
"""
|
|
570
|
+
results = []
|
|
571
|
+
try:
|
|
572
|
+
msg_result = subprocess.run(
|
|
573
|
+
["git", "log", f"-{max_commits}", "--pretty=format:%h %ai %s", f"--grep={search_term}", "-i"],
|
|
574
|
+
cwd=str(_project_dir),
|
|
575
|
+
capture_output=True, text=True, timeout=10,
|
|
576
|
+
)
|
|
577
|
+
if msg_result.stdout.strip():
|
|
578
|
+
results.append("Commits mentioning '" + search_term + "':\n" + msg_result.stdout.strip())
|
|
579
|
+
except Exception:
|
|
580
|
+
pass
|
|
581
|
+
|
|
582
|
+
try:
|
|
583
|
+
code_result = subprocess.run(
|
|
584
|
+
["git", "log", f"-{max_commits}", "--pretty=format:%h %ai %s", f"-S{search_term}"],
|
|
585
|
+
cwd=str(_project_dir),
|
|
586
|
+
capture_output=True, text=True, timeout=10,
|
|
587
|
+
)
|
|
588
|
+
if code_result.stdout.strip():
|
|
589
|
+
results.append("Commits changing code with '" + search_term + "':\n" + code_result.stdout.strip())
|
|
590
|
+
except Exception:
|
|
591
|
+
pass
|
|
592
|
+
|
|
593
|
+
if not results:
|
|
594
|
+
return f"No commits found related to '{search_term}'"
|
|
595
|
+
return "\n\n".join(results)
|
|
596
|
+
|
|
597
|
+
@tool
|
|
598
|
+
def file_stats(filepath: str) -> str:
|
|
599
|
+
"""Get stats about a file: line count, size, last modified, language.
|
|
600
|
+
Use to quickly assess file complexity and recency.
|
|
601
|
+
"""
|
|
602
|
+
full_path = (_project_dir / filepath).resolve()
|
|
603
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
604
|
+
return "Error: Path traversal not allowed."
|
|
605
|
+
if not full_path.is_file():
|
|
606
|
+
return f"Error: File '{filepath}' not found."
|
|
607
|
+
try:
|
|
608
|
+
content = full_path.read_text(encoding="utf-8", errors="replace")
|
|
609
|
+
lines = content.splitlines()
|
|
610
|
+
non_blank = sum(1 for l in lines if l.strip())
|
|
611
|
+
stat = full_path.stat()
|
|
612
|
+
from datetime import datetime
|
|
613
|
+
modified = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M")
|
|
614
|
+
size_kb = stat.st_size / 1024
|
|
615
|
+
|
|
616
|
+
parts = [
|
|
617
|
+
f"File: {filepath}",
|
|
618
|
+
f"Lines: {len(lines)} ({non_blank} non-blank)",
|
|
619
|
+
f"Size: {size_kb:.1f} KB",
|
|
620
|
+
f"Last modified: {modified}",
|
|
621
|
+
f"Extension: {full_path.suffix}",
|
|
622
|
+
]
|
|
623
|
+
|
|
624
|
+
import_count = sum(1 for l in lines if l.strip().startswith(("import ", "from ", "#include", "use ", "require")))
|
|
625
|
+
class_count = sum(1 for l in lines if any(l.strip().startswith(k) for k in ("class ", "struct ", "enum ", "interface ", "trait ")))
|
|
626
|
+
func_count = sum(1 for l in lines if any(l.strip().startswith(k) for k in ("def ", "func ", "fn ", "function ", "async def ", "pub fn ")))
|
|
627
|
+
|
|
628
|
+
if import_count: parts.append(f"Imports: {import_count}")
|
|
629
|
+
if class_count: parts.append(f"Classes/structs: {class_count}")
|
|
630
|
+
if func_count: parts.append(f"Functions: {func_count}")
|
|
631
|
+
|
|
632
|
+
return "\n".join(parts)
|
|
633
|
+
except Exception as e:
|
|
634
|
+
return f"Error: {e}"
|
|
635
|
+
|
|
636
|
+
@tool
|
|
637
|
+
def find_files_by_name(filename_pattern: str) -> str:
|
|
638
|
+
"""Find files whose name matches a pattern (case-insensitive substring match).
|
|
639
|
+
Use when you know part of a filename but not its full path.
|
|
640
|
+
Example: 'auth' finds auth_service.py, AuthController.java, etc.
|
|
641
|
+
"""
|
|
642
|
+
pattern_lower = filename_pattern.lower()
|
|
643
|
+
results = []
|
|
644
|
+
try:
|
|
645
|
+
for root, dirs, files in os.walk(_project_dir):
|
|
646
|
+
dirs[:] = [d for d in dirs if d not in {
|
|
647
|
+
".git", "node_modules", "__pycache__", "build", "dist",
|
|
648
|
+
"target", ".dart_tool", ".next", "venv", ".venv",
|
|
649
|
+
}]
|
|
650
|
+
for file in files:
|
|
651
|
+
if pattern_lower in file.lower():
|
|
652
|
+
fpath = Path(root) / file
|
|
653
|
+
rel_path = fpath.relative_to(_project_dir)
|
|
654
|
+
results.append(str(rel_path))
|
|
655
|
+
if len(results) >= 30:
|
|
656
|
+
return f"Found {len(results)}+ files (truncated):\n" + "\n".join(results)
|
|
657
|
+
if not results:
|
|
658
|
+
return f"No files matching '{filename_pattern}'"
|
|
659
|
+
return f"Found {len(results)} files:\n" + "\n".join(results)
|
|
660
|
+
except Exception as e:
|
|
661
|
+
return f"Error: {e}"
|
|
662
|
+
|
|
663
|
+
@tool
|
|
664
|
+
def get_call_chain(filepath: str, function_name: str) -> str:
|
|
665
|
+
"""Trace who calls a function and what it calls.
|
|
666
|
+
Returns callers (files that reference this function) and callees
|
|
667
|
+
(functions/methods invoked inside the given function's body).
|
|
668
|
+
"""
|
|
669
|
+
full_path = (_project_dir / filepath).resolve()
|
|
670
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
671
|
+
return "Error: Path traversal not allowed."
|
|
672
|
+
if not full_path.is_file():
|
|
673
|
+
return f"Error: File '{filepath}' not found."
|
|
674
|
+
|
|
675
|
+
parts = []
|
|
676
|
+
|
|
677
|
+
importers = graph.get_importers(filepath)
|
|
678
|
+
callers = []
|
|
679
|
+
for root, _, files in os.walk(_project_dir):
|
|
680
|
+
if any(x in root for x in [".git", "node_modules", "__pycache__", "build", "dist", "target"]):
|
|
681
|
+
continue
|
|
682
|
+
for file in files:
|
|
683
|
+
if file.endswith(searchable_exts):
|
|
684
|
+
fpath = Path(root) / file
|
|
685
|
+
if fpath.resolve() == full_path:
|
|
686
|
+
continue
|
|
687
|
+
try:
|
|
688
|
+
content = fpath.read_text("utf-8", "replace")
|
|
689
|
+
if function_name in content:
|
|
690
|
+
rel = fpath.relative_to(_project_dir)
|
|
691
|
+
for i, line in enumerate(content.splitlines()):
|
|
692
|
+
if function_name in line:
|
|
693
|
+
callers.append(f" {rel}:{i+1}: {line.strip()[:100]}")
|
|
694
|
+
break
|
|
695
|
+
except Exception:
|
|
696
|
+
continue
|
|
697
|
+
if len(callers) >= 15:
|
|
698
|
+
break
|
|
699
|
+
|
|
700
|
+
if callers:
|
|
701
|
+
parts.append(f"Callers of {function_name}:\n" + "\n".join(callers))
|
|
702
|
+
else:
|
|
703
|
+
parts.append(f"No callers found for {function_name}")
|
|
704
|
+
|
|
705
|
+
try:
|
|
706
|
+
import re as re_mod
|
|
707
|
+
content = full_path.read_text("utf-8", "replace")
|
|
708
|
+
lines = content.splitlines()
|
|
709
|
+
|
|
710
|
+
func_start = -1
|
|
711
|
+
for i, line in enumerate(lines):
|
|
712
|
+
if function_name in line and any(k in line for k in ("def ", "func ", "fn ", "function ", "void ", "class ")):
|
|
713
|
+
func_start = i
|
|
714
|
+
break
|
|
715
|
+
|
|
716
|
+
if func_start >= 0:
|
|
717
|
+
body = "\n".join(lines[func_start:func_start + 50])
|
|
718
|
+
callees = set(re_mod.findall(r'\b([a-zA-Z_]\w+)\s*\(', body))
|
|
719
|
+
keywords = {"if", "for", "while", "switch", "catch", "return", "print", "throw", function_name}
|
|
720
|
+
callees -= keywords
|
|
721
|
+
if callees:
|
|
722
|
+
parts.append(f"Functions called by {function_name}:\n " + ", ".join(sorted(callees)))
|
|
723
|
+
except Exception:
|
|
724
|
+
pass
|
|
725
|
+
|
|
726
|
+
return "\n\n".join(parts) if parts else f"Could not trace call chain for {function_name}"
|
|
727
|
+
|
|
728
|
+
@tool
|
|
729
|
+
def semantic_code_search(query: str, top_k: int = 6) -> str:
|
|
730
|
+
"""Search the codebase using semantic similarity (vector embeddings).
|
|
731
|
+
Unlike keyword search, this finds code by MEANING — even if the exact
|
|
732
|
+
words aren't present. Use for conceptual queries like:
|
|
733
|
+
- 'error handling logic'
|
|
734
|
+
- 'user authentication flow'
|
|
735
|
+
- 'database connection setup'
|
|
736
|
+
Returns the most semantically relevant code chunks with file paths.
|
|
737
|
+
"""
|
|
738
|
+
if not collection:
|
|
739
|
+
return "Error: Vector database not available."
|
|
740
|
+
try:
|
|
741
|
+
query_vector = embed_text(query)
|
|
742
|
+
results = query_similar(collection, query_vector, n_results=top_k)
|
|
743
|
+
if not results:
|
|
744
|
+
return f"No semantically relevant code found for: '{query}'"
|
|
745
|
+
parts = []
|
|
746
|
+
for r in results:
|
|
747
|
+
score = r.get('relevance', 0)
|
|
748
|
+
fname = r.get('filename', '?')
|
|
749
|
+
loc = r.get('location', '')
|
|
750
|
+
code = r.get('code', '')[:500]
|
|
751
|
+
feature = r.get('feature', '')
|
|
752
|
+
layer = r.get('layer', '')
|
|
753
|
+
meta = ""
|
|
754
|
+
if feature and feature != "unknown":
|
|
755
|
+
meta += f" [feature: {feature}]"
|
|
756
|
+
if layer and layer != "unknown":
|
|
757
|
+
meta += f" [layer: {layer}]"
|
|
758
|
+
parts.append(f"--- {fname} ({loc}) [score: {score:.2f}]{meta}\n{code}")
|
|
759
|
+
return f"Found {len(results)} semantically similar chunks:\n\n" + "\n\n".join(parts)
|
|
760
|
+
except Exception as e:
|
|
761
|
+
return f"Error in semantic search: {e}"
|
|
762
|
+
|
|
763
|
+
@tool
|
|
764
|
+
def find_similar_code(filepath: str, chunk_index: int = 0, top_k: int = 5) -> str:
|
|
765
|
+
"""Find code chunks that are semantically similar to a specific file/chunk.
|
|
766
|
+
Use this to discover related implementations, duplicated logic, or code
|
|
767
|
+
that serves a similar purpose elsewhere in the codebase.
|
|
768
|
+
Provide the file path and optionally a chunk index (default 0 = first chunk).
|
|
769
|
+
"""
|
|
770
|
+
if not collection:
|
|
771
|
+
return "Error: Vector database not available."
|
|
772
|
+
try:
|
|
773
|
+
full_path = (_project_dir / filepath).resolve()
|
|
774
|
+
if not str(full_path).startswith(str(_project_dir.resolve())):
|
|
775
|
+
return "Error: Path traversal not allowed."
|
|
776
|
+
if not full_path.is_file():
|
|
777
|
+
return f"Error: File '{filepath}' not found."
|
|
778
|
+
|
|
779
|
+
content = full_path.read_text(encoding="utf-8", errors="replace")
|
|
780
|
+
query_text = content[:800]
|
|
781
|
+
query_vector = embed_text(query_text)
|
|
782
|
+
|
|
783
|
+
results = query_similar(collection, query_vector, n_results=top_k + 2)
|
|
784
|
+
results = [r for r in results if r.get('filename') != filepath][:top_k]
|
|
785
|
+
|
|
786
|
+
if not results:
|
|
787
|
+
return f"No similar code found to '{filepath}'"
|
|
788
|
+
parts = []
|
|
789
|
+
for r in results:
|
|
790
|
+
score = r.get('relevance', 0)
|
|
791
|
+
fname = r.get('filename', '?')
|
|
792
|
+
code = r.get('code', '')[:400]
|
|
793
|
+
parts.append(f"--- {fname} [similarity: {score:.2f}]\n{code}")
|
|
794
|
+
return f"Files similar to '{filepath}':\n\n" + "\n\n".join(parts)
|
|
795
|
+
except Exception as e:
|
|
796
|
+
return f"Error: {e}"
|
|
797
|
+
|
|
798
|
+
return [
|
|
799
|
+
read_local_file,
|
|
800
|
+
read_lines,
|
|
801
|
+
list_directory,
|
|
802
|
+
find_file_usages,
|
|
803
|
+
search_codebase_keywords,
|
|
804
|
+
regex_search,
|
|
805
|
+
find_definition,
|
|
806
|
+
find_references,
|
|
807
|
+
find_files_by_name,
|
|
808
|
+
get_file_structure,
|
|
809
|
+
get_call_chain,
|
|
810
|
+
file_stats,
|
|
811
|
+
semantic_code_search,
|
|
812
|
+
find_similar_code,
|
|
813
|
+
git_file_history,
|
|
814
|
+
git_blame_lines,
|
|
815
|
+
git_diff,
|
|
816
|
+
git_log_search,
|
|
817
|
+
]
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
# ---------------------------------------------------------------------------
|
|
821
|
+
# Streaming RAG response
|
|
822
|
+
# ---------------------------------------------------------------------------
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
async def stream_rag_response(
|
|
826
|
+
collection: Any,
|
|
827
|
+
graph: ImportGraph,
|
|
828
|
+
query: str,
|
|
829
|
+
mode: str,
|
|
830
|
+
model: str = "gemini-2.5-flash",
|
|
831
|
+
history: list[dict[str, str]] | None = None,
|
|
832
|
+
profile: LanguageProfile | None = None,
|
|
833
|
+
project_dir: Path | None = None,
|
|
834
|
+
gemini_api_key: str = "",
|
|
835
|
+
) -> AsyncGenerator[dict[str, Any], None]:
|
|
836
|
+
"""Async generator that retrieves context and streams a Gemini response."""
|
|
837
|
+
if history is None:
|
|
838
|
+
history = []
|
|
839
|
+
|
|
840
|
+
_project_dir = project_dir or Path.cwd()
|
|
841
|
+
highlight_language = profile.highlight_language if profile else "text"
|
|
842
|
+
|
|
843
|
+
yield {"type": "status", "stage": "retrieving"}
|
|
844
|
+
|
|
845
|
+
try:
|
|
846
|
+
chunks = await retrieve_for_mode(collection, graph, query, mode, profile=profile)
|
|
847
|
+
except Exception as exc:
|
|
848
|
+
yield {"type": "error", "content": f"Retrieval failed: {exc}"}
|
|
849
|
+
return
|
|
850
|
+
|
|
851
|
+
unique_files = len(set(c.get("filename", "") for c in chunks))
|
|
852
|
+
yield {"type": "status", "stage": "context_found", "chunks": len(chunks), "files": unique_files}
|
|
853
|
+
|
|
854
|
+
claude_md, memory_content = _load_static_context(_project_dir)
|
|
855
|
+
|
|
856
|
+
prompt = assemble_prompt(
|
|
857
|
+
mode=mode,
|
|
858
|
+
query=query,
|
|
859
|
+
chunks=chunks,
|
|
860
|
+
claude_md=claude_md,
|
|
861
|
+
memory_content=memory_content,
|
|
862
|
+
history=history,
|
|
863
|
+
model=model,
|
|
864
|
+
highlight_language=highlight_language,
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
messages: list[Any] = [SystemMessage(content=prompt["system"])]
|
|
868
|
+
|
|
869
|
+
for turn in prompt["history"]:
|
|
870
|
+
role = turn.get("role", "user")
|
|
871
|
+
content = turn.get("content", "")
|
|
872
|
+
if role == "assistant":
|
|
873
|
+
messages.append(AIMessage(content=content))
|
|
874
|
+
else:
|
|
875
|
+
messages.append(HumanMessage(content=content))
|
|
876
|
+
|
|
877
|
+
context_block = prompt["context"]
|
|
878
|
+
if context_block:
|
|
879
|
+
human_content = f"## Retrieved Code Context\n\n{context_block}\n\n---\n\n{query}"
|
|
880
|
+
else:
|
|
881
|
+
human_content = query
|
|
882
|
+
|
|
883
|
+
messages.append(HumanMessage(content=human_content))
|
|
884
|
+
|
|
885
|
+
sources = [
|
|
886
|
+
{"filename": c.get("filename", ""), "priority": c.get("priority", "semantic")}
|
|
887
|
+
for c in chunks
|
|
888
|
+
]
|
|
889
|
+
|
|
890
|
+
try:
|
|
891
|
+
tools = _get_tools(graph, profile, collection, _project_dir)
|
|
892
|
+
llm = ChatGoogleGenerativeAI(
|
|
893
|
+
model=model,
|
|
894
|
+
google_api_key=gemini_api_key,
|
|
895
|
+
streaming=True,
|
|
896
|
+
)
|
|
897
|
+
llm_with_tools = llm.bind_tools(tools)
|
|
898
|
+
|
|
899
|
+
yield {"type": "status", "stage": "thinking"}
|
|
900
|
+
|
|
901
|
+
MAX_STEPS = 5
|
|
902
|
+
step = 0
|
|
903
|
+
first_token = True
|
|
904
|
+
|
|
905
|
+
while step < MAX_STEPS:
|
|
906
|
+
step += 1
|
|
907
|
+
has_tool_calls = False
|
|
908
|
+
full_chunk_msg = None
|
|
909
|
+
|
|
910
|
+
async for chunk in llm_with_tools.astream(messages):
|
|
911
|
+
if full_chunk_msg is None:
|
|
912
|
+
full_chunk_msg = chunk
|
|
913
|
+
else:
|
|
914
|
+
full_chunk_msg += chunk
|
|
915
|
+
|
|
916
|
+
if chunk.content:
|
|
917
|
+
content_val = chunk.content
|
|
918
|
+
if isinstance(content_val, list):
|
|
919
|
+
text_parts = [
|
|
920
|
+
b.get("text", "") for b in content_val
|
|
921
|
+
if isinstance(b, dict) and b.get("type") == "text"
|
|
922
|
+
]
|
|
923
|
+
content_str = "".join(text_parts)
|
|
924
|
+
else:
|
|
925
|
+
content_str = str(content_val)
|
|
926
|
+
|
|
927
|
+
if content_str:
|
|
928
|
+
if first_token:
|
|
929
|
+
yield {"type": "status", "stage": "generating"}
|
|
930
|
+
first_token = False
|
|
931
|
+
yield {"type": "token", "content": content_str}
|
|
932
|
+
|
|
933
|
+
if full_chunk_msg:
|
|
934
|
+
messages.append(full_chunk_msg)
|
|
935
|
+
|
|
936
|
+
if full_chunk_msg.tool_calls:
|
|
937
|
+
has_tool_calls = True
|
|
938
|
+
for tool_call in full_chunk_msg.tool_calls:
|
|
939
|
+
tool_name = tool_call["name"]
|
|
940
|
+
args = tool_call["args"]
|
|
941
|
+
tool_id = tool_call["id"]
|
|
942
|
+
|
|
943
|
+
args_summary = ", ".join(f"{k}={repr(v)[:40]}" for k, v in args.items()) if args else ""
|
|
944
|
+
yield {"type": "status", "stage": "tool_call", "tool": tool_name, "args": args_summary}
|
|
945
|
+
|
|
946
|
+
selected_tool = next((t for t in tools if t.name == tool_name), None)
|
|
947
|
+
if selected_tool:
|
|
948
|
+
try:
|
|
949
|
+
result = await asyncio.to_thread(selected_tool.invoke, args)
|
|
950
|
+
result_str = str(result)
|
|
951
|
+
except Exception as e:
|
|
952
|
+
result_str = f"Error executing tool: {e}"
|
|
953
|
+
else:
|
|
954
|
+
result_str = f"Error: Tool {tool_name} not found."
|
|
955
|
+
|
|
956
|
+
yield {"type": "status", "stage": "tool_done", "tool": tool_name}
|
|
957
|
+
messages.append(ToolMessage(content=result_str, tool_call_id=tool_id))
|
|
958
|
+
|
|
959
|
+
first_token = True
|
|
960
|
+
|
|
961
|
+
if not has_tool_calls:
|
|
962
|
+
break
|
|
963
|
+
|
|
964
|
+
yield {"type": "done", "sources": sources, "stats": prompt["stats"]}
|
|
965
|
+
|
|
966
|
+
except Exception as exc:
|
|
967
|
+
yield {"type": "error", "content": f"Gemini API error: {exc}"}
|