speedy-utils 1.1.26__py3-none-any.whl → 1.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +16 -4
- llm_utils/chat_format/__init__.py +10 -10
- llm_utils/chat_format/display.py +33 -21
- llm_utils/chat_format/transform.py +17 -19
- llm_utils/chat_format/utils.py +6 -4
- llm_utils/group_messages.py +17 -14
- llm_utils/lm/__init__.py +6 -5
- llm_utils/lm/async_lm/__init__.py +1 -0
- llm_utils/lm/async_lm/_utils.py +10 -9
- llm_utils/lm/async_lm/async_llm_task.py +141 -137
- llm_utils/lm/async_lm/async_lm.py +48 -42
- llm_utils/lm/async_lm/async_lm_base.py +59 -60
- llm_utils/lm/async_lm/lm_specific.py +4 -3
- llm_utils/lm/base_prompt_builder.py +93 -70
- llm_utils/lm/llm.py +126 -108
- llm_utils/lm/llm_signature.py +4 -2
- llm_utils/lm/lm_base.py +72 -73
- llm_utils/lm/mixins.py +102 -62
- llm_utils/lm/openai_memoize.py +124 -87
- llm_utils/lm/signature.py +105 -92
- llm_utils/lm/utils.py +42 -23
- llm_utils/scripts/vllm_load_balancer.py +23 -30
- llm_utils/scripts/vllm_serve.py +8 -7
- llm_utils/vector_cache/__init__.py +9 -3
- llm_utils/vector_cache/cli.py +1 -1
- llm_utils/vector_cache/core.py +59 -63
- llm_utils/vector_cache/types.py +7 -5
- llm_utils/vector_cache/utils.py +12 -8
- speedy_utils/__imports.py +244 -0
- speedy_utils/__init__.py +90 -194
- speedy_utils/all.py +125 -227
- speedy_utils/common/clock.py +37 -42
- speedy_utils/common/function_decorator.py +6 -12
- speedy_utils/common/logger.py +43 -52
- speedy_utils/common/notebook_utils.py +13 -21
- speedy_utils/common/patcher.py +21 -17
- speedy_utils/common/report_manager.py +42 -44
- speedy_utils/common/utils_cache.py +152 -169
- speedy_utils/common/utils_io.py +137 -103
- speedy_utils/common/utils_misc.py +15 -21
- speedy_utils/common/utils_print.py +22 -28
- speedy_utils/multi_worker/process.py +66 -79
- speedy_utils/multi_worker/thread.py +78 -155
- speedy_utils/scripts/mpython.py +38 -36
- speedy_utils/scripts/openapi_client_codegen.py +10 -10
- {speedy_utils-1.1.26.dist-info → speedy_utils-1.1.28.dist-info}/METADATA +1 -1
- speedy_utils-1.1.28.dist-info/RECORD +57 -0
- vision_utils/README.md +202 -0
- vision_utils/__init__.py +5 -0
- vision_utils/io_utils.py +470 -0
- vision_utils/plot.py +345 -0
- speedy_utils-1.1.26.dist-info/RECORD +0 -52
- {speedy_utils-1.1.26.dist-info → speedy_utils-1.1.28.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.26.dist-info → speedy_utils-1.1.28.dist-info}/entry_points.txt +0 -0
|
@@ -1,97 +1,5 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
# THREAD-BASED PARALLEL EXECUTION WITH PROGRESS TRACKING AND ERROR HANDLING
|
|
4
|
-
# ============================================================================= #
|
|
5
|
-
#
|
|
6
|
-
# Title & Intent:
|
|
7
|
-
# High-performance thread pool utilities for parallel processing with comprehensive error handling
|
|
8
|
-
#
|
|
9
|
-
# High-level Summary:
|
|
10
|
-
# This module provides robust thread-based parallel execution utilities designed for CPU-bound
|
|
11
|
-
# and I/O-bound tasks requiring concurrent processing. It features intelligent worker management,
|
|
12
|
-
# comprehensive error handling with detailed tracebacks, progress tracking with tqdm integration,
|
|
13
|
-
# and flexible batching strategies. The module optimizes for both throughput and reliability,
|
|
14
|
-
# making it suitable for data processing pipelines, batch operations, and concurrent API calls.
|
|
15
|
-
#
|
|
16
|
-
# Public API / Data Contracts:
|
|
17
|
-
# • multi_thread(func, inputs, *, workers=None, **kwargs) -> list[Any] - Main executor
|
|
18
|
-
# • multi_thread_standard(func, inputs, workers=4) -> list[Any] - Simple ordered helper
|
|
19
|
-
# • kill_all_thread(exc_type=SystemExit, join_timeout=0.1) -> int - Emergency stop
|
|
20
|
-
# • DEFAULT_WORKERS = (cpu_count * 2) - Default worker thread count
|
|
21
|
-
# • T = TypeVar('T'), R = TypeVar('R') - Generic type variables for input/output typing
|
|
22
|
-
# • _group_iter(src, size) -> Iterable[list[T]] - Utility for chunking iterables
|
|
23
|
-
# • _worker(item, func, fixed_kwargs) -> R - Individual worker function wrapper
|
|
24
|
-
# • _ResultCollector - Maintains ordered/unordered result aggregation
|
|
25
|
-
#
|
|
26
|
-
# Invariants / Constraints:
|
|
27
|
-
# • Worker count MUST be positive integer, defaults to (CPU cores * 2)
|
|
28
|
-
# • Input iterables MUST be finite and non-empty for meaningful processing
|
|
29
|
-
# • Functions MUST be thread-safe when used with multiple workers
|
|
30
|
-
# • Error handling MUST capture and log detailed tracebacks for debugging
|
|
31
|
-
# • Progress tracking MUST be optional and gracefully handle tqdm unavailability
|
|
32
|
-
# • Batch processing MUST maintain input order in results
|
|
33
|
-
# • MUST handle keyboard interruption gracefully with resource cleanup
|
|
34
|
-
# • Thread pool MUST be properly closed and joined after completion
|
|
35
|
-
#
|
|
36
|
-
# Usage Example:
|
|
37
|
-
# ```python
|
|
38
|
-
# from speedy_utils.multi_worker.thread import multi_thread, multi_thread_batch
|
|
39
|
-
# import requests
|
|
40
|
-
#
|
|
41
|
-
# # Simple parallel processing
|
|
42
|
-
# def square(x):
|
|
43
|
-
# return x ** 2
|
|
44
|
-
#
|
|
45
|
-
# numbers = list(range(100))
|
|
46
|
-
# results = multi_thread(square, numbers, num_workers=8)
|
|
47
|
-
# print(f"Processed {len(results)} items")
|
|
48
|
-
#
|
|
49
|
-
# # Parallel API calls with error handling
|
|
50
|
-
# def fetch_url(url):
|
|
51
|
-
# response = requests.get(url, timeout=10)
|
|
52
|
-
# return response.status_code, len(response.content)
|
|
53
|
-
#
|
|
54
|
-
# urls = ["http://example.com", "http://google.com", "http://github.com"]
|
|
55
|
-
# results = multi_thread(fetch_url, urls, num_workers=3, progress=True)
|
|
56
|
-
#
|
|
57
|
-
# # Batched processing for memory efficiency
|
|
58
|
-
# def process_batch(items):
|
|
59
|
-
# return [item.upper() for item in items]
|
|
60
|
-
#
|
|
61
|
-
# large_dataset = ["item" + str(i) for i in range(10000)]
|
|
62
|
-
# batched_results = multi_thread_batch(
|
|
63
|
-
# process_batch,
|
|
64
|
-
# large_dataset,
|
|
65
|
-
# batch_size=100,
|
|
66
|
-
# num_workers=4
|
|
67
|
-
# )
|
|
68
|
-
# ```
|
|
69
|
-
#
|
|
70
|
-
# TODO & Future Work:
|
|
71
|
-
# • Add adaptive worker count based on task characteristics
|
|
72
|
-
# • Implement priority queuing for time-sensitive tasks
|
|
73
|
-
# • Add memory usage monitoring and automatic batch size adjustment
|
|
74
|
-
# • Support for async function execution within thread pool
|
|
75
|
-
# • Add detailed performance metrics and timing analysis
|
|
76
|
-
# • Implement graceful degradation for resource-constrained environments
|
|
77
|
-
#
|
|
78
|
-
# ============================================================================= #
|
|
79
|
-
"""
|
|
80
|
-
|
|
81
|
-
import ctypes
|
|
82
|
-
import os
|
|
83
|
-
import sys
|
|
84
|
-
import threading
|
|
85
|
-
import time
|
|
86
|
-
import traceback
|
|
87
|
-
from collections.abc import Callable, Iterable, Mapping, Sequence
|
|
88
|
-
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
|
|
89
|
-
from heapq import heappop, heappush
|
|
90
|
-
from itertools import islice
|
|
91
|
-
from types import MappingProxyType
|
|
92
|
-
from typing import Any, Generic, TypeVar, cast
|
|
93
|
-
|
|
94
|
-
from loguru import logger
|
|
1
|
+
from ..__imports import *
|
|
2
|
+
|
|
95
3
|
|
|
96
4
|
try:
|
|
97
5
|
from tqdm import tqdm
|
|
@@ -101,8 +9,8 @@ except ImportError: # pragma: no cover
|
|
|
101
9
|
# Sensible defaults
|
|
102
10
|
DEFAULT_WORKERS = (os.cpu_count() or 4) * 2
|
|
103
11
|
|
|
104
|
-
T = TypeVar(
|
|
105
|
-
R = TypeVar(
|
|
12
|
+
T = TypeVar('T')
|
|
13
|
+
R = TypeVar('R')
|
|
106
14
|
|
|
107
15
|
SPEEDY_RUNNING_THREADS: list[threading.Thread] = [] # cooperative shutdown tracking
|
|
108
16
|
_SPEEDY_THREADS_LOCK = threading.Lock()
|
|
@@ -124,11 +32,11 @@ class UserFunctionError(Exception):
|
|
|
124
32
|
self.user_traceback = user_traceback
|
|
125
33
|
|
|
126
34
|
# Create a focused error message
|
|
127
|
-
tb_str =
|
|
35
|
+
tb_str = ''.join(traceback.format_list(user_traceback))
|
|
128
36
|
msg = (
|
|
129
37
|
f'\nError in function "{func_name}" with input: {input_value!r}\n'
|
|
130
|
-
f
|
|
131
|
-
f
|
|
38
|
+
f'\nUser code traceback:\n{tb_str}'
|
|
39
|
+
f'{type(original_exception).__name__}: {original_exception}'
|
|
132
40
|
)
|
|
133
41
|
super().__init__(msg)
|
|
134
42
|
|
|
@@ -165,7 +73,7 @@ def _track_threads(threads: Iterable[threading.Thread]) -> None:
|
|
|
165
73
|
|
|
166
74
|
|
|
167
75
|
def _track_executor_threads(pool: ThreadPoolExecutor) -> None:
|
|
168
|
-
thread_set = getattr(pool,
|
|
76
|
+
thread_set = getattr(pool, '_threads', None)
|
|
169
77
|
if not thread_set:
|
|
170
78
|
return
|
|
171
79
|
_track_threads(tuple(thread_set))
|
|
@@ -188,9 +96,9 @@ def _worker(
|
|
|
188
96
|
if not callable(func):
|
|
189
97
|
func_type = type(func).__name__
|
|
190
98
|
raise TypeError(
|
|
191
|
-
f
|
|
192
|
-
f
|
|
193
|
-
f
|
|
99
|
+
f'\nmulti_thread: func parameter must be callable, '
|
|
100
|
+
f'got {func_type}: {func!r}\n'
|
|
101
|
+
f'Hint: Did you accidentally pass a {func_type} instead of a function?'
|
|
194
102
|
)
|
|
195
103
|
|
|
196
104
|
try:
|
|
@@ -205,9 +113,9 @@ def _worker(
|
|
|
205
113
|
# Filter to keep only user code frames
|
|
206
114
|
user_frames = []
|
|
207
115
|
skip_patterns = [
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
116
|
+
'multi_worker/thread.py',
|
|
117
|
+
'concurrent/futures/',
|
|
118
|
+
'threading.py',
|
|
211
119
|
]
|
|
212
120
|
|
|
213
121
|
for frame in tb_list:
|
|
@@ -216,7 +124,7 @@ def _worker(
|
|
|
216
124
|
|
|
217
125
|
# If we have user frames, wrap in our custom exception
|
|
218
126
|
if user_frames:
|
|
219
|
-
func_name = getattr(func,
|
|
127
|
+
func_name = getattr(func, '__name__', repr(func))
|
|
220
128
|
raise UserFunctionError(
|
|
221
129
|
exc,
|
|
222
130
|
func_name,
|
|
@@ -237,14 +145,14 @@ def _run_batch(
|
|
|
237
145
|
|
|
238
146
|
|
|
239
147
|
def _attach_metadata(fut: Future[Any], idx: int, logical_size: int) -> None:
|
|
240
|
-
|
|
241
|
-
|
|
148
|
+
fut._speedy_idx = idx
|
|
149
|
+
fut._speedy_size = logical_size
|
|
242
150
|
|
|
243
151
|
|
|
244
152
|
def _future_meta(fut: Future[Any]) -> tuple[int, int]:
|
|
245
153
|
return (
|
|
246
|
-
|
|
247
|
-
|
|
154
|
+
fut._speedy_idx,
|
|
155
|
+
fut._speedy_size,
|
|
248
156
|
)
|
|
249
157
|
|
|
250
158
|
|
|
@@ -292,7 +200,7 @@ def _resolve_worker_count(workers: int | None) -> int:
|
|
|
292
200
|
if workers is None:
|
|
293
201
|
return DEFAULT_WORKERS
|
|
294
202
|
if workers <= 0:
|
|
295
|
-
raise ValueError(
|
|
203
|
+
raise ValueError('workers must be a positive integer')
|
|
296
204
|
return workers
|
|
297
205
|
|
|
298
206
|
|
|
@@ -300,18 +208,16 @@ def _normalize_batch_result(result: Any, logical_size: int) -> list[Any]:
|
|
|
300
208
|
if logical_size == 1:
|
|
301
209
|
return [result]
|
|
302
210
|
if result is None:
|
|
303
|
-
raise ValueError(
|
|
211
|
+
raise ValueError('batched callable returned None for a batch result')
|
|
304
212
|
if isinstance(result, (str, bytes, bytearray)):
|
|
305
|
-
raise TypeError(
|
|
306
|
-
if isinstance(result, Sequence):
|
|
307
|
-
out = list(result)
|
|
308
|
-
elif isinstance(result, Iterable):
|
|
213
|
+
raise TypeError('batched callable must not return str/bytes when batching')
|
|
214
|
+
if isinstance(result, (Sequence, Iterable)):
|
|
309
215
|
out = list(result)
|
|
310
216
|
else:
|
|
311
|
-
raise TypeError(
|
|
217
|
+
raise TypeError('batched callable must return an iterable of results')
|
|
312
218
|
if len(out) != logical_size:
|
|
313
219
|
raise ValueError(
|
|
314
|
-
f
|
|
220
|
+
f'batched callable returned {len(out)} items, expected {logical_size}',
|
|
315
221
|
)
|
|
316
222
|
return out
|
|
317
223
|
|
|
@@ -398,7 +304,9 @@ def multi_thread(
|
|
|
398
304
|
results: list[R | None] = []
|
|
399
305
|
|
|
400
306
|
for proc_idx, chunk in enumerate(chunks):
|
|
401
|
-
with tempfile.NamedTemporaryFile(
|
|
307
|
+
with tempfile.NamedTemporaryFile(
|
|
308
|
+
delete=False, suffix='multi_thread.pkl'
|
|
309
|
+
) as fh:
|
|
402
310
|
file_pkl = fh.name
|
|
403
311
|
assert isinstance(in_process_multi_thread, Callable)
|
|
404
312
|
proc = in_process_multi_thread(
|
|
@@ -420,28 +328,28 @@ def multi_thread(
|
|
|
420
328
|
|
|
421
329
|
for proc, file_pkl in procs:
|
|
422
330
|
proc.join()
|
|
423
|
-
logger.info(
|
|
331
|
+
logger.info('process finished: %s', proc)
|
|
424
332
|
try:
|
|
425
333
|
results.extend(load_by_ext(file_pkl))
|
|
426
334
|
finally:
|
|
427
335
|
try:
|
|
428
336
|
os.unlink(file_pkl)
|
|
429
337
|
except OSError as exc: # pragma: no cover - best effort cleanup
|
|
430
|
-
logger.warning(
|
|
338
|
+
logger.warning('failed to remove temp file %s: %s', file_pkl, exc)
|
|
431
339
|
return results
|
|
432
340
|
|
|
433
341
|
try:
|
|
434
342
|
import pandas as pd
|
|
435
343
|
|
|
436
344
|
if isinstance(inputs, pd.DataFrame):
|
|
437
|
-
inputs = cast(Iterable[T], inputs.to_dict(orient=
|
|
345
|
+
inputs = cast(Iterable[T], inputs.to_dict(orient='records'))
|
|
438
346
|
except ImportError: # pragma: no cover - optional dependency
|
|
439
347
|
pass
|
|
440
348
|
|
|
441
349
|
if batch <= 0:
|
|
442
|
-
raise ValueError(
|
|
350
|
+
raise ValueError('batch must be a positive integer')
|
|
443
351
|
if prefetch_factor <= 0:
|
|
444
|
-
raise ValueError(
|
|
352
|
+
raise ValueError('prefetch_factor must be a positive integer')
|
|
445
353
|
|
|
446
354
|
workers_val = _resolve_worker_count(workers)
|
|
447
355
|
progress_update = max(progress_update, 1)
|
|
@@ -463,12 +371,19 @@ def multi_thread(
|
|
|
463
371
|
|
|
464
372
|
bar = None
|
|
465
373
|
last_bar_update = 0
|
|
466
|
-
if
|
|
374
|
+
if (
|
|
375
|
+
progress
|
|
376
|
+
and tqdm is not None
|
|
377
|
+
and logical_total is not None
|
|
378
|
+
and logical_total > 0
|
|
379
|
+
):
|
|
467
380
|
bar = tqdm(
|
|
468
381
|
total=logical_total,
|
|
469
382
|
ncols=128,
|
|
470
|
-
colour=
|
|
471
|
-
bar_format=(
|
|
383
|
+
colour='green',
|
|
384
|
+
bar_format=(
|
|
385
|
+
'{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
|
|
386
|
+
),
|
|
472
387
|
)
|
|
473
388
|
|
|
474
389
|
deadline = time.monotonic() + timeout if timeout is not None else None
|
|
@@ -482,9 +397,9 @@ def multi_thread(
|
|
|
482
397
|
inflight: set[Future[Any]] = set()
|
|
483
398
|
pool = ThreadPoolExecutor(
|
|
484
399
|
max_workers=workers_val,
|
|
485
|
-
thread_name_prefix=
|
|
400
|
+
thread_name_prefix='speedy-thread',
|
|
486
401
|
)
|
|
487
|
-
shutdown_kwargs: dict[str, Any] = {
|
|
402
|
+
shutdown_kwargs: dict[str, Any] = {'wait': True}
|
|
488
403
|
|
|
489
404
|
try:
|
|
490
405
|
|
|
@@ -517,7 +432,7 @@ def multi_thread(
|
|
|
517
432
|
if remaining <= 0:
|
|
518
433
|
_cancel_futures(inflight)
|
|
519
434
|
raise TimeoutError(
|
|
520
|
-
f
|
|
435
|
+
f'multi_thread timed out after {timeout} seconds',
|
|
521
436
|
)
|
|
522
437
|
wait_timeout = max(remaining, 0.0)
|
|
523
438
|
|
|
@@ -530,7 +445,7 @@ def multi_thread(
|
|
|
530
445
|
if not done:
|
|
531
446
|
_cancel_futures(inflight)
|
|
532
447
|
raise TimeoutError(
|
|
533
|
-
f
|
|
448
|
+
f'multi_thread timed out after {timeout} seconds',
|
|
534
449
|
)
|
|
535
450
|
|
|
536
451
|
for fut in done:
|
|
@@ -549,11 +464,11 @@ def multi_thread(
|
|
|
549
464
|
orig_exc = exc.original_exception
|
|
550
465
|
|
|
551
466
|
# Build new traceback from user frames only
|
|
552
|
-
tb_str =
|
|
467
|
+
tb_str = ''.join(traceback.format_list(exc.user_traceback))
|
|
553
468
|
clean_msg = (
|
|
554
469
|
f'\nError in "{exc.func_name}" '
|
|
555
|
-
f
|
|
556
|
-
f
|
|
470
|
+
f'with input: {exc.input_value!r}\n\n{tb_str}'
|
|
471
|
+
f'{type(orig_exc).__name__}: {orig_exc}'
|
|
557
472
|
)
|
|
558
473
|
|
|
559
474
|
# Raise a new instance of the original exception type
|
|
@@ -568,7 +483,7 @@ def multi_thread(
|
|
|
568
483
|
if stop_on_error:
|
|
569
484
|
_cancel_futures(inflight)
|
|
570
485
|
raise
|
|
571
|
-
logger.exception(
|
|
486
|
+
logger.exception('multi_thread task failed', exc_info=exc)
|
|
572
487
|
out_items = [None] * logical_size
|
|
573
488
|
else:
|
|
574
489
|
try:
|
|
@@ -576,7 +491,7 @@ def multi_thread(
|
|
|
576
491
|
except Exception as exc:
|
|
577
492
|
_cancel_futures(inflight)
|
|
578
493
|
raise RuntimeError(
|
|
579
|
-
|
|
494
|
+
'batched callable returned an unexpected shape',
|
|
580
495
|
) from exc
|
|
581
496
|
|
|
582
497
|
collector.add(idx, out_items)
|
|
@@ -588,10 +503,14 @@ def multi_thread(
|
|
|
588
503
|
bar.update(delta)
|
|
589
504
|
last_bar_update = completed_items
|
|
590
505
|
submitted = next_logical_idx
|
|
591
|
-
pending =
|
|
506
|
+
pending = (
|
|
507
|
+
max(logical_total - submitted, 0)
|
|
508
|
+
if logical_total is not None
|
|
509
|
+
else '-'
|
|
510
|
+
)
|
|
592
511
|
postfix = {
|
|
593
|
-
|
|
594
|
-
|
|
512
|
+
'processing': min(len(inflight), workers_val),
|
|
513
|
+
'pending': pending,
|
|
595
514
|
}
|
|
596
515
|
bar.set_postfix(postfix)
|
|
597
516
|
|
|
@@ -604,7 +523,7 @@ def multi_thread(
|
|
|
604
523
|
results = collector.finalize()
|
|
605
524
|
|
|
606
525
|
except KeyboardInterrupt:
|
|
607
|
-
shutdown_kwargs = {
|
|
526
|
+
shutdown_kwargs = {'wait': False, 'cancel_futures': True}
|
|
608
527
|
_cancel_futures(inflight)
|
|
609
528
|
kill_all_thread(SystemExit)
|
|
610
529
|
raise KeyboardInterrupt() from None
|
|
@@ -612,27 +531,29 @@ def multi_thread(
|
|
|
612
531
|
try:
|
|
613
532
|
pool.shutdown(**shutdown_kwargs)
|
|
614
533
|
except TypeError: # pragma: no cover - Python <3.9 fallback
|
|
615
|
-
pool.shutdown(shutdown_kwargs.get(
|
|
534
|
+
pool.shutdown(shutdown_kwargs.get('wait', True))
|
|
616
535
|
if bar:
|
|
617
536
|
delta = completed_items - last_bar_update
|
|
618
537
|
if delta > 0:
|
|
619
538
|
bar.update(delta)
|
|
620
539
|
bar.close()
|
|
621
540
|
|
|
622
|
-
results = collector.finalize() if
|
|
541
|
+
results = collector.finalize() if 'results' not in locals() else results
|
|
623
542
|
if store_output_pkl_file:
|
|
624
543
|
dump_json_or_pickle(results, store_output_pkl_file)
|
|
625
544
|
_prune_dead_threads()
|
|
626
545
|
return results
|
|
627
546
|
|
|
628
547
|
|
|
629
|
-
def multi_thread_standard(
|
|
548
|
+
def multi_thread_standard(
|
|
549
|
+
fn: Callable[[T], R], items: Iterable[T], workers: int = 4
|
|
550
|
+
) -> list[R]:
|
|
630
551
|
"""Execute ``fn`` across ``items`` while preserving submission order."""
|
|
631
552
|
|
|
632
553
|
workers_val = _resolve_worker_count(workers)
|
|
633
554
|
with ThreadPoolExecutor(
|
|
634
555
|
max_workers=workers_val,
|
|
635
|
-
thread_name_prefix=
|
|
556
|
+
thread_name_prefix='speedy-thread',
|
|
636
557
|
) as executor:
|
|
637
558
|
futures: list[Future[R]] = []
|
|
638
559
|
for item in items:
|
|
@@ -647,17 +568,19 @@ def _async_raise(thread_id: int, exc_type: type[BaseException]) -> bool:
|
|
|
647
568
|
if thread_id <= 0:
|
|
648
569
|
return False
|
|
649
570
|
if not issubclass(exc_type, BaseException):
|
|
650
|
-
raise TypeError(
|
|
571
|
+
raise TypeError('exc_type must derive from BaseException')
|
|
651
572
|
res = _PY_SET_ASYNC_EXC(ctypes.c_ulong(thread_id), ctypes.py_object(exc_type))
|
|
652
573
|
if res == 0:
|
|
653
574
|
return False
|
|
654
575
|
if res > 1: # pragma: no cover - defensive branch
|
|
655
576
|
_PY_SET_ASYNC_EXC(ctypes.c_ulong(thread_id), None)
|
|
656
|
-
raise SystemError(
|
|
577
|
+
raise SystemError('PyThreadState_SetAsyncExc failed')
|
|
657
578
|
return True
|
|
658
579
|
|
|
659
580
|
|
|
660
|
-
def kill_all_thread(
|
|
581
|
+
def kill_all_thread(
|
|
582
|
+
exc_type: type[BaseException] = SystemExit, join_timeout: float = 0.1
|
|
583
|
+
) -> int:
|
|
661
584
|
"""Forcefully stop tracked worker threads (dangerous; use sparingly).
|
|
662
585
|
|
|
663
586
|
Returns
|
|
@@ -682,17 +605,17 @@ def kill_all_thread(exc_type: type[BaseException] = SystemExit, join_timeout: fl
|
|
|
682
605
|
terminated += 1
|
|
683
606
|
thread.join(timeout=join_timeout)
|
|
684
607
|
else:
|
|
685
|
-
logger.warning(
|
|
608
|
+
logger.warning('Unable to signal thread %s', thread.name)
|
|
686
609
|
except Exception as exc: # pragma: no cover - defensive
|
|
687
|
-
logger.error(
|
|
610
|
+
logger.error('Failed to stop thread %s: %s', thread.name, exc)
|
|
688
611
|
_prune_dead_threads()
|
|
689
612
|
return terminated
|
|
690
613
|
|
|
691
614
|
|
|
692
615
|
__all__ = [
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
616
|
+
'SPEEDY_RUNNING_THREADS',
|
|
617
|
+
'UserFunctionError',
|
|
618
|
+
'multi_thread',
|
|
619
|
+
'multi_thread_standard',
|
|
620
|
+
'kill_all_thread',
|
|
698
621
|
]
|
speedy_utils/scripts/mpython.py
CHANGED
|
@@ -6,101 +6,103 @@ import os
|
|
|
6
6
|
import shlex # To properly escape command line arguments
|
|
7
7
|
import shutil
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
|
|
10
|
+
taskset_path = shutil.which('taskset')
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def assert_script(python_path):
|
|
13
|
-
|
|
14
|
-
|
|
14
|
+
with open(python_path) as f:
|
|
15
|
+
code_str = f.read()
|
|
16
|
+
if 'MP_ID' not in code_str or 'MP_TOTAL' not in code_str:
|
|
15
17
|
example_code = (
|
|
16
18
|
'import os; MP_TOTAL = int(os.environ.get("MP_TOTAL"));MP_ID = int(os.environ.get("MP_ID"))\n'
|
|
17
|
-
|
|
19
|
+
'inputs = list(inputs[MP_ID::MP_TOTAL])'
|
|
18
20
|
)
|
|
19
21
|
# ANSI escape codes for coloring
|
|
20
|
-
YELLOW =
|
|
21
|
-
RESET =
|
|
22
|
+
YELLOW = '\033[93m'
|
|
23
|
+
RESET = '\033[0m'
|
|
22
24
|
raise_msg = (
|
|
23
|
-
f
|
|
24
|
-
f
|
|
25
|
+
f'MP_ID and MP_TOTAL not found in {python_path}, please add them.\n\n'
|
|
26
|
+
f'Example:\n{YELLOW}{example_code}{RESET}'
|
|
25
27
|
)
|
|
26
28
|
raise Exception(raise_msg)
|
|
27
29
|
|
|
28
30
|
|
|
29
31
|
def run_in_tmux(commands_to_run, tmux_name, num_windows):
|
|
30
|
-
with open(
|
|
32
|
+
with open('/tmp/start_multirun_tmux.sh', 'w') as script_file:
|
|
31
33
|
# first cmd is to kill the session if it exists
|
|
32
34
|
|
|
33
|
-
script_file.write(
|
|
34
|
-
script_file.write(f
|
|
35
|
-
script_file.write(f
|
|
35
|
+
script_file.write('#!/bin/bash\n\n')
|
|
36
|
+
script_file.write(f'tmux kill-session -t {tmux_name}\nsleep .1\n')
|
|
37
|
+
script_file.write(f'tmux new-session -d -s {tmux_name}\n')
|
|
36
38
|
for i, cmd in enumerate(itertools.cycle(commands_to_run)):
|
|
37
39
|
if i >= num_windows:
|
|
38
40
|
break
|
|
39
|
-
window_name = f
|
|
41
|
+
window_name = f'{tmux_name}:{i}'
|
|
40
42
|
if i == 0:
|
|
41
43
|
script_file.write(f"tmux send-keys -t {window_name} '{cmd}' C-m\n")
|
|
42
44
|
else:
|
|
43
|
-
script_file.write(f
|
|
45
|
+
script_file.write(f'tmux new-window -t {tmux_name}\n')
|
|
44
46
|
script_file.write(f"tmux send-keys -t {window_name} '{cmd}' C-m\n")
|
|
45
47
|
|
|
46
48
|
# Make the script executable
|
|
47
|
-
script_file.write(
|
|
48
|
-
print(
|
|
49
|
+
script_file.write('chmod +x /tmp/start_multirun_tmux.sh\n')
|
|
50
|
+
print('Run /tmp/start_multirun_tmux.sh')
|
|
49
51
|
|
|
50
52
|
|
|
51
53
|
def main():
|
|
52
54
|
# Assert that MP_ID and MP_TOTAL are not already set
|
|
53
55
|
|
|
54
|
-
parser = argparse.ArgumentParser(description=
|
|
56
|
+
parser = argparse.ArgumentParser(description='Process fold arguments')
|
|
55
57
|
parser.add_argument(
|
|
56
|
-
|
|
58
|
+
'--total_fold', '-t', default=16, type=int, help='total number of folds'
|
|
57
59
|
)
|
|
58
|
-
parser.add_argument(
|
|
59
|
-
parser.add_argument(
|
|
60
|
+
parser.add_argument('--gpus', type=str, default='0,1,2,3,4,5,6,7')
|
|
61
|
+
parser.add_argument('--ignore_gpus', '-ig', type=str, default='')
|
|
60
62
|
parser.add_argument(
|
|
61
|
-
|
|
63
|
+
'--total_cpu',
|
|
62
64
|
type=int,
|
|
63
65
|
default=multiprocessing.cpu_count(),
|
|
64
|
-
help=
|
|
66
|
+
help='total number of cpu cores available',
|
|
65
67
|
)
|
|
66
68
|
parser.add_argument(
|
|
67
|
-
|
|
69
|
+
'cmd', nargs=argparse.REMAINDER
|
|
68
70
|
) # This will gather the remaining unparsed arguments
|
|
69
71
|
|
|
70
72
|
args = parser.parse_args()
|
|
71
|
-
if not args.cmd or (args.cmd[0] ==
|
|
72
|
-
parser.error(
|
|
73
|
+
if not args.cmd or (args.cmd[0] == '--' and len(args.cmd) == 1):
|
|
74
|
+
parser.error('Invalid command provided')
|
|
73
75
|
assert_script(args.cmd[0])
|
|
74
76
|
|
|
75
77
|
cmd_str = None
|
|
76
|
-
if args.cmd[0] ==
|
|
78
|
+
if args.cmd[0] == '--':
|
|
77
79
|
cmd_str = shlex.join(args.cmd[1:])
|
|
78
80
|
else:
|
|
79
81
|
cmd_str = shlex.join(args.cmd)
|
|
80
82
|
|
|
81
|
-
gpus = args.gpus.split(
|
|
82
|
-
gpus = [gpu for gpu in gpus if gpu not in args.ignore_gpus.split(
|
|
83
|
+
gpus = args.gpus.split(',')
|
|
84
|
+
gpus = [gpu for gpu in gpus if gpu not in args.ignore_gpus.split(',')]
|
|
83
85
|
num_gpus = len(gpus)
|
|
84
86
|
|
|
85
87
|
cpu_per_process = max(args.total_cpu // args.total_fold, 1)
|
|
86
88
|
cmds = []
|
|
87
|
-
path_python = shutil.which(
|
|
89
|
+
path_python = shutil.which('python')
|
|
88
90
|
for i in range(args.total_fold):
|
|
89
91
|
gpu = gpus[i % num_gpus]
|
|
90
92
|
cpu_start = (i * cpu_per_process) % args.total_cpu
|
|
91
93
|
cpu_end = ((i + 1) * cpu_per_process - 1) % args.total_cpu
|
|
92
|
-
ENV = f
|
|
94
|
+
ENV = f'CUDA_VISIBLE_DEVICES={gpu} MP_ID={i} MP_TOTAL={args.total_fold}'
|
|
93
95
|
if taskset_path:
|
|
94
|
-
fold_cmd = f
|
|
96
|
+
fold_cmd = f'{ENV} {taskset_path} -c {cpu_start}-{cpu_end} {path_python} {cmd_str}'
|
|
95
97
|
else:
|
|
96
|
-
fold_cmd = f
|
|
98
|
+
fold_cmd = f'{ENV} {path_python} {cmd_str}'
|
|
97
99
|
|
|
98
100
|
cmds.append(fold_cmd)
|
|
99
101
|
|
|
100
|
-
run_in_tmux(cmds,
|
|
101
|
-
os.chmod(
|
|
102
|
-
os.system(
|
|
102
|
+
run_in_tmux(cmds, 'mpython', args.total_fold)
|
|
103
|
+
os.chmod('/tmp/start_multirun_tmux.sh', 0o755) # Make the script executable
|
|
104
|
+
os.system('/tmp/start_multirun_tmux.sh')
|
|
103
105
|
|
|
104
106
|
|
|
105
|
-
if __name__ ==
|
|
107
|
+
if __name__ == '__main__':
|
|
106
108
|
main()
|
|
@@ -33,7 +33,7 @@ def snake_case(s: str) -> str:
|
|
|
33
33
|
return "".join(out)
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
def map_openapi_type(prop:
|
|
36
|
+
def map_openapi_type(prop: dict[str, Any]) -> str:
|
|
37
37
|
t = prop.get("type")
|
|
38
38
|
if t == "string":
|
|
39
39
|
fmt = prop.get("format")
|
|
@@ -50,8 +50,8 @@ def map_openapi_type(prop: Dict[str, Any]) -> str:
|
|
|
50
50
|
return "Any"
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
def generate_models(components:
|
|
54
|
-
lines:
|
|
53
|
+
def generate_models(components: dict[str, Any]) -> list[str]:
|
|
54
|
+
lines: list[str] = []
|
|
55
55
|
schemas = components.get("schemas", {})
|
|
56
56
|
for name, schema in schemas.items():
|
|
57
57
|
if "enum" in schema:
|
|
@@ -77,10 +77,10 @@ def generate_models(components: Dict[str, Any]) -> List[str]:
|
|
|
77
77
|
return lines
|
|
78
78
|
|
|
79
79
|
|
|
80
|
-
def generate_client(spec:
|
|
80
|
+
def generate_client(spec: dict[str, Any]) -> list[str]:
|
|
81
81
|
paths = spec.get("paths", {})
|
|
82
82
|
models = spec.get("components", {}).get("schemas", {})
|
|
83
|
-
lines:
|
|
83
|
+
lines: list[str] = []
|
|
84
84
|
lines.append("class GeneratedClient:")
|
|
85
85
|
lines.append(' """Client generated from OpenAPI spec."""')
|
|
86
86
|
lines.append("")
|
|
@@ -115,8 +115,8 @@ def generate_client(spec: Dict[str, Any]) -> List[str]:
|
|
|
115
115
|
func_name = snake_case(op_id)
|
|
116
116
|
summary = op.get("summary", "").strip()
|
|
117
117
|
# collect parameters
|
|
118
|
-
req_params:
|
|
119
|
-
opt_params:
|
|
118
|
+
req_params: list[str] = ["self"]
|
|
119
|
+
opt_params: list[str] = []
|
|
120
120
|
# path params (required)
|
|
121
121
|
path_params = [p for p in op.get("parameters", []) if p.get("in") == "path"]
|
|
122
122
|
for p in path_params:
|
|
@@ -219,15 +219,15 @@ def main() -> None:
|
|
|
219
219
|
|
|
220
220
|
try:
|
|
221
221
|
spec_src = args.spec
|
|
222
|
-
if spec_src.startswith("http://"
|
|
222
|
+
if spec_src.startswith(("http://", "https://")):
|
|
223
223
|
import httpx
|
|
224
224
|
|
|
225
225
|
response = httpx.get(spec_src)
|
|
226
226
|
spec = response.json()
|
|
227
227
|
else:
|
|
228
|
-
with open(spec_src,
|
|
228
|
+
with open(spec_src, encoding="utf-8") as f:
|
|
229
229
|
spec = json.load(f)
|
|
230
|
-
out:
|
|
230
|
+
out: list[str] = []
|
|
231
231
|
# imports
|
|
232
232
|
out.append("from typing import Any, Dict, List, Optional")
|
|
233
233
|
out.append("from datetime import datetime")
|