structuremappingmemory 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sma/__init__.py +5 -0
- sma/__main__.py +5 -0
- sma/agent/__init__.py +5 -0
- sma/agent/adapter_draft.py +217 -0
- sma/agent/api.py +67 -0
- sma/agent/comparison.py +591 -0
- sma/agent/llm.py +280 -0
- sma/agent/policies.py +21 -0
- sma/agent/service.py +95 -0
- sma/cli.py +65 -0
- sma/encoders/__init__.py +38 -0
- sma/encoders/agentobs.py +27 -0
- sma/encoders/base.py +23 -0
- sma/encoders/code_treesitter.py +64 -0
- sma/encoders/coverage.py +80 -0
- sma/encoders/draft_adapter.py +183 -0
- sma/encoders/healthcare.py +207 -0
- sma/encoders/logs_drain.py +142 -0
- sma/encoders/prose_tier1.py +57 -0
- sma/encoders/structured.py +57 -0
- sma/encoders/traces.py +45 -0
- sma/eval/__init__.py +2 -0
- sma/eval/agentic/__init__.py +35 -0
- sma/eval/agentic/arms/__init__.py +0 -0
- sma/eval/agentic/arms/cyber.py +48 -0
- sma/eval/agentic/arms/discovery.py +35 -0
- sma/eval/agentic/arms/finance.py +38 -0
- sma/eval/agentic/arms/legal.py +74 -0
- sma/eval/agentic/arms/medicine.py +45 -0
- sma/eval/agentic/harness.py +275 -0
- sma/eval/agentic/memories.py +308 -0
- sma/eval/agentic/metrics.py +82 -0
- sma/eval/agentic_qa/__init__.py +27 -0
- sma/eval/agentic_qa/agent.py +383 -0
- sma/eval/agentic_qa/metrics.py +239 -0
- sma/eval/agentic_qa/pools.py +197 -0
- sma/eval/arn.py +65 -0
- sma/eval/baselines/__init__.py +6 -0
- sma/eval/baselines/bge_dense.py +54 -0
- sma/eval/baselines/bm25.py +18 -0
- sma/eval/baselines/dense.py +42 -0
- sma/eval/baselines/hipporag.py +235 -0
- sma/eval/baselines/hybrid_rrf.py +30 -0
- sma/eval/baselines/longcontext_llm.py +124 -0
- sma/eval/baselines/rerank.py +41 -0
- sma/eval/baselines/splade.py +77 -0
- sma/eval/baselines/wl_kernel.py +163 -0
- sma/eval/bugsinpy.py +358 -0
- sma/eval/bugsinpy_families.py +164 -0
- sma/eval/crossdomain.py +89 -0
- sma/eval/diabetes.py +61 -0
- sma/eval/drift_env.py +26 -0
- sma/eval/drift_metrics.py +24 -0
- sma/eval/family_labels.py +167 -0
- sma/eval/fraud_elliptic/__init__.py +29 -0
- sma/eval/fraud_elliptic/encoder.py +279 -0
- sma/eval/fraud_elliptic/eval.py +269 -0
- sma/eval/fraud_elliptic/test_encoder.py +123 -0
- sma/eval/ieee_cis.py +66 -0
- sma/eval/loghub.py +16 -0
- sma/eval/loghub_eval.py +480 -0
- sma/eval/longmemeval.py +51 -0
- sma/eval/memory_backends/__init__.py +2 -0
- sma/eval/memory_backends/base.py +22 -0
- sma/eval/memory_backends/context_only.py +14 -0
- sma/eval/memory_backends/rag_notes.py +17 -0
- sma/eval/memory_backends/shared_llm.py +30 -0
- sma/eval/memory_backends/sma_memory.py +54 -0
- sma/eval/memory_backends/zep_graphiti.py +33 -0
- sma/eval/metrics.py +32 -0
- sma/eval/ontology_bench.py +219 -0
- sma/eval/report.py +573 -0
- sma/eval/ssb_eval.py +216 -0
- sma/eval/ssb_generator.py +116 -0
- sma/eval/stats.py +108 -0
- sma/eval/transfer_eval.py +844 -0
- sma/index/__init__.py +15 -0
- sma/index/ann.py +21 -0
- sma/index/content_vectors.py +60 -0
- sma/index/inverted.py +63 -0
- sma/index/macfac.py +174 -0
- sma/ir/__init__.py +22 -0
- sma/ir/canon.py +106 -0
- sma/ir/schema.py +165 -0
- sma/ir/sexpr.py +86 -0
- sma/ir/signatures.py +76 -0
- sma/match/__init__.py +20 -0
- sma/match/conflicts.py +46 -0
- sma/match/engine.py +60 -0
- sma/match/explain.py +59 -0
- sma/match/infer.py +54 -0
- sma/match/kernels.py +54 -0
- sma/match/mdl.py +30 -0
- sma/match/merge_cpsat.py +77 -0
- sma/match/merge_greedy.py +15 -0
- sma/match/mh.py +177 -0
- sma/match/ses.py +84 -0
- sma/match/types.py +115 -0
- sma/match/verifier.py +27 -0
- sma/ontology/__init__.py +45 -0
- sma/ontology/attack.py +134 -0
- sma/ontology/cpc.py +69 -0
- sma/ontology/graph.py +58 -0
- sma/ontology/loader.py +262 -0
- sma/ontology/mitre_xml.py +67 -0
- sma/ontology/mount.py +101 -0
- sma/ontology/rdf_loader.py +75 -0
- sma/ontology/registry.py +115 -0
- sma/ontology/router.py +69 -0
- sma/ontology/usgaap.py +73 -0
- sma/sage/__init__.py +6 -0
- sma/sage/assimilate.py +12 -0
- sma/sage/pools.py +105 -0
- sma/sage/probabilities.py +10 -0
- sma/store/__init__.py +6 -0
- sma/store/lmdb_store.py +78 -0
- sma/store/registry.py +26 -0
- sma/store/wal.py +26 -0
- sma/ui/app.py +642 -0
- structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
- structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
- structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
- structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
- structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
- structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
"""The one-shot QA agent for the Phase 5 LLM-QA "trustworthy specialist" phase.
|
|
2
|
+
|
|
3
|
+
A single :class:`QAAgent` holds the LLM and prompt FIXED and swaps only the
|
|
4
|
+
retrieval ``Memory`` (none / dense-RAG / SMA), exactly as registered in
|
|
5
|
+
``configs/preregistration_v2_llmqa.md`` section 2. For each
|
|
6
|
+
:class:`~sma.eval.agentic_qa.pools.QAItem` it runs one agent turn and returns a
|
|
7
|
+
result dict carrying every field the trustworthy-QA metrics read
|
|
8
|
+
(``sma.eval.agentic_qa.metrics``): ``gold_id``, ``gold_name``, ``answerable``,
|
|
9
|
+
``novel``, ``abstained``, ``pred_id``, ``answer``, ``novelty_flag``,
|
|
10
|
+
``confidence``, ``grounding_score``.
|
|
11
|
+
|
|
12
|
+
Two grounding regimes:
|
|
13
|
+
|
|
14
|
+
* **grounded** (a memory is given) — retrieve top-k candidates, render them as a
|
|
15
|
+
numbered list, and ask the LLM for a strict one-line JSON ``{"choice": <n>}``
|
|
16
|
+
where ``n`` is a candidate number or ``0`` to abstain. ``pred_id`` is the
|
|
17
|
+
chosen candidate's key (the disease id), so correctness/citation can be checked
|
|
18
|
+
structurally against the gold. When a calibrated ``score_threshold`` is given,
|
|
19
|
+
a case whose top RAW grounding score falls below it is abstained AND flagged
|
|
20
|
+
novel *before* the LLM call (the structural score, not the saturated confidence
|
|
21
|
+
or the expectation-violation flag, is what separates known from unknown); with
|
|
22
|
+
no threshold the novelty flag falls back to ``memory.novelty(query)``.
|
|
23
|
+
* **closed-book** (``memory is None``) — the LLM answers from the case alone with
|
|
24
|
+
a strict one-line JSON ``{"diagnosis": "<name or ABSTAIN>"}``; ``pred_id`` is
|
|
25
|
+
``None`` (no citation), ``confidence`` is a flat ``0.5``, and novelty is N/A.
|
|
26
|
+
|
|
27
|
+
JSON parsing is defensive (strips ```` ``` ```` code fences, scans for the first
|
|
28
|
+
``{...}`` object) and falls back to ABSTAIN on any parse/validation failure, so a
|
|
29
|
+
malformed model reply degrades to the safe action rather than crashing the run.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from __future__ import annotations
|
|
33
|
+
|
|
34
|
+
import json
|
|
35
|
+
import re
|
|
36
|
+
from typing import Any, Protocol
|
|
37
|
+
|
|
38
|
+
from sma.eval.agentic import Query
|
|
39
|
+
from sma.eval.agentic_qa.pools import QAItem
|
|
40
|
+
|
|
41
|
+
# How many characteristic feature names to show per candidate (keeps the prompt
|
|
42
|
+
# bounded; the candidate is identified by its number, not by an exhaustive list).
|
|
43
|
+
_FEATURES_PER_CANDIDATE = 6
|
|
44
|
+
|
|
45
|
+
SYSTEM_PROMPT = (
|
|
46
|
+
"You are a careful diagnostic assistant. You are given a clinical case and a "
|
|
47
|
+
"numbered list of candidate diseases retrieved from a grounded knowledge base, "
|
|
48
|
+
"each with a few of its characteristic features. Choose the single candidate "
|
|
49
|
+
"whose characteristic features best match the case. Answer ONLY when a "
|
|
50
|
+
"candidate genuinely grounds the case; if none of the candidates fit, abstain. "
|
|
51
|
+
"Reply with STRICT one-line JSON and nothing else: "
|
|
52
|
+
'{"choice": <candidate number, or 0 for none / abstain>}.'
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
CLOSED_BOOK_SYSTEM_PROMPT = (
|
|
56
|
+
"You are a careful diagnostic assistant. You are given a clinical case and no "
|
|
57
|
+
"external knowledge. Name the single most likely disease, or abstain if you are "
|
|
58
|
+
"not confident. Reply with STRICT one-line JSON and nothing else: "
|
|
59
|
+
'{"diagnosis": "<disease name, or ABSTAIN>"}.'
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
ABSTAIN = "ABSTAIN"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class LLM(Protocol):
|
|
66
|
+
"""The fixed LLM backend (``DeepSeekOrchestrator`` or a mock in tests)."""
|
|
67
|
+
|
|
68
|
+
def complete(
|
|
69
|
+
self, messages: list[dict], max_tokens: int = 600, temperature: float = 0.0
|
|
70
|
+
) -> str: ...
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class MockLLM:
|
|
74
|
+
"""A deterministic stand-in for the real LLM (NEVER calls DeepSeek).
|
|
75
|
+
|
|
76
|
+
Used by the tests and the ``--mock`` driver so the whole harness can run with
|
|
77
|
+
zero API spend. By default it picks candidate ``1`` in the grounded regime and
|
|
78
|
+
echoes a fixed diagnosis closed-book; pass ``choice`` / ``diagnosis`` to script
|
|
79
|
+
other behaviours (e.g. ``choice=0`` to exercise the abstain path). When
|
|
80
|
+
``raw`` is set it is returned verbatim, to test defensive JSON parsing.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
choice: int = 1,
|
|
86
|
+
diagnosis: str = "Mock disease",
|
|
87
|
+
raw: str | None = None,
|
|
88
|
+
):
|
|
89
|
+
self.choice = choice
|
|
90
|
+
self.diagnosis = diagnosis
|
|
91
|
+
self.raw = raw
|
|
92
|
+
self.calls: list[list[dict]] = []
|
|
93
|
+
|
|
94
|
+
def complete(
|
|
95
|
+
self, messages: list[dict], max_tokens: int = 600, temperature: float = 0.0
|
|
96
|
+
) -> str:
|
|
97
|
+
self.calls.append(messages)
|
|
98
|
+
if self.raw is not None:
|
|
99
|
+
return self.raw
|
|
100
|
+
# Closed-book prompts ask for a "diagnosis" key; grounded ask for "choice".
|
|
101
|
+
system = messages[0]["content"] if messages else ""
|
|
102
|
+
if "diagnosis" in system:
|
|
103
|
+
return json.dumps({"diagnosis": self.diagnosis})
|
|
104
|
+
return json.dumps({"choice": self.choice})
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _strip_fences(text: str) -> str:
|
|
108
|
+
"""Drop Markdown code fences so JSON wrapped in ```` ```json ... ``` ```` parses."""
|
|
109
|
+
t = text.strip()
|
|
110
|
+
if t.startswith("```"):
|
|
111
|
+
# Remove the opening fence (with optional language tag) and closing fence.
|
|
112
|
+
t = re.sub(r"^```[a-zA-Z0-9]*\s*", "", t)
|
|
113
|
+
t = re.sub(r"\s*```$", "", t.strip())
|
|
114
|
+
return t.strip()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _parse_json_object(text: str) -> dict | None:
|
|
118
|
+
"""Best-effort parse of a single JSON object from a (possibly noisy) reply.
|
|
119
|
+
|
|
120
|
+
Tries the whole stripped string first, then falls back to the first balanced
|
|
121
|
+
``{...}`` substring. Returns ``None`` when nothing parses to a dict.
|
|
122
|
+
"""
|
|
123
|
+
stripped = _strip_fences(text)
|
|
124
|
+
for candidate in (stripped, _first_brace_object(stripped)):
|
|
125
|
+
if not candidate:
|
|
126
|
+
continue
|
|
127
|
+
try:
|
|
128
|
+
obj = json.loads(candidate)
|
|
129
|
+
except (json.JSONDecodeError, ValueError):
|
|
130
|
+
continue
|
|
131
|
+
if isinstance(obj, dict):
|
|
132
|
+
return obj
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _first_brace_object(text: str) -> str | None:
|
|
137
|
+
"""Return the first balanced ``{...}`` substring, or ``None``."""
|
|
138
|
+
start = text.find("{")
|
|
139
|
+
if start < 0:
|
|
140
|
+
return None
|
|
141
|
+
depth = 0
|
|
142
|
+
for i in range(start, len(text)):
|
|
143
|
+
c = text[i]
|
|
144
|
+
if c == "{":
|
|
145
|
+
depth += 1
|
|
146
|
+
elif c == "}":
|
|
147
|
+
depth -= 1
|
|
148
|
+
if depth == 0:
|
|
149
|
+
return text[start : i + 1]
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class QAAgent:
|
|
154
|
+
"""One-shot retrieve-then-answer agent with a swappable retrieval memory.
|
|
155
|
+
|
|
156
|
+
``memory`` is one of the frozen ``Memory`` retrievers (``SmaMemory`` /
|
|
157
|
+
``DenseMemory`` / ...) or ``None`` for the closed-book condition. ``key_to_name``
|
|
158
|
+
/ ``key_to_terms`` map an :class:`IndexItem` key (disease id) to its display
|
|
159
|
+
name and its ontology term ids, used to render the numbered candidate list;
|
|
160
|
+
pass the same maps that back the indexed knowledge. ``k`` is the retrieval
|
|
161
|
+
depth and ``novelty_threshold`` is the cut for the ``expectation_violation``
|
|
162
|
+
novelty flag (only meaningful for SMA).
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
def __init__(
|
|
166
|
+
self,
|
|
167
|
+
llm: LLM,
|
|
168
|
+
memory: Any | None,
|
|
169
|
+
*,
|
|
170
|
+
key_to_name: dict[str, str] | None = None,
|
|
171
|
+
key_to_terms: dict[str, frozenset[str]] | None = None,
|
|
172
|
+
k: int = 5,
|
|
173
|
+
novelty_threshold: float = 0.5,
|
|
174
|
+
score_threshold: float | None = None,
|
|
175
|
+
):
|
|
176
|
+
self.llm = llm
|
|
177
|
+
self.memory = memory
|
|
178
|
+
self.key_to_name = key_to_name or {}
|
|
179
|
+
self.key_to_terms = key_to_terms or {}
|
|
180
|
+
self.k = k
|
|
181
|
+
self.novelty_threshold = novelty_threshold
|
|
182
|
+
# Calibrated cite-or-abstain: the RAW structural grounding score (not the
|
|
183
|
+
# saturated normalized confidence, nor the expectation-violation flag — both
|
|
184
|
+
# of which fail to separate known/unknown, AUROC~0.48) is the abstention
|
|
185
|
+
# signal. Below this threshold the memory has no grounding -> abstain + flag
|
|
186
|
+
# novel, WITHOUT spending an LLM call. None = no gate (LLM-only abstention).
|
|
187
|
+
self.score_threshold = score_threshold
|
|
188
|
+
|
|
189
|
+
# -- rendering ----------------------------------------------------------
|
|
190
|
+
def _feature_text(self, key: str) -> str:
|
|
191
|
+
"""A few characteristic feature NAMES for a candidate disease."""
|
|
192
|
+
terms = sorted(self.key_to_terms.get(key, frozenset()))
|
|
193
|
+
names = [self._term_name(t) for t in terms[:_FEATURES_PER_CANDIDATE]]
|
|
194
|
+
return ", ".join(n for n in names if n)
|
|
195
|
+
|
|
196
|
+
def _term_name(self, term_id: str) -> str:
|
|
197
|
+
"""Resolve a term id to a human name via the SMA ontology when available."""
|
|
198
|
+
mounted = getattr(self.memory, "mounted", None)
|
|
199
|
+
if mounted is not None:
|
|
200
|
+
term = mounted.graph.terms.get(term_id)
|
|
201
|
+
if term is not None and term.name:
|
|
202
|
+
return term.name
|
|
203
|
+
return term_id
|
|
204
|
+
|
|
205
|
+
def _render_candidates(self, retrieved: list) -> tuple[str, list[str]]:
|
|
206
|
+
"""Build the numbered candidate block and the parallel key list.
|
|
207
|
+
|
|
208
|
+
Returns ``(text, keys)`` where ``keys[i]`` is the disease id of candidate
|
|
209
|
+
``i + 1`` (so a parsed ``{"choice": n}`` maps to ``keys[n - 1]``).
|
|
210
|
+
"""
|
|
211
|
+
lines: list[str] = []
|
|
212
|
+
keys: list[str] = []
|
|
213
|
+
for i, r in enumerate(retrieved, 1):
|
|
214
|
+
keys.append(r.key)
|
|
215
|
+
name = self.key_to_name.get(r.key, r.key)
|
|
216
|
+
features = self._feature_text(r.key)
|
|
217
|
+
feat = f" -- characteristic features: {features}" if features else ""
|
|
218
|
+
lines.append(f"[{i}] {name}{feat}")
|
|
219
|
+
return "\n".join(lines), keys
|
|
220
|
+
|
|
221
|
+
# -- answer -------------------------------------------------------------
|
|
222
|
+
def answer(self, item: QAItem) -> dict:
|
|
223
|
+
"""Run one agent turn over ``item`` and return the metrics result dict."""
|
|
224
|
+
if self.memory is None:
|
|
225
|
+
return self._answer_closed_book(item)
|
|
226
|
+
return self._answer_grounded(item)
|
|
227
|
+
|
|
228
|
+
def _result(
|
|
229
|
+
self,
|
|
230
|
+
item: QAItem,
|
|
231
|
+
*,
|
|
232
|
+
abstained: bool,
|
|
233
|
+
pred_id: str | None,
|
|
234
|
+
answer: str,
|
|
235
|
+
novelty_flag: bool,
|
|
236
|
+
confidence: float,
|
|
237
|
+
grounding_score: float | None,
|
|
238
|
+
) -> dict:
|
|
239
|
+
"""Assemble the per-item result dict the trustworthy-QA metrics read."""
|
|
240
|
+
return {
|
|
241
|
+
"gold_id": item.gold_id,
|
|
242
|
+
"gold_name": item.gold_name,
|
|
243
|
+
"answerable": item.answerable,
|
|
244
|
+
"novel": item.novel,
|
|
245
|
+
"abstained": abstained,
|
|
246
|
+
"pred_id": pred_id,
|
|
247
|
+
"answer": answer,
|
|
248
|
+
"novelty_flag": novelty_flag,
|
|
249
|
+
"confidence": confidence,
|
|
250
|
+
# The RAW top structural grounding score (None closed-book). This is
|
|
251
|
+
# the signal that actually separates known from unknown; the metrics
|
|
252
|
+
# use it for threshold-free discrimination AUROC.
|
|
253
|
+
"grounding_score": grounding_score,
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
def _answer_grounded(self, item: QAItem) -> dict:
|
|
257
|
+
query = Query(item.case_terms, item.case_text)
|
|
258
|
+
retrieved = self.memory.retrieve(query, self.k)
|
|
259
|
+
confidence = retrieved[0].confidence if retrieved else 0.0
|
|
260
|
+
grounding_score = retrieved[0].score if retrieved else 0.0
|
|
261
|
+
|
|
262
|
+
# Calibrated cite-or-abstain. If the top RAW grounding score is below the
|
|
263
|
+
# validation-calibrated threshold, the memory does not structurally ground
|
|
264
|
+
# this case -> ABSTAIN and FLAG NOVEL, WITHOUT spending an LLM call. The
|
|
265
|
+
# raw structural match score is the discriminating signal (answerable vs
|
|
266
|
+
# out-of-knowledge AUROC ~0.84); the squashed confidence (top hit always
|
|
267
|
+
# ~1.0) and the expectation-violation flag are not (AUROC ~0.48). A None
|
|
268
|
+
# threshold disables the gate -> pure LLM-mediated abstention (legacy).
|
|
269
|
+
if self.score_threshold is not None and grounding_score < self.score_threshold:
|
|
270
|
+
return self._result(
|
|
271
|
+
item,
|
|
272
|
+
abstained=True,
|
|
273
|
+
pred_id=None,
|
|
274
|
+
answer=ABSTAIN,
|
|
275
|
+
novelty_flag=True,
|
|
276
|
+
confidence=confidence,
|
|
277
|
+
grounding_score=grounding_score,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
candidates_text, keys = self._render_candidates(retrieved)
|
|
281
|
+
|
|
282
|
+
# With a calibrated gate, the structural signal IS the novelty signal:
|
|
283
|
+
# above threshold here -> not flagged. Without a gate, fall back to the
|
|
284
|
+
# memory's own expectation-violation novelty vs novelty_threshold.
|
|
285
|
+
if self.score_threshold is not None:
|
|
286
|
+
novelty_flag = False
|
|
287
|
+
else:
|
|
288
|
+
novelty_flag = bool(self.memory.novelty(query) > self.novelty_threshold)
|
|
289
|
+
|
|
290
|
+
user = (
|
|
291
|
+
f"Clinical case:\n{item.case_text}\n\n"
|
|
292
|
+
f"Candidate diseases:\n{candidates_text or '(none retrieved)'}\n\n"
|
|
293
|
+
"Rule: choose the candidate whose characteristic features best match "
|
|
294
|
+
"the case; answer only if a candidate genuinely grounds the case, "
|
|
295
|
+
"otherwise choose 0 to abstain.\n"
|
|
296
|
+
'Reply with STRICT one-line JSON: {"choice": <candidate number or 0>}.'
|
|
297
|
+
)
|
|
298
|
+
reply = self.llm.complete(
|
|
299
|
+
[
|
|
300
|
+
{"role": "system", "content": SYSTEM_PROMPT},
|
|
301
|
+
{"role": "user", "content": user},
|
|
302
|
+
],
|
|
303
|
+
max_tokens=600,
|
|
304
|
+
temperature=0.0,
|
|
305
|
+
)
|
|
306
|
+
choice = self._parse_choice(reply, n_candidates=len(keys))
|
|
307
|
+
|
|
308
|
+
if choice == 0:
|
|
309
|
+
pred_id: str | None = None
|
|
310
|
+
answer = ABSTAIN
|
|
311
|
+
abstained = True
|
|
312
|
+
else:
|
|
313
|
+
pred_id = keys[choice - 1]
|
|
314
|
+
answer = self.key_to_name.get(pred_id, pred_id)
|
|
315
|
+
abstained = False
|
|
316
|
+
|
|
317
|
+
return self._result(
|
|
318
|
+
item,
|
|
319
|
+
abstained=abstained,
|
|
320
|
+
pred_id=pred_id,
|
|
321
|
+
answer=answer,
|
|
322
|
+
novelty_flag=novelty_flag,
|
|
323
|
+
confidence=confidence,
|
|
324
|
+
grounding_score=grounding_score,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
def _answer_closed_book(self, item: QAItem) -> dict:
|
|
328
|
+
user = (
|
|
329
|
+
f"Clinical case:\n{item.case_text}\n\n"
|
|
330
|
+
"Name the single most likely disease, or abstain if not confident.\n"
|
|
331
|
+
'Reply with STRICT one-line JSON: {"diagnosis": "<disease name or ABSTAIN>"}.'
|
|
332
|
+
)
|
|
333
|
+
reply = self.llm.complete(
|
|
334
|
+
[
|
|
335
|
+
{"role": "system", "content": CLOSED_BOOK_SYSTEM_PROMPT},
|
|
336
|
+
{"role": "user", "content": user},
|
|
337
|
+
],
|
|
338
|
+
max_tokens=600,
|
|
339
|
+
temperature=0.0,
|
|
340
|
+
)
|
|
341
|
+
diagnosis = self._parse_diagnosis(reply)
|
|
342
|
+
abstained = diagnosis.strip().upper() == ABSTAIN
|
|
343
|
+
answer = ABSTAIN if abstained else diagnosis
|
|
344
|
+
|
|
345
|
+
return self._result(
|
|
346
|
+
item,
|
|
347
|
+
abstained=abstained,
|
|
348
|
+
pred_id=None,
|
|
349
|
+
answer=answer,
|
|
350
|
+
novelty_flag=False,
|
|
351
|
+
confidence=0.5,
|
|
352
|
+
grounding_score=None,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# -- parsing ------------------------------------------------------------
|
|
356
|
+
@staticmethod
|
|
357
|
+
def _parse_choice(reply: str, *, n_candidates: int) -> int:
|
|
358
|
+
"""Parse ``{"choice": n}`` -> int in ``0..n_candidates``; abstain on failure.
|
|
359
|
+
|
|
360
|
+
Any parse error, missing/ill-typed ``choice``, or out-of-range index
|
|
361
|
+
collapses to ``0`` (abstain), the safe action.
|
|
362
|
+
"""
|
|
363
|
+
obj = _parse_json_object(reply)
|
|
364
|
+
if obj is None or "choice" not in obj:
|
|
365
|
+
return 0
|
|
366
|
+
try:
|
|
367
|
+
choice = int(obj["choice"])
|
|
368
|
+
except (TypeError, ValueError):
|
|
369
|
+
return 0
|
|
370
|
+
if choice < 0 or choice > n_candidates:
|
|
371
|
+
return 0
|
|
372
|
+
return choice
|
|
373
|
+
|
|
374
|
+
@staticmethod
|
|
375
|
+
def _parse_diagnosis(reply: str) -> str:
|
|
376
|
+
"""Parse ``{"diagnosis": "..."}`` -> str; abstain on failure."""
|
|
377
|
+
obj = _parse_json_object(reply)
|
|
378
|
+
if obj is None:
|
|
379
|
+
return ABSTAIN
|
|
380
|
+
value = obj.get("diagnosis")
|
|
381
|
+
if not isinstance(value, str) or not value.strip():
|
|
382
|
+
return ABSTAIN
|
|
383
|
+
return value.strip()
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""Trustworthy-QA metrics for the Phase 5 LLM-QA harness (prereg v2 section 4).
|
|
2
|
+
|
|
3
|
+
Given per-item agent results, compute the four pre-registered axes that
|
|
4
|
+
distinguish a *verifiable specialist* from a confident-but-opaque RAG agent:
|
|
5
|
+
|
|
6
|
+
* :func:`accuracy` — answer correct on the **answerable** pool (the accuracy
|
|
7
|
+
floor; the capability gains must not cost accuracy).
|
|
8
|
+
* :func:`citation_faithfulness` — ALCE-style support score over **answered
|
|
9
|
+
answerable** items: did the cited candidate actually turn out to be the gold?
|
|
10
|
+
N/A (``None``) for the closed-book condition, which has no citation.
|
|
11
|
+
* :func:`abstention` — selective prediction over the union of **answerable**
|
|
12
|
+
(should answer) and **held-out / out-of-knowledge** (should abstain):
|
|
13
|
+
abstain-recall, false-abstain, selective-accuracy, plus the risk-coverage AURC
|
|
14
|
+
with confidence ``= 1 - abstain_flag``.
|
|
15
|
+
* :func:`grounding_auroc` — threshold-free discrimination of the RAW grounding
|
|
16
|
+
score: AUROC for separating answerable (high score) from held-out (low score).
|
|
17
|
+
The intrinsic "can the memory tell known from unknown" signal, independent of
|
|
18
|
+
where the abstention threshold sits.
|
|
19
|
+
* :func:`novelty_recall` / :func:`novelty_f1` — recall (and precision/F1 against
|
|
20
|
+
answerable false-alarms) of the novelty flag over the **novel** pool.
|
|
21
|
+
|
|
22
|
+
A result is a simple dict or object exposing: ``gold_id``, ``gold_name``,
|
|
23
|
+
``answerable``, ``novel``, ``abstained`` (bool), ``pred_id`` (str | None),
|
|
24
|
+
``answer`` (str), ``novelty_flag`` (bool), ``confidence`` (float),
|
|
25
|
+
``grounding_score`` (float | None). The data have two disjoint groups:
|
|
26
|
+
**answerable** (the gold disease IS indexed) and **held-out** (the gold disease
|
|
27
|
+
is NOT indexed). A held-out case is simultaneously out-of-knowledge (the agent
|
|
28
|
+
should ABSTAIN) and novel (the agent should FLAG it) — both correct trustworthy
|
|
29
|
+
behaviours on the same unindexed case — so abstention and novelty are scored on
|
|
30
|
+
the same held-out items (``answerable == False``, ``novel == True``).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from __future__ import annotations
|
|
34
|
+
|
|
35
|
+
from typing import Any
|
|
36
|
+
|
|
37
|
+
from sma.eval.agentic.metrics import risk_coverage_aurc
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get(result: Any, field: str, default: Any = None) -> Any:
|
|
41
|
+
"""Read ``field`` from a result whether it is a dict or an object."""
|
|
42
|
+
if isinstance(result, dict):
|
|
43
|
+
return result.get(field, default)
|
|
44
|
+
return getattr(result, field, default)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _correct(result: Any) -> bool:
|
|
48
|
+
"""Did the agent name the right entity? (grounded id-match, else name-match).
|
|
49
|
+
|
|
50
|
+
If the agent cited a candidate (``pred_id`` is not None), correctness is an
|
|
51
|
+
exact id match against the gold. For the closed-book condition (no
|
|
52
|
+
retrieval, ``pred_id`` is None) we fall back to a case-insensitive substring
|
|
53
|
+
name-match of the free-text ``answer`` against ``gold_name``.
|
|
54
|
+
"""
|
|
55
|
+
pred_id = _get(result, "pred_id")
|
|
56
|
+
if pred_id is not None:
|
|
57
|
+
return pred_id == _get(result, "gold_id")
|
|
58
|
+
answer = (_get(result, "answer") or "").strip().lower()
|
|
59
|
+
gold_name = (_get(result, "gold_name") or "").strip().lower()
|
|
60
|
+
if not answer or not gold_name:
|
|
61
|
+
return False
|
|
62
|
+
return gold_name in answer or answer in gold_name
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def accuracy(results: list[Any]) -> float:
|
|
66
|
+
"""Fraction of **answerable** items answered (not abstained) and correct.
|
|
67
|
+
|
|
68
|
+
Returns 0.0 when there are no answerable items (no division by zero).
|
|
69
|
+
"""
|
|
70
|
+
answerable = [r for r in results if _get(r, "answerable")]
|
|
71
|
+
if not answerable:
|
|
72
|
+
return 0.0
|
|
73
|
+
hits = sum(
|
|
74
|
+
1 for r in answerable if not _get(r, "abstained") and _correct(r)
|
|
75
|
+
)
|
|
76
|
+
return hits / len(answerable)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def citation_faithfulness(results: list[Any]) -> float | None:
|
|
80
|
+
"""Support score over **answered answerable** items with a citation.
|
|
81
|
+
|
|
82
|
+
Over answerable items that were answered (not abstained) *and* carry a
|
|
83
|
+
citation (``pred_id`` is not None), the fraction whose cited candidate is in
|
|
84
|
+
fact the gold (``pred_id == gold_id``). Items with no retrieval/citation are
|
|
85
|
+
skipped. Returns ``None`` (N/A) when no item is applicable — e.g. the
|
|
86
|
+
closed-book condition, where citation-faithfulness is undefined.
|
|
87
|
+
"""
|
|
88
|
+
cited = [
|
|
89
|
+
r
|
|
90
|
+
for r in results
|
|
91
|
+
if _get(r, "answerable")
|
|
92
|
+
and not _get(r, "abstained")
|
|
93
|
+
and _get(r, "pred_id") is not None
|
|
94
|
+
]
|
|
95
|
+
if not cited:
|
|
96
|
+
return None
|
|
97
|
+
hits = sum(1 for r in cited if _get(r, "pred_id") == _get(r, "gold_id"))
|
|
98
|
+
return hits / len(cited)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def abstention(results: list[Any]) -> dict[str, Any]:
|
|
102
|
+
"""Selective prediction over {answerable should-answer} + {held-out should-abstain}.
|
|
103
|
+
|
|
104
|
+
The should-abstain (out-of-knowledge) set is every **held-out** item — i.e.
|
|
105
|
+
``not answerable`` — because an unindexed disease is out-of-knowledge whether
|
|
106
|
+
or not it is also flagged novel (it always is, here). Returns a dict with:
|
|
107
|
+
|
|
108
|
+
* ``abstain_recall`` — fraction of ook items that abstained;
|
|
109
|
+
* ``false_abstain`` — fraction of answerable items that wrongly abstained;
|
|
110
|
+
* ``selective_accuracy`` — over the answerable+ook union, the fraction that
|
|
111
|
+
either answered correctly (answerable) or correctly abstained (ook);
|
|
112
|
+
* ``aurc`` / ``rc_points`` — risk-coverage curve over the same union with
|
|
113
|
+
``confidence = 1 - abstain_flag`` and ``correct = answered & correct``
|
|
114
|
+
(an abstain is never "correct" for the risk curve; a wrong answer is the
|
|
115
|
+
worst case, surfaced first at high confidence).
|
|
116
|
+
|
|
117
|
+
Empty pools yield 0.0 for their respective fractions (no division by zero).
|
|
118
|
+
"""
|
|
119
|
+
answerable = [r for r in results if _get(r, "answerable")]
|
|
120
|
+
ook = [r for r in results if not _get(r, "answerable")]
|
|
121
|
+
|
|
122
|
+
n_ook_abstain = sum(1 for r in ook if _get(r, "abstained"))
|
|
123
|
+
abstain_recall = n_ook_abstain / len(ook) if ook else 0.0
|
|
124
|
+
|
|
125
|
+
n_ans_abstain = sum(1 for r in answerable if _get(r, "abstained"))
|
|
126
|
+
false_abstain = n_ans_abstain / len(answerable) if answerable else 0.0
|
|
127
|
+
|
|
128
|
+
union = answerable + ook
|
|
129
|
+
n_selective_ok = 0
|
|
130
|
+
confidences: list[float] = []
|
|
131
|
+
correct: list[bool] = []
|
|
132
|
+
for r in union:
|
|
133
|
+
abstained = bool(_get(r, "abstained"))
|
|
134
|
+
answerable_r = bool(_get(r, "answerable"))
|
|
135
|
+
answered_correct = (not abstained) and _correct(r)
|
|
136
|
+
# selective-accuracy: answer right (answerable) OR abstain right (ook).
|
|
137
|
+
if answerable_r:
|
|
138
|
+
if answered_correct:
|
|
139
|
+
n_selective_ok += 1
|
|
140
|
+
else: # ook -> the right move is to abstain
|
|
141
|
+
if abstained:
|
|
142
|
+
n_selective_ok += 1
|
|
143
|
+
# risk-coverage: coverage = answered, correctness = answered & right.
|
|
144
|
+
confidences.append(0.0 if abstained else 1.0)
|
|
145
|
+
correct.append(answered_correct)
|
|
146
|
+
|
|
147
|
+
selective_accuracy = n_selective_ok / len(union) if union else 0.0
|
|
148
|
+
aurc, rc_points = risk_coverage_aurc(confidences, correct)
|
|
149
|
+
|
|
150
|
+
return {
|
|
151
|
+
"abstain_recall": abstain_recall,
|
|
152
|
+
"false_abstain": false_abstain,
|
|
153
|
+
"selective_accuracy": selective_accuracy,
|
|
154
|
+
"aurc": aurc,
|
|
155
|
+
"rc_points": rc_points,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def novelty_recall(results: list[Any]) -> float:
|
|
160
|
+
"""Fraction of **novel** items the agent flagged (``novelty_flag`` True).
|
|
161
|
+
|
|
162
|
+
Returns 0.0 when there are no novel items (no division by zero).
|
|
163
|
+
"""
|
|
164
|
+
novel = [r for r in results if _get(r, "novel")]
|
|
165
|
+
if not novel:
|
|
166
|
+
return 0.0
|
|
167
|
+
hits = sum(1 for r in novel if _get(r, "novelty_flag"))
|
|
168
|
+
return hits / len(novel)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def novelty_f1(results: list[Any]) -> dict[str, float]:
|
|
172
|
+
"""Precision / recall / F1 of the novelty flag (held-out positive, answerable negative).
|
|
173
|
+
|
|
174
|
+
Treats **novel** (held-out) items as positives and **answerable** (indexed)
|
|
175
|
+
items as negatives, so a novelty flag fired on an answerable case counts as a
|
|
176
|
+
false alarm. This penalises an agent that flags everything (recall 1.0 is
|
|
177
|
+
cheap; F1 is not). Returns ``precision`` / ``recall`` / ``f1`` / ``fpr``
|
|
178
|
+
(false-positive rate on answerable). Empty-pool fractions collapse to 0.0.
|
|
179
|
+
"""
|
|
180
|
+
novel = [r for r in results if _get(r, "novel")]
|
|
181
|
+
answerable = [r for r in results if _get(r, "answerable")]
|
|
182
|
+
|
|
183
|
+
tp = sum(1 for r in novel if _get(r, "novelty_flag"))
|
|
184
|
+
fn = len(novel) - tp
|
|
185
|
+
fp = sum(1 for r in answerable if _get(r, "novelty_flag"))
|
|
186
|
+
|
|
187
|
+
recall = tp / len(novel) if novel else 0.0
|
|
188
|
+
precision = tp / (tp + fp) if (tp + fp) else 0.0
|
|
189
|
+
f1 = (
|
|
190
|
+
2 * precision * recall / (precision + recall)
|
|
191
|
+
if (precision + recall)
|
|
192
|
+
else 0.0
|
|
193
|
+
)
|
|
194
|
+
fpr = fp / len(answerable) if answerable else 0.0
|
|
195
|
+
return {"precision": precision, "recall": recall, "f1": f1, "fpr": fpr}
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def auroc(pos: list[float], neg: list[float]) -> float:
|
|
199
|
+
"""Rank-based (Mann-Whitney U) AUROC that ``pos`` scores exceed ``neg`` scores.
|
|
200
|
+
|
|
201
|
+
Concordant pairs count 1, ties 0.5. ``pos`` are the should-answer (high-score)
|
|
202
|
+
class, ``neg`` the should-abstain (low-score) class. Assumes both non-empty.
|
|
203
|
+
Reused by the driver to score the calibration split and by
|
|
204
|
+
:func:`grounding_auroc` to score the test split.
|
|
205
|
+
"""
|
|
206
|
+
wins = 0.0
|
|
207
|
+
for p in pos:
|
|
208
|
+
for n in neg:
|
|
209
|
+
if p > n:
|
|
210
|
+
wins += 1.0
|
|
211
|
+
elif p == n:
|
|
212
|
+
wins += 0.5
|
|
213
|
+
return wins / (len(pos) * len(neg))
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def grounding_auroc(results: list[Any]) -> float | None:
|
|
217
|
+
"""Threshold-free discrimination of the RAW grounding score: answerable vs held-out.
|
|
218
|
+
|
|
219
|
+
AUROC that the top structural grounding score is higher on **answerable**
|
|
220
|
+
(the gold disease is indexed -> should ground) than on **held-out** items
|
|
221
|
+
(not indexed -> should not ground). This is the intrinsic known-vs-unknown
|
|
222
|
+
signal, independent of where any abstention threshold is set. Returns ``None``
|
|
223
|
+
(N/A) when either group is empty or carries no grounding score (closed-book).
|
|
224
|
+
"""
|
|
225
|
+
pos = [
|
|
226
|
+
s
|
|
227
|
+
for r in results
|
|
228
|
+
if _get(r, "answerable")
|
|
229
|
+
and (s := _get(r, "grounding_score")) is not None
|
|
230
|
+
]
|
|
231
|
+
neg = [
|
|
232
|
+
s
|
|
233
|
+
for r in results
|
|
234
|
+
if not _get(r, "answerable")
|
|
235
|
+
and (s := _get(r, "grounding_score")) is not None
|
|
236
|
+
]
|
|
237
|
+
if not pos or not neg:
|
|
238
|
+
return None
|
|
239
|
+
return auroc(pos, neg)
|