shrinkray 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shrinkray/__init__.py +1 -0
- shrinkray/__main__.py +1205 -0
- shrinkray/learning.py +221 -0
- shrinkray/passes/__init__.py +0 -0
- shrinkray/passes/bytes.py +547 -0
- shrinkray/passes/clangdelta.py +230 -0
- shrinkray/passes/definitions.py +52 -0
- shrinkray/passes/genericlanguages.py +277 -0
- shrinkray/passes/json.py +91 -0
- shrinkray/passes/patching.py +280 -0
- shrinkray/passes/python.py +176 -0
- shrinkray/passes/sat.py +176 -0
- shrinkray/passes/sequences.py +69 -0
- shrinkray/problem.py +318 -0
- shrinkray/py.typed +0 -0
- shrinkray/reducer.py +430 -0
- shrinkray/work.py +217 -0
- shrinkray-0.0.0.dist-info/LICENSE +21 -0
- shrinkray-0.0.0.dist-info/METADATA +170 -0
- shrinkray-0.0.0.dist-info/RECORD +22 -0
- shrinkray-0.0.0.dist-info/WHEEL +4 -0
- shrinkray-0.0.0.dist-info/entry_points.txt +3 -0
shrinkray/reducer.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Generator
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from typing import Any, Generic, Iterable, Optional, TypeVar
|
|
5
|
+
|
|
6
|
+
import attrs
|
|
7
|
+
import trio
|
|
8
|
+
from attrs import define
|
|
9
|
+
|
|
10
|
+
from shrinkray.passes.bytes import (
|
|
11
|
+
Split,
|
|
12
|
+
Tokenize,
|
|
13
|
+
debracket,
|
|
14
|
+
delete_byte_spans,
|
|
15
|
+
hollow,
|
|
16
|
+
lexeme_based_deletions,
|
|
17
|
+
lift_braces,
|
|
18
|
+
lower_bytes,
|
|
19
|
+
lower_individual_bytes,
|
|
20
|
+
remove_indents,
|
|
21
|
+
remove_whitespace,
|
|
22
|
+
replace_space_with_newlines,
|
|
23
|
+
short_deletions,
|
|
24
|
+
standard_substitutions,
|
|
25
|
+
)
|
|
26
|
+
from shrinkray.passes.clangdelta import C_FILE_EXTENSIONS, ClangDelta, clang_delta_pumps
|
|
27
|
+
from shrinkray.passes.definitions import Format, ReductionPass, ReductionPump, compose
|
|
28
|
+
from shrinkray.passes.genericlanguages import (
|
|
29
|
+
combine_expressions,
|
|
30
|
+
cut_comment_like_things,
|
|
31
|
+
merge_adjacent_strings,
|
|
32
|
+
normalize_identifiers,
|
|
33
|
+
reduce_integer_literals,
|
|
34
|
+
replace_falsey_with_zero,
|
|
35
|
+
simplify_brackets,
|
|
36
|
+
)
|
|
37
|
+
from shrinkray.passes.json import JSON, JSON_PASSES
|
|
38
|
+
from shrinkray.passes.patching import PatchApplier, Patches
|
|
39
|
+
from shrinkray.passes.python import PYTHON_PASSES, is_python
|
|
40
|
+
from shrinkray.passes.sat import SAT_PASSES, DimacsCNF
|
|
41
|
+
from shrinkray.passes.sequences import block_deletion, delete_duplicates
|
|
42
|
+
from shrinkray.problem import ReductionProblem, shortlex
|
|
43
|
+
|
|
44
|
+
S = TypeVar("S")
|
|
45
|
+
T = TypeVar("T")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@define
|
|
49
|
+
class Reducer(Generic[T], ABC):
|
|
50
|
+
target: ReductionProblem[T]
|
|
51
|
+
|
|
52
|
+
@contextmanager
|
|
53
|
+
def backtrack(self, restart: T) -> Generator[None, None, None]:
|
|
54
|
+
current = self.target
|
|
55
|
+
try:
|
|
56
|
+
self.target = self.target.backtrack(restart)
|
|
57
|
+
yield
|
|
58
|
+
finally:
|
|
59
|
+
self.target = current
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
async def run(self) -> None: ...
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def status(self) -> str:
|
|
66
|
+
return ""
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@define
|
|
70
|
+
class BasicReducer(Reducer[T]):
|
|
71
|
+
reduction_passes: Iterable[ReductionPass[T]]
|
|
72
|
+
pumps: Iterable[ReductionPump[T]] = ()
|
|
73
|
+
status: str = "Starting up"
|
|
74
|
+
|
|
75
|
+
def __attrs_post_init__(self) -> None:
|
|
76
|
+
self.reduction_passes = list(self.reduction_passes)
|
|
77
|
+
|
|
78
|
+
async def run_pass(self, rp: ReductionPass[T]) -> None:
|
|
79
|
+
await rp(self.target)
|
|
80
|
+
|
|
81
|
+
async def run(self) -> None:
|
|
82
|
+
await self.target.setup()
|
|
83
|
+
|
|
84
|
+
while True:
|
|
85
|
+
prev = self.target.current_test_case
|
|
86
|
+
for rp in self.reduction_passes:
|
|
87
|
+
self.status = f"Running reduction pass {rp.__name__}"
|
|
88
|
+
await self.run_pass(rp)
|
|
89
|
+
for pump in self.pumps:
|
|
90
|
+
self.status = f"Pumping with {pump.__name__}"
|
|
91
|
+
pumped = await pump(self.target)
|
|
92
|
+
if pumped != self.target.current_test_case:
|
|
93
|
+
with self.backtrack(pumped):
|
|
94
|
+
for rp in self.reduction_passes:
|
|
95
|
+
self.status = f"Running reduction pass {rp.__name__} under pump {pump.__name__}"
|
|
96
|
+
await self.run_pass(rp)
|
|
97
|
+
if prev == self.target.current_test_case:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class RestartPass(Exception):
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@define
|
|
106
|
+
class ShrinkRay(Reducer[bytes]):
|
|
107
|
+
clang_delta: Optional[ClangDelta] = None
|
|
108
|
+
|
|
109
|
+
current_reduction_pass: Optional[ReductionPass[bytes]] = None
|
|
110
|
+
current_pump: Optional[ReductionPump[bytes]] = None
|
|
111
|
+
|
|
112
|
+
unlocked_ok_passes: bool = False
|
|
113
|
+
|
|
114
|
+
initial_cuts: list[ReductionPass[bytes]] = attrs.Factory(
|
|
115
|
+
lambda: [
|
|
116
|
+
cut_comment_like_things,
|
|
117
|
+
hollow,
|
|
118
|
+
compose(Split(b"\n"), delete_duplicates),
|
|
119
|
+
compose(Split(b"\n"), block_deletion(10, 100)),
|
|
120
|
+
lift_braces,
|
|
121
|
+
remove_indents,
|
|
122
|
+
remove_whitespace,
|
|
123
|
+
]
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
great_passes: list[ReductionPass[bytes]] = attrs.Factory(
|
|
127
|
+
lambda: [
|
|
128
|
+
compose(Split(b"\n"), delete_duplicates),
|
|
129
|
+
compose(Split(b"\n"), block_deletion(1, 10)),
|
|
130
|
+
compose(Split(b";"), block_deletion(1, 10)),
|
|
131
|
+
remove_indents,
|
|
132
|
+
hollow,
|
|
133
|
+
lift_braces,
|
|
134
|
+
delete_byte_spans,
|
|
135
|
+
debracket,
|
|
136
|
+
]
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
ok_passes: list[ReductionPass[bytes]] = attrs.Factory(
|
|
140
|
+
lambda: [
|
|
141
|
+
compose(Split(b"\n"), block_deletion(11, 20)),
|
|
142
|
+
remove_indents,
|
|
143
|
+
remove_whitespace,
|
|
144
|
+
compose(Tokenize(), block_deletion(1, 20)),
|
|
145
|
+
reduce_integer_literals,
|
|
146
|
+
replace_falsey_with_zero,
|
|
147
|
+
combine_expressions,
|
|
148
|
+
merge_adjacent_strings,
|
|
149
|
+
lexeme_based_deletions,
|
|
150
|
+
short_deletions,
|
|
151
|
+
normalize_identifiers,
|
|
152
|
+
]
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
last_ditch_passes: list[ReductionPass[bytes]] = attrs.Factory(
|
|
156
|
+
lambda: [
|
|
157
|
+
compose(Split(b"\n"), block_deletion(21, 100)),
|
|
158
|
+
replace_space_with_newlines,
|
|
159
|
+
delete_byte_spans,
|
|
160
|
+
lower_bytes,
|
|
161
|
+
lower_individual_bytes,
|
|
162
|
+
simplify_brackets,
|
|
163
|
+
standard_substitutions,
|
|
164
|
+
# This is in last ditch because it's probably not useful
|
|
165
|
+
# to run it more than once.
|
|
166
|
+
cut_comment_like_things,
|
|
167
|
+
]
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def __attrs_post_init__(self) -> None:
|
|
171
|
+
if is_python(self.target.current_test_case):
|
|
172
|
+
self.great_passes[:0] = PYTHON_PASSES
|
|
173
|
+
self.initial_cuts[:0] = PYTHON_PASSES
|
|
174
|
+
self.register_format_specific_pass(JSON, JSON_PASSES)
|
|
175
|
+
self.register_format_specific_pass(
|
|
176
|
+
DimacsCNF,
|
|
177
|
+
SAT_PASSES,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def register_format_specific_pass(
|
|
181
|
+
self, format: Format[bytes, T], passes: Iterable[ReductionPass[T]]
|
|
182
|
+
):
|
|
183
|
+
if format.is_valid(self.target.current_test_case):
|
|
184
|
+
composed = [compose(format, p) for p in passes]
|
|
185
|
+
self.great_passes[:0] = composed
|
|
186
|
+
self.initial_cuts[:0] = composed
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def pumps(self) -> Iterable[ReductionPump[bytes]]:
|
|
190
|
+
if self.clang_delta is None:
|
|
191
|
+
return ()
|
|
192
|
+
else:
|
|
193
|
+
return clang_delta_pumps(self.clang_delta)
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def status(self) -> str:
|
|
197
|
+
if self.current_pump is None:
|
|
198
|
+
if self.current_reduction_pass is not None:
|
|
199
|
+
return f"Running reduction pass {self.current_reduction_pass.__name__}"
|
|
200
|
+
else:
|
|
201
|
+
return "Selecting reduction pass"
|
|
202
|
+
else:
|
|
203
|
+
if self.current_reduction_pass is not None:
|
|
204
|
+
return f"Running reduction pass {self.current_reduction_pass.__name__} under pump {self.current_pump.__name__}"
|
|
205
|
+
else:
|
|
206
|
+
return f"Running reduction pump {self.current_pump.__name__}"
|
|
207
|
+
|
|
208
|
+
async def run_pass(self, rp: ReductionPass[bytes]) -> None:
|
|
209
|
+
try:
|
|
210
|
+
assert self.current_reduction_pass is None
|
|
211
|
+
self.current_reduction_pass = rp
|
|
212
|
+
await rp(self.target)
|
|
213
|
+
finally:
|
|
214
|
+
self.current_reduction_pass = None
|
|
215
|
+
|
|
216
|
+
async def pump(self, rp: ReductionPump[bytes]) -> None:
|
|
217
|
+
try:
|
|
218
|
+
assert self.current_pump is None
|
|
219
|
+
self.current_pump = rp
|
|
220
|
+
pumped = await rp(self.target)
|
|
221
|
+
current = self.target.current_test_case
|
|
222
|
+
if pumped == current:
|
|
223
|
+
return
|
|
224
|
+
with self.backtrack(pumped):
|
|
225
|
+
for f in [
|
|
226
|
+
self.run_great_passes,
|
|
227
|
+
self.run_ok_passes,
|
|
228
|
+
self.run_last_ditch_passes,
|
|
229
|
+
]:
|
|
230
|
+
await f()
|
|
231
|
+
if self.target.sort_key(
|
|
232
|
+
self.target.current_test_case
|
|
233
|
+
) < self.target.sort_key(current):
|
|
234
|
+
break
|
|
235
|
+
|
|
236
|
+
finally:
|
|
237
|
+
self.current_pump = None
|
|
238
|
+
|
|
239
|
+
async def run_great_passes(self) -> None:
|
|
240
|
+
for rp in self.great_passes:
|
|
241
|
+
await self.run_pass(rp)
|
|
242
|
+
|
|
243
|
+
async def run_ok_passes(self) -> None:
|
|
244
|
+
for rp in self.ok_passes:
|
|
245
|
+
await self.run_pass(rp)
|
|
246
|
+
|
|
247
|
+
async def run_last_ditch_passes(self) -> None:
|
|
248
|
+
for rp in self.last_ditch_passes:
|
|
249
|
+
await self.run_pass(rp)
|
|
250
|
+
|
|
251
|
+
async def run_some_passes(self) -> None:
|
|
252
|
+
prev = self.target.current_test_case
|
|
253
|
+
await self.run_great_passes()
|
|
254
|
+
if prev != self.target.current_test_case and not self.unlocked_ok_passes:
|
|
255
|
+
return
|
|
256
|
+
self.unlocked_ok_passes = True
|
|
257
|
+
await self.run_ok_passes()
|
|
258
|
+
if prev != self.target.current_test_case:
|
|
259
|
+
return
|
|
260
|
+
await self.run_last_ditch_passes()
|
|
261
|
+
|
|
262
|
+
async def initial_cut(self) -> None:
|
|
263
|
+
while True:
|
|
264
|
+
prev = self.target.current_size
|
|
265
|
+
for rp in self.initial_cuts:
|
|
266
|
+
async with trio.open_nursery() as nursery:
|
|
267
|
+
|
|
268
|
+
@nursery.start_soon
|
|
269
|
+
async def _() -> None:
|
|
270
|
+
"""
|
|
271
|
+
Watcher task that cancels the current reduction pass as
|
|
272
|
+
soon as it stops looking like a good idea to keep running
|
|
273
|
+
it. Current criteria:
|
|
274
|
+
|
|
275
|
+
1. If it's been more than 5s since the last successful reduction.
|
|
276
|
+
2. If the reduction rate of the task has dropped under 50% of its
|
|
277
|
+
best so far.
|
|
278
|
+
"""
|
|
279
|
+
iters = 0
|
|
280
|
+
initial_size = self.target.current_size
|
|
281
|
+
best_reduction_rate: float | None = None
|
|
282
|
+
|
|
283
|
+
while True:
|
|
284
|
+
iters += 1
|
|
285
|
+
deleted = initial_size - self.target.current_size
|
|
286
|
+
|
|
287
|
+
current = self.target.current_test_case
|
|
288
|
+
await trio.sleep(5)
|
|
289
|
+
rate = deleted / iters
|
|
290
|
+
|
|
291
|
+
if (
|
|
292
|
+
best_reduction_rate is None
|
|
293
|
+
or rate > best_reduction_rate
|
|
294
|
+
):
|
|
295
|
+
best_reduction_rate = rate
|
|
296
|
+
|
|
297
|
+
assert best_reduction_rate is not None
|
|
298
|
+
|
|
299
|
+
if (
|
|
300
|
+
rate < 0.5 * best_reduction_rate
|
|
301
|
+
or current == self.target.current_test_case
|
|
302
|
+
):
|
|
303
|
+
nursery.cancel_scope.cancel()
|
|
304
|
+
break
|
|
305
|
+
|
|
306
|
+
await self.run_pass(rp)
|
|
307
|
+
nursery.cancel_scope.cancel()
|
|
308
|
+
if self.target.current_size >= 0.99 * prev:
|
|
309
|
+
return
|
|
310
|
+
|
|
311
|
+
async def run(self) -> None:
|
|
312
|
+
await self.target.setup()
|
|
313
|
+
|
|
314
|
+
if await self.target.is_interesting(b""):
|
|
315
|
+
return
|
|
316
|
+
|
|
317
|
+
prev = 0
|
|
318
|
+
for c in [0, 1, ord(b"\n"), ord(b"0"), ord(b"z"), 255]:
|
|
319
|
+
if await self.target.is_interesting(bytes([c])):
|
|
320
|
+
for i in range(c):
|
|
321
|
+
if await self.target.is_interesting(bytes([i])):
|
|
322
|
+
break
|
|
323
|
+
return
|
|
324
|
+
|
|
325
|
+
await self.initial_cut()
|
|
326
|
+
|
|
327
|
+
while True:
|
|
328
|
+
prev = self.target.current_test_case
|
|
329
|
+
await self.run_some_passes()
|
|
330
|
+
if self.target.current_test_case != prev:
|
|
331
|
+
continue
|
|
332
|
+
for pump in self.pumps:
|
|
333
|
+
await self.pump(pump)
|
|
334
|
+
if self.target.current_test_case == prev:
|
|
335
|
+
break
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class UpdateKeys(Patches[dict[str, bytes], dict[str, bytes]]):
|
|
339
|
+
@property
|
|
340
|
+
def empty(self) -> dict[str, bytes]:
|
|
341
|
+
return {}
|
|
342
|
+
|
|
343
|
+
def combine(self, *patches: dict[str, bytes]) -> dict[str, bytes]:
|
|
344
|
+
result = {}
|
|
345
|
+
for p in patches:
|
|
346
|
+
for k, v in p.items():
|
|
347
|
+
result[k] = v
|
|
348
|
+
return result
|
|
349
|
+
|
|
350
|
+
def apply(
|
|
351
|
+
self, patch: dict[str, bytes], target: dict[str, bytes]
|
|
352
|
+
) -> dict[str, bytes]:
|
|
353
|
+
result = target.copy()
|
|
354
|
+
result.update(patch)
|
|
355
|
+
return result
|
|
356
|
+
|
|
357
|
+
def size(self, patch: dict[str, bytes]) -> int:
|
|
358
|
+
return len(patch)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class KeyProblem(ReductionProblem[bytes]):
|
|
362
|
+
def __init__(
|
|
363
|
+
self,
|
|
364
|
+
base_problem: ReductionProblem[dict[str, bytes]],
|
|
365
|
+
applier: PatchApplier[dict[str, bytes], dict[str, bytes]],
|
|
366
|
+
key: str,
|
|
367
|
+
):
|
|
368
|
+
super().__init__(work=base_problem.work)
|
|
369
|
+
self.base_problem = base_problem
|
|
370
|
+
self.applier = applier
|
|
371
|
+
self.key = key
|
|
372
|
+
|
|
373
|
+
@property
|
|
374
|
+
def current_test_case(self) -> bytes:
|
|
375
|
+
return self.base_problem.current_test_case[self.key]
|
|
376
|
+
|
|
377
|
+
async def is_interesting(self, test_case: bytes) -> bool:
|
|
378
|
+
return await self.applier.try_apply_patch({self.key: test_case})
|
|
379
|
+
|
|
380
|
+
def size(self, test_case: bytes) -> int:
|
|
381
|
+
return len(test_case)
|
|
382
|
+
|
|
383
|
+
def sort_key(self, test_case: bytes) -> Any:
|
|
384
|
+
return shortlex(test_case)
|
|
385
|
+
|
|
386
|
+
def display(self, value: bytes) -> str:
|
|
387
|
+
return repr(value)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
@define
|
|
391
|
+
class DirectoryShrinkRay(Reducer[dict[str, bytes]]):
|
|
392
|
+
clang_delta: Optional[ClangDelta] = None
|
|
393
|
+
|
|
394
|
+
async def run(self):
|
|
395
|
+
prev = None
|
|
396
|
+
while prev != self.target.current_test_case:
|
|
397
|
+
prev = self.target.current_test_case
|
|
398
|
+
await self.delete_keys()
|
|
399
|
+
await self.shrink_values()
|
|
400
|
+
|
|
401
|
+
async def delete_keys(self):
|
|
402
|
+
target = self.target.current_test_case
|
|
403
|
+
keys = list(target.keys())
|
|
404
|
+
keys.sort(key=lambda k: (shortlex(target[k]), shortlex(k)), reverse=True)
|
|
405
|
+
for k in keys:
|
|
406
|
+
attempt = self.target.current_test_case.copy()
|
|
407
|
+
del attempt[k]
|
|
408
|
+
await self.target.is_interesting(attempt)
|
|
409
|
+
|
|
410
|
+
async def shrink_values(self):
|
|
411
|
+
async with trio.open_nursery() as nursery:
|
|
412
|
+
applier = PatchApplier(patches=UpdateKeys(), problem=self.target)
|
|
413
|
+
for k in self.target.current_test_case.keys():
|
|
414
|
+
key_problem = KeyProblem(
|
|
415
|
+
base_problem=self.target,
|
|
416
|
+
applier=applier,
|
|
417
|
+
key=k,
|
|
418
|
+
)
|
|
419
|
+
if self.clang_delta is not None and any(
|
|
420
|
+
k.endswith(s) for s in C_FILE_EXTENSIONS
|
|
421
|
+
):
|
|
422
|
+
clang_delta = self.clang_delta
|
|
423
|
+
else:
|
|
424
|
+
clang_delta = None
|
|
425
|
+
|
|
426
|
+
key_shrinkray = ShrinkRay(
|
|
427
|
+
clang_delta=clang_delta,
|
|
428
|
+
target=key_problem,
|
|
429
|
+
)
|
|
430
|
+
nursery.start_soon(key_shrinkray.run)
|
shrinkray/work.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
import heapq
|
|
2
|
+
from contextlib import asynccontextmanager
|
|
3
|
+
from enum import IntEnum
|
|
4
|
+
from itertools import islice
|
|
5
|
+
from random import Random
|
|
6
|
+
from typing import Awaitable, Callable, Optional, Sequence, TypeVar
|
|
7
|
+
|
|
8
|
+
import trio
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Volume(IntEnum):
|
|
12
|
+
quiet = 0
|
|
13
|
+
normal = 1
|
|
14
|
+
verbose = 2
|
|
15
|
+
debug = 3
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
S = TypeVar("S")
|
|
19
|
+
T = TypeVar("T")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
TICK_FREQUENCY = 0.05
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class WorkContext:
|
|
26
|
+
"""A grab bag of useful tools for 'doing work'. Manages randomness,
|
|
27
|
+
logging, concurrency."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
random: Optional[Random] = None,
|
|
32
|
+
parallelism: int = 1,
|
|
33
|
+
volume: Volume = Volume.normal,
|
|
34
|
+
):
|
|
35
|
+
self.random = random or Random(0)
|
|
36
|
+
self.parallelism = parallelism
|
|
37
|
+
self.volume = volume
|
|
38
|
+
self.last_ticked = float("-inf")
|
|
39
|
+
|
|
40
|
+
@asynccontextmanager
|
|
41
|
+
async def map(self, ls: Sequence[T], f: Callable[[T], Awaitable[S]]):
|
|
42
|
+
"""Lazy parallel map.
|
|
43
|
+
|
|
44
|
+
Does a reasonable amount of fine tuning so that it doesn't race
|
|
45
|
+
ahead of the current point of iteration and will generallly have
|
|
46
|
+
prefetched at most as many values as you've already read. This
|
|
47
|
+
is especially important for its use in implementing `find_first`,
|
|
48
|
+
which we want to avoid doing redundant work when there are lots of
|
|
49
|
+
reduction opportunities.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
async with trio.open_nursery() as nursery:
|
|
53
|
+
send, receive = trio.open_memory_channel(self.parallelism + 1)
|
|
54
|
+
|
|
55
|
+
@nursery.start_soon
|
|
56
|
+
async def do_map():
|
|
57
|
+
if self.parallelism > 1:
|
|
58
|
+
it = iter(ls)
|
|
59
|
+
|
|
60
|
+
for x in it:
|
|
61
|
+
await send.send(await f(x))
|
|
62
|
+
break
|
|
63
|
+
else:
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
n = 2
|
|
67
|
+
while True:
|
|
68
|
+
values = list(islice(it, n))
|
|
69
|
+
if not values:
|
|
70
|
+
send.close()
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
async with parallel_map(
|
|
74
|
+
values, f, parallelism=min(self.parallelism, n)
|
|
75
|
+
) as result:
|
|
76
|
+
async for v in result:
|
|
77
|
+
await send.send(v)
|
|
78
|
+
|
|
79
|
+
n *= 2
|
|
80
|
+
else:
|
|
81
|
+
for x in ls:
|
|
82
|
+
await send.send(await f(x))
|
|
83
|
+
send.close()
|
|
84
|
+
|
|
85
|
+
yield receive
|
|
86
|
+
|
|
87
|
+
@asynccontextmanager
|
|
88
|
+
async def filter(self, ls: Sequence[T], f: Callable[[T], Awaitable[bool]]):
|
|
89
|
+
async def apply(x: T) -> tuple[T, bool]:
|
|
90
|
+
return (x, await f(x))
|
|
91
|
+
|
|
92
|
+
async with trio.open_nursery() as nursery:
|
|
93
|
+
send, receive = trio.open_memory_channel(float("inf"))
|
|
94
|
+
|
|
95
|
+
@nursery.start_soon
|
|
96
|
+
async def _():
|
|
97
|
+
async with self.map(ls, apply) as results:
|
|
98
|
+
async for x, v in results:
|
|
99
|
+
if v:
|
|
100
|
+
await send.send(x)
|
|
101
|
+
|
|
102
|
+
yield receive
|
|
103
|
+
nursery.cancel_scope.cancel()
|
|
104
|
+
|
|
105
|
+
async def find_first_value(
|
|
106
|
+
self, ls: Sequence[T], f: Callable[[T], Awaitable[bool]]
|
|
107
|
+
) -> T:
|
|
108
|
+
"""Returns the first element of `ls` that satisfies `f`, or
|
|
109
|
+
raises `NotFound` if no such element exists.
|
|
110
|
+
|
|
111
|
+
Will run in parallel if parallelism is enabled.
|
|
112
|
+
"""
|
|
113
|
+
async with self.filter(ls, f) as filtered:
|
|
114
|
+
async for x in filtered:
|
|
115
|
+
return x
|
|
116
|
+
raise NotFound()
|
|
117
|
+
|
|
118
|
+
async def find_large_integer(self, f: Callable[[int], Awaitable[bool]]) -> int:
|
|
119
|
+
"""Finds a (hopefully large) integer n such that f(n) is True and f(n + 1)
|
|
120
|
+
is False. Runs in O(log(n)).
|
|
121
|
+
|
|
122
|
+
f(0) is assumed to be True and will not be checked. May not terminate unless
|
|
123
|
+
f(n) is False for all sufficiently large n.
|
|
124
|
+
"""
|
|
125
|
+
# We first do a linear scan over the small numbers and only start to do
|
|
126
|
+
# anything intelligent if f(4) is true. This is because it's very hard to
|
|
127
|
+
# win big when the result is small. If the result is 0 and we try 2 first
|
|
128
|
+
# then we've done twice as much work as we needed to!
|
|
129
|
+
for i in range(1, 5):
|
|
130
|
+
if not await f(i):
|
|
131
|
+
return i - 1
|
|
132
|
+
|
|
133
|
+
# We now know that f(4) is true. We want to find some number for which
|
|
134
|
+
# f(n) is *not* true.
|
|
135
|
+
# lo is the largest number for which we know that f(lo) is true.
|
|
136
|
+
lo = 4
|
|
137
|
+
|
|
138
|
+
# Exponential probe upwards until we find some value hi such that f(hi)
|
|
139
|
+
# is not true. Subsequently we maintain the invariant that hi is the
|
|
140
|
+
# smallest number for which we know that f(hi) is not true.
|
|
141
|
+
hi = 5
|
|
142
|
+
while await f(hi):
|
|
143
|
+
lo = hi
|
|
144
|
+
hi *= 2
|
|
145
|
+
|
|
146
|
+
# Now binary search until lo + 1 = hi. At that point we have f(lo) and not
|
|
147
|
+
# f(lo + 1), as desired.
|
|
148
|
+
while lo + 1 < hi:
|
|
149
|
+
mid = (lo + hi) // 2
|
|
150
|
+
if await f(mid):
|
|
151
|
+
lo = mid
|
|
152
|
+
else:
|
|
153
|
+
hi = mid
|
|
154
|
+
return lo
|
|
155
|
+
|
|
156
|
+
def warn(self, msg: str) -> None:
|
|
157
|
+
self.report(msg, Volume.normal)
|
|
158
|
+
|
|
159
|
+
def note(self, msg: str) -> None:
|
|
160
|
+
self.report(msg, Volume.normal)
|
|
161
|
+
|
|
162
|
+
def debug(self, msg: str) -> None:
|
|
163
|
+
self.report(msg, Volume.debug)
|
|
164
|
+
|
|
165
|
+
def report(self, msg: str, level: Volume) -> None:
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class NotFound(Exception):
|
|
170
|
+
pass
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@asynccontextmanager
|
|
174
|
+
async def parallel_map(
|
|
175
|
+
ls: Sequence[T],
|
|
176
|
+
f: Callable[[T], Awaitable[S]],
|
|
177
|
+
parallelism: int,
|
|
178
|
+
):
|
|
179
|
+
send_out_values, receive_out_values = trio.open_memory_channel(parallelism)
|
|
180
|
+
|
|
181
|
+
work = list(enumerate(ls))
|
|
182
|
+
work.reverse()
|
|
183
|
+
|
|
184
|
+
result_heap = []
|
|
185
|
+
|
|
186
|
+
async with trio.open_nursery() as nursery:
|
|
187
|
+
results_ready = trio.Event()
|
|
188
|
+
|
|
189
|
+
for _ in range(parallelism):
|
|
190
|
+
|
|
191
|
+
@nursery.start_soon
|
|
192
|
+
async def do_work():
|
|
193
|
+
while work:
|
|
194
|
+
i, x = work.pop()
|
|
195
|
+
result = await f(x)
|
|
196
|
+
heapq.heappush(result_heap, (i, result))
|
|
197
|
+
results_ready.set()
|
|
198
|
+
|
|
199
|
+
@nursery.start_soon
|
|
200
|
+
async def consolidate() -> None:
|
|
201
|
+
i = 0
|
|
202
|
+
|
|
203
|
+
while work or result_heap:
|
|
204
|
+
while not result_heap:
|
|
205
|
+
await results_ready.wait()
|
|
206
|
+
assert result_heap
|
|
207
|
+
j, x = result_heap[0]
|
|
208
|
+
if j == i:
|
|
209
|
+
await send_out_values.send(x)
|
|
210
|
+
i = j + 1
|
|
211
|
+
heapq.heappop(result_heap)
|
|
212
|
+
else:
|
|
213
|
+
await results_ready.wait()
|
|
214
|
+
send_out_values.close()
|
|
215
|
+
|
|
216
|
+
yield receive_out_values
|
|
217
|
+
nursery.cancel_scope.cancel()
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright © 2023 David R. MacIver
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|