shrinkray 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,280 @@
1
+ from abc import ABC, abstractmethod
2
+ from enum import Enum
3
+ from random import Random
4
+ from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar, cast
5
+
6
+ import trio
7
+
8
+ from shrinkray.problem import ReductionProblem
9
+
10
+ Seq = TypeVar("Seq", bound=Sequence[Any])
11
+ T = TypeVar("T")
12
+
13
+ PatchType = TypeVar("PatchType")
14
+ TargetType = TypeVar("TargetType")
15
+
16
+
17
+ class Conflict(Exception):
18
+ pass
19
+
20
+
21
+ class Patches(Generic[PatchType, TargetType], ABC):
22
+ @property
23
+ @abstractmethod
24
+ def empty(self) -> PatchType: ...
25
+
26
+ @abstractmethod
27
+ def combine(self, *patches: PatchType) -> PatchType: ...
28
+
29
+ @abstractmethod
30
+ def apply(self, patch: PatchType, target: TargetType) -> TargetType: ...
31
+
32
+ @abstractmethod
33
+ def size(self, patch: PatchType) -> int: ...
34
+
35
+
36
+ class SetPatches(Patches[frozenset[T], TargetType]):
37
+ def __init__(self, apply: Callable[[frozenset[T], TargetType], TargetType]):
38
+ self.__apply = apply
39
+
40
+ @property
41
+ def empty(self):
42
+ return frozenset()
43
+
44
+ def combine(self, *patches: frozenset[T]) -> frozenset[T]:
45
+ result = set()
46
+ for p in patches:
47
+ result.update(p)
48
+ return frozenset(result)
49
+
50
+ def apply(self, patch: frozenset[T], target: TargetType) -> TargetType:
51
+ return self.__apply(patch, target)
52
+
53
+ def size(self, patch: frozenset[T]) -> int:
54
+ return len(patch)
55
+
56
+
57
+ class ListPatches(Patches[list[T], TargetType]):
58
+ def __init__(self, apply: Callable[[list[T], TargetType], TargetType]):
59
+ self.__apply = apply
60
+
61
+ @property
62
+ def empty(self):
63
+ return []
64
+
65
+ def combine(self, *patches: list[T]) -> list[T]:
66
+ result = []
67
+ for p in patches:
68
+ result.extend(p)
69
+ return result
70
+
71
+ def apply(self, patch: list[T], target: TargetType) -> TargetType:
72
+ return self.__apply(patch, target)
73
+
74
+ def size(self, patch: list[T]) -> int:
75
+ return len(patch)
76
+
77
+
78
+ class PatchApplier(Generic[PatchType, TargetType], ABC):
79
+ def __init__(
80
+ self,
81
+ patches: Patches[PatchType, TargetType],
82
+ problem: ReductionProblem[TargetType],
83
+ ):
84
+ self.__patches = patches
85
+ self.__problem = problem
86
+
87
+ self.__tick = 0
88
+ self.__merge_queue = []
89
+ self.__merge_lock = trio.Lock()
90
+
91
+ self.__current_patch = self.__patches.empty
92
+ self.__initial_test_case = problem.current_test_case
93
+
94
+ async def try_apply_patch(self, patch: PatchType) -> bool:
95
+ initial_patch = self.__current_patch
96
+ try:
97
+ combined_patch = self.__patches.combine(initial_patch, patch)
98
+ except Conflict:
99
+ return False
100
+ if combined_patch == self.__current_patch:
101
+ return True
102
+ with_patch_applied = self.__patches.apply(
103
+ combined_patch, self.__initial_test_case
104
+ )
105
+ if with_patch_applied == self.__problem.current_test_case:
106
+ return True
107
+ if not await self.__problem.is_interesting(with_patch_applied):
108
+ return False
109
+ send_merge_result, receive_merge_result = trio.open_memory_channel(1)
110
+
111
+ sort_key = (self.__tick, self.__problem.sort_key(with_patch_applied))
112
+ self.__tick += 1
113
+
114
+ self.__merge_queue.append((sort_key, patch, send_merge_result))
115
+
116
+ async with self.__merge_lock:
117
+ if (
118
+ self.__current_patch == initial_patch
119
+ and len(self.__merge_queue) == 1
120
+ and self.__merge_queue[0][1] == patch
121
+ and self.__problem.sort_key(with_patch_applied)
122
+ <= self.__problem.sort_key(self.__problem.current_test_case)
123
+ ):
124
+ self.__current_patch = combined_patch
125
+ self.__merge_queue.clear()
126
+ return True
127
+
128
+ while self.__merge_queue:
129
+ base_patch = self.__current_patch
130
+ to_merge = len(self.__merge_queue)
131
+
132
+ async def can_merge(k):
133
+ if k > to_merge:
134
+ return False
135
+ try:
136
+ attempted_patch = self.__patches.combine(
137
+ base_patch, *[p for _, p, _ in self.__merge_queue[:k]]
138
+ )
139
+ except Conflict:
140
+ return False
141
+ if attempted_patch == base_patch:
142
+ return True
143
+ with_patch_applied = self.__patches.apply(
144
+ attempted_patch, self.__initial_test_case
145
+ )
146
+ if await self.__problem.is_reduction(with_patch_applied):
147
+ self.__current_patch = attempted_patch
148
+ return True
149
+ else:
150
+ return False
151
+
152
+ if await can_merge(to_merge):
153
+ merged = to_merge
154
+ else:
155
+ merged = await self.__problem.work.find_large_integer(can_merge)
156
+
157
+ for _, _, send_result in self.__merge_queue[:merged]:
158
+ send_result.send_nowait(True)
159
+
160
+ assert merged <= to_merge
161
+ if merged < to_merge:
162
+ self.__merge_queue[merged][-1].send_nowait(False)
163
+ del self.__merge_queue[: merged + 1]
164
+ else:
165
+ del self.__merge_queue[:to_merge]
166
+
167
+ # This should always have been populated during the previous merge,
168
+ # either by us or someone else merging.
169
+ return receive_merge_result.receive_nowait()
170
+
171
+
172
+ class Direction(Enum):
173
+ LEFT = 0
174
+ RIGHT = 1
175
+
176
+
177
+ class Completed(Exception):
178
+ pass
179
+
180
+
181
+ async def apply_patches(
182
+ problem: ReductionProblem[TargetType],
183
+ patch_info: Patches[PatchType, TargetType],
184
+ patches: Iterable[PatchType],
185
+ ) -> None:
186
+ if await problem.is_interesting(
187
+ patch_info.apply(patch_info.combine(*patches), problem.current_test_case)
188
+ ):
189
+ return
190
+
191
+ applier = PatchApplier(patch_info, problem)
192
+
193
+ send_patches, receive_patches = trio.open_memory_channel(float("inf"))
194
+
195
+ patches = list(patches)
196
+ problem.work.random.shuffle(patches)
197
+ patches.sort(key=patch_info.size, reverse=True)
198
+ for patch in patches:
199
+ send_patches.send_nowait(patch)
200
+ send_patches.close()
201
+
202
+ async with trio.open_nursery() as nursery:
203
+ for _ in range(problem.work.parallelism):
204
+
205
+ @nursery.start_soon
206
+ async def _():
207
+ while True:
208
+ try:
209
+ patch = await receive_patches.receive()
210
+ except trio.EndOfChannel:
211
+ break
212
+ await applier.try_apply_patch(patch)
213
+
214
+
215
+ class LazyMutableRange:
216
+ def __init__(self, n: int):
217
+ self.__size = n
218
+ self.__mask: dict[int, int] = {}
219
+
220
+ def __getitem__(self, i: int) -> int:
221
+ return self.__mask.get(i, i)
222
+
223
+ def __setitem__(self, i: int, v: int) -> None:
224
+ self.__mask[i] = v
225
+
226
+ def __len__(self) -> int:
227
+ return self.__size
228
+
229
+ def pop(self) -> int:
230
+ i = len(self) - 1
231
+ result = self[i]
232
+ self.__size = i
233
+ self.__mask.pop(i, None)
234
+ return result
235
+
236
+
237
+ def lazy_shuffle(seq: Sequence[T], rnd: Random) -> Iterable[T]:
238
+ indices = LazyMutableRange(len(seq))
239
+ while indices:
240
+ j = len(indices) - 1
241
+ i = rnd.randrange(0, len(indices))
242
+ indices[i], indices[j] = indices[j], indices[i]
243
+ yield seq[indices.pop()]
244
+
245
+
246
+ CutPatch = list[tuple[int, int]]
247
+
248
+
249
+ class Cuts(Patches[CutPatch, Seq]):
250
+ @property
251
+ def empty(self) -> CutPatch:
252
+ return []
253
+
254
+ def combine(self, *patches: CutPatch) -> CutPatch:
255
+ all_cuts: CutPatch = []
256
+ for p in patches:
257
+ all_cuts.extend(p)
258
+ all_cuts.sort()
259
+ normalized: list[list[int]] = []
260
+ for start, end in all_cuts:
261
+ if normalized and normalized[-1][-1] >= start:
262
+ normalized[-1][-1] = max(normalized[-1][-1], end)
263
+ else:
264
+ normalized.append([start, end])
265
+ return [cast(tuple[int, int], tuple(x)) for x in normalized]
266
+
267
+ def apply(self, patch: CutPatch, target: Seq) -> Seq:
268
+ result: list[Any] = []
269
+ prev = 0
270
+ total_deleted = 0
271
+ for start, end in patch:
272
+ total_deleted += end - start
273
+ result.extend(target[prev:start])
274
+ prev = end
275
+ result.extend(target[prev:])
276
+ assert len(result) + total_deleted == len(target)
277
+ return type(target)(result) # type: ignore
278
+
279
+ def size(self, patch: CutPatch) -> int:
280
+ return sum(v - u for u, v in patch)
@@ -0,0 +1,176 @@
1
+ from typing import Any, AnyStr, Callable
2
+
3
+ import libcst
4
+ import libcst.matchers as m
5
+ from libcst import CSTNode, codemod
6
+
7
+ from shrinkray.problem import ReductionProblem
8
+ from shrinkray.work import NotFound
9
+
10
+
11
+ def is_python(source: AnyStr) -> bool:
12
+ try:
13
+ libcst.parse_module(source)
14
+ return True
15
+ except (SyntaxError, UnicodeDecodeError, libcst.ParserSyntaxError, Exception):
16
+ return False
17
+
18
+
19
+ Replacement = CSTNode | libcst.RemovalSentinel | libcst.FlattenSentinel[Any]
20
+
21
+
22
+ async def libcst_transform(
23
+ problem: ReductionProblem[bytes],
24
+ matcher: m.BaseMatcherNode,
25
+ transformer: Callable[
26
+ [CSTNode],
27
+ Replacement,
28
+ ],
29
+ ) -> None:
30
+ class CM(codemod.VisitorBasedCodemodCommand):
31
+ def __init__(self, context: codemod.CodemodContext, target_index: int):
32
+ super().__init__(context)
33
+ self.target_index = target_index
34
+ self.current_index = 0
35
+ self.fired = False
36
+
37
+ # We have to have an ignore on the return type because if we don't LibCST
38
+ # will do some stupid bullshit with checking if the return type is correct
39
+ # and we use this generically in a way that makes it hard to type correctly.
40
+ @m.leave(matcher)
41
+ def maybe_change_node(self, _, updated_node): # type: ignore
42
+ if self.current_index == self.target_index:
43
+ self.fired = True
44
+ return transformer(updated_node)
45
+ else:
46
+ self.current_index += 1
47
+ return updated_node
48
+
49
+ try:
50
+ module = libcst.parse_module(problem.current_test_case)
51
+ except Exception:
52
+ return
53
+
54
+ context = codemod.CodemodContext()
55
+
56
+ counting_mod = CM(context, -1)
57
+ counting_mod.transform_module(module)
58
+
59
+ n = counting_mod.current_index + 1
60
+
61
+ async def can_apply(i: int) -> bool:
62
+ nonlocal n
63
+ if i >= n:
64
+ return False
65
+ initial_test_case = problem.current_test_case
66
+ try:
67
+ module = libcst.parse_module(initial_test_case)
68
+ except libcst.ParserSyntaxError:
69
+ n = 0
70
+ return False
71
+
72
+ codemod_i = CM(context, i)
73
+ try:
74
+ transformed = codemod_i.transform_module(module)
75
+ except libcst.CSTValidationError:
76
+ return False
77
+ except TypeError as e:
78
+ if "does not allow for it" in e.args[0]:
79
+ return False
80
+ raise
81
+
82
+ if not codemod_i.fired:
83
+ n = i
84
+ return False
85
+
86
+ transformed_test_case = transformed.code.encode(transformed.encoding)
87
+
88
+ if problem.sort_key(transformed_test_case) >= problem.sort_key(
89
+ initial_test_case
90
+ ):
91
+ return False
92
+
93
+ return await problem.is_interesting(transformed_test_case)
94
+
95
+ i = 0
96
+ while i < n:
97
+ try:
98
+ i = await problem.work.find_first_value(range(i, n), can_apply)
99
+ except NotFound:
100
+ break
101
+
102
+
103
+ async def lift_indented_constructs(problem: ReductionProblem[bytes]) -> None:
104
+ await libcst_transform(
105
+ problem,
106
+ m.OneOf(m.While(), m.If(), m.Try()),
107
+ lambda x: x.with_changes(orelse=None),
108
+ )
109
+
110
+ await libcst_transform(
111
+ problem,
112
+ m.OneOf(m.While(), m.If(), m.Try(), m.With()),
113
+ lambda x: libcst.FlattenSentinel(x.body.body), # type: ignore
114
+ )
115
+
116
+
117
+ async def delete_statements(problem: ReductionProblem[bytes]) -> None:
118
+ await libcst_transform(
119
+ problem,
120
+ m.SimpleStatementLine(),
121
+ lambda x: libcst.RemoveFromParent(), # type: ignore
122
+ )
123
+
124
+
125
+ async def replace_statements_with_pass(problem: ReductionProblem[bytes]) -> None:
126
+ await libcst_transform(
127
+ problem,
128
+ m.SimpleStatementLine(),
129
+ lambda x: x.with_changes(body=[libcst.Pass()]), # type: ignore
130
+ )
131
+
132
+
133
+ ELLIPSIS_STATEMENT = libcst.parse_statement("...")
134
+
135
+
136
+ async def replace_bodies_with_ellipsis(problem: ReductionProblem[bytes]) -> None:
137
+ await libcst_transform(
138
+ problem,
139
+ m.IndentedBlock(),
140
+ lambda x: x.with_changes(body=[ELLIPSIS_STATEMENT]), # type: ignore
141
+ )
142
+
143
+
144
+ async def strip_annotations(problem: ReductionProblem[bytes]) -> None:
145
+ await libcst_transform(
146
+ problem,
147
+ m.FunctionDef(),
148
+ lambda x: x.with_changes(returns=None),
149
+ )
150
+ await libcst_transform(
151
+ problem,
152
+ m.Param(),
153
+ lambda x: x.with_changes(annotation=None),
154
+ )
155
+ await libcst_transform(
156
+ problem,
157
+ m.AnnAssign(),
158
+ lambda x: (
159
+ libcst.Assign(
160
+ targets=[libcst.AssignTarget(target=x.target)],
161
+ value=x.value,
162
+ semicolon=x.semicolon,
163
+ )
164
+ if x.value
165
+ else libcst.RemoveFromParent()
166
+ ),
167
+ )
168
+
169
+
170
+ PYTHON_PASSES = [
171
+ replace_bodies_with_ellipsis,
172
+ strip_annotations,
173
+ lift_indented_constructs,
174
+ delete_statements,
175
+ replace_statements_with_pass,
176
+ ]
@@ -0,0 +1,176 @@
1
+ from shrinkray.passes.definitions import Format, ParseError, ReductionPass
2
+ from shrinkray.passes.patching import SetPatches, apply_patches
3
+ from shrinkray.passes.sequences import delete_elements
4
+ from shrinkray.problem import ReductionProblem
5
+
6
+ Clause = list[int]
7
+ SAT = list[Clause]
8
+
9
+
10
+ class _DimacsCNF(Format[bytes, SAT]):
11
+ @property
12
+ def name(self) -> str:
13
+ return "DimacsCNF"
14
+
15
+ def parse(self, input: bytes) -> SAT:
16
+ try:
17
+ contents = input.decode("utf-8")
18
+ except UnicodeDecodeError as e:
19
+ raise ParseError(*e.args)
20
+ clauses = []
21
+ for line in contents.splitlines():
22
+ line = line.strip()
23
+ if line.startswith("c"):
24
+ continue
25
+ if line.startswith("p"):
26
+ continue
27
+ if not line.strip():
28
+ continue
29
+ try:
30
+ clause = list(map(int, line.strip().split()))
31
+ except ValueError as e:
32
+ raise ParseError(*e.args)
33
+ if clause[-1] != 0:
34
+ raise ParseError(f"{line} did not end with 0")
35
+ clause.pop()
36
+ clauses.append(clause)
37
+ if not clauses:
38
+ raise ParseError("No clauses found")
39
+ return clauses
40
+
41
+ def dumps(self, input: SAT) -> bytes:
42
+ n_variables = max(abs(literal) for clause in input for literal in clause)
43
+
44
+ parts = [f"p cnf {n_variables} {len(input)}"]
45
+
46
+ for c in input:
47
+ parts.append(" ".join(map(repr, list(c) + [0])))
48
+
49
+ return "\n".join(parts).encode("utf-8")
50
+
51
+
52
+ DimacsCNF = _DimacsCNF()
53
+
54
+
55
+ async def renumber_variables(problem: ReductionProblem[SAT]):
56
+ renumbering = {}
57
+
58
+ def renumber(l):
59
+ if l < 0:
60
+ return -renumber(-l)
61
+ try:
62
+ return renumbering[l]
63
+ except KeyError:
64
+ pass
65
+ result = len(renumbering) + 1
66
+ renumbering[l] = result
67
+ return result
68
+
69
+ renumbered = [
70
+ [renumber(literal) for literal in clause]
71
+ for clause in problem.current_test_case
72
+ ]
73
+
74
+ await problem.is_interesting(renumbered)
75
+
76
+
77
+ async def flip_literal_signs(problem: ReductionProblem[SAT]):
78
+ seen_variables = set()
79
+ target = problem.current_test_case
80
+ for i in range(len(target)):
81
+ for j, v in enumerate(target[i]):
82
+ if abs(v) not in seen_variables and v < 0:
83
+ attempt = []
84
+ for clause in target:
85
+ new_clause = []
86
+ for literal in clause:
87
+ if abs(literal) == abs(v):
88
+ new_clause.append(-literal)
89
+ else:
90
+ new_clause.append(literal)
91
+ attempt.append(new_clause)
92
+ if await problem.is_interesting(attempt):
93
+ target = attempt
94
+ seen_variables.add(abs(v))
95
+
96
+
97
+ async def remove_redundant_clauses(problem: ReductionProblem[SAT]):
98
+ attempt = []
99
+ seen = set()
100
+ for clause in problem.current_test_case:
101
+ if len(set(map(abs, clause))) < len(set(clause)):
102
+ continue
103
+ key = tuple(clause)
104
+ if key in seen:
105
+ continue
106
+ seen.add(key)
107
+ attempt.append(clause)
108
+ await problem.is_interesting(attempt)
109
+
110
+
111
+ def literals_in(sat: SAT) -> frozenset[int]:
112
+ return frozenset({literal for clause in sat for literal in clause})
113
+
114
+
115
+ async def delete_literals(problem: ReductionProblem[SAT]):
116
+ def remove_literals(literals: frozenset[int], sat: SAT) -> SAT:
117
+ result = []
118
+ for clause in sat:
119
+ new_clause = [v for v in clause if v not in literals]
120
+ if new_clause:
121
+ result.append(new_clause)
122
+ return result
123
+
124
+ await apply_patches(
125
+ problem,
126
+ SetPatches(remove_literals),
127
+ [frozenset({v}) for v in literals_in(problem.current_test_case)],
128
+ )
129
+
130
+
131
+ async def merge_variables(problem: ReductionProblem[SAT]):
132
+ i = 0
133
+ j = 1
134
+ while True:
135
+ variables = sorted({abs(l) for c in problem.current_test_case for l in c})
136
+ if j >= len(variables):
137
+ i += 1
138
+ j = i + 1
139
+ if j >= len(variables):
140
+ return
141
+
142
+ target = variables[i]
143
+ to_replace = variables[j]
144
+
145
+ new_clauses = []
146
+ for c in problem.current_test_case:
147
+ c = set(c)
148
+ if to_replace in c:
149
+ c.discard(to_replace)
150
+ c.add(target)
151
+ if -to_replace in c:
152
+ c.discard(-to_replace)
153
+ c.add(-target)
154
+ if len(set(map(abs, c))) < len(c):
155
+ continue
156
+ new_clauses.append(sorted(c))
157
+
158
+ assert new_clauses != problem.current_test_case
159
+ await problem.is_interesting(new_clauses)
160
+ if new_clauses != problem.current_test_case:
161
+ j += 1
162
+
163
+
164
+ async def sort_clauses(problem: ReductionProblem[SAT]):
165
+ await problem.is_interesting(sorted(map(sorted, problem.current_test_case)))
166
+
167
+
168
+ SAT_PASSES: list[ReductionPass[SAT]] = [
169
+ sort_clauses,
170
+ renumber_variables,
171
+ flip_literal_signs,
172
+ remove_redundant_clauses,
173
+ delete_elements,
174
+ delete_literals,
175
+ merge_variables,
176
+ ]
@@ -0,0 +1,69 @@
1
+ from collections import defaultdict
2
+ from typing import Any, Sequence, TypeVar
3
+
4
+ from shrinkray.passes.definitions import ReductionPass
5
+ from shrinkray.passes.patching import CutPatch, Cuts, apply_patches
6
+ from shrinkray.problem import ReductionProblem
7
+
8
+ Seq = TypeVar("Seq", bound=Sequence[Any])
9
+
10
+
11
+ async def delete_elements(problem: ReductionProblem[Seq]) -> None:
12
+ await apply_patches(
13
+ problem, Cuts(), [[(i, i + 1)] for i in range(len(problem.current_test_case))]
14
+ )
15
+
16
+
17
+ def merged_intervals(intervals: list[tuple[int, int]]) -> list[tuple[int, int]]:
18
+ normalized: list[list[int]] = []
19
+ for start, end in sorted(map(tuple, intervals)):
20
+ if normalized and normalized[-1][-1] >= start:
21
+ normalized[-1][-1] = max(normalized[-1][-1], end)
22
+ else:
23
+ normalized.append([start, end])
24
+ return list(map(tuple, normalized)) # type: ignore
25
+
26
+
27
+ def with_deletions(target: Seq, deletions: list[tuple[int, int]]) -> Seq:
28
+ result: list[Any] = []
29
+ prev = 0
30
+ total_deleted = 0
31
+ for start, end in deletions:
32
+ total_deleted += end - start
33
+ result.extend(target[prev:start])
34
+ prev = end
35
+ result.extend(target[prev:])
36
+ assert len(result) + total_deleted == len(target)
37
+ return type(target)(result) # type: ignore
38
+
39
+
40
+ def block_deletion(min_block: int, max_block: int) -> ReductionPass[Seq]:
41
+ async def apply(problem: ReductionProblem[Seq]) -> None:
42
+ n = len(problem.current_test_case)
43
+ if n <= min_block:
44
+ return
45
+ blocks = [
46
+ [(i, i + block_size)]
47
+ for block_size in range(min_block, max_block + 1)
48
+ for offset in range(block_size)
49
+ for i in range(offset, n, block_size)
50
+ if i + block_size <= n
51
+ ]
52
+ await apply_patches(problem, Cuts(), blocks)
53
+
54
+ apply.__name__ = f"block_deletion({min_block}, {max_block})"
55
+ return apply
56
+
57
+
58
+ async def delete_duplicates(problem: ReductionProblem[Seq]) -> None:
59
+ index: dict[int, list[int]] = defaultdict(list)
60
+
61
+ for i, c in enumerate(problem.current_test_case):
62
+ index[c].append(i)
63
+
64
+ cuts: list[CutPatch] = []
65
+
66
+ for ix in index.values():
67
+ if len(ix) > 1:
68
+ cuts.append([(i, i + 1) for i in ix])
69
+ await apply_patches(problem, Cuts(), cuts)