speedy-utils 1.1.40__py3-none-any.whl → 1.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +2 -0
- llm_utils/llm_ray.py +370 -0
- llm_utils/lm/llm.py +36 -29
- speedy_utils/__init__.py +10 -0
- speedy_utils/common/utils_io.py +3 -1
- speedy_utils/multi_worker/__init__.py +12 -0
- speedy_utils/multi_worker/dataset_ray.py +303 -0
- speedy_utils/multi_worker/parallel_gpu_pool.py +178 -0
- speedy_utils/multi_worker/process.py +989 -86
- speedy_utils/multi_worker/progress.py +140 -0
- speedy_utils/multi_worker/thread.py +202 -42
- speedy_utils/scripts/mpython.py +49 -4
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.43.dist-info}/METADATA +5 -3
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.43.dist-info}/RECORD +16 -12
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.43.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.43.dist-info}/entry_points.txt +0 -0
|
@@ -1,15 +1,495 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import os
|
|
3
|
+
# Suppress Ray FutureWarnings before any imports
|
|
4
|
+
warnings.filterwarnings("ignore", category=FutureWarning, module="ray.*")
|
|
5
|
+
warnings.filterwarnings("ignore", message=".*pynvml.*deprecated.*", category=FutureWarning)
|
|
6
|
+
|
|
7
|
+
# Set environment variables before Ray is imported anywhere
|
|
8
|
+
os.environ["RAY_ACCEL_ENV_VAR_OVERRI" \
|
|
9
|
+
"DE_ON_ZERO"] = "0"
|
|
10
|
+
os.environ["RAY_DEDUP_LOGS"] = "1"
|
|
11
|
+
os.environ["RAY_LOG_TO_STDERR"] = "0"
|
|
12
|
+
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
|
|
13
|
+
|
|
1
14
|
from ..__imports import *
|
|
15
|
+
import tempfile
|
|
16
|
+
import inspect
|
|
17
|
+
import linecache
|
|
18
|
+
import traceback as tb_module
|
|
19
|
+
from .progress import create_progress_tracker, ProgressPoller, get_ray_progress_actor
|
|
2
20
|
|
|
21
|
+
# Import thread tracking functions if available
|
|
22
|
+
try:
|
|
23
|
+
from .thread import _prune_dead_threads, _track_executor_threads
|
|
24
|
+
except ImportError:
|
|
25
|
+
_prune_dead_threads = None # type: ignore[assignment]
|
|
26
|
+
_track_executor_threads = None # type: ignore[assignment]
|
|
3
27
|
|
|
4
|
-
SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
|
|
5
|
-
_SPEEDY_PROCESSES_LOCK = threading.Lock()
|
|
6
28
|
|
|
29
|
+
# ─── error handler types ────────────────────────────────────────
|
|
30
|
+
ErrorHandlerType = Literal['raise', 'ignore', 'log']
|
|
7
31
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
32
|
+
|
|
33
|
+
class ErrorStats:
|
|
34
|
+
"""Thread-safe error statistics tracker."""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
func_name: str,
|
|
39
|
+
max_error_files: int = 100,
|
|
40
|
+
write_logs: bool = True
|
|
41
|
+
):
|
|
42
|
+
self._lock = threading.Lock()
|
|
43
|
+
self._success_count = 0
|
|
44
|
+
self._error_count = 0
|
|
45
|
+
self._first_error_shown = False
|
|
46
|
+
self._max_error_files = max_error_files
|
|
47
|
+
self._write_logs = write_logs
|
|
48
|
+
self._error_log_dir = self._get_error_log_dir(func_name)
|
|
49
|
+
if write_logs:
|
|
50
|
+
self._error_log_dir.mkdir(parents=True, exist_ok=True)
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def _get_error_log_dir(func_name: str) -> Path:
|
|
54
|
+
"""Generate unique error log directory with run counter."""
|
|
55
|
+
base_dir = Path('.cache/speedy_utils/error_logs')
|
|
56
|
+
base_dir.mkdir(parents=True, exist_ok=True)
|
|
57
|
+
|
|
58
|
+
# Find the next run counter
|
|
59
|
+
counter = 1
|
|
60
|
+
existing = list(base_dir.glob(f'{func_name}_run_*'))
|
|
61
|
+
if existing:
|
|
62
|
+
counters = []
|
|
63
|
+
for p in existing:
|
|
64
|
+
try:
|
|
65
|
+
parts = p.name.split('_run_')
|
|
66
|
+
if len(parts) == 2:
|
|
67
|
+
counters.append(int(parts[1]))
|
|
68
|
+
except (ValueError, IndexError):
|
|
69
|
+
pass
|
|
70
|
+
if counters:
|
|
71
|
+
counter = max(counters) + 1
|
|
72
|
+
|
|
73
|
+
return base_dir / f'{func_name}_run_{counter}'
|
|
74
|
+
|
|
75
|
+
def record_success(self) -> None:
|
|
76
|
+
with self._lock:
|
|
77
|
+
self._success_count += 1
|
|
78
|
+
|
|
79
|
+
def record_error(
|
|
80
|
+
self,
|
|
81
|
+
idx: int,
|
|
82
|
+
error: Exception,
|
|
83
|
+
input_value: Any,
|
|
84
|
+
func_name: str,
|
|
85
|
+
) -> str | None:
|
|
86
|
+
"""
|
|
87
|
+
Record an error and write to log file.
|
|
88
|
+
Returns the log file path if written, None otherwise.
|
|
89
|
+
"""
|
|
90
|
+
with self._lock:
|
|
91
|
+
self._error_count += 1
|
|
92
|
+
should_show_first = not self._first_error_shown
|
|
93
|
+
if should_show_first:
|
|
94
|
+
self._first_error_shown = True
|
|
95
|
+
should_write = (
|
|
96
|
+
self._write_logs and self._error_count <= self._max_error_files
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
log_path = None
|
|
100
|
+
if should_write:
|
|
101
|
+
log_path = self._write_error_log(
|
|
102
|
+
idx, error, input_value, func_name
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if should_show_first:
|
|
106
|
+
self._print_first_error(error, input_value, func_name, log_path)
|
|
107
|
+
|
|
108
|
+
return log_path
|
|
109
|
+
|
|
110
|
+
def _write_error_log(
|
|
111
|
+
self,
|
|
112
|
+
idx: int,
|
|
113
|
+
error: Exception,
|
|
114
|
+
input_value: Any,
|
|
115
|
+
func_name: str,
|
|
116
|
+
) -> str:
|
|
117
|
+
"""Write error details to a log file."""
|
|
118
|
+
log_path = self._error_log_dir / f'{idx}.log'
|
|
119
|
+
|
|
120
|
+
# Format traceback
|
|
121
|
+
tb_lines = self._format_traceback(error)
|
|
122
|
+
|
|
123
|
+
content = []
|
|
124
|
+
content.append(f'{"=" * 60}')
|
|
125
|
+
content.append(f'Error at index: {idx}')
|
|
126
|
+
content.append(f'Function: {func_name}')
|
|
127
|
+
content.append(f'Error Type: {type(error).__name__}')
|
|
128
|
+
content.append(f'Error Message: {error}')
|
|
129
|
+
content.append(f'{"=" * 60}')
|
|
130
|
+
content.append('')
|
|
131
|
+
content.append('Input:')
|
|
132
|
+
content.append('-' * 40)
|
|
133
|
+
try:
|
|
134
|
+
content.append(repr(input_value))
|
|
135
|
+
except Exception:
|
|
136
|
+
content.append('<unable to repr input>')
|
|
137
|
+
content.append('')
|
|
138
|
+
content.append('Traceback:')
|
|
139
|
+
content.append('-' * 40)
|
|
140
|
+
content.extend(tb_lines)
|
|
141
|
+
|
|
142
|
+
with open(log_path, 'w') as f:
|
|
143
|
+
f.write('\n'.join(content))
|
|
144
|
+
|
|
145
|
+
return str(log_path)
|
|
146
|
+
|
|
147
|
+
def _format_traceback(self, error: Exception) -> list[str]:
|
|
148
|
+
"""Format traceback with context lines like Rich panel."""
|
|
149
|
+
lines = []
|
|
150
|
+
frames = _extract_frames_from_traceback(error)
|
|
151
|
+
|
|
152
|
+
for filepath, lineno, funcname, frame_locals in frames:
|
|
153
|
+
lines.append(f'│ {filepath}:{lineno} in {funcname} │')
|
|
154
|
+
lines.append('│' + ' ' * 70 + '│')
|
|
155
|
+
|
|
156
|
+
# Get context lines
|
|
157
|
+
context_size = 3
|
|
158
|
+
start_line = max(1, lineno - context_size)
|
|
159
|
+
end_line = lineno + context_size + 1
|
|
160
|
+
|
|
161
|
+
for line_num in range(start_line, end_line):
|
|
162
|
+
line_text = linecache.getline(filepath, line_num).rstrip()
|
|
163
|
+
if line_text:
|
|
164
|
+
num_str = str(line_num).rjust(4)
|
|
165
|
+
if line_num == lineno:
|
|
166
|
+
lines.append(f'│ {num_str} ❱ {line_text}')
|
|
167
|
+
else:
|
|
168
|
+
lines.append(f'│ {num_str} │ {line_text}')
|
|
169
|
+
lines.append('')
|
|
170
|
+
|
|
171
|
+
return lines
|
|
172
|
+
|
|
173
|
+
def _print_first_error(
|
|
174
|
+
self,
|
|
175
|
+
error: Exception,
|
|
176
|
+
input_value: Any,
|
|
177
|
+
func_name: str,
|
|
178
|
+
log_path: str | None,
|
|
179
|
+
) -> None:
|
|
180
|
+
"""Print the first error to screen with Rich formatting."""
|
|
181
|
+
try:
|
|
182
|
+
from rich.console import Console
|
|
183
|
+
from rich.panel import Panel
|
|
184
|
+
console = Console(stderr=True)
|
|
185
|
+
|
|
186
|
+
tb_lines = self._format_traceback(error)
|
|
187
|
+
|
|
188
|
+
console.print()
|
|
189
|
+
console.print(
|
|
190
|
+
Panel(
|
|
191
|
+
'\n'.join(tb_lines),
|
|
192
|
+
title='[bold red]First Error (continuing with remaining items)[/bold red]',
|
|
193
|
+
border_style='yellow',
|
|
194
|
+
expand=False,
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
console.print(
|
|
198
|
+
f'[bold red]{type(error).__name__}[/bold red]: {error}'
|
|
199
|
+
)
|
|
200
|
+
if log_path:
|
|
201
|
+
console.print(f'[dim]Error log: {log_path}[/dim]')
|
|
202
|
+
console.print()
|
|
203
|
+
except ImportError:
|
|
204
|
+
# Fallback to plain print
|
|
205
|
+
print(f'\n--- First Error ---', file=sys.stderr)
|
|
206
|
+
print(f'{type(error).__name__}: {error}', file=sys.stderr)
|
|
207
|
+
if log_path:
|
|
208
|
+
print(f'Error log: {log_path}', file=sys.stderr)
|
|
209
|
+
print('', file=sys.stderr)
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def success_count(self) -> int:
|
|
213
|
+
with self._lock:
|
|
214
|
+
return self._success_count
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def error_count(self) -> int:
|
|
218
|
+
with self._lock:
|
|
219
|
+
return self._error_count
|
|
220
|
+
|
|
221
|
+
def get_postfix_dict(self) -> dict[str, int]:
|
|
222
|
+
"""Get dict for pbar postfix."""
|
|
223
|
+
with self._lock:
|
|
224
|
+
return {'ok': self._success_count, 'err': self._error_count}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _should_skip_frame(filepath: str) -> bool:
|
|
228
|
+
"""Check if a frame should be filtered from traceback display."""
|
|
229
|
+
skip_patterns = [
|
|
230
|
+
'ray/_private',
|
|
231
|
+
'ray/worker',
|
|
232
|
+
'site-packages/ray',
|
|
233
|
+
'speedy_utils/multi_worker',
|
|
234
|
+
'concurrent/futures',
|
|
235
|
+
'multiprocessing/',
|
|
236
|
+
'fastcore/parallel',
|
|
237
|
+
'fastcore/foundation',
|
|
238
|
+
'fastcore/basics',
|
|
239
|
+
'site-packages/fastcore',
|
|
240
|
+
'/threading.py',
|
|
241
|
+
'/concurrent/',
|
|
242
|
+
]
|
|
243
|
+
return any(skip in filepath for skip in skip_patterns)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _should_show_local(name: str, value: object) -> bool:
|
|
247
|
+
"""Check if a local variable should be displayed in traceback."""
|
|
248
|
+
import types
|
|
249
|
+
|
|
250
|
+
# Skip dunder variables
|
|
251
|
+
if name.startswith('__') and name.endswith('__'):
|
|
252
|
+
return False
|
|
253
|
+
|
|
254
|
+
# Skip modules
|
|
255
|
+
if isinstance(value, types.ModuleType):
|
|
256
|
+
return False
|
|
257
|
+
|
|
258
|
+
# Skip type objects and classes
|
|
259
|
+
if isinstance(value, type):
|
|
260
|
+
return False
|
|
261
|
+
|
|
262
|
+
# Skip functions and methods
|
|
263
|
+
if isinstance(value, (types.FunctionType, types.MethodType, types.BuiltinFunctionType)):
|
|
264
|
+
return False
|
|
265
|
+
|
|
266
|
+
# Skip common typing aliases
|
|
267
|
+
value_str = str(value)
|
|
268
|
+
if value_str.startswith('typing.'):
|
|
269
|
+
return False
|
|
270
|
+
|
|
271
|
+
# Skip large objects that would clutter output
|
|
272
|
+
if value_str.startswith('<') and any(x in value_str for x in ['module', 'function', 'method', 'built-in']):
|
|
273
|
+
return False
|
|
274
|
+
|
|
275
|
+
return True
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _format_locals(frame_locals: dict) -> list[str]:
|
|
279
|
+
"""Format local variables for display, filtering out noisy imports."""
|
|
280
|
+
from rich.pretty import Pretty
|
|
281
|
+
from rich.console import Console
|
|
282
|
+
from io import StringIO
|
|
283
|
+
|
|
284
|
+
# Filter locals
|
|
285
|
+
clean_locals = {k: v for k, v in frame_locals.items() if _should_show_local(k, v)}
|
|
286
|
+
|
|
287
|
+
if not clean_locals:
|
|
288
|
+
return []
|
|
289
|
+
|
|
290
|
+
lines = []
|
|
291
|
+
lines.append('[dim]╭─ locals ─╮[/dim]')
|
|
292
|
+
|
|
293
|
+
# Format each local variable
|
|
294
|
+
for name, value in clean_locals.items():
|
|
295
|
+
# Use Rich's Pretty for nice formatting
|
|
296
|
+
try:
|
|
297
|
+
console = Console(file=StringIO(), width=60)
|
|
298
|
+
console.print(Pretty(value), end='')
|
|
299
|
+
value_str = console.file.getvalue().strip()
|
|
300
|
+
# Limit length
|
|
301
|
+
if len(value_str) > 100:
|
|
302
|
+
value_str = value_str[:97] + '...'
|
|
303
|
+
except Exception:
|
|
304
|
+
value_str = repr(value)
|
|
305
|
+
if len(value_str) > 100:
|
|
306
|
+
value_str = value_str[:97] + '...'
|
|
307
|
+
|
|
308
|
+
lines.append(f'[dim]│[/dim] {name} = {value_str}')
|
|
309
|
+
|
|
310
|
+
lines.append('[dim]╰──────────╯[/dim]')
|
|
311
|
+
return lines
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _format_frame_with_context(filepath: str, lineno: int, funcname: str, frame_locals: dict | None = None) -> list[str]:
|
|
315
|
+
"""Format a single frame with context lines and optional locals."""
|
|
316
|
+
lines = []
|
|
317
|
+
# Frame header
|
|
318
|
+
lines.append(
|
|
319
|
+
f'[cyan]{filepath}[/cyan]:[yellow]{lineno}[/yellow] '
|
|
320
|
+
f'in [green]{funcname}[/green]'
|
|
321
|
+
)
|
|
322
|
+
lines.append('')
|
|
323
|
+
|
|
324
|
+
# Get context lines
|
|
325
|
+
context_size = 3
|
|
326
|
+
start_line = max(1, lineno - context_size)
|
|
327
|
+
end_line = lineno + context_size + 1
|
|
328
|
+
|
|
329
|
+
for line_num in range(start_line, end_line):
|
|
330
|
+
import linecache
|
|
331
|
+
line_text = linecache.getline(filepath, line_num).rstrip()
|
|
332
|
+
if line_text:
|
|
333
|
+
num_str = str(line_num).rjust(4)
|
|
334
|
+
if line_num == lineno:
|
|
335
|
+
lines.append(f'[dim]{num_str}[/dim] [red]❱[/red] {line_text}')
|
|
336
|
+
else:
|
|
337
|
+
lines.append(f'[dim]{num_str} │[/dim] {line_text}')
|
|
338
|
+
|
|
339
|
+
# Add locals if available
|
|
340
|
+
if frame_locals:
|
|
341
|
+
locals_lines = _format_locals(frame_locals)
|
|
342
|
+
if locals_lines:
|
|
343
|
+
lines.append('')
|
|
344
|
+
lines.extend(locals_lines)
|
|
345
|
+
|
|
346
|
+
lines.append('')
|
|
347
|
+
return lines
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def _display_formatted_error(
|
|
351
|
+
exc_type_name: str,
|
|
352
|
+
exc_msg: str,
|
|
353
|
+
frames: list[tuple[str, int, str, dict]],
|
|
354
|
+
caller_info: dict | None,
|
|
355
|
+
backend: str,
|
|
356
|
+
pbar=None,
|
|
357
|
+
) -> None:
|
|
358
|
+
|
|
359
|
+
# Suppress additional error logs
|
|
360
|
+
os.environ['RAY_IGNORE_UNHANDLED_ERRORS'] = '1'
|
|
361
|
+
|
|
362
|
+
# Close progress bar cleanly if provided
|
|
363
|
+
if pbar is not None:
|
|
364
|
+
pbar.close()
|
|
365
|
+
|
|
366
|
+
from rich.console import Console
|
|
367
|
+
from rich.panel import Panel
|
|
368
|
+
console = Console(stderr=True)
|
|
369
|
+
|
|
370
|
+
if frames or caller_info:
|
|
371
|
+
display_lines = []
|
|
372
|
+
|
|
373
|
+
# Add caller frame first if available (no locals for caller)
|
|
374
|
+
if caller_info:
|
|
375
|
+
display_lines.extend(_format_frame_with_context(
|
|
376
|
+
caller_info['filename'],
|
|
377
|
+
caller_info['lineno'],
|
|
378
|
+
caller_info['function'],
|
|
379
|
+
None # Don't show locals for caller frame
|
|
380
|
+
))
|
|
381
|
+
|
|
382
|
+
# Add error frames with locals
|
|
383
|
+
for filepath, lineno, funcname, frame_locals in frames:
|
|
384
|
+
display_lines.extend(_format_frame_with_context(
|
|
385
|
+
filepath, lineno, funcname, frame_locals
|
|
386
|
+
))
|
|
387
|
+
|
|
388
|
+
# Display the traceback
|
|
389
|
+
console.print()
|
|
390
|
+
console.print(
|
|
391
|
+
Panel(
|
|
392
|
+
'\n'.join(display_lines),
|
|
393
|
+
title=f'[bold red]Traceback (most recent call last) [{backend}][/bold red]',
|
|
394
|
+
border_style='red',
|
|
395
|
+
expand=False,
|
|
396
|
+
)
|
|
397
|
+
)
|
|
398
|
+
console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
|
|
399
|
+
console.print()
|
|
400
|
+
else:
|
|
401
|
+
# No frames found, minimal output
|
|
402
|
+
console.print()
|
|
403
|
+
console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
|
|
404
|
+
console.print()
|
|
405
|
+
|
|
406
|
+
# Ensure output is flushed
|
|
407
|
+
sys.stderr.flush()
|
|
408
|
+
sys.stdout.flush()
|
|
409
|
+
sys.exit(1)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def _extract_frames_from_traceback(error: Exception) -> list[tuple[str, int, str, dict]]:
|
|
413
|
+
"""Extract user frames from exception traceback object with locals."""
|
|
414
|
+
frames = []
|
|
415
|
+
if hasattr(error, '__traceback__') and error.__traceback__ is not None:
|
|
416
|
+
tb = error.__traceback__
|
|
417
|
+
while tb is not None:
|
|
418
|
+
frame = tb.tb_frame
|
|
419
|
+
filename = frame.f_code.co_filename
|
|
420
|
+
lineno = tb.tb_lineno
|
|
421
|
+
funcname = frame.f_code.co_name
|
|
422
|
+
|
|
423
|
+
if not _should_skip_frame(filename):
|
|
424
|
+
# Get local variables from the frame
|
|
425
|
+
frame_locals = dict(frame.f_locals)
|
|
426
|
+
frames.append((filename, lineno, funcname, frame_locals))
|
|
427
|
+
|
|
428
|
+
tb = tb.tb_next
|
|
429
|
+
return frames
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def _extract_frames_from_ray_error(ray_task_error: Exception) -> list[tuple[str, int, str, dict]]:
|
|
433
|
+
"""Extract user frames from Ray's string traceback representation."""
|
|
434
|
+
frames = []
|
|
435
|
+
error_str = str(ray_task_error)
|
|
436
|
+
lines = error_str.split('\n')
|
|
437
|
+
|
|
438
|
+
import re
|
|
439
|
+
for i, line in enumerate(lines):
|
|
440
|
+
# Match: File "path", line N, in func
|
|
441
|
+
file_match = re.match(r'\s*File "([^"]+)", line (\d+), in (.+)', line)
|
|
442
|
+
if file_match:
|
|
443
|
+
filepath, lineno, funcname = file_match.groups()
|
|
444
|
+
if not _should_skip_frame(filepath):
|
|
445
|
+
# Ray doesn't preserve locals, so use empty dict
|
|
446
|
+
frames.append((filepath, int(lineno), funcname, {}))
|
|
447
|
+
|
|
448
|
+
return frames
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def _reraise_worker_error(error: Exception, pbar=None, caller_info=None, backend: str = 'unknown') -> None:
|
|
452
|
+
"""
|
|
453
|
+
Re-raise the original exception from a worker error with clean traceback.
|
|
454
|
+
Works for multiprocessing, threadpool, and other backends with real tracebacks.
|
|
455
|
+
"""
|
|
456
|
+
frames = _extract_frames_from_traceback(error)
|
|
457
|
+
_display_formatted_error(
|
|
458
|
+
exc_type_name=type(error).__name__,
|
|
459
|
+
exc_msg=str(error),
|
|
460
|
+
frames=frames,
|
|
461
|
+
caller_info=caller_info,
|
|
462
|
+
backend=backend,
|
|
463
|
+
pbar=pbar,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def _reraise_ray_error(ray_task_error: Exception, pbar=None, caller_info=None) -> None:
|
|
468
|
+
"""
|
|
469
|
+
Re-raise the original exception from a RayTaskError with clean traceback.
|
|
470
|
+
Parses Ray's string traceback and displays with full context.
|
|
471
|
+
"""
|
|
472
|
+
# Get the exception info
|
|
473
|
+
cause = ray_task_error.cause if hasattr(ray_task_error, 'cause') else None
|
|
474
|
+
if cause is None:
|
|
475
|
+
cause = ray_task_error.__cause__
|
|
476
|
+
|
|
477
|
+
exc_type_name = type(cause).__name__ if cause else 'Error'
|
|
478
|
+
exc_msg = str(cause) if cause else str(ray_task_error)
|
|
479
|
+
|
|
480
|
+
frames = _extract_frames_from_ray_error(ray_task_error)
|
|
481
|
+
_display_formatted_error(
|
|
482
|
+
exc_type_name=exc_type_name,
|
|
483
|
+
exc_msg=exc_msg,
|
|
484
|
+
frames=frames,
|
|
485
|
+
caller_info=caller_info,
|
|
486
|
+
backend='ray',
|
|
487
|
+
pbar=pbar,
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
|
|
492
|
+
_SPEEDY_PROCESSES_LOCK = threading.Lock()
|
|
13
493
|
|
|
14
494
|
def _prune_dead_processes() -> None:
|
|
15
495
|
"""Remove dead processes from tracking list."""
|
|
@@ -89,7 +569,7 @@ def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
|
|
|
89
569
|
path = Path('.cache') / run_id
|
|
90
570
|
path.mkdir(parents=True, exist_ok=True)
|
|
91
571
|
return path
|
|
92
|
-
|
|
572
|
+
_DUMP_INTERMEDIATE_THREADS = []
|
|
93
573
|
def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = True):
|
|
94
574
|
"""Wrap a function so results are dumped to .pkl when cache_dir is set."""
|
|
95
575
|
if cache_dir is None:
|
|
@@ -108,7 +588,7 @@ def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = Tru
|
|
|
108
588
|
|
|
109
589
|
if dump_in_thread:
|
|
110
590
|
thread = threading.Thread(target=save)
|
|
111
|
-
|
|
591
|
+
_DUMP_INTERMEDIATE_THREADS.append(thread)
|
|
112
592
|
# count thread
|
|
113
593
|
# print(f'Thread count: {threading.active_count()}')
|
|
114
594
|
while threading.active_count() > 16:
|
|
@@ -121,35 +601,213 @@ def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = Tru
|
|
|
121
601
|
return wrapped
|
|
122
602
|
|
|
123
603
|
|
|
604
|
+
_LOG_GATE_CACHE: dict[str, bool] = {}
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
def _should_allow_worker_logs(mode: Literal['all', 'zero', 'first'], gate_path: Path | None) -> bool:
|
|
608
|
+
"""Determine if current worker should emit logs for the given mode."""
|
|
609
|
+
if mode == 'all':
|
|
610
|
+
return True
|
|
611
|
+
if mode == 'zero':
|
|
612
|
+
return False
|
|
613
|
+
if mode == 'first':
|
|
614
|
+
if gate_path is None:
|
|
615
|
+
return True
|
|
616
|
+
key = str(gate_path)
|
|
617
|
+
cached = _LOG_GATE_CACHE.get(key)
|
|
618
|
+
if cached is not None:
|
|
619
|
+
return cached
|
|
620
|
+
gate_path.parent.mkdir(parents=True, exist_ok=True)
|
|
621
|
+
try:
|
|
622
|
+
fd = os.open(key, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
|
|
623
|
+
except FileExistsError:
|
|
624
|
+
allowed = False
|
|
625
|
+
else:
|
|
626
|
+
os.close(fd)
|
|
627
|
+
allowed = True
|
|
628
|
+
_LOG_GATE_CACHE[key] = allowed
|
|
629
|
+
return allowed
|
|
630
|
+
raise ValueError(f'Unsupported log mode: {mode!r}')
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
def _cleanup_log_gate(gate_path: Path | None):
|
|
634
|
+
if gate_path is None:
|
|
635
|
+
return
|
|
636
|
+
try:
|
|
637
|
+
gate_path.unlink(missing_ok=True)
|
|
638
|
+
except OSError:
|
|
639
|
+
pass
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
@contextlib.contextmanager
|
|
643
|
+
def _patch_fastcore_progress_bar(*, leave: bool = True):
|
|
644
|
+
"""Temporarily force fastcore.progress_bar to keep the bar on screen."""
|
|
645
|
+
try:
|
|
646
|
+
import fastcore.parallel as _fp
|
|
647
|
+
except ImportError:
|
|
648
|
+
yield False
|
|
649
|
+
return
|
|
650
|
+
|
|
651
|
+
orig = getattr(_fp, 'progress_bar', None)
|
|
652
|
+
if orig is None:
|
|
653
|
+
yield False
|
|
654
|
+
return
|
|
655
|
+
|
|
656
|
+
def _wrapped(*args, **kwargs):
|
|
657
|
+
kwargs.setdefault('leave', leave)
|
|
658
|
+
return orig(*args, **kwargs)
|
|
659
|
+
|
|
660
|
+
_fp.progress_bar = _wrapped
|
|
661
|
+
try:
|
|
662
|
+
yield True
|
|
663
|
+
finally:
|
|
664
|
+
_fp.progress_bar = orig
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
class _PrefixedWriter:
|
|
668
|
+
"""Stream wrapper that prefixes each line with worker id."""
|
|
669
|
+
|
|
670
|
+
def __init__(self, stream, prefix: str):
|
|
671
|
+
self._stream = stream
|
|
672
|
+
self._prefix = prefix
|
|
673
|
+
self._at_line_start = True
|
|
674
|
+
|
|
675
|
+
def write(self, s):
|
|
676
|
+
if not s:
|
|
677
|
+
return 0
|
|
678
|
+
total = 0
|
|
679
|
+
for chunk in s.splitlines(True):
|
|
680
|
+
if self._at_line_start:
|
|
681
|
+
self._stream.write(self._prefix)
|
|
682
|
+
total += len(self._prefix)
|
|
683
|
+
self._stream.write(chunk)
|
|
684
|
+
total += len(chunk)
|
|
685
|
+
self._at_line_start = chunk.endswith('\n')
|
|
686
|
+
return total
|
|
687
|
+
|
|
688
|
+
def flush(self):
|
|
689
|
+
self._stream.flush()
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
def _call_with_log_control(
|
|
693
|
+
func: Callable,
|
|
694
|
+
x: Any,
|
|
695
|
+
func_kwargs: dict[str, Any],
|
|
696
|
+
log_mode: Literal['all', 'zero', 'first'],
|
|
697
|
+
gate_path: Path | None,
|
|
698
|
+
):
|
|
699
|
+
"""Call a function, silencing stdout/stderr based on log mode."""
|
|
700
|
+
allow_logs = _should_allow_worker_logs(log_mode, gate_path)
|
|
701
|
+
if allow_logs:
|
|
702
|
+
prefix = f"[worker-{os.getpid()}] "
|
|
703
|
+
# Route worker logs to stderr to reduce clobbering tqdm/progress output on stdout
|
|
704
|
+
out = _PrefixedWriter(sys.stderr, prefix)
|
|
705
|
+
err = out
|
|
706
|
+
with contextlib.redirect_stdout(out), contextlib.redirect_stderr(err):
|
|
707
|
+
return func(x, **func_kwargs)
|
|
708
|
+
with open(os.devnull, 'w') as devnull, contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
|
|
709
|
+
return func(x, **func_kwargs)
|
|
710
|
+
|
|
711
|
+
|
|
124
712
|
# ─── ray management ─────────────────────────────────────────
|
|
125
713
|
|
|
126
714
|
RAY_WORKER = None
|
|
127
715
|
|
|
128
716
|
|
|
129
|
-
def ensure_ray(workers: int, pbar: tqdm | None = None):
|
|
130
|
-
"""
|
|
717
|
+
def ensure_ray(workers: int | None, pbar: tqdm | None = None, ray_metrics_port: int | None = None):
|
|
718
|
+
"""
|
|
719
|
+
Initialize or reinitialize Ray safely for both local and cluster environments.
|
|
720
|
+
|
|
721
|
+
1. Tries to connect to an existing cluster (address='auto') first.
|
|
722
|
+
2. If no cluster is found, starts a local Ray instance with 'workers' CPUs.
|
|
723
|
+
"""
|
|
131
724
|
import ray as _ray_module
|
|
132
725
|
import logging
|
|
133
726
|
|
|
134
727
|
global RAY_WORKER
|
|
135
|
-
|
|
136
|
-
if
|
|
137
|
-
|
|
138
|
-
|
|
728
|
+
requested_workers = workers
|
|
729
|
+
if workers is None:
|
|
730
|
+
workers = os.cpu_count() or 1
|
|
731
|
+
|
|
732
|
+
if ray_metrics_port is not None:
|
|
733
|
+
os.environ['RAY_metrics_export_port'] = str(ray_metrics_port)
|
|
734
|
+
|
|
735
|
+
allow_restart = os.environ.get("RESTART_RAY", "0").lower() in ("1", "true")
|
|
736
|
+
is_cluster_env = "RAY_ADDRESS" in os.environ or os.environ.get("RAY_CLUSTER") == "1"
|
|
737
|
+
|
|
738
|
+
# 1. Handle existing session
|
|
739
|
+
if _ray_module.is_initialized():
|
|
740
|
+
if not allow_restart:
|
|
741
|
+
if pbar:
|
|
742
|
+
pbar.set_postfix_str("Using existing Ray session")
|
|
743
|
+
return
|
|
744
|
+
|
|
745
|
+
# Avoid shutting down shared cluster sessions.
|
|
746
|
+
if is_cluster_env:
|
|
747
|
+
if pbar:
|
|
748
|
+
pbar.set_postfix_str("Cluster active: skipping restart to protect connection")
|
|
749
|
+
return
|
|
750
|
+
|
|
751
|
+
# Local restart: only if worker count changed
|
|
752
|
+
if workers != RAY_WORKER:
|
|
753
|
+
if pbar:
|
|
754
|
+
pbar.set_postfix_str(f'Restarting local Ray with {workers} workers')
|
|
139
755
|
_ray_module.shutdown()
|
|
140
|
-
|
|
756
|
+
else:
|
|
757
|
+
return
|
|
758
|
+
|
|
759
|
+
# 2. Initialization logic
|
|
760
|
+
t0 = time.time()
|
|
761
|
+
|
|
762
|
+
# Try to connect to existing cluster FIRST (address="auto")
|
|
763
|
+
try:
|
|
764
|
+
if pbar:
|
|
765
|
+
pbar.set_postfix_str("Searching for Ray cluster...")
|
|
766
|
+
|
|
767
|
+
# MUST NOT pass num_cpus/num_gpus here to avoid ValueError on existing clusters
|
|
768
|
+
_ray_module.init(
|
|
769
|
+
address="auto",
|
|
770
|
+
ignore_reinit_error=True,
|
|
771
|
+
logging_level=logging.ERROR,
|
|
772
|
+
log_to_driver=False
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
if pbar:
|
|
776
|
+
resources = _ray_module.cluster_resources()
|
|
777
|
+
cpus = resources.get("CPU", 0)
|
|
778
|
+
pbar.set_postfix_str(f"Connected to Ray Cluster ({int(cpus)} CPUs)")
|
|
779
|
+
|
|
780
|
+
except Exception:
|
|
781
|
+
# 3. Fallback: Start a local Ray session
|
|
782
|
+
if pbar:
|
|
783
|
+
pbar.set_postfix_str(f"No cluster found. Starting local Ray ({workers} CPUs)...")
|
|
784
|
+
|
|
141
785
|
_ray_module.init(
|
|
142
786
|
num_cpus=workers,
|
|
143
787
|
ignore_reinit_error=True,
|
|
144
788
|
logging_level=logging.ERROR,
|
|
145
789
|
log_to_driver=False,
|
|
146
790
|
)
|
|
147
|
-
|
|
148
|
-
_track_ray_processes() # Track Ray worker processes
|
|
791
|
+
|
|
149
792
|
if pbar:
|
|
150
|
-
|
|
151
|
-
|
|
793
|
+
took = time.time() - t0
|
|
794
|
+
pbar.set_postfix_str(f'ray.init local {workers} took {took:.2f}s')
|
|
795
|
+
|
|
796
|
+
_track_ray_processes()
|
|
797
|
+
|
|
798
|
+
if requested_workers is None:
|
|
799
|
+
try:
|
|
800
|
+
resources = _ray_module.cluster_resources()
|
|
801
|
+
total_cpus = int(resources.get("CPU", 0))
|
|
802
|
+
if total_cpus > 0:
|
|
803
|
+
workers = total_cpus
|
|
804
|
+
except Exception:
|
|
805
|
+
pass
|
|
806
|
+
|
|
807
|
+
RAY_WORKER = workers
|
|
152
808
|
|
|
809
|
+
|
|
810
|
+
# TODO: make smarter backend selection, when shared_kwargs is used, and backend != 'ray', do not raise error but change to ray and warning user about this
|
|
153
811
|
def multi_process(
|
|
154
812
|
func: Callable[[Any], Any],
|
|
155
813
|
items: Iterable[Any] | None = None,
|
|
@@ -158,11 +816,16 @@ def multi_process(
|
|
|
158
816
|
workers: int | None = None,
|
|
159
817
|
lazy_output: bool = False,
|
|
160
818
|
progress: bool = True,
|
|
161
|
-
|
|
162
|
-
backend: Literal['seq', 'ray', 'mp', 'threadpool', 'safe'] = 'mp',
|
|
819
|
+
backend: Literal['seq', 'ray', 'mp', 'safe'] = 'mp',
|
|
163
820
|
desc: str | None = None,
|
|
164
821
|
shared_kwargs: list[str] | None = None,
|
|
165
822
|
dump_in_thread: bool = True,
|
|
823
|
+
ray_metrics_port: int | None = None,
|
|
824
|
+
log_worker: Literal['zero', 'first', 'all'] = 'first',
|
|
825
|
+
total_items: int | None = None,
|
|
826
|
+
poll_interval: float = 0.3,
|
|
827
|
+
error_handler: ErrorHandlerType = 'raise',
|
|
828
|
+
max_error_files: int = 100,
|
|
166
829
|
**func_kwargs: Any,
|
|
167
830
|
) -> list[Any]:
|
|
168
831
|
"""
|
|
@@ -171,19 +834,44 @@ def multi_process(
|
|
|
171
834
|
backend:
|
|
172
835
|
- "seq": run sequentially
|
|
173
836
|
- "ray": run in parallel with Ray
|
|
174
|
-
- "mp": run in parallel with
|
|
175
|
-
- "threadpool": run in parallel with thread pool
|
|
837
|
+
- "mp": run in parallel with thread pool (uses ThreadPoolExecutor)
|
|
176
838
|
- "safe": run in parallel with thread pool (explicitly safe for tests)
|
|
177
839
|
|
|
178
840
|
shared_kwargs:
|
|
179
|
-
- Optional list of kwarg names that should be shared via Ray's
|
|
841
|
+
- Optional list of kwarg names that should be shared via Ray's
|
|
842
|
+
zero-copy object store
|
|
180
843
|
- Only works with Ray backend
|
|
181
|
-
- Useful for large objects (e.g., models, datasets)
|
|
182
|
-
- Example: shared_kwargs=['model', 'tokenizer']
|
|
844
|
+
- Useful for large objects (e.g., models, datasets)
|
|
845
|
+
- Example: shared_kwargs=['model', 'tokenizer']
|
|
183
846
|
|
|
184
847
|
dump_in_thread:
|
|
185
848
|
- Whether to dump results to disk in a separate thread (default: True)
|
|
186
|
-
- If False, dumping is done synchronously
|
|
849
|
+
- If False, dumping is done synchronously
|
|
850
|
+
|
|
851
|
+
ray_metrics_port:
|
|
852
|
+
- Optional port for Ray metrics export (Ray backend only)
|
|
853
|
+
|
|
854
|
+
log_worker:
|
|
855
|
+
- Control worker stdout/stderr noise
|
|
856
|
+
- 'first': only first worker emits logs (default)
|
|
857
|
+
- 'all': allow worker prints
|
|
858
|
+
- 'zero': silence all worker output
|
|
859
|
+
|
|
860
|
+
total_items:
|
|
861
|
+
- Optional item-level total for progress tracking (Ray backend only)
|
|
862
|
+
|
|
863
|
+
poll_interval:
|
|
864
|
+
- Poll interval in seconds for progress actor updates (Ray only)
|
|
865
|
+
|
|
866
|
+
error_handler:
|
|
867
|
+
- 'raise': raise exception on first error (default)
|
|
868
|
+
- 'ignore': continue processing, return None for failed items
|
|
869
|
+
- 'log': same as ignore, but logs errors to files
|
|
870
|
+
|
|
871
|
+
max_error_files:
|
|
872
|
+
- Maximum number of error log files to write (default: 100)
|
|
873
|
+
- Error logs are written to .cache/speedy_utils/error_logs/{idx}.log
|
|
874
|
+
- First error is always printed to screen with the log file path
|
|
187
875
|
|
|
188
876
|
If lazy_output=True, every result is saved to .pkl and
|
|
189
877
|
the returned list contains file paths.
|
|
@@ -219,11 +907,13 @@ def multi_process(
|
|
|
219
907
|
f"Valid parameters: {valid_params}"
|
|
220
908
|
)
|
|
221
909
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
910
|
+
# Prefer Ray backend when shared kwargs are requested
|
|
911
|
+
if shared_kwargs and backend != 'ray':
|
|
912
|
+
warnings.warn(
|
|
913
|
+
"shared_kwargs only supported with 'ray' backend, switching backend to 'ray'",
|
|
914
|
+
UserWarning,
|
|
915
|
+
)
|
|
916
|
+
backend = 'ray'
|
|
227
917
|
|
|
228
918
|
# unify items
|
|
229
919
|
# unify items and coerce to concrete list so we can use len() and
|
|
@@ -235,35 +925,111 @@ def multi_process(
|
|
|
235
925
|
if items is None:
|
|
236
926
|
raise ValueError("'items' or 'inputs' must be provided")
|
|
237
927
|
|
|
238
|
-
if workers is None:
|
|
928
|
+
if workers is None and backend != 'ray':
|
|
239
929
|
workers = os.cpu_count() or 1
|
|
240
930
|
|
|
241
931
|
# build cache dir + wrap func
|
|
242
932
|
cache_dir = _build_cache_dir(func, items) if lazy_output else None
|
|
243
933
|
f_wrapped = wrap_dump(func, cache_dir, dump_in_thread)
|
|
244
934
|
|
|
935
|
+
log_gate_path: Path | None = None
|
|
936
|
+
if log_worker == 'first':
|
|
937
|
+
log_gate_path = Path(tempfile.gettempdir()) / f'speedy_utils_log_gate_{os.getpid()}_{uuid.uuid4().hex}.gate'
|
|
938
|
+
elif log_worker not in ('zero', 'all'):
|
|
939
|
+
raise ValueError(f'Unsupported log_worker: {log_worker!r}')
|
|
940
|
+
|
|
245
941
|
total = len(items)
|
|
246
942
|
if desc:
|
|
247
943
|
desc = desc.strip() + f'[{backend}]'
|
|
248
944
|
else:
|
|
249
945
|
desc = f'Multi-process [{backend}]'
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
946
|
+
|
|
947
|
+
# Initialize error stats for error handling
|
|
948
|
+
func_name = getattr(func, '__name__', repr(func))
|
|
949
|
+
error_stats = ErrorStats(
|
|
950
|
+
func_name=func_name,
|
|
951
|
+
max_error_files=max_error_files,
|
|
952
|
+
write_logs=error_handler == 'log'
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
def _update_pbar_postfix(pbar: tqdm) -> None:
|
|
956
|
+
"""Update pbar with success/error counts."""
|
|
957
|
+
postfix = error_stats.get_postfix_dict()
|
|
958
|
+
pbar.set_postfix(postfix)
|
|
959
|
+
|
|
960
|
+
def _wrap_with_error_handler(
|
|
961
|
+
f: Callable,
|
|
962
|
+
idx: int,
|
|
963
|
+
input_value: Any,
|
|
964
|
+
error_stats_ref: ErrorStats,
|
|
965
|
+
handler: ErrorHandlerType,
|
|
966
|
+
) -> Callable:
|
|
967
|
+
"""Wrap function to handle errors based on error_handler setting."""
|
|
968
|
+
@functools.wraps(f)
|
|
969
|
+
def wrapper(*args, **kwargs):
|
|
970
|
+
try:
|
|
971
|
+
result = f(*args, **kwargs)
|
|
972
|
+
error_stats_ref.record_success()
|
|
973
|
+
return result
|
|
974
|
+
except Exception as e:
|
|
975
|
+
if handler == 'raise':
|
|
976
|
+
raise
|
|
977
|
+
error_stats_ref.record_error(idx, e, input_value, func_name)
|
|
978
|
+
return None
|
|
979
|
+
return wrapper
|
|
980
|
+
|
|
981
|
+
# ---- sequential backend ----
|
|
982
|
+
if backend == 'seq':
|
|
983
|
+
results: list[Any] = []
|
|
984
|
+
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
985
|
+
for idx, x in enumerate(items):
|
|
986
|
+
try:
|
|
987
|
+
result = _call_with_log_control(
|
|
988
|
+
f_wrapped,
|
|
989
|
+
x,
|
|
990
|
+
func_kwargs,
|
|
991
|
+
log_worker,
|
|
992
|
+
log_gate_path,
|
|
993
|
+
)
|
|
994
|
+
error_stats.record_success()
|
|
995
|
+
results.append(result)
|
|
996
|
+
except Exception as e:
|
|
997
|
+
if error_handler == 'raise':
|
|
998
|
+
raise
|
|
999
|
+
error_stats.record_error(idx, e, x, func_name)
|
|
1000
|
+
results.append(None)
|
|
257
1001
|
pbar.update(1)
|
|
258
|
-
|
|
1002
|
+
_update_pbar_postfix(pbar)
|
|
1003
|
+
_cleanup_log_gate(log_gate_path)
|
|
1004
|
+
return results
|
|
259
1005
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
1006
|
+
# ---- ray backend ----
|
|
1007
|
+
if backend == 'ray':
|
|
1008
|
+
import ray as _ray_module
|
|
263
1009
|
|
|
264
|
-
|
|
1010
|
+
# Capture caller frame for better error reporting
|
|
1011
|
+
caller_frame = inspect.currentframe()
|
|
1012
|
+
caller_info = None
|
|
1013
|
+
if caller_frame and caller_frame.f_back:
|
|
1014
|
+
caller_info = {
|
|
1015
|
+
'filename': caller_frame.f_back.f_code.co_filename,
|
|
1016
|
+
'lineno': caller_frame.f_back.f_lineno,
|
|
1017
|
+
'function': caller_frame.f_back.f_code.co_name,
|
|
1018
|
+
}
|
|
1019
|
+
|
|
1020
|
+
results = []
|
|
1021
|
+
gate_path_str = str(log_gate_path) if log_gate_path else None
|
|
1022
|
+
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
1023
|
+
ensure_ray(workers, pbar, ray_metrics_port)
|
|
265
1024
|
shared_refs = {}
|
|
266
1025
|
regular_kwargs = {}
|
|
1026
|
+
|
|
1027
|
+
# Create progress actor for item-level tracking if total_items specified
|
|
1028
|
+
progress_actor = None
|
|
1029
|
+
progress_poller = None
|
|
1030
|
+
if total_items is not None:
|
|
1031
|
+
progress_actor = create_progress_tracker(total_items, desc or "Items")
|
|
1032
|
+
shared_refs['progress_actor'] = progress_actor
|
|
267
1033
|
|
|
268
1034
|
if shared_kwargs:
|
|
269
1035
|
for kw in shared_kwargs:
|
|
@@ -283,60 +1049,193 @@ def multi_process(
|
|
|
283
1049
|
def _task(x, shared_refs_dict, regular_kwargs_dict):
|
|
284
1050
|
# Dereference shared objects (zero-copy for numpy arrays)
|
|
285
1051
|
import ray as _ray_in_task
|
|
286
|
-
|
|
287
|
-
|
|
1052
|
+
gate = Path(gate_path_str) if gate_path_str else None
|
|
1053
|
+
dereferenced = {}
|
|
1054
|
+
for k, v in shared_refs_dict.items():
|
|
1055
|
+
if k == 'progress_actor':
|
|
1056
|
+
dereferenced[k] = v
|
|
1057
|
+
else:
|
|
1058
|
+
dereferenced[k] = _ray_in_task.get(v)
|
|
288
1059
|
all_kwargs = {**dereferenced, **regular_kwargs_dict}
|
|
289
|
-
return
|
|
1060
|
+
return _call_with_log_control(
|
|
1061
|
+
f_wrapped,
|
|
1062
|
+
x,
|
|
1063
|
+
all_kwargs,
|
|
1064
|
+
log_worker,
|
|
1065
|
+
gate,
|
|
1066
|
+
)
|
|
290
1067
|
|
|
291
1068
|
refs = [
|
|
292
1069
|
_task.remote(x, shared_refs, regular_kwargs) for x in items
|
|
293
1070
|
]
|
|
294
1071
|
|
|
295
|
-
results = []
|
|
296
1072
|
t_start = time.time()
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
pbar.
|
|
1073
|
+
|
|
1074
|
+
if progress_actor is not None:
|
|
1075
|
+
pbar.total = total_items
|
|
1076
|
+
pbar.refresh()
|
|
1077
|
+
progress_poller = ProgressPoller(progress_actor, pbar, poll_interval)
|
|
1078
|
+
progress_poller.start()
|
|
1079
|
+
|
|
1080
|
+
for idx, r in enumerate(refs):
|
|
1081
|
+
try:
|
|
1082
|
+
result = _ray_module.get(r)
|
|
1083
|
+
error_stats.record_success()
|
|
1084
|
+
results.append(result)
|
|
1085
|
+
except _ray_module.exceptions.RayTaskError as e:
|
|
1086
|
+
if error_handler == 'raise':
|
|
1087
|
+
if progress_poller is not None:
|
|
1088
|
+
progress_poller.stop()
|
|
1089
|
+
_reraise_ray_error(e, pbar, caller_info)
|
|
1090
|
+
# Extract original error from RayTaskError
|
|
1091
|
+
cause = e.cause if hasattr(e, 'cause') else e.__cause__
|
|
1092
|
+
original_error = cause if cause else e
|
|
1093
|
+
error_stats.record_error(idx, original_error, items[idx], func_name)
|
|
1094
|
+
results.append(None)
|
|
1095
|
+
|
|
1096
|
+
if progress_actor is None:
|
|
1097
|
+
pbar.update(1)
|
|
1098
|
+
_update_pbar_postfix(pbar)
|
|
1099
|
+
|
|
1100
|
+
if progress_poller is not None:
|
|
1101
|
+
progress_poller.stop()
|
|
1102
|
+
|
|
300
1103
|
t_end = time.time()
|
|
301
|
-
|
|
302
|
-
|
|
1104
|
+
item_desc = f"{total_items:,} items" if total_items else f"{total} tasks"
|
|
1105
|
+
print(f"Ray processing took {t_end - t_start:.2f}s for {item_desc}")
|
|
1106
|
+
_cleanup_log_gate(log_gate_path)
|
|
1107
|
+
return results
|
|
303
1108
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
1109
|
+
# ---- fastcore/thread backend (mp) ----
|
|
1110
|
+
if backend == 'mp':
|
|
1111
|
+
import concurrent.futures
|
|
1112
|
+
|
|
1113
|
+
# Capture caller frame for better error reporting
|
|
1114
|
+
caller_frame = inspect.currentframe()
|
|
1115
|
+
caller_info = None
|
|
1116
|
+
if caller_frame and caller_frame.f_back:
|
|
1117
|
+
caller_info = {
|
|
1118
|
+
'filename': caller_frame.f_back.f_code.co_filename,
|
|
1119
|
+
'lineno': caller_frame.f_back.f_lineno,
|
|
1120
|
+
'function': caller_frame.f_back.f_code.co_name,
|
|
1121
|
+
}
|
|
1122
|
+
|
|
1123
|
+
def worker_func(x):
|
|
1124
|
+
return _call_with_log_control(
|
|
1125
|
+
f_wrapped,
|
|
1126
|
+
x,
|
|
1127
|
+
func_kwargs,
|
|
1128
|
+
log_worker,
|
|
1129
|
+
log_gate_path,
|
|
315
1130
|
)
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
import concurrent.futures
|
|
320
|
-
|
|
321
|
-
# Import thread tracking from thread module
|
|
1131
|
+
|
|
1132
|
+
results: list[Any] = [None] * total
|
|
1133
|
+
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
322
1134
|
try:
|
|
323
1135
|
from .thread import _prune_dead_threads, _track_executor_threads
|
|
324
|
-
|
|
325
|
-
with concurrent.futures.ThreadPoolExecutor(
|
|
326
|
-
max_workers=workers
|
|
327
|
-
) as executor:
|
|
328
|
-
_track_executor_threads(executor) # Track threads
|
|
329
|
-
results = list(executor.map(f_wrapped, items))
|
|
330
|
-
_prune_dead_threads() # Clean up dead threads
|
|
1136
|
+
has_thread_tracking = True
|
|
331
1137
|
except ImportError:
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
1138
|
+
has_thread_tracking = False
|
|
1139
|
+
|
|
1140
|
+
with concurrent.futures.ThreadPoolExecutor(
|
|
1141
|
+
max_workers=workers
|
|
1142
|
+
) as executor:
|
|
1143
|
+
if has_thread_tracking:
|
|
1144
|
+
_track_executor_threads(executor)
|
|
1145
|
+
|
|
1146
|
+
# Submit all tasks
|
|
1147
|
+
future_to_idx = {
|
|
1148
|
+
executor.submit(worker_func, x): idx
|
|
1149
|
+
for idx, x in enumerate(items)
|
|
1150
|
+
}
|
|
1151
|
+
|
|
1152
|
+
# Process results as they complete
|
|
1153
|
+
for future in concurrent.futures.as_completed(future_to_idx):
|
|
1154
|
+
idx = future_to_idx[future]
|
|
1155
|
+
try:
|
|
1156
|
+
result = future.result()
|
|
1157
|
+
error_stats.record_success()
|
|
1158
|
+
results[idx] = result
|
|
1159
|
+
except Exception as e:
|
|
1160
|
+
if error_handler == 'raise':
|
|
1161
|
+
# Cancel remaining futures
|
|
1162
|
+
for f in future_to_idx:
|
|
1163
|
+
f.cancel()
|
|
1164
|
+
_reraise_worker_error(e, pbar, caller_info, backend='mp')
|
|
1165
|
+
error_stats.record_error(idx, e, items[idx], func_name)
|
|
1166
|
+
results[idx] = None
|
|
1167
|
+
pbar.update(1)
|
|
1168
|
+
_update_pbar_postfix(pbar)
|
|
1169
|
+
|
|
1170
|
+
if _prune_dead_threads is not None:
|
|
1171
|
+
_prune_dead_threads()
|
|
1172
|
+
|
|
1173
|
+
_track_multiprocessing_processes()
|
|
1174
|
+
_prune_dead_processes()
|
|
1175
|
+
_cleanup_log_gate(log_gate_path)
|
|
1176
|
+
return results
|
|
1177
|
+
|
|
1178
|
+
if backend == 'safe':
|
|
1179
|
+
# Completely safe backend for tests - no multiprocessing
|
|
1180
|
+
import concurrent.futures
|
|
1181
|
+
|
|
1182
|
+
# Capture caller frame for better error reporting
|
|
1183
|
+
caller_frame = inspect.currentframe()
|
|
1184
|
+
caller_info = None
|
|
1185
|
+
if caller_frame and caller_frame.f_back:
|
|
1186
|
+
caller_info = {
|
|
1187
|
+
'filename': caller_frame.f_back.f_code.co_filename,
|
|
1188
|
+
'lineno': caller_frame.f_back.f_lineno,
|
|
1189
|
+
'function': caller_frame.f_back.f_code.co_name,
|
|
1190
|
+
}
|
|
1191
|
+
|
|
1192
|
+
def worker_func(x):
|
|
1193
|
+
return _call_with_log_control(
|
|
1194
|
+
f_wrapped,
|
|
1195
|
+
x,
|
|
1196
|
+
func_kwargs,
|
|
1197
|
+
log_worker,
|
|
1198
|
+
log_gate_path,
|
|
1199
|
+
)
|
|
1200
|
+
|
|
1201
|
+
results: list[Any] = [None] * total
|
|
1202
|
+
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
1203
|
+
with concurrent.futures.ThreadPoolExecutor(
|
|
1204
|
+
max_workers=workers
|
|
1205
|
+
) as executor:
|
|
1206
|
+
if _track_executor_threads is not None:
|
|
1207
|
+
_track_executor_threads(executor)
|
|
1208
|
+
|
|
1209
|
+
# Submit all tasks
|
|
1210
|
+
future_to_idx = {
|
|
1211
|
+
executor.submit(worker_func, x): idx
|
|
1212
|
+
for idx, x in enumerate(items)
|
|
1213
|
+
}
|
|
1214
|
+
|
|
1215
|
+
# Process results as they complete
|
|
1216
|
+
for future in concurrent.futures.as_completed(future_to_idx):
|
|
1217
|
+
idx = future_to_idx[future]
|
|
1218
|
+
try:
|
|
1219
|
+
result = future.result()
|
|
1220
|
+
error_stats.record_success()
|
|
1221
|
+
results[idx] = result
|
|
1222
|
+
except Exception as e:
|
|
1223
|
+
if error_handler == 'raise':
|
|
1224
|
+
for f in future_to_idx:
|
|
1225
|
+
f.cancel()
|
|
1226
|
+
_reraise_worker_error(e, pbar, caller_info, backend='safe')
|
|
1227
|
+
error_stats.record_error(idx, e, items[idx], func_name)
|
|
1228
|
+
results[idx] = None
|
|
1229
|
+
pbar.update(1)
|
|
1230
|
+
_update_pbar_postfix(pbar)
|
|
1231
|
+
|
|
1232
|
+
if _prune_dead_threads is not None:
|
|
1233
|
+
_prune_dead_threads()
|
|
1234
|
+
|
|
1235
|
+
_cleanup_log_gate(log_gate_path)
|
|
1236
|
+
return results
|
|
338
1237
|
|
|
339
|
-
|
|
1238
|
+
raise ValueError(f'Unsupported backend: {backend!r}')
|
|
340
1239
|
|
|
341
1240
|
|
|
342
1241
|
def cleanup_phantom_workers():
|
|
@@ -394,6 +1293,10 @@ def cleanup_phantom_workers():
|
|
|
394
1293
|
|
|
395
1294
|
__all__ = [
|
|
396
1295
|
'SPEEDY_RUNNING_PROCESSES',
|
|
1296
|
+
'ErrorStats',
|
|
1297
|
+
'ErrorHandlerType',
|
|
397
1298
|
'multi_process',
|
|
398
1299
|
'cleanup_phantom_workers',
|
|
1300
|
+
'create_progress_tracker',
|
|
1301
|
+
'get_ray_progress_actor',
|
|
399
1302
|
]
|