@pentatonic-ai/ai-agent-sdk 0.10.5 → 0.10.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -878,7 +878,7 @@ function fireAndForgetEmit(clientConfig, sessionOpts, messages, result, model) {
878
878
  }
879
879
 
880
880
  // src/telemetry.js
881
- var VERSION = "0.10.5";
881
+ var VERSION = "0.10.6";
882
882
  var TELEMETRY_URL = "https://sdk-telemetry.philip-134.workers.dev";
883
883
  function machineId() {
884
884
  const raw = typeof process !== "undefined" ? `${process.env?.USER || process.env?.USERNAME || "u"}:${process.platform || "x"}:${process.arch || "x"}` : "browser";
package/dist/index.js CHANGED
@@ -847,7 +847,7 @@ function fireAndForgetEmit(clientConfig, sessionOpts, messages, result, model) {
847
847
  }
848
848
 
849
849
  // src/telemetry.js
850
- var VERSION = "0.10.5";
850
+ var VERSION = "0.10.6";
851
851
  var TELEMETRY_URL = "https://sdk-telemetry.philip-134.workers.dev";
852
852
  function machineId() {
853
853
  const raw = typeof process !== "undefined" ? `${process.env?.USER || process.env?.USERNAME || "u"}:${process.platform || "x"}:${process.arch || "x"}` : "browser";
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@pentatonic-ai/ai-agent-sdk",
3
- "version": "0.10.5",
3
+ "version": "0.10.6",
4
4
  "description": "TES SDK — LLM observability and lifecycle tracking via Pentatonic Thing Event System. Track token usage, tool calls, and conversations. Manage things through event-sourced lifecycle stages with AI enrichment and vector search.",
5
5
  "type": "module",
6
6
  "main": "./dist/index.cjs",
@@ -4,3 +4,9 @@ psycopg[binary,pool]==3.2.3
4
4
  httpx==0.27.2
5
5
  qdrant-client==1.12.1
6
6
  pydantic==2.9.2
7
+ # BET 3 (hybrid retrieval): CPU-only BM25 sparse encoder for the named
8
+ # 'lex' vector. 0.3.6 = the exact pin qdrant-client 1.12.1's own
9
+ # [fastembed] extra uses (and it requires python <3.13 — the compat
10
+ # image is python:3.12-slim). Only imported lazily when
11
+ # SEARCH_HYBRID_ENABLED is on; flag-off behavior is unchanged.
12
+ fastembed==0.3.6
@@ -102,6 +102,24 @@ SEARCH_INTENT_BOOST = os.environ.get("SEARCH_INTENT_BOOST", "1") not in ("0", "f
102
102
  # without a parseable timestamp sink to the bottom but aren't dropped.
103
103
  SEARCH_TEMPORAL_RERANK = os.environ.get("SEARCH_TEMPORAL_RERANK", "1") not in ("0", "false", "")
104
104
 
105
+ # ── Hybrid lexical+dense retrieval (roadmap BET 3) ───────────────────
106
+ # SEARCH_HYBRID_ENABLED gates EVERY hybrid behavior in one switch:
107
+ # - /store and /store-batch additionally write a NAMED sparse vector
108
+ # ("lex", BM25 term weights via fastembed, CPU-only) alongside the
109
+ # existing unnamed dense vector. The dense embedder + its vectors
110
+ # are never touched — additive only, zero dense re-embed.
111
+ # - startup runs an idempotent update_collection to add the sparse
112
+ # vector config ("lex": IDF modifier, on-disk index) when missing.
113
+ # - /search swaps the single dense search() for a server-side
114
+ # RRF-fused query_points(prefetch=[dense, lex]) — everything
115
+ # downstream (dedup → intent boost → MMR/temporal → quota →
116
+ # hydration) is untouched; the RRF score lands in r.score.
117
+ # Default OFF (env unset/0/false): the request path is byte-identical
118
+ # to pre-hybrid behavior and fastembed is never imported at all.
119
+ SEARCH_HYBRID_ENABLED = os.environ.get("SEARCH_HYBRID_ENABLED", "") not in ("", "0", "false")
120
+ SPARSE_VECTOR_NAME = "lex"
121
+ SPARSE_MODEL_NAME = os.environ.get("SEARCH_SPARSE_MODEL", "Qdrant/bm25")
122
+
105
123
  TEMPORAL_INTENT_RE = re.compile(
106
124
  r"\b(when did|when was|last (?:time|met|saw|spoke|called)|"
107
125
  r"how long ago|first time (?:i|we) (?:met|saw|spoke)|recent(?:ly)?|"
@@ -116,6 +134,18 @@ FACTUAL_INTENT_RE = re.compile(
116
134
  )
117
135
  INTENT_BOOSTS: dict[str, dict[str, float]] = {
118
136
  # source_kind -> additive boost on cosine score
137
+ #
138
+ # ⚠️ HYBRID-RRF RECALIBRATION NEEDED (BET 3): these magnitudes were
139
+ # tuned against COSINE similarity scores (typical 0.7–0.85 range,
140
+ # where +0.06 flips a near-tie). When SEARCH_HYBRID_ENABLED is on,
141
+ # /search returns RRF fusion scores instead — 1/(k+rank) with
142
+ # Qdrant's k=60, i.e. ~0.016 at rank 1 decaying to ~0.006 at rank
143
+ # 100. On that scale a +0.06 additive boost is no longer a nudge:
144
+ # it catapults any matching source_kind above EVERY un-boosted
145
+ # result regardless of rank. Do not flip the hybrid flag to
146
+ # default-on until these are recalibrated against eval-harness
147
+ # numbers (see eval/recall_at_k.py); flag-off default protects
148
+ # prod until then.
119
149
  "temporal": {"event": 0.08, "doc": 0.04, "note": 0.02},
120
150
  "factual": {"doc": 0.06, "note": 0.03, "event": 0.03},
121
151
  }
@@ -184,6 +214,64 @@ def _apply_temporal_sort(
184
214
  return sorted(results, key=neg_ts)
185
215
 
186
216
 
217
+ # ── Sparse (BM25) encoding — hybrid retrieval, BET 3 ─────────────────
218
+ # fastembed's Qdrant/bm25 sparse encoder. CPU-only — no GPU contention
219
+ # with the dense embed gateway. Lazily initialised so that (a) flag-off
220
+ # deployments never import fastembed (it isn't even a hard dependency
221
+ # of the request path) and (b) the model artifact download happens on
222
+ # first use, not at process start.
223
+ _sparse_encoder: Any | None = None
224
+
225
+
226
+ def _get_sparse_encoder() -> Any:
227
+ global _sparse_encoder
228
+ if _sparse_encoder is None:
229
+ # Deferred import — module load must stay fastembed-free when
230
+ # SEARCH_HYBRID_ENABLED is off.
231
+ from fastembed import SparseTextEmbedding
232
+
233
+ _sparse_encoder = SparseTextEmbedding(model_name=SPARSE_MODEL_NAME)
234
+ log.info(f"sparse encoder initialised: {SPARSE_MODEL_NAME}")
235
+ return _sparse_encoder
236
+
237
+
238
+ def _to_sparse_vector(emb: Any) -> qmodels.SparseVector:
239
+ """fastembed SparseEmbedding (numpy indices/values) → Qdrant model."""
240
+ return qmodels.SparseVector(
241
+ indices=[int(i) for i in emb.indices],
242
+ values=[float(v) for v in emb.values],
243
+ )
244
+
245
+
246
+ async def _sparse_encode_documents(texts: list[str]) -> list[qmodels.SparseVector]:
247
+ """BM25-encode full document content for the named 'lex' vector.
248
+ Runs in a thread — fastembed is synchronous CPU work and must not
249
+ block the event loop under concurrent /store-batch load."""
250
+ enc = _get_sparse_encoder()
251
+ embs = await asyncio.to_thread(lambda: list(enc.embed(texts)))
252
+ return [_to_sparse_vector(e) for e in embs]
253
+
254
+
255
+ async def _sparse_encode_query(text: str) -> qmodels.SparseVector:
256
+ """BM25-encode a query. `query_embed` (not `embed`) — BM25 weights
257
+ documents by term frequency/length but queries as bare term sets;
258
+ the IDF half lives server-side via Modifier.IDF on the collection."""
259
+ enc = _get_sparse_encoder()
260
+ embs = await asyncio.to_thread(lambda: list(enc.query_embed(text)))
261
+ return _to_sparse_vector(embs[0])
262
+
263
+
264
+ def _dense_vector_of(candidate: Any) -> Any:
265
+ """Extract the dense vector from a scored point. With hybrid on,
266
+ Qdrant returns the full named-vector bag ({'': dense, 'lex':
267
+ sparse}); the dense vector rides the default '' slot. Flag-off
268
+ points return the bare list unchanged."""
269
+ v = getattr(candidate, "vector", None)
270
+ if isinstance(v, dict):
271
+ return v.get("")
272
+ return v
273
+
274
+
187
275
  def _mmr_select(
188
276
  candidates: list[Any], target: int, lambda_: float
189
277
  ) -> list[Any]:
@@ -199,10 +287,12 @@ def _mmr_select(
199
287
  if not candidates or target <= 0:
200
288
  return []
201
289
  # Bail to pure-relevance ordering if vectors weren't returned.
202
- if any(getattr(c, "vector", None) is None for c in candidates):
290
+ # (_dense_vector_of unwraps the hybrid named-vector bag; flag-off
291
+ # bare-list vectors pass through unchanged.)
292
+ if any(_dense_vector_of(c) is None for c in candidates):
203
293
  return sorted(candidates, key=lambda r: r.score, reverse=True)[:target]
204
294
 
205
- vecs = np.asarray([c.vector for c in candidates], dtype=np.float32)
295
+ vecs = np.asarray([_dense_vector_of(c) for c in candidates], dtype=np.float32)
206
296
  scores = np.asarray([c.score for c in candidates], dtype=np.float32)
207
297
  # Precompute pairwise similarity matrix; cheaper than per-step
208
298
  # dot products at our scale and lets us slice into it by index.
@@ -239,6 +329,47 @@ _qdrant: AsyncQdrantClient | None = None
239
329
  _http: httpx.AsyncClient | None = None
240
330
 
241
331
 
332
+ def _sparse_vectors_config() -> dict[str, Any]:
333
+ """The 'lex' named-sparse-vector schema (BET 3).
334
+
335
+ Modifier.IDF — Qdrant computes/applies IDF server-side, so the
336
+ client-side BM25 encoding only needs term frequency × length
337
+ normalisation (which is exactly what fastembed's Qdrant/bm25
338
+ produces). on_disk index — the sparse index joins the dense
339
+ vectors on disk rather than competing for RAM; the 06-05 outage
340
+ was disk pressure, not RAM, and mmap/page-cache governs hot set
341
+ the same way the dense side is configured."""
342
+ return {
343
+ SPARSE_VECTOR_NAME: qmodels.SparseVectorParams(
344
+ modifier=qmodels.Modifier.IDF,
345
+ index=qmodels.SparseIndexParams(on_disk=True),
346
+ )
347
+ }
348
+
349
+
350
+ async def _ensure_sparse_vector_config() -> bool:
351
+ """Idempotent collection migration: add the 'lex' sparse vector
352
+ config to the existing collection when missing. Called from
353
+ lifespan only when SEARCH_HYBRID_ENABLED — flag-off startups never
354
+ touch the collection config. Adding a sparse vector config is
355
+ additive metadata: existing points and the unnamed dense vector
356
+ are untouched (no re-embed, no rebuild). Returns True if the
357
+ config was added, False if already present."""
358
+ info = await _qdrant.get_collection(COLLECTION_NAME)
359
+ existing = getattr(info.config.params, "sparse_vectors", None) or {}
360
+ if SPARSE_VECTOR_NAME in existing:
361
+ return False
362
+ await _qdrant.update_collection(
363
+ collection_name=COLLECTION_NAME,
364
+ sparse_vectors_config=_sparse_vectors_config(),
365
+ )
366
+ log.info(
367
+ f"added sparse vector config '{SPARSE_VECTOR_NAME}' "
368
+ f"(modifier=idf, on_disk=true) to collection {COLLECTION_NAME}"
369
+ )
370
+ return True
371
+
372
+
242
373
  @asynccontextmanager
243
374
  async def lifespan(app: FastAPI):
244
375
  global _pool, _qdrant, _http
@@ -260,6 +391,13 @@ async def lifespan(app: FastAPI):
260
391
  collections = await _qdrant.get_collections()
261
392
  names = {c.name for c in collections.collections}
262
393
  if COLLECTION_NAME not in names:
394
+ create_kwargs: dict[str, Any] = {}
395
+ if SEARCH_HYBRID_ENABLED:
396
+ # Fresh collection with the flag on gets the 'lex'
397
+ # sparse config at creation time (BET 3); existing
398
+ # collections are migrated by
399
+ # _ensure_sparse_vector_config below.
400
+ create_kwargs["sparse_vectors_config"] = _sparse_vectors_config()
263
401
  await _qdrant.create_collection(
264
402
  collection_name=COLLECTION_NAME,
265
403
  vectors_config=qmodels.VectorParams(
@@ -275,6 +413,7 @@ async def lifespan(app: FastAPI):
275
413
  always_ram=False,
276
414
  )
277
415
  ),
416
+ **create_kwargs,
278
417
  )
279
418
  log.info(f"created qdrant collection: {COLLECTION_NAME} dim={EMBED_DIM}")
280
419
  # Payload indexes for fast filtered search (this is the
@@ -286,6 +425,12 @@ async def lifespan(app: FastAPI):
286
425
  field_schema=qmodels.PayloadSchemaType.KEYWORD,
287
426
  )
288
427
  log.info("created qdrant payload indexes: arena, source_kind, clientId, userId")
428
+ if SEARCH_HYBRID_ENABLED:
429
+ # BET 3 migration — idempotent, additive-only; no-op when
430
+ # the 'lex' config is already present. Flag-off startups
431
+ # never reach this line, so the collection config is
432
+ # byte-identical to today until the flag is flipped.
433
+ await _ensure_sparse_vector_config()
289
434
  except Exception as e:
290
435
  log.error(f"qdrant init error: {e}")
291
436
  # Don't crash compat on Qdrant init failure — let liveness
@@ -553,6 +698,16 @@ async def store(req: StoreRequest):
553
698
  event_id = await _extract(arena, clientId, userId, source_kind, req.content, meta)
554
699
  embeddings = await _embed_batch([req.content])
555
700
 
701
+ # BET 3: BM25-encode the FULL content into the named 'lex' sparse
702
+ # vector. Encode failure degrades to dense-only (ingest must not
703
+ # fail on the lexical leg; the backfill script repairs gaps).
704
+ sparse_vec: Any | None = None
705
+ if SEARCH_HYBRID_ENABLED:
706
+ try:
707
+ sparse_vec = (await _sparse_encode_documents([req.content]))[0]
708
+ except Exception as e:
709
+ log.warning(f"sparse encode failed; storing dense-only (backfill repairs): {e}")
710
+
556
711
  vector_id = str(uuid.uuid4())
557
712
  # Write vector_provenance + Qdrant point in the same logical
558
713
  # operation. If Qdrant fails, the provenance row gets rolled back —
@@ -569,7 +724,15 @@ async def store(req: StoreRequest):
569
724
  points=[
570
725
  qmodels.PointStruct(
571
726
  id=vector_id,
572
- vector=embeddings[0],
727
+ # Flag-off: bare dense list — byte-identical to
728
+ # today. Flag-on: named-vector bag; the dense
729
+ # vector keeps its unnamed ('') slot, 'lex' is
730
+ # purely additive.
731
+ vector=(
732
+ embeddings[0]
733
+ if sparse_vec is None
734
+ else {"": embeddings[0], SPARSE_VECTOR_NAME: sparse_vec}
735
+ ),
573
736
  # Issue #345 (caps #342/#343/#344): Pip emits a rich
574
737
  # metadata bag — timestamp, contact_email, channel,
575
738
  # kind, direction, source, etc. Pre-fix the payload
@@ -621,6 +784,22 @@ async def store_batch(req: StoreBatchRequest):
621
784
  if len(embeddings) != len(texts):
622
785
  raise HTTPException(500, f"embed count mismatch: {len(embeddings)} vs {len(texts)}")
623
786
 
787
+ # BET 3: sparse-encode the FULL content batch for the named 'lex'
788
+ # vector. Best-effort — a sparse failure degrades the whole batch
789
+ # to dense-only rather than failing ingest (backfill repairs).
790
+ sparse_vecs: list[Any] | None = None
791
+ if SEARCH_HYBRID_ENABLED:
792
+ try:
793
+ sparse_vecs = await _sparse_encode_documents(texts)
794
+ if len(sparse_vecs) != len(texts):
795
+ log.warning(
796
+ f"sparse encode count mismatch ({len(sparse_vecs)} vs {len(texts)}); storing dense-only"
797
+ )
798
+ sparse_vecs = None
799
+ except Exception as e:
800
+ log.warning(f"sparse encode failed; storing dense-only (backfill repairs): {e}")
801
+ sparse_vecs = None
802
+
624
803
  # Resolve per-record routing fields first so we can fan out the
625
804
  # extractor-sync calls in parallel. Each _extract is a network
626
805
  # round-trip; serialising them was the dominant cost in /store-batch
@@ -644,9 +823,9 @@ async def store_batch(req: StoreBatchRequest):
644
823
  ids: list[str] = []
645
824
  points: list[qmodels.PointStruct] = []
646
825
  provenance_rows: list[tuple] = []
647
- for (arena, clientId, userId, source_kind, content, meta), vec, event_id in zip(
826
+ for idx, ((arena, clientId, userId, source_kind, content, meta), vec, event_id) in enumerate(zip(
648
827
  resolved, embeddings, event_ids
649
- ):
828
+ )):
650
829
  vector_id = str(uuid.uuid4())
651
830
  provenance_rows.append((vector_id, event_id, "nv-embed-v2", EMBED_DIM))
652
831
  # See /store above — issue #345. Spread the caller's metadata
@@ -655,7 +834,13 @@ async def store_batch(req: StoreBatchRequest):
655
834
  # work with. Structural keys override on collision.
656
835
  points.append(qmodels.PointStruct(
657
836
  id=vector_id,
658
- vector=vec,
837
+ # BET 3: flag-off keeps the bare dense list (byte-identical
838
+ # to today); flag-on adds the named 'lex' sparse vector.
839
+ vector=(
840
+ vec
841
+ if sparse_vecs is None
842
+ else {"": vec, SPARSE_VECTOR_NAME: sparse_vecs[idx]}
843
+ ),
659
844
  payload={
660
845
  **(meta or {}),
661
846
  "event_id": event_id,
@@ -896,18 +1081,73 @@ async def search(req: SearchRequest):
896
1081
  # vector-payload bandwidth (4096 × float32 × overfetch) when
897
1082
  # vectors won't be used.
898
1083
  temporal_active = (intent == "temporal") and SEARCH_TEMPORAL_RERANK
899
- raw_results = await _qdrant.search(
900
- collection_name=COLLECTION_NAME,
901
- query_vector=qvec,
902
- query_filter=filter_,
903
- limit=max(overfetch, target_limit),
904
- score_threshold=req.min_score,
905
- with_payload=True,
906
- # Phase 3 (#343): MMR needs the actual vectors to score pairwise
907
- # similarity. Only pull them when MMR is enabled AND we aren't
908
- # about to skip MMR for a temporal re-rank.
909
- with_vectors=SEARCH_MMR_ENABLED and not temporal_active,
910
- )
1084
+ fetch_limit = max(overfetch, target_limit)
1085
+ # Phase 3 (#343): MMR needs the actual vectors to score pairwise
1086
+ # similarity. Only pull them when MMR is enabled AND we aren't
1087
+ # about to skip MMR for a temporal re-rank.
1088
+ fetch_vectors = SEARCH_MMR_ENABLED and not temporal_active
1089
+
1090
+ # ── BET 3: hybrid lexical+dense retrieval ────────────────────────
1091
+ # Flag on encode the query with BM25 and replace the single dense
1092
+ # search() with a server-side RRF fusion over two prefetch legs
1093
+ # (dense on the unnamed '' vector, lexical on the named 'lex'
1094
+ # sparse vector). Qdrant runs both legs inside one request, fuses
1095
+ # by reciprocal rank (1/(k+rank), k=60), and the fused score lands
1096
+ # in r.score — everything downstream (dedup → intent boost →
1097
+ # MMR/temporal → quota → hydration) is untouched.
1098
+ #
1099
+ # ⚠️ SCORE-SCALE CAVEAT (recalibration required before default-on):
1100
+ # RRF scores live on a ~0.006–0.033 scale, NOT the cosine 0.7–0.85
1101
+ # scale the intent-boost magnitudes (+0.02…+0.08, see INTENT_BOOSTS)
1102
+ # were tuned against. With hybrid on, those additive boosts dominate
1103
+ # the fused ranking instead of nudging it. The flag-off default
1104
+ # protects prod until eval-harness numbers (eval/recall_at_k.py)
1105
+ # exist to recalibrate them. `min_score` is likewise a cosine-scale
1106
+ # knob, so it is NOT applied to the fused path.
1107
+ #
1108
+ # A sparse-encode failure (e.g. fastembed missing/model fetch
1109
+ # failed) logs and falls back to the legacy dense-only path —
1110
+ # /search availability never depends on the lexical leg.
1111
+ sparse_qvec: Any | None = None
1112
+ if SEARCH_HYBRID_ENABLED:
1113
+ try:
1114
+ sparse_qvec = await _sparse_encode_query(req.query)
1115
+ except Exception as e:
1116
+ log.warning(f"sparse query encode failed; dense-only fallback: {e}")
1117
+
1118
+ if sparse_qvec is not None:
1119
+ fused = await _qdrant.query_points(
1120
+ collection_name=COLLECTION_NAME,
1121
+ prefetch=[
1122
+ qmodels.Prefetch(
1123
+ query=qvec,
1124
+ using="", # the unnamed dense vector's internal name
1125
+ filter=filter_,
1126
+ limit=fetch_limit,
1127
+ ),
1128
+ qmodels.Prefetch(
1129
+ query=sparse_qvec,
1130
+ using=SPARSE_VECTOR_NAME,
1131
+ filter=filter_,
1132
+ limit=fetch_limit,
1133
+ ),
1134
+ ],
1135
+ query=qmodels.FusionQuery(fusion=qmodels.Fusion.RRF),
1136
+ limit=fetch_limit,
1137
+ with_payload=True,
1138
+ with_vectors=fetch_vectors,
1139
+ )
1140
+ raw_results = fused.points
1141
+ else:
1142
+ raw_results = await _qdrant.search(
1143
+ collection_name=COLLECTION_NAME,
1144
+ query_vector=qvec,
1145
+ query_filter=filter_,
1146
+ limit=fetch_limit,
1147
+ score_threshold=req.min_score,
1148
+ with_payload=True,
1149
+ with_vectors=fetch_vectors,
1150
+ )
911
1151
 
912
1152
  # (a) dedup by event_id — first occurrence wins (highest score).
913
1153
  seen_eids: set[str] = set()
@@ -0,0 +1,242 @@
1
+ #!/usr/bin/env python3
2
+ """Retrieval eval: recall@k / nDCG@k for hybrid (SEARCH_HYBRID_ENABLED=1)
3
+ vs baseline dense-only /search (roadmap BET 3).
4
+
5
+ The hybrid flag is a SERVER-side env var, not a request parameter, so a
6
+ flag-on/flag-off comparison needs either (a) two compat instances — one
7
+ with the flag on, one off — passed as --base-url-on/--base-url-off, or
8
+ (b) two separate runs against one instance while the operator flips the
9
+ flag, each labelled with --label and saved with --out, then compared
10
+ offline with --compare run_a.json run_b.json.
11
+
12
+ Stdlib-only (urllib) — runnable on the engine box or anywhere with HTTP
13
+ access to compat. This script makes NO calls until you point it at an
14
+ engine (--base-url*); CI never runs it. Usage:
15
+
16
+ # two instances side by side
17
+ python3 recall_at_k.py --golden retrieval_golden.seed.json \
18
+ --base-url-off http://127.0.0.1:8099 \
19
+ --base-url-on http://127.0.0.1:8098 \
20
+ --k 5 10 20
21
+
22
+ # one instance, two passes (operator flips SEARCH_HYBRID_ENABLED between)
23
+ python3 recall_at_k.py --golden ... --base-url http://127.0.0.1:8099 \
24
+ --label flag-off --out runs/off.json
25
+ python3 recall_at_k.py --golden ... --base-url http://127.0.0.1:8099 \
26
+ --label flag-on --out runs/on.json
27
+ python3 recall_at_k.py --compare runs/off.json runs/on.json
28
+
29
+ Metrics per question (and mean over questions):
30
+ recall@k — |relevant ∩ top-k| / |relevant|
31
+ nDCG@k — graded (relevance 2/1), log2 discount, normalised by the
32
+ ideal ordering of that question's judged set.
33
+ Questions whose `relevant` list still contains placeholders (or is
34
+ empty) are skipped and reported as unjudged.
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ import argparse
40
+ import json
41
+ import math
42
+ import sys
43
+ import urllib.error
44
+ import urllib.request
45
+
46
+ DEFAULT_KS = [5, 10, 20]
47
+
48
+
49
+ # ----------------------------------------------------------------------
50
+ # Metrics — pure functions, unit-testable without any engine.
51
+ # ----------------------------------------------------------------------
52
+
53
+
54
+ def recall_at_k(ranked_ids: list[str], relevant_ids: set[str], k: int) -> float:
55
+ if not relevant_ids:
56
+ return 0.0
57
+ hits = sum(1 for rid in ranked_ids[:k] if rid in relevant_ids)
58
+ return hits / len(relevant_ids)
59
+
60
+
61
+ def dcg_at_k(ranked_ids: list[str], gains: dict[str, float], k: int) -> float:
62
+ return sum(
63
+ gains.get(rid, 0.0) / math.log2(i + 2) # i=0 → log2(2)=1
64
+ for i, rid in enumerate(ranked_ids[:k])
65
+ )
66
+
67
+
68
+ def ndcg_at_k(ranked_ids: list[str], gains: dict[str, float], k: int) -> float:
69
+ ideal = sorted(gains.values(), reverse=True)[:k]
70
+ idcg = sum(g / math.log2(i + 2) for i, g in enumerate(ideal))
71
+ if idcg <= 0:
72
+ return 0.0
73
+ return dcg_at_k(ranked_ids, gains, k) / idcg
74
+
75
+
76
+ def is_judged(question: dict) -> bool:
77
+ rel = question.get("relevant") or []
78
+ return bool(rel) and not any(
79
+ "PLACEHOLDER" in (r.get("event_id") or "") for r in rel
80
+ )
81
+
82
+
83
+ def evaluate_ranking(ranked_ids: list[str], question: dict, ks: list[int]) -> dict:
84
+ rel = question.get("relevant") or []
85
+ relevant_ids = {r["event_id"] for r in rel}
86
+ gains = {r["event_id"]: float(r.get("relevance", 1)) for r in rel}
87
+ return {
88
+ "recall": {k: recall_at_k(ranked_ids, relevant_ids, k) for k in ks},
89
+ "ndcg": {k: ndcg_at_k(ranked_ids, gains, k) for k in ks},
90
+ }
91
+
92
+
93
+ def summarize(per_question: list[dict], ks: list[int]) -> dict:
94
+ if not per_question:
95
+ return {"recall": {k: 0.0 for k in ks}, "ndcg": {k: 0.0 for k in ks}, "n": 0}
96
+ return {
97
+ "n": len(per_question),
98
+ "recall": {
99
+ k: sum(q["metrics"]["recall"][k] for q in per_question) / len(per_question)
100
+ for k in ks
101
+ },
102
+ "ndcg": {
103
+ k: sum(q["metrics"]["ndcg"][k] for q in per_question) / len(per_question)
104
+ for k in ks
105
+ },
106
+ }
107
+
108
+
109
+ # ----------------------------------------------------------------------
110
+ # Engine I/O
111
+ # ----------------------------------------------------------------------
112
+
113
+
114
+ def search(base_url: str, query: str, arena: str, limit: int, timeout: float = 30.0) -> list[str]:
115
+ body = json.dumps({"query": query, "arena": arena, "limit": limit}).encode()
116
+ req = urllib.request.Request(
117
+ base_url.rstrip("/") + "/search",
118
+ data=body,
119
+ headers={"Content-Type": "application/json"},
120
+ method="POST",
121
+ )
122
+ with urllib.request.urlopen(req, timeout=timeout) as r:
123
+ data = json.loads(r.read())
124
+ return [res["id"] for res in data.get("results", [])]
125
+
126
+
127
+ def run_pass(base_url: str, golden: dict, ks: list[int], label: str) -> dict:
128
+ max_k = max(ks)
129
+ per_question = []
130
+ unjudged = []
131
+ for q in golden.get("questions", []):
132
+ if not is_judged(q):
133
+ unjudged.append(q.get("id"))
134
+ continue
135
+ arena = q.get("arena") or golden.get("default_arena")
136
+ try:
137
+ ranked = search(base_url, q["query"], arena, limit=max_k)
138
+ except (urllib.error.URLError, OSError) as e:
139
+ print(f" [{label}] {q['id']}: SEARCH FAILED: {e}", file=sys.stderr)
140
+ continue
141
+ m = evaluate_ranking(ranked, q, ks)
142
+ per_question.append({"id": q["id"], "class": q.get("class"),
143
+ "ranked": ranked, "metrics": m})
144
+ return {
145
+ "label": label,
146
+ "base_url": base_url,
147
+ "ks": ks,
148
+ "unjudged": unjudged,
149
+ "per_question": per_question,
150
+ "summary": summarize(per_question, ks),
151
+ }
152
+
153
+
154
+ def print_run(run: dict) -> None:
155
+ ks = run["ks"]
156
+ s = run["summary"]
157
+ print(f"\n== {run['label']} ({run['base_url']}) — {s.get('n', 0)} judged questions ==")
158
+ if run.get("unjudged"):
159
+ print(f" skipped (placeholders/empty): {', '.join(run['unjudged'])}")
160
+ header = "metric " + "".join(f" @{k:<5}" for k in ks)
161
+ print(header)
162
+ print("recall " + "".join(f" {s['recall'][k]:.3f} " for k in ks))
163
+ print("nDCG " + "".join(f" {s['ndcg'][k]:.3f} " for k in ks))
164
+ for q in run["per_question"]:
165
+ r = q["metrics"]
166
+ print(f" {q['id']:<20} ({q.get('class') or '-':<8}) "
167
+ + " ".join(f"R@{k}={r['recall'][k]:.2f}" for k in ks))
168
+
169
+
170
+ def print_comparison(off: dict, on: dict) -> None:
171
+ ks = off["ks"]
172
+ print(f"\n== Δ ({on['label']} − {off['label']}) ==")
173
+ for name in ("recall", "ndcg"):
174
+ deltas = "".join(
175
+ f" {on['summary'][name][k] - off['summary'][name][k]:+.3f}" for k in ks
176
+ )
177
+ print(f"{name:<8}{deltas} (k = {ks})")
178
+
179
+
180
+ def _coerce_keys(run: dict) -> dict:
181
+ """JSON round-trip turns int dict keys into strings — undo that."""
182
+ run["ks"] = [int(k) for k in run["ks"]]
183
+ for scope in [run["summary"], *[q["metrics"] for q in run["per_question"]]]:
184
+ for name in ("recall", "ndcg"):
185
+ if name in scope:
186
+ scope[name] = {int(k): v for k, v in scope[name].items()}
187
+ return run
188
+
189
+
190
+ def main() -> int:
191
+ p = argparse.ArgumentParser(description=__doc__,
192
+ formatter_class=argparse.RawDescriptionHelpFormatter)
193
+ p.add_argument("--golden", default=None, help="golden questions JSON")
194
+ p.add_argument("--base-url", default=None, help="single engine base url")
195
+ p.add_argument("--base-url-off", default=None, help="flag-OFF engine base url")
196
+ p.add_argument("--base-url-on", default=None, help="flag-ON engine base url")
197
+ p.add_argument("--k", type=int, nargs="+", default=DEFAULT_KS)
198
+ p.add_argument("--label", default="run", help="label for single-pass mode")
199
+ p.add_argument("--out", default=None, help="write run JSON here")
200
+ p.add_argument("--compare", nargs=2, metavar=("OFF_JSON", "ON_JSON"),
201
+ help="compare two previously saved runs; no engine calls")
202
+ args = p.parse_args()
203
+
204
+ if args.compare:
205
+ with open(args.compare[0]) as f:
206
+ off = _coerce_keys(json.load(f))
207
+ with open(args.compare[1]) as f:
208
+ on = _coerce_keys(json.load(f))
209
+ print_run(off)
210
+ print_run(on)
211
+ print_comparison(off, on)
212
+ return 0
213
+
214
+ if not args.golden:
215
+ p.error("--golden required unless --compare")
216
+ with open(args.golden) as f:
217
+ golden = json.load(f)
218
+
219
+ ks = sorted(set(args.k))
220
+ runs = []
221
+ if args.base_url_off and args.base_url_on:
222
+ runs.append(run_pass(args.base_url_off, golden, ks, "flag-off"))
223
+ runs.append(run_pass(args.base_url_on, golden, ks, "flag-on"))
224
+ elif args.base_url:
225
+ runs.append(run_pass(args.base_url, golden, ks, args.label))
226
+ else:
227
+ p.error("provide --base-url, or both --base-url-off and --base-url-on")
228
+
229
+ for r in runs:
230
+ print_run(r)
231
+ if len(runs) == 2:
232
+ print_comparison(runs[0], runs[1])
233
+
234
+ if args.out:
235
+ with open(args.out, "w") as f:
236
+ json.dump(runs[0] if len(runs) == 1 else {"runs": runs}, f, indent=2)
237
+ print(f"\nwrote {args.out}")
238
+ return 0
239
+
240
+
241
+ if __name__ == "__main__":
242
+ sys.exit(main())