structuremappingmemory 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. sma/__init__.py +5 -0
  2. sma/__main__.py +5 -0
  3. sma/agent/__init__.py +5 -0
  4. sma/agent/adapter_draft.py +217 -0
  5. sma/agent/api.py +67 -0
  6. sma/agent/comparison.py +591 -0
  7. sma/agent/llm.py +280 -0
  8. sma/agent/policies.py +21 -0
  9. sma/agent/service.py +95 -0
  10. sma/cli.py +65 -0
  11. sma/encoders/__init__.py +38 -0
  12. sma/encoders/agentobs.py +27 -0
  13. sma/encoders/base.py +23 -0
  14. sma/encoders/code_treesitter.py +64 -0
  15. sma/encoders/coverage.py +80 -0
  16. sma/encoders/draft_adapter.py +183 -0
  17. sma/encoders/healthcare.py +207 -0
  18. sma/encoders/logs_drain.py +142 -0
  19. sma/encoders/prose_tier1.py +57 -0
  20. sma/encoders/structured.py +57 -0
  21. sma/encoders/traces.py +45 -0
  22. sma/eval/__init__.py +2 -0
  23. sma/eval/agentic/__init__.py +35 -0
  24. sma/eval/agentic/arms/__init__.py +0 -0
  25. sma/eval/agentic/arms/cyber.py +48 -0
  26. sma/eval/agentic/arms/discovery.py +35 -0
  27. sma/eval/agentic/arms/finance.py +38 -0
  28. sma/eval/agentic/arms/legal.py +74 -0
  29. sma/eval/agentic/arms/medicine.py +45 -0
  30. sma/eval/agentic/harness.py +275 -0
  31. sma/eval/agentic/memories.py +308 -0
  32. sma/eval/agentic/metrics.py +82 -0
  33. sma/eval/agentic_qa/__init__.py +27 -0
  34. sma/eval/agentic_qa/agent.py +383 -0
  35. sma/eval/agentic_qa/metrics.py +239 -0
  36. sma/eval/agentic_qa/pools.py +197 -0
  37. sma/eval/arn.py +65 -0
  38. sma/eval/baselines/__init__.py +6 -0
  39. sma/eval/baselines/bge_dense.py +54 -0
  40. sma/eval/baselines/bm25.py +18 -0
  41. sma/eval/baselines/dense.py +42 -0
  42. sma/eval/baselines/hipporag.py +235 -0
  43. sma/eval/baselines/hybrid_rrf.py +30 -0
  44. sma/eval/baselines/longcontext_llm.py +124 -0
  45. sma/eval/baselines/rerank.py +41 -0
  46. sma/eval/baselines/splade.py +77 -0
  47. sma/eval/baselines/wl_kernel.py +163 -0
  48. sma/eval/bugsinpy.py +358 -0
  49. sma/eval/bugsinpy_families.py +164 -0
  50. sma/eval/crossdomain.py +89 -0
  51. sma/eval/diabetes.py +61 -0
  52. sma/eval/drift_env.py +26 -0
  53. sma/eval/drift_metrics.py +24 -0
  54. sma/eval/family_labels.py +167 -0
  55. sma/eval/fraud_elliptic/__init__.py +29 -0
  56. sma/eval/fraud_elliptic/encoder.py +279 -0
  57. sma/eval/fraud_elliptic/eval.py +269 -0
  58. sma/eval/fraud_elliptic/test_encoder.py +123 -0
  59. sma/eval/ieee_cis.py +66 -0
  60. sma/eval/loghub.py +16 -0
  61. sma/eval/loghub_eval.py +480 -0
  62. sma/eval/longmemeval.py +51 -0
  63. sma/eval/memory_backends/__init__.py +2 -0
  64. sma/eval/memory_backends/base.py +22 -0
  65. sma/eval/memory_backends/context_only.py +14 -0
  66. sma/eval/memory_backends/rag_notes.py +17 -0
  67. sma/eval/memory_backends/shared_llm.py +30 -0
  68. sma/eval/memory_backends/sma_memory.py +54 -0
  69. sma/eval/memory_backends/zep_graphiti.py +33 -0
  70. sma/eval/metrics.py +32 -0
  71. sma/eval/ontology_bench.py +219 -0
  72. sma/eval/report.py +573 -0
  73. sma/eval/ssb_eval.py +216 -0
  74. sma/eval/ssb_generator.py +116 -0
  75. sma/eval/stats.py +108 -0
  76. sma/eval/transfer_eval.py +844 -0
  77. sma/index/__init__.py +15 -0
  78. sma/index/ann.py +21 -0
  79. sma/index/content_vectors.py +60 -0
  80. sma/index/inverted.py +63 -0
  81. sma/index/macfac.py +174 -0
  82. sma/ir/__init__.py +22 -0
  83. sma/ir/canon.py +106 -0
  84. sma/ir/schema.py +165 -0
  85. sma/ir/sexpr.py +86 -0
  86. sma/ir/signatures.py +76 -0
  87. sma/match/__init__.py +20 -0
  88. sma/match/conflicts.py +46 -0
  89. sma/match/engine.py +60 -0
  90. sma/match/explain.py +59 -0
  91. sma/match/infer.py +54 -0
  92. sma/match/kernels.py +54 -0
  93. sma/match/mdl.py +30 -0
  94. sma/match/merge_cpsat.py +77 -0
  95. sma/match/merge_greedy.py +15 -0
  96. sma/match/mh.py +177 -0
  97. sma/match/ses.py +84 -0
  98. sma/match/types.py +115 -0
  99. sma/match/verifier.py +27 -0
  100. sma/ontology/__init__.py +45 -0
  101. sma/ontology/attack.py +134 -0
  102. sma/ontology/cpc.py +69 -0
  103. sma/ontology/graph.py +58 -0
  104. sma/ontology/loader.py +262 -0
  105. sma/ontology/mitre_xml.py +67 -0
  106. sma/ontology/mount.py +101 -0
  107. sma/ontology/rdf_loader.py +75 -0
  108. sma/ontology/registry.py +115 -0
  109. sma/ontology/router.py +69 -0
  110. sma/ontology/usgaap.py +73 -0
  111. sma/sage/__init__.py +6 -0
  112. sma/sage/assimilate.py +12 -0
  113. sma/sage/pools.py +105 -0
  114. sma/sage/probabilities.py +10 -0
  115. sma/store/__init__.py +6 -0
  116. sma/store/lmdb_store.py +78 -0
  117. sma/store/registry.py +26 -0
  118. sma/store/wal.py +26 -0
  119. sma/ui/app.py +642 -0
  120. structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
  121. structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
  122. structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
  123. structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
  124. structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
  125. structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,383 @@
1
+ """The one-shot QA agent for the Phase 5 LLM-QA "trustworthy specialist" phase.
2
+
3
+ A single :class:`QAAgent` holds the LLM and prompt FIXED and swaps only the
4
+ retrieval ``Memory`` (none / dense-RAG / SMA), exactly as registered in
5
+ ``configs/preregistration_v2_llmqa.md`` section 2. For each
6
+ :class:`~sma.eval.agentic_qa.pools.QAItem` it runs one agent turn and returns a
7
+ result dict carrying every field the trustworthy-QA metrics read
8
+ (``sma.eval.agentic_qa.metrics``): ``gold_id``, ``gold_name``, ``answerable``,
9
+ ``novel``, ``abstained``, ``pred_id``, ``answer``, ``novelty_flag``,
10
+ ``confidence``, ``grounding_score``.
11
+
12
+ Two grounding regimes:
13
+
14
+ * **grounded** (a memory is given) — retrieve top-k candidates, render them as a
15
+ numbered list, and ask the LLM for a strict one-line JSON ``{"choice": <n>}``
16
+ where ``n`` is a candidate number or ``0`` to abstain. ``pred_id`` is the
17
+ chosen candidate's key (the disease id), so correctness/citation can be checked
18
+ structurally against the gold. When a calibrated ``score_threshold`` is given,
19
+ a case whose top RAW grounding score falls below it is abstained AND flagged
20
+ novel *before* the LLM call (the structural score, not the saturated confidence
21
+ or the expectation-violation flag, is what separates known from unknown); with
22
+ no threshold the novelty flag falls back to ``memory.novelty(query)``.
23
+ * **closed-book** (``memory is None``) — the LLM answers from the case alone with
24
+ a strict one-line JSON ``{"diagnosis": "<name or ABSTAIN>"}``; ``pred_id`` is
25
+ ``None`` (no citation), ``confidence`` is a flat ``0.5``, and novelty is N/A.
26
+
27
+ JSON parsing is defensive (strips ```` ``` ```` code fences, scans for the first
28
+ ``{...}`` object) and falls back to ABSTAIN on any parse/validation failure, so a
29
+ malformed model reply degrades to the safe action rather than crashing the run.
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import json
35
+ import re
36
+ from typing import Any, Protocol
37
+
38
+ from sma.eval.agentic import Query
39
+ from sma.eval.agentic_qa.pools import QAItem
40
+
41
+ # How many characteristic feature names to show per candidate (keeps the prompt
42
+ # bounded; the candidate is identified by its number, not by an exhaustive list).
43
+ _FEATURES_PER_CANDIDATE = 6
44
+
45
+ SYSTEM_PROMPT = (
46
+ "You are a careful diagnostic assistant. You are given a clinical case and a "
47
+ "numbered list of candidate diseases retrieved from a grounded knowledge base, "
48
+ "each with a few of its characteristic features. Choose the single candidate "
49
+ "whose characteristic features best match the case. Answer ONLY when a "
50
+ "candidate genuinely grounds the case; if none of the candidates fit, abstain. "
51
+ "Reply with STRICT one-line JSON and nothing else: "
52
+ '{"choice": <candidate number, or 0 for none / abstain>}.'
53
+ )
54
+
55
+ CLOSED_BOOK_SYSTEM_PROMPT = (
56
+ "You are a careful diagnostic assistant. You are given a clinical case and no "
57
+ "external knowledge. Name the single most likely disease, or abstain if you are "
58
+ "not confident. Reply with STRICT one-line JSON and nothing else: "
59
+ '{"diagnosis": "<disease name, or ABSTAIN>"}.'
60
+ )
61
+
62
+ ABSTAIN = "ABSTAIN"
63
+
64
+
65
+ class LLM(Protocol):
66
+ """The fixed LLM backend (``DeepSeekOrchestrator`` or a mock in tests)."""
67
+
68
+ def complete(
69
+ self, messages: list[dict], max_tokens: int = 600, temperature: float = 0.0
70
+ ) -> str: ...
71
+
72
+
73
+ class MockLLM:
74
+ """A deterministic stand-in for the real LLM (NEVER calls DeepSeek).
75
+
76
+ Used by the tests and the ``--mock`` driver so the whole harness can run with
77
+ zero API spend. By default it picks candidate ``1`` in the grounded regime and
78
+ echoes a fixed diagnosis closed-book; pass ``choice`` / ``diagnosis`` to script
79
+ other behaviours (e.g. ``choice=0`` to exercise the abstain path). When
80
+ ``raw`` is set it is returned verbatim, to test defensive JSON parsing.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ choice: int = 1,
86
+ diagnosis: str = "Mock disease",
87
+ raw: str | None = None,
88
+ ):
89
+ self.choice = choice
90
+ self.diagnosis = diagnosis
91
+ self.raw = raw
92
+ self.calls: list[list[dict]] = []
93
+
94
+ def complete(
95
+ self, messages: list[dict], max_tokens: int = 600, temperature: float = 0.0
96
+ ) -> str:
97
+ self.calls.append(messages)
98
+ if self.raw is not None:
99
+ return self.raw
100
+ # Closed-book prompts ask for a "diagnosis" key; grounded ask for "choice".
101
+ system = messages[0]["content"] if messages else ""
102
+ if "diagnosis" in system:
103
+ return json.dumps({"diagnosis": self.diagnosis})
104
+ return json.dumps({"choice": self.choice})
105
+
106
+
107
+ def _strip_fences(text: str) -> str:
108
+ """Drop Markdown code fences so JSON wrapped in ```` ```json ... ``` ```` parses."""
109
+ t = text.strip()
110
+ if t.startswith("```"):
111
+ # Remove the opening fence (with optional language tag) and closing fence.
112
+ t = re.sub(r"^```[a-zA-Z0-9]*\s*", "", t)
113
+ t = re.sub(r"\s*```$", "", t.strip())
114
+ return t.strip()
115
+
116
+
117
+ def _parse_json_object(text: str) -> dict | None:
118
+ """Best-effort parse of a single JSON object from a (possibly noisy) reply.
119
+
120
+ Tries the whole stripped string first, then falls back to the first balanced
121
+ ``{...}`` substring. Returns ``None`` when nothing parses to a dict.
122
+ """
123
+ stripped = _strip_fences(text)
124
+ for candidate in (stripped, _first_brace_object(stripped)):
125
+ if not candidate:
126
+ continue
127
+ try:
128
+ obj = json.loads(candidate)
129
+ except (json.JSONDecodeError, ValueError):
130
+ continue
131
+ if isinstance(obj, dict):
132
+ return obj
133
+ return None
134
+
135
+
136
+ def _first_brace_object(text: str) -> str | None:
137
+ """Return the first balanced ``{...}`` substring, or ``None``."""
138
+ start = text.find("{")
139
+ if start < 0:
140
+ return None
141
+ depth = 0
142
+ for i in range(start, len(text)):
143
+ c = text[i]
144
+ if c == "{":
145
+ depth += 1
146
+ elif c == "}":
147
+ depth -= 1
148
+ if depth == 0:
149
+ return text[start : i + 1]
150
+ return None
151
+
152
+
153
+ class QAAgent:
154
+ """One-shot retrieve-then-answer agent with a swappable retrieval memory.
155
+
156
+ ``memory`` is one of the frozen ``Memory`` retrievers (``SmaMemory`` /
157
+ ``DenseMemory`` / ...) or ``None`` for the closed-book condition. ``key_to_name``
158
+ / ``key_to_terms`` map an :class:`IndexItem` key (disease id) to its display
159
+ name and its ontology term ids, used to render the numbered candidate list;
160
+ pass the same maps that back the indexed knowledge. ``k`` is the retrieval
161
+ depth and ``novelty_threshold`` is the cut for the ``expectation_violation``
162
+ novelty flag (only meaningful for SMA).
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ llm: LLM,
168
+ memory: Any | None,
169
+ *,
170
+ key_to_name: dict[str, str] | None = None,
171
+ key_to_terms: dict[str, frozenset[str]] | None = None,
172
+ k: int = 5,
173
+ novelty_threshold: float = 0.5,
174
+ score_threshold: float | None = None,
175
+ ):
176
+ self.llm = llm
177
+ self.memory = memory
178
+ self.key_to_name = key_to_name or {}
179
+ self.key_to_terms = key_to_terms or {}
180
+ self.k = k
181
+ self.novelty_threshold = novelty_threshold
182
+ # Calibrated cite-or-abstain: the RAW structural grounding score (not the
183
+ # saturated normalized confidence, nor the expectation-violation flag — both
184
+ # of which fail to separate known/unknown, AUROC~0.48) is the abstention
185
+ # signal. Below this threshold the memory has no grounding -> abstain + flag
186
+ # novel, WITHOUT spending an LLM call. None = no gate (LLM-only abstention).
187
+ self.score_threshold = score_threshold
188
+
189
+ # -- rendering ----------------------------------------------------------
190
+ def _feature_text(self, key: str) -> str:
191
+ """A few characteristic feature NAMES for a candidate disease."""
192
+ terms = sorted(self.key_to_terms.get(key, frozenset()))
193
+ names = [self._term_name(t) for t in terms[:_FEATURES_PER_CANDIDATE]]
194
+ return ", ".join(n for n in names if n)
195
+
196
+ def _term_name(self, term_id: str) -> str:
197
+ """Resolve a term id to a human name via the SMA ontology when available."""
198
+ mounted = getattr(self.memory, "mounted", None)
199
+ if mounted is not None:
200
+ term = mounted.graph.terms.get(term_id)
201
+ if term is not None and term.name:
202
+ return term.name
203
+ return term_id
204
+
205
+ def _render_candidates(self, retrieved: list) -> tuple[str, list[str]]:
206
+ """Build the numbered candidate block and the parallel key list.
207
+
208
+ Returns ``(text, keys)`` where ``keys[i]`` is the disease id of candidate
209
+ ``i + 1`` (so a parsed ``{"choice": n}`` maps to ``keys[n - 1]``).
210
+ """
211
+ lines: list[str] = []
212
+ keys: list[str] = []
213
+ for i, r in enumerate(retrieved, 1):
214
+ keys.append(r.key)
215
+ name = self.key_to_name.get(r.key, r.key)
216
+ features = self._feature_text(r.key)
217
+ feat = f" -- characteristic features: {features}" if features else ""
218
+ lines.append(f"[{i}] {name}{feat}")
219
+ return "\n".join(lines), keys
220
+
221
+ # -- answer -------------------------------------------------------------
222
+ def answer(self, item: QAItem) -> dict:
223
+ """Run one agent turn over ``item`` and return the metrics result dict."""
224
+ if self.memory is None:
225
+ return self._answer_closed_book(item)
226
+ return self._answer_grounded(item)
227
+
228
+ def _result(
229
+ self,
230
+ item: QAItem,
231
+ *,
232
+ abstained: bool,
233
+ pred_id: str | None,
234
+ answer: str,
235
+ novelty_flag: bool,
236
+ confidence: float,
237
+ grounding_score: float | None,
238
+ ) -> dict:
239
+ """Assemble the per-item result dict the trustworthy-QA metrics read."""
240
+ return {
241
+ "gold_id": item.gold_id,
242
+ "gold_name": item.gold_name,
243
+ "answerable": item.answerable,
244
+ "novel": item.novel,
245
+ "abstained": abstained,
246
+ "pred_id": pred_id,
247
+ "answer": answer,
248
+ "novelty_flag": novelty_flag,
249
+ "confidence": confidence,
250
+ # The RAW top structural grounding score (None closed-book). This is
251
+ # the signal that actually separates known from unknown; the metrics
252
+ # use it for threshold-free discrimination AUROC.
253
+ "grounding_score": grounding_score,
254
+ }
255
+
256
+ def _answer_grounded(self, item: QAItem) -> dict:
257
+ query = Query(item.case_terms, item.case_text)
258
+ retrieved = self.memory.retrieve(query, self.k)
259
+ confidence = retrieved[0].confidence if retrieved else 0.0
260
+ grounding_score = retrieved[0].score if retrieved else 0.0
261
+
262
+ # Calibrated cite-or-abstain. If the top RAW grounding score is below the
263
+ # validation-calibrated threshold, the memory does not structurally ground
264
+ # this case -> ABSTAIN and FLAG NOVEL, WITHOUT spending an LLM call. The
265
+ # raw structural match score is the discriminating signal (answerable vs
266
+ # out-of-knowledge AUROC ~0.84); the squashed confidence (top hit always
267
+ # ~1.0) and the expectation-violation flag are not (AUROC ~0.48). A None
268
+ # threshold disables the gate -> pure LLM-mediated abstention (legacy).
269
+ if self.score_threshold is not None and grounding_score < self.score_threshold:
270
+ return self._result(
271
+ item,
272
+ abstained=True,
273
+ pred_id=None,
274
+ answer=ABSTAIN,
275
+ novelty_flag=True,
276
+ confidence=confidence,
277
+ grounding_score=grounding_score,
278
+ )
279
+
280
+ candidates_text, keys = self._render_candidates(retrieved)
281
+
282
+ # With a calibrated gate, the structural signal IS the novelty signal:
283
+ # above threshold here -> not flagged. Without a gate, fall back to the
284
+ # memory's own expectation-violation novelty vs novelty_threshold.
285
+ if self.score_threshold is not None:
286
+ novelty_flag = False
287
+ else:
288
+ novelty_flag = bool(self.memory.novelty(query) > self.novelty_threshold)
289
+
290
+ user = (
291
+ f"Clinical case:\n{item.case_text}\n\n"
292
+ f"Candidate diseases:\n{candidates_text or '(none retrieved)'}\n\n"
293
+ "Rule: choose the candidate whose characteristic features best match "
294
+ "the case; answer only if a candidate genuinely grounds the case, "
295
+ "otherwise choose 0 to abstain.\n"
296
+ 'Reply with STRICT one-line JSON: {"choice": <candidate number or 0>}.'
297
+ )
298
+ reply = self.llm.complete(
299
+ [
300
+ {"role": "system", "content": SYSTEM_PROMPT},
301
+ {"role": "user", "content": user},
302
+ ],
303
+ max_tokens=600,
304
+ temperature=0.0,
305
+ )
306
+ choice = self._parse_choice(reply, n_candidates=len(keys))
307
+
308
+ if choice == 0:
309
+ pred_id: str | None = None
310
+ answer = ABSTAIN
311
+ abstained = True
312
+ else:
313
+ pred_id = keys[choice - 1]
314
+ answer = self.key_to_name.get(pred_id, pred_id)
315
+ abstained = False
316
+
317
+ return self._result(
318
+ item,
319
+ abstained=abstained,
320
+ pred_id=pred_id,
321
+ answer=answer,
322
+ novelty_flag=novelty_flag,
323
+ confidence=confidence,
324
+ grounding_score=grounding_score,
325
+ )
326
+
327
+ def _answer_closed_book(self, item: QAItem) -> dict:
328
+ user = (
329
+ f"Clinical case:\n{item.case_text}\n\n"
330
+ "Name the single most likely disease, or abstain if not confident.\n"
331
+ 'Reply with STRICT one-line JSON: {"diagnosis": "<disease name or ABSTAIN>"}.'
332
+ )
333
+ reply = self.llm.complete(
334
+ [
335
+ {"role": "system", "content": CLOSED_BOOK_SYSTEM_PROMPT},
336
+ {"role": "user", "content": user},
337
+ ],
338
+ max_tokens=600,
339
+ temperature=0.0,
340
+ )
341
+ diagnosis = self._parse_diagnosis(reply)
342
+ abstained = diagnosis.strip().upper() == ABSTAIN
343
+ answer = ABSTAIN if abstained else diagnosis
344
+
345
+ return self._result(
346
+ item,
347
+ abstained=abstained,
348
+ pred_id=None,
349
+ answer=answer,
350
+ novelty_flag=False,
351
+ confidence=0.5,
352
+ grounding_score=None,
353
+ )
354
+
355
+ # -- parsing ------------------------------------------------------------
356
+ @staticmethod
357
+ def _parse_choice(reply: str, *, n_candidates: int) -> int:
358
+ """Parse ``{"choice": n}`` -> int in ``0..n_candidates``; abstain on failure.
359
+
360
+ Any parse error, missing/ill-typed ``choice``, or out-of-range index
361
+ collapses to ``0`` (abstain), the safe action.
362
+ """
363
+ obj = _parse_json_object(reply)
364
+ if obj is None or "choice" not in obj:
365
+ return 0
366
+ try:
367
+ choice = int(obj["choice"])
368
+ except (TypeError, ValueError):
369
+ return 0
370
+ if choice < 0 or choice > n_candidates:
371
+ return 0
372
+ return choice
373
+
374
+ @staticmethod
375
+ def _parse_diagnosis(reply: str) -> str:
376
+ """Parse ``{"diagnosis": "..."}`` -> str; abstain on failure."""
377
+ obj = _parse_json_object(reply)
378
+ if obj is None:
379
+ return ABSTAIN
380
+ value = obj.get("diagnosis")
381
+ if not isinstance(value, str) or not value.strip():
382
+ return ABSTAIN
383
+ return value.strip()
@@ -0,0 +1,239 @@
1
+ """Trustworthy-QA metrics for the Phase 5 LLM-QA harness (prereg v2 section 4).
2
+
3
+ Given per-item agent results, compute the four pre-registered axes that
4
+ distinguish a *verifiable specialist* from a confident-but-opaque RAG agent:
5
+
6
+ * :func:`accuracy` — answer correct on the **answerable** pool (the accuracy
7
+ floor; the capability gains must not cost accuracy).
8
+ * :func:`citation_faithfulness` — ALCE-style support score over **answered
9
+ answerable** items: did the cited candidate actually turn out to be the gold?
10
+ N/A (``None``) for the closed-book condition, which has no citation.
11
+ * :func:`abstention` — selective prediction over the union of **answerable**
12
+ (should answer) and **held-out / out-of-knowledge** (should abstain):
13
+ abstain-recall, false-abstain, selective-accuracy, plus the risk-coverage AURC
14
+ with confidence ``= 1 - abstain_flag``.
15
+ * :func:`grounding_auroc` — threshold-free discrimination of the RAW grounding
16
+ score: AUROC for separating answerable (high score) from held-out (low score).
17
+ The intrinsic "can the memory tell known from unknown" signal, independent of
18
+ where the abstention threshold sits.
19
+ * :func:`novelty_recall` / :func:`novelty_f1` — recall (and precision/F1 against
20
+ answerable false-alarms) of the novelty flag over the **novel** pool.
21
+
22
+ A result is a simple dict or object exposing: ``gold_id``, ``gold_name``,
23
+ ``answerable``, ``novel``, ``abstained`` (bool), ``pred_id`` (str | None),
24
+ ``answer`` (str), ``novelty_flag`` (bool), ``confidence`` (float),
25
+ ``grounding_score`` (float | None). The data have two disjoint groups:
26
+ **answerable** (the gold disease IS indexed) and **held-out** (the gold disease
27
+ is NOT indexed). A held-out case is simultaneously out-of-knowledge (the agent
28
+ should ABSTAIN) and novel (the agent should FLAG it) — both correct trustworthy
29
+ behaviours on the same unindexed case — so abstention and novelty are scored on
30
+ the same held-out items (``answerable == False``, ``novel == True``).
31
+ """
32
+
33
+ from __future__ import annotations
34
+
35
+ from typing import Any
36
+
37
+ from sma.eval.agentic.metrics import risk_coverage_aurc
38
+
39
+
40
+ def _get(result: Any, field: str, default: Any = None) -> Any:
41
+ """Read ``field`` from a result whether it is a dict or an object."""
42
+ if isinstance(result, dict):
43
+ return result.get(field, default)
44
+ return getattr(result, field, default)
45
+
46
+
47
+ def _correct(result: Any) -> bool:
48
+ """Did the agent name the right entity? (grounded id-match, else name-match).
49
+
50
+ If the agent cited a candidate (``pred_id`` is not None), correctness is an
51
+ exact id match against the gold. For the closed-book condition (no
52
+ retrieval, ``pred_id`` is None) we fall back to a case-insensitive substring
53
+ name-match of the free-text ``answer`` against ``gold_name``.
54
+ """
55
+ pred_id = _get(result, "pred_id")
56
+ if pred_id is not None:
57
+ return pred_id == _get(result, "gold_id")
58
+ answer = (_get(result, "answer") or "").strip().lower()
59
+ gold_name = (_get(result, "gold_name") or "").strip().lower()
60
+ if not answer or not gold_name:
61
+ return False
62
+ return gold_name in answer or answer in gold_name
63
+
64
+
65
+ def accuracy(results: list[Any]) -> float:
66
+ """Fraction of **answerable** items answered (not abstained) and correct.
67
+
68
+ Returns 0.0 when there are no answerable items (no division by zero).
69
+ """
70
+ answerable = [r for r in results if _get(r, "answerable")]
71
+ if not answerable:
72
+ return 0.0
73
+ hits = sum(
74
+ 1 for r in answerable if not _get(r, "abstained") and _correct(r)
75
+ )
76
+ return hits / len(answerable)
77
+
78
+
79
+ def citation_faithfulness(results: list[Any]) -> float | None:
80
+ """Support score over **answered answerable** items with a citation.
81
+
82
+ Over answerable items that were answered (not abstained) *and* carry a
83
+ citation (``pred_id`` is not None), the fraction whose cited candidate is in
84
+ fact the gold (``pred_id == gold_id``). Items with no retrieval/citation are
85
+ skipped. Returns ``None`` (N/A) when no item is applicable — e.g. the
86
+ closed-book condition, where citation-faithfulness is undefined.
87
+ """
88
+ cited = [
89
+ r
90
+ for r in results
91
+ if _get(r, "answerable")
92
+ and not _get(r, "abstained")
93
+ and _get(r, "pred_id") is not None
94
+ ]
95
+ if not cited:
96
+ return None
97
+ hits = sum(1 for r in cited if _get(r, "pred_id") == _get(r, "gold_id"))
98
+ return hits / len(cited)
99
+
100
+
101
+ def abstention(results: list[Any]) -> dict[str, Any]:
102
+ """Selective prediction over {answerable should-answer} + {held-out should-abstain}.
103
+
104
+ The should-abstain (out-of-knowledge) set is every **held-out** item — i.e.
105
+ ``not answerable`` — because an unindexed disease is out-of-knowledge whether
106
+ or not it is also flagged novel (it always is, here). Returns a dict with:
107
+
108
+ * ``abstain_recall`` — fraction of ook items that abstained;
109
+ * ``false_abstain`` — fraction of answerable items that wrongly abstained;
110
+ * ``selective_accuracy`` — over the answerable+ook union, the fraction that
111
+ either answered correctly (answerable) or correctly abstained (ook);
112
+ * ``aurc`` / ``rc_points`` — risk-coverage curve over the same union with
113
+ ``confidence = 1 - abstain_flag`` and ``correct = answered & correct``
114
+ (an abstain is never "correct" for the risk curve; a wrong answer is the
115
+ worst case, surfaced first at high confidence).
116
+
117
+ Empty pools yield 0.0 for their respective fractions (no division by zero).
118
+ """
119
+ answerable = [r for r in results if _get(r, "answerable")]
120
+ ook = [r for r in results if not _get(r, "answerable")]
121
+
122
+ n_ook_abstain = sum(1 for r in ook if _get(r, "abstained"))
123
+ abstain_recall = n_ook_abstain / len(ook) if ook else 0.0
124
+
125
+ n_ans_abstain = sum(1 for r in answerable if _get(r, "abstained"))
126
+ false_abstain = n_ans_abstain / len(answerable) if answerable else 0.0
127
+
128
+ union = answerable + ook
129
+ n_selective_ok = 0
130
+ confidences: list[float] = []
131
+ correct: list[bool] = []
132
+ for r in union:
133
+ abstained = bool(_get(r, "abstained"))
134
+ answerable_r = bool(_get(r, "answerable"))
135
+ answered_correct = (not abstained) and _correct(r)
136
+ # selective-accuracy: answer right (answerable) OR abstain right (ook).
137
+ if answerable_r:
138
+ if answered_correct:
139
+ n_selective_ok += 1
140
+ else: # ook -> the right move is to abstain
141
+ if abstained:
142
+ n_selective_ok += 1
143
+ # risk-coverage: coverage = answered, correctness = answered & right.
144
+ confidences.append(0.0 if abstained else 1.0)
145
+ correct.append(answered_correct)
146
+
147
+ selective_accuracy = n_selective_ok / len(union) if union else 0.0
148
+ aurc, rc_points = risk_coverage_aurc(confidences, correct)
149
+
150
+ return {
151
+ "abstain_recall": abstain_recall,
152
+ "false_abstain": false_abstain,
153
+ "selective_accuracy": selective_accuracy,
154
+ "aurc": aurc,
155
+ "rc_points": rc_points,
156
+ }
157
+
158
+
159
+ def novelty_recall(results: list[Any]) -> float:
160
+ """Fraction of **novel** items the agent flagged (``novelty_flag`` True).
161
+
162
+ Returns 0.0 when there are no novel items (no division by zero).
163
+ """
164
+ novel = [r for r in results if _get(r, "novel")]
165
+ if not novel:
166
+ return 0.0
167
+ hits = sum(1 for r in novel if _get(r, "novelty_flag"))
168
+ return hits / len(novel)
169
+
170
+
171
+ def novelty_f1(results: list[Any]) -> dict[str, float]:
172
+ """Precision / recall / F1 of the novelty flag (held-out positive, answerable negative).
173
+
174
+ Treats **novel** (held-out) items as positives and **answerable** (indexed)
175
+ items as negatives, so a novelty flag fired on an answerable case counts as a
176
+ false alarm. This penalises an agent that flags everything (recall 1.0 is
177
+ cheap; F1 is not). Returns ``precision`` / ``recall`` / ``f1`` / ``fpr``
178
+ (false-positive rate on answerable). Empty-pool fractions collapse to 0.0.
179
+ """
180
+ novel = [r for r in results if _get(r, "novel")]
181
+ answerable = [r for r in results if _get(r, "answerable")]
182
+
183
+ tp = sum(1 for r in novel if _get(r, "novelty_flag"))
184
+ fn = len(novel) - tp
185
+ fp = sum(1 for r in answerable if _get(r, "novelty_flag"))
186
+
187
+ recall = tp / len(novel) if novel else 0.0
188
+ precision = tp / (tp + fp) if (tp + fp) else 0.0
189
+ f1 = (
190
+ 2 * precision * recall / (precision + recall)
191
+ if (precision + recall)
192
+ else 0.0
193
+ )
194
+ fpr = fp / len(answerable) if answerable else 0.0
195
+ return {"precision": precision, "recall": recall, "f1": f1, "fpr": fpr}
196
+
197
+
198
+ def auroc(pos: list[float], neg: list[float]) -> float:
199
+ """Rank-based (Mann-Whitney U) AUROC that ``pos`` scores exceed ``neg`` scores.
200
+
201
+ Concordant pairs count 1, ties 0.5. ``pos`` are the should-answer (high-score)
202
+ class, ``neg`` the should-abstain (low-score) class. Assumes both non-empty.
203
+ Reused by the driver to score the calibration split and by
204
+ :func:`grounding_auroc` to score the test split.
205
+ """
206
+ wins = 0.0
207
+ for p in pos:
208
+ for n in neg:
209
+ if p > n:
210
+ wins += 1.0
211
+ elif p == n:
212
+ wins += 0.5
213
+ return wins / (len(pos) * len(neg))
214
+
215
+
216
+ def grounding_auroc(results: list[Any]) -> float | None:
217
+ """Threshold-free discrimination of the RAW grounding score: answerable vs held-out.
218
+
219
+ AUROC that the top structural grounding score is higher on **answerable**
220
+ (the gold disease is indexed -> should ground) than on **held-out** items
221
+ (not indexed -> should not ground). This is the intrinsic known-vs-unknown
222
+ signal, independent of where any abstention threshold is set. Returns ``None``
223
+ (N/A) when either group is empty or carries no grounding score (closed-book).
224
+ """
225
+ pos = [
226
+ s
227
+ for r in results
228
+ if _get(r, "answerable")
229
+ and (s := _get(r, "grounding_score")) is not None
230
+ ]
231
+ neg = [
232
+ s
233
+ for r in results
234
+ if not _get(r, "answerable")
235
+ and (s := _get(r, "grounding_score")) is not None
236
+ ]
237
+ if not pos or not neg:
238
+ return None
239
+ return auroc(pos, neg)