shrinkray 0.0.0__py3-none-any.whl → 25.12.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,253 @@
1
+ """Client for communicating with the reducer subprocess."""
2
+
3
+ import asyncio
4
+ import sys
5
+ import traceback
6
+ import uuid
7
+ from collections.abc import AsyncIterator
8
+ from typing import Any
9
+
10
+ from shrinkray.subprocess.protocol import (
11
+ ProgressUpdate,
12
+ Request,
13
+ Response,
14
+ deserialize,
15
+ serialize,
16
+ )
17
+
18
+
19
+ class SubprocessClient:
20
+ """Client for communicating with the reducer subprocess via JSON protocol."""
21
+
22
+ def __init__(self, debug_mode: bool = False):
23
+ self._process: asyncio.subprocess.Process | None = None
24
+ self._pending_responses: dict[str, asyncio.Future[Response]] = {}
25
+ self._progress_queue: asyncio.Queue[ProgressUpdate] = asyncio.Queue()
26
+ self._reader_task: asyncio.Task | None = None
27
+ self._completed = False
28
+ self._error_message: str | None = None
29
+ self._debug_mode = debug_mode
30
+
31
+ async def start(self) -> None:
32
+ """Launch the subprocess."""
33
+ # In debug mode, inherit stderr so interestingness test output
34
+ # goes directly to the parent process's stderr
35
+ self._process = await asyncio.create_subprocess_exec(
36
+ sys.executable,
37
+ "-m",
38
+ "shrinkray.subprocess.worker",
39
+ stdin=asyncio.subprocess.PIPE,
40
+ stdout=asyncio.subprocess.PIPE,
41
+ stderr=sys.stderr,
42
+ )
43
+ self._reader_task = asyncio.create_task(self._read_output())
44
+
45
+ async def _read_output(self) -> None:
46
+ """Read and dispatch messages from subprocess stdout."""
47
+ if self._process is None or self._process.stdout is None:
48
+ return
49
+
50
+ buffer = b""
51
+ while True:
52
+ try:
53
+ chunk = await self._process.stdout.read(4096)
54
+ if not chunk:
55
+ break
56
+ buffer += chunk
57
+ while b"\n" in buffer:
58
+ line, buffer = buffer.split(b"\n", 1)
59
+ if line:
60
+ await self._handle_message(line.decode("utf-8"))
61
+ except Exception:
62
+ traceback.print_exc()
63
+ break
64
+
65
+ async def _handle_message(self, line: str) -> None:
66
+ """Handle a message from the subprocess."""
67
+ try:
68
+ msg = deserialize(line)
69
+ except Exception:
70
+ traceback.print_exc()
71
+ return
72
+
73
+ if isinstance(msg, ProgressUpdate):
74
+ await self._progress_queue.put(msg)
75
+ elif isinstance(msg, Response):
76
+ # Check for completion or error signal (unsolicited responses with empty id)
77
+ if msg.id == "":
78
+ if msg.result and msg.result.get("status") == "completed":
79
+ self._completed = True
80
+ # Wake up any pending futures
81
+ for future in self._pending_responses.values():
82
+ if not future.done():
83
+ future.set_exception(Exception("Subprocess completed"))
84
+ elif msg.error:
85
+ self._completed = True
86
+ self._error_message = msg.error
87
+ # Wake up any pending futures with the error
88
+ for future in self._pending_responses.values():
89
+ if not future.done():
90
+ future.set_exception(Exception(msg.error))
91
+ return
92
+
93
+ # Match response to pending request
94
+ if msg.id in self._pending_responses:
95
+ future = self._pending_responses.pop(msg.id)
96
+ if not future.done():
97
+ future.set_result(msg)
98
+
99
+ async def send_command(
100
+ self, command: str, params: dict[str, Any] | None = None
101
+ ) -> Response:
102
+ """Send a command to the subprocess and wait for response."""
103
+ if self._process is None or self._process.stdin is None:
104
+ raise RuntimeError("Subprocess not started")
105
+
106
+ request_id = str(uuid.uuid4())
107
+ request = Request(id=request_id, command=command, params=params or {})
108
+
109
+ # Create future for response
110
+ future: asyncio.Future[Response] = asyncio.get_event_loop().create_future()
111
+ self._pending_responses[request_id] = future
112
+
113
+ # Send request
114
+ line = serialize(request) + "\n"
115
+ self._process.stdin.write(line.encode("utf-8"))
116
+ await self._process.stdin.drain()
117
+
118
+ # Wait for response
119
+ try:
120
+ return await future
121
+ except Exception:
122
+ self._pending_responses.pop(request_id, None)
123
+ raise
124
+
125
+ async def start_reduction(
126
+ self,
127
+ file_path: str,
128
+ test: list[str],
129
+ parallelism: int | None = None,
130
+ timeout: float = 1.0,
131
+ seed: int = 0,
132
+ input_type: str = "all",
133
+ in_place: bool = False,
134
+ formatter: str = "default",
135
+ volume: str = "normal",
136
+ no_clang_delta: bool = False,
137
+ clang_delta: str = "",
138
+ trivial_is_error: bool = True,
139
+ ) -> Response:
140
+ """Start the reduction process."""
141
+ params = {
142
+ "file_path": file_path,
143
+ "test": test,
144
+ "timeout": timeout,
145
+ "seed": seed,
146
+ "input_type": input_type,
147
+ "in_place": in_place,
148
+ "formatter": formatter,
149
+ "volume": volume,
150
+ "no_clang_delta": no_clang_delta,
151
+ "clang_delta": clang_delta,
152
+ "trivial_is_error": trivial_is_error,
153
+ }
154
+ if parallelism is not None:
155
+ params["parallelism"] = parallelism
156
+ return await self.send_command("start", params)
157
+
158
+ async def get_status(self) -> Response:
159
+ """Get current reduction status."""
160
+ return await self.send_command("status")
161
+
162
+ async def cancel(self) -> Response:
163
+ """Cancel the reduction."""
164
+ if self._completed:
165
+ return Response(id="", result={"status": "already_completed"})
166
+ if self._process is None or self._process.returncode is not None:
167
+ return Response(id="", result={"status": "process_exited"})
168
+ try:
169
+ return await self.send_command("cancel")
170
+ except Exception:
171
+ return Response(id="", result={"status": "cancelled"})
172
+
173
+ async def disable_pass(self, pass_name: str) -> Response:
174
+ """Disable a reduction pass by name."""
175
+ if self._completed:
176
+ return Response(id="", result={"status": "already_completed"})
177
+ try:
178
+ return await self.send_command("disable_pass", {"pass_name": pass_name})
179
+ except Exception:
180
+ traceback.print_exc()
181
+ return Response(id="", error="Failed to disable pass")
182
+
183
+ async def enable_pass(self, pass_name: str) -> Response:
184
+ """Enable a previously disabled reduction pass."""
185
+ if self._completed:
186
+ return Response(id="", result={"status": "already_completed"})
187
+ try:
188
+ return await self.send_command("enable_pass", {"pass_name": pass_name})
189
+ except Exception:
190
+ traceback.print_exc()
191
+ return Response(id="", error="Failed to enable pass")
192
+
193
+ async def skip_current_pass(self) -> Response:
194
+ """Skip the currently running pass."""
195
+ if self._completed:
196
+ return Response(id="", result={"status": "already_completed"})
197
+ try:
198
+ return await self.send_command("skip_pass")
199
+ except Exception:
200
+ traceback.print_exc()
201
+ return Response(id="", error="Failed to skip pass")
202
+
203
+ async def get_progress_updates(self) -> AsyncIterator[ProgressUpdate]:
204
+ """Yield progress updates as they arrive."""
205
+ while not self._completed:
206
+ try:
207
+ update = await asyncio.wait_for(self._progress_queue.get(), timeout=0.5)
208
+ yield update
209
+ except TimeoutError:
210
+ continue
211
+
212
+ @property
213
+ def is_completed(self) -> bool:
214
+ """Check if the reduction has completed."""
215
+ return self._completed
216
+
217
+ @property
218
+ def error_message(self) -> str | None:
219
+ """Get the error message if the subprocess failed."""
220
+ return self._error_message
221
+
222
+ async def close(self) -> None:
223
+ """Close the subprocess."""
224
+ if self._reader_task is not None:
225
+ self._reader_task.cancel()
226
+ try:
227
+ await self._reader_task
228
+ except asyncio.CancelledError:
229
+ pass
230
+
231
+ if self._process is not None:
232
+ if self._process.stdin is not None:
233
+ try:
234
+ self._process.stdin.close()
235
+ except Exception:
236
+ traceback.print_exc()
237
+ # Only terminate if still running
238
+ if self._process.returncode is None:
239
+ try:
240
+ self._process.terminate()
241
+ await asyncio.wait_for(self._process.wait(), timeout=5.0)
242
+ except TimeoutError:
243
+ self._process.kill()
244
+ await self._process.wait()
245
+ except ProcessLookupError:
246
+ pass # Process already exited
247
+
248
+ async def __aenter__(self) -> "SubprocessClient":
249
+ await self.start()
250
+ return self
251
+
252
+ async def __aexit__(self, *args) -> None:
253
+ await self.close()
@@ -0,0 +1,190 @@
1
+ """Line-oriented JSON protocol for subprocess communication."""
2
+
3
+ import base64
4
+ import json
5
+ from dataclasses import dataclass, field
6
+ from typing import Any
7
+
8
+
9
+ @dataclass
10
+ class Request:
11
+ """A command request from the UI to the worker subprocess."""
12
+
13
+ id: str
14
+ command: str
15
+ params: dict = field(default_factory=dict)
16
+
17
+
18
+ @dataclass
19
+ class Response:
20
+ """A response from the worker subprocess to the UI."""
21
+
22
+ id: str
23
+ result: Any = None
24
+ error: str | None = None
25
+
26
+
27
+ @dataclass
28
+ class ProgressUpdate:
29
+ """An unsolicited progress update from the worker subprocess."""
30
+
31
+ # Current reducer pass/pump status
32
+ status: str
33
+ # Size information
34
+ size: int
35
+ original_size: int
36
+ # Call statistics
37
+ calls: int
38
+ reductions: int
39
+ interesting_calls: int = 0
40
+ wasted_calls: int = 0
41
+ # Runtime in seconds
42
+ runtime: float = 0.0
43
+ # Parallelism stats
44
+ parallel_workers: int = 0
45
+ average_parallelism: float = 0.0
46
+ effective_parallelism: float = 0.0
47
+ # Time since last reduction
48
+ time_since_last_reduction: float = 0.0
49
+ # Content preview (truncated for large files)
50
+ content_preview: str = ""
51
+ # Whether content is hex mode
52
+ hex_mode: bool = False
53
+ # Pass statistics (only passes with test evaluations)
54
+ pass_stats: list["PassStatsData"] = field(default_factory=list)
55
+ # Currently running pass name (for highlighting)
56
+ current_pass_name: str = ""
57
+ # List of disabled pass names
58
+ disabled_passes: list[str] = field(default_factory=list)
59
+
60
+
61
+ @dataclass
62
+ class PassStatsData:
63
+ """Statistics for a single pass (serializable)."""
64
+
65
+ pass_name: str
66
+ bytes_deleted: int
67
+ run_count: int
68
+ test_evaluations: int
69
+ successful_reductions: int
70
+ success_rate: float
71
+
72
+
73
+ def serialize(msg: Request | Response | ProgressUpdate) -> str:
74
+ """Serialize a message to a JSON line (without newline)."""
75
+ if isinstance(msg, Request):
76
+ data = {
77
+ "id": msg.id,
78
+ "command": msg.command,
79
+ "params": msg.params,
80
+ }
81
+ elif isinstance(msg, Response):
82
+ data = {
83
+ "id": msg.id,
84
+ "result": msg.result,
85
+ "error": msg.error,
86
+ }
87
+ elif isinstance(msg, ProgressUpdate):
88
+ data = {
89
+ "type": "progress",
90
+ "data": {
91
+ "status": msg.status,
92
+ "size": msg.size,
93
+ "original_size": msg.original_size,
94
+ "calls": msg.calls,
95
+ "reductions": msg.reductions,
96
+ "interesting_calls": msg.interesting_calls,
97
+ "wasted_calls": msg.wasted_calls,
98
+ "runtime": msg.runtime,
99
+ "parallel_workers": msg.parallel_workers,
100
+ "average_parallelism": msg.average_parallelism,
101
+ "effective_parallelism": msg.effective_parallelism,
102
+ "time_since_last_reduction": msg.time_since_last_reduction,
103
+ "content_preview": msg.content_preview,
104
+ "hex_mode": msg.hex_mode,
105
+ "pass_stats": [
106
+ {
107
+ "pass_name": ps.pass_name,
108
+ "bytes_deleted": ps.bytes_deleted,
109
+ "run_count": ps.run_count,
110
+ "test_evaluations": ps.test_evaluations,
111
+ "successful_reductions": ps.successful_reductions,
112
+ "success_rate": ps.success_rate,
113
+ }
114
+ for ps in msg.pass_stats
115
+ ],
116
+ "current_pass_name": msg.current_pass_name,
117
+ "disabled_passes": msg.disabled_passes,
118
+ },
119
+ }
120
+ else:
121
+ raise TypeError(f"Cannot serialize {msg!r}")
122
+ return json.dumps(data, separators=(",", ":"))
123
+
124
+
125
+ def deserialize(line: str) -> Request | Response | ProgressUpdate:
126
+ """Deserialize a JSON line to a message object."""
127
+ data = json.loads(line)
128
+
129
+ # Check for progress update (has "type" field)
130
+ if data.get("type") == "progress":
131
+ d = data["data"]
132
+
133
+ # Parse pass stats
134
+ pass_stats_data = []
135
+ for ps_dict in d.get("pass_stats", []):
136
+ pass_stats_data.append(
137
+ PassStatsData(
138
+ pass_name=ps_dict["pass_name"],
139
+ bytes_deleted=ps_dict["bytes_deleted"],
140
+ run_count=ps_dict["run_count"],
141
+ test_evaluations=ps_dict["test_evaluations"],
142
+ successful_reductions=ps_dict["successful_reductions"],
143
+ success_rate=ps_dict["success_rate"],
144
+ )
145
+ )
146
+
147
+ return ProgressUpdate(
148
+ status=d["status"],
149
+ size=d["size"],
150
+ original_size=d["original_size"],
151
+ calls=d["calls"],
152
+ reductions=d["reductions"],
153
+ interesting_calls=d.get("interesting_calls", 0),
154
+ wasted_calls=d.get("wasted_calls", 0),
155
+ runtime=d.get("runtime", 0.0),
156
+ parallel_workers=d.get("parallel_workers", 0),
157
+ average_parallelism=d.get("average_parallelism", 0.0),
158
+ effective_parallelism=d.get("effective_parallelism", 0.0),
159
+ time_since_last_reduction=d.get("time_since_last_reduction", 0.0),
160
+ content_preview=d.get("content_preview", ""),
161
+ hex_mode=d.get("hex_mode", False),
162
+ pass_stats=pass_stats_data,
163
+ current_pass_name=d.get("current_pass_name", ""),
164
+ disabled_passes=d.get("disabled_passes", []),
165
+ )
166
+
167
+ # Check for response (has "result" or "error" field)
168
+ if "result" in data or "error" in data:
169
+ return Response(
170
+ id=data["id"],
171
+ result=data.get("result"),
172
+ error=data.get("error"),
173
+ )
174
+
175
+ # Otherwise it's a request
176
+ return Request(
177
+ id=data["id"],
178
+ command=data["command"],
179
+ params=data.get("params", {}),
180
+ )
181
+
182
+
183
+ def encode_bytes(data: bytes) -> str:
184
+ """Encode bytes to base64 string for JSON transport."""
185
+ return base64.b64encode(data).decode("ascii")
186
+
187
+
188
+ def decode_bytes(data: str) -> bytes:
189
+ """Decode base64 string back to bytes."""
190
+ return base64.b64decode(data.encode("ascii"))