structuremappingmemory 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sma/__init__.py +5 -0
- sma/__main__.py +5 -0
- sma/agent/__init__.py +5 -0
- sma/agent/adapter_draft.py +217 -0
- sma/agent/api.py +67 -0
- sma/agent/comparison.py +591 -0
- sma/agent/llm.py +280 -0
- sma/agent/policies.py +21 -0
- sma/agent/service.py +95 -0
- sma/cli.py +65 -0
- sma/encoders/__init__.py +38 -0
- sma/encoders/agentobs.py +27 -0
- sma/encoders/base.py +23 -0
- sma/encoders/code_treesitter.py +64 -0
- sma/encoders/coverage.py +80 -0
- sma/encoders/draft_adapter.py +183 -0
- sma/encoders/healthcare.py +207 -0
- sma/encoders/logs_drain.py +142 -0
- sma/encoders/prose_tier1.py +57 -0
- sma/encoders/structured.py +57 -0
- sma/encoders/traces.py +45 -0
- sma/eval/__init__.py +2 -0
- sma/eval/agentic/__init__.py +35 -0
- sma/eval/agentic/arms/__init__.py +0 -0
- sma/eval/agentic/arms/cyber.py +48 -0
- sma/eval/agentic/arms/discovery.py +35 -0
- sma/eval/agentic/arms/finance.py +38 -0
- sma/eval/agentic/arms/legal.py +74 -0
- sma/eval/agentic/arms/medicine.py +45 -0
- sma/eval/agentic/harness.py +275 -0
- sma/eval/agentic/memories.py +308 -0
- sma/eval/agentic/metrics.py +82 -0
- sma/eval/agentic_qa/__init__.py +27 -0
- sma/eval/agentic_qa/agent.py +383 -0
- sma/eval/agentic_qa/metrics.py +239 -0
- sma/eval/agentic_qa/pools.py +197 -0
- sma/eval/arn.py +65 -0
- sma/eval/baselines/__init__.py +6 -0
- sma/eval/baselines/bge_dense.py +54 -0
- sma/eval/baselines/bm25.py +18 -0
- sma/eval/baselines/dense.py +42 -0
- sma/eval/baselines/hipporag.py +235 -0
- sma/eval/baselines/hybrid_rrf.py +30 -0
- sma/eval/baselines/longcontext_llm.py +124 -0
- sma/eval/baselines/rerank.py +41 -0
- sma/eval/baselines/splade.py +77 -0
- sma/eval/baselines/wl_kernel.py +163 -0
- sma/eval/bugsinpy.py +358 -0
- sma/eval/bugsinpy_families.py +164 -0
- sma/eval/crossdomain.py +89 -0
- sma/eval/diabetes.py +61 -0
- sma/eval/drift_env.py +26 -0
- sma/eval/drift_metrics.py +24 -0
- sma/eval/family_labels.py +167 -0
- sma/eval/fraud_elliptic/__init__.py +29 -0
- sma/eval/fraud_elliptic/encoder.py +279 -0
- sma/eval/fraud_elliptic/eval.py +269 -0
- sma/eval/fraud_elliptic/test_encoder.py +123 -0
- sma/eval/ieee_cis.py +66 -0
- sma/eval/loghub.py +16 -0
- sma/eval/loghub_eval.py +480 -0
- sma/eval/longmemeval.py +51 -0
- sma/eval/memory_backends/__init__.py +2 -0
- sma/eval/memory_backends/base.py +22 -0
- sma/eval/memory_backends/context_only.py +14 -0
- sma/eval/memory_backends/rag_notes.py +17 -0
- sma/eval/memory_backends/shared_llm.py +30 -0
- sma/eval/memory_backends/sma_memory.py +54 -0
- sma/eval/memory_backends/zep_graphiti.py +33 -0
- sma/eval/metrics.py +32 -0
- sma/eval/ontology_bench.py +219 -0
- sma/eval/report.py +573 -0
- sma/eval/ssb_eval.py +216 -0
- sma/eval/ssb_generator.py +116 -0
- sma/eval/stats.py +108 -0
- sma/eval/transfer_eval.py +844 -0
- sma/index/__init__.py +15 -0
- sma/index/ann.py +21 -0
- sma/index/content_vectors.py +60 -0
- sma/index/inverted.py +63 -0
- sma/index/macfac.py +174 -0
- sma/ir/__init__.py +22 -0
- sma/ir/canon.py +106 -0
- sma/ir/schema.py +165 -0
- sma/ir/sexpr.py +86 -0
- sma/ir/signatures.py +76 -0
- sma/match/__init__.py +20 -0
- sma/match/conflicts.py +46 -0
- sma/match/engine.py +60 -0
- sma/match/explain.py +59 -0
- sma/match/infer.py +54 -0
- sma/match/kernels.py +54 -0
- sma/match/mdl.py +30 -0
- sma/match/merge_cpsat.py +77 -0
- sma/match/merge_greedy.py +15 -0
- sma/match/mh.py +177 -0
- sma/match/ses.py +84 -0
- sma/match/types.py +115 -0
- sma/match/verifier.py +27 -0
- sma/ontology/__init__.py +45 -0
- sma/ontology/attack.py +134 -0
- sma/ontology/cpc.py +69 -0
- sma/ontology/graph.py +58 -0
- sma/ontology/loader.py +262 -0
- sma/ontology/mitre_xml.py +67 -0
- sma/ontology/mount.py +101 -0
- sma/ontology/rdf_loader.py +75 -0
- sma/ontology/registry.py +115 -0
- sma/ontology/router.py +69 -0
- sma/ontology/usgaap.py +73 -0
- sma/sage/__init__.py +6 -0
- sma/sage/assimilate.py +12 -0
- sma/sage/pools.py +105 -0
- sma/sage/probabilities.py +10 -0
- sma/store/__init__.py +6 -0
- sma/store/lmdb_store.py +78 -0
- sma/store/registry.py +26 -0
- sma/store/wal.py +26 -0
- sma/ui/app.py +642 -0
- structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
- structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
- structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
- structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
- structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
- structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Structured JSON/CSV/XML-ish Tier-0 encoder."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import csv
|
|
6
|
+
import io
|
|
7
|
+
import json
|
|
8
|
+
from collections.abc import Mapping
|
|
9
|
+
|
|
10
|
+
from sma.ir.schema import Statement, entity, make_case, stmt
|
|
11
|
+
|
|
12
|
+
from .base import EncodeResult
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class StructuredEncoder:
|
|
16
|
+
adapter_id = "structured"
|
|
17
|
+
version = "0.1.0"
|
|
18
|
+
|
|
19
|
+
def encode(self, artifact: str, **kwargs) -> EncodeResult:
|
|
20
|
+
fmt = kwargs.get("format") or infer_format(artifact)
|
|
21
|
+
statements: list[Statement] = []
|
|
22
|
+
if fmt == "json":
|
|
23
|
+
data = json.loads(artifact)
|
|
24
|
+
encode_json(data, statements, "root")
|
|
25
|
+
else:
|
|
26
|
+
reader = csv.DictReader(io.StringIO(artifact))
|
|
27
|
+
for i, row in enumerate(reader):
|
|
28
|
+
row_ent = entity(f"row_{i}", "row")
|
|
29
|
+
for key, value in row.items():
|
|
30
|
+
if value is not None and value != "":
|
|
31
|
+
statements.append(stmt(key, row_ent, entity(value, "value")))
|
|
32
|
+
case = make_case(statements or [stmt("emptyStructured", entity("root"))], {"adapter": self.adapter_id, "tier": 0})
|
|
33
|
+
return EncodeResult(case, ())
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def infer_format(artifact: str) -> str:
|
|
37
|
+
stripped = artifact.lstrip()
|
|
38
|
+
if stripped.startswith("{") or stripped.startswith("["):
|
|
39
|
+
return "json"
|
|
40
|
+
return "csv"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def encode_json(value, statements: list[Statement], path: str) -> None:
|
|
44
|
+
subject = entity(path, "json_node")
|
|
45
|
+
if isinstance(value, Mapping):
|
|
46
|
+
for key, child in sorted(value.items()):
|
|
47
|
+
child_path = f"{path}.{key}"
|
|
48
|
+
statements.append(stmt(str(key), subject, entity(child_path, "json_node")))
|
|
49
|
+
encode_json(child, statements, child_path)
|
|
50
|
+
elif isinstance(value, list):
|
|
51
|
+
for i, child in enumerate(value):
|
|
52
|
+
child_path = f"{path}.{i}"
|
|
53
|
+
statements.append(stmt("item", subject, entity(child_path, "json_node")))
|
|
54
|
+
encode_json(child, statements, child_path)
|
|
55
|
+
else:
|
|
56
|
+
statements.append(stmt("value", subject, entity(str(value), "value")))
|
|
57
|
+
|
sma/encoders/traces.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Stack trace and exception grammar encoder."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
|
|
7
|
+
from sma.ir.schema import Statement, entity, make_case, stmt
|
|
8
|
+
|
|
9
|
+
from .base import EncodeResult
|
|
10
|
+
|
|
11
|
+
FRAME_RE = re.compile(r'File "?(?P<file>[^",\n]+)"?, line (?P<line>\d+), in (?P<func>[A-Za-z_][\w.]*)')
|
|
12
|
+
JAVA_FRAME_RE = re.compile(r"\s*at (?P<class>[\w.$]+)\.(?P<func>\w+)\((?P<file>[^:]+):(?P<line>\d+)\)")
|
|
13
|
+
CAUSE_RE = re.compile(r"(?:(?:Caused by|The above exception).*?:\s*)?(?P<exc>[A-Za-z_][\w.]*Error|[A-Za-z_][\w.]*Exception)")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TraceEncoder:
|
|
17
|
+
adapter_id = "traces"
|
|
18
|
+
version = "0.1.0"
|
|
19
|
+
|
|
20
|
+
def encode(self, artifact: str, **kwargs) -> EncodeResult:
|
|
21
|
+
statements: list[Statement] = []
|
|
22
|
+
frames: list[Statement] = []
|
|
23
|
+
for i, line in enumerate(artifact.splitlines()):
|
|
24
|
+
frame = FRAME_RE.search(line) or JAVA_FRAME_RE.search(line)
|
|
25
|
+
if frame:
|
|
26
|
+
file = frame.group("file")
|
|
27
|
+
func = frame.group("func")
|
|
28
|
+
line_no = frame.group("line")
|
|
29
|
+
frame_stmt = stmt(
|
|
30
|
+
"frame",
|
|
31
|
+
entity(f"f{i}", "frame"),
|
|
32
|
+
entity(file, "file"),
|
|
33
|
+
entity(func, "function"),
|
|
34
|
+
entity(line_no, "line"),
|
|
35
|
+
)
|
|
36
|
+
frames.append(frame_stmt)
|
|
37
|
+
statements.append(frame_stmt)
|
|
38
|
+
cause = CAUSE_RE.search(line)
|
|
39
|
+
if cause:
|
|
40
|
+
statements.append(stmt("exception", entity(cause.group("exc"), "exception")))
|
|
41
|
+
for left, right in zip(frames, frames[1:], strict=False):
|
|
42
|
+
statements.append(stmt("calledFrom", left, right))
|
|
43
|
+
case = make_case(statements or [stmt("emptyTrace", entity("trace_0"))], {"adapter": self.adapter_id, "tier": 0})
|
|
44
|
+
return EncodeResult(case, ())
|
|
45
|
+
|
sma/eval/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Agentic ontology benchmark harness (memory-swap suite)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from sma.eval.agentic.memories import (
|
|
6
|
+
BM25Memory,
|
|
7
|
+
DenseMemory,
|
|
8
|
+
HippoMemory,
|
|
9
|
+
HybridRerankMemory,
|
|
10
|
+
HybridRRFMemory,
|
|
11
|
+
IndexItem,
|
|
12
|
+
Memory,
|
|
13
|
+
Query,
|
|
14
|
+
Retrieved,
|
|
15
|
+
SmaMemory,
|
|
16
|
+
)
|
|
17
|
+
from sma.eval.agentic.metrics import novelty_f1, risk_coverage_aurc, tail_topk
|
|
18
|
+
from sma.eval.agentic.harness import run_oneshot
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"run_oneshot",
|
|
22
|
+
"tail_topk",
|
|
23
|
+
"risk_coverage_aurc",
|
|
24
|
+
"novelty_f1",
|
|
25
|
+
"IndexItem",
|
|
26
|
+
"Query",
|
|
27
|
+
"Retrieved",
|
|
28
|
+
"Memory",
|
|
29
|
+
"SmaMemory",
|
|
30
|
+
"BM25Memory",
|
|
31
|
+
"DenseMemory",
|
|
32
|
+
"HippoMemory",
|
|
33
|
+
"HybridRRFMemory",
|
|
34
|
+
"HybridRerankMemory",
|
|
35
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Cyber arm: threat-group attribution over MITRE ATT&CK.
|
|
2
|
+
|
|
3
|
+
Entities are intrusion-sets (threat groups); each is annotated with the set of
|
|
4
|
+
ATT&CK techniques it is documented to USE (gold, from the STIX ``uses``
|
|
5
|
+
relationships). The hard query simulates a partial/imprecise incident (a few
|
|
6
|
+
observed TTPs, climbed up the sub-technique lattice, plus noise); the task is to
|
|
7
|
+
attribute it to the right group. ATT&CK's sub-technique tree is the is-a lattice.
|
|
8
|
+
"""
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import pathlib
|
|
13
|
+
|
|
14
|
+
from sma.ontology import load_attack_stix, mount
|
|
15
|
+
|
|
16
|
+
ROOT = pathlib.Path(__file__).resolve().parents[4]
|
|
17
|
+
STIX = ROOT / "data/raw/attack/enterprise-attack.json"
|
|
18
|
+
# Band matches the medicine arm: >=7 techniques to be non-trivial, <=30 to keep
|
|
19
|
+
# structure-mapping cases tractable (prolific groups with 100+ techniques blow up
|
|
20
|
+
# SME kernel enumeration). Excludes the most-documented groups; n reported.
|
|
21
|
+
MIN_TERMS, MAX_TERMS = 7, 30
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def load():
|
|
25
|
+
graph = load_attack_stix(str(STIX), name="attack")
|
|
26
|
+
mounted = mount(graph)
|
|
27
|
+
bundle = json.loads(STIX.read_text())
|
|
28
|
+
objs = bundle["objects"]
|
|
29
|
+
|
|
30
|
+
# stix-id -> ATT&CK external_id (e.g. "T1059.001"), for attack-patterns.
|
|
31
|
+
ext: dict[str, str] = {}
|
|
32
|
+
for o in objs:
|
|
33
|
+
if o.get("type") == "attack-pattern":
|
|
34
|
+
for ref in o.get("external_references", []):
|
|
35
|
+
if ref.get("source_name") == "mitre-attack" and ref.get("external_id"):
|
|
36
|
+
ext[o["id"]] = ref["external_id"]
|
|
37
|
+
groups = {o["id"]: o.get("name", o["id"])
|
|
38
|
+
for o in objs if o.get("type") == "intrusion-set" and not o.get("revoked")}
|
|
39
|
+
|
|
40
|
+
recs: dict[str, set[str]] = {}
|
|
41
|
+
for o in objs:
|
|
42
|
+
if o.get("type") == "relationship" and o.get("relationship_type") == "uses":
|
|
43
|
+
s, t = o.get("source_ref"), o.get("target_ref")
|
|
44
|
+
if s in groups and t in ext and ext[t] in graph.terms:
|
|
45
|
+
recs.setdefault(groups[s], set()).add(ext[t])
|
|
46
|
+
|
|
47
|
+
records = {g: ts for g, ts in recs.items() if MIN_TERMS <= len(ts) <= MAX_TERMS}
|
|
48
|
+
return mounted, records
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Discovery arm: gene-function retrieval over the Gene Ontology (GO).
|
|
2
|
+
|
|
3
|
+
Entities are human proteins; each is annotated with its set of GO biological-process
|
|
4
|
+
terms (gold, from the GOA GAF). The hard query is a partial/imprecise functional
|
|
5
|
+
profile; the task is to retrieve the right protein by its function signature. GO's
|
|
6
|
+
is-a tree is the ascension lattice and its part_of/regulates relations become
|
|
7
|
+
higher-order statements.
|
|
8
|
+
"""
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import pathlib
|
|
12
|
+
|
|
13
|
+
from sma.ontology import load_obo, mount
|
|
14
|
+
|
|
15
|
+
ROOT = pathlib.Path(__file__).resolve().parents[4]
|
|
16
|
+
GO_OBO = ROOT / "data/raw/obo/go-basic.obo"
|
|
17
|
+
GAF = ROOT / "data/raw/go/goa_human.gaf"
|
|
18
|
+
ASPECT = "P" # biological process
|
|
19
|
+
MIN_TERMS, MAX_TERMS = 7, 30
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def load():
|
|
23
|
+
mounted = mount(load_obo(str(GO_OBO), name="go"))
|
|
24
|
+
known = mounted.graph.terms
|
|
25
|
+
recs: dict[str, set[str]] = {}
|
|
26
|
+
for line in GAF.open():
|
|
27
|
+
if line.startswith("!"):
|
|
28
|
+
continue
|
|
29
|
+
p = line.rstrip("\n").split("\t")
|
|
30
|
+
if len(p) < 9 or p[8] != ASPECT:
|
|
31
|
+
continue
|
|
32
|
+
if p[4] in known:
|
|
33
|
+
recs.setdefault(p[1], set()).add(p[4]) # protein -> GO term
|
|
34
|
+
records = {k: v for k, v in recs.items() if MIN_TERMS <= len(v) <= MAX_TERMS}
|
|
35
|
+
return mounted, records
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Finance arm: SEC filing classification over the US-GAAP reporting taxonomy.
|
|
2
|
+
|
|
3
|
+
Entities are SEC filings (accession numbers); each is annotated with the set of
|
|
4
|
+
US-GAAP concepts it reports (gold, from the SEC Financial Statement Data Sets'
|
|
5
|
+
num.txt). The hard query is a partial/imprecise set of reported concepts climbed
|
|
6
|
+
up the statement hierarchy plus noise; the task is to retrieve the filing by its
|
|
7
|
+
financial-reporting signature. The US-GAAP presentation hierarchy (statement
|
|
8
|
+
headers subsume line items) is the is-a lattice (FIBO, a schema, has no instance
|
|
9
|
+
corpus, so US-GAAP provides the financial entity->concept gold).
|
|
10
|
+
"""
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import pathlib
|
|
14
|
+
|
|
15
|
+
from sma.ontology import mount
|
|
16
|
+
from sma.ontology.usgaap import load_usgaap
|
|
17
|
+
|
|
18
|
+
ROOT = pathlib.Path(__file__).resolve().parents[4]
|
|
19
|
+
USGAAP_DIR = ROOT / "data/raw/finance/usgaap"
|
|
20
|
+
NUM = ROOT / "data/raw/finance/sec_2024q1/num.txt"
|
|
21
|
+
MIN_TERMS, MAX_TERMS = 10, 40
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def load():
|
|
25
|
+
mounted = mount(load_usgaap(str(USGAAP_DIR), name="usgaap"))
|
|
26
|
+
known = mounted.graph.terms
|
|
27
|
+
recs: dict[str, set[str]] = {}
|
|
28
|
+
with NUM.open(encoding="utf-8", errors="ignore") as fh:
|
|
29
|
+
header = fh.readline().rstrip("\n").split("\t")
|
|
30
|
+
ia, it, iv = header.index("adsh"), header.index("tag"), header.index("version")
|
|
31
|
+
for line in fh:
|
|
32
|
+
p = line.rstrip("\n").split("\t")
|
|
33
|
+
if len(p) <= iv or not p[iv].startswith("us-gaap"):
|
|
34
|
+
continue
|
|
35
|
+
if p[it] in known:
|
|
36
|
+
recs.setdefault(p[ia], set()).add(p[it])
|
|
37
|
+
records = {k: v for k, v in recs.items() if MIN_TERMS <= len(v) <= MAX_TERMS}
|
|
38
|
+
return mounted, records
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Legal/IP arm: patent classification over the Cooperative Patent Classification.
|
|
2
|
+
|
|
3
|
+
Entities are granted patents; each is annotated with the set of CPC subgroup codes
|
|
4
|
+
its examiner assigned (gold, from one week of USPTO grant full-text XML). The hard
|
|
5
|
+
query is a partial/imprecise technical profile (a few codes, climbed up the CPC
|
|
6
|
+
hierarchy, plus noise); the task is to retrieve the patent by its classification
|
|
7
|
+
signature. CPC's deep section->class->subclass->group->subgroup tree is the is-a
|
|
8
|
+
lattice (254k nodes), mounted via sma.ontology.load_cpc.
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import pathlib
|
|
13
|
+
import re
|
|
14
|
+
import xml.etree.ElementTree as ET
|
|
15
|
+
|
|
16
|
+
from sma.ontology import load_cpc, mount
|
|
17
|
+
|
|
18
|
+
ROOT = pathlib.Path(__file__).resolve().parents[4]
|
|
19
|
+
CPC_DIR = ROOT / "data/raw/legal/cpc"
|
|
20
|
+
GRANT_XML = ROOT / "data/raw/patents/ipg161011.xml"
|
|
21
|
+
MIN_TERMS, MAX_TERMS = 7, 30
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _grant_blocks(path):
|
|
25
|
+
"""Yield each <us-patent-grant>...</us-patent-grant> block (the file is many
|
|
26
|
+
concatenated XML documents, so we can't parse it as one tree)."""
|
|
27
|
+
buf, inside = [], False
|
|
28
|
+
with open(path, "r", encoding="utf-8", errors="ignore") as fh:
|
|
29
|
+
for line in fh:
|
|
30
|
+
if line.startswith("<us-patent-grant"):
|
|
31
|
+
inside, buf = True, [line]
|
|
32
|
+
elif inside:
|
|
33
|
+
buf.append(line)
|
|
34
|
+
if line.startswith("</us-patent-grant>"):
|
|
35
|
+
yield "".join(buf)
|
|
36
|
+
inside, buf = False, []
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _cpc_codes(grant_el) -> set[str]:
|
|
40
|
+
"""CPC subgroup codes from the bibliographic block only (not cited references)."""
|
|
41
|
+
block = grant_el.find(".//us-bibliographic-data-grant/classifications-cpc")
|
|
42
|
+
if block is None:
|
|
43
|
+
return set()
|
|
44
|
+
codes = set()
|
|
45
|
+
for c in block.iter("classification-cpc"):
|
|
46
|
+
try:
|
|
47
|
+
sec = c.findtext("section", "").strip()
|
|
48
|
+
cls = c.findtext("class", "").strip()
|
|
49
|
+
sub = c.findtext("subclass", "").strip()
|
|
50
|
+
mg = c.findtext("main-group", "").strip()
|
|
51
|
+
sg = c.findtext("subgroup", "").strip()
|
|
52
|
+
except AttributeError:
|
|
53
|
+
continue
|
|
54
|
+
if sec and cls and sub and mg and sg:
|
|
55
|
+
codes.add(f"{sec}{cls}{sub}{mg}/{sg}")
|
|
56
|
+
return codes
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def load():
|
|
60
|
+
mounted = mount(load_cpc(str(CPC_DIR), name="cpc"))
|
|
61
|
+
known = mounted.graph.terms
|
|
62
|
+
recs: dict[str, set[str]] = {}
|
|
63
|
+
for block in _grant_blocks(GRANT_XML):
|
|
64
|
+
try:
|
|
65
|
+
g = ET.fromstring(block)
|
|
66
|
+
except ET.ParseError:
|
|
67
|
+
continue
|
|
68
|
+
num = g.findtext(".//publication-reference//doc-number")
|
|
69
|
+
if not num:
|
|
70
|
+
continue
|
|
71
|
+
codes = {c for c in _cpc_codes(g) if c in known}
|
|
72
|
+
if MIN_TERMS <= len(codes) <= MAX_TERMS:
|
|
73
|
+
recs[f"US{num}"] = codes
|
|
74
|
+
return mounted, recs
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Medicine arm: HPO ontology + rare-disease phenotype records.
|
|
2
|
+
|
|
3
|
+
``load()`` mounts the Human Phenotype Ontology and parses ``phenotype.hpoa`` into
|
|
4
|
+
``disease_id -> {hpo_term_id}`` records, restricted to phenotypic-abnormality
|
|
5
|
+
(aspect ``P``) annotations of diseases carrying 7..30 phenotypes -- the same
|
|
6
|
+
record-construction used by ``scripts/bench_ontology_suite.load_hpo_records`` and
|
|
7
|
+
the 7..30 eligibility band of ``sma.eval.ontology_bench.run_arm``.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import pathlib
|
|
13
|
+
|
|
14
|
+
from sma.ontology import MountedOntology, load_obo, mount
|
|
15
|
+
|
|
16
|
+
ROOT = pathlib.Path(__file__).resolve().parents[4]
|
|
17
|
+
HP_OBO = ROOT / "data/raw/hpo/hp.obo"
|
|
18
|
+
HPOA = ROOT / "data/raw/hpo/phenotype.hpoa"
|
|
19
|
+
|
|
20
|
+
MIN_TERMS = 7
|
|
21
|
+
MAX_TERMS = 30
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def load_hpo_records(path: pathlib.Path = HPOA) -> dict[str, set[str]]:
|
|
25
|
+
"""Parse ``phenotype.hpoa`` into ``disease_id -> {hpo_term_id}`` (aspect P).
|
|
26
|
+
|
|
27
|
+
Skips header/comment lines, keeps only phenotypic-abnormality annotations
|
|
28
|
+
(column 10 == ``"P"``), and retains diseases with 7..30 phenotypes.
|
|
29
|
+
"""
|
|
30
|
+
rec: dict[str, set[str]] = {}
|
|
31
|
+
for line in path.open():
|
|
32
|
+
if line.startswith(("#", "database_id")):
|
|
33
|
+
continue
|
|
34
|
+
p = line.rstrip("\n").split("\t")
|
|
35
|
+
if len(p) < 11 or p[10] != "P":
|
|
36
|
+
continue
|
|
37
|
+
rec.setdefault(p[0], set()).add(p[3])
|
|
38
|
+
return {d: terms for d, terms in rec.items() if MIN_TERMS <= len(terms) <= MAX_TERMS}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def load() -> tuple[MountedOntology, dict[str, set[str]]]:
|
|
42
|
+
"""Return the mounted HPO ontology and its disease->phenotype records."""
|
|
43
|
+
mounted = mount(load_obo(str(HP_OBO), name="hpo"))
|
|
44
|
+
records = load_hpo_records()
|
|
45
|
+
return mounted, records
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""One-shot agentic harness: swap the :class:`Memory`, hold everything else fixed.
|
|
2
|
+
|
|
3
|
+
``run_oneshot`` builds one benchmark from an arm ``(mounted ontology, entity ->
|
|
4
|
+
term-set records)`` and scores every memory identically:
|
|
5
|
+
|
|
6
|
+
* a deterministic ``holdout_frac`` of entities is reserved as NOVEL (their
|
|
7
|
+
queries are unanswerable and feed the abstain/novelty metrics); only the rest
|
|
8
|
+
is indexed in every memory;
|
|
9
|
+
* hard queries are generated for both answerable and novel entities (sample a
|
|
10
|
+
few terms, climb 0-2 is-a levels for imprecision, add noise terms);
|
|
11
|
+
* each memory retrieves the true key's rank, a confidence, and a novelty signal;
|
|
12
|
+
* metrics: tail top-k (answerable; all + rare slice), risk-coverage AURC
|
|
13
|
+
(answerable), novelty F1 over ALL queries (novelty > 0.5 vs truly novel);
|
|
14
|
+
* primary stat: SMA vs the best enterprise-RAG memory (best by tail top-5) on
|
|
15
|
+
per-query top-5 correctness via paired bootstrap.
|
|
16
|
+
|
|
17
|
+
Determinism: every set is sorted to a list before use and every RNG is
|
|
18
|
+
explicitly seeded, so identical ``seeds`` yield identical result dicts.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import math
|
|
24
|
+
import random
|
|
25
|
+
import statistics
|
|
26
|
+
from typing import Iterable
|
|
27
|
+
|
|
28
|
+
from sma.eval.agentic.memories import IndexItem, Memory, Query
|
|
29
|
+
from sma.eval.agentic.metrics import (
|
|
30
|
+
ABSENT_RANK,
|
|
31
|
+
novelty_f1,
|
|
32
|
+
risk_coverage_aurc,
|
|
33
|
+
tail_topk,
|
|
34
|
+
)
|
|
35
|
+
from sma.eval.stats import cliffs_delta, paired_bootstrap
|
|
36
|
+
from sma.ontology import MountedOntology
|
|
37
|
+
|
|
38
|
+
# Enterprise-RAG/KG gauntlet — SMA's primary comparison is the best of these.
|
|
39
|
+
ENTERPRISE_NAMES = ("bm25", "dense", "hybrid_rrf", "hybrid_rerank", "hippo")
|
|
40
|
+
|
|
41
|
+
NOVELTY_THRESHOLD = 0.5 # fixed (not per-method tuned); noted caveat in the spec.
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# --- ontology IC machinery (closure-propagated term frequency) -------------
|
|
45
|
+
# Mirrors sma/eval/ontology_bench.py so the "rare" slice is defined identically.
|
|
46
|
+
def _ancestors(term: str, parents: dict[str, tuple[str, ...]], cache: dict[str, set]) -> set:
|
|
47
|
+
if term in cache:
|
|
48
|
+
return cache[term]
|
|
49
|
+
# Iterative closure with a visited set: cycle-safe (some ontologies, e.g. CPC
|
|
50
|
+
# built from repeated scheme symbols, contain is-a cycles) and immune to the
|
|
51
|
+
# recursion limit on deep hierarchies. Same result as the recursive form on
|
|
52
|
+
# acyclic graphs.
|
|
53
|
+
acc: set[str] = set()
|
|
54
|
+
stack = list(parents.get(term, ()))
|
|
55
|
+
while stack:
|
|
56
|
+
p = stack.pop()
|
|
57
|
+
if p in acc or p == term:
|
|
58
|
+
continue
|
|
59
|
+
acc.add(p)
|
|
60
|
+
stack.extend(parents.get(p, ()))
|
|
61
|
+
cache[term] = acc
|
|
62
|
+
return acc
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _build_ic(
|
|
66
|
+
entity_terms: list[set[str]],
|
|
67
|
+
parents: dict[str, tuple[str, ...]],
|
|
68
|
+
anc_cache: dict[str, set],
|
|
69
|
+
) -> dict[str, float]:
|
|
70
|
+
"""Information content per term via closure-propagated frequency."""
|
|
71
|
+
n = len(entity_terms)
|
|
72
|
+
freq: dict[str, int] = {}
|
|
73
|
+
for terms in entity_terms:
|
|
74
|
+
clo = set(terms)
|
|
75
|
+
for t in terms:
|
|
76
|
+
clo |= _ancestors(t, parents, anc_cache)
|
|
77
|
+
for t in clo:
|
|
78
|
+
freq[t] = freq.get(t, 0) + 1
|
|
79
|
+
return {t: -math.log(c / n) for t, c in freq.items()} if n else {}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def run_oneshot(
|
|
83
|
+
name: str,
|
|
84
|
+
mounted: MountedOntology,
|
|
85
|
+
records: dict[str, set[str]],
|
|
86
|
+
memories: list[Memory],
|
|
87
|
+
*,
|
|
88
|
+
seeds: Iterable[int] = (7, 17, 23),
|
|
89
|
+
n_index: int = 2000,
|
|
90
|
+
n_query: int = 120,
|
|
91
|
+
holdout_frac: float = 0.1,
|
|
92
|
+
) -> dict:
|
|
93
|
+
"""Run the one-shot agentic benchmark and return a result dict.
|
|
94
|
+
|
|
95
|
+
``records`` maps ``entity_id -> set(term_id)``. Returns
|
|
96
|
+
``{"arm", "memories", "n_all", "n_rare", "n_novel", "per_memory", "primary"}``
|
|
97
|
+
where ``per_memory[name]`` carries tail top-k (all + rare slices), AURC, and
|
|
98
|
+
novelty F1, and ``primary`` is the SMA-vs-best-enterprise paired bootstrap.
|
|
99
|
+
"""
|
|
100
|
+
graph = mounted.graph
|
|
101
|
+
parents = {tid: tuple(t.parents) for tid, t in graph.terms.items()}
|
|
102
|
+
|
|
103
|
+
def term_text(t: str) -> str:
|
|
104
|
+
nm = graph.terms[t].name if t in graph.terms else ""
|
|
105
|
+
return nm or t
|
|
106
|
+
|
|
107
|
+
mem_names = [m.name for m in memories]
|
|
108
|
+
|
|
109
|
+
# Eligible entities: those with at least one known term (SORTED for determinism).
|
|
110
|
+
eligible = sorted(
|
|
111
|
+
eid
|
|
112
|
+
for eid, terms in records.items()
|
|
113
|
+
if any(t in graph.terms for t in terms)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Pooled per-query rows across seeds. Each row holds every memory's rank plus
|
|
117
|
+
# confidence/novelty/flags for that query.
|
|
118
|
+
answerable_rows: list[dict] = [] # rank rows for tail_topk on answerable queries
|
|
119
|
+
per_mem: dict[str, dict[str, list]] = {
|
|
120
|
+
m: {"ans_conf": [], "ans_correct": [], "all_pred_novel": [], "all_is_novel": []}
|
|
121
|
+
for m in mem_names
|
|
122
|
+
}
|
|
123
|
+
# Per-query top-5 correctness on ALL queries (answerable + novel) for the
|
|
124
|
+
# paired bootstrap. A novel query is "correct" only if the memory abstains
|
|
125
|
+
# via novelty; for top-5 retrieval correctness a novel query is always a miss.
|
|
126
|
+
top5_ans: dict[str, list[float]] = {m: [] for m in mem_names}
|
|
127
|
+
|
|
128
|
+
n_novel_total = 0
|
|
129
|
+
|
|
130
|
+
for seed in seeds:
|
|
131
|
+
rng = random.Random(seed)
|
|
132
|
+
ids = list(eligible)
|
|
133
|
+
rng.shuffle(ids)
|
|
134
|
+
pool = ids[:n_index]
|
|
135
|
+
|
|
136
|
+
# Deterministic NOVEL holdout: SORTED -> shuffle -> slice.
|
|
137
|
+
pool_sorted = sorted(pool)
|
|
138
|
+
rng.shuffle(pool_sorted)
|
|
139
|
+
n_holdout = int(round(len(pool_sorted) * holdout_frac))
|
|
140
|
+
novel_ids = sorted(pool_sorted[:n_holdout])
|
|
141
|
+
index_ids = sorted(pool_sorted[n_holdout:])
|
|
142
|
+
|
|
143
|
+
# Indexed term-sets (known terms only, SORTED to lists).
|
|
144
|
+
dz = {e: sorted(t for t in records[e] if t in graph.terms) for e in index_ids}
|
|
145
|
+
dz_novel = {e: sorted(t for t in records[e] if t in graph.terms) for e in novel_ids}
|
|
146
|
+
|
|
147
|
+
# IC + rare threshold over the INDEXED records only.
|
|
148
|
+
anc_cache: dict[str, set] = {}
|
|
149
|
+
ic = _build_ic([set(v) for v in dz.values()], parents, anc_cache)
|
|
150
|
+
median_ic = statistics.median(ic.values()) if ic else 0.0
|
|
151
|
+
noise_pool = sorted(ic) or sorted({t for v in dz.values() for t in v})
|
|
152
|
+
|
|
153
|
+
# Build IndexItems and index every memory (identical input).
|
|
154
|
+
items = [
|
|
155
|
+
IndexItem(
|
|
156
|
+
key=e,
|
|
157
|
+
term_ids=frozenset(dz[e]),
|
|
158
|
+
text=" ".join(term_text(t) for t in dz[e]),
|
|
159
|
+
meta={"id": e},
|
|
160
|
+
)
|
|
161
|
+
for e in index_ids
|
|
162
|
+
]
|
|
163
|
+
for mem in memories:
|
|
164
|
+
mem.index(items)
|
|
165
|
+
|
|
166
|
+
# Query specs: hard partial/imprecise observations for answerable AND novel.
|
|
167
|
+
def make_qspec(terms: list[str]) -> list[str]:
|
|
168
|
+
keep = rng.sample(terms, min(5, len(terms)))
|
|
169
|
+
q: list[str] = []
|
|
170
|
+
for t in keep:
|
|
171
|
+
cur = t
|
|
172
|
+
for _ in range(rng.choice([0, 0, 1, 1, 2])):
|
|
173
|
+
ps = parents.get(cur)
|
|
174
|
+
if ps:
|
|
175
|
+
cur = rng.choice(sorted(ps))
|
|
176
|
+
q.append(cur)
|
|
177
|
+
if noise_pool:
|
|
178
|
+
q += rng.sample(noise_pool, min(3, len(noise_pool)))
|
|
179
|
+
return q
|
|
180
|
+
|
|
181
|
+
# Allocate the n_query budget between answerable and novel entities,
|
|
182
|
+
# preserving the holdout proportion. Both kinds get hard queries.
|
|
183
|
+
ans_candidates = [e for e in index_ids if dz[e]]
|
|
184
|
+
nov_candidates = [e for e in novel_ids if dz_novel[e]]
|
|
185
|
+
n_nov = min(len(nov_candidates), int(round(n_query * holdout_frac)))
|
|
186
|
+
n_ans = min(len(ans_candidates), n_query - n_nov)
|
|
187
|
+
ans_q = ans_candidates[:n_ans]
|
|
188
|
+
nov_q = nov_candidates[:n_nov]
|
|
189
|
+
n_novel_total += len(nov_q)
|
|
190
|
+
|
|
191
|
+
qspecs: list[tuple[str, list[str], bool]] = []
|
|
192
|
+
for e in ans_q:
|
|
193
|
+
qspecs.append((e, make_qspec(dz[e]), False))
|
|
194
|
+
for e in nov_q:
|
|
195
|
+
qspecs.append((e, make_qspec(dz_novel[e]), True))
|
|
196
|
+
|
|
197
|
+
for e, qterms, is_novel in qspecs:
|
|
198
|
+
query = Query(
|
|
199
|
+
term_ids=frozenset(qterms),
|
|
200
|
+
text=" ".join(term_text(t) for t in qterms),
|
|
201
|
+
)
|
|
202
|
+
rare = (
|
|
203
|
+
max((ic.get(t, 0.0) for t in (dz[e] if not is_novel else dz_novel[e])), default=0.0)
|
|
204
|
+
> median_ic
|
|
205
|
+
)
|
|
206
|
+
rank_row = {"rare": rare}
|
|
207
|
+
for mem in memories:
|
|
208
|
+
res = mem.retrieve(query, k=10)
|
|
209
|
+
rank = next((r.rank for r in res if r.key == e), ABSENT_RANK)
|
|
210
|
+
if is_novel:
|
|
211
|
+
rank = ABSENT_RANK # unanswerable: true key is not indexed
|
|
212
|
+
conf = res[0].confidence if res else 0.0
|
|
213
|
+
nov = mem.novelty(query)
|
|
214
|
+
|
|
215
|
+
rank_row[mem.name] = rank
|
|
216
|
+
top5_ans[mem.name].append(1.0 if (not is_novel and rank <= 5) else 0.0)
|
|
217
|
+
per_mem[mem.name]["all_pred_novel"].append(nov > NOVELTY_THRESHOLD)
|
|
218
|
+
per_mem[mem.name]["all_is_novel"].append(is_novel)
|
|
219
|
+
if not is_novel:
|
|
220
|
+
correct = rank <= 5
|
|
221
|
+
per_mem[mem.name]["ans_conf"].append(conf)
|
|
222
|
+
per_mem[mem.name]["ans_correct"].append(correct)
|
|
223
|
+
|
|
224
|
+
if not is_novel:
|
|
225
|
+
answerable_rows.append(rank_row)
|
|
226
|
+
|
|
227
|
+
# --- aggregate ---------------------------------------------------------
|
|
228
|
+
tail5 = tail_topk(answerable_rows, k=5)
|
|
229
|
+
tail1 = tail_topk(answerable_rows, k=1)
|
|
230
|
+
tail10 = tail_topk(answerable_rows, k=10)
|
|
231
|
+
|
|
232
|
+
per_memory: dict[str, dict] = {}
|
|
233
|
+
for m in mem_names:
|
|
234
|
+
aurc, _curve = risk_coverage_aurc(
|
|
235
|
+
per_mem[m]["ans_conf"], per_mem[m]["ans_correct"]
|
|
236
|
+
)
|
|
237
|
+
f1 = novelty_f1(per_mem[m]["all_pred_novel"], per_mem[m]["all_is_novel"])
|
|
238
|
+
per_memory[m] = {
|
|
239
|
+
"tail": {
|
|
240
|
+
"top1": tail1.get(m, {"all": 0.0, "rare": 0.0}),
|
|
241
|
+
"top5": tail5.get(m, {"all": 0.0, "rare": 0.0}),
|
|
242
|
+
"top10": tail10.get(m, {"all": 0.0, "rare": 0.0}),
|
|
243
|
+
},
|
|
244
|
+
"aurc": aurc,
|
|
245
|
+
"novelty_f1": f1,
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
# --- primary: SMA vs best enterprise-RAG on per-query top-5 correctness --
|
|
249
|
+
primary: dict | None = None
|
|
250
|
+
present_enterprise = [m for m in ENTERPRISE_NAMES if m in tail5]
|
|
251
|
+
if "sma" in tail5 and present_enterprise:
|
|
252
|
+
best = max(present_enterprise, key=lambda m: tail5[m]["all"])
|
|
253
|
+
a = top5_ans["sma"]
|
|
254
|
+
b = top5_ans[best]
|
|
255
|
+
bs = paired_bootstrap(a, b)
|
|
256
|
+
primary = {
|
|
257
|
+
"a": "sma",
|
|
258
|
+
"b": best,
|
|
259
|
+
"best_enterprise": best,
|
|
260
|
+
"delta_top5": bs["delta"],
|
|
261
|
+
"ci_low": bs["ci_low"],
|
|
262
|
+
"ci_high": bs["ci_high"],
|
|
263
|
+
"p_value": bs["p_value"],
|
|
264
|
+
"cliffs": cliffs_delta(a, b),
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
return {
|
|
268
|
+
"arm": name,
|
|
269
|
+
"memories": mem_names,
|
|
270
|
+
"n_all": len(answerable_rows),
|
|
271
|
+
"n_rare": sum(1 for r in answerable_rows if r["rare"]),
|
|
272
|
+
"n_novel": n_novel_total,
|
|
273
|
+
"per_memory": per_memory,
|
|
274
|
+
"primary": primary,
|
|
275
|
+
}
|