interruptible-threading 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- interruptible_threading/__init__.py +859 -0
- interruptible_threading/_version.py +683 -0
- interruptible_threading/version.py +23 -0
- interruptible_threading-0.0.1.dist-info/METADATA +191 -0
- interruptible_threading-0.0.1.dist-info/RECORD +8 -0
- interruptible_threading-0.0.1.dist-info/WHEEL +5 -0
- interruptible_threading-0.0.1.dist-info/licenses/docs/LICENSE.txt +29 -0
- interruptible_threading-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,859 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cooperative thread interruption for CPython.
|
|
3
|
+
|
|
4
|
+
Exposes an :class:`InterruptibleThread` whose :meth:`~InterruptibleThread.interrupt`
|
|
5
|
+
method raises an exception (``ThreadInterrupted`` by default) inside the target
|
|
6
|
+
thread. Pure-Python code is interrupted via ``PyThreadState_SetAsyncExc`` (an async
|
|
7
|
+
exception that fires at the next bytecode boundary). Because async exceptions cannot
|
|
8
|
+
break a thread that is parked in a C-level blocking call, ``install_patches()``
|
|
9
|
+
monkeypatches a curated set of stdlib blocking primitives (``time.sleep``,
|
|
10
|
+
``selectors.DefaultSelector``, ``select.select``, ``threading.Condition.wait``) so
|
|
11
|
+
they wake promptly and re-check a durable per-thread "interrupt pending" flag.
|
|
12
|
+
|
|
13
|
+
The flag is the single source of truth: ``interrupt()`` sets it first (under one
|
|
14
|
+
lock) and then issues a wakeup nudge; every blocking primitive checks the flag
|
|
15
|
+
before parking and again after waking. This closes the races inherent in choosing a
|
|
16
|
+
delivery path from transient state, and lets interrupts be *masked* during critical
|
|
17
|
+
sections (``critical_section()``) and *polled* in CPU-bound loops (``check_interrupt()``).
|
|
18
|
+
|
|
19
|
+
Requires monkey patching some stdlib functionality via
|
|
20
|
+
``InterruptibleThread.install_patches()``. CPython only; POSIX (Linux / Darwin).
|
|
21
|
+
|
|
22
|
+
Why not signals: ``signal.pthread_kill`` can unblock a syscall but cannot deliver an
|
|
23
|
+
exception to a worker thread -- CPython runs Python-level signal handlers only on the
|
|
24
|
+
main thread, and PEP 475's EINTR auto-retry loops call ``PyErr_CheckSignals()``
|
|
25
|
+
(a no-op off the main thread) without consulting ``tstate->async_exc``, so the
|
|
26
|
+
syscall is transparently retried. The self-pipe + cooperative-primitive approach is
|
|
27
|
+
the only way to get prompt, exception-bearing interruption of worker threads.
|
|
28
|
+
"""
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import contextlib
|
|
32
|
+
import ctypes
|
|
33
|
+
import os
|
|
34
|
+
import select
|
|
35
|
+
import selectors
|
|
36
|
+
import socket as _socket
|
|
37
|
+
import sys
|
|
38
|
+
import threading
|
|
39
|
+
import time
|
|
40
|
+
from typing import TYPE_CHECKING, Any
|
|
41
|
+
|
|
42
|
+
if TYPE_CHECKING:
|
|
43
|
+
from collections.abc import Iterator
|
|
44
|
+
from io import IOBase
|
|
45
|
+
|
|
46
|
+
__all__ = [
|
|
47
|
+
"InterruptibleThread",
|
|
48
|
+
"ThreadInterrupted",
|
|
49
|
+
"is_interrupted",
|
|
50
|
+
"clear_interrupt",
|
|
51
|
+
"check_interrupt",
|
|
52
|
+
"interruptible_checkpoint",
|
|
53
|
+
"periodic_checkpoint",
|
|
54
|
+
"critical_section",
|
|
55
|
+
"interrupts_disabled",
|
|
56
|
+
"interruptible_recv",
|
|
57
|
+
"interruptible_send",
|
|
58
|
+
"interruptible_accept",
|
|
59
|
+
"set_poll_interval",
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ThreadInterrupted(BaseException):
|
|
64
|
+
"""Raised inside a thread when another thread calls ``.interrupt()`` on it.
|
|
65
|
+
|
|
66
|
+
Subclasses ``BaseException`` (not ``Exception``) so that ordinary
|
|
67
|
+
``except Exception:`` handlers do not accidentally swallow it -- mirroring
|
|
68
|
+
``KeyboardInterrupt`` and ``asyncio.CancelledError``.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class _ThreadInterruptedKeyboard(ThreadInterrupted, KeyboardInterrupt):
|
|
73
|
+
"""Legacy-compatible interrupt: caught by both ``ThreadInterrupted`` and
|
|
74
|
+
``KeyboardInterrupt`` handlers. Used when ``legacy_keyboardinterrupt=True``."""
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# The exception class delivered by interrupt(); swapped by install_patches().
|
|
78
|
+
_INTERRUPT_EXC: type[BaseException] = ThreadInterrupted
|
|
79
|
+
|
|
80
|
+
_PTSSE = ctypes.pythonapi.PyThreadState_SetAsyncExc
|
|
81
|
+
_PTSSE.argtypes = [ctypes.c_ulong, ctypes.py_object]
|
|
82
|
+
_PTSSE.restype = ctypes.c_int
|
|
83
|
+
_MAIN_IDENT = threading.main_thread().ident
|
|
84
|
+
|
|
85
|
+
_ORIG_SLEEP = time.sleep
|
|
86
|
+
_ORIG_DEFAULT_SELECTOR = selectors.DefaultSelector
|
|
87
|
+
_ORIG_SELECT = select.select
|
|
88
|
+
_ORIG_THREAD = threading.Thread
|
|
89
|
+
_ORIG_COND_WAIT = threading.Condition.wait
|
|
90
|
+
|
|
91
|
+
_WAKEUP_TOKEN = "__interruptible_wakeup__"
|
|
92
|
+
|
|
93
|
+
# Max latency (seconds) for chunked-poll primitives (Condition.wait / Event / Queue).
|
|
94
|
+
_POLL_INTERVAL = 0.05
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def set_poll_interval(seconds: float) -> None:
|
|
98
|
+
"""Tune the wakeup latency of the chunked-poll blocking primitives."""
|
|
99
|
+
global _POLL_INTERVAL
|
|
100
|
+
_POLL_INTERVAL = float(seconds)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _clear_async_exc(tid: int | None = None) -> None:
|
|
104
|
+
"""Clear any async exception armed for ``tid`` (default: current thread).
|
|
105
|
+
|
|
106
|
+
Called right before raising on a cooperative path so an exception armed by
|
|
107
|
+
``interrupt()`` does not fire a second time at the next bytecode boundary.
|
|
108
|
+
"""
|
|
109
|
+
# An empty py_object() is a NULL pointer, which tells SetAsyncExc to *clear*
|
|
110
|
+
# the pending async exception. Passing Python ``None`` would instead arm None
|
|
111
|
+
# as the exception (raising SystemError when later delivered).
|
|
112
|
+
_PTSSE(
|
|
113
|
+
ctypes.c_ulong(tid if tid is not None else threading.get_ident()),
|
|
114
|
+
ctypes.py_object(),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _tracing_active() -> bool:
|
|
119
|
+
"""Whether a trace/profile hook is installed (coverage, a debugger, a
|
|
120
|
+
profiler) anywhere we can observe.
|
|
121
|
+
|
|
122
|
+
``PyThreadState_SetAsyncExc`` can permanently wedge a thread that is doing
|
|
123
|
+
lock operations while such a hook is active (a CPython interaction), so async
|
|
124
|
+
injection -- the only delivery path that uses it -- is skipped while tracing.
|
|
125
|
+
Cooperative delivery via the durable ``pending`` flag is unaffected.
|
|
126
|
+
"""
|
|
127
|
+
if sys.gettrace() is not None or sys.getprofile() is not None:
|
|
128
|
+
return True
|
|
129
|
+
# threading.gettrace/getprofile (3.10+) expose the global hook coverage
|
|
130
|
+
# installs for worker threads; guarded for 3.9 where they don't exist.
|
|
131
|
+
for name in ("gettrace", "getprofile"):
|
|
132
|
+
getter = getattr(threading, name, None)
|
|
133
|
+
if getter is not None and getter() is not None:
|
|
134
|
+
return True
|
|
135
|
+
return False
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _drain(fd: int) -> None:
|
|
139
|
+
"""Drain a non-blocking wakeup pipe until empty."""
|
|
140
|
+
while True:
|
|
141
|
+
try:
|
|
142
|
+
if not os.read(fd, 65536):
|
|
143
|
+
break
|
|
144
|
+
except (BlockingIOError, OSError):
|
|
145
|
+
break
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class _State:
|
|
149
|
+
registry_lock = threading.RLock()
|
|
150
|
+
registry: dict[int, _State] = {}
|
|
151
|
+
|
|
152
|
+
def __init__(self) -> None:
|
|
153
|
+
cur_thread = threading.current_thread()
|
|
154
|
+
if not isinstance(cur_thread, InterruptibleThread):
|
|
155
|
+
raise TypeError(
|
|
156
|
+
f"current thread should be of type {InterruptibleThread.__name__}"
|
|
157
|
+
)
|
|
158
|
+
self.thread: InterruptibleThread = cur_thread
|
|
159
|
+
self.children: set[int] = set()
|
|
160
|
+
# cancel_cond's lock is the single per-thread mutex guarding all the
|
|
161
|
+
# mutable interrupt state below.
|
|
162
|
+
self.cancel_cond = threading.Condition()
|
|
163
|
+
# Durable source of truth: an interrupt has been requested, not yet delivered.
|
|
164
|
+
self.pending = False
|
|
165
|
+
self.interrupt_gen = 0
|
|
166
|
+
self.mask_depth = 0
|
|
167
|
+
# True only while an async exception has actually been armed for this
|
|
168
|
+
# thread via PyThreadState_SetAsyncExc. Gates the ctypes "clear" call so
|
|
169
|
+
# the cooperative paths never touch SetAsyncExc -- calling it while a
|
|
170
|
+
# sys.settrace tracer (e.g. coverage) is active can deadlock the thread.
|
|
171
|
+
self.async_armed = False
|
|
172
|
+
# Hints used only to pick a wakeup nudge; never the source of truth.
|
|
173
|
+
self.sleeping = False
|
|
174
|
+
self.selecting = False
|
|
175
|
+
# asyncio integration (set by run_interruptible).
|
|
176
|
+
self.event_loop: Any = None
|
|
177
|
+
self.root_task: Any = None
|
|
178
|
+
# Lazily-allocated self-pipe; only needed on the selector/select path.
|
|
179
|
+
self.rfd = -1
|
|
180
|
+
self.wfd = -1
|
|
181
|
+
|
|
182
|
+
def ensure_pipe(self) -> None:
|
|
183
|
+
"""Allocate the self-pipe on first use (idempotent)."""
|
|
184
|
+
with self.cancel_cond:
|
|
185
|
+
if self.rfd == -1:
|
|
186
|
+
r, w = os.pipe()
|
|
187
|
+
os.set_blocking(r, False)
|
|
188
|
+
# The write end must also be non-blocking: ``_pipe_write`` runs
|
|
189
|
+
# while holding ``cancel_cond``, and a blocking ``os.write`` to a
|
|
190
|
+
# full pipe would deadlock. A full pipe already means a wakeup is
|
|
191
|
+
# pending, so dropping the extra byte (EAGAIN) is correct.
|
|
192
|
+
os.set_blocking(w, False)
|
|
193
|
+
self.rfd, self.wfd = r, w
|
|
194
|
+
|
|
195
|
+
@classmethod
|
|
196
|
+
def register_current_thread(cls) -> _State:
|
|
197
|
+
tid = threading.get_ident()
|
|
198
|
+
with cls.registry_lock:
|
|
199
|
+
st = cls.registry.get(tid)
|
|
200
|
+
if st is None:
|
|
201
|
+
st = cls()
|
|
202
|
+
cls.registry[tid] = st
|
|
203
|
+
return st
|
|
204
|
+
|
|
205
|
+
@classmethod
|
|
206
|
+
def get_state_by_ident(cls, tid: int | None = None) -> _State | None:
|
|
207
|
+
return cls.registry.get(tid if tid is not None else threading.get_ident())
|
|
208
|
+
|
|
209
|
+
@classmethod
|
|
210
|
+
def unregister_current_thread(cls) -> None:
|
|
211
|
+
tid = threading.get_ident()
|
|
212
|
+
with cls.registry_lock:
|
|
213
|
+
st = cls.registry.pop(tid, None)
|
|
214
|
+
if not st:
|
|
215
|
+
return
|
|
216
|
+
for fd in (st.rfd, st.wfd):
|
|
217
|
+
if fd == -1:
|
|
218
|
+
continue
|
|
219
|
+
try:
|
|
220
|
+
os.close(fd)
|
|
221
|
+
except OSError:
|
|
222
|
+
pass
|
|
223
|
+
st.rfd = st.wfd = -1
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _disarm_async(st: _State) -> None:
|
|
227
|
+
"""Clear an armed async exception for the current thread, but only if one was
|
|
228
|
+
actually armed. ``PyThreadState_SetAsyncExc`` must not be called speculatively:
|
|
229
|
+
invoking it while a ``sys.settrace`` tracer (e.g. coverage) is active can wedge
|
|
230
|
+
the thread, so the cooperative paths -- which never arm one -- must skip it.
|
|
231
|
+
|
|
232
|
+
Must be called holding ``st.cancel_cond``.
|
|
233
|
+
"""
|
|
234
|
+
if st.async_armed:
|
|
235
|
+
st.async_armed = False
|
|
236
|
+
_clear_async_exc()
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _take_pending(st: _State) -> bool:
|
|
240
|
+
"""If an unmasked interrupt is pending, consume it and clear any armed async
|
|
241
|
+
exception, returning True (caller should raise). Otherwise return False.
|
|
242
|
+
|
|
243
|
+
Must be called holding ``st.cancel_cond``.
|
|
244
|
+
"""
|
|
245
|
+
if st.pending and st.mask_depth == 0:
|
|
246
|
+
st.pending = False
|
|
247
|
+
_disarm_async(st)
|
|
248
|
+
return True
|
|
249
|
+
return False
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
# --------------------------------------------------------------------------- #
|
|
253
|
+
# Public cooperative API (operates on the *current* thread) #
|
|
254
|
+
# --------------------------------------------------------------------------- #
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def is_interrupted(thread: InterruptibleThread | None = None) -> bool:
|
|
258
|
+
"""Return whether an interrupt is pending for ``thread`` (default: current).
|
|
259
|
+
|
|
260
|
+
Non-consuming; safe to call after the thread has exited (returns ``False``).
|
|
261
|
+
"""
|
|
262
|
+
tid = thread.ident if thread is not None else threading.get_ident()
|
|
263
|
+
st = _State.get_state_by_ident(tid)
|
|
264
|
+
if st is None:
|
|
265
|
+
return False
|
|
266
|
+
with st.cancel_cond:
|
|
267
|
+
return st.pending
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def clear_interrupt() -> bool:
|
|
271
|
+
"""Consume any pending interrupt for the current thread without raising.
|
|
272
|
+
|
|
273
|
+
Returns whether one was pending (Java's ``Thread.interrupted()`` semantics).
|
|
274
|
+
|
|
275
|
+
Call this after *catching* ``ThreadInterrupted`` when you intend to keep
|
|
276
|
+
running: the pending flag is durable (so an interrupt is never lost if the
|
|
277
|
+
thread parks in a blocking call before async injection can fire), so without
|
|
278
|
+
clearing it the next checkpoint or blocking primitive would re-raise.
|
|
279
|
+
"""
|
|
280
|
+
st = _State.get_state_by_ident()
|
|
281
|
+
if st is None:
|
|
282
|
+
return False
|
|
283
|
+
with st.cancel_cond:
|
|
284
|
+
prev = st.pending
|
|
285
|
+
st.pending = False
|
|
286
|
+
if prev:
|
|
287
|
+
_disarm_async(st)
|
|
288
|
+
return prev
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def check_interrupt() -> None:
|
|
292
|
+
"""Raise the interrupt exception if one is pending and not masked.
|
|
293
|
+
|
|
294
|
+
Cheap; intended for CPU-bound loops and around opaque C calls so they become
|
|
295
|
+
interruptible while still honoring ``critical_section()``.
|
|
296
|
+
"""
|
|
297
|
+
st = _State.get_state_by_ident()
|
|
298
|
+
if st is None:
|
|
299
|
+
return
|
|
300
|
+
with st.cancel_cond:
|
|
301
|
+
if _take_pending(st):
|
|
302
|
+
raise _INTERRUPT_EXC()
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
# Alias.
|
|
306
|
+
interruptible_checkpoint = check_interrupt
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@contextlib.contextmanager
|
|
310
|
+
def periodic_checkpoint(every: int = 1000) -> Iterator[_Ticker]:
|
|
311
|
+
"""Yield a ticker whose ``.tick()`` runs :func:`check_interrupt` every ``every``
|
|
312
|
+
calls, amortizing the per-iteration cost in a tight loop::
|
|
313
|
+
|
|
314
|
+
with periodic_checkpoint(every=1000) as ck:
|
|
315
|
+
for item in huge_iterable:
|
|
316
|
+
ck.tick()
|
|
317
|
+
crunch(item)
|
|
318
|
+
"""
|
|
319
|
+
yield _Ticker(every)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class _Ticker:
|
|
323
|
+
__slots__ = ("_every", "_n")
|
|
324
|
+
|
|
325
|
+
def __init__(self, every: int) -> None:
|
|
326
|
+
self._every = max(1, int(every))
|
|
327
|
+
self._n = 0
|
|
328
|
+
|
|
329
|
+
def tick(self) -> bool:
|
|
330
|
+
"""Count one iteration; every Nth call runs check_interrupt(). Returns
|
|
331
|
+
whether the check ran this call."""
|
|
332
|
+
self._n += 1
|
|
333
|
+
if self._n >= self._every:
|
|
334
|
+
self._n = 0
|
|
335
|
+
check_interrupt()
|
|
336
|
+
return True
|
|
337
|
+
return False
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
@contextlib.contextmanager
|
|
341
|
+
def interrupts_disabled() -> Iterator[None]:
|
|
342
|
+
"""Defer interrupt delivery to the current thread for the duration of the block.
|
|
343
|
+
|
|
344
|
+
An interrupt that arrives while masked is recorded and delivered exactly once
|
|
345
|
+
when the outermost mask exits (re-entrant). Reliable for the flag-driven paths
|
|
346
|
+
(sleep / select / Condition / checkpoints); an async exception already in flight
|
|
347
|
+
from ``SetAsyncExc`` microseconds before entering cannot be recalled.
|
|
348
|
+
"""
|
|
349
|
+
st = _State.get_state_by_ident()
|
|
350
|
+
if st is None:
|
|
351
|
+
yield
|
|
352
|
+
return
|
|
353
|
+
with st.cancel_cond:
|
|
354
|
+
st.mask_depth += 1
|
|
355
|
+
try:
|
|
356
|
+
yield
|
|
357
|
+
finally:
|
|
358
|
+
raise_now = False
|
|
359
|
+
with st.cancel_cond:
|
|
360
|
+
st.mask_depth -= 1
|
|
361
|
+
if st.mask_depth == 0 and st.pending:
|
|
362
|
+
st.pending = False
|
|
363
|
+
_disarm_async(st)
|
|
364
|
+
raise_now = True
|
|
365
|
+
if raise_now:
|
|
366
|
+
raise _INTERRUPT_EXC()
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
# Alias.
|
|
370
|
+
critical_section = interrupts_disabled
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
# --------------------------------------------------------------------------- #
|
|
374
|
+
# Patched blocking primitives #
|
|
375
|
+
# --------------------------------------------------------------------------- #
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def _coop_sleep(secs: float) -> None:
|
|
379
|
+
tid = threading.get_ident()
|
|
380
|
+
if tid == _MAIN_IDENT:
|
|
381
|
+
return _ORIG_SLEEP(secs)
|
|
382
|
+
st = _State.get_state_by_ident(tid)
|
|
383
|
+
if st is None:
|
|
384
|
+
return _ORIG_SLEEP(secs)
|
|
385
|
+
deadline = time.monotonic() + secs
|
|
386
|
+
with st.cancel_cond:
|
|
387
|
+
while True:
|
|
388
|
+
if _take_pending(st):
|
|
389
|
+
raise _INTERRUPT_EXC()
|
|
390
|
+
remaining = deadline - time.monotonic()
|
|
391
|
+
if remaining <= 0:
|
|
392
|
+
return
|
|
393
|
+
st.sleeping = True
|
|
394
|
+
try:
|
|
395
|
+
_ORIG_COND_WAIT(st.cancel_cond, remaining)
|
|
396
|
+
finally:
|
|
397
|
+
st.sleeping = False
|
|
398
|
+
# Loop: an unmasked pending flag raises at the top; a masked one keeps
|
|
399
|
+
# sleeping out the remaining time; spurious wakeups simply re-park.
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def _patched_cond_wait(self: threading.Condition, timeout: float | None = None) -> bool:
|
|
403
|
+
tid = threading.get_ident()
|
|
404
|
+
if tid == _MAIN_IDENT:
|
|
405
|
+
return _ORIG_COND_WAIT(self, timeout)
|
|
406
|
+
st = _State.get_state_by_ident(tid)
|
|
407
|
+
if st is None:
|
|
408
|
+
return _ORIG_COND_WAIT(self, timeout)
|
|
409
|
+
deadline = None if timeout is None else time.monotonic() + timeout
|
|
410
|
+
while True:
|
|
411
|
+
with st.cancel_cond:
|
|
412
|
+
if _take_pending(st):
|
|
413
|
+
raise _INTERRUPT_EXC()
|
|
414
|
+
if deadline is None:
|
|
415
|
+
chunk = _POLL_INTERVAL
|
|
416
|
+
else:
|
|
417
|
+
remaining = deadline - time.monotonic()
|
|
418
|
+
if remaining <= 0:
|
|
419
|
+
return False
|
|
420
|
+
chunk = min(_POLL_INTERVAL, remaining)
|
|
421
|
+
if _ORIG_COND_WAIT(self, chunk):
|
|
422
|
+
return True
|
|
423
|
+
# Timed-out chunk: loop to re-check the pending flag.
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
class _InterruptibleSelector(_ORIG_DEFAULT_SELECTOR):
|
|
427
|
+
"""DefaultSelector that also watches the current thread's self-pipe, so an
|
|
428
|
+
interrupt can wake a parked ``select`` (this is what makes asyncio interruptible).
|
|
429
|
+
"""
|
|
430
|
+
|
|
431
|
+
def __init__(self) -> None:
|
|
432
|
+
super().__init__()
|
|
433
|
+
st = self._ithr_st = _State.get_state_by_ident()
|
|
434
|
+
if st is not None:
|
|
435
|
+
st.ensure_pipe()
|
|
436
|
+
try:
|
|
437
|
+
super().register(st.rfd, selectors.EVENT_READ, data=_WAKEUP_TOKEN)
|
|
438
|
+
except (KeyError, ValueError):
|
|
439
|
+
pass
|
|
440
|
+
|
|
441
|
+
def select(self, timeout: float | None = None):
|
|
442
|
+
st = getattr(self, "_ithr_st", None)
|
|
443
|
+
if st is None:
|
|
444
|
+
return super().select(timeout)
|
|
445
|
+
with st.cancel_cond:
|
|
446
|
+
if _take_pending(st):
|
|
447
|
+
raise _INTERRUPT_EXC()
|
|
448
|
+
st.selecting = True
|
|
449
|
+
try:
|
|
450
|
+
events = super().select(timeout)
|
|
451
|
+
finally:
|
|
452
|
+
with st.cancel_cond:
|
|
453
|
+
st.selecting = False
|
|
454
|
+
out = []
|
|
455
|
+
for key, mask in events:
|
|
456
|
+
if key.data == _WAKEUP_TOKEN:
|
|
457
|
+
_drain(key.fd)
|
|
458
|
+
continue
|
|
459
|
+
out.append((key, mask))
|
|
460
|
+
with st.cancel_cond:
|
|
461
|
+
if _take_pending(st):
|
|
462
|
+
raise _INTERRUPT_EXC()
|
|
463
|
+
return out
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def _fileno(x: IOBase | int) -> int:
|
|
467
|
+
try:
|
|
468
|
+
return x.fileno() # type: ignore[union-attr]
|
|
469
|
+
except AttributeError:
|
|
470
|
+
return int(x) # type: ignore[arg-type]
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def _patched_select(
|
|
474
|
+
rlist: list[IOBase | int],
|
|
475
|
+
wlist: list[IOBase | int],
|
|
476
|
+
xlist: list[IOBase | int],
|
|
477
|
+
timeout: float | None = None,
|
|
478
|
+
) -> tuple[list[IOBase | int], list[IOBase | int], list[IOBase | int]]:
|
|
479
|
+
tid = threading.get_ident()
|
|
480
|
+
if tid == _MAIN_IDENT:
|
|
481
|
+
return _ORIG_SELECT(rlist, wlist, xlist, timeout)
|
|
482
|
+
st = _State.get_state_by_ident(tid)
|
|
483
|
+
if st is None:
|
|
484
|
+
return _ORIG_SELECT(rlist, wlist, xlist, timeout)
|
|
485
|
+
|
|
486
|
+
st.ensure_pipe()
|
|
487
|
+
rfd = st.rfd
|
|
488
|
+
|
|
489
|
+
def to_fd_list(lst: list[IOBase | int]) -> list[tuple[int, IOBase | int]]:
|
|
490
|
+
return [(_fileno(o), o) for o in lst]
|
|
491
|
+
|
|
492
|
+
rmap = to_fd_list(rlist)
|
|
493
|
+
wmap = to_fd_list(wlist)
|
|
494
|
+
xmap = to_fd_list(xlist)
|
|
495
|
+
|
|
496
|
+
rfd_list = [fd for fd, _ in rmap]
|
|
497
|
+
wfd_list = [fd for fd, _ in wmap]
|
|
498
|
+
xfd_list = [fd for fd, _ in xmap]
|
|
499
|
+
|
|
500
|
+
if rfd not in rfd_list:
|
|
501
|
+
rfd_list.append(rfd)
|
|
502
|
+
|
|
503
|
+
with st.cancel_cond:
|
|
504
|
+
if _take_pending(st):
|
|
505
|
+
raise _INTERRUPT_EXC()
|
|
506
|
+
st.selecting = True
|
|
507
|
+
try:
|
|
508
|
+
rr, ww, xx = _ORIG_SELECT(rfd_list, wfd_list, xfd_list, timeout)
|
|
509
|
+
finally:
|
|
510
|
+
with st.cancel_cond:
|
|
511
|
+
st.selecting = False
|
|
512
|
+
|
|
513
|
+
if rfd in rr:
|
|
514
|
+
_drain(rfd)
|
|
515
|
+
with st.cancel_cond:
|
|
516
|
+
if _take_pending(st):
|
|
517
|
+
raise _INTERRUPT_EXC()
|
|
518
|
+
|
|
519
|
+
def map_back(
|
|
520
|
+
fd_list: list[int], fmap: list[tuple[int, IOBase | int]]
|
|
521
|
+
) -> list[IOBase | int]:
|
|
522
|
+
fds = set(fd_list)
|
|
523
|
+
return [obj for fd, obj in fmap if fd in fds]
|
|
524
|
+
|
|
525
|
+
return (
|
|
526
|
+
map_back(rr, rmap),
|
|
527
|
+
map_back(ww, wmap),
|
|
528
|
+
map_back(xx, xmap),
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
# --------------------------------------------------------------------------- #
|
|
533
|
+
# Interruptible socket helpers (opt-in; reuse the self-pipe directly) #
|
|
534
|
+
# --------------------------------------------------------------------------- #
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
def _interruptible_io(sock: _socket.socket, want_write: bool, op):
|
|
538
|
+
tid = threading.get_ident()
|
|
539
|
+
st = _State.get_state_by_ident(tid)
|
|
540
|
+
if st is None or tid == _MAIN_IDENT:
|
|
541
|
+
return op()
|
|
542
|
+
st.ensure_pipe()
|
|
543
|
+
prev_timeout = sock.gettimeout()
|
|
544
|
+
sock.setblocking(False)
|
|
545
|
+
try:
|
|
546
|
+
while True:
|
|
547
|
+
with st.cancel_cond:
|
|
548
|
+
if _take_pending(st):
|
|
549
|
+
raise _INTERRUPT_EXC()
|
|
550
|
+
st.selecting = True
|
|
551
|
+
try:
|
|
552
|
+
if want_write:
|
|
553
|
+
_rr, ww, _xx = _ORIG_SELECT([st.rfd], [sock], [])
|
|
554
|
+
ready = sock in ww
|
|
555
|
+
else:
|
|
556
|
+
rr, _ww, _xx = _ORIG_SELECT([sock, st.rfd], [], [])
|
|
557
|
+
ready = sock in rr
|
|
558
|
+
finally:
|
|
559
|
+
with st.cancel_cond:
|
|
560
|
+
st.selecting = False
|
|
561
|
+
_drain(st.rfd)
|
|
562
|
+
with st.cancel_cond:
|
|
563
|
+
if _take_pending(st):
|
|
564
|
+
raise _INTERRUPT_EXC()
|
|
565
|
+
if ready:
|
|
566
|
+
try:
|
|
567
|
+
return op()
|
|
568
|
+
except (BlockingIOError, InterruptedError):
|
|
569
|
+
continue
|
|
570
|
+
finally:
|
|
571
|
+
sock.settimeout(prev_timeout)
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def interruptible_recv(sock: _socket.socket, bufsize: int, flags: int = 0) -> bytes:
|
|
575
|
+
"""``sock.recv`` that raises the interrupt exception if interrupted while blocked."""
|
|
576
|
+
return _interruptible_io(sock, False, lambda: sock.recv(bufsize, flags))
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
def interruptible_send(sock: _socket.socket, data: bytes, flags: int = 0) -> int:
|
|
580
|
+
"""``sock.send`` that raises the interrupt exception if interrupted while blocked."""
|
|
581
|
+
return _interruptible_io(sock, True, lambda: sock.send(data, flags))
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def interruptible_accept(sock: _socket.socket):
|
|
585
|
+
"""``sock.accept`` that raises the interrupt exception if interrupted while blocked."""
|
|
586
|
+
return _interruptible_io(sock, False, sock.accept)
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
_ORIG_SOCK_RECV = _socket.socket.recv
|
|
590
|
+
_ORIG_SOCK_SEND = _socket.socket.send
|
|
591
|
+
_ORIG_SOCK_ACCEPT = _socket.socket.accept
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
def _sock_recv_patch(self, bufsize, flags=0):
|
|
595
|
+
return interruptible_recv(self, bufsize, flags)
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def _sock_send_patch(self, data, flags=0):
|
|
599
|
+
return interruptible_send(self, data, flags)
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def _sock_accept_patch(self):
|
|
603
|
+
return interruptible_accept(self)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
# --------------------------------------------------------------------------- #
|
|
607
|
+
# The thread class #
|
|
608
|
+
# --------------------------------------------------------------------------- #
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
class InterruptibleThread(_ORIG_THREAD):
|
|
612
|
+
"""Thread with a built-in cooperative :meth:`interrupt`."""
|
|
613
|
+
|
|
614
|
+
_patches_installed = False
|
|
615
|
+
_socket_patched = False
|
|
616
|
+
|
|
617
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
618
|
+
super().__init__(*args, **kwargs)
|
|
619
|
+
self._parent_tid = threading.get_ident()
|
|
620
|
+
|
|
621
|
+
def run(self) -> None:
|
|
622
|
+
with _State.registry_lock:
|
|
623
|
+
_State.register_current_thread()
|
|
624
|
+
parent_st = _State.get_state_by_ident(self._parent_tid)
|
|
625
|
+
if parent_st is not None:
|
|
626
|
+
parent_st.children.add(threading.get_ident())
|
|
627
|
+
try:
|
|
628
|
+
super().run()
|
|
629
|
+
finally:
|
|
630
|
+
my_tid = threading.get_ident()
|
|
631
|
+
with _State.registry_lock:
|
|
632
|
+
parent_st = _State.get_state_by_ident(self._parent_tid)
|
|
633
|
+
if parent_st is not None:
|
|
634
|
+
parent_st.children.discard(my_tid)
|
|
635
|
+
st = _State.get_state_by_ident(my_tid)
|
|
636
|
+
if st is not None:
|
|
637
|
+
with st.cancel_cond:
|
|
638
|
+
_State.unregister_current_thread()
|
|
639
|
+
|
|
640
|
+
@classmethod
|
|
641
|
+
def run_interruptible(cls, coro):
|
|
642
|
+
"""Run a coroutine via ``asyncio.run`` such that ``interrupt()`` cancels the
|
|
643
|
+
event loop's tasks cooperatively (clean ``finally``/``async with`` unwind)
|
|
644
|
+
instead of injecting an exception into the selector. Returns the coroutine's
|
|
645
|
+
result; re-raises the interrupt exception if interrupted."""
|
|
646
|
+
import asyncio
|
|
647
|
+
|
|
648
|
+
st = _State.get_state_by_ident()
|
|
649
|
+
|
|
650
|
+
async def _runner():
|
|
651
|
+
if st is not None:
|
|
652
|
+
with st.cancel_cond:
|
|
653
|
+
st.event_loop = asyncio.get_running_loop()
|
|
654
|
+
st.root_task = asyncio.current_task()
|
|
655
|
+
try:
|
|
656
|
+
return await coro
|
|
657
|
+
finally:
|
|
658
|
+
if st is not None:
|
|
659
|
+
with st.cancel_cond:
|
|
660
|
+
st.event_loop = None
|
|
661
|
+
st.root_task = None
|
|
662
|
+
|
|
663
|
+
try:
|
|
664
|
+
return asyncio.run(_runner())
|
|
665
|
+
except asyncio.CancelledError:
|
|
666
|
+
if st is not None:
|
|
667
|
+
with st.cancel_cond:
|
|
668
|
+
if _take_pending(st):
|
|
669
|
+
raise _INTERRUPT_EXC() from None
|
|
670
|
+
raise
|
|
671
|
+
|
|
672
|
+
def _inject_exc(self) -> None:
|
|
673
|
+
tid = ctypes.c_ulong(self.ident or 0)
|
|
674
|
+
rc = _PTSSE(tid, ctypes.py_object(_INTERRUPT_EXC))
|
|
675
|
+
if rc == 0:
|
|
676
|
+
raise ValueError("no such thread")
|
|
677
|
+
elif rc > 1:
|
|
678
|
+
_PTSSE(tid, ctypes.py_object())
|
|
679
|
+
raise SystemError("SetAsyncExc affected multiple threads")
|
|
680
|
+
|
|
681
|
+
@staticmethod
|
|
682
|
+
def _pipe_write(st: _State) -> None:
|
|
683
|
+
if st.wfd != -1:
|
|
684
|
+
try:
|
|
685
|
+
os.write(st.wfd, b"\x00")
|
|
686
|
+
except OSError:
|
|
687
|
+
pass
|
|
688
|
+
|
|
689
|
+
def _nudge(self, st: _State) -> None:
|
|
690
|
+
"""Wake the target out of any current park without arming an async exception.
|
|
691
|
+
Must be called holding ``st.cancel_cond``."""
|
|
692
|
+
st.cancel_cond.notify_all()
|
|
693
|
+
self._pipe_write(st)
|
|
694
|
+
|
|
695
|
+
def _cancel_event_loop(self, st: _State) -> None:
|
|
696
|
+
loop, task = st.event_loop, st.root_task
|
|
697
|
+
|
|
698
|
+
def _cancel():
|
|
699
|
+
if task is not None and not task.done():
|
|
700
|
+
task.cancel()
|
|
701
|
+
|
|
702
|
+
try:
|
|
703
|
+
loop.call_soon_threadsafe(_cancel)
|
|
704
|
+
except RuntimeError:
|
|
705
|
+
# Loop already closed; fall back to async injection.
|
|
706
|
+
self._inject_exc()
|
|
707
|
+
|
|
708
|
+
def interrupt(self, recursive: bool = False) -> None:
|
|
709
|
+
st = _State.get_state_by_ident(self.ident)
|
|
710
|
+
if st is None:
|
|
711
|
+
raise ValueError("no such thread")
|
|
712
|
+
|
|
713
|
+
if recursive:
|
|
714
|
+
with _State.registry_lock:
|
|
715
|
+
children = list(st.children)
|
|
716
|
+
for child_tid in children:
|
|
717
|
+
with _State.registry_lock:
|
|
718
|
+
child_st = _State.get_state_by_ident(child_tid)
|
|
719
|
+
child_thread = child_st.thread if child_st is not None else None
|
|
720
|
+
if (
|
|
721
|
+
child_thread is None
|
|
722
|
+
or child_thread.ident != child_tid
|
|
723
|
+
or not child_thread.is_alive()
|
|
724
|
+
):
|
|
725
|
+
continue
|
|
726
|
+
try:
|
|
727
|
+
child_thread.interrupt(recursive=True)
|
|
728
|
+
except ValueError:
|
|
729
|
+
continue
|
|
730
|
+
|
|
731
|
+
with st.cancel_cond:
|
|
732
|
+
if _State.get_state_by_ident(self.ident) is None:
|
|
733
|
+
# Thread terminated while we waited on the lock; nothing to do.
|
|
734
|
+
return
|
|
735
|
+
if st.pending:
|
|
736
|
+
# Already requested -- idempotent. Re-nudge in case a prior wakeup
|
|
737
|
+
# was missed; never re-arm async injection.
|
|
738
|
+
self._nudge(st)
|
|
739
|
+
return
|
|
740
|
+
st.pending = True
|
|
741
|
+
st.interrupt_gen += 1
|
|
742
|
+
|
|
743
|
+
if st.event_loop is not None:
|
|
744
|
+
# Running an asyncio loop: cancel tasks cooperatively.
|
|
745
|
+
self._cancel_event_loop(st)
|
|
746
|
+
elif st.mask_depth > 0:
|
|
747
|
+
# Masked: record + unblock so it loops to its mask check; the
|
|
748
|
+
# pending flag is delivered when the mask exits.
|
|
749
|
+
self._nudge(st)
|
|
750
|
+
elif st.sleeping:
|
|
751
|
+
st.cancel_cond.notify_all()
|
|
752
|
+
elif st.selecting and st.wfd != -1:
|
|
753
|
+
self._pipe_write(st)
|
|
754
|
+
else:
|
|
755
|
+
# Pure-Python execution, or a thread that has *decided* to block
|
|
756
|
+
# but not yet committed (so the sleeping/selecting hints still
|
|
757
|
+
# read False -- common under tracing/coverage). Async-inject to
|
|
758
|
+
# break a CPU loop, and also nudge every cooperative channel in
|
|
759
|
+
# case the thread is entering a primitive right now.
|
|
760
|
+
#
|
|
761
|
+
# Crucially, do NOT clear `pending`. It is the durable source of
|
|
762
|
+
# truth: if async injection cannot fire (the thread parks in a
|
|
763
|
+
# C-level wait before the next bytecode boundary), the primitive
|
|
764
|
+
# still sees `pending` on its pre-block / post-wake check and
|
|
765
|
+
# raises. The cooperative consumer (`_take_pending`) clears the
|
|
766
|
+
# flag and disarms the async exception together, so delivery is
|
|
767
|
+
# exactly-once via whichever path wins. (A thread that *catches*
|
|
768
|
+
# an async-injected interrupt and continues must call
|
|
769
|
+
# `clear_interrupt()` -- see its docstring.)
|
|
770
|
+
self._nudge(st)
|
|
771
|
+
# Async injection uses PyThreadState_SetAsyncExc, which can
|
|
772
|
+
# deadlock the target under an active trace hook (coverage /
|
|
773
|
+
# debuggers); skip it then. Any thread that reaches a checkpoint
|
|
774
|
+
# or blocking primitive is still delivered via the durable flag;
|
|
775
|
+
# only a checkpoint-less pure-Python loop is left uninterruptible
|
|
776
|
+
# while tracing (a narrow, documented gap).
|
|
777
|
+
if not _tracing_active():
|
|
778
|
+
self._inject_exc()
|
|
779
|
+
st.async_armed = True
|
|
780
|
+
|
|
781
|
+
@classmethod
|
|
782
|
+
def get_thread_cls_for_current_thread(cls, item: str) -> type[threading.Thread]:
|
|
783
|
+
if item != "Thread":
|
|
784
|
+
raise AttributeError("No attribute %s in module threading" % item)
|
|
785
|
+
st = _State.get_state_by_ident()
|
|
786
|
+
if st is None:
|
|
787
|
+
return _ORIG_THREAD
|
|
788
|
+
else:
|
|
789
|
+
return cls
|
|
790
|
+
|
|
791
|
+
@classmethod
|
|
792
|
+
def install_patches(
|
|
793
|
+
cls,
|
|
794
|
+
interrupt_exc: type[BaseException] = ThreadInterrupted,
|
|
795
|
+
legacy_keyboardinterrupt: bool = False,
|
|
796
|
+
monkeypatch_socket: bool = False,
|
|
797
|
+
) -> None:
|
|
798
|
+
"""Install the stdlib monkeypatches that make blocking calls interruptible.
|
|
799
|
+
|
|
800
|
+
``interrupt_exc``: the exception class delivered by ``interrupt()``
|
|
801
|
+
(default ``ThreadInterrupted``; pass ``KeyboardInterrupt`` for legacy code).
|
|
802
|
+
``legacy_keyboardinterrupt``: deliver an exception caught by *both*
|
|
803
|
+
``ThreadInterrupted`` and ``KeyboardInterrupt`` handlers.
|
|
804
|
+
``monkeypatch_socket``: also swap blocking ``socket.recv/send/accept`` for
|
|
805
|
+
their interruptible variants process-wide (off by default; large blast radius).
|
|
806
|
+
"""
|
|
807
|
+
global _INTERRUPT_EXC
|
|
808
|
+
if cls._patches_installed:
|
|
809
|
+
raise ValueError("patches already installed")
|
|
810
|
+
cls._patches_installed = True
|
|
811
|
+
|
|
812
|
+
if legacy_keyboardinterrupt:
|
|
813
|
+
_INTERRUPT_EXC = _ThreadInterruptedKeyboard
|
|
814
|
+
else:
|
|
815
|
+
_INTERRUPT_EXC = interrupt_exc
|
|
816
|
+
|
|
817
|
+
time.sleep = _coop_sleep # type: ignore
|
|
818
|
+
selectors.DefaultSelector = _InterruptibleSelector # type: ignore
|
|
819
|
+
select.select = _patched_select # type: ignore
|
|
820
|
+
threading.Condition.wait = _patched_cond_wait # type: ignore
|
|
821
|
+
|
|
822
|
+
if monkeypatch_socket:
|
|
823
|
+
cls._socket_patched = True
|
|
824
|
+
_socket.socket.recv = _sock_recv_patch # type: ignore
|
|
825
|
+
_socket.socket.send = _sock_send_patch # type: ignore
|
|
826
|
+
_socket.socket.accept = _sock_accept_patch # type: ignore
|
|
827
|
+
|
|
828
|
+
del threading.Thread
|
|
829
|
+
threading.__getattr__ = cls.get_thread_cls_for_current_thread # type: ignore
|
|
830
|
+
|
|
831
|
+
def __dir__():
|
|
832
|
+
return list(threading.__dict__.keys()) + [_ORIG_THREAD.__name__]
|
|
833
|
+
|
|
834
|
+
threading.__dir__ = __dir__ # type: ignore
|
|
835
|
+
|
|
836
|
+
@classmethod
|
|
837
|
+
def uninstall_patches(cls) -> None:
|
|
838
|
+
global _INTERRUPT_EXC
|
|
839
|
+
if not cls._patches_installed:
|
|
840
|
+
raise ValueError("patches not installed")
|
|
841
|
+
cls._patches_installed = False
|
|
842
|
+
time.sleep = _ORIG_SLEEP
|
|
843
|
+
select.select = _ORIG_SELECT # type: ignore
|
|
844
|
+
selectors.DefaultSelector = _ORIG_DEFAULT_SELECTOR # type: ignore
|
|
845
|
+
threading.Condition.wait = _ORIG_COND_WAIT # type: ignore
|
|
846
|
+
if cls._socket_patched:
|
|
847
|
+
_socket.socket.recv = _ORIG_SOCK_RECV # type: ignore
|
|
848
|
+
_socket.socket.send = _ORIG_SOCK_SEND # type: ignore
|
|
849
|
+
_socket.socket.accept = _ORIG_SOCK_ACCEPT # type: ignore
|
|
850
|
+
cls._socket_patched = False
|
|
851
|
+
threading.Thread = _ORIG_THREAD # type: ignore
|
|
852
|
+
del threading.__getattr__ # type: ignore
|
|
853
|
+
del threading.__dir__ # type: ignore
|
|
854
|
+
_INTERRUPT_EXC = ThreadInterrupted
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
from . import _version # noqa: E402
|
|
858
|
+
|
|
859
|
+
__version__ = _version.get_versions()["version"]
|