speedy-utils 1.1.40__py3-none-any.whl → 1.1.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,23 @@
1
+ import warnings
2
+ import os
3
+ # Suppress Ray FutureWarnings before any imports
4
+ warnings.filterwarnings("ignore", category=FutureWarning, module="ray.*")
5
+ warnings.filterwarnings("ignore", message=".*pynvml.*deprecated.*", category=FutureWarning)
6
+
7
+ # Set environment variables before Ray is imported anywhere
8
+ os.environ["RAY_ACCEL_ENV_VAR_OVERRI" \
9
+ "DE_ON_ZERO"] = "0"
10
+ os.environ["RAY_DEDUP_LOGS"] = "1"
11
+ os.environ["RAY_LOG_TO_STDERR"] = "0"
12
+
1
13
  from ..__imports import *
14
+ import tempfile
15
+ from .progress import create_progress_tracker, ProgressPoller, get_ray_progress_actor
2
16
 
3
17
 
4
18
  SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
5
19
  _SPEEDY_PROCESSES_LOCK = threading.Lock()
6
20
 
7
-
8
- # /mnt/data/anhvth8/venvs/Megatron-Bridge-Host/lib/python3.12/site-packages/ray/_private/worker.py:2046: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
9
- # turn off future warning and verbose task logs
10
- os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
11
- os.environ["RAY_DEDUP_LOGS"] = "0"
12
- os.environ["RAY_LOG_TO_STDERR"] = "0"
13
-
14
21
  def _prune_dead_processes() -> None:
15
22
  """Remove dead processes from tracking list."""
16
23
  with _SPEEDY_PROCESSES_LOCK:
@@ -121,35 +128,213 @@ def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = Tru
121
128
  return wrapped
122
129
 
123
130
 
131
+ _LOG_GATE_CACHE: dict[str, bool] = {}
132
+
133
+
134
+ def _should_allow_worker_logs(mode: Literal['all', 'zero', 'first'], gate_path: Path | None) -> bool:
135
+ """Determine if current worker should emit logs for the given mode."""
136
+ if mode == 'all':
137
+ return True
138
+ if mode == 'zero':
139
+ return False
140
+ if mode == 'first':
141
+ if gate_path is None:
142
+ return True
143
+ key = str(gate_path)
144
+ cached = _LOG_GATE_CACHE.get(key)
145
+ if cached is not None:
146
+ return cached
147
+ gate_path.parent.mkdir(parents=True, exist_ok=True)
148
+ try:
149
+ fd = os.open(key, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
150
+ except FileExistsError:
151
+ allowed = False
152
+ else:
153
+ os.close(fd)
154
+ allowed = True
155
+ _LOG_GATE_CACHE[key] = allowed
156
+ return allowed
157
+ raise ValueError(f'Unsupported log mode: {mode!r}')
158
+
159
+
160
+ def _cleanup_log_gate(gate_path: Path | None):
161
+ if gate_path is None:
162
+ return
163
+ try:
164
+ gate_path.unlink(missing_ok=True)
165
+ except OSError:
166
+ pass
167
+
168
+
169
+ @contextlib.contextmanager
170
+ def _patch_fastcore_progress_bar(*, leave: bool = True):
171
+ """Temporarily force fastcore.progress_bar to keep the bar on screen."""
172
+ try:
173
+ import fastcore.parallel as _fp
174
+ except ImportError:
175
+ yield False
176
+ return
177
+
178
+ orig = getattr(_fp, 'progress_bar', None)
179
+ if orig is None:
180
+ yield False
181
+ return
182
+
183
+ def _wrapped(*args, **kwargs):
184
+ kwargs.setdefault('leave', leave)
185
+ return orig(*args, **kwargs)
186
+
187
+ _fp.progress_bar = _wrapped
188
+ try:
189
+ yield True
190
+ finally:
191
+ _fp.progress_bar = orig
192
+
193
+
194
+ class _PrefixedWriter:
195
+ """Stream wrapper that prefixes each line with worker id."""
196
+
197
+ def __init__(self, stream, prefix: str):
198
+ self._stream = stream
199
+ self._prefix = prefix
200
+ self._at_line_start = True
201
+
202
+ def write(self, s):
203
+ if not s:
204
+ return 0
205
+ total = 0
206
+ for chunk in s.splitlines(True):
207
+ if self._at_line_start:
208
+ self._stream.write(self._prefix)
209
+ total += len(self._prefix)
210
+ self._stream.write(chunk)
211
+ total += len(chunk)
212
+ self._at_line_start = chunk.endswith('\n')
213
+ return total
214
+
215
+ def flush(self):
216
+ self._stream.flush()
217
+
218
+
219
+ def _call_with_log_control(
220
+ func: Callable,
221
+ x: Any,
222
+ func_kwargs: dict[str, Any],
223
+ log_mode: Literal['all', 'zero', 'first'],
224
+ gate_path: Path | None,
225
+ ):
226
+ """Call a function, silencing stdout/stderr based on log mode."""
227
+ allow_logs = _should_allow_worker_logs(log_mode, gate_path)
228
+ if allow_logs:
229
+ prefix = f"[worker-{os.getpid()}] "
230
+ # Route worker logs to stderr to reduce clobbering tqdm/progress output on stdout
231
+ out = _PrefixedWriter(sys.stderr, prefix)
232
+ err = out
233
+ with contextlib.redirect_stdout(out), contextlib.redirect_stderr(err):
234
+ return func(x, **func_kwargs)
235
+ with open(os.devnull, 'w') as devnull, contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
236
+ return func(x, **func_kwargs)
237
+
238
+
124
239
  # ─── ray management ─────────────────────────────────────────
125
240
 
126
241
  RAY_WORKER = None
127
242
 
128
243
 
129
- def ensure_ray(workers: int, pbar: tqdm | None = None):
130
- """Initialize or reinitialize Ray with a given worker count, log to bar postfix."""
244
+ def ensure_ray(workers: int | None, pbar: tqdm | None = None, ray_metrics_port: int | None = None):
245
+ """
246
+ Initialize or reinitialize Ray safely for both local and cluster environments.
247
+
248
+ 1. Tries to connect to an existing cluster (address='auto') first.
249
+ 2. If no cluster is found, starts a local Ray instance with 'workers' CPUs.
250
+ """
131
251
  import ray as _ray_module
132
252
  import logging
133
253
 
134
254
  global RAY_WORKER
135
- # shutdown when worker count changes or if Ray not initialized
136
- if not _ray_module.is_initialized() or workers != RAY_WORKER:
137
- if _ray_module.is_initialized() and pbar:
138
- pbar.set_postfix_str(f'Restarting Ray {workers} workers')
255
+ requested_workers = workers
256
+ if workers is None:
257
+ workers = os.cpu_count() or 1
258
+
259
+ if ray_metrics_port is not None:
260
+ os.environ['RAY_metrics_export_port'] = str(ray_metrics_port)
261
+
262
+ allow_restart = os.environ.get("RESTART_RAY", "0").lower() in ("1", "true")
263
+ is_cluster_env = "RAY_ADDRESS" in os.environ or os.environ.get("RAY_CLUSTER") == "1"
264
+
265
+ # 1. Handle existing session
266
+ if _ray_module.is_initialized():
267
+ if not allow_restart:
268
+ if pbar:
269
+ pbar.set_postfix_str("Using existing Ray session")
270
+ return
271
+
272
+ # Avoid shutting down shared cluster sessions.
273
+ if is_cluster_env:
274
+ if pbar:
275
+ pbar.set_postfix_str("Cluster active: skipping restart to protect connection")
276
+ return
277
+
278
+ # Local restart: only if worker count changed
279
+ if workers != RAY_WORKER:
280
+ if pbar:
281
+ pbar.set_postfix_str(f'Restarting local Ray with {workers} workers')
139
282
  _ray_module.shutdown()
140
- t0 = time.time()
283
+ else:
284
+ return
285
+
286
+ # 2. Initialization logic
287
+ t0 = time.time()
288
+
289
+ # Try to connect to existing cluster FIRST (address="auto")
290
+ try:
291
+ if pbar:
292
+ pbar.set_postfix_str("Searching for Ray cluster...")
293
+
294
+ # MUST NOT pass num_cpus/num_gpus here to avoid ValueError on existing clusters
295
+ _ray_module.init(
296
+ address="auto",
297
+ ignore_reinit_error=True,
298
+ logging_level=logging.ERROR,
299
+ log_to_driver=False
300
+ )
301
+
302
+ if pbar:
303
+ resources = _ray_module.cluster_resources()
304
+ cpus = resources.get("CPU", 0)
305
+ pbar.set_postfix_str(f"Connected to Ray Cluster ({int(cpus)} CPUs)")
306
+
307
+ except Exception:
308
+ # 3. Fallback: Start a local Ray session
309
+ if pbar:
310
+ pbar.set_postfix_str(f"No cluster found. Starting local Ray ({workers} CPUs)...")
311
+
141
312
  _ray_module.init(
142
313
  num_cpus=workers,
143
314
  ignore_reinit_error=True,
144
315
  logging_level=logging.ERROR,
145
316
  log_to_driver=False,
146
317
  )
147
- took = time.time() - t0
148
- _track_ray_processes() # Track Ray worker processes
318
+
149
319
  if pbar:
150
- pbar.set_postfix_str(f'ray.init {workers} took {took:.2f}s')
151
- RAY_WORKER = workers
320
+ took = time.time() - t0
321
+ pbar.set_postfix_str(f'ray.init local {workers} took {took:.2f}s')
152
322
 
323
+ _track_ray_processes()
324
+
325
+ if requested_workers is None:
326
+ try:
327
+ resources = _ray_module.cluster_resources()
328
+ total_cpus = int(resources.get("CPU", 0))
329
+ if total_cpus > 0:
330
+ workers = total_cpus
331
+ except Exception:
332
+ pass
333
+
334
+ RAY_WORKER = workers
335
+
336
+
337
+ # TODO: make smarter backend selection, when shared_kwargs is used, and backend != 'ray', do not raise error but change to ray and warning user about this
153
338
  def multi_process(
154
339
  func: Callable[[Any], Any],
155
340
  items: Iterable[Any] | None = None,
@@ -163,6 +348,10 @@ def multi_process(
163
348
  desc: str | None = None,
164
349
  shared_kwargs: list[str] | None = None,
165
350
  dump_in_thread: bool = True,
351
+ ray_metrics_port: int | None = None,
352
+ log_worker: Literal['zero', 'first', 'all'] = 'first',
353
+ total_items: int | None = None,
354
+ poll_interval: float = 0.3,
166
355
  **func_kwargs: Any,
167
356
  ) -> list[Any]:
168
357
  """
@@ -185,6 +374,23 @@ def multi_process(
185
374
  - Whether to dump results to disk in a separate thread (default: True)
186
375
  - If False, dumping is done synchronously, which may block but ensures data is saved before returning
187
376
 
377
+ ray_metrics_port:
378
+ - Optional port for Ray metrics export (Ray backend only)
379
+ - Set to 0 to disable Ray metrics
380
+ - If None, uses Ray's default behavior
381
+
382
+ log_worker:
383
+ - Control worker stdout/stderr noise
384
+ - 'first': only first worker emits logs, others are silenced (default)
385
+ - 'all': allow worker prints (may overlap tqdm)
386
+ - 'zero': silence worker stdout/stderr to keep progress bar clean
387
+
388
+ total_items:
389
+ - Optional item-level total for progress tracking (Ray backend only)
390
+
391
+ poll_interval:
392
+ - Poll interval in seconds for progress actor updates (Ray backend only)
393
+
188
394
  If lazy_output=True, every result is saved to .pkl and
189
395
  the returned list contains file paths.
190
396
  """
@@ -219,11 +425,13 @@ def multi_process(
219
425
  f"Valid parameters: {valid_params}"
220
426
  )
221
427
 
222
- # Only allow shared_kwargs with Ray backend
223
- if backend != 'ray':
224
- raise ValueError(
225
- f"shared_kwargs only supported with 'ray' backend, got '{backend}'"
226
- )
428
+ # Prefer Ray backend when shared kwargs are requested
429
+ if backend != 'ray':
430
+ warnings.warn(
431
+ "shared_kwargs only supported with 'ray' backend, switching backend to 'ray'",
432
+ UserWarning,
433
+ )
434
+ backend = 'ray'
227
435
 
228
436
  # unify items
229
437
  # unify items and coerce to concrete list so we can use len() and
@@ -235,35 +443,60 @@ def multi_process(
235
443
  if items is None:
236
444
  raise ValueError("'items' or 'inputs' must be provided")
237
445
 
238
- if workers is None:
446
+ if workers is None and backend != 'ray':
239
447
  workers = os.cpu_count() or 1
240
448
 
241
449
  # build cache dir + wrap func
242
450
  cache_dir = _build_cache_dir(func, items) if lazy_output else None
243
451
  f_wrapped = wrap_dump(func, cache_dir, dump_in_thread)
244
452
 
453
+ log_gate_path: Path | None = None
454
+ if log_worker == 'first':
455
+ log_gate_path = Path(tempfile.gettempdir()) / f'speedy_utils_log_gate_{os.getpid()}_{uuid.uuid4().hex}.gate'
456
+ elif log_worker not in ('zero', 'all'):
457
+ raise ValueError(f'Unsupported log_worker: {log_worker!r}')
458
+
245
459
  total = len(items)
246
460
  if desc:
247
461
  desc = desc.strip() + f'[{backend}]'
248
462
  else:
249
463
  desc = f'Multi-process [{backend}]'
250
- with tqdm(
251
- total=total, desc=desc , disable=not progress
252
- ) as pbar:
253
- # ---- sequential backend ----
254
- if backend == 'seq':
464
+
465
+ # ---- sequential backend ----
466
+ if backend == 'seq':
467
+ results: list[Any] = []
468
+ with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
255
469
  for x in items:
256
- results.append(f_wrapped(x, **func_kwargs))
470
+ results.append(
471
+ _call_with_log_control(
472
+ f_wrapped,
473
+ x,
474
+ func_kwargs,
475
+ log_worker,
476
+ log_gate_path,
477
+ )
478
+ )
257
479
  pbar.update(1)
258
- return results
480
+ _cleanup_log_gate(log_gate_path)
481
+ return results
259
482
 
260
- # ---- ray backend ----
261
- if backend == 'ray':
262
- import ray as _ray_module
483
+ # ---- ray backend ----
484
+ if backend == 'ray':
485
+ import ray as _ray_module
263
486
 
264
- ensure_ray(workers, pbar)
487
+ results = []
488
+ gate_path_str = str(log_gate_path) if log_gate_path else None
489
+ with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
490
+ ensure_ray(workers, pbar, ray_metrics_port)
265
491
  shared_refs = {}
266
492
  regular_kwargs = {}
493
+
494
+ # Create progress actor for item-level tracking if total_items specified
495
+ progress_actor = None
496
+ progress_poller = None
497
+ if total_items is not None:
498
+ progress_actor = create_progress_tracker(total_items, desc or "Items")
499
+ shared_refs['progress_actor'] = progress_actor # Pass actor handle directly (not via put)
267
500
 
268
501
  if shared_kwargs:
269
502
  for kw in shared_kwargs:
@@ -283,60 +516,125 @@ def multi_process(
283
516
  def _task(x, shared_refs_dict, regular_kwargs_dict):
284
517
  # Dereference shared objects (zero-copy for numpy arrays)
285
518
  import ray as _ray_in_task
286
- dereferenced = {k: _ray_in_task.get(v) for k, v in shared_refs_dict.items()}
519
+ gate = Path(gate_path_str) if gate_path_str else None
520
+ dereferenced = {}
521
+ for k, v in shared_refs_dict.items():
522
+ if k == 'progress_actor':
523
+ # Pass actor handle directly (don't dereference)
524
+ dereferenced[k] = v
525
+ else:
526
+ dereferenced[k] = _ray_in_task.get(v)
287
527
  # Merge with regular kwargs
288
528
  all_kwargs = {**dereferenced, **regular_kwargs_dict}
289
- return f_wrapped(x, **all_kwargs)
529
+ return _call_with_log_control(
530
+ f_wrapped,
531
+ x,
532
+ all_kwargs,
533
+ log_worker,
534
+ gate,
535
+ )
290
536
 
291
537
  refs = [
292
538
  _task.remote(x, shared_refs, regular_kwargs) for x in items
293
539
  ]
294
540
 
295
- results = []
296
541
  t_start = time.time()
542
+
543
+ # Start progress poller if using item-level progress
544
+ if progress_actor is not None:
545
+ # Update pbar total to show items instead of tasks
546
+ pbar.total = total_items
547
+ pbar.refresh()
548
+ progress_poller = ProgressPoller(progress_actor, pbar, poll_interval)
549
+ progress_poller.start()
550
+
297
551
  for r in refs:
298
552
  results.append(_ray_module.get(r))
299
- pbar.update(1)
553
+ if progress_actor is None:
554
+ # Only update task-level progress if not using item-level
555
+ pbar.update(1)
556
+
557
+ # Stop progress poller
558
+ if progress_poller is not None:
559
+ progress_poller.stop()
560
+
300
561
  t_end = time.time()
301
- print(f"Ray processing took {t_end - t_start:.2f}s for {total} items")
302
- return results
303
-
304
- # ---- fastcore backend ----
305
- if backend == 'mp':
306
- results = parallel(
307
- f_wrapped, items, n_workers=workers, progress=progress, threadpool=False
562
+ item_desc = f"{total_items:,} items" if total_items else f"{total} tasks"
563
+ print(f"Ray processing took {t_end - t_start:.2f}s for {item_desc}")
564
+ _cleanup_log_gate(log_gate_path)
565
+ return results
566
+
567
+ # ---- fastcore backend ----
568
+ if backend == 'mp':
569
+ worker_func = functools.partial(
570
+ _call_with_log_control,
571
+ f_wrapped,
572
+ func_kwargs=func_kwargs,
573
+ log_mode=log_worker,
574
+ gate_path=log_gate_path,
575
+ )
576
+ with _patch_fastcore_progress_bar(leave=True):
577
+ results = list(parallel(
578
+ worker_func, items, n_workers=workers, progress=progress, threadpool=False
579
+ ))
580
+ _track_multiprocessing_processes() # Track multiprocessing workers
581
+ _prune_dead_processes() # Clean up dead processes
582
+ _cleanup_log_gate(log_gate_path)
583
+ return results
584
+
585
+ if backend == 'threadpool':
586
+ worker_func = functools.partial(
587
+ _call_with_log_control,
588
+ f_wrapped,
589
+ func_kwargs=func_kwargs,
590
+ log_mode=log_worker,
591
+ gate_path=log_gate_path,
592
+ )
593
+ with _patch_fastcore_progress_bar(leave=True):
594
+ results = list(parallel(
595
+ worker_func, items, n_workers=workers, progress=progress, threadpool=True
596
+ ))
597
+ _cleanup_log_gate(log_gate_path)
598
+ return results
599
+
600
+ if backend == 'safe':
601
+ # Completely safe backend for tests - no multiprocessing, no external progress bars
602
+ import concurrent.futures
603
+
604
+ # Import thread tracking from thread module
605
+ try:
606
+ from .thread import _prune_dead_threads, _track_executor_threads
607
+
608
+ worker_func = functools.partial(
609
+ _call_with_log_control,
610
+ f_wrapped,
611
+ func_kwargs=func_kwargs,
612
+ log_mode=log_worker,
613
+ gate_path=log_gate_path,
308
614
  )
309
- _track_multiprocessing_processes() # Track multiprocessing workers
310
- _prune_dead_processes() # Clean up dead processes
311
- return list(results)
312
- if backend == 'threadpool':
313
- results = parallel(
314
- f_wrapped, items, n_workers=workers, progress=progress, threadpool=True
615
+ with concurrent.futures.ThreadPoolExecutor(
616
+ max_workers=workers
617
+ ) as executor:
618
+ _track_executor_threads(executor) # Track threads
619
+ results = list(executor.map(worker_func, items))
620
+ _prune_dead_threads() # Clean up dead threads
621
+ except ImportError:
622
+ # Fallback if thread module not available
623
+ worker_func = functools.partial(
624
+ _call_with_log_control,
625
+ f_wrapped,
626
+ func_kwargs=func_kwargs,
627
+ log_mode=log_worker,
628
+ gate_path=log_gate_path,
315
629
  )
316
- return list(results)
317
- if backend == 'safe':
318
- # Completely safe backend for tests - no multiprocessing, no external progress bars
319
- import concurrent.futures
320
-
321
- # Import thread tracking from thread module
322
- try:
323
- from .thread import _prune_dead_threads, _track_executor_threads
324
-
325
- with concurrent.futures.ThreadPoolExecutor(
326
- max_workers=workers
327
- ) as executor:
328
- _track_executor_threads(executor) # Track threads
329
- results = list(executor.map(f_wrapped, items))
330
- _prune_dead_threads() # Clean up dead threads
331
- except ImportError:
332
- # Fallback if thread module not available
333
- with concurrent.futures.ThreadPoolExecutor(
334
- max_workers=workers
335
- ) as executor:
336
- results = list(executor.map(f_wrapped, items))
337
- return results
630
+ with concurrent.futures.ThreadPoolExecutor(
631
+ max_workers=workers
632
+ ) as executor:
633
+ results = list(executor.map(worker_func, items))
634
+ _cleanup_log_gate(log_gate_path)
635
+ return results
338
636
 
339
- raise ValueError(f'Unsupported backend: {backend!r}')
637
+ raise ValueError(f'Unsupported backend: {backend!r}')
340
638
 
341
639
 
342
640
  def cleanup_phantom_workers():
@@ -396,4 +694,6 @@ __all__ = [
396
694
  'SPEEDY_RUNNING_PROCESSES',
397
695
  'multi_process',
398
696
  'cleanup_phantom_workers',
697
+ 'create_progress_tracker',
698
+ 'get_ray_progress_actor',
399
699
  ]
@@ -0,0 +1,140 @@
1
+ """
2
+ Real-time progress tracking for distributed Ray tasks.
3
+
4
+ This module provides a ProgressActor that allows workers to report item-level
5
+ progress in real-time, giving users visibility into actual items processed
6
+ rather than just task completion.
7
+ """
8
+ import time
9
+ import threading
10
+ from typing import Optional, Callable
11
+
12
+ __all__ = ['ProgressActor', 'create_progress_tracker', 'get_ray_progress_actor']
13
+
14
+
15
+ def get_ray_progress_actor():
16
+ """Get the Ray-decorated ProgressActor class (lazy import to avoid Ray at module load)."""
17
+ import ray
18
+
19
+ @ray.remote
20
+ class ProgressActor:
21
+ """
22
+ A Ray actor for tracking real-time progress across distributed workers.
23
+
24
+ Workers call `update(n)` to report items processed, and the main process
25
+ can poll `get_progress()` to update a tqdm bar in real-time.
26
+ """
27
+ def __init__(self, total: int, desc: str = "Items"):
28
+ self.total = total
29
+ self.processed = 0
30
+ self.desc = desc
31
+ self.start_time = time.time()
32
+ self._lock = threading.Lock()
33
+
34
+ def update(self, n: int = 1) -> int:
35
+ """Increment processed count by n. Returns new total."""
36
+ with self._lock:
37
+ self.processed += n
38
+ return self.processed
39
+
40
+ def get_progress(self) -> dict:
41
+ """Get current progress stats."""
42
+ with self._lock:
43
+ elapsed = time.time() - self.start_time
44
+ rate = self.processed / elapsed if elapsed > 0 else 0
45
+ return {
46
+ "processed": self.processed,
47
+ "total": self.total,
48
+ "elapsed": elapsed,
49
+ "rate": rate,
50
+ "desc": self.desc,
51
+ }
52
+
53
+ def set_total(self, total: int):
54
+ """Update total (useful if exact count unknown at start)."""
55
+ with self._lock:
56
+ self.total = total
57
+
58
+ def reset(self):
59
+ """Reset progress counter."""
60
+ with self._lock:
61
+ self.processed = 0
62
+ self.start_time = time.time()
63
+
64
+ return ProgressActor
65
+
66
+
67
+ def create_progress_tracker(total: int, desc: str = "Items"):
68
+ """
69
+ Create a progress tracker actor for use with Ray distributed tasks.
70
+
71
+ Args:
72
+ total: Total number of items to process
73
+ desc: Description for the progress bar
74
+
75
+ Returns:
76
+ A Ray actor handle that workers can use to report progress
77
+
78
+ Example:
79
+ progress_actor = create_progress_tracker(1000000, "Processing items")
80
+
81
+ @ray.remote
82
+ def worker(items, progress_actor):
83
+ for item in items:
84
+ process(item)
85
+ ray.get(progress_actor.update.remote(1))
86
+
87
+ # In main process, poll progress:
88
+ while not done:
89
+ stats = ray.get(progress_actor.get_progress.remote())
90
+ pbar.n = stats["processed"]
91
+ pbar.refresh()
92
+ """
93
+ import ray
94
+ ProgressActor = get_ray_progress_actor()
95
+ return ProgressActor.remote(total, desc)
96
+
97
+
98
+ class ProgressPoller:
99
+ """
100
+ Background thread that polls a Ray progress actor and updates a tqdm bar.
101
+ """
102
+ def __init__(self, progress_actor, pbar, poll_interval: float = 0.5):
103
+ import ray
104
+ self._ray = ray
105
+ self.progress_actor = progress_actor
106
+ self.pbar = pbar
107
+ self.poll_interval = poll_interval
108
+ self._stop_event = threading.Event()
109
+ self._thread: Optional[threading.Thread] = None
110
+
111
+ def start(self):
112
+ """Start the polling thread."""
113
+ self._thread = threading.Thread(target=self._poll_loop, daemon=True)
114
+ self._thread.start()
115
+
116
+ def stop(self):
117
+ """Stop the polling thread."""
118
+ self._stop_event.set()
119
+ if self._thread:
120
+ self._thread.join(timeout=2.0)
121
+
122
+ def _poll_loop(self):
123
+ """Poll the progress actor and update tqdm."""
124
+ while not self._stop_event.is_set():
125
+ try:
126
+ stats = self._ray.get(self.progress_actor.get_progress.remote())
127
+ self.pbar.n = stats["processed"]
128
+ self.pbar.set_postfix_str(f'{stats["rate"]:.1f} items/s')
129
+ self.pbar.refresh()
130
+ except Exception:
131
+ pass # Ignore errors during polling
132
+ self._stop_event.wait(self.poll_interval)
133
+
134
+ # Final update
135
+ try:
136
+ stats = self._ray.get(self.progress_actor.get_progress.remote())
137
+ self.pbar.n = stats["processed"]
138
+ self.pbar.refresh()
139
+ except Exception:
140
+ pass