speedy-utils 1.1.45__py3-none-any.whl → 1.1.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,1248 +1,67 @@
1
+ """
2
+ Multi-process map with selectable backends.
3
+
4
+ This module re-exports from the refactored implementation files:
5
+ - common.py: shared utilities (error formatting, log gating, cache, tracking)
6
+ - _multi_process.py: sequential + threadpool backends + dispatcher
7
+ - _multi_process_ray.py: Ray-specific backend
8
+
9
+ For backward compatibility, all public symbols are re-exported here.
10
+ """
1
11
  import warnings
2
12
  import os
13
+
3
14
  # Suppress Ray FutureWarnings before any imports
4
- warnings.filterwarnings("ignore", category=FutureWarning, module="ray.*")
5
- warnings.filterwarnings("ignore", message=".*pynvml.*deprecated.*", category=FutureWarning)
15
+ warnings.filterwarnings('ignore', category=FutureWarning, module='ray.*')
16
+ warnings.filterwarnings(
17
+ 'ignore',
18
+ message='.*pynvml.*deprecated.*',
19
+ category=FutureWarning,
20
+ )
6
21
 
7
22
  # Set environment variables before Ray is imported anywhere
8
- os.environ["RAY_ACCEL_ENV_VAR_OVERRI" \
9
- "DE_ON_ZERO"] = "0"
10
- os.environ["RAY_DEDUP_LOGS"] = "1"
11
- os.environ["RAY_LOG_TO_STDERR"] = "0"
12
- os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
13
-
14
- from ..__imports import *
15
- import tempfile
16
- import inspect
17
- import linecache
18
- import traceback as tb_module
19
- from .progress import create_progress_tracker, ProgressPoller, get_ray_progress_actor
20
-
21
- # Import thread tracking functions if available
23
+ os.environ['RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO'] = '0'
24
+ os.environ['RAY_DEDUP_LOGS'] = '1'
25
+ os.environ['RAY_LOG_TO_STDERR'] = '0'
26
+ os.environ['RAY_IGNORE_UNHANDLED_ERRORS'] = '1'
27
+
28
+ # Re-export public API from common
29
+ from .common import (
30
+ ErrorHandlerType,
31
+ ErrorStats,
32
+ SPEEDY_RUNNING_PROCESSES,
33
+ cleanup_phantom_workers,
34
+ )
35
+
36
+ # Re-export main dispatcher
37
+ from ._multi_process import multi_process
38
+
39
+ # Re-export Ray utilities (lazy import to avoid requiring Ray)
22
40
  try:
23
- from .thread import _prune_dead_threads, _track_executor_threads
41
+ from ._multi_process_ray import ensure_ray, RAY_WORKER
24
42
  except ImportError:
25
- _prune_dead_threads = None # type: ignore[assignment]
26
- _track_executor_threads = None # type: ignore[assignment]
27
-
28
-
29
- # ─── error handler types ────────────────────────────────────────
30
- ErrorHandlerType = Literal['raise', 'ignore', 'log']
31
-
32
-
33
- class ErrorStats:
34
- """Thread-safe error statistics tracker."""
35
-
36
- def __init__(
37
- self,
38
- func_name: str,
39
- max_error_files: int = 100,
40
- write_logs: bool = True
41
- ):
42
- self._lock = threading.Lock()
43
- self._success_count = 0
44
- self._error_count = 0
45
- self._first_error_shown = False
46
- self._max_error_files = max_error_files
47
- self._write_logs = write_logs
48
- self._error_log_dir = self._get_error_log_dir(func_name)
49
- if write_logs:
50
- self._error_log_dir.mkdir(parents=True, exist_ok=True)
51
-
52
- @staticmethod
53
- def _get_error_log_dir(func_name: str) -> Path:
54
- """Generate unique error log directory with run counter."""
55
- base_dir = Path('.cache/speedy_utils/error_logs')
56
- base_dir.mkdir(parents=True, exist_ok=True)
57
-
58
- # Find the next run counter
59
- counter = 1
60
- existing = list(base_dir.glob(f'{func_name}_run_*'))
61
- if existing:
62
- counters = []
63
- for p in existing:
64
- try:
65
- parts = p.name.split('_run_')
66
- if len(parts) == 2:
67
- counters.append(int(parts[1]))
68
- except (ValueError, IndexError):
69
- pass
70
- if counters:
71
- counter = max(counters) + 1
72
-
73
- return base_dir / f'{func_name}_run_{counter}'
74
-
75
- def record_success(self) -> None:
76
- with self._lock:
77
- self._success_count += 1
78
-
79
- def record_error(
80
- self,
81
- idx: int,
82
- error: Exception,
83
- input_value: Any,
84
- func_name: str,
85
- ) -> str | None:
86
- """
87
- Record an error and write to log file.
88
- Returns the log file path if written, None otherwise.
89
- """
90
- with self._lock:
91
- self._error_count += 1
92
- should_show_first = not self._first_error_shown
93
- if should_show_first:
94
- self._first_error_shown = True
95
- should_write = (
96
- self._write_logs and self._error_count <= self._max_error_files
97
- )
98
-
99
- log_path = None
100
- if should_write:
101
- log_path = self._write_error_log(
102
- idx, error, input_value, func_name
103
- )
104
-
105
- if should_show_first:
106
- self._print_first_error(error, input_value, func_name, log_path)
107
-
108
- return log_path
109
-
110
- def _write_error_log(
111
- self,
112
- idx: int,
113
- error: Exception,
114
- input_value: Any,
115
- func_name: str,
116
- ) -> str:
117
- """Write error details to a log file."""
118
- from io import StringIO
119
- from rich.console import Console
120
-
121
- log_path = self._error_log_dir / f'{idx}.log'
122
-
123
- output = StringIO()
124
- console = Console(file=output, width=120, no_color=False)
125
-
126
- # Format traceback
127
- tb_lines = self._format_traceback(error)
128
-
129
- console.print(f'{"=" * 60}')
130
- console.print(f'Error at index: {idx}')
131
- console.print(f'Function: {func_name}')
132
- console.print(f'Error Type: {type(error).__name__}')
133
- console.print(f'Error Message: {error}')
134
- console.print(f'{"=" * 60}')
135
- console.print('')
136
- console.print('Input:')
137
- console.print('-' * 40)
138
- try:
139
- import json
140
- console.print(json.dumps(input_value, indent=2))
141
- except Exception:
142
- console.print(repr(input_value))
143
- console.print('')
144
- console.print('Traceback:')
145
- console.print('-' * 40)
146
- for line in tb_lines:
147
- console.print(line)
148
-
149
- with open(log_path, 'w') as f:
150
- f.write(output.getvalue())
151
-
152
- return str(log_path)
153
-
154
- def _format_traceback(self, error: Exception) -> list[str]:
155
- """Format traceback with context lines like Rich panel."""
156
- lines = []
157
- frames = _extract_frames_from_traceback(error)
158
-
159
- for filepath, lineno, funcname, frame_locals in frames:
160
- lines.append(f'│ {filepath}:{lineno} in {funcname} │')
161
- lines.append('│' + ' ' * 70 + '│')
162
-
163
- # Get context lines
164
- context_size = 3
165
- start_line = max(1, lineno - context_size)
166
- end_line = lineno + context_size + 1
167
-
168
- for line_num in range(start_line, end_line):
169
- line_text = linecache.getline(filepath, line_num).rstrip()
170
- if line_text:
171
- num_str = str(line_num).rjust(4)
172
- if line_num == lineno:
173
- lines.append(f'│ {num_str} ❱ {line_text}')
174
- else:
175
- lines.append(f'│ {num_str} │ {line_text}')
176
- lines.append('')
177
-
178
- return lines
179
-
180
- def _print_first_error(
181
- self,
182
- error: Exception,
183
- input_value: Any,
184
- func_name: str,
185
- log_path: str | None,
186
- ) -> None:
187
- """Print the first error to screen with Rich formatting."""
188
- try:
189
- from rich.console import Console
190
- from rich.panel import Panel
191
- console = Console(stderr=True)
192
-
193
- tb_lines = self._format_traceback(error)
194
-
195
- console.print()
196
- console.print(
197
- Panel(
198
- '\n'.join(tb_lines),
199
- title='[bold red]First Error (continuing with remaining items)[/bold red]',
200
- border_style='yellow',
201
- expand=False,
202
- )
203
- )
204
- console.print(
205
- f'[bold red]{type(error).__name__}[/bold red]: {error}'
206
- )
207
- if log_path:
208
- console.print(f'[dim]Error log: {log_path}[/dim]')
209
- console.print()
210
- except ImportError:
211
- # Fallback to plain print
212
- print(f'\n--- First Error ---', file=sys.stderr)
213
- print(f'{type(error).__name__}: {error}', file=sys.stderr)
214
- if log_path:
215
- print(f'Error log: {log_path}', file=sys.stderr)
216
- print('', file=sys.stderr)
217
-
218
- @property
219
- def success_count(self) -> int:
220
- with self._lock:
221
- return self._success_count
222
-
223
- @property
224
- def error_count(self) -> int:
225
- with self._lock:
226
- return self._error_count
227
-
228
- def get_postfix_dict(self) -> dict[str, int]:
229
- """Get dict for pbar postfix."""
230
- with self._lock:
231
- return {'ok': self._success_count, 'err': self._error_count}
232
-
233
-
234
- def _should_skip_frame(filepath: str) -> bool:
235
- """Check if a frame should be filtered from traceback display."""
236
- skip_patterns = [
237
- 'ray/_private',
238
- 'ray/worker',
239
- 'site-packages/ray',
240
- 'speedy_utils/multi_worker',
241
- 'concurrent/futures',
242
- 'multiprocessing/',
243
- 'fastcore/parallel',
244
- 'fastcore/foundation',
245
- 'fastcore/basics',
246
- 'site-packages/fastcore',
247
- '/threading.py',
248
- '/concurrent/',
249
- ]
250
- return any(skip in filepath for skip in skip_patterns)
251
-
252
-
253
- def _should_show_local(name: str, value: object) -> bool:
254
- """Check if a local variable should be displayed in traceback."""
255
- import types
256
-
257
- # Skip dunder variables
258
- if name.startswith('__') and name.endswith('__'):
259
- return False
260
-
261
- # Skip modules
262
- if isinstance(value, types.ModuleType):
263
- return False
264
-
265
- # Skip type objects and classes
266
- if isinstance(value, type):
267
- return False
268
-
269
- # Skip functions and methods
270
- if isinstance(value, (types.FunctionType, types.MethodType, types.BuiltinFunctionType)):
271
- return False
272
-
273
- # Skip common typing aliases
274
- value_str = str(value)
275
- if value_str.startswith('typing.'):
276
- return False
277
-
278
- # Skip large objects that would clutter output
279
- if value_str.startswith('<') and any(x in value_str for x in ['module', 'function', 'method', 'built-in']):
280
- return False
281
-
282
- return True
283
-
284
-
285
- def _format_locals(frame_locals: dict) -> list[str]:
286
- """Format local variables for display, filtering out noisy imports."""
287
- from rich.pretty import Pretty
288
- from rich.console import Console
289
- from io import StringIO
290
-
291
- # Filter locals
292
- clean_locals = {k: v for k, v in frame_locals.items() if _should_show_local(k, v)}
293
-
294
- if not clean_locals:
295
- return []
296
-
297
- lines = []
298
- lines.append('[dim]╭─ locals ─╮[/dim]')
299
-
300
- # Format each local variable
301
- for name, value in clean_locals.items():
302
- # Use Rich's Pretty for nice formatting
303
- try:
304
- console = Console(file=StringIO(), width=60)
305
- console.print(Pretty(value), end='')
306
- value_str = console.file.getvalue().strip()
307
- # Limit length
308
- if len(value_str) > 100:
309
- value_str = value_str[:97] + '...'
310
- except Exception:
311
- value_str = repr(value)
312
- if len(value_str) > 100:
313
- value_str = value_str[:97] + '...'
314
-
315
- lines.append(f'[dim]│[/dim] {name} = {value_str}')
316
-
317
- lines.append('[dim]╰──────────╯[/dim]')
318
- return lines
319
-
320
-
321
- def _format_frame_with_context(filepath: str, lineno: int, funcname: str, frame_locals: dict | None = None) -> list[str]:
322
- """Format a single frame with context lines and optional locals."""
323
- lines = []
324
- # Frame header
325
- lines.append(
326
- f'[cyan]{filepath}[/cyan]:[yellow]{lineno}[/yellow] '
327
- f'in [green]{funcname}[/green]'
328
- )
329
- lines.append('')
330
-
331
- # Get context lines
332
- context_size = 3
333
- start_line = max(1, lineno - context_size)
334
- end_line = lineno + context_size + 1
335
-
336
- for line_num in range(start_line, end_line):
337
- import linecache
338
- line_text = linecache.getline(filepath, line_num).rstrip()
339
- if line_text:
340
- num_str = str(line_num).rjust(4)
341
- if line_num == lineno:
342
- lines.append(f'[dim]{num_str}[/dim] [red]❱[/red] {line_text}')
343
- else:
344
- lines.append(f'[dim]{num_str} │[/dim] {line_text}')
345
-
346
- # Add locals if available
347
- if frame_locals:
348
- locals_lines = _format_locals(frame_locals)
349
- if locals_lines:
350
- lines.append('')
351
- lines.extend(locals_lines)
352
-
353
- lines.append('')
354
- return lines
355
-
356
-
357
- def _display_formatted_error(
358
- exc_type_name: str,
359
- exc_msg: str,
360
- frames: list[tuple[str, int, str, dict]],
361
- caller_info: dict | None,
362
- backend: str,
363
- pbar=None,
364
- ) -> None:
365
-
366
- # Suppress additional error logs
367
- os.environ['RAY_IGNORE_UNHANDLED_ERRORS'] = '1'
368
-
369
- # Close progress bar cleanly if provided
370
- if pbar is not None:
371
- pbar.close()
372
-
373
- from rich.console import Console
374
- from rich.panel import Panel
375
- console = Console(stderr=True)
376
-
377
- if frames or caller_info:
378
- display_lines = []
379
-
380
- # Add caller frame first if available (no locals for caller)
381
- if caller_info:
382
- display_lines.extend(_format_frame_with_context(
383
- caller_info['filename'],
384
- caller_info['lineno'],
385
- caller_info['function'],
386
- None # Don't show locals for caller frame
387
- ))
388
-
389
- # Add error frames with locals
390
- for filepath, lineno, funcname, frame_locals in frames:
391
- display_lines.extend(_format_frame_with_context(
392
- filepath, lineno, funcname, frame_locals
393
- ))
394
-
395
- # Display the traceback
396
- console.print()
397
- console.print(
398
- Panel(
399
- '\n'.join(display_lines),
400
- title=f'[bold red]Traceback (most recent call last) [{backend}][/bold red]',
401
- border_style='red',
402
- expand=False,
403
- )
404
- )
405
- console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
406
- console.print()
407
- else:
408
- # No frames found, minimal output
409
- console.print()
410
- console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
411
- console.print()
412
-
413
- # Ensure output is flushed
414
- sys.stderr.flush()
415
- sys.stdout.flush()
416
- sys.exit(1)
417
-
418
-
419
- def _extract_frames_from_traceback(error: Exception) -> list[tuple[str, int, str, dict]]:
420
- """Extract user frames from exception traceback object with locals."""
421
- frames = []
422
- if hasattr(error, '__traceback__') and error.__traceback__ is not None:
423
- tb = error.__traceback__
424
- while tb is not None:
425
- frame = tb.tb_frame
426
- filename = frame.f_code.co_filename
427
- lineno = tb.tb_lineno
428
- funcname = frame.f_code.co_name
429
-
430
- if not _should_skip_frame(filename):
431
- # Get local variables from the frame
432
- frame_locals = dict(frame.f_locals)
433
- frames.append((filename, lineno, funcname, frame_locals))
434
-
435
- tb = tb.tb_next
436
- return frames
437
-
438
-
439
- def _extract_frames_from_ray_error(ray_task_error: Exception) -> list[tuple[str, int, str, dict]]:
440
- """Extract user frames from Ray's string traceback representation."""
441
- frames = []
442
- error_str = str(ray_task_error)
443
- lines = error_str.split('\n')
444
-
445
- import re
446
- for i, line in enumerate(lines):
447
- # Match: File "path", line N, in func
448
- file_match = re.match(r'\s*File "([^"]+)", line (\d+), in (.+)', line)
449
- if file_match:
450
- filepath, lineno, funcname = file_match.groups()
451
- if not _should_skip_frame(filepath):
452
- # Ray doesn't preserve locals, so use empty dict
453
- frames.append((filepath, int(lineno), funcname, {}))
454
-
455
- return frames
456
-
457
-
458
- def _reraise_worker_error(error: Exception, pbar=None, caller_info=None, backend: str = 'unknown') -> None:
459
- """
460
- Re-raise the original exception from a worker error with clean traceback.
461
- Works for multiprocessing, threadpool, and other backends with real tracebacks.
462
- """
463
- frames = _extract_frames_from_traceback(error)
464
- _display_formatted_error(
465
- exc_type_name=type(error).__name__,
466
- exc_msg=str(error),
467
- frames=frames,
468
- caller_info=caller_info,
469
- backend=backend,
470
- pbar=pbar,
471
- )
472
-
473
-
474
- def _reraise_ray_error(ray_task_error: Exception, pbar=None, caller_info=None) -> None:
475
- """
476
- Re-raise the original exception from a RayTaskError with clean traceback.
477
- Parses Ray's string traceback and displays with full context.
478
- """
479
- # Get the exception info
480
- cause = ray_task_error.cause if hasattr(ray_task_error, 'cause') else None
481
- if cause is None:
482
- cause = ray_task_error.__cause__
483
-
484
- exc_type_name = type(cause).__name__ if cause else 'Error'
485
- exc_msg = str(cause) if cause else str(ray_task_error)
486
-
487
- frames = _extract_frames_from_ray_error(ray_task_error)
488
- _display_formatted_error(
489
- exc_type_name=exc_type_name,
490
- exc_msg=exc_msg,
491
- frames=frames,
492
- caller_info=caller_info,
493
- backend='ray',
494
- pbar=pbar,
495
- )
496
-
497
-
498
- SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
499
- _SPEEDY_PROCESSES_LOCK = threading.Lock()
500
-
501
- def _prune_dead_processes() -> None:
502
- """Remove dead processes from tracking list."""
503
- with _SPEEDY_PROCESSES_LOCK:
504
- SPEEDY_RUNNING_PROCESSES[:] = [
505
- p for p in SPEEDY_RUNNING_PROCESSES if p.is_running()
506
- ]
507
-
508
-
509
- def _track_processes(processes: list[psutil.Process]) -> None:
510
- """Add processes to global tracking list."""
511
- if not processes:
512
- return
513
- with _SPEEDY_PROCESSES_LOCK:
514
- living = [p for p in SPEEDY_RUNNING_PROCESSES if p.is_running()]
515
- for candidate in processes:
516
- if not candidate.is_running():
517
- continue
518
- if any(existing.pid == candidate.pid for existing in living):
519
- continue
520
- living.append(candidate)
521
- SPEEDY_RUNNING_PROCESSES[:] = living
522
-
523
-
524
- def _track_ray_processes() -> None:
525
- """Track Ray worker processes when Ray is initialized."""
526
-
527
- try:
528
- # Get Ray worker processes
529
- current_pid = os.getpid()
530
- parent = psutil.Process(current_pid)
531
- ray_processes = []
532
- for child in parent.children(recursive=True):
533
- try:
534
- if 'ray' in child.name().lower() or 'worker' in child.name().lower():
535
- ray_processes.append(child)
536
- except (psutil.NoSuchProcess, psutil.AccessDenied):
537
- continue
538
- _track_processes(ray_processes)
539
- except Exception:
540
- # Don't fail if process tracking fails
541
- pass
542
-
543
-
544
- def _track_multiprocessing_processes() -> None:
545
- """Track multiprocessing worker processes."""
546
- try:
547
- # Find recently created child processes that might be multiprocessing workers
548
- current_pid = os.getpid()
549
- parent = psutil.Process(current_pid)
550
- new_processes = []
551
- for child in parent.children(recursive=False): # Only direct children
552
- try:
553
- # Basic heuristic: if it's a recent child process, it might be a worker
554
- if (
555
- time.time() - child.create_time() < 5
556
- ): # Created within last 5 seconds
557
- new_processes.append(child)
558
- except (psutil.NoSuchProcess, psutil.AccessDenied):
559
- continue
560
- _track_processes(new_processes)
561
- except Exception:
562
- # Don't fail if process tracking fails
563
- pass
564
-
565
-
566
- # ─── cache helpers ──────────────────────────────────────────
567
-
568
-
569
- def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
570
- """Build cache dir with function name + timestamp."""
571
- import datetime
572
- func_name = getattr(func, '__name__', 'func')
573
- now = datetime.datetime.now()
574
- stamp = now.strftime('%m%d_%Hh%Mm%Ss')
575
- run_id = f'{func_name}_{stamp}_{uuid.uuid4().hex[:6]}'
576
- path = Path('.cache') / run_id
577
- path.mkdir(parents=True, exist_ok=True)
578
- return path
579
- _DUMP_INTERMEDIATE_THREADS = []
580
- def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = True):
581
- """Wrap a function so results are dumped to .pkl when cache_dir is set."""
582
- if cache_dir is None:
583
- return func
584
-
585
- def wrapped(x, *args, **kwargs):
586
- res = func(x, *args, **kwargs)
587
- p = cache_dir / f'{uuid.uuid4().hex}.pkl'
588
-
589
- def save():
590
- with open(p, 'wb') as fh:
591
- pickle.dump(res, fh)
592
- # Clean trash to avoid bloating memory
593
- # print(f'Thread count: {threading.active_count()}')
594
- # print(f'Saved result to {p}')
595
-
596
- if dump_in_thread:
597
- thread = threading.Thread(target=save)
598
- _DUMP_INTERMEDIATE_THREADS.append(thread)
599
- # count thread
600
- # print(f'Thread count: {threading.active_count()}')
601
- while threading.active_count() > 16:
602
- time.sleep(0.1)
603
- thread.start()
604
- else:
605
- save()
606
- return str(p)
607
-
608
- return wrapped
609
-
610
-
611
- _LOG_GATE_CACHE: dict[str, bool] = {}
612
-
613
-
614
- def _should_allow_worker_logs(mode: Literal['all', 'zero', 'first'], gate_path: Path | None) -> bool:
615
- """Determine if current worker should emit logs for the given mode."""
616
- if mode == 'all':
617
- return True
618
- if mode == 'zero':
619
- return False
620
- if mode == 'first':
621
- if gate_path is None:
622
- return True
623
- key = str(gate_path)
624
- cached = _LOG_GATE_CACHE.get(key)
625
- if cached is not None:
626
- return cached
627
- gate_path.parent.mkdir(parents=True, exist_ok=True)
628
- try:
629
- fd = os.open(key, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
630
- except FileExistsError:
631
- allowed = False
632
- else:
633
- os.close(fd)
634
- allowed = True
635
- _LOG_GATE_CACHE[key] = allowed
636
- return allowed
637
- raise ValueError(f'Unsupported log mode: {mode!r}')
638
-
639
-
640
- def _cleanup_log_gate(gate_path: Path | None):
641
- if gate_path is None:
642
- return
643
- try:
644
- gate_path.unlink(missing_ok=True)
645
- except OSError:
646
- pass
647
-
648
-
649
- @contextlib.contextmanager
650
- def _patch_fastcore_progress_bar(*, leave: bool = True):
651
- """Temporarily force fastcore.progress_bar to keep the bar on screen."""
652
- try:
653
- import fastcore.parallel as _fp
654
- except ImportError:
655
- yield False
656
- return
657
-
658
- orig = getattr(_fp, 'progress_bar', None)
659
- if orig is None:
660
- yield False
661
- return
662
-
663
- def _wrapped(*args, **kwargs):
664
- kwargs.setdefault('leave', leave)
665
- return orig(*args, **kwargs)
666
-
667
- _fp.progress_bar = _wrapped
668
- try:
669
- yield True
670
- finally:
671
- _fp.progress_bar = orig
672
-
673
-
674
- class _PrefixedWriter:
675
- """Stream wrapper that prefixes each line with worker id."""
676
-
677
- def __init__(self, stream, prefix: str):
678
- self._stream = stream
679
- self._prefix = prefix
680
- self._at_line_start = True
681
-
682
- def write(self, s):
683
- if not s:
684
- return 0
685
- total = 0
686
- for chunk in s.splitlines(True):
687
- if self._at_line_start:
688
- self._stream.write(self._prefix)
689
- total += len(self._prefix)
690
- self._stream.write(chunk)
691
- total += len(chunk)
692
- self._at_line_start = chunk.endswith('\n')
693
- return total
694
-
695
- def flush(self):
696
- self._stream.flush()
697
-
698
-
699
- def _call_with_log_control(
700
- func: Callable,
701
- x: Any,
702
- func_kwargs: dict[str, Any],
703
- log_mode: Literal['all', 'zero', 'first'],
704
- gate_path: Path | None,
705
- ):
706
- """Call a function, silencing stdout/stderr based on log mode."""
707
- allow_logs = _should_allow_worker_logs(log_mode, gate_path)
708
- if allow_logs:
709
- prefix = f"[worker-{os.getpid()}] "
710
- # Route worker logs to stderr to reduce clobbering tqdm/progress output on stdout
711
- out = _PrefixedWriter(sys.stderr, prefix)
712
- err = out
713
- with contextlib.redirect_stdout(out), contextlib.redirect_stderr(err):
714
- return func(x, **func_kwargs)
715
- with open(os.devnull, 'w') as devnull, contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
716
- return func(x, **func_kwargs)
717
-
718
-
719
- # ─── ray management ─────────────────────────────────────────
720
-
721
- RAY_WORKER = None
722
-
723
-
724
- def ensure_ray(workers: int | None, pbar: tqdm | None = None, ray_metrics_port: int | None = None):
725
- """
726
- Initialize or reinitialize Ray safely for both local and cluster environments.
727
-
728
- 1. Tries to connect to an existing cluster (address='auto') first.
729
- 2. If no cluster is found, starts a local Ray instance with 'workers' CPUs.
730
- """
731
- import ray as _ray_module
732
- import logging
733
-
734
- global RAY_WORKER
735
- requested_workers = workers
736
- if workers is None:
737
- workers = os.cpu_count() or 1
738
-
739
- if ray_metrics_port is not None:
740
- os.environ['RAY_metrics_export_port'] = str(ray_metrics_port)
741
-
742
- allow_restart = os.environ.get("RESTART_RAY", "0").lower() in ("1", "true")
743
- is_cluster_env = "RAY_ADDRESS" in os.environ or os.environ.get("RAY_CLUSTER") == "1"
744
-
745
- # 1. Handle existing session
746
- if _ray_module.is_initialized():
747
- if not allow_restart:
748
- if pbar:
749
- pbar.set_postfix_str("Using existing Ray session")
750
- return
751
-
752
- # Avoid shutting down shared cluster sessions.
753
- if is_cluster_env:
754
- if pbar:
755
- pbar.set_postfix_str("Cluster active: skipping restart to protect connection")
756
- return
757
-
758
- # Local restart: only if worker count changed
759
- if workers != RAY_WORKER:
760
- if pbar:
761
- pbar.set_postfix_str(f'Restarting local Ray with {workers} workers')
762
- _ray_module.shutdown()
763
- else:
764
- return
765
-
766
- # 2. Initialization logic
767
- t0 = time.time()
768
-
769
- # Try to connect to existing cluster FIRST (address="auto")
770
- try:
771
- if pbar:
772
- pbar.set_postfix_str("Searching for Ray cluster...")
773
-
774
- # MUST NOT pass num_cpus/num_gpus here to avoid ValueError on existing clusters
775
- _ray_module.init(
776
- address="auto",
777
- ignore_reinit_error=True,
778
- logging_level=logging.ERROR,
779
- log_to_driver=False
780
- )
781
-
782
- if pbar:
783
- resources = _ray_module.cluster_resources()
784
- cpus = resources.get("CPU", 0)
785
- pbar.set_postfix_str(f"Connected to Ray Cluster ({int(cpus)} CPUs)")
786
-
787
- except Exception:
788
- # 3. Fallback: Start a local Ray session
789
- if pbar:
790
- pbar.set_postfix_str(f"No cluster found. Starting local Ray ({workers} CPUs)...")
791
-
792
- _ray_module.init(
793
- num_cpus=workers,
794
- ignore_reinit_error=True,
795
- logging_level=logging.ERROR,
796
- log_to_driver=False,
797
- )
798
-
799
- if pbar:
800
- took = time.time() - t0
801
- pbar.set_postfix_str(f'ray.init local {workers} took {took:.2f}s')
802
-
803
- _track_ray_processes()
804
-
805
- if requested_workers is None:
806
- try:
807
- resources = _ray_module.cluster_resources()
808
- total_cpus = int(resources.get("CPU", 0))
809
- if total_cpus > 0:
810
- workers = total_cpus
811
- except Exception:
812
- pass
813
-
814
- RAY_WORKER = workers
815
-
816
-
817
- # TODO: make smarter backend selection, when shared_kwargs is used, and backend != 'ray', do not raise error but change to ray and warning user about this
818
- def multi_process(
819
- func: Callable[[Any], Any],
820
- items: Iterable[Any] | None = None,
821
- *,
822
- inputs: Iterable[Any] | None = None,
823
- workers: int | None = None,
824
- lazy_output: bool = False,
825
- progress: bool = True,
826
- backend: Literal['seq', 'ray', 'mp', 'safe'] = 'mp',
827
- desc: str | None = None,
828
- shared_kwargs: list[str] | None = None,
829
- dump_in_thread: bool = True,
830
- ray_metrics_port: int | None = None,
831
- log_worker: Literal['zero', 'first', 'all'] = 'first',
832
- total_items: int | None = None,
833
- poll_interval: float = 0.3,
834
- error_handler: ErrorHandlerType = 'log',
835
- max_error_files: int = 100,
836
- **func_kwargs: Any,
837
- ) -> list[Any]:
838
- """
839
- Multi-process map with selectable backend.
840
-
841
- backend:
842
- - "seq": run sequentially
843
- - "ray": run in parallel with Ray
844
- - "mp": run in parallel with thread pool (uses ThreadPoolExecutor)
845
- - "safe": run in parallel with thread pool (explicitly safe for tests)
43
+ ensure_ray = None # type: ignore[assignment,misc]
44
+ RAY_WORKER = None # type: ignore[assignment,misc]
846
45
 
847
- shared_kwargs:
848
- - Optional list of kwarg names that should be shared via Ray's
849
- zero-copy object store
850
- - Only works with Ray backend
851
- - Useful for large objects (e.g., models, datasets)
852
- - Example: shared_kwargs=['model', 'tokenizer']
46
+ # Re-export progress utilities
47
+ from .progress import create_progress_tracker, get_ray_progress_actor
853
48
 
854
- dump_in_thread:
855
- - Whether to dump results to disk in a separate thread (default: True)
856
- - If False, dumping is done synchronously
49
+ # Re-export tqdm for backward compatibility
50
+ from tqdm import tqdm
857
51
 
858
- ray_metrics_port:
859
- - Optional port for Ray metrics export (Ray backend only)
860
52
 
861
- log_worker:
862
- - Control worker stdout/stderr noise
863
- - 'first': only first worker emits logs (default)
864
- - 'all': allow worker prints
865
- - 'zero': silence all worker output
866
-
867
- total_items:
868
- - Optional item-level total for progress tracking (Ray backend only)
869
-
870
- poll_interval:
871
- - Poll interval in seconds for progress actor updates (Ray only)
872
-
873
- error_handler:
874
- - 'raise': raise exception on first error (default)
875
- - 'ignore': continue processing, return None for failed items
876
- - 'log': same as ignore, but logs errors to files
877
-
878
- max_error_files:
879
- - Maximum number of error log files to write (default: 100)
880
- - Error logs are written to .cache/speedy_utils/error_logs/{idx}.log
881
- - First error is always printed to screen with the log file path
882
-
883
- If lazy_output=True, every result is saved to .pkl and
884
- the returned list contains file paths.
885
- """
886
-
887
- # default backend selection
888
- if backend is None:
889
- try:
890
- import ray as _ray_module
891
- backend = 'ray'
892
- except ImportError:
893
- backend = 'mp'
894
-
895
- # Validate shared_kwargs
896
- if shared_kwargs:
897
- # Validate that all shared_kwargs are valid kwargs for the function
898
- sig = inspect.signature(func)
899
- valid_params = set(sig.parameters.keys())
900
-
901
- for kw in shared_kwargs:
902
- if kw not in func_kwargs:
903
- raise ValueError(
904
- f"shared_kwargs key '{kw}' not found in provided func_kwargs"
905
- )
906
- # Check if parameter exists in function signature or if function accepts **kwargs
907
- has_var_keyword = any(
908
- p.kind == inspect.Parameter.VAR_KEYWORD
909
- for p in sig.parameters.values()
910
- )
911
- if kw not in valid_params and not has_var_keyword:
912
- raise ValueError(
913
- f"shared_kwargs key '{kw}' is not a valid parameter for function '{func.__name__}'. "
914
- f"Valid parameters: {valid_params}"
915
- )
916
-
917
- # Prefer Ray backend when shared kwargs are requested
918
- if shared_kwargs and backend != 'ray':
919
- warnings.warn(
920
- "shared_kwargs only supported with 'ray' backend, switching backend to 'ray'",
921
- UserWarning,
922
- )
923
- backend = 'ray'
924
-
925
- # unify items
926
- # unify items and coerce to concrete list so we can use len() and
927
- # iterate multiple times. This accepts ranges and other iterables.
928
- if items is None and inputs is not None:
929
- items = list(inputs)
930
- if items is not None and not isinstance(items, list):
931
- items = list(items)
932
- if items is None:
933
- raise ValueError("'items' or 'inputs' must be provided")
934
-
935
- if workers is None and backend != 'ray':
936
- workers = os.cpu_count() or 1
937
-
938
- # build cache dir + wrap func
939
- cache_dir = _build_cache_dir(func, items) if lazy_output else None
940
- f_wrapped = wrap_dump(func, cache_dir, dump_in_thread)
941
-
942
- log_gate_path: Path | None = None
943
- if log_worker == 'first':
944
- log_gate_path = Path(tempfile.gettempdir()) / f'speedy_utils_log_gate_{os.getpid()}_{uuid.uuid4().hex}.gate'
945
- elif log_worker not in ('zero', 'all'):
946
- raise ValueError(f'Unsupported log_worker: {log_worker!r}')
947
-
948
- total = len(items)
949
- if desc:
950
- desc = desc.strip() + f'[{backend}]'
951
- else:
952
- desc = f'Multi-process [{backend}]'
953
-
954
- # Initialize error stats for error handling
955
- func_name = getattr(func, '__name__', repr(func))
956
- error_stats = ErrorStats(
957
- func_name=func_name,
958
- max_error_files=max_error_files,
959
- write_logs=error_handler == 'log'
960
- )
961
-
962
- def _update_pbar_postfix(pbar: tqdm) -> None:
963
- """Update pbar with success/error counts."""
964
- postfix = error_stats.get_postfix_dict()
965
- pbar.set_postfix(postfix)
966
-
967
- def _wrap_with_error_handler(
968
- f: Callable,
969
- idx: int,
970
- input_value: Any,
971
- error_stats_ref: ErrorStats,
972
- handler: ErrorHandlerType,
973
- ) -> Callable:
974
- """Wrap function to handle errors based on error_handler setting."""
975
- @functools.wraps(f)
976
- def wrapper(*args, **kwargs):
977
- try:
978
- result = f(*args, **kwargs)
979
- error_stats_ref.record_success()
980
- return result
981
- except Exception as e:
982
- if handler == 'raise':
983
- raise
984
- error_stats_ref.record_error(idx, e, input_value, func_name)
985
- return None
986
- return wrapper
987
-
988
- # ---- sequential backend ----
989
- if backend == 'seq':
990
- results: list[Any] = []
991
- with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
992
- for idx, x in enumerate(items):
993
- try:
994
- result = _call_with_log_control(
995
- f_wrapped,
996
- x,
997
- func_kwargs,
998
- log_worker,
999
- log_gate_path,
1000
- )
1001
- error_stats.record_success()
1002
- results.append(result)
1003
- except Exception as e:
1004
- if error_handler == 'raise':
1005
- raise
1006
- error_stats.record_error(idx, e, x, func_name)
1007
- results.append(None)
1008
- pbar.update(1)
1009
- _update_pbar_postfix(pbar)
1010
- _cleanup_log_gate(log_gate_path)
1011
- return results
1012
-
1013
- # ---- ray backend ----
1014
- if backend == 'ray':
1015
- import ray as _ray_module
1016
-
1017
- # Capture caller frame for better error reporting
1018
- caller_frame = inspect.currentframe()
1019
- caller_info = None
1020
- if caller_frame and caller_frame.f_back:
1021
- caller_info = {
1022
- 'filename': caller_frame.f_back.f_code.co_filename,
1023
- 'lineno': caller_frame.f_back.f_lineno,
1024
- 'function': caller_frame.f_back.f_code.co_name,
1025
- }
1026
-
1027
- results = []
1028
- gate_path_str = str(log_gate_path) if log_gate_path else None
1029
- with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
1030
- ensure_ray(workers, pbar, ray_metrics_port)
1031
- shared_refs = {}
1032
- regular_kwargs = {}
1033
-
1034
- # Create progress actor for item-level tracking if total_items specified
1035
- progress_actor = None
1036
- progress_poller = None
1037
- if total_items is not None:
1038
- progress_actor = create_progress_tracker(total_items, desc or "Items")
1039
- shared_refs['progress_actor'] = progress_actor
1040
-
1041
- if shared_kwargs:
1042
- for kw in shared_kwargs:
1043
- # Put large objects in Ray's object store (zero-copy)
1044
- shared_refs[kw] = _ray_module.put(func_kwargs[kw])
1045
- pbar.set_postfix_str(f'ray: shared `{kw}` via object store')
1046
-
1047
- # Remaining kwargs are regular
1048
- regular_kwargs = {
1049
- k: v for k, v in func_kwargs.items()
1050
- if k not in shared_kwargs
1051
- }
1052
- else:
1053
- regular_kwargs = func_kwargs
1054
-
1055
- @_ray_module.remote
1056
- def _task(x, shared_refs_dict, regular_kwargs_dict):
1057
- # Dereference shared objects (zero-copy for numpy arrays)
1058
- import ray as _ray_in_task
1059
- gate = Path(gate_path_str) if gate_path_str else None
1060
- dereferenced = {}
1061
- for k, v in shared_refs_dict.items():
1062
- if k == 'progress_actor':
1063
- dereferenced[k] = v
1064
- else:
1065
- dereferenced[k] = _ray_in_task.get(v)
1066
- all_kwargs = {**dereferenced, **regular_kwargs_dict}
1067
- return _call_with_log_control(
1068
- f_wrapped,
1069
- x,
1070
- all_kwargs,
1071
- log_worker,
1072
- gate,
1073
- )
1074
-
1075
- refs = [
1076
- _task.remote(x, shared_refs, regular_kwargs) for x in items
1077
- ]
1078
-
1079
- t_start = time.time()
1080
-
1081
- if progress_actor is not None:
1082
- pbar.total = total_items
1083
- pbar.refresh()
1084
- progress_poller = ProgressPoller(progress_actor, pbar, poll_interval)
1085
- progress_poller.start()
1086
-
1087
- for idx, r in enumerate(refs):
1088
- try:
1089
- result = _ray_module.get(r)
1090
- error_stats.record_success()
1091
- results.append(result)
1092
- except _ray_module.exceptions.RayTaskError as e:
1093
- if error_handler == 'raise':
1094
- if progress_poller is not None:
1095
- progress_poller.stop()
1096
- _reraise_ray_error(e, pbar, caller_info)
1097
- # Extract original error from RayTaskError
1098
- cause = e.cause if hasattr(e, 'cause') else e.__cause__
1099
- original_error = cause if cause else e
1100
- error_stats.record_error(idx, original_error, items[idx], func_name)
1101
- results.append(None)
1102
-
1103
- if progress_actor is None:
1104
- pbar.update(1)
1105
- _update_pbar_postfix(pbar)
1106
-
1107
- if progress_poller is not None:
1108
- progress_poller.stop()
1109
-
1110
- t_end = time.time()
1111
- item_desc = f"{total_items:,} items" if total_items else f"{total} tasks"
1112
- print(f"Ray processing took {t_end - t_start:.2f}s for {item_desc}")
1113
- _cleanup_log_gate(log_gate_path)
1114
- return results
1115
-
1116
- # ---- fastcore/thread backend (mp) ----
1117
- if backend == 'mp':
1118
- import concurrent.futures
1119
-
1120
- # Capture caller frame for better error reporting
1121
- caller_frame = inspect.currentframe()
1122
- caller_info = None
1123
- if caller_frame and caller_frame.f_back:
1124
- caller_info = {
1125
- 'filename': caller_frame.f_back.f_code.co_filename,
1126
- 'lineno': caller_frame.f_back.f_lineno,
1127
- 'function': caller_frame.f_back.f_code.co_name,
1128
- }
1129
-
1130
- def worker_func(x):
1131
- return _call_with_log_control(
1132
- f_wrapped,
1133
- x,
1134
- func_kwargs,
1135
- log_worker,
1136
- log_gate_path,
1137
- )
1138
-
1139
- results: list[Any] = [None] * total
1140
- with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
1141
- try:
1142
- from .thread import _prune_dead_threads, _track_executor_threads
1143
- has_thread_tracking = True
1144
- except ImportError:
1145
- has_thread_tracking = False
1146
-
1147
- with concurrent.futures.ThreadPoolExecutor(
1148
- max_workers=workers
1149
- ) as executor:
1150
- if has_thread_tracking:
1151
- _track_executor_threads(executor)
1152
-
1153
- # Submit all tasks
1154
- future_to_idx = {
1155
- executor.submit(worker_func, x): idx
1156
- for idx, x in enumerate(items)
1157
- }
1158
-
1159
- # Process results as they complete
1160
- for future in concurrent.futures.as_completed(future_to_idx):
1161
- idx = future_to_idx[future]
1162
- try:
1163
- result = future.result()
1164
- error_stats.record_success()
1165
- results[idx] = result
1166
- except Exception as e:
1167
- if error_handler == 'raise':
1168
- # Cancel remaining futures
1169
- for f in future_to_idx:
1170
- f.cancel()
1171
- _reraise_worker_error(e, pbar, caller_info, backend='mp')
1172
- error_stats.record_error(idx, e, items[idx], func_name)
1173
- results[idx] = None
1174
- pbar.update(1)
1175
- _update_pbar_postfix(pbar)
1176
-
1177
- if _prune_dead_threads is not None:
1178
- _prune_dead_threads()
1179
-
1180
- _track_multiprocessing_processes()
1181
- _prune_dead_processes()
1182
- _cleanup_log_gate(log_gate_path)
1183
- return results
1184
-
1185
- if backend == 'safe':
1186
- # Completely safe backend for tests - no multiprocessing
1187
- import concurrent.futures
1188
-
1189
- # Capture caller frame for better error reporting
1190
- caller_frame = inspect.currentframe()
1191
- caller_info = None
1192
- if caller_frame and caller_frame.f_back:
1193
- caller_info = {
1194
- 'filename': caller_frame.f_back.f_code.co_filename,
1195
- 'lineno': caller_frame.f_back.f_lineno,
1196
- 'function': caller_frame.f_back.f_code.co_name,
1197
- }
1198
-
1199
- def worker_func(x):
1200
- return _call_with_log_control(
1201
- f_wrapped,
1202
- x,
1203
- func_kwargs,
1204
- log_worker,
1205
- log_gate_path,
1206
- )
1207
-
1208
- results: list[Any] = [None] * total
1209
- with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
1210
- with concurrent.futures.ThreadPoolExecutor(
1211
- max_workers=workers
1212
- ) as executor:
1213
- if _track_executor_threads is not None:
1214
- _track_executor_threads(executor)
1215
-
1216
- # Submit all tasks
1217
- future_to_idx = {
1218
- executor.submit(worker_func, x): idx
1219
- for idx, x in enumerate(items)
1220
- }
1221
-
1222
- # Process results as they complete
1223
- for future in concurrent.futures.as_completed(future_to_idx):
1224
- idx = future_to_idx[future]
1225
- try:
1226
- result = future.result()
1227
- error_stats.record_success()
1228
- results[idx] = result
1229
- except Exception as e:
1230
- if error_handler == 'raise':
1231
- for f in future_to_idx:
1232
- f.cancel()
1233
- _reraise_worker_error(e, pbar, caller_info, backend='safe')
1234
- error_stats.record_error(idx, e, items[idx], func_name)
1235
- results[idx] = None
1236
- pbar.update(1)
1237
- _update_pbar_postfix(pbar)
1238
-
1239
- if _prune_dead_threads is not None:
1240
- _prune_dead_threads()
1241
-
1242
- _cleanup_log_gate(log_gate_path)
1243
- return results
1244
-
1245
- raise ValueError(f'Unsupported backend: {backend!r}')
53
+ __all__ = [
54
+ 'SPEEDY_RUNNING_PROCESSES',
55
+ 'ErrorStats',
56
+ 'ErrorHandlerType',
57
+ 'multi_process',
58
+ 'cleanup_phantom_workers',
59
+ 'create_progress_tracker',
60
+ 'get_ray_progress_actor',
61
+ 'ensure_ray',
62
+ 'RAY_WORKER',
63
+ 'tqdm',
64
+ ]
1246
65
 
1247
66
 
1248
67
  def cleanup_phantom_workers():