structuremappingmemory 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. sma/__init__.py +5 -0
  2. sma/__main__.py +5 -0
  3. sma/agent/__init__.py +5 -0
  4. sma/agent/adapter_draft.py +217 -0
  5. sma/agent/api.py +67 -0
  6. sma/agent/comparison.py +591 -0
  7. sma/agent/llm.py +280 -0
  8. sma/agent/policies.py +21 -0
  9. sma/agent/service.py +95 -0
  10. sma/cli.py +65 -0
  11. sma/encoders/__init__.py +38 -0
  12. sma/encoders/agentobs.py +27 -0
  13. sma/encoders/base.py +23 -0
  14. sma/encoders/code_treesitter.py +64 -0
  15. sma/encoders/coverage.py +80 -0
  16. sma/encoders/draft_adapter.py +183 -0
  17. sma/encoders/healthcare.py +207 -0
  18. sma/encoders/logs_drain.py +142 -0
  19. sma/encoders/prose_tier1.py +57 -0
  20. sma/encoders/structured.py +57 -0
  21. sma/encoders/traces.py +45 -0
  22. sma/eval/__init__.py +2 -0
  23. sma/eval/agentic/__init__.py +35 -0
  24. sma/eval/agentic/arms/__init__.py +0 -0
  25. sma/eval/agentic/arms/cyber.py +48 -0
  26. sma/eval/agentic/arms/discovery.py +35 -0
  27. sma/eval/agentic/arms/finance.py +38 -0
  28. sma/eval/agentic/arms/legal.py +74 -0
  29. sma/eval/agentic/arms/medicine.py +45 -0
  30. sma/eval/agentic/harness.py +275 -0
  31. sma/eval/agentic/memories.py +308 -0
  32. sma/eval/agentic/metrics.py +82 -0
  33. sma/eval/agentic_qa/__init__.py +27 -0
  34. sma/eval/agentic_qa/agent.py +383 -0
  35. sma/eval/agentic_qa/metrics.py +239 -0
  36. sma/eval/agentic_qa/pools.py +197 -0
  37. sma/eval/arn.py +65 -0
  38. sma/eval/baselines/__init__.py +6 -0
  39. sma/eval/baselines/bge_dense.py +54 -0
  40. sma/eval/baselines/bm25.py +18 -0
  41. sma/eval/baselines/dense.py +42 -0
  42. sma/eval/baselines/hipporag.py +235 -0
  43. sma/eval/baselines/hybrid_rrf.py +30 -0
  44. sma/eval/baselines/longcontext_llm.py +124 -0
  45. sma/eval/baselines/rerank.py +41 -0
  46. sma/eval/baselines/splade.py +77 -0
  47. sma/eval/baselines/wl_kernel.py +163 -0
  48. sma/eval/bugsinpy.py +358 -0
  49. sma/eval/bugsinpy_families.py +164 -0
  50. sma/eval/crossdomain.py +89 -0
  51. sma/eval/diabetes.py +61 -0
  52. sma/eval/drift_env.py +26 -0
  53. sma/eval/drift_metrics.py +24 -0
  54. sma/eval/family_labels.py +167 -0
  55. sma/eval/fraud_elliptic/__init__.py +29 -0
  56. sma/eval/fraud_elliptic/encoder.py +279 -0
  57. sma/eval/fraud_elliptic/eval.py +269 -0
  58. sma/eval/fraud_elliptic/test_encoder.py +123 -0
  59. sma/eval/ieee_cis.py +66 -0
  60. sma/eval/loghub.py +16 -0
  61. sma/eval/loghub_eval.py +480 -0
  62. sma/eval/longmemeval.py +51 -0
  63. sma/eval/memory_backends/__init__.py +2 -0
  64. sma/eval/memory_backends/base.py +22 -0
  65. sma/eval/memory_backends/context_only.py +14 -0
  66. sma/eval/memory_backends/rag_notes.py +17 -0
  67. sma/eval/memory_backends/shared_llm.py +30 -0
  68. sma/eval/memory_backends/sma_memory.py +54 -0
  69. sma/eval/memory_backends/zep_graphiti.py +33 -0
  70. sma/eval/metrics.py +32 -0
  71. sma/eval/ontology_bench.py +219 -0
  72. sma/eval/report.py +573 -0
  73. sma/eval/ssb_eval.py +216 -0
  74. sma/eval/ssb_generator.py +116 -0
  75. sma/eval/stats.py +108 -0
  76. sma/eval/transfer_eval.py +844 -0
  77. sma/index/__init__.py +15 -0
  78. sma/index/ann.py +21 -0
  79. sma/index/content_vectors.py +60 -0
  80. sma/index/inverted.py +63 -0
  81. sma/index/macfac.py +174 -0
  82. sma/ir/__init__.py +22 -0
  83. sma/ir/canon.py +106 -0
  84. sma/ir/schema.py +165 -0
  85. sma/ir/sexpr.py +86 -0
  86. sma/ir/signatures.py +76 -0
  87. sma/match/__init__.py +20 -0
  88. sma/match/conflicts.py +46 -0
  89. sma/match/engine.py +60 -0
  90. sma/match/explain.py +59 -0
  91. sma/match/infer.py +54 -0
  92. sma/match/kernels.py +54 -0
  93. sma/match/mdl.py +30 -0
  94. sma/match/merge_cpsat.py +77 -0
  95. sma/match/merge_greedy.py +15 -0
  96. sma/match/mh.py +177 -0
  97. sma/match/ses.py +84 -0
  98. sma/match/types.py +115 -0
  99. sma/match/verifier.py +27 -0
  100. sma/ontology/__init__.py +45 -0
  101. sma/ontology/attack.py +134 -0
  102. sma/ontology/cpc.py +69 -0
  103. sma/ontology/graph.py +58 -0
  104. sma/ontology/loader.py +262 -0
  105. sma/ontology/mitre_xml.py +67 -0
  106. sma/ontology/mount.py +101 -0
  107. sma/ontology/rdf_loader.py +75 -0
  108. sma/ontology/registry.py +115 -0
  109. sma/ontology/router.py +69 -0
  110. sma/ontology/usgaap.py +73 -0
  111. sma/sage/__init__.py +6 -0
  112. sma/sage/assimilate.py +12 -0
  113. sma/sage/pools.py +105 -0
  114. sma/sage/probabilities.py +10 -0
  115. sma/store/__init__.py +6 -0
  116. sma/store/lmdb_store.py +78 -0
  117. sma/store/registry.py +26 -0
  118. sma/store/wal.py +26 -0
  119. sma/ui/app.py +642 -0
  120. structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
  121. structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
  122. structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
  123. structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
  124. structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
  125. structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
sma/ir/sexpr.py ADDED
@@ -0,0 +1,86 @@
1
+ """Canonical S-expression codec for SMA cases."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Iterable
6
+
7
+ from .schema import Entity, Node, Statement, entity
8
+
9
+
10
+ def dumps_node(node: Node) -> str:
11
+ if isinstance(node, Entity):
12
+ return node.name
13
+ return dumps_statement(node)
14
+
15
+
16
+ def dumps_statement(statement: Statement) -> str:
17
+ if not statement.args:
18
+ return f"({statement.functor})"
19
+ args = " ".join(dumps_node(arg) for arg in statement.args)
20
+ return f"({statement.functor} {args})"
21
+
22
+
23
+ def canonical_case_text(statements: Iterable[Statement]) -> str:
24
+ return "\n".join(sorted(dumps_statement(stmt) for stmt in statements))
25
+
26
+
27
+ def loads_statement(text: str) -> Statement:
28
+ tokens = _tokenize(text)
29
+ node, pos = _parse(tokens, 0)
30
+ if pos != len(tokens):
31
+ raise ValueError(f"trailing tokens after S-expression: {tokens[pos:]}")
32
+ if not isinstance(node, Statement):
33
+ raise ValueError("top-level S-expression must be a statement")
34
+ return node
35
+
36
+
37
+ def loads_case(text: str) -> tuple[Statement, ...]:
38
+ statements: list[Statement] = []
39
+ tokens = _tokenize(text)
40
+ pos = 0
41
+ while pos < len(tokens):
42
+ node, pos = _parse(tokens, pos)
43
+ if not isinstance(node, Statement):
44
+ raise ValueError("case entries must be statements")
45
+ statements.append(node)
46
+ return tuple(statements)
47
+
48
+
49
+ def _tokenize(text: str) -> list[str]:
50
+ tokens: list[str] = []
51
+ token = []
52
+ for ch in text:
53
+ if ch in "()":
54
+ if token:
55
+ tokens.append("".join(token))
56
+ token = []
57
+ tokens.append(ch)
58
+ elif ch.isspace():
59
+ if token:
60
+ tokens.append("".join(token))
61
+ token = []
62
+ else:
63
+ token.append(ch)
64
+ if token:
65
+ tokens.append("".join(token))
66
+ return tokens
67
+
68
+
69
+ def _parse(tokens: list[str], pos: int) -> tuple[Node, int]:
70
+ if pos >= len(tokens):
71
+ raise ValueError("unexpected end of input")
72
+ tok = tokens[pos]
73
+ if tok != "(":
74
+ return entity(tok), pos + 1
75
+ if pos + 1 >= len(tokens):
76
+ raise ValueError("missing functor")
77
+ functor = tokens[pos + 1]
78
+ args: list[Node] = []
79
+ pos += 2
80
+ while pos < len(tokens) and tokens[pos] != ")":
81
+ arg, pos = _parse(tokens, pos)
82
+ args.append(arg)
83
+ if pos >= len(tokens) or tokens[pos] != ")":
84
+ raise ValueError("unclosed statement")
85
+ return Statement(functor, tuple(args)), pos + 1
86
+
sma/ir/signatures.py ADDED
@@ -0,0 +1,76 @@
1
+ """Signature registry and type validation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+
7
+ from .schema import Entity, Signature, Statement, SymbolKind
8
+
9
+
10
+ DEFAULT_SIGNATURES = (
11
+ Signature("before", 2, higher_order=True),
12
+ Signature("after", 2, higher_order=True),
13
+ Signature("cause", 2, higher_order=True),
14
+ Signature("implies", 2, higher_order=True),
15
+ Signature("enables", 2, higher_order=True),
16
+ Signature("prevents", 2, higher_order=True),
17
+ Signature("during", 2, higher_order=True),
18
+ Signature("count", 2),
19
+ Signature("burst", 2),
20
+ Signature("frame", 4),
21
+ Signature("calledFrom", 2, higher_order=True),
22
+ Signature("defines", 2),
23
+ Signature("calls", 2),
24
+ Signature("imports", 2),
25
+ Signature("throws", 2),
26
+ Signature("catches", 2),
27
+ Signature("adds", 1),
28
+ Signature("removes", 1),
29
+ Signature("modifies", 1),
30
+ )
31
+
32
+
33
+ @dataclass
34
+ class SignatureRegistry:
35
+ signatures: dict[str, Signature] = field(default_factory=dict)
36
+
37
+ @classmethod
38
+ def with_defaults(cls) -> "SignatureRegistry":
39
+ registry = cls()
40
+ for sig in DEFAULT_SIGNATURES:
41
+ registry.register(sig)
42
+ return registry
43
+
44
+ def register(self, signature: Signature) -> None:
45
+ self.signatures[signature.functor] = signature
46
+
47
+ def get(self, functor: str, arity: int | None = None) -> Signature:
48
+ if functor in self.signatures:
49
+ sig = self.signatures[functor]
50
+ if arity is not None and sig.arity != arity:
51
+ raise ValueError(f"{functor} registered with arity {sig.arity}, got {arity}")
52
+ return sig
53
+ return Signature(functor=functor, arity=arity if arity is not None else -1)
54
+
55
+ def validate_statement(self, statement: Statement) -> None:
56
+ sig = self.get(statement.functor, statement.arity)
57
+ if sig.arity >= 0:
58
+ sig.validate_arity(statement.args)
59
+ for arg in statement.args:
60
+ if isinstance(arg, Statement):
61
+ self.validate_statement(arg)
62
+ elif not isinstance(arg, Entity):
63
+ raise TypeError(f"unsupported argument node: {arg!r}")
64
+
65
+ def validate_case(self, statements: tuple[Statement, ...]) -> None:
66
+ for statement in statements:
67
+ self.validate_statement(statement)
68
+
69
+
70
+ def infer_kind(functor: str, arity: int) -> SymbolKind:
71
+ if arity == 0:
72
+ return SymbolKind.ATTRIBUTE
73
+ if arity == 1:
74
+ return SymbolKind.ATTRIBUTE
75
+ return SymbolKind.RELATION
76
+
sma/match/__init__.py ADDED
@@ -0,0 +1,20 @@
1
+ from .engine import match_cases
2
+ from .explain import correspondence_table, explain_text
3
+ from .infer import candidate_inferences
4
+ from .types import CandidateInference, GMap, Kernel, MatchConfig, MatchHypothesis
5
+ from .verifier import VerificationResult, verify_inference
6
+
7
+ __all__ = [
8
+ "CandidateInference",
9
+ "GMap",
10
+ "Kernel",
11
+ "MatchConfig",
12
+ "MatchHypothesis",
13
+ "VerificationResult",
14
+ "candidate_inferences",
15
+ "correspondence_table",
16
+ "explain_text",
17
+ "match_cases",
18
+ "verify_inference",
19
+ ]
20
+
sma/match/conflicts.py ADDED
@@ -0,0 +1,46 @@
1
+ """Structural consistency checks and kernel conflict graph."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from itertools import combinations
6
+
7
+ from .types import Kernel, MatchHypothesis
8
+
9
+
10
+ def structurally_consistent(hypotheses: tuple[MatchHypothesis, ...]) -> bool:
11
+ base_to_target: dict[str, str] = {}
12
+ target_to_base: dict[str, str] = {}
13
+ for mh in hypotheses:
14
+ if mh.base_key in base_to_target and base_to_target[mh.base_key] != mh.target_key:
15
+ return False
16
+ if mh.target_key in target_to_base and target_to_base[mh.target_key] != mh.base_key:
17
+ return False
18
+ base_to_target[mh.base_key] = mh.target_key
19
+ target_to_base[mh.target_key] = mh.base_key
20
+ return True
21
+
22
+
23
+ def kernels_conflict(left: Kernel, right: Kernel) -> bool:
24
+ # Iterate over the smaller binding table; both tables are cached on the
25
+ # kernels, so each check is O(min(|left|, |right|)) hash probes.
26
+ if len(left.bindings) > len(right.bindings):
27
+ left, right = right, left
28
+ right_bindings = right.bindings
29
+ right_reverse = right.reverse_bindings
30
+ for b_key, t_key in left.bindings.items():
31
+ other_t = right_bindings.get(b_key)
32
+ if other_t is not None and other_t != t_key:
33
+ return True
34
+ other_b = right_reverse.get(t_key)
35
+ if other_b is not None and other_b != b_key:
36
+ return True
37
+ return False
38
+
39
+
40
+ def conflict_edges(kernels: tuple[Kernel, ...]) -> set[tuple[int, int]]:
41
+ edges: set[tuple[int, int]] = set()
42
+ for i, j in combinations(range(len(kernels)), 2):
43
+ if kernels_conflict(kernels[i], kernels[j]):
44
+ edges.add((i, j))
45
+ return edges
46
+
sma/match/engine.py ADDED
@@ -0,0 +1,60 @@
1
+ """Top-level mapping engine."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from sma.ir.canon import Canonicalizer, default_canonicalizer
6
+ from sma.ir.schema import Case, Statement
7
+
8
+ from .kernels import build_kernels
9
+ from .mdl import mdl_gain
10
+ from .merge_cpsat import exact_or_greedy_merge
11
+ from .ses import normalize_score, structural_evaluation
12
+ from .types import GMap, MatchConfig, MatchHypothesis
13
+
14
+
15
+ def match_cases(
16
+ base: Case,
17
+ target: Case,
18
+ config: MatchConfig | None = None,
19
+ canon: Canonicalizer | None = None,
20
+ ) -> GMap:
21
+ config = config or MatchConfig()
22
+ canon = canon or default_canonicalizer()
23
+ cost_fn = None
24
+ if config.scorer == "surprisal" and config.functor_costs:
25
+ costs = config.functor_costs
26
+
27
+ def cost_fn(mh):
28
+ if isinstance(mh.base, Statement):
29
+ return costs.get(canon.canonical(mh.base.functor), 1.0)
30
+ return 1.0
31
+
32
+ kernels = build_kernels(base, target, config=config, canon=canon, cost_fn=cost_fn)
33
+ selected, gap = exact_or_greedy_merge(
34
+ kernels, exact_limit=config.exact_kernel_limit, time_budget_ms=config.cpsat_time_ms
35
+ )
36
+ unique: dict[tuple[str, str], MatchHypothesis] = {}
37
+ for kernel in selected:
38
+ for mh in kernel.hypotheses:
39
+ unique[mh.key] = mh
40
+ hypotheses = tuple(unique.values())
41
+ if config.scorer == "mdl":
42
+ score = mdl_gain(hypotheses, target)
43
+ normalized = score / max(len(target.expressions()), 1)
44
+ else:
45
+ score = structural_evaluation(hypotheses, gamma=config.gamma, cost_fn=cost_fn)
46
+ normalized = normalize_score(
47
+ score, base, target, gamma=config.gamma, cost_fn=cost_fn,
48
+ normalization=config.normalization,
49
+ )
50
+ return GMap(
51
+ base=base,
52
+ target=target,
53
+ hypotheses=hypotheses,
54
+ kernels=selected,
55
+ score=score,
56
+ normalized_score=normalized,
57
+ scorer=config.scorer,
58
+ optimality_gap=gap,
59
+ )
60
+
sma/match/explain.py ADDED
@@ -0,0 +1,59 @@
1
+ """Human-readable mapping explanations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from .types import GMap
6
+
7
+
8
+ def correspondence_table(gmap: GMap) -> list[dict[str, str]]:
9
+ rows: list[dict[str, str]] = []
10
+ for mh in gmap.hypotheses:
11
+ rows.append(
12
+ {
13
+ "base": mh.base_key,
14
+ "target": mh.target_key,
15
+ "ascension": f"{mh.ascension:.3f}",
16
+ "ancestor": mh.ancestor or "",
17
+ }
18
+ )
19
+ return rows
20
+
21
+
22
+ def explain_text(gmap: GMap) -> str:
23
+ lines = [
24
+ f"gmap score={gmap.score:.3f} SES_n={gmap.normalized_score:.3f} scorer={gmap.scorer}",
25
+ "correspondences:",
26
+ ]
27
+ for row in correspondence_table(gmap):
28
+ lines.append(f"- {row['base']} -> {row['target']} asc={row['ascension']} {row['ancestor']}")
29
+ return "\n".join(lines)
30
+
31
+
32
+
33
+ def alignment_summary(gmap: GMap) -> str:
34
+ """One-line 'why this precedent matched' for verbalizers and evidence panels.
35
+
36
+ Aggregates matched statement-level correspondences by canonical functor so
37
+ an LLM (or human) sees the shared STRUCTURE (e.g. 'kernelEvent x4,
38
+ failureEvent x4, before x3') instead of re-deriving similarity from raw
39
+ text semantics.
40
+ """
41
+ from collections import Counter
42
+
43
+ from sma.ir.canon import default_canonicalizer
44
+ from sma.ir.schema import Statement
45
+
46
+ canon = default_canonicalizer()
47
+ matched: Counter[str] = Counter()
48
+ for mh in gmap.hypotheses:
49
+ if isinstance(mh.base, Statement):
50
+ functor = canon.canonical(mh.base.functor)
51
+ if functor != "logSession":
52
+ matched[functor] += 1
53
+ if not matched:
54
+ return "no statement-level correspondences"
55
+ parts = ", ".join(
56
+ f"{functor} x{count}" if count > 1 else functor
57
+ for functor, count in matched.most_common(6)
58
+ )
59
+ return f"shared structure: {parts}; ses_n={gmap.normalized_score:.2f}"
sma/match/infer.py ADDED
@@ -0,0 +1,54 @@
1
+ """Candidate inference projection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from sma.ir.schema import Entity, Node, Statement
6
+ from sma.ir.sexpr import dumps_statement
7
+
8
+ from .types import CandidateInference, GMap, node_key
9
+
10
+
11
+ def candidate_inferences(gmap: GMap) -> tuple[CandidateInference, ...]:
12
+ base_to_target: dict[str, Node] = {mh.base_key: mh.target for mh in gmap.hypotheses}
13
+ mapped_statement_keys = {
14
+ mh.base_key for mh in gmap.hypotheses if isinstance(mh.base, Statement)
15
+ }
16
+ out: list[CandidateInference] = []
17
+ skolem_counter = 0
18
+
19
+ def project(node: Node) -> Node:
20
+ nonlocal skolem_counter
21
+ mapped = base_to_target.get(node_key(node))
22
+ if mapped is not None:
23
+ return mapped
24
+ if isinstance(node, Entity):
25
+ skolem_counter += 1
26
+ return Entity(f"AnalogySkolemFn_{skolem_counter}", node.type)
27
+ return Statement(node.functor, tuple(project(arg) for arg in node.args), ascension=node.ascension)
28
+
29
+ for statement in gmap.base.statements:
30
+ if node_key(statement) in mapped_statement_keys:
31
+ continue
32
+ if not any(node_key(entity) in base_to_target for entity in statement.entities()):
33
+ continue
34
+ skolem_counter = 0
35
+ projected = project(statement)
36
+ skolems = tuple(
37
+ entity.name for entity in projected.entities() if entity.name.startswith("AnalogySkolemFn_")
38
+ )
39
+ ascensions = tuple(
40
+ f"{mh.base_key}->{mh.ancestor}" for mh in gmap.hypotheses if mh.distance > 0 and mh.ancestor
41
+ )
42
+ out.append(
43
+ CandidateInference(
44
+ inference_sexpr=dumps_statement(projected),
45
+ base_case_id=gmap.base.case_id,
46
+ target_case_id=gmap.target.case_id,
47
+ ses_n=gmap.normalized_score,
48
+ support=tuple(sorted(mapped_statement_keys)),
49
+ skolems=skolems,
50
+ ascensions=ascensions,
51
+ )
52
+ )
53
+ return tuple(out)
54
+
sma/match/kernels.py ADDED
@@ -0,0 +1,54 @@
1
+ """Kernel construction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from sma.ir.canon import Canonicalizer, default_canonicalizer
6
+ from sma.ir.schema import Case
7
+
8
+ from sma.ir.schema import Statement
9
+
10
+ from .conflicts import structurally_consistent
11
+ from .mh import seed_expression_mhs, support_closure
12
+ from .ses import structural_evaluation
13
+ from .types import Kernel, MatchConfig, node_key
14
+
15
+
16
+ def build_kernels(
17
+ base: Case,
18
+ target: Case,
19
+ config: MatchConfig | None = None,
20
+ canon: Canonicalizer | None = None,
21
+ cost_fn=None,
22
+ ) -> tuple[Kernel, ...]:
23
+ config = config or MatchConfig()
24
+ canon = canon or default_canonicalizer()
25
+ seeds = seed_expression_mhs(base.expressions(), target.expressions(), config=config, canon=canon)
26
+ # Root MHs only (blueprint section 2.2): a seed that appears as an argument
27
+ # pair of another seed is covered by its parent's support closure, so giving
28
+ # it its own kernel only inflates the merge problem.
29
+ child_keys: set[tuple[str, str]] = set()
30
+ for seed in seeds:
31
+ if isinstance(seed.base, Statement) and isinstance(seed.target, Statement):
32
+ for b_arg, t_arg in zip(seed.base.args, seed.target.args):
33
+ if isinstance(b_arg, Statement) and isinstance(t_arg, Statement):
34
+ child_keys.add((node_key(b_arg), node_key(t_arg)))
35
+ kernels: list[Kernel] = []
36
+ seen: set[tuple[tuple[str, str], ...]] = set()
37
+ for seed in seeds:
38
+ if seed.key in child_keys:
39
+ continue
40
+ closure = support_closure(seed, canon=canon, delta=config.delta, rho=config.rho)
41
+ if closure is None:
42
+ # Root pairs unequal constants (e.g. different template-name
43
+ # entities under count): structurally impossible, discard.
44
+ continue
45
+ if not structurally_consistent(closure):
46
+ continue
47
+ key = tuple(sorted(mh.key for mh in closure))
48
+ if key in seen:
49
+ continue
50
+ seen.add(key)
51
+ weight = structural_evaluation(closure, gamma=config.gamma, cost_fn=cost_fn)
52
+ kernels.append(Kernel(seed, closure, weight=weight))
53
+ return tuple(kernels)
54
+
sma/match/mdl.py ADDED
@@ -0,0 +1,30 @@
1
+ """Parameter-free MDL-like scorer."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from collections import Counter
7
+
8
+ from sma.ir.schema import Case
9
+
10
+ from .types import MatchHypothesis
11
+
12
+
13
+ def corpus_functor_costs(cases: list[Case]) -> dict[str, float]:
14
+ counts: Counter[str] = Counter()
15
+ total = 0
16
+ for case in cases:
17
+ for functor, n in case.functor_counts().items():
18
+ counts[functor] += n
19
+ total += n
20
+ vocab = max(len(counts), 1)
21
+ return {functor: -math.log2((count + 0.5) / (total + 0.5 * vocab)) for functor, count in counts.items()}
22
+
23
+
24
+ def mdl_gain(hypotheses: tuple[MatchHypothesis, ...], target: Case) -> float:
25
+ costs = corpus_functor_costs([target])
26
+ matched_target_exprs = {
27
+ mh.target.functor for mh in hypotheses if hasattr(mh.target, "functor")
28
+ }
29
+ return sum(costs.get(functor, 1.0) for functor in matched_target_exprs)
30
+
@@ -0,0 +1,77 @@
1
+ """Exact-anytime MWIS merge using Google OR-Tools CP-SAT solver."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from .conflicts import conflict_edges
6
+ from .merge_greedy import greedy_merge
7
+ from .types import Kernel
8
+
9
+ try:
10
+ from ortools.sat.python import cp_model
11
+ except ImportError: # pragma: no cover - ortools is a hard dependency in pyproject
12
+ cp_model = None
13
+
14
+ # Brute-force enumeration cap when ortools is unavailable (2^12 masks max).
15
+ _BRUTE_FORCE_LIMIT = 12
16
+
17
+
18
+ def exact_or_greedy_merge(
19
+ kernels: tuple[Kernel, ...],
20
+ exact_limit: int = 60,
21
+ time_budget_ms: int = 20,
22
+ ) -> tuple[tuple[Kernel, ...], float | None]:
23
+ """Select a maximum-weight independent set of kernels.
24
+
25
+ Returns (selected kernels, optimality gap). Gap 0.0 means certified
26
+ optimal; None means the greedy/anytime fallback was used without a
27
+ certificate.
28
+ """
29
+ if not kernels:
30
+ return (), 0.0
31
+
32
+ # Greedy is the published O(n^2 log n) fallback; use it directly when the
33
+ # MIP model would cost more than it buys (large conflict-kernel counts).
34
+ if len(kernels) > exact_limit:
35
+ return greedy_merge(kernels), None
36
+
37
+ edges = conflict_edges(kernels)
38
+
39
+ if cp_model is not None:
40
+ model = cp_model.CpModel()
41
+ n = len(kernels)
42
+ x = [model.NewBoolVar(f"x_{i}") for i in range(n)]
43
+ for i, j in edges:
44
+ model.Add(x[i] + x[j] <= 1)
45
+ # CP-SAT needs integer weights; 1e5 scaling keeps 5 decimal places.
46
+ scaled_weights = [int(round(max(k.weight, 0.0) * 100000)) for k in kernels]
47
+ model.Maximize(sum(scaled_weights[i] * x[i] for i in range(n)))
48
+
49
+ solver = cp_model.CpSolver()
50
+ solver.parameters.max_time_in_seconds = time_budget_ms / 1000.0
51
+ status = solver.Solve(model)
52
+ if status in (cp_model.OPTIMAL, cp_model.FEASIBLE):
53
+ selected = tuple(kernels[i] for i in range(n) if solver.Value(x[i]))
54
+ gap = 0.0 if status == cp_model.OPTIMAL else None
55
+ return selected, gap
56
+ return greedy_merge(kernels), None
57
+
58
+ # No ortools: exact enumeration for tiny instances, greedy otherwise.
59
+ n = len(kernels)
60
+ if n > _BRUTE_FORCE_LIMIT:
61
+ return greedy_merge(kernels), None
62
+ best_mask = 0
63
+ best_weight = float("-inf")
64
+ for mask in range(1 << n):
65
+ ok = True
66
+ for i, j in edges:
67
+ if (mask & (1 << i)) and (mask & (1 << j)):
68
+ ok = False
69
+ break
70
+ if not ok:
71
+ continue
72
+ weight = sum(kernels[i].weight for i in range(n) if mask & (1 << i))
73
+ if weight > best_weight:
74
+ best_weight = weight
75
+ best_mask = mask
76
+ selected = tuple(kernels[i] for i in range(n) if best_mask & (1 << i))
77
+ return selected, 0.0
@@ -0,0 +1,15 @@
1
+ """Published-style greedy kernel merge fallback."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from .conflicts import kernels_conflict
6
+ from .types import Kernel
7
+
8
+
9
+ def greedy_merge(kernels: tuple[Kernel, ...]) -> tuple[Kernel, ...]:
10
+ selected: list[Kernel] = []
11
+ for kernel in sorted(kernels, key=lambda k: (-k.weight, k.root.base_key, k.root.target_key)):
12
+ if all(not kernels_conflict(kernel, existing) for existing in selected):
13
+ selected.append(kernel)
14
+ return tuple(selected)
15
+