@pentatonic-ai/ai-agent-sdk 0.10.4 → 0.10.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.cjs +1 -1
- package/dist/index.js +1 -1
- package/package.json +1 -1
- package/packages/memory-engine-v2/compat/requirements.txt +6 -0
- package/packages/memory-engine-v2/compat/server.py +258 -18
- package/packages/memory-engine-v2/eval/recall_at_k.py +242 -0
- package/packages/memory-engine-v2/eval/retrieval_golden.seed.json +69 -0
- package/packages/memory-engine-v2/extractor-async/Dockerfile +1 -1
- package/packages/memory-engine-v2/extractor-async/extraction_schema.py +246 -0
- package/packages/memory-engine-v2/extractor-async/test_guided_json_parser.py +411 -0
- package/packages/memory-engine-v2/extractor-async/worker.py +417 -31
- package/packages/memory-engine-v2/resolution-queue-design.md +165 -0
- package/packages/memory-engine-v2/scripts/backfill_entity_reconciliation.py +11 -2
- package/packages/memory-engine-v2/scripts/backfill_sparse_vectors.py +369 -0
- package/packages/memory-engine-v2/scripts/bakeoff_guided_vs_kv.py +607 -0
- package/packages/memory-engine-v2/scripts/entity_resolution_v2.py +1041 -0
- package/packages/memory-engine-v2/tests/test_entity_resolution_v2.py +507 -0
- package/packages/memory-engine-v2/tests/test_hybrid_retrieval.py +810 -0
|
@@ -0,0 +1,810 @@
|
|
|
1
|
+
"""Unit tests for flag-gated hybrid BM25+RRF retrieval (roadmap BET 3).
|
|
2
|
+
|
|
3
|
+
Covers, without any live engine / network / docker:
|
|
4
|
+
|
|
5
|
+
- flag OFF → /search uses the legacy `search()` call (query_points
|
|
6
|
+
never touched) and /store writes the bare unnamed dense vector —
|
|
7
|
+
i.e. the request path is unchanged.
|
|
8
|
+
- flag ON → /search issues `query_points` with the exact two-leg
|
|
9
|
+
prefetch (dense on the unnamed '' vector, sparse on 'lex') fused by
|
|
10
|
+
FusionQuery(RRF); /store and /store-batch write the named-vector
|
|
11
|
+
bag {'': dense, 'lex': sparse} from FULL content.
|
|
12
|
+
- sparse-encode failure with flag ON → graceful dense-only fallback.
|
|
13
|
+
- backfill script dry-run math + state round-trip.
|
|
14
|
+
- eval harness metric math (recall@k / nDCG@k).
|
|
15
|
+
- the real fastembed encoder wrapper (skipped when fastembed is not
|
|
16
|
+
installed — pytest.importorskip / stdlib-runner skip).
|
|
17
|
+
|
|
18
|
+
Dependency strategy: compat/server.py imports fastapi/pydantic/qdrant/
|
|
19
|
+
httpx/numpy/psycopg at module load. These tests use the REAL packages
|
|
20
|
+
when importable and install minimal in-memory stubs into sys.modules
|
|
21
|
+
otherwise, so the suite runs both in a full dev env (pytest) and in a
|
|
22
|
+
bare stdlib environment:
|
|
23
|
+
|
|
24
|
+
python3 packages/memory-engine-v2/tests/test_hybrid_retrieval.py
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
import asyncio
|
|
30
|
+
import importlib.util
|
|
31
|
+
import json
|
|
32
|
+
import os
|
|
33
|
+
import sys
|
|
34
|
+
import tempfile
|
|
35
|
+
import types
|
|
36
|
+
from pathlib import Path
|
|
37
|
+
|
|
38
|
+
_PKG = Path(__file__).resolve().parent.parent
|
|
39
|
+
_SERVER = _PKG / "compat" / "server.py"
|
|
40
|
+
_BACKFILL = _PKG / "scripts" / "backfill_sparse_vectors.py"
|
|
41
|
+
_EVAL = _PKG / "eval" / "recall_at_k.py"
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
import pytest
|
|
45
|
+
except ImportError: # bare stdlib runner
|
|
46
|
+
pytest = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class _Skip(Exception):
|
|
50
|
+
"""Stdlib-runner skip marker (pytest path uses pytest.skip)."""
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _skip(msg: str):
|
|
54
|
+
if pytest is not None:
|
|
55
|
+
pytest.skip(msg)
|
|
56
|
+
raise _Skip(msg)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# ----------------------------------------------------------------------
|
|
60
|
+
# Minimal stubs for server.py's import surface (used only when the real
|
|
61
|
+
# package is not importable in this environment).
|
|
62
|
+
# ----------------------------------------------------------------------
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class _Rec:
|
|
66
|
+
"""Generic kwargs-record stand-in for qdrant model classes."""
|
|
67
|
+
|
|
68
|
+
def __init__(self, **kw):
|
|
69
|
+
self._kw = kw
|
|
70
|
+
for k, v in kw.items():
|
|
71
|
+
setattr(self, k, v)
|
|
72
|
+
|
|
73
|
+
def __eq__(self, other):
|
|
74
|
+
return type(other) is type(self) and self._kw == other._kw
|
|
75
|
+
|
|
76
|
+
def __repr__(self):
|
|
77
|
+
return f"{type(self).__name__}({self._kw})"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _stub_qdrant_client() -> types.ModuleType:
|
|
81
|
+
mod = types.ModuleType("qdrant_client")
|
|
82
|
+
models = types.ModuleType("qdrant_client.models")
|
|
83
|
+
|
|
84
|
+
for name in (
|
|
85
|
+
"VectorParams", "ScalarQuantization", "ScalarQuantizationConfig",
|
|
86
|
+
"FieldCondition", "MatchAny", "MatchValue", "Filter", "PointStruct",
|
|
87
|
+
"FilterSelector", "SparseVector", "SparseVectorParams",
|
|
88
|
+
"SparseIndexParams", "Prefetch", "FusionQuery", "PointVectors",
|
|
89
|
+
):
|
|
90
|
+
setattr(models, name, type(name, (_Rec,), {}))
|
|
91
|
+
|
|
92
|
+
class Distance:
|
|
93
|
+
COSINE = "Cosine"
|
|
94
|
+
|
|
95
|
+
class ScalarType:
|
|
96
|
+
INT8 = "int8"
|
|
97
|
+
|
|
98
|
+
class PayloadSchemaType:
|
|
99
|
+
KEYWORD = "keyword"
|
|
100
|
+
|
|
101
|
+
class Modifier:
|
|
102
|
+
IDF = "idf"
|
|
103
|
+
|
|
104
|
+
class Fusion:
|
|
105
|
+
RRF = "rrf"
|
|
106
|
+
DBSF = "dbsf"
|
|
107
|
+
|
|
108
|
+
models.Distance = Distance
|
|
109
|
+
models.ScalarType = ScalarType
|
|
110
|
+
models.PayloadSchemaType = PayloadSchemaType
|
|
111
|
+
models.Modifier = Modifier
|
|
112
|
+
models.Fusion = Fusion
|
|
113
|
+
|
|
114
|
+
class AsyncQdrantClient: # never instantiated in tests
|
|
115
|
+
def __init__(self, *a, **kw):
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
mod.AsyncQdrantClient = AsyncQdrantClient
|
|
119
|
+
mod.models = models
|
|
120
|
+
sys.modules["qdrant_client.models"] = models
|
|
121
|
+
return mod
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _stub_pydantic() -> types.ModuleType:
|
|
125
|
+
mod = types.ModuleType("pydantic")
|
|
126
|
+
|
|
127
|
+
class _FieldInfo:
|
|
128
|
+
def __init__(self, default=None, default_factory=None):
|
|
129
|
+
self.default = default
|
|
130
|
+
self.default_factory = default_factory
|
|
131
|
+
|
|
132
|
+
def Field(default=None, **kw):
|
|
133
|
+
return _FieldInfo(default, kw.get("default_factory"))
|
|
134
|
+
|
|
135
|
+
class BaseModel:
|
|
136
|
+
def __init__(self, **kwargs):
|
|
137
|
+
ann: dict = {}
|
|
138
|
+
for klass in reversed(type(self).__mro__):
|
|
139
|
+
ann.update(getattr(klass, "__annotations__", {}))
|
|
140
|
+
for name in ann:
|
|
141
|
+
if name in kwargs:
|
|
142
|
+
value = kwargs[name]
|
|
143
|
+
else:
|
|
144
|
+
default = getattr(type(self), name, None)
|
|
145
|
+
if isinstance(default, _FieldInfo):
|
|
146
|
+
value = (default.default_factory()
|
|
147
|
+
if default.default_factory else default.default)
|
|
148
|
+
else:
|
|
149
|
+
value = default
|
|
150
|
+
setattr(self, name, value)
|
|
151
|
+
|
|
152
|
+
mod.BaseModel = BaseModel
|
|
153
|
+
mod.Field = Field
|
|
154
|
+
return mod
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _stub_fastapi() -> types.ModuleType:
|
|
158
|
+
mod = types.ModuleType("fastapi")
|
|
159
|
+
|
|
160
|
+
class FastAPI:
|
|
161
|
+
def __init__(self, **kw):
|
|
162
|
+
pass
|
|
163
|
+
|
|
164
|
+
def get(self, path):
|
|
165
|
+
return lambda fn: fn
|
|
166
|
+
|
|
167
|
+
def post(self, path):
|
|
168
|
+
return lambda fn: fn
|
|
169
|
+
|
|
170
|
+
class HTTPException(Exception):
|
|
171
|
+
def __init__(self, status_code, detail=None):
|
|
172
|
+
super().__init__(detail)
|
|
173
|
+
self.status_code = status_code
|
|
174
|
+
self.detail = detail
|
|
175
|
+
|
|
176
|
+
mod.FastAPI = FastAPI
|
|
177
|
+
mod.HTTPException = HTTPException
|
|
178
|
+
return mod
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _stub_httpx() -> types.ModuleType:
|
|
182
|
+
mod = types.ModuleType("httpx")
|
|
183
|
+
|
|
184
|
+
class AsyncClient:
|
|
185
|
+
def __init__(self, *a, **kw):
|
|
186
|
+
pass
|
|
187
|
+
|
|
188
|
+
class Timeout:
|
|
189
|
+
def __init__(self, *a, **kw):
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
class HTTPStatusError(Exception):
|
|
193
|
+
def __init__(self, *a, **kw):
|
|
194
|
+
super().__init__(*a)
|
|
195
|
+
|
|
196
|
+
class TimeoutException(Exception):
|
|
197
|
+
pass
|
|
198
|
+
|
|
199
|
+
class NetworkError(Exception):
|
|
200
|
+
pass
|
|
201
|
+
|
|
202
|
+
mod.AsyncClient = AsyncClient
|
|
203
|
+
mod.Timeout = Timeout
|
|
204
|
+
mod.HTTPStatusError = HTTPStatusError
|
|
205
|
+
mod.TimeoutException = TimeoutException
|
|
206
|
+
mod.NetworkError = NetworkError
|
|
207
|
+
return mod
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _stub_numpy() -> types.ModuleType:
|
|
211
|
+
mod = types.ModuleType("numpy")
|
|
212
|
+
|
|
213
|
+
def _unused(*a, **kw): # MMR is vector-gated; tests never reach numpy
|
|
214
|
+
raise AssertionError("numpy stub should not be exercised by these tests")
|
|
215
|
+
|
|
216
|
+
mod.asarray = _unused
|
|
217
|
+
mod.max = _unused
|
|
218
|
+
mod.float32 = "float32"
|
|
219
|
+
return mod
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _stub_psycopg() -> tuple[types.ModuleType, types.ModuleType, types.ModuleType]:
|
|
223
|
+
psycopg = types.ModuleType("psycopg")
|
|
224
|
+
rows = types.ModuleType("psycopg.rows")
|
|
225
|
+
rows.dict_row = object()
|
|
226
|
+
tjson = types.ModuleType("psycopg.types.json")
|
|
227
|
+
|
|
228
|
+
class Json:
|
|
229
|
+
def __init__(self, obj):
|
|
230
|
+
self.obj = obj
|
|
231
|
+
|
|
232
|
+
tjson.Json = Json
|
|
233
|
+
tmod = types.ModuleType("psycopg.types")
|
|
234
|
+
tmod.json = tjson
|
|
235
|
+
psycopg.rows = rows
|
|
236
|
+
psycopg.types = tmod
|
|
237
|
+
|
|
238
|
+
pool = types.ModuleType("psycopg_pool")
|
|
239
|
+
|
|
240
|
+
class AsyncConnectionPool:
|
|
241
|
+
def __init__(self, *a, **kw):
|
|
242
|
+
pass
|
|
243
|
+
|
|
244
|
+
pool.AsyncConnectionPool = AsyncConnectionPool
|
|
245
|
+
return psycopg, rows, tmod, tjson, pool
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _ensure_modules():
|
|
249
|
+
"""Install stubs for any of server.py's deps that aren't importable."""
|
|
250
|
+
def missing(name: str) -> bool:
|
|
251
|
+
if name in sys.modules:
|
|
252
|
+
return False
|
|
253
|
+
try:
|
|
254
|
+
return importlib.util.find_spec(name) is None
|
|
255
|
+
except (ImportError, ValueError):
|
|
256
|
+
return True
|
|
257
|
+
|
|
258
|
+
if missing("qdrant_client"):
|
|
259
|
+
sys.modules["qdrant_client"] = _stub_qdrant_client()
|
|
260
|
+
if missing("pydantic"):
|
|
261
|
+
sys.modules["pydantic"] = _stub_pydantic()
|
|
262
|
+
if missing("fastapi"):
|
|
263
|
+
sys.modules["fastapi"] = _stub_fastapi()
|
|
264
|
+
if missing("httpx"):
|
|
265
|
+
sys.modules["httpx"] = _stub_httpx()
|
|
266
|
+
if missing("numpy"):
|
|
267
|
+
sys.modules["numpy"] = _stub_numpy()
|
|
268
|
+
if missing("psycopg"):
|
|
269
|
+
psycopg, rows, tmod, tjson, pool = _stub_psycopg()
|
|
270
|
+
sys.modules["psycopg"] = psycopg
|
|
271
|
+
sys.modules["psycopg.rows"] = rows
|
|
272
|
+
sys.modules["psycopg.types"] = tmod
|
|
273
|
+
sys.modules["psycopg.types.json"] = tjson
|
|
274
|
+
sys.modules["psycopg_pool"] = pool
|
|
275
|
+
elif missing("psycopg_pool"):
|
|
276
|
+
_, _, _, _, pool = _stub_psycopg()
|
|
277
|
+
sys.modules["psycopg_pool"] = pool
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
_LOAD_SEQ = 0
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _load_module(path: Path, name: str):
|
|
284
|
+
spec = importlib.util.spec_from_file_location(name, path)
|
|
285
|
+
assert spec and spec.loader
|
|
286
|
+
mod = importlib.util.module_from_spec(spec)
|
|
287
|
+
sys.modules[name] = mod
|
|
288
|
+
spec.loader.exec_module(mod)
|
|
289
|
+
return mod
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def load_server(hybrid: bool):
|
|
293
|
+
"""Fresh server module instance with SEARCH_HYBRID_ENABLED set
|
|
294
|
+
before import (the flag is read at module load)."""
|
|
295
|
+
global _LOAD_SEQ
|
|
296
|
+
_LOAD_SEQ += 1
|
|
297
|
+
_ensure_modules()
|
|
298
|
+
os.environ["SEARCH_HYBRID_ENABLED"] = "1" if hybrid else "0"
|
|
299
|
+
return _load_module(_SERVER, f"_compat_server_under_test_{_LOAD_SEQ}")
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# ----------------------------------------------------------------------
|
|
303
|
+
# Async fakes (qdrant client / pg pool)
|
|
304
|
+
# ----------------------------------------------------------------------
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class FakeScored:
|
|
308
|
+
def __init__(self, event_id: str, score: float, source_kind: str = "note",
|
|
309
|
+
payload_extra: dict | None = None):
|
|
310
|
+
self.payload = {"event_id": event_id, "arena": "arena-a",
|
|
311
|
+
"source_kind": source_kind, **(payload_extra or {})}
|
|
312
|
+
self.score = score
|
|
313
|
+
self.vector = None # vector-less → MMR falls back to score order
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class FakeQdrant:
|
|
317
|
+
def __init__(self, search_results=None, query_points_results=None):
|
|
318
|
+
self.search_results = search_results or []
|
|
319
|
+
self.query_points_results = query_points_results or []
|
|
320
|
+
self.search_calls: list[dict] = []
|
|
321
|
+
self.query_points_calls: list[dict] = []
|
|
322
|
+
self.upsert_calls: list[dict] = []
|
|
323
|
+
|
|
324
|
+
async def search(self, **kw):
|
|
325
|
+
self.search_calls.append(kw)
|
|
326
|
+
return list(self.search_results)
|
|
327
|
+
|
|
328
|
+
async def query_points(self, **kw):
|
|
329
|
+
self.query_points_calls.append(kw)
|
|
330
|
+
return types.SimpleNamespace(points=list(self.query_points_results))
|
|
331
|
+
|
|
332
|
+
async def upsert(self, **kw):
|
|
333
|
+
self.upsert_calls.append(kw)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class _AsyncCM:
|
|
337
|
+
def __init__(self, value):
|
|
338
|
+
self.value = value
|
|
339
|
+
|
|
340
|
+
async def __aenter__(self):
|
|
341
|
+
return self.value
|
|
342
|
+
|
|
343
|
+
async def __aexit__(self, *a):
|
|
344
|
+
return False
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class FakeCursor:
|
|
348
|
+
def __init__(self, rows=None):
|
|
349
|
+
self.rows = rows or []
|
|
350
|
+
self.executed: list[tuple] = []
|
|
351
|
+
|
|
352
|
+
async def execute(self, sql, params=None):
|
|
353
|
+
self.executed.append((sql, params))
|
|
354
|
+
|
|
355
|
+
async def executemany(self, sql, rows):
|
|
356
|
+
self.executed.append((sql, rows))
|
|
357
|
+
|
|
358
|
+
async def fetchall(self):
|
|
359
|
+
return list(self.rows)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class FakeConn:
|
|
363
|
+
def __init__(self, cursor):
|
|
364
|
+
self._cursor = cursor
|
|
365
|
+
|
|
366
|
+
def cursor(self):
|
|
367
|
+
return _AsyncCM(self._cursor)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
class FakePool:
|
|
371
|
+
def __init__(self, cursor=None):
|
|
372
|
+
self.cursor = cursor or FakeCursor()
|
|
373
|
+
|
|
374
|
+
def connection(self):
|
|
375
|
+
return _AsyncCM(FakeConn(self.cursor))
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def _wire_search_fakes(server, qdrant, db_rows):
|
|
379
|
+
async def fake_embed(texts, lane="bulk"):
|
|
380
|
+
return [[0.1, 0.2, 0.3, 0.4] for _ in texts]
|
|
381
|
+
|
|
382
|
+
server._embed_batch = fake_embed
|
|
383
|
+
server._qdrant = qdrant
|
|
384
|
+
server._pool = FakePool(FakeCursor(rows=db_rows))
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def _db_row(event_id: str, content: str = "full content", ts: str | None = None):
|
|
388
|
+
attrs = {"timestamp": ts} if ts else {}
|
|
389
|
+
return {"id": event_id, "content": content, "attributes": attrs}
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
# ----------------------------------------------------------------------
|
|
393
|
+
# /search — flag OFF: legacy path, byte-identical behavior
|
|
394
|
+
# ----------------------------------------------------------------------
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def test_flag_off_search_uses_legacy_search_not_query_points():
|
|
398
|
+
server = load_server(hybrid=False)
|
|
399
|
+
qdrant = FakeQdrant(search_results=[FakeScored("e1", 0.91), FakeScored("e2", 0.84)])
|
|
400
|
+
_wire_search_fakes(server, qdrant, [_db_row("e1"), _db_row("e2")])
|
|
401
|
+
|
|
402
|
+
out = asyncio.run(server.search(server.SearchRequest(query="who is pact", arena="arena-a")))
|
|
403
|
+
|
|
404
|
+
assert len(qdrant.search_calls) == 1, "flag-off must use the legacy search()"
|
|
405
|
+
assert qdrant.query_points_calls == [], "flag-off must NEVER call query_points"
|
|
406
|
+
call = qdrant.search_calls[0]
|
|
407
|
+
assert call["collection_name"] == "evidence"
|
|
408
|
+
assert call["query_vector"] == [0.1, 0.2, 0.3, 0.4]
|
|
409
|
+
assert call["limit"] == 30 # limit 10 × SEARCH_OVERFETCH_MULT 3
|
|
410
|
+
assert call["score_threshold"] == 0.001
|
|
411
|
+
assert call["with_payload"] is True
|
|
412
|
+
ids = [r["id"] for r in out["results"]]
|
|
413
|
+
assert ids == ["e1", "e2"]
|
|
414
|
+
assert out["results"][0]["content"] == "full content"
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def test_flag_off_never_calls_sparse_encoder():
|
|
418
|
+
server = load_server(hybrid=False)
|
|
419
|
+
|
|
420
|
+
def boom(*a, **kw):
|
|
421
|
+
raise AssertionError("sparse encoder must not be touched when flag is off")
|
|
422
|
+
|
|
423
|
+
server._get_sparse_encoder = boom
|
|
424
|
+
qdrant = FakeQdrant(search_results=[])
|
|
425
|
+
_wire_search_fakes(server, qdrant, [])
|
|
426
|
+
out = asyncio.run(server.search(server.SearchRequest(query="x", arena="arena-a")))
|
|
427
|
+
assert out == {"results": []}
|
|
428
|
+
assert len(qdrant.search_calls) == 1
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
# ----------------------------------------------------------------------
|
|
432
|
+
# /search — flag ON: RRF-fused query_points with two prefetch legs
|
|
433
|
+
# ----------------------------------------------------------------------
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def _sentinel_sparse(server):
|
|
437
|
+
return server.qmodels.SparseVector(indices=[3, 17], values=[1.0, 1.0])
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def test_flag_on_search_uses_query_points_with_rrf_prefetch():
|
|
441
|
+
server = load_server(hybrid=True)
|
|
442
|
+
qdrant = FakeQdrant(query_points_results=[FakeScored("e1", 0.0163), FakeScored("e2", 0.0161)])
|
|
443
|
+
_wire_search_fakes(server, qdrant, [_db_row("e1"), _db_row("e2")])
|
|
444
|
+
sentinel = _sentinel_sparse(server)
|
|
445
|
+
|
|
446
|
+
async def fake_sparse_query(text):
|
|
447
|
+
return sentinel
|
|
448
|
+
|
|
449
|
+
server._sparse_encode_query = fake_sparse_query
|
|
450
|
+
|
|
451
|
+
out = asyncio.run(server.search(server.SearchRequest(query="acme invoice 4711", arena="arena-a")))
|
|
452
|
+
|
|
453
|
+
assert qdrant.search_calls == [], "flag-on must not use the legacy search()"
|
|
454
|
+
assert len(qdrant.query_points_calls) == 1
|
|
455
|
+
call = qdrant.query_points_calls[0]
|
|
456
|
+
assert call["collection_name"] == "evidence"
|
|
457
|
+
assert call["with_payload"] is True
|
|
458
|
+
assert call["limit"] == 30
|
|
459
|
+
|
|
460
|
+
prefetch = call["prefetch"]
|
|
461
|
+
assert len(prefetch) == 2
|
|
462
|
+
dense, sparse = prefetch
|
|
463
|
+
assert dense.using == "" # unnamed dense vector
|
|
464
|
+
assert dense.query == [0.1, 0.2, 0.3, 0.4]
|
|
465
|
+
assert dense.limit == 30
|
|
466
|
+
assert dense.filter is not None
|
|
467
|
+
assert sparse.using == "lex"
|
|
468
|
+
assert sparse.query == sentinel
|
|
469
|
+
assert sparse.limit == 30
|
|
470
|
+
assert sparse.filter is not None
|
|
471
|
+
|
|
472
|
+
fusion_query = call["query"]
|
|
473
|
+
assert fusion_query.fusion == server.qmodels.Fusion.RRF
|
|
474
|
+
|
|
475
|
+
# downstream pipeline (dedup → hydration) untouched: RRF score
|
|
476
|
+
# surfaces as `similarity`, content hydrated from postgres.
|
|
477
|
+
assert [r["id"] for r in out["results"]] == ["e1", "e2"]
|
|
478
|
+
assert out["results"][0]["similarity"] == 0.0163
|
|
479
|
+
assert out["results"][0]["content"] == "full content"
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def test_flag_on_sparse_query_failure_falls_back_to_dense():
|
|
483
|
+
server = load_server(hybrid=True)
|
|
484
|
+
qdrant = FakeQdrant(search_results=[FakeScored("e1", 0.9)])
|
|
485
|
+
_wire_search_fakes(server, qdrant, [_db_row("e1")])
|
|
486
|
+
|
|
487
|
+
async def broken_sparse_query(text):
|
|
488
|
+
raise RuntimeError("fastembed unavailable")
|
|
489
|
+
|
|
490
|
+
server._sparse_encode_query = broken_sparse_query
|
|
491
|
+
|
|
492
|
+
out = asyncio.run(server.search(server.SearchRequest(query="x", arena="arena-a")))
|
|
493
|
+
assert len(qdrant.search_calls) == 1, "sparse failure must fall back to dense search()"
|
|
494
|
+
assert qdrant.query_points_calls == []
|
|
495
|
+
assert [r["id"] for r in out["results"]] == ["e1"]
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
# ----------------------------------------------------------------------
|
|
499
|
+
# /store + /store-batch — named sparse vector writes
|
|
500
|
+
# ----------------------------------------------------------------------
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def _wire_store_fakes(server, qdrant):
|
|
504
|
+
async def fake_embed(texts, lane="bulk"):
|
|
505
|
+
return [[0.5, 0.6] for _ in texts]
|
|
506
|
+
|
|
507
|
+
async def fake_extract(arena, clientId, userId, source_kind, content, attributes):
|
|
508
|
+
return "evt-" + str(abs(hash(content)) % 10_000)
|
|
509
|
+
|
|
510
|
+
server._embed_batch = fake_embed
|
|
511
|
+
server._extract = fake_extract
|
|
512
|
+
server._qdrant = qdrant
|
|
513
|
+
server._pool = FakePool()
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def test_flag_off_store_writes_bare_dense_vector():
|
|
517
|
+
server = load_server(hybrid=False)
|
|
518
|
+
qdrant = FakeQdrant()
|
|
519
|
+
_wire_store_fakes(server, qdrant)
|
|
520
|
+
|
|
521
|
+
asyncio.run(server.store(server.StoreRequest(content="hello world", metadata={"arena": "arena-a"})))
|
|
522
|
+
|
|
523
|
+
assert len(qdrant.upsert_calls) == 1
|
|
524
|
+
point = qdrant.upsert_calls[0]["points"][0]
|
|
525
|
+
assert point.vector == [0.5, 0.6], "flag-off must keep the bare unnamed dense vector"
|
|
526
|
+
assert not isinstance(point.vector, dict)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def test_flag_on_store_writes_named_dense_plus_lex_sparse_from_full_content():
|
|
530
|
+
server = load_server(hybrid=True)
|
|
531
|
+
qdrant = FakeQdrant()
|
|
532
|
+
_wire_store_fakes(server, qdrant)
|
|
533
|
+
sentinel = _sentinel_sparse(server)
|
|
534
|
+
seen_texts: list[list[str]] = []
|
|
535
|
+
|
|
536
|
+
async def fake_sparse_docs(texts):
|
|
537
|
+
seen_texts.append(list(texts))
|
|
538
|
+
return [sentinel for _ in texts]
|
|
539
|
+
|
|
540
|
+
server._sparse_encode_documents = fake_sparse_docs
|
|
541
|
+
|
|
542
|
+
long_content = "x" * 800 # > the 300-char content_preview truncation
|
|
543
|
+
asyncio.run(server.store(server.StoreRequest(content=long_content, metadata={"arena": "arena-a"})))
|
|
544
|
+
|
|
545
|
+
point = qdrant.upsert_calls[0]["points"][0]
|
|
546
|
+
assert isinstance(point.vector, dict)
|
|
547
|
+
assert point.vector[""] == [0.5, 0.6]
|
|
548
|
+
assert point.vector["lex"] == sentinel
|
|
549
|
+
# sparse encoding must see FULL content, not the 300-char preview
|
|
550
|
+
assert seen_texts == [[long_content]]
|
|
551
|
+
assert point.payload["content_preview"] == "x" * 300
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def test_flag_on_store_batch_writes_named_vectors_per_record():
|
|
555
|
+
server = load_server(hybrid=True)
|
|
556
|
+
qdrant = FakeQdrant()
|
|
557
|
+
_wire_store_fakes(server, qdrant)
|
|
558
|
+
s1 = server.qmodels.SparseVector(indices=[1], values=[1.0])
|
|
559
|
+
s2 = server.qmodels.SparseVector(indices=[2], values=[1.0])
|
|
560
|
+
|
|
561
|
+
async def fake_sparse_docs(texts):
|
|
562
|
+
assert texts == ["first record", "second record"]
|
|
563
|
+
return [s1, s2]
|
|
564
|
+
|
|
565
|
+
server._sparse_encode_documents = fake_sparse_docs
|
|
566
|
+
|
|
567
|
+
out = asyncio.run(server.store_batch(server.StoreBatchRequest(
|
|
568
|
+
records=[
|
|
569
|
+
{"content": "first record", "metadata": {"arena": "arena-a"}},
|
|
570
|
+
{"content": "second record", "metadata": {"arena": "arena-a"}},
|
|
571
|
+
],
|
|
572
|
+
arena="arena-a",
|
|
573
|
+
)))
|
|
574
|
+
|
|
575
|
+
assert out["inserted"] == 2
|
|
576
|
+
points = qdrant.upsert_calls[0]["points"]
|
|
577
|
+
assert points[0].vector["lex"] == s1
|
|
578
|
+
assert points[1].vector["lex"] == s2
|
|
579
|
+
assert points[0].vector[""] == [0.5, 0.6]
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def test_flag_on_store_sparse_failure_degrades_to_dense_only():
|
|
583
|
+
server = load_server(hybrid=True)
|
|
584
|
+
qdrant = FakeQdrant()
|
|
585
|
+
_wire_store_fakes(server, qdrant)
|
|
586
|
+
|
|
587
|
+
async def broken(texts):
|
|
588
|
+
raise RuntimeError("model fetch failed")
|
|
589
|
+
|
|
590
|
+
server._sparse_encode_documents = broken
|
|
591
|
+
|
|
592
|
+
asyncio.run(server.store(server.StoreRequest(content="hello", metadata={"arena": "arena-a"})))
|
|
593
|
+
point = qdrant.upsert_calls[0]["points"][0]
|
|
594
|
+
assert point.vector == [0.5, 0.6], "sparse failure must not fail ingest"
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def test_flag_off_store_batch_keeps_bare_dense_vectors():
|
|
598
|
+
server = load_server(hybrid=False)
|
|
599
|
+
qdrant = FakeQdrant()
|
|
600
|
+
_wire_store_fakes(server, qdrant)
|
|
601
|
+
|
|
602
|
+
asyncio.run(server.store_batch(server.StoreBatchRequest(
|
|
603
|
+
records=[{"content": "rec", "metadata": {"arena": "arena-a"}}],
|
|
604
|
+
arena="arena-a",
|
|
605
|
+
)))
|
|
606
|
+
point = qdrant.upsert_calls[0]["points"][0]
|
|
607
|
+
assert point.vector == [0.5, 0.6]
|
|
608
|
+
assert not isinstance(point.vector, dict)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
# ----------------------------------------------------------------------
|
|
612
|
+
# Collection migration helper
|
|
613
|
+
# ----------------------------------------------------------------------
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
class _FakeCollectionInfo:
|
|
617
|
+
def __init__(self, sparse: dict | None):
|
|
618
|
+
self.config = types.SimpleNamespace(
|
|
619
|
+
params=types.SimpleNamespace(sparse_vectors=sparse)
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
|
|
623
|
+
def test_ensure_sparse_config_adds_when_missing():
|
|
624
|
+
server = load_server(hybrid=True)
|
|
625
|
+
calls = {}
|
|
626
|
+
|
|
627
|
+
class Q:
|
|
628
|
+
async def get_collection(self, name):
|
|
629
|
+
return _FakeCollectionInfo(sparse=None)
|
|
630
|
+
|
|
631
|
+
async def update_collection(self, collection_name, sparse_vectors_config):
|
|
632
|
+
calls["collection"] = collection_name
|
|
633
|
+
calls["config"] = sparse_vectors_config
|
|
634
|
+
|
|
635
|
+
server._qdrant = Q()
|
|
636
|
+
added = asyncio.run(server._ensure_sparse_vector_config())
|
|
637
|
+
assert added is True
|
|
638
|
+
assert calls["collection"] == "evidence"
|
|
639
|
+
cfg = calls["config"]["lex"]
|
|
640
|
+
assert cfg.modifier == server.qmodels.Modifier.IDF
|
|
641
|
+
assert cfg.index.on_disk is True
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def test_ensure_sparse_config_noop_when_present():
|
|
645
|
+
server = load_server(hybrid=True)
|
|
646
|
+
|
|
647
|
+
class Q:
|
|
648
|
+
async def get_collection(self, name):
|
|
649
|
+
return _FakeCollectionInfo(sparse={"lex": object()})
|
|
650
|
+
|
|
651
|
+
async def update_collection(self, **kw):
|
|
652
|
+
raise AssertionError("must not update when 'lex' already configured")
|
|
653
|
+
|
|
654
|
+
server._qdrant = Q()
|
|
655
|
+
assert asyncio.run(server._ensure_sparse_vector_config()) is False
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
# ----------------------------------------------------------------------
|
|
659
|
+
# Sparse encoder wrapper (real fastembed — skipped if not installed)
|
|
660
|
+
# ----------------------------------------------------------------------
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def test_sparse_encoder_wrapper_roundtrip():
|
|
664
|
+
try:
|
|
665
|
+
import fastembed # noqa: F401
|
|
666
|
+
except ImportError:
|
|
667
|
+
_skip("fastembed not installed in this test environment")
|
|
668
|
+
server = load_server(hybrid=True)
|
|
669
|
+
docs = asyncio.run(server._sparse_encode_documents(
|
|
670
|
+
["the quick brown fox", "pays the invoice 4711"]
|
|
671
|
+
))
|
|
672
|
+
assert len(docs) == 2
|
|
673
|
+
for d in docs:
|
|
674
|
+
assert len(d.indices) == len(d.values) > 0
|
|
675
|
+
assert all(isinstance(i, int) for i in d.indices)
|
|
676
|
+
q = asyncio.run(server._sparse_encode_query("invoice 4711"))
|
|
677
|
+
assert len(q.indices) == len(q.values) > 0
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def test_to_sparse_vector_coerces_numpy_like_arrays():
|
|
681
|
+
server = load_server(hybrid=True)
|
|
682
|
+
|
|
683
|
+
class FakeEmb:
|
|
684
|
+
indices = [7, 11, 13]
|
|
685
|
+
values = [0.5, 1.5, 2.0]
|
|
686
|
+
|
|
687
|
+
sv = server._to_sparse_vector(FakeEmb())
|
|
688
|
+
assert sv.indices == [7, 11, 13]
|
|
689
|
+
assert sv.values == [0.5, 1.5, 2.0]
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
# ----------------------------------------------------------------------
|
|
693
|
+
# Backfill script — dry-run math + state handling (stdlib only)
|
|
694
|
+
# ----------------------------------------------------------------------
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def _load_backfill():
|
|
698
|
+
return _load_module(_BACKFILL, "_backfill_sparse_under_test")
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
def test_backfill_batch_count_math():
|
|
702
|
+
bf = _load_backfill()
|
|
703
|
+
assert bf.batch_count(0, 256) == 0
|
|
704
|
+
assert bf.batch_count(1, 256) == 1
|
|
705
|
+
assert bf.batch_count(256, 256) == 1
|
|
706
|
+
assert bf.batch_count(257, 256) == 2
|
|
707
|
+
assert bf.batch_count(745_000, 256) == 2911
|
|
708
|
+
assert bf.batch_count(100, 0) == 0
|
|
709
|
+
|
|
710
|
+
|
|
711
|
+
def test_backfill_eta_math():
|
|
712
|
+
bf = _load_backfill()
|
|
713
|
+
assert bf.eta_seconds(620_000, 400.0) == 1550.0
|
|
714
|
+
assert bf.eta_seconds(0, 400.0) == 0.0
|
|
715
|
+
assert bf.eta_seconds(100, 0) == 0.0
|
|
716
|
+
assert bf.format_eta(1550) == "25m50s"
|
|
717
|
+
assert bf.format_eta(7325) == "2h02m"
|
|
718
|
+
assert bf.format_eta(42) == "42s"
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
def test_backfill_state_roundtrip_and_corruption_tolerance():
|
|
722
|
+
bf = _load_backfill()
|
|
723
|
+
with tempfile.TemporaryDirectory() as d:
|
|
724
|
+
path = os.path.join(d, "state.json")
|
|
725
|
+
assert bf.load_state(path) == {}
|
|
726
|
+
bf.save_state(path, {"next_offset": "abc-123", "scanned": 512})
|
|
727
|
+
assert bf.load_state(path) == {"next_offset": "abc-123", "scanned": 512}
|
|
728
|
+
with open(path, "w") as f:
|
|
729
|
+
f.write("{corrupt")
|
|
730
|
+
assert bf.load_state(path) == {}
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def test_backfill_defaults_are_safe():
|
|
734
|
+
bf = _load_backfill()
|
|
735
|
+
args = bf.parse_args([])
|
|
736
|
+
assert args.apply is False, "backfill must be dry-run by default"
|
|
737
|
+
assert args.collection == "evidence"
|
|
738
|
+
assert args.batch_size == 256
|
|
739
|
+
assert args.force is False
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
# ----------------------------------------------------------------------
|
|
743
|
+
# Eval harness metric math (stdlib only)
|
|
744
|
+
# ----------------------------------------------------------------------
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def _load_eval():
|
|
748
|
+
return _load_module(_EVAL, "_recall_at_k_under_test")
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def test_eval_recall_at_k():
|
|
752
|
+
ev = _load_eval()
|
|
753
|
+
assert ev.recall_at_k(["a", "b", "c"], {"a", "c"}, 2) == 0.5
|
|
754
|
+
assert ev.recall_at_k(["a", "b", "c"], {"a", "c"}, 3) == 1.0
|
|
755
|
+
assert ev.recall_at_k([], {"a"}, 5) == 0.0
|
|
756
|
+
assert ev.recall_at_k(["a"], set(), 5) == 0.0
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def test_eval_ndcg_at_k():
|
|
760
|
+
ev = _load_eval()
|
|
761
|
+
gains = {"a": 2.0, "b": 1.0}
|
|
762
|
+
assert abs(ev.ndcg_at_k(["a", "b"], gains, 2) - 1.0) < 1e-9 # ideal order
|
|
763
|
+
worse = ev.ndcg_at_k(["b", "a"], gains, 2)
|
|
764
|
+
assert 0.0 < worse < 1.0
|
|
765
|
+
assert ev.ndcg_at_k(["x", "y"], gains, 2) == 0.0
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
def test_eval_skips_placeholder_questions():
|
|
769
|
+
ev = _load_eval()
|
|
770
|
+
assert ev.is_judged({"relevant": [{"event_id": "EVENT_ID_PLACEHOLDER_1A"}]}) is False
|
|
771
|
+
assert ev.is_judged({"relevant": []}) is False
|
|
772
|
+
assert ev.is_judged({}) is False
|
|
773
|
+
assert ev.is_judged({"relevant": [{"event_id": "ev-real-1"}]}) is True
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def test_eval_seed_file_parses_and_is_all_placeholders():
|
|
777
|
+
ev = _load_eval()
|
|
778
|
+
with open(_PKG / "eval" / "retrieval_golden.seed.json") as f:
|
|
779
|
+
golden = json.load(f)
|
|
780
|
+
assert golden["questions"], "seed must ship with example questions"
|
|
781
|
+
assert all(not ev.is_judged(q) for q in golden["questions"]), (
|
|
782
|
+
"the committed seed must contain only placeholders — no live ids"
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
# ----------------------------------------------------------------------
|
|
787
|
+
# Stdlib runner (pytest collects the same functions when available)
|
|
788
|
+
# ----------------------------------------------------------------------
|
|
789
|
+
|
|
790
|
+
if __name__ == "__main__":
|
|
791
|
+
passed, skipped, failed = 0, 0, []
|
|
792
|
+
for name, fn in sorted(globals().items()):
|
|
793
|
+
if not (name.startswith("test_") and callable(fn)):
|
|
794
|
+
continue
|
|
795
|
+
try:
|
|
796
|
+
fn()
|
|
797
|
+
passed += 1
|
|
798
|
+
print(f"PASS {name}")
|
|
799
|
+
except _Skip as e:
|
|
800
|
+
skipped += 1
|
|
801
|
+
print(f"SKIP {name} ({e})")
|
|
802
|
+
except BaseException as e: # pytest.skip raises BaseException subclass
|
|
803
|
+
if pytest is not None and isinstance(e, pytest.skip.Exception):
|
|
804
|
+
skipped += 1
|
|
805
|
+
print(f"SKIP {name} ({e})")
|
|
806
|
+
else:
|
|
807
|
+
failed.append((name, e))
|
|
808
|
+
print(f"FAIL {name}: {type(e).__name__}: {e}")
|
|
809
|
+
print(f"\n{passed} passed, {skipped} skipped, {len(failed)} failed")
|
|
810
|
+
sys.exit(1 if failed else 0)
|