structuremappingmemory 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sma/__init__.py +5 -0
- sma/__main__.py +5 -0
- sma/agent/__init__.py +5 -0
- sma/agent/adapter_draft.py +217 -0
- sma/agent/api.py +67 -0
- sma/agent/comparison.py +591 -0
- sma/agent/llm.py +280 -0
- sma/agent/policies.py +21 -0
- sma/agent/service.py +95 -0
- sma/cli.py +65 -0
- sma/encoders/__init__.py +38 -0
- sma/encoders/agentobs.py +27 -0
- sma/encoders/base.py +23 -0
- sma/encoders/code_treesitter.py +64 -0
- sma/encoders/coverage.py +80 -0
- sma/encoders/draft_adapter.py +183 -0
- sma/encoders/healthcare.py +207 -0
- sma/encoders/logs_drain.py +142 -0
- sma/encoders/prose_tier1.py +57 -0
- sma/encoders/structured.py +57 -0
- sma/encoders/traces.py +45 -0
- sma/eval/__init__.py +2 -0
- sma/eval/agentic/__init__.py +35 -0
- sma/eval/agentic/arms/__init__.py +0 -0
- sma/eval/agentic/arms/cyber.py +48 -0
- sma/eval/agentic/arms/discovery.py +35 -0
- sma/eval/agentic/arms/finance.py +38 -0
- sma/eval/agentic/arms/legal.py +74 -0
- sma/eval/agentic/arms/medicine.py +45 -0
- sma/eval/agentic/harness.py +275 -0
- sma/eval/agentic/memories.py +308 -0
- sma/eval/agentic/metrics.py +82 -0
- sma/eval/agentic_qa/__init__.py +27 -0
- sma/eval/agentic_qa/agent.py +383 -0
- sma/eval/agentic_qa/metrics.py +239 -0
- sma/eval/agentic_qa/pools.py +197 -0
- sma/eval/arn.py +65 -0
- sma/eval/baselines/__init__.py +6 -0
- sma/eval/baselines/bge_dense.py +54 -0
- sma/eval/baselines/bm25.py +18 -0
- sma/eval/baselines/dense.py +42 -0
- sma/eval/baselines/hipporag.py +235 -0
- sma/eval/baselines/hybrid_rrf.py +30 -0
- sma/eval/baselines/longcontext_llm.py +124 -0
- sma/eval/baselines/rerank.py +41 -0
- sma/eval/baselines/splade.py +77 -0
- sma/eval/baselines/wl_kernel.py +163 -0
- sma/eval/bugsinpy.py +358 -0
- sma/eval/bugsinpy_families.py +164 -0
- sma/eval/crossdomain.py +89 -0
- sma/eval/diabetes.py +61 -0
- sma/eval/drift_env.py +26 -0
- sma/eval/drift_metrics.py +24 -0
- sma/eval/family_labels.py +167 -0
- sma/eval/fraud_elliptic/__init__.py +29 -0
- sma/eval/fraud_elliptic/encoder.py +279 -0
- sma/eval/fraud_elliptic/eval.py +269 -0
- sma/eval/fraud_elliptic/test_encoder.py +123 -0
- sma/eval/ieee_cis.py +66 -0
- sma/eval/loghub.py +16 -0
- sma/eval/loghub_eval.py +480 -0
- sma/eval/longmemeval.py +51 -0
- sma/eval/memory_backends/__init__.py +2 -0
- sma/eval/memory_backends/base.py +22 -0
- sma/eval/memory_backends/context_only.py +14 -0
- sma/eval/memory_backends/rag_notes.py +17 -0
- sma/eval/memory_backends/shared_llm.py +30 -0
- sma/eval/memory_backends/sma_memory.py +54 -0
- sma/eval/memory_backends/zep_graphiti.py +33 -0
- sma/eval/metrics.py +32 -0
- sma/eval/ontology_bench.py +219 -0
- sma/eval/report.py +573 -0
- sma/eval/ssb_eval.py +216 -0
- sma/eval/ssb_generator.py +116 -0
- sma/eval/stats.py +108 -0
- sma/eval/transfer_eval.py +844 -0
- sma/index/__init__.py +15 -0
- sma/index/ann.py +21 -0
- sma/index/content_vectors.py +60 -0
- sma/index/inverted.py +63 -0
- sma/index/macfac.py +174 -0
- sma/ir/__init__.py +22 -0
- sma/ir/canon.py +106 -0
- sma/ir/schema.py +165 -0
- sma/ir/sexpr.py +86 -0
- sma/ir/signatures.py +76 -0
- sma/match/__init__.py +20 -0
- sma/match/conflicts.py +46 -0
- sma/match/engine.py +60 -0
- sma/match/explain.py +59 -0
- sma/match/infer.py +54 -0
- sma/match/kernels.py +54 -0
- sma/match/mdl.py +30 -0
- sma/match/merge_cpsat.py +77 -0
- sma/match/merge_greedy.py +15 -0
- sma/match/mh.py +177 -0
- sma/match/ses.py +84 -0
- sma/match/types.py +115 -0
- sma/match/verifier.py +27 -0
- sma/ontology/__init__.py +45 -0
- sma/ontology/attack.py +134 -0
- sma/ontology/cpc.py +69 -0
- sma/ontology/graph.py +58 -0
- sma/ontology/loader.py +262 -0
- sma/ontology/mitre_xml.py +67 -0
- sma/ontology/mount.py +101 -0
- sma/ontology/rdf_loader.py +75 -0
- sma/ontology/registry.py +115 -0
- sma/ontology/router.py +69 -0
- sma/ontology/usgaap.py +73 -0
- sma/sage/__init__.py +6 -0
- sma/sage/assimilate.py +12 -0
- sma/sage/pools.py +105 -0
- sma/sage/probabilities.py +10 -0
- sma/store/__init__.py +6 -0
- sma/store/lmdb_store.py +78 -0
- sma/store/registry.py +26 -0
- sma/store/wal.py +26 -0
- sma/ui/app.py +642 -0
- structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
- structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
- structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
- structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
- structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
- structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
sma/ir/sexpr.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Canonical S-expression codec for SMA cases."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable
|
|
6
|
+
|
|
7
|
+
from .schema import Entity, Node, Statement, entity
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def dumps_node(node: Node) -> str:
|
|
11
|
+
if isinstance(node, Entity):
|
|
12
|
+
return node.name
|
|
13
|
+
return dumps_statement(node)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def dumps_statement(statement: Statement) -> str:
|
|
17
|
+
if not statement.args:
|
|
18
|
+
return f"({statement.functor})"
|
|
19
|
+
args = " ".join(dumps_node(arg) for arg in statement.args)
|
|
20
|
+
return f"({statement.functor} {args})"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def canonical_case_text(statements: Iterable[Statement]) -> str:
|
|
24
|
+
return "\n".join(sorted(dumps_statement(stmt) for stmt in statements))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def loads_statement(text: str) -> Statement:
|
|
28
|
+
tokens = _tokenize(text)
|
|
29
|
+
node, pos = _parse(tokens, 0)
|
|
30
|
+
if pos != len(tokens):
|
|
31
|
+
raise ValueError(f"trailing tokens after S-expression: {tokens[pos:]}")
|
|
32
|
+
if not isinstance(node, Statement):
|
|
33
|
+
raise ValueError("top-level S-expression must be a statement")
|
|
34
|
+
return node
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def loads_case(text: str) -> tuple[Statement, ...]:
|
|
38
|
+
statements: list[Statement] = []
|
|
39
|
+
tokens = _tokenize(text)
|
|
40
|
+
pos = 0
|
|
41
|
+
while pos < len(tokens):
|
|
42
|
+
node, pos = _parse(tokens, pos)
|
|
43
|
+
if not isinstance(node, Statement):
|
|
44
|
+
raise ValueError("case entries must be statements")
|
|
45
|
+
statements.append(node)
|
|
46
|
+
return tuple(statements)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _tokenize(text: str) -> list[str]:
|
|
50
|
+
tokens: list[str] = []
|
|
51
|
+
token = []
|
|
52
|
+
for ch in text:
|
|
53
|
+
if ch in "()":
|
|
54
|
+
if token:
|
|
55
|
+
tokens.append("".join(token))
|
|
56
|
+
token = []
|
|
57
|
+
tokens.append(ch)
|
|
58
|
+
elif ch.isspace():
|
|
59
|
+
if token:
|
|
60
|
+
tokens.append("".join(token))
|
|
61
|
+
token = []
|
|
62
|
+
else:
|
|
63
|
+
token.append(ch)
|
|
64
|
+
if token:
|
|
65
|
+
tokens.append("".join(token))
|
|
66
|
+
return tokens
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _parse(tokens: list[str], pos: int) -> tuple[Node, int]:
|
|
70
|
+
if pos >= len(tokens):
|
|
71
|
+
raise ValueError("unexpected end of input")
|
|
72
|
+
tok = tokens[pos]
|
|
73
|
+
if tok != "(":
|
|
74
|
+
return entity(tok), pos + 1
|
|
75
|
+
if pos + 1 >= len(tokens):
|
|
76
|
+
raise ValueError("missing functor")
|
|
77
|
+
functor = tokens[pos + 1]
|
|
78
|
+
args: list[Node] = []
|
|
79
|
+
pos += 2
|
|
80
|
+
while pos < len(tokens) and tokens[pos] != ")":
|
|
81
|
+
arg, pos = _parse(tokens, pos)
|
|
82
|
+
args.append(arg)
|
|
83
|
+
if pos >= len(tokens) or tokens[pos] != ")":
|
|
84
|
+
raise ValueError("unclosed statement")
|
|
85
|
+
return Statement(functor, tuple(args)), pos + 1
|
|
86
|
+
|
sma/ir/signatures.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Signature registry and type validation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
|
|
7
|
+
from .schema import Entity, Signature, Statement, SymbolKind
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
DEFAULT_SIGNATURES = (
|
|
11
|
+
Signature("before", 2, higher_order=True),
|
|
12
|
+
Signature("after", 2, higher_order=True),
|
|
13
|
+
Signature("cause", 2, higher_order=True),
|
|
14
|
+
Signature("implies", 2, higher_order=True),
|
|
15
|
+
Signature("enables", 2, higher_order=True),
|
|
16
|
+
Signature("prevents", 2, higher_order=True),
|
|
17
|
+
Signature("during", 2, higher_order=True),
|
|
18
|
+
Signature("count", 2),
|
|
19
|
+
Signature("burst", 2),
|
|
20
|
+
Signature("frame", 4),
|
|
21
|
+
Signature("calledFrom", 2, higher_order=True),
|
|
22
|
+
Signature("defines", 2),
|
|
23
|
+
Signature("calls", 2),
|
|
24
|
+
Signature("imports", 2),
|
|
25
|
+
Signature("throws", 2),
|
|
26
|
+
Signature("catches", 2),
|
|
27
|
+
Signature("adds", 1),
|
|
28
|
+
Signature("removes", 1),
|
|
29
|
+
Signature("modifies", 1),
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class SignatureRegistry:
|
|
35
|
+
signatures: dict[str, Signature] = field(default_factory=dict)
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def with_defaults(cls) -> "SignatureRegistry":
|
|
39
|
+
registry = cls()
|
|
40
|
+
for sig in DEFAULT_SIGNATURES:
|
|
41
|
+
registry.register(sig)
|
|
42
|
+
return registry
|
|
43
|
+
|
|
44
|
+
def register(self, signature: Signature) -> None:
|
|
45
|
+
self.signatures[signature.functor] = signature
|
|
46
|
+
|
|
47
|
+
def get(self, functor: str, arity: int | None = None) -> Signature:
|
|
48
|
+
if functor in self.signatures:
|
|
49
|
+
sig = self.signatures[functor]
|
|
50
|
+
if arity is not None and sig.arity != arity:
|
|
51
|
+
raise ValueError(f"{functor} registered with arity {sig.arity}, got {arity}")
|
|
52
|
+
return sig
|
|
53
|
+
return Signature(functor=functor, arity=arity if arity is not None else -1)
|
|
54
|
+
|
|
55
|
+
def validate_statement(self, statement: Statement) -> None:
|
|
56
|
+
sig = self.get(statement.functor, statement.arity)
|
|
57
|
+
if sig.arity >= 0:
|
|
58
|
+
sig.validate_arity(statement.args)
|
|
59
|
+
for arg in statement.args:
|
|
60
|
+
if isinstance(arg, Statement):
|
|
61
|
+
self.validate_statement(arg)
|
|
62
|
+
elif not isinstance(arg, Entity):
|
|
63
|
+
raise TypeError(f"unsupported argument node: {arg!r}")
|
|
64
|
+
|
|
65
|
+
def validate_case(self, statements: tuple[Statement, ...]) -> None:
|
|
66
|
+
for statement in statements:
|
|
67
|
+
self.validate_statement(statement)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def infer_kind(functor: str, arity: int) -> SymbolKind:
|
|
71
|
+
if arity == 0:
|
|
72
|
+
return SymbolKind.ATTRIBUTE
|
|
73
|
+
if arity == 1:
|
|
74
|
+
return SymbolKind.ATTRIBUTE
|
|
75
|
+
return SymbolKind.RELATION
|
|
76
|
+
|
sma/match/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from .engine import match_cases
|
|
2
|
+
from .explain import correspondence_table, explain_text
|
|
3
|
+
from .infer import candidate_inferences
|
|
4
|
+
from .types import CandidateInference, GMap, Kernel, MatchConfig, MatchHypothesis
|
|
5
|
+
from .verifier import VerificationResult, verify_inference
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"CandidateInference",
|
|
9
|
+
"GMap",
|
|
10
|
+
"Kernel",
|
|
11
|
+
"MatchConfig",
|
|
12
|
+
"MatchHypothesis",
|
|
13
|
+
"VerificationResult",
|
|
14
|
+
"candidate_inferences",
|
|
15
|
+
"correspondence_table",
|
|
16
|
+
"explain_text",
|
|
17
|
+
"match_cases",
|
|
18
|
+
"verify_inference",
|
|
19
|
+
]
|
|
20
|
+
|
sma/match/conflicts.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Structural consistency checks and kernel conflict graph."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from itertools import combinations
|
|
6
|
+
|
|
7
|
+
from .types import Kernel, MatchHypothesis
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def structurally_consistent(hypotheses: tuple[MatchHypothesis, ...]) -> bool:
|
|
11
|
+
base_to_target: dict[str, str] = {}
|
|
12
|
+
target_to_base: dict[str, str] = {}
|
|
13
|
+
for mh in hypotheses:
|
|
14
|
+
if mh.base_key in base_to_target and base_to_target[mh.base_key] != mh.target_key:
|
|
15
|
+
return False
|
|
16
|
+
if mh.target_key in target_to_base and target_to_base[mh.target_key] != mh.base_key:
|
|
17
|
+
return False
|
|
18
|
+
base_to_target[mh.base_key] = mh.target_key
|
|
19
|
+
target_to_base[mh.target_key] = mh.base_key
|
|
20
|
+
return True
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def kernels_conflict(left: Kernel, right: Kernel) -> bool:
|
|
24
|
+
# Iterate over the smaller binding table; both tables are cached on the
|
|
25
|
+
# kernels, so each check is O(min(|left|, |right|)) hash probes.
|
|
26
|
+
if len(left.bindings) > len(right.bindings):
|
|
27
|
+
left, right = right, left
|
|
28
|
+
right_bindings = right.bindings
|
|
29
|
+
right_reverse = right.reverse_bindings
|
|
30
|
+
for b_key, t_key in left.bindings.items():
|
|
31
|
+
other_t = right_bindings.get(b_key)
|
|
32
|
+
if other_t is not None and other_t != t_key:
|
|
33
|
+
return True
|
|
34
|
+
other_b = right_reverse.get(t_key)
|
|
35
|
+
if other_b is not None and other_b != b_key:
|
|
36
|
+
return True
|
|
37
|
+
return False
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def conflict_edges(kernels: tuple[Kernel, ...]) -> set[tuple[int, int]]:
|
|
41
|
+
edges: set[tuple[int, int]] = set()
|
|
42
|
+
for i, j in combinations(range(len(kernels)), 2):
|
|
43
|
+
if kernels_conflict(kernels[i], kernels[j]):
|
|
44
|
+
edges.add((i, j))
|
|
45
|
+
return edges
|
|
46
|
+
|
sma/match/engine.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Top-level mapping engine."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from sma.ir.canon import Canonicalizer, default_canonicalizer
|
|
6
|
+
from sma.ir.schema import Case, Statement
|
|
7
|
+
|
|
8
|
+
from .kernels import build_kernels
|
|
9
|
+
from .mdl import mdl_gain
|
|
10
|
+
from .merge_cpsat import exact_or_greedy_merge
|
|
11
|
+
from .ses import normalize_score, structural_evaluation
|
|
12
|
+
from .types import GMap, MatchConfig, MatchHypothesis
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def match_cases(
|
|
16
|
+
base: Case,
|
|
17
|
+
target: Case,
|
|
18
|
+
config: MatchConfig | None = None,
|
|
19
|
+
canon: Canonicalizer | None = None,
|
|
20
|
+
) -> GMap:
|
|
21
|
+
config = config or MatchConfig()
|
|
22
|
+
canon = canon or default_canonicalizer()
|
|
23
|
+
cost_fn = None
|
|
24
|
+
if config.scorer == "surprisal" and config.functor_costs:
|
|
25
|
+
costs = config.functor_costs
|
|
26
|
+
|
|
27
|
+
def cost_fn(mh):
|
|
28
|
+
if isinstance(mh.base, Statement):
|
|
29
|
+
return costs.get(canon.canonical(mh.base.functor), 1.0)
|
|
30
|
+
return 1.0
|
|
31
|
+
|
|
32
|
+
kernels = build_kernels(base, target, config=config, canon=canon, cost_fn=cost_fn)
|
|
33
|
+
selected, gap = exact_or_greedy_merge(
|
|
34
|
+
kernels, exact_limit=config.exact_kernel_limit, time_budget_ms=config.cpsat_time_ms
|
|
35
|
+
)
|
|
36
|
+
unique: dict[tuple[str, str], MatchHypothesis] = {}
|
|
37
|
+
for kernel in selected:
|
|
38
|
+
for mh in kernel.hypotheses:
|
|
39
|
+
unique[mh.key] = mh
|
|
40
|
+
hypotheses = tuple(unique.values())
|
|
41
|
+
if config.scorer == "mdl":
|
|
42
|
+
score = mdl_gain(hypotheses, target)
|
|
43
|
+
normalized = score / max(len(target.expressions()), 1)
|
|
44
|
+
else:
|
|
45
|
+
score = structural_evaluation(hypotheses, gamma=config.gamma, cost_fn=cost_fn)
|
|
46
|
+
normalized = normalize_score(
|
|
47
|
+
score, base, target, gamma=config.gamma, cost_fn=cost_fn,
|
|
48
|
+
normalization=config.normalization,
|
|
49
|
+
)
|
|
50
|
+
return GMap(
|
|
51
|
+
base=base,
|
|
52
|
+
target=target,
|
|
53
|
+
hypotheses=hypotheses,
|
|
54
|
+
kernels=selected,
|
|
55
|
+
score=score,
|
|
56
|
+
normalized_score=normalized,
|
|
57
|
+
scorer=config.scorer,
|
|
58
|
+
optimality_gap=gap,
|
|
59
|
+
)
|
|
60
|
+
|
sma/match/explain.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Human-readable mapping explanations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from .types import GMap
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def correspondence_table(gmap: GMap) -> list[dict[str, str]]:
|
|
9
|
+
rows: list[dict[str, str]] = []
|
|
10
|
+
for mh in gmap.hypotheses:
|
|
11
|
+
rows.append(
|
|
12
|
+
{
|
|
13
|
+
"base": mh.base_key,
|
|
14
|
+
"target": mh.target_key,
|
|
15
|
+
"ascension": f"{mh.ascension:.3f}",
|
|
16
|
+
"ancestor": mh.ancestor or "",
|
|
17
|
+
}
|
|
18
|
+
)
|
|
19
|
+
return rows
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def explain_text(gmap: GMap) -> str:
|
|
23
|
+
lines = [
|
|
24
|
+
f"gmap score={gmap.score:.3f} SES_n={gmap.normalized_score:.3f} scorer={gmap.scorer}",
|
|
25
|
+
"correspondences:",
|
|
26
|
+
]
|
|
27
|
+
for row in correspondence_table(gmap):
|
|
28
|
+
lines.append(f"- {row['base']} -> {row['target']} asc={row['ascension']} {row['ancestor']}")
|
|
29
|
+
return "\n".join(lines)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def alignment_summary(gmap: GMap) -> str:
|
|
34
|
+
"""One-line 'why this precedent matched' for verbalizers and evidence panels.
|
|
35
|
+
|
|
36
|
+
Aggregates matched statement-level correspondences by canonical functor so
|
|
37
|
+
an LLM (or human) sees the shared STRUCTURE (e.g. 'kernelEvent x4,
|
|
38
|
+
failureEvent x4, before x3') instead of re-deriving similarity from raw
|
|
39
|
+
text semantics.
|
|
40
|
+
"""
|
|
41
|
+
from collections import Counter
|
|
42
|
+
|
|
43
|
+
from sma.ir.canon import default_canonicalizer
|
|
44
|
+
from sma.ir.schema import Statement
|
|
45
|
+
|
|
46
|
+
canon = default_canonicalizer()
|
|
47
|
+
matched: Counter[str] = Counter()
|
|
48
|
+
for mh in gmap.hypotheses:
|
|
49
|
+
if isinstance(mh.base, Statement):
|
|
50
|
+
functor = canon.canonical(mh.base.functor)
|
|
51
|
+
if functor != "logSession":
|
|
52
|
+
matched[functor] += 1
|
|
53
|
+
if not matched:
|
|
54
|
+
return "no statement-level correspondences"
|
|
55
|
+
parts = ", ".join(
|
|
56
|
+
f"{functor} x{count}" if count > 1 else functor
|
|
57
|
+
for functor, count in matched.most_common(6)
|
|
58
|
+
)
|
|
59
|
+
return f"shared structure: {parts}; ses_n={gmap.normalized_score:.2f}"
|
sma/match/infer.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Candidate inference projection."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from sma.ir.schema import Entity, Node, Statement
|
|
6
|
+
from sma.ir.sexpr import dumps_statement
|
|
7
|
+
|
|
8
|
+
from .types import CandidateInference, GMap, node_key
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def candidate_inferences(gmap: GMap) -> tuple[CandidateInference, ...]:
|
|
12
|
+
base_to_target: dict[str, Node] = {mh.base_key: mh.target for mh in gmap.hypotheses}
|
|
13
|
+
mapped_statement_keys = {
|
|
14
|
+
mh.base_key for mh in gmap.hypotheses if isinstance(mh.base, Statement)
|
|
15
|
+
}
|
|
16
|
+
out: list[CandidateInference] = []
|
|
17
|
+
skolem_counter = 0
|
|
18
|
+
|
|
19
|
+
def project(node: Node) -> Node:
|
|
20
|
+
nonlocal skolem_counter
|
|
21
|
+
mapped = base_to_target.get(node_key(node))
|
|
22
|
+
if mapped is not None:
|
|
23
|
+
return mapped
|
|
24
|
+
if isinstance(node, Entity):
|
|
25
|
+
skolem_counter += 1
|
|
26
|
+
return Entity(f"AnalogySkolemFn_{skolem_counter}", node.type)
|
|
27
|
+
return Statement(node.functor, tuple(project(arg) for arg in node.args), ascension=node.ascension)
|
|
28
|
+
|
|
29
|
+
for statement in gmap.base.statements:
|
|
30
|
+
if node_key(statement) in mapped_statement_keys:
|
|
31
|
+
continue
|
|
32
|
+
if not any(node_key(entity) in base_to_target for entity in statement.entities()):
|
|
33
|
+
continue
|
|
34
|
+
skolem_counter = 0
|
|
35
|
+
projected = project(statement)
|
|
36
|
+
skolems = tuple(
|
|
37
|
+
entity.name for entity in projected.entities() if entity.name.startswith("AnalogySkolemFn_")
|
|
38
|
+
)
|
|
39
|
+
ascensions = tuple(
|
|
40
|
+
f"{mh.base_key}->{mh.ancestor}" for mh in gmap.hypotheses if mh.distance > 0 and mh.ancestor
|
|
41
|
+
)
|
|
42
|
+
out.append(
|
|
43
|
+
CandidateInference(
|
|
44
|
+
inference_sexpr=dumps_statement(projected),
|
|
45
|
+
base_case_id=gmap.base.case_id,
|
|
46
|
+
target_case_id=gmap.target.case_id,
|
|
47
|
+
ses_n=gmap.normalized_score,
|
|
48
|
+
support=tuple(sorted(mapped_statement_keys)),
|
|
49
|
+
skolems=skolems,
|
|
50
|
+
ascensions=ascensions,
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
return tuple(out)
|
|
54
|
+
|
sma/match/kernels.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Kernel construction."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from sma.ir.canon import Canonicalizer, default_canonicalizer
|
|
6
|
+
from sma.ir.schema import Case
|
|
7
|
+
|
|
8
|
+
from sma.ir.schema import Statement
|
|
9
|
+
|
|
10
|
+
from .conflicts import structurally_consistent
|
|
11
|
+
from .mh import seed_expression_mhs, support_closure
|
|
12
|
+
from .ses import structural_evaluation
|
|
13
|
+
from .types import Kernel, MatchConfig, node_key
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def build_kernels(
|
|
17
|
+
base: Case,
|
|
18
|
+
target: Case,
|
|
19
|
+
config: MatchConfig | None = None,
|
|
20
|
+
canon: Canonicalizer | None = None,
|
|
21
|
+
cost_fn=None,
|
|
22
|
+
) -> tuple[Kernel, ...]:
|
|
23
|
+
config = config or MatchConfig()
|
|
24
|
+
canon = canon or default_canonicalizer()
|
|
25
|
+
seeds = seed_expression_mhs(base.expressions(), target.expressions(), config=config, canon=canon)
|
|
26
|
+
# Root MHs only (blueprint section 2.2): a seed that appears as an argument
|
|
27
|
+
# pair of another seed is covered by its parent's support closure, so giving
|
|
28
|
+
# it its own kernel only inflates the merge problem.
|
|
29
|
+
child_keys: set[tuple[str, str]] = set()
|
|
30
|
+
for seed in seeds:
|
|
31
|
+
if isinstance(seed.base, Statement) and isinstance(seed.target, Statement):
|
|
32
|
+
for b_arg, t_arg in zip(seed.base.args, seed.target.args):
|
|
33
|
+
if isinstance(b_arg, Statement) and isinstance(t_arg, Statement):
|
|
34
|
+
child_keys.add((node_key(b_arg), node_key(t_arg)))
|
|
35
|
+
kernels: list[Kernel] = []
|
|
36
|
+
seen: set[tuple[tuple[str, str], ...]] = set()
|
|
37
|
+
for seed in seeds:
|
|
38
|
+
if seed.key in child_keys:
|
|
39
|
+
continue
|
|
40
|
+
closure = support_closure(seed, canon=canon, delta=config.delta, rho=config.rho)
|
|
41
|
+
if closure is None:
|
|
42
|
+
# Root pairs unequal constants (e.g. different template-name
|
|
43
|
+
# entities under count): structurally impossible, discard.
|
|
44
|
+
continue
|
|
45
|
+
if not structurally_consistent(closure):
|
|
46
|
+
continue
|
|
47
|
+
key = tuple(sorted(mh.key for mh in closure))
|
|
48
|
+
if key in seen:
|
|
49
|
+
continue
|
|
50
|
+
seen.add(key)
|
|
51
|
+
weight = structural_evaluation(closure, gamma=config.gamma, cost_fn=cost_fn)
|
|
52
|
+
kernels.append(Kernel(seed, closure, weight=weight))
|
|
53
|
+
return tuple(kernels)
|
|
54
|
+
|
sma/match/mdl.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Parameter-free MDL-like scorer."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from collections import Counter
|
|
7
|
+
|
|
8
|
+
from sma.ir.schema import Case
|
|
9
|
+
|
|
10
|
+
from .types import MatchHypothesis
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def corpus_functor_costs(cases: list[Case]) -> dict[str, float]:
|
|
14
|
+
counts: Counter[str] = Counter()
|
|
15
|
+
total = 0
|
|
16
|
+
for case in cases:
|
|
17
|
+
for functor, n in case.functor_counts().items():
|
|
18
|
+
counts[functor] += n
|
|
19
|
+
total += n
|
|
20
|
+
vocab = max(len(counts), 1)
|
|
21
|
+
return {functor: -math.log2((count + 0.5) / (total + 0.5 * vocab)) for functor, count in counts.items()}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def mdl_gain(hypotheses: tuple[MatchHypothesis, ...], target: Case) -> float:
|
|
25
|
+
costs = corpus_functor_costs([target])
|
|
26
|
+
matched_target_exprs = {
|
|
27
|
+
mh.target.functor for mh in hypotheses if hasattr(mh.target, "functor")
|
|
28
|
+
}
|
|
29
|
+
return sum(costs.get(functor, 1.0) for functor in matched_target_exprs)
|
|
30
|
+
|
sma/match/merge_cpsat.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Exact-anytime MWIS merge using Google OR-Tools CP-SAT solver."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from .conflicts import conflict_edges
|
|
6
|
+
from .merge_greedy import greedy_merge
|
|
7
|
+
from .types import Kernel
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from ortools.sat.python import cp_model
|
|
11
|
+
except ImportError: # pragma: no cover - ortools is a hard dependency in pyproject
|
|
12
|
+
cp_model = None
|
|
13
|
+
|
|
14
|
+
# Brute-force enumeration cap when ortools is unavailable (2^12 masks max).
|
|
15
|
+
_BRUTE_FORCE_LIMIT = 12
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def exact_or_greedy_merge(
|
|
19
|
+
kernels: tuple[Kernel, ...],
|
|
20
|
+
exact_limit: int = 60,
|
|
21
|
+
time_budget_ms: int = 20,
|
|
22
|
+
) -> tuple[tuple[Kernel, ...], float | None]:
|
|
23
|
+
"""Select a maximum-weight independent set of kernels.
|
|
24
|
+
|
|
25
|
+
Returns (selected kernels, optimality gap). Gap 0.0 means certified
|
|
26
|
+
optimal; None means the greedy/anytime fallback was used without a
|
|
27
|
+
certificate.
|
|
28
|
+
"""
|
|
29
|
+
if not kernels:
|
|
30
|
+
return (), 0.0
|
|
31
|
+
|
|
32
|
+
# Greedy is the published O(n^2 log n) fallback; use it directly when the
|
|
33
|
+
# MIP model would cost more than it buys (large conflict-kernel counts).
|
|
34
|
+
if len(kernels) > exact_limit:
|
|
35
|
+
return greedy_merge(kernels), None
|
|
36
|
+
|
|
37
|
+
edges = conflict_edges(kernels)
|
|
38
|
+
|
|
39
|
+
if cp_model is not None:
|
|
40
|
+
model = cp_model.CpModel()
|
|
41
|
+
n = len(kernels)
|
|
42
|
+
x = [model.NewBoolVar(f"x_{i}") for i in range(n)]
|
|
43
|
+
for i, j in edges:
|
|
44
|
+
model.Add(x[i] + x[j] <= 1)
|
|
45
|
+
# CP-SAT needs integer weights; 1e5 scaling keeps 5 decimal places.
|
|
46
|
+
scaled_weights = [int(round(max(k.weight, 0.0) * 100000)) for k in kernels]
|
|
47
|
+
model.Maximize(sum(scaled_weights[i] * x[i] for i in range(n)))
|
|
48
|
+
|
|
49
|
+
solver = cp_model.CpSolver()
|
|
50
|
+
solver.parameters.max_time_in_seconds = time_budget_ms / 1000.0
|
|
51
|
+
status = solver.Solve(model)
|
|
52
|
+
if status in (cp_model.OPTIMAL, cp_model.FEASIBLE):
|
|
53
|
+
selected = tuple(kernels[i] for i in range(n) if solver.Value(x[i]))
|
|
54
|
+
gap = 0.0 if status == cp_model.OPTIMAL else None
|
|
55
|
+
return selected, gap
|
|
56
|
+
return greedy_merge(kernels), None
|
|
57
|
+
|
|
58
|
+
# No ortools: exact enumeration for tiny instances, greedy otherwise.
|
|
59
|
+
n = len(kernels)
|
|
60
|
+
if n > _BRUTE_FORCE_LIMIT:
|
|
61
|
+
return greedy_merge(kernels), None
|
|
62
|
+
best_mask = 0
|
|
63
|
+
best_weight = float("-inf")
|
|
64
|
+
for mask in range(1 << n):
|
|
65
|
+
ok = True
|
|
66
|
+
for i, j in edges:
|
|
67
|
+
if (mask & (1 << i)) and (mask & (1 << j)):
|
|
68
|
+
ok = False
|
|
69
|
+
break
|
|
70
|
+
if not ok:
|
|
71
|
+
continue
|
|
72
|
+
weight = sum(kernels[i].weight for i in range(n) if mask & (1 << i))
|
|
73
|
+
if weight > best_weight:
|
|
74
|
+
best_weight = weight
|
|
75
|
+
best_mask = mask
|
|
76
|
+
selected = tuple(kernels[i] for i in range(n) if best_mask & (1 << i))
|
|
77
|
+
return selected, 0.0
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Published-style greedy kernel merge fallback."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from .conflicts import kernels_conflict
|
|
6
|
+
from .types import Kernel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def greedy_merge(kernels: tuple[Kernel, ...]) -> tuple[Kernel, ...]:
|
|
10
|
+
selected: list[Kernel] = []
|
|
11
|
+
for kernel in sorted(kernels, key=lambda k: (-k.weight, k.root.base_key, k.root.target_key)):
|
|
12
|
+
if all(not kernels_conflict(kernel, existing) for existing in selected):
|
|
13
|
+
selected.append(kernel)
|
|
14
|
+
return tuple(selected)
|
|
15
|
+
|