speedy-utils 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,198 @@
1
+ import inspect
2
+ import multiprocessing
3
+ import os
4
+ import time
5
+ import traceback
6
+ from collections.abc import Callable, Iterable, Iterator, Sequence
7
+ from concurrent.futures import ProcessPoolExecutor, as_completed
8
+ from itertools import islice
9
+ from typing import Any, List, TypeVar, cast
10
+
11
+ T = TypeVar("T")
12
+
13
+ if hasattr(multiprocessing, "set_start_method"):
14
+ try:
15
+ multiprocessing.set_start_method("spawn", force=True)
16
+ except RuntimeError:
17
+ pass
18
+
19
+ try:
20
+ from tqdm import tqdm
21
+ except ImportError: # pragma: no cover
22
+ tqdm = None # type: ignore[assignment]
23
+
24
+
25
+ # ──── internal helpers ────────────────────────────────────────────────────
26
+
27
+
28
+ def _group_iter(src: Iterable[Any], size: int) -> Iterable[list[Any]]:
29
+ "Yield *size*-sized chunks from *src*."
30
+ it = iter(src)
31
+ while chunk := list(islice(it, size)):
32
+ yield chunk
33
+
34
+
35
+ def _short_tb() -> str:
36
+ tb = "".join(traceback.format_exc())
37
+ return "\n".join(ln for ln in tb.splitlines() if "multi_process" not in ln)
38
+
39
+
40
+ def _safe_call(func: Callable, obj, fixed):
41
+ try:
42
+ return func(obj, **fixed)
43
+ except Exception as exc:
44
+ func_name = getattr(func, "__name__", str(func))
45
+ raise RuntimeError(
46
+ f"{func_name}({obj!r}) failed: {exc}\n{_short_tb()}"
47
+ ) from exc
48
+
49
+
50
+ def _worker_process(
51
+ func: Callable, item_batch: Any, fixed_kwargs: dict, batch_size: int
52
+ ):
53
+ """Worker function executed in each process."""
54
+ if batch_size > 1:
55
+ results = []
56
+ for itm in item_batch:
57
+ try:
58
+ results.append(_safe_call(func, itm, fixed_kwargs))
59
+ except Exception:
60
+ results.append(None)
61
+ return results
62
+ return _safe_call(func, item_batch, fixed_kwargs)
63
+
64
+
65
+ # ──── public API ──────────────────────────────────────────────────────────
66
+ def multi_process(
67
+ func: Callable[[Any], Any],
68
+ inputs: Iterable[Any],
69
+ *,
70
+ workers: int | None = None,
71
+ batch: int = 1,
72
+ ordered: bool = True,
73
+ progress: bool = False,
74
+ inflight: int | None = None,
75
+ timeout: float | None = None,
76
+ stop_on_error: bool = True,
77
+ process_update_interval=10,
78
+ **fixed_kwargs,
79
+ ) -> list[Any]:
80
+ """
81
+ Simple multi‑processing parallel map that returns a *list*.
82
+
83
+ Parameters
84
+ ----------
85
+ func – target callable executed in separate processes, must be of the form f(obj, ...).
86
+ inputs – iterable with the objects.
87
+ workers – process pool size (defaults to :pyfunc:`os.cpu_count()`).
88
+ batch – package *batch* inputs into one call to reduce IPC overhead.
89
+ ordered – keep original order; if ``False`` results stream as finished.
90
+ progress – show a tqdm bar (requires *tqdm*).
91
+ inflight – max logical items concurrently submitted
92
+ *(default: ``workers × 4``)*.
93
+ timeout – overall timeout for the mapping (seconds).
94
+ stop_on_error – raise immediately on first exception (default) or
95
+ substitute failing result with ``None``.
96
+ **fixed_kwargs – static keyword args forwarded to every ``func()`` call.
97
+ """
98
+ if workers is None:
99
+ workers = os.cpu_count() or 1
100
+ if inflight is None:
101
+ inflight = workers * 4
102
+ if batch < 1:
103
+ raise ValueError("batch must be ≥ 1")
104
+
105
+ try:
106
+ n_inputs = len(inputs) # type: ignore[arg-type]
107
+ except Exception:
108
+ n_inputs = None
109
+
110
+ src_iter: Iterator[Any] = iter(inputs)
111
+ if batch > 1:
112
+ src_iter = cast(Iterator[Any], _group_iter(src_iter, batch))
113
+
114
+ logical_total = n_inputs
115
+ bar = None
116
+ last_bar = 0
117
+ if progress and tqdm is not None and logical_total is not None:
118
+ bar = tqdm(
119
+ total=logical_total,
120
+ ncols=80,
121
+ colour="green",
122
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"
123
+ " [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
124
+ )
125
+
126
+ if ordered and logical_total is not None:
127
+ results: list[Any] = [None] * logical_total
128
+ else:
129
+ results = []
130
+
131
+ completed = 0
132
+ next_idx = 0
133
+
134
+ with ProcessPoolExecutor(max_workers=workers) as pool:
135
+ futures = set()
136
+
137
+ for _ in range(min(inflight, workers)):
138
+ try:
139
+ arg = next(src_iter)
140
+ except StopIteration:
141
+ break
142
+ fut = pool.submit(_worker_process, func, arg, fixed_kwargs, batch)
143
+ fut.idx = next_idx # type: ignore[attr-defined]
144
+ futures.add(fut)
145
+ next_idx += len(arg) if batch > 1 else 1
146
+
147
+ while futures:
148
+ for fut in as_completed(futures, timeout=timeout):
149
+ futures.remove(fut)
150
+ idx = fut.idx # type: ignore[attr-defined]
151
+ try:
152
+ res = fut.result()
153
+ except Exception:
154
+ if stop_on_error:
155
+ raise
156
+ num_items = batch if batch > 1 else 1
157
+ res = [None] * num_items if batch > 1 else None
158
+
159
+ out_items = res if batch > 1 else [res]
160
+ if out_items is None:
161
+ out_items = []
162
+
163
+ if ordered and logical_total is not None:
164
+ if isinstance(out_items, list) and len(out_items) > 0:
165
+ for i, item in enumerate(out_items):
166
+ if idx + i < len(results):
167
+ results[idx + i] = item
168
+ else:
169
+ if isinstance(out_items, list):
170
+ results.extend(out_items)
171
+
172
+ completed += len(out_items)
173
+
174
+ if bar and completed - last_bar >= process_update_interval:
175
+ bar.update(completed - last_bar)
176
+ last_bar = completed
177
+
178
+ try:
179
+ while next_idx - completed < inflight:
180
+ arg = next(src_iter)
181
+ fut2 = pool.submit(
182
+ _worker_process, func, arg, fixed_kwargs, batch
183
+ )
184
+ fut2.idx = next_idx # type: ignore[attr-defined]
185
+ futures.add(fut2)
186
+ next_idx += len(arg) if batch > 1 else 1
187
+ except StopIteration:
188
+ pass
189
+ break
190
+
191
+ if bar:
192
+ bar.update(completed - last_bar)
193
+ bar.close()
194
+
195
+ return results
196
+
197
+
198
+ __all__ = ["multi_process"]
@@ -0,0 +1,327 @@
1
+ """Provides thread-based parallel execution utilities."""
2
+
3
+ import os
4
+ import time
5
+ import traceback
6
+ from collections.abc import Callable, Iterable
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from itertools import islice
9
+ from typing import Any, TypeVar
10
+
11
+ from loguru import logger
12
+
13
+ try:
14
+ from tqdm import tqdm
15
+ except ImportError: # pragma: no cover
16
+ tqdm = None # type: ignore[assignment]
17
+
18
+ # Sensible defaults
19
+ DEFAULT_WORKERS = (os.cpu_count() or 4) * 2
20
+
21
+ T = TypeVar("T")
22
+ R = TypeVar("R")
23
+
24
+
25
+ def _group_iter(src: Iterable[T], size: int) -> Iterable[list[T]]:
26
+ """Yield successive chunks from iterable of specified size."""
27
+ it = iter(src)
28
+ while chunk := list(islice(it, size)):
29
+ yield chunk
30
+
31
+
32
+ def _short_tb() -> str:
33
+ """Return a shortened traceback, excluding internal frames."""
34
+ tb = "".join(traceback.format_exc())
35
+ # hide frames inside this helper to keep logs short
36
+ return "\n".join(ln for ln in tb.splitlines() if "multi_thread.py" not in ln)
37
+
38
+
39
+ def _worker(item: T, func: Callable[[T], R], fixed_kwargs: dict[str, Any]) -> R:
40
+ """Execute the function with an item and fixed kwargs."""
41
+ return func(item, **fixed_kwargs)
42
+
43
+
44
+ # ────────────────────────────────────────────────────────────
45
+ # main API
46
+ # ────────────────────────────────────────────────────────────
47
+ def multi_thread(
48
+ func: Callable,
49
+ inputs: Iterable[Any],
50
+ *,
51
+ workers: int | None = DEFAULT_WORKERS,
52
+ batch: int = 1,
53
+ ordered: bool = True,
54
+ progress: bool = True,
55
+ progress_update: int = 10,
56
+ prefetch_factor: int = 4,
57
+ timeout: float | None = None,
58
+ stop_on_error: bool = True,
59
+ n_proc=0,
60
+ store_output_pkl_file: str | None = None,
61
+ **fixed_kwargs,
62
+ ) -> list[Any]:
63
+ """
64
+ ThreadPool **map** that returns a *list*.
65
+
66
+ Parameters
67
+ ----------
68
+ func – target callable.
69
+ inputs – iterable with the arguments.
70
+ workers – defaults to ``os.cpu_count()*2``.
71
+ batch – package *batch* inputs into one call for low‑overhead.
72
+ ordered – keep original order (costs memory); if ``False`` results
73
+ are yielded as soon as they finish.
74
+ progress – show a tqdm bar (requires *tqdm* installed).
75
+ progress_update – bar redraw frequency (logical items, *not* batches).
76
+ prefetch_factor – in‑flight tasks ≈ ``workers * prefetch_factor``.
77
+ timeout – overall timeout (seconds) for the mapping.
78
+ stop_on_error – raise immediately on first exception (default). If
79
+ ``False`` the failing task’s result becomes ``None``.
80
+ **fixed_kwargs – static keyword args forwarded to every ``func()`` call.
81
+ """
82
+ from speedy_utils import dump_json_or_pickle, load_by_ext
83
+
84
+ if n_proc > 1:
85
+ import tempfile
86
+
87
+ from fastcore.all import threaded
88
+
89
+ # split the inputs by nproc
90
+ inputs = list(inputs)
91
+ n_per_proc = max(len(inputs) // n_proc, 1)
92
+ proc_inputs_list = []
93
+ for i in range(0, len(inputs), n_per_proc):
94
+ proc_inputs_list.append(inputs[i : i + n_per_proc])
95
+ procs = []
96
+ in_process_multi_thread = threaded(process=True)(multi_thread)
97
+
98
+ for proc_id, proc_inputs in enumerate(proc_inputs_list):
99
+ with tempfile.NamedTemporaryFile(
100
+ delete=False, suffix="multi_thread.pkl"
101
+ ) as tmp_file:
102
+ file_pkl = tmp_file.name
103
+ assert isinstance(in_process_multi_thread, Callable)
104
+ proc = in_process_multi_thread(
105
+ func,
106
+ proc_inputs,
107
+ workers=workers,
108
+ batch=batch,
109
+ ordered=ordered,
110
+ progress=proc_id == 0,
111
+ progress_update=progress_update,
112
+ prefetch_factor=prefetch_factor,
113
+ timeout=timeout,
114
+ stop_on_error=stop_on_error,
115
+ n_proc=0, # prevent recursion
116
+ store_output_pkl_file=file_pkl,
117
+ **fixed_kwargs,
118
+ )
119
+ procs.append([proc, file_pkl])
120
+ # join
121
+ results = []
122
+
123
+ for proc, file_pkl in procs:
124
+ proc.join()
125
+ logger.info(f"Done proc {proc=}")
126
+ results.extend(load_by_ext(file_pkl))
127
+ return results
128
+
129
+ try:
130
+ import pandas as pd
131
+
132
+ if isinstance(inputs, pd.DataFrame):
133
+ inputs = inputs.to_dict(orient="records")
134
+ except ImportError:
135
+ pass
136
+
137
+ try:
138
+ n_inputs = len(inputs) # type: ignore[arg-type]
139
+ except Exception:
140
+ n_inputs = None
141
+ workers_val = workers if workers is not None else DEFAULT_WORKERS
142
+
143
+ if batch == 1 and n_inputs and n_inputs / max(workers_val, 1) > 20_000:
144
+ batch = 32 # empirically good for sub‑ms tasks
145
+
146
+ # ── build (maybe‑batched) source iterator ────────────────────────────
147
+ src_iter: Iterable[Any] = iter(inputs)
148
+ if batch > 1:
149
+ src_iter = _group_iter(src_iter, batch)
150
+ # Ensure src_iter is always an iterator
151
+ src_iter = iter(src_iter)
152
+
153
+ # total logical items (for bar & ordered pre‑allocation)
154
+ logical_total = n_inputs
155
+ if logical_total is not None and batch > 1:
156
+ logical_total = n_inputs # still number of *items*, not batches
157
+
158
+ # ── progress bar ─────────────────────────────────────────────────────
159
+ bar = None
160
+ last_bar_update = 0
161
+ if progress and tqdm is not None and logical_total is not None:
162
+ bar = tqdm(
163
+ total=logical_total,
164
+ ncols=128,
165
+ colour="green",
166
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"
167
+ " [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
168
+ )
169
+
170
+ # ── prepare result container ─────────────────────────────────────────
171
+ if ordered and logical_total is not None:
172
+ results: list[Any] = [None] * logical_total
173
+ else:
174
+ results = []
175
+
176
+ # ── main execution loop ──────────────────────────────────────────────────
177
+ workers_val = workers if workers is not None else DEFAULT_WORKERS
178
+ max_inflight = workers_val * max(prefetch_factor, 1)
179
+ completed_items = 0
180
+ next_logical_idx = 0 # index assigned to the next submission
181
+
182
+ with ThreadPoolExecutor(max_workers=workers) as pool:
183
+ inflight = set()
184
+
185
+ # prime the pool
186
+ for _ in range(max_inflight):
187
+ try:
188
+ arg = next(src_iter)
189
+ except StopIteration:
190
+ break
191
+ if batch > 1:
192
+ fut = pool.submit(
193
+ lambda items: [_worker(item, func, fixed_kwargs) for item in items],
194
+ arg,
195
+ )
196
+ fut.idx = next_logical_idx # type: ignore[attr-defined]
197
+ inflight.add(fut)
198
+ next_logical_idx += len(arg)
199
+ else:
200
+ fut = pool.submit(_worker, arg, func, fixed_kwargs)
201
+ fut.idx = next_logical_idx # type: ignore[attr-defined]
202
+ inflight.add(fut)
203
+ next_logical_idx += 1
204
+
205
+ try:
206
+ # Process futures as they complete and add new ones to keep the pool busy
207
+ while inflight: # Continue until all in-flight tasks are processed
208
+ for fut in as_completed(inflight, timeout=timeout):
209
+ inflight.remove(fut)
210
+ idx = fut.idx # type: ignore[attr-defined]
211
+ try:
212
+ res = fut.result()
213
+ except Exception:
214
+ if stop_on_error:
215
+ raise
216
+ res = None
217
+
218
+ # flatten res to list of logical outputs
219
+ out_items = res if batch > 1 else [res]
220
+
221
+ # Ensure out_items is a list (and thus Sized)
222
+ if out_items is None:
223
+ out_items = [None]
224
+ elif not isinstance(out_items, list):
225
+ out_items = (
226
+ list(out_items)
227
+ if isinstance(out_items, Iterable)
228
+ else [out_items]
229
+ )
230
+
231
+ # store outputs
232
+ if ordered and logical_total is not None:
233
+ results[idx : idx + len(out_items)] = out_items
234
+ else:
235
+ results.extend(out_items)
236
+
237
+ completed_items += len(out_items)
238
+
239
+ # progress bar update
240
+ if bar and completed_items - last_bar_update >= progress_update:
241
+ bar.update(completed_items - last_bar_update)
242
+ last_bar_update = completed_items
243
+ # Show pending, submitted, processing in the bar postfix
244
+ submitted = next_logical_idx
245
+ processing = min(len(inflight), workers_val)
246
+ pending = (
247
+ (logical_total - submitted)
248
+ if logical_total is not None
249
+ else None
250
+ )
251
+ postfix = {
252
+ "pending": pending if pending is not None else "-",
253
+ # 'submitted': submitted,
254
+ "processing": processing,
255
+ }
256
+ bar.set_postfix(postfix)
257
+
258
+ # keep queue full
259
+ try:
260
+ while next_logical_idx - completed_items < max_inflight:
261
+ arg = next(src_iter)
262
+ if batch > 1:
263
+ fut2 = pool.submit(
264
+ lambda items: [
265
+ _worker(item, func, fixed_kwargs)
266
+ for item in items
267
+ ],
268
+ arg,
269
+ )
270
+ fut2.idx = next_logical_idx # type: ignore[attr-defined]
271
+ inflight.add(fut2)
272
+ next_logical_idx += len(arg)
273
+ else:
274
+ fut2 = pool.submit(_worker, arg, func, fixed_kwargs)
275
+ fut2.idx = next_logical_idx # type: ignore[attr-defined]
276
+ inflight.add(fut2)
277
+ next_logical_idx += 1
278
+ except StopIteration:
279
+ pass
280
+
281
+ # Break the inner loop as we've processed one future
282
+ break
283
+
284
+ # If we've exhausted the inner loop without processing anything,
285
+ # and there are still in-flight tasks, we need to wait for them
286
+ if inflight and timeout is not None:
287
+ # Use a small timeout to avoid hanging indefinitely
288
+ time.sleep(min(0.01, timeout / 10))
289
+
290
+ finally:
291
+ if bar:
292
+ bar.update(completed_items - last_bar_update)
293
+ bar.close()
294
+ if store_output_pkl_file:
295
+ dump_json_or_pickle(results, store_output_pkl_file)
296
+ return results
297
+
298
+
299
+ def multi_thread_standard(
300
+ fn: Callable[[Any], Any], items: Iterable[Any], workers: int = 4
301
+ ) -> list[Any]:
302
+ """Execute a function using standard ThreadPoolExecutor.
303
+
304
+ A standard implementation of multi-threading using ThreadPoolExecutor.
305
+ Ensures the order of results matches the input order.
306
+
307
+ Parameters
308
+ ----------
309
+ fn : Callable
310
+ The function to execute for each item.
311
+ items : Iterable
312
+ The items to process.
313
+ workers : int, optional
314
+ Number of worker threads, by default 4.
315
+
316
+ Returns
317
+ -------
318
+ list
319
+ Results in same order as input items.
320
+ """
321
+ with ThreadPoolExecutor(max_workers=workers) as executor:
322
+ futures = [executor.submit(fn, item) for item in items]
323
+ results = [fut.result() for fut in futures]
324
+ return results
325
+
326
+
327
+ __all__ = ["multi_thread", "multi_thread_standard"]
@@ -0,0 +1,108 @@
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import itertools
4
+ import multiprocessing # Import multiprocessing module
5
+ import os
6
+ import shlex # To properly escape command line arguments
7
+ import shutil
8
+ import subprocess
9
+
10
+ taskset_path = shutil.which("taskset")
11
+
12
+
13
+ def assert_script(python_path):
14
+ code_str = open(python_path).read()
15
+ if "MP_ID" not in code_str or "MP_TOTAL" not in code_str:
16
+ example_code = (
17
+ 'import os; MP_TOTAL = int(os.environ.get("MP_TOTAL"));MP_ID = int(os.environ.get("MP_ID"))\n'
18
+ "inputs = list(inputs[MP_ID::MP_TOTAL])"
19
+ )
20
+ # ANSI escape codes for coloring
21
+ YELLOW = "\033[93m"
22
+ RESET = "\033[0m"
23
+ raise_msg = (
24
+ f"MP_ID and MP_TOTAL not found in {python_path}, please add them.\n\n"
25
+ f"Example:\n{YELLOW}{example_code}{RESET}"
26
+ )
27
+ raise Exception(raise_msg)
28
+
29
+
30
+ def run_in_tmux(commands_to_run, tmux_name, num_windows):
31
+ with open("/tmp/start_multirun_tmux.sh", "w") as script_file:
32
+ # first cmd is to kill the session if it exists
33
+
34
+ script_file.write("#!/bin/bash\n\n")
35
+ script_file.write(f"tmux kill-session -t {tmux_name}\nsleep .1\n")
36
+ script_file.write(f"tmux new-session -d -s {tmux_name}\n")
37
+ for i, cmd in enumerate(itertools.cycle(commands_to_run)):
38
+ if i >= num_windows:
39
+ break
40
+ window_name = f"{tmux_name}:{i}"
41
+ if i == 0:
42
+ script_file.write(f"tmux send-keys -t {window_name} '{cmd}' C-m\n")
43
+ else:
44
+ script_file.write(f"tmux new-window -t {tmux_name}\n")
45
+ script_file.write(f"tmux send-keys -t {window_name} '{cmd}' C-m\n")
46
+
47
+ # Make the script executable
48
+ script_file.write("chmod +x /tmp/start_multirun_tmux.sh\n")
49
+ print("Run /tmp/start_multirun_tmux.sh")
50
+
51
+
52
+ def main():
53
+ # Assert that MP_ID and MP_TOTAL are not already set
54
+
55
+ parser = argparse.ArgumentParser(description="Process fold arguments")
56
+ parser.add_argument(
57
+ "--total_fold", "-t", default=16, type=int, help="total number of folds"
58
+ )
59
+ parser.add_argument("--gpus", type=str, default="0,1,2,3,4,5,6,7")
60
+ parser.add_argument("--ignore_gpus", "-ig", type=str, default="")
61
+ parser.add_argument(
62
+ "--total_cpu",
63
+ type=int,
64
+ default=multiprocessing.cpu_count(),
65
+ help="total number of cpu cores available",
66
+ )
67
+ parser.add_argument(
68
+ "cmd", nargs=argparse.REMAINDER
69
+ ) # This will gather the remaining unparsed arguments
70
+
71
+ args = parser.parse_args()
72
+ if not args.cmd or (args.cmd[0] == "--" and len(args.cmd) == 1):
73
+ parser.error("Invalid command provided")
74
+ assert_script(args.cmd[0])
75
+
76
+ cmd_str = None
77
+ if args.cmd[0] == "--":
78
+ cmd_str = shlex.join(args.cmd[1:])
79
+ else:
80
+ cmd_str = shlex.join(args.cmd)
81
+
82
+ gpus = args.gpus.split(",")
83
+ gpus = [gpu for gpu in gpus if not gpu in args.ignore_gpus.split(",")]
84
+ num_gpus = len(gpus)
85
+
86
+ cpu_per_process = max(args.total_cpu // args.total_fold, 1)
87
+ cmds = []
88
+ for i in range(args.total_fold):
89
+ gpu = gpus[i % num_gpus]
90
+ cpu_start = (i * cpu_per_process) % args.total_cpu
91
+ cpu_end = ((i + 1) * cpu_per_process - 1) % args.total_cpu
92
+ ENV = f"CUDA_VISIBLE_DEVICES={gpu} MP_ID={i} MP_TOTAL={args.total_fold}"
93
+ if taskset_path:
94
+ fold_cmd = (
95
+ f"{ENV} {taskset_path} -c {cpu_start}-{cpu_end} python {cmd_str}"
96
+ )
97
+ else:
98
+ fold_cmd = f"{ENV} python {cmd_str}"
99
+
100
+ cmds.append(fold_cmd)
101
+
102
+ run_in_tmux(cmds, "mpython", args.total_fold)
103
+ os.chmod("/tmp/start_multirun_tmux.sh", 0o755) # Make the script executable
104
+ os.system("/tmp/start_multirun_tmux.sh")
105
+
106
+
107
+ if __name__ == "__main__":
108
+ main()