routing-memory 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rm/__init__.py +28 -0
- rm/codebook.py +337 -0
- rm/drift.py +145 -0
- rm/embeddings/__init__.py +4 -0
- rm/embeddings/base.py +22 -0
- rm/embeddings/local.py +56 -0
- rm/filtering.py +15 -0
- rm/memory.py +351 -0
- rm/py.typed +0 -0
- rm/retrieval.py +90 -0
- rm/storage/__init__.py +4 -0
- rm/storage/base.py +46 -0
- rm/storage/sqlite.py +122 -0
- routing_memory-0.1.0.dist-info/METADATA +212 -0
- routing_memory-0.1.0.dist-info/RECORD +18 -0
- routing_memory-0.1.0.dist-info/WHEEL +5 -0
- routing_memory-0.1.0.dist-info/licenses/LICENSE +21 -0
- routing_memory-0.1.0.dist-info/top_level.txt +1 -0
rm/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Routing Memory — Lightweight long-term memory via vector-quantized routing."""
|
|
2
|
+
|
|
3
|
+
from .memory import RoutingMemory
|
|
4
|
+
from .codebook import Codebook, CentroidBucket
|
|
5
|
+
from .retrieval import L1Retriever, L1Result, RetrievalResult
|
|
6
|
+
from .filtering import filter_by_score, filter_top_n
|
|
7
|
+
from .drift import DriftMonitor, DriftAlarm
|
|
8
|
+
from .embeddings import LocalEmbeddings, EmbeddingBackend
|
|
9
|
+
from .storage import RMSQLiteBackend, RMStorageBackend
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"RoutingMemory",
|
|
13
|
+
"Codebook",
|
|
14
|
+
"CentroidBucket",
|
|
15
|
+
"L1Retriever",
|
|
16
|
+
"L1Result",
|
|
17
|
+
"RetrievalResult",
|
|
18
|
+
"filter_by_score",
|
|
19
|
+
"filter_top_n",
|
|
20
|
+
"DriftMonitor",
|
|
21
|
+
"DriftAlarm",
|
|
22
|
+
"LocalEmbeddings",
|
|
23
|
+
"EmbeddingBackend",
|
|
24
|
+
"RMSQLiteBackend",
|
|
25
|
+
"RMStorageBackend",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
__version__ = "0.1.0"
|
rm/codebook.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
"""VQ Codebook: MiniBatchKMeans with adaptive K + online adaptation.
|
|
2
|
+
|
|
3
|
+
Features:
|
|
4
|
+
- Adaptive K: K = ceil(N / B_target)
|
|
5
|
+
- EMA centroid update: centroids drift toward recent data
|
|
6
|
+
- Split: high-variance buckets split into 2
|
|
7
|
+
- Prune: idle centroids get reassigned and removed
|
|
8
|
+
- K range guard: K stays in [K_MIN, K_MAX]
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import time
|
|
13
|
+
import numpy as np
|
|
14
|
+
from typing import Dict, List, Optional, Tuple
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger("rm.codebook")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class CentroidBucket:
|
|
22
|
+
"""Items assigned to a centroid."""
|
|
23
|
+
item_ids: List[str] = field(default_factory=list)
|
|
24
|
+
embeddings: List[np.ndarray] = field(default_factory=list)
|
|
25
|
+
last_accessed: float = field(default_factory=time.time)
|
|
26
|
+
access_count: int = 0
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Codebook:
|
|
30
|
+
"""Vector Quantization codebook using MiniBatchKMeans.
|
|
31
|
+
|
|
32
|
+
Adaptive K: K = ceil(N / B_target) where B_target is items per bucket.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
B_TARGET = 20 # target items per bucket
|
|
36
|
+
K_MIN = 2
|
|
37
|
+
K_MAX = 512
|
|
38
|
+
|
|
39
|
+
def __init__(self, dim: int = 384, seed: int = 42):
|
|
40
|
+
self.dim = dim
|
|
41
|
+
self.seed = seed
|
|
42
|
+
self.centroids: Optional[np.ndarray] = None # (K, dim)
|
|
43
|
+
self.buckets: Dict[int, CentroidBucket] = {}
|
|
44
|
+
self._n_items = 0
|
|
45
|
+
self._fitted = False
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def K(self) -> int:
|
|
49
|
+
if self.centroids is None:
|
|
50
|
+
return 0
|
|
51
|
+
return self.centroids.shape[0]
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def fitted(self) -> bool:
|
|
55
|
+
return self._fitted and self.centroids is not None
|
|
56
|
+
|
|
57
|
+
def _adaptive_k(self, n: int) -> int:
|
|
58
|
+
k = max(self.K_MIN, int(np.ceil(n / self.B_TARGET)))
|
|
59
|
+
return min(k, self.K_MAX)
|
|
60
|
+
|
|
61
|
+
def fit(self, embeddings: np.ndarray, item_ids: List[str]) -> None:
|
|
62
|
+
"""Fit codebook from scratch using MiniBatchKMeans."""
|
|
63
|
+
n = embeddings.shape[0]
|
|
64
|
+
if n == 0:
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
k = self._adaptive_k(n)
|
|
68
|
+
k = min(k, n) # can't have more centroids than items
|
|
69
|
+
|
|
70
|
+
from sklearn.cluster import MiniBatchKMeans
|
|
71
|
+
kmeans = MiniBatchKMeans(
|
|
72
|
+
n_clusters=k,
|
|
73
|
+
random_state=self.seed,
|
|
74
|
+
batch_size=min(256, n),
|
|
75
|
+
n_init=3,
|
|
76
|
+
)
|
|
77
|
+
labels = kmeans.fit_predict(embeddings)
|
|
78
|
+
|
|
79
|
+
# Normalize centroids
|
|
80
|
+
norms = np.linalg.norm(kmeans.cluster_centers_, axis=1, keepdims=True)
|
|
81
|
+
norms = np.maximum(norms, 1e-10)
|
|
82
|
+
self.centroids = (kmeans.cluster_centers_ / norms).astype(np.float32)
|
|
83
|
+
|
|
84
|
+
# Build buckets
|
|
85
|
+
self.buckets = {}
|
|
86
|
+
for idx, label in enumerate(labels):
|
|
87
|
+
label = int(label)
|
|
88
|
+
if label not in self.buckets:
|
|
89
|
+
self.buckets[label] = CentroidBucket()
|
|
90
|
+
self.buckets[label].item_ids.append(item_ids[idx])
|
|
91
|
+
self.buckets[label].embeddings.append(embeddings[idx])
|
|
92
|
+
|
|
93
|
+
self._n_items = n
|
|
94
|
+
self._fitted = True
|
|
95
|
+
|
|
96
|
+
def encode(self, x: np.ndarray) -> Tuple[int, float]:
|
|
97
|
+
"""Find nearest centroid. Returns (centroid_id, quantization_error)."""
|
|
98
|
+
if not self.fitted:
|
|
99
|
+
raise RuntimeError("Codebook not fitted. Call fit() or add items first.")
|
|
100
|
+
dots = x @ self.centroids.T # (K,)
|
|
101
|
+
best = int(np.argmax(dots))
|
|
102
|
+
qerr = 1.0 - float(dots[best])
|
|
103
|
+
return best, qerr
|
|
104
|
+
|
|
105
|
+
def conf(self, x: np.ndarray) -> float:
|
|
106
|
+
"""Confidence: 1 / (1 + qerr)."""
|
|
107
|
+
_, qerr = self.encode(x)
|
|
108
|
+
return 1.0 / (1.0 + qerr)
|
|
109
|
+
|
|
110
|
+
def margin(self, x: np.ndarray) -> float:
|
|
111
|
+
"""Gap between top-2 centroid similarities."""
|
|
112
|
+
if not self.fitted or self.K < 2:
|
|
113
|
+
return 0.0
|
|
114
|
+
dots = x @ self.centroids.T
|
|
115
|
+
top2 = np.partition(dots, -2)[-2:]
|
|
116
|
+
return float(top2.max() - top2.min())
|
|
117
|
+
|
|
118
|
+
def top_centroids(self, x: np.ndarray, n: int = 3) -> List[int]:
|
|
119
|
+
"""Return top-n centroid IDs by similarity."""
|
|
120
|
+
if not self.fitted:
|
|
121
|
+
return []
|
|
122
|
+
dots = x @ self.centroids.T
|
|
123
|
+
n = min(n, self.K)
|
|
124
|
+
return list(np.argsort(dots)[-n:][::-1])
|
|
125
|
+
|
|
126
|
+
def add(self, item_id: str, embedding: np.ndarray) -> int:
|
|
127
|
+
"""Assign item to nearest centroid. Returns centroid ID."""
|
|
128
|
+
if not self.fitted:
|
|
129
|
+
# Bootstrap: collect items and fit when we have enough
|
|
130
|
+
if 0 not in self.buckets:
|
|
131
|
+
self.buckets[0] = CentroidBucket()
|
|
132
|
+
self.buckets[0].item_ids.append(item_id)
|
|
133
|
+
self.buckets[0].embeddings.append(embedding)
|
|
134
|
+
self._n_items += 1
|
|
135
|
+
|
|
136
|
+
if self._n_items >= self.K_MIN:
|
|
137
|
+
self._refit()
|
|
138
|
+
return 0
|
|
139
|
+
|
|
140
|
+
cid, _ = self.encode(embedding)
|
|
141
|
+
if cid not in self.buckets:
|
|
142
|
+
self.buckets[cid] = CentroidBucket()
|
|
143
|
+
self.buckets[cid].item_ids.append(item_id)
|
|
144
|
+
self.buckets[cid].embeddings.append(embedding)
|
|
145
|
+
self._n_items += 1
|
|
146
|
+
|
|
147
|
+
# Check if we need to refit (bucket too large)
|
|
148
|
+
if self._n_items % 100 == 0:
|
|
149
|
+
ideal_k = self._adaptive_k(self._n_items)
|
|
150
|
+
if ideal_k > self.K * 1.5:
|
|
151
|
+
self._refit()
|
|
152
|
+
|
|
153
|
+
return cid
|
|
154
|
+
|
|
155
|
+
def _refit(self) -> None:
|
|
156
|
+
"""Refit codebook from all stored embeddings."""
|
|
157
|
+
all_ids = []
|
|
158
|
+
all_embs = []
|
|
159
|
+
for bucket in self.buckets.values():
|
|
160
|
+
all_ids.extend(bucket.item_ids)
|
|
161
|
+
all_embs.extend(bucket.embeddings)
|
|
162
|
+
if len(all_embs) >= self.K_MIN:
|
|
163
|
+
emb_matrix = np.array(all_embs, dtype=np.float32)
|
|
164
|
+
self.fit(emb_matrix, all_ids)
|
|
165
|
+
|
|
166
|
+
def get_bucket_items(self, centroid_id: int) -> CentroidBucket:
|
|
167
|
+
return self.buckets.get(centroid_id, CentroidBucket())
|
|
168
|
+
|
|
169
|
+
# ─── Online Adaptation ────────────────────────────────────────────────
|
|
170
|
+
|
|
171
|
+
EMA_ETA = 0.01 # EMA learning rate
|
|
172
|
+
SPLIT_VARIANCE_THRESHOLD = 0.15 # bucket variance above this → split
|
|
173
|
+
PRUNE_IDLE_SECONDS = 3600 # 1 hour idle → prune candidate
|
|
174
|
+
UPDATE_INTERVAL = 50 # run adaptation every N queries
|
|
175
|
+
|
|
176
|
+
def ema_update(self, centroid_id: int, new_embedding: np.ndarray,
|
|
177
|
+
eta: float = None) -> None:
|
|
178
|
+
"""EMA update: centroid drifts toward new data point."""
|
|
179
|
+
if not self.fitted or centroid_id >= self.K:
|
|
180
|
+
return
|
|
181
|
+
eta = eta or self.EMA_ETA
|
|
182
|
+
old = self.centroids[centroid_id]
|
|
183
|
+
updated = (1 - eta) * old + eta * new_embedding
|
|
184
|
+
# Renormalize
|
|
185
|
+
norm = np.linalg.norm(updated)
|
|
186
|
+
if norm > 1e-10:
|
|
187
|
+
updated = updated / norm
|
|
188
|
+
self.centroids[centroid_id] = updated.astype(np.float32)
|
|
189
|
+
|
|
190
|
+
# Track access
|
|
191
|
+
if centroid_id in self.buckets:
|
|
192
|
+
self.buckets[centroid_id].last_accessed = time.time()
|
|
193
|
+
self.buckets[centroid_id].access_count += 1
|
|
194
|
+
|
|
195
|
+
def bucket_variance(self, centroid_id: int) -> float:
|
|
196
|
+
"""Compute intra-bucket variance (mean squared distance to centroid)."""
|
|
197
|
+
bucket = self.buckets.get(centroid_id)
|
|
198
|
+
if not bucket or len(bucket.embeddings) < 2:
|
|
199
|
+
return 0.0
|
|
200
|
+
centroid = self.centroids[centroid_id]
|
|
201
|
+
embs = np.array(bucket.embeddings, dtype=np.float32)
|
|
202
|
+
dists = 1.0 - embs @ centroid # cosine distance
|
|
203
|
+
return float(np.mean(dists ** 2))
|
|
204
|
+
|
|
205
|
+
def maybe_split(self, centroid_id: int) -> bool:
|
|
206
|
+
"""Split a high-variance bucket into 2 if K < K_MAX.
|
|
207
|
+
|
|
208
|
+
Returns True if split occurred.
|
|
209
|
+
"""
|
|
210
|
+
if self.K >= self.K_MAX:
|
|
211
|
+
return False
|
|
212
|
+
variance = self.bucket_variance(centroid_id)
|
|
213
|
+
if variance < self.SPLIT_VARIANCE_THRESHOLD:
|
|
214
|
+
return False
|
|
215
|
+
|
|
216
|
+
bucket = self.buckets.get(centroid_id)
|
|
217
|
+
if not bucket or len(bucket.embeddings) < 4:
|
|
218
|
+
return False
|
|
219
|
+
|
|
220
|
+
# Mini k-means on bucket items
|
|
221
|
+
from sklearn.cluster import KMeans
|
|
222
|
+
embs = np.array(bucket.embeddings, dtype=np.float32)
|
|
223
|
+
km = KMeans(n_clusters=2, random_state=42, n_init=3)
|
|
224
|
+
labels = km.fit_predict(embs)
|
|
225
|
+
|
|
226
|
+
# Create new centroids
|
|
227
|
+
c0 = km.cluster_centers_[0]
|
|
228
|
+
c1 = km.cluster_centers_[1]
|
|
229
|
+
c0 = c0 / max(np.linalg.norm(c0), 1e-10)
|
|
230
|
+
c1 = c1 / max(np.linalg.norm(c1), 1e-10)
|
|
231
|
+
|
|
232
|
+
# Replace original centroid with c0, add c1 as new
|
|
233
|
+
self.centroids[centroid_id] = c0.astype(np.float32)
|
|
234
|
+
new_cid = self.K # new centroid ID
|
|
235
|
+
self.centroids = np.vstack([self.centroids, c1.astype(np.float32).reshape(1, -1)])
|
|
236
|
+
|
|
237
|
+
# Redistribute items
|
|
238
|
+
new_bucket_0 = CentroidBucket()
|
|
239
|
+
new_bucket_1 = CentroidBucket()
|
|
240
|
+
for i, label in enumerate(labels):
|
|
241
|
+
if label == 0:
|
|
242
|
+
new_bucket_0.item_ids.append(bucket.item_ids[i])
|
|
243
|
+
new_bucket_0.embeddings.append(bucket.embeddings[i])
|
|
244
|
+
else:
|
|
245
|
+
new_bucket_1.item_ids.append(bucket.item_ids[i])
|
|
246
|
+
new_bucket_1.embeddings.append(bucket.embeddings[i])
|
|
247
|
+
|
|
248
|
+
self.buckets[centroid_id] = new_bucket_0
|
|
249
|
+
self.buckets[new_cid] = new_bucket_1
|
|
250
|
+
|
|
251
|
+
logger.info(f"Split centroid {centroid_id}: variance={variance:.4f}, "
|
|
252
|
+
f"K={self.K}, sizes={len(new_bucket_0.item_ids)}/{len(new_bucket_1.item_ids)}")
|
|
253
|
+
return True
|
|
254
|
+
|
|
255
|
+
def maybe_prune(self, centroid_id: int) -> bool:
|
|
256
|
+
"""Prune idle centroid: reassign items to nearest neighbors.
|
|
257
|
+
|
|
258
|
+
Returns True if pruned.
|
|
259
|
+
"""
|
|
260
|
+
if self.K <= self.K_MIN:
|
|
261
|
+
return False
|
|
262
|
+
|
|
263
|
+
bucket = self.buckets.get(centroid_id)
|
|
264
|
+
if not bucket:
|
|
265
|
+
return False
|
|
266
|
+
|
|
267
|
+
idle_time = time.time() - bucket.last_accessed
|
|
268
|
+
if idle_time < self.PRUNE_IDLE_SECONDS:
|
|
269
|
+
return False
|
|
270
|
+
|
|
271
|
+
# Reassign items to nearest remaining centroids
|
|
272
|
+
for i, emb in enumerate(bucket.embeddings):
|
|
273
|
+
dots = emb @ self.centroids.T
|
|
274
|
+
# Exclude the centroid being pruned
|
|
275
|
+
dots[centroid_id] = -np.inf
|
|
276
|
+
new_cid = int(np.argmax(dots))
|
|
277
|
+
if new_cid not in self.buckets:
|
|
278
|
+
self.buckets[new_cid] = CentroidBucket()
|
|
279
|
+
self.buckets[new_cid].item_ids.append(bucket.item_ids[i])
|
|
280
|
+
self.buckets[new_cid].embeddings.append(emb)
|
|
281
|
+
|
|
282
|
+
# Remove centroid
|
|
283
|
+
del self.buckets[centroid_id]
|
|
284
|
+
# Rebuild centroids array without the pruned one
|
|
285
|
+
mask = np.ones(self.K, dtype=bool)
|
|
286
|
+
mask[centroid_id] = False
|
|
287
|
+
self.centroids = self.centroids[mask]
|
|
288
|
+
# Remap bucket keys
|
|
289
|
+
self._remap_buckets_after_prune(centroid_id)
|
|
290
|
+
|
|
291
|
+
logger.info(f"Pruned centroid {centroid_id}: idle={idle_time:.0f}s, K={self.K}")
|
|
292
|
+
return True
|
|
293
|
+
|
|
294
|
+
def _remap_buckets_after_prune(self, removed_id: int) -> None:
|
|
295
|
+
"""Remap bucket IDs after removing a centroid."""
|
|
296
|
+
new_buckets = {}
|
|
297
|
+
for old_id, bucket in sorted(self.buckets.items()):
|
|
298
|
+
if old_id < removed_id:
|
|
299
|
+
new_buckets[old_id] = bucket
|
|
300
|
+
elif old_id > removed_id:
|
|
301
|
+
new_buckets[old_id - 1] = bucket
|
|
302
|
+
self.buckets = new_buckets
|
|
303
|
+
|
|
304
|
+
def run_adaptation(self, force: bool = False) -> Dict:
|
|
305
|
+
"""Run a full adaptation cycle: EMA already applied per-query,
|
|
306
|
+
this handles splits and prunes.
|
|
307
|
+
|
|
308
|
+
Returns dict of operations performed.
|
|
309
|
+
"""
|
|
310
|
+
ops = {"splits": 0, "prunes": 0, "k_before": self.K}
|
|
311
|
+
|
|
312
|
+
if not self.fitted:
|
|
313
|
+
ops["k_after"] = self.K
|
|
314
|
+
return ops
|
|
315
|
+
|
|
316
|
+
# Check splits on high-variance buckets
|
|
317
|
+
cids_to_check = list(self.buckets.keys())
|
|
318
|
+
for cid in cids_to_check:
|
|
319
|
+
if cid < self.K and self.maybe_split(cid):
|
|
320
|
+
ops["splits"] += 1
|
|
321
|
+
|
|
322
|
+
# Check prunes on idle buckets (only during explicit adaptation)
|
|
323
|
+
if force:
|
|
324
|
+
cids_to_check = list(self.buckets.keys())
|
|
325
|
+
for cid in cids_to_check:
|
|
326
|
+
if cid < self.K and self.maybe_prune(cid):
|
|
327
|
+
ops["prunes"] += 1
|
|
328
|
+
|
|
329
|
+
ops["k_after"] = self.K
|
|
330
|
+
return ops
|
|
331
|
+
|
|
332
|
+
def get_state(self) -> Tuple[Optional[np.ndarray], Dict]:
|
|
333
|
+
"""For persistence: return (centroids, {cid: [item_ids]})."""
|
|
334
|
+
assignments = {}
|
|
335
|
+
for cid, bucket in self.buckets.items():
|
|
336
|
+
assignments[cid] = bucket.item_ids
|
|
337
|
+
return self.centroids, assignments
|
rm/drift.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Drift Monitor — Detects distribution shift via quantization error tracking.
|
|
2
|
+
|
|
3
|
+
Monitors rolling qerr and margin to detect when incoming queries diverge
|
|
4
|
+
from the codebook's trained distribution. When drift is detected:
|
|
5
|
+
- Triggers aggressive codebook adaptation (splits/updates)
|
|
6
|
+
- Logs alarm for observability
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
from collections import deque
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import Deque, List, Optional
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger("rm.drift")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class DriftAlarm:
|
|
20
|
+
"""Record of a drift event."""
|
|
21
|
+
timestamp: float
|
|
22
|
+
ratio: float # qerr_recent / qerr_baseline
|
|
23
|
+
episode: int
|
|
24
|
+
message: str
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DriftMonitor:
|
|
28
|
+
"""Tracks rolling quantization error to detect distribution drift.
|
|
29
|
+
|
|
30
|
+
Alarm triggers when recent qerr exceeds baseline by a configurable factor.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
WINDOW_SIZE = 20 # rolling window for recent stats
|
|
34
|
+
BASELINE_WINDOW = 50 # initial baseline window
|
|
35
|
+
ALARM_RATIO = 1.3 # qerr ratio threshold for alarm
|
|
36
|
+
COOLDOWN_QUERIES = 30 # min queries between alarms
|
|
37
|
+
|
|
38
|
+
def __init__(self):
|
|
39
|
+
self._qerr_history: Deque[float] = deque(maxlen=2000)
|
|
40
|
+
self._margin_history: Deque[float] = deque(maxlen=2000)
|
|
41
|
+
self._baseline_qerr: Optional[float] = None
|
|
42
|
+
self._baseline_margin: Optional[float] = None
|
|
43
|
+
self._episode = 0
|
|
44
|
+
self._last_alarm_episode = -self.COOLDOWN_QUERIES
|
|
45
|
+
self._alarms: List[DriftAlarm] = []
|
|
46
|
+
self._in_drift = False
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def episode_count(self) -> int:
|
|
50
|
+
return self._episode
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def in_drift(self) -> bool:
|
|
54
|
+
return self._in_drift
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def alarms(self) -> List[DriftAlarm]:
|
|
58
|
+
return list(self._alarms)
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def baseline_qerr(self) -> Optional[float]:
|
|
62
|
+
return self._baseline_qerr
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def recent_qerr(self) -> float:
|
|
66
|
+
if len(self._qerr_history) < 5:
|
|
67
|
+
return 0.0
|
|
68
|
+
recent = list(self._qerr_history)[-self.WINDOW_SIZE:]
|
|
69
|
+
return sum(recent) / len(recent)
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def drift_ratio(self) -> float:
|
|
73
|
+
if self._baseline_qerr is None or self._baseline_qerr < 1e-10:
|
|
74
|
+
return 0.0
|
|
75
|
+
return self.recent_qerr / self._baseline_qerr
|
|
76
|
+
|
|
77
|
+
def record(self, qerr: float, margin: float = 0.0) -> Optional[DriftAlarm]:
|
|
78
|
+
"""Record a query's quantization error and margin.
|
|
79
|
+
|
|
80
|
+
Returns DriftAlarm if alarm was triggered, None otherwise.
|
|
81
|
+
"""
|
|
82
|
+
self._episode += 1
|
|
83
|
+
self._qerr_history.append(qerr)
|
|
84
|
+
self._margin_history.append(margin)
|
|
85
|
+
|
|
86
|
+
# Establish baseline from first N queries
|
|
87
|
+
if self._baseline_qerr is None and len(self._qerr_history) >= self.BASELINE_WINDOW:
|
|
88
|
+
baseline_data = list(self._qerr_history)[:self.BASELINE_WINDOW]
|
|
89
|
+
self._baseline_qerr = sum(baseline_data) / len(baseline_data)
|
|
90
|
+
baseline_margins = list(self._margin_history)[:self.BASELINE_WINDOW]
|
|
91
|
+
self._baseline_margin = sum(baseline_margins) / len(baseline_margins)
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
# Check for drift
|
|
95
|
+
return self._check_alarm()
|
|
96
|
+
|
|
97
|
+
def _check_alarm(self) -> Optional[DriftAlarm]:
|
|
98
|
+
"""Check if current qerr indicates drift."""
|
|
99
|
+
if self._baseline_qerr is None:
|
|
100
|
+
return None
|
|
101
|
+
if len(self._qerr_history) < self.WINDOW_SIZE:
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
ratio = self.drift_ratio
|
|
105
|
+
|
|
106
|
+
if ratio >= self.ALARM_RATIO:
|
|
107
|
+
if not self._in_drift:
|
|
108
|
+
self._in_drift = True
|
|
109
|
+
# Rate-limit alarms
|
|
110
|
+
if self._episode - self._last_alarm_episode >= self.COOLDOWN_QUERIES:
|
|
111
|
+
alarm = DriftAlarm(
|
|
112
|
+
timestamp=time.time(),
|
|
113
|
+
ratio=ratio,
|
|
114
|
+
episode=self._episode,
|
|
115
|
+
message=f"DRIFT_ALARM: qerr {ratio:.1f}x baseline at episode {self._episode}",
|
|
116
|
+
)
|
|
117
|
+
self._alarms.append(alarm)
|
|
118
|
+
self._last_alarm_episode = self._episode
|
|
119
|
+
logger.warning(f"[rm] {alarm.message}")
|
|
120
|
+
return alarm
|
|
121
|
+
else:
|
|
122
|
+
if self._in_drift:
|
|
123
|
+
self._in_drift = False
|
|
124
|
+
logger.info(f"[rm] Drift recovered at episode {self._episode}, "
|
|
125
|
+
f"ratio={ratio:.2f}")
|
|
126
|
+
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
def reset_baseline(self) -> None:
|
|
130
|
+
"""Reset baseline using recent data (after adaptation)."""
|
|
131
|
+
if len(self._qerr_history) >= self.WINDOW_SIZE:
|
|
132
|
+
recent = list(self._qerr_history)[-self.WINDOW_SIZE:]
|
|
133
|
+
self._baseline_qerr = sum(recent) / len(recent)
|
|
134
|
+
self._in_drift = False
|
|
135
|
+
logger.info(f"[rm] Baseline reset: new baseline_qerr={self._baseline_qerr:.4f}")
|
|
136
|
+
|
|
137
|
+
def get_stats(self) -> dict:
|
|
138
|
+
return {
|
|
139
|
+
"episode": self._episode,
|
|
140
|
+
"baseline_qerr": self._baseline_qerr,
|
|
141
|
+
"recent_qerr": self.recent_qerr,
|
|
142
|
+
"drift_ratio": self.drift_ratio,
|
|
143
|
+
"in_drift": self._in_drift,
|
|
144
|
+
"alarm_count": len(self._alarms),
|
|
145
|
+
}
|
rm/embeddings/base.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Abstract embedding backend interface."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import List
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EmbeddingBackend(ABC):
|
|
9
|
+
"""Abstract base for embedding providers."""
|
|
10
|
+
|
|
11
|
+
@property
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def dim(self) -> int:
|
|
14
|
+
"""Embedding dimensionality."""
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def encode(self, text: str) -> np.ndarray:
|
|
18
|
+
"""Encode a single text to a normalized vector."""
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def encode_batch(self, texts: List[str]) -> np.ndarray:
|
|
22
|
+
"""Encode a batch of texts. Returns (N, dim) matrix."""
|
rm/embeddings/local.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Local sentence-transformers embedding backend."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from .base import EmbeddingBackend
|
|
7
|
+
|
|
8
|
+
_model_instance = None
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _get_model(model_name: str):
|
|
12
|
+
"""Lazy singleton for SentenceTransformer model."""
|
|
13
|
+
global _model_instance
|
|
14
|
+
if _model_instance is None:
|
|
15
|
+
from sentence_transformers import SentenceTransformer
|
|
16
|
+
_model_instance = SentenceTransformer(model_name)
|
|
17
|
+
return _model_instance
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LocalEmbeddings(EmbeddingBackend):
|
|
21
|
+
"""sentence-transformers wrapper with caching. Default: all-MiniLM-L6-v2 (d=384)."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
|
24
|
+
self.model_name = model_name
|
|
25
|
+
self._cache: Dict[str, np.ndarray] = {}
|
|
26
|
+
self._dim: int = 384 # MiniLM default
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def dim(self) -> int:
|
|
30
|
+
return self._dim
|
|
31
|
+
|
|
32
|
+
def encode(self, text: str) -> np.ndarray:
|
|
33
|
+
if text in self._cache:
|
|
34
|
+
return self._cache[text]
|
|
35
|
+
model = _get_model(self.model_name)
|
|
36
|
+
vec = model.encode(text, normalize_embeddings=True)
|
|
37
|
+
vec = np.asarray(vec, dtype=np.float32)
|
|
38
|
+
self._dim = vec.shape[0]
|
|
39
|
+
self._cache[text] = vec
|
|
40
|
+
return vec
|
|
41
|
+
|
|
42
|
+
def encode_batch(self, texts: List[str]) -> np.ndarray:
|
|
43
|
+
# Check cache for all
|
|
44
|
+
uncached = [(i, t) for i, t in enumerate(texts) if t not in self._cache]
|
|
45
|
+
|
|
46
|
+
if uncached:
|
|
47
|
+
model = _get_model(self.model_name)
|
|
48
|
+
uncached_texts = [t for _, t in uncached]
|
|
49
|
+
vecs = model.encode(uncached_texts, normalize_embeddings=True, batch_size=64)
|
|
50
|
+
vecs = np.asarray(vecs, dtype=np.float32)
|
|
51
|
+
self._dim = vecs.shape[1]
|
|
52
|
+
for (_, text), vec in zip(uncached, vecs):
|
|
53
|
+
self._cache[text] = vec
|
|
54
|
+
|
|
55
|
+
result = np.array([self._cache[t] for t in texts], dtype=np.float32)
|
|
56
|
+
return result
|
rm/filtering.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Score-based token filtering (τ threshold)."""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
from .retrieval import RetrievalResult
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def filter_by_score(results: List[RetrievalResult], tau: float = 0.5) -> List[RetrievalResult]:
|
|
8
|
+
"""Filter retrieval results below score threshold τ."""
|
|
9
|
+
return [r for r in results if r.score >= tau]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def filter_top_n(results: List[RetrievalResult], n: int = 3) -> List[RetrievalResult]:
|
|
13
|
+
"""Keep only top-n results by score."""
|
|
14
|
+
sorted_results = sorted(results, key=lambda r: r.score, reverse=True)
|
|
15
|
+
return sorted_results[:n]
|