structuremappingmemory 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. sma/__init__.py +5 -0
  2. sma/__main__.py +5 -0
  3. sma/agent/__init__.py +5 -0
  4. sma/agent/adapter_draft.py +217 -0
  5. sma/agent/api.py +67 -0
  6. sma/agent/comparison.py +591 -0
  7. sma/agent/llm.py +280 -0
  8. sma/agent/policies.py +21 -0
  9. sma/agent/service.py +95 -0
  10. sma/cli.py +65 -0
  11. sma/encoders/__init__.py +38 -0
  12. sma/encoders/agentobs.py +27 -0
  13. sma/encoders/base.py +23 -0
  14. sma/encoders/code_treesitter.py +64 -0
  15. sma/encoders/coverage.py +80 -0
  16. sma/encoders/draft_adapter.py +183 -0
  17. sma/encoders/healthcare.py +207 -0
  18. sma/encoders/logs_drain.py +142 -0
  19. sma/encoders/prose_tier1.py +57 -0
  20. sma/encoders/structured.py +57 -0
  21. sma/encoders/traces.py +45 -0
  22. sma/eval/__init__.py +2 -0
  23. sma/eval/agentic/__init__.py +35 -0
  24. sma/eval/agentic/arms/__init__.py +0 -0
  25. sma/eval/agentic/arms/cyber.py +48 -0
  26. sma/eval/agentic/arms/discovery.py +35 -0
  27. sma/eval/agentic/arms/finance.py +38 -0
  28. sma/eval/agentic/arms/legal.py +74 -0
  29. sma/eval/agentic/arms/medicine.py +45 -0
  30. sma/eval/agentic/harness.py +275 -0
  31. sma/eval/agentic/memories.py +308 -0
  32. sma/eval/agentic/metrics.py +82 -0
  33. sma/eval/agentic_qa/__init__.py +27 -0
  34. sma/eval/agentic_qa/agent.py +383 -0
  35. sma/eval/agentic_qa/metrics.py +239 -0
  36. sma/eval/agentic_qa/pools.py +197 -0
  37. sma/eval/arn.py +65 -0
  38. sma/eval/baselines/__init__.py +6 -0
  39. sma/eval/baselines/bge_dense.py +54 -0
  40. sma/eval/baselines/bm25.py +18 -0
  41. sma/eval/baselines/dense.py +42 -0
  42. sma/eval/baselines/hipporag.py +235 -0
  43. sma/eval/baselines/hybrid_rrf.py +30 -0
  44. sma/eval/baselines/longcontext_llm.py +124 -0
  45. sma/eval/baselines/rerank.py +41 -0
  46. sma/eval/baselines/splade.py +77 -0
  47. sma/eval/baselines/wl_kernel.py +163 -0
  48. sma/eval/bugsinpy.py +358 -0
  49. sma/eval/bugsinpy_families.py +164 -0
  50. sma/eval/crossdomain.py +89 -0
  51. sma/eval/diabetes.py +61 -0
  52. sma/eval/drift_env.py +26 -0
  53. sma/eval/drift_metrics.py +24 -0
  54. sma/eval/family_labels.py +167 -0
  55. sma/eval/fraud_elliptic/__init__.py +29 -0
  56. sma/eval/fraud_elliptic/encoder.py +279 -0
  57. sma/eval/fraud_elliptic/eval.py +269 -0
  58. sma/eval/fraud_elliptic/test_encoder.py +123 -0
  59. sma/eval/ieee_cis.py +66 -0
  60. sma/eval/loghub.py +16 -0
  61. sma/eval/loghub_eval.py +480 -0
  62. sma/eval/longmemeval.py +51 -0
  63. sma/eval/memory_backends/__init__.py +2 -0
  64. sma/eval/memory_backends/base.py +22 -0
  65. sma/eval/memory_backends/context_only.py +14 -0
  66. sma/eval/memory_backends/rag_notes.py +17 -0
  67. sma/eval/memory_backends/shared_llm.py +30 -0
  68. sma/eval/memory_backends/sma_memory.py +54 -0
  69. sma/eval/memory_backends/zep_graphiti.py +33 -0
  70. sma/eval/metrics.py +32 -0
  71. sma/eval/ontology_bench.py +219 -0
  72. sma/eval/report.py +573 -0
  73. sma/eval/ssb_eval.py +216 -0
  74. sma/eval/ssb_generator.py +116 -0
  75. sma/eval/stats.py +108 -0
  76. sma/eval/transfer_eval.py +844 -0
  77. sma/index/__init__.py +15 -0
  78. sma/index/ann.py +21 -0
  79. sma/index/content_vectors.py +60 -0
  80. sma/index/inverted.py +63 -0
  81. sma/index/macfac.py +174 -0
  82. sma/ir/__init__.py +22 -0
  83. sma/ir/canon.py +106 -0
  84. sma/ir/schema.py +165 -0
  85. sma/ir/sexpr.py +86 -0
  86. sma/ir/signatures.py +76 -0
  87. sma/match/__init__.py +20 -0
  88. sma/match/conflicts.py +46 -0
  89. sma/match/engine.py +60 -0
  90. sma/match/explain.py +59 -0
  91. sma/match/infer.py +54 -0
  92. sma/match/kernels.py +54 -0
  93. sma/match/mdl.py +30 -0
  94. sma/match/merge_cpsat.py +77 -0
  95. sma/match/merge_greedy.py +15 -0
  96. sma/match/mh.py +177 -0
  97. sma/match/ses.py +84 -0
  98. sma/match/types.py +115 -0
  99. sma/match/verifier.py +27 -0
  100. sma/ontology/__init__.py +45 -0
  101. sma/ontology/attack.py +134 -0
  102. sma/ontology/cpc.py +69 -0
  103. sma/ontology/graph.py +58 -0
  104. sma/ontology/loader.py +262 -0
  105. sma/ontology/mitre_xml.py +67 -0
  106. sma/ontology/mount.py +101 -0
  107. sma/ontology/rdf_loader.py +75 -0
  108. sma/ontology/registry.py +115 -0
  109. sma/ontology/router.py +69 -0
  110. sma/ontology/usgaap.py +73 -0
  111. sma/sage/__init__.py +6 -0
  112. sma/sage/assimilate.py +12 -0
  113. sma/sage/pools.py +105 -0
  114. sma/sage/probabilities.py +10 -0
  115. sma/store/__init__.py +6 -0
  116. sma/store/lmdb_store.py +78 -0
  117. sma/store/registry.py +26 -0
  118. sma/store/wal.py +26 -0
  119. sma/ui/app.py +642 -0
  120. structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
  121. structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
  122. structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
  123. structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
  124. structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
  125. structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
sma/eval/bugsinpy.py ADDED
@@ -0,0 +1,358 @@
1
+ """BugsInPy metadata parsing, patch-structure extraction, and case assembly.
2
+
3
+ Blueprint T3 (bug-fix memory): a case is the deterministic structure of a
4
+ buggy program state -- the unified diff of the fix (which files/functions
5
+ were touched, what calls/keywords/exceptions were added or removed, how big
6
+ the change is) plus the failing-test context from ``run_test.sh``. Everything
7
+ here is Tier-0: regex/diff parsing only, no models.
8
+
9
+ Layout of the dataset (github.com/soarsmu/BugsInPy, cloned into
10
+ ``data/raw/bugsinpy``)::
11
+
12
+ projects/<project>/bugs/<id>/bug.info # commit ids, test_file
13
+ projects/<project>/bugs/<id>/bug_patch.txt # unified diff buggy->fixed
14
+ projects/<project>/bugs/<id>/run_test.sh # failing test invocation
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import pathlib
20
+ import re
21
+ from collections import Counter
22
+ from dataclasses import dataclass, field
23
+
24
+ from sma.encoders import get_encoder
25
+ from sma.ir.schema import Case, Statement, entity, make_case, stmt
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Discovery / loading
29
+ # ---------------------------------------------------------------------------
30
+
31
+
32
+ def discover_bug_metadata(root: str | pathlib.Path) -> list[pathlib.Path]:
33
+ return sorted(pathlib.Path(root).glob("projects/*/bugs/*/bug.info"))
34
+
35
+
36
+ @dataclass
37
+ class BugRecord:
38
+ project: str
39
+ bug_id: str
40
+ patch_text: str
41
+ test_files: tuple[str, ...] = ()
42
+ test_names: tuple[str, ...] = ()
43
+
44
+ @property
45
+ def key(self) -> str:
46
+ return f"{self.project}/{self.bug_id}"
47
+
48
+
49
+ def _parse_bug_info(path: pathlib.Path) -> dict[str, str]:
50
+ info: dict[str, str] = {}
51
+ for line in path.read_text(encoding="utf-8", errors="replace").splitlines():
52
+ m = re.match(r'^(\w+)="(.*)"\s*$', line.strip())
53
+ if m:
54
+ info[m.group(1)] = m.group(2)
55
+ return info
56
+
57
+
58
+ def _parse_run_test(path: pathlib.Path) -> tuple[str, ...]:
59
+ """Extract failing-test function names from run_test.sh (deterministic)."""
60
+ if not path.exists():
61
+ return ()
62
+ names: list[str] = []
63
+ for line in path.read_text(encoding="utf-8", errors="replace").splitlines():
64
+ line = line.strip()
65
+ if not line or line.startswith("#"):
66
+ continue
67
+ # pytest path/to/file.py::Class::test_name (possibly several)
68
+ for token in line.split():
69
+ if "::" in token:
70
+ names.append(token.split("::")[-1])
71
+ # python -m unittest [-q] pkg.mod.Class.test_name
72
+ if "unittest" in line and "::" not in line:
73
+ tail = line.split()[-1]
74
+ comp = tail.split(".")[-1]
75
+ if comp.startswith("test"):
76
+ names.append(comp)
77
+ out, seen = [], set()
78
+ for n in names:
79
+ n = re.sub(r"\[.*\]$", "", n) # strip parametrize ids
80
+ if n and n not in seen:
81
+ seen.add(n)
82
+ out.append(n)
83
+ return tuple(out)
84
+
85
+
86
+ def load_bugs(root: str | pathlib.Path) -> list[BugRecord]:
87
+ """Load every bug with a non-empty patch, sorted by (project, bug id)."""
88
+ records: list[BugRecord] = []
89
+ for info_path in discover_bug_metadata(root):
90
+ bug_dir = info_path.parent
91
+ project = bug_dir.parent.parent.name
92
+ bug_id = bug_dir.name
93
+ patch_path = bug_dir / "bug_patch.txt"
94
+ if not patch_path.exists():
95
+ continue
96
+ patch_text = patch_path.read_text(encoding="utf-8", errors="replace")
97
+ if not patch_text.strip():
98
+ continue # e.g. keras bug 12 ships an empty patch
99
+ info = _parse_bug_info(info_path)
100
+ test_files = tuple(
101
+ t.strip() for t in info.get("test_file", "").split(";") if t.strip()
102
+ )
103
+ records.append(
104
+ BugRecord(
105
+ project=project,
106
+ bug_id=bug_id,
107
+ patch_text=patch_text,
108
+ test_files=test_files,
109
+ test_names=_parse_run_test(bug_dir / "run_test.sh"),
110
+ )
111
+ )
112
+ records.sort(key=lambda r: (r.project, int(r.bug_id) if r.bug_id.isdigit() else 0))
113
+ return records
114
+
115
+
116
+ # ---------------------------------------------------------------------------
117
+ # Unified-diff parsing
118
+ # ---------------------------------------------------------------------------
119
+
120
+
121
+ @dataclass
122
+ class Hunk:
123
+ file: str
124
+ header: str # text after the second @@ (enclosing-scope context)
125
+ removed: list[str] = field(default_factory=list)
126
+ added: list[str] = field(default_factory=list)
127
+ context: list[str] = field(default_factory=list)
128
+
129
+
130
+ @dataclass
131
+ class PatchFacts:
132
+ files: list[str]
133
+ hunks: list[Hunk]
134
+ functions: list[tuple[str, str]] # (file basename, function) pairs
135
+ added_lines: list[str] # non-blank code lines added (markers stripped)
136
+ removed_lines: list[str]
137
+ exceptions: list[str] # CamelCase *Error/*Exception/*Warning mentioned
138
+
139
+ @property
140
+ def n_added(self) -> int:
141
+ return len(self.added_lines)
142
+
143
+ @property
144
+ def n_removed(self) -> int:
145
+ return len(self.removed_lines)
146
+
147
+
148
+ _EXC_RE = re.compile(r"\b([A-Z][A-Za-z0-9]*(?:Error|Exception|Warning))\b")
149
+ _DEF_RE = re.compile(r"\bdef\s+([A-Za-z_]\w*)")
150
+
151
+
152
+ def parse_patch(patch_text: str) -> PatchFacts:
153
+ files: list[str] = []
154
+ hunks: list[Hunk] = []
155
+ current_file = ""
156
+ hunk: Hunk | None = None
157
+ for raw in patch_text.splitlines():
158
+ if raw.startswith("diff --git"):
159
+ m = re.search(r" b/(\S+)$", raw)
160
+ current_file = m.group(1) if m else raw.split()[-1]
161
+ files.append(current_file)
162
+ hunk = None
163
+ continue
164
+ if raw.startswith("+++") or raw.startswith("---") or raw.startswith("index "):
165
+ continue
166
+ if raw.startswith("@@"):
167
+ m = re.match(r"^@@[^@]*@@\s?(.*)$", raw)
168
+ hunk = Hunk(file=current_file, header=m.group(1) if m else "")
169
+ hunks.append(hunk)
170
+ continue
171
+ if hunk is None:
172
+ continue
173
+ if raw.startswith("+"):
174
+ line = raw[1:]
175
+ if line.strip():
176
+ hunk.added.append(line)
177
+ elif raw.startswith("-"):
178
+ line = raw[1:]
179
+ if line.strip():
180
+ hunk.removed.append(line)
181
+ else:
182
+ line = raw[1:] if raw.startswith(" ") else raw
183
+ if line.strip():
184
+ hunk.context.append(line)
185
+
186
+ added = [l for h in hunks for l in h.added]
187
+ removed = [l for h in hunks for l in h.removed]
188
+ context = [l for h in hunks for l in h.context]
189
+
190
+ functions: list[tuple[str, str]] = []
191
+ seen_fn: set[tuple[str, str]] = set()
192
+ for h in hunks:
193
+ base = pathlib.PurePosixPath(h.file).name
194
+ names = _DEF_RE.findall(h.header)
195
+ # defs appearing inside the changed lines are also "modified functions"
196
+ for l in h.removed + h.added:
197
+ names.extend(_DEF_RE.findall(l))
198
+ if not names:
199
+ names = ["<module>"]
200
+ for name in names:
201
+ key = (base, name)
202
+ if key not in seen_fn:
203
+ seen_fn.add(key)
204
+ functions.append(key)
205
+
206
+ exceptions = sorted(
207
+ {e for l in added + context for e in _EXC_RE.findall(l)}
208
+ )
209
+ return PatchFacts(
210
+ files=files,
211
+ hunks=hunks,
212
+ functions=functions,
213
+ added_lines=added,
214
+ removed_lines=removed,
215
+ exceptions=exceptions,
216
+ )
217
+
218
+
219
+ # ---------------------------------------------------------------------------
220
+ # Case assembly (Tier-0)
221
+ # ---------------------------------------------------------------------------
222
+
223
+ _PY_KEYWORDS = {
224
+ "if", "elif", "else", "for", "while", "try", "except", "finally", "raise",
225
+ "return", "with", "assert", "not", "and", "or", "in", "is", "def", "class",
226
+ "lambda", "yield", "del", "pass", "import", "from", "as", "global", "print",
227
+ }
228
+ _CALL_RE = re.compile(r"([A-Za-z_][\w\.]*)\s*\(")
229
+ _KEYWORD_TOKENS = (
230
+ "if", "elif", "else", "for", "while", "try", "except", "finally", "raise",
231
+ "return", "with", "assert", "not", "and", "or", "in", "is", "None",
232
+ )
233
+ _KEYWORD_RES = {k: re.compile(rf"\b{k}\b") for k in _KEYWORD_TOKENS}
234
+
235
+ _MAX_CALLS = 12
236
+ _MAX_EXCEPTIONS = 8
237
+ _MAX_FUNCTIONS = 16
238
+ _MAX_ENCODER_STMTS = 24
239
+
240
+
241
+ def _calls(lines: list[str]) -> Counter[str]:
242
+ counts: Counter[str] = Counter()
243
+ for line in lines:
244
+ for name in _CALL_RE.findall(line):
245
+ tail = name.split(".")[-1]
246
+ if tail in _PY_KEYWORDS or name in _PY_KEYWORDS:
247
+ continue
248
+ counts[tail] += 1
249
+ return counts
250
+
251
+
252
+ def _keyword_counts(lines: list[str]) -> Counter[str]:
253
+ counts: Counter[str] = Counter()
254
+ for line in lines:
255
+ code = line.split("#", 1)[0]
256
+ for kw, rx in _KEYWORD_RES.items():
257
+ counts[kw] += len(rx.findall(code))
258
+ return counts
259
+
260
+
261
+ def size_bucket(n: int) -> str:
262
+ if n <= 0:
263
+ return "0"
264
+ if n == 1:
265
+ return "1"
266
+ if n <= 3:
267
+ return "2-3"
268
+ if n <= 7:
269
+ return "4-7"
270
+ if n <= 15:
271
+ return "8-15"
272
+ if n <= 31:
273
+ return "16-31"
274
+ return "32plus"
275
+
276
+
277
+ def _encoder_statements(added_lines: list[str]) -> list[Statement]:
278
+ """Run the existing code adapter on the added hunks (AST if it parses,
279
+ regex fallback otherwise); keep its non-placeholder statements."""
280
+ block = "\n".join(l.strip() for l in added_lines)
281
+ if not block.strip():
282
+ return []
283
+ encoder = get_encoder("code")
284
+ result = encoder.encode(block, language="python")
285
+ functors = {s.functor for s in result.case.statements}
286
+ if "syntaxError" in functors: # hunk fragments rarely parse; regex fallback
287
+ result = encoder.encode(block, language="text")
288
+ keep = [
289
+ s
290
+ for s in result.case.statements
291
+ if s.functor not in {"syntaxError", "emptyCode", "rawCode"}
292
+ ]
293
+ return keep[:_MAX_ENCODER_STMTS]
294
+
295
+
296
+ def bug_case(record: BugRecord, facts: PatchFacts | None = None) -> Case:
297
+ """Deterministic Tier-0 case for one bug: diff structure + failing test."""
298
+ facts = facts or parse_patch(record.patch_text)
299
+ statements: set[Statement] = set()
300
+
301
+ for base, fn in facts.functions[:_MAX_FUNCTIONS]:
302
+ statements.add(stmt("modifies", entity(base, "file"), entity(fn, "function")))
303
+ statements.add(stmt("addsLines", entity(size_bucket(facts.n_added), "count")))
304
+ statements.add(stmt("removesLines", entity(size_bucket(facts.n_removed), "count")))
305
+ statements.add(
306
+ stmt("touchesFiles", entity(size_bucket(len(facts.files)), "count"))
307
+ )
308
+ if any("test" in f.lower() for f in facts.files):
309
+ statements.add(stmt("touchesTestFile", entity("patch", "scope")))
310
+
311
+ for exc in facts.exceptions[:_MAX_EXCEPTIONS]:
312
+ statements.add(stmt("mentionsException", entity(exc, "exception")))
313
+
314
+ added_calls = _calls(facts.added_lines)
315
+ removed_calls = _calls(facts.removed_lines)
316
+ net_added = sorted(
317
+ (c for c in added_calls if added_calls[c] > removed_calls.get(c, 0)),
318
+ key=lambda c: (-added_calls[c], c),
319
+ )[:_MAX_CALLS]
320
+ net_removed = sorted(
321
+ (c for c in removed_calls if removed_calls[c] > added_calls.get(c, 0)),
322
+ key=lambda c: (-removed_calls[c], c),
323
+ )[:_MAX_CALLS]
324
+ for name in net_added:
325
+ statements.add(stmt("addsCall", entity("patch", "scope"), entity(name, "callable")))
326
+ for name in net_removed:
327
+ statements.add(stmt("removesCall", entity("patch", "scope"), entity(name, "callable")))
328
+
329
+ kw_added = _keyword_counts(facts.added_lines)
330
+ kw_removed = _keyword_counts(facts.removed_lines)
331
+ for kw in _KEYWORD_TOKENS:
332
+ if kw_added[kw] > kw_removed[kw]:
333
+ statements.add(stmt("addsKeyword", entity("patch", "scope"), entity(kw, "keyword")))
334
+ elif kw_removed[kw] > kw_added[kw]:
335
+ statements.add(stmt("removesKeyword", entity("patch", "scope"), entity(kw, "keyword")))
336
+
337
+ for name in record.test_names[:4]:
338
+ statements.add(stmt("failingTest", entity(name, "test")))
339
+ for tf in record.test_files[:4]:
340
+ statements.add(
341
+ stmt("testModule", entity(pathlib.PurePosixPath(tf).name, "file"))
342
+ )
343
+
344
+ statements.update(_encoder_statements(facts.added_lines))
345
+
346
+ if not statements:
347
+ statements.add(stmt("emptyPatch", entity(record.key, "bug")))
348
+ return make_case(
349
+ sorted(statements, key=repr),
350
+ {
351
+ "adapter": "code",
352
+ "tier": 0,
353
+ "dataset": "bugsinpy",
354
+ "project": record.project,
355
+ "bug": record.bug_id,
356
+ },
357
+ case_id=f"bugsinpy:{record.key}",
358
+ )
@@ -0,0 +1,164 @@
1
+ """Deterministic fix-pattern categories for BugsInPy patches.
2
+
3
+ Categories are assigned by ORDERED rules over the unified diff (first match
4
+ wins). They are the ground-truth labels for the T3 retrieval metric
5
+ (fix-category-hit@k): did retrieval surface a past bug fixed by the SAME
6
+ KIND of change, not just a textually similar file?
7
+
8
+ Rule order (per the T3 specification):
9
+ 1. add-null-check adds ``is None`` / ``is not None`` / ``if not x``
10
+ 2. exception-handling adds ``try:`` / ``except`` / ``raise``
11
+ 3. boundary comparison-operator swap or +/-1 on an otherwise
12
+ identical line (off-by-one / boundary fixes)
13
+ 4. type-coercion adds a builtin cast (int()/str()/list()/...)
14
+ 5. api-substitution small hunk replaces one call with another
15
+ 6. default-arg-change a ``def`` signature line changes its defaults/args
16
+ 7. condition-strengthening an if/elif/while gains and/or clauses
17
+ 8. other
18
+
19
+ "Net added" means the pattern occurs more often in added lines than in
20
+ removed lines, so pure code motion does not trigger a rule.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import re
26
+
27
+ from .bugsinpy import Hunk, PatchFacts, _calls
28
+
29
+ CATEGORIES = (
30
+ "add-null-check",
31
+ "exception-handling",
32
+ "boundary",
33
+ "type-coercion",
34
+ "api-substitution",
35
+ "default-arg-change",
36
+ "condition-strengthening",
37
+ "other",
38
+ )
39
+
40
+ _NULL_CHECK = re.compile(r"\bis\s+(?:not\s+)?None\b|^\s*(?:el)?if\s+not\s+\w")
41
+ _EXC_HANDLING = re.compile(r"^\s*(?:try\s*:|except[\s:(]|raise\b)")
42
+ _CAST = re.compile(
43
+ r"(?<![\w.])(?:int|str|float|bool|list|tuple|set|dict|frozenset|bytes)\(|\.astype\("
44
+ )
45
+ _CMP_OPS = re.compile(r"<=|>=|==|!=|<|>")
46
+ _PM_ONE = re.compile(r"[+-]\s*1\b")
47
+ _COND_LINE = re.compile(r"^\s*((?:el)?if|while)\b")
48
+ _BOOL_OP = re.compile(r"\b(?:and|or)\b")
49
+ _DEF_LINE = re.compile(r"^\s*def\s+([A-Za-z_]\w*)\s*\((.*)$")
50
+
51
+
52
+ def _strip_code(line: str) -> str:
53
+ return line.split("#", 1)[0]
54
+
55
+
56
+ def _net(pattern: re.Pattern[str], added: list[str], removed: list[str]) -> bool:
57
+ n_add = sum(len(pattern.findall(_strip_code(l))) for l in added)
58
+ n_rem = sum(len(pattern.findall(_strip_code(l))) for l in removed)
59
+ return n_add > n_rem
60
+
61
+
62
+ def _paired_swap(hunk: Hunk, pattern: re.Pattern[str]) -> bool:
63
+ """True if some removed/added line pair is identical once `pattern`
64
+ occurrences are masked out, while the raw lines differ (i.e. the ONLY
65
+ change is in the matched operator/constant)."""
66
+ def norm(line: str) -> str:
67
+ return re.sub(r"\s+", " ", pattern.sub("\x00", _strip_code(line))).strip()
68
+
69
+ removed = {norm(l): _strip_code(l).strip() for l in hunk.removed if pattern.search(_strip_code(l))}
70
+ for line in hunk.added:
71
+ code = _strip_code(line)
72
+ if not pattern.search(code):
73
+ continue
74
+ key = norm(line)
75
+ if key in removed and removed[key] != code.strip() and "\x00" in key:
76
+ return True
77
+ return False
78
+
79
+
80
+ def _is_null_check(facts: PatchFacts) -> bool:
81
+ return _net(_NULL_CHECK, facts.added_lines, facts.removed_lines)
82
+
83
+
84
+ def _is_exception_handling(facts: PatchFacts) -> bool:
85
+ return _net(_EXC_HANDLING, facts.added_lines, facts.removed_lines)
86
+
87
+
88
+ def _is_boundary(facts: PatchFacts) -> bool:
89
+ for hunk in facts.hunks:
90
+ if _paired_swap(hunk, _CMP_OPS) or _paired_swap(hunk, _PM_ONE):
91
+ return True
92
+ return False
93
+
94
+
95
+ def _is_type_coercion(facts: PatchFacts) -> bool:
96
+ return _net(_CAST, facts.added_lines, facts.removed_lines)
97
+
98
+
99
+ def _is_api_substitution(facts: PatchFacts) -> bool:
100
+ for hunk in facts.hunks:
101
+ if not (1 <= len(hunk.removed) <= 3 and 1 <= len(hunk.added) <= 3):
102
+ continue
103
+ removed_calls = set(_calls(hunk.removed))
104
+ added_calls = set(_calls(hunk.added))
105
+ if (removed_calls - added_calls) and (added_calls - removed_calls):
106
+ return True
107
+ return False
108
+
109
+
110
+ def _is_default_arg_change(facts: PatchFacts) -> bool:
111
+ for hunk in facts.hunks:
112
+ removed_defs = {}
113
+ for line in hunk.removed:
114
+ m = _DEF_LINE.match(_strip_code(line))
115
+ if m:
116
+ removed_defs[m.group(1)] = m.group(2).strip()
117
+ for line in hunk.added:
118
+ m = _DEF_LINE.match(_strip_code(line))
119
+ if not m:
120
+ continue
121
+ name, args = m.group(1), m.group(2).strip()
122
+ old = removed_defs.get(name)
123
+ if old is not None and old != args and ("=" in old or "=" in args):
124
+ return True
125
+ return False
126
+
127
+
128
+ def _is_condition_strengthening(facts: PatchFacts) -> bool:
129
+ for hunk in facts.hunks:
130
+ removed_conds: dict[str, int] = {}
131
+ for line in hunk.removed:
132
+ m = _COND_LINE.match(_strip_code(line))
133
+ if m:
134
+ kw = m.group(1)
135
+ n = len(_BOOL_OP.findall(_strip_code(line)))
136
+ removed_conds[kw] = max(removed_conds.get(kw, -1), n)
137
+ for line in hunk.added:
138
+ m = _COND_LINE.match(_strip_code(line))
139
+ if not m:
140
+ continue
141
+ kw = m.group(1)
142
+ n = len(_BOOL_OP.findall(_strip_code(line)))
143
+ if kw in removed_conds and n > removed_conds[kw]:
144
+ return True
145
+ return False
146
+
147
+
148
+ _RULES = (
149
+ ("add-null-check", _is_null_check),
150
+ ("exception-handling", _is_exception_handling),
151
+ ("boundary", _is_boundary),
152
+ ("type-coercion", _is_type_coercion),
153
+ ("api-substitution", _is_api_substitution),
154
+ ("default-arg-change", _is_default_arg_change),
155
+ ("condition-strengthening", _is_condition_strengthening),
156
+ )
157
+
158
+
159
+ def categorize(facts: PatchFacts) -> str:
160
+ """Ordered first-match-wins fix-pattern category for one patch."""
161
+ for name, rule in _RULES:
162
+ if rule(facts):
163
+ return name
164
+ return "other"
@@ -0,0 +1,89 @@
1
+ """Shared cross-domain arm evaluator (4b).
2
+
3
+ Guarantees an IDENTICAL index/query partition across arms (generic / drafted /
4
+ expert) by splitting on the original record order — NOT on the encoding-dependent
5
+ case_id — so the baselines are stable and the only thing that varies between arms
6
+ is the SMA encoding. Returns the result dict and writes cd_<domain>_<phase>.csv.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import csv
11
+ import pathlib
12
+ import time
13
+
14
+ from sma.index.macfac import MacFacIndex
15
+ from sma.ir.schema import Statement
16
+ from sma.eval.baselines.bm25 import rank_bm25_like
17
+ from sma.eval.baselines.dense import rank_tfidf_dense_batch
18
+ from sma.eval.metrics import macro_f1
19
+ from sma.eval.stats import paired_bootstrap, holm_bonferroni, cliffs_delta
20
+
21
+ OUT = pathlib.Path("reports/confirmatory")
22
+
23
+
24
+ def _ho_density(case) -> float:
25
+ ex = case.expressions()
26
+ ho = sum(1 for s in ex if any(isinstance(a, Statement) for a in s.args))
27
+ return ho / max(len(ex), 1)
28
+
29
+
30
+ def _vote(ranked_ids, label_of, labels):
31
+ pos, neg = labels
32
+ tally = {pos: 0, neg: 0}
33
+ for cid in ranked_ids:
34
+ tally[label_of[cid]] += 1
35
+ return pos if tally[pos] >= tally[neg] else neg
36
+
37
+
38
+ def evaluate_arm(items, encode_fn, row_text_fn, labels, phase, domain,
39
+ k=10, frac=0.7):
40
+ """items: list of records (each has .label). encode_fn(item)->Case.
41
+ row_text_fn(item)->str for the baselines. labels=(pos,neg)."""
42
+ index_n = int(len(items) * frac)
43
+ index_items, query_items = items[:index_n], items[index_n:] # FIXED split
44
+
45
+ def enc(its):
46
+ return [(it, encode_fn(it)) for it in its]
47
+ idx, qry = enc(index_items), enc(query_items)
48
+ label_of = {c.case_id: it.label for it, c in idx + qry}
49
+ text_of = {c.case_id: row_text_fn(it) for it, c in idx + qry}
50
+ dens = [_ho_density(c) for _, c in idx + qry]
51
+ mean_ho = sum(dens) / len(dens)
52
+
53
+ index_cases = [c for _, c in idx]
54
+ index_docs = [(c.case_id, text_of[c.case_id]) for _, c in idx]
55
+ t0 = time.perf_counter()
56
+ sma_index = MacFacIndex(); sma_index.build(index_cases)
57
+ gold, sma_p, bm_p, dn_p = [], [], [], []
58
+ dense_rk = rank_tfidf_dense_batch([text_of[c.case_id] for _, c in qry], index_docs, k=k)
59
+ for qi, (_, qc) in enumerate(qry):
60
+ gold.append(label_of[qc.case_id])
61
+ res = sma_index.retrieve(qc, k=k, shortlist=60, fac_budget=25)
62
+ sma_p.append(_vote([r.case_id for r in res], label_of, labels))
63
+ bm = rank_bm25_like(text_of[qc.case_id], index_docs, k=k)
64
+ bm_p.append(_vote([cid for cid, _ in bm], label_of, labels))
65
+ dn_p.append(_vote([cid for cid, _ in dense_rk[qi]], label_of, labels))
66
+ dt = time.perf_counter() - t0
67
+
68
+ f1 = {"SMA": macro_f1(gold, sma_p), "BM25": macro_f1(gold, bm_p), "Dense": macro_f1(gold, dn_p)}
69
+ print(f"[{domain}/{phase}] HO-density={mean_ho:.4f} macro-F1: SMA {f1['SMA']:.4f} "
70
+ f"BM25 {f1['BM25']:.4f} Dense {f1['Dense']:.4f} ({dt:.0f}s)", flush=True)
71
+
72
+ def correct(p):
73
+ return [1.0 if a == g else 0.0 for a, g in zip(p, gold)]
74
+ sma_c = correct(sma_p); pv, summ = {}, []
75
+ for name, pred in (("BM25", bm_p), ("Dense", dn_p)):
76
+ bs = paired_bootstrap(sma_c, correct(pred)); pv[name] = bs["p_value"]
77
+ summ.append({"phase": phase, "domain": domain, "baseline": name,
78
+ "ho_density": f"{mean_ho:.4f}", "sma_f1": f"{f1['SMA']:.4f}",
79
+ "baseline_f1": f"{f1[name]:.4f}", "delta": f"{bs['delta']:.4f}",
80
+ "ci_low": f"{bs['ci_low']:.4f}", "ci_high": f"{bs['ci_high']:.4f}",
81
+ "cliffs": f"{cliffs_delta(sma_c, correct(pred)):.4f}"})
82
+ holm = holm_bonferroni(pv)
83
+ for s in summ:
84
+ s["p_holm"] = f"{holm[s['baseline']]:.4f}"
85
+ out = OUT / f"cd_{domain}_{phase}.csv"
86
+ out.parent.mkdir(parents=True, exist_ok=True)
87
+ with out.open("w", newline="") as fh:
88
+ w = csv.DictWriter(fh, fieldnames=list(summ[0])); w.writeheader(); w.writerows(summ)
89
+ return {"f1": f1, "ho": mean_ho, "summ": summ}
sma/eval/diabetes.py ADDED
@@ -0,0 +1,61 @@
1
+ """Diabetes-130 loader + per-encounter artifact builder (4b cross-domain).
2
+
3
+ Real UCI EHR. We build a flat CSV-row artifact for the GENERIC structured
4
+ adapter (the honest 'before': flat triples, no higher-order relations) and a
5
+ plain attr=value text for the lexical/dense baselines. The readmission label is
6
+ NEVER encoded (leakage guard).
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import csv
11
+ import pathlib
12
+ import random
13
+ from dataclasses import dataclass
14
+
15
+ # ids, mostly-missing columns, and the LABEL are excluded from the encoding.
16
+ DROP = {"encounter_id", "patient_nbr", "weight", "payer_code", "readmitted"}
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class Encounter:
21
+ eid: str
22
+ fields: dict[str, str] # cleaned attribute -> value (no '?'/'' )
23
+ label: str # "early" (readmitted <30 days) vs "not"
24
+
25
+
26
+ def _csv_path() -> pathlib.Path:
27
+ return next(pathlib.Path("data/raw/diabetes130").rglob("diabetic_data.csv"))
28
+
29
+
30
+ def load_encounters(sample: int = 1500, seed: int = 7, balanced: bool = True) -> list[Encounter]:
31
+ rows = list(csv.DictReader(_csv_path().open()))
32
+ rng = random.Random(seed)
33
+ rng.shuffle(rows)
34
+
35
+ def to_enc(r: dict) -> Encounter:
36
+ fields = {k: v for k, v in r.items() if k not in DROP and v not in ("?", "")}
37
+ label = "early" if r["readmitted"] == "<30" else "not"
38
+ return Encounter(r["encounter_id"], fields, label)
39
+
40
+ encs = [to_enc(r) for r in rows]
41
+ if not balanced:
42
+ return encs[:sample]
43
+ # balance the two classes so retrieval-by-analogy has signal (early is ~11%)
44
+ early = [e for e in encs if e.label == "early"]
45
+ notr = [e for e in encs if e.label == "not"]
46
+ half = sample // 2
47
+ out = early[:half] + notr[:half]
48
+ rng.shuffle(out)
49
+ return out
50
+
51
+
52
+ def row_csv(enc: Encounter) -> str:
53
+ """Two-line CSV (header + one row) for the structured adapter -> flat
54
+ (attribute row value) triples."""
55
+ keys = sorted(enc.fields)
56
+ return ",".join(keys) + "\n" + ",".join(enc.fields[k] for k in keys) + "\n"
57
+
58
+
59
+ def row_text(enc: Encounter) -> str:
60
+ """attr=value text for BM25 / dense baselines."""
61
+ return " ".join(f"{k}={v}" for k, v in sorted(enc.fields.items()))