speedy-utils 1.1.42__py3-none-any.whl → 1.1.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- speedy_utils/__init__.py +7 -0
- speedy_utils/multi_worker/process.py +695 -85
- speedy_utils/multi_worker/thread.py +202 -42
- {speedy_utils-1.1.42.dist-info → speedy_utils-1.1.44.dist-info}/METADATA +158 -9
- {speedy_utils-1.1.42.dist-info → speedy_utils-1.1.44.dist-info}/RECORD +7 -7
- {speedy_utils-1.1.42.dist-info → speedy_utils-1.1.44.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.42.dist-info → speedy_utils-1.1.44.dist-info}/entry_points.txt +0 -0
|
@@ -9,11 +9,491 @@ os.environ["RAY_ACCEL_ENV_VAR_OVERRI" \
|
|
|
9
9
|
"DE_ON_ZERO"] = "0"
|
|
10
10
|
os.environ["RAY_DEDUP_LOGS"] = "1"
|
|
11
11
|
os.environ["RAY_LOG_TO_STDERR"] = "0"
|
|
12
|
+
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
|
|
12
13
|
|
|
13
14
|
from ..__imports import *
|
|
14
15
|
import tempfile
|
|
16
|
+
import inspect
|
|
17
|
+
import linecache
|
|
18
|
+
import traceback as tb_module
|
|
15
19
|
from .progress import create_progress_tracker, ProgressPoller, get_ray_progress_actor
|
|
16
20
|
|
|
21
|
+
# Import thread tracking functions if available
|
|
22
|
+
try:
|
|
23
|
+
from .thread import _prune_dead_threads, _track_executor_threads
|
|
24
|
+
except ImportError:
|
|
25
|
+
_prune_dead_threads = None # type: ignore[assignment]
|
|
26
|
+
_track_executor_threads = None # type: ignore[assignment]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# ─── error handler types ────────────────────────────────────────
|
|
30
|
+
ErrorHandlerType = Literal['raise', 'ignore', 'log']
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ErrorStats:
|
|
34
|
+
"""Thread-safe error statistics tracker."""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
func_name: str,
|
|
39
|
+
max_error_files: int = 100,
|
|
40
|
+
write_logs: bool = True
|
|
41
|
+
):
|
|
42
|
+
self._lock = threading.Lock()
|
|
43
|
+
self._success_count = 0
|
|
44
|
+
self._error_count = 0
|
|
45
|
+
self._first_error_shown = False
|
|
46
|
+
self._max_error_files = max_error_files
|
|
47
|
+
self._write_logs = write_logs
|
|
48
|
+
self._error_log_dir = self._get_error_log_dir(func_name)
|
|
49
|
+
if write_logs:
|
|
50
|
+
self._error_log_dir.mkdir(parents=True, exist_ok=True)
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def _get_error_log_dir(func_name: str) -> Path:
|
|
54
|
+
"""Generate unique error log directory with run counter."""
|
|
55
|
+
base_dir = Path('.cache/speedy_utils/error_logs')
|
|
56
|
+
base_dir.mkdir(parents=True, exist_ok=True)
|
|
57
|
+
|
|
58
|
+
# Find the next run counter
|
|
59
|
+
counter = 1
|
|
60
|
+
existing = list(base_dir.glob(f'{func_name}_run_*'))
|
|
61
|
+
if existing:
|
|
62
|
+
counters = []
|
|
63
|
+
for p in existing:
|
|
64
|
+
try:
|
|
65
|
+
parts = p.name.split('_run_')
|
|
66
|
+
if len(parts) == 2:
|
|
67
|
+
counters.append(int(parts[1]))
|
|
68
|
+
except (ValueError, IndexError):
|
|
69
|
+
pass
|
|
70
|
+
if counters:
|
|
71
|
+
counter = max(counters) + 1
|
|
72
|
+
|
|
73
|
+
return base_dir / f'{func_name}_run_{counter}'
|
|
74
|
+
|
|
75
|
+
def record_success(self) -> None:
|
|
76
|
+
with self._lock:
|
|
77
|
+
self._success_count += 1
|
|
78
|
+
|
|
79
|
+
def record_error(
|
|
80
|
+
self,
|
|
81
|
+
idx: int,
|
|
82
|
+
error: Exception,
|
|
83
|
+
input_value: Any,
|
|
84
|
+
func_name: str,
|
|
85
|
+
) -> str | None:
|
|
86
|
+
"""
|
|
87
|
+
Record an error and write to log file.
|
|
88
|
+
Returns the log file path if written, None otherwise.
|
|
89
|
+
"""
|
|
90
|
+
with self._lock:
|
|
91
|
+
self._error_count += 1
|
|
92
|
+
should_show_first = not self._first_error_shown
|
|
93
|
+
if should_show_first:
|
|
94
|
+
self._first_error_shown = True
|
|
95
|
+
should_write = (
|
|
96
|
+
self._write_logs and self._error_count <= self._max_error_files
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
log_path = None
|
|
100
|
+
if should_write:
|
|
101
|
+
log_path = self._write_error_log(
|
|
102
|
+
idx, error, input_value, func_name
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if should_show_first:
|
|
106
|
+
self._print_first_error(error, input_value, func_name, log_path)
|
|
107
|
+
|
|
108
|
+
return log_path
|
|
109
|
+
|
|
110
|
+
def _write_error_log(
|
|
111
|
+
self,
|
|
112
|
+
idx: int,
|
|
113
|
+
error: Exception,
|
|
114
|
+
input_value: Any,
|
|
115
|
+
func_name: str,
|
|
116
|
+
) -> str:
|
|
117
|
+
"""Write error details to a log file."""
|
|
118
|
+
from io import StringIO
|
|
119
|
+
from rich.console import Console
|
|
120
|
+
|
|
121
|
+
log_path = self._error_log_dir / f'{idx}.log'
|
|
122
|
+
|
|
123
|
+
output = StringIO()
|
|
124
|
+
console = Console(file=output, width=120, no_color=False)
|
|
125
|
+
|
|
126
|
+
# Format traceback
|
|
127
|
+
tb_lines = self._format_traceback(error)
|
|
128
|
+
|
|
129
|
+
console.print(f'{"=" * 60}')
|
|
130
|
+
console.print(f'Error at index: {idx}')
|
|
131
|
+
console.print(f'Function: {func_name}')
|
|
132
|
+
console.print(f'Error Type: {type(error).__name__}')
|
|
133
|
+
console.print(f'Error Message: {error}')
|
|
134
|
+
console.print(f'{"=" * 60}')
|
|
135
|
+
console.print('')
|
|
136
|
+
console.print('Input:')
|
|
137
|
+
console.print('-' * 40)
|
|
138
|
+
try:
|
|
139
|
+
import json
|
|
140
|
+
console.print(json.dumps(input_value, indent=2))
|
|
141
|
+
except Exception:
|
|
142
|
+
console.print(repr(input_value))
|
|
143
|
+
console.print('')
|
|
144
|
+
console.print('Traceback:')
|
|
145
|
+
console.print('-' * 40)
|
|
146
|
+
for line in tb_lines:
|
|
147
|
+
console.print(line)
|
|
148
|
+
|
|
149
|
+
with open(log_path, 'w') as f:
|
|
150
|
+
f.write(output.getvalue())
|
|
151
|
+
|
|
152
|
+
return str(log_path)
|
|
153
|
+
|
|
154
|
+
def _format_traceback(self, error: Exception) -> list[str]:
|
|
155
|
+
"""Format traceback with context lines like Rich panel."""
|
|
156
|
+
lines = []
|
|
157
|
+
frames = _extract_frames_from_traceback(error)
|
|
158
|
+
|
|
159
|
+
for filepath, lineno, funcname, frame_locals in frames:
|
|
160
|
+
lines.append(f'│ {filepath}:{lineno} in {funcname} │')
|
|
161
|
+
lines.append('│' + ' ' * 70 + '│')
|
|
162
|
+
|
|
163
|
+
# Get context lines
|
|
164
|
+
context_size = 3
|
|
165
|
+
start_line = max(1, lineno - context_size)
|
|
166
|
+
end_line = lineno + context_size + 1
|
|
167
|
+
|
|
168
|
+
for line_num in range(start_line, end_line):
|
|
169
|
+
line_text = linecache.getline(filepath, line_num).rstrip()
|
|
170
|
+
if line_text:
|
|
171
|
+
num_str = str(line_num).rjust(4)
|
|
172
|
+
if line_num == lineno:
|
|
173
|
+
lines.append(f'│ {num_str} ❱ {line_text}')
|
|
174
|
+
else:
|
|
175
|
+
lines.append(f'│ {num_str} │ {line_text}')
|
|
176
|
+
lines.append('')
|
|
177
|
+
|
|
178
|
+
return lines
|
|
179
|
+
|
|
180
|
+
def _print_first_error(
|
|
181
|
+
self,
|
|
182
|
+
error: Exception,
|
|
183
|
+
input_value: Any,
|
|
184
|
+
func_name: str,
|
|
185
|
+
log_path: str | None,
|
|
186
|
+
) -> None:
|
|
187
|
+
"""Print the first error to screen with Rich formatting."""
|
|
188
|
+
try:
|
|
189
|
+
from rich.console import Console
|
|
190
|
+
from rich.panel import Panel
|
|
191
|
+
console = Console(stderr=True)
|
|
192
|
+
|
|
193
|
+
tb_lines = self._format_traceback(error)
|
|
194
|
+
|
|
195
|
+
console.print()
|
|
196
|
+
console.print(
|
|
197
|
+
Panel(
|
|
198
|
+
'\n'.join(tb_lines),
|
|
199
|
+
title='[bold red]First Error (continuing with remaining items)[/bold red]',
|
|
200
|
+
border_style='yellow',
|
|
201
|
+
expand=False,
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
console.print(
|
|
205
|
+
f'[bold red]{type(error).__name__}[/bold red]: {error}'
|
|
206
|
+
)
|
|
207
|
+
if log_path:
|
|
208
|
+
console.print(f'[dim]Error log: {log_path}[/dim]')
|
|
209
|
+
console.print()
|
|
210
|
+
except ImportError:
|
|
211
|
+
# Fallback to plain print
|
|
212
|
+
print(f'\n--- First Error ---', file=sys.stderr)
|
|
213
|
+
print(f'{type(error).__name__}: {error}', file=sys.stderr)
|
|
214
|
+
if log_path:
|
|
215
|
+
print(f'Error log: {log_path}', file=sys.stderr)
|
|
216
|
+
print('', file=sys.stderr)
|
|
217
|
+
|
|
218
|
+
@property
|
|
219
|
+
def success_count(self) -> int:
|
|
220
|
+
with self._lock:
|
|
221
|
+
return self._success_count
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def error_count(self) -> int:
|
|
225
|
+
with self._lock:
|
|
226
|
+
return self._error_count
|
|
227
|
+
|
|
228
|
+
def get_postfix_dict(self) -> dict[str, int]:
|
|
229
|
+
"""Get dict for pbar postfix."""
|
|
230
|
+
with self._lock:
|
|
231
|
+
return {'ok': self._success_count, 'err': self._error_count}
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _should_skip_frame(filepath: str) -> bool:
|
|
235
|
+
"""Check if a frame should be filtered from traceback display."""
|
|
236
|
+
skip_patterns = [
|
|
237
|
+
'ray/_private',
|
|
238
|
+
'ray/worker',
|
|
239
|
+
'site-packages/ray',
|
|
240
|
+
'speedy_utils/multi_worker',
|
|
241
|
+
'concurrent/futures',
|
|
242
|
+
'multiprocessing/',
|
|
243
|
+
'fastcore/parallel',
|
|
244
|
+
'fastcore/foundation',
|
|
245
|
+
'fastcore/basics',
|
|
246
|
+
'site-packages/fastcore',
|
|
247
|
+
'/threading.py',
|
|
248
|
+
'/concurrent/',
|
|
249
|
+
]
|
|
250
|
+
return any(skip in filepath for skip in skip_patterns)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _should_show_local(name: str, value: object) -> bool:
|
|
254
|
+
"""Check if a local variable should be displayed in traceback."""
|
|
255
|
+
import types
|
|
256
|
+
|
|
257
|
+
# Skip dunder variables
|
|
258
|
+
if name.startswith('__') and name.endswith('__'):
|
|
259
|
+
return False
|
|
260
|
+
|
|
261
|
+
# Skip modules
|
|
262
|
+
if isinstance(value, types.ModuleType):
|
|
263
|
+
return False
|
|
264
|
+
|
|
265
|
+
# Skip type objects and classes
|
|
266
|
+
if isinstance(value, type):
|
|
267
|
+
return False
|
|
268
|
+
|
|
269
|
+
# Skip functions and methods
|
|
270
|
+
if isinstance(value, (types.FunctionType, types.MethodType, types.BuiltinFunctionType)):
|
|
271
|
+
return False
|
|
272
|
+
|
|
273
|
+
# Skip common typing aliases
|
|
274
|
+
value_str = str(value)
|
|
275
|
+
if value_str.startswith('typing.'):
|
|
276
|
+
return False
|
|
277
|
+
|
|
278
|
+
# Skip large objects that would clutter output
|
|
279
|
+
if value_str.startswith('<') and any(x in value_str for x in ['module', 'function', 'method', 'built-in']):
|
|
280
|
+
return False
|
|
281
|
+
|
|
282
|
+
return True
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _format_locals(frame_locals: dict) -> list[str]:
|
|
286
|
+
"""Format local variables for display, filtering out noisy imports."""
|
|
287
|
+
from rich.pretty import Pretty
|
|
288
|
+
from rich.console import Console
|
|
289
|
+
from io import StringIO
|
|
290
|
+
|
|
291
|
+
# Filter locals
|
|
292
|
+
clean_locals = {k: v for k, v in frame_locals.items() if _should_show_local(k, v)}
|
|
293
|
+
|
|
294
|
+
if not clean_locals:
|
|
295
|
+
return []
|
|
296
|
+
|
|
297
|
+
lines = []
|
|
298
|
+
lines.append('[dim]╭─ locals ─╮[/dim]')
|
|
299
|
+
|
|
300
|
+
# Format each local variable
|
|
301
|
+
for name, value in clean_locals.items():
|
|
302
|
+
# Use Rich's Pretty for nice formatting
|
|
303
|
+
try:
|
|
304
|
+
console = Console(file=StringIO(), width=60)
|
|
305
|
+
console.print(Pretty(value), end='')
|
|
306
|
+
value_str = console.file.getvalue().strip()
|
|
307
|
+
# Limit length
|
|
308
|
+
if len(value_str) > 100:
|
|
309
|
+
value_str = value_str[:97] + '...'
|
|
310
|
+
except Exception:
|
|
311
|
+
value_str = repr(value)
|
|
312
|
+
if len(value_str) > 100:
|
|
313
|
+
value_str = value_str[:97] + '...'
|
|
314
|
+
|
|
315
|
+
lines.append(f'[dim]│[/dim] {name} = {value_str}')
|
|
316
|
+
|
|
317
|
+
lines.append('[dim]╰──────────╯[/dim]')
|
|
318
|
+
return lines
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _format_frame_with_context(filepath: str, lineno: int, funcname: str, frame_locals: dict | None = None) -> list[str]:
|
|
322
|
+
"""Format a single frame with context lines and optional locals."""
|
|
323
|
+
lines = []
|
|
324
|
+
# Frame header
|
|
325
|
+
lines.append(
|
|
326
|
+
f'[cyan]{filepath}[/cyan]:[yellow]{lineno}[/yellow] '
|
|
327
|
+
f'in [green]{funcname}[/green]'
|
|
328
|
+
)
|
|
329
|
+
lines.append('')
|
|
330
|
+
|
|
331
|
+
# Get context lines
|
|
332
|
+
context_size = 3
|
|
333
|
+
start_line = max(1, lineno - context_size)
|
|
334
|
+
end_line = lineno + context_size + 1
|
|
335
|
+
|
|
336
|
+
for line_num in range(start_line, end_line):
|
|
337
|
+
import linecache
|
|
338
|
+
line_text = linecache.getline(filepath, line_num).rstrip()
|
|
339
|
+
if line_text:
|
|
340
|
+
num_str = str(line_num).rjust(4)
|
|
341
|
+
if line_num == lineno:
|
|
342
|
+
lines.append(f'[dim]{num_str}[/dim] [red]❱[/red] {line_text}')
|
|
343
|
+
else:
|
|
344
|
+
lines.append(f'[dim]{num_str} │[/dim] {line_text}')
|
|
345
|
+
|
|
346
|
+
# Add locals if available
|
|
347
|
+
if frame_locals:
|
|
348
|
+
locals_lines = _format_locals(frame_locals)
|
|
349
|
+
if locals_lines:
|
|
350
|
+
lines.append('')
|
|
351
|
+
lines.extend(locals_lines)
|
|
352
|
+
|
|
353
|
+
lines.append('')
|
|
354
|
+
return lines
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def _display_formatted_error(
|
|
358
|
+
exc_type_name: str,
|
|
359
|
+
exc_msg: str,
|
|
360
|
+
frames: list[tuple[str, int, str, dict]],
|
|
361
|
+
caller_info: dict | None,
|
|
362
|
+
backend: str,
|
|
363
|
+
pbar=None,
|
|
364
|
+
) -> None:
|
|
365
|
+
|
|
366
|
+
# Suppress additional error logs
|
|
367
|
+
os.environ['RAY_IGNORE_UNHANDLED_ERRORS'] = '1'
|
|
368
|
+
|
|
369
|
+
# Close progress bar cleanly if provided
|
|
370
|
+
if pbar is not None:
|
|
371
|
+
pbar.close()
|
|
372
|
+
|
|
373
|
+
from rich.console import Console
|
|
374
|
+
from rich.panel import Panel
|
|
375
|
+
console = Console(stderr=True)
|
|
376
|
+
|
|
377
|
+
if frames or caller_info:
|
|
378
|
+
display_lines = []
|
|
379
|
+
|
|
380
|
+
# Add caller frame first if available (no locals for caller)
|
|
381
|
+
if caller_info:
|
|
382
|
+
display_lines.extend(_format_frame_with_context(
|
|
383
|
+
caller_info['filename'],
|
|
384
|
+
caller_info['lineno'],
|
|
385
|
+
caller_info['function'],
|
|
386
|
+
None # Don't show locals for caller frame
|
|
387
|
+
))
|
|
388
|
+
|
|
389
|
+
# Add error frames with locals
|
|
390
|
+
for filepath, lineno, funcname, frame_locals in frames:
|
|
391
|
+
display_lines.extend(_format_frame_with_context(
|
|
392
|
+
filepath, lineno, funcname, frame_locals
|
|
393
|
+
))
|
|
394
|
+
|
|
395
|
+
# Display the traceback
|
|
396
|
+
console.print()
|
|
397
|
+
console.print(
|
|
398
|
+
Panel(
|
|
399
|
+
'\n'.join(display_lines),
|
|
400
|
+
title=f'[bold red]Traceback (most recent call last) [{backend}][/bold red]',
|
|
401
|
+
border_style='red',
|
|
402
|
+
expand=False,
|
|
403
|
+
)
|
|
404
|
+
)
|
|
405
|
+
console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
|
|
406
|
+
console.print()
|
|
407
|
+
else:
|
|
408
|
+
# No frames found, minimal output
|
|
409
|
+
console.print()
|
|
410
|
+
console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
|
|
411
|
+
console.print()
|
|
412
|
+
|
|
413
|
+
# Ensure output is flushed
|
|
414
|
+
sys.stderr.flush()
|
|
415
|
+
sys.stdout.flush()
|
|
416
|
+
sys.exit(1)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def _extract_frames_from_traceback(error: Exception) -> list[tuple[str, int, str, dict]]:
|
|
420
|
+
"""Extract user frames from exception traceback object with locals."""
|
|
421
|
+
frames = []
|
|
422
|
+
if hasattr(error, '__traceback__') and error.__traceback__ is not None:
|
|
423
|
+
tb = error.__traceback__
|
|
424
|
+
while tb is not None:
|
|
425
|
+
frame = tb.tb_frame
|
|
426
|
+
filename = frame.f_code.co_filename
|
|
427
|
+
lineno = tb.tb_lineno
|
|
428
|
+
funcname = frame.f_code.co_name
|
|
429
|
+
|
|
430
|
+
if not _should_skip_frame(filename):
|
|
431
|
+
# Get local variables from the frame
|
|
432
|
+
frame_locals = dict(frame.f_locals)
|
|
433
|
+
frames.append((filename, lineno, funcname, frame_locals))
|
|
434
|
+
|
|
435
|
+
tb = tb.tb_next
|
|
436
|
+
return frames
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def _extract_frames_from_ray_error(ray_task_error: Exception) -> list[tuple[str, int, str, dict]]:
|
|
440
|
+
"""Extract user frames from Ray's string traceback representation."""
|
|
441
|
+
frames = []
|
|
442
|
+
error_str = str(ray_task_error)
|
|
443
|
+
lines = error_str.split('\n')
|
|
444
|
+
|
|
445
|
+
import re
|
|
446
|
+
for i, line in enumerate(lines):
|
|
447
|
+
# Match: File "path", line N, in func
|
|
448
|
+
file_match = re.match(r'\s*File "([^"]+)", line (\d+), in (.+)', line)
|
|
449
|
+
if file_match:
|
|
450
|
+
filepath, lineno, funcname = file_match.groups()
|
|
451
|
+
if not _should_skip_frame(filepath):
|
|
452
|
+
# Ray doesn't preserve locals, so use empty dict
|
|
453
|
+
frames.append((filepath, int(lineno), funcname, {}))
|
|
454
|
+
|
|
455
|
+
return frames
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def _reraise_worker_error(error: Exception, pbar=None, caller_info=None, backend: str = 'unknown') -> None:
|
|
459
|
+
"""
|
|
460
|
+
Re-raise the original exception from a worker error with clean traceback.
|
|
461
|
+
Works for multiprocessing, threadpool, and other backends with real tracebacks.
|
|
462
|
+
"""
|
|
463
|
+
frames = _extract_frames_from_traceback(error)
|
|
464
|
+
_display_formatted_error(
|
|
465
|
+
exc_type_name=type(error).__name__,
|
|
466
|
+
exc_msg=str(error),
|
|
467
|
+
frames=frames,
|
|
468
|
+
caller_info=caller_info,
|
|
469
|
+
backend=backend,
|
|
470
|
+
pbar=pbar,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def _reraise_ray_error(ray_task_error: Exception, pbar=None, caller_info=None) -> None:
|
|
475
|
+
"""
|
|
476
|
+
Re-raise the original exception from a RayTaskError with clean traceback.
|
|
477
|
+
Parses Ray's string traceback and displays with full context.
|
|
478
|
+
"""
|
|
479
|
+
# Get the exception info
|
|
480
|
+
cause = ray_task_error.cause if hasattr(ray_task_error, 'cause') else None
|
|
481
|
+
if cause is None:
|
|
482
|
+
cause = ray_task_error.__cause__
|
|
483
|
+
|
|
484
|
+
exc_type_name = type(cause).__name__ if cause else 'Error'
|
|
485
|
+
exc_msg = str(cause) if cause else str(ray_task_error)
|
|
486
|
+
|
|
487
|
+
frames = _extract_frames_from_ray_error(ray_task_error)
|
|
488
|
+
_display_formatted_error(
|
|
489
|
+
exc_type_name=exc_type_name,
|
|
490
|
+
exc_msg=exc_msg,
|
|
491
|
+
frames=frames,
|
|
492
|
+
caller_info=caller_info,
|
|
493
|
+
backend='ray',
|
|
494
|
+
pbar=pbar,
|
|
495
|
+
)
|
|
496
|
+
|
|
17
497
|
|
|
18
498
|
SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
|
|
19
499
|
_SPEEDY_PROCESSES_LOCK = threading.Lock()
|
|
@@ -96,7 +576,7 @@ def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
|
|
|
96
576
|
path = Path('.cache') / run_id
|
|
97
577
|
path.mkdir(parents=True, exist_ok=True)
|
|
98
578
|
return path
|
|
99
|
-
|
|
579
|
+
_DUMP_INTERMEDIATE_THREADS = []
|
|
100
580
|
def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = True):
|
|
101
581
|
"""Wrap a function so results are dumped to .pkl when cache_dir is set."""
|
|
102
582
|
if cache_dir is None:
|
|
@@ -115,7 +595,7 @@ def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = Tru
|
|
|
115
595
|
|
|
116
596
|
if dump_in_thread:
|
|
117
597
|
thread = threading.Thread(target=save)
|
|
118
|
-
|
|
598
|
+
_DUMP_INTERMEDIATE_THREADS.append(thread)
|
|
119
599
|
# count thread
|
|
120
600
|
# print(f'Thread count: {threading.active_count()}')
|
|
121
601
|
while threading.active_count() > 16:
|
|
@@ -343,8 +823,7 @@ def multi_process(
|
|
|
343
823
|
workers: int | None = None,
|
|
344
824
|
lazy_output: bool = False,
|
|
345
825
|
progress: bool = True,
|
|
346
|
-
|
|
347
|
-
backend: Literal['seq', 'ray', 'mp', 'threadpool', 'safe'] = 'mp',
|
|
826
|
+
backend: Literal['seq', 'ray', 'mp', 'safe'] = 'mp',
|
|
348
827
|
desc: str | None = None,
|
|
349
828
|
shared_kwargs: list[str] | None = None,
|
|
350
829
|
dump_in_thread: bool = True,
|
|
@@ -352,6 +831,8 @@ def multi_process(
|
|
|
352
831
|
log_worker: Literal['zero', 'first', 'all'] = 'first',
|
|
353
832
|
total_items: int | None = None,
|
|
354
833
|
poll_interval: float = 0.3,
|
|
834
|
+
error_handler: ErrorHandlerType = 'log',
|
|
835
|
+
max_error_files: int = 100,
|
|
355
836
|
**func_kwargs: Any,
|
|
356
837
|
) -> list[Any]:
|
|
357
838
|
"""
|
|
@@ -360,36 +841,44 @@ def multi_process(
|
|
|
360
841
|
backend:
|
|
361
842
|
- "seq": run sequentially
|
|
362
843
|
- "ray": run in parallel with Ray
|
|
363
|
-
- "mp": run in parallel with
|
|
364
|
-
- "threadpool": run in parallel with thread pool
|
|
844
|
+
- "mp": run in parallel with thread pool (uses ThreadPoolExecutor)
|
|
365
845
|
- "safe": run in parallel with thread pool (explicitly safe for tests)
|
|
366
846
|
|
|
367
847
|
shared_kwargs:
|
|
368
|
-
- Optional list of kwarg names that should be shared via Ray's
|
|
848
|
+
- Optional list of kwarg names that should be shared via Ray's
|
|
849
|
+
zero-copy object store
|
|
369
850
|
- Only works with Ray backend
|
|
370
|
-
- Useful for large objects (e.g., models, datasets)
|
|
371
|
-
- Example: shared_kwargs=['model', 'tokenizer']
|
|
851
|
+
- Useful for large objects (e.g., models, datasets)
|
|
852
|
+
- Example: shared_kwargs=['model', 'tokenizer']
|
|
372
853
|
|
|
373
854
|
dump_in_thread:
|
|
374
855
|
- Whether to dump results to disk in a separate thread (default: True)
|
|
375
|
-
- If False, dumping is done synchronously
|
|
856
|
+
- If False, dumping is done synchronously
|
|
376
857
|
|
|
377
858
|
ray_metrics_port:
|
|
378
859
|
- Optional port for Ray metrics export (Ray backend only)
|
|
379
|
-
- Set to 0 to disable Ray metrics
|
|
380
|
-
- If None, uses Ray's default behavior
|
|
381
860
|
|
|
382
861
|
log_worker:
|
|
383
862
|
- Control worker stdout/stderr noise
|
|
384
|
-
- 'first': only first worker emits logs
|
|
385
|
-
- 'all': allow worker prints
|
|
386
|
-
- 'zero': silence worker
|
|
863
|
+
- 'first': only first worker emits logs (default)
|
|
864
|
+
- 'all': allow worker prints
|
|
865
|
+
- 'zero': silence all worker output
|
|
387
866
|
|
|
388
867
|
total_items:
|
|
389
868
|
- Optional item-level total for progress tracking (Ray backend only)
|
|
390
869
|
|
|
391
870
|
poll_interval:
|
|
392
|
-
- Poll interval in seconds for progress actor updates (Ray
|
|
871
|
+
- Poll interval in seconds for progress actor updates (Ray only)
|
|
872
|
+
|
|
873
|
+
error_handler:
|
|
874
|
+
- 'raise': raise exception on first error (default)
|
|
875
|
+
- 'ignore': continue processing, return None for failed items
|
|
876
|
+
- 'log': same as ignore, but logs errors to files
|
|
877
|
+
|
|
878
|
+
max_error_files:
|
|
879
|
+
- Maximum number of error log files to write (default: 100)
|
|
880
|
+
- Error logs are written to .cache/speedy_utils/error_logs/{idx}.log
|
|
881
|
+
- First error is always printed to screen with the log file path
|
|
393
882
|
|
|
394
883
|
If lazy_output=True, every result is saved to .pkl and
|
|
395
884
|
the returned list contains file paths.
|
|
@@ -426,7 +915,7 @@ def multi_process(
|
|
|
426
915
|
)
|
|
427
916
|
|
|
428
917
|
# Prefer Ray backend when shared kwargs are requested
|
|
429
|
-
if backend != 'ray':
|
|
918
|
+
if shared_kwargs and backend != 'ray':
|
|
430
919
|
warnings.warn(
|
|
431
920
|
"shared_kwargs only supported with 'ray' backend, switching backend to 'ray'",
|
|
432
921
|
UserWarning,
|
|
@@ -462,21 +951,62 @@ def multi_process(
|
|
|
462
951
|
else:
|
|
463
952
|
desc = f'Multi-process [{backend}]'
|
|
464
953
|
|
|
954
|
+
# Initialize error stats for error handling
|
|
955
|
+
func_name = getattr(func, '__name__', repr(func))
|
|
956
|
+
error_stats = ErrorStats(
|
|
957
|
+
func_name=func_name,
|
|
958
|
+
max_error_files=max_error_files,
|
|
959
|
+
write_logs=error_handler == 'log'
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
def _update_pbar_postfix(pbar: tqdm) -> None:
|
|
963
|
+
"""Update pbar with success/error counts."""
|
|
964
|
+
postfix = error_stats.get_postfix_dict()
|
|
965
|
+
pbar.set_postfix(postfix)
|
|
966
|
+
|
|
967
|
+
def _wrap_with_error_handler(
|
|
968
|
+
f: Callable,
|
|
969
|
+
idx: int,
|
|
970
|
+
input_value: Any,
|
|
971
|
+
error_stats_ref: ErrorStats,
|
|
972
|
+
handler: ErrorHandlerType,
|
|
973
|
+
) -> Callable:
|
|
974
|
+
"""Wrap function to handle errors based on error_handler setting."""
|
|
975
|
+
@functools.wraps(f)
|
|
976
|
+
def wrapper(*args, **kwargs):
|
|
977
|
+
try:
|
|
978
|
+
result = f(*args, **kwargs)
|
|
979
|
+
error_stats_ref.record_success()
|
|
980
|
+
return result
|
|
981
|
+
except Exception as e:
|
|
982
|
+
if handler == 'raise':
|
|
983
|
+
raise
|
|
984
|
+
error_stats_ref.record_error(idx, e, input_value, func_name)
|
|
985
|
+
return None
|
|
986
|
+
return wrapper
|
|
987
|
+
|
|
465
988
|
# ---- sequential backend ----
|
|
466
989
|
if backend == 'seq':
|
|
467
990
|
results: list[Any] = []
|
|
468
991
|
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
469
|
-
for x in items:
|
|
470
|
-
|
|
471
|
-
_call_with_log_control(
|
|
992
|
+
for idx, x in enumerate(items):
|
|
993
|
+
try:
|
|
994
|
+
result = _call_with_log_control(
|
|
472
995
|
f_wrapped,
|
|
473
996
|
x,
|
|
474
997
|
func_kwargs,
|
|
475
998
|
log_worker,
|
|
476
999
|
log_gate_path,
|
|
477
1000
|
)
|
|
478
|
-
|
|
1001
|
+
error_stats.record_success()
|
|
1002
|
+
results.append(result)
|
|
1003
|
+
except Exception as e:
|
|
1004
|
+
if error_handler == 'raise':
|
|
1005
|
+
raise
|
|
1006
|
+
error_stats.record_error(idx, e, x, func_name)
|
|
1007
|
+
results.append(None)
|
|
479
1008
|
pbar.update(1)
|
|
1009
|
+
_update_pbar_postfix(pbar)
|
|
480
1010
|
_cleanup_log_gate(log_gate_path)
|
|
481
1011
|
return results
|
|
482
1012
|
|
|
@@ -484,6 +1014,16 @@ def multi_process(
|
|
|
484
1014
|
if backend == 'ray':
|
|
485
1015
|
import ray as _ray_module
|
|
486
1016
|
|
|
1017
|
+
# Capture caller frame for better error reporting
|
|
1018
|
+
caller_frame = inspect.currentframe()
|
|
1019
|
+
caller_info = None
|
|
1020
|
+
if caller_frame and caller_frame.f_back:
|
|
1021
|
+
caller_info = {
|
|
1022
|
+
'filename': caller_frame.f_back.f_code.co_filename,
|
|
1023
|
+
'lineno': caller_frame.f_back.f_lineno,
|
|
1024
|
+
'function': caller_frame.f_back.f_code.co_name,
|
|
1025
|
+
}
|
|
1026
|
+
|
|
487
1027
|
results = []
|
|
488
1028
|
gate_path_str = str(log_gate_path) if log_gate_path else None
|
|
489
1029
|
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
@@ -496,7 +1036,7 @@ def multi_process(
|
|
|
496
1036
|
progress_poller = None
|
|
497
1037
|
if total_items is not None:
|
|
498
1038
|
progress_actor = create_progress_tracker(total_items, desc or "Items")
|
|
499
|
-
shared_refs['progress_actor'] = progress_actor
|
|
1039
|
+
shared_refs['progress_actor'] = progress_actor
|
|
500
1040
|
|
|
501
1041
|
if shared_kwargs:
|
|
502
1042
|
for kw in shared_kwargs:
|
|
@@ -520,11 +1060,9 @@ def multi_process(
|
|
|
520
1060
|
dereferenced = {}
|
|
521
1061
|
for k, v in shared_refs_dict.items():
|
|
522
1062
|
if k == 'progress_actor':
|
|
523
|
-
# Pass actor handle directly (don't dereference)
|
|
524
1063
|
dereferenced[k] = v
|
|
525
1064
|
else:
|
|
526
1065
|
dereferenced[k] = _ray_in_task.get(v)
|
|
527
|
-
# Merge with regular kwargs
|
|
528
1066
|
all_kwargs = {**dereferenced, **regular_kwargs_dict}
|
|
529
1067
|
return _call_with_log_control(
|
|
530
1068
|
f_wrapped,
|
|
@@ -540,21 +1078,32 @@ def multi_process(
|
|
|
540
1078
|
|
|
541
1079
|
t_start = time.time()
|
|
542
1080
|
|
|
543
|
-
# Start progress poller if using item-level progress
|
|
544
1081
|
if progress_actor is not None:
|
|
545
|
-
# Update pbar total to show items instead of tasks
|
|
546
1082
|
pbar.total = total_items
|
|
547
1083
|
pbar.refresh()
|
|
548
1084
|
progress_poller = ProgressPoller(progress_actor, pbar, poll_interval)
|
|
549
1085
|
progress_poller.start()
|
|
550
1086
|
|
|
551
|
-
for r in refs:
|
|
552
|
-
|
|
1087
|
+
for idx, r in enumerate(refs):
|
|
1088
|
+
try:
|
|
1089
|
+
result = _ray_module.get(r)
|
|
1090
|
+
error_stats.record_success()
|
|
1091
|
+
results.append(result)
|
|
1092
|
+
except _ray_module.exceptions.RayTaskError as e:
|
|
1093
|
+
if error_handler == 'raise':
|
|
1094
|
+
if progress_poller is not None:
|
|
1095
|
+
progress_poller.stop()
|
|
1096
|
+
_reraise_ray_error(e, pbar, caller_info)
|
|
1097
|
+
# Extract original error from RayTaskError
|
|
1098
|
+
cause = e.cause if hasattr(e, 'cause') else e.__cause__
|
|
1099
|
+
original_error = cause if cause else e
|
|
1100
|
+
error_stats.record_error(idx, original_error, items[idx], func_name)
|
|
1101
|
+
results.append(None)
|
|
1102
|
+
|
|
553
1103
|
if progress_actor is None:
|
|
554
|
-
# Only update task-level progress if not using item-level
|
|
555
1104
|
pbar.update(1)
|
|
1105
|
+
_update_pbar_postfix(pbar)
|
|
556
1106
|
|
|
557
|
-
# Stop progress poller
|
|
558
1107
|
if progress_poller is not None:
|
|
559
1108
|
progress_poller.stop()
|
|
560
1109
|
|
|
@@ -564,73 +1113,132 @@ def multi_process(
|
|
|
564
1113
|
_cleanup_log_gate(log_gate_path)
|
|
565
1114
|
return results
|
|
566
1115
|
|
|
567
|
-
# ---- fastcore backend ----
|
|
1116
|
+
# ---- fastcore/thread backend (mp) ----
|
|
568
1117
|
if backend == 'mp':
|
|
569
|
-
worker_func = functools.partial(
|
|
570
|
-
_call_with_log_control,
|
|
571
|
-
f_wrapped,
|
|
572
|
-
func_kwargs=func_kwargs,
|
|
573
|
-
log_mode=log_worker,
|
|
574
|
-
gate_path=log_gate_path,
|
|
575
|
-
)
|
|
576
|
-
with _patch_fastcore_progress_bar(leave=True):
|
|
577
|
-
results = list(parallel(
|
|
578
|
-
worker_func, items, n_workers=workers, progress=progress, threadpool=False
|
|
579
|
-
))
|
|
580
|
-
_track_multiprocessing_processes() # Track multiprocessing workers
|
|
581
|
-
_prune_dead_processes() # Clean up dead processes
|
|
582
|
-
_cleanup_log_gate(log_gate_path)
|
|
583
|
-
return results
|
|
584
|
-
|
|
585
|
-
if backend == 'threadpool':
|
|
586
|
-
worker_func = functools.partial(
|
|
587
|
-
_call_with_log_control,
|
|
588
|
-
f_wrapped,
|
|
589
|
-
func_kwargs=func_kwargs,
|
|
590
|
-
log_mode=log_worker,
|
|
591
|
-
gate_path=log_gate_path,
|
|
592
|
-
)
|
|
593
|
-
with _patch_fastcore_progress_bar(leave=True):
|
|
594
|
-
results = list(parallel(
|
|
595
|
-
worker_func, items, n_workers=workers, progress=progress, threadpool=True
|
|
596
|
-
))
|
|
597
|
-
_cleanup_log_gate(log_gate_path)
|
|
598
|
-
return results
|
|
599
|
-
|
|
600
|
-
if backend == 'safe':
|
|
601
|
-
# Completely safe backend for tests - no multiprocessing, no external progress bars
|
|
602
1118
|
import concurrent.futures
|
|
603
|
-
|
|
604
|
-
#
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
1119
|
+
|
|
1120
|
+
# Capture caller frame for better error reporting
|
|
1121
|
+
caller_frame = inspect.currentframe()
|
|
1122
|
+
caller_info = None
|
|
1123
|
+
if caller_frame and caller_frame.f_back:
|
|
1124
|
+
caller_info = {
|
|
1125
|
+
'filename': caller_frame.f_back.f_code.co_filename,
|
|
1126
|
+
'lineno': caller_frame.f_back.f_lineno,
|
|
1127
|
+
'function': caller_frame.f_back.f_code.co_name,
|
|
1128
|
+
}
|
|
1129
|
+
|
|
1130
|
+
def worker_func(x):
|
|
1131
|
+
return _call_with_log_control(
|
|
610
1132
|
f_wrapped,
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
1133
|
+
x,
|
|
1134
|
+
func_kwargs,
|
|
1135
|
+
log_worker,
|
|
1136
|
+
log_gate_path,
|
|
614
1137
|
)
|
|
1138
|
+
|
|
1139
|
+
results: list[Any] = [None] * total
|
|
1140
|
+
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
1141
|
+
try:
|
|
1142
|
+
from .thread import _prune_dead_threads, _track_executor_threads
|
|
1143
|
+
has_thread_tracking = True
|
|
1144
|
+
except ImportError:
|
|
1145
|
+
has_thread_tracking = False
|
|
1146
|
+
|
|
615
1147
|
with concurrent.futures.ThreadPoolExecutor(
|
|
616
1148
|
max_workers=workers
|
|
617
1149
|
) as executor:
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
1150
|
+
if has_thread_tracking:
|
|
1151
|
+
_track_executor_threads(executor)
|
|
1152
|
+
|
|
1153
|
+
# Submit all tasks
|
|
1154
|
+
future_to_idx = {
|
|
1155
|
+
executor.submit(worker_func, x): idx
|
|
1156
|
+
for idx, x in enumerate(items)
|
|
1157
|
+
}
|
|
1158
|
+
|
|
1159
|
+
# Process results as they complete
|
|
1160
|
+
for future in concurrent.futures.as_completed(future_to_idx):
|
|
1161
|
+
idx = future_to_idx[future]
|
|
1162
|
+
try:
|
|
1163
|
+
result = future.result()
|
|
1164
|
+
error_stats.record_success()
|
|
1165
|
+
results[idx] = result
|
|
1166
|
+
except Exception as e:
|
|
1167
|
+
if error_handler == 'raise':
|
|
1168
|
+
# Cancel remaining futures
|
|
1169
|
+
for f in future_to_idx:
|
|
1170
|
+
f.cancel()
|
|
1171
|
+
_reraise_worker_error(e, pbar, caller_info, backend='mp')
|
|
1172
|
+
error_stats.record_error(idx, e, items[idx], func_name)
|
|
1173
|
+
results[idx] = None
|
|
1174
|
+
pbar.update(1)
|
|
1175
|
+
_update_pbar_postfix(pbar)
|
|
1176
|
+
|
|
1177
|
+
if _prune_dead_threads is not None:
|
|
1178
|
+
_prune_dead_threads()
|
|
1179
|
+
|
|
1180
|
+
_track_multiprocessing_processes()
|
|
1181
|
+
_prune_dead_processes()
|
|
1182
|
+
_cleanup_log_gate(log_gate_path)
|
|
1183
|
+
return results
|
|
1184
|
+
|
|
1185
|
+
if backend == 'safe':
|
|
1186
|
+
# Completely safe backend for tests - no multiprocessing
|
|
1187
|
+
import concurrent.futures
|
|
1188
|
+
|
|
1189
|
+
# Capture caller frame for better error reporting
|
|
1190
|
+
caller_frame = inspect.currentframe()
|
|
1191
|
+
caller_info = None
|
|
1192
|
+
if caller_frame and caller_frame.f_back:
|
|
1193
|
+
caller_info = {
|
|
1194
|
+
'filename': caller_frame.f_back.f_code.co_filename,
|
|
1195
|
+
'lineno': caller_frame.f_back.f_lineno,
|
|
1196
|
+
'function': caller_frame.f_back.f_code.co_name,
|
|
1197
|
+
}
|
|
1198
|
+
|
|
1199
|
+
def worker_func(x):
|
|
1200
|
+
return _call_with_log_control(
|
|
625
1201
|
f_wrapped,
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
1202
|
+
x,
|
|
1203
|
+
func_kwargs,
|
|
1204
|
+
log_worker,
|
|
1205
|
+
log_gate_path,
|
|
629
1206
|
)
|
|
1207
|
+
|
|
1208
|
+
results: list[Any] = [None] * total
|
|
1209
|
+
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
630
1210
|
with concurrent.futures.ThreadPoolExecutor(
|
|
631
1211
|
max_workers=workers
|
|
632
1212
|
) as executor:
|
|
633
|
-
|
|
1213
|
+
if _track_executor_threads is not None:
|
|
1214
|
+
_track_executor_threads(executor)
|
|
1215
|
+
|
|
1216
|
+
# Submit all tasks
|
|
1217
|
+
future_to_idx = {
|
|
1218
|
+
executor.submit(worker_func, x): idx
|
|
1219
|
+
for idx, x in enumerate(items)
|
|
1220
|
+
}
|
|
1221
|
+
|
|
1222
|
+
# Process results as they complete
|
|
1223
|
+
for future in concurrent.futures.as_completed(future_to_idx):
|
|
1224
|
+
idx = future_to_idx[future]
|
|
1225
|
+
try:
|
|
1226
|
+
result = future.result()
|
|
1227
|
+
error_stats.record_success()
|
|
1228
|
+
results[idx] = result
|
|
1229
|
+
except Exception as e:
|
|
1230
|
+
if error_handler == 'raise':
|
|
1231
|
+
for f in future_to_idx:
|
|
1232
|
+
f.cancel()
|
|
1233
|
+
_reraise_worker_error(e, pbar, caller_info, backend='safe')
|
|
1234
|
+
error_stats.record_error(idx, e, items[idx], func_name)
|
|
1235
|
+
results[idx] = None
|
|
1236
|
+
pbar.update(1)
|
|
1237
|
+
_update_pbar_postfix(pbar)
|
|
1238
|
+
|
|
1239
|
+
if _prune_dead_threads is not None:
|
|
1240
|
+
_prune_dead_threads()
|
|
1241
|
+
|
|
634
1242
|
_cleanup_log_gate(log_gate_path)
|
|
635
1243
|
return results
|
|
636
1244
|
|
|
@@ -692,6 +1300,8 @@ def cleanup_phantom_workers():
|
|
|
692
1300
|
|
|
693
1301
|
__all__ = [
|
|
694
1302
|
'SPEEDY_RUNNING_PROCESSES',
|
|
1303
|
+
'ErrorStats',
|
|
1304
|
+
'ErrorHandlerType',
|
|
695
1305
|
'multi_process',
|
|
696
1306
|
'cleanup_phantom_workers',
|
|
697
1307
|
'create_progress_tracker',
|