mcp-plesk-dev-docs 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,218 @@
1
+ """
2
+ Model profile configuration for mcp-plesk-dev-docs.
3
+
4
+ Profiles let you trade RAM/VRAM footprint against retrieval quality.
5
+ Select a profile via the PLESK_MODEL_PROFILE environment variable.
6
+
7
+ PLESK_MODEL_PROFILE=light  ~200 MB total (M2 MacBook Air, memory-constrained)
8
+ PLESK_MODEL_PROFILE=full  ~1.8 GB total (RTX 4070 Super, max quality)
9
+ PLESK_MODEL_PROFILE=medium  ~600 MB total (balanced middle ground)
10
+
11
+ You can also override individual components without changing the profile:
12
+ PLESK_EMBED_MODEL=BAAI/bge-base-en-v1.5
13
+ PLESK_RERANKER_MODEL=cross-encoder/ms-marco-MiniLM-L-6-v2
14
+ PLESK_RERANKER_ENABLED=false
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ from dataclasses import dataclass
21
+
22
+ from plesk_unified.platform_utils import get_optimal_device, get_platform_info
23
+ from plesk_unified.settings import settings
24
+
25
+ logger = logging.getLogger("plesk_unified")
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Profile definitions
29
+ # ---------------------------------------------------------------------------
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class ModelProfile:
34
+ name: str
35
+ embed_model: str
36
+ embed_dim: int # must match LanceDB vector column dimension
37
+ reranker_model: str | None
38
+ reranker_enabled: bool
39
+ description: str
40
+ approx_ram_mb: int
41
+ rerank_candidates: int = 35
42
+ use_turboquant: bool = False
43
+ tq_bits: int = 5
44
+ tq_top_k: int = 25
45
+
46
+
47
+ _PROFILES: dict[str, ModelProfile] = {
48
+ "light": ModelProfile(
49
+ name="light",
50
+ embed_model="BAAI/bge-small-en-v1.5",
51
+ embed_dim=384,
52
+ reranker_model="cross-encoder/ms-marco-MiniLM-L4-v2",
53
+ reranker_enabled=True,
54
+ description=(
55
+ "~200 MB total. Ideal for M2 MacBook Air or any memory-constrained host."
56
+ ),
57
+ approx_ram_mb=200,
58
+ rerank_candidates=35,
59
+ ),
60
+ "medium": ModelProfile(
61
+ name="medium",
62
+ embed_model="BAAI/bge-base-en-v1.5",
63
+ embed_dim=768,
64
+ reranker_model="cross-encoder/ms-marco-MiniLM-L4-v2",
65
+ reranker_enabled=True,
66
+ description="~600 MB total. Good quality with moderate memory use.",
67
+ approx_ram_mb=600,
68
+ rerank_candidates=35,
69
+ ),
70
+ "full": ModelProfile(
71
+ name="full",
72
+ embed_model="BAAI/bge-m3",
73
+ embed_dim=1024,
74
+ reranker_model="BAAI/bge-reranker-base",
75
+ reranker_enabled=True,
76
+ description=(
77
+ "~1.8 GB total. Maximum quality. Recommended for RTX 4070 Super / CUDA."
78
+ ),
79
+ approx_ram_mb=1800,
80
+ ),
81
+ "full-tq": ModelProfile(
82
+ name="full-tq",
83
+ embed_model="BAAI/bge-m3",
84
+ embed_dim=1024,
85
+ reranker_model="BAAI/bge-reranker-base",
86
+ reranker_enabled=True,
87
+ description=(
88
+ "TurboQuant 4-bit profile with category-aware retrieval. "
89
+ "Quality parity target with significantly lower latency."
90
+ ),
91
+ approx_ram_mb=1300,
92
+ use_turboquant=True,
93
+ tq_bits=4,
94
+ tq_top_k=25,
95
+ ),
96
+ }
97
+
98
+ DEFAULT_PROFILE = "full-tq"
99
+
100
+
101
+ # ---------------------------------------------------------------------------
102
+ # Resolution logic (env vars  profile  individual overrides)
103
+ # ---------------------------------------------------------------------------
104
+
105
+
106
+ def get_active_profile() -> ModelProfile:
107
+ """
108
+ Resolve the active model profile from environment variables.
109
+
110
+ Priority (highest to lowest):
111
+ 1. PLESK_EMBED_MODEL / PLESK_RERANKER_MODEL (per-component overrides)
112
+ 2. PLESK_MODEL_PROFILE (named profile)
113
+ 3. Compiled-in default ("full-tq")
114
+ """
115
+ profile_name = settings.plesk_model_profile
116
+ profile_name = profile_name.lower().strip()
117
+
118
+ if profile_name not in _PROFILES:
119
+ logger.warning(
120
+ "Unknown PLESK_MODEL_PROFILE=%r. Valid options: %s. Falling back to '%s'.",
121
+ profile_name,
122
+ ", ".join(_PROFILES),
123
+ DEFAULT_PROFILE,
124
+ )
125
+ profile_name = DEFAULT_PROFILE
126
+
127
+ base = _PROFILES[profile_name]
128
+
129
+ # Apply per-component overrides on top of the profile
130
+ embed_model = (settings.plesk_embed_model or base.embed_model).strip()
131
+
132
+ # Handle reranker_model specifically to allow empty string as "none"
133
+ if settings.plesk_reranker_model is not None:
134
+ reranker_model = settings.plesk_reranker_model.strip() or None
135
+ else:
136
+ reranker_model = base.reranker_model
137
+
138
+ if settings.plesk_reranker_enabled is False:
139
+ reranker_enabled = False
140
+ elif settings.plesk_reranker_enabled is True:
141
+ reranker_enabled = True
142
+ else:
143
+ reranker_enabled = base.reranker_enabled
144
+
145
+ # Determine vector dimension  if the embed model changed, the user must
146
+ # also set PLESK_EMBED_DIM or we fall back to the profile default with a warning.
147
+ embed_dim = base.embed_dim
148
+ if embed_model != base.embed_model:
149
+ if settings.plesk_embed_dim:
150
+ embed_dim = settings.plesk_embed_dim
151
+ else:
152
+ logger.warning(
153
+ "PLESK_EMBED_MODEL overridden to %r but PLESK_EMBED_DIM not set. "
154
+ "Using profile default dim=%d. Set PLESK_EMBED_DIM if the model "
155
+ "uses a different dimension, then delete storage/lancedb and reindex.",
156
+ embed_model,
157
+ embed_dim,
158
+ )
159
+
160
+ active = ModelProfile(
161
+ name=profile_name,
162
+ embed_model=embed_model,
163
+ embed_dim=embed_dim,
164
+ reranker_model=reranker_model,
165
+ reranker_enabled=reranker_enabled and (reranker_model is not None),
166
+ description=base.description,
167
+ approx_ram_mb=base.approx_ram_mb,
168
+ rerank_candidates=settings.plesk_rerank_candidates or base.rerank_candidates,
169
+ use_turboquant=base.use_turboquant,
170
+ tq_bits=base.tq_bits,
171
+ tq_top_k=base.tq_top_k,
172
+ )
173
+
174
+ # VRAM Auto-tuning check
175
+ device = get_optimal_device()
176
+ if device == "cuda":
177
+ info = get_platform_info()
178
+ free_vram = info.get("vram_free_mb")
179
+ if free_vram and active.approx_ram_mb > free_vram:
180
+ logger.warning(
181
+ "VRAM Auto-Tune: Profile '%s' requires ~%d MB but only %d MB is free. "
182
+ "Consider switching to a lighter profile (e.g., 'medium' or 'light') "
183
+ "to avoid Out-Of-Memory (OOM) errors.",
184
+ active.name,
185
+ active.approx_ram_mb,
186
+ free_vram,
187
+ )
188
+
189
+ logger.info(
190
+ "Active model profile: %s | embed=%s (dim=%d) | reranker=%s "
191
+ "(enabled=%s) | ~%d MB",
192
+ active.name,
193
+ active.embed_model,
194
+ active.embed_dim,
195
+ active.reranker_model,
196
+ active.reranker_enabled,
197
+ active.approx_ram_mb,
198
+ )
199
+
200
+ return active
201
+
202
+
203
+ def list_profiles() -> dict[str, dict]:
204
+ """Return a serialisable summary of all built-in profiles."""
205
+ return {
206
+ name: {
207
+ "embed_model": p.embed_model,
208
+ "embed_dim": p.embed_dim,
209
+ "reranker_model": p.reranker_model,
210
+ "reranker_enabled": p.reranker_enabled,
211
+ "approx_ram_mb": p.approx_ram_mb,
212
+ "use_turboquant": p.use_turboquant,
213
+ "tq_bits": p.tq_bits,
214
+ "tq_top_k": p.tq_top_k,
215
+ "description": p.description,
216
+ }
217
+ for name, p in _PROFILES.items()
218
+ }
@@ -0,0 +1,214 @@
1
+ """
2
+ Platform detection and GPU configuration utilities.
3
+
4
+ Provides cross-platform support for:
5
+ - Windows: CUDA GPU acceleration (if available)
6
+ - macOS: Apple Silicon MPS acceleration (M1/M2/M3) or CPU fallback
7
+ - Linux: CUDA or CPU fallback
8
+ """
9
+
10
+ import logging
11
+ import os
12
+ import platform
13
+ import sys
14
+ from typing import Any, Optional
15
+
16
+ # Lazy import for torch to avoid heavy import at module level
17
+
18
+ _torch: Optional[Any] = None
19
+
20
+ # Inherit logger from the main application (configured in server.py)
21
+ logger = logging.getLogger("plesk_unified")
22
+
23
+
24
+ def _get_torch() -> Any:
25
+ """Lazy load torch to avoid import overhead."""
26
+ global _torch
27
+ if _torch is None:
28
+ try:
29
+ import torch
30
+
31
+ _torch = torch
32
+ except ImportError:
33
+ logger.warning(
34
+ "PyTorch import failed. GPU acceleration will be unavailable."
35
+ )
36
+ raise
37
+ return _torch
38
+
39
+
40
+ def get_platform_info() -> dict:
41
+ """
42
+ Returns a dictionary with detailed platform information.
43
+ """
44
+ info: dict[str, Any] = {
45
+ "system": platform.system(),
46
+ "machine": platform.machine(),
47
+ "python_version": platform.python_version(),
48
+ "python_executable": sys.executable,
49
+ }
50
+
51
+ # Try to get PyTorch info
52
+ try:
53
+ torch = _get_torch()
54
+ info["torch_version"] = str(torch.__version__)
55
+ info["cuda_available"] = torch.cuda.is_available()
56
+ # Check MPS availability on macOS
57
+ info["mps_available"] = (
58
+ torch.backends.mps.is_available()
59
+ if platform.system() == "Darwin" and hasattr(torch.backends, "mps")
60
+ else False
61
+ )
62
+
63
+ if torch.cuda.is_available():
64
+ info["cuda_version"] = str(torch.version.cuda)
65
+ info["gpu_count"] = torch.cuda.device_count()
66
+ if torch.cuda.device_count() > 0:
67
+ info["gpu_name"] = str(torch.cuda.get_device_name(0))
68
+ try:
69
+ free_vram, total_vram = torch.cuda.mem_get_info()
70
+ info["vram_free_mb"] = free_vram // (1024**2)
71
+ info["vram_total_mb"] = total_vram // (1024**2)
72
+ except Exception:
73
+ pass
74
+
75
+ except ImportError:
76
+ info["torch_available"] = False
77
+ except Exception as e:
78
+ info["torch_error"] = str(e)
79
+ logger.debug("Error gathering detailed platform info: %s", e)
80
+
81
+ return info
82
+
83
+
84
+ def get_optimal_device() -> str:
85
+ """
86
+ Determine the optimal compute device based on platform and hardware.
87
+ Priority: Environment Variable -> MPS (macOS) -> CUDA (Win/Linux) -> CPU
88
+ """
89
+ # Check for forced device via environment variable
90
+ forced_device = os.environ.get("FORCE_DEVICE", "").lower().strip()
91
+ if forced_device in ("cuda", "mps", "cpu"):
92
+ logger.info("Device forced via env var: %s", forced_device)
93
+ return forced_device
94
+
95
+ system = platform.system()
96
+
97
+ # macOS: Check for Apple Silicon MPS
98
+ if system == "Darwin":
99
+ try:
100
+ torch = _get_torch()
101
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
102
+ return "mps"
103
+ except Exception:
104
+ pass
105
+ return "cpu"
106
+
107
+ # Windows/Linux: Check for CUDA
108
+ if system in ("Windows", "Linux"):
109
+ try:
110
+ torch = _get_torch()
111
+ if torch.cuda.is_available():
112
+ return "cuda"
113
+ except Exception:
114
+ pass
115
+
116
+ return "cpu"
117
+
118
+
119
+ def log_hardware_degradation(
120
+ failed_device: str, reason: Exception | str, fallback: str = "cpu"
121
+ ) -> None:
122
+ """
123
+ Log a WARNING when a hardware accelerator (CUDA/MPS) fails and falls back to CPU.
124
+ T05 (M1) Implementation.
125
+ """
126
+ logger.warning(
127
+ "Hardware acceleration failure: %s initialization failed (%s). "
128
+ "Falling back to %s. Performance may be degraded.",
129
+ failed_device.upper(),
130
+ str(reason),
131
+ fallback.upper(),
132
+ )
133
+
134
+
135
+ def get_device_config() -> dict:
136
+ """
137
+ Get comprehensive device configuration for model initialization.
138
+ """
139
+ device = get_optimal_device()
140
+ config: dict[str, Any] = {
141
+ "device": device,
142
+ "precision": "float32",
143
+ }
144
+
145
+ try:
146
+ if device == "cuda":
147
+ torch = _get_torch()
148
+ if torch.cuda.is_available():
149
+ config["precision"] = "float16"
150
+ config["torch_dtype"] = "float16"
151
+ elif device == "mps":
152
+ config["precision"] = "float16"
153
+ config["torch_dtype"] = "float16"
154
+ except Exception as e:
155
+ logger.debug("Failed to set precision config: %s", e)
156
+
157
+ return config
158
+
159
+
160
+ def log_platform_info() -> None:
161
+ """Logs platform information to the shared logger."""
162
+ info = get_platform_info()
163
+ device = get_optimal_device()
164
+
165
+ # Construct a concise summary for the log
166
+ summary = [
167
+ f"OS: {info.get('system')} {info.get('machine')}",
168
+ f"Python: {info.get('python_version')}",
169
+ f"Device: {device.upper()}",
170
+ ]
171
+
172
+ if info.get("cuda_available"):
173
+ summary.append(
174
+ f"GPU: {info.get('gpu_name', 'Unknown')} (CUDA {info.get('cuda_version')})"
175
+ )
176
+ elif info.get("mps_available"):
177
+ summary.append("GPU: Apple Silicon (MPS)")
178
+
179
+ logger.info("Platform Check: " + " | ".join(summary))
180
+
181
+
182
+ def is_apple_silicon() -> bool:
183
+ """Check if running on Apple Silicon (M1/M2/M3)."""
184
+ return platform.system() == "Darwin" and platform.machine() == "arm64"
185
+
186
+
187
+ def is_windows() -> bool:
188
+ """Check if running on Windows."""
189
+ return platform.system() == "Windows"
190
+
191
+
192
+ def is_linux() -> bool:
193
+ """Check if running on Linux."""
194
+ return platform.system() == "Linux"
195
+
196
+
197
+ def is_macos() -> bool:
198
+ """Check if running on macOS."""
199
+ return platform.system() == "Darwin"
200
+
201
+
202
+ def get_model_cache_dir() -> str:
203
+ """
204
+ Get the appropriate model cache directory for the platform.
205
+ """
206
+ import tempfile
207
+
208
+ if is_windows():
209
+ base = os.environ.get("LOCALAPPDATA", tempfile.gettempdir())
210
+ return os.path.join(base, "huggingface", "hub")
211
+ elif is_macos():
212
+ return os.path.expanduser("~/Library/Caches/huggingface/hub")
213
+ else:
214
+ return os.path.expanduser("~/.cache/huggingface/hub")
@@ -0,0 +1,93 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Literal, Optional
4
+
5
+ from pydantic_settings import BaseSettings, SettingsConfigDict
6
+
7
+
8
+ class PleskSettings(BaseSettings):
9
+ """
10
+ Configuration settings for mcp-plesk-dev-docs.
11
+
12
+ Fields map to environment variables (e.g., log_level maps to LOG_LEVEL).
13
+ """
14
+
15
+ model_config = SettingsConfigDict(
16
+ env_file=".env",
17
+ env_file_encoding="utf-8",
18
+ extra="ignore",
19
+ )
20
+
21
+ # Logging
22
+ log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
23
+ log_file: Optional[str] = None
24
+ log_handler: Literal["os", "file", "both"] = "os"
25
+
26
+ # Model Profile & Overrides
27
+ plesk_model_profile: str = "full-tq"
28
+ plesk_embed_model: Optional[str] = None
29
+ plesk_reranker_model: Optional[str] = None
30
+ plesk_reranker_enabled: Optional[bool] = None
31
+ plesk_embed_dim: Optional[int] = None
32
+
33
+ # Operational Behaviors
34
+ plesk_daemon_auto_warmup: bool = False
35
+ plesk_auto_refresh_on_startup: bool = True
36
+ plesk_index_summaries: bool = False
37
+ plesk_enable_fts: bool = True
38
+ plesk_enable_ast_chunking: bool = False
39
+ plesk_enable_sampling: bool = False
40
+ # Default number of candidates to send to the reranker when not overridden.
41
+ # Historically this default was 50; keep that default for backward compatibility.
42
+ plesk_rerank_candidates: Optional[int] = 50
43
+ plesk_min_relevance_threshold: Optional[float] = None
44
+
45
+ # External APIs & Hardware
46
+ openrouter_api_key: str = ""
47
+ force_device: Optional[str] = None
48
+ plesk_html_llm_table_normalize: bool = False
49
+
50
+ # Third-party library silencing
51
+ tqdm_disable: bool = True
52
+ transformers_verbosity: str = "error"
53
+
54
+ @property
55
+ def effective_log_file(self) -> str:
56
+ """Resolve the log file path, ensuring the parent directory exists."""
57
+ if self.log_file:
58
+ return self.log_file
59
+ base_dir = Path(__file__).parent.parent
60
+ log_dir = base_dir / "storage" / "logs"
61
+ log_dir.mkdir(parents=True, exist_ok=True)
62
+ return str(log_dir / "plesk_unified.log")
63
+
64
+ @property
65
+ def embedding_model_dimensions(self) -> int:
66
+ """Return embedding vector dimension: explicit override or profile default.
67
+
68
+ Priority: `plesk_embed_dim` (explicit env override) -> profile default.
69
+ Lazy-imports model_config to avoid expensive imports at module load time.
70
+ """
71
+ if self.plesk_embed_dim:
72
+ return self.plesk_embed_dim
73
+ # Lazy import so we don't cause circular imports at module import time.
74
+ from plesk_unified.model_config import get_active_profile
75
+
76
+ return get_active_profile().embed_dim
77
+
78
+
79
+ # Allow tests to suppress .env loading by setting PLESK_ENV_FILE="" or to another file.
80
+ _env_file = os.environ.get("PLESK_ENV_FILE", ".env")
81
+ if not _env_file:
82
+ _env_file = None
83
+
84
+ # Update the class-level model_config before instantiating the settings singleton.
85
+ # We intentionally set `env_file` even when `_env_file` is `None` so tests can
86
+ # suppress .env loading by setting `PLESK_ENV_FILE=""` in the environment.
87
+ PleskSettings.model_config = SettingsConfigDict(
88
+ env_file=_env_file,
89
+ env_file_encoding="utf-8",
90
+ extra="ignore",
91
+ )
92
+
93
+ settings = PleskSettings()
@@ -0,0 +1,55 @@
1
+ import json
2
+ import logging
3
+ import hashlib
4
+ from pathlib import Path
5
+ from typing import Dict, Optional
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ CACHE_FILE = Path("storage/summaries_cache.json")
10
+
11
+
12
+ class SummaryCache:
13
+ """A persistent cache for AI-generated summaries indexed by file content hash."""
14
+
15
+ def __init__(self):
16
+ self.cache: Dict[str, str] = {}
17
+ self._load()
18
+
19
+ def _get_file_hash(self, file_path: Path) -> str:
20
+ """Computes MD5 hash of file content."""
21
+ hasher = hashlib.md5()
22
+ with open(file_path, "rb") as f:
23
+ for chunk in iter(lambda: f.read(4096), b""):
24
+ hasher.update(chunk)
25
+ return hasher.hexdigest()
26
+
27
+ def _load(self):
28
+ """Loads cache from disk."""
29
+ if CACHE_FILE.exists():
30
+ try:
31
+ with open(CACHE_FILE, "r", encoding="utf-8") as f:
32
+ self.cache = json.load(f)
33
+ logger.info("Loaded %d summaries from cache.", len(self.cache))
34
+ except Exception as e:
35
+ logger.error("Failed to load summary cache: %s", e)
36
+ self.cache = {}
37
+
38
+ def save(self):
39
+ """Saves cache to disk."""
40
+ try:
41
+ CACHE_FILE.parent.mkdir(parents=True, exist_ok=True)
42
+ with open(CACHE_FILE, "w", encoding="utf-8") as f:
43
+ json.dump(self.cache, f, indent=2, ensure_ascii=False)
44
+ except Exception as e:
45
+ logger.error("Failed to save summary cache: %s", e)
46
+
47
+ def get(self, file_path: Path) -> Optional[str]:
48
+ """Retrieves a cached summary if the file hasn't changed."""
49
+ file_hash = self._get_file_hash(file_path)
50
+ return self.cache.get(file_hash)
51
+
52
+ def set(self, file_path: Path, summary: str):
53
+ """Stores a summary in the cache."""
54
+ file_hash = self._get_file_hash(file_path)
55
+ self.cache[file_hash] = summary
@@ -0,0 +1,85 @@
1
+ """TurboQuantIndex loading strategy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from plesk_unified.turboquant import TurboQuantProd
9
+
10
+
11
+ class TurboQuantIndex:
12
+ def __init__(self, dim: int, bits: int = 3, device: str = "cpu"):
13
+ self.dim = dim
14
+ self.bits = bits
15
+ self.device = device
16
+ self.quantizer = TurboQuantProd(dim, bits, device=device)
17
+
18
+ # Batched tensors for compressed vectors.
19
+ self.compressed_db = None
20
+ self._meta = []
21
+ self._category_to_indices: dict[str, list[int]] = {}
22
+
23
+ def add(self, vecs: np.ndarray, metas: list[dict]) -> None:
24
+ # 1. L2-Normalize the input vectors (Critical for TurboQuant accuracy)
25
+ norms = np.linalg.norm(vecs, axis=-1, keepdims=True)
26
+ vecs_normalized = vecs / np.maximum(norms, 1e-12)
27
+
28
+ # 2. Quantize the entire batch at once
29
+ x = torch.from_numpy(vecs_normalized).to(self.device)
30
+ compressed = self.quantizer.quantize(x)
31
+
32
+ # 3. Store or append to the batched tensors in CPU memory
33
+ if self.compressed_db is None:
34
+ self.compressed_db = {k: v.cpu() for k, v in compressed.items()}
35
+ else:
36
+ for k in self.compressed_db:
37
+ self.compressed_db[k] = torch.cat(
38
+ [self.compressed_db[k], compressed[k].cpu()], dim=0
39
+ )
40
+
41
+ start_idx = len(self._meta)
42
+ self._meta.extend(metas)
43
+ for offset, meta in enumerate(metas):
44
+ category = meta.get("category")
45
+ if isinstance(category, str) and category:
46
+ index = start_idx + offset
47
+ self._category_to_indices.setdefault(category, []).append(index)
48
+
49
+ def search(
50
+ self, query_vec: np.ndarray, top_k: int = 25, category: str | None = None
51
+ ):
52
+ if self.compressed_db is None:
53
+ return []
54
+
55
+ selected_indices: list[int]
56
+ if category:
57
+ selected_indices = self._category_to_indices.get(category, [])
58
+ if not selected_indices:
59
+ return []
60
+ else:
61
+ selected_indices = list(range(len(self._meta)))
62
+
63
+ # 1. L2-Normalize the query
64
+ norm = np.linalg.norm(query_vec)
65
+ query_normalized = query_vec / max(norm, 1e-12)
66
+
67
+ # 2. Prepare query as a batched tensor (1, dim)
68
+ q = torch.from_numpy(query_normalized).to(self.device).unsqueeze(0)
69
+
70
+ # 3. Slice candidates and move them to the target device.
71
+ selected_tensor = torch.as_tensor(selected_indices, dtype=torch.long)
72
+ db_on_device = {
73
+ k: v.index_select(0, selected_tensor).to(self.device)
74
+ for k, v in self.compressed_db.items()
75
+ }
76
+
77
+ # 4. Perform a SINGLE batched inner product calculation
78
+ with torch.no_grad():
79
+ scores = self.quantizer.inner_product(q, db_on_device).squeeze(0)
80
+
81
+ # 5. Sort and return
82
+ scores_np = scores.cpu().numpy()
83
+ idx = np.argsort(-scores_np)[:top_k]
84
+
85
+ return [(self._meta[selected_indices[i]], float(scores_np[i])) for i in idx]
@@ -0,0 +1,21 @@
1
+ """TurboQuant helpers used by the unified retrieval path.
2
+
3
+ Base implementation: https://github.com/tonbistudio/turboquant-pytorch
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from .compressors import TurboQuantCompressorMSE, TurboQuantCompressorV2
9
+ from .lloyd_max import LloydMaxCodebook, compute_expected_distortion, solve_lloyd_max
10
+ from .turboquant import TurboQuantKVCache, TurboQuantMSE, TurboQuantProd
11
+
12
+ __all__ = [
13
+ "TurboQuantCompressorMSE",
14
+ "TurboQuantCompressorV2",
15
+ "TurboQuantKVCache",
16
+ "TurboQuantMSE",
17
+ "TurboQuantProd",
18
+ "LloydMaxCodebook",
19
+ "compute_expected_distortion",
20
+ "solve_lloyd_max",
21
+ ]