interruptible-threading 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,859 @@
1
+ """
2
+ Cooperative thread interruption for CPython.
3
+
4
+ Exposes an :class:`InterruptibleThread` whose :meth:`~InterruptibleThread.interrupt`
5
+ method raises an exception (``ThreadInterrupted`` by default) inside the target
6
+ thread. Pure-Python code is interrupted via ``PyThreadState_SetAsyncExc`` (an async
7
+ exception that fires at the next bytecode boundary). Because async exceptions cannot
8
+ break a thread that is parked in a C-level blocking call, ``install_patches()``
9
+ monkeypatches a curated set of stdlib blocking primitives (``time.sleep``,
10
+ ``selectors.DefaultSelector``, ``select.select``, ``threading.Condition.wait``) so
11
+ they wake promptly and re-check a durable per-thread "interrupt pending" flag.
12
+
13
+ The flag is the single source of truth: ``interrupt()`` sets it first (under one
14
+ lock) and then issues a wakeup nudge; every blocking primitive checks the flag
15
+ before parking and again after waking. This closes the races inherent in choosing a
16
+ delivery path from transient state, and lets interrupts be *masked* during critical
17
+ sections (``critical_section()``) and *polled* in CPU-bound loops (``check_interrupt()``).
18
+
19
+ Requires monkey patching some stdlib functionality via
20
+ ``InterruptibleThread.install_patches()``. CPython only; POSIX (Linux / Darwin).
21
+
22
+ Why not signals: ``signal.pthread_kill`` can unblock a syscall but cannot deliver an
23
+ exception to a worker thread -- CPython runs Python-level signal handlers only on the
24
+ main thread, and PEP 475's EINTR auto-retry loops call ``PyErr_CheckSignals()``
25
+ (a no-op off the main thread) without consulting ``tstate->async_exc``, so the
26
+ syscall is transparently retried. The self-pipe + cooperative-primitive approach is
27
+ the only way to get prompt, exception-bearing interruption of worker threads.
28
+ """
29
+ from __future__ import annotations
30
+
31
+ import contextlib
32
+ import ctypes
33
+ import os
34
+ import select
35
+ import selectors
36
+ import socket as _socket
37
+ import sys
38
+ import threading
39
+ import time
40
+ from typing import TYPE_CHECKING, Any
41
+
42
+ if TYPE_CHECKING:
43
+ from collections.abc import Iterator
44
+ from io import IOBase
45
+
46
+ __all__ = [
47
+ "InterruptibleThread",
48
+ "ThreadInterrupted",
49
+ "is_interrupted",
50
+ "clear_interrupt",
51
+ "check_interrupt",
52
+ "interruptible_checkpoint",
53
+ "periodic_checkpoint",
54
+ "critical_section",
55
+ "interrupts_disabled",
56
+ "interruptible_recv",
57
+ "interruptible_send",
58
+ "interruptible_accept",
59
+ "set_poll_interval",
60
+ ]
61
+
62
+
63
+ class ThreadInterrupted(BaseException):
64
+ """Raised inside a thread when another thread calls ``.interrupt()`` on it.
65
+
66
+ Subclasses ``BaseException`` (not ``Exception``) so that ordinary
67
+ ``except Exception:`` handlers do not accidentally swallow it -- mirroring
68
+ ``KeyboardInterrupt`` and ``asyncio.CancelledError``.
69
+ """
70
+
71
+
72
+ class _ThreadInterruptedKeyboard(ThreadInterrupted, KeyboardInterrupt):
73
+ """Legacy-compatible interrupt: caught by both ``ThreadInterrupted`` and
74
+ ``KeyboardInterrupt`` handlers. Used when ``legacy_keyboardinterrupt=True``."""
75
+
76
+
77
+ # The exception class delivered by interrupt(); swapped by install_patches().
78
+ _INTERRUPT_EXC: type[BaseException] = ThreadInterrupted
79
+
80
+ _PTSSE = ctypes.pythonapi.PyThreadState_SetAsyncExc
81
+ _PTSSE.argtypes = [ctypes.c_ulong, ctypes.py_object]
82
+ _PTSSE.restype = ctypes.c_int
83
+ _MAIN_IDENT = threading.main_thread().ident
84
+
85
+ _ORIG_SLEEP = time.sleep
86
+ _ORIG_DEFAULT_SELECTOR = selectors.DefaultSelector
87
+ _ORIG_SELECT = select.select
88
+ _ORIG_THREAD = threading.Thread
89
+ _ORIG_COND_WAIT = threading.Condition.wait
90
+
91
+ _WAKEUP_TOKEN = "__interruptible_wakeup__"
92
+
93
+ # Max latency (seconds) for chunked-poll primitives (Condition.wait / Event / Queue).
94
+ _POLL_INTERVAL = 0.05
95
+
96
+
97
+ def set_poll_interval(seconds: float) -> None:
98
+ """Tune the wakeup latency of the chunked-poll blocking primitives."""
99
+ global _POLL_INTERVAL
100
+ _POLL_INTERVAL = float(seconds)
101
+
102
+
103
+ def _clear_async_exc(tid: int | None = None) -> None:
104
+ """Clear any async exception armed for ``tid`` (default: current thread).
105
+
106
+ Called right before raising on a cooperative path so an exception armed by
107
+ ``interrupt()`` does not fire a second time at the next bytecode boundary.
108
+ """
109
+ # An empty py_object() is a NULL pointer, which tells SetAsyncExc to *clear*
110
+ # the pending async exception. Passing Python ``None`` would instead arm None
111
+ # as the exception (raising SystemError when later delivered).
112
+ _PTSSE(
113
+ ctypes.c_ulong(tid if tid is not None else threading.get_ident()),
114
+ ctypes.py_object(),
115
+ )
116
+
117
+
118
+ def _tracing_active() -> bool:
119
+ """Whether a trace/profile hook is installed (coverage, a debugger, a
120
+ profiler) anywhere we can observe.
121
+
122
+ ``PyThreadState_SetAsyncExc`` can permanently wedge a thread that is doing
123
+ lock operations while such a hook is active (a CPython interaction), so async
124
+ injection -- the only delivery path that uses it -- is skipped while tracing.
125
+ Cooperative delivery via the durable ``pending`` flag is unaffected.
126
+ """
127
+ if sys.gettrace() is not None or sys.getprofile() is not None:
128
+ return True
129
+ # threading.gettrace/getprofile (3.10+) expose the global hook coverage
130
+ # installs for worker threads; guarded for 3.9 where they don't exist.
131
+ for name in ("gettrace", "getprofile"):
132
+ getter = getattr(threading, name, None)
133
+ if getter is not None and getter() is not None:
134
+ return True
135
+ return False
136
+
137
+
138
+ def _drain(fd: int) -> None:
139
+ """Drain a non-blocking wakeup pipe until empty."""
140
+ while True:
141
+ try:
142
+ if not os.read(fd, 65536):
143
+ break
144
+ except (BlockingIOError, OSError):
145
+ break
146
+
147
+
148
+ class _State:
149
+ registry_lock = threading.RLock()
150
+ registry: dict[int, _State] = {}
151
+
152
+ def __init__(self) -> None:
153
+ cur_thread = threading.current_thread()
154
+ if not isinstance(cur_thread, InterruptibleThread):
155
+ raise TypeError(
156
+ f"current thread should be of type {InterruptibleThread.__name__}"
157
+ )
158
+ self.thread: InterruptibleThread = cur_thread
159
+ self.children: set[int] = set()
160
+ # cancel_cond's lock is the single per-thread mutex guarding all the
161
+ # mutable interrupt state below.
162
+ self.cancel_cond = threading.Condition()
163
+ # Durable source of truth: an interrupt has been requested, not yet delivered.
164
+ self.pending = False
165
+ self.interrupt_gen = 0
166
+ self.mask_depth = 0
167
+ # True only while an async exception has actually been armed for this
168
+ # thread via PyThreadState_SetAsyncExc. Gates the ctypes "clear" call so
169
+ # the cooperative paths never touch SetAsyncExc -- calling it while a
170
+ # sys.settrace tracer (e.g. coverage) is active can deadlock the thread.
171
+ self.async_armed = False
172
+ # Hints used only to pick a wakeup nudge; never the source of truth.
173
+ self.sleeping = False
174
+ self.selecting = False
175
+ # asyncio integration (set by run_interruptible).
176
+ self.event_loop: Any = None
177
+ self.root_task: Any = None
178
+ # Lazily-allocated self-pipe; only needed on the selector/select path.
179
+ self.rfd = -1
180
+ self.wfd = -1
181
+
182
+ def ensure_pipe(self) -> None:
183
+ """Allocate the self-pipe on first use (idempotent)."""
184
+ with self.cancel_cond:
185
+ if self.rfd == -1:
186
+ r, w = os.pipe()
187
+ os.set_blocking(r, False)
188
+ # The write end must also be non-blocking: ``_pipe_write`` runs
189
+ # while holding ``cancel_cond``, and a blocking ``os.write`` to a
190
+ # full pipe would deadlock. A full pipe already means a wakeup is
191
+ # pending, so dropping the extra byte (EAGAIN) is correct.
192
+ os.set_blocking(w, False)
193
+ self.rfd, self.wfd = r, w
194
+
195
+ @classmethod
196
+ def register_current_thread(cls) -> _State:
197
+ tid = threading.get_ident()
198
+ with cls.registry_lock:
199
+ st = cls.registry.get(tid)
200
+ if st is None:
201
+ st = cls()
202
+ cls.registry[tid] = st
203
+ return st
204
+
205
+ @classmethod
206
+ def get_state_by_ident(cls, tid: int | None = None) -> _State | None:
207
+ return cls.registry.get(tid if tid is not None else threading.get_ident())
208
+
209
+ @classmethod
210
+ def unregister_current_thread(cls) -> None:
211
+ tid = threading.get_ident()
212
+ with cls.registry_lock:
213
+ st = cls.registry.pop(tid, None)
214
+ if not st:
215
+ return
216
+ for fd in (st.rfd, st.wfd):
217
+ if fd == -1:
218
+ continue
219
+ try:
220
+ os.close(fd)
221
+ except OSError:
222
+ pass
223
+ st.rfd = st.wfd = -1
224
+
225
+
226
+ def _disarm_async(st: _State) -> None:
227
+ """Clear an armed async exception for the current thread, but only if one was
228
+ actually armed. ``PyThreadState_SetAsyncExc`` must not be called speculatively:
229
+ invoking it while a ``sys.settrace`` tracer (e.g. coverage) is active can wedge
230
+ the thread, so the cooperative paths -- which never arm one -- must skip it.
231
+
232
+ Must be called holding ``st.cancel_cond``.
233
+ """
234
+ if st.async_armed:
235
+ st.async_armed = False
236
+ _clear_async_exc()
237
+
238
+
239
+ def _take_pending(st: _State) -> bool:
240
+ """If an unmasked interrupt is pending, consume it and clear any armed async
241
+ exception, returning True (caller should raise). Otherwise return False.
242
+
243
+ Must be called holding ``st.cancel_cond``.
244
+ """
245
+ if st.pending and st.mask_depth == 0:
246
+ st.pending = False
247
+ _disarm_async(st)
248
+ return True
249
+ return False
250
+
251
+
252
+ # --------------------------------------------------------------------------- #
253
+ # Public cooperative API (operates on the *current* thread) #
254
+ # --------------------------------------------------------------------------- #
255
+
256
+
257
+ def is_interrupted(thread: InterruptibleThread | None = None) -> bool:
258
+ """Return whether an interrupt is pending for ``thread`` (default: current).
259
+
260
+ Non-consuming; safe to call after the thread has exited (returns ``False``).
261
+ """
262
+ tid = thread.ident if thread is not None else threading.get_ident()
263
+ st = _State.get_state_by_ident(tid)
264
+ if st is None:
265
+ return False
266
+ with st.cancel_cond:
267
+ return st.pending
268
+
269
+
270
+ def clear_interrupt() -> bool:
271
+ """Consume any pending interrupt for the current thread without raising.
272
+
273
+ Returns whether one was pending (Java's ``Thread.interrupted()`` semantics).
274
+
275
+ Call this after *catching* ``ThreadInterrupted`` when you intend to keep
276
+ running: the pending flag is durable (so an interrupt is never lost if the
277
+ thread parks in a blocking call before async injection can fire), so without
278
+ clearing it the next checkpoint or blocking primitive would re-raise.
279
+ """
280
+ st = _State.get_state_by_ident()
281
+ if st is None:
282
+ return False
283
+ with st.cancel_cond:
284
+ prev = st.pending
285
+ st.pending = False
286
+ if prev:
287
+ _disarm_async(st)
288
+ return prev
289
+
290
+
291
+ def check_interrupt() -> None:
292
+ """Raise the interrupt exception if one is pending and not masked.
293
+
294
+ Cheap; intended for CPU-bound loops and around opaque C calls so they become
295
+ interruptible while still honoring ``critical_section()``.
296
+ """
297
+ st = _State.get_state_by_ident()
298
+ if st is None:
299
+ return
300
+ with st.cancel_cond:
301
+ if _take_pending(st):
302
+ raise _INTERRUPT_EXC()
303
+
304
+
305
+ # Alias.
306
+ interruptible_checkpoint = check_interrupt
307
+
308
+
309
+ @contextlib.contextmanager
310
+ def periodic_checkpoint(every: int = 1000) -> Iterator[_Ticker]:
311
+ """Yield a ticker whose ``.tick()`` runs :func:`check_interrupt` every ``every``
312
+ calls, amortizing the per-iteration cost in a tight loop::
313
+
314
+ with periodic_checkpoint(every=1000) as ck:
315
+ for item in huge_iterable:
316
+ ck.tick()
317
+ crunch(item)
318
+ """
319
+ yield _Ticker(every)
320
+
321
+
322
+ class _Ticker:
323
+ __slots__ = ("_every", "_n")
324
+
325
+ def __init__(self, every: int) -> None:
326
+ self._every = max(1, int(every))
327
+ self._n = 0
328
+
329
+ def tick(self) -> bool:
330
+ """Count one iteration; every Nth call runs check_interrupt(). Returns
331
+ whether the check ran this call."""
332
+ self._n += 1
333
+ if self._n >= self._every:
334
+ self._n = 0
335
+ check_interrupt()
336
+ return True
337
+ return False
338
+
339
+
340
+ @contextlib.contextmanager
341
+ def interrupts_disabled() -> Iterator[None]:
342
+ """Defer interrupt delivery to the current thread for the duration of the block.
343
+
344
+ An interrupt that arrives while masked is recorded and delivered exactly once
345
+ when the outermost mask exits (re-entrant). Reliable for the flag-driven paths
346
+ (sleep / select / Condition / checkpoints); an async exception already in flight
347
+ from ``SetAsyncExc`` microseconds before entering cannot be recalled.
348
+ """
349
+ st = _State.get_state_by_ident()
350
+ if st is None:
351
+ yield
352
+ return
353
+ with st.cancel_cond:
354
+ st.mask_depth += 1
355
+ try:
356
+ yield
357
+ finally:
358
+ raise_now = False
359
+ with st.cancel_cond:
360
+ st.mask_depth -= 1
361
+ if st.mask_depth == 0 and st.pending:
362
+ st.pending = False
363
+ _disarm_async(st)
364
+ raise_now = True
365
+ if raise_now:
366
+ raise _INTERRUPT_EXC()
367
+
368
+
369
+ # Alias.
370
+ critical_section = interrupts_disabled
371
+
372
+
373
+ # --------------------------------------------------------------------------- #
374
+ # Patched blocking primitives #
375
+ # --------------------------------------------------------------------------- #
376
+
377
+
378
+ def _coop_sleep(secs: float) -> None:
379
+ tid = threading.get_ident()
380
+ if tid == _MAIN_IDENT:
381
+ return _ORIG_SLEEP(secs)
382
+ st = _State.get_state_by_ident(tid)
383
+ if st is None:
384
+ return _ORIG_SLEEP(secs)
385
+ deadline = time.monotonic() + secs
386
+ with st.cancel_cond:
387
+ while True:
388
+ if _take_pending(st):
389
+ raise _INTERRUPT_EXC()
390
+ remaining = deadline - time.monotonic()
391
+ if remaining <= 0:
392
+ return
393
+ st.sleeping = True
394
+ try:
395
+ _ORIG_COND_WAIT(st.cancel_cond, remaining)
396
+ finally:
397
+ st.sleeping = False
398
+ # Loop: an unmasked pending flag raises at the top; a masked one keeps
399
+ # sleeping out the remaining time; spurious wakeups simply re-park.
400
+
401
+
402
+ def _patched_cond_wait(self: threading.Condition, timeout: float | None = None) -> bool:
403
+ tid = threading.get_ident()
404
+ if tid == _MAIN_IDENT:
405
+ return _ORIG_COND_WAIT(self, timeout)
406
+ st = _State.get_state_by_ident(tid)
407
+ if st is None:
408
+ return _ORIG_COND_WAIT(self, timeout)
409
+ deadline = None if timeout is None else time.monotonic() + timeout
410
+ while True:
411
+ with st.cancel_cond:
412
+ if _take_pending(st):
413
+ raise _INTERRUPT_EXC()
414
+ if deadline is None:
415
+ chunk = _POLL_INTERVAL
416
+ else:
417
+ remaining = deadline - time.monotonic()
418
+ if remaining <= 0:
419
+ return False
420
+ chunk = min(_POLL_INTERVAL, remaining)
421
+ if _ORIG_COND_WAIT(self, chunk):
422
+ return True
423
+ # Timed-out chunk: loop to re-check the pending flag.
424
+
425
+
426
+ class _InterruptibleSelector(_ORIG_DEFAULT_SELECTOR):
427
+ """DefaultSelector that also watches the current thread's self-pipe, so an
428
+ interrupt can wake a parked ``select`` (this is what makes asyncio interruptible).
429
+ """
430
+
431
+ def __init__(self) -> None:
432
+ super().__init__()
433
+ st = self._ithr_st = _State.get_state_by_ident()
434
+ if st is not None:
435
+ st.ensure_pipe()
436
+ try:
437
+ super().register(st.rfd, selectors.EVENT_READ, data=_WAKEUP_TOKEN)
438
+ except (KeyError, ValueError):
439
+ pass
440
+
441
+ def select(self, timeout: float | None = None):
442
+ st = getattr(self, "_ithr_st", None)
443
+ if st is None:
444
+ return super().select(timeout)
445
+ with st.cancel_cond:
446
+ if _take_pending(st):
447
+ raise _INTERRUPT_EXC()
448
+ st.selecting = True
449
+ try:
450
+ events = super().select(timeout)
451
+ finally:
452
+ with st.cancel_cond:
453
+ st.selecting = False
454
+ out = []
455
+ for key, mask in events:
456
+ if key.data == _WAKEUP_TOKEN:
457
+ _drain(key.fd)
458
+ continue
459
+ out.append((key, mask))
460
+ with st.cancel_cond:
461
+ if _take_pending(st):
462
+ raise _INTERRUPT_EXC()
463
+ return out
464
+
465
+
466
+ def _fileno(x: IOBase | int) -> int:
467
+ try:
468
+ return x.fileno() # type: ignore[union-attr]
469
+ except AttributeError:
470
+ return int(x) # type: ignore[arg-type]
471
+
472
+
473
+ def _patched_select(
474
+ rlist: list[IOBase | int],
475
+ wlist: list[IOBase | int],
476
+ xlist: list[IOBase | int],
477
+ timeout: float | None = None,
478
+ ) -> tuple[list[IOBase | int], list[IOBase | int], list[IOBase | int]]:
479
+ tid = threading.get_ident()
480
+ if tid == _MAIN_IDENT:
481
+ return _ORIG_SELECT(rlist, wlist, xlist, timeout)
482
+ st = _State.get_state_by_ident(tid)
483
+ if st is None:
484
+ return _ORIG_SELECT(rlist, wlist, xlist, timeout)
485
+
486
+ st.ensure_pipe()
487
+ rfd = st.rfd
488
+
489
+ def to_fd_list(lst: list[IOBase | int]) -> list[tuple[int, IOBase | int]]:
490
+ return [(_fileno(o), o) for o in lst]
491
+
492
+ rmap = to_fd_list(rlist)
493
+ wmap = to_fd_list(wlist)
494
+ xmap = to_fd_list(xlist)
495
+
496
+ rfd_list = [fd for fd, _ in rmap]
497
+ wfd_list = [fd for fd, _ in wmap]
498
+ xfd_list = [fd for fd, _ in xmap]
499
+
500
+ if rfd not in rfd_list:
501
+ rfd_list.append(rfd)
502
+
503
+ with st.cancel_cond:
504
+ if _take_pending(st):
505
+ raise _INTERRUPT_EXC()
506
+ st.selecting = True
507
+ try:
508
+ rr, ww, xx = _ORIG_SELECT(rfd_list, wfd_list, xfd_list, timeout)
509
+ finally:
510
+ with st.cancel_cond:
511
+ st.selecting = False
512
+
513
+ if rfd in rr:
514
+ _drain(rfd)
515
+ with st.cancel_cond:
516
+ if _take_pending(st):
517
+ raise _INTERRUPT_EXC()
518
+
519
+ def map_back(
520
+ fd_list: list[int], fmap: list[tuple[int, IOBase | int]]
521
+ ) -> list[IOBase | int]:
522
+ fds = set(fd_list)
523
+ return [obj for fd, obj in fmap if fd in fds]
524
+
525
+ return (
526
+ map_back(rr, rmap),
527
+ map_back(ww, wmap),
528
+ map_back(xx, xmap),
529
+ )
530
+
531
+
532
+ # --------------------------------------------------------------------------- #
533
+ # Interruptible socket helpers (opt-in; reuse the self-pipe directly) #
534
+ # --------------------------------------------------------------------------- #
535
+
536
+
537
+ def _interruptible_io(sock: _socket.socket, want_write: bool, op):
538
+ tid = threading.get_ident()
539
+ st = _State.get_state_by_ident(tid)
540
+ if st is None or tid == _MAIN_IDENT:
541
+ return op()
542
+ st.ensure_pipe()
543
+ prev_timeout = sock.gettimeout()
544
+ sock.setblocking(False)
545
+ try:
546
+ while True:
547
+ with st.cancel_cond:
548
+ if _take_pending(st):
549
+ raise _INTERRUPT_EXC()
550
+ st.selecting = True
551
+ try:
552
+ if want_write:
553
+ _rr, ww, _xx = _ORIG_SELECT([st.rfd], [sock], [])
554
+ ready = sock in ww
555
+ else:
556
+ rr, _ww, _xx = _ORIG_SELECT([sock, st.rfd], [], [])
557
+ ready = sock in rr
558
+ finally:
559
+ with st.cancel_cond:
560
+ st.selecting = False
561
+ _drain(st.rfd)
562
+ with st.cancel_cond:
563
+ if _take_pending(st):
564
+ raise _INTERRUPT_EXC()
565
+ if ready:
566
+ try:
567
+ return op()
568
+ except (BlockingIOError, InterruptedError):
569
+ continue
570
+ finally:
571
+ sock.settimeout(prev_timeout)
572
+
573
+
574
+ def interruptible_recv(sock: _socket.socket, bufsize: int, flags: int = 0) -> bytes:
575
+ """``sock.recv`` that raises the interrupt exception if interrupted while blocked."""
576
+ return _interruptible_io(sock, False, lambda: sock.recv(bufsize, flags))
577
+
578
+
579
+ def interruptible_send(sock: _socket.socket, data: bytes, flags: int = 0) -> int:
580
+ """``sock.send`` that raises the interrupt exception if interrupted while blocked."""
581
+ return _interruptible_io(sock, True, lambda: sock.send(data, flags))
582
+
583
+
584
+ def interruptible_accept(sock: _socket.socket):
585
+ """``sock.accept`` that raises the interrupt exception if interrupted while blocked."""
586
+ return _interruptible_io(sock, False, sock.accept)
587
+
588
+
589
+ _ORIG_SOCK_RECV = _socket.socket.recv
590
+ _ORIG_SOCK_SEND = _socket.socket.send
591
+ _ORIG_SOCK_ACCEPT = _socket.socket.accept
592
+
593
+
594
+ def _sock_recv_patch(self, bufsize, flags=0):
595
+ return interruptible_recv(self, bufsize, flags)
596
+
597
+
598
+ def _sock_send_patch(self, data, flags=0):
599
+ return interruptible_send(self, data, flags)
600
+
601
+
602
+ def _sock_accept_patch(self):
603
+ return interruptible_accept(self)
604
+
605
+
606
+ # --------------------------------------------------------------------------- #
607
+ # The thread class #
608
+ # --------------------------------------------------------------------------- #
609
+
610
+
611
+ class InterruptibleThread(_ORIG_THREAD):
612
+ """Thread with a built-in cooperative :meth:`interrupt`."""
613
+
614
+ _patches_installed = False
615
+ _socket_patched = False
616
+
617
+ def __init__(self, *args, **kwargs) -> None:
618
+ super().__init__(*args, **kwargs)
619
+ self._parent_tid = threading.get_ident()
620
+
621
+ def run(self) -> None:
622
+ with _State.registry_lock:
623
+ _State.register_current_thread()
624
+ parent_st = _State.get_state_by_ident(self._parent_tid)
625
+ if parent_st is not None:
626
+ parent_st.children.add(threading.get_ident())
627
+ try:
628
+ super().run()
629
+ finally:
630
+ my_tid = threading.get_ident()
631
+ with _State.registry_lock:
632
+ parent_st = _State.get_state_by_ident(self._parent_tid)
633
+ if parent_st is not None:
634
+ parent_st.children.discard(my_tid)
635
+ st = _State.get_state_by_ident(my_tid)
636
+ if st is not None:
637
+ with st.cancel_cond:
638
+ _State.unregister_current_thread()
639
+
640
+ @classmethod
641
+ def run_interruptible(cls, coro):
642
+ """Run a coroutine via ``asyncio.run`` such that ``interrupt()`` cancels the
643
+ event loop's tasks cooperatively (clean ``finally``/``async with`` unwind)
644
+ instead of injecting an exception into the selector. Returns the coroutine's
645
+ result; re-raises the interrupt exception if interrupted."""
646
+ import asyncio
647
+
648
+ st = _State.get_state_by_ident()
649
+
650
+ async def _runner():
651
+ if st is not None:
652
+ with st.cancel_cond:
653
+ st.event_loop = asyncio.get_running_loop()
654
+ st.root_task = asyncio.current_task()
655
+ try:
656
+ return await coro
657
+ finally:
658
+ if st is not None:
659
+ with st.cancel_cond:
660
+ st.event_loop = None
661
+ st.root_task = None
662
+
663
+ try:
664
+ return asyncio.run(_runner())
665
+ except asyncio.CancelledError:
666
+ if st is not None:
667
+ with st.cancel_cond:
668
+ if _take_pending(st):
669
+ raise _INTERRUPT_EXC() from None
670
+ raise
671
+
672
+ def _inject_exc(self) -> None:
673
+ tid = ctypes.c_ulong(self.ident or 0)
674
+ rc = _PTSSE(tid, ctypes.py_object(_INTERRUPT_EXC))
675
+ if rc == 0:
676
+ raise ValueError("no such thread")
677
+ elif rc > 1:
678
+ _PTSSE(tid, ctypes.py_object())
679
+ raise SystemError("SetAsyncExc affected multiple threads")
680
+
681
+ @staticmethod
682
+ def _pipe_write(st: _State) -> None:
683
+ if st.wfd != -1:
684
+ try:
685
+ os.write(st.wfd, b"\x00")
686
+ except OSError:
687
+ pass
688
+
689
+ def _nudge(self, st: _State) -> None:
690
+ """Wake the target out of any current park without arming an async exception.
691
+ Must be called holding ``st.cancel_cond``."""
692
+ st.cancel_cond.notify_all()
693
+ self._pipe_write(st)
694
+
695
+ def _cancel_event_loop(self, st: _State) -> None:
696
+ loop, task = st.event_loop, st.root_task
697
+
698
+ def _cancel():
699
+ if task is not None and not task.done():
700
+ task.cancel()
701
+
702
+ try:
703
+ loop.call_soon_threadsafe(_cancel)
704
+ except RuntimeError:
705
+ # Loop already closed; fall back to async injection.
706
+ self._inject_exc()
707
+
708
+ def interrupt(self, recursive: bool = False) -> None:
709
+ st = _State.get_state_by_ident(self.ident)
710
+ if st is None:
711
+ raise ValueError("no such thread")
712
+
713
+ if recursive:
714
+ with _State.registry_lock:
715
+ children = list(st.children)
716
+ for child_tid in children:
717
+ with _State.registry_lock:
718
+ child_st = _State.get_state_by_ident(child_tid)
719
+ child_thread = child_st.thread if child_st is not None else None
720
+ if (
721
+ child_thread is None
722
+ or child_thread.ident != child_tid
723
+ or not child_thread.is_alive()
724
+ ):
725
+ continue
726
+ try:
727
+ child_thread.interrupt(recursive=True)
728
+ except ValueError:
729
+ continue
730
+
731
+ with st.cancel_cond:
732
+ if _State.get_state_by_ident(self.ident) is None:
733
+ # Thread terminated while we waited on the lock; nothing to do.
734
+ return
735
+ if st.pending:
736
+ # Already requested -- idempotent. Re-nudge in case a prior wakeup
737
+ # was missed; never re-arm async injection.
738
+ self._nudge(st)
739
+ return
740
+ st.pending = True
741
+ st.interrupt_gen += 1
742
+
743
+ if st.event_loop is not None:
744
+ # Running an asyncio loop: cancel tasks cooperatively.
745
+ self._cancel_event_loop(st)
746
+ elif st.mask_depth > 0:
747
+ # Masked: record + unblock so it loops to its mask check; the
748
+ # pending flag is delivered when the mask exits.
749
+ self._nudge(st)
750
+ elif st.sleeping:
751
+ st.cancel_cond.notify_all()
752
+ elif st.selecting and st.wfd != -1:
753
+ self._pipe_write(st)
754
+ else:
755
+ # Pure-Python execution, or a thread that has *decided* to block
756
+ # but not yet committed (so the sleeping/selecting hints still
757
+ # read False -- common under tracing/coverage). Async-inject to
758
+ # break a CPU loop, and also nudge every cooperative channel in
759
+ # case the thread is entering a primitive right now.
760
+ #
761
+ # Crucially, do NOT clear `pending`. It is the durable source of
762
+ # truth: if async injection cannot fire (the thread parks in a
763
+ # C-level wait before the next bytecode boundary), the primitive
764
+ # still sees `pending` on its pre-block / post-wake check and
765
+ # raises. The cooperative consumer (`_take_pending`) clears the
766
+ # flag and disarms the async exception together, so delivery is
767
+ # exactly-once via whichever path wins. (A thread that *catches*
768
+ # an async-injected interrupt and continues must call
769
+ # `clear_interrupt()` -- see its docstring.)
770
+ self._nudge(st)
771
+ # Async injection uses PyThreadState_SetAsyncExc, which can
772
+ # deadlock the target under an active trace hook (coverage /
773
+ # debuggers); skip it then. Any thread that reaches a checkpoint
774
+ # or blocking primitive is still delivered via the durable flag;
775
+ # only a checkpoint-less pure-Python loop is left uninterruptible
776
+ # while tracing (a narrow, documented gap).
777
+ if not _tracing_active():
778
+ self._inject_exc()
779
+ st.async_armed = True
780
+
781
+ @classmethod
782
+ def get_thread_cls_for_current_thread(cls, item: str) -> type[threading.Thread]:
783
+ if item != "Thread":
784
+ raise AttributeError("No attribute %s in module threading" % item)
785
+ st = _State.get_state_by_ident()
786
+ if st is None:
787
+ return _ORIG_THREAD
788
+ else:
789
+ return cls
790
+
791
+ @classmethod
792
+ def install_patches(
793
+ cls,
794
+ interrupt_exc: type[BaseException] = ThreadInterrupted,
795
+ legacy_keyboardinterrupt: bool = False,
796
+ monkeypatch_socket: bool = False,
797
+ ) -> None:
798
+ """Install the stdlib monkeypatches that make blocking calls interruptible.
799
+
800
+ ``interrupt_exc``: the exception class delivered by ``interrupt()``
801
+ (default ``ThreadInterrupted``; pass ``KeyboardInterrupt`` for legacy code).
802
+ ``legacy_keyboardinterrupt``: deliver an exception caught by *both*
803
+ ``ThreadInterrupted`` and ``KeyboardInterrupt`` handlers.
804
+ ``monkeypatch_socket``: also swap blocking ``socket.recv/send/accept`` for
805
+ their interruptible variants process-wide (off by default; large blast radius).
806
+ """
807
+ global _INTERRUPT_EXC
808
+ if cls._patches_installed:
809
+ raise ValueError("patches already installed")
810
+ cls._patches_installed = True
811
+
812
+ if legacy_keyboardinterrupt:
813
+ _INTERRUPT_EXC = _ThreadInterruptedKeyboard
814
+ else:
815
+ _INTERRUPT_EXC = interrupt_exc
816
+
817
+ time.sleep = _coop_sleep # type: ignore
818
+ selectors.DefaultSelector = _InterruptibleSelector # type: ignore
819
+ select.select = _patched_select # type: ignore
820
+ threading.Condition.wait = _patched_cond_wait # type: ignore
821
+
822
+ if monkeypatch_socket:
823
+ cls._socket_patched = True
824
+ _socket.socket.recv = _sock_recv_patch # type: ignore
825
+ _socket.socket.send = _sock_send_patch # type: ignore
826
+ _socket.socket.accept = _sock_accept_patch # type: ignore
827
+
828
+ del threading.Thread
829
+ threading.__getattr__ = cls.get_thread_cls_for_current_thread # type: ignore
830
+
831
+ def __dir__():
832
+ return list(threading.__dict__.keys()) + [_ORIG_THREAD.__name__]
833
+
834
+ threading.__dir__ = __dir__ # type: ignore
835
+
836
+ @classmethod
837
+ def uninstall_patches(cls) -> None:
838
+ global _INTERRUPT_EXC
839
+ if not cls._patches_installed:
840
+ raise ValueError("patches not installed")
841
+ cls._patches_installed = False
842
+ time.sleep = _ORIG_SLEEP
843
+ select.select = _ORIG_SELECT # type: ignore
844
+ selectors.DefaultSelector = _ORIG_DEFAULT_SELECTOR # type: ignore
845
+ threading.Condition.wait = _ORIG_COND_WAIT # type: ignore
846
+ if cls._socket_patched:
847
+ _socket.socket.recv = _ORIG_SOCK_RECV # type: ignore
848
+ _socket.socket.send = _ORIG_SOCK_SEND # type: ignore
849
+ _socket.socket.accept = _ORIG_SOCK_ACCEPT # type: ignore
850
+ cls._socket_patched = False
851
+ threading.Thread = _ORIG_THREAD # type: ignore
852
+ del threading.__getattr__ # type: ignore
853
+ del threading.__dir__ # type: ignore
854
+ _INTERRUPT_EXC = ThreadInterrupted
855
+
856
+
857
+ from . import _version # noqa: E402
858
+
859
+ __version__ = _version.get_versions()["version"]