midas-memory 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- midas/__init__.py +64 -0
- midas/ann.py +156 -0
- midas/bm25.py +51 -0
- midas/embeddings.py +334 -0
- midas/entity.py +51 -0
- midas/importance.py +146 -0
- midas/integrations/__init__.py +2 -0
- midas/integrations/langgraph_store.py +150 -0
- midas/mcp_server.py +220 -0
- midas/memory.py +1028 -0
- midas/nli.py +90 -0
- midas/policy.py +70 -0
- midas/py.typed +0 -0
- midas/sqlite_store.py +120 -0
- midas/store.py +114 -0
- midas/types.py +47 -0
- midas_memory-0.0.1.dist-info/METADATA +343 -0
- midas_memory-0.0.1.dist-info/RECORD +21 -0
- midas_memory-0.0.1.dist-info/WHEEL +4 -0
- midas_memory-0.0.1.dist-info/entry_points.txt +2 -0
- midas_memory-0.0.1.dist-info/licenses/LICENSE +21 -0
midas/__init__.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Midas — agentic memory.
|
|
2
|
+
|
|
3
|
+
Open-core wedge: semantic recall + budgeted context assembly behind a small,
|
|
4
|
+
store/embedder-agnostic API. This is the thing the eval harness benchmarks.
|
|
5
|
+
"""
|
|
6
|
+
from .embeddings import (
|
|
7
|
+
DiskCachedEmbedder,
|
|
8
|
+
Embedder,
|
|
9
|
+
HashingEmbedder,
|
|
10
|
+
LocalEmbedder,
|
|
11
|
+
LocalReranker,
|
|
12
|
+
OpenAIEmbedder,
|
|
13
|
+
configure_local_model_cache,
|
|
14
|
+
cosine,
|
|
15
|
+
tokenize,
|
|
16
|
+
)
|
|
17
|
+
from .importance import ContentImportance, StructuralImportance
|
|
18
|
+
from .memory import CaptureResult, ContextBlock, Memory, Reranker, approx_tokens, format_record
|
|
19
|
+
from .policy import AGENT_MEMORY_INSTRUCTIONS, DEFAULT_POLICY, MemoryPolicy, policy_summary
|
|
20
|
+
from .store import InMemoryStore
|
|
21
|
+
from .types import MEMORY_KINDS, MemoryKind, MemoryRecord, RecallHit
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from .sqlite_store import SQLiteStore
|
|
25
|
+
except ImportError:
|
|
26
|
+
SQLiteStore = None # sqlite-vec not installed
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
from .ann import IVFIndex, IVFStore
|
|
30
|
+
except ImportError:
|
|
31
|
+
IVFIndex = IVFStore = None # numpy not installed (ANN backend is optional)
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"Memory",
|
|
35
|
+
"ContentImportance",
|
|
36
|
+
"StructuralImportance",
|
|
37
|
+
"MemoryPolicy",
|
|
38
|
+
"DEFAULT_POLICY",
|
|
39
|
+
"AGENT_MEMORY_INSTRUCTIONS",
|
|
40
|
+
"policy_summary",
|
|
41
|
+
"CaptureResult",
|
|
42
|
+
"ContextBlock",
|
|
43
|
+
"Reranker",
|
|
44
|
+
"approx_tokens",
|
|
45
|
+
"format_record",
|
|
46
|
+
"MemoryRecord",
|
|
47
|
+
"RecallHit",
|
|
48
|
+
"MemoryKind",
|
|
49
|
+
"MEMORY_KINDS",
|
|
50
|
+
"Embedder",
|
|
51
|
+
"DiskCachedEmbedder",
|
|
52
|
+
"HashingEmbedder",
|
|
53
|
+
"LocalEmbedder",
|
|
54
|
+
"LocalReranker",
|
|
55
|
+
"OpenAIEmbedder",
|
|
56
|
+
"configure_local_model_cache",
|
|
57
|
+
"cosine",
|
|
58
|
+
"tokenize",
|
|
59
|
+
"InMemoryStore",
|
|
60
|
+
"SQLiteStore",
|
|
61
|
+
"IVFIndex",
|
|
62
|
+
"IVFStore",
|
|
63
|
+
]
|
|
64
|
+
__version__ = "0.0.1"
|
midas/ann.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""Approximate nearest-neighbour search for scaling Midas past the exact in-memory scan.
|
|
2
|
+
|
|
3
|
+
`InMemoryStore`'s cached scan is exact but O(N) per query (~230 ms at 1M x 768-d). `IVFIndex` is an
|
|
4
|
+
**inverted-file** index built with **numpy only** (no native dependency, unlike faiss/hnswlib): it
|
|
5
|
+
clusters the corpus into `nlist` cells via k-means, and a query then compares against only the
|
|
6
|
+
`nprobe` nearest cells instead of the whole corpus. That makes search **sub-linear** -- at 1M with
|
|
7
|
+
`nlist=1000, nprobe=8` a query scans ~8K vectors instead of 1M (~100x fewer) -- trading a little
|
|
8
|
+
recall for a large speedup. `nprobe` tunes the recall/latency trade-off at query time (no rebuild).
|
|
9
|
+
|
|
10
|
+
`IVFStore` wraps the index behind the same store surface as `InMemoryStore`, so `Memory` can use it
|
|
11
|
+
unchanged. The index is built lazily and rebuilt on mutation, so it suits **read-heavy / batch-loaded**
|
|
12
|
+
corpora (build once, query many); for write-heavy workloads keep `InMemoryStore`.
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from typing import Callable, Sequence
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from .types import MemoryRecord
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _kmeans(x: np.ndarray, k: int, *, n_iter: int, rng: np.random.Generator) -> np.ndarray:
|
|
24
|
+
"""Spherical k-means (cosine): centroids are L2-normalized, assignment is by max dot product.
|
|
25
|
+
|
|
26
|
+
Inputs are assumed L2-normalized (Midas embeddings are), so dot product == cosine similarity.
|
|
27
|
+
"""
|
|
28
|
+
# Init from k distinct points.
|
|
29
|
+
centroids = x[rng.choice(x.shape[0], k, replace=False)].copy()
|
|
30
|
+
for _ in range(n_iter):
|
|
31
|
+
assign = np.argmax(x @ centroids.T, axis=1)
|
|
32
|
+
for c in range(k):
|
|
33
|
+
members = x[assign == c]
|
|
34
|
+
if len(members):
|
|
35
|
+
v = members.sum(axis=0)
|
|
36
|
+
norm = float(np.linalg.norm(v))
|
|
37
|
+
if norm > 0.0:
|
|
38
|
+
centroids[c] = v / norm
|
|
39
|
+
else: # empty cell -> reseed from a random point so it can pick up members next round
|
|
40
|
+
centroids[c] = x[rng.integers(x.shape[0])]
|
|
41
|
+
return centroids
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class IVFIndex:
|
|
45
|
+
"""Inverted-file ANN index over L2-normalized vectors (cosine similarity)."""
|
|
46
|
+
|
|
47
|
+
def __init__(self, nlist: int | None = None, *, n_iter: int = 10, train_sample: int = 50_000,
|
|
48
|
+
seed: int = 0) -> None:
|
|
49
|
+
self.nlist = nlist
|
|
50
|
+
self.n_iter = n_iter
|
|
51
|
+
self.train_sample = train_sample
|
|
52
|
+
self.seed = seed
|
|
53
|
+
self._centroids: np.ndarray | None = None
|
|
54
|
+
self._lists: list[np.ndarray] = [] # cell -> row indices into _vectors
|
|
55
|
+
self._vectors: np.ndarray | None = None # (N, d) float32, L2-normalized
|
|
56
|
+
|
|
57
|
+
def fit(self, vectors: Sequence[Sequence[float]] | np.ndarray) -> "IVFIndex":
|
|
58
|
+
x = np.asarray(vectors, dtype=np.float32)
|
|
59
|
+
if x.ndim != 2 or x.shape[0] == 0:
|
|
60
|
+
raise ValueError("fit() needs a non-empty (N, d) matrix")
|
|
61
|
+
n = x.shape[0]
|
|
62
|
+
nlist = self.nlist or max(1, int(round(np.sqrt(n))))
|
|
63
|
+
nlist = min(nlist, n)
|
|
64
|
+
self.nlist = nlist
|
|
65
|
+
rng = np.random.default_rng(self.seed)
|
|
66
|
+
# Train centroids on a sample (k-means cost is O(sample x nlist x d x iters)); assign all.
|
|
67
|
+
train = x if n <= self.train_sample else x[rng.choice(n, self.train_sample, replace=False)]
|
|
68
|
+
centroids = _kmeans(train, nlist, n_iter=self.n_iter, rng=rng)
|
|
69
|
+
assign = np.argmax(x @ centroids.T, axis=1)
|
|
70
|
+
self._centroids = centroids
|
|
71
|
+
self._lists = [np.where(assign == c)[0] for c in range(nlist)]
|
|
72
|
+
self._vectors = x
|
|
73
|
+
return self
|
|
74
|
+
|
|
75
|
+
def search(self, query: Sequence[float], *, k: int = 10, nprobe: int = 8,
|
|
76
|
+
allowed: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
|
|
77
|
+
"""Return (row indices, scores) for the top-k, highest cosine first.
|
|
78
|
+
|
|
79
|
+
`allowed` is an optional boolean mask over rows (predicate pushdown); candidates failing it
|
|
80
|
+
are dropped before the top-k cut.
|
|
81
|
+
"""
|
|
82
|
+
if self._vectors is None or self._centroids is None:
|
|
83
|
+
return np.array([], dtype=int), np.array([], dtype=np.float32)
|
|
84
|
+
q = np.asarray(query, dtype=np.float32)
|
|
85
|
+
nprobe = max(1, min(nprobe, self.nlist))
|
|
86
|
+
csims = self._centroids @ q
|
|
87
|
+
cells = np.argpartition(-csims, nprobe - 1)[:nprobe]
|
|
88
|
+
candidates = np.concatenate([self._lists[c] for c in cells]) if len(cells) else np.empty(0, int)
|
|
89
|
+
if candidates.size and allowed is not None:
|
|
90
|
+
candidates = candidates[allowed[candidates]]
|
|
91
|
+
if candidates.size == 0:
|
|
92
|
+
return np.array([], dtype=int), np.array([], dtype=np.float32)
|
|
93
|
+
sims = self._vectors[candidates] @ q
|
|
94
|
+
kk = min(k, candidates.size)
|
|
95
|
+
part = np.argpartition(-sims, kk - 1)[:kk]
|
|
96
|
+
order = part[np.argsort(-sims[part], kind="stable")]
|
|
97
|
+
return candidates[order], sims[order].astype(np.float32)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class IVFStore:
|
|
101
|
+
"""Store with the `InMemoryStore` surface, backed by an `IVFIndex` for sub-linear search.
|
|
102
|
+
|
|
103
|
+
Build is lazy and triggered on the first search after a mutation. `nprobe` (recall/latency knob)
|
|
104
|
+
is set at construction. Best for read-heavy corpora; for write-heavy use `InMemoryStore`.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(self, *, nlist: int | None = None, nprobe: int = 8, seed: int = 0) -> None:
|
|
108
|
+
self._records: dict[str, MemoryRecord] = {}
|
|
109
|
+
self._nlist = nlist
|
|
110
|
+
self.nprobe = nprobe
|
|
111
|
+
self.seed = seed
|
|
112
|
+
self._index: IVFIndex | None = None
|
|
113
|
+
self._rows: list[MemoryRecord] = [] # row -> record (aligned with the fitted matrix)
|
|
114
|
+
self._dirty = True
|
|
115
|
+
|
|
116
|
+
def put(self, record: MemoryRecord) -> None:
|
|
117
|
+
self._records[record.id] = record
|
|
118
|
+
self._dirty = True
|
|
119
|
+
|
|
120
|
+
def get(self, record_id: str) -> MemoryRecord | None:
|
|
121
|
+
return self._records.get(record_id)
|
|
122
|
+
|
|
123
|
+
def delete(self, record_id: str) -> bool:
|
|
124
|
+
existed = self._records.pop(record_id, None) is not None
|
|
125
|
+
self._dirty = self._dirty or existed
|
|
126
|
+
return existed
|
|
127
|
+
|
|
128
|
+
def all(self) -> list[MemoryRecord]:
|
|
129
|
+
return list(self._records.values())
|
|
130
|
+
|
|
131
|
+
def clear(self) -> None:
|
|
132
|
+
self._records.clear()
|
|
133
|
+
self._index = None
|
|
134
|
+
self._rows = []
|
|
135
|
+
self._dirty = True
|
|
136
|
+
|
|
137
|
+
def _ensure_index(self) -> None:
|
|
138
|
+
if not self._dirty:
|
|
139
|
+
return
|
|
140
|
+
recs = [r for r in self._records.values() if r.embedding is not None]
|
|
141
|
+
self._rows = recs
|
|
142
|
+
self._index = IVFIndex(nlist=self._nlist, seed=self.seed).fit(
|
|
143
|
+
[r.embedding for r in recs]) if recs else None
|
|
144
|
+
self._dirty = False
|
|
145
|
+
|
|
146
|
+
def search(self, embedding: list[float], *, limit: int,
|
|
147
|
+
predicate: Callable[[MemoryRecord], bool] | None = None
|
|
148
|
+
) -> list[tuple[float, MemoryRecord]]:
|
|
149
|
+
self._ensure_index()
|
|
150
|
+
if self._index is None:
|
|
151
|
+
return []
|
|
152
|
+
allowed = None
|
|
153
|
+
if predicate is not None:
|
|
154
|
+
allowed = np.fromiter((predicate(r) for r in self._rows), dtype=bool, count=len(self._rows))
|
|
155
|
+
rows, scores = self._index.search(embedding, k=limit, nprobe=self.nprobe, allowed=allowed)
|
|
156
|
+
return [(float(s), self._rows[i]) for i, s in zip(rows, scores)]
|
midas/bm25.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""BM25 lexical ranking — pure-Python, zero-dep.
|
|
2
|
+
|
|
3
|
+
Fused with semantic recall (see `Memory.recall(hybrid=True)`) to catch lexically-relevant memories
|
|
4
|
+
the bi-encoder misses — a no-LLM lever to push retrieval past the embedding-only ceiling. Standard
|
|
5
|
+
Okapi BM25 with non-negative IDF; rebuilt per recall over the candidate set (cheap at eval scale).
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
from collections import Counter
|
|
11
|
+
|
|
12
|
+
from .embeddings import tokenize
|
|
13
|
+
from .types import MemoryRecord
|
|
14
|
+
|
|
15
|
+
_K1 = 1.5
|
|
16
|
+
_B = 0.75
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BM25:
|
|
20
|
+
def __init__(self, records: list[MemoryRecord]) -> None:
|
|
21
|
+
self.records = records
|
|
22
|
+
self._docs = [tokenize(r.content) for r in records]
|
|
23
|
+
self._tfs = [Counter(d) for d in self._docs]
|
|
24
|
+
self._dls = [len(d) for d in self._docs]
|
|
25
|
+
self._avgdl = (sum(self._dls) / len(self._dls)) if self._dls else 0.0
|
|
26
|
+
df: Counter = Counter()
|
|
27
|
+
for doc in self._docs:
|
|
28
|
+
for term in set(doc):
|
|
29
|
+
df[term] += 1
|
|
30
|
+
n = len(records)
|
|
31
|
+
# Non-negative IDF (the BM25+ form) so common terms never push scores negative.
|
|
32
|
+
self._idf = {t: math.log(1 + (n - c + 0.5) / (c + 0.5)) for t, c in df.items()}
|
|
33
|
+
|
|
34
|
+
def scores(self, query: str) -> dict[str, float]:
|
|
35
|
+
"""Map record id -> BM25 score (only records with a positive score are included)."""
|
|
36
|
+
q_terms = tokenize(query)
|
|
37
|
+
avgdl = self._avgdl or 1.0
|
|
38
|
+
out: dict[str, float] = {}
|
|
39
|
+
for i, record in enumerate(self.records):
|
|
40
|
+
tf = self._tfs[i]
|
|
41
|
+
dl = self._dls[i]
|
|
42
|
+
s = 0.0
|
|
43
|
+
for term in q_terms:
|
|
44
|
+
f = tf.get(term, 0)
|
|
45
|
+
if not f:
|
|
46
|
+
continue
|
|
47
|
+
idf = self._idf.get(term, 0.0)
|
|
48
|
+
s += idf * (f * (_K1 + 1)) / (f + _K1 * (1 - _B + _B * dl / avgdl))
|
|
49
|
+
if s > 0.0:
|
|
50
|
+
out[record.id] = s
|
|
51
|
+
return out
|
midas/embeddings.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""Pluggable embeddings.
|
|
2
|
+
|
|
3
|
+
`HashingEmbedder` is a dependency-free, deterministic, offline stand-in so the
|
|
4
|
+
whole harness runs anywhere with zero setup. It is essentially a vectorized
|
|
5
|
+
bag-of-words — good enough to exercise the pipeline, NOT a real semantic model.
|
|
6
|
+
Swap in `OpenAIEmbedder` (or a local BGE/e5) for genuine semantic retrieval;
|
|
7
|
+
that swap is the entire reason embeddings live behind a Protocol.
|
|
8
|
+
"""
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import hashlib
|
|
12
|
+
import math
|
|
13
|
+
import os
|
|
14
|
+
import re
|
|
15
|
+
import sqlite3
|
|
16
|
+
import struct
|
|
17
|
+
import threading
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Protocol, runtime_checkable
|
|
20
|
+
|
|
21
|
+
_WORD = re.compile(r"\w+", re.UNICODE)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def tokenize(text: str) -> list[str]:
|
|
25
|
+
"""Lowercase word tokens of length > 2 (Unicode-aware)."""
|
|
26
|
+
return [w for w in _WORD.findall(text.lower()) if len(w) > 2]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@runtime_checkable
|
|
30
|
+
class Embedder(Protocol):
|
|
31
|
+
dim: int
|
|
32
|
+
|
|
33
|
+
def embed(self, text: str) -> list[float]: ...
|
|
34
|
+
|
|
35
|
+
def embed_many(self, texts: list[str]) -> list[list[float]]: ...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class HashingEmbedder:
|
|
39
|
+
"""Offline, deterministic bag-of-words hashed into a fixed-dim unit vector."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, dim: int = 256) -> None:
|
|
42
|
+
self.dim = dim
|
|
43
|
+
|
|
44
|
+
def embed(self, text: str) -> list[float]:
|
|
45
|
+
vec = [0.0] * self.dim
|
|
46
|
+
for tok in tokenize(text):
|
|
47
|
+
digest = hashlib.md5(tok.encode("utf-8")).digest()
|
|
48
|
+
h = int.from_bytes(digest[:8], "big")
|
|
49
|
+
idx = h % self.dim
|
|
50
|
+
sign = 1.0 if (h >> 8) & 1 else -1.0
|
|
51
|
+
vec[idx] += sign
|
|
52
|
+
return l2_normalize(vec)
|
|
53
|
+
|
|
54
|
+
def embed_many(self, texts: list[str]) -> list[list[float]]:
|
|
55
|
+
return [self.embed(text) for text in texts]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class OpenAIEmbedder:
|
|
59
|
+
"""Real semantic embeddings via OpenAI. Requires `OPENAI_API_KEY` and the
|
|
60
|
+
`openai` package (`uv pip install openai`)."""
|
|
61
|
+
|
|
62
|
+
def __init__(self, model: str = "text-embedding-3-small") -> None:
|
|
63
|
+
from openai import OpenAI # lazy import — optional dependency
|
|
64
|
+
|
|
65
|
+
self._client = OpenAI()
|
|
66
|
+
self.model = model
|
|
67
|
+
self.dim = 1536
|
|
68
|
+
|
|
69
|
+
def embed(self, text: str) -> list[float]:
|
|
70
|
+
resp = self._client.embeddings.create(model=self.model, input=text)
|
|
71
|
+
return l2_normalize(list(resp.data[0].embedding))
|
|
72
|
+
|
|
73
|
+
def embed_many(self, texts: list[str]) -> list[list[float]]:
|
|
74
|
+
if not texts:
|
|
75
|
+
return []
|
|
76
|
+
resp = self._client.embeddings.create(model=self.model, input=texts)
|
|
77
|
+
return [l2_normalize(list(item.embedding)) for item in resp.data]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _default_embedding_cache_path() -> Path:
|
|
81
|
+
return _default_cache_root() / "midas-embeddings.sqlite3"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _cache_namespace(embedder: object) -> str:
|
|
85
|
+
parts = [
|
|
86
|
+
f"{type(embedder).__module__}.{type(embedder).__qualname__}",
|
|
87
|
+
f"dim={getattr(embedder, 'dim', 'unknown')}",
|
|
88
|
+
]
|
|
89
|
+
for attr in ("model_name", "max_text_chars"):
|
|
90
|
+
value = getattr(embedder, attr, None)
|
|
91
|
+
if value is not None:
|
|
92
|
+
parts.append(f"{attr}={value}")
|
|
93
|
+
return "|".join(parts)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _encode_vector(vec: list[float]) -> bytes:
|
|
97
|
+
return struct.pack(f"<{len(vec)}f", *(float(v) for v in vec))
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _decode_vector(blob: bytes, *, dim: int) -> list[float]:
|
|
101
|
+
if len(blob) != dim * 4:
|
|
102
|
+
raise ValueError("cached embedding has an unexpected dimension")
|
|
103
|
+
return [float(v) for v in struct.unpack(f"<{dim}f", blob)]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class DiskCachedEmbedder:
|
|
107
|
+
"""Persistent SQLite cache wrapper for expensive embedders.
|
|
108
|
+
|
|
109
|
+
The key includes the embedder namespace and exact input text hash, so changing the
|
|
110
|
+
local model or truncation limit gets a separate cache namespace by default.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
inner: Embedder,
|
|
116
|
+
*,
|
|
117
|
+
path: str | Path | None = None,
|
|
118
|
+
namespace: str | None = None,
|
|
119
|
+
) -> None:
|
|
120
|
+
self.inner = inner
|
|
121
|
+
self.dim = inner.dim
|
|
122
|
+
self.namespace = namespace or _cache_namespace(inner)
|
|
123
|
+
self.path = Path(path) if path is not None else _default_embedding_cache_path()
|
|
124
|
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
125
|
+
self._lock = threading.Lock()
|
|
126
|
+
self._conn = sqlite3.connect(str(self.path), check_same_thread=False)
|
|
127
|
+
with self._lock:
|
|
128
|
+
self._conn.execute("PRAGMA journal_mode=WAL")
|
|
129
|
+
self._conn.execute(
|
|
130
|
+
"""
|
|
131
|
+
CREATE TABLE IF NOT EXISTS embeddings (
|
|
132
|
+
namespace TEXT NOT NULL,
|
|
133
|
+
text_sha256 TEXT NOT NULL,
|
|
134
|
+
dim INTEGER NOT NULL,
|
|
135
|
+
embedding BLOB NOT NULL,
|
|
136
|
+
created_at INTEGER NOT NULL DEFAULT (unixepoch()),
|
|
137
|
+
PRIMARY KEY (namespace, text_sha256)
|
|
138
|
+
)
|
|
139
|
+
"""
|
|
140
|
+
)
|
|
141
|
+
self._conn.commit()
|
|
142
|
+
self.hits = 0
|
|
143
|
+
self.misses = 0
|
|
144
|
+
|
|
145
|
+
def embed(self, text: str) -> list[float]:
|
|
146
|
+
return self.embed_many([text])[0]
|
|
147
|
+
|
|
148
|
+
def embed_many(self, texts: list[str]) -> list[list[float]]:
|
|
149
|
+
if not texts:
|
|
150
|
+
return []
|
|
151
|
+
|
|
152
|
+
keys = [hashlib.sha256(text.encode("utf-8")).hexdigest() for text in texts]
|
|
153
|
+
cached = self._read_many(set(keys))
|
|
154
|
+
missing: dict[str, str] = {}
|
|
155
|
+
for key, text in zip(keys, texts):
|
|
156
|
+
if key not in cached and key not in missing:
|
|
157
|
+
missing[key] = text
|
|
158
|
+
|
|
159
|
+
self.hits += len(texts) - sum(1 for key in keys if key not in cached)
|
|
160
|
+
self.misses += len(missing)
|
|
161
|
+
|
|
162
|
+
if missing:
|
|
163
|
+
missing_keys = list(missing)
|
|
164
|
+
missing_texts = [missing[key] for key in missing_keys]
|
|
165
|
+
embed_many = getattr(self.inner, "embed_many", None)
|
|
166
|
+
vectors = (
|
|
167
|
+
embed_many(missing_texts)
|
|
168
|
+
if callable(embed_many)
|
|
169
|
+
else [self.inner.embed(text) for text in missing_texts]
|
|
170
|
+
)
|
|
171
|
+
if len(vectors) != len(missing_keys):
|
|
172
|
+
raise ValueError("cached embedder inner returned the wrong number of embeddings")
|
|
173
|
+
self._write_many(zip(missing_keys, vectors))
|
|
174
|
+
cached.update(zip(missing_keys, vectors))
|
|
175
|
+
|
|
176
|
+
return [cached[key] for key in keys]
|
|
177
|
+
|
|
178
|
+
def _read_many(self, keys: set[str]) -> dict[str, list[float]]:
|
|
179
|
+
if not keys:
|
|
180
|
+
return {}
|
|
181
|
+
rows: dict[str, list[float]] = {}
|
|
182
|
+
ordered = list(keys)
|
|
183
|
+
with self._lock:
|
|
184
|
+
for i in range(0, len(ordered), 500):
|
|
185
|
+
chunk = ordered[i : i + 500]
|
|
186
|
+
placeholders = ",".join("?" for _ in chunk)
|
|
187
|
+
params = [self.namespace, *chunk]
|
|
188
|
+
cur = self._conn.execute(
|
|
189
|
+
f"""
|
|
190
|
+
SELECT text_sha256, dim, embedding
|
|
191
|
+
FROM embeddings
|
|
192
|
+
WHERE namespace = ? AND text_sha256 IN ({placeholders})
|
|
193
|
+
""",
|
|
194
|
+
params,
|
|
195
|
+
)
|
|
196
|
+
for key, dim, blob in cur.fetchall():
|
|
197
|
+
if dim == self.dim:
|
|
198
|
+
rows[key] = _decode_vector(blob, dim=self.dim)
|
|
199
|
+
return rows
|
|
200
|
+
|
|
201
|
+
def _write_many(self, items) -> None:
|
|
202
|
+
rows = [
|
|
203
|
+
(self.namespace, key, self.dim, _encode_vector(vec))
|
|
204
|
+
for key, vec in items
|
|
205
|
+
]
|
|
206
|
+
if not rows:
|
|
207
|
+
return
|
|
208
|
+
with self._lock:
|
|
209
|
+
self._conn.executemany(
|
|
210
|
+
"""
|
|
211
|
+
INSERT OR REPLACE INTO embeddings(namespace, text_sha256, dim, embedding)
|
|
212
|
+
VALUES (?, ?, ?, ?)
|
|
213
|
+
""",
|
|
214
|
+
rows,
|
|
215
|
+
)
|
|
216
|
+
self._conn.commit()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _default_cache_root() -> Path:
|
|
220
|
+
if root := os.getenv("MIDAS_CACHE_ROOT"):
|
|
221
|
+
return Path(root)
|
|
222
|
+
if os.name == "nt" and Path("D:/").exists():
|
|
223
|
+
return Path("D:/hf-cache")
|
|
224
|
+
return Path.home() / ".cache" / "midas"
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _default_tmp_dir() -> Path:
|
|
228
|
+
if tmp := os.getenv("MIDAS_TMP_DIR"):
|
|
229
|
+
return Path(tmp)
|
|
230
|
+
if os.name == "nt" and Path("D:/").exists():
|
|
231
|
+
return Path("D:/tmp")
|
|
232
|
+
return _default_cache_root() / "tmp"
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def configure_local_model_cache(cache_dir: str | Path | None = None) -> str:
|
|
236
|
+
"""Return a fastembed cache dir and set safe local cache env defaults."""
|
|
237
|
+
cache_root = _default_cache_root()
|
|
238
|
+
root = Path(cache_dir) if cache_dir is not None else cache_root / "fastembed"
|
|
239
|
+
tmp = _default_tmp_dir()
|
|
240
|
+
root.mkdir(parents=True, exist_ok=True)
|
|
241
|
+
tmp.mkdir(parents=True, exist_ok=True)
|
|
242
|
+
|
|
243
|
+
os.environ.setdefault("HF_HOME", str(cache_root))
|
|
244
|
+
os.environ.setdefault("HF_HUB_CACHE", str(cache_root / "hub"))
|
|
245
|
+
os.environ.setdefault("TMP", str(tmp))
|
|
246
|
+
os.environ.setdefault("TEMP", str(tmp))
|
|
247
|
+
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
|
|
248
|
+
return str(root)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class LocalEmbedder:
|
|
252
|
+
"""Real local semantic embeddings via fastembed (ONNX — no API key, no torch).
|
|
253
|
+
First use downloads a small model (~130 MB). Install: `uv pip install fastembed`."""
|
|
254
|
+
|
|
255
|
+
_MAX_TEXT_CHARS = 2000
|
|
256
|
+
|
|
257
|
+
def __init__(
|
|
258
|
+
self,
|
|
259
|
+
model_name: str = "BAAI/bge-base-en-v1.5",
|
|
260
|
+
*,
|
|
261
|
+
cache_dir: str | Path | None = None,
|
|
262
|
+
batch_size: int = 32,
|
|
263
|
+
max_text_chars: int | None = None,
|
|
264
|
+
) -> None:
|
|
265
|
+
self.cache_dir = configure_local_model_cache(cache_dir)
|
|
266
|
+
from fastembed import TextEmbedding # lazy import — optional dependency
|
|
267
|
+
|
|
268
|
+
self._model = TextEmbedding(model_name=model_name, cache_dir=self.cache_dir)
|
|
269
|
+
self.model_name = model_name
|
|
270
|
+
self.dim = 768
|
|
271
|
+
self.batch_size = batch_size
|
|
272
|
+
self.max_text_chars = max_text_chars or self._MAX_TEXT_CHARS
|
|
273
|
+
|
|
274
|
+
def embed(self, text: str) -> list[float]:
|
|
275
|
+
vec = next(iter(self._model.embed([text[: self.max_text_chars]], batch_size=1)))
|
|
276
|
+
return l2_normalize([float(x) for x in vec])
|
|
277
|
+
|
|
278
|
+
def embed_many(self, texts: list[str]) -> list[list[float]]:
|
|
279
|
+
if not texts:
|
|
280
|
+
return []
|
|
281
|
+
capped = [text[: self.max_text_chars] for text in texts]
|
|
282
|
+
return [
|
|
283
|
+
l2_normalize([float(x) for x in vec])
|
|
284
|
+
for vec in self._model.embed(capped, batch_size=self.batch_size)
|
|
285
|
+
]
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class LocalReranker:
|
|
289
|
+
"""Local cross-encoder reranker (fastembed/ONNX, no API key, no torch). Reorders
|
|
290
|
+
candidate documents by relevance to the query — sharper precision than bi-encoder
|
|
291
|
+
cosine. First use downloads a small model. Used by Midas to refine the recall pool."""
|
|
292
|
+
|
|
293
|
+
def __init__(
|
|
294
|
+
self,
|
|
295
|
+
model_name: str = "Xenova/ms-marco-MiniLM-L-6-v2",
|
|
296
|
+
*,
|
|
297
|
+
cache_dir: str | Path | None = None,
|
|
298
|
+
) -> None:
|
|
299
|
+
self.cache_dir = configure_local_model_cache(cache_dir)
|
|
300
|
+
from fastembed.rerank.cross_encoder import TextCrossEncoder # lazy, optional
|
|
301
|
+
|
|
302
|
+
self._model = TextCrossEncoder(model_name=model_name, cache_dir=self.cache_dir)
|
|
303
|
+
self.model_name = model_name
|
|
304
|
+
|
|
305
|
+
# Cross-encoders cap the query+doc PAIR at ~512 tokens; oversized input crashes the ONNX
|
|
306
|
+
# runtime (seen on long real turns). Cap defensively — affects only the rerank score, not
|
|
307
|
+
# the stored/assembled content.
|
|
308
|
+
_MAX_QUERY_CHARS = 400
|
|
309
|
+
_MAX_DOC_CHARS = 1200
|
|
310
|
+
|
|
311
|
+
def rerank(self, query: str, documents: list[str]) -> list[float]:
|
|
312
|
+
"""Return a relevance score per document (higher = more relevant), in order.
|
|
313
|
+
Length-capped to stay within the cross-encoder context; degrades to neutral scores on
|
|
314
|
+
any runtime error rather than crashing the whole query."""
|
|
315
|
+
if not documents:
|
|
316
|
+
return []
|
|
317
|
+
q = query[: self._MAX_QUERY_CHARS]
|
|
318
|
+
docs = [d[: self._MAX_DOC_CHARS] for d in documents]
|
|
319
|
+
try:
|
|
320
|
+
return [float(s) for s in self._model.rerank(q, docs)]
|
|
321
|
+
except Exception:
|
|
322
|
+
return [0.0] * len(docs)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def l2_normalize(vec: list[float]) -> list[float]:
|
|
326
|
+
norm = math.sqrt(sum(v * v for v in vec))
|
|
327
|
+
if norm == 0.0:
|
|
328
|
+
return vec
|
|
329
|
+
return [v / norm for v in vec]
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def cosine(a: list[float], b: list[float]) -> float:
|
|
333
|
+
"""Cosine similarity. Inputs assumed L2-normalized; result clamped to [-1, 1]."""
|
|
334
|
+
return max(-1.0, min(1.0, sum(x * y for x, y in zip(a, b))))
|
midas/entity.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""No-LLM entity-grounded abstention — a candidate lever for the *Calibrated* frontier.
|
|
2
|
+
|
|
3
|
+
Diagnosed root cause (see docs §5/§7): the reader confabulates an answer drawn from a retrieved
|
|
4
|
+
distractor that is about a DIFFERENT entity than the question asks about (e.g. Q asks the *hamster's*
|
|
5
|
+
name, the retrieved turn is "I have a *cat* named Luna" → the reader answers "Luna"). Relevance, cosine
|
|
6
|
+
and NLI-entailment all fail here because that distractor genuinely *entails* the wrong answer.
|
|
7
|
+
|
|
8
|
+
This checks something orthogonal: is the answer's **source turn about the entity the question asks
|
|
9
|
+
about?** We pull the question's focus nouns and test whether the source mentions them. Pure regex, no
|
|
10
|
+
LLM. This module is **measured offline** before any end-to-end claim (the real test needs a capable
|
|
11
|
+
reader; the local 1B model exhibits hallucination, not confab-from-distractor — see the benchmark notes).
|
|
12
|
+
"""
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import re
|
|
16
|
+
|
|
17
|
+
# Generic scaffolding that is NOT the entity a question is about: question words, copulas, and the
|
|
18
|
+
# attribute/relation words that recur across confab pairs ("favorite", "name") and so must not count
|
|
19
|
+
# as a shared entity.
|
|
20
|
+
_STOP = {
|
|
21
|
+
"what", "which", "who", "whom", "whose", "where", "when", "why", "how", "is", "are", "was", "were",
|
|
22
|
+
"the", "a", "an", "of", "to", "for", "in", "on", "at", "by", "with", "and", "or", "but", "do",
|
|
23
|
+
"does", "did", "have", "has", "had", "his", "her", "their", "my", "your", "our", "its", "this",
|
|
24
|
+
"that", "user", "users", "you", "i", "we", "they", "he", "she", "it", "name", "named", "call",
|
|
25
|
+
"called", "favorite", "favourite", "kind", "type", "about", "tell", "me", "please", "they're",
|
|
26
|
+
}
|
|
27
|
+
_WORD = re.compile(r"[A-Za-z][A-Za-z'-]+")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def entities(text: str) -> set[str]:
|
|
31
|
+
"""Salient content words (non-stopword, >=3 chars), lowercased — a turn's rough entity fingerprint.
|
|
32
|
+
Possessives are normalised to the base noun ("user's" -> "user") so scaffolding is filtered."""
|
|
33
|
+
out: set[str] = set()
|
|
34
|
+
for m in _WORD.finditer(text or ""):
|
|
35
|
+
w = m.group().lower()
|
|
36
|
+
if w.endswith("'s"):
|
|
37
|
+
w = w[:-2] # possessive -> base noun
|
|
38
|
+
if len(w) >= 3 and w not in _STOP:
|
|
39
|
+
out.add(w)
|
|
40
|
+
return out
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def question_focus(question: str) -> set[str]:
|
|
44
|
+
"""The entity/entities a question is about — its salient nouns minus the generic scaffolding."""
|
|
45
|
+
return entities(question)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def entity_grounded(question: str, source: str, *, min_overlap: int = 1) -> bool:
|
|
49
|
+
"""Is the source turn about the entity the question asks about? True if they share >= `min_overlap`
|
|
50
|
+
salient words. A confab drawn from a wrong-entity distractor shares none of the question's focus."""
|
|
51
|
+
return len(question_focus(question) & entities(source)) >= min_overlap
|