speedy-utils 1.1.46__py3-none-any.whl → 1.1.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/lm/llm.py +41 -12
- speedy_utils/__init__.py +4 -0
- speedy_utils/multi_worker/__init__.py +4 -0
- speedy_utils/multi_worker/_multi_process.py +425 -0
- speedy_utils/multi_worker/_multi_process_ray.py +308 -0
- speedy_utils/multi_worker/common.py +879 -0
- speedy_utils/multi_worker/dataset_sharding.py +203 -0
- speedy_utils/multi_worker/process.py +53 -1234
- speedy_utils/multi_worker/progress.py +71 -1
- speedy_utils/multi_worker/thread.py +45 -0
- speedy_utils/scripts/mpython.py +19 -12
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.47.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.47.dist-info}/RECORD +15 -11
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.47.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.47.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,879 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Common utilities shared across multi_process backends.
|
|
3
|
+
|
|
4
|
+
Includes:
|
|
5
|
+
- Error formatting and logging
|
|
6
|
+
- Log gating (stdout/stderr control)
|
|
7
|
+
- Cache helpers
|
|
8
|
+
- Process/thread tracking
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import contextlib
|
|
13
|
+
import linecache
|
|
14
|
+
import os
|
|
15
|
+
import pickle
|
|
16
|
+
import sys
|
|
17
|
+
import tempfile
|
|
18
|
+
import threading
|
|
19
|
+
import time
|
|
20
|
+
import uuid
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal
|
|
23
|
+
|
|
24
|
+
import psutil
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from tqdm import tqdm
|
|
28
|
+
|
|
29
|
+
# ─── error handler types ────────────────────────────────────────
|
|
30
|
+
ErrorHandlerType = Literal['raise', 'ignore', 'log']
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ErrorStats:
|
|
34
|
+
"""Thread-safe error statistics tracker."""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
func_name: str,
|
|
39
|
+
max_error_files: int = 100,
|
|
40
|
+
write_logs: bool = True
|
|
41
|
+
):
|
|
42
|
+
self._lock = threading.Lock()
|
|
43
|
+
self._success_count = 0
|
|
44
|
+
self._error_count = 0
|
|
45
|
+
self._first_error_shown = False
|
|
46
|
+
self._max_error_files = max_error_files
|
|
47
|
+
self._write_logs = write_logs
|
|
48
|
+
self._error_log_dir = self._get_error_log_dir(func_name)
|
|
49
|
+
if write_logs:
|
|
50
|
+
self._error_log_dir.mkdir(parents=True, exist_ok=True)
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def _get_error_log_dir(func_name: str) -> Path:
|
|
54
|
+
"""Generate unique error log directory with run counter."""
|
|
55
|
+
base_dir = Path('.cache/speedy_utils/error_logs')
|
|
56
|
+
base_dir.mkdir(parents=True, exist_ok=True)
|
|
57
|
+
|
|
58
|
+
# Find the next run counter
|
|
59
|
+
counter = 1
|
|
60
|
+
existing = list(base_dir.glob(f'{func_name}_run_*'))
|
|
61
|
+
if existing:
|
|
62
|
+
counters = []
|
|
63
|
+
for p in existing:
|
|
64
|
+
try:
|
|
65
|
+
parts = p.name.split('_run_')
|
|
66
|
+
if len(parts) == 2:
|
|
67
|
+
counters.append(int(parts[1]))
|
|
68
|
+
except (ValueError, IndexError):
|
|
69
|
+
pass
|
|
70
|
+
if counters:
|
|
71
|
+
counter = max(counters) + 1
|
|
72
|
+
|
|
73
|
+
return base_dir / f'{func_name}_run_{counter}'
|
|
74
|
+
|
|
75
|
+
def record_success(self) -> None:
|
|
76
|
+
with self._lock:
|
|
77
|
+
self._success_count += 1
|
|
78
|
+
|
|
79
|
+
def record_error(
|
|
80
|
+
self,
|
|
81
|
+
idx: int,
|
|
82
|
+
error: Exception,
|
|
83
|
+
input_value: Any,
|
|
84
|
+
func_name: str,
|
|
85
|
+
ray_task_error: Exception | None = None,
|
|
86
|
+
) -> str | None:
|
|
87
|
+
"""
|
|
88
|
+
Record an error and write to log file.
|
|
89
|
+
Returns the log file path if written, None otherwise.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
ray_task_error: Optional RayTaskError for fallback frame extraction
|
|
93
|
+
when the native traceback is unavailable.
|
|
94
|
+
"""
|
|
95
|
+
with self._lock:
|
|
96
|
+
self._error_count += 1
|
|
97
|
+
should_show_first = not self._first_error_shown
|
|
98
|
+
if should_show_first:
|
|
99
|
+
self._first_error_shown = True
|
|
100
|
+
should_write = (
|
|
101
|
+
self._write_logs and self._error_count <= self._max_error_files
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
log_path = None
|
|
105
|
+
if should_write:
|
|
106
|
+
log_path = self._write_error_log(
|
|
107
|
+
idx, error, input_value, func_name, ray_task_error
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
if should_show_first:
|
|
111
|
+
self._print_first_error(
|
|
112
|
+
error, input_value, func_name, log_path, ray_task_error
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return log_path
|
|
116
|
+
|
|
117
|
+
def _write_error_log(
|
|
118
|
+
self,
|
|
119
|
+
idx: int,
|
|
120
|
+
error: Exception,
|
|
121
|
+
input_value: Any,
|
|
122
|
+
func_name: str,
|
|
123
|
+
ray_task_error: Exception | None = None,
|
|
124
|
+
) -> str:
|
|
125
|
+
"""Write error details to a log file."""
|
|
126
|
+
from io import StringIO
|
|
127
|
+
from rich.console import Console
|
|
128
|
+
|
|
129
|
+
log_path = self._error_log_dir / f'{idx}.log'
|
|
130
|
+
|
|
131
|
+
output = StringIO()
|
|
132
|
+
console = Console(file=output, width=120, no_color=False)
|
|
133
|
+
|
|
134
|
+
# Format traceback using unified extraction
|
|
135
|
+
tb_lines = _format_traceback_lines(
|
|
136
|
+
_extract_frames(error, ray_task_error),
|
|
137
|
+
include_locals=False,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
console.print(f'{"=" * 60}')
|
|
141
|
+
console.print(f'Error at index: {idx}')
|
|
142
|
+
console.print(f'Function: {func_name}')
|
|
143
|
+
console.print(f'Error Type: {type(error).__name__}')
|
|
144
|
+
console.print(f'Error Message: {error}')
|
|
145
|
+
console.print(f'{"=" * 60}')
|
|
146
|
+
console.print('')
|
|
147
|
+
console.print('Input:')
|
|
148
|
+
console.print('-' * 40)
|
|
149
|
+
try:
|
|
150
|
+
import json
|
|
151
|
+
console.print(json.dumps(input_value, indent=2))
|
|
152
|
+
except Exception:
|
|
153
|
+
console.print(repr(input_value))
|
|
154
|
+
console.print('')
|
|
155
|
+
console.print('Traceback:')
|
|
156
|
+
console.print('-' * 40)
|
|
157
|
+
for line in tb_lines:
|
|
158
|
+
console.print(line)
|
|
159
|
+
|
|
160
|
+
with open(log_path, 'w') as f:
|
|
161
|
+
f.write(output.getvalue())
|
|
162
|
+
|
|
163
|
+
return str(log_path)
|
|
164
|
+
|
|
165
|
+
def _print_first_error(
|
|
166
|
+
self,
|
|
167
|
+
error: Exception,
|
|
168
|
+
input_value: Any,
|
|
169
|
+
func_name: str,
|
|
170
|
+
log_path: str | None,
|
|
171
|
+
ray_task_error: Exception | None = None,
|
|
172
|
+
) -> None:
|
|
173
|
+
"""Print the first error to screen with Rich formatting."""
|
|
174
|
+
try:
|
|
175
|
+
from rich.console import Console
|
|
176
|
+
from rich.panel import Panel
|
|
177
|
+
console = Console(stderr=True)
|
|
178
|
+
|
|
179
|
+
# Use unified frame extraction
|
|
180
|
+
tb_lines = _format_traceback_lines(
|
|
181
|
+
_extract_frames(error, ray_task_error),
|
|
182
|
+
include_locals=False,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
console.print()
|
|
186
|
+
console.print(
|
|
187
|
+
Panel(
|
|
188
|
+
'\n'.join(tb_lines),
|
|
189
|
+
title=(
|
|
190
|
+
'[bold red]First Error '
|
|
191
|
+
'(continuing with remaining items)[/bold red]'
|
|
192
|
+
),
|
|
193
|
+
border_style='yellow',
|
|
194
|
+
expand=False,
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
console.print(
|
|
198
|
+
f'[bold red]{type(error).__name__}[/bold red]: {error}'
|
|
199
|
+
)
|
|
200
|
+
if log_path:
|
|
201
|
+
console.print(f'[dim]Error log: {log_path}[/dim]')
|
|
202
|
+
console.print()
|
|
203
|
+
except ImportError:
|
|
204
|
+
# Fallback to plain print
|
|
205
|
+
print('\n--- First Error ---', file=sys.stderr)
|
|
206
|
+
print(f'{type(error).__name__}: {error}', file=sys.stderr)
|
|
207
|
+
if log_path:
|
|
208
|
+
print(f'Error log: {log_path}', file=sys.stderr)
|
|
209
|
+
print('', file=sys.stderr)
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def success_count(self) -> int:
|
|
213
|
+
with self._lock:
|
|
214
|
+
return self._success_count
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def error_count(self) -> int:
|
|
218
|
+
with self._lock:
|
|
219
|
+
return self._error_count
|
|
220
|
+
|
|
221
|
+
def get_postfix_dict(self) -> dict[str, int]:
|
|
222
|
+
"""Get dict for pbar postfix."""
|
|
223
|
+
with self._lock:
|
|
224
|
+
return {'ok': self._success_count, 'err': self._error_count}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
# ─── Traceback formatting utilities ─────────────────────────────
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _should_skip_frame(filepath: str) -> bool:
|
|
231
|
+
"""Check if a frame should be filtered from traceback display."""
|
|
232
|
+
skip_patterns = [
|
|
233
|
+
'ray/_private',
|
|
234
|
+
'ray/worker',
|
|
235
|
+
'site-packages/ray',
|
|
236
|
+
'speedy_utils/multi_worker',
|
|
237
|
+
'concurrent/futures',
|
|
238
|
+
'multiprocessing/',
|
|
239
|
+
'fastcore/parallel',
|
|
240
|
+
'fastcore/foundation',
|
|
241
|
+
'fastcore/basics',
|
|
242
|
+
'site-packages/fastcore',
|
|
243
|
+
'/threading.py',
|
|
244
|
+
'/concurrent/',
|
|
245
|
+
]
|
|
246
|
+
return any(skip in filepath for skip in skip_patterns)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def _should_show_local(name: str, value: object) -> bool:
|
|
250
|
+
"""Check if a local variable should be displayed in traceback."""
|
|
251
|
+
import types
|
|
252
|
+
|
|
253
|
+
# Skip dunder variables
|
|
254
|
+
if name.startswith('__') and name.endswith('__'):
|
|
255
|
+
return False
|
|
256
|
+
|
|
257
|
+
# Skip modules
|
|
258
|
+
if isinstance(value, types.ModuleType):
|
|
259
|
+
return False
|
|
260
|
+
|
|
261
|
+
# Skip type objects and classes
|
|
262
|
+
if isinstance(value, type):
|
|
263
|
+
return False
|
|
264
|
+
|
|
265
|
+
# Skip functions and methods
|
|
266
|
+
if isinstance(
|
|
267
|
+
value,
|
|
268
|
+
(types.FunctionType, types.MethodType, types.BuiltinFunctionType)
|
|
269
|
+
):
|
|
270
|
+
return False
|
|
271
|
+
|
|
272
|
+
# Skip common typing aliases
|
|
273
|
+
value_str = str(value)
|
|
274
|
+
if value_str.startswith('typing.'):
|
|
275
|
+
return False
|
|
276
|
+
|
|
277
|
+
# Skip large objects that would clutter output
|
|
278
|
+
skip_markers = ['module', 'function', 'method', 'built-in']
|
|
279
|
+
if value_str.startswith('<') and any(x in value_str for x in skip_markers):
|
|
280
|
+
return False
|
|
281
|
+
|
|
282
|
+
return True
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _format_locals(frame_locals: dict) -> list[str]:
|
|
286
|
+
"""Format local variables for display, filtering out noisy imports."""
|
|
287
|
+
from io import StringIO
|
|
288
|
+
from rich.console import Console
|
|
289
|
+
from rich.pretty import Pretty
|
|
290
|
+
|
|
291
|
+
# Filter locals
|
|
292
|
+
clean_locals = {
|
|
293
|
+
k: v for k, v in frame_locals.items() if _should_show_local(k, v)
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
if not clean_locals:
|
|
297
|
+
return []
|
|
298
|
+
|
|
299
|
+
lines = []
|
|
300
|
+
lines.append('[dim]╭─ locals ─╮[/dim]')
|
|
301
|
+
|
|
302
|
+
# Format each local variable
|
|
303
|
+
for name, value in clean_locals.items():
|
|
304
|
+
# Use Rich's Pretty for nice formatting
|
|
305
|
+
try:
|
|
306
|
+
console = Console(file=StringIO(), width=60)
|
|
307
|
+
console.print(Pretty(value), end='')
|
|
308
|
+
value_str = console.file.getvalue().strip()
|
|
309
|
+
# Limit length
|
|
310
|
+
if len(value_str) > 100:
|
|
311
|
+
value_str = value_str[:97] + '...'
|
|
312
|
+
except Exception:
|
|
313
|
+
value_str = repr(value)
|
|
314
|
+
if len(value_str) > 100:
|
|
315
|
+
value_str = value_str[:97] + '...'
|
|
316
|
+
|
|
317
|
+
lines.append(f'[dim]│[/dim] {name} = {value_str}')
|
|
318
|
+
|
|
319
|
+
lines.append('[dim]╰──────────╯[/dim]')
|
|
320
|
+
return lines
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def _format_frame_with_context(
|
|
324
|
+
filepath: str,
|
|
325
|
+
lineno: int,
|
|
326
|
+
funcname: str,
|
|
327
|
+
frame_locals: dict | None = None,
|
|
328
|
+
) -> list[str]:
|
|
329
|
+
"""Format a single frame with context lines and optional locals."""
|
|
330
|
+
lines = []
|
|
331
|
+
# Frame header
|
|
332
|
+
lines.append(
|
|
333
|
+
f'[cyan]{filepath}[/cyan]:[yellow]{lineno}[/yellow] '
|
|
334
|
+
f'in [green]{funcname}[/green]'
|
|
335
|
+
)
|
|
336
|
+
lines.append('')
|
|
337
|
+
|
|
338
|
+
# Get context lines
|
|
339
|
+
context_size = 3
|
|
340
|
+
start_line = max(1, lineno - context_size)
|
|
341
|
+
end_line = lineno + context_size + 1
|
|
342
|
+
|
|
343
|
+
for line_num in range(start_line, end_line):
|
|
344
|
+
line_text = linecache.getline(filepath, line_num).rstrip()
|
|
345
|
+
if line_text:
|
|
346
|
+
num_str = str(line_num).rjust(4)
|
|
347
|
+
if line_num == lineno:
|
|
348
|
+
lines.append(f'[dim]{num_str}[/dim] [red]❱[/red] {line_text}')
|
|
349
|
+
else:
|
|
350
|
+
lines.append(f'[dim]{num_str} │[/dim] {line_text}')
|
|
351
|
+
|
|
352
|
+
# Add locals if available
|
|
353
|
+
if frame_locals:
|
|
354
|
+
locals_lines = _format_locals(frame_locals)
|
|
355
|
+
if locals_lines:
|
|
356
|
+
lines.append('')
|
|
357
|
+
lines.extend(locals_lines)
|
|
358
|
+
|
|
359
|
+
lines.append('')
|
|
360
|
+
return lines
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def _format_traceback_lines(
|
|
364
|
+
frames: list[tuple[str, int, str, dict]],
|
|
365
|
+
*,
|
|
366
|
+
caller_info: dict | None = None,
|
|
367
|
+
include_locals: bool = True,
|
|
368
|
+
) -> list[str]:
|
|
369
|
+
"""Format frames into a list of Rich-markup lines for display/logging."""
|
|
370
|
+
display_lines: list[str] = []
|
|
371
|
+
if caller_info:
|
|
372
|
+
display_lines.extend(
|
|
373
|
+
_format_frame_with_context(
|
|
374
|
+
caller_info['filename'],
|
|
375
|
+
caller_info['lineno'],
|
|
376
|
+
caller_info['function'],
|
|
377
|
+
None,
|
|
378
|
+
)
|
|
379
|
+
)
|
|
380
|
+
for filepath, lineno, funcname, frame_locals in frames:
|
|
381
|
+
locals_for_frame = frame_locals if include_locals else None
|
|
382
|
+
display_lines.extend(
|
|
383
|
+
_format_frame_with_context(
|
|
384
|
+
filepath,
|
|
385
|
+
lineno,
|
|
386
|
+
funcname,
|
|
387
|
+
locals_for_frame,
|
|
388
|
+
)
|
|
389
|
+
)
|
|
390
|
+
return display_lines
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def _extract_frames_from_traceback(
|
|
394
|
+
error: Exception,
|
|
395
|
+
) -> list[tuple[str, int, str, dict]]:
|
|
396
|
+
"""Extract user frames from exception traceback object with locals."""
|
|
397
|
+
frames = []
|
|
398
|
+
if hasattr(error, '__traceback__') and error.__traceback__ is not None:
|
|
399
|
+
tb = error.__traceback__
|
|
400
|
+
while tb is not None:
|
|
401
|
+
frame = tb.tb_frame
|
|
402
|
+
filename = frame.f_code.co_filename
|
|
403
|
+
lineno = tb.tb_lineno
|
|
404
|
+
funcname = frame.f_code.co_name
|
|
405
|
+
|
|
406
|
+
if not _should_skip_frame(filename):
|
|
407
|
+
# Get local variables from the frame
|
|
408
|
+
frame_locals = dict(frame.f_locals)
|
|
409
|
+
frames.append((filename, lineno, funcname, frame_locals))
|
|
410
|
+
|
|
411
|
+
tb = tb.tb_next
|
|
412
|
+
return frames
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def _extract_frames_from_ray_error(
|
|
416
|
+
ray_task_error: Exception,
|
|
417
|
+
) -> list[tuple[str, int, str, dict]]:
|
|
418
|
+
"""Extract user frames from Ray's string traceback representation."""
|
|
419
|
+
import re
|
|
420
|
+
|
|
421
|
+
frames = []
|
|
422
|
+
error_str = str(ray_task_error)
|
|
423
|
+
lines = error_str.split('\n')
|
|
424
|
+
|
|
425
|
+
for line in lines:
|
|
426
|
+
# Match: File "path", line N, in func
|
|
427
|
+
file_match = re.match(r'\s*File "([^"]+)", line (\d+), in (.+)', line)
|
|
428
|
+
if file_match:
|
|
429
|
+
filepath, lineno, funcname = file_match.groups()
|
|
430
|
+
if not _should_skip_frame(filepath):
|
|
431
|
+
# Ray doesn't preserve locals, so use empty dict
|
|
432
|
+
frames.append((filepath, int(lineno), funcname, {}))
|
|
433
|
+
|
|
434
|
+
return frames
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def _extract_frames(
|
|
438
|
+
error: Exception,
|
|
439
|
+
ray_task_error: Exception | None = None,
|
|
440
|
+
) -> list[tuple[str, int, str, dict]]:
|
|
441
|
+
"""
|
|
442
|
+
Unified frame extraction that works for both native exceptions and Ray errors.
|
|
443
|
+
|
|
444
|
+
First tries to extract frames from the error's __traceback__.
|
|
445
|
+
If that's empty and ray_task_error is provided, falls back to parsing
|
|
446
|
+
the Ray error string representation.
|
|
447
|
+
"""
|
|
448
|
+
# Try native traceback extraction first
|
|
449
|
+
frames = _extract_frames_from_traceback(error)
|
|
450
|
+
|
|
451
|
+
# If empty and we have a Ray error, try string parsing
|
|
452
|
+
if not frames and ray_task_error is not None:
|
|
453
|
+
frames = _extract_frames_from_ray_error(ray_task_error)
|
|
454
|
+
|
|
455
|
+
return frames
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def _display_formatted_error_and_exit(
|
|
459
|
+
exc_type_name: str,
|
|
460
|
+
exc_msg: str,
|
|
461
|
+
frames: list[tuple[str, int, str, dict]],
|
|
462
|
+
caller_info: dict | None,
|
|
463
|
+
backend: str,
|
|
464
|
+
pbar: tqdm | None = None,
|
|
465
|
+
) -> None:
|
|
466
|
+
"""Display a formatted error and exit the process."""
|
|
467
|
+
# Suppress additional error logs
|
|
468
|
+
os.environ['RAY_IGNORE_UNHANDLED_ERRORS'] = '1'
|
|
469
|
+
|
|
470
|
+
# Close progress bar cleanly if provided
|
|
471
|
+
if pbar is not None:
|
|
472
|
+
pbar.close()
|
|
473
|
+
|
|
474
|
+
from rich.console import Console
|
|
475
|
+
from rich.panel import Panel
|
|
476
|
+
console = Console(stderr=True)
|
|
477
|
+
|
|
478
|
+
if frames or caller_info:
|
|
479
|
+
display_lines = _format_traceback_lines(
|
|
480
|
+
frames,
|
|
481
|
+
caller_info=caller_info,
|
|
482
|
+
include_locals=True,
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# Display the traceback
|
|
486
|
+
console.print()
|
|
487
|
+
console.print(
|
|
488
|
+
Panel(
|
|
489
|
+
'\n'.join(display_lines),
|
|
490
|
+
title=(
|
|
491
|
+
f'[bold red]Traceback (most recent call last) '
|
|
492
|
+
f'[{backend}][/bold red]'
|
|
493
|
+
),
|
|
494
|
+
border_style='red',
|
|
495
|
+
expand=False,
|
|
496
|
+
)
|
|
497
|
+
)
|
|
498
|
+
console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
|
|
499
|
+
console.print()
|
|
500
|
+
else:
|
|
501
|
+
# No frames found, minimal output
|
|
502
|
+
console.print()
|
|
503
|
+
console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
|
|
504
|
+
console.print()
|
|
505
|
+
|
|
506
|
+
# Ensure output is flushed
|
|
507
|
+
sys.stderr.flush()
|
|
508
|
+
sys.stdout.flush()
|
|
509
|
+
sys.exit(1)
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def _exit_on_worker_error(
|
|
513
|
+
error: Exception,
|
|
514
|
+
pbar: tqdm | None = None,
|
|
515
|
+
caller_info: dict | None = None,
|
|
516
|
+
backend: str = 'unknown',
|
|
517
|
+
) -> None:
|
|
518
|
+
"""Display a clean traceback for a worker error and exit the process."""
|
|
519
|
+
frames = _extract_frames_from_traceback(error)
|
|
520
|
+
_display_formatted_error_and_exit(
|
|
521
|
+
exc_type_name=type(error).__name__,
|
|
522
|
+
exc_msg=str(error),
|
|
523
|
+
frames=frames,
|
|
524
|
+
caller_info=caller_info,
|
|
525
|
+
backend=backend,
|
|
526
|
+
pbar=pbar,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def _exit_on_ray_error(
|
|
531
|
+
ray_task_error: Exception,
|
|
532
|
+
pbar: tqdm | None = None,
|
|
533
|
+
caller_info: dict | None = None,
|
|
534
|
+
) -> None:
|
|
535
|
+
"""Display a clean traceback for a RayTaskError and exit the process."""
|
|
536
|
+
# Get the exception info
|
|
537
|
+
cause = (
|
|
538
|
+
ray_task_error.cause
|
|
539
|
+
if hasattr(ray_task_error, 'cause')
|
|
540
|
+
else None
|
|
541
|
+
)
|
|
542
|
+
if cause is None:
|
|
543
|
+
cause = ray_task_error.__cause__
|
|
544
|
+
|
|
545
|
+
exc_type_name = type(cause).__name__ if cause else 'Error'
|
|
546
|
+
exc_msg = str(cause) if cause else str(ray_task_error)
|
|
547
|
+
|
|
548
|
+
frames = []
|
|
549
|
+
if cause:
|
|
550
|
+
frames = _extract_frames_from_traceback(cause)
|
|
551
|
+
if not frames:
|
|
552
|
+
frames = _extract_frames_from_ray_error(ray_task_error)
|
|
553
|
+
_display_formatted_error_and_exit(
|
|
554
|
+
exc_type_name=exc_type_name,
|
|
555
|
+
exc_msg=exc_msg,
|
|
556
|
+
frames=frames,
|
|
557
|
+
caller_info=caller_info,
|
|
558
|
+
backend='ray',
|
|
559
|
+
pbar=pbar,
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
# ─── Process/thread tracking ────────────────────────────────────
|
|
564
|
+
|
|
565
|
+
SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
|
|
566
|
+
_SPEEDY_PROCESSES_LOCK = threading.Lock()
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def _prune_dead_processes() -> None:
|
|
570
|
+
"""Remove dead processes from tracking list."""
|
|
571
|
+
with _SPEEDY_PROCESSES_LOCK:
|
|
572
|
+
SPEEDY_RUNNING_PROCESSES[:] = [
|
|
573
|
+
p for p in SPEEDY_RUNNING_PROCESSES if p.is_running()
|
|
574
|
+
]
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def _track_processes(processes: list[psutil.Process]) -> None:
|
|
578
|
+
"""Add processes to global tracking list."""
|
|
579
|
+
if not processes:
|
|
580
|
+
return
|
|
581
|
+
with _SPEEDY_PROCESSES_LOCK:
|
|
582
|
+
living = [p for p in SPEEDY_RUNNING_PROCESSES if p.is_running()]
|
|
583
|
+
for candidate in processes:
|
|
584
|
+
if not candidate.is_running():
|
|
585
|
+
continue
|
|
586
|
+
if any(existing.pid == candidate.pid for existing in living):
|
|
587
|
+
continue
|
|
588
|
+
living.append(candidate)
|
|
589
|
+
SPEEDY_RUNNING_PROCESSES[:] = living
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def _track_multiprocessing_processes() -> None:
|
|
593
|
+
"""Track multiprocessing worker processes."""
|
|
594
|
+
try:
|
|
595
|
+
# Find recently created child processes
|
|
596
|
+
current_pid = os.getpid()
|
|
597
|
+
parent = psutil.Process(current_pid)
|
|
598
|
+
new_processes = []
|
|
599
|
+
for child in parent.children(recursive=False):
|
|
600
|
+
try:
|
|
601
|
+
# Created within last 5 seconds
|
|
602
|
+
if time.time() - child.create_time() < 5:
|
|
603
|
+
new_processes.append(child)
|
|
604
|
+
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
605
|
+
continue
|
|
606
|
+
_track_processes(new_processes)
|
|
607
|
+
except Exception:
|
|
608
|
+
# Don't fail if process tracking fails
|
|
609
|
+
pass
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def _track_ray_processes() -> None:
|
|
613
|
+
"""Track Ray worker processes when Ray is initialized."""
|
|
614
|
+
try:
|
|
615
|
+
current_pid = os.getpid()
|
|
616
|
+
parent = psutil.Process(current_pid)
|
|
617
|
+
ray_processes = []
|
|
618
|
+
for child in parent.children(recursive=True):
|
|
619
|
+
try:
|
|
620
|
+
name = child.name().lower()
|
|
621
|
+
if 'ray' in name or 'worker' in name:
|
|
622
|
+
ray_processes.append(child)
|
|
623
|
+
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
624
|
+
continue
|
|
625
|
+
_track_processes(ray_processes)
|
|
626
|
+
except Exception:
|
|
627
|
+
pass
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
# ─── Log gating utilities ───────────────────────────────────────
|
|
631
|
+
|
|
632
|
+
_LOG_GATE_CACHE: dict[str, bool] = {}
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
def _should_allow_worker_logs(
|
|
636
|
+
mode: Literal['all', 'zero', 'first'],
|
|
637
|
+
gate_path: Path | None,
|
|
638
|
+
) -> bool:
|
|
639
|
+
"""Determine if current worker should emit logs for the given mode."""
|
|
640
|
+
if mode == 'all':
|
|
641
|
+
return True
|
|
642
|
+
if mode == 'zero':
|
|
643
|
+
return False
|
|
644
|
+
if mode == 'first':
|
|
645
|
+
if gate_path is None:
|
|
646
|
+
return True
|
|
647
|
+
key = str(gate_path)
|
|
648
|
+
cached = _LOG_GATE_CACHE.get(key)
|
|
649
|
+
if cached is not None:
|
|
650
|
+
return cached
|
|
651
|
+
gate_path.parent.mkdir(parents=True, exist_ok=True)
|
|
652
|
+
try:
|
|
653
|
+
fd = os.open(key, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
|
|
654
|
+
except FileExistsError:
|
|
655
|
+
allowed = False
|
|
656
|
+
else:
|
|
657
|
+
os.close(fd)
|
|
658
|
+
allowed = True
|
|
659
|
+
_LOG_GATE_CACHE[key] = allowed
|
|
660
|
+
return allowed
|
|
661
|
+
raise ValueError(f'Unsupported log mode: {mode!r}')
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
def _cleanup_log_gate(gate_path: Path | None) -> None:
|
|
665
|
+
"""Remove the log gate file if it exists."""
|
|
666
|
+
if gate_path is None:
|
|
667
|
+
return
|
|
668
|
+
try:
|
|
669
|
+
gate_path.unlink(missing_ok=True)
|
|
670
|
+
except OSError:
|
|
671
|
+
pass
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
class _PrefixedWriter:
|
|
675
|
+
"""Stream wrapper that prefixes each line with worker id."""
|
|
676
|
+
|
|
677
|
+
def __init__(self, stream, prefix: str):
|
|
678
|
+
self._stream = stream
|
|
679
|
+
self._prefix = prefix
|
|
680
|
+
self._at_line_start = True
|
|
681
|
+
|
|
682
|
+
def write(self, s):
|
|
683
|
+
if not s:
|
|
684
|
+
return 0
|
|
685
|
+
total = 0
|
|
686
|
+
for chunk in s.splitlines(True):
|
|
687
|
+
if self._at_line_start:
|
|
688
|
+
self._stream.write(self._prefix)
|
|
689
|
+
total += len(self._prefix)
|
|
690
|
+
self._stream.write(chunk)
|
|
691
|
+
total += len(chunk)
|
|
692
|
+
self._at_line_start = chunk.endswith('\n')
|
|
693
|
+
return total
|
|
694
|
+
|
|
695
|
+
def flush(self):
|
|
696
|
+
self._stream.flush()
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
def _call_with_log_control(
|
|
700
|
+
func: Callable,
|
|
701
|
+
x: Any,
|
|
702
|
+
func_kwargs: dict[str, Any],
|
|
703
|
+
log_mode: Literal['all', 'zero', 'first'],
|
|
704
|
+
gate_path: Path | None,
|
|
705
|
+
):
|
|
706
|
+
"""Call a function, silencing stdout/stderr based on log mode."""
|
|
707
|
+
allow_logs = _should_allow_worker_logs(log_mode, gate_path)
|
|
708
|
+
if allow_logs:
|
|
709
|
+
prefix = f'[worker-{os.getpid()}] '
|
|
710
|
+
# Route worker logs to stderr to reduce clobbering tqdm on stdout
|
|
711
|
+
out = _PrefixedWriter(sys.stderr, prefix)
|
|
712
|
+
err = out
|
|
713
|
+
with contextlib.redirect_stdout(out), contextlib.redirect_stderr(err):
|
|
714
|
+
return func(x, **func_kwargs)
|
|
715
|
+
with (
|
|
716
|
+
open(os.devnull, 'w') as devnull,
|
|
717
|
+
contextlib.redirect_stdout(devnull),
|
|
718
|
+
contextlib.redirect_stderr(devnull),
|
|
719
|
+
):
|
|
720
|
+
return func(x, **func_kwargs)
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
# ─── Cache helpers ──────────────────────────────────────────────
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
|
|
727
|
+
"""Build cache dir with function name + timestamp."""
|
|
728
|
+
import datetime
|
|
729
|
+
func_name = getattr(func, '__name__', 'func')
|
|
730
|
+
now = datetime.datetime.now()
|
|
731
|
+
stamp = now.strftime('%m%d_%Hh%Mm%Ss')
|
|
732
|
+
run_id = f'{func_name}_{stamp}_{uuid.uuid4().hex[:6]}'
|
|
733
|
+
path = Path('.cache') / run_id
|
|
734
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
735
|
+
return path
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def wrap_dump(
|
|
739
|
+
func: Callable,
|
|
740
|
+
cache_dir: Path | None,
|
|
741
|
+
dump_in_thread: bool = True,
|
|
742
|
+
):
|
|
743
|
+
"""Wrap a function so results are dumped to .pkl when cache_dir is set."""
|
|
744
|
+
if cache_dir is None:
|
|
745
|
+
return func
|
|
746
|
+
|
|
747
|
+
def wrapped(x, *args, **kwargs):
|
|
748
|
+
res = func(x, *args, **kwargs)
|
|
749
|
+
p = cache_dir / f'{uuid.uuid4().hex}.pkl'
|
|
750
|
+
|
|
751
|
+
def save():
|
|
752
|
+
with open(p, 'wb') as fh:
|
|
753
|
+
pickle.dump(res, fh)
|
|
754
|
+
|
|
755
|
+
if dump_in_thread:
|
|
756
|
+
thread = threading.Thread(target=save)
|
|
757
|
+
while threading.active_count() > 16:
|
|
758
|
+
time.sleep(0.1)
|
|
759
|
+
thread.start()
|
|
760
|
+
else:
|
|
761
|
+
save()
|
|
762
|
+
return str(p)
|
|
763
|
+
|
|
764
|
+
return wrapped
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
# ─── Log gate path helper ───────────────────────────────────────
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def create_log_gate_path(
|
|
771
|
+
log_worker: Literal['zero', 'first', 'all'],
|
|
772
|
+
) -> Path | None:
|
|
773
|
+
"""Create a log gate path for first-worker-only logging."""
|
|
774
|
+
if log_worker == 'first':
|
|
775
|
+
return (
|
|
776
|
+
Path(tempfile.gettempdir())
|
|
777
|
+
/ f'speedy_utils_log_gate_{os.getpid()}_{uuid.uuid4().hex}.gate'
|
|
778
|
+
)
|
|
779
|
+
elif log_worker not in ('zero', 'all'):
|
|
780
|
+
raise ValueError(f'Unsupported log_worker: {log_worker!r}')
|
|
781
|
+
return None
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
# ─── Cleanup utility ────────────────────────────────────────────
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def cleanup_phantom_workers() -> None:
|
|
788
|
+
"""
|
|
789
|
+
Kill all tracked processes and threads (phantom workers).
|
|
790
|
+
|
|
791
|
+
Also lists non-daemon threads that remain.
|
|
792
|
+
"""
|
|
793
|
+
# Clean up tracked processes first
|
|
794
|
+
_prune_dead_processes()
|
|
795
|
+
killed_processes = 0
|
|
796
|
+
with _SPEEDY_PROCESSES_LOCK:
|
|
797
|
+
for process in SPEEDY_RUNNING_PROCESSES[:]:
|
|
798
|
+
try:
|
|
799
|
+
print(
|
|
800
|
+
f'🔪 Killing tracked process {process.pid} '
|
|
801
|
+
f'({process.name()})'
|
|
802
|
+
)
|
|
803
|
+
process.kill()
|
|
804
|
+
killed_processes += 1
|
|
805
|
+
except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
|
|
806
|
+
print(f'⚠️ Could not kill process {process.pid}: {e}')
|
|
807
|
+
SPEEDY_RUNNING_PROCESSES.clear()
|
|
808
|
+
|
|
809
|
+
# Also kill any remaining child processes (fallback)
|
|
810
|
+
parent = psutil.Process(os.getpid())
|
|
811
|
+
for child in parent.children(recursive=True):
|
|
812
|
+
try:
|
|
813
|
+
print(f'🔪 Killing child process {child.pid} ({child.name()})')
|
|
814
|
+
child.kill()
|
|
815
|
+
except psutil.NoSuchProcess:
|
|
816
|
+
pass
|
|
817
|
+
|
|
818
|
+
# Try to clean up threads using thread module functions if available
|
|
819
|
+
try:
|
|
820
|
+
from .thread import (
|
|
821
|
+
SPEEDY_RUNNING_THREADS,
|
|
822
|
+
_prune_dead_threads,
|
|
823
|
+
kill_all_thread,
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
_prune_dead_threads()
|
|
827
|
+
killed_threads = kill_all_thread()
|
|
828
|
+
if killed_threads > 0:
|
|
829
|
+
print(f'🔪 Killed {killed_threads} tracked threads')
|
|
830
|
+
except ImportError:
|
|
831
|
+
# Fallback: just report stray threads
|
|
832
|
+
for t in threading.enumerate():
|
|
833
|
+
if t is threading.current_thread():
|
|
834
|
+
continue
|
|
835
|
+
if not t.daemon:
|
|
836
|
+
print(
|
|
837
|
+
f'⚠️ Thread {t.name} is still running '
|
|
838
|
+
f'(cannot be force-killed).'
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
print(
|
|
842
|
+
f'✅ Cleaned up {killed_processes} tracked processes and '
|
|
843
|
+
f'child processes (kernel untouched).'
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
|
|
847
|
+
__all__ = [
|
|
848
|
+
# Types
|
|
849
|
+
'ErrorHandlerType',
|
|
850
|
+
'ErrorStats',
|
|
851
|
+
# Process tracking globals
|
|
852
|
+
'SPEEDY_RUNNING_PROCESSES',
|
|
853
|
+
'_SPEEDY_PROCESSES_LOCK',
|
|
854
|
+
# Error utilities
|
|
855
|
+
'_should_skip_frame',
|
|
856
|
+
'_format_traceback_lines',
|
|
857
|
+
'_extract_frames_from_traceback',
|
|
858
|
+
'_extract_frames_from_ray_error',
|
|
859
|
+
'_display_formatted_error_and_exit',
|
|
860
|
+
'_exit_on_worker_error',
|
|
861
|
+
'_exit_on_ray_error',
|
|
862
|
+
# Process tracking
|
|
863
|
+
'_prune_dead_processes',
|
|
864
|
+
'_track_processes',
|
|
865
|
+
'_track_multiprocessing_processes',
|
|
866
|
+
'_track_ray_processes',
|
|
867
|
+
# Log gating
|
|
868
|
+
'_LOG_GATE_CACHE',
|
|
869
|
+
'_should_allow_worker_logs',
|
|
870
|
+
'_cleanup_log_gate',
|
|
871
|
+
'_PrefixedWriter',
|
|
872
|
+
'_call_with_log_control',
|
|
873
|
+
'create_log_gate_path',
|
|
874
|
+
# Cache helpers
|
|
875
|
+
'_build_cache_dir',
|
|
876
|
+
'wrap_dump',
|
|
877
|
+
# Cleanup
|
|
878
|
+
'cleanup_phantom_workers',
|
|
879
|
+
]
|