structuremappingmemory 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sma/__init__.py +5 -0
- sma/__main__.py +5 -0
- sma/agent/__init__.py +5 -0
- sma/agent/adapter_draft.py +217 -0
- sma/agent/api.py +67 -0
- sma/agent/comparison.py +591 -0
- sma/agent/llm.py +280 -0
- sma/agent/policies.py +21 -0
- sma/agent/service.py +95 -0
- sma/cli.py +65 -0
- sma/encoders/__init__.py +38 -0
- sma/encoders/agentobs.py +27 -0
- sma/encoders/base.py +23 -0
- sma/encoders/code_treesitter.py +64 -0
- sma/encoders/coverage.py +80 -0
- sma/encoders/draft_adapter.py +183 -0
- sma/encoders/healthcare.py +207 -0
- sma/encoders/logs_drain.py +142 -0
- sma/encoders/prose_tier1.py +57 -0
- sma/encoders/structured.py +57 -0
- sma/encoders/traces.py +45 -0
- sma/eval/__init__.py +2 -0
- sma/eval/agentic/__init__.py +35 -0
- sma/eval/agentic/arms/__init__.py +0 -0
- sma/eval/agentic/arms/cyber.py +48 -0
- sma/eval/agentic/arms/discovery.py +35 -0
- sma/eval/agentic/arms/finance.py +38 -0
- sma/eval/agentic/arms/legal.py +74 -0
- sma/eval/agentic/arms/medicine.py +45 -0
- sma/eval/agentic/harness.py +275 -0
- sma/eval/agentic/memories.py +308 -0
- sma/eval/agentic/metrics.py +82 -0
- sma/eval/agentic_qa/__init__.py +27 -0
- sma/eval/agentic_qa/agent.py +383 -0
- sma/eval/agentic_qa/metrics.py +239 -0
- sma/eval/agentic_qa/pools.py +197 -0
- sma/eval/arn.py +65 -0
- sma/eval/baselines/__init__.py +6 -0
- sma/eval/baselines/bge_dense.py +54 -0
- sma/eval/baselines/bm25.py +18 -0
- sma/eval/baselines/dense.py +42 -0
- sma/eval/baselines/hipporag.py +235 -0
- sma/eval/baselines/hybrid_rrf.py +30 -0
- sma/eval/baselines/longcontext_llm.py +124 -0
- sma/eval/baselines/rerank.py +41 -0
- sma/eval/baselines/splade.py +77 -0
- sma/eval/baselines/wl_kernel.py +163 -0
- sma/eval/bugsinpy.py +358 -0
- sma/eval/bugsinpy_families.py +164 -0
- sma/eval/crossdomain.py +89 -0
- sma/eval/diabetes.py +61 -0
- sma/eval/drift_env.py +26 -0
- sma/eval/drift_metrics.py +24 -0
- sma/eval/family_labels.py +167 -0
- sma/eval/fraud_elliptic/__init__.py +29 -0
- sma/eval/fraud_elliptic/encoder.py +279 -0
- sma/eval/fraud_elliptic/eval.py +269 -0
- sma/eval/fraud_elliptic/test_encoder.py +123 -0
- sma/eval/ieee_cis.py +66 -0
- sma/eval/loghub.py +16 -0
- sma/eval/loghub_eval.py +480 -0
- sma/eval/longmemeval.py +51 -0
- sma/eval/memory_backends/__init__.py +2 -0
- sma/eval/memory_backends/base.py +22 -0
- sma/eval/memory_backends/context_only.py +14 -0
- sma/eval/memory_backends/rag_notes.py +17 -0
- sma/eval/memory_backends/shared_llm.py +30 -0
- sma/eval/memory_backends/sma_memory.py +54 -0
- sma/eval/memory_backends/zep_graphiti.py +33 -0
- sma/eval/metrics.py +32 -0
- sma/eval/ontology_bench.py +219 -0
- sma/eval/report.py +573 -0
- sma/eval/ssb_eval.py +216 -0
- sma/eval/ssb_generator.py +116 -0
- sma/eval/stats.py +108 -0
- sma/eval/transfer_eval.py +844 -0
- sma/index/__init__.py +15 -0
- sma/index/ann.py +21 -0
- sma/index/content_vectors.py +60 -0
- sma/index/inverted.py +63 -0
- sma/index/macfac.py +174 -0
- sma/ir/__init__.py +22 -0
- sma/ir/canon.py +106 -0
- sma/ir/schema.py +165 -0
- sma/ir/sexpr.py +86 -0
- sma/ir/signatures.py +76 -0
- sma/match/__init__.py +20 -0
- sma/match/conflicts.py +46 -0
- sma/match/engine.py +60 -0
- sma/match/explain.py +59 -0
- sma/match/infer.py +54 -0
- sma/match/kernels.py +54 -0
- sma/match/mdl.py +30 -0
- sma/match/merge_cpsat.py +77 -0
- sma/match/merge_greedy.py +15 -0
- sma/match/mh.py +177 -0
- sma/match/ses.py +84 -0
- sma/match/types.py +115 -0
- sma/match/verifier.py +27 -0
- sma/ontology/__init__.py +45 -0
- sma/ontology/attack.py +134 -0
- sma/ontology/cpc.py +69 -0
- sma/ontology/graph.py +58 -0
- sma/ontology/loader.py +262 -0
- sma/ontology/mitre_xml.py +67 -0
- sma/ontology/mount.py +101 -0
- sma/ontology/rdf_loader.py +75 -0
- sma/ontology/registry.py +115 -0
- sma/ontology/router.py +69 -0
- sma/ontology/usgaap.py +73 -0
- sma/sage/__init__.py +6 -0
- sma/sage/assimilate.py +12 -0
- sma/sage/pools.py +105 -0
- sma/sage/probabilities.py +10 -0
- sma/store/__init__.py +6 -0
- sma/store/lmdb_store.py +78 -0
- sma/store/registry.py +26 -0
- sma/store/wal.py +26 -0
- sma/ui/app.py +642 -0
- structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
- structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
- structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
- structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
- structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
- structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
sma/index/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .content_vectors import CaseVector, cosine, functor_vector
|
|
2
|
+
from .inverted import InvertedIndex, histogram_intersection, ses_upper_bound
|
|
3
|
+
from .macfac import MacFacIndex, RetrievalResult
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"CaseVector",
|
|
7
|
+
"InvertedIndex",
|
|
8
|
+
"MacFacIndex",
|
|
9
|
+
"RetrievalResult",
|
|
10
|
+
"cosine",
|
|
11
|
+
"functor_vector",
|
|
12
|
+
"histogram_intersection",
|
|
13
|
+
"ses_upper_bound",
|
|
14
|
+
]
|
|
15
|
+
|
sma/index/ann.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Small ANN facade with deterministic brute-force fallback."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
from .content_vectors import Vector, cosine
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class AnnIndex:
|
|
12
|
+
vectors: dict[str, Vector] = field(default_factory=dict)
|
|
13
|
+
|
|
14
|
+
def add(self, case_id: str, vector: Vector) -> None:
|
|
15
|
+
self.vectors[case_id] = vector
|
|
16
|
+
|
|
17
|
+
def search(self, query: Vector, k: int = 200) -> list[tuple[str, float]]:
|
|
18
|
+
ranked = [(case_id, cosine(query, vector)) for case_id, vector in self.vectors.items()]
|
|
19
|
+
ranked.sort(key=lambda row: (-row[1], row[0]))
|
|
20
|
+
return ranked[:k]
|
|
21
|
+
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""MAC content vectors with WL-1 refinement features."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import Counter
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
from sma.ir.canon import Canonicalizer, default_canonicalizer
|
|
9
|
+
from sma.ir.schema import Case, Statement
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
Vector = Counter[str]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def functor_vector(
|
|
16
|
+
case: Case,
|
|
17
|
+
wl: bool = True,
|
|
18
|
+
canon: Canonicalizer | None = None,
|
|
19
|
+
canonicalize: bool = True,
|
|
20
|
+
delta: int = 0,
|
|
21
|
+
) -> Vector:
|
|
22
|
+
"""MAC content vector: canonical functor counts + WL-1 features.
|
|
23
|
+
|
|
24
|
+
With delta > 0, each functor also contributes its lattice ancestors within
|
|
25
|
+
delta steps (blueprint 2.7: counts over the <=delta ancestor closure), so
|
|
26
|
+
vocabularies bridged only by the lattice still intersect at the MAC stage.
|
|
27
|
+
Ancestor features only ADD mass, keeping the Lemma-2 bound admissible.
|
|
28
|
+
"""
|
|
29
|
+
canon = canon or default_canonicalizer()
|
|
30
|
+
counts: Vector = Counter()
|
|
31
|
+
for expr in case.expressions():
|
|
32
|
+
functor = canon.canonical(expr.functor) if canonicalize else expr.functor
|
|
33
|
+
counts[f"f:{functor}"] += 1
|
|
34
|
+
if delta:
|
|
35
|
+
for ancestor, dist in canon.lattice.ancestors(functor, delta).items():
|
|
36
|
+
if ancestor != functor:
|
|
37
|
+
counts[f"f:{ancestor}"] += 1
|
|
38
|
+
if wl:
|
|
39
|
+
for i, arg in enumerate(expr.args):
|
|
40
|
+
if isinstance(arg, Statement):
|
|
41
|
+
child_functor = canon.canonical(arg.functor) if canonicalize else arg.functor
|
|
42
|
+
counts[f"wl:{functor}:{i}:{child_functor}"] += 1
|
|
43
|
+
else:
|
|
44
|
+
counts[f"wl:{functor}:{i}:ENTITY"] += 1
|
|
45
|
+
return counts
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def cosine(left: Vector, right: Vector) -> float:
|
|
49
|
+
dot = sum(v * right.get(k, 0) for k, v in left.items())
|
|
50
|
+
left_norm = sum(v * v for v in left.values()) ** 0.5
|
|
51
|
+
right_norm = sum(v * v for v in right.values()) ** 0.5
|
|
52
|
+
if left_norm == 0 or right_norm == 0:
|
|
53
|
+
return 0.0
|
|
54
|
+
return dot / (left_norm * right_norm)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(frozen=True)
|
|
58
|
+
class CaseVector:
|
|
59
|
+
case_id: str
|
|
60
|
+
vector: Vector
|
sma/index/inverted.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Exact histogram-intersection upper bounds."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import Counter, defaultdict
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
|
|
8
|
+
from .content_vectors import Vector
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def histogram_intersection(left: Vector, right: Vector) -> int:
|
|
12
|
+
if len(left) > len(right):
|
|
13
|
+
left, right = right, left
|
|
14
|
+
return sum(min(value, right.get(key, 0)) for key, value in left.items())
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def ses_upper_bound(
|
|
18
|
+
left: Vector, right: Vector, max_score_per_mh: float = 2.0, costs: dict | None = None
|
|
19
|
+
) -> float:
|
|
20
|
+
"""Admissible bound on the (optionally surprisal-weighted) SES score.
|
|
21
|
+
|
|
22
|
+
Unweighted: max_score_per_mh * histogram intersection (Lemma 2). With
|
|
23
|
+
costs, each shared functor occurrence may carry at most cost * s-bar, so
|
|
24
|
+
'f:' features are cost-weighted; WL features keep weight 1, which only
|
|
25
|
+
adds slack (still admissible, blueprint section 2.7 weighted form).
|
|
26
|
+
"""
|
|
27
|
+
if costs is None:
|
|
28
|
+
return max_score_per_mh * histogram_intersection(left, right)
|
|
29
|
+
if len(left) > len(right):
|
|
30
|
+
left, right = right, left
|
|
31
|
+
total = 0.0
|
|
32
|
+
for key, value in left.items():
|
|
33
|
+
shared = min(value, right.get(key, 0))
|
|
34
|
+
if not shared:
|
|
35
|
+
continue
|
|
36
|
+
weight = costs.get(key[2:], 1.0) if key.startswith("f:") else 1.0
|
|
37
|
+
total += weight * shared
|
|
38
|
+
return max_score_per_mh * total
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class InvertedIndex:
|
|
43
|
+
postings: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
|
|
44
|
+
vectors: dict[str, Vector] = field(default_factory=dict)
|
|
45
|
+
|
|
46
|
+
def add(self, case_id: str, vector: Vector) -> None:
|
|
47
|
+
self.vectors[case_id] = vector
|
|
48
|
+
for feature in vector:
|
|
49
|
+
self.postings.setdefault(feature, set()).add(case_id)
|
|
50
|
+
|
|
51
|
+
def candidates(self, query: Vector) -> set[str]:
|
|
52
|
+
out: set[str] = set()
|
|
53
|
+
for feature in query:
|
|
54
|
+
out.update(self.postings.get(feature, ()))
|
|
55
|
+
return out
|
|
56
|
+
|
|
57
|
+
def bound(
|
|
58
|
+
self, query: Vector, case_id: str, max_score_per_mh: float = 2.0, costs: dict | None = None
|
|
59
|
+
) -> float:
|
|
60
|
+
return ses_upper_bound(
|
|
61
|
+
query, self.vectors[case_id], max_score_per_mh=max_score_per_mh, costs=costs
|
|
62
|
+
)
|
|
63
|
+
|
sma/index/macfac.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""Certified MAC/FAC retrieval over a case library."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
from sma.ir.canon import Canonicalizer, default_canonicalizer
|
|
8
|
+
from sma.ir.schema import Case
|
|
9
|
+
from sma.match.engine import match_cases
|
|
10
|
+
from sma.match.ses import self_score
|
|
11
|
+
from sma.match.types import MatchConfig
|
|
12
|
+
|
|
13
|
+
from .ann import AnnIndex
|
|
14
|
+
from .content_vectors import functor_vector
|
|
15
|
+
from .inverted import InvertedIndex
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class RetrievalResult:
|
|
20
|
+
case_id: str
|
|
21
|
+
ses_n: float
|
|
22
|
+
score: float
|
|
23
|
+
u_bound: float
|
|
24
|
+
certified: bool
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MacFacIndex:
|
|
28
|
+
# Below this corpus size the admissible bound orders ALL candidates
|
|
29
|
+
# directly; above it the ANN cosine pre-screen trims first (scale only).
|
|
30
|
+
ANN_THRESHOLD = 20000
|
|
31
|
+
|
|
32
|
+
def __init__(self, config: MatchConfig | None = None, canon: Canonicalizer | None = None):
|
|
33
|
+
self.config = config or MatchConfig()
|
|
34
|
+
self.canon = canon or default_canonicalizer()
|
|
35
|
+
self.cases: dict[str, Case] = {}
|
|
36
|
+
self.ann = AnnIndex()
|
|
37
|
+
self.inverted = InvertedIndex()
|
|
38
|
+
self._score_cache: dict[tuple, tuple[float, float]] = {}
|
|
39
|
+
|
|
40
|
+
def add(self, case: Case) -> None:
|
|
41
|
+
vector = functor_vector(case, canon=self.canon, delta=self.config.delta)
|
|
42
|
+
self.cases[case.case_id] = case
|
|
43
|
+
self.ann.add(case.case_id, vector)
|
|
44
|
+
self.inverted.add(case.case_id, vector)
|
|
45
|
+
if self.config.functor_costs is not None:
|
|
46
|
+
# Corpus changed; surprisal costs and cached scores are stale.
|
|
47
|
+
self.config.functor_costs = None
|
|
48
|
+
self._score_cache.clear()
|
|
49
|
+
|
|
50
|
+
def build(self, cases: list[Case]) -> None:
|
|
51
|
+
for case in cases:
|
|
52
|
+
self.add(case)
|
|
53
|
+
|
|
54
|
+
def corpus_costs(self) -> dict[str, float]:
|
|
55
|
+
"""Corpus surprisal (-log2 p, KT-smoothed) per canonical functor."""
|
|
56
|
+
import math
|
|
57
|
+
from collections import Counter
|
|
58
|
+
|
|
59
|
+
counts: Counter[str] = Counter()
|
|
60
|
+
for vector in self.inverted.vectors.values():
|
|
61
|
+
for feature, n in vector.items():
|
|
62
|
+
if feature.startswith("f:"):
|
|
63
|
+
counts[feature[2:]] += n
|
|
64
|
+
total = sum(counts.values())
|
|
65
|
+
vocab = max(len(counts), 1)
|
|
66
|
+
return {
|
|
67
|
+
functor: -math.log2((count + 0.5) / (total + 0.5 * vocab))
|
|
68
|
+
for functor, count in counts.items()
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
def retrieve(
|
|
72
|
+
self, query: Case, k: int = 10, shortlist: int = 200, fac_budget: int | None = None
|
|
73
|
+
) -> list[RetrievalResult]:
|
|
74
|
+
if self.config.scorer == "surprisal" and self.config.functor_costs is None:
|
|
75
|
+
# Lazily derive costs from the indexed corpus; deterministic given
|
|
76
|
+
# contents. Stale after add() — cleared there.
|
|
77
|
+
self.config.functor_costs = self.corpus_costs()
|
|
78
|
+
qvec = functor_vector(query, canon=self.canon, delta=self.config.delta)
|
|
79
|
+
bound_costs = self.config.functor_costs if self.config.scorer == "surprisal" else None
|
|
80
|
+
# Lemma 2: the weighted histogram-intersection bound IS the admissible
|
|
81
|
+
# MAC ordering. Up to ANN_THRESHOLD cases we bound-order everything
|
|
82
|
+
# directly (the Liberty haystack showed cosine pre-screening can drop
|
|
83
|
+
# the true positives the bound ordering ranks first); beyond it, the
|
|
84
|
+
# ANN cosine pre-screen trims the candidate set for tractability.
|
|
85
|
+
if len(self.cases) <= self.ANN_THRESHOLD:
|
|
86
|
+
candidate_ids = set(self.inverted.candidates(qvec))
|
|
87
|
+
if len(candidate_ids) < min(shortlist, len(self.cases)):
|
|
88
|
+
# Zero-overlap cases still belong in the candidate pool (bound
|
|
89
|
+
# 0, scored only if budget reaches them) so top-k cardinality
|
|
90
|
+
# matches brute force on tiny/disjoint corpora.
|
|
91
|
+
candidate_ids = set(self.cases)
|
|
92
|
+
else:
|
|
93
|
+
candidate_ids = {
|
|
94
|
+
case_id
|
|
95
|
+
for case_id, _ in self.ann.search(qvec, k=min(max(shortlist * 5, 1000), len(self.cases)))
|
|
96
|
+
}
|
|
97
|
+
bounded = [
|
|
98
|
+
(case_id, self.inverted.bound(qvec, case_id, max_score_per_mh=4.0, costs=bound_costs))
|
|
99
|
+
for case_id in candidate_ids
|
|
100
|
+
]
|
|
101
|
+
bounded.sort(key=lambda row: (-row[1], row[0]))
|
|
102
|
+
bounded = bounded[: max(shortlist, 1)]
|
|
103
|
+
# ses_n = score / max(self(base), self(target)) <= U_bound / self(target),
|
|
104
|
+
# so dividing the raw-score bound by the query's self-score gives an
|
|
105
|
+
# admissible bound in ses_n units (weighted consistently for the
|
|
106
|
+
# surprisal scorer). The MDL scorer has no such bound, so it never
|
|
107
|
+
# early-stops on bounds (budget only).
|
|
108
|
+
ses_n_denom = None
|
|
109
|
+
if self.config.scorer in ("ses", "surprisal") and self.config.normalization != "min":
|
|
110
|
+
cost_fn = None
|
|
111
|
+
if bound_costs:
|
|
112
|
+
costs = bound_costs
|
|
113
|
+
|
|
114
|
+
def cost_fn(mh):
|
|
115
|
+
from sma.ir.schema import Statement
|
|
116
|
+
|
|
117
|
+
if isinstance(mh.base, Statement):
|
|
118
|
+
return costs.get(self.canon.canonical(mh.base.functor), 1.0)
|
|
119
|
+
return 1.0
|
|
120
|
+
|
|
121
|
+
ses_n_denom = max(self_score(query, gamma=self.config.gamma, cost_fn=cost_fn), 1e-9)
|
|
122
|
+
scored: list[tuple[str, float, float, float]] = []
|
|
123
|
+
kth_ses_n = float("-inf")
|
|
124
|
+
n_examined = 0
|
|
125
|
+
for case_id, bound in bounded:
|
|
126
|
+
if fac_budget is not None and n_examined >= fac_budget:
|
|
127
|
+
break
|
|
128
|
+
if (
|
|
129
|
+
ses_n_denom is not None
|
|
130
|
+
and len(scored) >= k
|
|
131
|
+
and bound / ses_n_denom < kth_ses_n
|
|
132
|
+
):
|
|
133
|
+
break
|
|
134
|
+
score, ses_n = self._score_case(case_id, query)
|
|
135
|
+
scored.append((case_id, ses_n, score, bound))
|
|
136
|
+
n_examined += 1
|
|
137
|
+
if len(scored) >= k:
|
|
138
|
+
kth_ses_n = sorted(s[1] for s in scored)[-k]
|
|
139
|
+
# The top-k is certified exact over the shortlist iff no unexamined
|
|
140
|
+
# candidate's bound could still beat the k-th best ses_n.
|
|
141
|
+
remaining = bounded[n_examined:]
|
|
142
|
+
certified = not remaining or (
|
|
143
|
+
ses_n_denom is not None
|
|
144
|
+
and len(scored) >= k
|
|
145
|
+
and remaining[0][1] / ses_n_denom < kth_ses_n
|
|
146
|
+
)
|
|
147
|
+
scored.sort(key=lambda row: (-row[1], row[0]))
|
|
148
|
+
return [
|
|
149
|
+
RetrievalResult(case_id=cid, ses_n=ses_n, score=score, u_bound=bound, certified=certified)
|
|
150
|
+
for cid, ses_n, score, bound in scored[:k]
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
def brute_force(self, query: Case, k: int = 10) -> list[RetrievalResult]:
|
|
154
|
+
qvec = functor_vector(query, canon=self.canon, delta=self.config.delta)
|
|
155
|
+
results: list[RetrievalResult] = []
|
|
156
|
+
for case_id, case in self.cases.items():
|
|
157
|
+
score, ses_n = self._score_case(case_id, query)
|
|
158
|
+
results.append(
|
|
159
|
+
RetrievalResult(
|
|
160
|
+
case_id=case_id,
|
|
161
|
+
ses_n=ses_n,
|
|
162
|
+
score=score,
|
|
163
|
+
u_bound=self.inverted.bound(qvec, case_id, max_score_per_mh=4.0),
|
|
164
|
+
certified=True,
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
return sorted(results, key=lambda row: (-row.ses_n, row.case_id))[:k]
|
|
168
|
+
|
|
169
|
+
def _score_case(self, case_id: str, query: Case) -> tuple[float, float]:
|
|
170
|
+
key = (case_id, query.case_id, self.config.scorer, self.config.normalization)
|
|
171
|
+
if key not in self._score_cache:
|
|
172
|
+
gmap = match_cases(self.cases[case_id], query, config=self.config, canon=self.canon)
|
|
173
|
+
self._score_cache[key] = (gmap.score, gmap.normalized_score)
|
|
174
|
+
return self._score_cache[key]
|
sma/ir/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from .canon import Canonicalizer, PredicateLattice, default_canonicalizer
|
|
2
|
+
from .schema import Case, Entity, Signature, Statement, SymbolKind, entity, make_case, stmt
|
|
3
|
+
from .sexpr import canonical_case_text, dumps_statement, loads_case, loads_statement
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"Canonicalizer",
|
|
7
|
+
"Case",
|
|
8
|
+
"Entity",
|
|
9
|
+
"PredicateLattice",
|
|
10
|
+
"Signature",
|
|
11
|
+
"Statement",
|
|
12
|
+
"SymbolKind",
|
|
13
|
+
"canonical_case_text",
|
|
14
|
+
"default_canonicalizer",
|
|
15
|
+
"dumps_statement",
|
|
16
|
+
"entity",
|
|
17
|
+
"loads_case",
|
|
18
|
+
"loads_statement",
|
|
19
|
+
"make_case",
|
|
20
|
+
"stmt",
|
|
21
|
+
]
|
|
22
|
+
|
sma/ir/canon.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Symbolic canonicalization and minimal ascension support."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import deque
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
|
|
8
|
+
from .schema import safe_symbol
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class PredicateLattice:
|
|
13
|
+
parents: dict[str, set[str]] = field(default_factory=dict)
|
|
14
|
+
|
|
15
|
+
def add(self, child: str, parent: str) -> None:
|
|
16
|
+
self.parents.setdefault(safe_symbol(child), set()).add(safe_symbol(parent))
|
|
17
|
+
|
|
18
|
+
def ancestors(self, symbol: str, max_depth: int = 2) -> dict[str, int]:
|
|
19
|
+
symbol = safe_symbol(symbol)
|
|
20
|
+
out = {symbol: 0}
|
|
21
|
+
queue = deque([(symbol, 0)])
|
|
22
|
+
while queue:
|
|
23
|
+
current, depth = queue.popleft()
|
|
24
|
+
if depth >= max_depth:
|
|
25
|
+
continue
|
|
26
|
+
for parent in self.parents.get(current, ()):
|
|
27
|
+
if parent not in out or depth + 1 < out[parent]:
|
|
28
|
+
out[parent] = depth + 1
|
|
29
|
+
queue.append((parent, depth + 1))
|
|
30
|
+
return out
|
|
31
|
+
|
|
32
|
+
def minimal_ascension(self, left: str, right: str, delta: int) -> tuple[str, int] | None:
|
|
33
|
+
left_anc = self.ancestors(left, delta)
|
|
34
|
+
right_anc = self.ancestors(right, delta)
|
|
35
|
+
overlap = set(left_anc).intersection(right_anc)
|
|
36
|
+
if not overlap:
|
|
37
|
+
return None
|
|
38
|
+
best = min(overlap, key=lambda s: (left_anc[s] + right_anc[s], s))
|
|
39
|
+
dist = left_anc[best] + right_anc[best]
|
|
40
|
+
if dist <= delta:
|
|
41
|
+
return best, dist
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class Canonicalizer:
|
|
47
|
+
aliases: dict[str, str] = field(default_factory=dict)
|
|
48
|
+
lattice: PredicateLattice = field(default_factory=PredicateLattice)
|
|
49
|
+
|
|
50
|
+
def canonical(self, symbol: str) -> str:
|
|
51
|
+
# NOTE: this used to strip "far_"/"near_" prefixes - a circularity
|
|
52
|
+
# with the old SSB generator (the benchmark's vocabulary bijection
|
|
53
|
+
# was known to the matcher). Removed: cross-vocabulary matching now
|
|
54
|
+
# goes exclusively through the lattice with ascension penalties.
|
|
55
|
+
symbol = safe_symbol(symbol)
|
|
56
|
+
seen: set[str] = set()
|
|
57
|
+
while symbol in self.aliases and symbol not in seen:
|
|
58
|
+
seen.add(symbol)
|
|
59
|
+
symbol = safe_symbol(self.aliases[symbol])
|
|
60
|
+
return symbol
|
|
61
|
+
|
|
62
|
+
def compatible(
|
|
63
|
+
self, left: str, right: str, delta: int = 0, rho: float = 1.0
|
|
64
|
+
) -> tuple[bool, float, str | None, int]:
|
|
65
|
+
left_c = self.canonical(left)
|
|
66
|
+
right_c = self.canonical(right)
|
|
67
|
+
if left_c == right_c:
|
|
68
|
+
return True, 1.0, left_c, 0
|
|
69
|
+
if delta <= 0:
|
|
70
|
+
return False, 0.0, None, 0
|
|
71
|
+
asc = self.lattice.minimal_ascension(left_c, right_c, delta)
|
|
72
|
+
if asc is None:
|
|
73
|
+
return False, 0.0, None, 0
|
|
74
|
+
ancestor, dist = asc
|
|
75
|
+
return True, rho**dist, ancestor, dist
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def default_canonicalizer() -> Canonicalizer:
|
|
79
|
+
canon = Canonicalizer(
|
|
80
|
+
aliases={
|
|
81
|
+
"connTimeout": "timeout",
|
|
82
|
+
"connectionTimeout": "timeout",
|
|
83
|
+
"retryStorm": "retry",
|
|
84
|
+
"blockReceiveError": "ioError",
|
|
85
|
+
"pressure": "intensity",
|
|
86
|
+
"temperature": "intensity",
|
|
87
|
+
"waterFlow": "flow",
|
|
88
|
+
"heatFlow": "flow",
|
|
89
|
+
"sun": "centralBody",
|
|
90
|
+
"nucleus": "centralBody",
|
|
91
|
+
"planet": "orbitingBody",
|
|
92
|
+
"electron": "orbitingBody",
|
|
93
|
+
"attractsGravity": "attracts",
|
|
94
|
+
"attractsElectrostatic": "attracts",
|
|
95
|
+
}
|
|
96
|
+
)
|
|
97
|
+
for child, parent in (
|
|
98
|
+
("timeout", "failureEvent"),
|
|
99
|
+
("ioError", "failureEvent"),
|
|
100
|
+
("exception", "failureEvent"),
|
|
101
|
+
("retry", "recoveryAction"),
|
|
102
|
+
("restart", "recoveryAction"),
|
|
103
|
+
("saturation", "resourcePressure"),
|
|
104
|
+
):
|
|
105
|
+
canon.lattice.add(child, parent)
|
|
106
|
+
return canon
|
sma/ir/schema.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Typed predicate IR for SMA cases.
|
|
2
|
+
|
|
3
|
+
The implementation deliberately keeps the runtime representation small and
|
|
4
|
+
serializable. It is immutable enough for hashing, but remains plain Python so
|
|
5
|
+
fixtures and reports are easy to inspect.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Any, Iterable, Mapping, Sequence
|
|
13
|
+
|
|
14
|
+
import blake3
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SymbolKind(str, Enum):
|
|
18
|
+
ENTITY = "entity"
|
|
19
|
+
FUNCTION = "function"
|
|
20
|
+
ATTRIBUTE = "attribute"
|
|
21
|
+
RELATION = "relation"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(frozen=True, slots=True)
|
|
25
|
+
class Signature:
|
|
26
|
+
functor: str
|
|
27
|
+
arity: int
|
|
28
|
+
kind: SymbolKind = SymbolKind.RELATION
|
|
29
|
+
arg_types: tuple[str, ...] = ()
|
|
30
|
+
commutative: bool = False
|
|
31
|
+
higher_order: bool = False
|
|
32
|
+
|
|
33
|
+
def validate_arity(self, args: Sequence[Node]) -> None:
|
|
34
|
+
if len(args) != self.arity:
|
|
35
|
+
raise ValueError(f"{self.functor} expects {self.arity} args, got {len(args)}")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass(frozen=True, slots=True)
|
|
39
|
+
class Entity:
|
|
40
|
+
name: str
|
|
41
|
+
type: str = "entity"
|
|
42
|
+
|
|
43
|
+
def nodes(self) -> tuple[Entity, ...]:
|
|
44
|
+
return (self,)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True, slots=True)
|
|
48
|
+
class Statement:
|
|
49
|
+
functor: str
|
|
50
|
+
args: tuple["Node", ...] = ()
|
|
51
|
+
ascension: float = 1.0
|
|
52
|
+
|
|
53
|
+
def nodes(self) -> tuple["Node", ...]:
|
|
54
|
+
out: list[Node] = [self]
|
|
55
|
+
for arg in self.args:
|
|
56
|
+
out.extend(arg.nodes())
|
|
57
|
+
return tuple(out)
|
|
58
|
+
|
|
59
|
+
def expressions(self) -> tuple["Statement", ...]:
|
|
60
|
+
out: list[Statement] = [self]
|
|
61
|
+
for arg in self.args:
|
|
62
|
+
if isinstance(arg, Statement):
|
|
63
|
+
out.extend(arg.expressions())
|
|
64
|
+
return tuple(out)
|
|
65
|
+
|
|
66
|
+
def entities(self) -> tuple[Entity, ...]:
|
|
67
|
+
out: list[Entity] = []
|
|
68
|
+
for arg in self.args:
|
|
69
|
+
if isinstance(arg, Entity):
|
|
70
|
+
out.append(arg)
|
|
71
|
+
else:
|
|
72
|
+
out.extend(arg.entities())
|
|
73
|
+
return tuple(out)
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def arity(self) -> int:
|
|
77
|
+
return len(self.args)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
Node = Entity | Statement
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass(frozen=True, slots=True)
|
|
84
|
+
class Case:
|
|
85
|
+
case_id: str
|
|
86
|
+
statements: tuple[Statement, ...]
|
|
87
|
+
metadata: Mapping[str, Any] = field(default_factory=dict)
|
|
88
|
+
|
|
89
|
+
def expressions(self) -> tuple[Statement, ...]:
|
|
90
|
+
seen: set[str] = set()
|
|
91
|
+
ordered: list[Statement] = []
|
|
92
|
+
from .sexpr import dumps_statement
|
|
93
|
+
|
|
94
|
+
for stmt in self.statements:
|
|
95
|
+
for expr in stmt.expressions():
|
|
96
|
+
key = dumps_statement(expr)
|
|
97
|
+
if key not in seen:
|
|
98
|
+
seen.add(key)
|
|
99
|
+
ordered.append(expr)
|
|
100
|
+
return tuple(ordered)
|
|
101
|
+
|
|
102
|
+
def entities(self) -> tuple[Entity, ...]:
|
|
103
|
+
seen: set[tuple[str, str]] = set()
|
|
104
|
+
ordered: list[Entity] = []
|
|
105
|
+
for stmt in self.statements:
|
|
106
|
+
for entity in stmt.entities():
|
|
107
|
+
key = (entity.name, entity.type)
|
|
108
|
+
if key not in seen:
|
|
109
|
+
seen.add(key)
|
|
110
|
+
ordered.append(entity)
|
|
111
|
+
return tuple(ordered)
|
|
112
|
+
|
|
113
|
+
def functor_counts(self) -> dict[str, int]:
|
|
114
|
+
counts: dict[str, int] = {}
|
|
115
|
+
for expr in self.expressions():
|
|
116
|
+
counts[expr.functor] = counts.get(expr.functor, 0) + 1
|
|
117
|
+
return counts
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def make_case(
|
|
121
|
+
statements: Iterable[Statement], metadata: Mapping[str, Any] | None = None, case_id: str = ""
|
|
122
|
+
) -> Case:
|
|
123
|
+
from .sexpr import canonical_case_text
|
|
124
|
+
|
|
125
|
+
ordered = tuple(sorted(statements, key=lambda s: canonical_case_text([s])))
|
|
126
|
+
text = canonical_case_text(ordered)
|
|
127
|
+
return Case(case_id or content_id(text), ordered, dict(metadata or {}))
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def content_id(text: str) -> str:
|
|
131
|
+
# blake3 is a hard dependency: case ids are content addresses and must be
|
|
132
|
+
# identical across machines, so no fallback hash is allowed here.
|
|
133
|
+
return blake3.blake3(text.encode("utf-8")).hexdigest()
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def walk_statement(stmt: Statement) -> Iterable[Node]:
|
|
137
|
+
yield stmt
|
|
138
|
+
for arg in stmt.args:
|
|
139
|
+
if isinstance(arg, Statement):
|
|
140
|
+
yield from walk_statement(arg)
|
|
141
|
+
else:
|
|
142
|
+
yield arg
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def entity(name: str, type: str = "entity") -> Entity:
|
|
146
|
+
return Entity(name=safe_symbol(name), type=safe_symbol(type))
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def stmt(functor: str, *args: Node | str, ascension: float = 1.0) -> Statement:
|
|
150
|
+
coerced = tuple(arg if isinstance(arg, (Entity, Statement)) else entity(str(arg)) for arg in args)
|
|
151
|
+
return Statement(safe_symbol(functor), coerced, ascension=ascension)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def safe_symbol(value: str) -> str:
|
|
155
|
+
value = str(value).strip()
|
|
156
|
+
if not value:
|
|
157
|
+
return "_"
|
|
158
|
+
out = []
|
|
159
|
+
for ch in value:
|
|
160
|
+
if ch.isalnum() or ch in "_-:.@/":
|
|
161
|
+
out.append(ch)
|
|
162
|
+
else:
|
|
163
|
+
out.append("_")
|
|
164
|
+
return "".join(out)
|
|
165
|
+
|