structuremappingmemory 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sma/__init__.py +5 -0
- sma/__main__.py +5 -0
- sma/agent/__init__.py +5 -0
- sma/agent/adapter_draft.py +217 -0
- sma/agent/api.py +67 -0
- sma/agent/comparison.py +591 -0
- sma/agent/llm.py +280 -0
- sma/agent/policies.py +21 -0
- sma/agent/service.py +95 -0
- sma/cli.py +65 -0
- sma/encoders/__init__.py +38 -0
- sma/encoders/agentobs.py +27 -0
- sma/encoders/base.py +23 -0
- sma/encoders/code_treesitter.py +64 -0
- sma/encoders/coverage.py +80 -0
- sma/encoders/draft_adapter.py +183 -0
- sma/encoders/healthcare.py +207 -0
- sma/encoders/logs_drain.py +142 -0
- sma/encoders/prose_tier1.py +57 -0
- sma/encoders/structured.py +57 -0
- sma/encoders/traces.py +45 -0
- sma/eval/__init__.py +2 -0
- sma/eval/agentic/__init__.py +35 -0
- sma/eval/agentic/arms/__init__.py +0 -0
- sma/eval/agentic/arms/cyber.py +48 -0
- sma/eval/agentic/arms/discovery.py +35 -0
- sma/eval/agentic/arms/finance.py +38 -0
- sma/eval/agentic/arms/legal.py +74 -0
- sma/eval/agentic/arms/medicine.py +45 -0
- sma/eval/agentic/harness.py +275 -0
- sma/eval/agentic/memories.py +308 -0
- sma/eval/agentic/metrics.py +82 -0
- sma/eval/agentic_qa/__init__.py +27 -0
- sma/eval/agentic_qa/agent.py +383 -0
- sma/eval/agentic_qa/metrics.py +239 -0
- sma/eval/agentic_qa/pools.py +197 -0
- sma/eval/arn.py +65 -0
- sma/eval/baselines/__init__.py +6 -0
- sma/eval/baselines/bge_dense.py +54 -0
- sma/eval/baselines/bm25.py +18 -0
- sma/eval/baselines/dense.py +42 -0
- sma/eval/baselines/hipporag.py +235 -0
- sma/eval/baselines/hybrid_rrf.py +30 -0
- sma/eval/baselines/longcontext_llm.py +124 -0
- sma/eval/baselines/rerank.py +41 -0
- sma/eval/baselines/splade.py +77 -0
- sma/eval/baselines/wl_kernel.py +163 -0
- sma/eval/bugsinpy.py +358 -0
- sma/eval/bugsinpy_families.py +164 -0
- sma/eval/crossdomain.py +89 -0
- sma/eval/diabetes.py +61 -0
- sma/eval/drift_env.py +26 -0
- sma/eval/drift_metrics.py +24 -0
- sma/eval/family_labels.py +167 -0
- sma/eval/fraud_elliptic/__init__.py +29 -0
- sma/eval/fraud_elliptic/encoder.py +279 -0
- sma/eval/fraud_elliptic/eval.py +269 -0
- sma/eval/fraud_elliptic/test_encoder.py +123 -0
- sma/eval/ieee_cis.py +66 -0
- sma/eval/loghub.py +16 -0
- sma/eval/loghub_eval.py +480 -0
- sma/eval/longmemeval.py +51 -0
- sma/eval/memory_backends/__init__.py +2 -0
- sma/eval/memory_backends/base.py +22 -0
- sma/eval/memory_backends/context_only.py +14 -0
- sma/eval/memory_backends/rag_notes.py +17 -0
- sma/eval/memory_backends/shared_llm.py +30 -0
- sma/eval/memory_backends/sma_memory.py +54 -0
- sma/eval/memory_backends/zep_graphiti.py +33 -0
- sma/eval/metrics.py +32 -0
- sma/eval/ontology_bench.py +219 -0
- sma/eval/report.py +573 -0
- sma/eval/ssb_eval.py +216 -0
- sma/eval/ssb_generator.py +116 -0
- sma/eval/stats.py +108 -0
- sma/eval/transfer_eval.py +844 -0
- sma/index/__init__.py +15 -0
- sma/index/ann.py +21 -0
- sma/index/content_vectors.py +60 -0
- sma/index/inverted.py +63 -0
- sma/index/macfac.py +174 -0
- sma/ir/__init__.py +22 -0
- sma/ir/canon.py +106 -0
- sma/ir/schema.py +165 -0
- sma/ir/sexpr.py +86 -0
- sma/ir/signatures.py +76 -0
- sma/match/__init__.py +20 -0
- sma/match/conflicts.py +46 -0
- sma/match/engine.py +60 -0
- sma/match/explain.py +59 -0
- sma/match/infer.py +54 -0
- sma/match/kernels.py +54 -0
- sma/match/mdl.py +30 -0
- sma/match/merge_cpsat.py +77 -0
- sma/match/merge_greedy.py +15 -0
- sma/match/mh.py +177 -0
- sma/match/ses.py +84 -0
- sma/match/types.py +115 -0
- sma/match/verifier.py +27 -0
- sma/ontology/__init__.py +45 -0
- sma/ontology/attack.py +134 -0
- sma/ontology/cpc.py +69 -0
- sma/ontology/graph.py +58 -0
- sma/ontology/loader.py +262 -0
- sma/ontology/mitre_xml.py +67 -0
- sma/ontology/mount.py +101 -0
- sma/ontology/rdf_loader.py +75 -0
- sma/ontology/registry.py +115 -0
- sma/ontology/router.py +69 -0
- sma/ontology/usgaap.py +73 -0
- sma/sage/__init__.py +6 -0
- sma/sage/assimilate.py +12 -0
- sma/sage/pools.py +105 -0
- sma/sage/probabilities.py +10 -0
- sma/store/__init__.py +6 -0
- sma/store/lmdb_store.py +78 -0
- sma/store/registry.py +26 -0
- sma/store/wal.py +26 -0
- sma/ui/app.py +642 -0
- structuremappingmemory-1.0.0.dist-info/METADATA +190 -0
- structuremappingmemory-1.0.0.dist-info/RECORD +125 -0
- structuremappingmemory-1.0.0.dist-info/WHEEL +5 -0
- structuremappingmemory-1.0.0.dist-info/entry_points.txt +2 -0
- structuremappingmemory-1.0.0.dist-info/licenses/LICENSE +204 -0
- structuremappingmemory-1.0.0.dist-info/top_level.txt +1 -0
sma/eval/bugsinpy.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
"""BugsInPy metadata parsing, patch-structure extraction, and case assembly.
|
|
2
|
+
|
|
3
|
+
Blueprint T3 (bug-fix memory): a case is the deterministic structure of a
|
|
4
|
+
buggy program state -- the unified diff of the fix (which files/functions
|
|
5
|
+
were touched, what calls/keywords/exceptions were added or removed, how big
|
|
6
|
+
the change is) plus the failing-test context from ``run_test.sh``. Everything
|
|
7
|
+
here is Tier-0: regex/diff parsing only, no models.
|
|
8
|
+
|
|
9
|
+
Layout of the dataset (github.com/soarsmu/BugsInPy, cloned into
|
|
10
|
+
``data/raw/bugsinpy``)::
|
|
11
|
+
|
|
12
|
+
projects/<project>/bugs/<id>/bug.info # commit ids, test_file
|
|
13
|
+
projects/<project>/bugs/<id>/bug_patch.txt # unified diff buggy->fixed
|
|
14
|
+
projects/<project>/bugs/<id>/run_test.sh # failing test invocation
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import pathlib
|
|
20
|
+
import re
|
|
21
|
+
from collections import Counter
|
|
22
|
+
from dataclasses import dataclass, field
|
|
23
|
+
|
|
24
|
+
from sma.encoders import get_encoder
|
|
25
|
+
from sma.ir.schema import Case, Statement, entity, make_case, stmt
|
|
26
|
+
|
|
27
|
+
# ---------------------------------------------------------------------------
|
|
28
|
+
# Discovery / loading
|
|
29
|
+
# ---------------------------------------------------------------------------
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def discover_bug_metadata(root: str | pathlib.Path) -> list[pathlib.Path]:
|
|
33
|
+
return sorted(pathlib.Path(root).glob("projects/*/bugs/*/bug.info"))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class BugRecord:
|
|
38
|
+
project: str
|
|
39
|
+
bug_id: str
|
|
40
|
+
patch_text: str
|
|
41
|
+
test_files: tuple[str, ...] = ()
|
|
42
|
+
test_names: tuple[str, ...] = ()
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def key(self) -> str:
|
|
46
|
+
return f"{self.project}/{self.bug_id}"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _parse_bug_info(path: pathlib.Path) -> dict[str, str]:
|
|
50
|
+
info: dict[str, str] = {}
|
|
51
|
+
for line in path.read_text(encoding="utf-8", errors="replace").splitlines():
|
|
52
|
+
m = re.match(r'^(\w+)="(.*)"\s*$', line.strip())
|
|
53
|
+
if m:
|
|
54
|
+
info[m.group(1)] = m.group(2)
|
|
55
|
+
return info
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _parse_run_test(path: pathlib.Path) -> tuple[str, ...]:
|
|
59
|
+
"""Extract failing-test function names from run_test.sh (deterministic)."""
|
|
60
|
+
if not path.exists():
|
|
61
|
+
return ()
|
|
62
|
+
names: list[str] = []
|
|
63
|
+
for line in path.read_text(encoding="utf-8", errors="replace").splitlines():
|
|
64
|
+
line = line.strip()
|
|
65
|
+
if not line or line.startswith("#"):
|
|
66
|
+
continue
|
|
67
|
+
# pytest path/to/file.py::Class::test_name (possibly several)
|
|
68
|
+
for token in line.split():
|
|
69
|
+
if "::" in token:
|
|
70
|
+
names.append(token.split("::")[-1])
|
|
71
|
+
# python -m unittest [-q] pkg.mod.Class.test_name
|
|
72
|
+
if "unittest" in line and "::" not in line:
|
|
73
|
+
tail = line.split()[-1]
|
|
74
|
+
comp = tail.split(".")[-1]
|
|
75
|
+
if comp.startswith("test"):
|
|
76
|
+
names.append(comp)
|
|
77
|
+
out, seen = [], set()
|
|
78
|
+
for n in names:
|
|
79
|
+
n = re.sub(r"\[.*\]$", "", n) # strip parametrize ids
|
|
80
|
+
if n and n not in seen:
|
|
81
|
+
seen.add(n)
|
|
82
|
+
out.append(n)
|
|
83
|
+
return tuple(out)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def load_bugs(root: str | pathlib.Path) -> list[BugRecord]:
|
|
87
|
+
"""Load every bug with a non-empty patch, sorted by (project, bug id)."""
|
|
88
|
+
records: list[BugRecord] = []
|
|
89
|
+
for info_path in discover_bug_metadata(root):
|
|
90
|
+
bug_dir = info_path.parent
|
|
91
|
+
project = bug_dir.parent.parent.name
|
|
92
|
+
bug_id = bug_dir.name
|
|
93
|
+
patch_path = bug_dir / "bug_patch.txt"
|
|
94
|
+
if not patch_path.exists():
|
|
95
|
+
continue
|
|
96
|
+
patch_text = patch_path.read_text(encoding="utf-8", errors="replace")
|
|
97
|
+
if not patch_text.strip():
|
|
98
|
+
continue # e.g. keras bug 12 ships an empty patch
|
|
99
|
+
info = _parse_bug_info(info_path)
|
|
100
|
+
test_files = tuple(
|
|
101
|
+
t.strip() for t in info.get("test_file", "").split(";") if t.strip()
|
|
102
|
+
)
|
|
103
|
+
records.append(
|
|
104
|
+
BugRecord(
|
|
105
|
+
project=project,
|
|
106
|
+
bug_id=bug_id,
|
|
107
|
+
patch_text=patch_text,
|
|
108
|
+
test_files=test_files,
|
|
109
|
+
test_names=_parse_run_test(bug_dir / "run_test.sh"),
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
records.sort(key=lambda r: (r.project, int(r.bug_id) if r.bug_id.isdigit() else 0))
|
|
113
|
+
return records
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# ---------------------------------------------------------------------------
|
|
117
|
+
# Unified-diff parsing
|
|
118
|
+
# ---------------------------------------------------------------------------
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@dataclass
|
|
122
|
+
class Hunk:
|
|
123
|
+
file: str
|
|
124
|
+
header: str # text after the second @@ (enclosing-scope context)
|
|
125
|
+
removed: list[str] = field(default_factory=list)
|
|
126
|
+
added: list[str] = field(default_factory=list)
|
|
127
|
+
context: list[str] = field(default_factory=list)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@dataclass
|
|
131
|
+
class PatchFacts:
|
|
132
|
+
files: list[str]
|
|
133
|
+
hunks: list[Hunk]
|
|
134
|
+
functions: list[tuple[str, str]] # (file basename, function) pairs
|
|
135
|
+
added_lines: list[str] # non-blank code lines added (markers stripped)
|
|
136
|
+
removed_lines: list[str]
|
|
137
|
+
exceptions: list[str] # CamelCase *Error/*Exception/*Warning mentioned
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def n_added(self) -> int:
|
|
141
|
+
return len(self.added_lines)
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def n_removed(self) -> int:
|
|
145
|
+
return len(self.removed_lines)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
_EXC_RE = re.compile(r"\b([A-Z][A-Za-z0-9]*(?:Error|Exception|Warning))\b")
|
|
149
|
+
_DEF_RE = re.compile(r"\bdef\s+([A-Za-z_]\w*)")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def parse_patch(patch_text: str) -> PatchFacts:
|
|
153
|
+
files: list[str] = []
|
|
154
|
+
hunks: list[Hunk] = []
|
|
155
|
+
current_file = ""
|
|
156
|
+
hunk: Hunk | None = None
|
|
157
|
+
for raw in patch_text.splitlines():
|
|
158
|
+
if raw.startswith("diff --git"):
|
|
159
|
+
m = re.search(r" b/(\S+)$", raw)
|
|
160
|
+
current_file = m.group(1) if m else raw.split()[-1]
|
|
161
|
+
files.append(current_file)
|
|
162
|
+
hunk = None
|
|
163
|
+
continue
|
|
164
|
+
if raw.startswith("+++") or raw.startswith("---") or raw.startswith("index "):
|
|
165
|
+
continue
|
|
166
|
+
if raw.startswith("@@"):
|
|
167
|
+
m = re.match(r"^@@[^@]*@@\s?(.*)$", raw)
|
|
168
|
+
hunk = Hunk(file=current_file, header=m.group(1) if m else "")
|
|
169
|
+
hunks.append(hunk)
|
|
170
|
+
continue
|
|
171
|
+
if hunk is None:
|
|
172
|
+
continue
|
|
173
|
+
if raw.startswith("+"):
|
|
174
|
+
line = raw[1:]
|
|
175
|
+
if line.strip():
|
|
176
|
+
hunk.added.append(line)
|
|
177
|
+
elif raw.startswith("-"):
|
|
178
|
+
line = raw[1:]
|
|
179
|
+
if line.strip():
|
|
180
|
+
hunk.removed.append(line)
|
|
181
|
+
else:
|
|
182
|
+
line = raw[1:] if raw.startswith(" ") else raw
|
|
183
|
+
if line.strip():
|
|
184
|
+
hunk.context.append(line)
|
|
185
|
+
|
|
186
|
+
added = [l for h in hunks for l in h.added]
|
|
187
|
+
removed = [l for h in hunks for l in h.removed]
|
|
188
|
+
context = [l for h in hunks for l in h.context]
|
|
189
|
+
|
|
190
|
+
functions: list[tuple[str, str]] = []
|
|
191
|
+
seen_fn: set[tuple[str, str]] = set()
|
|
192
|
+
for h in hunks:
|
|
193
|
+
base = pathlib.PurePosixPath(h.file).name
|
|
194
|
+
names = _DEF_RE.findall(h.header)
|
|
195
|
+
# defs appearing inside the changed lines are also "modified functions"
|
|
196
|
+
for l in h.removed + h.added:
|
|
197
|
+
names.extend(_DEF_RE.findall(l))
|
|
198
|
+
if not names:
|
|
199
|
+
names = ["<module>"]
|
|
200
|
+
for name in names:
|
|
201
|
+
key = (base, name)
|
|
202
|
+
if key not in seen_fn:
|
|
203
|
+
seen_fn.add(key)
|
|
204
|
+
functions.append(key)
|
|
205
|
+
|
|
206
|
+
exceptions = sorted(
|
|
207
|
+
{e for l in added + context for e in _EXC_RE.findall(l)}
|
|
208
|
+
)
|
|
209
|
+
return PatchFacts(
|
|
210
|
+
files=files,
|
|
211
|
+
hunks=hunks,
|
|
212
|
+
functions=functions,
|
|
213
|
+
added_lines=added,
|
|
214
|
+
removed_lines=removed,
|
|
215
|
+
exceptions=exceptions,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# ---------------------------------------------------------------------------
|
|
220
|
+
# Case assembly (Tier-0)
|
|
221
|
+
# ---------------------------------------------------------------------------
|
|
222
|
+
|
|
223
|
+
_PY_KEYWORDS = {
|
|
224
|
+
"if", "elif", "else", "for", "while", "try", "except", "finally", "raise",
|
|
225
|
+
"return", "with", "assert", "not", "and", "or", "in", "is", "def", "class",
|
|
226
|
+
"lambda", "yield", "del", "pass", "import", "from", "as", "global", "print",
|
|
227
|
+
}
|
|
228
|
+
_CALL_RE = re.compile(r"([A-Za-z_][\w\.]*)\s*\(")
|
|
229
|
+
_KEYWORD_TOKENS = (
|
|
230
|
+
"if", "elif", "else", "for", "while", "try", "except", "finally", "raise",
|
|
231
|
+
"return", "with", "assert", "not", "and", "or", "in", "is", "None",
|
|
232
|
+
)
|
|
233
|
+
_KEYWORD_RES = {k: re.compile(rf"\b{k}\b") for k in _KEYWORD_TOKENS}
|
|
234
|
+
|
|
235
|
+
_MAX_CALLS = 12
|
|
236
|
+
_MAX_EXCEPTIONS = 8
|
|
237
|
+
_MAX_FUNCTIONS = 16
|
|
238
|
+
_MAX_ENCODER_STMTS = 24
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def _calls(lines: list[str]) -> Counter[str]:
|
|
242
|
+
counts: Counter[str] = Counter()
|
|
243
|
+
for line in lines:
|
|
244
|
+
for name in _CALL_RE.findall(line):
|
|
245
|
+
tail = name.split(".")[-1]
|
|
246
|
+
if tail in _PY_KEYWORDS or name in _PY_KEYWORDS:
|
|
247
|
+
continue
|
|
248
|
+
counts[tail] += 1
|
|
249
|
+
return counts
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _keyword_counts(lines: list[str]) -> Counter[str]:
|
|
253
|
+
counts: Counter[str] = Counter()
|
|
254
|
+
for line in lines:
|
|
255
|
+
code = line.split("#", 1)[0]
|
|
256
|
+
for kw, rx in _KEYWORD_RES.items():
|
|
257
|
+
counts[kw] += len(rx.findall(code))
|
|
258
|
+
return counts
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def size_bucket(n: int) -> str:
|
|
262
|
+
if n <= 0:
|
|
263
|
+
return "0"
|
|
264
|
+
if n == 1:
|
|
265
|
+
return "1"
|
|
266
|
+
if n <= 3:
|
|
267
|
+
return "2-3"
|
|
268
|
+
if n <= 7:
|
|
269
|
+
return "4-7"
|
|
270
|
+
if n <= 15:
|
|
271
|
+
return "8-15"
|
|
272
|
+
if n <= 31:
|
|
273
|
+
return "16-31"
|
|
274
|
+
return "32plus"
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _encoder_statements(added_lines: list[str]) -> list[Statement]:
|
|
278
|
+
"""Run the existing code adapter on the added hunks (AST if it parses,
|
|
279
|
+
regex fallback otherwise); keep its non-placeholder statements."""
|
|
280
|
+
block = "\n".join(l.strip() for l in added_lines)
|
|
281
|
+
if not block.strip():
|
|
282
|
+
return []
|
|
283
|
+
encoder = get_encoder("code")
|
|
284
|
+
result = encoder.encode(block, language="python")
|
|
285
|
+
functors = {s.functor for s in result.case.statements}
|
|
286
|
+
if "syntaxError" in functors: # hunk fragments rarely parse; regex fallback
|
|
287
|
+
result = encoder.encode(block, language="text")
|
|
288
|
+
keep = [
|
|
289
|
+
s
|
|
290
|
+
for s in result.case.statements
|
|
291
|
+
if s.functor not in {"syntaxError", "emptyCode", "rawCode"}
|
|
292
|
+
]
|
|
293
|
+
return keep[:_MAX_ENCODER_STMTS]
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def bug_case(record: BugRecord, facts: PatchFacts | None = None) -> Case:
|
|
297
|
+
"""Deterministic Tier-0 case for one bug: diff structure + failing test."""
|
|
298
|
+
facts = facts or parse_patch(record.patch_text)
|
|
299
|
+
statements: set[Statement] = set()
|
|
300
|
+
|
|
301
|
+
for base, fn in facts.functions[:_MAX_FUNCTIONS]:
|
|
302
|
+
statements.add(stmt("modifies", entity(base, "file"), entity(fn, "function")))
|
|
303
|
+
statements.add(stmt("addsLines", entity(size_bucket(facts.n_added), "count")))
|
|
304
|
+
statements.add(stmt("removesLines", entity(size_bucket(facts.n_removed), "count")))
|
|
305
|
+
statements.add(
|
|
306
|
+
stmt("touchesFiles", entity(size_bucket(len(facts.files)), "count"))
|
|
307
|
+
)
|
|
308
|
+
if any("test" in f.lower() for f in facts.files):
|
|
309
|
+
statements.add(stmt("touchesTestFile", entity("patch", "scope")))
|
|
310
|
+
|
|
311
|
+
for exc in facts.exceptions[:_MAX_EXCEPTIONS]:
|
|
312
|
+
statements.add(stmt("mentionsException", entity(exc, "exception")))
|
|
313
|
+
|
|
314
|
+
added_calls = _calls(facts.added_lines)
|
|
315
|
+
removed_calls = _calls(facts.removed_lines)
|
|
316
|
+
net_added = sorted(
|
|
317
|
+
(c for c in added_calls if added_calls[c] > removed_calls.get(c, 0)),
|
|
318
|
+
key=lambda c: (-added_calls[c], c),
|
|
319
|
+
)[:_MAX_CALLS]
|
|
320
|
+
net_removed = sorted(
|
|
321
|
+
(c for c in removed_calls if removed_calls[c] > added_calls.get(c, 0)),
|
|
322
|
+
key=lambda c: (-removed_calls[c], c),
|
|
323
|
+
)[:_MAX_CALLS]
|
|
324
|
+
for name in net_added:
|
|
325
|
+
statements.add(stmt("addsCall", entity("patch", "scope"), entity(name, "callable")))
|
|
326
|
+
for name in net_removed:
|
|
327
|
+
statements.add(stmt("removesCall", entity("patch", "scope"), entity(name, "callable")))
|
|
328
|
+
|
|
329
|
+
kw_added = _keyword_counts(facts.added_lines)
|
|
330
|
+
kw_removed = _keyword_counts(facts.removed_lines)
|
|
331
|
+
for kw in _KEYWORD_TOKENS:
|
|
332
|
+
if kw_added[kw] > kw_removed[kw]:
|
|
333
|
+
statements.add(stmt("addsKeyword", entity("patch", "scope"), entity(kw, "keyword")))
|
|
334
|
+
elif kw_removed[kw] > kw_added[kw]:
|
|
335
|
+
statements.add(stmt("removesKeyword", entity("patch", "scope"), entity(kw, "keyword")))
|
|
336
|
+
|
|
337
|
+
for name in record.test_names[:4]:
|
|
338
|
+
statements.add(stmt("failingTest", entity(name, "test")))
|
|
339
|
+
for tf in record.test_files[:4]:
|
|
340
|
+
statements.add(
|
|
341
|
+
stmt("testModule", entity(pathlib.PurePosixPath(tf).name, "file"))
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
statements.update(_encoder_statements(facts.added_lines))
|
|
345
|
+
|
|
346
|
+
if not statements:
|
|
347
|
+
statements.add(stmt("emptyPatch", entity(record.key, "bug")))
|
|
348
|
+
return make_case(
|
|
349
|
+
sorted(statements, key=repr),
|
|
350
|
+
{
|
|
351
|
+
"adapter": "code",
|
|
352
|
+
"tier": 0,
|
|
353
|
+
"dataset": "bugsinpy",
|
|
354
|
+
"project": record.project,
|
|
355
|
+
"bug": record.bug_id,
|
|
356
|
+
},
|
|
357
|
+
case_id=f"bugsinpy:{record.key}",
|
|
358
|
+
)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Deterministic fix-pattern categories for BugsInPy patches.
|
|
2
|
+
|
|
3
|
+
Categories are assigned by ORDERED rules over the unified diff (first match
|
|
4
|
+
wins). They are the ground-truth labels for the T3 retrieval metric
|
|
5
|
+
(fix-category-hit@k): did retrieval surface a past bug fixed by the SAME
|
|
6
|
+
KIND of change, not just a textually similar file?
|
|
7
|
+
|
|
8
|
+
Rule order (per the T3 specification):
|
|
9
|
+
1. add-null-check adds ``is None`` / ``is not None`` / ``if not x``
|
|
10
|
+
2. exception-handling adds ``try:`` / ``except`` / ``raise``
|
|
11
|
+
3. boundary comparison-operator swap or +/-1 on an otherwise
|
|
12
|
+
identical line (off-by-one / boundary fixes)
|
|
13
|
+
4. type-coercion adds a builtin cast (int()/str()/list()/...)
|
|
14
|
+
5. api-substitution small hunk replaces one call with another
|
|
15
|
+
6. default-arg-change a ``def`` signature line changes its defaults/args
|
|
16
|
+
7. condition-strengthening an if/elif/while gains and/or clauses
|
|
17
|
+
8. other
|
|
18
|
+
|
|
19
|
+
"Net added" means the pattern occurs more often in added lines than in
|
|
20
|
+
removed lines, so pure code motion does not trigger a rule.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import re
|
|
26
|
+
|
|
27
|
+
from .bugsinpy import Hunk, PatchFacts, _calls
|
|
28
|
+
|
|
29
|
+
CATEGORIES = (
|
|
30
|
+
"add-null-check",
|
|
31
|
+
"exception-handling",
|
|
32
|
+
"boundary",
|
|
33
|
+
"type-coercion",
|
|
34
|
+
"api-substitution",
|
|
35
|
+
"default-arg-change",
|
|
36
|
+
"condition-strengthening",
|
|
37
|
+
"other",
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
_NULL_CHECK = re.compile(r"\bis\s+(?:not\s+)?None\b|^\s*(?:el)?if\s+not\s+\w")
|
|
41
|
+
_EXC_HANDLING = re.compile(r"^\s*(?:try\s*:|except[\s:(]|raise\b)")
|
|
42
|
+
_CAST = re.compile(
|
|
43
|
+
r"(?<![\w.])(?:int|str|float|bool|list|tuple|set|dict|frozenset|bytes)\(|\.astype\("
|
|
44
|
+
)
|
|
45
|
+
_CMP_OPS = re.compile(r"<=|>=|==|!=|<|>")
|
|
46
|
+
_PM_ONE = re.compile(r"[+-]\s*1\b")
|
|
47
|
+
_COND_LINE = re.compile(r"^\s*((?:el)?if|while)\b")
|
|
48
|
+
_BOOL_OP = re.compile(r"\b(?:and|or)\b")
|
|
49
|
+
_DEF_LINE = re.compile(r"^\s*def\s+([A-Za-z_]\w*)\s*\((.*)$")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _strip_code(line: str) -> str:
|
|
53
|
+
return line.split("#", 1)[0]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _net(pattern: re.Pattern[str], added: list[str], removed: list[str]) -> bool:
|
|
57
|
+
n_add = sum(len(pattern.findall(_strip_code(l))) for l in added)
|
|
58
|
+
n_rem = sum(len(pattern.findall(_strip_code(l))) for l in removed)
|
|
59
|
+
return n_add > n_rem
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _paired_swap(hunk: Hunk, pattern: re.Pattern[str]) -> bool:
|
|
63
|
+
"""True if some removed/added line pair is identical once `pattern`
|
|
64
|
+
occurrences are masked out, while the raw lines differ (i.e. the ONLY
|
|
65
|
+
change is in the matched operator/constant)."""
|
|
66
|
+
def norm(line: str) -> str:
|
|
67
|
+
return re.sub(r"\s+", " ", pattern.sub("\x00", _strip_code(line))).strip()
|
|
68
|
+
|
|
69
|
+
removed = {norm(l): _strip_code(l).strip() for l in hunk.removed if pattern.search(_strip_code(l))}
|
|
70
|
+
for line in hunk.added:
|
|
71
|
+
code = _strip_code(line)
|
|
72
|
+
if not pattern.search(code):
|
|
73
|
+
continue
|
|
74
|
+
key = norm(line)
|
|
75
|
+
if key in removed and removed[key] != code.strip() and "\x00" in key:
|
|
76
|
+
return True
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _is_null_check(facts: PatchFacts) -> bool:
|
|
81
|
+
return _net(_NULL_CHECK, facts.added_lines, facts.removed_lines)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _is_exception_handling(facts: PatchFacts) -> bool:
|
|
85
|
+
return _net(_EXC_HANDLING, facts.added_lines, facts.removed_lines)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _is_boundary(facts: PatchFacts) -> bool:
|
|
89
|
+
for hunk in facts.hunks:
|
|
90
|
+
if _paired_swap(hunk, _CMP_OPS) or _paired_swap(hunk, _PM_ONE):
|
|
91
|
+
return True
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _is_type_coercion(facts: PatchFacts) -> bool:
|
|
96
|
+
return _net(_CAST, facts.added_lines, facts.removed_lines)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _is_api_substitution(facts: PatchFacts) -> bool:
|
|
100
|
+
for hunk in facts.hunks:
|
|
101
|
+
if not (1 <= len(hunk.removed) <= 3 and 1 <= len(hunk.added) <= 3):
|
|
102
|
+
continue
|
|
103
|
+
removed_calls = set(_calls(hunk.removed))
|
|
104
|
+
added_calls = set(_calls(hunk.added))
|
|
105
|
+
if (removed_calls - added_calls) and (added_calls - removed_calls):
|
|
106
|
+
return True
|
|
107
|
+
return False
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _is_default_arg_change(facts: PatchFacts) -> bool:
|
|
111
|
+
for hunk in facts.hunks:
|
|
112
|
+
removed_defs = {}
|
|
113
|
+
for line in hunk.removed:
|
|
114
|
+
m = _DEF_LINE.match(_strip_code(line))
|
|
115
|
+
if m:
|
|
116
|
+
removed_defs[m.group(1)] = m.group(2).strip()
|
|
117
|
+
for line in hunk.added:
|
|
118
|
+
m = _DEF_LINE.match(_strip_code(line))
|
|
119
|
+
if not m:
|
|
120
|
+
continue
|
|
121
|
+
name, args = m.group(1), m.group(2).strip()
|
|
122
|
+
old = removed_defs.get(name)
|
|
123
|
+
if old is not None and old != args and ("=" in old or "=" in args):
|
|
124
|
+
return True
|
|
125
|
+
return False
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _is_condition_strengthening(facts: PatchFacts) -> bool:
|
|
129
|
+
for hunk in facts.hunks:
|
|
130
|
+
removed_conds: dict[str, int] = {}
|
|
131
|
+
for line in hunk.removed:
|
|
132
|
+
m = _COND_LINE.match(_strip_code(line))
|
|
133
|
+
if m:
|
|
134
|
+
kw = m.group(1)
|
|
135
|
+
n = len(_BOOL_OP.findall(_strip_code(line)))
|
|
136
|
+
removed_conds[kw] = max(removed_conds.get(kw, -1), n)
|
|
137
|
+
for line in hunk.added:
|
|
138
|
+
m = _COND_LINE.match(_strip_code(line))
|
|
139
|
+
if not m:
|
|
140
|
+
continue
|
|
141
|
+
kw = m.group(1)
|
|
142
|
+
n = len(_BOOL_OP.findall(_strip_code(line)))
|
|
143
|
+
if kw in removed_conds and n > removed_conds[kw]:
|
|
144
|
+
return True
|
|
145
|
+
return False
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
_RULES = (
|
|
149
|
+
("add-null-check", _is_null_check),
|
|
150
|
+
("exception-handling", _is_exception_handling),
|
|
151
|
+
("boundary", _is_boundary),
|
|
152
|
+
("type-coercion", _is_type_coercion),
|
|
153
|
+
("api-substitution", _is_api_substitution),
|
|
154
|
+
("default-arg-change", _is_default_arg_change),
|
|
155
|
+
("condition-strengthening", _is_condition_strengthening),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def categorize(facts: PatchFacts) -> str:
|
|
160
|
+
"""Ordered first-match-wins fix-pattern category for one patch."""
|
|
161
|
+
for name, rule in _RULES:
|
|
162
|
+
if rule(facts):
|
|
163
|
+
return name
|
|
164
|
+
return "other"
|
sma/eval/crossdomain.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""Shared cross-domain arm evaluator (4b).
|
|
2
|
+
|
|
3
|
+
Guarantees an IDENTICAL index/query partition across arms (generic / drafted /
|
|
4
|
+
expert) by splitting on the original record order — NOT on the encoding-dependent
|
|
5
|
+
case_id — so the baselines are stable and the only thing that varies between arms
|
|
6
|
+
is the SMA encoding. Returns the result dict and writes cd_<domain>_<phase>.csv.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import csv
|
|
11
|
+
import pathlib
|
|
12
|
+
import time
|
|
13
|
+
|
|
14
|
+
from sma.index.macfac import MacFacIndex
|
|
15
|
+
from sma.ir.schema import Statement
|
|
16
|
+
from sma.eval.baselines.bm25 import rank_bm25_like
|
|
17
|
+
from sma.eval.baselines.dense import rank_tfidf_dense_batch
|
|
18
|
+
from sma.eval.metrics import macro_f1
|
|
19
|
+
from sma.eval.stats import paired_bootstrap, holm_bonferroni, cliffs_delta
|
|
20
|
+
|
|
21
|
+
OUT = pathlib.Path("reports/confirmatory")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _ho_density(case) -> float:
|
|
25
|
+
ex = case.expressions()
|
|
26
|
+
ho = sum(1 for s in ex if any(isinstance(a, Statement) for a in s.args))
|
|
27
|
+
return ho / max(len(ex), 1)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _vote(ranked_ids, label_of, labels):
|
|
31
|
+
pos, neg = labels
|
|
32
|
+
tally = {pos: 0, neg: 0}
|
|
33
|
+
for cid in ranked_ids:
|
|
34
|
+
tally[label_of[cid]] += 1
|
|
35
|
+
return pos if tally[pos] >= tally[neg] else neg
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def evaluate_arm(items, encode_fn, row_text_fn, labels, phase, domain,
|
|
39
|
+
k=10, frac=0.7):
|
|
40
|
+
"""items: list of records (each has .label). encode_fn(item)->Case.
|
|
41
|
+
row_text_fn(item)->str for the baselines. labels=(pos,neg)."""
|
|
42
|
+
index_n = int(len(items) * frac)
|
|
43
|
+
index_items, query_items = items[:index_n], items[index_n:] # FIXED split
|
|
44
|
+
|
|
45
|
+
def enc(its):
|
|
46
|
+
return [(it, encode_fn(it)) for it in its]
|
|
47
|
+
idx, qry = enc(index_items), enc(query_items)
|
|
48
|
+
label_of = {c.case_id: it.label for it, c in idx + qry}
|
|
49
|
+
text_of = {c.case_id: row_text_fn(it) for it, c in idx + qry}
|
|
50
|
+
dens = [_ho_density(c) for _, c in idx + qry]
|
|
51
|
+
mean_ho = sum(dens) / len(dens)
|
|
52
|
+
|
|
53
|
+
index_cases = [c for _, c in idx]
|
|
54
|
+
index_docs = [(c.case_id, text_of[c.case_id]) for _, c in idx]
|
|
55
|
+
t0 = time.perf_counter()
|
|
56
|
+
sma_index = MacFacIndex(); sma_index.build(index_cases)
|
|
57
|
+
gold, sma_p, bm_p, dn_p = [], [], [], []
|
|
58
|
+
dense_rk = rank_tfidf_dense_batch([text_of[c.case_id] for _, c in qry], index_docs, k=k)
|
|
59
|
+
for qi, (_, qc) in enumerate(qry):
|
|
60
|
+
gold.append(label_of[qc.case_id])
|
|
61
|
+
res = sma_index.retrieve(qc, k=k, shortlist=60, fac_budget=25)
|
|
62
|
+
sma_p.append(_vote([r.case_id for r in res], label_of, labels))
|
|
63
|
+
bm = rank_bm25_like(text_of[qc.case_id], index_docs, k=k)
|
|
64
|
+
bm_p.append(_vote([cid for cid, _ in bm], label_of, labels))
|
|
65
|
+
dn_p.append(_vote([cid for cid, _ in dense_rk[qi]], label_of, labels))
|
|
66
|
+
dt = time.perf_counter() - t0
|
|
67
|
+
|
|
68
|
+
f1 = {"SMA": macro_f1(gold, sma_p), "BM25": macro_f1(gold, bm_p), "Dense": macro_f1(gold, dn_p)}
|
|
69
|
+
print(f"[{domain}/{phase}] HO-density={mean_ho:.4f} macro-F1: SMA {f1['SMA']:.4f} "
|
|
70
|
+
f"BM25 {f1['BM25']:.4f} Dense {f1['Dense']:.4f} ({dt:.0f}s)", flush=True)
|
|
71
|
+
|
|
72
|
+
def correct(p):
|
|
73
|
+
return [1.0 if a == g else 0.0 for a, g in zip(p, gold)]
|
|
74
|
+
sma_c = correct(sma_p); pv, summ = {}, []
|
|
75
|
+
for name, pred in (("BM25", bm_p), ("Dense", dn_p)):
|
|
76
|
+
bs = paired_bootstrap(sma_c, correct(pred)); pv[name] = bs["p_value"]
|
|
77
|
+
summ.append({"phase": phase, "domain": domain, "baseline": name,
|
|
78
|
+
"ho_density": f"{mean_ho:.4f}", "sma_f1": f"{f1['SMA']:.4f}",
|
|
79
|
+
"baseline_f1": f"{f1[name]:.4f}", "delta": f"{bs['delta']:.4f}",
|
|
80
|
+
"ci_low": f"{bs['ci_low']:.4f}", "ci_high": f"{bs['ci_high']:.4f}",
|
|
81
|
+
"cliffs": f"{cliffs_delta(sma_c, correct(pred)):.4f}"})
|
|
82
|
+
holm = holm_bonferroni(pv)
|
|
83
|
+
for s in summ:
|
|
84
|
+
s["p_holm"] = f"{holm[s['baseline']]:.4f}"
|
|
85
|
+
out = OUT / f"cd_{domain}_{phase}.csv"
|
|
86
|
+
out.parent.mkdir(parents=True, exist_ok=True)
|
|
87
|
+
with out.open("w", newline="") as fh:
|
|
88
|
+
w = csv.DictWriter(fh, fieldnames=list(summ[0])); w.writeheader(); w.writerows(summ)
|
|
89
|
+
return {"f1": f1, "ho": mean_ho, "summ": summ}
|
sma/eval/diabetes.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Diabetes-130 loader + per-encounter artifact builder (4b cross-domain).
|
|
2
|
+
|
|
3
|
+
Real UCI EHR. We build a flat CSV-row artifact for the GENERIC structured
|
|
4
|
+
adapter (the honest 'before': flat triples, no higher-order relations) and a
|
|
5
|
+
plain attr=value text for the lexical/dense baselines. The readmission label is
|
|
6
|
+
NEVER encoded (leakage guard).
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import csv
|
|
11
|
+
import pathlib
|
|
12
|
+
import random
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
|
|
15
|
+
# ids, mostly-missing columns, and the LABEL are excluded from the encoding.
|
|
16
|
+
DROP = {"encounter_id", "patient_nbr", "weight", "payer_code", "readmitted"}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class Encounter:
|
|
21
|
+
eid: str
|
|
22
|
+
fields: dict[str, str] # cleaned attribute -> value (no '?'/'' )
|
|
23
|
+
label: str # "early" (readmitted <30 days) vs "not"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _csv_path() -> pathlib.Path:
|
|
27
|
+
return next(pathlib.Path("data/raw/diabetes130").rglob("diabetic_data.csv"))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def load_encounters(sample: int = 1500, seed: int = 7, balanced: bool = True) -> list[Encounter]:
|
|
31
|
+
rows = list(csv.DictReader(_csv_path().open()))
|
|
32
|
+
rng = random.Random(seed)
|
|
33
|
+
rng.shuffle(rows)
|
|
34
|
+
|
|
35
|
+
def to_enc(r: dict) -> Encounter:
|
|
36
|
+
fields = {k: v for k, v in r.items() if k not in DROP and v not in ("?", "")}
|
|
37
|
+
label = "early" if r["readmitted"] == "<30" else "not"
|
|
38
|
+
return Encounter(r["encounter_id"], fields, label)
|
|
39
|
+
|
|
40
|
+
encs = [to_enc(r) for r in rows]
|
|
41
|
+
if not balanced:
|
|
42
|
+
return encs[:sample]
|
|
43
|
+
# balance the two classes so retrieval-by-analogy has signal (early is ~11%)
|
|
44
|
+
early = [e for e in encs if e.label == "early"]
|
|
45
|
+
notr = [e for e in encs if e.label == "not"]
|
|
46
|
+
half = sample // 2
|
|
47
|
+
out = early[:half] + notr[:half]
|
|
48
|
+
rng.shuffle(out)
|
|
49
|
+
return out
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def row_csv(enc: Encounter) -> str:
|
|
53
|
+
"""Two-line CSV (header + one row) for the structured adapter -> flat
|
|
54
|
+
(attribute row value) triples."""
|
|
55
|
+
keys = sorted(enc.fields)
|
|
56
|
+
return ",".join(keys) + "\n" + ",".join(enc.fields[k] for k in keys) + "\n"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def row_text(enc: Encounter) -> str:
|
|
60
|
+
"""attr=value text for BM25 / dense baselines."""
|
|
61
|
+
return " ".join(f"{k}={v}" for k, v in sorted(enc.fields.items()))
|