more-compute 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,381 @@
1
+ import os
2
+ import time
3
+ import signal
4
+ from typing import TYPE_CHECKING, cast
5
+ import subprocess
6
+ import sys
7
+ import asyncio
8
+ from fastapi import WebSocket
9
+ import zmq
10
+
11
+ from ..utils.special_commands import AsyncSpecialCommandHandler
12
+
13
+ if TYPE_CHECKING:
14
+ from ..utils.error_utils import ErrorUtils
15
+
16
+ class NextZmqExecutor:
17
+ error_utils: "ErrorUtils"
18
+ cmd_addr: str
19
+ pub_addr: str
20
+ execution_count: int
21
+ interrupt_timeout: float
22
+ worker_pid: int | None
23
+ worker_proc: subprocess.Popen[bytes] | None
24
+ interrupted_cell: int | None
25
+ special_handler: AsyncSpecialCommandHandler | None
26
+ ctx: object # zmq.Context - untyped due to zmq type limitations
27
+ req: object # zmq.Socket - untyped due to zmq type limitations
28
+ sub: object # zmq.Socket - untyped due to zmq type limitations
29
+
30
+ def __init__(self, error_utils: "ErrorUtils", cmd_addr: str | None = None, pub_addr: str | None = None, interrupt_timeout: float = 0.5) -> None:
31
+ self.error_utils = error_utils
32
+ self.cmd_addr = cmd_addr or os.getenv('MC_ZMQ_CMD_ADDR', 'tcp://127.0.0.1:5555')
33
+ self.pub_addr = pub_addr or os.getenv('MC_ZMQ_PUB_ADDR', 'tcp://127.0.0.1:5556')
34
+ self.execution_count = 0
35
+ self.interrupt_timeout = interrupt_timeout
36
+ self.worker_pid = None
37
+ self.worker_proc = None
38
+ self.interrupted_cell = None
39
+ self.special_handler = None
40
+ self._ensure_special_handler()
41
+ self.ctx = zmq.Context.instance() # type: ignore[reportUnknownMemberType]
42
+ self.req = self.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
43
+ self.req.connect(self.cmd_addr) # type: ignore[reportAttributeAccessIssue]
44
+ self.sub = self.ctx.socket(zmq.SUB) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
45
+ self.sub.connect(self.pub_addr) # type: ignore[reportAttributeAccessIssue]
46
+ self.sub.setsockopt_string(zmq.SUBSCRIBE, '') # type: ignore[reportAttributeAccessIssue]
47
+ self._ensure_worker()
48
+
49
+ def _ensure_special_handler(self) -> None:
50
+ if self.special_handler is None:
51
+ self.special_handler = AsyncSpecialCommandHandler({"__name__": "__main__"})
52
+
53
+ def _ensure_worker(self) -> None:
54
+ # Use a temporary REQ socket for probing to avoid locking self.req's state
55
+ tmp = self.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
56
+ tmp.setsockopt(zmq.LINGER, 0) # type: ignore[reportAttributeAccessIssue]
57
+ tmp.setsockopt(zmq.RCVTIMEO, 500) # type: ignore[reportAttributeAccessIssue]
58
+ tmp.setsockopt(zmq.SNDTIMEO, 500) # type: ignore[reportAttributeAccessIssue]
59
+ try:
60
+ tmp.connect(self.cmd_addr) # type: ignore[reportAttributeAccessIssue]
61
+ tmp.send_json({'type': 'ping'}) # type: ignore[reportAttributeAccessIssue]
62
+ _ = cast(dict[str, object], tmp.recv_json()) # type: ignore[reportAttributeAccessIssue]
63
+ except Exception:
64
+ #worker not responding, need to start it
65
+ pass
66
+ else:
67
+ #worker alive
68
+ return
69
+ finally:
70
+ tmp.close(0) # type: ignore[reportAttributeAccessIssue]
71
+
72
+ # Spawn a worker detached if not reachable
73
+ env = os.environ.copy()
74
+ env.setdefault('MC_ZMQ_CMD_ADDR', self.cmd_addr)
75
+ env.setdefault('MC_ZMQ_PUB_ADDR', self.pub_addr)
76
+ try:
77
+ # Keep track of the worker process
78
+ # Redirect stderr to see errors during development
79
+ self.worker_proc = subprocess.Popen(
80
+ [sys.executable, '-m', 'morecompute.execution.worker'],
81
+ env=env,
82
+ stdout=subprocess.DEVNULL,
83
+ stderr=None # Show errors in terminal
84
+ )
85
+ for _ in range(50):
86
+ try:
87
+ tmp2 = self.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
88
+ tmp2.setsockopt(zmq.LINGER, 0) # type: ignore[reportAttributeAccessIssue]
89
+ tmp2.setsockopt(zmq.RCVTIMEO, 500) # type: ignore[reportAttributeAccessIssue]
90
+ tmp2.setsockopt(zmq.SNDTIMEO, 500) # type: ignore[reportAttributeAccessIssue]
91
+ tmp2.connect(self.cmd_addr) # type: ignore[reportAttributeAccessIssue]
92
+ tmp2.send_json({'type': 'ping'}) # type: ignore[reportAttributeAccessIssue]
93
+ resp = cast(dict[str, object], tmp2.recv_json()) # type: ignore[reportAttributeAccessIssue]
94
+ # Store the worker PID for force-kill if needed
95
+ self.worker_pid = resp.get('pid') # type: ignore[assignment]
96
+ except Exception:
97
+ time.sleep(0.1)
98
+ else:
99
+ return
100
+ finally:
101
+ try:
102
+ tmp2.close(0) # type: ignore[reportAttributeAccessIssue]
103
+ except Exception:
104
+ pass
105
+ except Exception:
106
+ pass
107
+ raise RuntimeError('Failed to start/connect ZMQ worker')
108
+
109
+ async def execute_cell(self, cell_index: int, source_code: str, websocket: WebSocket | None = None) -> dict[str, object]:
110
+ import sys
111
+ self._ensure_special_handler()
112
+ handler = self.special_handler
113
+ normalized_source = source_code
114
+ if handler is not None:
115
+ normalized_source = handler._coerce_source_to_text(source_code) # type: ignore[reportPrivateUsage]
116
+ if handler.is_special_command(normalized_source):
117
+ execution_count = getattr(self, 'execution_count', 0) + 1
118
+ self.execution_count = execution_count
119
+ start_time = time.time()
120
+ result: dict[str, object] = {
121
+ 'outputs': [],
122
+ 'error': None,
123
+ 'status': 'ok',
124
+ 'execution_count': execution_count,
125
+ 'execution_time': None,
126
+ }
127
+ if websocket:
128
+ await websocket.send_json({'type': 'execution_start', 'data': {'cell_index': cell_index, 'execution_count': execution_count}})
129
+ result = await handler.execute_special_command(
130
+ normalized_source, result, start_time, execution_count, websocket, cell_index
131
+ )
132
+ result['execution_time'] = f"{(time.time()-start_time)*1000:.1f}ms"
133
+ if websocket:
134
+ await websocket.send_json({'type': 'execution_complete', 'data': {'cell_index': cell_index, 'result': result}})
135
+ return result
136
+
137
+ execution_count = getattr(self, 'execution_count', 0) + 1
138
+ self.execution_count = execution_count
139
+ result: dict[str, object] = {'outputs': [], 'error': None, 'status': 'ok', 'execution_count': execution_count, 'execution_time': None}
140
+ if websocket:
141
+ await websocket.send_json({'type': 'execution_start', 'data': {'cell_index': cell_index, 'execution_count': execution_count}})
142
+
143
+ self.req.send_json({'type': 'execute_cell', 'code': source_code, 'cell_index': cell_index, 'execution_count': execution_count}) # type: ignore[reportAttributeAccessIssue]
144
+ # Consume pub until we see complete for this cell
145
+ start_time = time.time()
146
+ max_wait = 300.0 # 5 minute timeout for really long operations
147
+ while True:
148
+ # Check if this cell was interrupted
149
+ if self.interrupted_cell == cell_index:
150
+ print(f"[EXECUTE] Cell {cell_index} was interrupted, breaking out of execution loop", file=sys.stderr, flush=True)
151
+ self.interrupted_cell = None # Clear the flag
152
+ result.update({
153
+ 'status': 'error',
154
+ 'error': {
155
+ 'output_type': 'error',
156
+ 'ename': 'KeyboardInterrupt',
157
+ 'evalue': 'Execution interrupted by user',
158
+ 'traceback': ['KeyboardInterrupt: Execution was stopped by user']
159
+ }
160
+ })
161
+ break
162
+
163
+ # Timeout check for stuck operations
164
+ if time.time() - start_time > max_wait:
165
+ print(f"[EXECUTE] Cell {cell_index} exceeded max wait time, timing out", file=sys.stderr, flush=True)
166
+ result.update({
167
+ 'status': 'error',
168
+ 'error': {
169
+ 'output_type': 'error',
170
+ 'ename': 'TimeoutError',
171
+ 'evalue': 'Execution exceeded maximum time limit',
172
+ 'traceback': ['TimeoutError: Operation took too long']
173
+ }
174
+ })
175
+ break
176
+
177
+ try:
178
+ msg = cast(dict[str, object], self.sub.recv_json(flags=zmq.NOBLOCK)) # type: ignore[reportAttributeAccessIssue]
179
+ except zmq.Again:
180
+ await asyncio.sleep(0.01)
181
+ continue
182
+ t = msg.get('type')
183
+ if t == 'stream' and websocket:
184
+ await websocket.send_json({'type': 'stream_output', 'data': msg})
185
+ elif t == 'stream_update' and websocket:
186
+ await websocket.send_json({'type': 'stream_output', 'data': msg})
187
+ elif t == 'execute_result' and websocket:
188
+ await websocket.send_json({'type': 'execution_result', 'data': msg})
189
+ elif t == 'display_data' and websocket:
190
+ await websocket.send_json({'type': 'execution_result', 'data': {'cell_index': msg.get('cell_index'), 'execution_count': None, 'data': msg.get('data')}})
191
+ elif t == 'execution_error' and websocket:
192
+ await websocket.send_json({'type': 'execution_error', 'data': msg})
193
+ elif t == 'execution_error':
194
+ if msg.get('cell_index') == cell_index:
195
+ result.update({'status': 'error', 'error': msg.get('error')})
196
+ elif t == 'execution_complete' and msg.get('cell_index') == cell_index:
197
+ result.update(msg.get('result') or {})
198
+ result.setdefault('execution_count', execution_count)
199
+ break
200
+
201
+ # Try to receive the reply from REQ socket (if worker is still alive)
202
+ # If we interrupted/killed the worker, this will fail and we need to reset the socket
203
+ try:
204
+ self.req.setsockopt(zmq.RCVTIMEO, 100) # type: ignore[reportAttributeAccessIssue]
205
+ _ = cast(dict[str, object], self.req.recv_json()) # type: ignore[reportAttributeAccessIssue]
206
+ self.req.setsockopt(zmq.RCVTIMEO, -1) # type: ignore[reportAttributeAccessIssue]
207
+ except zmq.Again:
208
+ # Timeout - worker didn't reply (probably killed), need to reset socket
209
+ print(f"[EXECUTE] Worker didn't reply, resetting REQ socket", file=sys.stderr, flush=True)
210
+ try:
211
+ self.req.close(0) # type: ignore[reportAttributeAccessIssue]
212
+ self.req = self.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
213
+ self.req.connect(self.cmd_addr) # type: ignore[reportAttributeAccessIssue]
214
+ except Exception as e:
215
+ print(f"[EXECUTE] Error resetting socket: {e}", file=sys.stderr, flush=True)
216
+ except Exception as e:
217
+ # Some other error, also reset socket to be safe
218
+ print(f"[EXECUTE] Error receiving reply: {e}, resetting socket", file=sys.stderr, flush=True)
219
+ try:
220
+ self.req.setsockopt(zmq.RCVTIMEO, -1) # type: ignore[reportAttributeAccessIssue]
221
+ self.req.close(0) # type: ignore[reportAttributeAccessIssue]
222
+ self.req = self.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
223
+ self.req.connect(self.cmd_addr) # type: ignore[reportAttributeAccessIssue]
224
+ except Exception:
225
+ pass
226
+ result['execution_time'] = f"{(time.time()-start_time)*1000:.1f}ms"
227
+ if websocket:
228
+ await websocket.send_json({'type': 'execution_complete', 'data': {'cell_index': cell_index, 'result': result}})
229
+ return result
230
+
231
+ async def interrupt_kernel(self, cell_index: int | None = None) -> None:
232
+ """Interrupt the kernel with escalation to force-kill if needed"""
233
+ import sys
234
+ print(f"[INTERRUPT] Starting interrupt for cell {cell_index}", file=sys.stderr, flush=True)
235
+
236
+ # Mark this cell as interrupted so execute_cell can break out
237
+ if isinstance(cell_index, int):
238
+ self.interrupted_cell = cell_index
239
+ print(f"[INTERRUPT] Marked cell {cell_index} as interrupted", file=sys.stderr, flush=True)
240
+
241
+ payload: dict[str, object] = {'type': 'interrupt'}
242
+ if isinstance(cell_index, int):
243
+ payload['cell_index'] = cell_index
244
+
245
+ # Try graceful interrupt, but don't trust it for blocking I/O
246
+ try:
247
+ # Very short timeout since we'll force-kill anyway
248
+ self.req.setsockopt(zmq.SNDTIMEO, 100) # type: ignore[reportAttributeAccessIssue]
249
+ self.req.setsockopt(zmq.RCVTIMEO, 100) # type: ignore[reportAttributeAccessIssue]
250
+ self.req.send_json(payload) # type: ignore[reportAttributeAccessIssue]
251
+ _ = cast(dict[str, object], self.req.recv_json()) # type: ignore[reportAttributeAccessIssue]
252
+ print(f"[INTERRUPT] Sent interrupt signal to worker", file=sys.stderr, flush=True)
253
+ except Exception as e:
254
+ print(f"[INTERRUPT] Could not send interrupt signal: {e}", file=sys.stderr, flush=True)
255
+ finally:
256
+ # Reset timeouts
257
+ self.req.setsockopt(zmq.SNDTIMEO, -1) # type: ignore[reportAttributeAccessIssue]
258
+ self.req.setsockopt(zmq.RCVTIMEO, -1) # type: ignore[reportAttributeAccessIssue]
259
+
260
+ # Wait briefly to see if worker responds, but DON'T read from pub socket
261
+ # (execute_cell is already reading from it - we'd steal messages!)
262
+ # Instead, just wait a moment and force-kill if needed
263
+ print(f"[INTERRUPT] Waiting {self.interrupt_timeout}s before force-kill...", file=sys.stderr, flush=True)
264
+ await asyncio.sleep(self.interrupt_timeout)
265
+
266
+ # For blocking I/O operations, interrupt rarely works - just force-kill
267
+ # The interrupted_cell flag will let execute_cell break out gracefully
268
+ print(f"[INTERRUPT] Force killing worker to ensure stop...", file=sys.stderr, flush=True)
269
+ await self._force_kill_worker()
270
+ print(f"[INTERRUPT] Force kill completed", file=sys.stderr, flush=True)
271
+
272
+ # Interrupt special handler
273
+ if self.special_handler:
274
+ try:
275
+ await self.special_handler.interrupt()
276
+ except Exception:
277
+ pass
278
+
279
+ print(f"[INTERRUPT] Interrupt complete", file=sys.stderr, flush=True)
280
+
281
+ async def _force_kill_worker(self) -> None:
282
+ """Force kill the worker process and respawn"""
283
+ import sys
284
+ print(f"[FORCE_KILL] Killing worker PID={self.worker_pid}", file=sys.stderr, flush=True)
285
+
286
+ if self.worker_pid:
287
+ try:
288
+ # For blocking I/O, SIGKILL immediately - no mercy
289
+ print(f"[FORCE_KILL] Sending SIGKILL to {self.worker_pid}", file=sys.stderr, flush=True)
290
+ os.kill(self.worker_pid, signal.SIGKILL)
291
+ await asyncio.sleep(0.1) # Brief wait for process to die
292
+ except ProcessLookupError:
293
+ print(f"[FORCE_KILL] Process {self.worker_pid} already dead", file=sys.stderr, flush=True)
294
+ except Exception as e:
295
+ print(f"[FORCE_KILL] Error killing PID {self.worker_pid}: {e}", file=sys.stderr, flush=True)
296
+
297
+ # Also try via Popen object if available
298
+ if self.worker_proc:
299
+ try:
300
+ print(f"[FORCE_KILL] Killing via Popen object", file=sys.stderr, flush=True)
301
+ self.worker_proc.kill() # SIGKILL directly
302
+ await asyncio.sleep(0.1)
303
+ except Exception as e:
304
+ print(f"[FORCE_KILL] Error killing via Popen: {e}", file=sys.stderr, flush=True)
305
+
306
+ # CRITICAL: Reset socket state - close and recreate
307
+ # The REQ socket may be waiting for a reply from the dead worker
308
+ try:
309
+ print(f"[FORCE_KILL] Resetting REQ socket", file=sys.stderr, flush=True)
310
+ self.req.close(0) # type: ignore[reportAttributeAccessIssue]
311
+ self.req = self.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
312
+ self.req.connect(self.cmd_addr) # type: ignore[reportAttributeAccessIssue]
313
+ print(f"[FORCE_KILL] REQ socket reset complete", file=sys.stderr, flush=True)
314
+ except Exception as e:
315
+ print(f"[FORCE_KILL] Error resetting socket: {e}", file=sys.stderr, flush=True)
316
+
317
+ # Respawn worker
318
+ try:
319
+ self._ensure_worker()
320
+ except Exception:
321
+ pass
322
+
323
+ def reset_kernel(self) -> None:
324
+ """Reset the kernel by shutting down worker and restarting"""
325
+ # Try graceful shutdown first
326
+ try:
327
+ self.req.setsockopt(zmq.SNDTIMEO, 500) # type: ignore[reportAttributeAccessIssue]
328
+ self.req.setsockopt(zmq.RCVTIMEO, 500) # type: ignore[reportAttributeAccessIssue]
329
+ self.req.send_json({'type': 'shutdown'}) # type: ignore[reportAttributeAccessIssue]
330
+ _ = cast(dict[str, object], self.req.recv_json()) # type: ignore[reportAttributeAccessIssue]
331
+ except Exception:
332
+ pass
333
+ finally:
334
+ self.req.setsockopt(zmq.SNDTIMEO, -1) # type: ignore[reportAttributeAccessIssue]
335
+ self.req.setsockopt(zmq.RCVTIMEO, -1) # type: ignore[reportAttributeAccessIssue]
336
+
337
+ # Force kill if needed
338
+ if self.worker_pid:
339
+ try:
340
+ os.kill(self.worker_pid, signal.SIGTERM)
341
+ time.sleep(0.2)
342
+ try:
343
+ os.kill(self.worker_pid, 0)
344
+ os.kill(self.worker_pid, signal.SIGKILL)
345
+ except ProcessLookupError:
346
+ pass
347
+ except Exception:
348
+ pass
349
+
350
+ if self.worker_proc:
351
+ try:
352
+ self.worker_proc.terminate()
353
+ self.worker_proc.wait(timeout=1)
354
+ except Exception:
355
+ try:
356
+ self.worker_proc.kill()
357
+ except Exception:
358
+ pass
359
+
360
+ # Reset state
361
+ self.execution_count = 0
362
+ self.worker_pid = None
363
+ self.worker_proc = None
364
+
365
+ # Recreate sockets
366
+ try:
367
+ self.req.close(0) # type: ignore[reportAttributeAccessIssue]
368
+ self.req = self.ctx.socket(zmq.REQ) # type: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
369
+ self.req.connect(self.cmd_addr) # type: ignore[reportAttributeAccessIssue]
370
+ except Exception:
371
+ pass
372
+
373
+ # Reset special handler
374
+ if self.special_handler is not None:
375
+ self.special_handler = AsyncSpecialCommandHandler({"__name__": "__main__"})
376
+
377
+ # Respawn worker
378
+ try:
379
+ self._ensure_worker()
380
+ except Exception:
381
+ pass
@@ -0,0 +1,244 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ import signal
5
+ import base64
6
+ import io
7
+ import traceback
8
+ import zmq
9
+ import matplotlib
10
+ matplotlib.use('Agg')
11
+ import matplotlib.pyplot as plt
12
+ import re
13
+
14
+ def _setup_signals():
15
+ def _handler(signum, frame):
16
+ try:
17
+ sys.stdout.flush(); sys.stderr.flush()
18
+ except Exception:
19
+ pass
20
+ os._exit(0)
21
+ try:
22
+ signal.signal(signal.SIGTERM, _handler)
23
+ signal.signal(signal.SIGINT, signal.default_int_handler)
24
+ except Exception:
25
+ pass
26
+
27
+
28
+ class _StreamForwarder:
29
+ def __init__(self, pub, cell_index):
30
+ self.pub = pub
31
+ self.cell_index = cell_index
32
+ self.out_buf = []
33
+ self.err_buf = []
34
+
35
+ def write_out(self, text):
36
+ self._write('stdout', text)
37
+
38
+ def write_err(self, text):
39
+ self._write('stderr', text)
40
+
41
+ def _write(self, name, text):
42
+ if not text:
43
+ return
44
+ if '\r' in text and '\n' not in text:
45
+ self.pub.send_json({'type': 'stream_update', 'name': name, 'text': text.split('\r')[-1], 'cell_index': self.cell_index})
46
+ return
47
+ lines = text.split('\n')
48
+ buf = self.out_buf if name == 'stdout' else self.err_buf
49
+ for i, line in enumerate(lines):
50
+ if i < len(lines) - 1:
51
+ buf.append(line)
52
+ complete = ''.join(buf) + '\n'
53
+ self.pub.send_json({'type': 'stream', 'name': name, 'text': complete, 'cell_index': self.cell_index})
54
+ buf.clear()
55
+ else:
56
+ buf.append(line)
57
+
58
+ def flush(self):
59
+ if self.out_buf:
60
+ self.pub.send_json({'type': 'stream', 'name': 'stdout', 'text': ''.join(self.out_buf), 'cell_index': self.cell_index})
61
+ self.out_buf.clear()
62
+ if self.err_buf:
63
+ self.pub.send_json({'type': 'stream', 'name': 'stderr', 'text': ''.join(self.err_buf), 'cell_index': self.cell_index})
64
+ self.err_buf.clear()
65
+
66
+
67
+ def _capture_matplotlib(pub, cell_index):
68
+ try:
69
+ figs = plt.get_fignums()
70
+ for num in figs:
71
+ try:
72
+ fig = plt.figure(num)
73
+ buf = io.BytesIO()
74
+ fig.savefig(buf, format='png', bbox_inches='tight')
75
+ buf.seek(0)
76
+ b64 = base64.b64encode(buf.read()).decode('ascii')
77
+ pub.send_json({'type': 'display_data', 'data': {'image/png': b64}, 'cell_index': cell_index})
78
+ except Exception:
79
+ continue
80
+ try:
81
+ plt.close('all')
82
+ except Exception:
83
+ pass
84
+ except Exception:
85
+ return
86
+
87
+
88
+ def worker_main():
89
+ _setup_signals()
90
+ cmd_addr = os.environ['MC_ZMQ_CMD_ADDR']
91
+ pub_addr = os.environ['MC_ZMQ_PUB_ADDR']
92
+
93
+ ctx = zmq.Context.instance()
94
+ rep = ctx.socket(zmq.REP)
95
+ rep.bind(cmd_addr)
96
+ # Set timeout so we can check for signals during execution
97
+ rep.setsockopt(zmq.RCVTIMEO, 100) # 100ms timeout
98
+
99
+ pub = ctx.socket(zmq.PUB)
100
+ pub.bind(pub_addr)
101
+
102
+ # Persistent REPL state
103
+ g = {"__name__": "__main__"}
104
+ l = g
105
+ exec_count = 0
106
+
107
+ last_hb = time.time()
108
+ current_cell = None
109
+ shutdown_requested = False
110
+
111
+ while True:
112
+ try:
113
+ msg = rep.recv_json()
114
+ except zmq.Again:
115
+ # Timeout - check if we should send heartbeat
116
+ if time.time() - last_hb > 5.0:
117
+ pub.send_json({'type': 'heartbeat', 'ts': time.time()})
118
+ last_hb = time.time()
119
+ if shutdown_requested:
120
+ break
121
+ continue
122
+ except Exception:
123
+ if shutdown_requested:
124
+ break
125
+ continue
126
+ mtype = msg.get('type')
127
+ if mtype == 'ping':
128
+ rep.send_json({'ok': True, 'pid': os.getpid()})
129
+ continue
130
+ if mtype == 'shutdown':
131
+ rep.send_json({'ok': True, 'pid': os.getpid()})
132
+ shutdown_requested = True
133
+ # Don't break immediately - let the loop handle cleanup
134
+ continue
135
+ if mtype == 'interrupt':
136
+ requested = msg.get('cell_index') if isinstance(msg, dict) else None
137
+ if requested is None or requested == current_cell:
138
+ try:
139
+ os.kill(os.getpid(), signal.SIGINT)
140
+ except Exception:
141
+ pass
142
+ rep.send_json({'ok': True, 'pid': os.getpid()})
143
+ continue
144
+ if mtype == 'execute_cell':
145
+ code = msg.get('code', '')
146
+ cell_index = msg.get('cell_index')
147
+ requested_count = msg.get('execution_count')
148
+ current_cell = cell_index
149
+ if isinstance(requested_count, int):
150
+ exec_count = requested_count - 1
151
+ command_type = msg.get('command_type')
152
+ pub.send_json({'type': 'execution_start', 'cell_index': cell_index, 'execution_count': exec_count + 1})
153
+ # Redirect streams
154
+ sf = _StreamForwarder(pub, cell_index)
155
+ old_out, old_err = sys.stdout, sys.stderr
156
+ class _O:
157
+ def write(self, t): sf.write_out(t)
158
+ def flush(self): sf.flush()
159
+ class _E:
160
+ def write(self, t): sf.write_err(t)
161
+ def flush(self): sf.flush()
162
+ sys.stdout, sys.stderr = _O(), _E()
163
+ status = 'ok'
164
+ error_payload = None
165
+ start = time.time()
166
+ try:
167
+ if command_type == 'special':
168
+ # This path should be handled in-process; worker only handles python execution
169
+ exec_count += 1
170
+ pub.send_json({'type': 'execution_complete', 'cell_index': cell_index, 'result': {'status': 'ok', 'execution_count': exec_count, 'execution_time': '0.0ms', 'outputs': [], 'error': None}})
171
+ rep.send_json({'ok': True})
172
+ current_cell = None
173
+ continue
174
+ compiled = compile(code, '<cell>', 'exec')
175
+ exec(compiled, g, l)
176
+
177
+ # Try to evaluate last expression for display (like Jupyter)
178
+ lines = code.strip().split('\n')
179
+ if lines:
180
+ last = lines[-1].strip()
181
+ # Skip comments and empty lines
182
+ if last and not last.startswith('#'):
183
+ # Check if it looks like a statement (assignment, import, etc)
184
+ is_statement = False
185
+
186
+ # Check for assignment (but not comparison operators)
187
+ if '=' in last and not any(op in last for op in ['==', '!=', '<=', '>=', '=<', '=>']):
188
+ is_statement = True
189
+
190
+ # Check for statement keywords (handle both "assert x" and "assert(x)")
191
+ statement_keywords = ['import', 'from', 'def', 'class', 'if', 'elif', 'else',
192
+ 'for', 'while', 'try', 'except', 'finally', 'with',
193
+ 'assert', 'del', 'global', 'nonlocal', 'pass', 'break',
194
+ 'continue', 'return', 'raise', 'yield']
195
+
196
+ # Get first word, handling cases like "assert(...)" by splitting on non-alphanumeric
197
+ first_word_match = re.match(r'^(\w+)', last)
198
+ first_word = first_word_match.group(1) if first_word_match else ''
199
+
200
+ if first_word in statement_keywords:
201
+ is_statement = True
202
+
203
+ # Don't eval function calls - they were already executed by exec()
204
+ # This prevents double execution of code like: what()
205
+ if '(' in last and ')' in last:
206
+ is_statement = True
207
+
208
+ if not is_statement:
209
+ try:
210
+ res = eval(last, g, l)
211
+ if res is not None:
212
+ pub.send_json({'type': 'execute_result', 'cell_index': cell_index, 'execution_count': exec_count + 1, 'data': {'text/plain': repr(res)}})
213
+ except Exception as e:
214
+ print(f"[WORKER] Failed to eval last expression '{last[:50]}...': {e}", file=sys.stderr, flush=True)
215
+
216
+ _capture_matplotlib(pub, cell_index)
217
+ except KeyboardInterrupt:
218
+ status = 'error'
219
+ error_payload = {'ename': 'KeyboardInterrupt', 'evalue': 'Execution interrupted by user', 'traceback': []}
220
+ except Exception as exc:
221
+ status = 'error'
222
+ error_payload = {'ename': type(exc).__name__, 'evalue': str(exc), 'traceback': traceback.format_exc().split('\n')}
223
+ finally:
224
+ sys.stdout, sys.stderr = old_out, old_err
225
+ exec_count += 1
226
+ duration_ms = f"{(time.time()-start)*1000:.1f}ms"
227
+ if error_payload:
228
+ pub.send_json({'type': 'execution_error', 'cell_index': cell_index, 'error': error_payload})
229
+ pub.send_json({'type': 'execution_complete', 'cell_index': cell_index, 'result': {'status': status, 'execution_count': exec_count, 'execution_time': duration_ms, 'outputs': [], 'error': error_payload}})
230
+ rep.send_json({'ok': True, 'pid': os.getpid()})
231
+ current_cell = None
232
+
233
+ try:
234
+ rep.close(0); pub.close(0)
235
+ except Exception:
236
+ pass
237
+ try:
238
+ ctx.term()
239
+ except Exception:
240
+ pass
241
+
242
+
243
+ if __name__ == '__main__':
244
+ worker_main()