midas-memory 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
midas/__init__.py ADDED
@@ -0,0 +1,64 @@
1
+ """Midas — agentic memory.
2
+
3
+ Open-core wedge: semantic recall + budgeted context assembly behind a small,
4
+ store/embedder-agnostic API. This is the thing the eval harness benchmarks.
5
+ """
6
+ from .embeddings import (
7
+ DiskCachedEmbedder,
8
+ Embedder,
9
+ HashingEmbedder,
10
+ LocalEmbedder,
11
+ LocalReranker,
12
+ OpenAIEmbedder,
13
+ configure_local_model_cache,
14
+ cosine,
15
+ tokenize,
16
+ )
17
+ from .importance import ContentImportance, StructuralImportance
18
+ from .memory import CaptureResult, ContextBlock, Memory, Reranker, approx_tokens, format_record
19
+ from .policy import AGENT_MEMORY_INSTRUCTIONS, DEFAULT_POLICY, MemoryPolicy, policy_summary
20
+ from .store import InMemoryStore
21
+ from .types import MEMORY_KINDS, MemoryKind, MemoryRecord, RecallHit
22
+
23
+ try:
24
+ from .sqlite_store import SQLiteStore
25
+ except ImportError:
26
+ SQLiteStore = None # sqlite-vec not installed
27
+
28
+ try:
29
+ from .ann import IVFIndex, IVFStore
30
+ except ImportError:
31
+ IVFIndex = IVFStore = None # numpy not installed (ANN backend is optional)
32
+
33
+ __all__ = [
34
+ "Memory",
35
+ "ContentImportance",
36
+ "StructuralImportance",
37
+ "MemoryPolicy",
38
+ "DEFAULT_POLICY",
39
+ "AGENT_MEMORY_INSTRUCTIONS",
40
+ "policy_summary",
41
+ "CaptureResult",
42
+ "ContextBlock",
43
+ "Reranker",
44
+ "approx_tokens",
45
+ "format_record",
46
+ "MemoryRecord",
47
+ "RecallHit",
48
+ "MemoryKind",
49
+ "MEMORY_KINDS",
50
+ "Embedder",
51
+ "DiskCachedEmbedder",
52
+ "HashingEmbedder",
53
+ "LocalEmbedder",
54
+ "LocalReranker",
55
+ "OpenAIEmbedder",
56
+ "configure_local_model_cache",
57
+ "cosine",
58
+ "tokenize",
59
+ "InMemoryStore",
60
+ "SQLiteStore",
61
+ "IVFIndex",
62
+ "IVFStore",
63
+ ]
64
+ __version__ = "0.0.1"
midas/ann.py ADDED
@@ -0,0 +1,156 @@
1
+ """Approximate nearest-neighbour search for scaling Midas past the exact in-memory scan.
2
+
3
+ `InMemoryStore`'s cached scan is exact but O(N) per query (~230 ms at 1M x 768-d). `IVFIndex` is an
4
+ **inverted-file** index built with **numpy only** (no native dependency, unlike faiss/hnswlib): it
5
+ clusters the corpus into `nlist` cells via k-means, and a query then compares against only the
6
+ `nprobe` nearest cells instead of the whole corpus. That makes search **sub-linear** -- at 1M with
7
+ `nlist=1000, nprobe=8` a query scans ~8K vectors instead of 1M (~100x fewer) -- trading a little
8
+ recall for a large speedup. `nprobe` tunes the recall/latency trade-off at query time (no rebuild).
9
+
10
+ `IVFStore` wraps the index behind the same store surface as `InMemoryStore`, so `Memory` can use it
11
+ unchanged. The index is built lazily and rebuilt on mutation, so it suits **read-heavy / batch-loaded**
12
+ corpora (build once, query many); for write-heavy workloads keep `InMemoryStore`.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ from typing import Callable, Sequence
17
+
18
+ import numpy as np
19
+
20
+ from .types import MemoryRecord
21
+
22
+
23
+ def _kmeans(x: np.ndarray, k: int, *, n_iter: int, rng: np.random.Generator) -> np.ndarray:
24
+ """Spherical k-means (cosine): centroids are L2-normalized, assignment is by max dot product.
25
+
26
+ Inputs are assumed L2-normalized (Midas embeddings are), so dot product == cosine similarity.
27
+ """
28
+ # Init from k distinct points.
29
+ centroids = x[rng.choice(x.shape[0], k, replace=False)].copy()
30
+ for _ in range(n_iter):
31
+ assign = np.argmax(x @ centroids.T, axis=1)
32
+ for c in range(k):
33
+ members = x[assign == c]
34
+ if len(members):
35
+ v = members.sum(axis=0)
36
+ norm = float(np.linalg.norm(v))
37
+ if norm > 0.0:
38
+ centroids[c] = v / norm
39
+ else: # empty cell -> reseed from a random point so it can pick up members next round
40
+ centroids[c] = x[rng.integers(x.shape[0])]
41
+ return centroids
42
+
43
+
44
+ class IVFIndex:
45
+ """Inverted-file ANN index over L2-normalized vectors (cosine similarity)."""
46
+
47
+ def __init__(self, nlist: int | None = None, *, n_iter: int = 10, train_sample: int = 50_000,
48
+ seed: int = 0) -> None:
49
+ self.nlist = nlist
50
+ self.n_iter = n_iter
51
+ self.train_sample = train_sample
52
+ self.seed = seed
53
+ self._centroids: np.ndarray | None = None
54
+ self._lists: list[np.ndarray] = [] # cell -> row indices into _vectors
55
+ self._vectors: np.ndarray | None = None # (N, d) float32, L2-normalized
56
+
57
+ def fit(self, vectors: Sequence[Sequence[float]] | np.ndarray) -> "IVFIndex":
58
+ x = np.asarray(vectors, dtype=np.float32)
59
+ if x.ndim != 2 or x.shape[0] == 0:
60
+ raise ValueError("fit() needs a non-empty (N, d) matrix")
61
+ n = x.shape[0]
62
+ nlist = self.nlist or max(1, int(round(np.sqrt(n))))
63
+ nlist = min(nlist, n)
64
+ self.nlist = nlist
65
+ rng = np.random.default_rng(self.seed)
66
+ # Train centroids on a sample (k-means cost is O(sample x nlist x d x iters)); assign all.
67
+ train = x if n <= self.train_sample else x[rng.choice(n, self.train_sample, replace=False)]
68
+ centroids = _kmeans(train, nlist, n_iter=self.n_iter, rng=rng)
69
+ assign = np.argmax(x @ centroids.T, axis=1)
70
+ self._centroids = centroids
71
+ self._lists = [np.where(assign == c)[0] for c in range(nlist)]
72
+ self._vectors = x
73
+ return self
74
+
75
+ def search(self, query: Sequence[float], *, k: int = 10, nprobe: int = 8,
76
+ allowed: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
77
+ """Return (row indices, scores) for the top-k, highest cosine first.
78
+
79
+ `allowed` is an optional boolean mask over rows (predicate pushdown); candidates failing it
80
+ are dropped before the top-k cut.
81
+ """
82
+ if self._vectors is None or self._centroids is None:
83
+ return np.array([], dtype=int), np.array([], dtype=np.float32)
84
+ q = np.asarray(query, dtype=np.float32)
85
+ nprobe = max(1, min(nprobe, self.nlist))
86
+ csims = self._centroids @ q
87
+ cells = np.argpartition(-csims, nprobe - 1)[:nprobe]
88
+ candidates = np.concatenate([self._lists[c] for c in cells]) if len(cells) else np.empty(0, int)
89
+ if candidates.size and allowed is not None:
90
+ candidates = candidates[allowed[candidates]]
91
+ if candidates.size == 0:
92
+ return np.array([], dtype=int), np.array([], dtype=np.float32)
93
+ sims = self._vectors[candidates] @ q
94
+ kk = min(k, candidates.size)
95
+ part = np.argpartition(-sims, kk - 1)[:kk]
96
+ order = part[np.argsort(-sims[part], kind="stable")]
97
+ return candidates[order], sims[order].astype(np.float32)
98
+
99
+
100
+ class IVFStore:
101
+ """Store with the `InMemoryStore` surface, backed by an `IVFIndex` for sub-linear search.
102
+
103
+ Build is lazy and triggered on the first search after a mutation. `nprobe` (recall/latency knob)
104
+ is set at construction. Best for read-heavy corpora; for write-heavy use `InMemoryStore`.
105
+ """
106
+
107
+ def __init__(self, *, nlist: int | None = None, nprobe: int = 8, seed: int = 0) -> None:
108
+ self._records: dict[str, MemoryRecord] = {}
109
+ self._nlist = nlist
110
+ self.nprobe = nprobe
111
+ self.seed = seed
112
+ self._index: IVFIndex | None = None
113
+ self._rows: list[MemoryRecord] = [] # row -> record (aligned with the fitted matrix)
114
+ self._dirty = True
115
+
116
+ def put(self, record: MemoryRecord) -> None:
117
+ self._records[record.id] = record
118
+ self._dirty = True
119
+
120
+ def get(self, record_id: str) -> MemoryRecord | None:
121
+ return self._records.get(record_id)
122
+
123
+ def delete(self, record_id: str) -> bool:
124
+ existed = self._records.pop(record_id, None) is not None
125
+ self._dirty = self._dirty or existed
126
+ return existed
127
+
128
+ def all(self) -> list[MemoryRecord]:
129
+ return list(self._records.values())
130
+
131
+ def clear(self) -> None:
132
+ self._records.clear()
133
+ self._index = None
134
+ self._rows = []
135
+ self._dirty = True
136
+
137
+ def _ensure_index(self) -> None:
138
+ if not self._dirty:
139
+ return
140
+ recs = [r for r in self._records.values() if r.embedding is not None]
141
+ self._rows = recs
142
+ self._index = IVFIndex(nlist=self._nlist, seed=self.seed).fit(
143
+ [r.embedding for r in recs]) if recs else None
144
+ self._dirty = False
145
+
146
+ def search(self, embedding: list[float], *, limit: int,
147
+ predicate: Callable[[MemoryRecord], bool] | None = None
148
+ ) -> list[tuple[float, MemoryRecord]]:
149
+ self._ensure_index()
150
+ if self._index is None:
151
+ return []
152
+ allowed = None
153
+ if predicate is not None:
154
+ allowed = np.fromiter((predicate(r) for r in self._rows), dtype=bool, count=len(self._rows))
155
+ rows, scores = self._index.search(embedding, k=limit, nprobe=self.nprobe, allowed=allowed)
156
+ return [(float(s), self._rows[i]) for i, s in zip(rows, scores)]
midas/bm25.py ADDED
@@ -0,0 +1,51 @@
1
+ """BM25 lexical ranking — pure-Python, zero-dep.
2
+
3
+ Fused with semantic recall (see `Memory.recall(hybrid=True)`) to catch lexically-relevant memories
4
+ the bi-encoder misses — a no-LLM lever to push retrieval past the embedding-only ceiling. Standard
5
+ Okapi BM25 with non-negative IDF; rebuilt per recall over the candidate set (cheap at eval scale).
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import math
10
+ from collections import Counter
11
+
12
+ from .embeddings import tokenize
13
+ from .types import MemoryRecord
14
+
15
+ _K1 = 1.5
16
+ _B = 0.75
17
+
18
+
19
+ class BM25:
20
+ def __init__(self, records: list[MemoryRecord]) -> None:
21
+ self.records = records
22
+ self._docs = [tokenize(r.content) for r in records]
23
+ self._tfs = [Counter(d) for d in self._docs]
24
+ self._dls = [len(d) for d in self._docs]
25
+ self._avgdl = (sum(self._dls) / len(self._dls)) if self._dls else 0.0
26
+ df: Counter = Counter()
27
+ for doc in self._docs:
28
+ for term in set(doc):
29
+ df[term] += 1
30
+ n = len(records)
31
+ # Non-negative IDF (the BM25+ form) so common terms never push scores negative.
32
+ self._idf = {t: math.log(1 + (n - c + 0.5) / (c + 0.5)) for t, c in df.items()}
33
+
34
+ def scores(self, query: str) -> dict[str, float]:
35
+ """Map record id -> BM25 score (only records with a positive score are included)."""
36
+ q_terms = tokenize(query)
37
+ avgdl = self._avgdl or 1.0
38
+ out: dict[str, float] = {}
39
+ for i, record in enumerate(self.records):
40
+ tf = self._tfs[i]
41
+ dl = self._dls[i]
42
+ s = 0.0
43
+ for term in q_terms:
44
+ f = tf.get(term, 0)
45
+ if not f:
46
+ continue
47
+ idf = self._idf.get(term, 0.0)
48
+ s += idf * (f * (_K1 + 1)) / (f + _K1 * (1 - _B + _B * dl / avgdl))
49
+ if s > 0.0:
50
+ out[record.id] = s
51
+ return out
midas/embeddings.py ADDED
@@ -0,0 +1,334 @@
1
+ """Pluggable embeddings.
2
+
3
+ `HashingEmbedder` is a dependency-free, deterministic, offline stand-in so the
4
+ whole harness runs anywhere with zero setup. It is essentially a vectorized
5
+ bag-of-words — good enough to exercise the pipeline, NOT a real semantic model.
6
+ Swap in `OpenAIEmbedder` (or a local BGE/e5) for genuine semantic retrieval;
7
+ that swap is the entire reason embeddings live behind a Protocol.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import hashlib
12
+ import math
13
+ import os
14
+ import re
15
+ import sqlite3
16
+ import struct
17
+ import threading
18
+ from pathlib import Path
19
+ from typing import Protocol, runtime_checkable
20
+
21
+ _WORD = re.compile(r"\w+", re.UNICODE)
22
+
23
+
24
+ def tokenize(text: str) -> list[str]:
25
+ """Lowercase word tokens of length > 2 (Unicode-aware)."""
26
+ return [w for w in _WORD.findall(text.lower()) if len(w) > 2]
27
+
28
+
29
+ @runtime_checkable
30
+ class Embedder(Protocol):
31
+ dim: int
32
+
33
+ def embed(self, text: str) -> list[float]: ...
34
+
35
+ def embed_many(self, texts: list[str]) -> list[list[float]]: ...
36
+
37
+
38
+ class HashingEmbedder:
39
+ """Offline, deterministic bag-of-words hashed into a fixed-dim unit vector."""
40
+
41
+ def __init__(self, dim: int = 256) -> None:
42
+ self.dim = dim
43
+
44
+ def embed(self, text: str) -> list[float]:
45
+ vec = [0.0] * self.dim
46
+ for tok in tokenize(text):
47
+ digest = hashlib.md5(tok.encode("utf-8")).digest()
48
+ h = int.from_bytes(digest[:8], "big")
49
+ idx = h % self.dim
50
+ sign = 1.0 if (h >> 8) & 1 else -1.0
51
+ vec[idx] += sign
52
+ return l2_normalize(vec)
53
+
54
+ def embed_many(self, texts: list[str]) -> list[list[float]]:
55
+ return [self.embed(text) for text in texts]
56
+
57
+
58
+ class OpenAIEmbedder:
59
+ """Real semantic embeddings via OpenAI. Requires `OPENAI_API_KEY` and the
60
+ `openai` package (`uv pip install openai`)."""
61
+
62
+ def __init__(self, model: str = "text-embedding-3-small") -> None:
63
+ from openai import OpenAI # lazy import — optional dependency
64
+
65
+ self._client = OpenAI()
66
+ self.model = model
67
+ self.dim = 1536
68
+
69
+ def embed(self, text: str) -> list[float]:
70
+ resp = self._client.embeddings.create(model=self.model, input=text)
71
+ return l2_normalize(list(resp.data[0].embedding))
72
+
73
+ def embed_many(self, texts: list[str]) -> list[list[float]]:
74
+ if not texts:
75
+ return []
76
+ resp = self._client.embeddings.create(model=self.model, input=texts)
77
+ return [l2_normalize(list(item.embedding)) for item in resp.data]
78
+
79
+
80
+ def _default_embedding_cache_path() -> Path:
81
+ return _default_cache_root() / "midas-embeddings.sqlite3"
82
+
83
+
84
+ def _cache_namespace(embedder: object) -> str:
85
+ parts = [
86
+ f"{type(embedder).__module__}.{type(embedder).__qualname__}",
87
+ f"dim={getattr(embedder, 'dim', 'unknown')}",
88
+ ]
89
+ for attr in ("model_name", "max_text_chars"):
90
+ value = getattr(embedder, attr, None)
91
+ if value is not None:
92
+ parts.append(f"{attr}={value}")
93
+ return "|".join(parts)
94
+
95
+
96
+ def _encode_vector(vec: list[float]) -> bytes:
97
+ return struct.pack(f"<{len(vec)}f", *(float(v) for v in vec))
98
+
99
+
100
+ def _decode_vector(blob: bytes, *, dim: int) -> list[float]:
101
+ if len(blob) != dim * 4:
102
+ raise ValueError("cached embedding has an unexpected dimension")
103
+ return [float(v) for v in struct.unpack(f"<{dim}f", blob)]
104
+
105
+
106
+ class DiskCachedEmbedder:
107
+ """Persistent SQLite cache wrapper for expensive embedders.
108
+
109
+ The key includes the embedder namespace and exact input text hash, so changing the
110
+ local model or truncation limit gets a separate cache namespace by default.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ inner: Embedder,
116
+ *,
117
+ path: str | Path | None = None,
118
+ namespace: str | None = None,
119
+ ) -> None:
120
+ self.inner = inner
121
+ self.dim = inner.dim
122
+ self.namespace = namespace or _cache_namespace(inner)
123
+ self.path = Path(path) if path is not None else _default_embedding_cache_path()
124
+ self.path.parent.mkdir(parents=True, exist_ok=True)
125
+ self._lock = threading.Lock()
126
+ self._conn = sqlite3.connect(str(self.path), check_same_thread=False)
127
+ with self._lock:
128
+ self._conn.execute("PRAGMA journal_mode=WAL")
129
+ self._conn.execute(
130
+ """
131
+ CREATE TABLE IF NOT EXISTS embeddings (
132
+ namespace TEXT NOT NULL,
133
+ text_sha256 TEXT NOT NULL,
134
+ dim INTEGER NOT NULL,
135
+ embedding BLOB NOT NULL,
136
+ created_at INTEGER NOT NULL DEFAULT (unixepoch()),
137
+ PRIMARY KEY (namespace, text_sha256)
138
+ )
139
+ """
140
+ )
141
+ self._conn.commit()
142
+ self.hits = 0
143
+ self.misses = 0
144
+
145
+ def embed(self, text: str) -> list[float]:
146
+ return self.embed_many([text])[0]
147
+
148
+ def embed_many(self, texts: list[str]) -> list[list[float]]:
149
+ if not texts:
150
+ return []
151
+
152
+ keys = [hashlib.sha256(text.encode("utf-8")).hexdigest() for text in texts]
153
+ cached = self._read_many(set(keys))
154
+ missing: dict[str, str] = {}
155
+ for key, text in zip(keys, texts):
156
+ if key not in cached and key not in missing:
157
+ missing[key] = text
158
+
159
+ self.hits += len(texts) - sum(1 for key in keys if key not in cached)
160
+ self.misses += len(missing)
161
+
162
+ if missing:
163
+ missing_keys = list(missing)
164
+ missing_texts = [missing[key] for key in missing_keys]
165
+ embed_many = getattr(self.inner, "embed_many", None)
166
+ vectors = (
167
+ embed_many(missing_texts)
168
+ if callable(embed_many)
169
+ else [self.inner.embed(text) for text in missing_texts]
170
+ )
171
+ if len(vectors) != len(missing_keys):
172
+ raise ValueError("cached embedder inner returned the wrong number of embeddings")
173
+ self._write_many(zip(missing_keys, vectors))
174
+ cached.update(zip(missing_keys, vectors))
175
+
176
+ return [cached[key] for key in keys]
177
+
178
+ def _read_many(self, keys: set[str]) -> dict[str, list[float]]:
179
+ if not keys:
180
+ return {}
181
+ rows: dict[str, list[float]] = {}
182
+ ordered = list(keys)
183
+ with self._lock:
184
+ for i in range(0, len(ordered), 500):
185
+ chunk = ordered[i : i + 500]
186
+ placeholders = ",".join("?" for _ in chunk)
187
+ params = [self.namespace, *chunk]
188
+ cur = self._conn.execute(
189
+ f"""
190
+ SELECT text_sha256, dim, embedding
191
+ FROM embeddings
192
+ WHERE namespace = ? AND text_sha256 IN ({placeholders})
193
+ """,
194
+ params,
195
+ )
196
+ for key, dim, blob in cur.fetchall():
197
+ if dim == self.dim:
198
+ rows[key] = _decode_vector(blob, dim=self.dim)
199
+ return rows
200
+
201
+ def _write_many(self, items) -> None:
202
+ rows = [
203
+ (self.namespace, key, self.dim, _encode_vector(vec))
204
+ for key, vec in items
205
+ ]
206
+ if not rows:
207
+ return
208
+ with self._lock:
209
+ self._conn.executemany(
210
+ """
211
+ INSERT OR REPLACE INTO embeddings(namespace, text_sha256, dim, embedding)
212
+ VALUES (?, ?, ?, ?)
213
+ """,
214
+ rows,
215
+ )
216
+ self._conn.commit()
217
+
218
+
219
+ def _default_cache_root() -> Path:
220
+ if root := os.getenv("MIDAS_CACHE_ROOT"):
221
+ return Path(root)
222
+ if os.name == "nt" and Path("D:/").exists():
223
+ return Path("D:/hf-cache")
224
+ return Path.home() / ".cache" / "midas"
225
+
226
+
227
+ def _default_tmp_dir() -> Path:
228
+ if tmp := os.getenv("MIDAS_TMP_DIR"):
229
+ return Path(tmp)
230
+ if os.name == "nt" and Path("D:/").exists():
231
+ return Path("D:/tmp")
232
+ return _default_cache_root() / "tmp"
233
+
234
+
235
+ def configure_local_model_cache(cache_dir: str | Path | None = None) -> str:
236
+ """Return a fastembed cache dir and set safe local cache env defaults."""
237
+ cache_root = _default_cache_root()
238
+ root = Path(cache_dir) if cache_dir is not None else cache_root / "fastembed"
239
+ tmp = _default_tmp_dir()
240
+ root.mkdir(parents=True, exist_ok=True)
241
+ tmp.mkdir(parents=True, exist_ok=True)
242
+
243
+ os.environ.setdefault("HF_HOME", str(cache_root))
244
+ os.environ.setdefault("HF_HUB_CACHE", str(cache_root / "hub"))
245
+ os.environ.setdefault("TMP", str(tmp))
246
+ os.environ.setdefault("TEMP", str(tmp))
247
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
248
+ return str(root)
249
+
250
+
251
+ class LocalEmbedder:
252
+ """Real local semantic embeddings via fastembed (ONNX — no API key, no torch).
253
+ First use downloads a small model (~130 MB). Install: `uv pip install fastembed`."""
254
+
255
+ _MAX_TEXT_CHARS = 2000
256
+
257
+ def __init__(
258
+ self,
259
+ model_name: str = "BAAI/bge-base-en-v1.5",
260
+ *,
261
+ cache_dir: str | Path | None = None,
262
+ batch_size: int = 32,
263
+ max_text_chars: int | None = None,
264
+ ) -> None:
265
+ self.cache_dir = configure_local_model_cache(cache_dir)
266
+ from fastembed import TextEmbedding # lazy import — optional dependency
267
+
268
+ self._model = TextEmbedding(model_name=model_name, cache_dir=self.cache_dir)
269
+ self.model_name = model_name
270
+ self.dim = 768
271
+ self.batch_size = batch_size
272
+ self.max_text_chars = max_text_chars or self._MAX_TEXT_CHARS
273
+
274
+ def embed(self, text: str) -> list[float]:
275
+ vec = next(iter(self._model.embed([text[: self.max_text_chars]], batch_size=1)))
276
+ return l2_normalize([float(x) for x in vec])
277
+
278
+ def embed_many(self, texts: list[str]) -> list[list[float]]:
279
+ if not texts:
280
+ return []
281
+ capped = [text[: self.max_text_chars] for text in texts]
282
+ return [
283
+ l2_normalize([float(x) for x in vec])
284
+ for vec in self._model.embed(capped, batch_size=self.batch_size)
285
+ ]
286
+
287
+
288
+ class LocalReranker:
289
+ """Local cross-encoder reranker (fastembed/ONNX, no API key, no torch). Reorders
290
+ candidate documents by relevance to the query — sharper precision than bi-encoder
291
+ cosine. First use downloads a small model. Used by Midas to refine the recall pool."""
292
+
293
+ def __init__(
294
+ self,
295
+ model_name: str = "Xenova/ms-marco-MiniLM-L-6-v2",
296
+ *,
297
+ cache_dir: str | Path | None = None,
298
+ ) -> None:
299
+ self.cache_dir = configure_local_model_cache(cache_dir)
300
+ from fastembed.rerank.cross_encoder import TextCrossEncoder # lazy, optional
301
+
302
+ self._model = TextCrossEncoder(model_name=model_name, cache_dir=self.cache_dir)
303
+ self.model_name = model_name
304
+
305
+ # Cross-encoders cap the query+doc PAIR at ~512 tokens; oversized input crashes the ONNX
306
+ # runtime (seen on long real turns). Cap defensively — affects only the rerank score, not
307
+ # the stored/assembled content.
308
+ _MAX_QUERY_CHARS = 400
309
+ _MAX_DOC_CHARS = 1200
310
+
311
+ def rerank(self, query: str, documents: list[str]) -> list[float]:
312
+ """Return a relevance score per document (higher = more relevant), in order.
313
+ Length-capped to stay within the cross-encoder context; degrades to neutral scores on
314
+ any runtime error rather than crashing the whole query."""
315
+ if not documents:
316
+ return []
317
+ q = query[: self._MAX_QUERY_CHARS]
318
+ docs = [d[: self._MAX_DOC_CHARS] for d in documents]
319
+ try:
320
+ return [float(s) for s in self._model.rerank(q, docs)]
321
+ except Exception:
322
+ return [0.0] * len(docs)
323
+
324
+
325
+ def l2_normalize(vec: list[float]) -> list[float]:
326
+ norm = math.sqrt(sum(v * v for v in vec))
327
+ if norm == 0.0:
328
+ return vec
329
+ return [v / norm for v in vec]
330
+
331
+
332
+ def cosine(a: list[float], b: list[float]) -> float:
333
+ """Cosine similarity. Inputs assumed L2-normalized; result clamped to [-1, 1]."""
334
+ return max(-1.0, min(1.0, sum(x * y for x, y in zip(a, b))))
midas/entity.py ADDED
@@ -0,0 +1,51 @@
1
+ """No-LLM entity-grounded abstention — a candidate lever for the *Calibrated* frontier.
2
+
3
+ Diagnosed root cause (see docs §5/§7): the reader confabulates an answer drawn from a retrieved
4
+ distractor that is about a DIFFERENT entity than the question asks about (e.g. Q asks the *hamster's*
5
+ name, the retrieved turn is "I have a *cat* named Luna" → the reader answers "Luna"). Relevance, cosine
6
+ and NLI-entailment all fail here because that distractor genuinely *entails* the wrong answer.
7
+
8
+ This checks something orthogonal: is the answer's **source turn about the entity the question asks
9
+ about?** We pull the question's focus nouns and test whether the source mentions them. Pure regex, no
10
+ LLM. This module is **measured offline** before any end-to-end claim (the real test needs a capable
11
+ reader; the local 1B model exhibits hallucination, not confab-from-distractor — see the benchmark notes).
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import re
16
+
17
+ # Generic scaffolding that is NOT the entity a question is about: question words, copulas, and the
18
+ # attribute/relation words that recur across confab pairs ("favorite", "name") and so must not count
19
+ # as a shared entity.
20
+ _STOP = {
21
+ "what", "which", "who", "whom", "whose", "where", "when", "why", "how", "is", "are", "was", "were",
22
+ "the", "a", "an", "of", "to", "for", "in", "on", "at", "by", "with", "and", "or", "but", "do",
23
+ "does", "did", "have", "has", "had", "his", "her", "their", "my", "your", "our", "its", "this",
24
+ "that", "user", "users", "you", "i", "we", "they", "he", "she", "it", "name", "named", "call",
25
+ "called", "favorite", "favourite", "kind", "type", "about", "tell", "me", "please", "they're",
26
+ }
27
+ _WORD = re.compile(r"[A-Za-z][A-Za-z'-]+")
28
+
29
+
30
+ def entities(text: str) -> set[str]:
31
+ """Salient content words (non-stopword, >=3 chars), lowercased — a turn's rough entity fingerprint.
32
+ Possessives are normalised to the base noun ("user's" -> "user") so scaffolding is filtered."""
33
+ out: set[str] = set()
34
+ for m in _WORD.finditer(text or ""):
35
+ w = m.group().lower()
36
+ if w.endswith("'s"):
37
+ w = w[:-2] # possessive -> base noun
38
+ if len(w) >= 3 and w not in _STOP:
39
+ out.add(w)
40
+ return out
41
+
42
+
43
+ def question_focus(question: str) -> set[str]:
44
+ """The entity/entities a question is about — its salient nouns minus the generic scaffolding."""
45
+ return entities(question)
46
+
47
+
48
+ def entity_grounded(question: str, source: str, *, min_overlap: int = 1) -> bool:
49
+ """Is the source turn about the entity the question asks about? True if they share >= `min_overlap`
50
+ salient words. A confab drawn from a wrong-entity distractor shares none of the question's focus."""
51
+ return len(question_focus(question) & entities(source)) >= min_overlap