java-codebase-rag 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
search_lancedb.py ADDED
@@ -0,0 +1,1075 @@
1
+ #!/usr/bin/env python3
2
+ """Semantic search over LanceDB tables built by CocoIndex (java_index_flow_lancedb)."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ import sys
10
+ import threading
11
+ from collections.abc import Callable
12
+ from pathlib import Path
13
+
14
+ import lancedb
15
+ import numpy as np
16
+ from sentence_transformers import SentenceTransformer
17
+
18
+ from chunk_heuristics import analyze_chunk, looks_like_code_identifier
19
+ from index_common import SBERT_MODEL
20
+ from java_codebase_rag.config import maybe_expand_embedding_model_path, resolved_sbert_model_for_process_env
21
+
22
+ TABLES: dict[str, str] = {
23
+ "java": "javacodeindex_java_code",
24
+ "sql": "sqlschemaindex_sql_schema",
25
+ "yaml": "yamlconfigindex_yaml_config",
26
+ }
27
+
28
+ # Optional enrichment columns on the java chunk table (absent on older indexes).
29
+ JAVA_ENRICHED_COLUMNS: tuple[str, ...] = (
30
+ "package",
31
+ "module",
32
+ "microservice",
33
+ "primary_type_fqn",
34
+ "primary_type_kind",
35
+ "role",
36
+ "annotations_on_type",
37
+ "symbols",
38
+ "symbol_id",
39
+ "metadata",
40
+ "ontology_version",
41
+ "capabilities",
42
+ )
43
+
44
+ VECTOR_COLUMN = "embedding"
45
+ _FTS_READY: set[tuple[str, str]] = set()
46
+ _FTS_LOCK = threading.Lock()
47
+ _SCHEMA_CACHE: dict[tuple[str, str], set[str]] = {}
48
+ _SCHEMA_LOCK = threading.Lock()
49
+
50
+
51
+ def _table_columns(uri: str, lance_table_name: str, db_obj: object | None = None) -> set[str]:
52
+ key = (uri, lance_table_name)
53
+ with _SCHEMA_LOCK:
54
+ cached = _SCHEMA_CACHE.get(key)
55
+ if cached is not None:
56
+ return cached
57
+ db = db_obj if db_obj is not None else lancedb.connect(uri)
58
+ tbl = db.open_table(lance_table_name)
59
+ cols = {f.name for f in tbl.schema}
60
+ with _SCHEMA_LOCK:
61
+ _SCHEMA_CACHE[key] = cols
62
+ return cols
63
+
64
+
65
+ def _escape_sql_str(s: str) -> str:
66
+ return s.replace("'", "''")
67
+
68
+
69
+ def _build_extra_predicates(
70
+ *,
71
+ columns: set[str],
72
+ role: str | None,
73
+ module: str | None,
74
+ microservice: str | None,
75
+ package_prefix: str | None,
76
+ fqn_in: list[str] | None,
77
+ role_in: list[str] | None = None,
78
+ exclude_roles: list[str] | None = None,
79
+ capability: str | None = None,
80
+ capability_in: list[str] | None = None,
81
+ ) -> list[str]:
82
+ preds: list[str] = []
83
+ if role and "role" in columns:
84
+ preds.append(f"role = '{_escape_sql_str(role)}'")
85
+
86
+ # When both role_in and capability_in are set, combine as OR so that
87
+ # capability-only entrypoints (e.g. role=OTHER with MESSAGE_LISTENER)
88
+ # are not silently excluded by the role filter.
89
+ role_pred: str | None = None
90
+ if role_in and "role" in columns:
91
+ vals = ", ".join(f"'{_escape_sql_str(v)}'" for v in role_in)
92
+ role_pred = f"role IN ({vals})"
93
+
94
+ cap_in_pred: str | None = None
95
+ if capability_in and "capabilities" in columns:
96
+ # array_has is the preferred form in LanceDB >= 0.10 (verified against 0.30.2).
97
+ parts = [
98
+ f"array_has(capabilities, '{_escape_sql_str(c)}')"
99
+ for c in capability_in
100
+ ]
101
+ cap_in_pred = "(" + " OR ".join(parts) + ")"
102
+
103
+ if role_pred and cap_in_pred:
104
+ preds.append(f"({role_pred} OR {cap_in_pred})")
105
+ elif role_pred:
106
+ preds.append(role_pred)
107
+ elif cap_in_pred:
108
+ preds.append(cap_in_pred)
109
+
110
+ if exclude_roles and "role" in columns:
111
+ vals = ", ".join(f"'{_escape_sql_str(v)}'" for v in exclude_roles)
112
+ preds.append(f"(role IS NULL OR role NOT IN ({vals}))")
113
+ if module and "module" in columns:
114
+ preds.append(f"module = '{_escape_sql_str(module)}'")
115
+ if microservice and "microservice" in columns:
116
+ preds.append(f"microservice = '{_escape_sql_str(microservice)}'")
117
+ if package_prefix and "package" in columns:
118
+ esc = _escape_sql_str(package_prefix)
119
+ preds.append(f"(package = '{esc}' OR package LIKE '{esc}.%')")
120
+ if fqn_in and "primary_type_fqn" in columns:
121
+ # LanceDB/Arrow SQL supports IN; quote each.
122
+ vals = ", ".join(f"'{_escape_sql_str(v)}'" for v in fqn_in)
123
+ preds.append(f"primary_type_fqn IN ({vals})")
124
+ if capability and "capabilities" in columns:
125
+ preds.append(f"array_has(capabilities, '{_escape_sql_str(capability)}')")
126
+ return preds
127
+
128
+
129
+ def coerce_position_field(val: object) -> dict[str, object]:
130
+ """LanceDB may return struct columns as JSON strings; normalize to a dict."""
131
+ if val is None:
132
+ return {}
133
+ if isinstance(val, dict):
134
+ return val
135
+ if isinstance(val, str):
136
+ try:
137
+ parsed = json.loads(val)
138
+ except json.JSONDecodeError:
139
+ return {}
140
+ return parsed if isinstance(parsed, dict) else {}
141
+ return {}
142
+
143
+
144
+ _IMPORT_DISTANCE_PENALTY = 0.08
145
+ _IMPORT_HYBRID_SCORE_FACTOR = 0.88
146
+
147
+ # Bonus for chunks whose declared symbols (method / field names) share tokens with
148
+ # the query. Behavioural queries like "what happens when a client message arrives"
149
+ # should float chunks containing `processClientMessage` above ones that only
150
+ # enqueue; this is a cheap, query-dependent signal computed at rank time.
151
+ _SYMBOL_MATCH_BONUS_PER_HIT = 0.03
152
+ _SYMBOL_MATCH_BONUS_CAP = 0.06
153
+
154
+ # Action verbs that typically mark behavioural entry points in this codebase.
155
+ # A chunk whose symbols begin with one of these verbs earns a small flat bump
156
+ # — again only for java chunks and only when role-filtering is off.
157
+ _ACTION_VERB_PREFIXES: tuple[str, ...] = (
158
+ "process", "handle", "on", "pick", "select", "assign",
159
+ "notify", "dispatch", "publish", "consume", "route",
160
+ "trigger", "enqueue", "distribute", "update", "create",
161
+ "apply", "resolve", "reassign", "close", "open",
162
+ )
163
+ _ACTION_VERB_BONUS = 0.02
164
+
165
+ # Type-name overlap bonus. The class name is a much stronger discovery signal
166
+ # than any individual method, because class naming in this codebase encodes
167
+ # the domain concept (`DistributionChunkService`, `OperatorSessionService`,
168
+ # `JoinOperatorController`). So we reward overlap between query tokens and the
169
+ # simple name of `primary_type_fqn` more heavily than per-method overlap, and
170
+ # we stack it on top of the existing `_symbol_bonus`.
171
+ _TYPE_MATCH_BONUS_PER_HIT = 0.05
172
+ _TYPE_MATCH_BONUS_CAP = 0.10
173
+
174
+ _STOPWORDS: frozenset[str] = frozenset({
175
+ "a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
176
+ "to", "of", "in", "on", "at", "by", "for", "with", "from", "as", "or",
177
+ "and", "but", "if", "then", "else", "when", "what", "how", "why", "does",
178
+ "do", "did", "has", "have", "had", "this", "that", "these", "those", "it",
179
+ "its", "new", "no", "not", "will", "would", "should", "can", "could",
180
+ "may", "might", "happens", "happen", "happened", "get", "gets", "got",
181
+ })
182
+
183
+ # Role-aware reweighting for Java chunks. Positive values favour actionable
184
+ # behavioural code (entrypoints, orchestrators, integrations) over configuration,
185
+ # schema, and persistence stubs for "what happens when..."-style queries.
186
+ # Applied to the similarity score (higher = better); distance-based sort subtracts
187
+ # the weight. Skipped when caller filters explicitly by role.
188
+ _ROLE_SCORE_WEIGHTS: dict[str, float] = {
189
+ "CONTROLLER": 0.10,
190
+ "SERVICE": 0.08,
191
+ "CLIENT": 0.06,
192
+ "COMPONENT": 0.03,
193
+ "REPOSITORY": 0.02,
194
+ "MAPPER": 0.00,
195
+ "OTHER": 0.00,
196
+ "ENTITY": -0.06,
197
+ "CONFIG": -0.10,
198
+ # DTOs are passive data carriers; they almost never answer "how/what
199
+ # happens" queries. Penalty is slightly stronger than ENTITY so a DTO
200
+ # with a great embedding match still loses to a mediocre SERVICE hit.
201
+ "DTO": -0.08,
202
+ }
203
+
204
+
205
+ def _query_tokens(query: str) -> set[str]:
206
+ """Lowercased alpha-only tokens from the query, minus stopwords, len >= 3.
207
+
208
+ Used to score symbol-name overlap; we keep it simple and locale-free.
209
+ """
210
+ out: set[str] = set()
211
+ cur: list[str] = []
212
+
213
+ def _flush() -> None:
214
+ if cur:
215
+ tok = "".join(cur).lower()
216
+ cur.clear()
217
+ if len(tok) >= 3 and tok not in _STOPWORDS:
218
+ out.add(tok)
219
+
220
+ for c in query:
221
+ if c.isalpha():
222
+ cur.append(c)
223
+ else:
224
+ _flush()
225
+ _flush()
226
+ return out
227
+
228
+
229
+ def _split_identifier(name: str) -> list[str]:
230
+ """camelCase / snake_case -> lowercase token list."""
231
+ parts: list[str] = []
232
+ cur: list[str] = []
233
+ for c in name:
234
+ if c == "_":
235
+ if cur:
236
+ parts.append("".join(cur).lower())
237
+ cur = []
238
+ elif c.isupper() and cur:
239
+ parts.append("".join(cur).lower())
240
+ cur = [c]
241
+ else:
242
+ cur.append(c)
243
+ if cur:
244
+ parts.append("".join(cur).lower())
245
+ return [p for p in parts if p]
246
+
247
+
248
+ def _symbol_bonus(r: dict, query_toks: set[str]) -> float:
249
+ """Symbol-name overlap + action-verb bump for java chunks.
250
+
251
+ Caps at `_SYMBOL_MATCH_BONUS_CAP + _ACTION_VERB_BONUS` to avoid runaway
252
+ ranks on chunks declaring many symbols.
253
+ """
254
+ if str(r.get("_kind", "")) != "java":
255
+ return 0.0
256
+ raw = r.get("symbols") or []
257
+ if isinstance(raw, str):
258
+ # Legacy JSON-encoded list column; parse defensively.
259
+ try:
260
+ parsed = json.loads(raw)
261
+ raw = parsed if isinstance(parsed, list) else []
262
+ except Exception:
263
+ raw = []
264
+ symbols = [str(s) for s in raw if s]
265
+
266
+ overlap_hits = 0
267
+ has_action = False
268
+ for s in symbols:
269
+ bare = s.split("(", 1)[0].strip()
270
+ if not bare:
271
+ continue
272
+ toks = _split_identifier(bare)
273
+ if toks:
274
+ if toks[0] in _ACTION_VERB_PREFIXES:
275
+ has_action = True
276
+ if query_toks & set(toks):
277
+ overlap_hits += 1
278
+
279
+ bonus = min(overlap_hits * _SYMBOL_MATCH_BONUS_PER_HIT, _SYMBOL_MATCH_BONUS_CAP)
280
+ if has_action:
281
+ bonus += _ACTION_VERB_BONUS
282
+
283
+ # Type-name overlap: strongest single lexical signal for "which class is
284
+ # the answer?" queries. Uses the simple name of primary_type_fqn.
285
+ fqn = str(r.get("primary_type_fqn") or "")
286
+ if fqn:
287
+ simple = fqn.rsplit(".", 1)[-1]
288
+ type_toks = set(_split_identifier(simple))
289
+ type_hits = len(query_toks & type_toks)
290
+ if type_hits:
291
+ bonus += min(type_hits * _TYPE_MATCH_BONUS_PER_HIT, _TYPE_MATCH_BONUS_CAP)
292
+ return bonus
293
+
294
+
295
+ def _apply_chunk_hints(rows: list[dict]) -> None:
296
+ for r in rows:
297
+ lang = r.get("language") or ""
298
+ kind = str(r.get("_kind", ""))
299
+ if kind == "sql" and not lang:
300
+ lang = "sql"
301
+ if kind == "yaml" and not lang:
302
+ lang = "yaml"
303
+ h = analyze_chunk(r.get("text"), language=str(lang), kind=kind)
304
+ r["_hints"] = {
305
+ "primary_type_hint": h.primary_type_hint,
306
+ "import_heavy": h.import_heavy,
307
+ }
308
+
309
+
310
+ def _role_weight(r: dict) -> float:
311
+ """Effective role weight for a row, captured into `_score_components.role_weight`."""
312
+ comps = r.setdefault("_score_components", {})
313
+ cached = comps.get("role_weight")
314
+ if cached is not None:
315
+ return float(cached)
316
+ if r.get("_skip_role_weight") or str(r.get("_kind", "")) != "java":
317
+ comps["role_weight"] = 0.0
318
+ return 0.0
319
+ role = (r.get("role") or "").upper()
320
+ w = _ROLE_SCORE_WEIGHTS.get(role, 0.0)
321
+ comps["role_weight"] = w
322
+ return w
323
+
324
+
325
+ def _apply_symbol_bonus(rows: list[dict], query_toks: set[str]) -> None:
326
+ """Pre-compute symbol-match bonus into `_score_components.symbol_bonus`."""
327
+ if not query_toks:
328
+ return
329
+ for r in rows:
330
+ if r.get("_skip_role_weight"):
331
+ # When the caller locked role, respect their intent everywhere.
332
+ continue
333
+ b = _symbol_bonus(r, query_toks)
334
+ if b:
335
+ r.setdefault("_score_components", {})["symbol_bonus"] = b
336
+
337
+
338
+ def _vector_sort_key(r: dict) -> float:
339
+ d = float(r["_distance"])
340
+ comps = r.setdefault("_score_components", {})
341
+ comps["distance"] = d
342
+ if r.get("_hints", {}).get("import_heavy"):
343
+ d += _IMPORT_DISTANCE_PENALTY
344
+ comps["import_penalty"] = _IMPORT_DISTANCE_PENALTY
345
+ d -= _role_weight(r)
346
+ d -= float(comps.get("symbol_bonus", 0.0))
347
+ return d
348
+
349
+
350
+ def _hybrid_sort_key(r: dict) -> float:
351
+ s = float(r.get("_score", 0.0))
352
+ comps = r.setdefault("_score_components", {})
353
+ comps["hybrid_rrf"] = s
354
+ if r.get("_hints", {}).get("import_heavy"):
355
+ s *= _IMPORT_HYBRID_SCORE_FACTOR
356
+ comps["import_penalty"] = _IMPORT_HYBRID_SCORE_FACTOR
357
+ s += _role_weight(r)
358
+ s += float(comps.get("symbol_bonus", 0.0))
359
+ return -s
360
+
361
+
362
+ def explain_score_components(
363
+ comps: dict[str, float] | None,
364
+ *,
365
+ role: str | None = None,
366
+ hybrid: bool = False,
367
+ graph_expanded: bool = False,
368
+ ) -> str:
369
+ """Compact human-readable 'why' string for a ranked hit.
370
+
371
+ Joins the interesting components of `_score_components` in a stable order
372
+ so agents can reason about rankings without chasing raw floats. Returns
373
+ "" if there's nothing worth mentioning.
374
+ """
375
+ if not comps:
376
+ comps = {}
377
+ parts: list[str] = []
378
+ if hybrid:
379
+ rrf = comps.get("hybrid_rrf")
380
+ if rrf is not None:
381
+ parts.append(f"rrf={float(rrf):.3f}")
382
+ else:
383
+ d = comps.get("distance")
384
+ if d is not None:
385
+ parts.append(f"dist={float(d):.2f}")
386
+ rw = comps.get("role_weight")
387
+ if rw:
388
+ label = f"role:{role}" if role else "role"
389
+ parts.append(f"{label}:{float(rw):+.02f}")
390
+ sb = comps.get("symbol_bonus")
391
+ if sb:
392
+ parts.append(f"symbol:{float(sb):+.02f}")
393
+ ip = comps.get("import_penalty")
394
+ if ip:
395
+ parts.append(f"import_penalty:{float(ip):+.02f}")
396
+ if graph_expanded:
397
+ parts.append("graph")
398
+ return " ".join(parts)
399
+
400
+
401
+ def l2_distance_to_score(distance: float) -> float:
402
+ """Map L2 distance to a similarity score for unit-normalized embeddings."""
403
+ return 1.0 - distance * distance / 2.0
404
+
405
+
406
+ def _escape_like_fragment(s: str) -> str:
407
+ return s.replace("'", "''")
408
+
409
+
410
+ def _escape_sql_like_pattern(s: str) -> str:
411
+ out: list[str] = []
412
+ for c in s:
413
+ if c in ("\\", "%", "_"):
414
+ out.append("\\" + c)
415
+ else:
416
+ out.append(c)
417
+ return "".join(out)
418
+
419
+
420
+ def _build_path_predicate(path_substring: str) -> str:
421
+ pat = _escape_sql_like_pattern(path_substring)
422
+ pat = _escape_like_fragment(pat)
423
+ return f"filename LIKE '%{pat}%' ESCAPE '\\'"
424
+
425
+
426
+ def ensure_text_fts_index(uri: str, lance_table_name: str) -> None:
427
+ key = (uri, lance_table_name)
428
+ with _FTS_LOCK:
429
+ if key in _FTS_READY:
430
+ return
431
+ db = lancedb.connect(uri)
432
+ tbl = db.open_table(lance_table_name)
433
+ try:
434
+ tbl.create_fts_index("text", replace=False)
435
+ except Exception as e:
436
+ low = str(e).lower()
437
+ if any(
438
+ w in low
439
+ for w in ("exist", "duplicate", "already", "same name")
440
+ ):
441
+ pass
442
+ else:
443
+ raise
444
+ _FTS_READY.add(key)
445
+
446
+
447
+ def _query_vector(model: SentenceTransformer, text: str) -> np.ndarray:
448
+ v = model.encode(
449
+ text,
450
+ convert_to_numpy=True,
451
+ normalize_embeddings=True,
452
+ show_progress_bar=False,
453
+ )
454
+ return np.asarray(v, dtype=np.float32)
455
+
456
+
457
+ def _combine_predicates(parts: list[str | None]) -> str | None:
458
+ clean = [p for p in parts if p]
459
+ if not clean:
460
+ return None
461
+ if len(clean) == 1:
462
+ return clean[0]
463
+ return " AND ".join(f"({p})" for p in clean)
464
+
465
+
466
+ def _search_one_table(
467
+ table_name: str,
468
+ *,
469
+ uri: str,
470
+ db: object,
471
+ query_vec: np.ndarray,
472
+ limit: int,
473
+ path_predicate: str | None,
474
+ kind: str,
475
+ hybrid: bool,
476
+ fts_text: str | None,
477
+ extra_predicates: list[str] | None = None,
478
+ ) -> list[dict]:
479
+ tbl = db.open_table(table_name)
480
+ has_lang = kind == "java"
481
+ table_cols = _table_columns(uri, table_name, db)
482
+ enriched_cols = table_cols if has_lang else set()
483
+ # `range_start` / `range_end` are needed downstream by `_attach_neighbor_context`
484
+ # to locate the chunk inside its file; select them whenever the schema has them.
485
+ base_cols = ["filename", "text", "start", "end"]
486
+ for col in ("range_start", "range_end"):
487
+ if col in table_cols:
488
+ base_cols.append(col)
489
+ java_extra = [c for c in JAVA_ENRICHED_COLUMNS if c in enriched_cols] if has_lang else []
490
+ combined_pred = _combine_predicates([path_predicate, *(extra_predicates or [])])
491
+
492
+ if hybrid:
493
+ ensure_text_fts_index(uri, table_name)
494
+ text_for_fts = fts_text if fts_text is not None else ""
495
+ columns = (
496
+ [*base_cols, "language", *java_extra]
497
+ if has_lang
498
+ else [*base_cols]
499
+ )
500
+ q = (
501
+ tbl.search(
502
+ query_type="hybrid",
503
+ vector_column_name=VECTOR_COLUMN,
504
+ )
505
+ .vector(query_vec)
506
+ .text(text_for_fts)
507
+ .select(columns)
508
+ .limit(limit)
509
+ )
510
+ if combined_pred:
511
+ q = q.where(combined_pred, prefilter=True)
512
+ rows = q.to_list()
513
+ for r in rows:
514
+ r["_kind"] = kind
515
+ rs = r.pop("_relevance_score", None)
516
+ r["_hybrid"] = True
517
+ if rs is not None:
518
+ r["_score"] = float(rs)
519
+ r["start"] = coerce_position_field(r.get("start"))
520
+ r["end"] = coerce_position_field(r.get("end"))
521
+ return rows
522
+
523
+ columns = (
524
+ [*base_cols, "language", *java_extra, "_distance"]
525
+ if has_lang
526
+ else [*base_cols, "_distance"]
527
+ )
528
+ q = tbl.search(query_vec, vector_column_name=VECTOR_COLUMN).select(
529
+ columns
530
+ ).limit(limit)
531
+ if combined_pred:
532
+ q = q.where(combined_pred, prefilter=True)
533
+ rows = q.to_list()
534
+ for r in rows:
535
+ r["_kind"] = kind
536
+ r["_hybrid"] = False
537
+ r["start"] = coerce_position_field(r.get("start"))
538
+ r["end"] = coerce_position_field(r.get("end"))
539
+ return rows
540
+
541
+
542
+ def _debug_ctx(msg: str) -> None:
543
+ """Emit context-expansion diagnostics when JAVA_CODEBASE_RAG_DEBUG_CONTEXT is set.
544
+
545
+ Writes to stderr so it doesn't pollute MCP stdout. Cheap no-op otherwise.
546
+ """
547
+ if os.environ.get("JAVA_CODEBASE_RAG_DEBUG_CONTEXT"):
548
+ print(f"[context_neighbors] {msg}", file=sys.stderr)
549
+
550
+
551
+ def _attach_neighbor_context(
552
+ rows: list[dict], *, db: object, neighbors: int, uri: str | None = None,
553
+ ) -> None:
554
+ """Populate `_context_before` / `_context_after` with adjacent Java chunk text.
555
+
556
+ Strategy (in order):
557
+ 1. Schema-aware scan of the java table, selecting only columns that exist
558
+ (`filename` + `text` always; `range_start`/`range_end` when present).
559
+ 2. Sort the per-file bucket by `range_start` if available; otherwise keep
560
+ the table's natural order (good enough because chunks are produced in
561
+ file order by CocoIndex).
562
+ 3. Locate each row's index via (a) range tuple match, (b) exact text match
563
+ as fallback. Missing both -> log and skip.
564
+ 4. Any exception is logged (behind env flag) and the field stays empty; we
565
+ never break search because of context expansion.
566
+ """
567
+ if neighbors <= 0:
568
+ return
569
+ java_rows = [r for r in rows if str(r.get("_kind", "")) == "java"]
570
+ if not java_rows:
571
+ _debug_ctx("no java rows in window; nothing to expand")
572
+ return
573
+ filenames = {str(r.get("filename", "")) for r in java_rows if r.get("filename")}
574
+ if not filenames:
575
+ _debug_ctx("java rows had no filename field; skipping")
576
+ return
577
+
578
+ java_table = TABLES["java"]
579
+ try:
580
+ tbl = db.open_table(java_table)
581
+ except Exception as exc:
582
+ _debug_ctx(f"open_table({java_table}) failed: {exc!r}")
583
+ return
584
+
585
+ # Discover which positional columns the index actually carries. Older
586
+ # indexes may predate `range_start`/`range_end`; newer ones always have
587
+ # them. Asking for a missing column makes the whole scan fail.
588
+ try:
589
+ schema_cols = _table_columns(uri, java_table, db) if uri else {f.name for f in tbl.schema}
590
+ except Exception as exc:
591
+ _debug_ctx(f"schema lookup failed: {exc!r}")
592
+ schema_cols = set()
593
+
594
+ has_range = {"range_start", "range_end"}.issubset(schema_cols)
595
+ scan_cols = ["filename", "text"]
596
+ if has_range:
597
+ scan_cols.extend(("range_start", "range_end"))
598
+
599
+ try:
600
+ in_list = ", ".join(f"'{_escape_sql_str(f)}'" for f in filenames)
601
+ scanner = tbl.to_lance().scanner(
602
+ filter=f"filename IN ({in_list})",
603
+ columns=scan_cols,
604
+ )
605
+ all_chunks = scanner.to_table().to_pylist()
606
+ except Exception as exc:
607
+ _debug_ctx(f"bucket scan failed (cols={scan_cols}): {exc!r}")
608
+ return
609
+
610
+ if not all_chunks:
611
+ _debug_ctx(f"bucket scan returned 0 chunks for {len(filenames)} filenames")
612
+ return
613
+
614
+ by_file: dict[str, list[dict]] = {}
615
+ for ch in all_chunks:
616
+ by_file.setdefault(str(ch.get("filename", "")), []).append(ch)
617
+ if has_range:
618
+ for lst in by_file.values():
619
+ lst.sort(
620
+ key=lambda c: (int(c.get("range_start") or 0), int(c.get("range_end") or 0))
621
+ )
622
+
623
+ attached = 0
624
+ for r in java_rows:
625
+ fn = str(r.get("filename", ""))
626
+ bucket = by_file.get(fn, [])
627
+ if not bucket:
628
+ _debug_ctx(f"no bucket for filename={fn!r}")
629
+ continue
630
+
631
+ idx: int | None = None
632
+ if has_range:
633
+ start = int(r.get("range_start") or 0)
634
+ end = int(r.get("range_end") or 0)
635
+ if start or end:
636
+ idx = next(
637
+ (
638
+ i for i, c in enumerate(bucket)
639
+ if int(c.get("range_start") or -1) == start
640
+ and int(c.get("range_end") or -1) == end
641
+ ),
642
+ None,
643
+ )
644
+
645
+ if idx is None:
646
+ r_text = str(r.get("text") or "")
647
+ if r_text:
648
+ idx = next(
649
+ (i for i, c in enumerate(bucket) if str(c.get("text") or "") == r_text),
650
+ None,
651
+ )
652
+
653
+ if idx is None:
654
+ _debug_ctx(
655
+ f"could not locate chunk in bucket (file={fn!r}, "
656
+ f"has_range={has_range}, bucket_size={len(bucket)})"
657
+ )
658
+ continue
659
+
660
+ before_parts = [str(c.get("text") or "") for c in bucket[max(0, idx - neighbors):idx]]
661
+ after_parts = [str(c.get("text") or "") for c in bucket[idx + 1 : idx + 1 + neighbors]]
662
+ r["_context_before"] = "\n".join(before_parts)
663
+ r["_context_after"] = "\n".join(after_parts)
664
+ attached += 1
665
+
666
+ _debug_ctx(f"attached context to {attached}/{len(java_rows)} java rows")
667
+
668
+
669
+ def _graph_expand_merge(
670
+ vector_rows: list[dict],
671
+ *,
672
+ query_vec: np.ndarray,
673
+ db: object,
674
+ uri: str,
675
+ limit: int,
676
+ extra_predicates: list[str],
677
+ expand_depth: int,
678
+ kuzu_path: str | None,
679
+ ) -> list[dict]:
680
+ """Expand vector top-k through the Kuzu graph and fuse (RRF) with the original list."""
681
+ # Lazy import so the module works without kuzu installed when graph_expand=False.
682
+ try:
683
+ from kuzu_queries import KuzuGraph
684
+ except Exception:
685
+ return vector_rows
686
+
687
+ if not KuzuGraph.exists(kuzu_path):
688
+ return vector_rows
689
+
690
+ seed_fqns = sorted({r.get("primary_type_fqn") for r in vector_rows if r.get("primary_type_fqn")})
691
+ if not seed_fqns:
692
+ return vector_rows
693
+
694
+ try:
695
+ graph = KuzuGraph.get(kuzu_path)
696
+ structural = graph.expand_fqns(seed_fqns, depth=expand_depth)
697
+ method_pairs = graph.expand_methods(
698
+ seed_fqns, depth=expand_depth, exclude_external=True,
699
+ )
700
+ expand_weight_by_fqn: dict[str, float] = {}
701
+ for f in structural:
702
+ if f:
703
+ expand_weight_by_fqn[f] = max(expand_weight_by_fqn.get(f, 0.0), 1.0)
704
+ for f, conf in method_pairs:
705
+ if f:
706
+ expand_weight_by_fqn[f] = max(expand_weight_by_fqn.get(f, 0.0), conf)
707
+ neighbor_fqns = list(dict.fromkeys(
708
+ list(structural) + [f for f, _ in method_pairs],
709
+ ))
710
+ except Exception:
711
+ return vector_rows
712
+
713
+ novel = [fqn for fqn in neighbor_fqns if fqn and fqn not in set(seed_fqns)]
714
+ if not novel:
715
+ return vector_rows
716
+
717
+ extra = list(extra_predicates)
718
+ extra.extend(_build_extra_predicates(
719
+ columns=_table_columns(uri, TABLES["java"], db),
720
+ role=None, module=None, microservice=None,
721
+ package_prefix=None, fqn_in=novel,
722
+ ))
723
+
724
+ try:
725
+ graph_rows = _search_one_table(
726
+ TABLES["java"],
727
+ uri=uri, db=db, query_vec=query_vec,
728
+ limit=max(limit, 20),
729
+ path_predicate=None, kind="java",
730
+ hybrid=False, fts_text=None,
731
+ extra_predicates=extra,
732
+ )
733
+ except Exception:
734
+ return vector_rows
735
+ _apply_chunk_hints(graph_rows)
736
+ graph_rows.sort(key=_vector_sort_key)
737
+ for r in graph_rows:
738
+ r["_graph_expanded"] = True
739
+ r["_graph_expand_weight"] = expand_weight_by_fqn.get(
740
+ r.get("primary_type_fqn"), 1.0,
741
+ )
742
+ fused = _rrf_merge(
743
+ [vector_rows, graph_rows],
744
+ row_weight_for_list_index=[
745
+ None,
746
+ lambda row: float(row.get("_graph_expand_weight", 1.0)),
747
+ ],
748
+ )
749
+ return fused
750
+
751
+
752
+ def _rrf_merge(
753
+ lists: list[list[dict]],
754
+ *,
755
+ k: int = 60,
756
+ row_weight_for_list_index: list[Callable[[dict], float] | None] | None = None,
757
+ ) -> list[dict]:
758
+ """Reciprocal-rank-fuse several ranked lists of chunk rows.
759
+
760
+ Rows are deduplicated by (filename, range_start, range_end). The merged
761
+ rows get a `_rrf_score` field so callers can inspect or re-sort.
762
+
763
+ When ``row_weight_for_list_index`` is set, its length must match ``lists``;
764
+ a non-None entry is a callable ``row -> weight`` multiplied into that list's
765
+ rank contribution (``None`` means weight ``1.0`` for every row).
766
+ """
767
+ pool: dict[tuple, dict] = {}
768
+ for li, ranked in enumerate(lists):
769
+ wfn: Callable[[dict], float] | None = None
770
+ if row_weight_for_list_index is not None and li < len(row_weight_for_list_index):
771
+ wfn = row_weight_for_list_index[li]
772
+ for rank, row in enumerate(ranked):
773
+ key = (row.get("filename"), row.get("range_start"), row.get("range_end"))
774
+ existing = pool.get(key)
775
+ weight = 1.0 if wfn is None else float(wfn(row))
776
+ contribution = weight * (1.0 / (k + rank + 1))
777
+ if existing is None:
778
+ row["_rrf_score"] = contribution
779
+ pool[key] = row
780
+ else:
781
+ existing["_rrf_score"] = float(existing.get("_rrf_score", 0.0)) + contribution
782
+ merged = list(pool.values())
783
+ merged.sort(key=lambda r: -float(r.get("_rrf_score", 0.0)))
784
+ return merged
785
+
786
+
787
+ def run_search(
788
+ query: str,
789
+ *,
790
+ uri: str,
791
+ table_keys: list[str],
792
+ limit: int,
793
+ path_substring: str | None,
794
+ model_name: str,
795
+ device: str | None,
796
+ offset: int = 0,
797
+ model: SentenceTransformer | None = None,
798
+ hybrid: bool = False,
799
+ fts_text: str | None = None,
800
+ auto_hybrid: bool = False,
801
+ role: str | None = None,
802
+ module: str | None = None,
803
+ microservice: str | None = None,
804
+ package_prefix: str | None = None,
805
+ graph_expand: bool = False,
806
+ expand_depth: int = 1,
807
+ kuzu_path: str | None = None,
808
+ context_neighbors: int = 0,
809
+ role_in: list[str] | None = None,
810
+ exclude_roles: list[str] | None = None,
811
+ capability: str | None = None,
812
+ capability_in: list[str] | None = None,
813
+ ) -> list[dict]:
814
+ effective_hybrid = hybrid
815
+ effective_fts = fts_text
816
+ if (
817
+ auto_hybrid
818
+ and not hybrid
819
+ and len(table_keys) == 1
820
+ and looks_like_code_identifier(query)
821
+ ):
822
+ effective_hybrid = True
823
+ if effective_fts is None:
824
+ effective_fts = query.strip()
825
+
826
+ if effective_hybrid and len(table_keys) != 1:
827
+ raise ValueError(
828
+ "hybrid search requires exactly one table; "
829
+ "use table java, sql, or yaml (not all)."
830
+ )
831
+
832
+ path_predicate = (
833
+ _build_path_predicate(path_substring) if path_substring else None
834
+ )
835
+
836
+ if model is None:
837
+ model = SentenceTransformer(
838
+ model_name,
839
+ device=device,
840
+ trust_remote_code=True,
841
+ )
842
+ query_vec = _query_vector(model, query)
843
+ fts_for_hybrid = effective_fts if effective_fts is not None else query
844
+
845
+ db = lancedb.connect(uri)
846
+ need = max(limit + offset, 1)
847
+
848
+ extra_java = _build_extra_predicates(
849
+ columns=_table_columns(uri, TABLES["java"], db),
850
+ role=role, module=module, microservice=microservice,
851
+ package_prefix=package_prefix, fqn_in=None,
852
+ role_in=role_in, exclude_roles=exclude_roles,
853
+ capability=capability, capability_in=capability_in,
854
+ ) if "java" in table_keys else []
855
+
856
+ skip_role_weight = bool(role or role_in)
857
+ query_toks = _query_tokens(query)
858
+
859
+ if len(table_keys) == 1:
860
+ key = table_keys[0]
861
+ preds = extra_java if key == "java" else []
862
+ rows = _search_one_table(
863
+ TABLES[key],
864
+ uri=uri,
865
+ db=db,
866
+ query_vec=query_vec,
867
+ limit=need,
868
+ path_predicate=path_predicate,
869
+ kind=key,
870
+ hybrid=effective_hybrid,
871
+ fts_text=fts_for_hybrid,
872
+ extra_predicates=preds,
873
+ )
874
+ _apply_chunk_hints(rows)
875
+ if skip_role_weight:
876
+ for r in rows:
877
+ r["_skip_role_weight"] = True
878
+ _apply_symbol_bonus(rows, query_toks)
879
+ if effective_hybrid:
880
+ rows.sort(key=_hybrid_sort_key)
881
+ else:
882
+ rows.sort(key=_vector_sort_key)
883
+
884
+ if graph_expand and key == "java" and expand_depth > 0:
885
+ rows = _graph_expand_merge(
886
+ rows,
887
+ query_vec=query_vec,
888
+ db=db,
889
+ uri=uri,
890
+ limit=need,
891
+ extra_predicates=extra_java,
892
+ expand_depth=expand_depth,
893
+ kuzu_path=kuzu_path,
894
+ )
895
+
896
+ window = rows[offset : offset + limit]
897
+ if context_neighbors > 0 and key == "java":
898
+ _attach_neighbor_context(window, db=db, neighbors=context_neighbors, uri=uri)
899
+ return window
900
+
901
+ merged: list[dict] = []
902
+ per_table = max(need * 3, need)
903
+ for key in table_keys:
904
+ preds = extra_java if key == "java" else []
905
+ merged.extend(
906
+ _search_one_table(
907
+ TABLES[key],
908
+ uri=uri,
909
+ db=db,
910
+ query_vec=query_vec,
911
+ limit=per_table,
912
+ path_predicate=path_predicate,
913
+ kind=key,
914
+ hybrid=False,
915
+ fts_text=None,
916
+ extra_predicates=preds,
917
+ )
918
+ )
919
+ _apply_chunk_hints(merged)
920
+ if skip_role_weight:
921
+ for r in merged:
922
+ r["_skip_role_weight"] = True
923
+ _apply_symbol_bonus(merged, query_toks)
924
+ merged.sort(key=_vector_sort_key)
925
+ window = merged[offset : offset + limit]
926
+ if context_neighbors > 0:
927
+ _attach_neighbor_context(window, db=db, neighbors=context_neighbors, uri=uri)
928
+ return window
929
+
930
+
931
+ def main() -> None:
932
+ parser = argparse.ArgumentParser(
933
+ description="Vector search in LanceDB index.",
934
+ )
935
+ parser.add_argument("query", help="Natural-language search query")
936
+ parser.add_argument(
937
+ "--table",
938
+ choices=["java", "sql", "yaml", "all"],
939
+ default="java",
940
+ )
941
+ parser.add_argument("--limit", type=int, default=10)
942
+ parser.add_argument(
943
+ "--lancedb-uri",
944
+ default=os.environ.get("JAVA_CODEBASE_RAG_INDEX_DIR", "")
945
+ or str((Path.cwd() / ".java-codebase-rag").resolve()),
946
+ )
947
+ parser.add_argument("--path-contains", metavar="SUBSTR", default=None)
948
+ parser.add_argument(
949
+ "--model",
950
+ default=None,
951
+ help=(
952
+ "sentence-transformers hub id or local model directory "
953
+ f"(default: SBERT_MODEL env or {SBERT_MODEL!r})"
954
+ ),
955
+ )
956
+ parser.add_argument("--device", default=None)
957
+ parser.add_argument("--text-width", type=int, default=320)
958
+ parser.add_argument("--hybrid", action="store_true")
959
+ parser.add_argument("--fts-text", metavar="TEXT", default=None)
960
+ parser.add_argument("--auto-hybrid", action="store_true")
961
+ parser.add_argument("--role", default=None)
962
+ parser.add_argument("--module", default=None,
963
+ help="Filter to a single Maven/Gradle module name.")
964
+ parser.add_argument("--microservice", default=None,
965
+ help="Filter to a single deployable microservice (top-level dir under project root).")
966
+ parser.add_argument("--package-prefix", default=None)
967
+ parser.add_argument("--graph-expand", action="store_true")
968
+ parser.add_argument("--expand-depth", type=int, default=1)
969
+ parser.add_argument("--kuzu-path", default=None)
970
+ parser.add_argument(
971
+ "--context-neighbors", type=int, default=0,
972
+ help="Attach N adjacent chunks per hit as surrounding context (Java only).",
973
+ )
974
+ args = parser.parse_args()
975
+
976
+ uri_path = Path(args.lancedb_uri)
977
+ if not uri_path.exists():
978
+ print(f"Error: LanceDB path missing: {uri_path.resolve()}", file=sys.stderr)
979
+ sys.exit(1)
980
+
981
+ keys = list(TABLES) if args.table == "all" else [args.table]
982
+ if args.hybrid and args.table == "all":
983
+ print("Error: --hybrid needs a single --table.", file=sys.stderr)
984
+ sys.exit(2)
985
+ if args.auto_hybrid and args.table == "all":
986
+ print("Error: --auto-hybrid needs a single --table.", file=sys.stderr)
987
+ sys.exit(2)
988
+
989
+ raw_model = args.model
990
+ if raw_model is None or not str(raw_model).strip():
991
+ model_name = resolved_sbert_model_for_process_env(SBERT_MODEL)
992
+ else:
993
+ model_name = maybe_expand_embedding_model_path(str(raw_model).strip())
994
+
995
+ try:
996
+ results = run_search(
997
+ args.query,
998
+ uri=str(uri_path),
999
+ table_keys=keys,
1000
+ limit=args.limit,
1001
+ path_substring=args.path_contains,
1002
+ model_name=model_name,
1003
+ device=args.device,
1004
+ hybrid=args.hybrid,
1005
+ fts_text=args.fts_text,
1006
+ auto_hybrid=args.auto_hybrid,
1007
+ role=args.role,
1008
+ module=args.module,
1009
+ microservice=args.microservice,
1010
+ package_prefix=args.package_prefix,
1011
+ graph_expand=args.graph_expand,
1012
+ expand_depth=args.expand_depth,
1013
+ kuzu_path=args.kuzu_path,
1014
+ context_neighbors=args.context_neighbors,
1015
+ )
1016
+ except Exception as e:
1017
+ print(f"Search failed: {e}", file=sys.stderr)
1018
+ sys.exit(1)
1019
+
1020
+ if not results:
1021
+ print("No results.")
1022
+ return
1023
+
1024
+ w = args.text_width
1025
+ for i, row in enumerate(results, start=1):
1026
+ kind = row["_kind"]
1027
+ fn = row["filename"]
1028
+ lang = row.get("language", "—")
1029
+ start = row.get("start") or {}
1030
+ end = row.get("end") or {}
1031
+ line_hint = ""
1032
+ if isinstance(start, dict) and "line" in start:
1033
+ el = (
1034
+ end["line"]
1035
+ if isinstance(end, dict) and "line" in end
1036
+ else start["line"]
1037
+ )
1038
+ line_hint = f" L{start['line']}-{el}"
1039
+ text = (row.get("text") or "").replace("\n", " ")
1040
+ preview = text if len(text) <= w else text[: w - 3] + "..."
1041
+ if row.get("_hybrid"):
1042
+ rank_s = f"hybrid RRF={float(row.get('_score', 0.0)):.4f}"
1043
+ else:
1044
+ rank_s = f"L2 distance={float(row['_distance']):.4f}"
1045
+ hints = row.get("_hints") or {}
1046
+ hint_s = ""
1047
+ if hints.get("primary_type_hint"):
1048
+ hint_s += f" | type:{hints['primary_type_hint']}"
1049
+ if hints.get("import_heavy"):
1050
+ hint_s += " | mostly-imports"
1051
+ role = row.get("role") or ""
1052
+ if role:
1053
+ hint_s += f" | role:{role}"
1054
+ ms = row.get("microservice") or ""
1055
+ if ms:
1056
+ hint_s += f" | microservice:{ms}"
1057
+ mod = row.get("module") or ""
1058
+ if mod and mod != ms:
1059
+ hint_s += f" | module:{mod}"
1060
+ comps = row.get("_score_components") or {}
1061
+ rw = comps.get("role_weight")
1062
+ if rw:
1063
+ hint_s += f" | role_weight:{rw:+.2f}"
1064
+ sb = comps.get("symbol_bonus")
1065
+ if sb:
1066
+ hint_s += f" | symbol_bonus:{sb:+.2f}"
1067
+ if row.get("_graph_expanded"):
1068
+ hint_s += " | graph"
1069
+ print(f"--- {i}. [{kind}] {rank_s} | {fn}{line_hint} | lang={lang}{hint_s}")
1070
+ print(preview)
1071
+ print()
1072
+
1073
+
1074
+ if __name__ == "__main__":
1075
+ main()