shrinkray 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,547 @@
1
+ from collections import defaultdict, deque
2
+ from typing import Sequence
3
+
4
+ from attrs import define
5
+
6
+ from shrinkray.passes.definitions import Format, ReductionProblem
7
+ from shrinkray.passes.patching import Cuts, Patches, apply_patches
8
+
9
+
10
+ @define(frozen=True)
11
+ class Encoding(Format[bytes, str]):
12
+ encoding: str
13
+
14
+ def __repr__(self) -> str:
15
+ return f"Encoding({repr(self.encoding)})"
16
+
17
+ @property
18
+ def name(self) -> str:
19
+ return self.encoding
20
+
21
+ def parse(self, input: bytes) -> str:
22
+ return input.decode(self.encoding)
23
+
24
+ def dumps(self, input: str) -> bytes:
25
+ return input.encode(self.encoding)
26
+
27
+
28
+ @define(frozen=True)
29
+ class Split(Format[bytes, list[bytes]]):
30
+ splitter: bytes
31
+
32
+ def __repr__(self) -> str:
33
+ return f"Split({repr(self.splitter)})"
34
+
35
+ @property
36
+ def name(self) -> str:
37
+ return f"split({repr(self.splitter)})"
38
+
39
+ def parse(self, input: bytes) -> list[bytes]:
40
+ return input.split(self.splitter)
41
+
42
+ def dumps(self, input: list[bytes]) -> bytes:
43
+ return self.splitter.join(input)
44
+
45
+
46
+ def find_ngram_endpoints(value: bytes) -> list[tuple[int, list[int]]]:
47
+ if len(set(value)) <= 1:
48
+ return []
49
+ queue: deque[tuple[int, Sequence[int]]] = deque([(0, range(len(value)))])
50
+ results: list[tuple[int, list[int]]] = []
51
+
52
+ while queue and len(results) < 10000:
53
+ k, indices = queue.popleft()
54
+
55
+ if k > 1:
56
+ normalized: list[int] = []
57
+ for i in indices:
58
+ if not normalized or i >= normalized[-1] + k:
59
+ normalized.append(i)
60
+ indices = normalized
61
+
62
+ while (
63
+ indices[-1] + k < len(value) and len({value[i + k] for i in indices}) == 1
64
+ ):
65
+ k += 1
66
+
67
+ if k > 0 and (indices[0] == 0 or len({value[i - 1] for i in indices}) > 1):
68
+ assert isinstance(indices, list), value
69
+ results.append((k, indices))
70
+
71
+ split: dict[int, list[int]] = defaultdict(list)
72
+ for i in indices:
73
+ try:
74
+ split[value[i + k]].append(i)
75
+ except IndexError:
76
+ pass
77
+ queue.extend([(k + 1, v) for v in split.values() if len(v) > 1])
78
+
79
+ return results
80
+
81
+
82
+ def tokenize(text: bytes) -> list[bytes]:
83
+ result: list[bytes] = []
84
+ i = 0
85
+ while i < len(text):
86
+ c = bytes([text[i]])
87
+ j = i + 1
88
+ if b"A" <= c <= b"z":
89
+ while j < len(text) and (
90
+ b"A"[0] <= text[j] <= b"z"[0]
91
+ or text[j] == b"_"[0]
92
+ or b"0"[0] <= text[j] <= b"9"[0]
93
+ ):
94
+ j += 1
95
+ elif b"0" <= c <= b"9":
96
+ while j < len(text) and (
97
+ text[j] == b"."[0] or b"0"[0] <= text[j] <= b"9"[0]
98
+ ):
99
+ j += 1
100
+ elif c == b" ":
101
+ while j < len(text) and (text[j] == b" "[0]):
102
+ j += 1
103
+ result.append(text[i:j])
104
+ i = j
105
+ assert b"".join(result) == text
106
+ return result
107
+
108
+
109
+ MAX_DELETE_INTERVAL = 8
110
+
111
+
112
+ async def lexeme_based_deletions(
113
+ problem: ReductionProblem[bytes], min_size: int = 8
114
+ ) -> None:
115
+ intervals_by_k: dict[int, set[tuple[int, int]]] = defaultdict(set)
116
+
117
+ for k, endpoints in find_ngram_endpoints(problem.current_test_case):
118
+ intervals_by_k[k].update(zip(endpoints, endpoints[1:]))
119
+
120
+ intervals_to_delete = [
121
+ t
122
+ for _, intervals in sorted(intervals_by_k.items(), reverse=True)
123
+ for t in sorted(intervals, key=lambda t: (t[1] - t[0], t[0]), reverse=True)
124
+ if t[1] - t[0] >= min_size
125
+ ]
126
+
127
+ await delete_intervals(problem, intervals_to_delete, shuffle=True)
128
+
129
+
130
+ async def delete_intervals(
131
+ problem: ReductionProblem[bytes],
132
+ intervals_to_delete: list[tuple[int, int]],
133
+ shuffle: bool = False,
134
+ ) -> None:
135
+ await apply_patches(problem, Cuts(), [[t] for t in intervals_to_delete])
136
+
137
+
138
+ def brace_intervals(target: bytes, brace: bytes) -> list[tuple[int, int]]:
139
+ open, close = brace
140
+ intervals: list[tuple[int, int]] = []
141
+ stack: list[int] = []
142
+ for i, c in enumerate(target):
143
+ if c == open:
144
+ stack.append(i)
145
+ elif c == close and stack:
146
+ start = stack.pop() + 1
147
+ end = i
148
+ if end > start:
149
+ intervals.append((start, end))
150
+ return intervals
151
+
152
+
153
+ async def debracket(problem: ReductionProblem[bytes]) -> None:
154
+ cuts = [
155
+ [(u - 1, u), (v, v + 1)]
156
+ for brackets in [b"{}", b"()", b"[]"]
157
+ for u, v in brace_intervals(problem.current_test_case, brackets)
158
+ ]
159
+ await apply_patches(
160
+ problem,
161
+ Cuts(),
162
+ cuts,
163
+ )
164
+
165
+
166
+ def quote_intervals(target: bytes) -> list[tuple[int, int]]:
167
+ indices: dict[int, list[int]] = defaultdict(list)
168
+ for i, c in enumerate(target):
169
+ indices[c].append(i)
170
+
171
+ intervals: list[tuple[int, int]] = []
172
+ for quote in b"\"'":
173
+ xs = indices[quote]
174
+ for u, v in zip(xs, xs[1:], strict=False):
175
+ if u + 1 < v:
176
+ intervals.append((u + 1, v))
177
+ return intervals
178
+
179
+
180
+ async def hollow(problem: ReductionProblem[bytes]) -> None:
181
+ target = problem.current_test_case
182
+ intervals: list[tuple[int, int]] = []
183
+ for b in [
184
+ quote_intervals(target),
185
+ brace_intervals(target, b"[]"),
186
+ brace_intervals(target, b"{}"),
187
+ ]:
188
+ b.sort(key=lambda t: (t[1] - t[0], t[0]))
189
+ intervals.extend(b)
190
+ await delete_intervals(
191
+ problem,
192
+ intervals,
193
+ )
194
+
195
+
196
+ async def short_deletions(problem: ReductionProblem[bytes]) -> None:
197
+ target = problem.current_test_case
198
+ await delete_intervals(
199
+ problem,
200
+ [
201
+ (i, j)
202
+ for i in range(len(target))
203
+ for j in range(i + 1, min(i + 11, len(target) + 1))
204
+ ],
205
+ )
206
+
207
+
208
+ async def lift_braces(problem: ReductionProblem[bytes]) -> None:
209
+ target = problem.current_test_case
210
+
211
+ open_brace, close_brace = b"{}"
212
+ start_stack: list[int] = []
213
+ child_stack: list[list[tuple[int, int]]] = []
214
+
215
+ results: list[tuple[int, int, list[tuple[int, int]]]] = []
216
+
217
+ for i, c in enumerate(target):
218
+ if c == open_brace:
219
+ start_stack.append(i)
220
+ child_stack.append([])
221
+ elif c == close_brace and start_stack:
222
+ start = start_stack.pop() + 1
223
+ end = i
224
+ children = child_stack.pop()
225
+ if child_stack:
226
+ child_stack[-1].append((start, end))
227
+ if end > start:
228
+ results.append((start, end, children))
229
+
230
+ cuts: list[list[tuple[int, int]]] = []
231
+ for start, end, children in results:
232
+ for child_start, child_end in children:
233
+ cuts.append([(start, child_start), (child_end, end)])
234
+
235
+ await apply_patches(problem, Cuts(), cuts)
236
+
237
+
238
+ @define(frozen=True)
239
+ class Tokenize(Format[bytes, list[bytes]]):
240
+ def __repr__(self) -> str:
241
+ return "tokenize"
242
+
243
+ @property
244
+ def name(self) -> str:
245
+ return "tokenize"
246
+
247
+ def parse(self, input: bytes) -> list[bytes]:
248
+ return tokenize(input)
249
+
250
+ def dumps(self, input: list[bytes]) -> bytes:
251
+ return b"".join(input)
252
+
253
+
254
+ async def delete_byte_spans(problem: ReductionProblem[bytes]) -> None:
255
+ indices: dict[int, list[int]] = defaultdict(list)
256
+ target = problem.current_test_case
257
+ for i, c in enumerate(target):
258
+ indices[c].append(i)
259
+
260
+ spans: list[tuple[int, int]] = []
261
+
262
+ for c, ix in sorted(indices.items()):
263
+ if len(ix) > 1:
264
+ spans.append((0, ix[0] + 1))
265
+ spans.extend(zip(ix, ix[1:]))
266
+ spans.append((ix[-1], len(target)))
267
+
268
+ await apply_patches(problem, Cuts(), [[s] for s in spans])
269
+
270
+
271
+ async def remove_indents(problem: ReductionProblem[bytes]) -> None:
272
+ target = problem.current_test_case
273
+ spans: list[list[tuple[int, int]]] = []
274
+
275
+ newline = ord(b"\n")
276
+ space = ord(b" ")
277
+
278
+ for i, c in enumerate(target):
279
+ if c == newline:
280
+ j = i + 1
281
+ while j < len(target) and target[j] == space:
282
+ j += 1
283
+
284
+ if j > i + 1:
285
+ spans.append([(i + 1, j)])
286
+
287
+ await apply_patches(problem, Cuts(), spans)
288
+
289
+
290
+ async def remove_whitespace(problem: ReductionProblem[bytes]) -> None:
291
+ target = problem.current_test_case
292
+ spans: list[list[tuple[int, int]]] = []
293
+
294
+ for i, c in enumerate(target):
295
+ char = bytes([c])
296
+ if char.isspace():
297
+ j = i + 1
298
+ while j < len(target) and target[j : j + 1].isspace():
299
+ j += 1
300
+
301
+ if j > i + 1:
302
+ spans.append([(i, j)])
303
+ if j > i + 2:
304
+ spans.append([(i + 1, j)])
305
+
306
+ await apply_patches(problem, Cuts(), spans)
307
+
308
+
309
+ class NewlineReplacer(Patches[frozenset[int], bytes]):
310
+ @property
311
+ def empty(self) -> frozenset[int]:
312
+ return frozenset()
313
+
314
+ def combine(self, *patches: frozenset[int]) -> frozenset[int]:
315
+ result: set[int] = set()
316
+ for p in patches:
317
+ result.update(p)
318
+ return frozenset(result)
319
+
320
+ def apply(self, patch: frozenset[int], target: bytes) -> bytes:
321
+ result = bytearray()
322
+
323
+ for i, c in enumerate(target):
324
+ if i in patch:
325
+ result.extend(b"\n")
326
+ else:
327
+ result.append(c)
328
+ return bytes(result)
329
+
330
+ def size(self, patch: frozenset[int]) -> int:
331
+ return len(patch)
332
+
333
+
334
+ async def replace_space_with_newlines(problem: ReductionProblem[bytes]) -> None:
335
+ await apply_patches(
336
+ problem,
337
+ NewlineReplacer(),
338
+ [
339
+ frozenset({i})
340
+ for i, c in enumerate(problem.current_test_case)
341
+ if c in b" \t"
342
+ ],
343
+ )
344
+
345
+
346
+ ReplacementPatch = dict[int, int]
347
+
348
+
349
+ class ByteReplacement(Patches[ReplacementPatch, bytes]):
350
+ @property
351
+ def empty(self) -> ReplacementPatch:
352
+ return {}
353
+
354
+ def combine(self, *patches: ReplacementPatch) -> ReplacementPatch:
355
+ result = {}
356
+ for p in patches:
357
+ for k, v in p.items():
358
+ if k not in result:
359
+ result[k] = v
360
+ else:
361
+ result[k] = min(result[k], v)
362
+ return result
363
+
364
+ def apply(self, patch: ReplacementPatch, target: bytes) -> bytes:
365
+ result = bytearray()
366
+ for c in target:
367
+ result.append(patch.get(c, c))
368
+ return bytes(result)
369
+
370
+ def size(self, patch: ReplacementPatch) -> int:
371
+ return 0
372
+
373
+
374
+ async def lower_bytes(problem: ReductionProblem[bytes]) -> None:
375
+ sources = sorted(set(problem.current_test_case))
376
+
377
+ patches = [
378
+ {c: r}
379
+ for c in sources
380
+ for r in sorted({0, 1, c // 2, c - 1} | set(b" \t\r\n"))
381
+ if r < c and r >= 0
382
+ ] + [
383
+ {c: r, d: r}
384
+ for c in sources
385
+ for d in sources
386
+ if c != d
387
+ for r in sorted({0, 1, c // 2, c - 1, d // 2, d - 1} | set(b" \t\r\n"))
388
+ if (r < c or r < d) and r >= 0
389
+ ]
390
+
391
+ await apply_patches(problem, ByteReplacement(), patches)
392
+
393
+
394
+ class IndividualByteReplacement(Patches[ReplacementPatch, bytes]):
395
+ @property
396
+ def empty(self) -> ReplacementPatch:
397
+ return {}
398
+
399
+ def combine(self, *patches: ReplacementPatch) -> ReplacementPatch:
400
+ result = {}
401
+ for p in patches:
402
+ for k, v in p.items():
403
+ if k not in result:
404
+ result[k] = v
405
+ else:
406
+ result[k] = min(result[k], v)
407
+ return result
408
+
409
+ def apply(self, patch: ReplacementPatch, target: bytes) -> bytes:
410
+ result = bytearray()
411
+ for i, c in enumerate(target):
412
+ result.append(patch.get(i, c))
413
+ return bytes(result)
414
+
415
+ def size(self, patch: ReplacementPatch) -> int:
416
+ return 0
417
+
418
+
419
+ async def lower_individual_bytes(problem: ReductionProblem[bytes]) -> None:
420
+ initial = problem.current_test_case
421
+ patches = [
422
+ {i: r}
423
+ for i, c in enumerate(initial)
424
+ for r in sorted({0, 1, c // 2, c - 1} | set(b" \t\r\n"))
425
+ if r < c and r >= 0
426
+ ] + [
427
+ {i - 1: initial[i - 1] - 1, i: 255}
428
+ for i, c in enumerate(initial)
429
+ if i > 0 and initial[i - 1] > 0 and c == 0
430
+ ]
431
+ await apply_patches(problem, IndividualByteReplacement(), patches)
432
+
433
+
434
+ RegionReplacementPatch = list[tuple[int, int, int]]
435
+
436
+
437
+ class RegionReplacement(Patches[ReplacementPatch, bytes]):
438
+ @property
439
+ def empty(self) -> ReplacementPatch:
440
+ return []
441
+
442
+ def combine(self, *patches: ReplacementPatch) -> ReplacementPatch:
443
+ result = []
444
+ for p in patches:
445
+ result.extend(p)
446
+ return result
447
+
448
+ def apply(self, patch: ReplacementPatch, target: bytes) -> bytes:
449
+ result = bytearray(target)
450
+ for i, j, d in patch:
451
+ if d < result[i]:
452
+ for k in range(i, j):
453
+ result[k] = d
454
+ return bytes(result)
455
+
456
+ def size(self, patch: ReplacementPatch) -> int:
457
+ return 0
458
+
459
+
460
+ async def short_replacements(problem: ReductionProblem[bytes]) -> None:
461
+ target = problem.current_test_case
462
+ patches = [
463
+ [(i, j, c)]
464
+ for c in [0, 1] + list(b"01 \t\n\r.")
465
+ for i in range(len(target))
466
+ if target[i] > c
467
+ for j in range(i + 1, min(i + 5, len(target) + 1))
468
+ ]
469
+
470
+ await apply_patches(problem, RegionReplacement(), patches)
471
+
472
+
473
+ WHITESPACE = b" \t\r\n"
474
+
475
+
476
+ async def sort_whitespace(problem: ReductionProblem[bytes]) -> None:
477
+ """NB: This is a stupid pass that we only really need for artificial
478
+ test cases, but it's helpful for allowing those artificial test cases
479
+ to expose other issues."""
480
+
481
+ whitespace_up_to = 0
482
+ while (
483
+ whitespace_up_to < len(problem.current_test_case)
484
+ and problem.current_test_case[whitespace_up_to] not in WHITESPACE
485
+ ):
486
+ whitespace_up_to += 1
487
+ while (
488
+ whitespace_up_to < len(problem.current_test_case)
489
+ and problem.current_test_case[whitespace_up_to] in WHITESPACE
490
+ ):
491
+ whitespace_up_to += 1
492
+
493
+ # If the initial whitespace ends with a newline we want to keep it doing
494
+ # that. This is mostly for Python purposes.
495
+ if (
496
+ whitespace_up_to > 0
497
+ and problem.current_test_case[whitespace_up_to - 1] == b"\n"[0]
498
+ ):
499
+ whitespace_up_to -= 1
500
+
501
+ i = whitespace_up_to + 1
502
+
503
+ while i < len(problem.current_test_case):
504
+ if problem.current_test_case[i] not in WHITESPACE:
505
+ i += 1
506
+ continue
507
+
508
+ async def can_move_to_whitespace(k):
509
+ if i + k > len(problem.current_test_case):
510
+ return False
511
+
512
+ base = problem.current_test_case
513
+ target = base[i : i + k]
514
+
515
+ if any(c not in WHITESPACE for c in target):
516
+ return False
517
+
518
+ prefix = base[:whitespace_up_to]
519
+ attempt = prefix + target + base[whitespace_up_to:i] + base[i + k :]
520
+ return await problem.is_interesting(attempt)
521
+
522
+ k = await problem.work.find_large_integer(can_move_to_whitespace)
523
+ whitespace_up_to += k
524
+ i += k + 1
525
+ test_case = problem.current_test_case
526
+ await problem.is_interesting(
527
+ bytes(sorted(test_case[:whitespace_up_to])) + test_case[whitespace_up_to:]
528
+ )
529
+
530
+
531
+ # These are some cheat substitutions that are sometimes helpful, but mostly
532
+ # for passing stupid tests.
533
+ STANDARD_SUBSTITUTIONS = [(b"\0\0", b"\1"), (b"\0\0", b"\xff")]
534
+
535
+
536
+ async def standard_substitutions(problem: ReductionProblem[bytes]):
537
+ i = 0
538
+ while i < len(problem.current_test_case):
539
+ for k, v in STANDARD_SUBSTITUTIONS:
540
+ x = problem.current_test_case
541
+ if i + len(k) <= len(x) and x[i : i + len(k)] == k:
542
+ attempt = x[:i] + v + x[i + len(k) :]
543
+ if await problem.is_interesting(attempt):
544
+ assert problem.current_test_case == attempt
545
+ break
546
+ else:
547
+ i += 1