speedy-utils 1.1.20__py3-none-any.whl → 1.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +12 -1
- llm_utils/lm/llm_task.py +172 -251
- llm_utils/lm/utils.py +332 -110
- speedy_utils/multi_worker/process.py +128 -27
- speedy_utils/multi_worker/thread.py +341 -226
- {speedy_utils-1.1.20.dist-info → speedy_utils-1.1.22.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.20.dist-info → speedy_utils-1.1.22.dist-info}/RECORD +9 -9
- {speedy_utils-1.1.20.dist-info → speedy_utils-1.1.22.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.20.dist-info → speedy_utils-1.1.22.dist-info}/entry_points.txt +0 -0
|
@@ -14,13 +14,14 @@
|
|
|
14
14
|
# making it suitable for data processing pipelines, batch operations, and concurrent API calls.
|
|
15
15
|
#
|
|
16
16
|
# Public API / Data Contracts:
|
|
17
|
-
# • multi_thread(func, inputs,
|
|
18
|
-
# •
|
|
17
|
+
# • multi_thread(func, inputs, *, workers=None, **kwargs) -> list[Any] - Main executor
|
|
18
|
+
# • multi_thread_standard(func, inputs, workers=4) -> list[Any] - Simple ordered helper
|
|
19
|
+
# • kill_all_thread(exc_type=SystemExit, join_timeout=0.1) -> int - Emergency stop
|
|
19
20
|
# • DEFAULT_WORKERS = (cpu_count * 2) - Default worker thread count
|
|
20
|
-
# • T = TypeVar(
|
|
21
|
-
# • _group_iter(src, size) -> Iterable[
|
|
21
|
+
# • T = TypeVar('T'), R = TypeVar('R') - Generic type variables for input/output typing
|
|
22
|
+
# • _group_iter(src, size) -> Iterable[list[T]] - Utility for chunking iterables
|
|
22
23
|
# • _worker(item, func, fixed_kwargs) -> R - Individual worker function wrapper
|
|
23
|
-
# •
|
|
24
|
+
# • _ResultCollector - Maintains ordered/unordered result aggregation
|
|
24
25
|
#
|
|
25
26
|
# Invariants / Constraints:
|
|
26
27
|
# • Worker count MUST be positive integer, defaults to (CPU cores * 2)
|
|
@@ -81,11 +82,12 @@ import ctypes
|
|
|
81
82
|
import os
|
|
82
83
|
import threading
|
|
83
84
|
import time
|
|
84
|
-
import
|
|
85
|
-
from
|
|
86
|
-
from
|
|
85
|
+
from collections.abc import Callable, Iterable, Mapping, Sequence
|
|
86
|
+
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
|
|
87
|
+
from heapq import heappop, heappush
|
|
87
88
|
from itertools import islice
|
|
88
|
-
from
|
|
89
|
+
from types import MappingProxyType
|
|
90
|
+
from typing import Any, Generic, TypeVar, cast
|
|
89
91
|
|
|
90
92
|
from loguru import logger
|
|
91
93
|
|
|
@@ -97,10 +99,10 @@ except ImportError: # pragma: no cover
|
|
|
97
99
|
# Sensible defaults
|
|
98
100
|
DEFAULT_WORKERS = (os.cpu_count() or 4) * 2
|
|
99
101
|
|
|
100
|
-
T = TypeVar(
|
|
101
|
-
R = TypeVar(
|
|
102
|
+
T = TypeVar('T')
|
|
103
|
+
R = TypeVar('R')
|
|
102
104
|
|
|
103
|
-
SPEEDY_RUNNING_THREADS: list[threading.Thread] = []
|
|
105
|
+
SPEEDY_RUNNING_THREADS: list[threading.Thread] = [] # cooperative shutdown tracking
|
|
104
106
|
_SPEEDY_THREADS_LOCK = threading.Lock()
|
|
105
107
|
|
|
106
108
|
_PY_SET_ASYNC_EXC = ctypes.pythonapi.PyThreadState_SetAsyncExc
|
|
@@ -131,7 +133,7 @@ def _track_threads(threads: Iterable[threading.Thread]) -> None:
|
|
|
131
133
|
|
|
132
134
|
|
|
133
135
|
def _track_executor_threads(pool: ThreadPoolExecutor) -> None:
|
|
134
|
-
thread_set = getattr(pool,
|
|
136
|
+
thread_set = getattr(pool, '_threads', None)
|
|
135
137
|
if not thread_set:
|
|
136
138
|
return
|
|
137
139
|
_track_threads(tuple(thread_set))
|
|
@@ -144,55 +146,167 @@ def _group_iter(src: Iterable[T], size: int) -> Iterable[list[T]]:
|
|
|
144
146
|
yield chunk
|
|
145
147
|
|
|
146
148
|
|
|
147
|
-
def
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
def _worker(item: T, func: Callable[[T], R], fixed_kwargs: dict[str, Any]) -> R:
|
|
149
|
+
def _worker(
|
|
150
|
+
item: T,
|
|
151
|
+
func: Callable[[T], R],
|
|
152
|
+
fixed_kwargs: Mapping[str, Any],
|
|
153
|
+
) -> R:
|
|
155
154
|
"""Execute the function with an item and fixed kwargs."""
|
|
156
155
|
return func(item, **fixed_kwargs)
|
|
157
156
|
|
|
158
157
|
|
|
158
|
+
def _run_batch(
|
|
159
|
+
items: Sequence[T],
|
|
160
|
+
func: Callable[[T], R],
|
|
161
|
+
fixed_kwargs: Mapping[str, Any],
|
|
162
|
+
) -> list[R]:
|
|
163
|
+
return [_worker(item, func, fixed_kwargs) for item in items]
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _attach_metadata(fut: Future[Any], idx: int, logical_size: int) -> None:
|
|
167
|
+
setattr(fut, '_speedy_idx', idx)
|
|
168
|
+
setattr(fut, '_speedy_size', logical_size)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _future_meta(fut: Future[Any]) -> tuple[int, int]:
|
|
172
|
+
return (
|
|
173
|
+
getattr(fut, '_speedy_idx'),
|
|
174
|
+
getattr(fut, '_speedy_size'),
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class _ResultCollector(Generic[R]):
|
|
179
|
+
def __init__(self, ordered: bool, logical_total: int | None) -> None:
|
|
180
|
+
self._ordered = ordered
|
|
181
|
+
self._logical_total = logical_total
|
|
182
|
+
self._results: list[R | None]
|
|
183
|
+
self._heap: list[tuple[int, list[R | None]]] | None
|
|
184
|
+
self._next_idx = 0
|
|
185
|
+
if ordered and logical_total is not None:
|
|
186
|
+
self._results = [None] * logical_total
|
|
187
|
+
self._heap = None
|
|
188
|
+
else:
|
|
189
|
+
self._results = []
|
|
190
|
+
self._heap = [] if ordered else None
|
|
191
|
+
|
|
192
|
+
def add(self, idx: int, items: Sequence[R | None]) -> None:
|
|
193
|
+
if not items:
|
|
194
|
+
return
|
|
195
|
+
if self._ordered and self._logical_total is not None:
|
|
196
|
+
self._results[idx : idx + len(items)] = list(items)
|
|
197
|
+
return
|
|
198
|
+
if self._ordered:
|
|
199
|
+
assert self._heap is not None
|
|
200
|
+
heappush(self._heap, (idx, list(items)))
|
|
201
|
+
self._flush_ready()
|
|
202
|
+
return
|
|
203
|
+
self._results.extend(items)
|
|
204
|
+
|
|
205
|
+
def _flush_ready(self) -> None:
|
|
206
|
+
if self._heap is None:
|
|
207
|
+
return
|
|
208
|
+
while self._heap and self._heap[0][0] == self._next_idx:
|
|
209
|
+
_, chunk = heappop(self._heap)
|
|
210
|
+
self._results.extend(chunk)
|
|
211
|
+
self._next_idx += len(chunk)
|
|
212
|
+
|
|
213
|
+
def finalize(self) -> list[R | None]:
|
|
214
|
+
self._flush_ready()
|
|
215
|
+
return self._results
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _resolve_worker_count(workers: int | None) -> int:
|
|
219
|
+
if workers is None:
|
|
220
|
+
return DEFAULT_WORKERS
|
|
221
|
+
if workers <= 0:
|
|
222
|
+
raise ValueError('workers must be a positive integer')
|
|
223
|
+
return workers
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _normalize_batch_result(result: Any, logical_size: int) -> list[Any]:
|
|
227
|
+
if logical_size == 1:
|
|
228
|
+
return [result]
|
|
229
|
+
if result is None:
|
|
230
|
+
raise ValueError('batched callable returned None for a batch result')
|
|
231
|
+
if isinstance(result, (str, bytes, bytearray)):
|
|
232
|
+
raise TypeError('batched callable must not return str/bytes when batching')
|
|
233
|
+
if isinstance(result, Sequence):
|
|
234
|
+
out = list(result)
|
|
235
|
+
elif isinstance(result, Iterable):
|
|
236
|
+
out = list(result)
|
|
237
|
+
else:
|
|
238
|
+
raise TypeError('batched callable must return an iterable of results')
|
|
239
|
+
if len(out) != logical_size:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f'batched callable returned {len(out)} items, expected {logical_size}',
|
|
242
|
+
)
|
|
243
|
+
return out
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _cancel_futures(inflight: set[Future[Any]]) -> None:
|
|
247
|
+
for fut in inflight:
|
|
248
|
+
fut.cancel()
|
|
249
|
+
inflight.clear()
|
|
250
|
+
|
|
251
|
+
|
|
159
252
|
# ────────────────────────────────────────────────────────────
|
|
160
253
|
# main API
|
|
161
254
|
# ────────────────────────────────────────────────────────────
|
|
162
255
|
def multi_thread(
|
|
163
|
-
func: Callable,
|
|
164
|
-
inputs: Iterable[
|
|
256
|
+
func: Callable[[T], R],
|
|
257
|
+
inputs: Iterable[T],
|
|
165
258
|
*,
|
|
166
|
-
workers:
|
|
259
|
+
workers: int | None = DEFAULT_WORKERS,
|
|
167
260
|
batch: int = 1,
|
|
168
261
|
ordered: bool = True,
|
|
169
262
|
progress: bool = True,
|
|
170
263
|
progress_update: int = 10,
|
|
171
264
|
prefetch_factor: int = 4,
|
|
172
|
-
timeout:
|
|
265
|
+
timeout: float | None = None,
|
|
173
266
|
stop_on_error: bool = True,
|
|
174
|
-
n_proc=0,
|
|
175
|
-
store_output_pkl_file:
|
|
176
|
-
**fixed_kwargs,
|
|
177
|
-
) -> list[
|
|
178
|
-
"""
|
|
179
|
-
|
|
267
|
+
n_proc: int = 0,
|
|
268
|
+
store_output_pkl_file: str | None = None,
|
|
269
|
+
**fixed_kwargs: Any,
|
|
270
|
+
) -> list[R | None]:
|
|
271
|
+
"""Execute ``func`` over ``inputs`` using a managed thread pool.
|
|
272
|
+
|
|
273
|
+
The scheduler supports batching, ordered result delivery, progress
|
|
274
|
+
reporting, cooperative error handling, and a whole-run timeout.
|
|
180
275
|
|
|
181
276
|
Parameters
|
|
182
277
|
----------
|
|
183
|
-
func
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
278
|
+
func : Callable[[T], R]
|
|
279
|
+
Target callable applied to each logical input.
|
|
280
|
+
inputs : Iterable[T]
|
|
281
|
+
Source iterable with input payloads.
|
|
282
|
+
workers : int | None, optional
|
|
283
|
+
Worker thread count (defaults to ``cpu_count()*2`` when ``None``).
|
|
284
|
+
batch : int, optional
|
|
285
|
+
Logical items grouped per invocation. ``1`` disables batching.
|
|
286
|
+
ordered : bool, optional
|
|
287
|
+
Preserve original ordering when ``True`` (default).
|
|
288
|
+
progress : bool, optional
|
|
289
|
+
Toggle tqdm-based progress reporting.
|
|
290
|
+
progress_update : int, optional
|
|
291
|
+
Minimum logical items between progress refreshes.
|
|
292
|
+
prefetch_factor : int, optional
|
|
293
|
+
Multiplier controlling in-flight items (``workers * prefetch_factor``).
|
|
294
|
+
timeout : float | None, optional
|
|
295
|
+
Overall wall-clock timeout in seconds.
|
|
296
|
+
stop_on_error : bool, optional
|
|
297
|
+
Abort immediately on the first exception when ``True``.
|
|
298
|
+
n_proc : int, optional
|
|
299
|
+
Optional process-level fan-out; ``>1`` shards work across processes.
|
|
300
|
+
store_output_pkl_file : str | None, optional
|
|
301
|
+
When provided, persist the results to disk via speedy_utils helpers.
|
|
302
|
+
fixed_kwargs : dict[str, Any]
|
|
303
|
+
Extra kwargs forwarded to every invocation of ``func``.
|
|
304
|
+
|
|
305
|
+
Returns
|
|
306
|
+
-------
|
|
307
|
+
list[R | None]
|
|
308
|
+
Collected results, preserving order when requested. Failed tasks yield
|
|
309
|
+
``None`` entries if ``stop_on_error`` is ``False``.
|
|
196
310
|
"""
|
|
197
311
|
from speedy_utils import dump_json_or_pickle, load_by_ext
|
|
198
312
|
|
|
@@ -201,215 +315,223 @@ def multi_thread(
|
|
|
201
315
|
|
|
202
316
|
from fastcore.all import threaded
|
|
203
317
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
for i in range(0, len(
|
|
209
|
-
proc_inputs_list.append(inputs[i : i + n_per_proc])
|
|
318
|
+
items = list(inputs)
|
|
319
|
+
if not items:
|
|
320
|
+
return []
|
|
321
|
+
n_per_proc = max(len(items) // n_proc, 1)
|
|
322
|
+
chunks = [items[i : i + n_per_proc] for i in range(0, len(items), n_per_proc)]
|
|
210
323
|
procs = []
|
|
211
324
|
in_process_multi_thread = threaded(process=True)(multi_thread)
|
|
325
|
+
results: list[R | None] = []
|
|
212
326
|
|
|
213
|
-
for
|
|
214
|
-
with tempfile.NamedTemporaryFile(
|
|
215
|
-
|
|
216
|
-
) as tmp_file:
|
|
217
|
-
file_pkl = tmp_file.name
|
|
327
|
+
for proc_idx, chunk in enumerate(chunks):
|
|
328
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix='multi_thread.pkl') as fh:
|
|
329
|
+
file_pkl = fh.name
|
|
218
330
|
assert isinstance(in_process_multi_thread, Callable)
|
|
219
331
|
proc = in_process_multi_thread(
|
|
220
332
|
func,
|
|
221
|
-
|
|
333
|
+
chunk,
|
|
222
334
|
workers=workers,
|
|
223
335
|
batch=batch,
|
|
224
336
|
ordered=ordered,
|
|
225
|
-
progress=
|
|
337
|
+
progress=proc_idx == 0,
|
|
226
338
|
progress_update=progress_update,
|
|
227
339
|
prefetch_factor=prefetch_factor,
|
|
228
340
|
timeout=timeout,
|
|
229
341
|
stop_on_error=stop_on_error,
|
|
230
|
-
n_proc=0,
|
|
342
|
+
n_proc=0,
|
|
231
343
|
store_output_pkl_file=file_pkl,
|
|
232
344
|
**fixed_kwargs,
|
|
233
345
|
)
|
|
234
|
-
procs.append(
|
|
235
|
-
# join
|
|
236
|
-
results = []
|
|
346
|
+
procs.append((proc, file_pkl))
|
|
237
347
|
|
|
238
348
|
for proc, file_pkl in procs:
|
|
239
349
|
proc.join()
|
|
240
|
-
logger.info(
|
|
241
|
-
|
|
350
|
+
logger.info('process finished: %s', proc)
|
|
351
|
+
try:
|
|
352
|
+
results.extend(load_by_ext(file_pkl))
|
|
353
|
+
finally:
|
|
354
|
+
try:
|
|
355
|
+
os.unlink(file_pkl)
|
|
356
|
+
except OSError as exc: # pragma: no cover - best effort cleanup
|
|
357
|
+
logger.warning('failed to remove temp file %s: %s', file_pkl, exc)
|
|
242
358
|
return results
|
|
243
359
|
|
|
244
360
|
try:
|
|
245
361
|
import pandas as pd
|
|
246
362
|
|
|
247
363
|
if isinstance(inputs, pd.DataFrame):
|
|
248
|
-
inputs = inputs.to_dict(orient=
|
|
249
|
-
except ImportError:
|
|
364
|
+
inputs = cast(Iterable[T], inputs.to_dict(orient='records'))
|
|
365
|
+
except ImportError: # pragma: no cover - optional dependency
|
|
250
366
|
pass
|
|
251
367
|
|
|
368
|
+
if batch <= 0:
|
|
369
|
+
raise ValueError('batch must be a positive integer')
|
|
370
|
+
if prefetch_factor <= 0:
|
|
371
|
+
raise ValueError('prefetch_factor must be a positive integer')
|
|
372
|
+
|
|
373
|
+
workers_val = _resolve_worker_count(workers)
|
|
374
|
+
progress_update = max(progress_update, 1)
|
|
375
|
+
fixed_kwargs_map: Mapping[str, Any] = MappingProxyType(dict(fixed_kwargs))
|
|
376
|
+
|
|
252
377
|
try:
|
|
253
|
-
|
|
254
|
-
except Exception:
|
|
255
|
-
|
|
256
|
-
workers_val = workers if workers is not None else DEFAULT_WORKERS
|
|
378
|
+
logical_total = len(inputs) # type: ignore[arg-type]
|
|
379
|
+
except Exception: # pragma: no cover - generic iterable
|
|
380
|
+
logical_total = None
|
|
257
381
|
|
|
258
|
-
if batch == 1 and
|
|
259
|
-
batch = 32
|
|
382
|
+
if batch == 1 and logical_total and logical_total / max(workers_val, 1) > 20_000:
|
|
383
|
+
batch = 32
|
|
260
384
|
|
|
261
|
-
# ── build (maybe‑batched) source iterator ────────────────────────────
|
|
262
385
|
src_iter: Iterable[Any] = iter(inputs)
|
|
263
386
|
if batch > 1:
|
|
264
387
|
src_iter = _group_iter(src_iter, batch)
|
|
265
|
-
# Ensure src_iter is always an iterator
|
|
266
388
|
src_iter = iter(src_iter)
|
|
389
|
+
collector: _ResultCollector[Any] = _ResultCollector(ordered, logical_total)
|
|
267
390
|
|
|
268
|
-
# total logical items (for bar & ordered pre‑allocation)
|
|
269
|
-
logical_total = n_inputs
|
|
270
|
-
if logical_total is not None and batch > 1:
|
|
271
|
-
logical_total = n_inputs # still number of *items*, not batches
|
|
272
|
-
|
|
273
|
-
# ── progress bar ─────────────────────────────────────────────────────
|
|
274
391
|
bar = None
|
|
275
392
|
last_bar_update = 0
|
|
276
|
-
if
|
|
393
|
+
if (
|
|
394
|
+
progress
|
|
395
|
+
and tqdm is not None
|
|
396
|
+
and logical_total is not None
|
|
397
|
+
and logical_total > 0
|
|
398
|
+
):
|
|
277
399
|
bar = tqdm(
|
|
278
400
|
total=logical_total,
|
|
279
401
|
ncols=128,
|
|
280
|
-
colour=
|
|
281
|
-
bar_format=
|
|
282
|
-
|
|
402
|
+
colour='green',
|
|
403
|
+
bar_format=(
|
|
404
|
+
'{l_bar}{bar}| {n_fmt}/{total_fmt}'
|
|
405
|
+
' [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
|
|
406
|
+
),
|
|
283
407
|
)
|
|
284
408
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
results: list[Any] = [None] * logical_total
|
|
288
|
-
else:
|
|
289
|
-
results = []
|
|
290
|
-
|
|
291
|
-
# ── main execution loop ──────────────────────────────────────────────────
|
|
292
|
-
workers_val = workers if workers is not None else DEFAULT_WORKERS
|
|
293
|
-
max_inflight = workers_val * max(prefetch_factor, 1)
|
|
409
|
+
deadline = time.monotonic() + timeout if timeout is not None else None
|
|
410
|
+
max_inflight = max(workers_val * prefetch_factor, 1)
|
|
294
411
|
completed_items = 0
|
|
295
|
-
next_logical_idx = 0
|
|
412
|
+
next_logical_idx = 0
|
|
296
413
|
|
|
297
|
-
|
|
298
|
-
|
|
414
|
+
def items_inflight() -> int:
|
|
415
|
+
return next_logical_idx - completed_items
|
|
299
416
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
417
|
+
inflight: set[Future[Any]] = set()
|
|
418
|
+
pool = ThreadPoolExecutor(
|
|
419
|
+
max_workers=workers_val,
|
|
420
|
+
thread_name_prefix='speedy-thread',
|
|
421
|
+
)
|
|
422
|
+
shutdown_kwargs: dict[str, Any] = {'wait': True}
|
|
423
|
+
|
|
424
|
+
try:
|
|
425
|
+
def submit_arg(arg: Any) -> None:
|
|
426
|
+
nonlocal next_logical_idx
|
|
306
427
|
if batch > 1:
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
inflight.add(fut)
|
|
313
|
-
next_logical_idx += len(arg)
|
|
314
|
-
_track_executor_threads(pool)
|
|
428
|
+
batch_items = list(arg)
|
|
429
|
+
if not batch_items:
|
|
430
|
+
return
|
|
431
|
+
fut = pool.submit(_run_batch, batch_items, func, fixed_kwargs_map)
|
|
432
|
+
logical_size = len(batch_items)
|
|
315
433
|
else:
|
|
316
|
-
fut = pool.submit(_worker, arg, func,
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
434
|
+
fut = pool.submit(_worker, arg, func, fixed_kwargs_map)
|
|
435
|
+
logical_size = 1
|
|
436
|
+
_attach_metadata(fut, next_logical_idx, logical_size)
|
|
437
|
+
next_logical_idx += logical_size
|
|
438
|
+
inflight.add(fut)
|
|
439
|
+
_track_executor_threads(pool)
|
|
321
440
|
|
|
322
441
|
try:
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
if isinstance(out_items, Iterable)
|
|
345
|
-
else [out_items]
|
|
346
|
-
)
|
|
347
|
-
|
|
348
|
-
# store outputs
|
|
349
|
-
if ordered and logical_total is not None:
|
|
350
|
-
results[idx : idx + len(out_items)] = out_items
|
|
351
|
-
else:
|
|
352
|
-
results.extend(out_items)
|
|
442
|
+
while items_inflight() < max_inflight:
|
|
443
|
+
submit_arg(next(src_iter))
|
|
444
|
+
except StopIteration:
|
|
445
|
+
pass
|
|
446
|
+
|
|
447
|
+
while inflight:
|
|
448
|
+
wait_timeout = None
|
|
449
|
+
if deadline is not None:
|
|
450
|
+
remaining = deadline - time.monotonic()
|
|
451
|
+
if remaining <= 0:
|
|
452
|
+
_cancel_futures(inflight)
|
|
453
|
+
raise TimeoutError(
|
|
454
|
+
f'multi_thread timed out after {timeout} seconds',
|
|
455
|
+
)
|
|
456
|
+
wait_timeout = max(remaining, 0.0)
|
|
457
|
+
|
|
458
|
+
done, _ = wait(
|
|
459
|
+
inflight,
|
|
460
|
+
timeout=wait_timeout,
|
|
461
|
+
return_when=FIRST_COMPLETED,
|
|
462
|
+
)
|
|
353
463
|
|
|
354
|
-
|
|
464
|
+
if not done:
|
|
465
|
+
_cancel_futures(inflight)
|
|
466
|
+
raise TimeoutError(
|
|
467
|
+
f'multi_thread timed out after {timeout} seconds',
|
|
468
|
+
)
|
|
355
469
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
470
|
+
for fut in done:
|
|
471
|
+
inflight.remove(fut)
|
|
472
|
+
idx, logical_size = _future_meta(fut)
|
|
473
|
+
try:
|
|
474
|
+
result = fut.result()
|
|
475
|
+
except Exception as exc:
|
|
476
|
+
if stop_on_error:
|
|
477
|
+
_cancel_futures(inflight)
|
|
478
|
+
raise
|
|
479
|
+
logger.exception('multi_thread task failed', exc_info=exc)
|
|
480
|
+
out_items = [None] * logical_size
|
|
481
|
+
else:
|
|
482
|
+
try:
|
|
483
|
+
out_items = _normalize_batch_result(result, logical_size)
|
|
484
|
+
except Exception as exc:
|
|
485
|
+
_cancel_futures(inflight)
|
|
486
|
+
raise RuntimeError(
|
|
487
|
+
'batched callable returned an unexpected shape',
|
|
488
|
+
) from exc
|
|
489
|
+
|
|
490
|
+
collector.add(idx, out_items)
|
|
491
|
+
completed_items += len(out_items)
|
|
492
|
+
|
|
493
|
+
if bar:
|
|
494
|
+
delta = completed_items - last_bar_update
|
|
495
|
+
if delta >= progress_update:
|
|
496
|
+
bar.update(delta)
|
|
359
497
|
last_bar_update = completed_items
|
|
360
|
-
# Show pending, submitted, processing in the bar postfix
|
|
361
498
|
submitted = next_logical_idx
|
|
362
|
-
processing = min(len(inflight), workers_val)
|
|
363
499
|
pending = (
|
|
364
|
-
(logical_total - submitted)
|
|
500
|
+
max(logical_total - submitted, 0)
|
|
365
501
|
if logical_total is not None
|
|
366
|
-
else
|
|
502
|
+
else '-'
|
|
367
503
|
)
|
|
368
504
|
postfix = {
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
"processing": processing,
|
|
505
|
+
'processing': min(len(inflight), workers_val),
|
|
506
|
+
'pending': pending,
|
|
372
507
|
}
|
|
373
508
|
bar.set_postfix(postfix)
|
|
374
509
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
# Break the inner loop as we've processed one future
|
|
401
|
-
break
|
|
402
|
-
|
|
403
|
-
# If we've exhausted the inner loop without processing anything,
|
|
404
|
-
# and there are still in-flight tasks, we need to wait for them
|
|
405
|
-
if inflight and timeout is not None:
|
|
406
|
-
# Use a small timeout to avoid hanging indefinitely
|
|
407
|
-
time.sleep(min(0.01, timeout / 10))
|
|
408
|
-
|
|
409
|
-
finally:
|
|
410
|
-
if bar:
|
|
411
|
-
bar.update(completed_items - last_bar_update)
|
|
412
|
-
bar.close()
|
|
510
|
+
try:
|
|
511
|
+
while items_inflight() < max_inflight:
|
|
512
|
+
submit_arg(next(src_iter))
|
|
513
|
+
except StopIteration:
|
|
514
|
+
pass
|
|
515
|
+
|
|
516
|
+
results = collector.finalize()
|
|
517
|
+
|
|
518
|
+
except KeyboardInterrupt:
|
|
519
|
+
shutdown_kwargs = {'wait': False, 'cancel_futures': True}
|
|
520
|
+
_cancel_futures(inflight)
|
|
521
|
+
kill_all_thread(SystemExit)
|
|
522
|
+
raise KeyboardInterrupt() from None
|
|
523
|
+
finally:
|
|
524
|
+
try:
|
|
525
|
+
pool.shutdown(**shutdown_kwargs)
|
|
526
|
+
except TypeError: # pragma: no cover - Python <3.9 fallback
|
|
527
|
+
pool.shutdown(shutdown_kwargs.get('wait', True))
|
|
528
|
+
if bar:
|
|
529
|
+
delta = completed_items - last_bar_update
|
|
530
|
+
if delta > 0:
|
|
531
|
+
bar.update(delta)
|
|
532
|
+
bar.close()
|
|
533
|
+
|
|
534
|
+
results = collector.finalize() if 'results' not in locals() else results
|
|
413
535
|
if store_output_pkl_file:
|
|
414
536
|
dump_json_or_pickle(results, store_output_pkl_file)
|
|
415
537
|
_prune_dead_threads()
|
|
@@ -417,32 +539,19 @@ def multi_thread(
|
|
|
417
539
|
|
|
418
540
|
|
|
419
541
|
def multi_thread_standard(
|
|
420
|
-
fn: Callable[[
|
|
421
|
-
) -> list[
|
|
422
|
-
"""Execute
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
The function to execute for each item.
|
|
431
|
-
items : Iterable
|
|
432
|
-
The items to process.
|
|
433
|
-
workers : int, optional
|
|
434
|
-
Number of worker threads, by default 4.
|
|
435
|
-
|
|
436
|
-
Returns
|
|
437
|
-
-------
|
|
438
|
-
list
|
|
439
|
-
Results in same order as input items.
|
|
440
|
-
"""
|
|
441
|
-
with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
442
|
-
futures = []
|
|
542
|
+
fn: Callable[[T], R], items: Iterable[T], workers: int = 4
|
|
543
|
+
) -> list[R]:
|
|
544
|
+
"""Execute ``fn`` across ``items`` while preserving submission order."""
|
|
545
|
+
|
|
546
|
+
workers_val = _resolve_worker_count(workers)
|
|
547
|
+
with ThreadPoolExecutor(
|
|
548
|
+
max_workers=workers_val,
|
|
549
|
+
thread_name_prefix='speedy-thread',
|
|
550
|
+
) as executor:
|
|
551
|
+
futures: list[Future[R]] = []
|
|
443
552
|
for item in items:
|
|
444
553
|
futures.append(executor.submit(fn, item))
|
|
445
|
-
|
|
554
|
+
_track_executor_threads(executor)
|
|
446
555
|
results = [fut.result() for fut in futures]
|
|
447
556
|
_prune_dead_threads()
|
|
448
557
|
return results
|
|
@@ -452,18 +561,24 @@ def _async_raise(thread_id: int, exc_type: type[BaseException]) -> bool:
|
|
|
452
561
|
if thread_id <= 0:
|
|
453
562
|
return False
|
|
454
563
|
if not issubclass(exc_type, BaseException):
|
|
455
|
-
raise TypeError(
|
|
564
|
+
raise TypeError('exc_type must derive from BaseException')
|
|
456
565
|
res = _PY_SET_ASYNC_EXC(ctypes.c_ulong(thread_id), ctypes.py_object(exc_type))
|
|
457
566
|
if res == 0:
|
|
458
567
|
return False
|
|
459
568
|
if res > 1: # pragma: no cover - defensive branch
|
|
460
569
|
_PY_SET_ASYNC_EXC(ctypes.c_ulong(thread_id), None)
|
|
461
|
-
raise SystemError(
|
|
570
|
+
raise SystemError('PyThreadState_SetAsyncExc failed')
|
|
462
571
|
return True
|
|
463
572
|
|
|
464
573
|
|
|
465
574
|
def kill_all_thread(exc_type: type[BaseException] = SystemExit, join_timeout: float = 0.1) -> int:
|
|
466
|
-
"""Forcefully stop tracked worker threads
|
|
575
|
+
"""Forcefully stop tracked worker threads (dangerous; use sparingly).
|
|
576
|
+
|
|
577
|
+
Returns
|
|
578
|
+
-------
|
|
579
|
+
int
|
|
580
|
+
Count of threads signalled for termination.
|
|
581
|
+
"""
|
|
467
582
|
_prune_dead_threads()
|
|
468
583
|
current = threading.current_thread()
|
|
469
584
|
with _SPEEDY_THREADS_LOCK:
|
|
@@ -481,16 +596,16 @@ def kill_all_thread(exc_type: type[BaseException] = SystemExit, join_timeout: fl
|
|
|
481
596
|
terminated += 1
|
|
482
597
|
thread.join(timeout=join_timeout)
|
|
483
598
|
else:
|
|
484
|
-
logger.warning(
|
|
599
|
+
logger.warning('Unable to signal thread %s', thread.name)
|
|
485
600
|
except Exception as exc: # pragma: no cover - defensive
|
|
486
|
-
logger.error(
|
|
601
|
+
logger.error('Failed to stop thread %s: %s', thread.name, exc)
|
|
487
602
|
_prune_dead_threads()
|
|
488
603
|
return terminated
|
|
489
604
|
|
|
490
605
|
|
|
491
606
|
__all__ = [
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
607
|
+
'SPEEDY_RUNNING_THREADS',
|
|
608
|
+
'multi_thread',
|
|
609
|
+
'multi_thread_standard',
|
|
610
|
+
'kill_all_thread',
|
|
496
611
|
]
|