speedy-utils 1.1.40__py3-none-any.whl → 1.1.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +2 -0
- llm_utils/llm_ray.py +370 -0
- llm_utils/lm/llm.py +36 -29
- speedy_utils/__init__.py +3 -0
- speedy_utils/common/utils_io.py +3 -1
- speedy_utils/multi_worker/__init__.py +12 -0
- speedy_utils/multi_worker/dataset_ray.py +303 -0
- speedy_utils/multi_worker/parallel_gpu_pool.py +178 -0
- speedy_utils/multi_worker/process.py +375 -75
- speedy_utils/multi_worker/progress.py +140 -0
- speedy_utils/scripts/mpython.py +49 -4
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.42.dist-info}/METADATA +3 -2
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.42.dist-info}/RECORD +15 -11
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.42.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.42.dist-info}/entry_points.txt +0 -0
|
@@ -1,16 +1,23 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import os
|
|
3
|
+
# Suppress Ray FutureWarnings before any imports
|
|
4
|
+
warnings.filterwarnings("ignore", category=FutureWarning, module="ray.*")
|
|
5
|
+
warnings.filterwarnings("ignore", message=".*pynvml.*deprecated.*", category=FutureWarning)
|
|
6
|
+
|
|
7
|
+
# Set environment variables before Ray is imported anywhere
|
|
8
|
+
os.environ["RAY_ACCEL_ENV_VAR_OVERRI" \
|
|
9
|
+
"DE_ON_ZERO"] = "0"
|
|
10
|
+
os.environ["RAY_DEDUP_LOGS"] = "1"
|
|
11
|
+
os.environ["RAY_LOG_TO_STDERR"] = "0"
|
|
12
|
+
|
|
1
13
|
from ..__imports import *
|
|
14
|
+
import tempfile
|
|
15
|
+
from .progress import create_progress_tracker, ProgressPoller, get_ray_progress_actor
|
|
2
16
|
|
|
3
17
|
|
|
4
18
|
SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
|
|
5
19
|
_SPEEDY_PROCESSES_LOCK = threading.Lock()
|
|
6
20
|
|
|
7
|
-
|
|
8
|
-
# /mnt/data/anhvth8/venvs/Megatron-Bridge-Host/lib/python3.12/site-packages/ray/_private/worker.py:2046: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
|
|
9
|
-
# turn off future warning and verbose task logs
|
|
10
|
-
os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
|
|
11
|
-
os.environ["RAY_DEDUP_LOGS"] = "0"
|
|
12
|
-
os.environ["RAY_LOG_TO_STDERR"] = "0"
|
|
13
|
-
|
|
14
21
|
def _prune_dead_processes() -> None:
|
|
15
22
|
"""Remove dead processes from tracking list."""
|
|
16
23
|
with _SPEEDY_PROCESSES_LOCK:
|
|
@@ -121,35 +128,213 @@ def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = Tru
|
|
|
121
128
|
return wrapped
|
|
122
129
|
|
|
123
130
|
|
|
131
|
+
_LOG_GATE_CACHE: dict[str, bool] = {}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _should_allow_worker_logs(mode: Literal['all', 'zero', 'first'], gate_path: Path | None) -> bool:
|
|
135
|
+
"""Determine if current worker should emit logs for the given mode."""
|
|
136
|
+
if mode == 'all':
|
|
137
|
+
return True
|
|
138
|
+
if mode == 'zero':
|
|
139
|
+
return False
|
|
140
|
+
if mode == 'first':
|
|
141
|
+
if gate_path is None:
|
|
142
|
+
return True
|
|
143
|
+
key = str(gate_path)
|
|
144
|
+
cached = _LOG_GATE_CACHE.get(key)
|
|
145
|
+
if cached is not None:
|
|
146
|
+
return cached
|
|
147
|
+
gate_path.parent.mkdir(parents=True, exist_ok=True)
|
|
148
|
+
try:
|
|
149
|
+
fd = os.open(key, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
|
|
150
|
+
except FileExistsError:
|
|
151
|
+
allowed = False
|
|
152
|
+
else:
|
|
153
|
+
os.close(fd)
|
|
154
|
+
allowed = True
|
|
155
|
+
_LOG_GATE_CACHE[key] = allowed
|
|
156
|
+
return allowed
|
|
157
|
+
raise ValueError(f'Unsupported log mode: {mode!r}')
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _cleanup_log_gate(gate_path: Path | None):
|
|
161
|
+
if gate_path is None:
|
|
162
|
+
return
|
|
163
|
+
try:
|
|
164
|
+
gate_path.unlink(missing_ok=True)
|
|
165
|
+
except OSError:
|
|
166
|
+
pass
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@contextlib.contextmanager
|
|
170
|
+
def _patch_fastcore_progress_bar(*, leave: bool = True):
|
|
171
|
+
"""Temporarily force fastcore.progress_bar to keep the bar on screen."""
|
|
172
|
+
try:
|
|
173
|
+
import fastcore.parallel as _fp
|
|
174
|
+
except ImportError:
|
|
175
|
+
yield False
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
orig = getattr(_fp, 'progress_bar', None)
|
|
179
|
+
if orig is None:
|
|
180
|
+
yield False
|
|
181
|
+
return
|
|
182
|
+
|
|
183
|
+
def _wrapped(*args, **kwargs):
|
|
184
|
+
kwargs.setdefault('leave', leave)
|
|
185
|
+
return orig(*args, **kwargs)
|
|
186
|
+
|
|
187
|
+
_fp.progress_bar = _wrapped
|
|
188
|
+
try:
|
|
189
|
+
yield True
|
|
190
|
+
finally:
|
|
191
|
+
_fp.progress_bar = orig
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class _PrefixedWriter:
|
|
195
|
+
"""Stream wrapper that prefixes each line with worker id."""
|
|
196
|
+
|
|
197
|
+
def __init__(self, stream, prefix: str):
|
|
198
|
+
self._stream = stream
|
|
199
|
+
self._prefix = prefix
|
|
200
|
+
self._at_line_start = True
|
|
201
|
+
|
|
202
|
+
def write(self, s):
|
|
203
|
+
if not s:
|
|
204
|
+
return 0
|
|
205
|
+
total = 0
|
|
206
|
+
for chunk in s.splitlines(True):
|
|
207
|
+
if self._at_line_start:
|
|
208
|
+
self._stream.write(self._prefix)
|
|
209
|
+
total += len(self._prefix)
|
|
210
|
+
self._stream.write(chunk)
|
|
211
|
+
total += len(chunk)
|
|
212
|
+
self._at_line_start = chunk.endswith('\n')
|
|
213
|
+
return total
|
|
214
|
+
|
|
215
|
+
def flush(self):
|
|
216
|
+
self._stream.flush()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _call_with_log_control(
|
|
220
|
+
func: Callable,
|
|
221
|
+
x: Any,
|
|
222
|
+
func_kwargs: dict[str, Any],
|
|
223
|
+
log_mode: Literal['all', 'zero', 'first'],
|
|
224
|
+
gate_path: Path | None,
|
|
225
|
+
):
|
|
226
|
+
"""Call a function, silencing stdout/stderr based on log mode."""
|
|
227
|
+
allow_logs = _should_allow_worker_logs(log_mode, gate_path)
|
|
228
|
+
if allow_logs:
|
|
229
|
+
prefix = f"[worker-{os.getpid()}] "
|
|
230
|
+
# Route worker logs to stderr to reduce clobbering tqdm/progress output on stdout
|
|
231
|
+
out = _PrefixedWriter(sys.stderr, prefix)
|
|
232
|
+
err = out
|
|
233
|
+
with contextlib.redirect_stdout(out), contextlib.redirect_stderr(err):
|
|
234
|
+
return func(x, **func_kwargs)
|
|
235
|
+
with open(os.devnull, 'w') as devnull, contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
|
|
236
|
+
return func(x, **func_kwargs)
|
|
237
|
+
|
|
238
|
+
|
|
124
239
|
# ─── ray management ─────────────────────────────────────────
|
|
125
240
|
|
|
126
241
|
RAY_WORKER = None
|
|
127
242
|
|
|
128
243
|
|
|
129
|
-
def ensure_ray(workers: int, pbar: tqdm | None = None):
|
|
130
|
-
"""
|
|
244
|
+
def ensure_ray(workers: int | None, pbar: tqdm | None = None, ray_metrics_port: int | None = None):
|
|
245
|
+
"""
|
|
246
|
+
Initialize or reinitialize Ray safely for both local and cluster environments.
|
|
247
|
+
|
|
248
|
+
1. Tries to connect to an existing cluster (address='auto') first.
|
|
249
|
+
2. If no cluster is found, starts a local Ray instance with 'workers' CPUs.
|
|
250
|
+
"""
|
|
131
251
|
import ray as _ray_module
|
|
132
252
|
import logging
|
|
133
253
|
|
|
134
254
|
global RAY_WORKER
|
|
135
|
-
|
|
136
|
-
if
|
|
137
|
-
|
|
138
|
-
|
|
255
|
+
requested_workers = workers
|
|
256
|
+
if workers is None:
|
|
257
|
+
workers = os.cpu_count() or 1
|
|
258
|
+
|
|
259
|
+
if ray_metrics_port is not None:
|
|
260
|
+
os.environ['RAY_metrics_export_port'] = str(ray_metrics_port)
|
|
261
|
+
|
|
262
|
+
allow_restart = os.environ.get("RESTART_RAY", "0").lower() in ("1", "true")
|
|
263
|
+
is_cluster_env = "RAY_ADDRESS" in os.environ or os.environ.get("RAY_CLUSTER") == "1"
|
|
264
|
+
|
|
265
|
+
# 1. Handle existing session
|
|
266
|
+
if _ray_module.is_initialized():
|
|
267
|
+
if not allow_restart:
|
|
268
|
+
if pbar:
|
|
269
|
+
pbar.set_postfix_str("Using existing Ray session")
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
# Avoid shutting down shared cluster sessions.
|
|
273
|
+
if is_cluster_env:
|
|
274
|
+
if pbar:
|
|
275
|
+
pbar.set_postfix_str("Cluster active: skipping restart to protect connection")
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
# Local restart: only if worker count changed
|
|
279
|
+
if workers != RAY_WORKER:
|
|
280
|
+
if pbar:
|
|
281
|
+
pbar.set_postfix_str(f'Restarting local Ray with {workers} workers')
|
|
139
282
|
_ray_module.shutdown()
|
|
140
|
-
|
|
283
|
+
else:
|
|
284
|
+
return
|
|
285
|
+
|
|
286
|
+
# 2. Initialization logic
|
|
287
|
+
t0 = time.time()
|
|
288
|
+
|
|
289
|
+
# Try to connect to existing cluster FIRST (address="auto")
|
|
290
|
+
try:
|
|
291
|
+
if pbar:
|
|
292
|
+
pbar.set_postfix_str("Searching for Ray cluster...")
|
|
293
|
+
|
|
294
|
+
# MUST NOT pass num_cpus/num_gpus here to avoid ValueError on existing clusters
|
|
295
|
+
_ray_module.init(
|
|
296
|
+
address="auto",
|
|
297
|
+
ignore_reinit_error=True,
|
|
298
|
+
logging_level=logging.ERROR,
|
|
299
|
+
log_to_driver=False
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
if pbar:
|
|
303
|
+
resources = _ray_module.cluster_resources()
|
|
304
|
+
cpus = resources.get("CPU", 0)
|
|
305
|
+
pbar.set_postfix_str(f"Connected to Ray Cluster ({int(cpus)} CPUs)")
|
|
306
|
+
|
|
307
|
+
except Exception:
|
|
308
|
+
# 3. Fallback: Start a local Ray session
|
|
309
|
+
if pbar:
|
|
310
|
+
pbar.set_postfix_str(f"No cluster found. Starting local Ray ({workers} CPUs)...")
|
|
311
|
+
|
|
141
312
|
_ray_module.init(
|
|
142
313
|
num_cpus=workers,
|
|
143
314
|
ignore_reinit_error=True,
|
|
144
315
|
logging_level=logging.ERROR,
|
|
145
316
|
log_to_driver=False,
|
|
146
317
|
)
|
|
147
|
-
|
|
148
|
-
_track_ray_processes() # Track Ray worker processes
|
|
318
|
+
|
|
149
319
|
if pbar:
|
|
150
|
-
|
|
151
|
-
|
|
320
|
+
took = time.time() - t0
|
|
321
|
+
pbar.set_postfix_str(f'ray.init local {workers} took {took:.2f}s')
|
|
152
322
|
|
|
323
|
+
_track_ray_processes()
|
|
324
|
+
|
|
325
|
+
if requested_workers is None:
|
|
326
|
+
try:
|
|
327
|
+
resources = _ray_module.cluster_resources()
|
|
328
|
+
total_cpus = int(resources.get("CPU", 0))
|
|
329
|
+
if total_cpus > 0:
|
|
330
|
+
workers = total_cpus
|
|
331
|
+
except Exception:
|
|
332
|
+
pass
|
|
333
|
+
|
|
334
|
+
RAY_WORKER = workers
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
# TODO: make smarter backend selection, when shared_kwargs is used, and backend != 'ray', do not raise error but change to ray and warning user about this
|
|
153
338
|
def multi_process(
|
|
154
339
|
func: Callable[[Any], Any],
|
|
155
340
|
items: Iterable[Any] | None = None,
|
|
@@ -163,6 +348,10 @@ def multi_process(
|
|
|
163
348
|
desc: str | None = None,
|
|
164
349
|
shared_kwargs: list[str] | None = None,
|
|
165
350
|
dump_in_thread: bool = True,
|
|
351
|
+
ray_metrics_port: int | None = None,
|
|
352
|
+
log_worker: Literal['zero', 'first', 'all'] = 'first',
|
|
353
|
+
total_items: int | None = None,
|
|
354
|
+
poll_interval: float = 0.3,
|
|
166
355
|
**func_kwargs: Any,
|
|
167
356
|
) -> list[Any]:
|
|
168
357
|
"""
|
|
@@ -185,6 +374,23 @@ def multi_process(
|
|
|
185
374
|
- Whether to dump results to disk in a separate thread (default: True)
|
|
186
375
|
- If False, dumping is done synchronously, which may block but ensures data is saved before returning
|
|
187
376
|
|
|
377
|
+
ray_metrics_port:
|
|
378
|
+
- Optional port for Ray metrics export (Ray backend only)
|
|
379
|
+
- Set to 0 to disable Ray metrics
|
|
380
|
+
- If None, uses Ray's default behavior
|
|
381
|
+
|
|
382
|
+
log_worker:
|
|
383
|
+
- Control worker stdout/stderr noise
|
|
384
|
+
- 'first': only first worker emits logs, others are silenced (default)
|
|
385
|
+
- 'all': allow worker prints (may overlap tqdm)
|
|
386
|
+
- 'zero': silence worker stdout/stderr to keep progress bar clean
|
|
387
|
+
|
|
388
|
+
total_items:
|
|
389
|
+
- Optional item-level total for progress tracking (Ray backend only)
|
|
390
|
+
|
|
391
|
+
poll_interval:
|
|
392
|
+
- Poll interval in seconds for progress actor updates (Ray backend only)
|
|
393
|
+
|
|
188
394
|
If lazy_output=True, every result is saved to .pkl and
|
|
189
395
|
the returned list contains file paths.
|
|
190
396
|
"""
|
|
@@ -219,11 +425,13 @@ def multi_process(
|
|
|
219
425
|
f"Valid parameters: {valid_params}"
|
|
220
426
|
)
|
|
221
427
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
428
|
+
# Prefer Ray backend when shared kwargs are requested
|
|
429
|
+
if backend != 'ray':
|
|
430
|
+
warnings.warn(
|
|
431
|
+
"shared_kwargs only supported with 'ray' backend, switching backend to 'ray'",
|
|
432
|
+
UserWarning,
|
|
433
|
+
)
|
|
434
|
+
backend = 'ray'
|
|
227
435
|
|
|
228
436
|
# unify items
|
|
229
437
|
# unify items and coerce to concrete list so we can use len() and
|
|
@@ -235,35 +443,60 @@ def multi_process(
|
|
|
235
443
|
if items is None:
|
|
236
444
|
raise ValueError("'items' or 'inputs' must be provided")
|
|
237
445
|
|
|
238
|
-
if workers is None:
|
|
446
|
+
if workers is None and backend != 'ray':
|
|
239
447
|
workers = os.cpu_count() or 1
|
|
240
448
|
|
|
241
449
|
# build cache dir + wrap func
|
|
242
450
|
cache_dir = _build_cache_dir(func, items) if lazy_output else None
|
|
243
451
|
f_wrapped = wrap_dump(func, cache_dir, dump_in_thread)
|
|
244
452
|
|
|
453
|
+
log_gate_path: Path | None = None
|
|
454
|
+
if log_worker == 'first':
|
|
455
|
+
log_gate_path = Path(tempfile.gettempdir()) / f'speedy_utils_log_gate_{os.getpid()}_{uuid.uuid4().hex}.gate'
|
|
456
|
+
elif log_worker not in ('zero', 'all'):
|
|
457
|
+
raise ValueError(f'Unsupported log_worker: {log_worker!r}')
|
|
458
|
+
|
|
245
459
|
total = len(items)
|
|
246
460
|
if desc:
|
|
247
461
|
desc = desc.strip() + f'[{backend}]'
|
|
248
462
|
else:
|
|
249
463
|
desc = f'Multi-process [{backend}]'
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
464
|
+
|
|
465
|
+
# ---- sequential backend ----
|
|
466
|
+
if backend == 'seq':
|
|
467
|
+
results: list[Any] = []
|
|
468
|
+
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
255
469
|
for x in items:
|
|
256
|
-
results.append(
|
|
470
|
+
results.append(
|
|
471
|
+
_call_with_log_control(
|
|
472
|
+
f_wrapped,
|
|
473
|
+
x,
|
|
474
|
+
func_kwargs,
|
|
475
|
+
log_worker,
|
|
476
|
+
log_gate_path,
|
|
477
|
+
)
|
|
478
|
+
)
|
|
257
479
|
pbar.update(1)
|
|
258
|
-
|
|
480
|
+
_cleanup_log_gate(log_gate_path)
|
|
481
|
+
return results
|
|
259
482
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
483
|
+
# ---- ray backend ----
|
|
484
|
+
if backend == 'ray':
|
|
485
|
+
import ray as _ray_module
|
|
263
486
|
|
|
264
|
-
|
|
487
|
+
results = []
|
|
488
|
+
gate_path_str = str(log_gate_path) if log_gate_path else None
|
|
489
|
+
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
490
|
+
ensure_ray(workers, pbar, ray_metrics_port)
|
|
265
491
|
shared_refs = {}
|
|
266
492
|
regular_kwargs = {}
|
|
493
|
+
|
|
494
|
+
# Create progress actor for item-level tracking if total_items specified
|
|
495
|
+
progress_actor = None
|
|
496
|
+
progress_poller = None
|
|
497
|
+
if total_items is not None:
|
|
498
|
+
progress_actor = create_progress_tracker(total_items, desc or "Items")
|
|
499
|
+
shared_refs['progress_actor'] = progress_actor # Pass actor handle directly (not via put)
|
|
267
500
|
|
|
268
501
|
if shared_kwargs:
|
|
269
502
|
for kw in shared_kwargs:
|
|
@@ -283,60 +516,125 @@ def multi_process(
|
|
|
283
516
|
def _task(x, shared_refs_dict, regular_kwargs_dict):
|
|
284
517
|
# Dereference shared objects (zero-copy for numpy arrays)
|
|
285
518
|
import ray as _ray_in_task
|
|
286
|
-
|
|
519
|
+
gate = Path(gate_path_str) if gate_path_str else None
|
|
520
|
+
dereferenced = {}
|
|
521
|
+
for k, v in shared_refs_dict.items():
|
|
522
|
+
if k == 'progress_actor':
|
|
523
|
+
# Pass actor handle directly (don't dereference)
|
|
524
|
+
dereferenced[k] = v
|
|
525
|
+
else:
|
|
526
|
+
dereferenced[k] = _ray_in_task.get(v)
|
|
287
527
|
# Merge with regular kwargs
|
|
288
528
|
all_kwargs = {**dereferenced, **regular_kwargs_dict}
|
|
289
|
-
return
|
|
529
|
+
return _call_with_log_control(
|
|
530
|
+
f_wrapped,
|
|
531
|
+
x,
|
|
532
|
+
all_kwargs,
|
|
533
|
+
log_worker,
|
|
534
|
+
gate,
|
|
535
|
+
)
|
|
290
536
|
|
|
291
537
|
refs = [
|
|
292
538
|
_task.remote(x, shared_refs, regular_kwargs) for x in items
|
|
293
539
|
]
|
|
294
540
|
|
|
295
|
-
results = []
|
|
296
541
|
t_start = time.time()
|
|
542
|
+
|
|
543
|
+
# Start progress poller if using item-level progress
|
|
544
|
+
if progress_actor is not None:
|
|
545
|
+
# Update pbar total to show items instead of tasks
|
|
546
|
+
pbar.total = total_items
|
|
547
|
+
pbar.refresh()
|
|
548
|
+
progress_poller = ProgressPoller(progress_actor, pbar, poll_interval)
|
|
549
|
+
progress_poller.start()
|
|
550
|
+
|
|
297
551
|
for r in refs:
|
|
298
552
|
results.append(_ray_module.get(r))
|
|
299
|
-
|
|
553
|
+
if progress_actor is None:
|
|
554
|
+
# Only update task-level progress if not using item-level
|
|
555
|
+
pbar.update(1)
|
|
556
|
+
|
|
557
|
+
# Stop progress poller
|
|
558
|
+
if progress_poller is not None:
|
|
559
|
+
progress_poller.stop()
|
|
560
|
+
|
|
300
561
|
t_end = time.time()
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
562
|
+
item_desc = f"{total_items:,} items" if total_items else f"{total} tasks"
|
|
563
|
+
print(f"Ray processing took {t_end - t_start:.2f}s for {item_desc}")
|
|
564
|
+
_cleanup_log_gate(log_gate_path)
|
|
565
|
+
return results
|
|
566
|
+
|
|
567
|
+
# ---- fastcore backend ----
|
|
568
|
+
if backend == 'mp':
|
|
569
|
+
worker_func = functools.partial(
|
|
570
|
+
_call_with_log_control,
|
|
571
|
+
f_wrapped,
|
|
572
|
+
func_kwargs=func_kwargs,
|
|
573
|
+
log_mode=log_worker,
|
|
574
|
+
gate_path=log_gate_path,
|
|
575
|
+
)
|
|
576
|
+
with _patch_fastcore_progress_bar(leave=True):
|
|
577
|
+
results = list(parallel(
|
|
578
|
+
worker_func, items, n_workers=workers, progress=progress, threadpool=False
|
|
579
|
+
))
|
|
580
|
+
_track_multiprocessing_processes() # Track multiprocessing workers
|
|
581
|
+
_prune_dead_processes() # Clean up dead processes
|
|
582
|
+
_cleanup_log_gate(log_gate_path)
|
|
583
|
+
return results
|
|
584
|
+
|
|
585
|
+
if backend == 'threadpool':
|
|
586
|
+
worker_func = functools.partial(
|
|
587
|
+
_call_with_log_control,
|
|
588
|
+
f_wrapped,
|
|
589
|
+
func_kwargs=func_kwargs,
|
|
590
|
+
log_mode=log_worker,
|
|
591
|
+
gate_path=log_gate_path,
|
|
592
|
+
)
|
|
593
|
+
with _patch_fastcore_progress_bar(leave=True):
|
|
594
|
+
results = list(parallel(
|
|
595
|
+
worker_func, items, n_workers=workers, progress=progress, threadpool=True
|
|
596
|
+
))
|
|
597
|
+
_cleanup_log_gate(log_gate_path)
|
|
598
|
+
return results
|
|
599
|
+
|
|
600
|
+
if backend == 'safe':
|
|
601
|
+
# Completely safe backend for tests - no multiprocessing, no external progress bars
|
|
602
|
+
import concurrent.futures
|
|
603
|
+
|
|
604
|
+
# Import thread tracking from thread module
|
|
605
|
+
try:
|
|
606
|
+
from .thread import _prune_dead_threads, _track_executor_threads
|
|
607
|
+
|
|
608
|
+
worker_func = functools.partial(
|
|
609
|
+
_call_with_log_control,
|
|
610
|
+
f_wrapped,
|
|
611
|
+
func_kwargs=func_kwargs,
|
|
612
|
+
log_mode=log_worker,
|
|
613
|
+
gate_path=log_gate_path,
|
|
308
614
|
)
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
615
|
+
with concurrent.futures.ThreadPoolExecutor(
|
|
616
|
+
max_workers=workers
|
|
617
|
+
) as executor:
|
|
618
|
+
_track_executor_threads(executor) # Track threads
|
|
619
|
+
results = list(executor.map(worker_func, items))
|
|
620
|
+
_prune_dead_threads() # Clean up dead threads
|
|
621
|
+
except ImportError:
|
|
622
|
+
# Fallback if thread module not available
|
|
623
|
+
worker_func = functools.partial(
|
|
624
|
+
_call_with_log_control,
|
|
625
|
+
f_wrapped,
|
|
626
|
+
func_kwargs=func_kwargs,
|
|
627
|
+
log_mode=log_worker,
|
|
628
|
+
gate_path=log_gate_path,
|
|
315
629
|
)
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
try:
|
|
323
|
-
from .thread import _prune_dead_threads, _track_executor_threads
|
|
324
|
-
|
|
325
|
-
with concurrent.futures.ThreadPoolExecutor(
|
|
326
|
-
max_workers=workers
|
|
327
|
-
) as executor:
|
|
328
|
-
_track_executor_threads(executor) # Track threads
|
|
329
|
-
results = list(executor.map(f_wrapped, items))
|
|
330
|
-
_prune_dead_threads() # Clean up dead threads
|
|
331
|
-
except ImportError:
|
|
332
|
-
# Fallback if thread module not available
|
|
333
|
-
with concurrent.futures.ThreadPoolExecutor(
|
|
334
|
-
max_workers=workers
|
|
335
|
-
) as executor:
|
|
336
|
-
results = list(executor.map(f_wrapped, items))
|
|
337
|
-
return results
|
|
630
|
+
with concurrent.futures.ThreadPoolExecutor(
|
|
631
|
+
max_workers=workers
|
|
632
|
+
) as executor:
|
|
633
|
+
results = list(executor.map(worker_func, items))
|
|
634
|
+
_cleanup_log_gate(log_gate_path)
|
|
635
|
+
return results
|
|
338
636
|
|
|
339
|
-
|
|
637
|
+
raise ValueError(f'Unsupported backend: {backend!r}')
|
|
340
638
|
|
|
341
639
|
|
|
342
640
|
def cleanup_phantom_workers():
|
|
@@ -396,4 +694,6 @@ __all__ = [
|
|
|
396
694
|
'SPEEDY_RUNNING_PROCESSES',
|
|
397
695
|
'multi_process',
|
|
398
696
|
'cleanup_phantom_workers',
|
|
697
|
+
'create_progress_tracker',
|
|
698
|
+
'get_ray_progress_actor',
|
|
399
699
|
]
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Real-time progress tracking for distributed Ray tasks.
|
|
3
|
+
|
|
4
|
+
This module provides a ProgressActor that allows workers to report item-level
|
|
5
|
+
progress in real-time, giving users visibility into actual items processed
|
|
6
|
+
rather than just task completion.
|
|
7
|
+
"""
|
|
8
|
+
import time
|
|
9
|
+
import threading
|
|
10
|
+
from typing import Optional, Callable
|
|
11
|
+
|
|
12
|
+
__all__ = ['ProgressActor', 'create_progress_tracker', 'get_ray_progress_actor']
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_ray_progress_actor():
|
|
16
|
+
"""Get the Ray-decorated ProgressActor class (lazy import to avoid Ray at module load)."""
|
|
17
|
+
import ray
|
|
18
|
+
|
|
19
|
+
@ray.remote
|
|
20
|
+
class ProgressActor:
|
|
21
|
+
"""
|
|
22
|
+
A Ray actor for tracking real-time progress across distributed workers.
|
|
23
|
+
|
|
24
|
+
Workers call `update(n)` to report items processed, and the main process
|
|
25
|
+
can poll `get_progress()` to update a tqdm bar in real-time.
|
|
26
|
+
"""
|
|
27
|
+
def __init__(self, total: int, desc: str = "Items"):
|
|
28
|
+
self.total = total
|
|
29
|
+
self.processed = 0
|
|
30
|
+
self.desc = desc
|
|
31
|
+
self.start_time = time.time()
|
|
32
|
+
self._lock = threading.Lock()
|
|
33
|
+
|
|
34
|
+
def update(self, n: int = 1) -> int:
|
|
35
|
+
"""Increment processed count by n. Returns new total."""
|
|
36
|
+
with self._lock:
|
|
37
|
+
self.processed += n
|
|
38
|
+
return self.processed
|
|
39
|
+
|
|
40
|
+
def get_progress(self) -> dict:
|
|
41
|
+
"""Get current progress stats."""
|
|
42
|
+
with self._lock:
|
|
43
|
+
elapsed = time.time() - self.start_time
|
|
44
|
+
rate = self.processed / elapsed if elapsed > 0 else 0
|
|
45
|
+
return {
|
|
46
|
+
"processed": self.processed,
|
|
47
|
+
"total": self.total,
|
|
48
|
+
"elapsed": elapsed,
|
|
49
|
+
"rate": rate,
|
|
50
|
+
"desc": self.desc,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
def set_total(self, total: int):
|
|
54
|
+
"""Update total (useful if exact count unknown at start)."""
|
|
55
|
+
with self._lock:
|
|
56
|
+
self.total = total
|
|
57
|
+
|
|
58
|
+
def reset(self):
|
|
59
|
+
"""Reset progress counter."""
|
|
60
|
+
with self._lock:
|
|
61
|
+
self.processed = 0
|
|
62
|
+
self.start_time = time.time()
|
|
63
|
+
|
|
64
|
+
return ProgressActor
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def create_progress_tracker(total: int, desc: str = "Items"):
|
|
68
|
+
"""
|
|
69
|
+
Create a progress tracker actor for use with Ray distributed tasks.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
total: Total number of items to process
|
|
73
|
+
desc: Description for the progress bar
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
A Ray actor handle that workers can use to report progress
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
progress_actor = create_progress_tracker(1000000, "Processing items")
|
|
80
|
+
|
|
81
|
+
@ray.remote
|
|
82
|
+
def worker(items, progress_actor):
|
|
83
|
+
for item in items:
|
|
84
|
+
process(item)
|
|
85
|
+
ray.get(progress_actor.update.remote(1))
|
|
86
|
+
|
|
87
|
+
# In main process, poll progress:
|
|
88
|
+
while not done:
|
|
89
|
+
stats = ray.get(progress_actor.get_progress.remote())
|
|
90
|
+
pbar.n = stats["processed"]
|
|
91
|
+
pbar.refresh()
|
|
92
|
+
"""
|
|
93
|
+
import ray
|
|
94
|
+
ProgressActor = get_ray_progress_actor()
|
|
95
|
+
return ProgressActor.remote(total, desc)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ProgressPoller:
|
|
99
|
+
"""
|
|
100
|
+
Background thread that polls a Ray progress actor and updates a tqdm bar.
|
|
101
|
+
"""
|
|
102
|
+
def __init__(self, progress_actor, pbar, poll_interval: float = 0.5):
|
|
103
|
+
import ray
|
|
104
|
+
self._ray = ray
|
|
105
|
+
self.progress_actor = progress_actor
|
|
106
|
+
self.pbar = pbar
|
|
107
|
+
self.poll_interval = poll_interval
|
|
108
|
+
self._stop_event = threading.Event()
|
|
109
|
+
self._thread: Optional[threading.Thread] = None
|
|
110
|
+
|
|
111
|
+
def start(self):
|
|
112
|
+
"""Start the polling thread."""
|
|
113
|
+
self._thread = threading.Thread(target=self._poll_loop, daemon=True)
|
|
114
|
+
self._thread.start()
|
|
115
|
+
|
|
116
|
+
def stop(self):
|
|
117
|
+
"""Stop the polling thread."""
|
|
118
|
+
self._stop_event.set()
|
|
119
|
+
if self._thread:
|
|
120
|
+
self._thread.join(timeout=2.0)
|
|
121
|
+
|
|
122
|
+
def _poll_loop(self):
|
|
123
|
+
"""Poll the progress actor and update tqdm."""
|
|
124
|
+
while not self._stop_event.is_set():
|
|
125
|
+
try:
|
|
126
|
+
stats = self._ray.get(self.progress_actor.get_progress.remote())
|
|
127
|
+
self.pbar.n = stats["processed"]
|
|
128
|
+
self.pbar.set_postfix_str(f'{stats["rate"]:.1f} items/s')
|
|
129
|
+
self.pbar.refresh()
|
|
130
|
+
except Exception:
|
|
131
|
+
pass # Ignore errors during polling
|
|
132
|
+
self._stop_event.wait(self.poll_interval)
|
|
133
|
+
|
|
134
|
+
# Final update
|
|
135
|
+
try:
|
|
136
|
+
stats = self._ray.get(self.progress_actor.get_progress.remote())
|
|
137
|
+
self.pbar.n = stats["processed"]
|
|
138
|
+
self.pbar.refresh()
|
|
139
|
+
except Exception:
|
|
140
|
+
pass
|