routing-memory 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rm/__init__.py ADDED
@@ -0,0 +1,28 @@
1
+ """Routing Memory — Lightweight long-term memory via vector-quantized routing."""
2
+
3
+ from .memory import RoutingMemory
4
+ from .codebook import Codebook, CentroidBucket
5
+ from .retrieval import L1Retriever, L1Result, RetrievalResult
6
+ from .filtering import filter_by_score, filter_top_n
7
+ from .drift import DriftMonitor, DriftAlarm
8
+ from .embeddings import LocalEmbeddings, EmbeddingBackend
9
+ from .storage import RMSQLiteBackend, RMStorageBackend
10
+
11
+ __all__ = [
12
+ "RoutingMemory",
13
+ "Codebook",
14
+ "CentroidBucket",
15
+ "L1Retriever",
16
+ "L1Result",
17
+ "RetrievalResult",
18
+ "filter_by_score",
19
+ "filter_top_n",
20
+ "DriftMonitor",
21
+ "DriftAlarm",
22
+ "LocalEmbeddings",
23
+ "EmbeddingBackend",
24
+ "RMSQLiteBackend",
25
+ "RMStorageBackend",
26
+ ]
27
+
28
+ __version__ = "0.1.0"
rm/codebook.py ADDED
@@ -0,0 +1,337 @@
1
+ """VQ Codebook: MiniBatchKMeans with adaptive K + online adaptation.
2
+
3
+ Features:
4
+ - Adaptive K: K = ceil(N / B_target)
5
+ - EMA centroid update: centroids drift toward recent data
6
+ - Split: high-variance buckets split into 2
7
+ - Prune: idle centroids get reassigned and removed
8
+ - K range guard: K stays in [K_MIN, K_MAX]
9
+ """
10
+
11
+ import logging
12
+ import time
13
+ import numpy as np
14
+ from typing import Dict, List, Optional, Tuple
15
+ from dataclasses import dataclass, field
16
+
17
+ logger = logging.getLogger("rm.codebook")
18
+
19
+
20
+ @dataclass
21
+ class CentroidBucket:
22
+ """Items assigned to a centroid."""
23
+ item_ids: List[str] = field(default_factory=list)
24
+ embeddings: List[np.ndarray] = field(default_factory=list)
25
+ last_accessed: float = field(default_factory=time.time)
26
+ access_count: int = 0
27
+
28
+
29
+ class Codebook:
30
+ """Vector Quantization codebook using MiniBatchKMeans.
31
+
32
+ Adaptive K: K = ceil(N / B_target) where B_target is items per bucket.
33
+ """
34
+
35
+ B_TARGET = 20 # target items per bucket
36
+ K_MIN = 2
37
+ K_MAX = 512
38
+
39
+ def __init__(self, dim: int = 384, seed: int = 42):
40
+ self.dim = dim
41
+ self.seed = seed
42
+ self.centroids: Optional[np.ndarray] = None # (K, dim)
43
+ self.buckets: Dict[int, CentroidBucket] = {}
44
+ self._n_items = 0
45
+ self._fitted = False
46
+
47
+ @property
48
+ def K(self) -> int:
49
+ if self.centroids is None:
50
+ return 0
51
+ return self.centroids.shape[0]
52
+
53
+ @property
54
+ def fitted(self) -> bool:
55
+ return self._fitted and self.centroids is not None
56
+
57
+ def _adaptive_k(self, n: int) -> int:
58
+ k = max(self.K_MIN, int(np.ceil(n / self.B_TARGET)))
59
+ return min(k, self.K_MAX)
60
+
61
+ def fit(self, embeddings: np.ndarray, item_ids: List[str]) -> None:
62
+ """Fit codebook from scratch using MiniBatchKMeans."""
63
+ n = embeddings.shape[0]
64
+ if n == 0:
65
+ return
66
+
67
+ k = self._adaptive_k(n)
68
+ k = min(k, n) # can't have more centroids than items
69
+
70
+ from sklearn.cluster import MiniBatchKMeans
71
+ kmeans = MiniBatchKMeans(
72
+ n_clusters=k,
73
+ random_state=self.seed,
74
+ batch_size=min(256, n),
75
+ n_init=3,
76
+ )
77
+ labels = kmeans.fit_predict(embeddings)
78
+
79
+ # Normalize centroids
80
+ norms = np.linalg.norm(kmeans.cluster_centers_, axis=1, keepdims=True)
81
+ norms = np.maximum(norms, 1e-10)
82
+ self.centroids = (kmeans.cluster_centers_ / norms).astype(np.float32)
83
+
84
+ # Build buckets
85
+ self.buckets = {}
86
+ for idx, label in enumerate(labels):
87
+ label = int(label)
88
+ if label not in self.buckets:
89
+ self.buckets[label] = CentroidBucket()
90
+ self.buckets[label].item_ids.append(item_ids[idx])
91
+ self.buckets[label].embeddings.append(embeddings[idx])
92
+
93
+ self._n_items = n
94
+ self._fitted = True
95
+
96
+ def encode(self, x: np.ndarray) -> Tuple[int, float]:
97
+ """Find nearest centroid. Returns (centroid_id, quantization_error)."""
98
+ if not self.fitted:
99
+ raise RuntimeError("Codebook not fitted. Call fit() or add items first.")
100
+ dots = x @ self.centroids.T # (K,)
101
+ best = int(np.argmax(dots))
102
+ qerr = 1.0 - float(dots[best])
103
+ return best, qerr
104
+
105
+ def conf(self, x: np.ndarray) -> float:
106
+ """Confidence: 1 / (1 + qerr)."""
107
+ _, qerr = self.encode(x)
108
+ return 1.0 / (1.0 + qerr)
109
+
110
+ def margin(self, x: np.ndarray) -> float:
111
+ """Gap between top-2 centroid similarities."""
112
+ if not self.fitted or self.K < 2:
113
+ return 0.0
114
+ dots = x @ self.centroids.T
115
+ top2 = np.partition(dots, -2)[-2:]
116
+ return float(top2.max() - top2.min())
117
+
118
+ def top_centroids(self, x: np.ndarray, n: int = 3) -> List[int]:
119
+ """Return top-n centroid IDs by similarity."""
120
+ if not self.fitted:
121
+ return []
122
+ dots = x @ self.centroids.T
123
+ n = min(n, self.K)
124
+ return list(np.argsort(dots)[-n:][::-1])
125
+
126
+ def add(self, item_id: str, embedding: np.ndarray) -> int:
127
+ """Assign item to nearest centroid. Returns centroid ID."""
128
+ if not self.fitted:
129
+ # Bootstrap: collect items and fit when we have enough
130
+ if 0 not in self.buckets:
131
+ self.buckets[0] = CentroidBucket()
132
+ self.buckets[0].item_ids.append(item_id)
133
+ self.buckets[0].embeddings.append(embedding)
134
+ self._n_items += 1
135
+
136
+ if self._n_items >= self.K_MIN:
137
+ self._refit()
138
+ return 0
139
+
140
+ cid, _ = self.encode(embedding)
141
+ if cid not in self.buckets:
142
+ self.buckets[cid] = CentroidBucket()
143
+ self.buckets[cid].item_ids.append(item_id)
144
+ self.buckets[cid].embeddings.append(embedding)
145
+ self._n_items += 1
146
+
147
+ # Check if we need to refit (bucket too large)
148
+ if self._n_items % 100 == 0:
149
+ ideal_k = self._adaptive_k(self._n_items)
150
+ if ideal_k > self.K * 1.5:
151
+ self._refit()
152
+
153
+ return cid
154
+
155
+ def _refit(self) -> None:
156
+ """Refit codebook from all stored embeddings."""
157
+ all_ids = []
158
+ all_embs = []
159
+ for bucket in self.buckets.values():
160
+ all_ids.extend(bucket.item_ids)
161
+ all_embs.extend(bucket.embeddings)
162
+ if len(all_embs) >= self.K_MIN:
163
+ emb_matrix = np.array(all_embs, dtype=np.float32)
164
+ self.fit(emb_matrix, all_ids)
165
+
166
+ def get_bucket_items(self, centroid_id: int) -> CentroidBucket:
167
+ return self.buckets.get(centroid_id, CentroidBucket())
168
+
169
+ # ─── Online Adaptation ────────────────────────────────────────────────
170
+
171
+ EMA_ETA = 0.01 # EMA learning rate
172
+ SPLIT_VARIANCE_THRESHOLD = 0.15 # bucket variance above this → split
173
+ PRUNE_IDLE_SECONDS = 3600 # 1 hour idle → prune candidate
174
+ UPDATE_INTERVAL = 50 # run adaptation every N queries
175
+
176
+ def ema_update(self, centroid_id: int, new_embedding: np.ndarray,
177
+ eta: float = None) -> None:
178
+ """EMA update: centroid drifts toward new data point."""
179
+ if not self.fitted or centroid_id >= self.K:
180
+ return
181
+ eta = eta or self.EMA_ETA
182
+ old = self.centroids[centroid_id]
183
+ updated = (1 - eta) * old + eta * new_embedding
184
+ # Renormalize
185
+ norm = np.linalg.norm(updated)
186
+ if norm > 1e-10:
187
+ updated = updated / norm
188
+ self.centroids[centroid_id] = updated.astype(np.float32)
189
+
190
+ # Track access
191
+ if centroid_id in self.buckets:
192
+ self.buckets[centroid_id].last_accessed = time.time()
193
+ self.buckets[centroid_id].access_count += 1
194
+
195
+ def bucket_variance(self, centroid_id: int) -> float:
196
+ """Compute intra-bucket variance (mean squared distance to centroid)."""
197
+ bucket = self.buckets.get(centroid_id)
198
+ if not bucket or len(bucket.embeddings) < 2:
199
+ return 0.0
200
+ centroid = self.centroids[centroid_id]
201
+ embs = np.array(bucket.embeddings, dtype=np.float32)
202
+ dists = 1.0 - embs @ centroid # cosine distance
203
+ return float(np.mean(dists ** 2))
204
+
205
+ def maybe_split(self, centroid_id: int) -> bool:
206
+ """Split a high-variance bucket into 2 if K < K_MAX.
207
+
208
+ Returns True if split occurred.
209
+ """
210
+ if self.K >= self.K_MAX:
211
+ return False
212
+ variance = self.bucket_variance(centroid_id)
213
+ if variance < self.SPLIT_VARIANCE_THRESHOLD:
214
+ return False
215
+
216
+ bucket = self.buckets.get(centroid_id)
217
+ if not bucket or len(bucket.embeddings) < 4:
218
+ return False
219
+
220
+ # Mini k-means on bucket items
221
+ from sklearn.cluster import KMeans
222
+ embs = np.array(bucket.embeddings, dtype=np.float32)
223
+ km = KMeans(n_clusters=2, random_state=42, n_init=3)
224
+ labels = km.fit_predict(embs)
225
+
226
+ # Create new centroids
227
+ c0 = km.cluster_centers_[0]
228
+ c1 = km.cluster_centers_[1]
229
+ c0 = c0 / max(np.linalg.norm(c0), 1e-10)
230
+ c1 = c1 / max(np.linalg.norm(c1), 1e-10)
231
+
232
+ # Replace original centroid with c0, add c1 as new
233
+ self.centroids[centroid_id] = c0.astype(np.float32)
234
+ new_cid = self.K # new centroid ID
235
+ self.centroids = np.vstack([self.centroids, c1.astype(np.float32).reshape(1, -1)])
236
+
237
+ # Redistribute items
238
+ new_bucket_0 = CentroidBucket()
239
+ new_bucket_1 = CentroidBucket()
240
+ for i, label in enumerate(labels):
241
+ if label == 0:
242
+ new_bucket_0.item_ids.append(bucket.item_ids[i])
243
+ new_bucket_0.embeddings.append(bucket.embeddings[i])
244
+ else:
245
+ new_bucket_1.item_ids.append(bucket.item_ids[i])
246
+ new_bucket_1.embeddings.append(bucket.embeddings[i])
247
+
248
+ self.buckets[centroid_id] = new_bucket_0
249
+ self.buckets[new_cid] = new_bucket_1
250
+
251
+ logger.info(f"Split centroid {centroid_id}: variance={variance:.4f}, "
252
+ f"K={self.K}, sizes={len(new_bucket_0.item_ids)}/{len(new_bucket_1.item_ids)}")
253
+ return True
254
+
255
+ def maybe_prune(self, centroid_id: int) -> bool:
256
+ """Prune idle centroid: reassign items to nearest neighbors.
257
+
258
+ Returns True if pruned.
259
+ """
260
+ if self.K <= self.K_MIN:
261
+ return False
262
+
263
+ bucket = self.buckets.get(centroid_id)
264
+ if not bucket:
265
+ return False
266
+
267
+ idle_time = time.time() - bucket.last_accessed
268
+ if idle_time < self.PRUNE_IDLE_SECONDS:
269
+ return False
270
+
271
+ # Reassign items to nearest remaining centroids
272
+ for i, emb in enumerate(bucket.embeddings):
273
+ dots = emb @ self.centroids.T
274
+ # Exclude the centroid being pruned
275
+ dots[centroid_id] = -np.inf
276
+ new_cid = int(np.argmax(dots))
277
+ if new_cid not in self.buckets:
278
+ self.buckets[new_cid] = CentroidBucket()
279
+ self.buckets[new_cid].item_ids.append(bucket.item_ids[i])
280
+ self.buckets[new_cid].embeddings.append(emb)
281
+
282
+ # Remove centroid
283
+ del self.buckets[centroid_id]
284
+ # Rebuild centroids array without the pruned one
285
+ mask = np.ones(self.K, dtype=bool)
286
+ mask[centroid_id] = False
287
+ self.centroids = self.centroids[mask]
288
+ # Remap bucket keys
289
+ self._remap_buckets_after_prune(centroid_id)
290
+
291
+ logger.info(f"Pruned centroid {centroid_id}: idle={idle_time:.0f}s, K={self.K}")
292
+ return True
293
+
294
+ def _remap_buckets_after_prune(self, removed_id: int) -> None:
295
+ """Remap bucket IDs after removing a centroid."""
296
+ new_buckets = {}
297
+ for old_id, bucket in sorted(self.buckets.items()):
298
+ if old_id < removed_id:
299
+ new_buckets[old_id] = bucket
300
+ elif old_id > removed_id:
301
+ new_buckets[old_id - 1] = bucket
302
+ self.buckets = new_buckets
303
+
304
+ def run_adaptation(self, force: bool = False) -> Dict:
305
+ """Run a full adaptation cycle: EMA already applied per-query,
306
+ this handles splits and prunes.
307
+
308
+ Returns dict of operations performed.
309
+ """
310
+ ops = {"splits": 0, "prunes": 0, "k_before": self.K}
311
+
312
+ if not self.fitted:
313
+ ops["k_after"] = self.K
314
+ return ops
315
+
316
+ # Check splits on high-variance buckets
317
+ cids_to_check = list(self.buckets.keys())
318
+ for cid in cids_to_check:
319
+ if cid < self.K and self.maybe_split(cid):
320
+ ops["splits"] += 1
321
+
322
+ # Check prunes on idle buckets (only during explicit adaptation)
323
+ if force:
324
+ cids_to_check = list(self.buckets.keys())
325
+ for cid in cids_to_check:
326
+ if cid < self.K and self.maybe_prune(cid):
327
+ ops["prunes"] += 1
328
+
329
+ ops["k_after"] = self.K
330
+ return ops
331
+
332
+ def get_state(self) -> Tuple[Optional[np.ndarray], Dict]:
333
+ """For persistence: return (centroids, {cid: [item_ids]})."""
334
+ assignments = {}
335
+ for cid, bucket in self.buckets.items():
336
+ assignments[cid] = bucket.item_ids
337
+ return self.centroids, assignments
rm/drift.py ADDED
@@ -0,0 +1,145 @@
1
+ """Drift Monitor — Detects distribution shift via quantization error tracking.
2
+
3
+ Monitors rolling qerr and margin to detect when incoming queries diverge
4
+ from the codebook's trained distribution. When drift is detected:
5
+ - Triggers aggressive codebook adaptation (splits/updates)
6
+ - Logs alarm for observability
7
+ """
8
+
9
+ import logging
10
+ import time
11
+ from collections import deque
12
+ from dataclasses import dataclass, field
13
+ from typing import Deque, List, Optional
14
+
15
+ logger = logging.getLogger("rm.drift")
16
+
17
+
18
+ @dataclass
19
+ class DriftAlarm:
20
+ """Record of a drift event."""
21
+ timestamp: float
22
+ ratio: float # qerr_recent / qerr_baseline
23
+ episode: int
24
+ message: str
25
+
26
+
27
+ class DriftMonitor:
28
+ """Tracks rolling quantization error to detect distribution drift.
29
+
30
+ Alarm triggers when recent qerr exceeds baseline by a configurable factor.
31
+ """
32
+
33
+ WINDOW_SIZE = 20 # rolling window for recent stats
34
+ BASELINE_WINDOW = 50 # initial baseline window
35
+ ALARM_RATIO = 1.3 # qerr ratio threshold for alarm
36
+ COOLDOWN_QUERIES = 30 # min queries between alarms
37
+
38
+ def __init__(self):
39
+ self._qerr_history: Deque[float] = deque(maxlen=2000)
40
+ self._margin_history: Deque[float] = deque(maxlen=2000)
41
+ self._baseline_qerr: Optional[float] = None
42
+ self._baseline_margin: Optional[float] = None
43
+ self._episode = 0
44
+ self._last_alarm_episode = -self.COOLDOWN_QUERIES
45
+ self._alarms: List[DriftAlarm] = []
46
+ self._in_drift = False
47
+
48
+ @property
49
+ def episode_count(self) -> int:
50
+ return self._episode
51
+
52
+ @property
53
+ def in_drift(self) -> bool:
54
+ return self._in_drift
55
+
56
+ @property
57
+ def alarms(self) -> List[DriftAlarm]:
58
+ return list(self._alarms)
59
+
60
+ @property
61
+ def baseline_qerr(self) -> Optional[float]:
62
+ return self._baseline_qerr
63
+
64
+ @property
65
+ def recent_qerr(self) -> float:
66
+ if len(self._qerr_history) < 5:
67
+ return 0.0
68
+ recent = list(self._qerr_history)[-self.WINDOW_SIZE:]
69
+ return sum(recent) / len(recent)
70
+
71
+ @property
72
+ def drift_ratio(self) -> float:
73
+ if self._baseline_qerr is None or self._baseline_qerr < 1e-10:
74
+ return 0.0
75
+ return self.recent_qerr / self._baseline_qerr
76
+
77
+ def record(self, qerr: float, margin: float = 0.0) -> Optional[DriftAlarm]:
78
+ """Record a query's quantization error and margin.
79
+
80
+ Returns DriftAlarm if alarm was triggered, None otherwise.
81
+ """
82
+ self._episode += 1
83
+ self._qerr_history.append(qerr)
84
+ self._margin_history.append(margin)
85
+
86
+ # Establish baseline from first N queries
87
+ if self._baseline_qerr is None and len(self._qerr_history) >= self.BASELINE_WINDOW:
88
+ baseline_data = list(self._qerr_history)[:self.BASELINE_WINDOW]
89
+ self._baseline_qerr = sum(baseline_data) / len(baseline_data)
90
+ baseline_margins = list(self._margin_history)[:self.BASELINE_WINDOW]
91
+ self._baseline_margin = sum(baseline_margins) / len(baseline_margins)
92
+ return None
93
+
94
+ # Check for drift
95
+ return self._check_alarm()
96
+
97
+ def _check_alarm(self) -> Optional[DriftAlarm]:
98
+ """Check if current qerr indicates drift."""
99
+ if self._baseline_qerr is None:
100
+ return None
101
+ if len(self._qerr_history) < self.WINDOW_SIZE:
102
+ return None
103
+
104
+ ratio = self.drift_ratio
105
+
106
+ if ratio >= self.ALARM_RATIO:
107
+ if not self._in_drift:
108
+ self._in_drift = True
109
+ # Rate-limit alarms
110
+ if self._episode - self._last_alarm_episode >= self.COOLDOWN_QUERIES:
111
+ alarm = DriftAlarm(
112
+ timestamp=time.time(),
113
+ ratio=ratio,
114
+ episode=self._episode,
115
+ message=f"DRIFT_ALARM: qerr {ratio:.1f}x baseline at episode {self._episode}",
116
+ )
117
+ self._alarms.append(alarm)
118
+ self._last_alarm_episode = self._episode
119
+ logger.warning(f"[rm] {alarm.message}")
120
+ return alarm
121
+ else:
122
+ if self._in_drift:
123
+ self._in_drift = False
124
+ logger.info(f"[rm] Drift recovered at episode {self._episode}, "
125
+ f"ratio={ratio:.2f}")
126
+
127
+ return None
128
+
129
+ def reset_baseline(self) -> None:
130
+ """Reset baseline using recent data (after adaptation)."""
131
+ if len(self._qerr_history) >= self.WINDOW_SIZE:
132
+ recent = list(self._qerr_history)[-self.WINDOW_SIZE:]
133
+ self._baseline_qerr = sum(recent) / len(recent)
134
+ self._in_drift = False
135
+ logger.info(f"[rm] Baseline reset: new baseline_qerr={self._baseline_qerr:.4f}")
136
+
137
+ def get_stats(self) -> dict:
138
+ return {
139
+ "episode": self._episode,
140
+ "baseline_qerr": self._baseline_qerr,
141
+ "recent_qerr": self.recent_qerr,
142
+ "drift_ratio": self.drift_ratio,
143
+ "in_drift": self._in_drift,
144
+ "alarm_count": len(self._alarms),
145
+ }
@@ -0,0 +1,4 @@
1
+ from .local import LocalEmbeddings
2
+ from .base import EmbeddingBackend
3
+
4
+ __all__ = ["LocalEmbeddings", "EmbeddingBackend"]
rm/embeddings/base.py ADDED
@@ -0,0 +1,22 @@
1
+ """Abstract embedding backend interface."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import List
5
+ import numpy as np
6
+
7
+
8
+ class EmbeddingBackend(ABC):
9
+ """Abstract base for embedding providers."""
10
+
11
+ @property
12
+ @abstractmethod
13
+ def dim(self) -> int:
14
+ """Embedding dimensionality."""
15
+
16
+ @abstractmethod
17
+ def encode(self, text: str) -> np.ndarray:
18
+ """Encode a single text to a normalized vector."""
19
+
20
+ @abstractmethod
21
+ def encode_batch(self, texts: List[str]) -> np.ndarray:
22
+ """Encode a batch of texts. Returns (N, dim) matrix."""
rm/embeddings/local.py ADDED
@@ -0,0 +1,56 @@
1
+ """Local sentence-transformers embedding backend."""
2
+
3
+ from typing import Dict, List
4
+ import numpy as np
5
+
6
+ from .base import EmbeddingBackend
7
+
8
+ _model_instance = None
9
+
10
+
11
+ def _get_model(model_name: str):
12
+ """Lazy singleton for SentenceTransformer model."""
13
+ global _model_instance
14
+ if _model_instance is None:
15
+ from sentence_transformers import SentenceTransformer
16
+ _model_instance = SentenceTransformer(model_name)
17
+ return _model_instance
18
+
19
+
20
+ class LocalEmbeddings(EmbeddingBackend):
21
+ """sentence-transformers wrapper with caching. Default: all-MiniLM-L6-v2 (d=384)."""
22
+
23
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
24
+ self.model_name = model_name
25
+ self._cache: Dict[str, np.ndarray] = {}
26
+ self._dim: int = 384 # MiniLM default
27
+
28
+ @property
29
+ def dim(self) -> int:
30
+ return self._dim
31
+
32
+ def encode(self, text: str) -> np.ndarray:
33
+ if text in self._cache:
34
+ return self._cache[text]
35
+ model = _get_model(self.model_name)
36
+ vec = model.encode(text, normalize_embeddings=True)
37
+ vec = np.asarray(vec, dtype=np.float32)
38
+ self._dim = vec.shape[0]
39
+ self._cache[text] = vec
40
+ return vec
41
+
42
+ def encode_batch(self, texts: List[str]) -> np.ndarray:
43
+ # Check cache for all
44
+ uncached = [(i, t) for i, t in enumerate(texts) if t not in self._cache]
45
+
46
+ if uncached:
47
+ model = _get_model(self.model_name)
48
+ uncached_texts = [t for _, t in uncached]
49
+ vecs = model.encode(uncached_texts, normalize_embeddings=True, batch_size=64)
50
+ vecs = np.asarray(vecs, dtype=np.float32)
51
+ self._dim = vecs.shape[1]
52
+ for (_, text), vec in zip(uncached, vecs):
53
+ self._cache[text] = vec
54
+
55
+ result = np.array([self._cache[t] for t in texts], dtype=np.float32)
56
+ return result
rm/filtering.py ADDED
@@ -0,0 +1,15 @@
1
+ """Score-based token filtering (τ threshold)."""
2
+
3
+ from typing import List
4
+ from .retrieval import RetrievalResult
5
+
6
+
7
+ def filter_by_score(results: List[RetrievalResult], tau: float = 0.5) -> List[RetrievalResult]:
8
+ """Filter retrieval results below score threshold τ."""
9
+ return [r for r in results if r.score >= tau]
10
+
11
+
12
+ def filter_top_n(results: List[RetrievalResult], n: int = 3) -> List[RetrievalResult]:
13
+ """Keep only top-n results by score."""
14
+ sorted_results = sorted(results, key=lambda r: r.score, reverse=True)
15
+ return sorted_results[:n]