@pentatonic-ai/ai-agent-sdk 0.10.4 → 0.10.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,810 @@
1
+ """Unit tests for flag-gated hybrid BM25+RRF retrieval (roadmap BET 3).
2
+
3
+ Covers, without any live engine / network / docker:
4
+
5
+ - flag OFF → /search uses the legacy `search()` call (query_points
6
+ never touched) and /store writes the bare unnamed dense vector —
7
+ i.e. the request path is unchanged.
8
+ - flag ON → /search issues `query_points` with the exact two-leg
9
+ prefetch (dense on the unnamed '' vector, sparse on 'lex') fused by
10
+ FusionQuery(RRF); /store and /store-batch write the named-vector
11
+ bag {'': dense, 'lex': sparse} from FULL content.
12
+ - sparse-encode failure with flag ON → graceful dense-only fallback.
13
+ - backfill script dry-run math + state round-trip.
14
+ - eval harness metric math (recall@k / nDCG@k).
15
+ - the real fastembed encoder wrapper (skipped when fastembed is not
16
+ installed — pytest.importorskip / stdlib-runner skip).
17
+
18
+ Dependency strategy: compat/server.py imports fastapi/pydantic/qdrant/
19
+ httpx/numpy/psycopg at module load. These tests use the REAL packages
20
+ when importable and install minimal in-memory stubs into sys.modules
21
+ otherwise, so the suite runs both in a full dev env (pytest) and in a
22
+ bare stdlib environment:
23
+
24
+ python3 packages/memory-engine-v2/tests/test_hybrid_retrieval.py
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import asyncio
30
+ import importlib.util
31
+ import json
32
+ import os
33
+ import sys
34
+ import tempfile
35
+ import types
36
+ from pathlib import Path
37
+
38
+ _PKG = Path(__file__).resolve().parent.parent
39
+ _SERVER = _PKG / "compat" / "server.py"
40
+ _BACKFILL = _PKG / "scripts" / "backfill_sparse_vectors.py"
41
+ _EVAL = _PKG / "eval" / "recall_at_k.py"
42
+
43
+ try:
44
+ import pytest
45
+ except ImportError: # bare stdlib runner
46
+ pytest = None
47
+
48
+
49
+ class _Skip(Exception):
50
+ """Stdlib-runner skip marker (pytest path uses pytest.skip)."""
51
+
52
+
53
+ def _skip(msg: str):
54
+ if pytest is not None:
55
+ pytest.skip(msg)
56
+ raise _Skip(msg)
57
+
58
+
59
+ # ----------------------------------------------------------------------
60
+ # Minimal stubs for server.py's import surface (used only when the real
61
+ # package is not importable in this environment).
62
+ # ----------------------------------------------------------------------
63
+
64
+
65
+ class _Rec:
66
+ """Generic kwargs-record stand-in for qdrant model classes."""
67
+
68
+ def __init__(self, **kw):
69
+ self._kw = kw
70
+ for k, v in kw.items():
71
+ setattr(self, k, v)
72
+
73
+ def __eq__(self, other):
74
+ return type(other) is type(self) and self._kw == other._kw
75
+
76
+ def __repr__(self):
77
+ return f"{type(self).__name__}({self._kw})"
78
+
79
+
80
+ def _stub_qdrant_client() -> types.ModuleType:
81
+ mod = types.ModuleType("qdrant_client")
82
+ models = types.ModuleType("qdrant_client.models")
83
+
84
+ for name in (
85
+ "VectorParams", "ScalarQuantization", "ScalarQuantizationConfig",
86
+ "FieldCondition", "MatchAny", "MatchValue", "Filter", "PointStruct",
87
+ "FilterSelector", "SparseVector", "SparseVectorParams",
88
+ "SparseIndexParams", "Prefetch", "FusionQuery", "PointVectors",
89
+ ):
90
+ setattr(models, name, type(name, (_Rec,), {}))
91
+
92
+ class Distance:
93
+ COSINE = "Cosine"
94
+
95
+ class ScalarType:
96
+ INT8 = "int8"
97
+
98
+ class PayloadSchemaType:
99
+ KEYWORD = "keyword"
100
+
101
+ class Modifier:
102
+ IDF = "idf"
103
+
104
+ class Fusion:
105
+ RRF = "rrf"
106
+ DBSF = "dbsf"
107
+
108
+ models.Distance = Distance
109
+ models.ScalarType = ScalarType
110
+ models.PayloadSchemaType = PayloadSchemaType
111
+ models.Modifier = Modifier
112
+ models.Fusion = Fusion
113
+
114
+ class AsyncQdrantClient: # never instantiated in tests
115
+ def __init__(self, *a, **kw):
116
+ pass
117
+
118
+ mod.AsyncQdrantClient = AsyncQdrantClient
119
+ mod.models = models
120
+ sys.modules["qdrant_client.models"] = models
121
+ return mod
122
+
123
+
124
+ def _stub_pydantic() -> types.ModuleType:
125
+ mod = types.ModuleType("pydantic")
126
+
127
+ class _FieldInfo:
128
+ def __init__(self, default=None, default_factory=None):
129
+ self.default = default
130
+ self.default_factory = default_factory
131
+
132
+ def Field(default=None, **kw):
133
+ return _FieldInfo(default, kw.get("default_factory"))
134
+
135
+ class BaseModel:
136
+ def __init__(self, **kwargs):
137
+ ann: dict = {}
138
+ for klass in reversed(type(self).__mro__):
139
+ ann.update(getattr(klass, "__annotations__", {}))
140
+ for name in ann:
141
+ if name in kwargs:
142
+ value = kwargs[name]
143
+ else:
144
+ default = getattr(type(self), name, None)
145
+ if isinstance(default, _FieldInfo):
146
+ value = (default.default_factory()
147
+ if default.default_factory else default.default)
148
+ else:
149
+ value = default
150
+ setattr(self, name, value)
151
+
152
+ mod.BaseModel = BaseModel
153
+ mod.Field = Field
154
+ return mod
155
+
156
+
157
+ def _stub_fastapi() -> types.ModuleType:
158
+ mod = types.ModuleType("fastapi")
159
+
160
+ class FastAPI:
161
+ def __init__(self, **kw):
162
+ pass
163
+
164
+ def get(self, path):
165
+ return lambda fn: fn
166
+
167
+ def post(self, path):
168
+ return lambda fn: fn
169
+
170
+ class HTTPException(Exception):
171
+ def __init__(self, status_code, detail=None):
172
+ super().__init__(detail)
173
+ self.status_code = status_code
174
+ self.detail = detail
175
+
176
+ mod.FastAPI = FastAPI
177
+ mod.HTTPException = HTTPException
178
+ return mod
179
+
180
+
181
+ def _stub_httpx() -> types.ModuleType:
182
+ mod = types.ModuleType("httpx")
183
+
184
+ class AsyncClient:
185
+ def __init__(self, *a, **kw):
186
+ pass
187
+
188
+ class Timeout:
189
+ def __init__(self, *a, **kw):
190
+ pass
191
+
192
+ class HTTPStatusError(Exception):
193
+ def __init__(self, *a, **kw):
194
+ super().__init__(*a)
195
+
196
+ class TimeoutException(Exception):
197
+ pass
198
+
199
+ class NetworkError(Exception):
200
+ pass
201
+
202
+ mod.AsyncClient = AsyncClient
203
+ mod.Timeout = Timeout
204
+ mod.HTTPStatusError = HTTPStatusError
205
+ mod.TimeoutException = TimeoutException
206
+ mod.NetworkError = NetworkError
207
+ return mod
208
+
209
+
210
+ def _stub_numpy() -> types.ModuleType:
211
+ mod = types.ModuleType("numpy")
212
+
213
+ def _unused(*a, **kw): # MMR is vector-gated; tests never reach numpy
214
+ raise AssertionError("numpy stub should not be exercised by these tests")
215
+
216
+ mod.asarray = _unused
217
+ mod.max = _unused
218
+ mod.float32 = "float32"
219
+ return mod
220
+
221
+
222
+ def _stub_psycopg() -> tuple[types.ModuleType, types.ModuleType, types.ModuleType]:
223
+ psycopg = types.ModuleType("psycopg")
224
+ rows = types.ModuleType("psycopg.rows")
225
+ rows.dict_row = object()
226
+ tjson = types.ModuleType("psycopg.types.json")
227
+
228
+ class Json:
229
+ def __init__(self, obj):
230
+ self.obj = obj
231
+
232
+ tjson.Json = Json
233
+ tmod = types.ModuleType("psycopg.types")
234
+ tmod.json = tjson
235
+ psycopg.rows = rows
236
+ psycopg.types = tmod
237
+
238
+ pool = types.ModuleType("psycopg_pool")
239
+
240
+ class AsyncConnectionPool:
241
+ def __init__(self, *a, **kw):
242
+ pass
243
+
244
+ pool.AsyncConnectionPool = AsyncConnectionPool
245
+ return psycopg, rows, tmod, tjson, pool
246
+
247
+
248
+ def _ensure_modules():
249
+ """Install stubs for any of server.py's deps that aren't importable."""
250
+ def missing(name: str) -> bool:
251
+ if name in sys.modules:
252
+ return False
253
+ try:
254
+ return importlib.util.find_spec(name) is None
255
+ except (ImportError, ValueError):
256
+ return True
257
+
258
+ if missing("qdrant_client"):
259
+ sys.modules["qdrant_client"] = _stub_qdrant_client()
260
+ if missing("pydantic"):
261
+ sys.modules["pydantic"] = _stub_pydantic()
262
+ if missing("fastapi"):
263
+ sys.modules["fastapi"] = _stub_fastapi()
264
+ if missing("httpx"):
265
+ sys.modules["httpx"] = _stub_httpx()
266
+ if missing("numpy"):
267
+ sys.modules["numpy"] = _stub_numpy()
268
+ if missing("psycopg"):
269
+ psycopg, rows, tmod, tjson, pool = _stub_psycopg()
270
+ sys.modules["psycopg"] = psycopg
271
+ sys.modules["psycopg.rows"] = rows
272
+ sys.modules["psycopg.types"] = tmod
273
+ sys.modules["psycopg.types.json"] = tjson
274
+ sys.modules["psycopg_pool"] = pool
275
+ elif missing("psycopg_pool"):
276
+ _, _, _, _, pool = _stub_psycopg()
277
+ sys.modules["psycopg_pool"] = pool
278
+
279
+
280
+ _LOAD_SEQ = 0
281
+
282
+
283
+ def _load_module(path: Path, name: str):
284
+ spec = importlib.util.spec_from_file_location(name, path)
285
+ assert spec and spec.loader
286
+ mod = importlib.util.module_from_spec(spec)
287
+ sys.modules[name] = mod
288
+ spec.loader.exec_module(mod)
289
+ return mod
290
+
291
+
292
+ def load_server(hybrid: bool):
293
+ """Fresh server module instance with SEARCH_HYBRID_ENABLED set
294
+ before import (the flag is read at module load)."""
295
+ global _LOAD_SEQ
296
+ _LOAD_SEQ += 1
297
+ _ensure_modules()
298
+ os.environ["SEARCH_HYBRID_ENABLED"] = "1" if hybrid else "0"
299
+ return _load_module(_SERVER, f"_compat_server_under_test_{_LOAD_SEQ}")
300
+
301
+
302
+ # ----------------------------------------------------------------------
303
+ # Async fakes (qdrant client / pg pool)
304
+ # ----------------------------------------------------------------------
305
+
306
+
307
+ class FakeScored:
308
+ def __init__(self, event_id: str, score: float, source_kind: str = "note",
309
+ payload_extra: dict | None = None):
310
+ self.payload = {"event_id": event_id, "arena": "arena-a",
311
+ "source_kind": source_kind, **(payload_extra or {})}
312
+ self.score = score
313
+ self.vector = None # vector-less → MMR falls back to score order
314
+
315
+
316
+ class FakeQdrant:
317
+ def __init__(self, search_results=None, query_points_results=None):
318
+ self.search_results = search_results or []
319
+ self.query_points_results = query_points_results or []
320
+ self.search_calls: list[dict] = []
321
+ self.query_points_calls: list[dict] = []
322
+ self.upsert_calls: list[dict] = []
323
+
324
+ async def search(self, **kw):
325
+ self.search_calls.append(kw)
326
+ return list(self.search_results)
327
+
328
+ async def query_points(self, **kw):
329
+ self.query_points_calls.append(kw)
330
+ return types.SimpleNamespace(points=list(self.query_points_results))
331
+
332
+ async def upsert(self, **kw):
333
+ self.upsert_calls.append(kw)
334
+
335
+
336
+ class _AsyncCM:
337
+ def __init__(self, value):
338
+ self.value = value
339
+
340
+ async def __aenter__(self):
341
+ return self.value
342
+
343
+ async def __aexit__(self, *a):
344
+ return False
345
+
346
+
347
+ class FakeCursor:
348
+ def __init__(self, rows=None):
349
+ self.rows = rows or []
350
+ self.executed: list[tuple] = []
351
+
352
+ async def execute(self, sql, params=None):
353
+ self.executed.append((sql, params))
354
+
355
+ async def executemany(self, sql, rows):
356
+ self.executed.append((sql, rows))
357
+
358
+ async def fetchall(self):
359
+ return list(self.rows)
360
+
361
+
362
+ class FakeConn:
363
+ def __init__(self, cursor):
364
+ self._cursor = cursor
365
+
366
+ def cursor(self):
367
+ return _AsyncCM(self._cursor)
368
+
369
+
370
+ class FakePool:
371
+ def __init__(self, cursor=None):
372
+ self.cursor = cursor or FakeCursor()
373
+
374
+ def connection(self):
375
+ return _AsyncCM(FakeConn(self.cursor))
376
+
377
+
378
+ def _wire_search_fakes(server, qdrant, db_rows):
379
+ async def fake_embed(texts, lane="bulk"):
380
+ return [[0.1, 0.2, 0.3, 0.4] for _ in texts]
381
+
382
+ server._embed_batch = fake_embed
383
+ server._qdrant = qdrant
384
+ server._pool = FakePool(FakeCursor(rows=db_rows))
385
+
386
+
387
+ def _db_row(event_id: str, content: str = "full content", ts: str | None = None):
388
+ attrs = {"timestamp": ts} if ts else {}
389
+ return {"id": event_id, "content": content, "attributes": attrs}
390
+
391
+
392
+ # ----------------------------------------------------------------------
393
+ # /search — flag OFF: legacy path, byte-identical behavior
394
+ # ----------------------------------------------------------------------
395
+
396
+
397
+ def test_flag_off_search_uses_legacy_search_not_query_points():
398
+ server = load_server(hybrid=False)
399
+ qdrant = FakeQdrant(search_results=[FakeScored("e1", 0.91), FakeScored("e2", 0.84)])
400
+ _wire_search_fakes(server, qdrant, [_db_row("e1"), _db_row("e2")])
401
+
402
+ out = asyncio.run(server.search(server.SearchRequest(query="who is pact", arena="arena-a")))
403
+
404
+ assert len(qdrant.search_calls) == 1, "flag-off must use the legacy search()"
405
+ assert qdrant.query_points_calls == [], "flag-off must NEVER call query_points"
406
+ call = qdrant.search_calls[0]
407
+ assert call["collection_name"] == "evidence"
408
+ assert call["query_vector"] == [0.1, 0.2, 0.3, 0.4]
409
+ assert call["limit"] == 30 # limit 10 × SEARCH_OVERFETCH_MULT 3
410
+ assert call["score_threshold"] == 0.001
411
+ assert call["with_payload"] is True
412
+ ids = [r["id"] for r in out["results"]]
413
+ assert ids == ["e1", "e2"]
414
+ assert out["results"][0]["content"] == "full content"
415
+
416
+
417
+ def test_flag_off_never_calls_sparse_encoder():
418
+ server = load_server(hybrid=False)
419
+
420
+ def boom(*a, **kw):
421
+ raise AssertionError("sparse encoder must not be touched when flag is off")
422
+
423
+ server._get_sparse_encoder = boom
424
+ qdrant = FakeQdrant(search_results=[])
425
+ _wire_search_fakes(server, qdrant, [])
426
+ out = asyncio.run(server.search(server.SearchRequest(query="x", arena="arena-a")))
427
+ assert out == {"results": []}
428
+ assert len(qdrant.search_calls) == 1
429
+
430
+
431
+ # ----------------------------------------------------------------------
432
+ # /search — flag ON: RRF-fused query_points with two prefetch legs
433
+ # ----------------------------------------------------------------------
434
+
435
+
436
+ def _sentinel_sparse(server):
437
+ return server.qmodels.SparseVector(indices=[3, 17], values=[1.0, 1.0])
438
+
439
+
440
+ def test_flag_on_search_uses_query_points_with_rrf_prefetch():
441
+ server = load_server(hybrid=True)
442
+ qdrant = FakeQdrant(query_points_results=[FakeScored("e1", 0.0163), FakeScored("e2", 0.0161)])
443
+ _wire_search_fakes(server, qdrant, [_db_row("e1"), _db_row("e2")])
444
+ sentinel = _sentinel_sparse(server)
445
+
446
+ async def fake_sparse_query(text):
447
+ return sentinel
448
+
449
+ server._sparse_encode_query = fake_sparse_query
450
+
451
+ out = asyncio.run(server.search(server.SearchRequest(query="acme invoice 4711", arena="arena-a")))
452
+
453
+ assert qdrant.search_calls == [], "flag-on must not use the legacy search()"
454
+ assert len(qdrant.query_points_calls) == 1
455
+ call = qdrant.query_points_calls[0]
456
+ assert call["collection_name"] == "evidence"
457
+ assert call["with_payload"] is True
458
+ assert call["limit"] == 30
459
+
460
+ prefetch = call["prefetch"]
461
+ assert len(prefetch) == 2
462
+ dense, sparse = prefetch
463
+ assert dense.using == "" # unnamed dense vector
464
+ assert dense.query == [0.1, 0.2, 0.3, 0.4]
465
+ assert dense.limit == 30
466
+ assert dense.filter is not None
467
+ assert sparse.using == "lex"
468
+ assert sparse.query == sentinel
469
+ assert sparse.limit == 30
470
+ assert sparse.filter is not None
471
+
472
+ fusion_query = call["query"]
473
+ assert fusion_query.fusion == server.qmodels.Fusion.RRF
474
+
475
+ # downstream pipeline (dedup → hydration) untouched: RRF score
476
+ # surfaces as `similarity`, content hydrated from postgres.
477
+ assert [r["id"] for r in out["results"]] == ["e1", "e2"]
478
+ assert out["results"][0]["similarity"] == 0.0163
479
+ assert out["results"][0]["content"] == "full content"
480
+
481
+
482
+ def test_flag_on_sparse_query_failure_falls_back_to_dense():
483
+ server = load_server(hybrid=True)
484
+ qdrant = FakeQdrant(search_results=[FakeScored("e1", 0.9)])
485
+ _wire_search_fakes(server, qdrant, [_db_row("e1")])
486
+
487
+ async def broken_sparse_query(text):
488
+ raise RuntimeError("fastembed unavailable")
489
+
490
+ server._sparse_encode_query = broken_sparse_query
491
+
492
+ out = asyncio.run(server.search(server.SearchRequest(query="x", arena="arena-a")))
493
+ assert len(qdrant.search_calls) == 1, "sparse failure must fall back to dense search()"
494
+ assert qdrant.query_points_calls == []
495
+ assert [r["id"] for r in out["results"]] == ["e1"]
496
+
497
+
498
+ # ----------------------------------------------------------------------
499
+ # /store + /store-batch — named sparse vector writes
500
+ # ----------------------------------------------------------------------
501
+
502
+
503
+ def _wire_store_fakes(server, qdrant):
504
+ async def fake_embed(texts, lane="bulk"):
505
+ return [[0.5, 0.6] for _ in texts]
506
+
507
+ async def fake_extract(arena, clientId, userId, source_kind, content, attributes):
508
+ return "evt-" + str(abs(hash(content)) % 10_000)
509
+
510
+ server._embed_batch = fake_embed
511
+ server._extract = fake_extract
512
+ server._qdrant = qdrant
513
+ server._pool = FakePool()
514
+
515
+
516
+ def test_flag_off_store_writes_bare_dense_vector():
517
+ server = load_server(hybrid=False)
518
+ qdrant = FakeQdrant()
519
+ _wire_store_fakes(server, qdrant)
520
+
521
+ asyncio.run(server.store(server.StoreRequest(content="hello world", metadata={"arena": "arena-a"})))
522
+
523
+ assert len(qdrant.upsert_calls) == 1
524
+ point = qdrant.upsert_calls[0]["points"][0]
525
+ assert point.vector == [0.5, 0.6], "flag-off must keep the bare unnamed dense vector"
526
+ assert not isinstance(point.vector, dict)
527
+
528
+
529
+ def test_flag_on_store_writes_named_dense_plus_lex_sparse_from_full_content():
530
+ server = load_server(hybrid=True)
531
+ qdrant = FakeQdrant()
532
+ _wire_store_fakes(server, qdrant)
533
+ sentinel = _sentinel_sparse(server)
534
+ seen_texts: list[list[str]] = []
535
+
536
+ async def fake_sparse_docs(texts):
537
+ seen_texts.append(list(texts))
538
+ return [sentinel for _ in texts]
539
+
540
+ server._sparse_encode_documents = fake_sparse_docs
541
+
542
+ long_content = "x" * 800 # > the 300-char content_preview truncation
543
+ asyncio.run(server.store(server.StoreRequest(content=long_content, metadata={"arena": "arena-a"})))
544
+
545
+ point = qdrant.upsert_calls[0]["points"][0]
546
+ assert isinstance(point.vector, dict)
547
+ assert point.vector[""] == [0.5, 0.6]
548
+ assert point.vector["lex"] == sentinel
549
+ # sparse encoding must see FULL content, not the 300-char preview
550
+ assert seen_texts == [[long_content]]
551
+ assert point.payload["content_preview"] == "x" * 300
552
+
553
+
554
+ def test_flag_on_store_batch_writes_named_vectors_per_record():
555
+ server = load_server(hybrid=True)
556
+ qdrant = FakeQdrant()
557
+ _wire_store_fakes(server, qdrant)
558
+ s1 = server.qmodels.SparseVector(indices=[1], values=[1.0])
559
+ s2 = server.qmodels.SparseVector(indices=[2], values=[1.0])
560
+
561
+ async def fake_sparse_docs(texts):
562
+ assert texts == ["first record", "second record"]
563
+ return [s1, s2]
564
+
565
+ server._sparse_encode_documents = fake_sparse_docs
566
+
567
+ out = asyncio.run(server.store_batch(server.StoreBatchRequest(
568
+ records=[
569
+ {"content": "first record", "metadata": {"arena": "arena-a"}},
570
+ {"content": "second record", "metadata": {"arena": "arena-a"}},
571
+ ],
572
+ arena="arena-a",
573
+ )))
574
+
575
+ assert out["inserted"] == 2
576
+ points = qdrant.upsert_calls[0]["points"]
577
+ assert points[0].vector["lex"] == s1
578
+ assert points[1].vector["lex"] == s2
579
+ assert points[0].vector[""] == [0.5, 0.6]
580
+
581
+
582
+ def test_flag_on_store_sparse_failure_degrades_to_dense_only():
583
+ server = load_server(hybrid=True)
584
+ qdrant = FakeQdrant()
585
+ _wire_store_fakes(server, qdrant)
586
+
587
+ async def broken(texts):
588
+ raise RuntimeError("model fetch failed")
589
+
590
+ server._sparse_encode_documents = broken
591
+
592
+ asyncio.run(server.store(server.StoreRequest(content="hello", metadata={"arena": "arena-a"})))
593
+ point = qdrant.upsert_calls[0]["points"][0]
594
+ assert point.vector == [0.5, 0.6], "sparse failure must not fail ingest"
595
+
596
+
597
+ def test_flag_off_store_batch_keeps_bare_dense_vectors():
598
+ server = load_server(hybrid=False)
599
+ qdrant = FakeQdrant()
600
+ _wire_store_fakes(server, qdrant)
601
+
602
+ asyncio.run(server.store_batch(server.StoreBatchRequest(
603
+ records=[{"content": "rec", "metadata": {"arena": "arena-a"}}],
604
+ arena="arena-a",
605
+ )))
606
+ point = qdrant.upsert_calls[0]["points"][0]
607
+ assert point.vector == [0.5, 0.6]
608
+ assert not isinstance(point.vector, dict)
609
+
610
+
611
+ # ----------------------------------------------------------------------
612
+ # Collection migration helper
613
+ # ----------------------------------------------------------------------
614
+
615
+
616
+ class _FakeCollectionInfo:
617
+ def __init__(self, sparse: dict | None):
618
+ self.config = types.SimpleNamespace(
619
+ params=types.SimpleNamespace(sparse_vectors=sparse)
620
+ )
621
+
622
+
623
+ def test_ensure_sparse_config_adds_when_missing():
624
+ server = load_server(hybrid=True)
625
+ calls = {}
626
+
627
+ class Q:
628
+ async def get_collection(self, name):
629
+ return _FakeCollectionInfo(sparse=None)
630
+
631
+ async def update_collection(self, collection_name, sparse_vectors_config):
632
+ calls["collection"] = collection_name
633
+ calls["config"] = sparse_vectors_config
634
+
635
+ server._qdrant = Q()
636
+ added = asyncio.run(server._ensure_sparse_vector_config())
637
+ assert added is True
638
+ assert calls["collection"] == "evidence"
639
+ cfg = calls["config"]["lex"]
640
+ assert cfg.modifier == server.qmodels.Modifier.IDF
641
+ assert cfg.index.on_disk is True
642
+
643
+
644
+ def test_ensure_sparse_config_noop_when_present():
645
+ server = load_server(hybrid=True)
646
+
647
+ class Q:
648
+ async def get_collection(self, name):
649
+ return _FakeCollectionInfo(sparse={"lex": object()})
650
+
651
+ async def update_collection(self, **kw):
652
+ raise AssertionError("must not update when 'lex' already configured")
653
+
654
+ server._qdrant = Q()
655
+ assert asyncio.run(server._ensure_sparse_vector_config()) is False
656
+
657
+
658
+ # ----------------------------------------------------------------------
659
+ # Sparse encoder wrapper (real fastembed — skipped if not installed)
660
+ # ----------------------------------------------------------------------
661
+
662
+
663
+ def test_sparse_encoder_wrapper_roundtrip():
664
+ try:
665
+ import fastembed # noqa: F401
666
+ except ImportError:
667
+ _skip("fastembed not installed in this test environment")
668
+ server = load_server(hybrid=True)
669
+ docs = asyncio.run(server._sparse_encode_documents(
670
+ ["the quick brown fox", "pays the invoice 4711"]
671
+ ))
672
+ assert len(docs) == 2
673
+ for d in docs:
674
+ assert len(d.indices) == len(d.values) > 0
675
+ assert all(isinstance(i, int) for i in d.indices)
676
+ q = asyncio.run(server._sparse_encode_query("invoice 4711"))
677
+ assert len(q.indices) == len(q.values) > 0
678
+
679
+
680
+ def test_to_sparse_vector_coerces_numpy_like_arrays():
681
+ server = load_server(hybrid=True)
682
+
683
+ class FakeEmb:
684
+ indices = [7, 11, 13]
685
+ values = [0.5, 1.5, 2.0]
686
+
687
+ sv = server._to_sparse_vector(FakeEmb())
688
+ assert sv.indices == [7, 11, 13]
689
+ assert sv.values == [0.5, 1.5, 2.0]
690
+
691
+
692
+ # ----------------------------------------------------------------------
693
+ # Backfill script — dry-run math + state handling (stdlib only)
694
+ # ----------------------------------------------------------------------
695
+
696
+
697
+ def _load_backfill():
698
+ return _load_module(_BACKFILL, "_backfill_sparse_under_test")
699
+
700
+
701
+ def test_backfill_batch_count_math():
702
+ bf = _load_backfill()
703
+ assert bf.batch_count(0, 256) == 0
704
+ assert bf.batch_count(1, 256) == 1
705
+ assert bf.batch_count(256, 256) == 1
706
+ assert bf.batch_count(257, 256) == 2
707
+ assert bf.batch_count(745_000, 256) == 2911
708
+ assert bf.batch_count(100, 0) == 0
709
+
710
+
711
+ def test_backfill_eta_math():
712
+ bf = _load_backfill()
713
+ assert bf.eta_seconds(620_000, 400.0) == 1550.0
714
+ assert bf.eta_seconds(0, 400.0) == 0.0
715
+ assert bf.eta_seconds(100, 0) == 0.0
716
+ assert bf.format_eta(1550) == "25m50s"
717
+ assert bf.format_eta(7325) == "2h02m"
718
+ assert bf.format_eta(42) == "42s"
719
+
720
+
721
+ def test_backfill_state_roundtrip_and_corruption_tolerance():
722
+ bf = _load_backfill()
723
+ with tempfile.TemporaryDirectory() as d:
724
+ path = os.path.join(d, "state.json")
725
+ assert bf.load_state(path) == {}
726
+ bf.save_state(path, {"next_offset": "abc-123", "scanned": 512})
727
+ assert bf.load_state(path) == {"next_offset": "abc-123", "scanned": 512}
728
+ with open(path, "w") as f:
729
+ f.write("{corrupt")
730
+ assert bf.load_state(path) == {}
731
+
732
+
733
+ def test_backfill_defaults_are_safe():
734
+ bf = _load_backfill()
735
+ args = bf.parse_args([])
736
+ assert args.apply is False, "backfill must be dry-run by default"
737
+ assert args.collection == "evidence"
738
+ assert args.batch_size == 256
739
+ assert args.force is False
740
+
741
+
742
+ # ----------------------------------------------------------------------
743
+ # Eval harness metric math (stdlib only)
744
+ # ----------------------------------------------------------------------
745
+
746
+
747
+ def _load_eval():
748
+ return _load_module(_EVAL, "_recall_at_k_under_test")
749
+
750
+
751
+ def test_eval_recall_at_k():
752
+ ev = _load_eval()
753
+ assert ev.recall_at_k(["a", "b", "c"], {"a", "c"}, 2) == 0.5
754
+ assert ev.recall_at_k(["a", "b", "c"], {"a", "c"}, 3) == 1.0
755
+ assert ev.recall_at_k([], {"a"}, 5) == 0.0
756
+ assert ev.recall_at_k(["a"], set(), 5) == 0.0
757
+
758
+
759
+ def test_eval_ndcg_at_k():
760
+ ev = _load_eval()
761
+ gains = {"a": 2.0, "b": 1.0}
762
+ assert abs(ev.ndcg_at_k(["a", "b"], gains, 2) - 1.0) < 1e-9 # ideal order
763
+ worse = ev.ndcg_at_k(["b", "a"], gains, 2)
764
+ assert 0.0 < worse < 1.0
765
+ assert ev.ndcg_at_k(["x", "y"], gains, 2) == 0.0
766
+
767
+
768
+ def test_eval_skips_placeholder_questions():
769
+ ev = _load_eval()
770
+ assert ev.is_judged({"relevant": [{"event_id": "EVENT_ID_PLACEHOLDER_1A"}]}) is False
771
+ assert ev.is_judged({"relevant": []}) is False
772
+ assert ev.is_judged({}) is False
773
+ assert ev.is_judged({"relevant": [{"event_id": "ev-real-1"}]}) is True
774
+
775
+
776
+ def test_eval_seed_file_parses_and_is_all_placeholders():
777
+ ev = _load_eval()
778
+ with open(_PKG / "eval" / "retrieval_golden.seed.json") as f:
779
+ golden = json.load(f)
780
+ assert golden["questions"], "seed must ship with example questions"
781
+ assert all(not ev.is_judged(q) for q in golden["questions"]), (
782
+ "the committed seed must contain only placeholders — no live ids"
783
+ )
784
+
785
+
786
+ # ----------------------------------------------------------------------
787
+ # Stdlib runner (pytest collects the same functions when available)
788
+ # ----------------------------------------------------------------------
789
+
790
+ if __name__ == "__main__":
791
+ passed, skipped, failed = 0, 0, []
792
+ for name, fn in sorted(globals().items()):
793
+ if not (name.startswith("test_") and callable(fn)):
794
+ continue
795
+ try:
796
+ fn()
797
+ passed += 1
798
+ print(f"PASS {name}")
799
+ except _Skip as e:
800
+ skipped += 1
801
+ print(f"SKIP {name} ({e})")
802
+ except BaseException as e: # pytest.skip raises BaseException subclass
803
+ if pytest is not None and isinstance(e, pytest.skip.Exception):
804
+ skipped += 1
805
+ print(f"SKIP {name} ({e})")
806
+ else:
807
+ failed.append((name, e))
808
+ print(f"FAIL {name}: {type(e).__name__}: {e}")
809
+ print(f"\n{passed} passed, {skipped} skipped, {len(failed)} failed")
810
+ sys.exit(1 if failed else 0)