java-codebase-rag 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ast_java.py +2813 -0
- brownfield_events.py +58 -0
- build_ast_graph.py +3081 -0
- chunk_heuristics.py +62 -0
- graph_enrich.py +1681 -0
- index_common.py +10 -0
- java_codebase_rag/__init__.py +1 -0
- java_codebase_rag/cli.py +761 -0
- java_codebase_rag/cli_progress.py +52 -0
- java_codebase_rag/config.py +327 -0
- java_codebase_rag/pipeline.py +189 -0
- java_codebase_rag-0.1.0.dist-info/METADATA +818 -0
- java_codebase_rag-0.1.0.dist-info/RECORD +27 -0
- java_codebase_rag-0.1.0.dist-info/WHEEL +5 -0
- java_codebase_rag-0.1.0.dist-info/entry_points.txt +3 -0
- java_codebase_rag-0.1.0.dist-info/licenses/LICENSE +21 -0
- java_codebase_rag-0.1.0.dist-info/top_level.txt +17 -0
- java_index_flow_lancedb.py +398 -0
- java_index_v1_common.py +33 -0
- java_ontology.py +446 -0
- kuzu_queries.py +1989 -0
- mcp_hints.py +748 -0
- mcp_v2.py +1957 -0
- path_filtering.py +472 -0
- pr_analysis.py +534 -0
- search_lancedb.py +1075 -0
- server.py +578 -0
search_lancedb.py
ADDED
|
@@ -0,0 +1,1075 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Semantic search over LanceDB tables built by CocoIndex (java_index_flow_lancedb)."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import threading
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
import lancedb
|
|
15
|
+
import numpy as np
|
|
16
|
+
from sentence_transformers import SentenceTransformer
|
|
17
|
+
|
|
18
|
+
from chunk_heuristics import analyze_chunk, looks_like_code_identifier
|
|
19
|
+
from index_common import SBERT_MODEL
|
|
20
|
+
from java_codebase_rag.config import maybe_expand_embedding_model_path, resolved_sbert_model_for_process_env
|
|
21
|
+
|
|
22
|
+
TABLES: dict[str, str] = {
|
|
23
|
+
"java": "javacodeindex_java_code",
|
|
24
|
+
"sql": "sqlschemaindex_sql_schema",
|
|
25
|
+
"yaml": "yamlconfigindex_yaml_config",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
# Optional enrichment columns on the java chunk table (absent on older indexes).
|
|
29
|
+
JAVA_ENRICHED_COLUMNS: tuple[str, ...] = (
|
|
30
|
+
"package",
|
|
31
|
+
"module",
|
|
32
|
+
"microservice",
|
|
33
|
+
"primary_type_fqn",
|
|
34
|
+
"primary_type_kind",
|
|
35
|
+
"role",
|
|
36
|
+
"annotations_on_type",
|
|
37
|
+
"symbols",
|
|
38
|
+
"symbol_id",
|
|
39
|
+
"metadata",
|
|
40
|
+
"ontology_version",
|
|
41
|
+
"capabilities",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
VECTOR_COLUMN = "embedding"
|
|
45
|
+
_FTS_READY: set[tuple[str, str]] = set()
|
|
46
|
+
_FTS_LOCK = threading.Lock()
|
|
47
|
+
_SCHEMA_CACHE: dict[tuple[str, str], set[str]] = {}
|
|
48
|
+
_SCHEMA_LOCK = threading.Lock()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _table_columns(uri: str, lance_table_name: str, db_obj: object | None = None) -> set[str]:
|
|
52
|
+
key = (uri, lance_table_name)
|
|
53
|
+
with _SCHEMA_LOCK:
|
|
54
|
+
cached = _SCHEMA_CACHE.get(key)
|
|
55
|
+
if cached is not None:
|
|
56
|
+
return cached
|
|
57
|
+
db = db_obj if db_obj is not None else lancedb.connect(uri)
|
|
58
|
+
tbl = db.open_table(lance_table_name)
|
|
59
|
+
cols = {f.name for f in tbl.schema}
|
|
60
|
+
with _SCHEMA_LOCK:
|
|
61
|
+
_SCHEMA_CACHE[key] = cols
|
|
62
|
+
return cols
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _escape_sql_str(s: str) -> str:
|
|
66
|
+
return s.replace("'", "''")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _build_extra_predicates(
|
|
70
|
+
*,
|
|
71
|
+
columns: set[str],
|
|
72
|
+
role: str | None,
|
|
73
|
+
module: str | None,
|
|
74
|
+
microservice: str | None,
|
|
75
|
+
package_prefix: str | None,
|
|
76
|
+
fqn_in: list[str] | None,
|
|
77
|
+
role_in: list[str] | None = None,
|
|
78
|
+
exclude_roles: list[str] | None = None,
|
|
79
|
+
capability: str | None = None,
|
|
80
|
+
capability_in: list[str] | None = None,
|
|
81
|
+
) -> list[str]:
|
|
82
|
+
preds: list[str] = []
|
|
83
|
+
if role and "role" in columns:
|
|
84
|
+
preds.append(f"role = '{_escape_sql_str(role)}'")
|
|
85
|
+
|
|
86
|
+
# When both role_in and capability_in are set, combine as OR so that
|
|
87
|
+
# capability-only entrypoints (e.g. role=OTHER with MESSAGE_LISTENER)
|
|
88
|
+
# are not silently excluded by the role filter.
|
|
89
|
+
role_pred: str | None = None
|
|
90
|
+
if role_in and "role" in columns:
|
|
91
|
+
vals = ", ".join(f"'{_escape_sql_str(v)}'" for v in role_in)
|
|
92
|
+
role_pred = f"role IN ({vals})"
|
|
93
|
+
|
|
94
|
+
cap_in_pred: str | None = None
|
|
95
|
+
if capability_in and "capabilities" in columns:
|
|
96
|
+
# array_has is the preferred form in LanceDB >= 0.10 (verified against 0.30.2).
|
|
97
|
+
parts = [
|
|
98
|
+
f"array_has(capabilities, '{_escape_sql_str(c)}')"
|
|
99
|
+
for c in capability_in
|
|
100
|
+
]
|
|
101
|
+
cap_in_pred = "(" + " OR ".join(parts) + ")"
|
|
102
|
+
|
|
103
|
+
if role_pred and cap_in_pred:
|
|
104
|
+
preds.append(f"({role_pred} OR {cap_in_pred})")
|
|
105
|
+
elif role_pred:
|
|
106
|
+
preds.append(role_pred)
|
|
107
|
+
elif cap_in_pred:
|
|
108
|
+
preds.append(cap_in_pred)
|
|
109
|
+
|
|
110
|
+
if exclude_roles and "role" in columns:
|
|
111
|
+
vals = ", ".join(f"'{_escape_sql_str(v)}'" for v in exclude_roles)
|
|
112
|
+
preds.append(f"(role IS NULL OR role NOT IN ({vals}))")
|
|
113
|
+
if module and "module" in columns:
|
|
114
|
+
preds.append(f"module = '{_escape_sql_str(module)}'")
|
|
115
|
+
if microservice and "microservice" in columns:
|
|
116
|
+
preds.append(f"microservice = '{_escape_sql_str(microservice)}'")
|
|
117
|
+
if package_prefix and "package" in columns:
|
|
118
|
+
esc = _escape_sql_str(package_prefix)
|
|
119
|
+
preds.append(f"(package = '{esc}' OR package LIKE '{esc}.%')")
|
|
120
|
+
if fqn_in and "primary_type_fqn" in columns:
|
|
121
|
+
# LanceDB/Arrow SQL supports IN; quote each.
|
|
122
|
+
vals = ", ".join(f"'{_escape_sql_str(v)}'" for v in fqn_in)
|
|
123
|
+
preds.append(f"primary_type_fqn IN ({vals})")
|
|
124
|
+
if capability and "capabilities" in columns:
|
|
125
|
+
preds.append(f"array_has(capabilities, '{_escape_sql_str(capability)}')")
|
|
126
|
+
return preds
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def coerce_position_field(val: object) -> dict[str, object]:
|
|
130
|
+
"""LanceDB may return struct columns as JSON strings; normalize to a dict."""
|
|
131
|
+
if val is None:
|
|
132
|
+
return {}
|
|
133
|
+
if isinstance(val, dict):
|
|
134
|
+
return val
|
|
135
|
+
if isinstance(val, str):
|
|
136
|
+
try:
|
|
137
|
+
parsed = json.loads(val)
|
|
138
|
+
except json.JSONDecodeError:
|
|
139
|
+
return {}
|
|
140
|
+
return parsed if isinstance(parsed, dict) else {}
|
|
141
|
+
return {}
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
_IMPORT_DISTANCE_PENALTY = 0.08
|
|
145
|
+
_IMPORT_HYBRID_SCORE_FACTOR = 0.88
|
|
146
|
+
|
|
147
|
+
# Bonus for chunks whose declared symbols (method / field names) share tokens with
|
|
148
|
+
# the query. Behavioural queries like "what happens when a client message arrives"
|
|
149
|
+
# should float chunks containing `processClientMessage` above ones that only
|
|
150
|
+
# enqueue; this is a cheap, query-dependent signal computed at rank time.
|
|
151
|
+
_SYMBOL_MATCH_BONUS_PER_HIT = 0.03
|
|
152
|
+
_SYMBOL_MATCH_BONUS_CAP = 0.06
|
|
153
|
+
|
|
154
|
+
# Action verbs that typically mark behavioural entry points in this codebase.
|
|
155
|
+
# A chunk whose symbols begin with one of these verbs earns a small flat bump
|
|
156
|
+
# — again only for java chunks and only when role-filtering is off.
|
|
157
|
+
_ACTION_VERB_PREFIXES: tuple[str, ...] = (
|
|
158
|
+
"process", "handle", "on", "pick", "select", "assign",
|
|
159
|
+
"notify", "dispatch", "publish", "consume", "route",
|
|
160
|
+
"trigger", "enqueue", "distribute", "update", "create",
|
|
161
|
+
"apply", "resolve", "reassign", "close", "open",
|
|
162
|
+
)
|
|
163
|
+
_ACTION_VERB_BONUS = 0.02
|
|
164
|
+
|
|
165
|
+
# Type-name overlap bonus. The class name is a much stronger discovery signal
|
|
166
|
+
# than any individual method, because class naming in this codebase encodes
|
|
167
|
+
# the domain concept (`DistributionChunkService`, `OperatorSessionService`,
|
|
168
|
+
# `JoinOperatorController`). So we reward overlap between query tokens and the
|
|
169
|
+
# simple name of `primary_type_fqn` more heavily than per-method overlap, and
|
|
170
|
+
# we stack it on top of the existing `_symbol_bonus`.
|
|
171
|
+
_TYPE_MATCH_BONUS_PER_HIT = 0.05
|
|
172
|
+
_TYPE_MATCH_BONUS_CAP = 0.10
|
|
173
|
+
|
|
174
|
+
_STOPWORDS: frozenset[str] = frozenset({
|
|
175
|
+
"a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
|
|
176
|
+
"to", "of", "in", "on", "at", "by", "for", "with", "from", "as", "or",
|
|
177
|
+
"and", "but", "if", "then", "else", "when", "what", "how", "why", "does",
|
|
178
|
+
"do", "did", "has", "have", "had", "this", "that", "these", "those", "it",
|
|
179
|
+
"its", "new", "no", "not", "will", "would", "should", "can", "could",
|
|
180
|
+
"may", "might", "happens", "happen", "happened", "get", "gets", "got",
|
|
181
|
+
})
|
|
182
|
+
|
|
183
|
+
# Role-aware reweighting for Java chunks. Positive values favour actionable
|
|
184
|
+
# behavioural code (entrypoints, orchestrators, integrations) over configuration,
|
|
185
|
+
# schema, and persistence stubs for "what happens when..."-style queries.
|
|
186
|
+
# Applied to the similarity score (higher = better); distance-based sort subtracts
|
|
187
|
+
# the weight. Skipped when caller filters explicitly by role.
|
|
188
|
+
_ROLE_SCORE_WEIGHTS: dict[str, float] = {
|
|
189
|
+
"CONTROLLER": 0.10,
|
|
190
|
+
"SERVICE": 0.08,
|
|
191
|
+
"CLIENT": 0.06,
|
|
192
|
+
"COMPONENT": 0.03,
|
|
193
|
+
"REPOSITORY": 0.02,
|
|
194
|
+
"MAPPER": 0.00,
|
|
195
|
+
"OTHER": 0.00,
|
|
196
|
+
"ENTITY": -0.06,
|
|
197
|
+
"CONFIG": -0.10,
|
|
198
|
+
# DTOs are passive data carriers; they almost never answer "how/what
|
|
199
|
+
# happens" queries. Penalty is slightly stronger than ENTITY so a DTO
|
|
200
|
+
# with a great embedding match still loses to a mediocre SERVICE hit.
|
|
201
|
+
"DTO": -0.08,
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _query_tokens(query: str) -> set[str]:
|
|
206
|
+
"""Lowercased alpha-only tokens from the query, minus stopwords, len >= 3.
|
|
207
|
+
|
|
208
|
+
Used to score symbol-name overlap; we keep it simple and locale-free.
|
|
209
|
+
"""
|
|
210
|
+
out: set[str] = set()
|
|
211
|
+
cur: list[str] = []
|
|
212
|
+
|
|
213
|
+
def _flush() -> None:
|
|
214
|
+
if cur:
|
|
215
|
+
tok = "".join(cur).lower()
|
|
216
|
+
cur.clear()
|
|
217
|
+
if len(tok) >= 3 and tok not in _STOPWORDS:
|
|
218
|
+
out.add(tok)
|
|
219
|
+
|
|
220
|
+
for c in query:
|
|
221
|
+
if c.isalpha():
|
|
222
|
+
cur.append(c)
|
|
223
|
+
else:
|
|
224
|
+
_flush()
|
|
225
|
+
_flush()
|
|
226
|
+
return out
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _split_identifier(name: str) -> list[str]:
|
|
230
|
+
"""camelCase / snake_case -> lowercase token list."""
|
|
231
|
+
parts: list[str] = []
|
|
232
|
+
cur: list[str] = []
|
|
233
|
+
for c in name:
|
|
234
|
+
if c == "_":
|
|
235
|
+
if cur:
|
|
236
|
+
parts.append("".join(cur).lower())
|
|
237
|
+
cur = []
|
|
238
|
+
elif c.isupper() and cur:
|
|
239
|
+
parts.append("".join(cur).lower())
|
|
240
|
+
cur = [c]
|
|
241
|
+
else:
|
|
242
|
+
cur.append(c)
|
|
243
|
+
if cur:
|
|
244
|
+
parts.append("".join(cur).lower())
|
|
245
|
+
return [p for p in parts if p]
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _symbol_bonus(r: dict, query_toks: set[str]) -> float:
|
|
249
|
+
"""Symbol-name overlap + action-verb bump for java chunks.
|
|
250
|
+
|
|
251
|
+
Caps at `_SYMBOL_MATCH_BONUS_CAP + _ACTION_VERB_BONUS` to avoid runaway
|
|
252
|
+
ranks on chunks declaring many symbols.
|
|
253
|
+
"""
|
|
254
|
+
if str(r.get("_kind", "")) != "java":
|
|
255
|
+
return 0.0
|
|
256
|
+
raw = r.get("symbols") or []
|
|
257
|
+
if isinstance(raw, str):
|
|
258
|
+
# Legacy JSON-encoded list column; parse defensively.
|
|
259
|
+
try:
|
|
260
|
+
parsed = json.loads(raw)
|
|
261
|
+
raw = parsed if isinstance(parsed, list) else []
|
|
262
|
+
except Exception:
|
|
263
|
+
raw = []
|
|
264
|
+
symbols = [str(s) for s in raw if s]
|
|
265
|
+
|
|
266
|
+
overlap_hits = 0
|
|
267
|
+
has_action = False
|
|
268
|
+
for s in symbols:
|
|
269
|
+
bare = s.split("(", 1)[0].strip()
|
|
270
|
+
if not bare:
|
|
271
|
+
continue
|
|
272
|
+
toks = _split_identifier(bare)
|
|
273
|
+
if toks:
|
|
274
|
+
if toks[0] in _ACTION_VERB_PREFIXES:
|
|
275
|
+
has_action = True
|
|
276
|
+
if query_toks & set(toks):
|
|
277
|
+
overlap_hits += 1
|
|
278
|
+
|
|
279
|
+
bonus = min(overlap_hits * _SYMBOL_MATCH_BONUS_PER_HIT, _SYMBOL_MATCH_BONUS_CAP)
|
|
280
|
+
if has_action:
|
|
281
|
+
bonus += _ACTION_VERB_BONUS
|
|
282
|
+
|
|
283
|
+
# Type-name overlap: strongest single lexical signal for "which class is
|
|
284
|
+
# the answer?" queries. Uses the simple name of primary_type_fqn.
|
|
285
|
+
fqn = str(r.get("primary_type_fqn") or "")
|
|
286
|
+
if fqn:
|
|
287
|
+
simple = fqn.rsplit(".", 1)[-1]
|
|
288
|
+
type_toks = set(_split_identifier(simple))
|
|
289
|
+
type_hits = len(query_toks & type_toks)
|
|
290
|
+
if type_hits:
|
|
291
|
+
bonus += min(type_hits * _TYPE_MATCH_BONUS_PER_HIT, _TYPE_MATCH_BONUS_CAP)
|
|
292
|
+
return bonus
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _apply_chunk_hints(rows: list[dict]) -> None:
|
|
296
|
+
for r in rows:
|
|
297
|
+
lang = r.get("language") or ""
|
|
298
|
+
kind = str(r.get("_kind", ""))
|
|
299
|
+
if kind == "sql" and not lang:
|
|
300
|
+
lang = "sql"
|
|
301
|
+
if kind == "yaml" and not lang:
|
|
302
|
+
lang = "yaml"
|
|
303
|
+
h = analyze_chunk(r.get("text"), language=str(lang), kind=kind)
|
|
304
|
+
r["_hints"] = {
|
|
305
|
+
"primary_type_hint": h.primary_type_hint,
|
|
306
|
+
"import_heavy": h.import_heavy,
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def _role_weight(r: dict) -> float:
|
|
311
|
+
"""Effective role weight for a row, captured into `_score_components.role_weight`."""
|
|
312
|
+
comps = r.setdefault("_score_components", {})
|
|
313
|
+
cached = comps.get("role_weight")
|
|
314
|
+
if cached is not None:
|
|
315
|
+
return float(cached)
|
|
316
|
+
if r.get("_skip_role_weight") or str(r.get("_kind", "")) != "java":
|
|
317
|
+
comps["role_weight"] = 0.0
|
|
318
|
+
return 0.0
|
|
319
|
+
role = (r.get("role") or "").upper()
|
|
320
|
+
w = _ROLE_SCORE_WEIGHTS.get(role, 0.0)
|
|
321
|
+
comps["role_weight"] = w
|
|
322
|
+
return w
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _apply_symbol_bonus(rows: list[dict], query_toks: set[str]) -> None:
|
|
326
|
+
"""Pre-compute symbol-match bonus into `_score_components.symbol_bonus`."""
|
|
327
|
+
if not query_toks:
|
|
328
|
+
return
|
|
329
|
+
for r in rows:
|
|
330
|
+
if r.get("_skip_role_weight"):
|
|
331
|
+
# When the caller locked role, respect their intent everywhere.
|
|
332
|
+
continue
|
|
333
|
+
b = _symbol_bonus(r, query_toks)
|
|
334
|
+
if b:
|
|
335
|
+
r.setdefault("_score_components", {})["symbol_bonus"] = b
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _vector_sort_key(r: dict) -> float:
|
|
339
|
+
d = float(r["_distance"])
|
|
340
|
+
comps = r.setdefault("_score_components", {})
|
|
341
|
+
comps["distance"] = d
|
|
342
|
+
if r.get("_hints", {}).get("import_heavy"):
|
|
343
|
+
d += _IMPORT_DISTANCE_PENALTY
|
|
344
|
+
comps["import_penalty"] = _IMPORT_DISTANCE_PENALTY
|
|
345
|
+
d -= _role_weight(r)
|
|
346
|
+
d -= float(comps.get("symbol_bonus", 0.0))
|
|
347
|
+
return d
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def _hybrid_sort_key(r: dict) -> float:
|
|
351
|
+
s = float(r.get("_score", 0.0))
|
|
352
|
+
comps = r.setdefault("_score_components", {})
|
|
353
|
+
comps["hybrid_rrf"] = s
|
|
354
|
+
if r.get("_hints", {}).get("import_heavy"):
|
|
355
|
+
s *= _IMPORT_HYBRID_SCORE_FACTOR
|
|
356
|
+
comps["import_penalty"] = _IMPORT_HYBRID_SCORE_FACTOR
|
|
357
|
+
s += _role_weight(r)
|
|
358
|
+
s += float(comps.get("symbol_bonus", 0.0))
|
|
359
|
+
return -s
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def explain_score_components(
|
|
363
|
+
comps: dict[str, float] | None,
|
|
364
|
+
*,
|
|
365
|
+
role: str | None = None,
|
|
366
|
+
hybrid: bool = False,
|
|
367
|
+
graph_expanded: bool = False,
|
|
368
|
+
) -> str:
|
|
369
|
+
"""Compact human-readable 'why' string for a ranked hit.
|
|
370
|
+
|
|
371
|
+
Joins the interesting components of `_score_components` in a stable order
|
|
372
|
+
so agents can reason about rankings without chasing raw floats. Returns
|
|
373
|
+
"" if there's nothing worth mentioning.
|
|
374
|
+
"""
|
|
375
|
+
if not comps:
|
|
376
|
+
comps = {}
|
|
377
|
+
parts: list[str] = []
|
|
378
|
+
if hybrid:
|
|
379
|
+
rrf = comps.get("hybrid_rrf")
|
|
380
|
+
if rrf is not None:
|
|
381
|
+
parts.append(f"rrf={float(rrf):.3f}")
|
|
382
|
+
else:
|
|
383
|
+
d = comps.get("distance")
|
|
384
|
+
if d is not None:
|
|
385
|
+
parts.append(f"dist={float(d):.2f}")
|
|
386
|
+
rw = comps.get("role_weight")
|
|
387
|
+
if rw:
|
|
388
|
+
label = f"role:{role}" if role else "role"
|
|
389
|
+
parts.append(f"{label}:{float(rw):+.02f}")
|
|
390
|
+
sb = comps.get("symbol_bonus")
|
|
391
|
+
if sb:
|
|
392
|
+
parts.append(f"symbol:{float(sb):+.02f}")
|
|
393
|
+
ip = comps.get("import_penalty")
|
|
394
|
+
if ip:
|
|
395
|
+
parts.append(f"import_penalty:{float(ip):+.02f}")
|
|
396
|
+
if graph_expanded:
|
|
397
|
+
parts.append("graph")
|
|
398
|
+
return " ".join(parts)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def l2_distance_to_score(distance: float) -> float:
|
|
402
|
+
"""Map L2 distance to a similarity score for unit-normalized embeddings."""
|
|
403
|
+
return 1.0 - distance * distance / 2.0
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def _escape_like_fragment(s: str) -> str:
|
|
407
|
+
return s.replace("'", "''")
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def _escape_sql_like_pattern(s: str) -> str:
|
|
411
|
+
out: list[str] = []
|
|
412
|
+
for c in s:
|
|
413
|
+
if c in ("\\", "%", "_"):
|
|
414
|
+
out.append("\\" + c)
|
|
415
|
+
else:
|
|
416
|
+
out.append(c)
|
|
417
|
+
return "".join(out)
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def _build_path_predicate(path_substring: str) -> str:
|
|
421
|
+
pat = _escape_sql_like_pattern(path_substring)
|
|
422
|
+
pat = _escape_like_fragment(pat)
|
|
423
|
+
return f"filename LIKE '%{pat}%' ESCAPE '\\'"
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def ensure_text_fts_index(uri: str, lance_table_name: str) -> None:
|
|
427
|
+
key = (uri, lance_table_name)
|
|
428
|
+
with _FTS_LOCK:
|
|
429
|
+
if key in _FTS_READY:
|
|
430
|
+
return
|
|
431
|
+
db = lancedb.connect(uri)
|
|
432
|
+
tbl = db.open_table(lance_table_name)
|
|
433
|
+
try:
|
|
434
|
+
tbl.create_fts_index("text", replace=False)
|
|
435
|
+
except Exception as e:
|
|
436
|
+
low = str(e).lower()
|
|
437
|
+
if any(
|
|
438
|
+
w in low
|
|
439
|
+
for w in ("exist", "duplicate", "already", "same name")
|
|
440
|
+
):
|
|
441
|
+
pass
|
|
442
|
+
else:
|
|
443
|
+
raise
|
|
444
|
+
_FTS_READY.add(key)
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def _query_vector(model: SentenceTransformer, text: str) -> np.ndarray:
|
|
448
|
+
v = model.encode(
|
|
449
|
+
text,
|
|
450
|
+
convert_to_numpy=True,
|
|
451
|
+
normalize_embeddings=True,
|
|
452
|
+
show_progress_bar=False,
|
|
453
|
+
)
|
|
454
|
+
return np.asarray(v, dtype=np.float32)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def _combine_predicates(parts: list[str | None]) -> str | None:
|
|
458
|
+
clean = [p for p in parts if p]
|
|
459
|
+
if not clean:
|
|
460
|
+
return None
|
|
461
|
+
if len(clean) == 1:
|
|
462
|
+
return clean[0]
|
|
463
|
+
return " AND ".join(f"({p})" for p in clean)
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def _search_one_table(
|
|
467
|
+
table_name: str,
|
|
468
|
+
*,
|
|
469
|
+
uri: str,
|
|
470
|
+
db: object,
|
|
471
|
+
query_vec: np.ndarray,
|
|
472
|
+
limit: int,
|
|
473
|
+
path_predicate: str | None,
|
|
474
|
+
kind: str,
|
|
475
|
+
hybrid: bool,
|
|
476
|
+
fts_text: str | None,
|
|
477
|
+
extra_predicates: list[str] | None = None,
|
|
478
|
+
) -> list[dict]:
|
|
479
|
+
tbl = db.open_table(table_name)
|
|
480
|
+
has_lang = kind == "java"
|
|
481
|
+
table_cols = _table_columns(uri, table_name, db)
|
|
482
|
+
enriched_cols = table_cols if has_lang else set()
|
|
483
|
+
# `range_start` / `range_end` are needed downstream by `_attach_neighbor_context`
|
|
484
|
+
# to locate the chunk inside its file; select them whenever the schema has them.
|
|
485
|
+
base_cols = ["filename", "text", "start", "end"]
|
|
486
|
+
for col in ("range_start", "range_end"):
|
|
487
|
+
if col in table_cols:
|
|
488
|
+
base_cols.append(col)
|
|
489
|
+
java_extra = [c for c in JAVA_ENRICHED_COLUMNS if c in enriched_cols] if has_lang else []
|
|
490
|
+
combined_pred = _combine_predicates([path_predicate, *(extra_predicates or [])])
|
|
491
|
+
|
|
492
|
+
if hybrid:
|
|
493
|
+
ensure_text_fts_index(uri, table_name)
|
|
494
|
+
text_for_fts = fts_text if fts_text is not None else ""
|
|
495
|
+
columns = (
|
|
496
|
+
[*base_cols, "language", *java_extra]
|
|
497
|
+
if has_lang
|
|
498
|
+
else [*base_cols]
|
|
499
|
+
)
|
|
500
|
+
q = (
|
|
501
|
+
tbl.search(
|
|
502
|
+
query_type="hybrid",
|
|
503
|
+
vector_column_name=VECTOR_COLUMN,
|
|
504
|
+
)
|
|
505
|
+
.vector(query_vec)
|
|
506
|
+
.text(text_for_fts)
|
|
507
|
+
.select(columns)
|
|
508
|
+
.limit(limit)
|
|
509
|
+
)
|
|
510
|
+
if combined_pred:
|
|
511
|
+
q = q.where(combined_pred, prefilter=True)
|
|
512
|
+
rows = q.to_list()
|
|
513
|
+
for r in rows:
|
|
514
|
+
r["_kind"] = kind
|
|
515
|
+
rs = r.pop("_relevance_score", None)
|
|
516
|
+
r["_hybrid"] = True
|
|
517
|
+
if rs is not None:
|
|
518
|
+
r["_score"] = float(rs)
|
|
519
|
+
r["start"] = coerce_position_field(r.get("start"))
|
|
520
|
+
r["end"] = coerce_position_field(r.get("end"))
|
|
521
|
+
return rows
|
|
522
|
+
|
|
523
|
+
columns = (
|
|
524
|
+
[*base_cols, "language", *java_extra, "_distance"]
|
|
525
|
+
if has_lang
|
|
526
|
+
else [*base_cols, "_distance"]
|
|
527
|
+
)
|
|
528
|
+
q = tbl.search(query_vec, vector_column_name=VECTOR_COLUMN).select(
|
|
529
|
+
columns
|
|
530
|
+
).limit(limit)
|
|
531
|
+
if combined_pred:
|
|
532
|
+
q = q.where(combined_pred, prefilter=True)
|
|
533
|
+
rows = q.to_list()
|
|
534
|
+
for r in rows:
|
|
535
|
+
r["_kind"] = kind
|
|
536
|
+
r["_hybrid"] = False
|
|
537
|
+
r["start"] = coerce_position_field(r.get("start"))
|
|
538
|
+
r["end"] = coerce_position_field(r.get("end"))
|
|
539
|
+
return rows
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def _debug_ctx(msg: str) -> None:
|
|
543
|
+
"""Emit context-expansion diagnostics when JAVA_CODEBASE_RAG_DEBUG_CONTEXT is set.
|
|
544
|
+
|
|
545
|
+
Writes to stderr so it doesn't pollute MCP stdout. Cheap no-op otherwise.
|
|
546
|
+
"""
|
|
547
|
+
if os.environ.get("JAVA_CODEBASE_RAG_DEBUG_CONTEXT"):
|
|
548
|
+
print(f"[context_neighbors] {msg}", file=sys.stderr)
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def _attach_neighbor_context(
|
|
552
|
+
rows: list[dict], *, db: object, neighbors: int, uri: str | None = None,
|
|
553
|
+
) -> None:
|
|
554
|
+
"""Populate `_context_before` / `_context_after` with adjacent Java chunk text.
|
|
555
|
+
|
|
556
|
+
Strategy (in order):
|
|
557
|
+
1. Schema-aware scan of the java table, selecting only columns that exist
|
|
558
|
+
(`filename` + `text` always; `range_start`/`range_end` when present).
|
|
559
|
+
2. Sort the per-file bucket by `range_start` if available; otherwise keep
|
|
560
|
+
the table's natural order (good enough because chunks are produced in
|
|
561
|
+
file order by CocoIndex).
|
|
562
|
+
3. Locate each row's index via (a) range tuple match, (b) exact text match
|
|
563
|
+
as fallback. Missing both -> log and skip.
|
|
564
|
+
4. Any exception is logged (behind env flag) and the field stays empty; we
|
|
565
|
+
never break search because of context expansion.
|
|
566
|
+
"""
|
|
567
|
+
if neighbors <= 0:
|
|
568
|
+
return
|
|
569
|
+
java_rows = [r for r in rows if str(r.get("_kind", "")) == "java"]
|
|
570
|
+
if not java_rows:
|
|
571
|
+
_debug_ctx("no java rows in window; nothing to expand")
|
|
572
|
+
return
|
|
573
|
+
filenames = {str(r.get("filename", "")) for r in java_rows if r.get("filename")}
|
|
574
|
+
if not filenames:
|
|
575
|
+
_debug_ctx("java rows had no filename field; skipping")
|
|
576
|
+
return
|
|
577
|
+
|
|
578
|
+
java_table = TABLES["java"]
|
|
579
|
+
try:
|
|
580
|
+
tbl = db.open_table(java_table)
|
|
581
|
+
except Exception as exc:
|
|
582
|
+
_debug_ctx(f"open_table({java_table}) failed: {exc!r}")
|
|
583
|
+
return
|
|
584
|
+
|
|
585
|
+
# Discover which positional columns the index actually carries. Older
|
|
586
|
+
# indexes may predate `range_start`/`range_end`; newer ones always have
|
|
587
|
+
# them. Asking for a missing column makes the whole scan fail.
|
|
588
|
+
try:
|
|
589
|
+
schema_cols = _table_columns(uri, java_table, db) if uri else {f.name for f in tbl.schema}
|
|
590
|
+
except Exception as exc:
|
|
591
|
+
_debug_ctx(f"schema lookup failed: {exc!r}")
|
|
592
|
+
schema_cols = set()
|
|
593
|
+
|
|
594
|
+
has_range = {"range_start", "range_end"}.issubset(schema_cols)
|
|
595
|
+
scan_cols = ["filename", "text"]
|
|
596
|
+
if has_range:
|
|
597
|
+
scan_cols.extend(("range_start", "range_end"))
|
|
598
|
+
|
|
599
|
+
try:
|
|
600
|
+
in_list = ", ".join(f"'{_escape_sql_str(f)}'" for f in filenames)
|
|
601
|
+
scanner = tbl.to_lance().scanner(
|
|
602
|
+
filter=f"filename IN ({in_list})",
|
|
603
|
+
columns=scan_cols,
|
|
604
|
+
)
|
|
605
|
+
all_chunks = scanner.to_table().to_pylist()
|
|
606
|
+
except Exception as exc:
|
|
607
|
+
_debug_ctx(f"bucket scan failed (cols={scan_cols}): {exc!r}")
|
|
608
|
+
return
|
|
609
|
+
|
|
610
|
+
if not all_chunks:
|
|
611
|
+
_debug_ctx(f"bucket scan returned 0 chunks for {len(filenames)} filenames")
|
|
612
|
+
return
|
|
613
|
+
|
|
614
|
+
by_file: dict[str, list[dict]] = {}
|
|
615
|
+
for ch in all_chunks:
|
|
616
|
+
by_file.setdefault(str(ch.get("filename", "")), []).append(ch)
|
|
617
|
+
if has_range:
|
|
618
|
+
for lst in by_file.values():
|
|
619
|
+
lst.sort(
|
|
620
|
+
key=lambda c: (int(c.get("range_start") or 0), int(c.get("range_end") or 0))
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
attached = 0
|
|
624
|
+
for r in java_rows:
|
|
625
|
+
fn = str(r.get("filename", ""))
|
|
626
|
+
bucket = by_file.get(fn, [])
|
|
627
|
+
if not bucket:
|
|
628
|
+
_debug_ctx(f"no bucket for filename={fn!r}")
|
|
629
|
+
continue
|
|
630
|
+
|
|
631
|
+
idx: int | None = None
|
|
632
|
+
if has_range:
|
|
633
|
+
start = int(r.get("range_start") or 0)
|
|
634
|
+
end = int(r.get("range_end") or 0)
|
|
635
|
+
if start or end:
|
|
636
|
+
idx = next(
|
|
637
|
+
(
|
|
638
|
+
i for i, c in enumerate(bucket)
|
|
639
|
+
if int(c.get("range_start") or -1) == start
|
|
640
|
+
and int(c.get("range_end") or -1) == end
|
|
641
|
+
),
|
|
642
|
+
None,
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
if idx is None:
|
|
646
|
+
r_text = str(r.get("text") or "")
|
|
647
|
+
if r_text:
|
|
648
|
+
idx = next(
|
|
649
|
+
(i for i, c in enumerate(bucket) if str(c.get("text") or "") == r_text),
|
|
650
|
+
None,
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
if idx is None:
|
|
654
|
+
_debug_ctx(
|
|
655
|
+
f"could not locate chunk in bucket (file={fn!r}, "
|
|
656
|
+
f"has_range={has_range}, bucket_size={len(bucket)})"
|
|
657
|
+
)
|
|
658
|
+
continue
|
|
659
|
+
|
|
660
|
+
before_parts = [str(c.get("text") or "") for c in bucket[max(0, idx - neighbors):idx]]
|
|
661
|
+
after_parts = [str(c.get("text") or "") for c in bucket[idx + 1 : idx + 1 + neighbors]]
|
|
662
|
+
r["_context_before"] = "\n".join(before_parts)
|
|
663
|
+
r["_context_after"] = "\n".join(after_parts)
|
|
664
|
+
attached += 1
|
|
665
|
+
|
|
666
|
+
_debug_ctx(f"attached context to {attached}/{len(java_rows)} java rows")
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def _graph_expand_merge(
|
|
670
|
+
vector_rows: list[dict],
|
|
671
|
+
*,
|
|
672
|
+
query_vec: np.ndarray,
|
|
673
|
+
db: object,
|
|
674
|
+
uri: str,
|
|
675
|
+
limit: int,
|
|
676
|
+
extra_predicates: list[str],
|
|
677
|
+
expand_depth: int,
|
|
678
|
+
kuzu_path: str | None,
|
|
679
|
+
) -> list[dict]:
|
|
680
|
+
"""Expand vector top-k through the Kuzu graph and fuse (RRF) with the original list."""
|
|
681
|
+
# Lazy import so the module works without kuzu installed when graph_expand=False.
|
|
682
|
+
try:
|
|
683
|
+
from kuzu_queries import KuzuGraph
|
|
684
|
+
except Exception:
|
|
685
|
+
return vector_rows
|
|
686
|
+
|
|
687
|
+
if not KuzuGraph.exists(kuzu_path):
|
|
688
|
+
return vector_rows
|
|
689
|
+
|
|
690
|
+
seed_fqns = sorted({r.get("primary_type_fqn") for r in vector_rows if r.get("primary_type_fqn")})
|
|
691
|
+
if not seed_fqns:
|
|
692
|
+
return vector_rows
|
|
693
|
+
|
|
694
|
+
try:
|
|
695
|
+
graph = KuzuGraph.get(kuzu_path)
|
|
696
|
+
structural = graph.expand_fqns(seed_fqns, depth=expand_depth)
|
|
697
|
+
method_pairs = graph.expand_methods(
|
|
698
|
+
seed_fqns, depth=expand_depth, exclude_external=True,
|
|
699
|
+
)
|
|
700
|
+
expand_weight_by_fqn: dict[str, float] = {}
|
|
701
|
+
for f in structural:
|
|
702
|
+
if f:
|
|
703
|
+
expand_weight_by_fqn[f] = max(expand_weight_by_fqn.get(f, 0.0), 1.0)
|
|
704
|
+
for f, conf in method_pairs:
|
|
705
|
+
if f:
|
|
706
|
+
expand_weight_by_fqn[f] = max(expand_weight_by_fqn.get(f, 0.0), conf)
|
|
707
|
+
neighbor_fqns = list(dict.fromkeys(
|
|
708
|
+
list(structural) + [f for f, _ in method_pairs],
|
|
709
|
+
))
|
|
710
|
+
except Exception:
|
|
711
|
+
return vector_rows
|
|
712
|
+
|
|
713
|
+
novel = [fqn for fqn in neighbor_fqns if fqn and fqn not in set(seed_fqns)]
|
|
714
|
+
if not novel:
|
|
715
|
+
return vector_rows
|
|
716
|
+
|
|
717
|
+
extra = list(extra_predicates)
|
|
718
|
+
extra.extend(_build_extra_predicates(
|
|
719
|
+
columns=_table_columns(uri, TABLES["java"], db),
|
|
720
|
+
role=None, module=None, microservice=None,
|
|
721
|
+
package_prefix=None, fqn_in=novel,
|
|
722
|
+
))
|
|
723
|
+
|
|
724
|
+
try:
|
|
725
|
+
graph_rows = _search_one_table(
|
|
726
|
+
TABLES["java"],
|
|
727
|
+
uri=uri, db=db, query_vec=query_vec,
|
|
728
|
+
limit=max(limit, 20),
|
|
729
|
+
path_predicate=None, kind="java",
|
|
730
|
+
hybrid=False, fts_text=None,
|
|
731
|
+
extra_predicates=extra,
|
|
732
|
+
)
|
|
733
|
+
except Exception:
|
|
734
|
+
return vector_rows
|
|
735
|
+
_apply_chunk_hints(graph_rows)
|
|
736
|
+
graph_rows.sort(key=_vector_sort_key)
|
|
737
|
+
for r in graph_rows:
|
|
738
|
+
r["_graph_expanded"] = True
|
|
739
|
+
r["_graph_expand_weight"] = expand_weight_by_fqn.get(
|
|
740
|
+
r.get("primary_type_fqn"), 1.0,
|
|
741
|
+
)
|
|
742
|
+
fused = _rrf_merge(
|
|
743
|
+
[vector_rows, graph_rows],
|
|
744
|
+
row_weight_for_list_index=[
|
|
745
|
+
None,
|
|
746
|
+
lambda row: float(row.get("_graph_expand_weight", 1.0)),
|
|
747
|
+
],
|
|
748
|
+
)
|
|
749
|
+
return fused
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
def _rrf_merge(
|
|
753
|
+
lists: list[list[dict]],
|
|
754
|
+
*,
|
|
755
|
+
k: int = 60,
|
|
756
|
+
row_weight_for_list_index: list[Callable[[dict], float] | None] | None = None,
|
|
757
|
+
) -> list[dict]:
|
|
758
|
+
"""Reciprocal-rank-fuse several ranked lists of chunk rows.
|
|
759
|
+
|
|
760
|
+
Rows are deduplicated by (filename, range_start, range_end). The merged
|
|
761
|
+
rows get a `_rrf_score` field so callers can inspect or re-sort.
|
|
762
|
+
|
|
763
|
+
When ``row_weight_for_list_index`` is set, its length must match ``lists``;
|
|
764
|
+
a non-None entry is a callable ``row -> weight`` multiplied into that list's
|
|
765
|
+
rank contribution (``None`` means weight ``1.0`` for every row).
|
|
766
|
+
"""
|
|
767
|
+
pool: dict[tuple, dict] = {}
|
|
768
|
+
for li, ranked in enumerate(lists):
|
|
769
|
+
wfn: Callable[[dict], float] | None = None
|
|
770
|
+
if row_weight_for_list_index is not None and li < len(row_weight_for_list_index):
|
|
771
|
+
wfn = row_weight_for_list_index[li]
|
|
772
|
+
for rank, row in enumerate(ranked):
|
|
773
|
+
key = (row.get("filename"), row.get("range_start"), row.get("range_end"))
|
|
774
|
+
existing = pool.get(key)
|
|
775
|
+
weight = 1.0 if wfn is None else float(wfn(row))
|
|
776
|
+
contribution = weight * (1.0 / (k + rank + 1))
|
|
777
|
+
if existing is None:
|
|
778
|
+
row["_rrf_score"] = contribution
|
|
779
|
+
pool[key] = row
|
|
780
|
+
else:
|
|
781
|
+
existing["_rrf_score"] = float(existing.get("_rrf_score", 0.0)) + contribution
|
|
782
|
+
merged = list(pool.values())
|
|
783
|
+
merged.sort(key=lambda r: -float(r.get("_rrf_score", 0.0)))
|
|
784
|
+
return merged
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def run_search(
|
|
788
|
+
query: str,
|
|
789
|
+
*,
|
|
790
|
+
uri: str,
|
|
791
|
+
table_keys: list[str],
|
|
792
|
+
limit: int,
|
|
793
|
+
path_substring: str | None,
|
|
794
|
+
model_name: str,
|
|
795
|
+
device: str | None,
|
|
796
|
+
offset: int = 0,
|
|
797
|
+
model: SentenceTransformer | None = None,
|
|
798
|
+
hybrid: bool = False,
|
|
799
|
+
fts_text: str | None = None,
|
|
800
|
+
auto_hybrid: bool = False,
|
|
801
|
+
role: str | None = None,
|
|
802
|
+
module: str | None = None,
|
|
803
|
+
microservice: str | None = None,
|
|
804
|
+
package_prefix: str | None = None,
|
|
805
|
+
graph_expand: bool = False,
|
|
806
|
+
expand_depth: int = 1,
|
|
807
|
+
kuzu_path: str | None = None,
|
|
808
|
+
context_neighbors: int = 0,
|
|
809
|
+
role_in: list[str] | None = None,
|
|
810
|
+
exclude_roles: list[str] | None = None,
|
|
811
|
+
capability: str | None = None,
|
|
812
|
+
capability_in: list[str] | None = None,
|
|
813
|
+
) -> list[dict]:
|
|
814
|
+
effective_hybrid = hybrid
|
|
815
|
+
effective_fts = fts_text
|
|
816
|
+
if (
|
|
817
|
+
auto_hybrid
|
|
818
|
+
and not hybrid
|
|
819
|
+
and len(table_keys) == 1
|
|
820
|
+
and looks_like_code_identifier(query)
|
|
821
|
+
):
|
|
822
|
+
effective_hybrid = True
|
|
823
|
+
if effective_fts is None:
|
|
824
|
+
effective_fts = query.strip()
|
|
825
|
+
|
|
826
|
+
if effective_hybrid and len(table_keys) != 1:
|
|
827
|
+
raise ValueError(
|
|
828
|
+
"hybrid search requires exactly one table; "
|
|
829
|
+
"use table java, sql, or yaml (not all)."
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
path_predicate = (
|
|
833
|
+
_build_path_predicate(path_substring) if path_substring else None
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
if model is None:
|
|
837
|
+
model = SentenceTransformer(
|
|
838
|
+
model_name,
|
|
839
|
+
device=device,
|
|
840
|
+
trust_remote_code=True,
|
|
841
|
+
)
|
|
842
|
+
query_vec = _query_vector(model, query)
|
|
843
|
+
fts_for_hybrid = effective_fts if effective_fts is not None else query
|
|
844
|
+
|
|
845
|
+
db = lancedb.connect(uri)
|
|
846
|
+
need = max(limit + offset, 1)
|
|
847
|
+
|
|
848
|
+
extra_java = _build_extra_predicates(
|
|
849
|
+
columns=_table_columns(uri, TABLES["java"], db),
|
|
850
|
+
role=role, module=module, microservice=microservice,
|
|
851
|
+
package_prefix=package_prefix, fqn_in=None,
|
|
852
|
+
role_in=role_in, exclude_roles=exclude_roles,
|
|
853
|
+
capability=capability, capability_in=capability_in,
|
|
854
|
+
) if "java" in table_keys else []
|
|
855
|
+
|
|
856
|
+
skip_role_weight = bool(role or role_in)
|
|
857
|
+
query_toks = _query_tokens(query)
|
|
858
|
+
|
|
859
|
+
if len(table_keys) == 1:
|
|
860
|
+
key = table_keys[0]
|
|
861
|
+
preds = extra_java if key == "java" else []
|
|
862
|
+
rows = _search_one_table(
|
|
863
|
+
TABLES[key],
|
|
864
|
+
uri=uri,
|
|
865
|
+
db=db,
|
|
866
|
+
query_vec=query_vec,
|
|
867
|
+
limit=need,
|
|
868
|
+
path_predicate=path_predicate,
|
|
869
|
+
kind=key,
|
|
870
|
+
hybrid=effective_hybrid,
|
|
871
|
+
fts_text=fts_for_hybrid,
|
|
872
|
+
extra_predicates=preds,
|
|
873
|
+
)
|
|
874
|
+
_apply_chunk_hints(rows)
|
|
875
|
+
if skip_role_weight:
|
|
876
|
+
for r in rows:
|
|
877
|
+
r["_skip_role_weight"] = True
|
|
878
|
+
_apply_symbol_bonus(rows, query_toks)
|
|
879
|
+
if effective_hybrid:
|
|
880
|
+
rows.sort(key=_hybrid_sort_key)
|
|
881
|
+
else:
|
|
882
|
+
rows.sort(key=_vector_sort_key)
|
|
883
|
+
|
|
884
|
+
if graph_expand and key == "java" and expand_depth > 0:
|
|
885
|
+
rows = _graph_expand_merge(
|
|
886
|
+
rows,
|
|
887
|
+
query_vec=query_vec,
|
|
888
|
+
db=db,
|
|
889
|
+
uri=uri,
|
|
890
|
+
limit=need,
|
|
891
|
+
extra_predicates=extra_java,
|
|
892
|
+
expand_depth=expand_depth,
|
|
893
|
+
kuzu_path=kuzu_path,
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
window = rows[offset : offset + limit]
|
|
897
|
+
if context_neighbors > 0 and key == "java":
|
|
898
|
+
_attach_neighbor_context(window, db=db, neighbors=context_neighbors, uri=uri)
|
|
899
|
+
return window
|
|
900
|
+
|
|
901
|
+
merged: list[dict] = []
|
|
902
|
+
per_table = max(need * 3, need)
|
|
903
|
+
for key in table_keys:
|
|
904
|
+
preds = extra_java if key == "java" else []
|
|
905
|
+
merged.extend(
|
|
906
|
+
_search_one_table(
|
|
907
|
+
TABLES[key],
|
|
908
|
+
uri=uri,
|
|
909
|
+
db=db,
|
|
910
|
+
query_vec=query_vec,
|
|
911
|
+
limit=per_table,
|
|
912
|
+
path_predicate=path_predicate,
|
|
913
|
+
kind=key,
|
|
914
|
+
hybrid=False,
|
|
915
|
+
fts_text=None,
|
|
916
|
+
extra_predicates=preds,
|
|
917
|
+
)
|
|
918
|
+
)
|
|
919
|
+
_apply_chunk_hints(merged)
|
|
920
|
+
if skip_role_weight:
|
|
921
|
+
for r in merged:
|
|
922
|
+
r["_skip_role_weight"] = True
|
|
923
|
+
_apply_symbol_bonus(merged, query_toks)
|
|
924
|
+
merged.sort(key=_vector_sort_key)
|
|
925
|
+
window = merged[offset : offset + limit]
|
|
926
|
+
if context_neighbors > 0:
|
|
927
|
+
_attach_neighbor_context(window, db=db, neighbors=context_neighbors, uri=uri)
|
|
928
|
+
return window
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def main() -> None:
|
|
932
|
+
parser = argparse.ArgumentParser(
|
|
933
|
+
description="Vector search in LanceDB index.",
|
|
934
|
+
)
|
|
935
|
+
parser.add_argument("query", help="Natural-language search query")
|
|
936
|
+
parser.add_argument(
|
|
937
|
+
"--table",
|
|
938
|
+
choices=["java", "sql", "yaml", "all"],
|
|
939
|
+
default="java",
|
|
940
|
+
)
|
|
941
|
+
parser.add_argument("--limit", type=int, default=10)
|
|
942
|
+
parser.add_argument(
|
|
943
|
+
"--lancedb-uri",
|
|
944
|
+
default=os.environ.get("JAVA_CODEBASE_RAG_INDEX_DIR", "")
|
|
945
|
+
or str((Path.cwd() / ".java-codebase-rag").resolve()),
|
|
946
|
+
)
|
|
947
|
+
parser.add_argument("--path-contains", metavar="SUBSTR", default=None)
|
|
948
|
+
parser.add_argument(
|
|
949
|
+
"--model",
|
|
950
|
+
default=None,
|
|
951
|
+
help=(
|
|
952
|
+
"sentence-transformers hub id or local model directory "
|
|
953
|
+
f"(default: SBERT_MODEL env or {SBERT_MODEL!r})"
|
|
954
|
+
),
|
|
955
|
+
)
|
|
956
|
+
parser.add_argument("--device", default=None)
|
|
957
|
+
parser.add_argument("--text-width", type=int, default=320)
|
|
958
|
+
parser.add_argument("--hybrid", action="store_true")
|
|
959
|
+
parser.add_argument("--fts-text", metavar="TEXT", default=None)
|
|
960
|
+
parser.add_argument("--auto-hybrid", action="store_true")
|
|
961
|
+
parser.add_argument("--role", default=None)
|
|
962
|
+
parser.add_argument("--module", default=None,
|
|
963
|
+
help="Filter to a single Maven/Gradle module name.")
|
|
964
|
+
parser.add_argument("--microservice", default=None,
|
|
965
|
+
help="Filter to a single deployable microservice (top-level dir under project root).")
|
|
966
|
+
parser.add_argument("--package-prefix", default=None)
|
|
967
|
+
parser.add_argument("--graph-expand", action="store_true")
|
|
968
|
+
parser.add_argument("--expand-depth", type=int, default=1)
|
|
969
|
+
parser.add_argument("--kuzu-path", default=None)
|
|
970
|
+
parser.add_argument(
|
|
971
|
+
"--context-neighbors", type=int, default=0,
|
|
972
|
+
help="Attach N adjacent chunks per hit as surrounding context (Java only).",
|
|
973
|
+
)
|
|
974
|
+
args = parser.parse_args()
|
|
975
|
+
|
|
976
|
+
uri_path = Path(args.lancedb_uri)
|
|
977
|
+
if not uri_path.exists():
|
|
978
|
+
print(f"Error: LanceDB path missing: {uri_path.resolve()}", file=sys.stderr)
|
|
979
|
+
sys.exit(1)
|
|
980
|
+
|
|
981
|
+
keys = list(TABLES) if args.table == "all" else [args.table]
|
|
982
|
+
if args.hybrid and args.table == "all":
|
|
983
|
+
print("Error: --hybrid needs a single --table.", file=sys.stderr)
|
|
984
|
+
sys.exit(2)
|
|
985
|
+
if args.auto_hybrid and args.table == "all":
|
|
986
|
+
print("Error: --auto-hybrid needs a single --table.", file=sys.stderr)
|
|
987
|
+
sys.exit(2)
|
|
988
|
+
|
|
989
|
+
raw_model = args.model
|
|
990
|
+
if raw_model is None or not str(raw_model).strip():
|
|
991
|
+
model_name = resolved_sbert_model_for_process_env(SBERT_MODEL)
|
|
992
|
+
else:
|
|
993
|
+
model_name = maybe_expand_embedding_model_path(str(raw_model).strip())
|
|
994
|
+
|
|
995
|
+
try:
|
|
996
|
+
results = run_search(
|
|
997
|
+
args.query,
|
|
998
|
+
uri=str(uri_path),
|
|
999
|
+
table_keys=keys,
|
|
1000
|
+
limit=args.limit,
|
|
1001
|
+
path_substring=args.path_contains,
|
|
1002
|
+
model_name=model_name,
|
|
1003
|
+
device=args.device,
|
|
1004
|
+
hybrid=args.hybrid,
|
|
1005
|
+
fts_text=args.fts_text,
|
|
1006
|
+
auto_hybrid=args.auto_hybrid,
|
|
1007
|
+
role=args.role,
|
|
1008
|
+
module=args.module,
|
|
1009
|
+
microservice=args.microservice,
|
|
1010
|
+
package_prefix=args.package_prefix,
|
|
1011
|
+
graph_expand=args.graph_expand,
|
|
1012
|
+
expand_depth=args.expand_depth,
|
|
1013
|
+
kuzu_path=args.kuzu_path,
|
|
1014
|
+
context_neighbors=args.context_neighbors,
|
|
1015
|
+
)
|
|
1016
|
+
except Exception as e:
|
|
1017
|
+
print(f"Search failed: {e}", file=sys.stderr)
|
|
1018
|
+
sys.exit(1)
|
|
1019
|
+
|
|
1020
|
+
if not results:
|
|
1021
|
+
print("No results.")
|
|
1022
|
+
return
|
|
1023
|
+
|
|
1024
|
+
w = args.text_width
|
|
1025
|
+
for i, row in enumerate(results, start=1):
|
|
1026
|
+
kind = row["_kind"]
|
|
1027
|
+
fn = row["filename"]
|
|
1028
|
+
lang = row.get("language", "—")
|
|
1029
|
+
start = row.get("start") or {}
|
|
1030
|
+
end = row.get("end") or {}
|
|
1031
|
+
line_hint = ""
|
|
1032
|
+
if isinstance(start, dict) and "line" in start:
|
|
1033
|
+
el = (
|
|
1034
|
+
end["line"]
|
|
1035
|
+
if isinstance(end, dict) and "line" in end
|
|
1036
|
+
else start["line"]
|
|
1037
|
+
)
|
|
1038
|
+
line_hint = f" L{start['line']}-{el}"
|
|
1039
|
+
text = (row.get("text") or "").replace("\n", " ")
|
|
1040
|
+
preview = text if len(text) <= w else text[: w - 3] + "..."
|
|
1041
|
+
if row.get("_hybrid"):
|
|
1042
|
+
rank_s = f"hybrid RRF={float(row.get('_score', 0.0)):.4f}"
|
|
1043
|
+
else:
|
|
1044
|
+
rank_s = f"L2 distance={float(row['_distance']):.4f}"
|
|
1045
|
+
hints = row.get("_hints") or {}
|
|
1046
|
+
hint_s = ""
|
|
1047
|
+
if hints.get("primary_type_hint"):
|
|
1048
|
+
hint_s += f" | type:{hints['primary_type_hint']}"
|
|
1049
|
+
if hints.get("import_heavy"):
|
|
1050
|
+
hint_s += " | mostly-imports"
|
|
1051
|
+
role = row.get("role") or ""
|
|
1052
|
+
if role:
|
|
1053
|
+
hint_s += f" | role:{role}"
|
|
1054
|
+
ms = row.get("microservice") or ""
|
|
1055
|
+
if ms:
|
|
1056
|
+
hint_s += f" | microservice:{ms}"
|
|
1057
|
+
mod = row.get("module") or ""
|
|
1058
|
+
if mod and mod != ms:
|
|
1059
|
+
hint_s += f" | module:{mod}"
|
|
1060
|
+
comps = row.get("_score_components") or {}
|
|
1061
|
+
rw = comps.get("role_weight")
|
|
1062
|
+
if rw:
|
|
1063
|
+
hint_s += f" | role_weight:{rw:+.2f}"
|
|
1064
|
+
sb = comps.get("symbol_bonus")
|
|
1065
|
+
if sb:
|
|
1066
|
+
hint_s += f" | symbol_bonus:{sb:+.2f}"
|
|
1067
|
+
if row.get("_graph_expanded"):
|
|
1068
|
+
hint_s += " | graph"
|
|
1069
|
+
print(f"--- {i}. [{kind}] {rank_s} | {fn}{line_hint} | lang={lang}{hint_s}")
|
|
1070
|
+
print(preview)
|
|
1071
|
+
print()
|
|
1072
|
+
|
|
1073
|
+
|
|
1074
|
+
if __name__ == "__main__":
|
|
1075
|
+
main()
|