shrinkray 0.0.0__py3-none-any.whl → 25.12.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,491 @@
1
+ """Worker subprocess that runs the reducer with trio and communicates via JSON protocol."""
2
+
3
+ import os
4
+ import sys
5
+ import time
6
+ import traceback
7
+ from typing import Any, Protocol
8
+
9
+ import trio
10
+ from binaryornot.helpers import is_binary_string
11
+
12
+ from shrinkray.problem import InvalidInitialExample
13
+ from shrinkray.subprocess.protocol import (
14
+ PassStatsData,
15
+ ProgressUpdate,
16
+ Request,
17
+ Response,
18
+ deserialize,
19
+ serialize,
20
+ )
21
+
22
+
23
+ class InputStream(Protocol):
24
+ """Protocol for async input streams."""
25
+
26
+ def __aiter__(self) -> "InputStream": ...
27
+ async def __anext__(self) -> bytes | bytearray: ...
28
+
29
+
30
+ class OutputStream(Protocol):
31
+ """Protocol for output streams."""
32
+
33
+ async def send(self, data: bytes) -> None: ...
34
+
35
+
36
+ class StdoutStream:
37
+ """Wrapper around sys.stdout for the OutputStream protocol."""
38
+
39
+ async def send(self, data: bytes) -> None:
40
+ sys.stdout.write(data.decode("utf-8"))
41
+ sys.stdout.flush()
42
+
43
+
44
+ class ReducerWorker:
45
+ """Runs the reducer in a subprocess with JSON protocol communication."""
46
+
47
+ def __init__(
48
+ self,
49
+ input_stream: InputStream | None = None,
50
+ output_stream: OutputStream | None = None,
51
+ ):
52
+ self.running = False
53
+ self.reducer = None
54
+ self.problem = None
55
+ self.state = None
56
+ self._cancel_scope: trio.CancelScope | None = None
57
+ # Parallelism tracking
58
+ self._parallel_samples = 0
59
+ self._parallel_total = 0
60
+ # I/O streams - None means use stdin/stdout
61
+ self._input_stream = input_stream
62
+ self._output_stream = output_stream
63
+
64
+ async def emit(self, msg: Response | ProgressUpdate) -> None:
65
+ """Write a message to the output stream."""
66
+ line = serialize(msg) + "\n"
67
+ if self._output_stream is not None:
68
+ await self._output_stream.send(line.encode("utf-8"))
69
+ else:
70
+ sys.stdout.write(line)
71
+ sys.stdout.flush()
72
+
73
+ async def read_commands(
74
+ self,
75
+ input_stream: InputStream | None = None,
76
+ task_status: trio.TaskStatus[None] = trio.TASK_STATUS_IGNORED,
77
+ ) -> None:
78
+ """Read commands from input stream and dispatch them."""
79
+ task_status.started()
80
+
81
+ # Use provided stream, or instance stream, or default to stdin
82
+ stream: InputStream
83
+ if input_stream is not None:
84
+ stream = input_stream
85
+ elif self._input_stream is not None:
86
+ stream = self._input_stream
87
+ else:
88
+ stream = trio.lowlevel.FdStream(os.dup(sys.stdin.fileno()))
89
+
90
+ buffer = b""
91
+ async for chunk in stream:
92
+ buffer += chunk
93
+ while b"\n" in buffer:
94
+ line, buffer = buffer.split(b"\n", 1)
95
+ if line:
96
+ await self.handle_line(line.decode("utf-8"))
97
+
98
+ async def handle_line(self, line: str) -> None:
99
+ """Handle a single command line."""
100
+ try:
101
+ request = deserialize(line)
102
+ if not isinstance(request, Request):
103
+ await self.emit(Response(id="", error="Expected a request"))
104
+ return
105
+ response = await self.handle_command(request)
106
+ await self.emit(response)
107
+ except Exception as e:
108
+ traceback.print_exc()
109
+ await self.emit(Response(id="", error=str(e)))
110
+
111
+ async def handle_command(self, request: Request) -> Response:
112
+ """Handle a command request and return a response."""
113
+ match request.command:
114
+ case "start":
115
+ return await self._handle_start(request.id, request.params)
116
+ case "status":
117
+ return self._handle_status(request.id)
118
+ case "cancel":
119
+ return self._handle_cancel(request.id)
120
+ case "disable_pass":
121
+ return self._handle_disable_pass(request.id, request.params)
122
+ case "enable_pass":
123
+ return self._handle_enable_pass(request.id, request.params)
124
+ case "skip_pass":
125
+ return self._handle_skip_pass(request.id)
126
+ case _:
127
+ return Response(id=request.id, error=f"Unknown command: {request.command}")
128
+
129
+ async def _handle_start(self, request_id: str, params: dict) -> Response:
130
+ """Start the reduction process."""
131
+ if self.running:
132
+ return Response(id=request_id, error="Already running")
133
+
134
+ try:
135
+ await self._start_reduction(params)
136
+ return Response(id=request_id, result={"status": "started"})
137
+ except* InvalidInitialExample as excs:
138
+ assert len(excs.exceptions) == 1
139
+ (e,) = excs.exceptions
140
+ # Build a detailed error message for invalid initial examples
141
+ if self.state is not None:
142
+ error_message = await self.state.build_error_message(e)
143
+ else:
144
+ error_message = str(e)
145
+ except* Exception as e:
146
+ traceback.print_exc()
147
+ error_message = str(e.exceptions[0])
148
+ return Response(id=request_id, error=error_message)
149
+
150
+ async def _start_reduction(self, params: dict) -> None:
151
+ """Initialize and start the reduction."""
152
+ from shrinkray.cli import InputType
153
+ from shrinkray.passes.clangdelta import (
154
+ C_FILE_EXTENSIONS,
155
+ ClangDelta,
156
+ find_clang_delta,
157
+ )
158
+ from shrinkray.state import (
159
+ ShrinkRayDirectoryState,
160
+ ShrinkRayStateSingleFile,
161
+ )
162
+ from shrinkray.work import Volume
163
+
164
+ filename = params["file_path"]
165
+ test = params["test"]
166
+ parallelism = params.get("parallelism", os.cpu_count() or 1)
167
+ timeout = params.get("timeout", 1.0)
168
+ seed = params.get("seed", 0)
169
+ input_type = InputType[params.get("input_type", "all")]
170
+ in_place = params.get("in_place", False)
171
+ formatter = params.get("formatter", "default")
172
+ volume = Volume[params.get("volume", "normal")]
173
+ no_clang_delta = params.get("no_clang_delta", False)
174
+ clang_delta_path = params.get("clang_delta", "")
175
+ trivial_is_error = params.get("trivial_is_error", True)
176
+
177
+ clang_delta_executable = None
178
+ if os.path.splitext(filename)[1] in C_FILE_EXTENSIONS and not no_clang_delta:
179
+ if not clang_delta_path:
180
+ clang_delta_path = find_clang_delta()
181
+ if clang_delta_path:
182
+ clang_delta_executable = ClangDelta(clang_delta_path)
183
+
184
+ state_kwargs: dict[str, Any] = dict(
185
+ input_type=input_type,
186
+ in_place=in_place,
187
+ test=test,
188
+ timeout=timeout,
189
+ base=os.path.basename(filename),
190
+ parallelism=parallelism,
191
+ filename=filename,
192
+ formatter=formatter,
193
+ trivial_is_error=trivial_is_error,
194
+ seed=seed,
195
+ volume=volume,
196
+ clang_delta_executable=clang_delta_executable,
197
+ )
198
+
199
+ if os.path.isdir(filename):
200
+ files = [os.path.join(d, f) for d, _, fs in os.walk(filename) for f in fs]
201
+ initial = {}
202
+ for f in files:
203
+ with open(f, "rb") as i:
204
+ initial[os.path.relpath(f, filename)] = i.read()
205
+ self.state = ShrinkRayDirectoryState(initial=initial, **state_kwargs)
206
+ else:
207
+ with open(filename, "rb") as reader:
208
+ initial = reader.read()
209
+ self.state = ShrinkRayStateSingleFile(initial=initial, **state_kwargs)
210
+
211
+ self.problem = self.state.problem
212
+ self.reducer = self.state.reducer
213
+
214
+ # Validate initial example before starting - this will raise
215
+ # InvalidInitialExample if the initial test case fails
216
+ await self.problem.setup()
217
+
218
+ self.running = True
219
+
220
+ def _handle_status(self, request_id: str) -> Response:
221
+ """Get current status."""
222
+ if not self.running or self.problem is None:
223
+ return Response(id=request_id, result={"running": False})
224
+
225
+ stats = self.problem.stats
226
+ return Response(
227
+ id=request_id,
228
+ result={
229
+ "running": True,
230
+ "status": self.reducer.status if self.reducer else "",
231
+ "size": stats.current_test_case_size,
232
+ "original_size": stats.initial_test_case_size,
233
+ "calls": stats.calls,
234
+ "reductions": stats.reductions,
235
+ },
236
+ )
237
+
238
+ def _handle_cancel(self, request_id: str) -> Response:
239
+ """Cancel the reduction."""
240
+ if self._cancel_scope is not None:
241
+ self._cancel_scope.cancel()
242
+ self.running = False
243
+ return Response(id=request_id, result={"status": "cancelled"})
244
+
245
+ def _get_known_pass_names(self) -> set[str]:
246
+ """Get the set of known pass names from pass stats."""
247
+ if self.reducer is None or self.reducer.pass_stats is None:
248
+ return set()
249
+ return set(self.reducer.pass_stats._stats.keys())
250
+
251
+ def _handle_disable_pass(self, request_id: str, params: dict) -> Response:
252
+ """Disable a reduction pass by name."""
253
+ pass_name = params.get("pass_name", "")
254
+ if not pass_name:
255
+ return Response(id=request_id, error="pass_name is required")
256
+
257
+ known_passes = self._get_known_pass_names()
258
+ if known_passes and pass_name not in known_passes:
259
+ return Response(
260
+ id=request_id,
261
+ error=f"Unknown pass '{pass_name}'. Known passes: {sorted(known_passes)}",
262
+ )
263
+
264
+ if self.reducer is not None and hasattr(self.reducer, "disable_pass"):
265
+ self.reducer.disable_pass(pass_name)
266
+ return Response(id=request_id, result={"status": "disabled", "pass_name": pass_name})
267
+ return Response(id=request_id, error="Reducer does not support pass control")
268
+
269
+ def _handle_enable_pass(self, request_id: str, params: dict) -> Response:
270
+ """Enable a previously disabled reduction pass."""
271
+ pass_name = params.get("pass_name", "")
272
+ if not pass_name:
273
+ return Response(id=request_id, error="pass_name is required")
274
+
275
+ known_passes = self._get_known_pass_names()
276
+ if known_passes and pass_name not in known_passes:
277
+ return Response(
278
+ id=request_id,
279
+ error=f"Unknown pass '{pass_name}'. Known passes: {sorted(known_passes)}",
280
+ )
281
+
282
+ if self.reducer is not None and hasattr(self.reducer, "enable_pass"):
283
+ self.reducer.enable_pass(pass_name)
284
+ return Response(id=request_id, result={"status": "enabled", "pass_name": pass_name})
285
+ return Response(id=request_id, error="Reducer does not support pass control")
286
+
287
+ def _handle_skip_pass(self, request_id: str) -> Response:
288
+ """Skip the currently running pass."""
289
+ if self.reducer is not None and hasattr(self.reducer, "skip_current_pass"):
290
+ self.reducer.skip_current_pass()
291
+ return Response(id=request_id, result={"status": "skipped"})
292
+ return Response(id=request_id, error="Reducer does not support pass control")
293
+
294
+ def _get_content_preview(self) -> tuple[str, bool]:
295
+ """Get a preview of the current test case content."""
296
+ if self.problem is None:
297
+ return "", False
298
+
299
+ test_case = self.problem.current_test_case
300
+
301
+ # Handle directory mode
302
+ if isinstance(test_case, dict):
303
+ lines = []
304
+ for name, content in sorted(test_case.items()):
305
+ size = len(content)
306
+ lines.append(f"{name}: {size} bytes")
307
+ return "\n".join(lines[:50]), False
308
+
309
+ # Handle single file mode
310
+ hex_mode = is_binary_string(
311
+ test_case[:1024] if len(test_case) > 1024 else test_case
312
+ )
313
+
314
+ if hex_mode:
315
+ # Show hex dump for binary files
316
+ preview_bytes = test_case[:512]
317
+ lines = []
318
+ for i in range(0, len(preview_bytes), 16):
319
+ chunk = preview_bytes[i : i + 16]
320
+ hex_part = " ".join(f"{b:02x}" for b in chunk)
321
+ ascii_part = "".join(chr(b) if 32 <= b < 127 else "." for b in chunk)
322
+ lines.append(f"{i:08x} {hex_part:<48} {ascii_part}")
323
+ return "\n".join(lines), True
324
+ else:
325
+ # Show text for text files
326
+ # Send up to 100KB of content - the TUI will handle display truncation
327
+ try:
328
+ text = test_case.decode("utf-8", errors="replace")
329
+ # Truncate by character count to handle files with very long lines
330
+ max_chars = 100_000
331
+ if len(text) > max_chars:
332
+ text = text[:max_chars]
333
+ return text, False
334
+ except Exception:
335
+ return "", True
336
+
337
+ async def _build_progress_update(self) -> ProgressUpdate | None:
338
+ """Build a progress update from current state."""
339
+ if self.problem is None:
340
+ return None
341
+
342
+ stats = self.problem.stats
343
+ content_preview, hex_mode = self._get_content_preview()
344
+
345
+ # Get parallel workers count and track average
346
+ parallel_workers = 0
347
+ if self.state is not None and hasattr(self.state, "parallel_tasks_running"):
348
+ parallel_workers = self.state.parallel_tasks_running
349
+ self._parallel_samples += 1
350
+ self._parallel_total += parallel_workers
351
+
352
+ # Calculate parallelism stats
353
+ average_parallelism = 0.0
354
+ effective_parallelism = 0.0
355
+ if self._parallel_samples > 0:
356
+ average_parallelism = self._parallel_total / self._parallel_samples
357
+ wasteage = (
358
+ stats.wasted_interesting_calls / stats.calls if stats.calls > 0 else 0.0
359
+ )
360
+ effective_parallelism = average_parallelism * (1.0 - wasteage)
361
+
362
+ # Collect pass statistics in run order (only those with test evaluations)
363
+ pass_stats_list = []
364
+ current_pass_name = ""
365
+ if self.reducer is not None:
366
+ # Get the currently running pass name
367
+ current_pass = self.reducer.current_reduction_pass
368
+ if current_pass is not None:
369
+ current_pass_name = getattr(current_pass, "__name__", "")
370
+
371
+ # Get all stats in the order they were first run
372
+ pass_stats = self.reducer.pass_stats
373
+ if pass_stats is not None:
374
+ all_stats = pass_stats.get_stats_in_order()
375
+
376
+ # Only include passes that have made at least one test evaluation
377
+ pass_stats_list = [
378
+ PassStatsData(
379
+ pass_name=ps.pass_name,
380
+ bytes_deleted=ps.bytes_deleted,
381
+ run_count=ps.run_count,
382
+ test_evaluations=ps.test_evaluations,
383
+ successful_reductions=ps.successful_reductions,
384
+ success_rate=ps.success_rate,
385
+ )
386
+ for ps in all_stats
387
+ if ps.test_evaluations > 0
388
+ ]
389
+
390
+ # Get disabled passes
391
+ if self.reducer is not None and hasattr(self.reducer, "disabled_passes"):
392
+ disabled_passes = list(self.reducer.disabled_passes)
393
+ else:
394
+ disabled_passes = []
395
+
396
+ return ProgressUpdate(
397
+ status=self.reducer.status if self.reducer else "",
398
+ size=stats.current_test_case_size,
399
+ original_size=stats.initial_test_case_size,
400
+ calls=stats.calls,
401
+ reductions=stats.reductions,
402
+ interesting_calls=stats.interesting_calls,
403
+ wasted_calls=stats.wasted_interesting_calls,
404
+ runtime=time.time() - stats.start_time,
405
+ parallel_workers=parallel_workers,
406
+ average_parallelism=average_parallelism,
407
+ effective_parallelism=effective_parallelism,
408
+ time_since_last_reduction=stats.time_since_last_reduction(),
409
+ content_preview=content_preview,
410
+ hex_mode=hex_mode,
411
+ pass_stats=pass_stats_list,
412
+ current_pass_name=current_pass_name,
413
+ disabled_passes=disabled_passes,
414
+ )
415
+
416
+ async def emit_progress_updates(self) -> None:
417
+ """Periodically emit progress updates."""
418
+ # Emit initial progress update immediately
419
+ update = await self._build_progress_update()
420
+ if update is not None:
421
+ await self.emit(update)
422
+
423
+ while self.running:
424
+ await trio.sleep(0.1)
425
+ update = await self._build_progress_update()
426
+ if update is not None:
427
+ await self.emit(update)
428
+
429
+ async def run_reducer(self) -> None:
430
+ """Run the reducer."""
431
+ if self.reducer is None:
432
+ return
433
+
434
+ try:
435
+ with trio.CancelScope() as scope:
436
+ self._cancel_scope = scope
437
+ await self.reducer.run()
438
+
439
+ # Check for trivial result after successful completion
440
+ if self.state is not None and self.problem is not None:
441
+ trivial_error = self.state.check_trivial_result(self.problem)
442
+ if trivial_error:
443
+ await self.emit(Response(id="", error=trivial_error))
444
+ except* InvalidInitialExample as excs:
445
+ assert len(excs.exceptions) == 1
446
+ (e,) = excs.exceptions
447
+ # Build a detailed error message for invalid initial examples
448
+ if self.state is not None:
449
+ error_message = await self.state.build_error_message(e)
450
+ else:
451
+ error_message = str(e)
452
+ await self.emit(Response(id="", error=error_message))
453
+ except* Exception as e:
454
+ # Catch any other exception during reduction and emit as error
455
+ traceback.print_exc()
456
+ await self.emit(Response(id="", error=str(e.exceptions[0])))
457
+ finally:
458
+ self._cancel_scope = None
459
+ self.running = False
460
+
461
+ async def run(self) -> None:
462
+ """Main entry point for the worker."""
463
+ async with trio.open_nursery() as nursery:
464
+ await nursery.start(self.read_commands)
465
+
466
+ # Wait for start command
467
+ while not self.running:
468
+ await trio.sleep(0.01)
469
+
470
+ # Start progress updates and reducer
471
+ nursery.start_soon(self.emit_progress_updates)
472
+ await self.run_reducer()
473
+
474
+ # Emit final progress update before completion
475
+ final_update = await self._build_progress_update()
476
+ if final_update is not None:
477
+ await self.emit(final_update)
478
+
479
+ # Signal completion
480
+ await self.emit(Response(id="", result={"status": "completed"}))
481
+ nursery.cancel_scope.cancel()
482
+
483
+
484
+ def main() -> None:
485
+ """Entry point for the worker subprocess."""
486
+ worker = ReducerWorker()
487
+ trio.run(worker.run)
488
+
489
+
490
+ if __name__ == "__main__":
491
+ main()