speedy-utils 1.1.42__py3-none-any.whl → 1.1.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,11 +9,491 @@ os.environ["RAY_ACCEL_ENV_VAR_OVERRI" \
9
9
  "DE_ON_ZERO"] = "0"
10
10
  os.environ["RAY_DEDUP_LOGS"] = "1"
11
11
  os.environ["RAY_LOG_TO_STDERR"] = "0"
12
+ os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
12
13
 
13
14
  from ..__imports import *
14
15
  import tempfile
16
+ import inspect
17
+ import linecache
18
+ import traceback as tb_module
15
19
  from .progress import create_progress_tracker, ProgressPoller, get_ray_progress_actor
16
20
 
21
+ # Import thread tracking functions if available
22
+ try:
23
+ from .thread import _prune_dead_threads, _track_executor_threads
24
+ except ImportError:
25
+ _prune_dead_threads = None # type: ignore[assignment]
26
+ _track_executor_threads = None # type: ignore[assignment]
27
+
28
+
29
+ # ─── error handler types ────────────────────────────────────────
30
+ ErrorHandlerType = Literal['raise', 'ignore', 'log']
31
+
32
+
33
+ class ErrorStats:
34
+ """Thread-safe error statistics tracker."""
35
+
36
+ def __init__(
37
+ self,
38
+ func_name: str,
39
+ max_error_files: int = 100,
40
+ write_logs: bool = True
41
+ ):
42
+ self._lock = threading.Lock()
43
+ self._success_count = 0
44
+ self._error_count = 0
45
+ self._first_error_shown = False
46
+ self._max_error_files = max_error_files
47
+ self._write_logs = write_logs
48
+ self._error_log_dir = self._get_error_log_dir(func_name)
49
+ if write_logs:
50
+ self._error_log_dir.mkdir(parents=True, exist_ok=True)
51
+
52
+ @staticmethod
53
+ def _get_error_log_dir(func_name: str) -> Path:
54
+ """Generate unique error log directory with run counter."""
55
+ base_dir = Path('.cache/speedy_utils/error_logs')
56
+ base_dir.mkdir(parents=True, exist_ok=True)
57
+
58
+ # Find the next run counter
59
+ counter = 1
60
+ existing = list(base_dir.glob(f'{func_name}_run_*'))
61
+ if existing:
62
+ counters = []
63
+ for p in existing:
64
+ try:
65
+ parts = p.name.split('_run_')
66
+ if len(parts) == 2:
67
+ counters.append(int(parts[1]))
68
+ except (ValueError, IndexError):
69
+ pass
70
+ if counters:
71
+ counter = max(counters) + 1
72
+
73
+ return base_dir / f'{func_name}_run_{counter}'
74
+
75
+ def record_success(self) -> None:
76
+ with self._lock:
77
+ self._success_count += 1
78
+
79
+ def record_error(
80
+ self,
81
+ idx: int,
82
+ error: Exception,
83
+ input_value: Any,
84
+ func_name: str,
85
+ ) -> str | None:
86
+ """
87
+ Record an error and write to log file.
88
+ Returns the log file path if written, None otherwise.
89
+ """
90
+ with self._lock:
91
+ self._error_count += 1
92
+ should_show_first = not self._first_error_shown
93
+ if should_show_first:
94
+ self._first_error_shown = True
95
+ should_write = (
96
+ self._write_logs and self._error_count <= self._max_error_files
97
+ )
98
+
99
+ log_path = None
100
+ if should_write:
101
+ log_path = self._write_error_log(
102
+ idx, error, input_value, func_name
103
+ )
104
+
105
+ if should_show_first:
106
+ self._print_first_error(error, input_value, func_name, log_path)
107
+
108
+ return log_path
109
+
110
+ def _write_error_log(
111
+ self,
112
+ idx: int,
113
+ error: Exception,
114
+ input_value: Any,
115
+ func_name: str,
116
+ ) -> str:
117
+ """Write error details to a log file."""
118
+ from io import StringIO
119
+ from rich.console import Console
120
+
121
+ log_path = self._error_log_dir / f'{idx}.log'
122
+
123
+ output = StringIO()
124
+ console = Console(file=output, width=120, no_color=False)
125
+
126
+ # Format traceback
127
+ tb_lines = self._format_traceback(error)
128
+
129
+ console.print(f'{"=" * 60}')
130
+ console.print(f'Error at index: {idx}')
131
+ console.print(f'Function: {func_name}')
132
+ console.print(f'Error Type: {type(error).__name__}')
133
+ console.print(f'Error Message: {error}')
134
+ console.print(f'{"=" * 60}')
135
+ console.print('')
136
+ console.print('Input:')
137
+ console.print('-' * 40)
138
+ try:
139
+ import json
140
+ console.print(json.dumps(input_value, indent=2))
141
+ except Exception:
142
+ console.print(repr(input_value))
143
+ console.print('')
144
+ console.print('Traceback:')
145
+ console.print('-' * 40)
146
+ for line in tb_lines:
147
+ console.print(line)
148
+
149
+ with open(log_path, 'w') as f:
150
+ f.write(output.getvalue())
151
+
152
+ return str(log_path)
153
+
154
+ def _format_traceback(self, error: Exception) -> list[str]:
155
+ """Format traceback with context lines like Rich panel."""
156
+ lines = []
157
+ frames = _extract_frames_from_traceback(error)
158
+
159
+ for filepath, lineno, funcname, frame_locals in frames:
160
+ lines.append(f'│ {filepath}:{lineno} in {funcname} │')
161
+ lines.append('│' + ' ' * 70 + '│')
162
+
163
+ # Get context lines
164
+ context_size = 3
165
+ start_line = max(1, lineno - context_size)
166
+ end_line = lineno + context_size + 1
167
+
168
+ for line_num in range(start_line, end_line):
169
+ line_text = linecache.getline(filepath, line_num).rstrip()
170
+ if line_text:
171
+ num_str = str(line_num).rjust(4)
172
+ if line_num == lineno:
173
+ lines.append(f'│ {num_str} ❱ {line_text}')
174
+ else:
175
+ lines.append(f'│ {num_str} │ {line_text}')
176
+ lines.append('')
177
+
178
+ return lines
179
+
180
+ def _print_first_error(
181
+ self,
182
+ error: Exception,
183
+ input_value: Any,
184
+ func_name: str,
185
+ log_path: str | None,
186
+ ) -> None:
187
+ """Print the first error to screen with Rich formatting."""
188
+ try:
189
+ from rich.console import Console
190
+ from rich.panel import Panel
191
+ console = Console(stderr=True)
192
+
193
+ tb_lines = self._format_traceback(error)
194
+
195
+ console.print()
196
+ console.print(
197
+ Panel(
198
+ '\n'.join(tb_lines),
199
+ title='[bold red]First Error (continuing with remaining items)[/bold red]',
200
+ border_style='yellow',
201
+ expand=False,
202
+ )
203
+ )
204
+ console.print(
205
+ f'[bold red]{type(error).__name__}[/bold red]: {error}'
206
+ )
207
+ if log_path:
208
+ console.print(f'[dim]Error log: {log_path}[/dim]')
209
+ console.print()
210
+ except ImportError:
211
+ # Fallback to plain print
212
+ print(f'\n--- First Error ---', file=sys.stderr)
213
+ print(f'{type(error).__name__}: {error}', file=sys.stderr)
214
+ if log_path:
215
+ print(f'Error log: {log_path}', file=sys.stderr)
216
+ print('', file=sys.stderr)
217
+
218
+ @property
219
+ def success_count(self) -> int:
220
+ with self._lock:
221
+ return self._success_count
222
+
223
+ @property
224
+ def error_count(self) -> int:
225
+ with self._lock:
226
+ return self._error_count
227
+
228
+ def get_postfix_dict(self) -> dict[str, int]:
229
+ """Get dict for pbar postfix."""
230
+ with self._lock:
231
+ return {'ok': self._success_count, 'err': self._error_count}
232
+
233
+
234
+ def _should_skip_frame(filepath: str) -> bool:
235
+ """Check if a frame should be filtered from traceback display."""
236
+ skip_patterns = [
237
+ 'ray/_private',
238
+ 'ray/worker',
239
+ 'site-packages/ray',
240
+ 'speedy_utils/multi_worker',
241
+ 'concurrent/futures',
242
+ 'multiprocessing/',
243
+ 'fastcore/parallel',
244
+ 'fastcore/foundation',
245
+ 'fastcore/basics',
246
+ 'site-packages/fastcore',
247
+ '/threading.py',
248
+ '/concurrent/',
249
+ ]
250
+ return any(skip in filepath for skip in skip_patterns)
251
+
252
+
253
+ def _should_show_local(name: str, value: object) -> bool:
254
+ """Check if a local variable should be displayed in traceback."""
255
+ import types
256
+
257
+ # Skip dunder variables
258
+ if name.startswith('__') and name.endswith('__'):
259
+ return False
260
+
261
+ # Skip modules
262
+ if isinstance(value, types.ModuleType):
263
+ return False
264
+
265
+ # Skip type objects and classes
266
+ if isinstance(value, type):
267
+ return False
268
+
269
+ # Skip functions and methods
270
+ if isinstance(value, (types.FunctionType, types.MethodType, types.BuiltinFunctionType)):
271
+ return False
272
+
273
+ # Skip common typing aliases
274
+ value_str = str(value)
275
+ if value_str.startswith('typing.'):
276
+ return False
277
+
278
+ # Skip large objects that would clutter output
279
+ if value_str.startswith('<') and any(x in value_str for x in ['module', 'function', 'method', 'built-in']):
280
+ return False
281
+
282
+ return True
283
+
284
+
285
+ def _format_locals(frame_locals: dict) -> list[str]:
286
+ """Format local variables for display, filtering out noisy imports."""
287
+ from rich.pretty import Pretty
288
+ from rich.console import Console
289
+ from io import StringIO
290
+
291
+ # Filter locals
292
+ clean_locals = {k: v for k, v in frame_locals.items() if _should_show_local(k, v)}
293
+
294
+ if not clean_locals:
295
+ return []
296
+
297
+ lines = []
298
+ lines.append('[dim]╭─ locals ─╮[/dim]')
299
+
300
+ # Format each local variable
301
+ for name, value in clean_locals.items():
302
+ # Use Rich's Pretty for nice formatting
303
+ try:
304
+ console = Console(file=StringIO(), width=60)
305
+ console.print(Pretty(value), end='')
306
+ value_str = console.file.getvalue().strip()
307
+ # Limit length
308
+ if len(value_str) > 100:
309
+ value_str = value_str[:97] + '...'
310
+ except Exception:
311
+ value_str = repr(value)
312
+ if len(value_str) > 100:
313
+ value_str = value_str[:97] + '...'
314
+
315
+ lines.append(f'[dim]│[/dim] {name} = {value_str}')
316
+
317
+ lines.append('[dim]╰──────────╯[/dim]')
318
+ return lines
319
+
320
+
321
+ def _format_frame_with_context(filepath: str, lineno: int, funcname: str, frame_locals: dict | None = None) -> list[str]:
322
+ """Format a single frame with context lines and optional locals."""
323
+ lines = []
324
+ # Frame header
325
+ lines.append(
326
+ f'[cyan]{filepath}[/cyan]:[yellow]{lineno}[/yellow] '
327
+ f'in [green]{funcname}[/green]'
328
+ )
329
+ lines.append('')
330
+
331
+ # Get context lines
332
+ context_size = 3
333
+ start_line = max(1, lineno - context_size)
334
+ end_line = lineno + context_size + 1
335
+
336
+ for line_num in range(start_line, end_line):
337
+ import linecache
338
+ line_text = linecache.getline(filepath, line_num).rstrip()
339
+ if line_text:
340
+ num_str = str(line_num).rjust(4)
341
+ if line_num == lineno:
342
+ lines.append(f'[dim]{num_str}[/dim] [red]❱[/red] {line_text}')
343
+ else:
344
+ lines.append(f'[dim]{num_str} │[/dim] {line_text}')
345
+
346
+ # Add locals if available
347
+ if frame_locals:
348
+ locals_lines = _format_locals(frame_locals)
349
+ if locals_lines:
350
+ lines.append('')
351
+ lines.extend(locals_lines)
352
+
353
+ lines.append('')
354
+ return lines
355
+
356
+
357
+ def _display_formatted_error(
358
+ exc_type_name: str,
359
+ exc_msg: str,
360
+ frames: list[tuple[str, int, str, dict]],
361
+ caller_info: dict | None,
362
+ backend: str,
363
+ pbar=None,
364
+ ) -> None:
365
+
366
+ # Suppress additional error logs
367
+ os.environ['RAY_IGNORE_UNHANDLED_ERRORS'] = '1'
368
+
369
+ # Close progress bar cleanly if provided
370
+ if pbar is not None:
371
+ pbar.close()
372
+
373
+ from rich.console import Console
374
+ from rich.panel import Panel
375
+ console = Console(stderr=True)
376
+
377
+ if frames or caller_info:
378
+ display_lines = []
379
+
380
+ # Add caller frame first if available (no locals for caller)
381
+ if caller_info:
382
+ display_lines.extend(_format_frame_with_context(
383
+ caller_info['filename'],
384
+ caller_info['lineno'],
385
+ caller_info['function'],
386
+ None # Don't show locals for caller frame
387
+ ))
388
+
389
+ # Add error frames with locals
390
+ for filepath, lineno, funcname, frame_locals in frames:
391
+ display_lines.extend(_format_frame_with_context(
392
+ filepath, lineno, funcname, frame_locals
393
+ ))
394
+
395
+ # Display the traceback
396
+ console.print()
397
+ console.print(
398
+ Panel(
399
+ '\n'.join(display_lines),
400
+ title=f'[bold red]Traceback (most recent call last) [{backend}][/bold red]',
401
+ border_style='red',
402
+ expand=False,
403
+ )
404
+ )
405
+ console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
406
+ console.print()
407
+ else:
408
+ # No frames found, minimal output
409
+ console.print()
410
+ console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
411
+ console.print()
412
+
413
+ # Ensure output is flushed
414
+ sys.stderr.flush()
415
+ sys.stdout.flush()
416
+ sys.exit(1)
417
+
418
+
419
+ def _extract_frames_from_traceback(error: Exception) -> list[tuple[str, int, str, dict]]:
420
+ """Extract user frames from exception traceback object with locals."""
421
+ frames = []
422
+ if hasattr(error, '__traceback__') and error.__traceback__ is not None:
423
+ tb = error.__traceback__
424
+ while tb is not None:
425
+ frame = tb.tb_frame
426
+ filename = frame.f_code.co_filename
427
+ lineno = tb.tb_lineno
428
+ funcname = frame.f_code.co_name
429
+
430
+ if not _should_skip_frame(filename):
431
+ # Get local variables from the frame
432
+ frame_locals = dict(frame.f_locals)
433
+ frames.append((filename, lineno, funcname, frame_locals))
434
+
435
+ tb = tb.tb_next
436
+ return frames
437
+
438
+
439
+ def _extract_frames_from_ray_error(ray_task_error: Exception) -> list[tuple[str, int, str, dict]]:
440
+ """Extract user frames from Ray's string traceback representation."""
441
+ frames = []
442
+ error_str = str(ray_task_error)
443
+ lines = error_str.split('\n')
444
+
445
+ import re
446
+ for i, line in enumerate(lines):
447
+ # Match: File "path", line N, in func
448
+ file_match = re.match(r'\s*File "([^"]+)", line (\d+), in (.+)', line)
449
+ if file_match:
450
+ filepath, lineno, funcname = file_match.groups()
451
+ if not _should_skip_frame(filepath):
452
+ # Ray doesn't preserve locals, so use empty dict
453
+ frames.append((filepath, int(lineno), funcname, {}))
454
+
455
+ return frames
456
+
457
+
458
+ def _reraise_worker_error(error: Exception, pbar=None, caller_info=None, backend: str = 'unknown') -> None:
459
+ """
460
+ Re-raise the original exception from a worker error with clean traceback.
461
+ Works for multiprocessing, threadpool, and other backends with real tracebacks.
462
+ """
463
+ frames = _extract_frames_from_traceback(error)
464
+ _display_formatted_error(
465
+ exc_type_name=type(error).__name__,
466
+ exc_msg=str(error),
467
+ frames=frames,
468
+ caller_info=caller_info,
469
+ backend=backend,
470
+ pbar=pbar,
471
+ )
472
+
473
+
474
+ def _reraise_ray_error(ray_task_error: Exception, pbar=None, caller_info=None) -> None:
475
+ """
476
+ Re-raise the original exception from a RayTaskError with clean traceback.
477
+ Parses Ray's string traceback and displays with full context.
478
+ """
479
+ # Get the exception info
480
+ cause = ray_task_error.cause if hasattr(ray_task_error, 'cause') else None
481
+ if cause is None:
482
+ cause = ray_task_error.__cause__
483
+
484
+ exc_type_name = type(cause).__name__ if cause else 'Error'
485
+ exc_msg = str(cause) if cause else str(ray_task_error)
486
+
487
+ frames = _extract_frames_from_ray_error(ray_task_error)
488
+ _display_formatted_error(
489
+ exc_type_name=exc_type_name,
490
+ exc_msg=exc_msg,
491
+ frames=frames,
492
+ caller_info=caller_info,
493
+ backend='ray',
494
+ pbar=pbar,
495
+ )
496
+
17
497
 
18
498
  SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
19
499
  _SPEEDY_PROCESSES_LOCK = threading.Lock()
@@ -96,7 +576,7 @@ def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
96
576
  path = Path('.cache') / run_id
97
577
  path.mkdir(parents=True, exist_ok=True)
98
578
  return path
99
- _DUMP_THREADS = []
579
+ _DUMP_INTERMEDIATE_THREADS = []
100
580
  def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = True):
101
581
  """Wrap a function so results are dumped to .pkl when cache_dir is set."""
102
582
  if cache_dir is None:
@@ -115,7 +595,7 @@ def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = Tru
115
595
 
116
596
  if dump_in_thread:
117
597
  thread = threading.Thread(target=save)
118
- _DUMP_THREADS.append(thread)
598
+ _DUMP_INTERMEDIATE_THREADS.append(thread)
119
599
  # count thread
120
600
  # print(f'Thread count: {threading.active_count()}')
121
601
  while threading.active_count() > 16:
@@ -343,8 +823,7 @@ def multi_process(
343
823
  workers: int | None = None,
344
824
  lazy_output: bool = False,
345
825
  progress: bool = True,
346
- # backend: str = "ray", # "seq", "ray", or "fastcore"
347
- backend: Literal['seq', 'ray', 'mp', 'threadpool', 'safe'] = 'mp',
826
+ backend: Literal['seq', 'ray', 'mp', 'safe'] = 'mp',
348
827
  desc: str | None = None,
349
828
  shared_kwargs: list[str] | None = None,
350
829
  dump_in_thread: bool = True,
@@ -352,6 +831,8 @@ def multi_process(
352
831
  log_worker: Literal['zero', 'first', 'all'] = 'first',
353
832
  total_items: int | None = None,
354
833
  poll_interval: float = 0.3,
834
+ error_handler: ErrorHandlerType = 'log',
835
+ max_error_files: int = 100,
355
836
  **func_kwargs: Any,
356
837
  ) -> list[Any]:
357
838
  """
@@ -360,36 +841,44 @@ def multi_process(
360
841
  backend:
361
842
  - "seq": run sequentially
362
843
  - "ray": run in parallel with Ray
363
- - "mp": run in parallel with multiprocessing (uses threadpool to avoid fork warnings)
364
- - "threadpool": run in parallel with thread pool
844
+ - "mp": run in parallel with thread pool (uses ThreadPoolExecutor)
365
845
  - "safe": run in parallel with thread pool (explicitly safe for tests)
366
846
 
367
847
  shared_kwargs:
368
- - Optional list of kwarg names that should be shared via Ray's zero-copy object store
848
+ - Optional list of kwarg names that should be shared via Ray's
849
+ zero-copy object store
369
850
  - Only works with Ray backend
370
- - Useful for large objects (e.g., models, datasets) that should be shared across workers
371
- - Example: shared_kwargs=['model', 'tokenizer'] for sharing large ML models
851
+ - Useful for large objects (e.g., models, datasets)
852
+ - Example: shared_kwargs=['model', 'tokenizer']
372
853
 
373
854
  dump_in_thread:
374
855
  - Whether to dump results to disk in a separate thread (default: True)
375
- - If False, dumping is done synchronously, which may block but ensures data is saved before returning
856
+ - If False, dumping is done synchronously
376
857
 
377
858
  ray_metrics_port:
378
859
  - Optional port for Ray metrics export (Ray backend only)
379
- - Set to 0 to disable Ray metrics
380
- - If None, uses Ray's default behavior
381
860
 
382
861
  log_worker:
383
862
  - Control worker stdout/stderr noise
384
- - 'first': only first worker emits logs, others are silenced (default)
385
- - 'all': allow worker prints (may overlap tqdm)
386
- - 'zero': silence worker stdout/stderr to keep progress bar clean
863
+ - 'first': only first worker emits logs (default)
864
+ - 'all': allow worker prints
865
+ - 'zero': silence all worker output
387
866
 
388
867
  total_items:
389
868
  - Optional item-level total for progress tracking (Ray backend only)
390
869
 
391
870
  poll_interval:
392
- - Poll interval in seconds for progress actor updates (Ray backend only)
871
+ - Poll interval in seconds for progress actor updates (Ray only)
872
+
873
+ error_handler:
874
+ - 'raise': raise exception on first error (default)
875
+ - 'ignore': continue processing, return None for failed items
876
+ - 'log': same as ignore, but logs errors to files
877
+
878
+ max_error_files:
879
+ - Maximum number of error log files to write (default: 100)
880
+ - Error logs are written to .cache/speedy_utils/error_logs/{idx}.log
881
+ - First error is always printed to screen with the log file path
393
882
 
394
883
  If lazy_output=True, every result is saved to .pkl and
395
884
  the returned list contains file paths.
@@ -426,7 +915,7 @@ def multi_process(
426
915
  )
427
916
 
428
917
  # Prefer Ray backend when shared kwargs are requested
429
- if backend != 'ray':
918
+ if shared_kwargs and backend != 'ray':
430
919
  warnings.warn(
431
920
  "shared_kwargs only supported with 'ray' backend, switching backend to 'ray'",
432
921
  UserWarning,
@@ -462,21 +951,62 @@ def multi_process(
462
951
  else:
463
952
  desc = f'Multi-process [{backend}]'
464
953
 
954
+ # Initialize error stats for error handling
955
+ func_name = getattr(func, '__name__', repr(func))
956
+ error_stats = ErrorStats(
957
+ func_name=func_name,
958
+ max_error_files=max_error_files,
959
+ write_logs=error_handler == 'log'
960
+ )
961
+
962
+ def _update_pbar_postfix(pbar: tqdm) -> None:
963
+ """Update pbar with success/error counts."""
964
+ postfix = error_stats.get_postfix_dict()
965
+ pbar.set_postfix(postfix)
966
+
967
+ def _wrap_with_error_handler(
968
+ f: Callable,
969
+ idx: int,
970
+ input_value: Any,
971
+ error_stats_ref: ErrorStats,
972
+ handler: ErrorHandlerType,
973
+ ) -> Callable:
974
+ """Wrap function to handle errors based on error_handler setting."""
975
+ @functools.wraps(f)
976
+ def wrapper(*args, **kwargs):
977
+ try:
978
+ result = f(*args, **kwargs)
979
+ error_stats_ref.record_success()
980
+ return result
981
+ except Exception as e:
982
+ if handler == 'raise':
983
+ raise
984
+ error_stats_ref.record_error(idx, e, input_value, func_name)
985
+ return None
986
+ return wrapper
987
+
465
988
  # ---- sequential backend ----
466
989
  if backend == 'seq':
467
990
  results: list[Any] = []
468
991
  with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
469
- for x in items:
470
- results.append(
471
- _call_with_log_control(
992
+ for idx, x in enumerate(items):
993
+ try:
994
+ result = _call_with_log_control(
472
995
  f_wrapped,
473
996
  x,
474
997
  func_kwargs,
475
998
  log_worker,
476
999
  log_gate_path,
477
1000
  )
478
- )
1001
+ error_stats.record_success()
1002
+ results.append(result)
1003
+ except Exception as e:
1004
+ if error_handler == 'raise':
1005
+ raise
1006
+ error_stats.record_error(idx, e, x, func_name)
1007
+ results.append(None)
479
1008
  pbar.update(1)
1009
+ _update_pbar_postfix(pbar)
480
1010
  _cleanup_log_gate(log_gate_path)
481
1011
  return results
482
1012
 
@@ -484,6 +1014,16 @@ def multi_process(
484
1014
  if backend == 'ray':
485
1015
  import ray as _ray_module
486
1016
 
1017
+ # Capture caller frame for better error reporting
1018
+ caller_frame = inspect.currentframe()
1019
+ caller_info = None
1020
+ if caller_frame and caller_frame.f_back:
1021
+ caller_info = {
1022
+ 'filename': caller_frame.f_back.f_code.co_filename,
1023
+ 'lineno': caller_frame.f_back.f_lineno,
1024
+ 'function': caller_frame.f_back.f_code.co_name,
1025
+ }
1026
+
487
1027
  results = []
488
1028
  gate_path_str = str(log_gate_path) if log_gate_path else None
489
1029
  with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
@@ -496,7 +1036,7 @@ def multi_process(
496
1036
  progress_poller = None
497
1037
  if total_items is not None:
498
1038
  progress_actor = create_progress_tracker(total_items, desc or "Items")
499
- shared_refs['progress_actor'] = progress_actor # Pass actor handle directly (not via put)
1039
+ shared_refs['progress_actor'] = progress_actor
500
1040
 
501
1041
  if shared_kwargs:
502
1042
  for kw in shared_kwargs:
@@ -520,11 +1060,9 @@ def multi_process(
520
1060
  dereferenced = {}
521
1061
  for k, v in shared_refs_dict.items():
522
1062
  if k == 'progress_actor':
523
- # Pass actor handle directly (don't dereference)
524
1063
  dereferenced[k] = v
525
1064
  else:
526
1065
  dereferenced[k] = _ray_in_task.get(v)
527
- # Merge with regular kwargs
528
1066
  all_kwargs = {**dereferenced, **regular_kwargs_dict}
529
1067
  return _call_with_log_control(
530
1068
  f_wrapped,
@@ -540,21 +1078,32 @@ def multi_process(
540
1078
 
541
1079
  t_start = time.time()
542
1080
 
543
- # Start progress poller if using item-level progress
544
1081
  if progress_actor is not None:
545
- # Update pbar total to show items instead of tasks
546
1082
  pbar.total = total_items
547
1083
  pbar.refresh()
548
1084
  progress_poller = ProgressPoller(progress_actor, pbar, poll_interval)
549
1085
  progress_poller.start()
550
1086
 
551
- for r in refs:
552
- results.append(_ray_module.get(r))
1087
+ for idx, r in enumerate(refs):
1088
+ try:
1089
+ result = _ray_module.get(r)
1090
+ error_stats.record_success()
1091
+ results.append(result)
1092
+ except _ray_module.exceptions.RayTaskError as e:
1093
+ if error_handler == 'raise':
1094
+ if progress_poller is not None:
1095
+ progress_poller.stop()
1096
+ _reraise_ray_error(e, pbar, caller_info)
1097
+ # Extract original error from RayTaskError
1098
+ cause = e.cause if hasattr(e, 'cause') else e.__cause__
1099
+ original_error = cause if cause else e
1100
+ error_stats.record_error(idx, original_error, items[idx], func_name)
1101
+ results.append(None)
1102
+
553
1103
  if progress_actor is None:
554
- # Only update task-level progress if not using item-level
555
1104
  pbar.update(1)
1105
+ _update_pbar_postfix(pbar)
556
1106
 
557
- # Stop progress poller
558
1107
  if progress_poller is not None:
559
1108
  progress_poller.stop()
560
1109
 
@@ -564,73 +1113,132 @@ def multi_process(
564
1113
  _cleanup_log_gate(log_gate_path)
565
1114
  return results
566
1115
 
567
- # ---- fastcore backend ----
1116
+ # ---- fastcore/thread backend (mp) ----
568
1117
  if backend == 'mp':
569
- worker_func = functools.partial(
570
- _call_with_log_control,
571
- f_wrapped,
572
- func_kwargs=func_kwargs,
573
- log_mode=log_worker,
574
- gate_path=log_gate_path,
575
- )
576
- with _patch_fastcore_progress_bar(leave=True):
577
- results = list(parallel(
578
- worker_func, items, n_workers=workers, progress=progress, threadpool=False
579
- ))
580
- _track_multiprocessing_processes() # Track multiprocessing workers
581
- _prune_dead_processes() # Clean up dead processes
582
- _cleanup_log_gate(log_gate_path)
583
- return results
584
-
585
- if backend == 'threadpool':
586
- worker_func = functools.partial(
587
- _call_with_log_control,
588
- f_wrapped,
589
- func_kwargs=func_kwargs,
590
- log_mode=log_worker,
591
- gate_path=log_gate_path,
592
- )
593
- with _patch_fastcore_progress_bar(leave=True):
594
- results = list(parallel(
595
- worker_func, items, n_workers=workers, progress=progress, threadpool=True
596
- ))
597
- _cleanup_log_gate(log_gate_path)
598
- return results
599
-
600
- if backend == 'safe':
601
- # Completely safe backend for tests - no multiprocessing, no external progress bars
602
1118
  import concurrent.futures
603
-
604
- # Import thread tracking from thread module
605
- try:
606
- from .thread import _prune_dead_threads, _track_executor_threads
607
-
608
- worker_func = functools.partial(
609
- _call_with_log_control,
1119
+
1120
+ # Capture caller frame for better error reporting
1121
+ caller_frame = inspect.currentframe()
1122
+ caller_info = None
1123
+ if caller_frame and caller_frame.f_back:
1124
+ caller_info = {
1125
+ 'filename': caller_frame.f_back.f_code.co_filename,
1126
+ 'lineno': caller_frame.f_back.f_lineno,
1127
+ 'function': caller_frame.f_back.f_code.co_name,
1128
+ }
1129
+
1130
+ def worker_func(x):
1131
+ return _call_with_log_control(
610
1132
  f_wrapped,
611
- func_kwargs=func_kwargs,
612
- log_mode=log_worker,
613
- gate_path=log_gate_path,
1133
+ x,
1134
+ func_kwargs,
1135
+ log_worker,
1136
+ log_gate_path,
614
1137
  )
1138
+
1139
+ results: list[Any] = [None] * total
1140
+ with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
1141
+ try:
1142
+ from .thread import _prune_dead_threads, _track_executor_threads
1143
+ has_thread_tracking = True
1144
+ except ImportError:
1145
+ has_thread_tracking = False
1146
+
615
1147
  with concurrent.futures.ThreadPoolExecutor(
616
1148
  max_workers=workers
617
1149
  ) as executor:
618
- _track_executor_threads(executor) # Track threads
619
- results = list(executor.map(worker_func, items))
620
- _prune_dead_threads() # Clean up dead threads
621
- except ImportError:
622
- # Fallback if thread module not available
623
- worker_func = functools.partial(
624
- _call_with_log_control,
1150
+ if has_thread_tracking:
1151
+ _track_executor_threads(executor)
1152
+
1153
+ # Submit all tasks
1154
+ future_to_idx = {
1155
+ executor.submit(worker_func, x): idx
1156
+ for idx, x in enumerate(items)
1157
+ }
1158
+
1159
+ # Process results as they complete
1160
+ for future in concurrent.futures.as_completed(future_to_idx):
1161
+ idx = future_to_idx[future]
1162
+ try:
1163
+ result = future.result()
1164
+ error_stats.record_success()
1165
+ results[idx] = result
1166
+ except Exception as e:
1167
+ if error_handler == 'raise':
1168
+ # Cancel remaining futures
1169
+ for f in future_to_idx:
1170
+ f.cancel()
1171
+ _reraise_worker_error(e, pbar, caller_info, backend='mp')
1172
+ error_stats.record_error(idx, e, items[idx], func_name)
1173
+ results[idx] = None
1174
+ pbar.update(1)
1175
+ _update_pbar_postfix(pbar)
1176
+
1177
+ if _prune_dead_threads is not None:
1178
+ _prune_dead_threads()
1179
+
1180
+ _track_multiprocessing_processes()
1181
+ _prune_dead_processes()
1182
+ _cleanup_log_gate(log_gate_path)
1183
+ return results
1184
+
1185
+ if backend == 'safe':
1186
+ # Completely safe backend for tests - no multiprocessing
1187
+ import concurrent.futures
1188
+
1189
+ # Capture caller frame for better error reporting
1190
+ caller_frame = inspect.currentframe()
1191
+ caller_info = None
1192
+ if caller_frame and caller_frame.f_back:
1193
+ caller_info = {
1194
+ 'filename': caller_frame.f_back.f_code.co_filename,
1195
+ 'lineno': caller_frame.f_back.f_lineno,
1196
+ 'function': caller_frame.f_back.f_code.co_name,
1197
+ }
1198
+
1199
+ def worker_func(x):
1200
+ return _call_with_log_control(
625
1201
  f_wrapped,
626
- func_kwargs=func_kwargs,
627
- log_mode=log_worker,
628
- gate_path=log_gate_path,
1202
+ x,
1203
+ func_kwargs,
1204
+ log_worker,
1205
+ log_gate_path,
629
1206
  )
1207
+
1208
+ results: list[Any] = [None] * total
1209
+ with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
630
1210
  with concurrent.futures.ThreadPoolExecutor(
631
1211
  max_workers=workers
632
1212
  ) as executor:
633
- results = list(executor.map(worker_func, items))
1213
+ if _track_executor_threads is not None:
1214
+ _track_executor_threads(executor)
1215
+
1216
+ # Submit all tasks
1217
+ future_to_idx = {
1218
+ executor.submit(worker_func, x): idx
1219
+ for idx, x in enumerate(items)
1220
+ }
1221
+
1222
+ # Process results as they complete
1223
+ for future in concurrent.futures.as_completed(future_to_idx):
1224
+ idx = future_to_idx[future]
1225
+ try:
1226
+ result = future.result()
1227
+ error_stats.record_success()
1228
+ results[idx] = result
1229
+ except Exception as e:
1230
+ if error_handler == 'raise':
1231
+ for f in future_to_idx:
1232
+ f.cancel()
1233
+ _reraise_worker_error(e, pbar, caller_info, backend='safe')
1234
+ error_stats.record_error(idx, e, items[idx], func_name)
1235
+ results[idx] = None
1236
+ pbar.update(1)
1237
+ _update_pbar_postfix(pbar)
1238
+
1239
+ if _prune_dead_threads is not None:
1240
+ _prune_dead_threads()
1241
+
634
1242
  _cleanup_log_gate(log_gate_path)
635
1243
  return results
636
1244
 
@@ -692,6 +1300,8 @@ def cleanup_phantom_workers():
692
1300
 
693
1301
  __all__ = [
694
1302
  'SPEEDY_RUNNING_PROCESSES',
1303
+ 'ErrorStats',
1304
+ 'ErrorHandlerType',
695
1305
  'multi_process',
696
1306
  'cleanup_phantom_workers',
697
1307
  'create_progress_tracker',