shrinkray 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shrinkray/__init__.py +1 -0
- shrinkray/__main__.py +1205 -0
- shrinkray/learning.py +221 -0
- shrinkray/passes/__init__.py +0 -0
- shrinkray/passes/bytes.py +547 -0
- shrinkray/passes/clangdelta.py +230 -0
- shrinkray/passes/definitions.py +52 -0
- shrinkray/passes/genericlanguages.py +277 -0
- shrinkray/passes/json.py +91 -0
- shrinkray/passes/patching.py +280 -0
- shrinkray/passes/python.py +176 -0
- shrinkray/passes/sat.py +176 -0
- shrinkray/passes/sequences.py +69 -0
- shrinkray/problem.py +318 -0
- shrinkray/py.typed +0 -0
- shrinkray/reducer.py +430 -0
- shrinkray/work.py +217 -0
- shrinkray-0.0.0.dist-info/LICENSE +21 -0
- shrinkray-0.0.0.dist-info/METADATA +170 -0
- shrinkray-0.0.0.dist-info/RECORD +22 -0
- shrinkray-0.0.0.dist-info/WHEEL +4 -0
- shrinkray-0.0.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from random import Random
|
|
4
|
+
from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar, cast
|
|
5
|
+
|
|
6
|
+
import trio
|
|
7
|
+
|
|
8
|
+
from shrinkray.problem import ReductionProblem
|
|
9
|
+
|
|
10
|
+
Seq = TypeVar("Seq", bound=Sequence[Any])
|
|
11
|
+
T = TypeVar("T")
|
|
12
|
+
|
|
13
|
+
PatchType = TypeVar("PatchType")
|
|
14
|
+
TargetType = TypeVar("TargetType")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Conflict(Exception):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Patches(Generic[PatchType, TargetType], ABC):
|
|
22
|
+
@property
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def empty(self) -> PatchType: ...
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def combine(self, *patches: PatchType) -> PatchType: ...
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def apply(self, patch: PatchType, target: TargetType) -> TargetType: ...
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def size(self, patch: PatchType) -> int: ...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SetPatches(Patches[frozenset[T], TargetType]):
|
|
37
|
+
def __init__(self, apply: Callable[[frozenset[T], TargetType], TargetType]):
|
|
38
|
+
self.__apply = apply
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def empty(self):
|
|
42
|
+
return frozenset()
|
|
43
|
+
|
|
44
|
+
def combine(self, *patches: frozenset[T]) -> frozenset[T]:
|
|
45
|
+
result = set()
|
|
46
|
+
for p in patches:
|
|
47
|
+
result.update(p)
|
|
48
|
+
return frozenset(result)
|
|
49
|
+
|
|
50
|
+
def apply(self, patch: frozenset[T], target: TargetType) -> TargetType:
|
|
51
|
+
return self.__apply(patch, target)
|
|
52
|
+
|
|
53
|
+
def size(self, patch: frozenset[T]) -> int:
|
|
54
|
+
return len(patch)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ListPatches(Patches[list[T], TargetType]):
|
|
58
|
+
def __init__(self, apply: Callable[[list[T], TargetType], TargetType]):
|
|
59
|
+
self.__apply = apply
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def empty(self):
|
|
63
|
+
return []
|
|
64
|
+
|
|
65
|
+
def combine(self, *patches: list[T]) -> list[T]:
|
|
66
|
+
result = []
|
|
67
|
+
for p in patches:
|
|
68
|
+
result.extend(p)
|
|
69
|
+
return result
|
|
70
|
+
|
|
71
|
+
def apply(self, patch: list[T], target: TargetType) -> TargetType:
|
|
72
|
+
return self.__apply(patch, target)
|
|
73
|
+
|
|
74
|
+
def size(self, patch: list[T]) -> int:
|
|
75
|
+
return len(patch)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class PatchApplier(Generic[PatchType, TargetType], ABC):
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
patches: Patches[PatchType, TargetType],
|
|
82
|
+
problem: ReductionProblem[TargetType],
|
|
83
|
+
):
|
|
84
|
+
self.__patches = patches
|
|
85
|
+
self.__problem = problem
|
|
86
|
+
|
|
87
|
+
self.__tick = 0
|
|
88
|
+
self.__merge_queue = []
|
|
89
|
+
self.__merge_lock = trio.Lock()
|
|
90
|
+
|
|
91
|
+
self.__current_patch = self.__patches.empty
|
|
92
|
+
self.__initial_test_case = problem.current_test_case
|
|
93
|
+
|
|
94
|
+
async def try_apply_patch(self, patch: PatchType) -> bool:
|
|
95
|
+
initial_patch = self.__current_patch
|
|
96
|
+
try:
|
|
97
|
+
combined_patch = self.__patches.combine(initial_patch, patch)
|
|
98
|
+
except Conflict:
|
|
99
|
+
return False
|
|
100
|
+
if combined_patch == self.__current_patch:
|
|
101
|
+
return True
|
|
102
|
+
with_patch_applied = self.__patches.apply(
|
|
103
|
+
combined_patch, self.__initial_test_case
|
|
104
|
+
)
|
|
105
|
+
if with_patch_applied == self.__problem.current_test_case:
|
|
106
|
+
return True
|
|
107
|
+
if not await self.__problem.is_interesting(with_patch_applied):
|
|
108
|
+
return False
|
|
109
|
+
send_merge_result, receive_merge_result = trio.open_memory_channel(1)
|
|
110
|
+
|
|
111
|
+
sort_key = (self.__tick, self.__problem.sort_key(with_patch_applied))
|
|
112
|
+
self.__tick += 1
|
|
113
|
+
|
|
114
|
+
self.__merge_queue.append((sort_key, patch, send_merge_result))
|
|
115
|
+
|
|
116
|
+
async with self.__merge_lock:
|
|
117
|
+
if (
|
|
118
|
+
self.__current_patch == initial_patch
|
|
119
|
+
and len(self.__merge_queue) == 1
|
|
120
|
+
and self.__merge_queue[0][1] == patch
|
|
121
|
+
and self.__problem.sort_key(with_patch_applied)
|
|
122
|
+
<= self.__problem.sort_key(self.__problem.current_test_case)
|
|
123
|
+
):
|
|
124
|
+
self.__current_patch = combined_patch
|
|
125
|
+
self.__merge_queue.clear()
|
|
126
|
+
return True
|
|
127
|
+
|
|
128
|
+
while self.__merge_queue:
|
|
129
|
+
base_patch = self.__current_patch
|
|
130
|
+
to_merge = len(self.__merge_queue)
|
|
131
|
+
|
|
132
|
+
async def can_merge(k):
|
|
133
|
+
if k > to_merge:
|
|
134
|
+
return False
|
|
135
|
+
try:
|
|
136
|
+
attempted_patch = self.__patches.combine(
|
|
137
|
+
base_patch, *[p for _, p, _ in self.__merge_queue[:k]]
|
|
138
|
+
)
|
|
139
|
+
except Conflict:
|
|
140
|
+
return False
|
|
141
|
+
if attempted_patch == base_patch:
|
|
142
|
+
return True
|
|
143
|
+
with_patch_applied = self.__patches.apply(
|
|
144
|
+
attempted_patch, self.__initial_test_case
|
|
145
|
+
)
|
|
146
|
+
if await self.__problem.is_reduction(with_patch_applied):
|
|
147
|
+
self.__current_patch = attempted_patch
|
|
148
|
+
return True
|
|
149
|
+
else:
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
if await can_merge(to_merge):
|
|
153
|
+
merged = to_merge
|
|
154
|
+
else:
|
|
155
|
+
merged = await self.__problem.work.find_large_integer(can_merge)
|
|
156
|
+
|
|
157
|
+
for _, _, send_result in self.__merge_queue[:merged]:
|
|
158
|
+
send_result.send_nowait(True)
|
|
159
|
+
|
|
160
|
+
assert merged <= to_merge
|
|
161
|
+
if merged < to_merge:
|
|
162
|
+
self.__merge_queue[merged][-1].send_nowait(False)
|
|
163
|
+
del self.__merge_queue[: merged + 1]
|
|
164
|
+
else:
|
|
165
|
+
del self.__merge_queue[:to_merge]
|
|
166
|
+
|
|
167
|
+
# This should always have been populated during the previous merge,
|
|
168
|
+
# either by us or someone else merging.
|
|
169
|
+
return receive_merge_result.receive_nowait()
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class Direction(Enum):
|
|
173
|
+
LEFT = 0
|
|
174
|
+
RIGHT = 1
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class Completed(Exception):
|
|
178
|
+
pass
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
async def apply_patches(
|
|
182
|
+
problem: ReductionProblem[TargetType],
|
|
183
|
+
patch_info: Patches[PatchType, TargetType],
|
|
184
|
+
patches: Iterable[PatchType],
|
|
185
|
+
) -> None:
|
|
186
|
+
if await problem.is_interesting(
|
|
187
|
+
patch_info.apply(patch_info.combine(*patches), problem.current_test_case)
|
|
188
|
+
):
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
applier = PatchApplier(patch_info, problem)
|
|
192
|
+
|
|
193
|
+
send_patches, receive_patches = trio.open_memory_channel(float("inf"))
|
|
194
|
+
|
|
195
|
+
patches = list(patches)
|
|
196
|
+
problem.work.random.shuffle(patches)
|
|
197
|
+
patches.sort(key=patch_info.size, reverse=True)
|
|
198
|
+
for patch in patches:
|
|
199
|
+
send_patches.send_nowait(patch)
|
|
200
|
+
send_patches.close()
|
|
201
|
+
|
|
202
|
+
async with trio.open_nursery() as nursery:
|
|
203
|
+
for _ in range(problem.work.parallelism):
|
|
204
|
+
|
|
205
|
+
@nursery.start_soon
|
|
206
|
+
async def _():
|
|
207
|
+
while True:
|
|
208
|
+
try:
|
|
209
|
+
patch = await receive_patches.receive()
|
|
210
|
+
except trio.EndOfChannel:
|
|
211
|
+
break
|
|
212
|
+
await applier.try_apply_patch(patch)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class LazyMutableRange:
|
|
216
|
+
def __init__(self, n: int):
|
|
217
|
+
self.__size = n
|
|
218
|
+
self.__mask: dict[int, int] = {}
|
|
219
|
+
|
|
220
|
+
def __getitem__(self, i: int) -> int:
|
|
221
|
+
return self.__mask.get(i, i)
|
|
222
|
+
|
|
223
|
+
def __setitem__(self, i: int, v: int) -> None:
|
|
224
|
+
self.__mask[i] = v
|
|
225
|
+
|
|
226
|
+
def __len__(self) -> int:
|
|
227
|
+
return self.__size
|
|
228
|
+
|
|
229
|
+
def pop(self) -> int:
|
|
230
|
+
i = len(self) - 1
|
|
231
|
+
result = self[i]
|
|
232
|
+
self.__size = i
|
|
233
|
+
self.__mask.pop(i, None)
|
|
234
|
+
return result
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def lazy_shuffle(seq: Sequence[T], rnd: Random) -> Iterable[T]:
|
|
238
|
+
indices = LazyMutableRange(len(seq))
|
|
239
|
+
while indices:
|
|
240
|
+
j = len(indices) - 1
|
|
241
|
+
i = rnd.randrange(0, len(indices))
|
|
242
|
+
indices[i], indices[j] = indices[j], indices[i]
|
|
243
|
+
yield seq[indices.pop()]
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
CutPatch = list[tuple[int, int]]
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class Cuts(Patches[CutPatch, Seq]):
|
|
250
|
+
@property
|
|
251
|
+
def empty(self) -> CutPatch:
|
|
252
|
+
return []
|
|
253
|
+
|
|
254
|
+
def combine(self, *patches: CutPatch) -> CutPatch:
|
|
255
|
+
all_cuts: CutPatch = []
|
|
256
|
+
for p in patches:
|
|
257
|
+
all_cuts.extend(p)
|
|
258
|
+
all_cuts.sort()
|
|
259
|
+
normalized: list[list[int]] = []
|
|
260
|
+
for start, end in all_cuts:
|
|
261
|
+
if normalized and normalized[-1][-1] >= start:
|
|
262
|
+
normalized[-1][-1] = max(normalized[-1][-1], end)
|
|
263
|
+
else:
|
|
264
|
+
normalized.append([start, end])
|
|
265
|
+
return [cast(tuple[int, int], tuple(x)) for x in normalized]
|
|
266
|
+
|
|
267
|
+
def apply(self, patch: CutPatch, target: Seq) -> Seq:
|
|
268
|
+
result: list[Any] = []
|
|
269
|
+
prev = 0
|
|
270
|
+
total_deleted = 0
|
|
271
|
+
for start, end in patch:
|
|
272
|
+
total_deleted += end - start
|
|
273
|
+
result.extend(target[prev:start])
|
|
274
|
+
prev = end
|
|
275
|
+
result.extend(target[prev:])
|
|
276
|
+
assert len(result) + total_deleted == len(target)
|
|
277
|
+
return type(target)(result) # type: ignore
|
|
278
|
+
|
|
279
|
+
def size(self, patch: CutPatch) -> int:
|
|
280
|
+
return sum(v - u for u, v in patch)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from typing import Any, AnyStr, Callable
|
|
2
|
+
|
|
3
|
+
import libcst
|
|
4
|
+
import libcst.matchers as m
|
|
5
|
+
from libcst import CSTNode, codemod
|
|
6
|
+
|
|
7
|
+
from shrinkray.problem import ReductionProblem
|
|
8
|
+
from shrinkray.work import NotFound
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def is_python(source: AnyStr) -> bool:
|
|
12
|
+
try:
|
|
13
|
+
libcst.parse_module(source)
|
|
14
|
+
return True
|
|
15
|
+
except (SyntaxError, UnicodeDecodeError, libcst.ParserSyntaxError, Exception):
|
|
16
|
+
return False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
Replacement = CSTNode | libcst.RemovalSentinel | libcst.FlattenSentinel[Any]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
async def libcst_transform(
|
|
23
|
+
problem: ReductionProblem[bytes],
|
|
24
|
+
matcher: m.BaseMatcherNode,
|
|
25
|
+
transformer: Callable[
|
|
26
|
+
[CSTNode],
|
|
27
|
+
Replacement,
|
|
28
|
+
],
|
|
29
|
+
) -> None:
|
|
30
|
+
class CM(codemod.VisitorBasedCodemodCommand):
|
|
31
|
+
def __init__(self, context: codemod.CodemodContext, target_index: int):
|
|
32
|
+
super().__init__(context)
|
|
33
|
+
self.target_index = target_index
|
|
34
|
+
self.current_index = 0
|
|
35
|
+
self.fired = False
|
|
36
|
+
|
|
37
|
+
# We have to have an ignore on the return type because if we don't LibCST
|
|
38
|
+
# will do some stupid bullshit with checking if the return type is correct
|
|
39
|
+
# and we use this generically in a way that makes it hard to type correctly.
|
|
40
|
+
@m.leave(matcher)
|
|
41
|
+
def maybe_change_node(self, _, updated_node): # type: ignore
|
|
42
|
+
if self.current_index == self.target_index:
|
|
43
|
+
self.fired = True
|
|
44
|
+
return transformer(updated_node)
|
|
45
|
+
else:
|
|
46
|
+
self.current_index += 1
|
|
47
|
+
return updated_node
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
module = libcst.parse_module(problem.current_test_case)
|
|
51
|
+
except Exception:
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
context = codemod.CodemodContext()
|
|
55
|
+
|
|
56
|
+
counting_mod = CM(context, -1)
|
|
57
|
+
counting_mod.transform_module(module)
|
|
58
|
+
|
|
59
|
+
n = counting_mod.current_index + 1
|
|
60
|
+
|
|
61
|
+
async def can_apply(i: int) -> bool:
|
|
62
|
+
nonlocal n
|
|
63
|
+
if i >= n:
|
|
64
|
+
return False
|
|
65
|
+
initial_test_case = problem.current_test_case
|
|
66
|
+
try:
|
|
67
|
+
module = libcst.parse_module(initial_test_case)
|
|
68
|
+
except libcst.ParserSyntaxError:
|
|
69
|
+
n = 0
|
|
70
|
+
return False
|
|
71
|
+
|
|
72
|
+
codemod_i = CM(context, i)
|
|
73
|
+
try:
|
|
74
|
+
transformed = codemod_i.transform_module(module)
|
|
75
|
+
except libcst.CSTValidationError:
|
|
76
|
+
return False
|
|
77
|
+
except TypeError as e:
|
|
78
|
+
if "does not allow for it" in e.args[0]:
|
|
79
|
+
return False
|
|
80
|
+
raise
|
|
81
|
+
|
|
82
|
+
if not codemod_i.fired:
|
|
83
|
+
n = i
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
transformed_test_case = transformed.code.encode(transformed.encoding)
|
|
87
|
+
|
|
88
|
+
if problem.sort_key(transformed_test_case) >= problem.sort_key(
|
|
89
|
+
initial_test_case
|
|
90
|
+
):
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
return await problem.is_interesting(transformed_test_case)
|
|
94
|
+
|
|
95
|
+
i = 0
|
|
96
|
+
while i < n:
|
|
97
|
+
try:
|
|
98
|
+
i = await problem.work.find_first_value(range(i, n), can_apply)
|
|
99
|
+
except NotFound:
|
|
100
|
+
break
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
async def lift_indented_constructs(problem: ReductionProblem[bytes]) -> None:
|
|
104
|
+
await libcst_transform(
|
|
105
|
+
problem,
|
|
106
|
+
m.OneOf(m.While(), m.If(), m.Try()),
|
|
107
|
+
lambda x: x.with_changes(orelse=None),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
await libcst_transform(
|
|
111
|
+
problem,
|
|
112
|
+
m.OneOf(m.While(), m.If(), m.Try(), m.With()),
|
|
113
|
+
lambda x: libcst.FlattenSentinel(x.body.body), # type: ignore
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
async def delete_statements(problem: ReductionProblem[bytes]) -> None:
|
|
118
|
+
await libcst_transform(
|
|
119
|
+
problem,
|
|
120
|
+
m.SimpleStatementLine(),
|
|
121
|
+
lambda x: libcst.RemoveFromParent(), # type: ignore
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
async def replace_statements_with_pass(problem: ReductionProblem[bytes]) -> None:
|
|
126
|
+
await libcst_transform(
|
|
127
|
+
problem,
|
|
128
|
+
m.SimpleStatementLine(),
|
|
129
|
+
lambda x: x.with_changes(body=[libcst.Pass()]), # type: ignore
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
ELLIPSIS_STATEMENT = libcst.parse_statement("...")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
async def replace_bodies_with_ellipsis(problem: ReductionProblem[bytes]) -> None:
|
|
137
|
+
await libcst_transform(
|
|
138
|
+
problem,
|
|
139
|
+
m.IndentedBlock(),
|
|
140
|
+
lambda x: x.with_changes(body=[ELLIPSIS_STATEMENT]), # type: ignore
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
async def strip_annotations(problem: ReductionProblem[bytes]) -> None:
|
|
145
|
+
await libcst_transform(
|
|
146
|
+
problem,
|
|
147
|
+
m.FunctionDef(),
|
|
148
|
+
lambda x: x.with_changes(returns=None),
|
|
149
|
+
)
|
|
150
|
+
await libcst_transform(
|
|
151
|
+
problem,
|
|
152
|
+
m.Param(),
|
|
153
|
+
lambda x: x.with_changes(annotation=None),
|
|
154
|
+
)
|
|
155
|
+
await libcst_transform(
|
|
156
|
+
problem,
|
|
157
|
+
m.AnnAssign(),
|
|
158
|
+
lambda x: (
|
|
159
|
+
libcst.Assign(
|
|
160
|
+
targets=[libcst.AssignTarget(target=x.target)],
|
|
161
|
+
value=x.value,
|
|
162
|
+
semicolon=x.semicolon,
|
|
163
|
+
)
|
|
164
|
+
if x.value
|
|
165
|
+
else libcst.RemoveFromParent()
|
|
166
|
+
),
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
PYTHON_PASSES = [
|
|
171
|
+
replace_bodies_with_ellipsis,
|
|
172
|
+
strip_annotations,
|
|
173
|
+
lift_indented_constructs,
|
|
174
|
+
delete_statements,
|
|
175
|
+
replace_statements_with_pass,
|
|
176
|
+
]
|
shrinkray/passes/sat.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from shrinkray.passes.definitions import Format, ParseError, ReductionPass
|
|
2
|
+
from shrinkray.passes.patching import SetPatches, apply_patches
|
|
3
|
+
from shrinkray.passes.sequences import delete_elements
|
|
4
|
+
from shrinkray.problem import ReductionProblem
|
|
5
|
+
|
|
6
|
+
Clause = list[int]
|
|
7
|
+
SAT = list[Clause]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _DimacsCNF(Format[bytes, SAT]):
|
|
11
|
+
@property
|
|
12
|
+
def name(self) -> str:
|
|
13
|
+
return "DimacsCNF"
|
|
14
|
+
|
|
15
|
+
def parse(self, input: bytes) -> SAT:
|
|
16
|
+
try:
|
|
17
|
+
contents = input.decode("utf-8")
|
|
18
|
+
except UnicodeDecodeError as e:
|
|
19
|
+
raise ParseError(*e.args)
|
|
20
|
+
clauses = []
|
|
21
|
+
for line in contents.splitlines():
|
|
22
|
+
line = line.strip()
|
|
23
|
+
if line.startswith("c"):
|
|
24
|
+
continue
|
|
25
|
+
if line.startswith("p"):
|
|
26
|
+
continue
|
|
27
|
+
if not line.strip():
|
|
28
|
+
continue
|
|
29
|
+
try:
|
|
30
|
+
clause = list(map(int, line.strip().split()))
|
|
31
|
+
except ValueError as e:
|
|
32
|
+
raise ParseError(*e.args)
|
|
33
|
+
if clause[-1] != 0:
|
|
34
|
+
raise ParseError(f"{line} did not end with 0")
|
|
35
|
+
clause.pop()
|
|
36
|
+
clauses.append(clause)
|
|
37
|
+
if not clauses:
|
|
38
|
+
raise ParseError("No clauses found")
|
|
39
|
+
return clauses
|
|
40
|
+
|
|
41
|
+
def dumps(self, input: SAT) -> bytes:
|
|
42
|
+
n_variables = max(abs(literal) for clause in input for literal in clause)
|
|
43
|
+
|
|
44
|
+
parts = [f"p cnf {n_variables} {len(input)}"]
|
|
45
|
+
|
|
46
|
+
for c in input:
|
|
47
|
+
parts.append(" ".join(map(repr, list(c) + [0])))
|
|
48
|
+
|
|
49
|
+
return "\n".join(parts).encode("utf-8")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
DimacsCNF = _DimacsCNF()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
async def renumber_variables(problem: ReductionProblem[SAT]):
|
|
56
|
+
renumbering = {}
|
|
57
|
+
|
|
58
|
+
def renumber(l):
|
|
59
|
+
if l < 0:
|
|
60
|
+
return -renumber(-l)
|
|
61
|
+
try:
|
|
62
|
+
return renumbering[l]
|
|
63
|
+
except KeyError:
|
|
64
|
+
pass
|
|
65
|
+
result = len(renumbering) + 1
|
|
66
|
+
renumbering[l] = result
|
|
67
|
+
return result
|
|
68
|
+
|
|
69
|
+
renumbered = [
|
|
70
|
+
[renumber(literal) for literal in clause]
|
|
71
|
+
for clause in problem.current_test_case
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
await problem.is_interesting(renumbered)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
async def flip_literal_signs(problem: ReductionProblem[SAT]):
|
|
78
|
+
seen_variables = set()
|
|
79
|
+
target = problem.current_test_case
|
|
80
|
+
for i in range(len(target)):
|
|
81
|
+
for j, v in enumerate(target[i]):
|
|
82
|
+
if abs(v) not in seen_variables and v < 0:
|
|
83
|
+
attempt = []
|
|
84
|
+
for clause in target:
|
|
85
|
+
new_clause = []
|
|
86
|
+
for literal in clause:
|
|
87
|
+
if abs(literal) == abs(v):
|
|
88
|
+
new_clause.append(-literal)
|
|
89
|
+
else:
|
|
90
|
+
new_clause.append(literal)
|
|
91
|
+
attempt.append(new_clause)
|
|
92
|
+
if await problem.is_interesting(attempt):
|
|
93
|
+
target = attempt
|
|
94
|
+
seen_variables.add(abs(v))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
async def remove_redundant_clauses(problem: ReductionProblem[SAT]):
|
|
98
|
+
attempt = []
|
|
99
|
+
seen = set()
|
|
100
|
+
for clause in problem.current_test_case:
|
|
101
|
+
if len(set(map(abs, clause))) < len(set(clause)):
|
|
102
|
+
continue
|
|
103
|
+
key = tuple(clause)
|
|
104
|
+
if key in seen:
|
|
105
|
+
continue
|
|
106
|
+
seen.add(key)
|
|
107
|
+
attempt.append(clause)
|
|
108
|
+
await problem.is_interesting(attempt)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def literals_in(sat: SAT) -> frozenset[int]:
|
|
112
|
+
return frozenset({literal for clause in sat for literal in clause})
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
async def delete_literals(problem: ReductionProblem[SAT]):
|
|
116
|
+
def remove_literals(literals: frozenset[int], sat: SAT) -> SAT:
|
|
117
|
+
result = []
|
|
118
|
+
for clause in sat:
|
|
119
|
+
new_clause = [v for v in clause if v not in literals]
|
|
120
|
+
if new_clause:
|
|
121
|
+
result.append(new_clause)
|
|
122
|
+
return result
|
|
123
|
+
|
|
124
|
+
await apply_patches(
|
|
125
|
+
problem,
|
|
126
|
+
SetPatches(remove_literals),
|
|
127
|
+
[frozenset({v}) for v in literals_in(problem.current_test_case)],
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
async def merge_variables(problem: ReductionProblem[SAT]):
|
|
132
|
+
i = 0
|
|
133
|
+
j = 1
|
|
134
|
+
while True:
|
|
135
|
+
variables = sorted({abs(l) for c in problem.current_test_case for l in c})
|
|
136
|
+
if j >= len(variables):
|
|
137
|
+
i += 1
|
|
138
|
+
j = i + 1
|
|
139
|
+
if j >= len(variables):
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
target = variables[i]
|
|
143
|
+
to_replace = variables[j]
|
|
144
|
+
|
|
145
|
+
new_clauses = []
|
|
146
|
+
for c in problem.current_test_case:
|
|
147
|
+
c = set(c)
|
|
148
|
+
if to_replace in c:
|
|
149
|
+
c.discard(to_replace)
|
|
150
|
+
c.add(target)
|
|
151
|
+
if -to_replace in c:
|
|
152
|
+
c.discard(-to_replace)
|
|
153
|
+
c.add(-target)
|
|
154
|
+
if len(set(map(abs, c))) < len(c):
|
|
155
|
+
continue
|
|
156
|
+
new_clauses.append(sorted(c))
|
|
157
|
+
|
|
158
|
+
assert new_clauses != problem.current_test_case
|
|
159
|
+
await problem.is_interesting(new_clauses)
|
|
160
|
+
if new_clauses != problem.current_test_case:
|
|
161
|
+
j += 1
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
async def sort_clauses(problem: ReductionProblem[SAT]):
|
|
165
|
+
await problem.is_interesting(sorted(map(sorted, problem.current_test_case)))
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
SAT_PASSES: list[ReductionPass[SAT]] = [
|
|
169
|
+
sort_clauses,
|
|
170
|
+
renumber_variables,
|
|
171
|
+
flip_literal_signs,
|
|
172
|
+
remove_redundant_clauses,
|
|
173
|
+
delete_elements,
|
|
174
|
+
delete_literals,
|
|
175
|
+
merge_variables,
|
|
176
|
+
]
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Any, Sequence, TypeVar
|
|
3
|
+
|
|
4
|
+
from shrinkray.passes.definitions import ReductionPass
|
|
5
|
+
from shrinkray.passes.patching import CutPatch, Cuts, apply_patches
|
|
6
|
+
from shrinkray.problem import ReductionProblem
|
|
7
|
+
|
|
8
|
+
Seq = TypeVar("Seq", bound=Sequence[Any])
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
async def delete_elements(problem: ReductionProblem[Seq]) -> None:
|
|
12
|
+
await apply_patches(
|
|
13
|
+
problem, Cuts(), [[(i, i + 1)] for i in range(len(problem.current_test_case))]
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def merged_intervals(intervals: list[tuple[int, int]]) -> list[tuple[int, int]]:
|
|
18
|
+
normalized: list[list[int]] = []
|
|
19
|
+
for start, end in sorted(map(tuple, intervals)):
|
|
20
|
+
if normalized and normalized[-1][-1] >= start:
|
|
21
|
+
normalized[-1][-1] = max(normalized[-1][-1], end)
|
|
22
|
+
else:
|
|
23
|
+
normalized.append([start, end])
|
|
24
|
+
return list(map(tuple, normalized)) # type: ignore
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def with_deletions(target: Seq, deletions: list[tuple[int, int]]) -> Seq:
|
|
28
|
+
result: list[Any] = []
|
|
29
|
+
prev = 0
|
|
30
|
+
total_deleted = 0
|
|
31
|
+
for start, end in deletions:
|
|
32
|
+
total_deleted += end - start
|
|
33
|
+
result.extend(target[prev:start])
|
|
34
|
+
prev = end
|
|
35
|
+
result.extend(target[prev:])
|
|
36
|
+
assert len(result) + total_deleted == len(target)
|
|
37
|
+
return type(target)(result) # type: ignore
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def block_deletion(min_block: int, max_block: int) -> ReductionPass[Seq]:
|
|
41
|
+
async def apply(problem: ReductionProblem[Seq]) -> None:
|
|
42
|
+
n = len(problem.current_test_case)
|
|
43
|
+
if n <= min_block:
|
|
44
|
+
return
|
|
45
|
+
blocks = [
|
|
46
|
+
[(i, i + block_size)]
|
|
47
|
+
for block_size in range(min_block, max_block + 1)
|
|
48
|
+
for offset in range(block_size)
|
|
49
|
+
for i in range(offset, n, block_size)
|
|
50
|
+
if i + block_size <= n
|
|
51
|
+
]
|
|
52
|
+
await apply_patches(problem, Cuts(), blocks)
|
|
53
|
+
|
|
54
|
+
apply.__name__ = f"block_deletion({min_block}, {max_block})"
|
|
55
|
+
return apply
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
async def delete_duplicates(problem: ReductionProblem[Seq]) -> None:
|
|
59
|
+
index: dict[int, list[int]] = defaultdict(list)
|
|
60
|
+
|
|
61
|
+
for i, c in enumerate(problem.current_test_case):
|
|
62
|
+
index[c].append(i)
|
|
63
|
+
|
|
64
|
+
cuts: list[CutPatch] = []
|
|
65
|
+
|
|
66
|
+
for ix in index.values():
|
|
67
|
+
if len(ix) > 1:
|
|
68
|
+
cuts.append([(i, i + 1) for i in ix])
|
|
69
|
+
await apply_patches(problem, Cuts(), cuts)
|