code-review-graph-codeblackwell 2.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_review_graph/__init__.py +20 -0
- code_review_graph/__main__.py +4 -0
- code_review_graph/analysis.py +410 -0
- code_review_graph/changes.py +409 -0
- code_review_graph/cli.py +1255 -0
- code_review_graph/communities.py +874 -0
- code_review_graph/constants.py +23 -0
- code_review_graph/context_savings.py +317 -0
- code_review_graph/custom_languages.py +322 -0
- code_review_graph/daemon.py +1009 -0
- code_review_graph/daemon_cli.py +320 -0
- code_review_graph/docs/LLM-OPTIMIZED-REFERENCE.md +71 -0
- code_review_graph/embeddings.py +1006 -0
- code_review_graph/enrich.py +303 -0
- code_review_graph/eval/__init__.py +33 -0
- code_review_graph/eval/benchmarks/__init__.py +1 -0
- code_review_graph/eval/benchmarks/agent_baseline.py +193 -0
- code_review_graph/eval/benchmarks/build_performance.py +60 -0
- code_review_graph/eval/benchmarks/flow_completeness.py +36 -0
- code_review_graph/eval/benchmarks/impact_accuracy.py +220 -0
- code_review_graph/eval/benchmarks/multi_hop_retrieval.py +125 -0
- code_review_graph/eval/benchmarks/search_quality.py +59 -0
- code_review_graph/eval/benchmarks/token_efficiency.py +143 -0
- code_review_graph/eval/configs/code-review-graph.yaml +50 -0
- code_review_graph/eval/configs/express.yaml +45 -0
- code_review_graph/eval/configs/fastapi.yaml +48 -0
- code_review_graph/eval/configs/flask.yaml +50 -0
- code_review_graph/eval/configs/gin.yaml +51 -0
- code_review_graph/eval/configs/httpx.yaml +48 -0
- code_review_graph/eval/reporter.py +301 -0
- code_review_graph/eval/runner.py +211 -0
- code_review_graph/eval/scorer.py +85 -0
- code_review_graph/eval/token_benchmark.py +182 -0
- code_review_graph/exports.py +409 -0
- code_review_graph/flows.py +698 -0
- code_review_graph/graph.py +1427 -0
- code_review_graph/graph_diff.py +122 -0
- code_review_graph/hints.py +384 -0
- code_review_graph/incremental.py +1245 -0
- code_review_graph/jedi_resolver.py +303 -0
- code_review_graph/main.py +1079 -0
- code_review_graph/memory.py +142 -0
- code_review_graph/migrations.py +284 -0
- code_review_graph/parser.py +6957 -0
- code_review_graph/postprocessing.py +134 -0
- code_review_graph/prompts.py +159 -0
- code_review_graph/refactor.py +852 -0
- code_review_graph/registry.py +319 -0
- code_review_graph/rescript_resolver.py +206 -0
- code_review_graph/search.py +447 -0
- code_review_graph/skills.py +1481 -0
- code_review_graph/spring_resolver.py +200 -0
- code_review_graph/temporal_resolver.py +199 -0
- code_review_graph/token_benchmark.py +125 -0
- code_review_graph/tools/__init__.py +156 -0
- code_review_graph/tools/_common.py +176 -0
- code_review_graph/tools/analysis_tools.py +184 -0
- code_review_graph/tools/build.py +541 -0
- code_review_graph/tools/community_tools.py +246 -0
- code_review_graph/tools/context.py +152 -0
- code_review_graph/tools/docs.py +274 -0
- code_review_graph/tools/flows_tools.py +176 -0
- code_review_graph/tools/query.py +692 -0
- code_review_graph/tools/refactor_tools.py +168 -0
- code_review_graph/tools/registry_tools.py +125 -0
- code_review_graph/tools/review.py +477 -0
- code_review_graph/tsconfig_resolver.py +257 -0
- code_review_graph/visualization.py +2184 -0
- code_review_graph/wiki.py +305 -0
- code_review_graph_codeblackwell-2.3.6.post1.dist-info/METADATA +718 -0
- code_review_graph_codeblackwell-2.3.6.post1.dist-info/RECORD +74 -0
- code_review_graph_codeblackwell-2.3.6.post1.dist-info/WHEEL +4 -0
- code_review_graph_codeblackwell-2.3.6.post1.dist-info/entry_points.txt +3 -0
- code_review_graph_codeblackwell-2.3.6.post1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1006 @@
|
|
|
1
|
+
"""Vector embedding support for semantic code search.
|
|
2
|
+
|
|
3
|
+
Supports multiple providers:
|
|
4
|
+
1. Local (sentence-transformers) - Private, fast, offline.
|
|
5
|
+
2. Google Gemini - High-quality, cloud-based. Requires explicit opt-in.
|
|
6
|
+
3. MiniMax (embo-01) - High-quality 1536-dim cloud embeddings. Requires MINIMAX_API_KEY.
|
|
7
|
+
4. OpenAI-compatible - Any endpoint speaking OpenAI /v1/embeddings (real OpenAI,
|
|
8
|
+
Azure OpenAI, self-hosted gateways like new-api / LiteLLM / vLLM / LocalAI / Ollama).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import hashlib
|
|
14
|
+
import logging
|
|
15
|
+
import os
|
|
16
|
+
import re
|
|
17
|
+
import sqlite3
|
|
18
|
+
import struct
|
|
19
|
+
import sys
|
|
20
|
+
import time
|
|
21
|
+
from abc import ABC, abstractmethod
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any
|
|
24
|
+
from urllib.parse import urlparse
|
|
25
|
+
|
|
26
|
+
from . import __version__ as _crg_version
|
|
27
|
+
from .graph import GraphNode, GraphStore, node_to_dict
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
# Sent on every cloud-provider HTTP request. Some providers (e.g. Fireworks)
|
|
32
|
+
# sit behind Cloudflare and reject the urllib default ``Python-urllib/X.Y``
|
|
33
|
+
# UA with HTTP 403 / error 1010 ("browser signature banned"). A real UA gets
|
|
34
|
+
# us through and gives upstream a way to identify CRG-driven traffic.
|
|
35
|
+
_USER_AGENT = (
|
|
36
|
+
f"code-review-graph/{_crg_version} "
|
|
37
|
+
"(+https://github.com/tirth8205/code-review-graph)"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# ---------------------------------------------------------------------------
|
|
41
|
+
# Provider Interface and Implementations
|
|
42
|
+
# ---------------------------------------------------------------------------
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class EmbeddingProvider(ABC):
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def embed_query(self, text: str) -> list[float]:
|
|
52
|
+
"""Embed a search query (may use a different task type than indexing)."""
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def dimension(self) -> int:
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
@abstractmethod
|
|
62
|
+
def name(self) -> str:
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
LOCAL_DEFAULT_MODEL = "all-MiniLM-L6-v2"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# Process-wide cache of loaded sentence-transformer models, keyed by model name.
|
|
70
|
+
# Populated by ``prewarm_local_embeddings()`` at server startup (see ``main.main``)
|
|
71
|
+
# and by ``LocalEmbeddingProvider._get_model`` on first lazy load. Sharing the
|
|
72
|
+
# loaded model across ``LocalEmbeddingProvider`` instances avoids re-importing
|
|
73
|
+
# ``sentence_transformers`` + ``torch`` from worker threads, which deadlocks
|
|
74
|
+
# ``semantic_search_nodes_tool`` on Windows stdio MCP (#385 fixed the peer
|
|
75
|
+
# tools via ``asyncio.to_thread``; this cache fixes the remaining case where
|
|
76
|
+
# torch DLL / OpenMP init runs inside an executor thread).
|
|
77
|
+
_MODEL_CACHE: dict[str, Any] = {}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def prewarm_local_embeddings(model_name: str | None = None) -> None:
|
|
81
|
+
"""Eagerly load the local sentence-transformer model on the calling thread.
|
|
82
|
+
|
|
83
|
+
Call this from the **main thread** before entering an asyncio event loop
|
|
84
|
+
(e.g. before ``mcp.run()``) on Windows to prevent a deadlock where lazy-
|
|
85
|
+
loading ``sentence_transformers`` + ``torch`` inside a FastMCP executor
|
|
86
|
+
worker thread blocks indefinitely on DLL init / OpenMP thread-pool
|
|
87
|
+
registration.
|
|
88
|
+
|
|
89
|
+
No-op when ``sentence-transformers`` is not installed (cloud-provider
|
|
90
|
+
setups remain unaffected) or when the configured model is already cached.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
model_name: Optional override; falls back to the ``CRG_EMBEDDING_MODEL``
|
|
94
|
+
environment variable and then to ``LOCAL_DEFAULT_MODEL``.
|
|
95
|
+
"""
|
|
96
|
+
try:
|
|
97
|
+
from sentence_transformers import SentenceTransformer # noqa: F401
|
|
98
|
+
except ImportError:
|
|
99
|
+
return # cloud-only setup: nothing to pre-warm
|
|
100
|
+
|
|
101
|
+
resolved = model_name or os.environ.get(
|
|
102
|
+
"CRG_EMBEDDING_MODEL", LOCAL_DEFAULT_MODEL
|
|
103
|
+
)
|
|
104
|
+
if resolved in _MODEL_CACHE:
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
_MODEL_CACHE[resolved] = LocalEmbeddingProvider(resolved)._get_model()
|
|
109
|
+
except Exception as exc: # pragma: no cover — best-effort startup hook
|
|
110
|
+
logger.warning("prewarm_local_embeddings(%s) skipped: %s", resolved, exc)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class LocalEmbeddingProvider(EmbeddingProvider):
|
|
114
|
+
def __init__(self, model_name: str | None = None) -> None:
|
|
115
|
+
self._model_name = model_name or os.environ.get(
|
|
116
|
+
"CRG_EMBEDDING_MODEL", LOCAL_DEFAULT_MODEL
|
|
117
|
+
)
|
|
118
|
+
self._model = None # Lazy-loaded
|
|
119
|
+
|
|
120
|
+
def _get_model(self):
|
|
121
|
+
if self._model is None:
|
|
122
|
+
# Check the process-wide cache first — populated either by a prior
|
|
123
|
+
# provider instance or by ``prewarm_local_embeddings`` at startup.
|
|
124
|
+
cached = _MODEL_CACHE.get(self._model_name)
|
|
125
|
+
if cached is not None:
|
|
126
|
+
self._model = cached
|
|
127
|
+
return self._model
|
|
128
|
+
try:
|
|
129
|
+
from sentence_transformers import SentenceTransformer
|
|
130
|
+
# Check environment variable, default to False to prevent RCE
|
|
131
|
+
_rce_val = os.environ.get("CRG_ALLOW_REMOTE_CODE", "0")
|
|
132
|
+
allow_remote_code = _rce_val.lower() in ("1", "true", "yes")
|
|
133
|
+
|
|
134
|
+
self._model = SentenceTransformer(
|
|
135
|
+
self._model_name,
|
|
136
|
+
trust_remote_code=allow_remote_code,
|
|
137
|
+
)
|
|
138
|
+
_MODEL_CACHE[self._model_name] = self._model
|
|
139
|
+
except ImportError:
|
|
140
|
+
raise ImportError(
|
|
141
|
+
"sentence-transformers not installed. "
|
|
142
|
+
"Run: pip install code-review-graph[embeddings]"
|
|
143
|
+
)
|
|
144
|
+
return self._model
|
|
145
|
+
|
|
146
|
+
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
147
|
+
model = self._get_model()
|
|
148
|
+
vectors = model.encode(texts, show_progress_bar=False)
|
|
149
|
+
return [v.tolist() for v in vectors]
|
|
150
|
+
|
|
151
|
+
def embed_query(self, text: str) -> list[float]:
|
|
152
|
+
return self.embed([text])[0]
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def dimension(self) -> int:
|
|
156
|
+
model = self._get_model()
|
|
157
|
+
return model.get_sentence_embedding_dimension()
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def name(self) -> str:
|
|
161
|
+
return f"local:{self._model_name}"
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class GoogleEmbeddingProvider(EmbeddingProvider):
|
|
165
|
+
def __init__(self, api_key: str, model: str = "gemini-embedding-001") -> None:
|
|
166
|
+
try:
|
|
167
|
+
from google import genai
|
|
168
|
+
self._client = genai.Client(api_key=api_key)
|
|
169
|
+
self.model = model
|
|
170
|
+
self._dimension: int | None = None
|
|
171
|
+
except ImportError:
|
|
172
|
+
raise ImportError(
|
|
173
|
+
"google-generativeai not installed. "
|
|
174
|
+
"Run: pip install code-review-graph[google-embeddings]"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
178
|
+
batch_size = 100
|
|
179
|
+
results = []
|
|
180
|
+
for i in range(0, len(texts), batch_size):
|
|
181
|
+
batch = texts[i:i + batch_size]
|
|
182
|
+
response = self._call_with_retry(
|
|
183
|
+
lambda b=batch: self._client.models.embed_content(
|
|
184
|
+
model=self.model,
|
|
185
|
+
contents=b,
|
|
186
|
+
config={"task_type": "RETRIEVAL_DOCUMENT"},
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
results.extend([e.values for e in response.embeddings])
|
|
190
|
+
if self._dimension is None and results:
|
|
191
|
+
self._dimension = len(results[0])
|
|
192
|
+
return results
|
|
193
|
+
|
|
194
|
+
@staticmethod
|
|
195
|
+
def _call_with_retry(fn, max_retries: int = 3):
|
|
196
|
+
"""Call fn with exponential backoff on transient API errors."""
|
|
197
|
+
for attempt in range(max_retries):
|
|
198
|
+
try:
|
|
199
|
+
return fn()
|
|
200
|
+
except Exception as e:
|
|
201
|
+
# Retry on rate-limit (429) or server errors (5xx)
|
|
202
|
+
err_str = str(e)
|
|
203
|
+
is_retryable = "429" in err_str or "500" in err_str or "503" in err_str
|
|
204
|
+
if not is_retryable or attempt == max_retries - 1:
|
|
205
|
+
raise
|
|
206
|
+
wait = 2 ** attempt
|
|
207
|
+
logger.warning("Gemini API error (attempt %d/%d), retrying in %ds: %s",
|
|
208
|
+
attempt + 1, max_retries, wait, e)
|
|
209
|
+
time.sleep(wait)
|
|
210
|
+
|
|
211
|
+
def embed_query(self, text: str) -> list[float]:
|
|
212
|
+
response = self._call_with_retry(
|
|
213
|
+
lambda: self._client.models.embed_content(
|
|
214
|
+
model=self.model,
|
|
215
|
+
contents=[text],
|
|
216
|
+
config={"task_type": "RETRIEVAL_QUERY"},
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
vec = response.embeddings[0].values
|
|
220
|
+
if self._dimension is None:
|
|
221
|
+
self._dimension = len(vec)
|
|
222
|
+
return vec
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def dimension(self) -> int:
|
|
226
|
+
if self._dimension is not None:
|
|
227
|
+
return self._dimension
|
|
228
|
+
# Default for gemini-embedding-001; updated dynamically after first call
|
|
229
|
+
return 768
|
|
230
|
+
|
|
231
|
+
@property
|
|
232
|
+
def name(self) -> str:
|
|
233
|
+
return f"google:{self.model}"
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class MiniMaxEmbeddingProvider(EmbeddingProvider):
|
|
237
|
+
"""MiniMax embo-01 embedding provider (1536 dimensions).
|
|
238
|
+
|
|
239
|
+
Uses the MiniMax Embeddings API (https://api.minimax.io/v1/embeddings)
|
|
240
|
+
with the embo-01 model. Requires the MINIMAX_API_KEY environment variable.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
_ENDPOINT = "https://api.minimax.io/v1/embeddings"
|
|
244
|
+
_MODEL = "embo-01"
|
|
245
|
+
_DIMENSION = 1536
|
|
246
|
+
|
|
247
|
+
def __init__(self, api_key: str) -> None:
|
|
248
|
+
self._api_key = api_key
|
|
249
|
+
|
|
250
|
+
def _call_api(self, texts: list[str], task_type: str) -> list[list[float]]:
|
|
251
|
+
import json as _json
|
|
252
|
+
import urllib.request
|
|
253
|
+
|
|
254
|
+
payload = _json.dumps({
|
|
255
|
+
"model": self._MODEL,
|
|
256
|
+
"texts": texts,
|
|
257
|
+
"type": task_type,
|
|
258
|
+
}).encode("utf-8")
|
|
259
|
+
|
|
260
|
+
req = urllib.request.Request(
|
|
261
|
+
self._ENDPOINT,
|
|
262
|
+
data=payload,
|
|
263
|
+
headers={
|
|
264
|
+
"Content-Type": "application/json",
|
|
265
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
266
|
+
"User-Agent": _USER_AGENT,
|
|
267
|
+
"Accept": "application/json",
|
|
268
|
+
},
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
max_retries = 3
|
|
272
|
+
for attempt in range(max_retries):
|
|
273
|
+
try:
|
|
274
|
+
import ssl
|
|
275
|
+
_ssl_ctx = ssl.create_default_context()
|
|
276
|
+
with urllib.request.urlopen(req, timeout=60, context=_ssl_ctx) as resp: # nosec B310
|
|
277
|
+
body = _json.loads(resp.read().decode("utf-8"))
|
|
278
|
+
|
|
279
|
+
base_resp = body.get("base_resp", {})
|
|
280
|
+
if base_resp.get("status_code", 0) != 0:
|
|
281
|
+
raise RuntimeError(
|
|
282
|
+
f"MiniMax API error: {base_resp.get('status_msg', 'unknown')}"
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
return body["vectors"]
|
|
286
|
+
except Exception as e:
|
|
287
|
+
err_str = str(e)
|
|
288
|
+
is_retryable = "429" in err_str or "500" in err_str or "503" in err_str
|
|
289
|
+
if not is_retryable or attempt == max_retries - 1:
|
|
290
|
+
raise
|
|
291
|
+
wait = 2 ** attempt
|
|
292
|
+
logger.warning(
|
|
293
|
+
"MiniMax API error (attempt %d/%d), retrying in %ds: %s",
|
|
294
|
+
attempt + 1, max_retries, wait, e,
|
|
295
|
+
)
|
|
296
|
+
time.sleep(wait)
|
|
297
|
+
|
|
298
|
+
return [] # unreachable, but keeps mypy happy
|
|
299
|
+
|
|
300
|
+
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
301
|
+
batch_size = 100
|
|
302
|
+
results: list[list[float]] = []
|
|
303
|
+
for i in range(0, len(texts), batch_size):
|
|
304
|
+
batch = texts[i:i + batch_size]
|
|
305
|
+
results.extend(self._call_api(batch, "db"))
|
|
306
|
+
return results
|
|
307
|
+
|
|
308
|
+
def embed_query(self, text: str) -> list[float]:
|
|
309
|
+
return self._call_api([text], "query")[0]
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def dimension(self) -> int:
|
|
313
|
+
return self._DIMENSION
|
|
314
|
+
|
|
315
|
+
@property
|
|
316
|
+
def name(self) -> str:
|
|
317
|
+
return f"minimax:{self._MODEL}"
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
321
|
+
"""OpenAI-compatible embedding provider.
|
|
322
|
+
|
|
323
|
+
Works with any endpoint that speaks the OpenAI ``/v1/embeddings`` schema:
|
|
324
|
+
- Real OpenAI API (``https://api.openai.com/v1``)
|
|
325
|
+
- Azure OpenAI
|
|
326
|
+
- Self-hosted gateways: new-api, LiteLLM, vLLM, LocalAI, Ollama (openai mode)
|
|
327
|
+
|
|
328
|
+
Provider identity in ``name`` includes both the model and the endpoint
|
|
329
|
+
host (``openai:{model}@{host}``), so switching base URL while keeping the
|
|
330
|
+
same model ID re-partitions the embeddings table and forces a clean
|
|
331
|
+
re-embed. This is the only defense against silently mixing vector spaces
|
|
332
|
+
from different backends (e.g. real OpenAI vs. an OpenAI-compatible
|
|
333
|
+
gateway that ships different weights under the same model name).
|
|
334
|
+
|
|
335
|
+
Dimension is detected from the first response and frozen; switching the
|
|
336
|
+
``model`` in the environment also changes ``provider.name`` and triggers
|
|
337
|
+
re-embed via the same isolation key.
|
|
338
|
+
"""
|
|
339
|
+
|
|
340
|
+
_DEFAULT_BATCH_SIZE = 100
|
|
341
|
+
|
|
342
|
+
# Default ports by scheme; stripped from the host_key so the user can't
|
|
343
|
+
# accidentally force a re-embed by toggling an explicit default port.
|
|
344
|
+
_DEFAULT_PORTS = {"http": 80, "https": 443}
|
|
345
|
+
|
|
346
|
+
def __init__(
|
|
347
|
+
self,
|
|
348
|
+
api_key: str,
|
|
349
|
+
base_url: str,
|
|
350
|
+
model: str,
|
|
351
|
+
dimension: int | None = None,
|
|
352
|
+
timeout: int = 120,
|
|
353
|
+
batch_size: int | None = None,
|
|
354
|
+
) -> None:
|
|
355
|
+
self._api_key = api_key
|
|
356
|
+
self._base_url = base_url.rstrip("/")
|
|
357
|
+
self._model = model
|
|
358
|
+
self._dimension = dimension
|
|
359
|
+
self._timeout = timeout
|
|
360
|
+
self._batch_size = batch_size or self._DEFAULT_BATCH_SIZE
|
|
361
|
+
self._host_key = self._make_host_key(self._base_url)
|
|
362
|
+
|
|
363
|
+
@classmethod
|
|
364
|
+
def _make_host_key(cls, base_url: str) -> str:
|
|
365
|
+
"""Normalize the identity key used in ``provider.name``.
|
|
366
|
+
|
|
367
|
+
Codex review pushed this well past naive ``netloc`` because that
|
|
368
|
+
alone has three leaks:
|
|
369
|
+
|
|
370
|
+
1. ``netloc`` preserves ``userinfo`` (``user:pass@host``) — we'd
|
|
371
|
+
persist credentials into the DB's ``embeddings.provider`` column.
|
|
372
|
+
Use ``hostname`` instead.
|
|
373
|
+
2. Default ports (``:80`` for http, ``:443`` for https) are
|
|
374
|
+
semantically identical to omitting the port; keeping them would
|
|
375
|
+
cause spurious re-embeds when the user just spelled the URL
|
|
376
|
+
differently.
|
|
377
|
+
3. Path is part of the backend identity for path-routed gateways:
|
|
378
|
+
``https://gw/openai/v1`` and ``https://gw/vendor-b/v1`` front
|
|
379
|
+
different models and must not share cached vectors.
|
|
380
|
+
"""
|
|
381
|
+
parsed = urlparse(base_url)
|
|
382
|
+
hostname = (parsed.hostname or "").lower()
|
|
383
|
+
scheme = (parsed.scheme or "").lower()
|
|
384
|
+
port = parsed.port
|
|
385
|
+
if port and port != cls._DEFAULT_PORTS.get(scheme):
|
|
386
|
+
# Bracket IPv6 literals when appending a port.
|
|
387
|
+
host_part = f"[{hostname}]:{port}" if ":" in hostname else f"{hostname}:{port}"
|
|
388
|
+
else:
|
|
389
|
+
host_part = hostname
|
|
390
|
+
# Preserve path routing. Trim any trailing slash and any
|
|
391
|
+
# ``/embeddings`` suffix that callers may have included — we append
|
|
392
|
+
# that ourselves when building the request URL.
|
|
393
|
+
path = (parsed.path or "").rstrip("/")
|
|
394
|
+
if path.endswith("/embeddings"):
|
|
395
|
+
path = path[: -len("/embeddings")].rstrip("/")
|
|
396
|
+
# Include scheme: http and https to the same host+path front
|
|
397
|
+
# different endpoints in practice (plaintext vs TLS, dev vs prod
|
|
398
|
+
# gateway), and sharing cached vectors across them is the same
|
|
399
|
+
# silent-mixing failure mode as switching base URL entirely.
|
|
400
|
+
return f"{scheme}://{host_part}{path}" if path else f"{scheme}://{host_part}"
|
|
401
|
+
|
|
402
|
+
def _call_api(self, texts: list[str]) -> list[list[float]]:
|
|
403
|
+
import http.client
|
|
404
|
+
import json as _json
|
|
405
|
+
import socket
|
|
406
|
+
import ssl
|
|
407
|
+
import urllib.error
|
|
408
|
+
import urllib.request
|
|
409
|
+
|
|
410
|
+
body: dict[str, Any] = {"model": self._model, "input": texts}
|
|
411
|
+
# OpenAI v3 models (text-embedding-3-*) support dimension reduction;
|
|
412
|
+
# only forward the param when the user explicitly pinned one.
|
|
413
|
+
if self._dimension is not None:
|
|
414
|
+
body["dimensions"] = self._dimension
|
|
415
|
+
|
|
416
|
+
payload = _json.dumps(body).encode("utf-8")
|
|
417
|
+
req = urllib.request.Request(
|
|
418
|
+
f"{self._base_url}/embeddings",
|
|
419
|
+
data=payload,
|
|
420
|
+
headers={
|
|
421
|
+
"Content-Type": "application/json",
|
|
422
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
423
|
+
"User-Agent": _USER_AGENT,
|
|
424
|
+
"Accept": "application/json",
|
|
425
|
+
},
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
max_retries = 3
|
|
429
|
+
for attempt in range(max_retries):
|
|
430
|
+
try:
|
|
431
|
+
_ssl_ctx = ssl.create_default_context()
|
|
432
|
+
try:
|
|
433
|
+
with urllib.request.urlopen( # nosec B310
|
|
434
|
+
req, timeout=self._timeout, context=_ssl_ctx,
|
|
435
|
+
) as resp:
|
|
436
|
+
raw = resp.read().decode("utf-8")
|
|
437
|
+
except urllib.error.HTTPError as http_err:
|
|
438
|
+
# 429 / 5xx: re-raise and let the outer retry loop handle it.
|
|
439
|
+
# (We must not convert to RuntimeError here or retry below
|
|
440
|
+
# can't tell it was a transient HTTP failure.)
|
|
441
|
+
if http_err.code == 429 or 500 <= http_err.code < 600:
|
|
442
|
+
raise
|
|
443
|
+
# Other 4xx: surface the API error body instead of a bare
|
|
444
|
+
# "400 Bad Request" — gateways like new-api return JSON
|
|
445
|
+
# with the real reason (batch size limits, invalid model,
|
|
446
|
+
# etc.) which is far more actionable.
|
|
447
|
+
try:
|
|
448
|
+
err_body = http_err.read().decode("utf-8", errors="replace")
|
|
449
|
+
except Exception:
|
|
450
|
+
err_body = ""
|
|
451
|
+
err_msg = err_body or str(http_err)
|
|
452
|
+
try:
|
|
453
|
+
parsed = _json.loads(err_body)
|
|
454
|
+
if isinstance(parsed, dict) and "error" in parsed:
|
|
455
|
+
err_obj = parsed["error"]
|
|
456
|
+
err_msg = (
|
|
457
|
+
err_obj.get("message", err_msg)
|
|
458
|
+
if isinstance(err_obj, dict) else str(err_obj)
|
|
459
|
+
)
|
|
460
|
+
except Exception: # nosec B110
|
|
461
|
+
# Non-JSON error body is fine: we already seeded
|
|
462
|
+
# err_msg with the raw body above, so fall through.
|
|
463
|
+
pass
|
|
464
|
+
raise RuntimeError(
|
|
465
|
+
f"OpenAI API HTTP {http_err.code}: {err_msg}"
|
|
466
|
+
) from http_err
|
|
467
|
+
|
|
468
|
+
response = _json.loads(raw)
|
|
469
|
+
|
|
470
|
+
if "error" in response:
|
|
471
|
+
err = response["error"]
|
|
472
|
+
msg = err.get("message", "unknown") if isinstance(err, dict) else str(err)
|
|
473
|
+
raise RuntimeError(f"OpenAI API error: {msg}")
|
|
474
|
+
|
|
475
|
+
data = response.get("data", [])
|
|
476
|
+
if not data:
|
|
477
|
+
raise RuntimeError("OpenAI API returned empty data")
|
|
478
|
+
# OpenAI spec: data[i].index maps to input[i], but some
|
|
479
|
+
# compatible gateways re-order results or drop entries on
|
|
480
|
+
# partial failure, and others omit `index` entirely. Three
|
|
481
|
+
# disjoint cases:
|
|
482
|
+
# 1. All items have a valid int ``index``: must form a
|
|
483
|
+
# permutation of 0..N-1, then sort and use.
|
|
484
|
+
# 2. NO item carries an ``index`` field: trust server
|
|
485
|
+
# order, only verify count matches.
|
|
486
|
+
# 3. Anything in between (partial indices, str indices,
|
|
487
|
+
# missing on some): refuse. Zipping server order in
|
|
488
|
+
# that case would happily misalign the indexed items.
|
|
489
|
+
any_has_index = any("index" in item for item in data)
|
|
490
|
+
all_int_index = all(
|
|
491
|
+
isinstance(item.get("index"), int) for item in data
|
|
492
|
+
)
|
|
493
|
+
if all_int_index:
|
|
494
|
+
expected = set(range(len(texts)))
|
|
495
|
+
indices = [int(item["index"]) for item in data]
|
|
496
|
+
if len(set(indices)) != len(indices) or set(indices) != expected:
|
|
497
|
+
raise RuntimeError(
|
|
498
|
+
"OpenAI API returned malformed indices "
|
|
499
|
+
f"(got {indices}, expected permutation of "
|
|
500
|
+
f"0..{len(texts) - 1}) — refusing to misalign vectors."
|
|
501
|
+
)
|
|
502
|
+
data = sorted(data, key=lambda item: int(item["index"]))
|
|
503
|
+
elif not any_has_index:
|
|
504
|
+
if len(data) != len(texts):
|
|
505
|
+
raise RuntimeError(
|
|
506
|
+
f"OpenAI API returned {len(data)} embeddings for "
|
|
507
|
+
f"{len(texts)} inputs with no index field — "
|
|
508
|
+
"refusing to misalign vectors."
|
|
509
|
+
)
|
|
510
|
+
else:
|
|
511
|
+
# Mixed: some items have index, others don't (or carry
|
|
512
|
+
# non-int index). Server order would silently misplace
|
|
513
|
+
# the indexed items, so we refuse.
|
|
514
|
+
raise RuntimeError(
|
|
515
|
+
"OpenAI API returned mixed indexed/unindexed data — "
|
|
516
|
+
"refusing to misalign vectors."
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
vectors = [item["embedding"] for item in data]
|
|
520
|
+
if vectors and self._dimension is None:
|
|
521
|
+
self._dimension = len(vectors[0])
|
|
522
|
+
return vectors
|
|
523
|
+
|
|
524
|
+
except Exception as e:
|
|
525
|
+
# Retryable = HTTP 429/5xx, network/timeout/TLS issues.
|
|
526
|
+
# Non-retryable = HTTP 4xx (other), malformed responses,
|
|
527
|
+
# misaligned data length — those are caller-side bugs that
|
|
528
|
+
# will keep failing on retry.
|
|
529
|
+
is_retryable = False
|
|
530
|
+
if isinstance(e, urllib.error.HTTPError):
|
|
531
|
+
is_retryable = e.code == 429 or 500 <= e.code < 600
|
|
532
|
+
elif isinstance(e, (
|
|
533
|
+
urllib.error.URLError,
|
|
534
|
+
socket.timeout,
|
|
535
|
+
TimeoutError,
|
|
536
|
+
ConnectionError,
|
|
537
|
+
ssl.SSLError,
|
|
538
|
+
# Reverse proxies and edge gateways surface transient
|
|
539
|
+
# disconnects as these stdlib classes. Real incidents
|
|
540
|
+
# have been observed on Cloudflare-fronted endpoints
|
|
541
|
+
# and on LiteLLM when upstream providers hiccup.
|
|
542
|
+
http.client.IncompleteRead,
|
|
543
|
+
http.client.BadStatusLine,
|
|
544
|
+
http.client.RemoteDisconnected,
|
|
545
|
+
)):
|
|
546
|
+
is_retryable = True
|
|
547
|
+
if not is_retryable or attempt == max_retries - 1:
|
|
548
|
+
raise
|
|
549
|
+
wait = 2 ** attempt
|
|
550
|
+
logger.warning(
|
|
551
|
+
"OpenAI embeddings API error (attempt %d/%d), retrying in %ds: %s",
|
|
552
|
+
attempt + 1, max_retries, wait, e,
|
|
553
|
+
)
|
|
554
|
+
time.sleep(wait)
|
|
555
|
+
|
|
556
|
+
return [] # unreachable
|
|
557
|
+
|
|
558
|
+
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
559
|
+
if not texts:
|
|
560
|
+
return []
|
|
561
|
+
results: list[list[float]] = []
|
|
562
|
+
for i in range(0, len(texts), self._batch_size):
|
|
563
|
+
results.extend(self._call_api(texts[i:i + self._batch_size]))
|
|
564
|
+
return results
|
|
565
|
+
|
|
566
|
+
def embed_query(self, text: str) -> list[float]:
|
|
567
|
+
return self._call_api([text])[0]
|
|
568
|
+
|
|
569
|
+
@property
|
|
570
|
+
def dimension(self) -> int:
|
|
571
|
+
if self._dimension is not None:
|
|
572
|
+
return self._dimension
|
|
573
|
+
# Default for text-embedding-3-small; updated after first call.
|
|
574
|
+
return 1536
|
|
575
|
+
|
|
576
|
+
@property
|
|
577
|
+
def name(self) -> str:
|
|
578
|
+
# Endpoint-aware identity: model alone is NOT enough — two backends
|
|
579
|
+
# can serve the same model ID with different weights or dimensions,
|
|
580
|
+
# and re-using cached embeddings across them silently corrupts
|
|
581
|
+
# semantic ranking. Including the host partitions the embeddings
|
|
582
|
+
# table so switching CRG_OPENAI_BASE_URL triggers a safe re-embed.
|
|
583
|
+
return f"openai:{self._model}@{self._host_key}"
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
CLOUD_PROVIDERS = {"google", "minimax", "openai"}
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def _is_localhost_url(url: str) -> bool:
|
|
590
|
+
"""Return True if url points to a localhost host (never treat as cloud egress).
|
|
591
|
+
|
|
592
|
+
Uses urlparse.hostname so we compare the actual host, not a substring
|
|
593
|
+
match that could be fooled by e.g. ``https://my-openai.127.0.0.1.nip.io``.
|
|
594
|
+
"""
|
|
595
|
+
try:
|
|
596
|
+
host = (urlparse(url).hostname or "").lower()
|
|
597
|
+
except Exception:
|
|
598
|
+
return False
|
|
599
|
+
# nosec B104: we're *matching* a URL hostname, not binding a listener.
|
|
600
|
+
return host in {"127.0.0.1", "localhost", "0.0.0.0", "::1"} # nosec B104
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def _warn_cloud_egress(provider_name: str) -> None:
|
|
604
|
+
"""Print a stderr warning before a cloud embedding provider is used.
|
|
605
|
+
|
|
606
|
+
The warning is suppressed when ``CRG_ACCEPT_CLOUD_EMBEDDINGS=1`` is
|
|
607
|
+
set in the environment, so scripted / CI workloads can acknowledge
|
|
608
|
+
once and move on. Use stderr (never stdin/input) to stay compatible
|
|
609
|
+
with the MCP stdio transport — anything we write to stdout would
|
|
610
|
+
corrupt the JSON-RPC stream. See: #174
|
|
611
|
+
"""
|
|
612
|
+
if os.environ.get("CRG_ACCEPT_CLOUD_EMBEDDINGS", "").strip() == "1":
|
|
613
|
+
return
|
|
614
|
+
print(
|
|
615
|
+
f"\n⚠️ code-review-graph: about to embed code via the '{provider_name}' "
|
|
616
|
+
"cloud provider.\n"
|
|
617
|
+
" Your source code (function names, docstrings, file paths) will be "
|
|
618
|
+
"sent to an external API.\n"
|
|
619
|
+
" This is necessary for semantic search with the cloud provider you "
|
|
620
|
+
"selected.\n"
|
|
621
|
+
" To skip this warning in future runs, set "
|
|
622
|
+
"CRG_ACCEPT_CLOUD_EMBEDDINGS=1 in your environment.\n"
|
|
623
|
+
" To stay fully offline, use the default 'local' provider instead "
|
|
624
|
+
"(no API key needed).\n",
|
|
625
|
+
file=sys.stderr,
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
_VALID_PROVIDERS = {"local", "openai", "google", "minimax"}
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
def get_provider(
|
|
633
|
+
provider: str | None = None,
|
|
634
|
+
model: str | None = None,
|
|
635
|
+
) -> EmbeddingProvider | None:
|
|
636
|
+
"""Get an embedding provider by name.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
provider: Provider name. One of "local", "google", "minimax", "openai",
|
|
640
|
+
or None for local. Names are case-insensitive and surrounding
|
|
641
|
+
whitespace is ignored; unknown names raise ValueError instead
|
|
642
|
+
of silently falling back to the local provider.
|
|
643
|
+
Google requires GOOGLE_API_KEY env var and explicit opt-in.
|
|
644
|
+
MiniMax requires MINIMAX_API_KEY env var and explicit opt-in.
|
|
645
|
+
OpenAI requires CRG_OPENAI_API_KEY + CRG_OPENAI_BASE_URL +
|
|
646
|
+
CRG_OPENAI_MODEL env vars (or the ``model`` arg). The egress
|
|
647
|
+
warning is skipped when the base URL points to localhost.
|
|
648
|
+
Cloud providers emit a one-time stderr warning before use
|
|
649
|
+
unless ``CRG_ACCEPT_CLOUD_EMBEDDINGS=1`` is set. See: #174
|
|
650
|
+
model: Model name/path to use. For local provider this is any
|
|
651
|
+
sentence-transformers compatible model. Falls back to
|
|
652
|
+
CRG_EMBEDDING_MODEL env var, then to all-MiniLM-L6-v2.
|
|
653
|
+
For Google provider this is a Gemini model ID.
|
|
654
|
+
For OpenAI provider this overrides CRG_OPENAI_MODEL.
|
|
655
|
+
|
|
656
|
+
Raises:
|
|
657
|
+
ValueError: If the provider name is not one of the known providers,
|
|
658
|
+
or if required environment variables are missing.
|
|
659
|
+
"""
|
|
660
|
+
name = provider.strip().lower() if provider else ""
|
|
661
|
+
if name and name not in _VALID_PROVIDERS:
|
|
662
|
+
raise ValueError(
|
|
663
|
+
f"Unknown embedding provider '{name}'. "
|
|
664
|
+
"Valid: local, openai, google, minimax"
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
if name == "openai":
|
|
668
|
+
api_key = os.environ.get("CRG_OPENAI_API_KEY")
|
|
669
|
+
base_url = os.environ.get("CRG_OPENAI_BASE_URL")
|
|
670
|
+
resolved_model = model or os.environ.get("CRG_OPENAI_MODEL")
|
|
671
|
+
if not api_key or not base_url or not resolved_model:
|
|
672
|
+
missing = [
|
|
673
|
+
name for name, val in [
|
|
674
|
+
("CRG_OPENAI_API_KEY", api_key),
|
|
675
|
+
("CRG_OPENAI_BASE_URL", base_url),
|
|
676
|
+
("CRG_OPENAI_MODEL", resolved_model),
|
|
677
|
+
] if not val
|
|
678
|
+
]
|
|
679
|
+
raise ValueError(
|
|
680
|
+
"Missing required environment variable(s) for the OpenAI "
|
|
681
|
+
f"embedding provider: {', '.join(missing)}."
|
|
682
|
+
)
|
|
683
|
+
dim_env = os.environ.get("CRG_OPENAI_DIMENSION")
|
|
684
|
+
dimension = int(dim_env) if dim_env else None
|
|
685
|
+
batch_env = os.environ.get("CRG_OPENAI_BATCH_SIZE")
|
|
686
|
+
batch_size = int(batch_env) if batch_env else None
|
|
687
|
+
if not _is_localhost_url(base_url):
|
|
688
|
+
_warn_cloud_egress("openai")
|
|
689
|
+
return OpenAIEmbeddingProvider(
|
|
690
|
+
api_key=api_key,
|
|
691
|
+
base_url=base_url,
|
|
692
|
+
model=resolved_model,
|
|
693
|
+
dimension=dimension,
|
|
694
|
+
batch_size=batch_size,
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
if name == "minimax":
|
|
698
|
+
api_key = os.environ.get("MINIMAX_API_KEY")
|
|
699
|
+
if not api_key:
|
|
700
|
+
raise ValueError(
|
|
701
|
+
"MINIMAX_API_KEY environment variable is required for "
|
|
702
|
+
"the MiniMax embedding provider."
|
|
703
|
+
)
|
|
704
|
+
_warn_cloud_egress("minimax")
|
|
705
|
+
return MiniMaxEmbeddingProvider(api_key=api_key)
|
|
706
|
+
|
|
707
|
+
if name == "google":
|
|
708
|
+
api_key = os.environ.get("GOOGLE_API_KEY")
|
|
709
|
+
if not api_key:
|
|
710
|
+
raise ValueError(
|
|
711
|
+
"GOOGLE_API_KEY environment variable is required for "
|
|
712
|
+
"the Google embedding provider."
|
|
713
|
+
)
|
|
714
|
+
_warn_cloud_egress("google")
|
|
715
|
+
try:
|
|
716
|
+
return GoogleEmbeddingProvider(
|
|
717
|
+
api_key=api_key,
|
|
718
|
+
**({"model": model} if model else {}),
|
|
719
|
+
)
|
|
720
|
+
except ImportError:
|
|
721
|
+
return None
|
|
722
|
+
|
|
723
|
+
# Default: local
|
|
724
|
+
if not _check_available():
|
|
725
|
+
return None
|
|
726
|
+
try:
|
|
727
|
+
return LocalEmbeddingProvider(model_name=model)
|
|
728
|
+
except ImportError:
|
|
729
|
+
return None
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
def _check_available() -> bool:
|
|
733
|
+
"""Check whether local embedding support is available."""
|
|
734
|
+
try:
|
|
735
|
+
import sentence_transformers # noqa: F401
|
|
736
|
+
return True
|
|
737
|
+
except ImportError:
|
|
738
|
+
return False
|
|
739
|
+
|
|
740
|
+
|
|
741
|
+
# ---------------------------------------------------------------------------
|
|
742
|
+
# SQLite vector storage
|
|
743
|
+
# ---------------------------------------------------------------------------
|
|
744
|
+
|
|
745
|
+
_EMBEDDINGS_SCHEMA = """
|
|
746
|
+
CREATE TABLE IF NOT EXISTS embeddings (
|
|
747
|
+
qualified_name TEXT PRIMARY KEY,
|
|
748
|
+
vector BLOB NOT NULL,
|
|
749
|
+
text_hash TEXT NOT NULL,
|
|
750
|
+
provider TEXT NOT NULL DEFAULT 'unknown'
|
|
751
|
+
);
|
|
752
|
+
"""
|
|
753
|
+
|
|
754
|
+
|
|
755
|
+
def _encode_vector(vec: list[float]) -> bytes:
|
|
756
|
+
"""Encode a float vector as a compact binary blob."""
|
|
757
|
+
return struct.pack(f"{len(vec)}f", *vec)
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def _decode_vector(blob: bytes) -> list[float]:
|
|
761
|
+
"""Decode a binary blob back to a float vector."""
|
|
762
|
+
n = len(blob) // 4 # 4 bytes per float32
|
|
763
|
+
return list(struct.unpack(f"{n}f", blob))
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
|
767
|
+
"""Compute cosine similarity between two vectors."""
|
|
768
|
+
if len(a) != len(b):
|
|
769
|
+
return 0.0
|
|
770
|
+
dot = sum(x * y for x, y in zip(a, b))
|
|
771
|
+
norm_a = sum(x * x for x in a) ** 0.5
|
|
772
|
+
norm_b = sum(x * x for x in b) ** 0.5
|
|
773
|
+
if norm_a == 0 or norm_b == 0:
|
|
774
|
+
return 0.0
|
|
775
|
+
return dot / (norm_a * norm_b)
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
_IDENTIFIER_SPLIT_RE = re.compile(r"([a-z])([A-Z])|[_./\-]+")
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def _split_identifier(name: str) -> str:
|
|
782
|
+
"""Split snake_case / camelCase / PascalCase / dotted into space-separated words.
|
|
783
|
+
|
|
784
|
+
Examples:
|
|
785
|
+
get_route_handler -> "get route handler"
|
|
786
|
+
APIRoute -> "API Route"
|
|
787
|
+
dispatch_request -> "dispatch request"
|
|
788
|
+
full_dispatch_request -> "full dispatch request"
|
|
789
|
+
"""
|
|
790
|
+
if not name:
|
|
791
|
+
return ""
|
|
792
|
+
# Insert space between lowercase->uppercase transitions, then collapse
|
|
793
|
+
# snake_case / dotted / hyphenated separators.
|
|
794
|
+
spaced = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
|
|
795
|
+
spaced = re.sub(r"[_./\-]+", " ", spaced)
|
|
796
|
+
return " ".join(spaced.split())
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def _node_to_text(node: GraphNode) -> str:
|
|
800
|
+
"""Convert a node to a searchable text representation.
|
|
801
|
+
|
|
802
|
+
Designed so natural-language queries land on the right node, not just on
|
|
803
|
+
the enclosing class. We include the dotted ``Parent.name`` form, the
|
|
804
|
+
identifier split into words, an explicit ``"in <Parent>"`` phrase, the
|
|
805
|
+
enclosing module directory, and the language. Tested by the
|
|
806
|
+
``multi_hop_retrieval`` benchmark — see ``docs/REPRODUCING.md``.
|
|
807
|
+
"""
|
|
808
|
+
parts: list[str] = []
|
|
809
|
+
|
|
810
|
+
# 1. Dotted form first — strongest lexical signal for "method in class"
|
|
811
|
+
if node.parent_name and node.kind != "File":
|
|
812
|
+
parts.append(f"{node.parent_name}.{node.name}")
|
|
813
|
+
|
|
814
|
+
# 2. Bare name (always present)
|
|
815
|
+
parts.append(node.name)
|
|
816
|
+
|
|
817
|
+
# 3. Split-words form of the name (only if it differs from the bare name)
|
|
818
|
+
name_split = _split_identifier(node.name)
|
|
819
|
+
if name_split and name_split.lower() != node.name.lower():
|
|
820
|
+
parts.append(name_split)
|
|
821
|
+
|
|
822
|
+
# 4. Kind ("function", "class", "test", ...)
|
|
823
|
+
if node.kind != "File":
|
|
824
|
+
parts.append(node.kind.lower())
|
|
825
|
+
|
|
826
|
+
# 5. Parent context with the split form too
|
|
827
|
+
if node.parent_name:
|
|
828
|
+
parts.append(f"in {node.parent_name}")
|
|
829
|
+
parent_split = _split_identifier(node.parent_name)
|
|
830
|
+
if parent_split and parent_split.lower() != node.parent_name.lower():
|
|
831
|
+
parts.append(parent_split)
|
|
832
|
+
|
|
833
|
+
# 6. Signature bits
|
|
834
|
+
if node.params:
|
|
835
|
+
parts.append(node.params)
|
|
836
|
+
if node.return_type:
|
|
837
|
+
parts.append(f"returns {node.return_type}")
|
|
838
|
+
|
|
839
|
+
# 7. Module / directory context from the file path — gives queries a
|
|
840
|
+
# term like "routing" or "client" to anchor against.
|
|
841
|
+
if node.file_path:
|
|
842
|
+
parent_dir = Path(node.file_path).parent.name
|
|
843
|
+
if parent_dir and parent_dir not in (".", "src", "lib"):
|
|
844
|
+
parts.append(parent_dir)
|
|
845
|
+
|
|
846
|
+
# 8. Language
|
|
847
|
+
if node.language:
|
|
848
|
+
parts.append(node.language)
|
|
849
|
+
|
|
850
|
+
return " ".join(parts)
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
class EmbeddingStore:
|
|
854
|
+
"""Manages vector embeddings for graph nodes in SQLite."""
|
|
855
|
+
|
|
856
|
+
def __init__(
|
|
857
|
+
self,
|
|
858
|
+
db_path: str | Path,
|
|
859
|
+
provider: str | None = None,
|
|
860
|
+
model: str | None = None,
|
|
861
|
+
) -> None:
|
|
862
|
+
self.provider = get_provider(provider, model=model)
|
|
863
|
+
self.available = self.provider is not None
|
|
864
|
+
self.db_path = Path(db_path)
|
|
865
|
+
self._conn = sqlite3.connect(
|
|
866
|
+
str(self.db_path), timeout=30, check_same_thread=False,
|
|
867
|
+
isolation_level=None,
|
|
868
|
+
)
|
|
869
|
+
self._conn.row_factory = sqlite3.Row
|
|
870
|
+
self._conn.executescript(_EMBEDDINGS_SCHEMA)
|
|
871
|
+
|
|
872
|
+
# Migration for existing DBs missing the provider column
|
|
873
|
+
try:
|
|
874
|
+
self._conn.execute("SELECT provider FROM embeddings LIMIT 1")
|
|
875
|
+
except sqlite3.OperationalError:
|
|
876
|
+
self._conn.execute(
|
|
877
|
+
"ALTER TABLE embeddings ADD COLUMN provider "
|
|
878
|
+
"TEXT NOT NULL DEFAULT 'unknown'"
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
self._conn.commit()
|
|
882
|
+
|
|
883
|
+
def __enter__(self) -> "EmbeddingStore":
|
|
884
|
+
return self
|
|
885
|
+
|
|
886
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def]
|
|
887
|
+
self.close()
|
|
888
|
+
|
|
889
|
+
def close(self) -> None:
|
|
890
|
+
self._conn.close()
|
|
891
|
+
|
|
892
|
+
def embed_nodes(self, nodes: list[GraphNode], batch_size: int = 64) -> int:
|
|
893
|
+
"""Compute and store embeddings for a list of nodes."""
|
|
894
|
+
if not self.provider:
|
|
895
|
+
return 0
|
|
896
|
+
|
|
897
|
+
# Filter to nodes that need embedding
|
|
898
|
+
to_embed: list[tuple[GraphNode, str, str]] = []
|
|
899
|
+
provider_name = self.provider.name
|
|
900
|
+
|
|
901
|
+
for node in nodes:
|
|
902
|
+
if node.kind == "File":
|
|
903
|
+
continue
|
|
904
|
+
text = _node_to_text(node)
|
|
905
|
+
text_hash = hashlib.sha256(text.encode()).hexdigest()
|
|
906
|
+
|
|
907
|
+
existing = self._conn.execute(
|
|
908
|
+
"SELECT text_hash, provider FROM embeddings WHERE qualified_name = ?",
|
|
909
|
+
(node.qualified_name,),
|
|
910
|
+
).fetchone()
|
|
911
|
+
|
|
912
|
+
# Re-embed if text changed OR provider changed
|
|
913
|
+
if (existing and existing["text_hash"] == text_hash
|
|
914
|
+
and existing["provider"] == provider_name):
|
|
915
|
+
continue
|
|
916
|
+
to_embed.append((node, text, text_hash))
|
|
917
|
+
|
|
918
|
+
if not to_embed:
|
|
919
|
+
return 0
|
|
920
|
+
|
|
921
|
+
# Encode in batches
|
|
922
|
+
texts = [t for _, t, _ in to_embed]
|
|
923
|
+
vectors = self.provider.embed(texts)
|
|
924
|
+
|
|
925
|
+
for (node, _text, text_hash), vec in zip(to_embed, vectors):
|
|
926
|
+
blob = _encode_vector(vec)
|
|
927
|
+
self._conn.execute(
|
|
928
|
+
"""INSERT OR REPLACE INTO embeddings (qualified_name, vector, text_hash, provider)
|
|
929
|
+
VALUES (?, ?, ?, ?)""",
|
|
930
|
+
(node.qualified_name, blob, text_hash, provider_name),
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
self._conn.commit()
|
|
934
|
+
return len(to_embed)
|
|
935
|
+
|
|
936
|
+
def search(self, query: str, limit: int = 20) -> list[tuple[str, float]]:
|
|
937
|
+
"""Search for nodes by semantic similarity."""
|
|
938
|
+
if not self.provider:
|
|
939
|
+
return []
|
|
940
|
+
|
|
941
|
+
provider_name = self.provider.name
|
|
942
|
+
query_vec = self.provider.embed_query(query)
|
|
943
|
+
|
|
944
|
+
# Process in chunks, only matching current provider
|
|
945
|
+
scored: list[tuple[str, float]] = []
|
|
946
|
+
cursor = self._conn.execute(
|
|
947
|
+
"SELECT qualified_name, vector FROM embeddings WHERE provider = ?",
|
|
948
|
+
(provider_name,),
|
|
949
|
+
)
|
|
950
|
+
chunk_size = 500
|
|
951
|
+
while True:
|
|
952
|
+
rows = cursor.fetchmany(chunk_size)
|
|
953
|
+
if not rows:
|
|
954
|
+
break
|
|
955
|
+
for row in rows:
|
|
956
|
+
vec = _decode_vector(row["vector"])
|
|
957
|
+
sim = _cosine_similarity(query_vec, vec)
|
|
958
|
+
scored.append((row["qualified_name"], sim))
|
|
959
|
+
|
|
960
|
+
scored.sort(key=lambda x: x[1], reverse=True)
|
|
961
|
+
return scored[:limit]
|
|
962
|
+
|
|
963
|
+
def remove_node(self, qualified_name: str) -> None:
|
|
964
|
+
self._conn.execute(
|
|
965
|
+
"DELETE FROM embeddings WHERE qualified_name = ?", (qualified_name,)
|
|
966
|
+
)
|
|
967
|
+
self._conn.commit()
|
|
968
|
+
|
|
969
|
+
def count(self) -> int:
|
|
970
|
+
return self._conn.execute("SELECT COUNT(*) FROM embeddings").fetchone()[0]
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
def embed_all_nodes(graph_store: GraphStore, embedding_store: EmbeddingStore) -> int:
|
|
974
|
+
"""Embed all non-file nodes in the graph."""
|
|
975
|
+
if not embedding_store.available:
|
|
976
|
+
return 0
|
|
977
|
+
|
|
978
|
+
all_files = graph_store.get_all_files()
|
|
979
|
+
all_nodes: list[GraphNode] = []
|
|
980
|
+
for f in all_files:
|
|
981
|
+
all_nodes.extend(graph_store.get_nodes_by_file(f))
|
|
982
|
+
|
|
983
|
+
return embedding_store.embed_nodes(all_nodes)
|
|
984
|
+
|
|
985
|
+
|
|
986
|
+
def semantic_search(
|
|
987
|
+
query: str,
|
|
988
|
+
graph_store: GraphStore,
|
|
989
|
+
embedding_store: EmbeddingStore,
|
|
990
|
+
limit: int = 20,
|
|
991
|
+
) -> list[dict[str, Any]]:
|
|
992
|
+
"""Search nodes using vector similarity, falling back to keyword search."""
|
|
993
|
+
if embedding_store.available and embedding_store.count() > 0:
|
|
994
|
+
results = embedding_store.search(query, limit=limit)
|
|
995
|
+
output = []
|
|
996
|
+
for qn, score in results:
|
|
997
|
+
node = graph_store.get_node(qn)
|
|
998
|
+
if node:
|
|
999
|
+
d = node_to_dict(node)
|
|
1000
|
+
d["similarity_score"] = round(score, 4)
|
|
1001
|
+
output.append(d)
|
|
1002
|
+
return output
|
|
1003
|
+
|
|
1004
|
+
# Fallback to keyword search
|
|
1005
|
+
nodes = graph_store.search_nodes(query, limit=limit)
|
|
1006
|
+
return [node_to_dict(n) for n in nodes]
|