speedy-utils 1.1.21__py3-none-any.whl → 1.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,13 +14,14 @@
14
14
  # making it suitable for data processing pipelines, batch operations, and concurrent API calls.
15
15
  #
16
16
  # Public API / Data Contracts:
17
- # • multi_thread(func, inputs, num_workers=None, progress=True, **kwargs) -> List[Any] - Main parallel execution
18
- # • multi_thread_batch(func, inputs, batch_size=10, num_workers=None, **kwargs) -> List[Any] - Batched processing
17
+ # • multi_thread(func, inputs, *, workers=None, **kwargs) -> list[Any] - Main executor
18
+ # • multi_thread_standard(func, inputs, workers=4) -> list[Any] - Simple ordered helper
19
+ # • kill_all_thread(exc_type=SystemExit, join_timeout=0.1) -> int - Emergency stop
19
20
  # • DEFAULT_WORKERS = (cpu_count * 2) - Default worker thread count
20
- # • T = TypeVar("T"), R = TypeVar("R") - Generic type variables for input/output typing
21
- # • _group_iter(src, size) -> Iterable[List[T]] - Utility for chunking iterables
21
+ # • T = TypeVar('T'), R = TypeVar('R') - Generic type variables for input/output typing
22
+ # • _group_iter(src, size) -> Iterable[list[T]] - Utility for chunking iterables
22
23
  # • _worker(item, func, fixed_kwargs) -> R - Individual worker function wrapper
23
- # • _short_tb() -> str - Shortened traceback formatter for cleaner error logs
24
+ # • _ResultCollector - Maintains ordered/unordered result aggregation
24
25
  #
25
26
  # Invariants / Constraints:
26
27
  # • Worker count MUST be positive integer, defaults to (CPU cores * 2)
@@ -81,11 +82,12 @@ import ctypes
81
82
  import os
82
83
  import threading
83
84
  import time
84
- import traceback
85
- from collections.abc import Callable, Iterable
86
- from concurrent.futures import ThreadPoolExecutor, as_completed
85
+ from collections.abc import Callable, Iterable, Mapping, Sequence
86
+ from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
87
+ from heapq import heappop, heappush
87
88
  from itertools import islice
88
- from typing import Any, TypeVar, Union
89
+ from types import MappingProxyType
90
+ from typing import Any, Generic, TypeVar, cast
89
91
 
90
92
  from loguru import logger
91
93
 
@@ -97,10 +99,10 @@ except ImportError: # pragma: no cover
97
99
  # Sensible defaults
98
100
  DEFAULT_WORKERS = (os.cpu_count() or 4) * 2
99
101
 
100
- T = TypeVar("T")
101
- R = TypeVar("R")
102
+ T = TypeVar('T')
103
+ R = TypeVar('R')
102
104
 
103
- SPEEDY_RUNNING_THREADS: list[threading.Thread] = []
105
+ SPEEDY_RUNNING_THREADS: list[threading.Thread] = [] # cooperative shutdown tracking
104
106
  _SPEEDY_THREADS_LOCK = threading.Lock()
105
107
 
106
108
  _PY_SET_ASYNC_EXC = ctypes.pythonapi.PyThreadState_SetAsyncExc
@@ -131,7 +133,7 @@ def _track_threads(threads: Iterable[threading.Thread]) -> None:
131
133
 
132
134
 
133
135
  def _track_executor_threads(pool: ThreadPoolExecutor) -> None:
134
- thread_set = getattr(pool, "_threads", None)
136
+ thread_set = getattr(pool, '_threads', None)
135
137
  if not thread_set:
136
138
  return
137
139
  _track_threads(tuple(thread_set))
@@ -144,55 +146,167 @@ def _group_iter(src: Iterable[T], size: int) -> Iterable[list[T]]:
144
146
  yield chunk
145
147
 
146
148
 
147
- def _short_tb() -> str:
148
- """Return a shortened traceback, excluding internal frames."""
149
- tb = "".join(traceback.format_exc())
150
- # hide frames inside this helper to keep logs short
151
- return "\n".join(ln for ln in tb.splitlines() if "multi_thread.py" not in ln)
152
-
153
-
154
- def _worker(item: T, func: Callable[[T], R], fixed_kwargs: dict[str, Any]) -> R:
149
+ def _worker(
150
+ item: T,
151
+ func: Callable[[T], R],
152
+ fixed_kwargs: Mapping[str, Any],
153
+ ) -> R:
155
154
  """Execute the function with an item and fixed kwargs."""
156
155
  return func(item, **fixed_kwargs)
157
156
 
158
157
 
158
+ def _run_batch(
159
+ items: Sequence[T],
160
+ func: Callable[[T], R],
161
+ fixed_kwargs: Mapping[str, Any],
162
+ ) -> list[R]:
163
+ return [_worker(item, func, fixed_kwargs) for item in items]
164
+
165
+
166
+ def _attach_metadata(fut: Future[Any], idx: int, logical_size: int) -> None:
167
+ setattr(fut, '_speedy_idx', idx)
168
+ setattr(fut, '_speedy_size', logical_size)
169
+
170
+
171
+ def _future_meta(fut: Future[Any]) -> tuple[int, int]:
172
+ return (
173
+ getattr(fut, '_speedy_idx'),
174
+ getattr(fut, '_speedy_size'),
175
+ )
176
+
177
+
178
+ class _ResultCollector(Generic[R]):
179
+ def __init__(self, ordered: bool, logical_total: int | None) -> None:
180
+ self._ordered = ordered
181
+ self._logical_total = logical_total
182
+ self._results: list[R | None]
183
+ self._heap: list[tuple[int, list[R | None]]] | None
184
+ self._next_idx = 0
185
+ if ordered and logical_total is not None:
186
+ self._results = [None] * logical_total
187
+ self._heap = None
188
+ else:
189
+ self._results = []
190
+ self._heap = [] if ordered else None
191
+
192
+ def add(self, idx: int, items: Sequence[R | None]) -> None:
193
+ if not items:
194
+ return
195
+ if self._ordered and self._logical_total is not None:
196
+ self._results[idx : idx + len(items)] = list(items)
197
+ return
198
+ if self._ordered:
199
+ assert self._heap is not None
200
+ heappush(self._heap, (idx, list(items)))
201
+ self._flush_ready()
202
+ return
203
+ self._results.extend(items)
204
+
205
+ def _flush_ready(self) -> None:
206
+ if self._heap is None:
207
+ return
208
+ while self._heap and self._heap[0][0] == self._next_idx:
209
+ _, chunk = heappop(self._heap)
210
+ self._results.extend(chunk)
211
+ self._next_idx += len(chunk)
212
+
213
+ def finalize(self) -> list[R | None]:
214
+ self._flush_ready()
215
+ return self._results
216
+
217
+
218
+ def _resolve_worker_count(workers: int | None) -> int:
219
+ if workers is None:
220
+ return DEFAULT_WORKERS
221
+ if workers <= 0:
222
+ raise ValueError('workers must be a positive integer')
223
+ return workers
224
+
225
+
226
+ def _normalize_batch_result(result: Any, logical_size: int) -> list[Any]:
227
+ if logical_size == 1:
228
+ return [result]
229
+ if result is None:
230
+ raise ValueError('batched callable returned None for a batch result')
231
+ if isinstance(result, (str, bytes, bytearray)):
232
+ raise TypeError('batched callable must not return str/bytes when batching')
233
+ if isinstance(result, Sequence):
234
+ out = list(result)
235
+ elif isinstance(result, Iterable):
236
+ out = list(result)
237
+ else:
238
+ raise TypeError('batched callable must return an iterable of results')
239
+ if len(out) != logical_size:
240
+ raise ValueError(
241
+ f'batched callable returned {len(out)} items, expected {logical_size}',
242
+ )
243
+ return out
244
+
245
+
246
+ def _cancel_futures(inflight: set[Future[Any]]) -> None:
247
+ for fut in inflight:
248
+ fut.cancel()
249
+ inflight.clear()
250
+
251
+
159
252
  # ────────────────────────────────────────────────────────────
160
253
  # main API
161
254
  # ────────────────────────────────────────────────────────────
162
255
  def multi_thread(
163
- func: Callable,
164
- inputs: Iterable[Any],
256
+ func: Callable[[T], R],
257
+ inputs: Iterable[T],
165
258
  *,
166
- workers: Union[int, None] = DEFAULT_WORKERS,
259
+ workers: int | None = DEFAULT_WORKERS,
167
260
  batch: int = 1,
168
261
  ordered: bool = True,
169
262
  progress: bool = True,
170
263
  progress_update: int = 10,
171
264
  prefetch_factor: int = 4,
172
- timeout: Union[float, None] = None,
265
+ timeout: float | None = None,
173
266
  stop_on_error: bool = True,
174
- n_proc=0,
175
- store_output_pkl_file: Union[str, None] = None,
176
- **fixed_kwargs,
177
- ) -> list[Any]:
178
- """
179
- ThreadPool **map** that returns a *list*.
267
+ n_proc: int = 0,
268
+ store_output_pkl_file: str | None = None,
269
+ **fixed_kwargs: Any,
270
+ ) -> list[R | None]:
271
+ """Execute ``func`` over ``inputs`` using a managed thread pool.
272
+
273
+ The scheduler supports batching, ordered result delivery, progress
274
+ reporting, cooperative error handling, and a whole-run timeout.
180
275
 
181
276
  Parameters
182
277
  ----------
183
- func target callable.
184
- inputs – iterable with the arguments.
185
- workers – defaults to ``os.cpu_count()*2``.
186
- batch – package *batch* inputs into one call for low‑overhead.
187
- ordered – keep original order (costs memory); if ``False`` results
188
- are yielded as soon as they finish.
189
- progress – show a tqdm bar (requires *tqdm* installed).
190
- progress_update bar redraw frequency (logical items, *not* batches).
191
- prefetch_factor in‑flight tasks ≈ ``workers * prefetch_factor``.
192
- timeout – overall timeout (seconds) for the mapping.
193
- stop_on_error – raise immediately on first exception (default). If
194
- ``False`` the failing task’s result becomes ``None``.
195
- **fixed_kwargs – static keyword args forwarded to every ``func()`` call.
278
+ func : Callable[[T], R]
279
+ Target callable applied to each logical input.
280
+ inputs : Iterable[T]
281
+ Source iterable with input payloads.
282
+ workers : int | None, optional
283
+ Worker thread count (defaults to ``cpu_count()*2`` when ``None``).
284
+ batch : int, optional
285
+ Logical items grouped per invocation. ``1`` disables batching.
286
+ ordered : bool, optional
287
+ Preserve original ordering when ``True`` (default).
288
+ progress : bool, optional
289
+ Toggle tqdm-based progress reporting.
290
+ progress_update : int, optional
291
+ Minimum logical items between progress refreshes.
292
+ prefetch_factor : int, optional
293
+ Multiplier controlling in-flight items (``workers * prefetch_factor``).
294
+ timeout : float | None, optional
295
+ Overall wall-clock timeout in seconds.
296
+ stop_on_error : bool, optional
297
+ Abort immediately on the first exception when ``True``.
298
+ n_proc : int, optional
299
+ Optional process-level fan-out; ``>1`` shards work across processes.
300
+ store_output_pkl_file : str | None, optional
301
+ When provided, persist the results to disk via speedy_utils helpers.
302
+ fixed_kwargs : dict[str, Any]
303
+ Extra kwargs forwarded to every invocation of ``func``.
304
+
305
+ Returns
306
+ -------
307
+ list[R | None]
308
+ Collected results, preserving order when requested. Failed tasks yield
309
+ ``None`` entries if ``stop_on_error`` is ``False``.
196
310
  """
197
311
  from speedy_utils import dump_json_or_pickle, load_by_ext
198
312
 
@@ -201,215 +315,223 @@ def multi_thread(
201
315
 
202
316
  from fastcore.all import threaded
203
317
 
204
- # split the inputs by nproc
205
- inputs = list(inputs)
206
- n_per_proc = max(len(inputs) // n_proc, 1)
207
- proc_inputs_list = []
208
- for i in range(0, len(inputs), n_per_proc):
209
- proc_inputs_list.append(inputs[i : i + n_per_proc])
318
+ items = list(inputs)
319
+ if not items:
320
+ return []
321
+ n_per_proc = max(len(items) // n_proc, 1)
322
+ chunks = [items[i : i + n_per_proc] for i in range(0, len(items), n_per_proc)]
210
323
  procs = []
211
324
  in_process_multi_thread = threaded(process=True)(multi_thread)
325
+ results: list[R | None] = []
212
326
 
213
- for proc_id, proc_inputs in enumerate(proc_inputs_list):
214
- with tempfile.NamedTemporaryFile(
215
- delete=False, suffix="multi_thread.pkl"
216
- ) as tmp_file:
217
- file_pkl = tmp_file.name
327
+ for proc_idx, chunk in enumerate(chunks):
328
+ with tempfile.NamedTemporaryFile(delete=False, suffix='multi_thread.pkl') as fh:
329
+ file_pkl = fh.name
218
330
  assert isinstance(in_process_multi_thread, Callable)
219
331
  proc = in_process_multi_thread(
220
332
  func,
221
- proc_inputs,
333
+ chunk,
222
334
  workers=workers,
223
335
  batch=batch,
224
336
  ordered=ordered,
225
- progress=proc_id == 0,
337
+ progress=proc_idx == 0,
226
338
  progress_update=progress_update,
227
339
  prefetch_factor=prefetch_factor,
228
340
  timeout=timeout,
229
341
  stop_on_error=stop_on_error,
230
- n_proc=0, # prevent recursion
342
+ n_proc=0,
231
343
  store_output_pkl_file=file_pkl,
232
344
  **fixed_kwargs,
233
345
  )
234
- procs.append([proc, file_pkl])
235
- # join
236
- results = []
346
+ procs.append((proc, file_pkl))
237
347
 
238
348
  for proc, file_pkl in procs:
239
349
  proc.join()
240
- logger.info(f"Done proc {proc=}")
241
- results.extend(load_by_ext(file_pkl))
350
+ logger.info('process finished: %s', proc)
351
+ try:
352
+ results.extend(load_by_ext(file_pkl))
353
+ finally:
354
+ try:
355
+ os.unlink(file_pkl)
356
+ except OSError as exc: # pragma: no cover - best effort cleanup
357
+ logger.warning('failed to remove temp file %s: %s', file_pkl, exc)
242
358
  return results
243
359
 
244
360
  try:
245
361
  import pandas as pd
246
362
 
247
363
  if isinstance(inputs, pd.DataFrame):
248
- inputs = inputs.to_dict(orient="records")
249
- except ImportError:
364
+ inputs = cast(Iterable[T], inputs.to_dict(orient='records'))
365
+ except ImportError: # pragma: no cover - optional dependency
250
366
  pass
251
367
 
368
+ if batch <= 0:
369
+ raise ValueError('batch must be a positive integer')
370
+ if prefetch_factor <= 0:
371
+ raise ValueError('prefetch_factor must be a positive integer')
372
+
373
+ workers_val = _resolve_worker_count(workers)
374
+ progress_update = max(progress_update, 1)
375
+ fixed_kwargs_map: Mapping[str, Any] = MappingProxyType(dict(fixed_kwargs))
376
+
252
377
  try:
253
- n_inputs = len(inputs) # type: ignore[arg-type]
254
- except Exception:
255
- n_inputs = None
256
- workers_val = workers if workers is not None else DEFAULT_WORKERS
378
+ logical_total = len(inputs) # type: ignore[arg-type]
379
+ except Exception: # pragma: no cover - generic iterable
380
+ logical_total = None
257
381
 
258
- if batch == 1 and n_inputs and n_inputs / max(workers_val, 1) > 20_000:
259
- batch = 32 # empirically good for sub‑ms tasks
382
+ if batch == 1 and logical_total and logical_total / max(workers_val, 1) > 20_000:
383
+ batch = 32
260
384
 
261
- # ── build (maybe‑batched) source iterator ────────────────────────────
262
385
  src_iter: Iterable[Any] = iter(inputs)
263
386
  if batch > 1:
264
387
  src_iter = _group_iter(src_iter, batch)
265
- # Ensure src_iter is always an iterator
266
388
  src_iter = iter(src_iter)
389
+ collector: _ResultCollector[Any] = _ResultCollector(ordered, logical_total)
267
390
 
268
- # total logical items (for bar & ordered pre‑allocation)
269
- logical_total = n_inputs
270
- if logical_total is not None and batch > 1:
271
- logical_total = n_inputs # still number of *items*, not batches
272
-
273
- # ── progress bar ─────────────────────────────────────────────────────
274
391
  bar = None
275
392
  last_bar_update = 0
276
- if progress and tqdm is not None and logical_total is not None:
393
+ if (
394
+ progress
395
+ and tqdm is not None
396
+ and logical_total is not None
397
+ and logical_total > 0
398
+ ):
277
399
  bar = tqdm(
278
400
  total=logical_total,
279
401
  ncols=128,
280
- colour="green",
281
- bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"
282
- " [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
402
+ colour='green',
403
+ bar_format=(
404
+ '{l_bar}{bar}| {n_fmt}/{total_fmt}'
405
+ ' [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
406
+ ),
283
407
  )
284
408
 
285
- # ── prepare result container ─────────────────────────────────────────
286
- if ordered and logical_total is not None:
287
- results: list[Any] = [None] * logical_total
288
- else:
289
- results = []
290
-
291
- # ── main execution loop ──────────────────────────────────────────────────
292
- workers_val = workers if workers is not None else DEFAULT_WORKERS
293
- max_inflight = workers_val * max(prefetch_factor, 1)
409
+ deadline = time.monotonic() + timeout if timeout is not None else None
410
+ max_inflight = max(workers_val * prefetch_factor, 1)
294
411
  completed_items = 0
295
- next_logical_idx = 0 # index assigned to the next submission
412
+ next_logical_idx = 0
296
413
 
297
- with ThreadPoolExecutor(max_workers=workers) as pool:
298
- inflight = set()
414
+ def items_inflight() -> int:
415
+ return next_logical_idx - completed_items
299
416
 
300
- # prime the pool
301
- for _ in range(max_inflight):
302
- try:
303
- arg = next(src_iter)
304
- except StopIteration:
305
- break
417
+ inflight: set[Future[Any]] = set()
418
+ pool = ThreadPoolExecutor(
419
+ max_workers=workers_val,
420
+ thread_name_prefix='speedy-thread',
421
+ )
422
+ shutdown_kwargs: dict[str, Any] = {'wait': True}
423
+
424
+ try:
425
+ def submit_arg(arg: Any) -> None:
426
+ nonlocal next_logical_idx
306
427
  if batch > 1:
307
- fut = pool.submit(
308
- lambda items: [_worker(item, func, fixed_kwargs) for item in items],
309
- arg,
310
- )
311
- fut.idx = next_logical_idx # type: ignore[attr-defined]
312
- inflight.add(fut)
313
- next_logical_idx += len(arg)
314
- _track_executor_threads(pool)
428
+ batch_items = list(arg)
429
+ if not batch_items:
430
+ return
431
+ fut = pool.submit(_run_batch, batch_items, func, fixed_kwargs_map)
432
+ logical_size = len(batch_items)
315
433
  else:
316
- fut = pool.submit(_worker, arg, func, fixed_kwargs)
317
- fut.idx = next_logical_idx # type: ignore[attr-defined]
318
- inflight.add(fut)
319
- next_logical_idx += 1
320
- _track_executor_threads(pool)
434
+ fut = pool.submit(_worker, arg, func, fixed_kwargs_map)
435
+ logical_size = 1
436
+ _attach_metadata(fut, next_logical_idx, logical_size)
437
+ next_logical_idx += logical_size
438
+ inflight.add(fut)
439
+ _track_executor_threads(pool)
321
440
 
322
441
  try:
323
- # Process futures as they complete and add new ones to keep the pool busy
324
- while inflight: # Continue until all in-flight tasks are processed
325
- for fut in as_completed(inflight, timeout=timeout):
326
- inflight.remove(fut)
327
- idx = fut.idx # type: ignore[attr-defined]
328
- try:
329
- res = fut.result()
330
- except Exception:
331
- if stop_on_error:
332
- raise
333
- res = None
334
-
335
- # flatten res to list of logical outputs
336
- out_items = res if batch > 1 else [res]
337
-
338
- # Ensure out_items is a list (and thus Sized)
339
- if out_items is None:
340
- out_items = [None]
341
- elif not isinstance(out_items, list):
342
- out_items = (
343
- list(out_items)
344
- if isinstance(out_items, Iterable)
345
- else [out_items]
346
- )
347
-
348
- # store outputs
349
- if ordered and logical_total is not None:
350
- results[idx : idx + len(out_items)] = out_items
351
- else:
352
- results.extend(out_items)
442
+ while items_inflight() < max_inflight:
443
+ submit_arg(next(src_iter))
444
+ except StopIteration:
445
+ pass
446
+
447
+ while inflight:
448
+ wait_timeout = None
449
+ if deadline is not None:
450
+ remaining = deadline - time.monotonic()
451
+ if remaining <= 0:
452
+ _cancel_futures(inflight)
453
+ raise TimeoutError(
454
+ f'multi_thread timed out after {timeout} seconds',
455
+ )
456
+ wait_timeout = max(remaining, 0.0)
457
+
458
+ done, _ = wait(
459
+ inflight,
460
+ timeout=wait_timeout,
461
+ return_when=FIRST_COMPLETED,
462
+ )
353
463
 
354
- completed_items += len(out_items)
464
+ if not done:
465
+ _cancel_futures(inflight)
466
+ raise TimeoutError(
467
+ f'multi_thread timed out after {timeout} seconds',
468
+ )
355
469
 
356
- # progress bar update
357
- if bar and completed_items - last_bar_update >= progress_update:
358
- bar.update(completed_items - last_bar_update)
470
+ for fut in done:
471
+ inflight.remove(fut)
472
+ idx, logical_size = _future_meta(fut)
473
+ try:
474
+ result = fut.result()
475
+ except Exception as exc:
476
+ if stop_on_error:
477
+ _cancel_futures(inflight)
478
+ raise
479
+ logger.exception('multi_thread task failed', exc_info=exc)
480
+ out_items = [None] * logical_size
481
+ else:
482
+ try:
483
+ out_items = _normalize_batch_result(result, logical_size)
484
+ except Exception as exc:
485
+ _cancel_futures(inflight)
486
+ raise RuntimeError(
487
+ 'batched callable returned an unexpected shape',
488
+ ) from exc
489
+
490
+ collector.add(idx, out_items)
491
+ completed_items += len(out_items)
492
+
493
+ if bar:
494
+ delta = completed_items - last_bar_update
495
+ if delta >= progress_update:
496
+ bar.update(delta)
359
497
  last_bar_update = completed_items
360
- # Show pending, submitted, processing in the bar postfix
361
498
  submitted = next_logical_idx
362
- processing = min(len(inflight), workers_val)
363
499
  pending = (
364
- (logical_total - submitted)
500
+ max(logical_total - submitted, 0)
365
501
  if logical_total is not None
366
- else None
502
+ else '-'
367
503
  )
368
504
  postfix = {
369
- "pending": pending if pending is not None else "-",
370
- # 'submitted': submitted,
371
- "processing": processing,
505
+ 'processing': min(len(inflight), workers_val),
506
+ 'pending': pending,
372
507
  }
373
508
  bar.set_postfix(postfix)
374
509
 
375
- # keep queue full
376
- try:
377
- while next_logical_idx - completed_items < max_inflight:
378
- arg = next(src_iter)
379
- if batch > 1:
380
- fut2 = pool.submit(
381
- lambda items: [
382
- _worker(item, func, fixed_kwargs)
383
- for item in items
384
- ],
385
- arg,
386
- )
387
- fut2.idx = next_logical_idx # type: ignore[attr-defined]
388
- inflight.add(fut2)
389
- next_logical_idx += len(arg)
390
- _track_executor_threads(pool)
391
- else:
392
- fut2 = pool.submit(_worker, arg, func, fixed_kwargs)
393
- fut2.idx = next_logical_idx # type: ignore[attr-defined]
394
- inflight.add(fut2)
395
- next_logical_idx += 1
396
- _track_executor_threads(pool)
397
- except StopIteration:
398
- pass
399
-
400
- # Break the inner loop as we've processed one future
401
- break
402
-
403
- # If we've exhausted the inner loop without processing anything,
404
- # and there are still in-flight tasks, we need to wait for them
405
- if inflight and timeout is not None:
406
- # Use a small timeout to avoid hanging indefinitely
407
- time.sleep(min(0.01, timeout / 10))
408
-
409
- finally:
410
- if bar:
411
- bar.update(completed_items - last_bar_update)
412
- bar.close()
510
+ try:
511
+ while items_inflight() < max_inflight:
512
+ submit_arg(next(src_iter))
513
+ except StopIteration:
514
+ pass
515
+
516
+ results = collector.finalize()
517
+
518
+ except KeyboardInterrupt:
519
+ shutdown_kwargs = {'wait': False, 'cancel_futures': True}
520
+ _cancel_futures(inflight)
521
+ kill_all_thread(SystemExit)
522
+ raise KeyboardInterrupt() from None
523
+ finally:
524
+ try:
525
+ pool.shutdown(**shutdown_kwargs)
526
+ except TypeError: # pragma: no cover - Python <3.9 fallback
527
+ pool.shutdown(shutdown_kwargs.get('wait', True))
528
+ if bar:
529
+ delta = completed_items - last_bar_update
530
+ if delta > 0:
531
+ bar.update(delta)
532
+ bar.close()
533
+
534
+ results = collector.finalize() if 'results' not in locals() else results
413
535
  if store_output_pkl_file:
414
536
  dump_json_or_pickle(results, store_output_pkl_file)
415
537
  _prune_dead_threads()
@@ -417,32 +539,19 @@ def multi_thread(
417
539
 
418
540
 
419
541
  def multi_thread_standard(
420
- fn: Callable[[Any], Any], items: Iterable[Any], workers: int = 4
421
- ) -> list[Any]:
422
- """Execute a function using standard ThreadPoolExecutor.
423
-
424
- A standard implementation of multi-threading using ThreadPoolExecutor.
425
- Ensures the order of results matches the input order.
426
-
427
- Parameters
428
- ----------
429
- fn : Callable
430
- The function to execute for each item.
431
- items : Iterable
432
- The items to process.
433
- workers : int, optional
434
- Number of worker threads, by default 4.
435
-
436
- Returns
437
- -------
438
- list
439
- Results in same order as input items.
440
- """
441
- with ThreadPoolExecutor(max_workers=workers) as executor:
442
- futures = []
542
+ fn: Callable[[T], R], items: Iterable[T], workers: int = 4
543
+ ) -> list[R]:
544
+ """Execute ``fn`` across ``items`` while preserving submission order."""
545
+
546
+ workers_val = _resolve_worker_count(workers)
547
+ with ThreadPoolExecutor(
548
+ max_workers=workers_val,
549
+ thread_name_prefix='speedy-thread',
550
+ ) as executor:
551
+ futures: list[Future[R]] = []
443
552
  for item in items:
444
553
  futures.append(executor.submit(fn, item))
445
- _track_executor_threads(executor)
554
+ _track_executor_threads(executor)
446
555
  results = [fut.result() for fut in futures]
447
556
  _prune_dead_threads()
448
557
  return results
@@ -452,18 +561,24 @@ def _async_raise(thread_id: int, exc_type: type[BaseException]) -> bool:
452
561
  if thread_id <= 0:
453
562
  return False
454
563
  if not issubclass(exc_type, BaseException):
455
- raise TypeError("exc_type must derive from BaseException")
564
+ raise TypeError('exc_type must derive from BaseException')
456
565
  res = _PY_SET_ASYNC_EXC(ctypes.c_ulong(thread_id), ctypes.py_object(exc_type))
457
566
  if res == 0:
458
567
  return False
459
568
  if res > 1: # pragma: no cover - defensive branch
460
569
  _PY_SET_ASYNC_EXC(ctypes.c_ulong(thread_id), None)
461
- raise SystemError("PyThreadState_SetAsyncExc failed")
570
+ raise SystemError('PyThreadState_SetAsyncExc failed')
462
571
  return True
463
572
 
464
573
 
465
574
  def kill_all_thread(exc_type: type[BaseException] = SystemExit, join_timeout: float = 0.1) -> int:
466
- """Forcefully stop tracked worker threads. Returns number of threads signalled."""
575
+ """Forcefully stop tracked worker threads (dangerous; use sparingly).
576
+
577
+ Returns
578
+ -------
579
+ int
580
+ Count of threads signalled for termination.
581
+ """
467
582
  _prune_dead_threads()
468
583
  current = threading.current_thread()
469
584
  with _SPEEDY_THREADS_LOCK:
@@ -481,16 +596,16 @@ def kill_all_thread(exc_type: type[BaseException] = SystemExit, join_timeout: fl
481
596
  terminated += 1
482
597
  thread.join(timeout=join_timeout)
483
598
  else:
484
- logger.warning("Unable to signal thread %s", thread.name)
599
+ logger.warning('Unable to signal thread %s', thread.name)
485
600
  except Exception as exc: # pragma: no cover - defensive
486
- logger.error("Failed to stop thread %s: %s", thread.name, exc)
601
+ logger.error('Failed to stop thread %s: %s', thread.name, exc)
487
602
  _prune_dead_threads()
488
603
  return terminated
489
604
 
490
605
 
491
606
  __all__ = [
492
- "SPEEDY_RUNNING_THREADS",
493
- "multi_thread",
494
- "multi_thread_standard",
495
- "kill_all_thread",
607
+ 'SPEEDY_RUNNING_THREADS',
608
+ 'multi_thread',
609
+ 'multi_thread_standard',
610
+ 'kill_all_thread',
496
611
  ]