codebase-retrieval-context-engine 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. codebase_retrieval_context_engine-2.0.0.dist-info/METADATA +505 -0
  2. codebase_retrieval_context_engine-2.0.0.dist-info/RECORD +46 -0
  3. codebase_retrieval_context_engine-2.0.0.dist-info/WHEEL +4 -0
  4. codebase_retrieval_context_engine-2.0.0.dist-info/entry_points.txt +3 -0
  5. codebase_retrieval_context_engine-2.0.0.dist-info/licenses/LICENSE +201 -0
  6. corbell/__init__.py +6 -0
  7. corbell/cli/__init__.py +1 -0
  8. corbell/cli/commands/__init__.py +1 -0
  9. corbell/cli/commands/index.py +86 -0
  10. corbell/cli/commands/query.py +71 -0
  11. corbell/cli/main.py +57 -0
  12. corbell/core/__init__.py +1 -0
  13. corbell/core/constants.py +52 -0
  14. corbell/core/embeddings/__init__.py +6 -0
  15. corbell/core/embeddings/base.py +68 -0
  16. corbell/core/embeddings/extractor.py +201 -0
  17. corbell/core/embeddings/factory.py +48 -0
  18. corbell/core/embeddings/model.py +401 -0
  19. corbell/core/embeddings/search_cache.py +95 -0
  20. corbell/core/embeddings/sqlite_store.py +271 -0
  21. corbell/core/gitignore.py +76 -0
  22. corbell/core/graph/__init__.py +1 -0
  23. corbell/core/graph/builder.py +696 -0
  24. corbell/core/graph/method_graph.py +1077 -0
  25. corbell/core/graph/providers/__init__.py +6 -0
  26. corbell/core/graph/providers/aws_patterns.py +62 -0
  27. corbell/core/graph/providers/azure_patterns.py +64 -0
  28. corbell/core/graph/providers/gcp_patterns.py +59 -0
  29. corbell/core/graph/schema.py +175 -0
  30. corbell/core/graph/sqlite_store.py +500 -0
  31. corbell/core/indexing/__init__.py +1 -0
  32. corbell/core/indexing/builder.py +608 -0
  33. corbell/core/indexing/lock.py +150 -0
  34. corbell/core/indexing/tracker.py +245 -0
  35. corbell/core/llm_client.py +677 -0
  36. corbell/core/mcp/__init__.py +1 -0
  37. corbell/core/mcp/server.py +214 -0
  38. corbell/core/query/__init__.py +1 -0
  39. corbell/core/query/diagnostics.py +38 -0
  40. corbell/core/query/engine.py +321 -0
  41. corbell/core/query/enhancer.py +102 -0
  42. corbell/core/query/formatter.py +98 -0
  43. corbell/core/query/graph_expander.py +284 -0
  44. corbell/core/query/merger.py +171 -0
  45. corbell/core/query/reranker.py +131 -0
  46. corbell/core/workspace.py +408 -0
@@ -0,0 +1,214 @@
1
+ """MCP Server for Corbell code retrieval engine.
2
+
3
+ Exposes a single tool `context_engine_codebase_retrieval` via FastMCP,
4
+ supporting both stdio and SSE transports.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import asyncio
10
+ import os
11
+ import sys
12
+ from typing import Optional
13
+
14
+ from mcp.server.fastmcp import FastMCP
15
+
16
+
17
+ # Create the FastMCP server
18
+ mcp = FastMCP("corbell", dependencies=["corbell"])
19
+
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Tool: context_engine_codebase_retrieval
23
+ # ---------------------------------------------------------------------------
24
+
25
+ @mcp.tool()
26
+ def context_engine_codebase_retrieval(
27
+ query: str,
28
+ workspace_full_path: str = "",
29
+ top_k: int = 50,
30
+ rerank: bool = True,
31
+ ) -> str:
32
+ """Search the indexed codebase and return relevant code snippets.
33
+
34
+ Returns formatted code blocks with absolute file paths and line numbers,
35
+ ready for injection into an LLM context window.
36
+
37
+ Args:
38
+ query: Natural language description of the code you're looking for.
39
+ workspace_full_path: Full path to the workspace (repository) root directory.
40
+ Falls back to CORBELL_WORKSPACE env var if empty.
41
+ top_k: Maximum number of code chunks to return (default 50).
42
+ rerank: Whether to use LLM reranking for better relevance (default true).
43
+
44
+ Returns:
45
+ Formatted code snippets, or an error string on failure.
46
+ """
47
+ try:
48
+ workspace_path_str = _resolve_workspace(workspace_full_path)
49
+ if workspace_path_str is None:
50
+ return (
51
+ "Error: workspace_full_path is required. "
52
+ "Pass the full path to the workspace (repository) root directory."
53
+ )
54
+
55
+ from pathlib import Path
56
+ from corbell.core.workspace import build_config, db_path_for_workspace
57
+ from corbell.core.embeddings.sqlite_store import SQLiteEmbeddingStore
58
+ from corbell.core.indexing.tracker import IndexTracker
59
+ from corbell.core.indexing.builder import IndexBuilder
60
+
61
+ ws_path = Path(workspace_path_str).resolve()
62
+
63
+ if not ws_path.exists():
64
+ return (
65
+ f"Error: Workspace directory not found: {ws_path}. "
66
+ "Ensure the path points to a valid repository root."
67
+ )
68
+
69
+ cfg = build_config(ws_path)
70
+ db_path = db_path_for_workspace(ws_path, model=cfg.storage.resolved_model())
71
+
72
+ try:
73
+ emb_store = SQLiteEmbeddingStore(db_path)
74
+ except Exception:
75
+ return (
76
+ f"Error: Database corrupted at {db_path}. "
77
+ "Run 'corbell index build --rebuild' to recreate."
78
+ )
79
+
80
+ # Check index status
81
+ try:
82
+ chunk_count = emb_store.count()
83
+ except Exception:
84
+ return (
85
+ f"Error: Database corrupted at {db_path}. "
86
+ "Run 'corbell index build --rebuild' to recreate."
87
+ )
88
+
89
+ if chunk_count == 0:
90
+ import logging
91
+ logging.getLogger(__name__).info(
92
+ "Index is empty — running full build now (this may take a while)..."
93
+ )
94
+ builder = IndexBuilder()
95
+ builder.build(cfg, db_path, rebuild=True)
96
+
97
+ # Blocking incremental rebuild if stale (MCP never does full build)
98
+ tracker = IndexTracker(db_path)
99
+ stale_result = tracker.get_stale_files(cfg.repos, cfg)
100
+ if stale_result.has_changes:
101
+ try:
102
+ builder = IndexBuilder()
103
+ builder.build(cfg, db_path, rebuild=False)
104
+ except Exception:
105
+ # Non-fatal: proceed with current index
106
+ pass
107
+
108
+ # Run the retrieval pipeline
109
+ from corbell.core.query.engine import codebase_retrieval
110
+
111
+ result = codebase_retrieval(
112
+ query=query,
113
+ workspace_path=ws_path,
114
+ top_k=top_k,
115
+ use_llm=True,
116
+ rerank=rerank,
117
+ )
118
+
119
+ return result
120
+
121
+ except Exception as exc:
122
+ return f"Error: Unexpected failure in codebase_retrieval: {exc}"
123
+
124
+
125
+ def _resolve_workspace(workspace_full_path: str) -> Optional[str]:
126
+ """Resolve the workspace path from parameter or env var."""
127
+ # 1. Explicit path provided
128
+ if workspace_full_path and workspace_full_path.strip():
129
+ return workspace_full_path.strip()
130
+
131
+ # 2. Environment variable
132
+ env_path = os.environ.get("CORBELL_WORKSPACE")
133
+ if env_path:
134
+ return env_path
135
+
136
+ return None
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # Filtered stdin wrapper — prevents empty-line crashes in MCP SDK
141
+ # ---------------------------------------------------------------------------
142
+
143
+ class _FilteredStdin:
144
+ """Async iterator over stdin that silently drops empty/whitespace lines.
145
+
146
+ The MCP SDK's stdio transport passes every raw line from sys.stdin to
147
+ Pydantic's JSONRPCMessage.model_validate_json(). Empty newlines fail
148
+ validation and crash the server. This wrapper filters them out.
149
+ """
150
+
151
+ def __init__(self) -> None:
152
+ self._reader = None
153
+
154
+ def __aiter__(self):
155
+ return self
156
+
157
+ async def __anext__(self) -> str:
158
+ loop = asyncio.get_event_loop()
159
+ while True:
160
+ line = await loop.run_in_executor(None, sys.stdin.readline)
161
+ if not line: # EOF
162
+ raise StopAsyncIteration
163
+ if line.strip(): # Only forward non-empty lines
164
+ return line
165
+ # Empty/whitespace lines are silently dropped
166
+
167
+
168
+ # ---------------------------------------------------------------------------
169
+ # Server entry point
170
+ # ---------------------------------------------------------------------------
171
+
172
+ def serve(transport: str = "stdio", port: int = 8000) -> None:
173
+ """Run the MCP server.
174
+
175
+ Args:
176
+ transport: 'stdio' for pipe-based IDE integration, 'sse' for HTTP server.
177
+ port: Port number for SSE transport (ignored for stdio).
178
+ """
179
+ if transport == "sse":
180
+ print(f"Corbell MCP server starting on http://localhost:{port}/sse ...", file=sys.stderr)
181
+ mcp.settings.port = port
182
+ mcp.run(transport="sse")
183
+ else:
184
+ print("Corbell MCP server starting on stdio...", file=sys.stderr)
185
+
186
+ async def _run():
187
+ from mcp.server.stdio import stdio_server
188
+
189
+ filtered = _FilteredStdin()
190
+ async with stdio_server(stdin=filtered) as (read_stream, write_stream):
191
+ await mcp._mcp_server.run(
192
+ read_stream,
193
+ write_stream,
194
+ mcp._mcp_server.create_initialization_options(),
195
+ )
196
+
197
+ asyncio.run(_run())
198
+
199
+
200
+ def main() -> None:
201
+ """Entry point for `uvx codebase-retrieval-context-engine`."""
202
+ import argparse
203
+
204
+ parser = argparse.ArgumentParser(description="Codebase Retrieval Context Engine MCP Server")
205
+ parser.add_argument(
206
+ "--transport", "-t", default="stdio", choices=["stdio", "sse"],
207
+ help="Transport mode (default: stdio)",
208
+ )
209
+ parser.add_argument(
210
+ "--port", "-p", type=int, default=8000,
211
+ help="Port for SSE transport (default: 8000)",
212
+ )
213
+ args = parser.parse_args()
214
+ serve(transport=args.transport, port=args.port)
@@ -0,0 +1 @@
1
+ """Query pipeline module for Corbell code retrieval."""
@@ -0,0 +1,38 @@
1
+ """Query diagnostics for tracking and surfacing warnings during retrieval."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional
7
+
8
+
9
+ @dataclass
10
+ class QueryDiagnostics:
11
+ """Accumulates warning counters during a query execution.
12
+
13
+ Counters are incremented as the pipeline runs. At the end,
14
+ ``summary()`` returns a warning string if any threshold is exceeded.
15
+ """
16
+
17
+ skipped_files: int = 0 # files that no longer exist on disk
18
+ skipped_methods: int = 0 # method nodes that couldn't be expanded
19
+ graph_expansion_failures: int = 0 # graph lookups that failed
20
+
21
+ # Thresholds for emitting warnings
22
+ _FILE_THRESHOLD: int = field(default=3, init=False, repr=False)
23
+ _METHOD_THRESHOLD: int = field(default=5, init=False, repr=False)
24
+
25
+ def summary(self) -> Optional[str]:
26
+ """Return a warning string if any counter exceeds its threshold.
27
+
28
+ Returns:
29
+ Warning string suitable for display, or None if everything is fine.
30
+ """
31
+ parts = []
32
+ if self.skipped_files >= self._FILE_THRESHOLD:
33
+ parts.append(f"{self.skipped_files} files missing (index may be stale)")
34
+ if self.skipped_methods >= self._METHOD_THRESHOLD:
35
+ parts.append(f"{self.skipped_methods} methods skipped")
36
+ if not parts:
37
+ return None
38
+ return "; ".join(parts)
@@ -0,0 +1,321 @@
1
+ """Main query engine orchestrator for codebase retrieval."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def codebase_retrieval(
14
+ query: str,
15
+ workspace_path: str | Path,
16
+ top_k: int = 50,
17
+ use_llm: bool = True,
18
+ rerank: bool = True,
19
+ ) -> str:
20
+ """Execute the full code retrieval pipeline.
21
+
22
+ Pipeline:
23
+ 1. Load workspace config and open stores.
24
+ 2. Auto-index check (empty → full build, stale → blocking incremental rebuild).
25
+ Skipped entirely when last build completed within the past 30 seconds.
26
+ 3. Embedding search via EmbeddingSearchCache (raw query used directly).
27
+ 4. Graph call-chain expansion.
28
+ 5. Merge + dedup.
29
+ 6. LLM rerank (optional).
30
+ 7. Format results.
31
+
32
+ Args:
33
+ query: Natural language query string.
34
+ workspace_path: Path to the workspace (repository) root directory.
35
+ top_k: Maximum number of chunks to pass to reranker.
36
+ use_llm: If False, skip reranking.
37
+ rerank: If False, skip reranking even when LLM is configured.
38
+
39
+ Returns:
40
+ Formatted code snippet string ready for LLM context injection.
41
+ Returns an error string (prefixed with "Error:") on failure.
42
+ """
43
+ from corbell.core.workspace import build_config, db_path_for_workspace
44
+ from corbell.core.embeddings.sqlite_store import SQLiteEmbeddingStore
45
+ from corbell.core.embeddings.search_cache import EmbeddingSearchCache
46
+ from corbell.core.embeddings.model import SentenceTransformerModel, GoogleEmbeddingModel, VoyageEmbeddingModel, EmbeddingModel
47
+ from corbell.core.graph.sqlite_store import SQLiteGraphStore
48
+ from corbell.core.indexing.builder import IndexBuilder
49
+ from corbell.core.indexing.tracker import IndexTracker
50
+ from corbell.core.query.diagnostics import QueryDiagnostics
51
+ from corbell.core.query.graph_expander import ScoredChunk, expand_via_graph
52
+ from corbell.core.query.merger import merge_and_dedup
53
+ from corbell.core.query.reranker import rerank_chunks
54
+ from corbell.core.query.formatter import format_results
55
+
56
+ workspace_path = Path(workspace_path).resolve()
57
+
58
+ if not workspace_path.exists():
59
+ return f"Error: Workspace directory not found: {workspace_path}. Run 'corbell index build' first."
60
+
61
+ cfg = build_config(workspace_path)
62
+ db_path = db_path_for_workspace(workspace_path, model=cfg.storage.resolved_model())
63
+ emb_store = SQLiteEmbeddingStore(db_path)
64
+ graph_store = SQLiteGraphStore(db_path)
65
+ tracker = IndexTracker(db_path)
66
+
67
+ # --- Auto-index check ---
68
+ chunk_count = emb_store.count()
69
+ if chunk_count == 0:
70
+ logger.info("Index is empty — running full build now (this may take a while)...")
71
+ builder = IndexBuilder()
72
+ builder.build(cfg, db_path, rebuild=True, progress_fn=lambda msg: logger.info(msg))
73
+
74
+ # Short-circuit: skip stale check if a build finished within the last 30 seconds
75
+ last_build = tracker.get_last_build_at()
76
+ if last_build is None or (time.time() - last_build) >= 30:
77
+ stale_result = tracker.get_stale_files(cfg.repos, cfg)
78
+ if stale_result.has_changes:
79
+ # Always do a blocking incremental rebuild when stale
80
+ builder = IndexBuilder()
81
+ builder.build(cfg, db_path, rebuild=False, progress_fn=lambda msg: logger.info(msg))
82
+
83
+ # --- LLM client setup ---
84
+ llm_client: Optional[Any] = None
85
+ if use_llm:
86
+ from corbell.core.llm_client import LLMClient
87
+ llm_cfg = cfg.llm
88
+ llm_client = LLMClient(
89
+ provider=llm_cfg.provider,
90
+ model=llm_cfg.resolved_model(),
91
+ api_key=llm_cfg.resolved_api_key(),
92
+ aws_region=llm_cfg.aws_region,
93
+ azure_endpoint=llm_cfg.azure_endpoint,
94
+ azure_deployment=llm_cfg.azure_deployment,
95
+ azure_api_version=llm_cfg.azure_api_version,
96
+ gcp_project=llm_cfg.gcp_project,
97
+ gcp_region=llm_cfg.gcp_region,
98
+ )
99
+
100
+ # --- Search queries ---
101
+ search_queries = [query]
102
+
103
+ # --- Embedding model ---
104
+ model_name = cfg.storage.resolved_model()
105
+ emb_model: EmbeddingModel
106
+ if model_name.startswith("gemini-"):
107
+ emb_model = GoogleEmbeddingModel(model_name)
108
+ elif model_name.startswith("voyage-"):
109
+ emb_model = VoyageEmbeddingModel(model_name)
110
+ else:
111
+ emb_model = SentenceTransformerModel(model_name)
112
+
113
+ # --- Load search cache ---
114
+ cache = EmbeddingSearchCache()
115
+ cache.load(emb_store)
116
+
117
+ if not cache.is_loaded:
118
+ return "No index found. Run 'corbell index build' first."
119
+
120
+ # --- Embedding search ---
121
+ import numpy as np
122
+
123
+ all_embedding_results: dict[str, ScoredChunk] = {}
124
+ query_config = cfg.query
125
+
126
+ for sq in search_queries:
127
+ try:
128
+ if isinstance(emb_model, GoogleEmbeddingModel):
129
+ formatted_query = emb_model.prepare_query(sq) if emb_model.uses_prefix_format else sq
130
+ q_vecs = emb_model.encode([formatted_query], task_type="RETRIEVAL_QUERY")
131
+ elif isinstance(emb_model, VoyageEmbeddingModel):
132
+ q_vecs = emb_model.encode([sq], input_type="query")
133
+ else:
134
+ q_vecs = emb_model.encode([sq])
135
+ except Exception as exc:
136
+ return f"Error: Failed to load embedding model '{model_name}'. Ensure 'sentence-transformers' is installed. ({exc})"
137
+
138
+ q_vec = np.array(q_vecs[0], dtype=np.float32)
139
+ hits = cache.search(q_vec, top_k=top_k)
140
+
141
+ if not hits:
142
+ continue
143
+
144
+ # Fetch full records for top hits
145
+ hit_ids = [h[0] for h in hits]
146
+ hit_scores = {h[0]: h[1] for h in hits}
147
+
148
+ try:
149
+ records = emb_store.get_chunks_by_ids(hit_ids)
150
+ except Exception:
151
+ continue
152
+
153
+ # Build repo_path map for resolving absolute paths
154
+ repo_path_map = {
155
+ r.id: str(r.resolved_path) for r in cfg.repos if r.resolved_path
156
+ }
157
+
158
+ for record in records:
159
+ score = hit_scores.get(record.id, 0.0)
160
+ # Resolve absolute file path
161
+ abs_path = record.file_path
162
+ repo_root = repo_path_map.get(record.service_id, "")
163
+ if repo_root and not Path(abs_path).is_absolute():
164
+ abs_path = str((Path(repo_root) / abs_path).resolve())
165
+
166
+ chunk = ScoredChunk(
167
+ chunk_id=record.id,
168
+ score=score,
169
+ file_path=abs_path,
170
+ start_line=record.start_line,
171
+ end_line=record.end_line,
172
+ content=record.content,
173
+ repo_id=record.service_id,
174
+ symbol=record.symbol,
175
+ chunk_type=record.chunk_type,
176
+ language=record.language,
177
+ )
178
+
179
+ # Keep max score for deduplication across queries
180
+ existing = all_embedding_results.get(record.id)
181
+ if existing is None or score > existing.score:
182
+ all_embedding_results[record.id] = chunk
183
+
184
+ if not all_embedding_results:
185
+ return "No relevant code found for the given query."
186
+
187
+ base_chunks = list(all_embedding_results.values())
188
+
189
+ # --- Graph expansion ---
190
+ diagnostics = QueryDiagnostics()
191
+ bonus_chunks = expand_via_graph(
192
+ embedding_results=base_chunks,
193
+ graph_store=graph_store,
194
+ repos=cfg.repos,
195
+ max_depth=query_config.expand_call_depth,
196
+ max_chunks=query_config.expand_max_chunks,
197
+ diagnostics=diagnostics,
198
+ )
199
+
200
+ all_chunks = base_chunks + bonus_chunks
201
+
202
+ # --- Merge + dedup ---
203
+ merged = merge_and_dedup(all_chunks)
204
+
205
+ # --- Apply top_k cap ---
206
+ merged = merged[:top_k]
207
+
208
+ # --- LLM rerank ---
209
+ do_rerank = use_llm and rerank and query_config.rerank
210
+ if do_rerank:
211
+ # Annotate chunks with graph metadata before sending to the reranker
212
+ graph_meta = _annotate_with_graph_meta(merged, graph_store, cfg.repos)
213
+
214
+ rerank_start = time.time()
215
+ reranked_ids = rerank_chunks(query, merged, llm_client, graph_meta=graph_meta)
216
+ rerank_elapsed = time.time() - rerank_start
217
+ logger.info(
218
+ "Rerank complete: %.3fs, %d/%d chunks kept, order: %s",
219
+ rerank_elapsed,
220
+ len(reranked_ids),
221
+ len(merged),
222
+ reranked_ids,
223
+ )
224
+ # Reorder merged, keeping only chunks selected by the reranker
225
+ id_to_chunk = {c.chunk_id: c for c in merged}
226
+ merged = [id_to_chunk[cid] for cid in reranked_ids if cid in id_to_chunk]
227
+
228
+ # --- Format output ---
229
+ repo_paths = {r.id: str(r.resolved_path) for r in cfg.repos if r.resolved_path}
230
+ output = format_results(merged, repo_paths)
231
+
232
+ # Prepend diagnostics warning if thresholds exceeded
233
+ warning = diagnostics.summary()
234
+ if warning:
235
+ output = f"[warnings: {warning}]\n\n{output}"
236
+
237
+ return output
238
+
239
+
240
+ def _annotate_with_graph_meta(
241
+ chunks: List[Any],
242
+ graph_store: Any,
243
+ repos: List[Any],
244
+ ) -> Dict[str, Dict]:
245
+ """Build a graph metadata dict keyed by chunk_id for each chunk.
246
+
247
+ For each chunk, finds overlapping MethodNodes (by file_path + line range)
248
+ and collects:
249
+ - callers: number of methods that call into this chunk's method
250
+ - callees: number of method_call edges outgoing from this chunk's method
251
+ - flow: name of the first FlowNode that includes this method (or None)
252
+
253
+ Args:
254
+ chunks: List of ScoredChunk objects.
255
+ graph_store: SQLiteGraphStore instance.
256
+ repos: List of RepoConfig objects for path resolution.
257
+
258
+ Returns:
259
+ Dict mapping chunk_id -> {"callers": int, "callees": int, "flow": str | None}.
260
+ Chunks with no matching MethodNode are omitted.
261
+ """
262
+ from corbell.core.query.graph_expander import _find_matching_methods
263
+
264
+ # Build repo_id → absolute path mapping (same as graph_expander)
265
+ repo_path_map: Dict[str, Path] = {}
266
+ for repo in repos:
267
+ if repo.resolved_path:
268
+ repo_path_map[repo.id] = repo.resolved_path
269
+
270
+ try:
271
+ all_services = graph_store.get_all_services()
272
+ service_ids = [s.id for s in all_services]
273
+ except Exception:
274
+ return {}
275
+
276
+ graph_meta: Dict[str, Dict] = {}
277
+
278
+ for chunk in chunks:
279
+ try:
280
+ matching_methods = _find_matching_methods(
281
+ chunk, graph_store, repo_path_map, service_ids
282
+ )
283
+ except Exception:
284
+ continue
285
+
286
+ if not matching_methods:
287
+ continue
288
+
289
+ # Aggregate across all overlapping methods (e.g. nested lambdas)
290
+ total_callers = 0
291
+ total_callees = 0
292
+ flow_name: Optional[str] = None
293
+
294
+ for method in matching_methods:
295
+ try:
296
+ callers = graph_store.get_callers_of_method(method.id)
297
+ total_callers += len(callers)
298
+ except Exception:
299
+ pass
300
+
301
+ try:
302
+ outgoing = graph_store.get_dependencies(method.id)
303
+ total_callees += sum(1 for e in outgoing if e.kind == "method_call")
304
+ except Exception:
305
+ pass
306
+
307
+ if flow_name is None:
308
+ try:
309
+ flows = graph_store.get_flows_for_method(method.id)
310
+ if flows:
311
+ flow_name = flows[0].get("flow_name") or None
312
+ except Exception:
313
+ pass
314
+
315
+ graph_meta[chunk.chunk_id] = {
316
+ "callers": total_callers,
317
+ "callees": total_callees,
318
+ "flow": flow_name,
319
+ }
320
+
321
+ return graph_meta
@@ -0,0 +1,102 @@
1
+ """Query enhancement: LLM-based query expansion and keyword extraction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from typing import Any, List, Optional, Tuple
7
+
8
+
9
+ def enhance_query(
10
+ query: str,
11
+ llm_client: Optional[Any],
12
+ ) -> Tuple[List[str], List[str]]:
13
+ """Expand a user query into search queries and keywords.
14
+
15
+ With LLM configured: generates 3 natural-language search queries describing
16
+ what the relevant code would *do* (not technology names).
17
+
18
+ Without LLM: returns the original query as the sole search query and
19
+ extracts simple keywords via regex.
20
+
21
+ Args:
22
+ query: The user's natural language query.
23
+ llm_client: An LLMClient instance (or None / unconfigured).
24
+
25
+ Returns:
26
+ Tuple of (search_queries, keywords):
27
+ - search_queries: List of strings to embed and search with.
28
+ - keywords: List of extracted keywords for graph expansion hints.
29
+ """
30
+ if llm_client is not None and getattr(llm_client, "is_configured", False):
31
+ return _enhance_with_llm(query, llm_client)
32
+ else:
33
+ return _enhance_without_llm(query)
34
+
35
+
36
+ def _enhance_with_llm(
37
+ query: str, llm_client: Any
38
+ ) -> Tuple[List[str], List[str]]:
39
+ """Use LLM to generate 3 code-oriented search queries."""
40
+ system = (
41
+ "You are a code search assistant. Given a user query about code, "
42
+ "generate exactly 3 different natural-language search queries that describe "
43
+ "what the relevant implementation code *does* (not technology names or framework names). "
44
+ "Each query should describe behavior, logic, or data transformations. "
45
+ "Return exactly 3 queries, one per line, no numbering, no extra text."
46
+ )
47
+ user = f"User query: {query}\n\nGenerate 3 code search queries:"
48
+
49
+ try:
50
+ response = llm_client.call(system, user, max_tokens=300, temperature=0.1)
51
+ lines = [line.strip() for line in response.strip().splitlines() if line.strip()]
52
+ # Take up to 3 non-empty lines
53
+ search_queries = lines[:3]
54
+ if not search_queries:
55
+ search_queries = [query]
56
+ except Exception:
57
+ search_queries = [query]
58
+
59
+ # Extract keywords from the original query for graph expansion
60
+ keywords = _extract_keywords(query)
61
+ return search_queries, keywords
62
+
63
+
64
+ def _enhance_without_llm(query: str) -> Tuple[List[str], List[str]]:
65
+ """Simple enhancement without LLM: use query as-is, extract keywords via regex."""
66
+ keywords = _extract_keywords(query)
67
+ return [query], keywords
68
+
69
+
70
+ def _extract_keywords(text: str) -> List[str]:
71
+ """Extract meaningful keywords from text via regex."""
72
+ # Extract words that look like identifiers (camelCase, snake_case, etc.)
73
+ words = re.findall(r"[a-zA-Z][a-zA-Z0-9_]*", text)
74
+
75
+ # Filter stop words and short words
76
+ stop_words = {
77
+ "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
78
+ "of", "with", "by", "from", "up", "about", "into", "through", "during",
79
+ "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
80
+ "do", "does", "did", "will", "would", "could", "should", "may", "might",
81
+ "must", "shall", "can", "how", "what", "when", "where", "which", "who",
82
+ "that", "this", "these", "those", "it", "its", "get", "set", "use",
83
+ "new", "return", "class", "function", "method", "var", "let", "const",
84
+ "def", "import", "from", "as", "if", "else", "while", "for", "try",
85
+ "except", "raise", "pass", "break", "continue", "not", "and", "or",
86
+ }
87
+
88
+ keywords = [
89
+ w for w in words
90
+ if len(w) > 2 and w.lower() not in stop_words
91
+ ]
92
+
93
+ # Remove duplicates while preserving order
94
+ seen: set = set()
95
+ unique_keywords = []
96
+ for kw in keywords:
97
+ lower = kw.lower()
98
+ if lower not in seen:
99
+ seen.add(lower)
100
+ unique_keywords.append(kw)
101
+
102
+ return unique_keywords[:20] # cap at 20 keywords