shrinkray 0.0.0__py3-none-any.whl → 25.12.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shrinkray/__main__.py +130 -960
- shrinkray/cli.py +70 -0
- shrinkray/display.py +75 -0
- shrinkray/formatting.py +108 -0
- shrinkray/passes/bytes.py +217 -10
- shrinkray/passes/clangdelta.py +47 -17
- shrinkray/passes/definitions.py +84 -4
- shrinkray/passes/genericlanguages.py +61 -7
- shrinkray/passes/json.py +6 -0
- shrinkray/passes/patching.py +65 -57
- shrinkray/passes/python.py +66 -23
- shrinkray/passes/sat.py +505 -91
- shrinkray/passes/sequences.py +26 -6
- shrinkray/problem.py +206 -27
- shrinkray/process.py +49 -0
- shrinkray/reducer.py +187 -25
- shrinkray/state.py +599 -0
- shrinkray/subprocess/__init__.py +24 -0
- shrinkray/subprocess/client.py +253 -0
- shrinkray/subprocess/protocol.py +190 -0
- shrinkray/subprocess/worker.py +491 -0
- shrinkray/tui.py +915 -0
- shrinkray/ui.py +72 -0
- shrinkray/work.py +34 -6
- {shrinkray-0.0.0.dist-info → shrinkray-25.12.26.0.dist-info}/METADATA +44 -27
- shrinkray-25.12.26.0.dist-info/RECORD +33 -0
- {shrinkray-0.0.0.dist-info → shrinkray-25.12.26.0.dist-info}/WHEEL +2 -1
- shrinkray-25.12.26.0.dist-info/entry_points.txt +3 -0
- shrinkray-25.12.26.0.dist-info/top_level.txt +1 -0
- shrinkray/learning.py +0 -221
- shrinkray-0.0.0.dist-info/RECORD +0 -22
- shrinkray-0.0.0.dist-info/entry_points.txt +0 -3
- {shrinkray-0.0.0.dist-info → shrinkray-25.12.26.0.dist-info/licenses}/LICENSE +0 -0
shrinkray/passes/clangdelta.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import subprocess
|
|
3
|
+
from functools import lru_cache
|
|
3
4
|
from glob import glob
|
|
4
5
|
from shutil import which
|
|
5
6
|
from tempfile import NamedTemporaryFile
|
|
@@ -10,6 +11,7 @@ from shrinkray.passes.definitions import ReductionPump
|
|
|
10
11
|
from shrinkray.problem import ReductionProblem
|
|
11
12
|
from shrinkray.work import NotFound
|
|
12
13
|
|
|
14
|
+
|
|
13
15
|
C_FILE_EXTENSIONS = (".c", ".cpp", ".h", ".hpp", ".cxx", ".cc")
|
|
14
16
|
|
|
15
17
|
|
|
@@ -24,6 +26,29 @@ def find_clang_delta():
|
|
|
24
26
|
return clang_delta
|
|
25
27
|
|
|
26
28
|
|
|
29
|
+
@lru_cache(maxsize=1)
|
|
30
|
+
def clang_delta_works() -> bool:
|
|
31
|
+
"""Check if clang_delta can actually execute.
|
|
32
|
+
|
|
33
|
+
This verifies not just that the binary exists, but that it can run.
|
|
34
|
+
On some systems (e.g., Ubuntu 24.04), creduce is installed but
|
|
35
|
+
clang_delta fails at runtime due to shared library issues.
|
|
36
|
+
"""
|
|
37
|
+
clang_delta = find_clang_delta()
|
|
38
|
+
if not clang_delta:
|
|
39
|
+
return False
|
|
40
|
+
try:
|
|
41
|
+
# Run a simple test to verify clang_delta works
|
|
42
|
+
result = subprocess.run(
|
|
43
|
+
[clang_delta, "--help"],
|
|
44
|
+
capture_output=True,
|
|
45
|
+
timeout=5,
|
|
46
|
+
)
|
|
47
|
+
return result.returncode == 0
|
|
48
|
+
except (OSError, subprocess.TimeoutExpired):
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
|
|
27
52
|
TRANSFORMATIONS: list[str] = [
|
|
28
53
|
"aggregate-to-scalar",
|
|
29
54
|
"binop-simplification",
|
|
@@ -126,9 +151,10 @@ class ClangDelta:
|
|
|
126
151
|
).stdout
|
|
127
152
|
except subprocess.CalledProcessError as e:
|
|
128
153
|
msg = (e.stdout + e.stderr).strip()
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
154
|
+
# clang_delta has many internal assertions that can be triggered
|
|
155
|
+
# by malformed or unusual C/C++ code. These are harmless - we just
|
|
156
|
+
# report zero instances and skip this transformation.
|
|
157
|
+
if b"Assertion failed" in msg:
|
|
132
158
|
return 0
|
|
133
159
|
else:
|
|
134
160
|
raise ClangDeltaError(msg)
|
|
@@ -161,13 +187,13 @@ class ClangDelta:
|
|
|
161
187
|
)
|
|
162
188
|
).stdout
|
|
163
189
|
except subprocess.CalledProcessError as e:
|
|
164
|
-
if
|
|
165
|
-
raise ValueError("Not a C or C++ test case")
|
|
166
|
-
elif (
|
|
190
|
+
if (
|
|
167
191
|
e.stdout.strip()
|
|
168
192
|
== b"Error: No modification to the transformed program!"
|
|
169
193
|
):
|
|
170
194
|
return data
|
|
195
|
+
elif b"Assertion failed" in e.stderr.strip():
|
|
196
|
+
return data
|
|
171
197
|
else:
|
|
172
198
|
raise ClangDeltaError(e.stdout + e.stderr)
|
|
173
199
|
finally:
|
|
@@ -175,7 +201,9 @@ class ClangDelta:
|
|
|
175
201
|
|
|
176
202
|
|
|
177
203
|
class ClangDeltaError(Exception):
|
|
178
|
-
|
|
204
|
+
def __init__(self, message):
|
|
205
|
+
assert b"Assertion failed" not in message, message
|
|
206
|
+
super().__init__(message)
|
|
179
207
|
|
|
180
208
|
|
|
181
209
|
def clang_delta_pump(
|
|
@@ -186,10 +214,7 @@ def clang_delta_pump(
|
|
|
186
214
|
assert target is not None
|
|
187
215
|
try:
|
|
188
216
|
n = await clang_delta.query_instances(transformation, target)
|
|
189
|
-
except
|
|
190
|
-
import traceback
|
|
191
|
-
|
|
192
|
-
traceback.print_exc()
|
|
217
|
+
except ClangDeltaError:
|
|
193
218
|
return target
|
|
194
219
|
i = 1
|
|
195
220
|
while i <= n:
|
|
@@ -203,15 +228,20 @@ def clang_delta_pump(
|
|
|
203
228
|
return False
|
|
204
229
|
return await problem.is_interesting(attempt)
|
|
205
230
|
|
|
231
|
+
not_found = False
|
|
232
|
+
clang_delta_failed = False
|
|
206
233
|
try:
|
|
207
234
|
i = await problem.work.find_first_value(range(i, n + 1), can_apply)
|
|
208
|
-
except NotFound:
|
|
235
|
+
except* NotFound:
|
|
236
|
+
not_found = True
|
|
237
|
+
except* ClangDeltaError:
|
|
238
|
+
# clang_delta assertions can be triggered by unusual C/C++ code.
|
|
239
|
+
# These are harmless - just return what we have so far.
|
|
240
|
+
clang_delta_failed = True
|
|
241
|
+
if not_found:
|
|
209
242
|
break
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
# if you feed it bad enough C++. We solve this problem by ignoring it.
|
|
213
|
-
if b"Assertion failed" in e.args[0]:
|
|
214
|
-
return target
|
|
243
|
+
if clang_delta_failed:
|
|
244
|
+
return target
|
|
215
245
|
|
|
216
246
|
target = await clang_delta.apply_transformation(transformation, i, target)
|
|
217
247
|
assert target is not None
|
shrinkray/passes/definitions.py
CHANGED
|
@@ -1,30 +1,93 @@
|
|
|
1
|
+
"""Type definitions and utilities for reduction passes.
|
|
2
|
+
|
|
3
|
+
This module defines the core type aliases and abstractions for reduction:
|
|
4
|
+
|
|
5
|
+
- ReductionPass[T]: A function that attempts to reduce a test case
|
|
6
|
+
- ReductionPump[T]: A function that may temporarily increase test case size
|
|
7
|
+
- Format[S, T]: A bidirectional transformation between types
|
|
8
|
+
- compose(): Combines a Format with a pass to work on a different type
|
|
9
|
+
|
|
10
|
+
These abstractions enable format-agnostic reduction: the same pass
|
|
11
|
+
(e.g., "delete duplicate elements") can work on bytes, lines, tokens,
|
|
12
|
+
JSON arrays, or any other sequence-like type.
|
|
13
|
+
"""
|
|
14
|
+
|
|
1
15
|
from abc import ABC, abstractmethod
|
|
16
|
+
from collections.abc import Awaitable, Callable
|
|
2
17
|
from functools import wraps
|
|
3
|
-
from typing import
|
|
18
|
+
from typing import TypeVar
|
|
4
19
|
|
|
5
20
|
from shrinkray.problem import ReductionProblem
|
|
6
21
|
|
|
22
|
+
|
|
7
23
|
S = TypeVar("S")
|
|
8
24
|
T = TypeVar("T")
|
|
9
25
|
|
|
10
26
|
|
|
27
|
+
# A reduction pass takes a problem and attempts to reduce it.
|
|
28
|
+
# The pass modifies the problem by calling is_interesting() with smaller candidates.
|
|
29
|
+
# When a reduction succeeds, problem.current_test_case is automatically updated.
|
|
11
30
|
ReductionPass = Callable[[ReductionProblem[T]], Awaitable[None]]
|
|
31
|
+
|
|
32
|
+
# A reduction pump can temporarily INCREASE test case size.
|
|
33
|
+
# Example: inlining a function makes code larger, but may enable further reductions.
|
|
34
|
+
# The reducer runs passes on the pumped result using backtrack() to try to
|
|
35
|
+
# reduce it below the original size.
|
|
12
36
|
ReductionPump = Callable[[ReductionProblem[T]], Awaitable[T]]
|
|
13
37
|
|
|
14
38
|
|
|
15
39
|
class ParseError(Exception):
|
|
40
|
+
"""Raised when a Format cannot parse its input."""
|
|
41
|
+
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class DumpError(Exception):
|
|
46
|
+
"""Raised when a Format cannot serialize its output.
|
|
47
|
+
|
|
48
|
+
This occurs because not all internal representations map to valid
|
|
49
|
+
output in the target format. For example, a reduction might create
|
|
50
|
+
an invalid AST structure that cannot be converted back to source code.
|
|
51
|
+
"""
|
|
52
|
+
|
|
16
53
|
pass
|
|
17
54
|
|
|
18
55
|
|
|
19
|
-
class Format
|
|
56
|
+
class Format[S, T](ABC):
|
|
57
|
+
"""A bidirectional transformation between two types.
|
|
58
|
+
|
|
59
|
+
Formats enable format-agnostic passes by abstracting the
|
|
60
|
+
parse/serialize cycle. For example:
|
|
61
|
+
|
|
62
|
+
- Split(b"\\n"): bytes <-> list[bytes] (lines)
|
|
63
|
+
- Tokenize(): bytes <-> list[bytes] (tokens)
|
|
64
|
+
- JSON: bytes <-> Any (Python objects)
|
|
65
|
+
- DimacsCNF: bytes <-> list[list[int]] (SAT clauses)
|
|
66
|
+
|
|
67
|
+
A Format must satisfy the round-trip property:
|
|
68
|
+
dumps(parse(x)) should be equivalent to x
|
|
69
|
+
(possibly with normalization)
|
|
70
|
+
|
|
71
|
+
Example usage:
|
|
72
|
+
# Delete duplicate lines
|
|
73
|
+
compose(Split(b"\\n"), delete_duplicates)
|
|
74
|
+
|
|
75
|
+
# Reduce integer literals in source code
|
|
76
|
+
compose(IntegerFormat(), reduce_integer)
|
|
77
|
+
"""
|
|
78
|
+
|
|
20
79
|
@property
|
|
21
80
|
def name(self) -> str:
|
|
81
|
+
"""Human-readable name for this format, used in pass names."""
|
|
22
82
|
return repr(self)
|
|
23
83
|
|
|
24
84
|
@abstractmethod
|
|
25
|
-
def parse(self, input: S) -> T:
|
|
85
|
+
def parse(self, input: S) -> T:
|
|
86
|
+
"""Parse input into the target type. Raises ParseError on failure."""
|
|
87
|
+
...
|
|
26
88
|
|
|
27
89
|
def is_valid(self, input: S) -> bool:
|
|
90
|
+
"""Check if input can be parsed by this format."""
|
|
28
91
|
try:
|
|
29
92
|
self.parse(input)
|
|
30
93
|
return True
|
|
@@ -32,10 +95,27 @@ class Format(Generic[S, T], ABC):
|
|
|
32
95
|
return False
|
|
33
96
|
|
|
34
97
|
@abstractmethod
|
|
35
|
-
def dumps(self, input: T) -> S:
|
|
98
|
+
def dumps(self, input: T) -> S:
|
|
99
|
+
"""Serialize the target type back to the source type."""
|
|
100
|
+
...
|
|
36
101
|
|
|
37
102
|
|
|
38
103
|
def compose(format: Format[S, T], reduction_pass: ReductionPass[T]) -> ReductionPass[S]:
|
|
104
|
+
"""Wrap a reduction pass to work through a Format transformation.
|
|
105
|
+
|
|
106
|
+
This is the key combinator for format-agnostic reduction. It takes
|
|
107
|
+
a pass that works on type T and returns a pass that works on type S,
|
|
108
|
+
by parsing S->T before the pass and dumping T->S after.
|
|
109
|
+
|
|
110
|
+
Example:
|
|
111
|
+
# delete_duplicates works on sequences
|
|
112
|
+
# Split(b"\\n") parses bytes into lines
|
|
113
|
+
# The composed pass deletes duplicate lines from bytes
|
|
114
|
+
line_dedup = compose(Split(b"\\n"), delete_duplicates)
|
|
115
|
+
|
|
116
|
+
If parsing fails, the composed pass returns immediately (no-op).
|
|
117
|
+
"""
|
|
118
|
+
|
|
39
119
|
@wraps(reduction_pass)
|
|
40
120
|
async def wrapped_pass(problem: ReductionProblem[S]) -> None:
|
|
41
121
|
view = problem.view(format)
|
|
@@ -3,9 +3,10 @@ Module of reduction passes designed for "things that look like programming langu
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import re
|
|
6
|
+
from collections.abc import Callable, Sized
|
|
6
7
|
from functools import wraps
|
|
7
8
|
from string import ascii_lowercase, ascii_uppercase
|
|
8
|
-
from typing import AnyStr
|
|
9
|
+
from typing import AnyStr
|
|
9
10
|
|
|
10
11
|
import trio
|
|
11
12
|
from attr import define
|
|
@@ -37,9 +38,9 @@ class Substring(Format[AnyStr, AnyStr]):
|
|
|
37
38
|
|
|
38
39
|
|
|
39
40
|
class RegionReplacingPatches(Patches[dict[int, AnyStr], AnyStr]):
|
|
40
|
-
def __init__(self, regions):
|
|
41
|
+
def __init__(self, regions: list[tuple[int, int]]):
|
|
41
42
|
assert regions
|
|
42
|
-
for (_, v), (u, _) in zip(regions, regions[1:]):
|
|
43
|
+
for (_, v), (u, _) in zip(regions, regions[1:], strict=False):
|
|
43
44
|
assert v <= u
|
|
44
45
|
self.regions = regions
|
|
45
46
|
|
|
@@ -69,15 +70,15 @@ class RegionReplacingPatches(Patches[dict[int, AnyStr], AnyStr]):
|
|
|
69
70
|
return empty.join(parts)
|
|
70
71
|
|
|
71
72
|
def size(self, patch):
|
|
72
|
-
total = 0
|
|
73
73
|
for i, s in patch.items():
|
|
74
74
|
u, v = self.regions[i]
|
|
75
75
|
return v - u - len(s)
|
|
76
|
+
raise AssertionError(f"expected nonempty {patch=}")
|
|
76
77
|
|
|
77
78
|
|
|
78
79
|
def regex_pass(
|
|
79
80
|
pattern: AnyStr | re.Pattern[AnyStr],
|
|
80
|
-
flags: re.RegexFlag =
|
|
81
|
+
flags: re.RegexFlag = re.RegexFlag.NOFLAG,
|
|
81
82
|
) -> Callable[[ReductionPass[AnyStr]], ReductionPass[AnyStr]]:
|
|
82
83
|
if not isinstance(pattern, re.Pattern):
|
|
83
84
|
pattern = re.compile(pattern, flags=flags)
|
|
@@ -129,6 +130,11 @@ def regex_pass(
|
|
|
129
130
|
|
|
130
131
|
|
|
131
132
|
async def reduce_integer(problem: ReductionProblem[int]) -> None:
|
|
133
|
+
"""Reduce an integer to its smallest interesting value.
|
|
134
|
+
|
|
135
|
+
Uses binary search to find the smallest integer that maintains
|
|
136
|
+
interestingness. Tries 0 first, then narrows down the range.
|
|
137
|
+
"""
|
|
132
138
|
assert problem.current_test_case >= 0
|
|
133
139
|
|
|
134
140
|
if await problem.is_interesting(0):
|
|
@@ -166,11 +172,21 @@ class IntegerFormat(Format[bytes, int]):
|
|
|
166
172
|
|
|
167
173
|
@regex_pass(b"[0-9]+")
|
|
168
174
|
async def reduce_integer_literals(problem: ReductionProblem[bytes]) -> None:
|
|
175
|
+
"""Reduce integer literals in source code to smaller values.
|
|
176
|
+
|
|
177
|
+
Finds numeric literals and tries to reduce each one independently
|
|
178
|
+
using binary search.
|
|
179
|
+
"""
|
|
169
180
|
await reduce_integer(problem.view(IntegerFormat()))
|
|
170
181
|
|
|
171
182
|
|
|
172
183
|
@regex_pass(rb"[0-9]+ [*+-/] [0-9]+")
|
|
173
184
|
async def combine_expressions(problem: ReductionProblem[bytes]) -> None:
|
|
185
|
+
"""Evaluate and simplify simple arithmetic expressions.
|
|
186
|
+
|
|
187
|
+
Finds expressions like "2 + 3" and replaces them with their result "5".
|
|
188
|
+
Only handles basic integer arithmetic to avoid changing program semantics.
|
|
189
|
+
"""
|
|
174
190
|
try:
|
|
175
191
|
# NB: Use of eval is safe, as everything passed to this is a simple
|
|
176
192
|
# arithmetic expression. Would ideally replace with a guaranteed
|
|
@@ -184,18 +200,39 @@ async def combine_expressions(problem: ReductionProblem[bytes]) -> None:
|
|
|
184
200
|
|
|
185
201
|
@regex_pass(rb'([\'"])\s*\1')
|
|
186
202
|
async def merge_adjacent_strings(problem: ReductionProblem[bytes]) -> None:
|
|
203
|
+
"""Remove empty string concatenations like '' '' or "" "".
|
|
204
|
+
|
|
205
|
+
These patterns (quote, whitespace, same quote) often result from
|
|
206
|
+
other reductions and can be eliminated entirely.
|
|
207
|
+
"""
|
|
187
208
|
await problem.is_interesting(b"")
|
|
188
209
|
|
|
189
210
|
|
|
190
211
|
@regex_pass(rb"''|\"\"|false|\(\)|\[\]", re.IGNORECASE)
|
|
191
212
|
async def replace_falsey_with_zero(problem: ReductionProblem[bytes]) -> None:
|
|
213
|
+
"""Replace falsey values with 0.
|
|
214
|
+
|
|
215
|
+
Tries to replace empty strings, 'false', empty parentheses, and empty
|
|
216
|
+
brackets with the single character '0', which is shorter and often
|
|
217
|
+
equivalent in boolean contexts.
|
|
218
|
+
"""
|
|
192
219
|
await problem.is_interesting(b"0")
|
|
193
220
|
|
|
194
221
|
|
|
195
222
|
async def simplify_brackets(problem: ReductionProblem[bytes]) -> None:
|
|
223
|
+
"""Try to replace bracket types with simpler ones.
|
|
224
|
+
|
|
225
|
+
Attempts to replace {} with [] or (), and [] with (). This can
|
|
226
|
+
help normalize syntax when the specific bracket type doesn't matter.
|
|
227
|
+
"""
|
|
196
228
|
bracket_types = [b"[]", b"{}", b"()"]
|
|
197
229
|
|
|
198
|
-
patches = [
|
|
230
|
+
patches = [
|
|
231
|
+
dict(zip(u, v, strict=True))
|
|
232
|
+
for u in bracket_types
|
|
233
|
+
for v in bracket_types
|
|
234
|
+
if u > v
|
|
235
|
+
]
|
|
199
236
|
|
|
200
237
|
await apply_patches(problem, ByteReplacement(), patches)
|
|
201
238
|
|
|
@@ -203,11 +240,17 @@ async def simplify_brackets(problem: ReductionProblem[bytes]) -> None:
|
|
|
203
240
|
IDENTIFIER = re.compile(rb"(\b[A-Za-z][A-Za-z0-9_]*\b)|([0-9]+)")
|
|
204
241
|
|
|
205
242
|
|
|
206
|
-
def shortlex(s):
|
|
243
|
+
def shortlex[T: Sized](s: T) -> tuple[int, T]:
|
|
207
244
|
return (len(s), s)
|
|
208
245
|
|
|
209
246
|
|
|
210
247
|
async def normalize_identifiers(problem: ReductionProblem[bytes]) -> None:
|
|
248
|
+
"""Replace identifiers with shorter alternatives.
|
|
249
|
+
|
|
250
|
+
Finds all identifiers in the source and tries to replace longer ones
|
|
251
|
+
with shorter alternatives (single letters like 'a', 'b', etc.). This
|
|
252
|
+
normalizes variable/function names to minimal forms.
|
|
253
|
+
"""
|
|
211
254
|
identifiers = {m.group(0) for m in IDENTIFIER.finditer(problem.current_test_case)}
|
|
212
255
|
replacements = set(identifiers)
|
|
213
256
|
|
|
@@ -253,6 +296,12 @@ def iter_indices(s, substring):
|
|
|
253
296
|
|
|
254
297
|
|
|
255
298
|
async def cut_comments(problem: ReductionProblem[bytes], start, end, include_end=True):
|
|
299
|
+
"""Remove comment-like regions bounded by start and end markers.
|
|
300
|
+
|
|
301
|
+
Finds all regions starting with 'start' and ending with 'end', then
|
|
302
|
+
tries to delete them. Used to remove comments from various languages.
|
|
303
|
+
If include_end is False, the end marker itself is not deleted.
|
|
304
|
+
"""
|
|
256
305
|
cuts = []
|
|
257
306
|
target = problem.current_test_case
|
|
258
307
|
# python comments
|
|
@@ -271,6 +320,11 @@ async def cut_comments(problem: ReductionProblem[bytes], start, end, include_end
|
|
|
271
320
|
|
|
272
321
|
|
|
273
322
|
async def cut_comment_like_things(problem: ReductionProblem[bytes]):
|
|
323
|
+
"""Remove common comment syntaxes from source code.
|
|
324
|
+
|
|
325
|
+
Tries to delete Python-style (#), C++-style (//), Python docstrings
|
|
326
|
+
(triple quotes), and C-style block comments (/* ... */).
|
|
327
|
+
"""
|
|
274
328
|
await cut_comments(problem, b"#", b"\n", include_end=False)
|
|
275
329
|
await cut_comments(problem, b"//", b"\n", include_end=False)
|
|
276
330
|
await cut_comments(problem, b'"""', b'"""')
|
shrinkray/passes/json.py
CHANGED
|
@@ -81,6 +81,12 @@ class DeleteIdentifiers(Patches[frozenset[str], Any]):
|
|
|
81
81
|
|
|
82
82
|
|
|
83
83
|
async def delete_identifiers(problem: ReductionProblem[Any]):
|
|
84
|
+
"""Remove object keys from JSON structures.
|
|
85
|
+
|
|
86
|
+
Finds all string keys used in any nested object and tries to remove
|
|
87
|
+
them. When a key is removed, it's deleted from all objects that
|
|
88
|
+
contain it throughout the JSON tree.
|
|
89
|
+
"""
|
|
84
90
|
identifiers = gather_identifiers(problem.current_test_case)
|
|
85
91
|
|
|
86
92
|
await apply_patches(
|
shrinkray/passes/patching.py
CHANGED
|
@@ -1,24 +1,22 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Callable, Iterable, Sequence
|
|
2
3
|
from enum import Enum
|
|
3
4
|
from random import Random
|
|
4
|
-
from typing import Any,
|
|
5
|
+
from typing import Any, TypeVar, cast
|
|
5
6
|
|
|
6
7
|
import trio
|
|
7
8
|
|
|
8
9
|
from shrinkray.problem import ReductionProblem
|
|
9
10
|
|
|
10
|
-
Seq = TypeVar("Seq", bound=Sequence[Any])
|
|
11
|
-
T = TypeVar("T")
|
|
12
11
|
|
|
13
|
-
|
|
14
|
-
TargetType = TypeVar("TargetType")
|
|
12
|
+
Seq = TypeVar("Seq", bound=Sequence[Any])
|
|
15
13
|
|
|
16
14
|
|
|
17
15
|
class Conflict(Exception):
|
|
18
16
|
pass
|
|
19
17
|
|
|
20
18
|
|
|
21
|
-
class Patches
|
|
19
|
+
class Patches[PatchType, TargetType](ABC):
|
|
22
20
|
@property
|
|
23
21
|
@abstractmethod
|
|
24
22
|
def empty(self) -> PatchType: ...
|
|
@@ -33,7 +31,7 @@ class Patches(Generic[PatchType, TargetType], ABC):
|
|
|
33
31
|
def size(self, patch: PatchType) -> int: ...
|
|
34
32
|
|
|
35
33
|
|
|
36
|
-
class SetPatches(Patches[frozenset[T], TargetType]):
|
|
34
|
+
class SetPatches[T, TargetType](Patches[frozenset[T], TargetType]):
|
|
37
35
|
def __init__(self, apply: Callable[[frozenset[T], TargetType], TargetType]):
|
|
38
36
|
self.__apply = apply
|
|
39
37
|
|
|
@@ -54,7 +52,7 @@ class SetPatches(Patches[frozenset[T], TargetType]):
|
|
|
54
52
|
return len(patch)
|
|
55
53
|
|
|
56
54
|
|
|
57
|
-
class ListPatches(Patches[list[T], TargetType]):
|
|
55
|
+
class ListPatches[T, TargetType](Patches[list[T], TargetType]):
|
|
58
56
|
def __init__(self, apply: Callable[[list[T], TargetType], TargetType]):
|
|
59
57
|
self.__apply = apply
|
|
60
58
|
|
|
@@ -75,7 +73,7 @@ class ListPatches(Patches[list[T], TargetType]):
|
|
|
75
73
|
return len(patch)
|
|
76
74
|
|
|
77
75
|
|
|
78
|
-
class PatchApplier
|
|
76
|
+
class PatchApplier[PatchType, TargetType]:
|
|
79
77
|
def __init__(
|
|
80
78
|
self,
|
|
81
79
|
patches: Patches[PatchType, TargetType],
|
|
@@ -91,55 +89,28 @@ class PatchApplier(Generic[PatchType, TargetType], ABC):
|
|
|
91
89
|
self.__current_patch = self.__patches.empty
|
|
92
90
|
self.__initial_test_case = problem.current_test_case
|
|
93
91
|
|
|
94
|
-
async def
|
|
95
|
-
initial_patch = self.__current_patch
|
|
92
|
+
async def __possibly_become_merge_master(self):
|
|
96
93
|
try:
|
|
97
|
-
|
|
98
|
-
except
|
|
94
|
+
self.__merge_lock.acquire_nowait()
|
|
95
|
+
except trio.WouldBlock:
|
|
99
96
|
return False
|
|
100
|
-
|
|
101
|
-
return True
|
|
102
|
-
with_patch_applied = self.__patches.apply(
|
|
103
|
-
combined_patch, self.__initial_test_case
|
|
104
|
-
)
|
|
105
|
-
if with_patch_applied == self.__problem.current_test_case:
|
|
106
|
-
return True
|
|
107
|
-
if not await self.__problem.is_interesting(with_patch_applied):
|
|
108
|
-
return False
|
|
109
|
-
send_merge_result, receive_merge_result = trio.open_memory_channel(1)
|
|
110
|
-
|
|
111
|
-
sort_key = (self.__tick, self.__problem.sort_key(with_patch_applied))
|
|
112
|
-
self.__tick += 1
|
|
113
|
-
|
|
114
|
-
self.__merge_queue.append((sort_key, patch, send_merge_result))
|
|
115
|
-
|
|
116
|
-
async with self.__merge_lock:
|
|
117
|
-
if (
|
|
118
|
-
self.__current_patch == initial_patch
|
|
119
|
-
and len(self.__merge_queue) == 1
|
|
120
|
-
and self.__merge_queue[0][1] == patch
|
|
121
|
-
and self.__problem.sort_key(with_patch_applied)
|
|
122
|
-
<= self.__problem.sort_key(self.__problem.current_test_case)
|
|
123
|
-
):
|
|
124
|
-
self.__current_patch = combined_patch
|
|
125
|
-
self.__merge_queue.clear()
|
|
126
|
-
return True
|
|
127
|
-
|
|
97
|
+
try:
|
|
128
98
|
while self.__merge_queue:
|
|
129
99
|
base_patch = self.__current_patch
|
|
130
100
|
to_merge = len(self.__merge_queue)
|
|
131
101
|
|
|
132
102
|
async def can_merge(k):
|
|
133
|
-
|
|
134
|
-
|
|
103
|
+
# find_large_integer doubles each time, and
|
|
104
|
+
# if we call it then we know that can_merge(to_merge)
|
|
105
|
+
# is False, so we should never hit this.
|
|
106
|
+
assert k <= 2 * to_merge
|
|
135
107
|
try:
|
|
136
108
|
attempted_patch = self.__patches.combine(
|
|
137
|
-
base_patch,
|
|
109
|
+
base_patch,
|
|
110
|
+
*[p for _, p, _ in self.__merge_queue[:k]],
|
|
138
111
|
)
|
|
139
112
|
except Conflict:
|
|
140
113
|
return False
|
|
141
|
-
if attempted_patch == base_patch:
|
|
142
|
-
return True
|
|
143
114
|
with_patch_applied = self.__patches.apply(
|
|
144
115
|
attempted_patch, self.__initial_test_case
|
|
145
116
|
)
|
|
@@ -163,10 +134,44 @@ class PatchApplier(Generic[PatchType, TargetType], ABC):
|
|
|
163
134
|
del self.__merge_queue[: merged + 1]
|
|
164
135
|
else:
|
|
165
136
|
del self.__merge_queue[:to_merge]
|
|
137
|
+
finally:
|
|
138
|
+
self.__merge_lock.release()
|
|
166
139
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
140
|
+
return True
|
|
141
|
+
|
|
142
|
+
async def try_apply_patch(self, patch: PatchType) -> bool:
|
|
143
|
+
initial_patch = self.__current_patch
|
|
144
|
+
try:
|
|
145
|
+
combined_patch = self.__patches.combine(initial_patch, patch)
|
|
146
|
+
except Conflict:
|
|
147
|
+
return False
|
|
148
|
+
if combined_patch == self.__current_patch:
|
|
149
|
+
return True
|
|
150
|
+
with_patch_applied = self.__patches.apply(
|
|
151
|
+
combined_patch, self.__initial_test_case
|
|
152
|
+
)
|
|
153
|
+
if with_patch_applied == self.__problem.current_test_case:
|
|
154
|
+
return True
|
|
155
|
+
if not await self.__problem.is_interesting(with_patch_applied):
|
|
156
|
+
return False
|
|
157
|
+
send_merge_result, receive_merge_result = trio.open_memory_channel(1)
|
|
158
|
+
|
|
159
|
+
sort_key = (self.__tick, self.__problem.sort_key(with_patch_applied))
|
|
160
|
+
self.__tick += 1
|
|
161
|
+
|
|
162
|
+
self.__merge_queue.append((sort_key, patch, send_merge_result))
|
|
163
|
+
|
|
164
|
+
# If nobody else is merging the queue, that's our job now. This will
|
|
165
|
+
# run until the queue is fully cleared, including the job we just
|
|
166
|
+
# put on it.
|
|
167
|
+
if await self.__possibly_become_merge_master():
|
|
168
|
+
# This should always have been populated during the merge step we just
|
|
169
|
+
# performed, so we use a nowait here to ensure it doesn't hang on a
|
|
170
|
+
# bug.
|
|
171
|
+
return receive_merge_result.receive_nowait()
|
|
172
|
+
else:
|
|
173
|
+
# Wait to clear to merge queue.
|
|
174
|
+
return await receive_merge_result.receive()
|
|
170
175
|
|
|
171
176
|
|
|
172
177
|
class Direction(Enum):
|
|
@@ -178,15 +183,18 @@ class Completed(Exception):
|
|
|
178
183
|
pass
|
|
179
184
|
|
|
180
185
|
|
|
181
|
-
async def apply_patches(
|
|
186
|
+
async def apply_patches[PatchType, TargetType](
|
|
182
187
|
problem: ReductionProblem[TargetType],
|
|
183
188
|
patch_info: Patches[PatchType, TargetType],
|
|
184
189
|
patches: Iterable[PatchType],
|
|
185
190
|
) -> None:
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
191
|
+
try:
|
|
192
|
+
if await problem.is_interesting(
|
|
193
|
+
patch_info.apply(patch_info.combine(*patches), problem.current_test_case)
|
|
194
|
+
):
|
|
195
|
+
return
|
|
196
|
+
except Conflict:
|
|
197
|
+
pass
|
|
190
198
|
|
|
191
199
|
applier = PatchApplier(patch_info, problem)
|
|
192
200
|
|
|
@@ -200,10 +208,10 @@ async def apply_patches(
|
|
|
200
208
|
send_patches.close()
|
|
201
209
|
|
|
202
210
|
async with trio.open_nursery() as nursery:
|
|
203
|
-
for
|
|
211
|
+
for _i in range(problem.work.parallelism):
|
|
204
212
|
|
|
205
213
|
@nursery.start_soon
|
|
206
|
-
async def
|
|
214
|
+
async def worker() -> None:
|
|
207
215
|
while True:
|
|
208
216
|
try:
|
|
209
217
|
patch = await receive_patches.receive()
|
|
@@ -234,7 +242,7 @@ class LazyMutableRange:
|
|
|
234
242
|
return result
|
|
235
243
|
|
|
236
244
|
|
|
237
|
-
def lazy_shuffle(seq: Sequence[T], rnd: Random) -> Iterable[T]:
|
|
245
|
+
def lazy_shuffle[T](seq: Sequence[T], rnd: Random) -> Iterable[T]:
|
|
238
246
|
indices = LazyMutableRange(len(seq))
|
|
239
247
|
while indices:
|
|
240
248
|
j = len(indices) - 1
|