structuremappingmemory 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sma/__init__.py +5 -0
- sma/__main__.py +5 -0
- sma/agent/__init__.py +5 -0
- sma/agent/adapter_draft.py +217 -0
- sma/agent/api.py +67 -0
- sma/agent/comparison.py +591 -0
- sma/agent/llm.py +280 -0
- sma/agent/policies.py +21 -0
- sma/agent/service.py +95 -0
- sma/cli.py +65 -0
- sma/encoders/__init__.py +38 -0
- sma/encoders/agentobs.py +27 -0
- sma/encoders/base.py +23 -0
- sma/encoders/code_treesitter.py +64 -0
- sma/encoders/coverage.py +80 -0
- sma/encoders/draft_adapter.py +183 -0
- sma/encoders/healthcare.py +207 -0
- sma/encoders/logs_drain.py +142 -0
- sma/encoders/prose_tier1.py +57 -0
- sma/encoders/structured.py +57 -0
- sma/encoders/traces.py +45 -0
- sma/eval/__init__.py +2 -0
- sma/eval/agentic/__init__.py +35 -0
- sma/eval/agentic/arms/__init__.py +0 -0
- sma/eval/agentic/arms/cyber.py +48 -0
- sma/eval/agentic/arms/discovery.py +35 -0
- sma/eval/agentic/arms/finance.py +38 -0
- sma/eval/agentic/arms/legal.py +74 -0
- sma/eval/agentic/arms/medicine.py +45 -0
- sma/eval/agentic/harness.py +275 -0
- sma/eval/agentic/memories.py +308 -0
- sma/eval/agentic/metrics.py +82 -0
- sma/eval/agentic_qa/__init__.py +27 -0
- sma/eval/agentic_qa/agent.py +383 -0
- sma/eval/agentic_qa/metrics.py +239 -0
- sma/eval/agentic_qa/pools.py +197 -0
- sma/eval/arn.py +65 -0
- sma/eval/baselines/__init__.py +6 -0
- sma/eval/baselines/bge_dense.py +54 -0
- sma/eval/baselines/bm25.py +18 -0
- sma/eval/baselines/dense.py +42 -0
- sma/eval/baselines/hipporag.py +235 -0
- sma/eval/baselines/hybrid_rrf.py +30 -0
- sma/eval/baselines/longcontext_llm.py +124 -0
- sma/eval/baselines/rerank.py +41 -0
- sma/eval/baselines/splade.py +77 -0
- sma/eval/baselines/wl_kernel.py +163 -0
- sma/eval/bugsinpy.py +358 -0
- sma/eval/bugsinpy_families.py +164 -0
- sma/eval/crossdomain.py +89 -0
- sma/eval/diabetes.py +61 -0
- sma/eval/drift_env.py +26 -0
- sma/eval/drift_metrics.py +24 -0
- sma/eval/family_labels.py +167 -0
- sma/eval/fraud_elliptic/__init__.py +29 -0
- sma/eval/fraud_elliptic/encoder.py +279 -0
- sma/eval/fraud_elliptic/eval.py +269 -0
- sma/eval/fraud_elliptic/test_encoder.py +123 -0
- sma/eval/ieee_cis.py +66 -0
- sma/eval/loghub.py +16 -0
- sma/eval/loghub_eval.py +480 -0
- sma/eval/longmemeval.py +51 -0
- sma/eval/memory_backends/__init__.py +2 -0
- sma/eval/memory_backends/base.py +22 -0
- sma/eval/memory_backends/context_only.py +14 -0
- sma/eval/memory_backends/rag_notes.py +17 -0
- sma/eval/memory_backends/shared_llm.py +30 -0
- sma/eval/memory_backends/sma_memory.py +54 -0
- sma/eval/memory_backends/zep_graphiti.py +33 -0
- sma/eval/metrics.py +32 -0
- sma/eval/ontology_bench.py +219 -0
- sma/eval/report.py +573 -0
- sma/eval/ssb_eval.py +216 -0
- sma/eval/ssb_generator.py +116 -0
- sma/eval/stats.py +108 -0
- sma/eval/transfer_eval.py +844 -0
- sma/index/__init__.py +15 -0
- sma/index/ann.py +21 -0
- sma/index/content_vectors.py +60 -0
- sma/index/inverted.py +63 -0
- sma/index/macfac.py +174 -0
- sma/ir/__init__.py +22 -0
- sma/ir/canon.py +106 -0
- sma/ir/schema.py +165 -0
- sma/ir/sexpr.py +86 -0
- sma/ir/signatures.py +76 -0
- sma/match/__init__.py +20 -0
- sma/match/conflicts.py +46 -0
- sma/match/engine.py +60 -0
- sma/match/explain.py +59 -0
- sma/match/infer.py +54 -0
- sma/match/kernels.py +54 -0
- sma/match/mdl.py +30 -0
- sma/match/merge_cpsat.py +77 -0
- sma/match/merge_greedy.py +15 -0
- sma/match/mh.py +177 -0
- sma/match/ses.py +84 -0
- sma/match/types.py +115 -0
- sma/match/verifier.py +27 -0
- sma/ontology/__init__.py +45 -0
- sma/ontology/attack.py +134 -0
- sma/ontology/cpc.py +69 -0
- sma/ontology/graph.py +58 -0
- sma/ontology/loader.py +262 -0
- sma/ontology/mitre_xml.py +67 -0
- sma/ontology/mount.py +101 -0
- sma/ontology/rdf_loader.py +75 -0
- sma/ontology/registry.py +115 -0
- sma/ontology/router.py +69 -0
- sma/ontology/usgaap.py +73 -0
- sma/sage/__init__.py +6 -0
- sma/sage/assimilate.py +12 -0
- sma/sage/pools.py +105 -0
- sma/sage/probabilities.py +10 -0
- sma/store/__init__.py +6 -0
- sma/store/lmdb_store.py +78 -0
- sma/store/registry.py +26 -0
- sma/store/wal.py +26 -0
- sma/ui/app.py +642 -0
- structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
- structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
- structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
- structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
- structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
- structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
"""Uniform ``Memory`` interface + six retriever wrappers for the agentic suite.
|
|
2
|
+
|
|
3
|
+
The agentic harness holds everything fixed except the retrieval ``Memory``: SMA
|
|
4
|
+
(the universal ontology adapter) vs an enterprise-RAG/KG gauntlet (BM25, BGE
|
|
5
|
+
dense, Hybrid-RRF, Hybrid+Rerank, HippoRAG). Every memory implements the same
|
|
6
|
+
three-method contract so the harness never sees a retriever's internals:
|
|
7
|
+
|
|
8
|
+
index(items) -> None # build over IndexItem records
|
|
9
|
+
retrieve(query, k) -> list[Retrieved] # ranked, rank 1..k
|
|
10
|
+
novelty(query) -> float # in [0,1], higher = "this is new"
|
|
11
|
+
|
|
12
|
+
Confidence (on each :class:`Retrieved`) drives cite-or-abstain and is squashed
|
|
13
|
+
to ``[0,1]`` per method; novelty is the method's best out-of-distribution signal.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import math
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
from typing import Iterable, Protocol
|
|
21
|
+
|
|
22
|
+
from sma.eval.baselines.bm25 import rank_bm25_like
|
|
23
|
+
from sma.eval.baselines.hipporag import HippoRAGRetriever
|
|
24
|
+
from sma.index.macfac import MacFacIndex
|
|
25
|
+
from sma.ontology import MountedOntology
|
|
26
|
+
from sma.sage.pools import SagePool
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class IndexItem:
|
|
31
|
+
"""One indexable entity (the gold answer is its ``key``)."""
|
|
32
|
+
|
|
33
|
+
key: str # entity id (gold answer)
|
|
34
|
+
term_ids: frozenset[str] # ontology term ids (for SMA)
|
|
35
|
+
text: str # space-joined term NAMES (for text baselines)
|
|
36
|
+
meta: dict = field(default_factory=dict)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class Query:
|
|
41
|
+
"""A retrieval query in both ontology-term and text form."""
|
|
42
|
+
|
|
43
|
+
term_ids: frozenset[str]
|
|
44
|
+
text: str
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class Retrieved:
|
|
49
|
+
"""A single ranked hit returned by a :class:`Memory`."""
|
|
50
|
+
|
|
51
|
+
key: str
|
|
52
|
+
score: float
|
|
53
|
+
confidence: float # drives cite-or-abstain; method-specific, in [0,1]
|
|
54
|
+
rank: int
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Memory(Protocol):
|
|
58
|
+
"""The only thing that varies across harness runs."""
|
|
59
|
+
|
|
60
|
+
name: str
|
|
61
|
+
|
|
62
|
+
def index(self, items: list[IndexItem]) -> None: ...
|
|
63
|
+
|
|
64
|
+
def retrieve(self, query: Query, k: int) -> list[Retrieved]: ...
|
|
65
|
+
|
|
66
|
+
def novelty(self, query: Query) -> float: # higher = more "this is new"
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# ---------------------------------------------------------------------------
|
|
71
|
+
# Task 2: SMA memory (universal adapter)
|
|
72
|
+
# ---------------------------------------------------------------------------
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class SmaMemory:
|
|
76
|
+
"""SMA: mount an ontology, index cases via MacFac, novelty via SagePool."""
|
|
77
|
+
|
|
78
|
+
name = "sma"
|
|
79
|
+
|
|
80
|
+
def __init__(self, mounted: MountedOntology):
|
|
81
|
+
self.mounted = mounted
|
|
82
|
+
|
|
83
|
+
def index(self, items: list[IndexItem]) -> None:
|
|
84
|
+
self._key: dict[str, str] = {}
|
|
85
|
+
cases = []
|
|
86
|
+
self.pool = SagePool("agentic", assimilation_threshold=0.2)
|
|
87
|
+
for it in items:
|
|
88
|
+
c = self.mounted.build_case(it.term_ids, metadata={"key": it.key})
|
|
89
|
+
self._key[c.case_id] = it.key
|
|
90
|
+
cases.append(c)
|
|
91
|
+
self.pool.assimilate(c)
|
|
92
|
+
self.index_ = MacFacIndex(config=self.mounted.config, canon=self.mounted.canon)
|
|
93
|
+
self.index_.build(cases)
|
|
94
|
+
|
|
95
|
+
def retrieve(self, query: Query, k: int) -> list[Retrieved]:
|
|
96
|
+
qc = self.mounted.build_case(query.term_ids)
|
|
97
|
+
res = self.index_.retrieve(qc, k=k, shortlist=80, fac_budget=40)
|
|
98
|
+
if not res:
|
|
99
|
+
return []
|
|
100
|
+
top = max(r.score for r in res) or 1.0
|
|
101
|
+
out = []
|
|
102
|
+
for i, r in enumerate(res, 1):
|
|
103
|
+
conf = min(max(r.score / top, 0.0), 1.0)
|
|
104
|
+
out.append(Retrieved(self._key.get(r.case_id, ""), r.score, conf, i))
|
|
105
|
+
return out
|
|
106
|
+
|
|
107
|
+
def novelty(self, query: Query) -> float:
|
|
108
|
+
return self.pool.expectation_violation(self.mounted.build_case(query.term_ids))
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# ---------------------------------------------------------------------------
|
|
112
|
+
# Task 3: text baselines — BM25, Dense (BGE), HippoRAG
|
|
113
|
+
# ---------------------------------------------------------------------------
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class BM25Memory:
|
|
117
|
+
"""Lexical BM25-like baseline over term-name documents."""
|
|
118
|
+
|
|
119
|
+
name = "bm25"
|
|
120
|
+
|
|
121
|
+
def index(self, items: list[IndexItem]) -> None:
|
|
122
|
+
self._docs: list[tuple[str, str]] = [(it.key, it.text) for it in items]
|
|
123
|
+
|
|
124
|
+
def retrieve(self, query: Query, k: int) -> list[Retrieved]:
|
|
125
|
+
ranked = rank_bm25_like(query.text, self._docs, k=k)
|
|
126
|
+
if not ranked:
|
|
127
|
+
return []
|
|
128
|
+
top = ranked[0][1]
|
|
129
|
+
conf = top / (top + 1.0) if top > 0 else 0.0
|
|
130
|
+
return [Retrieved(key, score, conf, i) for i, (key, score) in enumerate(ranked, 1)]
|
|
131
|
+
|
|
132
|
+
def novelty(self, query: Query) -> float:
|
|
133
|
+
ranked = rank_bm25_like(query.text, self._docs, k=1)
|
|
134
|
+
if not ranked:
|
|
135
|
+
return 1.0
|
|
136
|
+
top = ranked[0][1]
|
|
137
|
+
conf = top / (top + 1.0) if top > 0 else 0.0
|
|
138
|
+
return 1.0 - conf
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
_BGE: dict = {}
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _bge():
|
|
145
|
+
"""Load the BGE-small embedder once (cached at module level)."""
|
|
146
|
+
if "m" not in _BGE:
|
|
147
|
+
from sentence_transformers import SentenceTransformer
|
|
148
|
+
|
|
149
|
+
_BGE["m"] = SentenceTransformer("BAAI/bge-small-en-v1.5")
|
|
150
|
+
return _BGE["m"]
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class DenseMemory:
|
|
154
|
+
"""Neural dense baseline: BGE-small embeddings + cosine similarity."""
|
|
155
|
+
|
|
156
|
+
name = "dense"
|
|
157
|
+
|
|
158
|
+
def index(self, items: list[IndexItem]) -> None:
|
|
159
|
+
self._keys = [it.key for it in items]
|
|
160
|
+
texts = [it.text for it in items]
|
|
161
|
+
model = _bge()
|
|
162
|
+
# normalize_embeddings=True -> cosine == dot product
|
|
163
|
+
self._mat = model.encode(texts, normalize_embeddings=True)
|
|
164
|
+
|
|
165
|
+
def _scores(self, query: Query):
|
|
166
|
+
model = _bge()
|
|
167
|
+
q = model.encode([query.text], normalize_embeddings=True)[0]
|
|
168
|
+
return [float(sum(q[d] * row[d] for d in range(len(q)))) for row in self._mat]
|
|
169
|
+
|
|
170
|
+
def retrieve(self, query: Query, k: int) -> list[Retrieved]:
|
|
171
|
+
if not self._keys:
|
|
172
|
+
return []
|
|
173
|
+
scores = self._scores(query)
|
|
174
|
+
order = sorted(range(len(scores)), key=lambda i: (-scores[i], self._keys[i]))[:k]
|
|
175
|
+
top = max(scores) if scores else 0.0
|
|
176
|
+
conf = min(max(top, 0.0), 1.0)
|
|
177
|
+
return [Retrieved(self._keys[i], scores[i], conf, rank) for rank, i in enumerate(order, 1)]
|
|
178
|
+
|
|
179
|
+
def novelty(self, query: Query) -> float:
|
|
180
|
+
if not self._keys:
|
|
181
|
+
return 1.0
|
|
182
|
+
top = max(self._scores(query))
|
|
183
|
+
return 1.0 - min(max(top, 0.0), 1.0)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class HippoMemory:
|
|
187
|
+
"""HippoRAG-2-style KG retrieval (phrase graph + Personalized PageRank)."""
|
|
188
|
+
|
|
189
|
+
name = "hipporag"
|
|
190
|
+
|
|
191
|
+
def index(self, items: list[IndexItem]) -> None:
|
|
192
|
+
self._retriever = HippoRAGRetriever()
|
|
193
|
+
self._retriever.build([(it.key, it.text) for it in items])
|
|
194
|
+
self._n = len(items)
|
|
195
|
+
|
|
196
|
+
def retrieve(self, query: Query, k: int) -> list[Retrieved]:
|
|
197
|
+
ranked = self._retriever.retrieve(query.text, k=k)
|
|
198
|
+
if not ranked:
|
|
199
|
+
return []
|
|
200
|
+
total = sum(s for _, s in ranked) or 1.0
|
|
201
|
+
conf = min(max(ranked[0][1] / total, 0.0), 1.0)
|
|
202
|
+
return [Retrieved(key, score, conf, i) for i, (key, score) in enumerate(ranked, 1)]
|
|
203
|
+
|
|
204
|
+
def novelty(self, query: Query) -> float:
|
|
205
|
+
ranked = self._retriever.retrieve(query.text, k=max(1, self._n))
|
|
206
|
+
if not ranked:
|
|
207
|
+
return 1.0
|
|
208
|
+
total = sum(s for _, s in ranked) or 1.0
|
|
209
|
+
conf = min(max(ranked[0][1] / total, 0.0), 1.0)
|
|
210
|
+
return 1.0 - conf
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
# ---------------------------------------------------------------------------
|
|
214
|
+
# Task 4: SOTA hybrid — Hybrid-RRF + Hybrid+Rerank
|
|
215
|
+
# ---------------------------------------------------------------------------
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class HybridRRFMemory:
|
|
219
|
+
"""Reciprocal Rank Fusion of a BM25 and a dense member memory."""
|
|
220
|
+
|
|
221
|
+
name = "hybrid_rrf"
|
|
222
|
+
|
|
223
|
+
def __init__(self, bm25_mem: BM25Memory, dense_mem: DenseMemory, k_rrf: int = 60):
|
|
224
|
+
self.bm25_mem = bm25_mem
|
|
225
|
+
self.dense_mem = dense_mem
|
|
226
|
+
self.k_rrf = k_rrf
|
|
227
|
+
|
|
228
|
+
def index(self, items: list[IndexItem]) -> None:
|
|
229
|
+
self.bm25_mem.index(items)
|
|
230
|
+
self.dense_mem.index(items)
|
|
231
|
+
self._n = len(items)
|
|
232
|
+
|
|
233
|
+
def _fused(self, query: Query, k: int) -> list[tuple[str, float]]:
|
|
234
|
+
depth = max(k, self._n)
|
|
235
|
+
fused: dict[str, float] = {}
|
|
236
|
+
for mem in (self.bm25_mem, self.dense_mem):
|
|
237
|
+
for r in mem.retrieve(query, depth):
|
|
238
|
+
fused[r.key] = fused.get(r.key, 0.0) + 1.0 / (self.k_rrf + r.rank)
|
|
239
|
+
return sorted(fused.items(), key=lambda kv: (-kv[1], kv[0]))[:k]
|
|
240
|
+
|
|
241
|
+
def retrieve(self, query: Query, k: int) -> list[Retrieved]:
|
|
242
|
+
ranked = self._fused(query, k)
|
|
243
|
+
if not ranked:
|
|
244
|
+
return []
|
|
245
|
+
top = ranked[0][1] or 1.0
|
|
246
|
+
return [
|
|
247
|
+
Retrieved(key, score, min(max(score / top, 0.0), 1.0), i)
|
|
248
|
+
for i, (key, score) in enumerate(ranked, 1)
|
|
249
|
+
]
|
|
250
|
+
|
|
251
|
+
def novelty(self, query: Query) -> float:
|
|
252
|
+
return min(self.bm25_mem.novelty(query), self.dense_mem.novelty(query))
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
_RERANKER: dict = {}
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _reranker(name: str = "BAAI/bge-reranker-base"):
|
|
259
|
+
"""Load the BGE cross-encoder reranker once (cached at module level)."""
|
|
260
|
+
if name not in _RERANKER:
|
|
261
|
+
from sentence_transformers import CrossEncoder
|
|
262
|
+
|
|
263
|
+
_RERANKER[name] = CrossEncoder(name)
|
|
264
|
+
return _RERANKER[name]
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _sigmoid(x: float) -> float:
|
|
268
|
+
return 1.0 / (1.0 + math.exp(-x))
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class HybridRerankMemory:
|
|
272
|
+
"""Cross-encoder reranking of a hybrid memory's top-n candidates."""
|
|
273
|
+
|
|
274
|
+
name = "hybrid_rerank"
|
|
275
|
+
|
|
276
|
+
def __init__(
|
|
277
|
+
self,
|
|
278
|
+
hybrid: HybridRRFMemory,
|
|
279
|
+
cross_encoder: str = "BAAI/bge-reranker-base",
|
|
280
|
+
top_n: int = 30,
|
|
281
|
+
):
|
|
282
|
+
self.hybrid = hybrid
|
|
283
|
+
self.cross_encoder = cross_encoder
|
|
284
|
+
self.top_n = top_n
|
|
285
|
+
|
|
286
|
+
def index(self, items: list[IndexItem]) -> None:
|
|
287
|
+
self.hybrid.index(items)
|
|
288
|
+
self._text = {it.key: it.text for it in items}
|
|
289
|
+
|
|
290
|
+
def retrieve(self, query: Query, k: int) -> list[Retrieved]:
|
|
291
|
+
candidates = self.hybrid.retrieve(query, self.top_n)
|
|
292
|
+
if not candidates:
|
|
293
|
+
return []
|
|
294
|
+
model = _reranker(self.cross_encoder)
|
|
295
|
+
pairs = [(query.text, self._text.get(c.key, "")) for c in candidates]
|
|
296
|
+
logits = model.predict(pairs)
|
|
297
|
+
scored = sorted(
|
|
298
|
+
zip(candidates, (float(s) for s in logits)),
|
|
299
|
+
key=lambda cs: (-cs[1], cs[0].key),
|
|
300
|
+
)[:k]
|
|
301
|
+
top_conf = _sigmoid(scored[0][1])
|
|
302
|
+
return [
|
|
303
|
+
Retrieved(c.key, logit, top_conf, i)
|
|
304
|
+
for i, (c, logit) in enumerate(scored, 1)
|
|
305
|
+
]
|
|
306
|
+
|
|
307
|
+
def novelty(self, query: Query) -> float:
|
|
308
|
+
return self.hybrid.novelty(query)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Metrics for the agentic ontology suite.
|
|
2
|
+
|
|
3
|
+
Three headline metrics, matching the design spec (prereg section 4):
|
|
4
|
+
|
|
5
|
+
* ``tail_topk`` — per-method top-k accuracy on ALL queries and on the
|
|
6
|
+
registered RARE slice (the tail). Top-k accuracy is the fraction of
|
|
7
|
+
rows whose true-entity rank is <= k, reusing the convention from
|
|
8
|
+
``sma.eval.ontology_bench``.
|
|
9
|
+
* ``risk_coverage_aurc`` — cite-or-abstain selective-prediction curve:
|
|
10
|
+
sort by confidence (desc), sweep coverage 0->1, risk is the cumulative
|
|
11
|
+
error rate over the covered head. AURC is the mean risk over the sweep;
|
|
12
|
+
LOWER is better (a well-calibrated ranker keeps its mistakes for last).
|
|
13
|
+
* ``novelty_f1`` — F1 of predicted-novel flags vs the truly held-out set.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
ABSENT_RANK = 999
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def tail_topk(
|
|
22
|
+
rows: list[dict],
|
|
23
|
+
k: int,
|
|
24
|
+
) -> dict[str, dict[str, float]]:
|
|
25
|
+
"""Per-method top-k accuracy on the ALL slice and the RARE slice.
|
|
26
|
+
|
|
27
|
+
``rows`` is a list of ``{method_name: rank, "rare": bool}`` dicts, where
|
|
28
|
+
``rank`` is the 1-based rank of the true entity (``ABSENT_RANK`` if it
|
|
29
|
+
never surfaced). Returns ``{method: {"all": acc, "rare": acc}}`` where
|
|
30
|
+
``acc`` is the fraction of rows with ``rank <= k``.
|
|
31
|
+
"""
|
|
32
|
+
methods = sorted({m for r in rows for m in r if m != "rare"})
|
|
33
|
+
rare_rows = [r for r in rows if r.get("rare")]
|
|
34
|
+
out: dict[str, dict[str, float]] = {}
|
|
35
|
+
for m in methods:
|
|
36
|
+
out[m] = {
|
|
37
|
+
"all": _topk_acc(rows, m, k),
|
|
38
|
+
"rare": _topk_acc(rare_rows, m, k),
|
|
39
|
+
}
|
|
40
|
+
return out
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _topk_acc(rows: list[dict], method: str, k: int) -> float:
|
|
44
|
+
"""Fraction of ``rows`` whose ``method`` rank is <= k."""
|
|
45
|
+
if not rows:
|
|
46
|
+
return 0.0
|
|
47
|
+
hits = sum(1 for r in rows if r.get(method, ABSENT_RANK) <= k)
|
|
48
|
+
return hits / len(rows)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def risk_coverage_aurc(
|
|
52
|
+
confidences: list[float],
|
|
53
|
+
correct: list[bool],
|
|
54
|
+
) -> tuple[float, list[tuple[float, float]]]:
|
|
55
|
+
"""Area under the risk-coverage curve (lower is better).
|
|
56
|
+
|
|
57
|
+
Items are sorted by ``confidences`` descending; coverage sweeps from the
|
|
58
|
+
most-confident prediction to all of them. At each coverage point the risk
|
|
59
|
+
is the cumulative error rate over the covered head. Returns the AURC (mean
|
|
60
|
+
risk over the sweep) and the ``(coverage, risk)`` curve points.
|
|
61
|
+
"""
|
|
62
|
+
n = len(correct)
|
|
63
|
+
if n == 0:
|
|
64
|
+
return 0.0, []
|
|
65
|
+
order = sorted(range(n), key=lambda i: -confidences[i])
|
|
66
|
+
cum_err = 0
|
|
67
|
+
pts: list[tuple[float, float]] = []
|
|
68
|
+
for j, i in enumerate(order, 1):
|
|
69
|
+
cum_err += 0 if correct[i] else 1
|
|
70
|
+
pts.append((j / n, cum_err / j)) # (coverage, risk)
|
|
71
|
+
aurc = sum(r for _, r in pts) / max(len(pts), 1)
|
|
72
|
+
return aurc, pts
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def novelty_f1(pred: list[bool], truth: list[bool]) -> float:
|
|
76
|
+
"""F1 of predicted-novel flags vs the truly held-out (novel) set."""
|
|
77
|
+
tp = sum(1 for p, t in zip(pred, truth) if p and t)
|
|
78
|
+
fp = sum(1 for p, t in zip(pred, truth) if p and not t)
|
|
79
|
+
fn = sum(1 for p, t in zip(pred, truth) if not p and t)
|
|
80
|
+
prec = tp / (tp + fp) if tp + fp else 0.0
|
|
81
|
+
rec = tp / (tp + fn) if tp + fn else 0.0
|
|
82
|
+
return 2 * prec * rec / (prec + rec) if prec + rec else 0.0
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Phase 5 LLM-QA harness: the one-shot agent + trustworthy-specialist metrics."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from sma.eval.agentic_qa.agent import MockLLM, QAAgent
|
|
6
|
+
from sma.eval.agentic_qa.metrics import (
|
|
7
|
+
abstention,
|
|
8
|
+
accuracy,
|
|
9
|
+
citation_faithfulness,
|
|
10
|
+
grounding_auroc,
|
|
11
|
+
novelty_f1,
|
|
12
|
+
novelty_recall,
|
|
13
|
+
)
|
|
14
|
+
from sma.eval.agentic_qa.pools import QAItem, build_pools
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"QAAgent",
|
|
18
|
+
"MockLLM",
|
|
19
|
+
"QAItem",
|
|
20
|
+
"build_pools",
|
|
21
|
+
"accuracy",
|
|
22
|
+
"citation_faithfulness",
|
|
23
|
+
"abstention",
|
|
24
|
+
"grounding_auroc",
|
|
25
|
+
"novelty_recall",
|
|
26
|
+
"novelty_f1",
|
|
27
|
+
]
|