speedy-utils 1.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +29 -0
- llm_utils/chat_format.py +427 -0
- llm_utils/group_messages.py +120 -0
- llm_utils/lm/__init__.py +8 -0
- llm_utils/lm/base_lm.py +304 -0
- llm_utils/lm/utils.py +130 -0
- llm_utils/scripts/vllm_load_balancer.py +353 -0
- llm_utils/scripts/vllm_serve.py +416 -0
- speedy_utils/__init__.py +85 -0
- speedy_utils/all.py +159 -0
- {speedy → speedy_utils}/common/__init__.py +0 -0
- speedy_utils/common/clock.py +215 -0
- speedy_utils/common/function_decorator.py +66 -0
- speedy_utils/common/logger.py +207 -0
- speedy_utils/common/report_manager.py +112 -0
- speedy_utils/common/utils_cache.py +264 -0
- {speedy → speedy_utils}/common/utils_io.py +66 -19
- {speedy → speedy_utils}/common/utils_misc.py +25 -11
- speedy_utils/common/utils_print.py +216 -0
- speedy_utils/multi_worker/__init__.py +0 -0
- speedy_utils/multi_worker/process.py +198 -0
- speedy_utils/multi_worker/thread.py +327 -0
- speedy_utils/scripts/mpython.py +108 -0
- speedy_utils-1.0.5.dist-info/METADATA +279 -0
- speedy_utils-1.0.5.dist-info/RECORD +27 -0
- {speedy_utils-1.0.4.dist-info → speedy_utils-1.0.5.dist-info}/WHEEL +1 -2
- speedy_utils-1.0.5.dist-info/entry_points.txt +3 -0
- speedy/__init__.py +0 -53
- speedy/common/clock.py +0 -68
- speedy/common/utils_cache.py +0 -170
- speedy/common/utils_print.py +0 -138
- speedy/multi_worker.py +0 -121
- speedy_utils-1.0.4.dist-info/METADATA +0 -22
- speedy_utils-1.0.4.dist-info/RECORD +0 -12
- speedy_utils-1.0.4.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import multiprocessing
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
import traceback
|
|
6
|
+
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
7
|
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
8
|
+
from itertools import islice
|
|
9
|
+
from typing import Any, List, TypeVar, cast
|
|
10
|
+
|
|
11
|
+
T = TypeVar("T")
|
|
12
|
+
|
|
13
|
+
if hasattr(multiprocessing, "set_start_method"):
|
|
14
|
+
try:
|
|
15
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
16
|
+
except RuntimeError:
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
except ImportError: # pragma: no cover
|
|
22
|
+
tqdm = None # type: ignore[assignment]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# ──── internal helpers ────────────────────────────────────────────────────
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _group_iter(src: Iterable[Any], size: int) -> Iterable[list[Any]]:
|
|
29
|
+
"Yield *size*-sized chunks from *src*."
|
|
30
|
+
it = iter(src)
|
|
31
|
+
while chunk := list(islice(it, size)):
|
|
32
|
+
yield chunk
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _short_tb() -> str:
|
|
36
|
+
tb = "".join(traceback.format_exc())
|
|
37
|
+
return "\n".join(ln for ln in tb.splitlines() if "multi_process" not in ln)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _safe_call(func: Callable, obj, fixed):
|
|
41
|
+
try:
|
|
42
|
+
return func(obj, **fixed)
|
|
43
|
+
except Exception as exc:
|
|
44
|
+
func_name = getattr(func, "__name__", str(func))
|
|
45
|
+
raise RuntimeError(
|
|
46
|
+
f"{func_name}({obj!r}) failed: {exc}\n{_short_tb()}"
|
|
47
|
+
) from exc
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _worker_process(
|
|
51
|
+
func: Callable, item_batch: Any, fixed_kwargs: dict, batch_size: int
|
|
52
|
+
):
|
|
53
|
+
"""Worker function executed in each process."""
|
|
54
|
+
if batch_size > 1:
|
|
55
|
+
results = []
|
|
56
|
+
for itm in item_batch:
|
|
57
|
+
try:
|
|
58
|
+
results.append(_safe_call(func, itm, fixed_kwargs))
|
|
59
|
+
except Exception:
|
|
60
|
+
results.append(None)
|
|
61
|
+
return results
|
|
62
|
+
return _safe_call(func, item_batch, fixed_kwargs)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ──── public API ──────────────────────────────────────────────────────────
|
|
66
|
+
def multi_process(
|
|
67
|
+
func: Callable[[Any], Any],
|
|
68
|
+
inputs: Iterable[Any],
|
|
69
|
+
*,
|
|
70
|
+
workers: int | None = None,
|
|
71
|
+
batch: int = 1,
|
|
72
|
+
ordered: bool = True,
|
|
73
|
+
progress: bool = False,
|
|
74
|
+
inflight: int | None = None,
|
|
75
|
+
timeout: float | None = None,
|
|
76
|
+
stop_on_error: bool = True,
|
|
77
|
+
process_update_interval=10,
|
|
78
|
+
**fixed_kwargs,
|
|
79
|
+
) -> list[Any]:
|
|
80
|
+
"""
|
|
81
|
+
Simple multi‑processing parallel map that returns a *list*.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
func – target callable executed in separate processes, must be of the form f(obj, ...).
|
|
86
|
+
inputs – iterable with the objects.
|
|
87
|
+
workers – process pool size (defaults to :pyfunc:`os.cpu_count()`).
|
|
88
|
+
batch – package *batch* inputs into one call to reduce IPC overhead.
|
|
89
|
+
ordered – keep original order; if ``False`` results stream as finished.
|
|
90
|
+
progress – show a tqdm bar (requires *tqdm*).
|
|
91
|
+
inflight – max logical items concurrently submitted
|
|
92
|
+
*(default: ``workers × 4``)*.
|
|
93
|
+
timeout – overall timeout for the mapping (seconds).
|
|
94
|
+
stop_on_error – raise immediately on first exception (default) or
|
|
95
|
+
substitute failing result with ``None``.
|
|
96
|
+
**fixed_kwargs – static keyword args forwarded to every ``func()`` call.
|
|
97
|
+
"""
|
|
98
|
+
if workers is None:
|
|
99
|
+
workers = os.cpu_count() or 1
|
|
100
|
+
if inflight is None:
|
|
101
|
+
inflight = workers * 4
|
|
102
|
+
if batch < 1:
|
|
103
|
+
raise ValueError("batch must be ≥ 1")
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
n_inputs = len(inputs) # type: ignore[arg-type]
|
|
107
|
+
except Exception:
|
|
108
|
+
n_inputs = None
|
|
109
|
+
|
|
110
|
+
src_iter: Iterator[Any] = iter(inputs)
|
|
111
|
+
if batch > 1:
|
|
112
|
+
src_iter = cast(Iterator[Any], _group_iter(src_iter, batch))
|
|
113
|
+
|
|
114
|
+
logical_total = n_inputs
|
|
115
|
+
bar = None
|
|
116
|
+
last_bar = 0
|
|
117
|
+
if progress and tqdm is not None and logical_total is not None:
|
|
118
|
+
bar = tqdm(
|
|
119
|
+
total=logical_total,
|
|
120
|
+
ncols=80,
|
|
121
|
+
colour="green",
|
|
122
|
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"
|
|
123
|
+
" [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if ordered and logical_total is not None:
|
|
127
|
+
results: list[Any] = [None] * logical_total
|
|
128
|
+
else:
|
|
129
|
+
results = []
|
|
130
|
+
|
|
131
|
+
completed = 0
|
|
132
|
+
next_idx = 0
|
|
133
|
+
|
|
134
|
+
with ProcessPoolExecutor(max_workers=workers) as pool:
|
|
135
|
+
futures = set()
|
|
136
|
+
|
|
137
|
+
for _ in range(min(inflight, workers)):
|
|
138
|
+
try:
|
|
139
|
+
arg = next(src_iter)
|
|
140
|
+
except StopIteration:
|
|
141
|
+
break
|
|
142
|
+
fut = pool.submit(_worker_process, func, arg, fixed_kwargs, batch)
|
|
143
|
+
fut.idx = next_idx # type: ignore[attr-defined]
|
|
144
|
+
futures.add(fut)
|
|
145
|
+
next_idx += len(arg) if batch > 1 else 1
|
|
146
|
+
|
|
147
|
+
while futures:
|
|
148
|
+
for fut in as_completed(futures, timeout=timeout):
|
|
149
|
+
futures.remove(fut)
|
|
150
|
+
idx = fut.idx # type: ignore[attr-defined]
|
|
151
|
+
try:
|
|
152
|
+
res = fut.result()
|
|
153
|
+
except Exception:
|
|
154
|
+
if stop_on_error:
|
|
155
|
+
raise
|
|
156
|
+
num_items = batch if batch > 1 else 1
|
|
157
|
+
res = [None] * num_items if batch > 1 else None
|
|
158
|
+
|
|
159
|
+
out_items = res if batch > 1 else [res]
|
|
160
|
+
if out_items is None:
|
|
161
|
+
out_items = []
|
|
162
|
+
|
|
163
|
+
if ordered and logical_total is not None:
|
|
164
|
+
if isinstance(out_items, list) and len(out_items) > 0:
|
|
165
|
+
for i, item in enumerate(out_items):
|
|
166
|
+
if idx + i < len(results):
|
|
167
|
+
results[idx + i] = item
|
|
168
|
+
else:
|
|
169
|
+
if isinstance(out_items, list):
|
|
170
|
+
results.extend(out_items)
|
|
171
|
+
|
|
172
|
+
completed += len(out_items)
|
|
173
|
+
|
|
174
|
+
if bar and completed - last_bar >= process_update_interval:
|
|
175
|
+
bar.update(completed - last_bar)
|
|
176
|
+
last_bar = completed
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
while next_idx - completed < inflight:
|
|
180
|
+
arg = next(src_iter)
|
|
181
|
+
fut2 = pool.submit(
|
|
182
|
+
_worker_process, func, arg, fixed_kwargs, batch
|
|
183
|
+
)
|
|
184
|
+
fut2.idx = next_idx # type: ignore[attr-defined]
|
|
185
|
+
futures.add(fut2)
|
|
186
|
+
next_idx += len(arg) if batch > 1 else 1
|
|
187
|
+
except StopIteration:
|
|
188
|
+
pass
|
|
189
|
+
break
|
|
190
|
+
|
|
191
|
+
if bar:
|
|
192
|
+
bar.update(completed - last_bar)
|
|
193
|
+
bar.close()
|
|
194
|
+
|
|
195
|
+
return results
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
__all__ = ["multi_process"]
|
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
"""Provides thread-based parallel execution utilities."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
import traceback
|
|
6
|
+
from collections.abc import Callable, Iterable
|
|
7
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
8
|
+
from itertools import islice
|
|
9
|
+
from typing import Any, TypeVar
|
|
10
|
+
|
|
11
|
+
from loguru import logger
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
except ImportError: # pragma: no cover
|
|
16
|
+
tqdm = None # type: ignore[assignment]
|
|
17
|
+
|
|
18
|
+
# Sensible defaults
|
|
19
|
+
DEFAULT_WORKERS = (os.cpu_count() or 4) * 2
|
|
20
|
+
|
|
21
|
+
T = TypeVar("T")
|
|
22
|
+
R = TypeVar("R")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _group_iter(src: Iterable[T], size: int) -> Iterable[list[T]]:
|
|
26
|
+
"""Yield successive chunks from iterable of specified size."""
|
|
27
|
+
it = iter(src)
|
|
28
|
+
while chunk := list(islice(it, size)):
|
|
29
|
+
yield chunk
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _short_tb() -> str:
|
|
33
|
+
"""Return a shortened traceback, excluding internal frames."""
|
|
34
|
+
tb = "".join(traceback.format_exc())
|
|
35
|
+
# hide frames inside this helper to keep logs short
|
|
36
|
+
return "\n".join(ln for ln in tb.splitlines() if "multi_thread.py" not in ln)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _worker(item: T, func: Callable[[T], R], fixed_kwargs: dict[str, Any]) -> R:
|
|
40
|
+
"""Execute the function with an item and fixed kwargs."""
|
|
41
|
+
return func(item, **fixed_kwargs)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# ────────────────────────────────────────────────────────────
|
|
45
|
+
# main API
|
|
46
|
+
# ────────────────────────────────────────────────────────────
|
|
47
|
+
def multi_thread(
|
|
48
|
+
func: Callable,
|
|
49
|
+
inputs: Iterable[Any],
|
|
50
|
+
*,
|
|
51
|
+
workers: int | None = DEFAULT_WORKERS,
|
|
52
|
+
batch: int = 1,
|
|
53
|
+
ordered: bool = True,
|
|
54
|
+
progress: bool = True,
|
|
55
|
+
progress_update: int = 10,
|
|
56
|
+
prefetch_factor: int = 4,
|
|
57
|
+
timeout: float | None = None,
|
|
58
|
+
stop_on_error: bool = True,
|
|
59
|
+
n_proc=0,
|
|
60
|
+
store_output_pkl_file: str | None = None,
|
|
61
|
+
**fixed_kwargs,
|
|
62
|
+
) -> list[Any]:
|
|
63
|
+
"""
|
|
64
|
+
ThreadPool **map** that returns a *list*.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
func – target callable.
|
|
69
|
+
inputs – iterable with the arguments.
|
|
70
|
+
workers – defaults to ``os.cpu_count()*2``.
|
|
71
|
+
batch – package *batch* inputs into one call for low‑overhead.
|
|
72
|
+
ordered – keep original order (costs memory); if ``False`` results
|
|
73
|
+
are yielded as soon as they finish.
|
|
74
|
+
progress – show a tqdm bar (requires *tqdm* installed).
|
|
75
|
+
progress_update – bar redraw frequency (logical items, *not* batches).
|
|
76
|
+
prefetch_factor – in‑flight tasks ≈ ``workers * prefetch_factor``.
|
|
77
|
+
timeout – overall timeout (seconds) for the mapping.
|
|
78
|
+
stop_on_error – raise immediately on first exception (default). If
|
|
79
|
+
``False`` the failing task’s result becomes ``None``.
|
|
80
|
+
**fixed_kwargs – static keyword args forwarded to every ``func()`` call.
|
|
81
|
+
"""
|
|
82
|
+
from speedy_utils import dump_json_or_pickle, load_by_ext
|
|
83
|
+
|
|
84
|
+
if n_proc > 1:
|
|
85
|
+
import tempfile
|
|
86
|
+
|
|
87
|
+
from fastcore.all import threaded
|
|
88
|
+
|
|
89
|
+
# split the inputs by nproc
|
|
90
|
+
inputs = list(inputs)
|
|
91
|
+
n_per_proc = max(len(inputs) // n_proc, 1)
|
|
92
|
+
proc_inputs_list = []
|
|
93
|
+
for i in range(0, len(inputs), n_per_proc):
|
|
94
|
+
proc_inputs_list.append(inputs[i : i + n_per_proc])
|
|
95
|
+
procs = []
|
|
96
|
+
in_process_multi_thread = threaded(process=True)(multi_thread)
|
|
97
|
+
|
|
98
|
+
for proc_id, proc_inputs in enumerate(proc_inputs_list):
|
|
99
|
+
with tempfile.NamedTemporaryFile(
|
|
100
|
+
delete=False, suffix="multi_thread.pkl"
|
|
101
|
+
) as tmp_file:
|
|
102
|
+
file_pkl = tmp_file.name
|
|
103
|
+
assert isinstance(in_process_multi_thread, Callable)
|
|
104
|
+
proc = in_process_multi_thread(
|
|
105
|
+
func,
|
|
106
|
+
proc_inputs,
|
|
107
|
+
workers=workers,
|
|
108
|
+
batch=batch,
|
|
109
|
+
ordered=ordered,
|
|
110
|
+
progress=proc_id == 0,
|
|
111
|
+
progress_update=progress_update,
|
|
112
|
+
prefetch_factor=prefetch_factor,
|
|
113
|
+
timeout=timeout,
|
|
114
|
+
stop_on_error=stop_on_error,
|
|
115
|
+
n_proc=0, # prevent recursion
|
|
116
|
+
store_output_pkl_file=file_pkl,
|
|
117
|
+
**fixed_kwargs,
|
|
118
|
+
)
|
|
119
|
+
procs.append([proc, file_pkl])
|
|
120
|
+
# join
|
|
121
|
+
results = []
|
|
122
|
+
|
|
123
|
+
for proc, file_pkl in procs:
|
|
124
|
+
proc.join()
|
|
125
|
+
logger.info(f"Done proc {proc=}")
|
|
126
|
+
results.extend(load_by_ext(file_pkl))
|
|
127
|
+
return results
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
import pandas as pd
|
|
131
|
+
|
|
132
|
+
if isinstance(inputs, pd.DataFrame):
|
|
133
|
+
inputs = inputs.to_dict(orient="records")
|
|
134
|
+
except ImportError:
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
try:
|
|
138
|
+
n_inputs = len(inputs) # type: ignore[arg-type]
|
|
139
|
+
except Exception:
|
|
140
|
+
n_inputs = None
|
|
141
|
+
workers_val = workers if workers is not None else DEFAULT_WORKERS
|
|
142
|
+
|
|
143
|
+
if batch == 1 and n_inputs and n_inputs / max(workers_val, 1) > 20_000:
|
|
144
|
+
batch = 32 # empirically good for sub‑ms tasks
|
|
145
|
+
|
|
146
|
+
# ── build (maybe‑batched) source iterator ────────────────────────────
|
|
147
|
+
src_iter: Iterable[Any] = iter(inputs)
|
|
148
|
+
if batch > 1:
|
|
149
|
+
src_iter = _group_iter(src_iter, batch)
|
|
150
|
+
# Ensure src_iter is always an iterator
|
|
151
|
+
src_iter = iter(src_iter)
|
|
152
|
+
|
|
153
|
+
# total logical items (for bar & ordered pre‑allocation)
|
|
154
|
+
logical_total = n_inputs
|
|
155
|
+
if logical_total is not None and batch > 1:
|
|
156
|
+
logical_total = n_inputs # still number of *items*, not batches
|
|
157
|
+
|
|
158
|
+
# ── progress bar ─────────────────────────────────────────────────────
|
|
159
|
+
bar = None
|
|
160
|
+
last_bar_update = 0
|
|
161
|
+
if progress and tqdm is not None and logical_total is not None:
|
|
162
|
+
bar = tqdm(
|
|
163
|
+
total=logical_total,
|
|
164
|
+
ncols=128,
|
|
165
|
+
colour="green",
|
|
166
|
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"
|
|
167
|
+
" [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# ── prepare result container ─────────────────────────────────────────
|
|
171
|
+
if ordered and logical_total is not None:
|
|
172
|
+
results: list[Any] = [None] * logical_total
|
|
173
|
+
else:
|
|
174
|
+
results = []
|
|
175
|
+
|
|
176
|
+
# ── main execution loop ──────────────────────────────────────────────────
|
|
177
|
+
workers_val = workers if workers is not None else DEFAULT_WORKERS
|
|
178
|
+
max_inflight = workers_val * max(prefetch_factor, 1)
|
|
179
|
+
completed_items = 0
|
|
180
|
+
next_logical_idx = 0 # index assigned to the next submission
|
|
181
|
+
|
|
182
|
+
with ThreadPoolExecutor(max_workers=workers) as pool:
|
|
183
|
+
inflight = set()
|
|
184
|
+
|
|
185
|
+
# prime the pool
|
|
186
|
+
for _ in range(max_inflight):
|
|
187
|
+
try:
|
|
188
|
+
arg = next(src_iter)
|
|
189
|
+
except StopIteration:
|
|
190
|
+
break
|
|
191
|
+
if batch > 1:
|
|
192
|
+
fut = pool.submit(
|
|
193
|
+
lambda items: [_worker(item, func, fixed_kwargs) for item in items],
|
|
194
|
+
arg,
|
|
195
|
+
)
|
|
196
|
+
fut.idx = next_logical_idx # type: ignore[attr-defined]
|
|
197
|
+
inflight.add(fut)
|
|
198
|
+
next_logical_idx += len(arg)
|
|
199
|
+
else:
|
|
200
|
+
fut = pool.submit(_worker, arg, func, fixed_kwargs)
|
|
201
|
+
fut.idx = next_logical_idx # type: ignore[attr-defined]
|
|
202
|
+
inflight.add(fut)
|
|
203
|
+
next_logical_idx += 1
|
|
204
|
+
|
|
205
|
+
try:
|
|
206
|
+
# Process futures as they complete and add new ones to keep the pool busy
|
|
207
|
+
while inflight: # Continue until all in-flight tasks are processed
|
|
208
|
+
for fut in as_completed(inflight, timeout=timeout):
|
|
209
|
+
inflight.remove(fut)
|
|
210
|
+
idx = fut.idx # type: ignore[attr-defined]
|
|
211
|
+
try:
|
|
212
|
+
res = fut.result()
|
|
213
|
+
except Exception:
|
|
214
|
+
if stop_on_error:
|
|
215
|
+
raise
|
|
216
|
+
res = None
|
|
217
|
+
|
|
218
|
+
# flatten res to list of logical outputs
|
|
219
|
+
out_items = res if batch > 1 else [res]
|
|
220
|
+
|
|
221
|
+
# Ensure out_items is a list (and thus Sized)
|
|
222
|
+
if out_items is None:
|
|
223
|
+
out_items = [None]
|
|
224
|
+
elif not isinstance(out_items, list):
|
|
225
|
+
out_items = (
|
|
226
|
+
list(out_items)
|
|
227
|
+
if isinstance(out_items, Iterable)
|
|
228
|
+
else [out_items]
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# store outputs
|
|
232
|
+
if ordered and logical_total is not None:
|
|
233
|
+
results[idx : idx + len(out_items)] = out_items
|
|
234
|
+
else:
|
|
235
|
+
results.extend(out_items)
|
|
236
|
+
|
|
237
|
+
completed_items += len(out_items)
|
|
238
|
+
|
|
239
|
+
# progress bar update
|
|
240
|
+
if bar and completed_items - last_bar_update >= progress_update:
|
|
241
|
+
bar.update(completed_items - last_bar_update)
|
|
242
|
+
last_bar_update = completed_items
|
|
243
|
+
# Show pending, submitted, processing in the bar postfix
|
|
244
|
+
submitted = next_logical_idx
|
|
245
|
+
processing = min(len(inflight), workers_val)
|
|
246
|
+
pending = (
|
|
247
|
+
(logical_total - submitted)
|
|
248
|
+
if logical_total is not None
|
|
249
|
+
else None
|
|
250
|
+
)
|
|
251
|
+
postfix = {
|
|
252
|
+
"pending": pending if pending is not None else "-",
|
|
253
|
+
# 'submitted': submitted,
|
|
254
|
+
"processing": processing,
|
|
255
|
+
}
|
|
256
|
+
bar.set_postfix(postfix)
|
|
257
|
+
|
|
258
|
+
# keep queue full
|
|
259
|
+
try:
|
|
260
|
+
while next_logical_idx - completed_items < max_inflight:
|
|
261
|
+
arg = next(src_iter)
|
|
262
|
+
if batch > 1:
|
|
263
|
+
fut2 = pool.submit(
|
|
264
|
+
lambda items: [
|
|
265
|
+
_worker(item, func, fixed_kwargs)
|
|
266
|
+
for item in items
|
|
267
|
+
],
|
|
268
|
+
arg,
|
|
269
|
+
)
|
|
270
|
+
fut2.idx = next_logical_idx # type: ignore[attr-defined]
|
|
271
|
+
inflight.add(fut2)
|
|
272
|
+
next_logical_idx += len(arg)
|
|
273
|
+
else:
|
|
274
|
+
fut2 = pool.submit(_worker, arg, func, fixed_kwargs)
|
|
275
|
+
fut2.idx = next_logical_idx # type: ignore[attr-defined]
|
|
276
|
+
inflight.add(fut2)
|
|
277
|
+
next_logical_idx += 1
|
|
278
|
+
except StopIteration:
|
|
279
|
+
pass
|
|
280
|
+
|
|
281
|
+
# Break the inner loop as we've processed one future
|
|
282
|
+
break
|
|
283
|
+
|
|
284
|
+
# If we've exhausted the inner loop without processing anything,
|
|
285
|
+
# and there are still in-flight tasks, we need to wait for them
|
|
286
|
+
if inflight and timeout is not None:
|
|
287
|
+
# Use a small timeout to avoid hanging indefinitely
|
|
288
|
+
time.sleep(min(0.01, timeout / 10))
|
|
289
|
+
|
|
290
|
+
finally:
|
|
291
|
+
if bar:
|
|
292
|
+
bar.update(completed_items - last_bar_update)
|
|
293
|
+
bar.close()
|
|
294
|
+
if store_output_pkl_file:
|
|
295
|
+
dump_json_or_pickle(results, store_output_pkl_file)
|
|
296
|
+
return results
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def multi_thread_standard(
|
|
300
|
+
fn: Callable[[Any], Any], items: Iterable[Any], workers: int = 4
|
|
301
|
+
) -> list[Any]:
|
|
302
|
+
"""Execute a function using standard ThreadPoolExecutor.
|
|
303
|
+
|
|
304
|
+
A standard implementation of multi-threading using ThreadPoolExecutor.
|
|
305
|
+
Ensures the order of results matches the input order.
|
|
306
|
+
|
|
307
|
+
Parameters
|
|
308
|
+
----------
|
|
309
|
+
fn : Callable
|
|
310
|
+
The function to execute for each item.
|
|
311
|
+
items : Iterable
|
|
312
|
+
The items to process.
|
|
313
|
+
workers : int, optional
|
|
314
|
+
Number of worker threads, by default 4.
|
|
315
|
+
|
|
316
|
+
Returns
|
|
317
|
+
-------
|
|
318
|
+
list
|
|
319
|
+
Results in same order as input items.
|
|
320
|
+
"""
|
|
321
|
+
with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
322
|
+
futures = [executor.submit(fn, item) for item in items]
|
|
323
|
+
results = [fut.result() for fut in futures]
|
|
324
|
+
return results
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
__all__ = ["multi_thread", "multi_thread_standard"]
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
import argparse
|
|
3
|
+
import itertools
|
|
4
|
+
import multiprocessing # Import multiprocessing module
|
|
5
|
+
import os
|
|
6
|
+
import shlex # To properly escape command line arguments
|
|
7
|
+
import shutil
|
|
8
|
+
import subprocess
|
|
9
|
+
|
|
10
|
+
taskset_path = shutil.which("taskset")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def assert_script(python_path):
|
|
14
|
+
code_str = open(python_path).read()
|
|
15
|
+
if "MP_ID" not in code_str or "MP_TOTAL" not in code_str:
|
|
16
|
+
example_code = (
|
|
17
|
+
'import os; MP_TOTAL = int(os.environ.get("MP_TOTAL"));MP_ID = int(os.environ.get("MP_ID"))\n'
|
|
18
|
+
"inputs = list(inputs[MP_ID::MP_TOTAL])"
|
|
19
|
+
)
|
|
20
|
+
# ANSI escape codes for coloring
|
|
21
|
+
YELLOW = "\033[93m"
|
|
22
|
+
RESET = "\033[0m"
|
|
23
|
+
raise_msg = (
|
|
24
|
+
f"MP_ID and MP_TOTAL not found in {python_path}, please add them.\n\n"
|
|
25
|
+
f"Example:\n{YELLOW}{example_code}{RESET}"
|
|
26
|
+
)
|
|
27
|
+
raise Exception(raise_msg)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def run_in_tmux(commands_to_run, tmux_name, num_windows):
|
|
31
|
+
with open("/tmp/start_multirun_tmux.sh", "w") as script_file:
|
|
32
|
+
# first cmd is to kill the session if it exists
|
|
33
|
+
|
|
34
|
+
script_file.write("#!/bin/bash\n\n")
|
|
35
|
+
script_file.write(f"tmux kill-session -t {tmux_name}\nsleep .1\n")
|
|
36
|
+
script_file.write(f"tmux new-session -d -s {tmux_name}\n")
|
|
37
|
+
for i, cmd in enumerate(itertools.cycle(commands_to_run)):
|
|
38
|
+
if i >= num_windows:
|
|
39
|
+
break
|
|
40
|
+
window_name = f"{tmux_name}:{i}"
|
|
41
|
+
if i == 0:
|
|
42
|
+
script_file.write(f"tmux send-keys -t {window_name} '{cmd}' C-m\n")
|
|
43
|
+
else:
|
|
44
|
+
script_file.write(f"tmux new-window -t {tmux_name}\n")
|
|
45
|
+
script_file.write(f"tmux send-keys -t {window_name} '{cmd}' C-m\n")
|
|
46
|
+
|
|
47
|
+
# Make the script executable
|
|
48
|
+
script_file.write("chmod +x /tmp/start_multirun_tmux.sh\n")
|
|
49
|
+
print("Run /tmp/start_multirun_tmux.sh")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def main():
|
|
53
|
+
# Assert that MP_ID and MP_TOTAL are not already set
|
|
54
|
+
|
|
55
|
+
parser = argparse.ArgumentParser(description="Process fold arguments")
|
|
56
|
+
parser.add_argument(
|
|
57
|
+
"--total_fold", "-t", default=16, type=int, help="total number of folds"
|
|
58
|
+
)
|
|
59
|
+
parser.add_argument("--gpus", type=str, default="0,1,2,3,4,5,6,7")
|
|
60
|
+
parser.add_argument("--ignore_gpus", "-ig", type=str, default="")
|
|
61
|
+
parser.add_argument(
|
|
62
|
+
"--total_cpu",
|
|
63
|
+
type=int,
|
|
64
|
+
default=multiprocessing.cpu_count(),
|
|
65
|
+
help="total number of cpu cores available",
|
|
66
|
+
)
|
|
67
|
+
parser.add_argument(
|
|
68
|
+
"cmd", nargs=argparse.REMAINDER
|
|
69
|
+
) # This will gather the remaining unparsed arguments
|
|
70
|
+
|
|
71
|
+
args = parser.parse_args()
|
|
72
|
+
if not args.cmd or (args.cmd[0] == "--" and len(args.cmd) == 1):
|
|
73
|
+
parser.error("Invalid command provided")
|
|
74
|
+
assert_script(args.cmd[0])
|
|
75
|
+
|
|
76
|
+
cmd_str = None
|
|
77
|
+
if args.cmd[0] == "--":
|
|
78
|
+
cmd_str = shlex.join(args.cmd[1:])
|
|
79
|
+
else:
|
|
80
|
+
cmd_str = shlex.join(args.cmd)
|
|
81
|
+
|
|
82
|
+
gpus = args.gpus.split(",")
|
|
83
|
+
gpus = [gpu for gpu in gpus if not gpu in args.ignore_gpus.split(",")]
|
|
84
|
+
num_gpus = len(gpus)
|
|
85
|
+
|
|
86
|
+
cpu_per_process = max(args.total_cpu // args.total_fold, 1)
|
|
87
|
+
cmds = []
|
|
88
|
+
for i in range(args.total_fold):
|
|
89
|
+
gpu = gpus[i % num_gpus]
|
|
90
|
+
cpu_start = (i * cpu_per_process) % args.total_cpu
|
|
91
|
+
cpu_end = ((i + 1) * cpu_per_process - 1) % args.total_cpu
|
|
92
|
+
ENV = f"CUDA_VISIBLE_DEVICES={gpu} MP_ID={i} MP_TOTAL={args.total_fold}"
|
|
93
|
+
if taskset_path:
|
|
94
|
+
fold_cmd = (
|
|
95
|
+
f"{ENV} {taskset_path} -c {cpu_start}-{cpu_end} python {cmd_str}"
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
fold_cmd = f"{ENV} python {cmd_str}"
|
|
99
|
+
|
|
100
|
+
cmds.append(fold_cmd)
|
|
101
|
+
|
|
102
|
+
run_in_tmux(cmds, "mpython", args.total_fold)
|
|
103
|
+
os.chmod("/tmp/start_multirun_tmux.sh", 0o755) # Make the script executable
|
|
104
|
+
os.system("/tmp/start_multirun_tmux.sh")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
if __name__ == "__main__":
|
|
108
|
+
main()
|