stouputils 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.py +40 -0
- stouputils/__main__.py +86 -0
- stouputils/_deprecated.py +37 -0
- stouputils/all_doctests.py +160 -0
- stouputils/applications/__init__.py +22 -0
- stouputils/applications/automatic_docs.py +634 -0
- stouputils/applications/upscaler/__init__.py +39 -0
- stouputils/applications/upscaler/config.py +128 -0
- stouputils/applications/upscaler/image.py +247 -0
- stouputils/applications/upscaler/video.py +287 -0
- stouputils/archive.py +344 -0
- stouputils/backup.py +488 -0
- stouputils/collections.py +244 -0
- stouputils/continuous_delivery/__init__.py +27 -0
- stouputils/continuous_delivery/cd_utils.py +243 -0
- stouputils/continuous_delivery/github.py +522 -0
- stouputils/continuous_delivery/pypi.py +130 -0
- stouputils/continuous_delivery/pyproject.py +147 -0
- stouputils/continuous_delivery/stubs.py +86 -0
- stouputils/ctx.py +408 -0
- stouputils/data_science/config/get.py +51 -0
- stouputils/data_science/config/set.py +125 -0
- stouputils/data_science/data_processing/image/__init__.py +66 -0
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
- stouputils/data_science/data_processing/image/axis_flip.py +58 -0
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
- stouputils/data_science/data_processing/image/blur.py +59 -0
- stouputils/data_science/data_processing/image/brightness.py +54 -0
- stouputils/data_science/data_processing/image/canny.py +110 -0
- stouputils/data_science/data_processing/image/clahe.py +92 -0
- stouputils/data_science/data_processing/image/common.py +30 -0
- stouputils/data_science/data_processing/image/contrast.py +53 -0
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
- stouputils/data_science/data_processing/image/denoise.py +378 -0
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
- stouputils/data_science/data_processing/image/invert.py +64 -0
- stouputils/data_science/data_processing/image/laplacian.py +60 -0
- stouputils/data_science/data_processing/image/median_blur.py +52 -0
- stouputils/data_science/data_processing/image/noise.py +59 -0
- stouputils/data_science/data_processing/image/normalize.py +65 -0
- stouputils/data_science/data_processing/image/random_erase.py +66 -0
- stouputils/data_science/data_processing/image/resize.py +69 -0
- stouputils/data_science/data_processing/image/rotation.py +80 -0
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
- stouputils/data_science/data_processing/image/sharpening.py +55 -0
- stouputils/data_science/data_processing/image/shearing.py +64 -0
- stouputils/data_science/data_processing/image/threshold.py +64 -0
- stouputils/data_science/data_processing/image/translation.py +71 -0
- stouputils/data_science/data_processing/image/zoom.py +83 -0
- stouputils/data_science/data_processing/image_augmentation.py +118 -0
- stouputils/data_science/data_processing/image_preprocess.py +183 -0
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
- stouputils/data_science/data_processing/technique.py +481 -0
- stouputils/data_science/dataset/__init__.py +45 -0
- stouputils/data_science/dataset/dataset.py +292 -0
- stouputils/data_science/dataset/dataset_loader.py +135 -0
- stouputils/data_science/dataset/grouping_strategy.py +296 -0
- stouputils/data_science/dataset/image_loader.py +100 -0
- stouputils/data_science/dataset/xy_tuple.py +696 -0
- stouputils/data_science/metric_dictionnary.py +106 -0
- stouputils/data_science/metric_utils.py +847 -0
- stouputils/data_science/mlflow_utils.py +206 -0
- stouputils/data_science/models/abstract_model.py +149 -0
- stouputils/data_science/models/all.py +85 -0
- stouputils/data_science/models/base_keras.py +765 -0
- stouputils/data_science/models/keras/all.py +38 -0
- stouputils/data_science/models/keras/convnext.py +62 -0
- stouputils/data_science/models/keras/densenet.py +50 -0
- stouputils/data_science/models/keras/efficientnet.py +60 -0
- stouputils/data_science/models/keras/mobilenet.py +56 -0
- stouputils/data_science/models/keras/resnet.py +52 -0
- stouputils/data_science/models/keras/squeezenet.py +233 -0
- stouputils/data_science/models/keras/vgg.py +42 -0
- stouputils/data_science/models/keras/xception.py +38 -0
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
- stouputils/data_science/models/keras_utils/visualizations.py +416 -0
- stouputils/data_science/models/model_interface.py +939 -0
- stouputils/data_science/models/sandbox.py +116 -0
- stouputils/data_science/range_tuple.py +234 -0
- stouputils/data_science/scripts/augment_dataset.py +77 -0
- stouputils/data_science/scripts/exhaustive_process.py +133 -0
- stouputils/data_science/scripts/preprocess_dataset.py +70 -0
- stouputils/data_science/scripts/routine.py +168 -0
- stouputils/data_science/utils.py +285 -0
- stouputils/decorators.py +605 -0
- stouputils/image.py +441 -0
- stouputils/installer/__init__.py +18 -0
- stouputils/installer/common.py +67 -0
- stouputils/installer/downloader.py +101 -0
- stouputils/installer/linux.py +144 -0
- stouputils/installer/main.py +223 -0
- stouputils/installer/windows.py +136 -0
- stouputils/io.py +486 -0
- stouputils/parallel.py +483 -0
- stouputils/print.py +482 -0
- stouputils/py.typed +1 -0
- stouputils/stouputils/__init__.pyi +15 -0
- stouputils/stouputils/_deprecated.pyi +12 -0
- stouputils/stouputils/all_doctests.pyi +46 -0
- stouputils/stouputils/applications/__init__.pyi +2 -0
- stouputils/stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/stouputils/archive.pyi +67 -0
- stouputils/stouputils/backup.pyi +109 -0
- stouputils/stouputils/collections.pyi +86 -0
- stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
- stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/stouputils/ctx.pyi +211 -0
- stouputils/stouputils/decorators.pyi +252 -0
- stouputils/stouputils/image.pyi +172 -0
- stouputils/stouputils/installer/__init__.pyi +5 -0
- stouputils/stouputils/installer/common.pyi +39 -0
- stouputils/stouputils/installer/downloader.pyi +24 -0
- stouputils/stouputils/installer/linux.pyi +39 -0
- stouputils/stouputils/installer/main.pyi +57 -0
- stouputils/stouputils/installer/windows.pyi +31 -0
- stouputils/stouputils/io.pyi +213 -0
- stouputils/stouputils/parallel.pyi +216 -0
- stouputils/stouputils/print.pyi +136 -0
- stouputils/stouputils/version_pkg.pyi +15 -0
- stouputils/version_pkg.py +189 -0
- stouputils-1.14.0.dist-info/METADATA +178 -0
- stouputils-1.14.0.dist-info/RECORD +140 -0
- stouputils-1.14.0.dist-info/WHEEL +4 -0
- stouputils-1.14.0.dist-info/entry_points.txt +3 -0
stouputils/parallel.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides utility functions for parallel processing, such as:
|
|
3
|
+
|
|
4
|
+
- multiprocessing(): Execute a function in parallel using multiprocessing
|
|
5
|
+
- multithreading(): Execute a function in parallel using multithreading
|
|
6
|
+
- run_in_subprocess(): Execute a function in a subprocess with args and kwargs
|
|
7
|
+
|
|
8
|
+
I highly encourage you to read the function docstrings to understand when to use each method.
|
|
9
|
+
|
|
10
|
+
.. image:: https://raw.githubusercontent.com/Stoupy51/stouputils/refs/heads/main/assets/parallel_module.gif
|
|
11
|
+
:alt: stouputils parallel examples
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
# Imports
|
|
15
|
+
import os
|
|
16
|
+
import time
|
|
17
|
+
from collections.abc import Callable, Iterable
|
|
18
|
+
from typing import Any, TypeVar, cast
|
|
19
|
+
|
|
20
|
+
from .ctx import SetMPStartMethod
|
|
21
|
+
from .print import BAR_FORMAT, MAGENTA
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Small test functions for doctests
|
|
25
|
+
def doctest_square(x: int) -> int:
|
|
26
|
+
return x * x
|
|
27
|
+
def doctest_slow(x: int) -> int:
|
|
28
|
+
time.sleep(0.1)
|
|
29
|
+
return x
|
|
30
|
+
|
|
31
|
+
# Constants
|
|
32
|
+
CPU_COUNT: int = cast(int, os.cpu_count())
|
|
33
|
+
T = TypeVar("T")
|
|
34
|
+
R = TypeVar("R")
|
|
35
|
+
|
|
36
|
+
# Functions
|
|
37
|
+
def multiprocessing[T, R](
|
|
38
|
+
func: Callable[..., R] | list[Callable[..., R]],
|
|
39
|
+
args: Iterable[T],
|
|
40
|
+
use_starmap: bool = False,
|
|
41
|
+
chunksize: int = 1,
|
|
42
|
+
desc: str = "",
|
|
43
|
+
max_workers: int | float = CPU_COUNT,
|
|
44
|
+
delay_first_calls: float = 0,
|
|
45
|
+
color: str = MAGENTA,
|
|
46
|
+
bar_format: str = BAR_FORMAT,
|
|
47
|
+
ascii: bool = False,
|
|
48
|
+
smooth_tqdm: bool = True,
|
|
49
|
+
**tqdm_kwargs: Any
|
|
50
|
+
) -> list[R]:
|
|
51
|
+
r""" Method to execute a function in parallel using multiprocessing
|
|
52
|
+
|
|
53
|
+
- For CPU-bound operations where the GIL (Global Interpreter Lock) is a bottleneck.
|
|
54
|
+
- When the task can be divided into smaller, independent sub-tasks that can be executed concurrently.
|
|
55
|
+
- For computationally intensive tasks like scientific simulations, data analysis, or machine learning workloads.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
func (Callable | list[Callable]): Function to execute, or list of functions (one per argument)
|
|
59
|
+
args (Iterable): Iterable of arguments to pass to the function(s)
|
|
60
|
+
use_starmap (bool): Whether to use starmap or not (Defaults to False):
|
|
61
|
+
True means the function will be called like func(\*args[i]) instead of func(args[i])
|
|
62
|
+
chunksize (int): Number of arguments to process at a time
|
|
63
|
+
(Defaults to 1 for proper progress bar display)
|
|
64
|
+
desc (str): Description displayed in the progress bar
|
|
65
|
+
(if not provided no progress bar will be displayed)
|
|
66
|
+
max_workers (int | float): Number of workers to use (Defaults to CPU_COUNT), -1 means CPU_COUNT.
|
|
67
|
+
If float between 0 and 1, it's treated as a percentage of CPU_COUNT.
|
|
68
|
+
If negative float between -1 and 0, it's treated as a percentage of len(args).
|
|
69
|
+
delay_first_calls (float): Apply i*delay_first_calls seconds delay to the first "max_workers" calls.
|
|
70
|
+
For instance, the first process will be delayed by 0 seconds, the second by 1 second, etc.
|
|
71
|
+
(Defaults to 0): This can be useful to avoid functions being called in the same second.
|
|
72
|
+
color (str): Color of the progress bar (Defaults to MAGENTA)
|
|
73
|
+
bar_format (str): Format of the progress bar (Defaults to BAR_FORMAT)
|
|
74
|
+
ascii (bool): Whether to use ASCII or Unicode characters for the progress bar
|
|
75
|
+
smooth_tqdm (bool): Whether to enable smooth progress bar updates by setting miniters and mininterval (Defaults to True)
|
|
76
|
+
**tqdm_kwargs (Any): Additional keyword arguments to pass to tqdm
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
list[object]: Results of the function execution
|
|
80
|
+
|
|
81
|
+
Examples:
|
|
82
|
+
.. code-block:: python
|
|
83
|
+
|
|
84
|
+
> multiprocessing(doctest_square, args=[1, 2, 3])
|
|
85
|
+
[1, 4, 9]
|
|
86
|
+
|
|
87
|
+
> multiprocessing(int.__mul__, [(1,2), (3,4), (5,6)], use_starmap=True)
|
|
88
|
+
[2, 12, 30]
|
|
89
|
+
|
|
90
|
+
> # Using a list of functions (one per argument)
|
|
91
|
+
> multiprocessing([doctest_square, doctest_square, doctest_square], [1, 2, 3])
|
|
92
|
+
[1, 4, 9]
|
|
93
|
+
|
|
94
|
+
> # Will process in parallel with progress bar
|
|
95
|
+
> multiprocessing(doctest_slow, range(10), desc="Processing")
|
|
96
|
+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
97
|
+
|
|
98
|
+
> # Will process in parallel with progress bar and delay the first threads
|
|
99
|
+
> multiprocessing(
|
|
100
|
+
. doctest_slow,
|
|
101
|
+
. range(10),
|
|
102
|
+
. desc="Processing with delay",
|
|
103
|
+
. max_workers=2,
|
|
104
|
+
. delay_first_calls=0.6
|
|
105
|
+
. )
|
|
106
|
+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
107
|
+
"""
|
|
108
|
+
# Imports
|
|
109
|
+
import multiprocessing as mp
|
|
110
|
+
from multiprocessing import Pool
|
|
111
|
+
|
|
112
|
+
from tqdm.auto import tqdm
|
|
113
|
+
from tqdm.contrib.concurrent import process_map # pyright: ignore[reportUnknownVariableType]
|
|
114
|
+
|
|
115
|
+
# Handle parameters
|
|
116
|
+
args = list(args) # Ensure we have a list (not other iterable)
|
|
117
|
+
if max_workers == -1:
|
|
118
|
+
max_workers = CPU_COUNT
|
|
119
|
+
if isinstance(max_workers, float):
|
|
120
|
+
if max_workers > 0:
|
|
121
|
+
assert max_workers <= 1, "max_workers as positive float must be between 0 and 1 (percentage of CPU_COUNT)"
|
|
122
|
+
max_workers = int(max_workers * CPU_COUNT)
|
|
123
|
+
else:
|
|
124
|
+
assert -1 <= max_workers < 0, "max_workers as negative float must be between -1 and 0 (percentage of len(args))"
|
|
125
|
+
max_workers = int(-max_workers * len(args))
|
|
126
|
+
verbose: bool = desc != ""
|
|
127
|
+
desc, func, args = _handle_parameters(func, args, use_starmap, delay_first_calls, max_workers, desc, color)
|
|
128
|
+
if bar_format == BAR_FORMAT:
|
|
129
|
+
bar_format = bar_format.replace(MAGENTA, color)
|
|
130
|
+
if smooth_tqdm:
|
|
131
|
+
tqdm_kwargs.setdefault("mininterval", 0.0)
|
|
132
|
+
try:
|
|
133
|
+
total = len(args) # type: ignore
|
|
134
|
+
import shutil
|
|
135
|
+
width = shutil.get_terminal_size().columns
|
|
136
|
+
tqdm_kwargs.setdefault("miniters", max(1, total // width))
|
|
137
|
+
except (TypeError, OSError):
|
|
138
|
+
tqdm_kwargs.setdefault("miniters", 1)
|
|
139
|
+
|
|
140
|
+
# Do multiprocessing only if there is more than 1 argument and more than 1 CPU
|
|
141
|
+
if max_workers > 1 and len(args) > 1:
|
|
142
|
+
def process() -> list[Any]:
|
|
143
|
+
if verbose:
|
|
144
|
+
return list(process_map(
|
|
145
|
+
func, args, max_workers=max_workers, chunksize=chunksize, desc=desc, bar_format=bar_format, ascii=ascii, **tqdm_kwargs
|
|
146
|
+
)) # type: ignore
|
|
147
|
+
else:
|
|
148
|
+
with Pool(max_workers) as pool:
|
|
149
|
+
return list(pool.map(func, args, chunksize=chunksize)) # type: ignore
|
|
150
|
+
try:
|
|
151
|
+
return process()
|
|
152
|
+
except RuntimeError as e:
|
|
153
|
+
if "SemLock created in a fork context is being shared with a process in a spawn context" in str(e):
|
|
154
|
+
|
|
155
|
+
# Try with alternate start method
|
|
156
|
+
with SetMPStartMethod("spawn" if mp.get_start_method() != "spawn" else "fork"):
|
|
157
|
+
return process()
|
|
158
|
+
else: # Re-raise if it's not the SemLock error
|
|
159
|
+
raise
|
|
160
|
+
|
|
161
|
+
# Single process execution
|
|
162
|
+
else:
|
|
163
|
+
if verbose:
|
|
164
|
+
return [func(arg) for arg in tqdm(args, total=len(args), desc=desc, bar_format=bar_format, ascii=ascii, **tqdm_kwargs)]
|
|
165
|
+
else:
|
|
166
|
+
return [func(arg) for arg in args]
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def multithreading[T, R](
|
|
170
|
+
func: Callable[..., R] | list[Callable[..., R]],
|
|
171
|
+
args: Iterable[T],
|
|
172
|
+
use_starmap: bool = False,
|
|
173
|
+
desc: str = "",
|
|
174
|
+
max_workers: int | float = CPU_COUNT,
|
|
175
|
+
delay_first_calls: float = 0,
|
|
176
|
+
color: str = MAGENTA,
|
|
177
|
+
bar_format: str = BAR_FORMAT,
|
|
178
|
+
ascii: bool = False,
|
|
179
|
+
smooth_tqdm: bool = True,
|
|
180
|
+
**tqdm_kwargs: Any
|
|
181
|
+
) -> list[R]:
|
|
182
|
+
r""" Method to execute a function in parallel using multithreading, you should use it:
|
|
183
|
+
|
|
184
|
+
- For I/O-bound operations where the GIL is not a bottleneck, such as network requests or disk operations.
|
|
185
|
+
- When the task involves waiting for external resources, such as network responses or user input.
|
|
186
|
+
- For operations that involve a lot of waiting, such as GUI event handling or handling user input.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
func (Callable | list[Callable]): Function to execute, or list of functions (one per argument)
|
|
190
|
+
args (Iterable): Iterable of arguments to pass to the function(s)
|
|
191
|
+
use_starmap (bool): Whether to use starmap or not (Defaults to False):
|
|
192
|
+
True means the function will be called like func(\*args[i]) instead of func(args[i])
|
|
193
|
+
desc (str): Description displayed in the progress bar
|
|
194
|
+
(if not provided no progress bar will be displayed)
|
|
195
|
+
max_workers (int | float): Number of workers to use (Defaults to CPU_COUNT), -1 means CPU_COUNT.
|
|
196
|
+
If float between 0 and 1, it's treated as a percentage of CPU_COUNT.
|
|
197
|
+
If negative float between -1 and 0, it's treated as a percentage of len(args).
|
|
198
|
+
delay_first_calls (float): Apply i*delay_first_calls seconds delay to the first "max_workers" calls.
|
|
199
|
+
For instance with value to 1, the first thread will be delayed by 0 seconds, the second by 1 second, etc.
|
|
200
|
+
(Defaults to 0): This can be useful to avoid functions being called in the same second.
|
|
201
|
+
color (str): Color of the progress bar (Defaults to MAGENTA)
|
|
202
|
+
bar_format (str): Format of the progress bar (Defaults to BAR_FORMAT)
|
|
203
|
+
ascii (bool): Whether to use ASCII or Unicode characters for the progress bar
|
|
204
|
+
smooth_tqdm (bool): Whether to enable smooth progress bar updates by setting miniters and mininterval (Defaults to True)
|
|
205
|
+
**tqdm_kwargs (Any): Additional keyword arguments to pass to tqdm
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
list[object]: Results of the function execution
|
|
209
|
+
|
|
210
|
+
Examples:
|
|
211
|
+
.. code-block:: python
|
|
212
|
+
|
|
213
|
+
> multithreading(doctest_square, args=[1, 2, 3])
|
|
214
|
+
[1, 4, 9]
|
|
215
|
+
|
|
216
|
+
> multithreading(int.__mul__, [(1,2), (3,4), (5,6)], use_starmap=True)
|
|
217
|
+
[2, 12, 30]
|
|
218
|
+
|
|
219
|
+
> # Using a list of functions (one per argument)
|
|
220
|
+
> multithreading([doctest_square, doctest_square, doctest_square], [1, 2, 3])
|
|
221
|
+
[1, 4, 9]
|
|
222
|
+
|
|
223
|
+
> # Will process in parallel with progress bar
|
|
224
|
+
> multithreading(doctest_slow, range(10), desc="Threading")
|
|
225
|
+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
226
|
+
|
|
227
|
+
> # Will process in parallel with progress bar and delay the first threads
|
|
228
|
+
> multithreading(
|
|
229
|
+
. doctest_slow,
|
|
230
|
+
. range(10),
|
|
231
|
+
. desc="Threading with delay",
|
|
232
|
+
. max_workers=2,
|
|
233
|
+
. delay_first_calls=0.6
|
|
234
|
+
. )
|
|
235
|
+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
236
|
+
"""
|
|
237
|
+
# Imports
|
|
238
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
239
|
+
|
|
240
|
+
from tqdm.auto import tqdm
|
|
241
|
+
|
|
242
|
+
# Handle parameters
|
|
243
|
+
args = list(args) # Ensure we have a list (not other iterable)
|
|
244
|
+
if max_workers == -1:
|
|
245
|
+
max_workers = CPU_COUNT
|
|
246
|
+
if isinstance(max_workers, float):
|
|
247
|
+
if max_workers > 0:
|
|
248
|
+
assert max_workers <= 1, "max_workers as positive float must be between 0 and 1 (percentage of CPU_COUNT)"
|
|
249
|
+
max_workers = int(max_workers * CPU_COUNT)
|
|
250
|
+
else:
|
|
251
|
+
assert -1 <= max_workers < 0, "max_workers as negative float must be between -1 and 0 (percentage of len(args))"
|
|
252
|
+
max_workers = int(-max_workers * len(args))
|
|
253
|
+
verbose: bool = desc != ""
|
|
254
|
+
desc, func, args = _handle_parameters(func, args, use_starmap, delay_first_calls, max_workers, desc, color)
|
|
255
|
+
if bar_format == BAR_FORMAT:
|
|
256
|
+
bar_format = bar_format.replace(MAGENTA, color)
|
|
257
|
+
if smooth_tqdm:
|
|
258
|
+
tqdm_kwargs.setdefault("mininterval", 0.0)
|
|
259
|
+
try:
|
|
260
|
+
total = len(args) # type: ignore
|
|
261
|
+
import shutil
|
|
262
|
+
width = shutil.get_terminal_size().columns
|
|
263
|
+
tqdm_kwargs.setdefault("miniters", max(1, total // width))
|
|
264
|
+
except (TypeError, OSError):
|
|
265
|
+
tqdm_kwargs.setdefault("miniters", 1)
|
|
266
|
+
|
|
267
|
+
# Do multithreading only if there is more than 1 argument and more than 1 CPU
|
|
268
|
+
if max_workers > 1 and len(args) > 1:
|
|
269
|
+
if verbose:
|
|
270
|
+
with ThreadPoolExecutor(max_workers) as executor:
|
|
271
|
+
return list(tqdm(executor.map(func, args), total=len(args), desc=desc, bar_format=bar_format, ascii=ascii, **tqdm_kwargs))
|
|
272
|
+
else:
|
|
273
|
+
with ThreadPoolExecutor(max_workers) as executor:
|
|
274
|
+
return list(executor.map(func, args))
|
|
275
|
+
|
|
276
|
+
# Single process execution
|
|
277
|
+
else:
|
|
278
|
+
if verbose:
|
|
279
|
+
return [func(arg) for arg in tqdm(args, total=len(args), desc=desc, bar_format=bar_format, ascii=ascii, **tqdm_kwargs)]
|
|
280
|
+
else:
|
|
281
|
+
return [func(arg) for arg in args]
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def run_in_subprocess[R](
|
|
285
|
+
func: Callable[..., R],
|
|
286
|
+
*args: Any,
|
|
287
|
+
timeout: float | None = None,
|
|
288
|
+
no_join: bool = False,
|
|
289
|
+
**kwargs: Any
|
|
290
|
+
) -> R:
|
|
291
|
+
""" Execute a function in a subprocess with positional and keyword arguments.
|
|
292
|
+
|
|
293
|
+
This is useful when you need to run a function in isolation to avoid memory leaks,
|
|
294
|
+
resource conflicts, or to ensure a clean execution environment. The subprocess will
|
|
295
|
+
be created, run the function with the provided arguments, and return the result.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
func (Callable): The function to execute in a subprocess.
|
|
299
|
+
(SHOULD BE A TOP-LEVEL FUNCTION TO BE PICKLABLE)
|
|
300
|
+
*args (Any): Positional arguments to pass to the function.
|
|
301
|
+
timeout (float | None): Maximum time in seconds to wait for the subprocess.
|
|
302
|
+
If None, wait indefinitely. If the subprocess exceeds this time, it will be terminated.
|
|
303
|
+
no_join (bool): If True, do not wait for the subprocess to finish (fire-and-forget).
|
|
304
|
+
**kwargs (Any): Keyword arguments to pass to the function.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
R: The return value of the function.
|
|
308
|
+
|
|
309
|
+
Raises:
|
|
310
|
+
RuntimeError: If the subprocess exits with a non-zero exit code or times out.
|
|
311
|
+
TimeoutError: If the subprocess exceeds the specified timeout.
|
|
312
|
+
|
|
313
|
+
Examples:
|
|
314
|
+
.. code-block:: python
|
|
315
|
+
|
|
316
|
+
> # Simple function execution
|
|
317
|
+
> run_in_subprocess(doctest_square, 5)
|
|
318
|
+
25
|
|
319
|
+
|
|
320
|
+
> # Function with multiple arguments
|
|
321
|
+
> def add(a: int, b: int) -> int:
|
|
322
|
+
. return a + b
|
|
323
|
+
> run_in_subprocess(add, 10, 20)
|
|
324
|
+
30
|
|
325
|
+
|
|
326
|
+
> # Function with keyword arguments
|
|
327
|
+
> def greet(name: str, greeting: str = "Hello") -> str:
|
|
328
|
+
. return f"{greeting}, {name}!"
|
|
329
|
+
> run_in_subprocess(greet, "World", greeting="Hi")
|
|
330
|
+
'Hi, World!'
|
|
331
|
+
|
|
332
|
+
> # With timeout to prevent hanging
|
|
333
|
+
> run_in_subprocess(some_gpu_func, data, timeout=300.0)
|
|
334
|
+
"""
|
|
335
|
+
import multiprocessing as mp
|
|
336
|
+
from multiprocessing import Queue
|
|
337
|
+
|
|
338
|
+
# Create a queue to get the result from the subprocess
|
|
339
|
+
result_queue: Queue[R | Exception] = Queue()
|
|
340
|
+
|
|
341
|
+
# Create and start the subprocess using the module-level wrapper
|
|
342
|
+
process: mp.Process = mp.Process(
|
|
343
|
+
target=_subprocess_wrapper,
|
|
344
|
+
args=(result_queue, func, args, kwargs)
|
|
345
|
+
)
|
|
346
|
+
process.start()
|
|
347
|
+
|
|
348
|
+
# Join with timeout to prevent indefinite hanging
|
|
349
|
+
if no_join:
|
|
350
|
+
return None # type: ignore
|
|
351
|
+
process.join(timeout=timeout)
|
|
352
|
+
|
|
353
|
+
# Check if process is still alive (timed out)
|
|
354
|
+
if process.is_alive():
|
|
355
|
+
process.terminate()
|
|
356
|
+
time.sleep(0.5) # Give it a moment to terminate gracefully
|
|
357
|
+
if process.is_alive():
|
|
358
|
+
process.kill()
|
|
359
|
+
process.join()
|
|
360
|
+
raise TimeoutError(f"Subprocess exceeded timeout of {timeout} seconds and was terminated")
|
|
361
|
+
|
|
362
|
+
# Check exit code
|
|
363
|
+
if process.exitcode != 0:
|
|
364
|
+
# Try to get any exception from the queue (non-blocking)
|
|
365
|
+
error_msg = f"Subprocess failed with exit code {process.exitcode}"
|
|
366
|
+
try:
|
|
367
|
+
if not result_queue.empty():
|
|
368
|
+
result_or_exception = result_queue.get_nowait()
|
|
369
|
+
if isinstance(result_or_exception, Exception):
|
|
370
|
+
raise result_or_exception
|
|
371
|
+
except Exception:
|
|
372
|
+
pass
|
|
373
|
+
raise RuntimeError(error_msg)
|
|
374
|
+
|
|
375
|
+
# Retrieve the result
|
|
376
|
+
try:
|
|
377
|
+
result_or_exception = result_queue.get_nowait()
|
|
378
|
+
if isinstance(result_or_exception, Exception):
|
|
379
|
+
raise result_or_exception
|
|
380
|
+
return result_or_exception
|
|
381
|
+
except Exception as e:
|
|
382
|
+
raise RuntimeError("Subprocess did not return any result") from e
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
# "Private" function for subprocess wrapper (must be at module level for pickling on Windows)
|
|
386
|
+
def _subprocess_wrapper[R](
|
|
387
|
+
result_queue: Any,
|
|
388
|
+
func: Callable[..., R],
|
|
389
|
+
args: tuple[Any, ...],
|
|
390
|
+
kwargs: dict[str, Any]
|
|
391
|
+
) -> None:
|
|
392
|
+
""" Wrapper function to execute the target function and store the result in the queue.
|
|
393
|
+
|
|
394
|
+
Must be at module level to be pickable on Windows (spawn context).
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
result_queue (multiprocessing.Queue): Queue to store the result or exception.
|
|
398
|
+
func (Callable): The target function to execute.
|
|
399
|
+
args (tuple): Positional arguments for the function.
|
|
400
|
+
kwargs (dict): Keyword arguments for the function.
|
|
401
|
+
"""
|
|
402
|
+
try:
|
|
403
|
+
result: R = func(*args, **kwargs)
|
|
404
|
+
result_queue.put(result)
|
|
405
|
+
except Exception as e:
|
|
406
|
+
result_queue.put(e)
|
|
407
|
+
|
|
408
|
+
# "Private" function to use starmap
|
|
409
|
+
def _starmap[T, R](args: tuple[Callable[[T], R], list[T]]) -> R:
|
|
410
|
+
r""" Private function to use starmap using args[0](\*args[1])
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
args (tuple): Tuple containing the function and the arguments list to pass to the function
|
|
414
|
+
Returns:
|
|
415
|
+
object: Result of the function execution
|
|
416
|
+
"""
|
|
417
|
+
func, arguments = args
|
|
418
|
+
return func(*arguments)
|
|
419
|
+
|
|
420
|
+
# "Private" function to apply delay before calling the target function
|
|
421
|
+
def _delayed_call[T, R](args: tuple[Callable[[T], R], float, T]) -> R:
|
|
422
|
+
""" Private function to apply delay before calling the target function
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
args (tuple): Tuple containing the function, delay in seconds, and the argument to pass to the function
|
|
426
|
+
Returns:
|
|
427
|
+
object: Result of the function execution
|
|
428
|
+
"""
|
|
429
|
+
func, delay, arg = args
|
|
430
|
+
time.sleep(delay)
|
|
431
|
+
return func(arg)
|
|
432
|
+
|
|
433
|
+
# "Private" function to handle parameters for multiprocessing or multithreading functions
|
|
434
|
+
def _handle_parameters[T, R](
|
|
435
|
+
func: Callable[[T], R] | list[Callable[[T], R]],
|
|
436
|
+
args: list[T],
|
|
437
|
+
use_starmap: bool,
|
|
438
|
+
delay_first_calls: float,
|
|
439
|
+
max_workers: int,
|
|
440
|
+
desc: str,
|
|
441
|
+
color: str
|
|
442
|
+
) -> tuple[str, Callable[[T], R], list[T]]:
|
|
443
|
+
r""" Private function to handle the parameters for multiprocessing or multithreading functions
|
|
444
|
+
|
|
445
|
+
Args:
|
|
446
|
+
func (Callable | list[Callable]): Function to execute, or list of functions (one per argument)
|
|
447
|
+
args (list): List of arguments to pass to the function(s)
|
|
448
|
+
use_starmap (bool): Whether to use starmap or not (Defaults to False):
|
|
449
|
+
True means the function will be called like func(\*args[i]) instead of func(args[i])
|
|
450
|
+
delay_first_calls (int): Apply i*delay_first_calls seconds delay to the first "max_workers" calls.
|
|
451
|
+
For instance, the first process will be delayed by 0 seconds, the second by 1 second, etc. (Defaults to 0):
|
|
452
|
+
This can be useful to avoid functions being called in the same second.
|
|
453
|
+
max_workers (int): Number of workers to use (Defaults to CPU_COUNT)
|
|
454
|
+
desc (str): Description of the function execution displayed in the progress bar
|
|
455
|
+
color (str): Color of the progress bar
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
tuple[str, Callable[[T], R], list[T]]: Tuple containing the description, function, and arguments
|
|
459
|
+
"""
|
|
460
|
+
desc = color + desc
|
|
461
|
+
|
|
462
|
+
# Handle list of functions: validate and convert to starmap format
|
|
463
|
+
if isinstance(func, list):
|
|
464
|
+
func = cast(list[Callable[[T], R]], func)
|
|
465
|
+
assert len(func) == len(args), f"Length mismatch: {len(func)} functions but {len(args)} arguments"
|
|
466
|
+
args = [(f, arg if use_starmap else (arg,)) for f, arg in zip(func, args, strict=False)] # type: ignore
|
|
467
|
+
func = _starmap # type: ignore
|
|
468
|
+
|
|
469
|
+
# If use_starmap is True, we use the _starmap function
|
|
470
|
+
elif use_starmap:
|
|
471
|
+
args = [(func, arg) for arg in args] # type: ignore
|
|
472
|
+
func = _starmap # type: ignore
|
|
473
|
+
|
|
474
|
+
# Prepare delayed function calls if delay_first_calls is set
|
|
475
|
+
if delay_first_calls > 0:
|
|
476
|
+
args = [
|
|
477
|
+
(func, i * delay_first_calls if i < max_workers else 0, arg) # type: ignore
|
|
478
|
+
for i, arg in enumerate(args)
|
|
479
|
+
]
|
|
480
|
+
func = _delayed_call # type: ignore
|
|
481
|
+
|
|
482
|
+
return desc, func, args # type: ignore
|
|
483
|
+
|