speedy-utils 1.1.40__py3-none-any.whl → 1.1.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,495 @@
1
+ import warnings
2
+ import os
3
+ # Suppress Ray FutureWarnings before any imports
4
+ warnings.filterwarnings("ignore", category=FutureWarning, module="ray.*")
5
+ warnings.filterwarnings("ignore", message=".*pynvml.*deprecated.*", category=FutureWarning)
6
+
7
+ # Set environment variables before Ray is imported anywhere
8
+ os.environ["RAY_ACCEL_ENV_VAR_OVERRI" \
9
+ "DE_ON_ZERO"] = "0"
10
+ os.environ["RAY_DEDUP_LOGS"] = "1"
11
+ os.environ["RAY_LOG_TO_STDERR"] = "0"
12
+ os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
13
+
1
14
  from ..__imports import *
15
+ import tempfile
16
+ import inspect
17
+ import linecache
18
+ import traceback as tb_module
19
+ from .progress import create_progress_tracker, ProgressPoller, get_ray_progress_actor
2
20
 
21
+ # Import thread tracking functions if available
22
+ try:
23
+ from .thread import _prune_dead_threads, _track_executor_threads
24
+ except ImportError:
25
+ _prune_dead_threads = None # type: ignore[assignment]
26
+ _track_executor_threads = None # type: ignore[assignment]
3
27
 
4
- SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
5
- _SPEEDY_PROCESSES_LOCK = threading.Lock()
6
28
 
29
+ # ─── error handler types ────────────────────────────────────────
30
+ ErrorHandlerType = Literal['raise', 'ignore', 'log']
7
31
 
8
- # /mnt/data/anhvth8/venvs/Megatron-Bridge-Host/lib/python3.12/site-packages/ray/_private/worker.py:2046: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
9
- # turn off future warning and verbose task logs
10
- os.environ["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
11
- os.environ["RAY_DEDUP_LOGS"] = "0"
12
- os.environ["RAY_LOG_TO_STDERR"] = "0"
32
+
33
+ class ErrorStats:
34
+ """Thread-safe error statistics tracker."""
35
+
36
+ def __init__(
37
+ self,
38
+ func_name: str,
39
+ max_error_files: int = 100,
40
+ write_logs: bool = True
41
+ ):
42
+ self._lock = threading.Lock()
43
+ self._success_count = 0
44
+ self._error_count = 0
45
+ self._first_error_shown = False
46
+ self._max_error_files = max_error_files
47
+ self._write_logs = write_logs
48
+ self._error_log_dir = self._get_error_log_dir(func_name)
49
+ if write_logs:
50
+ self._error_log_dir.mkdir(parents=True, exist_ok=True)
51
+
52
+ @staticmethod
53
+ def _get_error_log_dir(func_name: str) -> Path:
54
+ """Generate unique error log directory with run counter."""
55
+ base_dir = Path('.cache/speedy_utils/error_logs')
56
+ base_dir.mkdir(parents=True, exist_ok=True)
57
+
58
+ # Find the next run counter
59
+ counter = 1
60
+ existing = list(base_dir.glob(f'{func_name}_run_*'))
61
+ if existing:
62
+ counters = []
63
+ for p in existing:
64
+ try:
65
+ parts = p.name.split('_run_')
66
+ if len(parts) == 2:
67
+ counters.append(int(parts[1]))
68
+ except (ValueError, IndexError):
69
+ pass
70
+ if counters:
71
+ counter = max(counters) + 1
72
+
73
+ return base_dir / f'{func_name}_run_{counter}'
74
+
75
+ def record_success(self) -> None:
76
+ with self._lock:
77
+ self._success_count += 1
78
+
79
+ def record_error(
80
+ self,
81
+ idx: int,
82
+ error: Exception,
83
+ input_value: Any,
84
+ func_name: str,
85
+ ) -> str | None:
86
+ """
87
+ Record an error and write to log file.
88
+ Returns the log file path if written, None otherwise.
89
+ """
90
+ with self._lock:
91
+ self._error_count += 1
92
+ should_show_first = not self._first_error_shown
93
+ if should_show_first:
94
+ self._first_error_shown = True
95
+ should_write = (
96
+ self._write_logs and self._error_count <= self._max_error_files
97
+ )
98
+
99
+ log_path = None
100
+ if should_write:
101
+ log_path = self._write_error_log(
102
+ idx, error, input_value, func_name
103
+ )
104
+
105
+ if should_show_first:
106
+ self._print_first_error(error, input_value, func_name, log_path)
107
+
108
+ return log_path
109
+
110
+ def _write_error_log(
111
+ self,
112
+ idx: int,
113
+ error: Exception,
114
+ input_value: Any,
115
+ func_name: str,
116
+ ) -> str:
117
+ """Write error details to a log file."""
118
+ log_path = self._error_log_dir / f'{idx}.log'
119
+
120
+ # Format traceback
121
+ tb_lines = self._format_traceback(error)
122
+
123
+ content = []
124
+ content.append(f'{"=" * 60}')
125
+ content.append(f'Error at index: {idx}')
126
+ content.append(f'Function: {func_name}')
127
+ content.append(f'Error Type: {type(error).__name__}')
128
+ content.append(f'Error Message: {error}')
129
+ content.append(f'{"=" * 60}')
130
+ content.append('')
131
+ content.append('Input:')
132
+ content.append('-' * 40)
133
+ try:
134
+ content.append(repr(input_value))
135
+ except Exception:
136
+ content.append('<unable to repr input>')
137
+ content.append('')
138
+ content.append('Traceback:')
139
+ content.append('-' * 40)
140
+ content.extend(tb_lines)
141
+
142
+ with open(log_path, 'w') as f:
143
+ f.write('\n'.join(content))
144
+
145
+ return str(log_path)
146
+
147
+ def _format_traceback(self, error: Exception) -> list[str]:
148
+ """Format traceback with context lines like Rich panel."""
149
+ lines = []
150
+ frames = _extract_frames_from_traceback(error)
151
+
152
+ for filepath, lineno, funcname, frame_locals in frames:
153
+ lines.append(f'│ {filepath}:{lineno} in {funcname} │')
154
+ lines.append('│' + ' ' * 70 + '│')
155
+
156
+ # Get context lines
157
+ context_size = 3
158
+ start_line = max(1, lineno - context_size)
159
+ end_line = lineno + context_size + 1
160
+
161
+ for line_num in range(start_line, end_line):
162
+ line_text = linecache.getline(filepath, line_num).rstrip()
163
+ if line_text:
164
+ num_str = str(line_num).rjust(4)
165
+ if line_num == lineno:
166
+ lines.append(f'│ {num_str} ❱ {line_text}')
167
+ else:
168
+ lines.append(f'│ {num_str} │ {line_text}')
169
+ lines.append('')
170
+
171
+ return lines
172
+
173
+ def _print_first_error(
174
+ self,
175
+ error: Exception,
176
+ input_value: Any,
177
+ func_name: str,
178
+ log_path: str | None,
179
+ ) -> None:
180
+ """Print the first error to screen with Rich formatting."""
181
+ try:
182
+ from rich.console import Console
183
+ from rich.panel import Panel
184
+ console = Console(stderr=True)
185
+
186
+ tb_lines = self._format_traceback(error)
187
+
188
+ console.print()
189
+ console.print(
190
+ Panel(
191
+ '\n'.join(tb_lines),
192
+ title='[bold red]First Error (continuing with remaining items)[/bold red]',
193
+ border_style='yellow',
194
+ expand=False,
195
+ )
196
+ )
197
+ console.print(
198
+ f'[bold red]{type(error).__name__}[/bold red]: {error}'
199
+ )
200
+ if log_path:
201
+ console.print(f'[dim]Error log: {log_path}[/dim]')
202
+ console.print()
203
+ except ImportError:
204
+ # Fallback to plain print
205
+ print(f'\n--- First Error ---', file=sys.stderr)
206
+ print(f'{type(error).__name__}: {error}', file=sys.stderr)
207
+ if log_path:
208
+ print(f'Error log: {log_path}', file=sys.stderr)
209
+ print('', file=sys.stderr)
210
+
211
+ @property
212
+ def success_count(self) -> int:
213
+ with self._lock:
214
+ return self._success_count
215
+
216
+ @property
217
+ def error_count(self) -> int:
218
+ with self._lock:
219
+ return self._error_count
220
+
221
+ def get_postfix_dict(self) -> dict[str, int]:
222
+ """Get dict for pbar postfix."""
223
+ with self._lock:
224
+ return {'ok': self._success_count, 'err': self._error_count}
225
+
226
+
227
+ def _should_skip_frame(filepath: str) -> bool:
228
+ """Check if a frame should be filtered from traceback display."""
229
+ skip_patterns = [
230
+ 'ray/_private',
231
+ 'ray/worker',
232
+ 'site-packages/ray',
233
+ 'speedy_utils/multi_worker',
234
+ 'concurrent/futures',
235
+ 'multiprocessing/',
236
+ 'fastcore/parallel',
237
+ 'fastcore/foundation',
238
+ 'fastcore/basics',
239
+ 'site-packages/fastcore',
240
+ '/threading.py',
241
+ '/concurrent/',
242
+ ]
243
+ return any(skip in filepath for skip in skip_patterns)
244
+
245
+
246
+ def _should_show_local(name: str, value: object) -> bool:
247
+ """Check if a local variable should be displayed in traceback."""
248
+ import types
249
+
250
+ # Skip dunder variables
251
+ if name.startswith('__') and name.endswith('__'):
252
+ return False
253
+
254
+ # Skip modules
255
+ if isinstance(value, types.ModuleType):
256
+ return False
257
+
258
+ # Skip type objects and classes
259
+ if isinstance(value, type):
260
+ return False
261
+
262
+ # Skip functions and methods
263
+ if isinstance(value, (types.FunctionType, types.MethodType, types.BuiltinFunctionType)):
264
+ return False
265
+
266
+ # Skip common typing aliases
267
+ value_str = str(value)
268
+ if value_str.startswith('typing.'):
269
+ return False
270
+
271
+ # Skip large objects that would clutter output
272
+ if value_str.startswith('<') and any(x in value_str for x in ['module', 'function', 'method', 'built-in']):
273
+ return False
274
+
275
+ return True
276
+
277
+
278
+ def _format_locals(frame_locals: dict) -> list[str]:
279
+ """Format local variables for display, filtering out noisy imports."""
280
+ from rich.pretty import Pretty
281
+ from rich.console import Console
282
+ from io import StringIO
283
+
284
+ # Filter locals
285
+ clean_locals = {k: v for k, v in frame_locals.items() if _should_show_local(k, v)}
286
+
287
+ if not clean_locals:
288
+ return []
289
+
290
+ lines = []
291
+ lines.append('[dim]╭─ locals ─╮[/dim]')
292
+
293
+ # Format each local variable
294
+ for name, value in clean_locals.items():
295
+ # Use Rich's Pretty for nice formatting
296
+ try:
297
+ console = Console(file=StringIO(), width=60)
298
+ console.print(Pretty(value), end='')
299
+ value_str = console.file.getvalue().strip()
300
+ # Limit length
301
+ if len(value_str) > 100:
302
+ value_str = value_str[:97] + '...'
303
+ except Exception:
304
+ value_str = repr(value)
305
+ if len(value_str) > 100:
306
+ value_str = value_str[:97] + '...'
307
+
308
+ lines.append(f'[dim]│[/dim] {name} = {value_str}')
309
+
310
+ lines.append('[dim]╰──────────╯[/dim]')
311
+ return lines
312
+
313
+
314
+ def _format_frame_with_context(filepath: str, lineno: int, funcname: str, frame_locals: dict | None = None) -> list[str]:
315
+ """Format a single frame with context lines and optional locals."""
316
+ lines = []
317
+ # Frame header
318
+ lines.append(
319
+ f'[cyan]{filepath}[/cyan]:[yellow]{lineno}[/yellow] '
320
+ f'in [green]{funcname}[/green]'
321
+ )
322
+ lines.append('')
323
+
324
+ # Get context lines
325
+ context_size = 3
326
+ start_line = max(1, lineno - context_size)
327
+ end_line = lineno + context_size + 1
328
+
329
+ for line_num in range(start_line, end_line):
330
+ import linecache
331
+ line_text = linecache.getline(filepath, line_num).rstrip()
332
+ if line_text:
333
+ num_str = str(line_num).rjust(4)
334
+ if line_num == lineno:
335
+ lines.append(f'[dim]{num_str}[/dim] [red]❱[/red] {line_text}')
336
+ else:
337
+ lines.append(f'[dim]{num_str} │[/dim] {line_text}')
338
+
339
+ # Add locals if available
340
+ if frame_locals:
341
+ locals_lines = _format_locals(frame_locals)
342
+ if locals_lines:
343
+ lines.append('')
344
+ lines.extend(locals_lines)
345
+
346
+ lines.append('')
347
+ return lines
348
+
349
+
350
+ def _display_formatted_error(
351
+ exc_type_name: str,
352
+ exc_msg: str,
353
+ frames: list[tuple[str, int, str, dict]],
354
+ caller_info: dict | None,
355
+ backend: str,
356
+ pbar=None,
357
+ ) -> None:
358
+
359
+ # Suppress additional error logs
360
+ os.environ['RAY_IGNORE_UNHANDLED_ERRORS'] = '1'
361
+
362
+ # Close progress bar cleanly if provided
363
+ if pbar is not None:
364
+ pbar.close()
365
+
366
+ from rich.console import Console
367
+ from rich.panel import Panel
368
+ console = Console(stderr=True)
369
+
370
+ if frames or caller_info:
371
+ display_lines = []
372
+
373
+ # Add caller frame first if available (no locals for caller)
374
+ if caller_info:
375
+ display_lines.extend(_format_frame_with_context(
376
+ caller_info['filename'],
377
+ caller_info['lineno'],
378
+ caller_info['function'],
379
+ None # Don't show locals for caller frame
380
+ ))
381
+
382
+ # Add error frames with locals
383
+ for filepath, lineno, funcname, frame_locals in frames:
384
+ display_lines.extend(_format_frame_with_context(
385
+ filepath, lineno, funcname, frame_locals
386
+ ))
387
+
388
+ # Display the traceback
389
+ console.print()
390
+ console.print(
391
+ Panel(
392
+ '\n'.join(display_lines),
393
+ title=f'[bold red]Traceback (most recent call last) [{backend}][/bold red]',
394
+ border_style='red',
395
+ expand=False,
396
+ )
397
+ )
398
+ console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
399
+ console.print()
400
+ else:
401
+ # No frames found, minimal output
402
+ console.print()
403
+ console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
404
+ console.print()
405
+
406
+ # Ensure output is flushed
407
+ sys.stderr.flush()
408
+ sys.stdout.flush()
409
+ sys.exit(1)
410
+
411
+
412
+ def _extract_frames_from_traceback(error: Exception) -> list[tuple[str, int, str, dict]]:
413
+ """Extract user frames from exception traceback object with locals."""
414
+ frames = []
415
+ if hasattr(error, '__traceback__') and error.__traceback__ is not None:
416
+ tb = error.__traceback__
417
+ while tb is not None:
418
+ frame = tb.tb_frame
419
+ filename = frame.f_code.co_filename
420
+ lineno = tb.tb_lineno
421
+ funcname = frame.f_code.co_name
422
+
423
+ if not _should_skip_frame(filename):
424
+ # Get local variables from the frame
425
+ frame_locals = dict(frame.f_locals)
426
+ frames.append((filename, lineno, funcname, frame_locals))
427
+
428
+ tb = tb.tb_next
429
+ return frames
430
+
431
+
432
+ def _extract_frames_from_ray_error(ray_task_error: Exception) -> list[tuple[str, int, str, dict]]:
433
+ """Extract user frames from Ray's string traceback representation."""
434
+ frames = []
435
+ error_str = str(ray_task_error)
436
+ lines = error_str.split('\n')
437
+
438
+ import re
439
+ for i, line in enumerate(lines):
440
+ # Match: File "path", line N, in func
441
+ file_match = re.match(r'\s*File "([^"]+)", line (\d+), in (.+)', line)
442
+ if file_match:
443
+ filepath, lineno, funcname = file_match.groups()
444
+ if not _should_skip_frame(filepath):
445
+ # Ray doesn't preserve locals, so use empty dict
446
+ frames.append((filepath, int(lineno), funcname, {}))
447
+
448
+ return frames
449
+
450
+
451
+ def _reraise_worker_error(error: Exception, pbar=None, caller_info=None, backend: str = 'unknown') -> None:
452
+ """
453
+ Re-raise the original exception from a worker error with clean traceback.
454
+ Works for multiprocessing, threadpool, and other backends with real tracebacks.
455
+ """
456
+ frames = _extract_frames_from_traceback(error)
457
+ _display_formatted_error(
458
+ exc_type_name=type(error).__name__,
459
+ exc_msg=str(error),
460
+ frames=frames,
461
+ caller_info=caller_info,
462
+ backend=backend,
463
+ pbar=pbar,
464
+ )
465
+
466
+
467
+ def _reraise_ray_error(ray_task_error: Exception, pbar=None, caller_info=None) -> None:
468
+ """
469
+ Re-raise the original exception from a RayTaskError with clean traceback.
470
+ Parses Ray's string traceback and displays with full context.
471
+ """
472
+ # Get the exception info
473
+ cause = ray_task_error.cause if hasattr(ray_task_error, 'cause') else None
474
+ if cause is None:
475
+ cause = ray_task_error.__cause__
476
+
477
+ exc_type_name = type(cause).__name__ if cause else 'Error'
478
+ exc_msg = str(cause) if cause else str(ray_task_error)
479
+
480
+ frames = _extract_frames_from_ray_error(ray_task_error)
481
+ _display_formatted_error(
482
+ exc_type_name=exc_type_name,
483
+ exc_msg=exc_msg,
484
+ frames=frames,
485
+ caller_info=caller_info,
486
+ backend='ray',
487
+ pbar=pbar,
488
+ )
489
+
490
+
491
+ SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
492
+ _SPEEDY_PROCESSES_LOCK = threading.Lock()
13
493
 
14
494
  def _prune_dead_processes() -> None:
15
495
  """Remove dead processes from tracking list."""
@@ -89,7 +569,7 @@ def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
89
569
  path = Path('.cache') / run_id
90
570
  path.mkdir(parents=True, exist_ok=True)
91
571
  return path
92
- _DUMP_THREADS = []
572
+ _DUMP_INTERMEDIATE_THREADS = []
93
573
  def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = True):
94
574
  """Wrap a function so results are dumped to .pkl when cache_dir is set."""
95
575
  if cache_dir is None:
@@ -108,7 +588,7 @@ def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = Tru
108
588
 
109
589
  if dump_in_thread:
110
590
  thread = threading.Thread(target=save)
111
- _DUMP_THREADS.append(thread)
591
+ _DUMP_INTERMEDIATE_THREADS.append(thread)
112
592
  # count thread
113
593
  # print(f'Thread count: {threading.active_count()}')
114
594
  while threading.active_count() > 16:
@@ -121,35 +601,213 @@ def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = Tru
121
601
  return wrapped
122
602
 
123
603
 
604
+ _LOG_GATE_CACHE: dict[str, bool] = {}
605
+
606
+
607
+ def _should_allow_worker_logs(mode: Literal['all', 'zero', 'first'], gate_path: Path | None) -> bool:
608
+ """Determine if current worker should emit logs for the given mode."""
609
+ if mode == 'all':
610
+ return True
611
+ if mode == 'zero':
612
+ return False
613
+ if mode == 'first':
614
+ if gate_path is None:
615
+ return True
616
+ key = str(gate_path)
617
+ cached = _LOG_GATE_CACHE.get(key)
618
+ if cached is not None:
619
+ return cached
620
+ gate_path.parent.mkdir(parents=True, exist_ok=True)
621
+ try:
622
+ fd = os.open(key, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
623
+ except FileExistsError:
624
+ allowed = False
625
+ else:
626
+ os.close(fd)
627
+ allowed = True
628
+ _LOG_GATE_CACHE[key] = allowed
629
+ return allowed
630
+ raise ValueError(f'Unsupported log mode: {mode!r}')
631
+
632
+
633
+ def _cleanup_log_gate(gate_path: Path | None):
634
+ if gate_path is None:
635
+ return
636
+ try:
637
+ gate_path.unlink(missing_ok=True)
638
+ except OSError:
639
+ pass
640
+
641
+
642
+ @contextlib.contextmanager
643
+ def _patch_fastcore_progress_bar(*, leave: bool = True):
644
+ """Temporarily force fastcore.progress_bar to keep the bar on screen."""
645
+ try:
646
+ import fastcore.parallel as _fp
647
+ except ImportError:
648
+ yield False
649
+ return
650
+
651
+ orig = getattr(_fp, 'progress_bar', None)
652
+ if orig is None:
653
+ yield False
654
+ return
655
+
656
+ def _wrapped(*args, **kwargs):
657
+ kwargs.setdefault('leave', leave)
658
+ return orig(*args, **kwargs)
659
+
660
+ _fp.progress_bar = _wrapped
661
+ try:
662
+ yield True
663
+ finally:
664
+ _fp.progress_bar = orig
665
+
666
+
667
+ class _PrefixedWriter:
668
+ """Stream wrapper that prefixes each line with worker id."""
669
+
670
+ def __init__(self, stream, prefix: str):
671
+ self._stream = stream
672
+ self._prefix = prefix
673
+ self._at_line_start = True
674
+
675
+ def write(self, s):
676
+ if not s:
677
+ return 0
678
+ total = 0
679
+ for chunk in s.splitlines(True):
680
+ if self._at_line_start:
681
+ self._stream.write(self._prefix)
682
+ total += len(self._prefix)
683
+ self._stream.write(chunk)
684
+ total += len(chunk)
685
+ self._at_line_start = chunk.endswith('\n')
686
+ return total
687
+
688
+ def flush(self):
689
+ self._stream.flush()
690
+
691
+
692
+ def _call_with_log_control(
693
+ func: Callable,
694
+ x: Any,
695
+ func_kwargs: dict[str, Any],
696
+ log_mode: Literal['all', 'zero', 'first'],
697
+ gate_path: Path | None,
698
+ ):
699
+ """Call a function, silencing stdout/stderr based on log mode."""
700
+ allow_logs = _should_allow_worker_logs(log_mode, gate_path)
701
+ if allow_logs:
702
+ prefix = f"[worker-{os.getpid()}] "
703
+ # Route worker logs to stderr to reduce clobbering tqdm/progress output on stdout
704
+ out = _PrefixedWriter(sys.stderr, prefix)
705
+ err = out
706
+ with contextlib.redirect_stdout(out), contextlib.redirect_stderr(err):
707
+ return func(x, **func_kwargs)
708
+ with open(os.devnull, 'w') as devnull, contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
709
+ return func(x, **func_kwargs)
710
+
711
+
124
712
  # ─── ray management ─────────────────────────────────────────
125
713
 
126
714
  RAY_WORKER = None
127
715
 
128
716
 
129
- def ensure_ray(workers: int, pbar: tqdm | None = None):
130
- """Initialize or reinitialize Ray with a given worker count, log to bar postfix."""
717
+ def ensure_ray(workers: int | None, pbar: tqdm | None = None, ray_metrics_port: int | None = None):
718
+ """
719
+ Initialize or reinitialize Ray safely for both local and cluster environments.
720
+
721
+ 1. Tries to connect to an existing cluster (address='auto') first.
722
+ 2. If no cluster is found, starts a local Ray instance with 'workers' CPUs.
723
+ """
131
724
  import ray as _ray_module
132
725
  import logging
133
726
 
134
727
  global RAY_WORKER
135
- # shutdown when worker count changes or if Ray not initialized
136
- if not _ray_module.is_initialized() or workers != RAY_WORKER:
137
- if _ray_module.is_initialized() and pbar:
138
- pbar.set_postfix_str(f'Restarting Ray {workers} workers')
728
+ requested_workers = workers
729
+ if workers is None:
730
+ workers = os.cpu_count() or 1
731
+
732
+ if ray_metrics_port is not None:
733
+ os.environ['RAY_metrics_export_port'] = str(ray_metrics_port)
734
+
735
+ allow_restart = os.environ.get("RESTART_RAY", "0").lower() in ("1", "true")
736
+ is_cluster_env = "RAY_ADDRESS" in os.environ or os.environ.get("RAY_CLUSTER") == "1"
737
+
738
+ # 1. Handle existing session
739
+ if _ray_module.is_initialized():
740
+ if not allow_restart:
741
+ if pbar:
742
+ pbar.set_postfix_str("Using existing Ray session")
743
+ return
744
+
745
+ # Avoid shutting down shared cluster sessions.
746
+ if is_cluster_env:
747
+ if pbar:
748
+ pbar.set_postfix_str("Cluster active: skipping restart to protect connection")
749
+ return
750
+
751
+ # Local restart: only if worker count changed
752
+ if workers != RAY_WORKER:
753
+ if pbar:
754
+ pbar.set_postfix_str(f'Restarting local Ray with {workers} workers')
139
755
  _ray_module.shutdown()
140
- t0 = time.time()
756
+ else:
757
+ return
758
+
759
+ # 2. Initialization logic
760
+ t0 = time.time()
761
+
762
+ # Try to connect to existing cluster FIRST (address="auto")
763
+ try:
764
+ if pbar:
765
+ pbar.set_postfix_str("Searching for Ray cluster...")
766
+
767
+ # MUST NOT pass num_cpus/num_gpus here to avoid ValueError on existing clusters
768
+ _ray_module.init(
769
+ address="auto",
770
+ ignore_reinit_error=True,
771
+ logging_level=logging.ERROR,
772
+ log_to_driver=False
773
+ )
774
+
775
+ if pbar:
776
+ resources = _ray_module.cluster_resources()
777
+ cpus = resources.get("CPU", 0)
778
+ pbar.set_postfix_str(f"Connected to Ray Cluster ({int(cpus)} CPUs)")
779
+
780
+ except Exception:
781
+ # 3. Fallback: Start a local Ray session
782
+ if pbar:
783
+ pbar.set_postfix_str(f"No cluster found. Starting local Ray ({workers} CPUs)...")
784
+
141
785
  _ray_module.init(
142
786
  num_cpus=workers,
143
787
  ignore_reinit_error=True,
144
788
  logging_level=logging.ERROR,
145
789
  log_to_driver=False,
146
790
  )
147
- took = time.time() - t0
148
- _track_ray_processes() # Track Ray worker processes
791
+
149
792
  if pbar:
150
- pbar.set_postfix_str(f'ray.init {workers} took {took:.2f}s')
151
- RAY_WORKER = workers
793
+ took = time.time() - t0
794
+ pbar.set_postfix_str(f'ray.init local {workers} took {took:.2f}s')
795
+
796
+ _track_ray_processes()
797
+
798
+ if requested_workers is None:
799
+ try:
800
+ resources = _ray_module.cluster_resources()
801
+ total_cpus = int(resources.get("CPU", 0))
802
+ if total_cpus > 0:
803
+ workers = total_cpus
804
+ except Exception:
805
+ pass
806
+
807
+ RAY_WORKER = workers
152
808
 
809
+
810
+ # TODO: make smarter backend selection, when shared_kwargs is used, and backend != 'ray', do not raise error but change to ray and warning user about this
153
811
  def multi_process(
154
812
  func: Callable[[Any], Any],
155
813
  items: Iterable[Any] | None = None,
@@ -158,11 +816,16 @@ def multi_process(
158
816
  workers: int | None = None,
159
817
  lazy_output: bool = False,
160
818
  progress: bool = True,
161
- # backend: str = "ray", # "seq", "ray", or "fastcore"
162
- backend: Literal['seq', 'ray', 'mp', 'threadpool', 'safe'] = 'mp',
819
+ backend: Literal['seq', 'ray', 'mp', 'safe'] = 'mp',
163
820
  desc: str | None = None,
164
821
  shared_kwargs: list[str] | None = None,
165
822
  dump_in_thread: bool = True,
823
+ ray_metrics_port: int | None = None,
824
+ log_worker: Literal['zero', 'first', 'all'] = 'first',
825
+ total_items: int | None = None,
826
+ poll_interval: float = 0.3,
827
+ error_handler: ErrorHandlerType = 'raise',
828
+ max_error_files: int = 100,
166
829
  **func_kwargs: Any,
167
830
  ) -> list[Any]:
168
831
  """
@@ -171,19 +834,44 @@ def multi_process(
171
834
  backend:
172
835
  - "seq": run sequentially
173
836
  - "ray": run in parallel with Ray
174
- - "mp": run in parallel with multiprocessing (uses threadpool to avoid fork warnings)
175
- - "threadpool": run in parallel with thread pool
837
+ - "mp": run in parallel with thread pool (uses ThreadPoolExecutor)
176
838
  - "safe": run in parallel with thread pool (explicitly safe for tests)
177
839
 
178
840
  shared_kwargs:
179
- - Optional list of kwarg names that should be shared via Ray's zero-copy object store
841
+ - Optional list of kwarg names that should be shared via Ray's
842
+ zero-copy object store
180
843
  - Only works with Ray backend
181
- - Useful for large objects (e.g., models, datasets) that should be shared across workers
182
- - Example: shared_kwargs=['model', 'tokenizer'] for sharing large ML models
844
+ - Useful for large objects (e.g., models, datasets)
845
+ - Example: shared_kwargs=['model', 'tokenizer']
183
846
 
184
847
  dump_in_thread:
185
848
  - Whether to dump results to disk in a separate thread (default: True)
186
- - If False, dumping is done synchronously, which may block but ensures data is saved before returning
849
+ - If False, dumping is done synchronously
850
+
851
+ ray_metrics_port:
852
+ - Optional port for Ray metrics export (Ray backend only)
853
+
854
+ log_worker:
855
+ - Control worker stdout/stderr noise
856
+ - 'first': only first worker emits logs (default)
857
+ - 'all': allow worker prints
858
+ - 'zero': silence all worker output
859
+
860
+ total_items:
861
+ - Optional item-level total for progress tracking (Ray backend only)
862
+
863
+ poll_interval:
864
+ - Poll interval in seconds for progress actor updates (Ray only)
865
+
866
+ error_handler:
867
+ - 'raise': raise exception on first error (default)
868
+ - 'ignore': continue processing, return None for failed items
869
+ - 'log': same as ignore, but logs errors to files
870
+
871
+ max_error_files:
872
+ - Maximum number of error log files to write (default: 100)
873
+ - Error logs are written to .cache/speedy_utils/error_logs/{idx}.log
874
+ - First error is always printed to screen with the log file path
187
875
 
188
876
  If lazy_output=True, every result is saved to .pkl and
189
877
  the returned list contains file paths.
@@ -219,11 +907,13 @@ def multi_process(
219
907
  f"Valid parameters: {valid_params}"
220
908
  )
221
909
 
222
- # Only allow shared_kwargs with Ray backend
223
- if backend != 'ray':
224
- raise ValueError(
225
- f"shared_kwargs only supported with 'ray' backend, got '{backend}'"
226
- )
910
+ # Prefer Ray backend when shared kwargs are requested
911
+ if shared_kwargs and backend != 'ray':
912
+ warnings.warn(
913
+ "shared_kwargs only supported with 'ray' backend, switching backend to 'ray'",
914
+ UserWarning,
915
+ )
916
+ backend = 'ray'
227
917
 
228
918
  # unify items
229
919
  # unify items and coerce to concrete list so we can use len() and
@@ -235,35 +925,111 @@ def multi_process(
235
925
  if items is None:
236
926
  raise ValueError("'items' or 'inputs' must be provided")
237
927
 
238
- if workers is None:
928
+ if workers is None and backend != 'ray':
239
929
  workers = os.cpu_count() or 1
240
930
 
241
931
  # build cache dir + wrap func
242
932
  cache_dir = _build_cache_dir(func, items) if lazy_output else None
243
933
  f_wrapped = wrap_dump(func, cache_dir, dump_in_thread)
244
934
 
935
+ log_gate_path: Path | None = None
936
+ if log_worker == 'first':
937
+ log_gate_path = Path(tempfile.gettempdir()) / f'speedy_utils_log_gate_{os.getpid()}_{uuid.uuid4().hex}.gate'
938
+ elif log_worker not in ('zero', 'all'):
939
+ raise ValueError(f'Unsupported log_worker: {log_worker!r}')
940
+
245
941
  total = len(items)
246
942
  if desc:
247
943
  desc = desc.strip() + f'[{backend}]'
248
944
  else:
249
945
  desc = f'Multi-process [{backend}]'
250
- with tqdm(
251
- total=total, desc=desc , disable=not progress
252
- ) as pbar:
253
- # ---- sequential backend ----
254
- if backend == 'seq':
255
- for x in items:
256
- results.append(f_wrapped(x, **func_kwargs))
946
+
947
+ # Initialize error stats for error handling
948
+ func_name = getattr(func, '__name__', repr(func))
949
+ error_stats = ErrorStats(
950
+ func_name=func_name,
951
+ max_error_files=max_error_files,
952
+ write_logs=error_handler == 'log'
953
+ )
954
+
955
+ def _update_pbar_postfix(pbar: tqdm) -> None:
956
+ """Update pbar with success/error counts."""
957
+ postfix = error_stats.get_postfix_dict()
958
+ pbar.set_postfix(postfix)
959
+
960
+ def _wrap_with_error_handler(
961
+ f: Callable,
962
+ idx: int,
963
+ input_value: Any,
964
+ error_stats_ref: ErrorStats,
965
+ handler: ErrorHandlerType,
966
+ ) -> Callable:
967
+ """Wrap function to handle errors based on error_handler setting."""
968
+ @functools.wraps(f)
969
+ def wrapper(*args, **kwargs):
970
+ try:
971
+ result = f(*args, **kwargs)
972
+ error_stats_ref.record_success()
973
+ return result
974
+ except Exception as e:
975
+ if handler == 'raise':
976
+ raise
977
+ error_stats_ref.record_error(idx, e, input_value, func_name)
978
+ return None
979
+ return wrapper
980
+
981
+ # ---- sequential backend ----
982
+ if backend == 'seq':
983
+ results: list[Any] = []
984
+ with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
985
+ for idx, x in enumerate(items):
986
+ try:
987
+ result = _call_with_log_control(
988
+ f_wrapped,
989
+ x,
990
+ func_kwargs,
991
+ log_worker,
992
+ log_gate_path,
993
+ )
994
+ error_stats.record_success()
995
+ results.append(result)
996
+ except Exception as e:
997
+ if error_handler == 'raise':
998
+ raise
999
+ error_stats.record_error(idx, e, x, func_name)
1000
+ results.append(None)
257
1001
  pbar.update(1)
258
- return results
1002
+ _update_pbar_postfix(pbar)
1003
+ _cleanup_log_gate(log_gate_path)
1004
+ return results
259
1005
 
260
- # ---- ray backend ----
261
- if backend == 'ray':
262
- import ray as _ray_module
1006
+ # ---- ray backend ----
1007
+ if backend == 'ray':
1008
+ import ray as _ray_module
263
1009
 
264
- ensure_ray(workers, pbar)
1010
+ # Capture caller frame for better error reporting
1011
+ caller_frame = inspect.currentframe()
1012
+ caller_info = None
1013
+ if caller_frame and caller_frame.f_back:
1014
+ caller_info = {
1015
+ 'filename': caller_frame.f_back.f_code.co_filename,
1016
+ 'lineno': caller_frame.f_back.f_lineno,
1017
+ 'function': caller_frame.f_back.f_code.co_name,
1018
+ }
1019
+
1020
+ results = []
1021
+ gate_path_str = str(log_gate_path) if log_gate_path else None
1022
+ with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
1023
+ ensure_ray(workers, pbar, ray_metrics_port)
265
1024
  shared_refs = {}
266
1025
  regular_kwargs = {}
1026
+
1027
+ # Create progress actor for item-level tracking if total_items specified
1028
+ progress_actor = None
1029
+ progress_poller = None
1030
+ if total_items is not None:
1031
+ progress_actor = create_progress_tracker(total_items, desc or "Items")
1032
+ shared_refs['progress_actor'] = progress_actor
267
1033
 
268
1034
  if shared_kwargs:
269
1035
  for kw in shared_kwargs:
@@ -283,60 +1049,193 @@ def multi_process(
283
1049
  def _task(x, shared_refs_dict, regular_kwargs_dict):
284
1050
  # Dereference shared objects (zero-copy for numpy arrays)
285
1051
  import ray as _ray_in_task
286
- dereferenced = {k: _ray_in_task.get(v) for k, v in shared_refs_dict.items()}
287
- # Merge with regular kwargs
1052
+ gate = Path(gate_path_str) if gate_path_str else None
1053
+ dereferenced = {}
1054
+ for k, v in shared_refs_dict.items():
1055
+ if k == 'progress_actor':
1056
+ dereferenced[k] = v
1057
+ else:
1058
+ dereferenced[k] = _ray_in_task.get(v)
288
1059
  all_kwargs = {**dereferenced, **regular_kwargs_dict}
289
- return f_wrapped(x, **all_kwargs)
1060
+ return _call_with_log_control(
1061
+ f_wrapped,
1062
+ x,
1063
+ all_kwargs,
1064
+ log_worker,
1065
+ gate,
1066
+ )
290
1067
 
291
1068
  refs = [
292
1069
  _task.remote(x, shared_refs, regular_kwargs) for x in items
293
1070
  ]
294
1071
 
295
- results = []
296
1072
  t_start = time.time()
297
- for r in refs:
298
- results.append(_ray_module.get(r))
299
- pbar.update(1)
1073
+
1074
+ if progress_actor is not None:
1075
+ pbar.total = total_items
1076
+ pbar.refresh()
1077
+ progress_poller = ProgressPoller(progress_actor, pbar, poll_interval)
1078
+ progress_poller.start()
1079
+
1080
+ for idx, r in enumerate(refs):
1081
+ try:
1082
+ result = _ray_module.get(r)
1083
+ error_stats.record_success()
1084
+ results.append(result)
1085
+ except _ray_module.exceptions.RayTaskError as e:
1086
+ if error_handler == 'raise':
1087
+ if progress_poller is not None:
1088
+ progress_poller.stop()
1089
+ _reraise_ray_error(e, pbar, caller_info)
1090
+ # Extract original error from RayTaskError
1091
+ cause = e.cause if hasattr(e, 'cause') else e.__cause__
1092
+ original_error = cause if cause else e
1093
+ error_stats.record_error(idx, original_error, items[idx], func_name)
1094
+ results.append(None)
1095
+
1096
+ if progress_actor is None:
1097
+ pbar.update(1)
1098
+ _update_pbar_postfix(pbar)
1099
+
1100
+ if progress_poller is not None:
1101
+ progress_poller.stop()
1102
+
300
1103
  t_end = time.time()
301
- print(f"Ray processing took {t_end - t_start:.2f}s for {total} items")
302
- return results
1104
+ item_desc = f"{total_items:,} items" if total_items else f"{total} tasks"
1105
+ print(f"Ray processing took {t_end - t_start:.2f}s for {item_desc}")
1106
+ _cleanup_log_gate(log_gate_path)
1107
+ return results
303
1108
 
304
- # ---- fastcore backend ----
305
- if backend == 'mp':
306
- results = parallel(
307
- f_wrapped, items, n_workers=workers, progress=progress, threadpool=False
308
- )
309
- _track_multiprocessing_processes() # Track multiprocessing workers
310
- _prune_dead_processes() # Clean up dead processes
311
- return list(results)
312
- if backend == 'threadpool':
313
- results = parallel(
314
- f_wrapped, items, n_workers=workers, progress=progress, threadpool=True
1109
+ # ---- fastcore/thread backend (mp) ----
1110
+ if backend == 'mp':
1111
+ import concurrent.futures
1112
+
1113
+ # Capture caller frame for better error reporting
1114
+ caller_frame = inspect.currentframe()
1115
+ caller_info = None
1116
+ if caller_frame and caller_frame.f_back:
1117
+ caller_info = {
1118
+ 'filename': caller_frame.f_back.f_code.co_filename,
1119
+ 'lineno': caller_frame.f_back.f_lineno,
1120
+ 'function': caller_frame.f_back.f_code.co_name,
1121
+ }
1122
+
1123
+ def worker_func(x):
1124
+ return _call_with_log_control(
1125
+ f_wrapped,
1126
+ x,
1127
+ func_kwargs,
1128
+ log_worker,
1129
+ log_gate_path,
315
1130
  )
316
- return list(results)
317
- if backend == 'safe':
318
- # Completely safe backend for tests - no multiprocessing, no external progress bars
319
- import concurrent.futures
320
-
321
- # Import thread tracking from thread module
1131
+
1132
+ results: list[Any] = [None] * total
1133
+ with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
322
1134
  try:
323
1135
  from .thread import _prune_dead_threads, _track_executor_threads
324
-
325
- with concurrent.futures.ThreadPoolExecutor(
326
- max_workers=workers
327
- ) as executor:
328
- _track_executor_threads(executor) # Track threads
329
- results = list(executor.map(f_wrapped, items))
330
- _prune_dead_threads() # Clean up dead threads
1136
+ has_thread_tracking = True
331
1137
  except ImportError:
332
- # Fallback if thread module not available
333
- with concurrent.futures.ThreadPoolExecutor(
334
- max_workers=workers
335
- ) as executor:
336
- results = list(executor.map(f_wrapped, items))
337
- return results
1138
+ has_thread_tracking = False
1139
+
1140
+ with concurrent.futures.ThreadPoolExecutor(
1141
+ max_workers=workers
1142
+ ) as executor:
1143
+ if has_thread_tracking:
1144
+ _track_executor_threads(executor)
1145
+
1146
+ # Submit all tasks
1147
+ future_to_idx = {
1148
+ executor.submit(worker_func, x): idx
1149
+ for idx, x in enumerate(items)
1150
+ }
1151
+
1152
+ # Process results as they complete
1153
+ for future in concurrent.futures.as_completed(future_to_idx):
1154
+ idx = future_to_idx[future]
1155
+ try:
1156
+ result = future.result()
1157
+ error_stats.record_success()
1158
+ results[idx] = result
1159
+ except Exception as e:
1160
+ if error_handler == 'raise':
1161
+ # Cancel remaining futures
1162
+ for f in future_to_idx:
1163
+ f.cancel()
1164
+ _reraise_worker_error(e, pbar, caller_info, backend='mp')
1165
+ error_stats.record_error(idx, e, items[idx], func_name)
1166
+ results[idx] = None
1167
+ pbar.update(1)
1168
+ _update_pbar_postfix(pbar)
1169
+
1170
+ if _prune_dead_threads is not None:
1171
+ _prune_dead_threads()
1172
+
1173
+ _track_multiprocessing_processes()
1174
+ _prune_dead_processes()
1175
+ _cleanup_log_gate(log_gate_path)
1176
+ return results
1177
+
1178
+ if backend == 'safe':
1179
+ # Completely safe backend for tests - no multiprocessing
1180
+ import concurrent.futures
1181
+
1182
+ # Capture caller frame for better error reporting
1183
+ caller_frame = inspect.currentframe()
1184
+ caller_info = None
1185
+ if caller_frame and caller_frame.f_back:
1186
+ caller_info = {
1187
+ 'filename': caller_frame.f_back.f_code.co_filename,
1188
+ 'lineno': caller_frame.f_back.f_lineno,
1189
+ 'function': caller_frame.f_back.f_code.co_name,
1190
+ }
1191
+
1192
+ def worker_func(x):
1193
+ return _call_with_log_control(
1194
+ f_wrapped,
1195
+ x,
1196
+ func_kwargs,
1197
+ log_worker,
1198
+ log_gate_path,
1199
+ )
1200
+
1201
+ results: list[Any] = [None] * total
1202
+ with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
1203
+ with concurrent.futures.ThreadPoolExecutor(
1204
+ max_workers=workers
1205
+ ) as executor:
1206
+ if _track_executor_threads is not None:
1207
+ _track_executor_threads(executor)
1208
+
1209
+ # Submit all tasks
1210
+ future_to_idx = {
1211
+ executor.submit(worker_func, x): idx
1212
+ for idx, x in enumerate(items)
1213
+ }
1214
+
1215
+ # Process results as they complete
1216
+ for future in concurrent.futures.as_completed(future_to_idx):
1217
+ idx = future_to_idx[future]
1218
+ try:
1219
+ result = future.result()
1220
+ error_stats.record_success()
1221
+ results[idx] = result
1222
+ except Exception as e:
1223
+ if error_handler == 'raise':
1224
+ for f in future_to_idx:
1225
+ f.cancel()
1226
+ _reraise_worker_error(e, pbar, caller_info, backend='safe')
1227
+ error_stats.record_error(idx, e, items[idx], func_name)
1228
+ results[idx] = None
1229
+ pbar.update(1)
1230
+ _update_pbar_postfix(pbar)
1231
+
1232
+ if _prune_dead_threads is not None:
1233
+ _prune_dead_threads()
1234
+
1235
+ _cleanup_log_gate(log_gate_path)
1236
+ return results
338
1237
 
339
- raise ValueError(f'Unsupported backend: {backend!r}')
1238
+ raise ValueError(f'Unsupported backend: {backend!r}')
340
1239
 
341
1240
 
342
1241
  def cleanup_phantom_workers():
@@ -394,6 +1293,10 @@ def cleanup_phantom_workers():
394
1293
 
395
1294
  __all__ = [
396
1295
  'SPEEDY_RUNNING_PROCESSES',
1296
+ 'ErrorStats',
1297
+ 'ErrorHandlerType',
397
1298
  'multi_process',
398
1299
  'cleanup_phantom_workers',
1300
+ 'create_progress_tracker',
1301
+ 'get_ray_progress_actor',
399
1302
  ]