shrinkray 0.0.0__py3-none-any.whl → 25.12.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shrinkray/__main__.py +130 -960
- shrinkray/cli.py +70 -0
- shrinkray/display.py +75 -0
- shrinkray/formatting.py +108 -0
- shrinkray/passes/bytes.py +217 -10
- shrinkray/passes/clangdelta.py +47 -17
- shrinkray/passes/definitions.py +84 -4
- shrinkray/passes/genericlanguages.py +61 -7
- shrinkray/passes/json.py +6 -0
- shrinkray/passes/patching.py +65 -57
- shrinkray/passes/python.py +66 -23
- shrinkray/passes/sat.py +505 -91
- shrinkray/passes/sequences.py +26 -6
- shrinkray/problem.py +206 -27
- shrinkray/process.py +49 -0
- shrinkray/reducer.py +187 -25
- shrinkray/state.py +599 -0
- shrinkray/subprocess/__init__.py +24 -0
- shrinkray/subprocess/client.py +253 -0
- shrinkray/subprocess/protocol.py +190 -0
- shrinkray/subprocess/worker.py +491 -0
- shrinkray/tui.py +915 -0
- shrinkray/ui.py +72 -0
- shrinkray/work.py +34 -6
- {shrinkray-0.0.0.dist-info → shrinkray-25.12.26.0.dist-info}/METADATA +44 -27
- shrinkray-25.12.26.0.dist-info/RECORD +33 -0
- {shrinkray-0.0.0.dist-info → shrinkray-25.12.26.0.dist-info}/WHEEL +2 -1
- shrinkray-25.12.26.0.dist-info/entry_points.txt +3 -0
- shrinkray-25.12.26.0.dist-info/top_level.txt +1 -0
- shrinkray/learning.py +0 -221
- shrinkray-0.0.0.dist-info/RECORD +0 -22
- shrinkray-0.0.0.dist-info/entry_points.txt +0 -3
- {shrinkray-0.0.0.dist-info → shrinkray-25.12.26.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
"""Worker subprocess that runs the reducer with trio and communicates via JSON protocol."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import time
|
|
6
|
+
import traceback
|
|
7
|
+
from typing import Any, Protocol
|
|
8
|
+
|
|
9
|
+
import trio
|
|
10
|
+
from binaryornot.helpers import is_binary_string
|
|
11
|
+
|
|
12
|
+
from shrinkray.problem import InvalidInitialExample
|
|
13
|
+
from shrinkray.subprocess.protocol import (
|
|
14
|
+
PassStatsData,
|
|
15
|
+
ProgressUpdate,
|
|
16
|
+
Request,
|
|
17
|
+
Response,
|
|
18
|
+
deserialize,
|
|
19
|
+
serialize,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class InputStream(Protocol):
|
|
24
|
+
"""Protocol for async input streams."""
|
|
25
|
+
|
|
26
|
+
def __aiter__(self) -> "InputStream": ...
|
|
27
|
+
async def __anext__(self) -> bytes | bytearray: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class OutputStream(Protocol):
|
|
31
|
+
"""Protocol for output streams."""
|
|
32
|
+
|
|
33
|
+
async def send(self, data: bytes) -> None: ...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class StdoutStream:
|
|
37
|
+
"""Wrapper around sys.stdout for the OutputStream protocol."""
|
|
38
|
+
|
|
39
|
+
async def send(self, data: bytes) -> None:
|
|
40
|
+
sys.stdout.write(data.decode("utf-8"))
|
|
41
|
+
sys.stdout.flush()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ReducerWorker:
|
|
45
|
+
"""Runs the reducer in a subprocess with JSON protocol communication."""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
input_stream: InputStream | None = None,
|
|
50
|
+
output_stream: OutputStream | None = None,
|
|
51
|
+
):
|
|
52
|
+
self.running = False
|
|
53
|
+
self.reducer = None
|
|
54
|
+
self.problem = None
|
|
55
|
+
self.state = None
|
|
56
|
+
self._cancel_scope: trio.CancelScope | None = None
|
|
57
|
+
# Parallelism tracking
|
|
58
|
+
self._parallel_samples = 0
|
|
59
|
+
self._parallel_total = 0
|
|
60
|
+
# I/O streams - None means use stdin/stdout
|
|
61
|
+
self._input_stream = input_stream
|
|
62
|
+
self._output_stream = output_stream
|
|
63
|
+
|
|
64
|
+
async def emit(self, msg: Response | ProgressUpdate) -> None:
|
|
65
|
+
"""Write a message to the output stream."""
|
|
66
|
+
line = serialize(msg) + "\n"
|
|
67
|
+
if self._output_stream is not None:
|
|
68
|
+
await self._output_stream.send(line.encode("utf-8"))
|
|
69
|
+
else:
|
|
70
|
+
sys.stdout.write(line)
|
|
71
|
+
sys.stdout.flush()
|
|
72
|
+
|
|
73
|
+
async def read_commands(
|
|
74
|
+
self,
|
|
75
|
+
input_stream: InputStream | None = None,
|
|
76
|
+
task_status: trio.TaskStatus[None] = trio.TASK_STATUS_IGNORED,
|
|
77
|
+
) -> None:
|
|
78
|
+
"""Read commands from input stream and dispatch them."""
|
|
79
|
+
task_status.started()
|
|
80
|
+
|
|
81
|
+
# Use provided stream, or instance stream, or default to stdin
|
|
82
|
+
stream: InputStream
|
|
83
|
+
if input_stream is not None:
|
|
84
|
+
stream = input_stream
|
|
85
|
+
elif self._input_stream is not None:
|
|
86
|
+
stream = self._input_stream
|
|
87
|
+
else:
|
|
88
|
+
stream = trio.lowlevel.FdStream(os.dup(sys.stdin.fileno()))
|
|
89
|
+
|
|
90
|
+
buffer = b""
|
|
91
|
+
async for chunk in stream:
|
|
92
|
+
buffer += chunk
|
|
93
|
+
while b"\n" in buffer:
|
|
94
|
+
line, buffer = buffer.split(b"\n", 1)
|
|
95
|
+
if line:
|
|
96
|
+
await self.handle_line(line.decode("utf-8"))
|
|
97
|
+
|
|
98
|
+
async def handle_line(self, line: str) -> None:
|
|
99
|
+
"""Handle a single command line."""
|
|
100
|
+
try:
|
|
101
|
+
request = deserialize(line)
|
|
102
|
+
if not isinstance(request, Request):
|
|
103
|
+
await self.emit(Response(id="", error="Expected a request"))
|
|
104
|
+
return
|
|
105
|
+
response = await self.handle_command(request)
|
|
106
|
+
await self.emit(response)
|
|
107
|
+
except Exception as e:
|
|
108
|
+
traceback.print_exc()
|
|
109
|
+
await self.emit(Response(id="", error=str(e)))
|
|
110
|
+
|
|
111
|
+
async def handle_command(self, request: Request) -> Response:
|
|
112
|
+
"""Handle a command request and return a response."""
|
|
113
|
+
match request.command:
|
|
114
|
+
case "start":
|
|
115
|
+
return await self._handle_start(request.id, request.params)
|
|
116
|
+
case "status":
|
|
117
|
+
return self._handle_status(request.id)
|
|
118
|
+
case "cancel":
|
|
119
|
+
return self._handle_cancel(request.id)
|
|
120
|
+
case "disable_pass":
|
|
121
|
+
return self._handle_disable_pass(request.id, request.params)
|
|
122
|
+
case "enable_pass":
|
|
123
|
+
return self._handle_enable_pass(request.id, request.params)
|
|
124
|
+
case "skip_pass":
|
|
125
|
+
return self._handle_skip_pass(request.id)
|
|
126
|
+
case _:
|
|
127
|
+
return Response(id=request.id, error=f"Unknown command: {request.command}")
|
|
128
|
+
|
|
129
|
+
async def _handle_start(self, request_id: str, params: dict) -> Response:
|
|
130
|
+
"""Start the reduction process."""
|
|
131
|
+
if self.running:
|
|
132
|
+
return Response(id=request_id, error="Already running")
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
await self._start_reduction(params)
|
|
136
|
+
return Response(id=request_id, result={"status": "started"})
|
|
137
|
+
except* InvalidInitialExample as excs:
|
|
138
|
+
assert len(excs.exceptions) == 1
|
|
139
|
+
(e,) = excs.exceptions
|
|
140
|
+
# Build a detailed error message for invalid initial examples
|
|
141
|
+
if self.state is not None:
|
|
142
|
+
error_message = await self.state.build_error_message(e)
|
|
143
|
+
else:
|
|
144
|
+
error_message = str(e)
|
|
145
|
+
except* Exception as e:
|
|
146
|
+
traceback.print_exc()
|
|
147
|
+
error_message = str(e.exceptions[0])
|
|
148
|
+
return Response(id=request_id, error=error_message)
|
|
149
|
+
|
|
150
|
+
async def _start_reduction(self, params: dict) -> None:
|
|
151
|
+
"""Initialize and start the reduction."""
|
|
152
|
+
from shrinkray.cli import InputType
|
|
153
|
+
from shrinkray.passes.clangdelta import (
|
|
154
|
+
C_FILE_EXTENSIONS,
|
|
155
|
+
ClangDelta,
|
|
156
|
+
find_clang_delta,
|
|
157
|
+
)
|
|
158
|
+
from shrinkray.state import (
|
|
159
|
+
ShrinkRayDirectoryState,
|
|
160
|
+
ShrinkRayStateSingleFile,
|
|
161
|
+
)
|
|
162
|
+
from shrinkray.work import Volume
|
|
163
|
+
|
|
164
|
+
filename = params["file_path"]
|
|
165
|
+
test = params["test"]
|
|
166
|
+
parallelism = params.get("parallelism", os.cpu_count() or 1)
|
|
167
|
+
timeout = params.get("timeout", 1.0)
|
|
168
|
+
seed = params.get("seed", 0)
|
|
169
|
+
input_type = InputType[params.get("input_type", "all")]
|
|
170
|
+
in_place = params.get("in_place", False)
|
|
171
|
+
formatter = params.get("formatter", "default")
|
|
172
|
+
volume = Volume[params.get("volume", "normal")]
|
|
173
|
+
no_clang_delta = params.get("no_clang_delta", False)
|
|
174
|
+
clang_delta_path = params.get("clang_delta", "")
|
|
175
|
+
trivial_is_error = params.get("trivial_is_error", True)
|
|
176
|
+
|
|
177
|
+
clang_delta_executable = None
|
|
178
|
+
if os.path.splitext(filename)[1] in C_FILE_EXTENSIONS and not no_clang_delta:
|
|
179
|
+
if not clang_delta_path:
|
|
180
|
+
clang_delta_path = find_clang_delta()
|
|
181
|
+
if clang_delta_path:
|
|
182
|
+
clang_delta_executable = ClangDelta(clang_delta_path)
|
|
183
|
+
|
|
184
|
+
state_kwargs: dict[str, Any] = dict(
|
|
185
|
+
input_type=input_type,
|
|
186
|
+
in_place=in_place,
|
|
187
|
+
test=test,
|
|
188
|
+
timeout=timeout,
|
|
189
|
+
base=os.path.basename(filename),
|
|
190
|
+
parallelism=parallelism,
|
|
191
|
+
filename=filename,
|
|
192
|
+
formatter=formatter,
|
|
193
|
+
trivial_is_error=trivial_is_error,
|
|
194
|
+
seed=seed,
|
|
195
|
+
volume=volume,
|
|
196
|
+
clang_delta_executable=clang_delta_executable,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
if os.path.isdir(filename):
|
|
200
|
+
files = [os.path.join(d, f) for d, _, fs in os.walk(filename) for f in fs]
|
|
201
|
+
initial = {}
|
|
202
|
+
for f in files:
|
|
203
|
+
with open(f, "rb") as i:
|
|
204
|
+
initial[os.path.relpath(f, filename)] = i.read()
|
|
205
|
+
self.state = ShrinkRayDirectoryState(initial=initial, **state_kwargs)
|
|
206
|
+
else:
|
|
207
|
+
with open(filename, "rb") as reader:
|
|
208
|
+
initial = reader.read()
|
|
209
|
+
self.state = ShrinkRayStateSingleFile(initial=initial, **state_kwargs)
|
|
210
|
+
|
|
211
|
+
self.problem = self.state.problem
|
|
212
|
+
self.reducer = self.state.reducer
|
|
213
|
+
|
|
214
|
+
# Validate initial example before starting - this will raise
|
|
215
|
+
# InvalidInitialExample if the initial test case fails
|
|
216
|
+
await self.problem.setup()
|
|
217
|
+
|
|
218
|
+
self.running = True
|
|
219
|
+
|
|
220
|
+
def _handle_status(self, request_id: str) -> Response:
|
|
221
|
+
"""Get current status."""
|
|
222
|
+
if not self.running or self.problem is None:
|
|
223
|
+
return Response(id=request_id, result={"running": False})
|
|
224
|
+
|
|
225
|
+
stats = self.problem.stats
|
|
226
|
+
return Response(
|
|
227
|
+
id=request_id,
|
|
228
|
+
result={
|
|
229
|
+
"running": True,
|
|
230
|
+
"status": self.reducer.status if self.reducer else "",
|
|
231
|
+
"size": stats.current_test_case_size,
|
|
232
|
+
"original_size": stats.initial_test_case_size,
|
|
233
|
+
"calls": stats.calls,
|
|
234
|
+
"reductions": stats.reductions,
|
|
235
|
+
},
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def _handle_cancel(self, request_id: str) -> Response:
|
|
239
|
+
"""Cancel the reduction."""
|
|
240
|
+
if self._cancel_scope is not None:
|
|
241
|
+
self._cancel_scope.cancel()
|
|
242
|
+
self.running = False
|
|
243
|
+
return Response(id=request_id, result={"status": "cancelled"})
|
|
244
|
+
|
|
245
|
+
def _get_known_pass_names(self) -> set[str]:
|
|
246
|
+
"""Get the set of known pass names from pass stats."""
|
|
247
|
+
if self.reducer is None or self.reducer.pass_stats is None:
|
|
248
|
+
return set()
|
|
249
|
+
return set(self.reducer.pass_stats._stats.keys())
|
|
250
|
+
|
|
251
|
+
def _handle_disable_pass(self, request_id: str, params: dict) -> Response:
|
|
252
|
+
"""Disable a reduction pass by name."""
|
|
253
|
+
pass_name = params.get("pass_name", "")
|
|
254
|
+
if not pass_name:
|
|
255
|
+
return Response(id=request_id, error="pass_name is required")
|
|
256
|
+
|
|
257
|
+
known_passes = self._get_known_pass_names()
|
|
258
|
+
if known_passes and pass_name not in known_passes:
|
|
259
|
+
return Response(
|
|
260
|
+
id=request_id,
|
|
261
|
+
error=f"Unknown pass '{pass_name}'. Known passes: {sorted(known_passes)}",
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
if self.reducer is not None and hasattr(self.reducer, "disable_pass"):
|
|
265
|
+
self.reducer.disable_pass(pass_name)
|
|
266
|
+
return Response(id=request_id, result={"status": "disabled", "pass_name": pass_name})
|
|
267
|
+
return Response(id=request_id, error="Reducer does not support pass control")
|
|
268
|
+
|
|
269
|
+
def _handle_enable_pass(self, request_id: str, params: dict) -> Response:
|
|
270
|
+
"""Enable a previously disabled reduction pass."""
|
|
271
|
+
pass_name = params.get("pass_name", "")
|
|
272
|
+
if not pass_name:
|
|
273
|
+
return Response(id=request_id, error="pass_name is required")
|
|
274
|
+
|
|
275
|
+
known_passes = self._get_known_pass_names()
|
|
276
|
+
if known_passes and pass_name not in known_passes:
|
|
277
|
+
return Response(
|
|
278
|
+
id=request_id,
|
|
279
|
+
error=f"Unknown pass '{pass_name}'. Known passes: {sorted(known_passes)}",
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if self.reducer is not None and hasattr(self.reducer, "enable_pass"):
|
|
283
|
+
self.reducer.enable_pass(pass_name)
|
|
284
|
+
return Response(id=request_id, result={"status": "enabled", "pass_name": pass_name})
|
|
285
|
+
return Response(id=request_id, error="Reducer does not support pass control")
|
|
286
|
+
|
|
287
|
+
def _handle_skip_pass(self, request_id: str) -> Response:
|
|
288
|
+
"""Skip the currently running pass."""
|
|
289
|
+
if self.reducer is not None and hasattr(self.reducer, "skip_current_pass"):
|
|
290
|
+
self.reducer.skip_current_pass()
|
|
291
|
+
return Response(id=request_id, result={"status": "skipped"})
|
|
292
|
+
return Response(id=request_id, error="Reducer does not support pass control")
|
|
293
|
+
|
|
294
|
+
def _get_content_preview(self) -> tuple[str, bool]:
|
|
295
|
+
"""Get a preview of the current test case content."""
|
|
296
|
+
if self.problem is None:
|
|
297
|
+
return "", False
|
|
298
|
+
|
|
299
|
+
test_case = self.problem.current_test_case
|
|
300
|
+
|
|
301
|
+
# Handle directory mode
|
|
302
|
+
if isinstance(test_case, dict):
|
|
303
|
+
lines = []
|
|
304
|
+
for name, content in sorted(test_case.items()):
|
|
305
|
+
size = len(content)
|
|
306
|
+
lines.append(f"{name}: {size} bytes")
|
|
307
|
+
return "\n".join(lines[:50]), False
|
|
308
|
+
|
|
309
|
+
# Handle single file mode
|
|
310
|
+
hex_mode = is_binary_string(
|
|
311
|
+
test_case[:1024] if len(test_case) > 1024 else test_case
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
if hex_mode:
|
|
315
|
+
# Show hex dump for binary files
|
|
316
|
+
preview_bytes = test_case[:512]
|
|
317
|
+
lines = []
|
|
318
|
+
for i in range(0, len(preview_bytes), 16):
|
|
319
|
+
chunk = preview_bytes[i : i + 16]
|
|
320
|
+
hex_part = " ".join(f"{b:02x}" for b in chunk)
|
|
321
|
+
ascii_part = "".join(chr(b) if 32 <= b < 127 else "." for b in chunk)
|
|
322
|
+
lines.append(f"{i:08x} {hex_part:<48} {ascii_part}")
|
|
323
|
+
return "\n".join(lines), True
|
|
324
|
+
else:
|
|
325
|
+
# Show text for text files
|
|
326
|
+
# Send up to 100KB of content - the TUI will handle display truncation
|
|
327
|
+
try:
|
|
328
|
+
text = test_case.decode("utf-8", errors="replace")
|
|
329
|
+
# Truncate by character count to handle files with very long lines
|
|
330
|
+
max_chars = 100_000
|
|
331
|
+
if len(text) > max_chars:
|
|
332
|
+
text = text[:max_chars]
|
|
333
|
+
return text, False
|
|
334
|
+
except Exception:
|
|
335
|
+
return "", True
|
|
336
|
+
|
|
337
|
+
async def _build_progress_update(self) -> ProgressUpdate | None:
|
|
338
|
+
"""Build a progress update from current state."""
|
|
339
|
+
if self.problem is None:
|
|
340
|
+
return None
|
|
341
|
+
|
|
342
|
+
stats = self.problem.stats
|
|
343
|
+
content_preview, hex_mode = self._get_content_preview()
|
|
344
|
+
|
|
345
|
+
# Get parallel workers count and track average
|
|
346
|
+
parallel_workers = 0
|
|
347
|
+
if self.state is not None and hasattr(self.state, "parallel_tasks_running"):
|
|
348
|
+
parallel_workers = self.state.parallel_tasks_running
|
|
349
|
+
self._parallel_samples += 1
|
|
350
|
+
self._parallel_total += parallel_workers
|
|
351
|
+
|
|
352
|
+
# Calculate parallelism stats
|
|
353
|
+
average_parallelism = 0.0
|
|
354
|
+
effective_parallelism = 0.0
|
|
355
|
+
if self._parallel_samples > 0:
|
|
356
|
+
average_parallelism = self._parallel_total / self._parallel_samples
|
|
357
|
+
wasteage = (
|
|
358
|
+
stats.wasted_interesting_calls / stats.calls if stats.calls > 0 else 0.0
|
|
359
|
+
)
|
|
360
|
+
effective_parallelism = average_parallelism * (1.0 - wasteage)
|
|
361
|
+
|
|
362
|
+
# Collect pass statistics in run order (only those with test evaluations)
|
|
363
|
+
pass_stats_list = []
|
|
364
|
+
current_pass_name = ""
|
|
365
|
+
if self.reducer is not None:
|
|
366
|
+
# Get the currently running pass name
|
|
367
|
+
current_pass = self.reducer.current_reduction_pass
|
|
368
|
+
if current_pass is not None:
|
|
369
|
+
current_pass_name = getattr(current_pass, "__name__", "")
|
|
370
|
+
|
|
371
|
+
# Get all stats in the order they were first run
|
|
372
|
+
pass_stats = self.reducer.pass_stats
|
|
373
|
+
if pass_stats is not None:
|
|
374
|
+
all_stats = pass_stats.get_stats_in_order()
|
|
375
|
+
|
|
376
|
+
# Only include passes that have made at least one test evaluation
|
|
377
|
+
pass_stats_list = [
|
|
378
|
+
PassStatsData(
|
|
379
|
+
pass_name=ps.pass_name,
|
|
380
|
+
bytes_deleted=ps.bytes_deleted,
|
|
381
|
+
run_count=ps.run_count,
|
|
382
|
+
test_evaluations=ps.test_evaluations,
|
|
383
|
+
successful_reductions=ps.successful_reductions,
|
|
384
|
+
success_rate=ps.success_rate,
|
|
385
|
+
)
|
|
386
|
+
for ps in all_stats
|
|
387
|
+
if ps.test_evaluations > 0
|
|
388
|
+
]
|
|
389
|
+
|
|
390
|
+
# Get disabled passes
|
|
391
|
+
if self.reducer is not None and hasattr(self.reducer, "disabled_passes"):
|
|
392
|
+
disabled_passes = list(self.reducer.disabled_passes)
|
|
393
|
+
else:
|
|
394
|
+
disabled_passes = []
|
|
395
|
+
|
|
396
|
+
return ProgressUpdate(
|
|
397
|
+
status=self.reducer.status if self.reducer else "",
|
|
398
|
+
size=stats.current_test_case_size,
|
|
399
|
+
original_size=stats.initial_test_case_size,
|
|
400
|
+
calls=stats.calls,
|
|
401
|
+
reductions=stats.reductions,
|
|
402
|
+
interesting_calls=stats.interesting_calls,
|
|
403
|
+
wasted_calls=stats.wasted_interesting_calls,
|
|
404
|
+
runtime=time.time() - stats.start_time,
|
|
405
|
+
parallel_workers=parallel_workers,
|
|
406
|
+
average_parallelism=average_parallelism,
|
|
407
|
+
effective_parallelism=effective_parallelism,
|
|
408
|
+
time_since_last_reduction=stats.time_since_last_reduction(),
|
|
409
|
+
content_preview=content_preview,
|
|
410
|
+
hex_mode=hex_mode,
|
|
411
|
+
pass_stats=pass_stats_list,
|
|
412
|
+
current_pass_name=current_pass_name,
|
|
413
|
+
disabled_passes=disabled_passes,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
async def emit_progress_updates(self) -> None:
|
|
417
|
+
"""Periodically emit progress updates."""
|
|
418
|
+
# Emit initial progress update immediately
|
|
419
|
+
update = await self._build_progress_update()
|
|
420
|
+
if update is not None:
|
|
421
|
+
await self.emit(update)
|
|
422
|
+
|
|
423
|
+
while self.running:
|
|
424
|
+
await trio.sleep(0.1)
|
|
425
|
+
update = await self._build_progress_update()
|
|
426
|
+
if update is not None:
|
|
427
|
+
await self.emit(update)
|
|
428
|
+
|
|
429
|
+
async def run_reducer(self) -> None:
|
|
430
|
+
"""Run the reducer."""
|
|
431
|
+
if self.reducer is None:
|
|
432
|
+
return
|
|
433
|
+
|
|
434
|
+
try:
|
|
435
|
+
with trio.CancelScope() as scope:
|
|
436
|
+
self._cancel_scope = scope
|
|
437
|
+
await self.reducer.run()
|
|
438
|
+
|
|
439
|
+
# Check for trivial result after successful completion
|
|
440
|
+
if self.state is not None and self.problem is not None:
|
|
441
|
+
trivial_error = self.state.check_trivial_result(self.problem)
|
|
442
|
+
if trivial_error:
|
|
443
|
+
await self.emit(Response(id="", error=trivial_error))
|
|
444
|
+
except* InvalidInitialExample as excs:
|
|
445
|
+
assert len(excs.exceptions) == 1
|
|
446
|
+
(e,) = excs.exceptions
|
|
447
|
+
# Build a detailed error message for invalid initial examples
|
|
448
|
+
if self.state is not None:
|
|
449
|
+
error_message = await self.state.build_error_message(e)
|
|
450
|
+
else:
|
|
451
|
+
error_message = str(e)
|
|
452
|
+
await self.emit(Response(id="", error=error_message))
|
|
453
|
+
except* Exception as e:
|
|
454
|
+
# Catch any other exception during reduction and emit as error
|
|
455
|
+
traceback.print_exc()
|
|
456
|
+
await self.emit(Response(id="", error=str(e.exceptions[0])))
|
|
457
|
+
finally:
|
|
458
|
+
self._cancel_scope = None
|
|
459
|
+
self.running = False
|
|
460
|
+
|
|
461
|
+
async def run(self) -> None:
|
|
462
|
+
"""Main entry point for the worker."""
|
|
463
|
+
async with trio.open_nursery() as nursery:
|
|
464
|
+
await nursery.start(self.read_commands)
|
|
465
|
+
|
|
466
|
+
# Wait for start command
|
|
467
|
+
while not self.running:
|
|
468
|
+
await trio.sleep(0.01)
|
|
469
|
+
|
|
470
|
+
# Start progress updates and reducer
|
|
471
|
+
nursery.start_soon(self.emit_progress_updates)
|
|
472
|
+
await self.run_reducer()
|
|
473
|
+
|
|
474
|
+
# Emit final progress update before completion
|
|
475
|
+
final_update = await self._build_progress_update()
|
|
476
|
+
if final_update is not None:
|
|
477
|
+
await self.emit(final_update)
|
|
478
|
+
|
|
479
|
+
# Signal completion
|
|
480
|
+
await self.emit(Response(id="", result={"status": "completed"}))
|
|
481
|
+
nursery.cancel_scope.cancel()
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def main() -> None:
|
|
485
|
+
"""Entry point for the worker subprocess."""
|
|
486
|
+
worker = ReducerWorker()
|
|
487
|
+
trio.run(worker.run)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
if __name__ == "__main__":
|
|
491
|
+
main()
|