speedy-utils 1.1.46__py3-none-any.whl → 1.1.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +1 -3
- llm_utils/chat_format/__init__.py +0 -2
- llm_utils/chat_format/display.py +283 -364
- llm_utils/lm/llm.py +62 -22
- speedy_utils/__init__.py +4 -0
- speedy_utils/multi_worker/__init__.py +4 -0
- speedy_utils/multi_worker/_multi_process.py +425 -0
- speedy_utils/multi_worker/_multi_process_ray.py +308 -0
- speedy_utils/multi_worker/common.py +879 -0
- speedy_utils/multi_worker/dataset_sharding.py +203 -0
- speedy_utils/multi_worker/process.py +53 -1234
- speedy_utils/multi_worker/progress.py +71 -1
- speedy_utils/multi_worker/thread.py +45 -0
- speedy_utils/scripts/mpython.py +19 -12
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.48.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.48.dist-info}/RECORD +18 -14
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.48.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.48.dist-info}/entry_points.txt +0 -0
|
@@ -1,1248 +1,67 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Multi-process map with selectable backends.
|
|
3
|
+
|
|
4
|
+
This module re-exports from the refactored implementation files:
|
|
5
|
+
- common.py: shared utilities (error formatting, log gating, cache, tracking)
|
|
6
|
+
- _multi_process.py: sequential + threadpool backends + dispatcher
|
|
7
|
+
- _multi_process_ray.py: Ray-specific backend
|
|
8
|
+
|
|
9
|
+
For backward compatibility, all public symbols are re-exported here.
|
|
10
|
+
"""
|
|
1
11
|
import warnings
|
|
2
12
|
import os
|
|
13
|
+
|
|
3
14
|
# Suppress Ray FutureWarnings before any imports
|
|
4
|
-
warnings.filterwarnings(
|
|
5
|
-
warnings.filterwarnings(
|
|
15
|
+
warnings.filterwarnings('ignore', category=FutureWarning, module='ray.*')
|
|
16
|
+
warnings.filterwarnings(
|
|
17
|
+
'ignore',
|
|
18
|
+
message='.*pynvml.*deprecated.*',
|
|
19
|
+
category=FutureWarning,
|
|
20
|
+
)
|
|
6
21
|
|
|
7
22
|
# Set environment variables before Ray is imported anywhere
|
|
8
|
-
os.environ[
|
|
9
|
-
|
|
10
|
-
os.environ[
|
|
11
|
-
os.environ[
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
from
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
#
|
|
23
|
+
os.environ['RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO'] = '0'
|
|
24
|
+
os.environ['RAY_DEDUP_LOGS'] = '1'
|
|
25
|
+
os.environ['RAY_LOG_TO_STDERR'] = '0'
|
|
26
|
+
os.environ['RAY_IGNORE_UNHANDLED_ERRORS'] = '1'
|
|
27
|
+
|
|
28
|
+
# Re-export public API from common
|
|
29
|
+
from .common import (
|
|
30
|
+
ErrorHandlerType,
|
|
31
|
+
ErrorStats,
|
|
32
|
+
SPEEDY_RUNNING_PROCESSES,
|
|
33
|
+
cleanup_phantom_workers,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Re-export main dispatcher
|
|
37
|
+
from ._multi_process import multi_process
|
|
38
|
+
|
|
39
|
+
# Re-export Ray utilities (lazy import to avoid requiring Ray)
|
|
22
40
|
try:
|
|
23
|
-
from .
|
|
41
|
+
from ._multi_process_ray import ensure_ray, RAY_WORKER
|
|
24
42
|
except ImportError:
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
# ─── error handler types ────────────────────────────────────────
|
|
30
|
-
ErrorHandlerType = Literal['raise', 'ignore', 'log']
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class ErrorStats:
|
|
34
|
-
"""Thread-safe error statistics tracker."""
|
|
35
|
-
|
|
36
|
-
def __init__(
|
|
37
|
-
self,
|
|
38
|
-
func_name: str,
|
|
39
|
-
max_error_files: int = 100,
|
|
40
|
-
write_logs: bool = True
|
|
41
|
-
):
|
|
42
|
-
self._lock = threading.Lock()
|
|
43
|
-
self._success_count = 0
|
|
44
|
-
self._error_count = 0
|
|
45
|
-
self._first_error_shown = False
|
|
46
|
-
self._max_error_files = max_error_files
|
|
47
|
-
self._write_logs = write_logs
|
|
48
|
-
self._error_log_dir = self._get_error_log_dir(func_name)
|
|
49
|
-
if write_logs:
|
|
50
|
-
self._error_log_dir.mkdir(parents=True, exist_ok=True)
|
|
51
|
-
|
|
52
|
-
@staticmethod
|
|
53
|
-
def _get_error_log_dir(func_name: str) -> Path:
|
|
54
|
-
"""Generate unique error log directory with run counter."""
|
|
55
|
-
base_dir = Path('.cache/speedy_utils/error_logs')
|
|
56
|
-
base_dir.mkdir(parents=True, exist_ok=True)
|
|
57
|
-
|
|
58
|
-
# Find the next run counter
|
|
59
|
-
counter = 1
|
|
60
|
-
existing = list(base_dir.glob(f'{func_name}_run_*'))
|
|
61
|
-
if existing:
|
|
62
|
-
counters = []
|
|
63
|
-
for p in existing:
|
|
64
|
-
try:
|
|
65
|
-
parts = p.name.split('_run_')
|
|
66
|
-
if len(parts) == 2:
|
|
67
|
-
counters.append(int(parts[1]))
|
|
68
|
-
except (ValueError, IndexError):
|
|
69
|
-
pass
|
|
70
|
-
if counters:
|
|
71
|
-
counter = max(counters) + 1
|
|
72
|
-
|
|
73
|
-
return base_dir / f'{func_name}_run_{counter}'
|
|
74
|
-
|
|
75
|
-
def record_success(self) -> None:
|
|
76
|
-
with self._lock:
|
|
77
|
-
self._success_count += 1
|
|
78
|
-
|
|
79
|
-
def record_error(
|
|
80
|
-
self,
|
|
81
|
-
idx: int,
|
|
82
|
-
error: Exception,
|
|
83
|
-
input_value: Any,
|
|
84
|
-
func_name: str,
|
|
85
|
-
) -> str | None:
|
|
86
|
-
"""
|
|
87
|
-
Record an error and write to log file.
|
|
88
|
-
Returns the log file path if written, None otherwise.
|
|
89
|
-
"""
|
|
90
|
-
with self._lock:
|
|
91
|
-
self._error_count += 1
|
|
92
|
-
should_show_first = not self._first_error_shown
|
|
93
|
-
if should_show_first:
|
|
94
|
-
self._first_error_shown = True
|
|
95
|
-
should_write = (
|
|
96
|
-
self._write_logs and self._error_count <= self._max_error_files
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
log_path = None
|
|
100
|
-
if should_write:
|
|
101
|
-
log_path = self._write_error_log(
|
|
102
|
-
idx, error, input_value, func_name
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
if should_show_first:
|
|
106
|
-
self._print_first_error(error, input_value, func_name, log_path)
|
|
107
|
-
|
|
108
|
-
return log_path
|
|
109
|
-
|
|
110
|
-
def _write_error_log(
|
|
111
|
-
self,
|
|
112
|
-
idx: int,
|
|
113
|
-
error: Exception,
|
|
114
|
-
input_value: Any,
|
|
115
|
-
func_name: str,
|
|
116
|
-
) -> str:
|
|
117
|
-
"""Write error details to a log file."""
|
|
118
|
-
from io import StringIO
|
|
119
|
-
from rich.console import Console
|
|
120
|
-
|
|
121
|
-
log_path = self._error_log_dir / f'{idx}.log'
|
|
122
|
-
|
|
123
|
-
output = StringIO()
|
|
124
|
-
console = Console(file=output, width=120, no_color=False)
|
|
125
|
-
|
|
126
|
-
# Format traceback
|
|
127
|
-
tb_lines = self._format_traceback(error)
|
|
128
|
-
|
|
129
|
-
console.print(f'{"=" * 60}')
|
|
130
|
-
console.print(f'Error at index: {idx}')
|
|
131
|
-
console.print(f'Function: {func_name}')
|
|
132
|
-
console.print(f'Error Type: {type(error).__name__}')
|
|
133
|
-
console.print(f'Error Message: {error}')
|
|
134
|
-
console.print(f'{"=" * 60}')
|
|
135
|
-
console.print('')
|
|
136
|
-
console.print('Input:')
|
|
137
|
-
console.print('-' * 40)
|
|
138
|
-
try:
|
|
139
|
-
import json
|
|
140
|
-
console.print(json.dumps(input_value, indent=2))
|
|
141
|
-
except Exception:
|
|
142
|
-
console.print(repr(input_value))
|
|
143
|
-
console.print('')
|
|
144
|
-
console.print('Traceback:')
|
|
145
|
-
console.print('-' * 40)
|
|
146
|
-
for line in tb_lines:
|
|
147
|
-
console.print(line)
|
|
148
|
-
|
|
149
|
-
with open(log_path, 'w') as f:
|
|
150
|
-
f.write(output.getvalue())
|
|
151
|
-
|
|
152
|
-
return str(log_path)
|
|
153
|
-
|
|
154
|
-
def _format_traceback(self, error: Exception) -> list[str]:
|
|
155
|
-
"""Format traceback with context lines like Rich panel."""
|
|
156
|
-
lines = []
|
|
157
|
-
frames = _extract_frames_from_traceback(error)
|
|
158
|
-
|
|
159
|
-
for filepath, lineno, funcname, frame_locals in frames:
|
|
160
|
-
lines.append(f'│ {filepath}:{lineno} in {funcname} │')
|
|
161
|
-
lines.append('│' + ' ' * 70 + '│')
|
|
162
|
-
|
|
163
|
-
# Get context lines
|
|
164
|
-
context_size = 3
|
|
165
|
-
start_line = max(1, lineno - context_size)
|
|
166
|
-
end_line = lineno + context_size + 1
|
|
167
|
-
|
|
168
|
-
for line_num in range(start_line, end_line):
|
|
169
|
-
line_text = linecache.getline(filepath, line_num).rstrip()
|
|
170
|
-
if line_text:
|
|
171
|
-
num_str = str(line_num).rjust(4)
|
|
172
|
-
if line_num == lineno:
|
|
173
|
-
lines.append(f'│ {num_str} ❱ {line_text}')
|
|
174
|
-
else:
|
|
175
|
-
lines.append(f'│ {num_str} │ {line_text}')
|
|
176
|
-
lines.append('')
|
|
177
|
-
|
|
178
|
-
return lines
|
|
179
|
-
|
|
180
|
-
def _print_first_error(
|
|
181
|
-
self,
|
|
182
|
-
error: Exception,
|
|
183
|
-
input_value: Any,
|
|
184
|
-
func_name: str,
|
|
185
|
-
log_path: str | None,
|
|
186
|
-
) -> None:
|
|
187
|
-
"""Print the first error to screen with Rich formatting."""
|
|
188
|
-
try:
|
|
189
|
-
from rich.console import Console
|
|
190
|
-
from rich.panel import Panel
|
|
191
|
-
console = Console(stderr=True)
|
|
192
|
-
|
|
193
|
-
tb_lines = self._format_traceback(error)
|
|
194
|
-
|
|
195
|
-
console.print()
|
|
196
|
-
console.print(
|
|
197
|
-
Panel(
|
|
198
|
-
'\n'.join(tb_lines),
|
|
199
|
-
title='[bold red]First Error (continuing with remaining items)[/bold red]',
|
|
200
|
-
border_style='yellow',
|
|
201
|
-
expand=False,
|
|
202
|
-
)
|
|
203
|
-
)
|
|
204
|
-
console.print(
|
|
205
|
-
f'[bold red]{type(error).__name__}[/bold red]: {error}'
|
|
206
|
-
)
|
|
207
|
-
if log_path:
|
|
208
|
-
console.print(f'[dim]Error log: {log_path}[/dim]')
|
|
209
|
-
console.print()
|
|
210
|
-
except ImportError:
|
|
211
|
-
# Fallback to plain print
|
|
212
|
-
print(f'\n--- First Error ---', file=sys.stderr)
|
|
213
|
-
print(f'{type(error).__name__}: {error}', file=sys.stderr)
|
|
214
|
-
if log_path:
|
|
215
|
-
print(f'Error log: {log_path}', file=sys.stderr)
|
|
216
|
-
print('', file=sys.stderr)
|
|
217
|
-
|
|
218
|
-
@property
|
|
219
|
-
def success_count(self) -> int:
|
|
220
|
-
with self._lock:
|
|
221
|
-
return self._success_count
|
|
222
|
-
|
|
223
|
-
@property
|
|
224
|
-
def error_count(self) -> int:
|
|
225
|
-
with self._lock:
|
|
226
|
-
return self._error_count
|
|
227
|
-
|
|
228
|
-
def get_postfix_dict(self) -> dict[str, int]:
|
|
229
|
-
"""Get dict for pbar postfix."""
|
|
230
|
-
with self._lock:
|
|
231
|
-
return {'ok': self._success_count, 'err': self._error_count}
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
def _should_skip_frame(filepath: str) -> bool:
|
|
235
|
-
"""Check if a frame should be filtered from traceback display."""
|
|
236
|
-
skip_patterns = [
|
|
237
|
-
'ray/_private',
|
|
238
|
-
'ray/worker',
|
|
239
|
-
'site-packages/ray',
|
|
240
|
-
'speedy_utils/multi_worker',
|
|
241
|
-
'concurrent/futures',
|
|
242
|
-
'multiprocessing/',
|
|
243
|
-
'fastcore/parallel',
|
|
244
|
-
'fastcore/foundation',
|
|
245
|
-
'fastcore/basics',
|
|
246
|
-
'site-packages/fastcore',
|
|
247
|
-
'/threading.py',
|
|
248
|
-
'/concurrent/',
|
|
249
|
-
]
|
|
250
|
-
return any(skip in filepath for skip in skip_patterns)
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
def _should_show_local(name: str, value: object) -> bool:
|
|
254
|
-
"""Check if a local variable should be displayed in traceback."""
|
|
255
|
-
import types
|
|
256
|
-
|
|
257
|
-
# Skip dunder variables
|
|
258
|
-
if name.startswith('__') and name.endswith('__'):
|
|
259
|
-
return False
|
|
260
|
-
|
|
261
|
-
# Skip modules
|
|
262
|
-
if isinstance(value, types.ModuleType):
|
|
263
|
-
return False
|
|
264
|
-
|
|
265
|
-
# Skip type objects and classes
|
|
266
|
-
if isinstance(value, type):
|
|
267
|
-
return False
|
|
268
|
-
|
|
269
|
-
# Skip functions and methods
|
|
270
|
-
if isinstance(value, (types.FunctionType, types.MethodType, types.BuiltinFunctionType)):
|
|
271
|
-
return False
|
|
272
|
-
|
|
273
|
-
# Skip common typing aliases
|
|
274
|
-
value_str = str(value)
|
|
275
|
-
if value_str.startswith('typing.'):
|
|
276
|
-
return False
|
|
277
|
-
|
|
278
|
-
# Skip large objects that would clutter output
|
|
279
|
-
if value_str.startswith('<') and any(x in value_str for x in ['module', 'function', 'method', 'built-in']):
|
|
280
|
-
return False
|
|
281
|
-
|
|
282
|
-
return True
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
def _format_locals(frame_locals: dict) -> list[str]:
|
|
286
|
-
"""Format local variables for display, filtering out noisy imports."""
|
|
287
|
-
from rich.pretty import Pretty
|
|
288
|
-
from rich.console import Console
|
|
289
|
-
from io import StringIO
|
|
290
|
-
|
|
291
|
-
# Filter locals
|
|
292
|
-
clean_locals = {k: v for k, v in frame_locals.items() if _should_show_local(k, v)}
|
|
293
|
-
|
|
294
|
-
if not clean_locals:
|
|
295
|
-
return []
|
|
296
|
-
|
|
297
|
-
lines = []
|
|
298
|
-
lines.append('[dim]╭─ locals ─╮[/dim]')
|
|
299
|
-
|
|
300
|
-
# Format each local variable
|
|
301
|
-
for name, value in clean_locals.items():
|
|
302
|
-
# Use Rich's Pretty for nice formatting
|
|
303
|
-
try:
|
|
304
|
-
console = Console(file=StringIO(), width=60)
|
|
305
|
-
console.print(Pretty(value), end='')
|
|
306
|
-
value_str = console.file.getvalue().strip()
|
|
307
|
-
# Limit length
|
|
308
|
-
if len(value_str) > 100:
|
|
309
|
-
value_str = value_str[:97] + '...'
|
|
310
|
-
except Exception:
|
|
311
|
-
value_str = repr(value)
|
|
312
|
-
if len(value_str) > 100:
|
|
313
|
-
value_str = value_str[:97] + '...'
|
|
314
|
-
|
|
315
|
-
lines.append(f'[dim]│[/dim] {name} = {value_str}')
|
|
316
|
-
|
|
317
|
-
lines.append('[dim]╰──────────╯[/dim]')
|
|
318
|
-
return lines
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
def _format_frame_with_context(filepath: str, lineno: int, funcname: str, frame_locals: dict | None = None) -> list[str]:
|
|
322
|
-
"""Format a single frame with context lines and optional locals."""
|
|
323
|
-
lines = []
|
|
324
|
-
# Frame header
|
|
325
|
-
lines.append(
|
|
326
|
-
f'[cyan]{filepath}[/cyan]:[yellow]{lineno}[/yellow] '
|
|
327
|
-
f'in [green]{funcname}[/green]'
|
|
328
|
-
)
|
|
329
|
-
lines.append('')
|
|
330
|
-
|
|
331
|
-
# Get context lines
|
|
332
|
-
context_size = 3
|
|
333
|
-
start_line = max(1, lineno - context_size)
|
|
334
|
-
end_line = lineno + context_size + 1
|
|
335
|
-
|
|
336
|
-
for line_num in range(start_line, end_line):
|
|
337
|
-
import linecache
|
|
338
|
-
line_text = linecache.getline(filepath, line_num).rstrip()
|
|
339
|
-
if line_text:
|
|
340
|
-
num_str = str(line_num).rjust(4)
|
|
341
|
-
if line_num == lineno:
|
|
342
|
-
lines.append(f'[dim]{num_str}[/dim] [red]❱[/red] {line_text}')
|
|
343
|
-
else:
|
|
344
|
-
lines.append(f'[dim]{num_str} │[/dim] {line_text}')
|
|
345
|
-
|
|
346
|
-
# Add locals if available
|
|
347
|
-
if frame_locals:
|
|
348
|
-
locals_lines = _format_locals(frame_locals)
|
|
349
|
-
if locals_lines:
|
|
350
|
-
lines.append('')
|
|
351
|
-
lines.extend(locals_lines)
|
|
352
|
-
|
|
353
|
-
lines.append('')
|
|
354
|
-
return lines
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
def _display_formatted_error(
|
|
358
|
-
exc_type_name: str,
|
|
359
|
-
exc_msg: str,
|
|
360
|
-
frames: list[tuple[str, int, str, dict]],
|
|
361
|
-
caller_info: dict | None,
|
|
362
|
-
backend: str,
|
|
363
|
-
pbar=None,
|
|
364
|
-
) -> None:
|
|
365
|
-
|
|
366
|
-
# Suppress additional error logs
|
|
367
|
-
os.environ['RAY_IGNORE_UNHANDLED_ERRORS'] = '1'
|
|
368
|
-
|
|
369
|
-
# Close progress bar cleanly if provided
|
|
370
|
-
if pbar is not None:
|
|
371
|
-
pbar.close()
|
|
372
|
-
|
|
373
|
-
from rich.console import Console
|
|
374
|
-
from rich.panel import Panel
|
|
375
|
-
console = Console(stderr=True)
|
|
376
|
-
|
|
377
|
-
if frames or caller_info:
|
|
378
|
-
display_lines = []
|
|
379
|
-
|
|
380
|
-
# Add caller frame first if available (no locals for caller)
|
|
381
|
-
if caller_info:
|
|
382
|
-
display_lines.extend(_format_frame_with_context(
|
|
383
|
-
caller_info['filename'],
|
|
384
|
-
caller_info['lineno'],
|
|
385
|
-
caller_info['function'],
|
|
386
|
-
None # Don't show locals for caller frame
|
|
387
|
-
))
|
|
388
|
-
|
|
389
|
-
# Add error frames with locals
|
|
390
|
-
for filepath, lineno, funcname, frame_locals in frames:
|
|
391
|
-
display_lines.extend(_format_frame_with_context(
|
|
392
|
-
filepath, lineno, funcname, frame_locals
|
|
393
|
-
))
|
|
394
|
-
|
|
395
|
-
# Display the traceback
|
|
396
|
-
console.print()
|
|
397
|
-
console.print(
|
|
398
|
-
Panel(
|
|
399
|
-
'\n'.join(display_lines),
|
|
400
|
-
title=f'[bold red]Traceback (most recent call last) [{backend}][/bold red]',
|
|
401
|
-
border_style='red',
|
|
402
|
-
expand=False,
|
|
403
|
-
)
|
|
404
|
-
)
|
|
405
|
-
console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
|
|
406
|
-
console.print()
|
|
407
|
-
else:
|
|
408
|
-
# No frames found, minimal output
|
|
409
|
-
console.print()
|
|
410
|
-
console.print(f'[bold red]{exc_type_name}[/bold red]: {exc_msg}')
|
|
411
|
-
console.print()
|
|
412
|
-
|
|
413
|
-
# Ensure output is flushed
|
|
414
|
-
sys.stderr.flush()
|
|
415
|
-
sys.stdout.flush()
|
|
416
|
-
sys.exit(1)
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
def _extract_frames_from_traceback(error: Exception) -> list[tuple[str, int, str, dict]]:
|
|
420
|
-
"""Extract user frames from exception traceback object with locals."""
|
|
421
|
-
frames = []
|
|
422
|
-
if hasattr(error, '__traceback__') and error.__traceback__ is not None:
|
|
423
|
-
tb = error.__traceback__
|
|
424
|
-
while tb is not None:
|
|
425
|
-
frame = tb.tb_frame
|
|
426
|
-
filename = frame.f_code.co_filename
|
|
427
|
-
lineno = tb.tb_lineno
|
|
428
|
-
funcname = frame.f_code.co_name
|
|
429
|
-
|
|
430
|
-
if not _should_skip_frame(filename):
|
|
431
|
-
# Get local variables from the frame
|
|
432
|
-
frame_locals = dict(frame.f_locals)
|
|
433
|
-
frames.append((filename, lineno, funcname, frame_locals))
|
|
434
|
-
|
|
435
|
-
tb = tb.tb_next
|
|
436
|
-
return frames
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
def _extract_frames_from_ray_error(ray_task_error: Exception) -> list[tuple[str, int, str, dict]]:
|
|
440
|
-
"""Extract user frames from Ray's string traceback representation."""
|
|
441
|
-
frames = []
|
|
442
|
-
error_str = str(ray_task_error)
|
|
443
|
-
lines = error_str.split('\n')
|
|
444
|
-
|
|
445
|
-
import re
|
|
446
|
-
for i, line in enumerate(lines):
|
|
447
|
-
# Match: File "path", line N, in func
|
|
448
|
-
file_match = re.match(r'\s*File "([^"]+)", line (\d+), in (.+)', line)
|
|
449
|
-
if file_match:
|
|
450
|
-
filepath, lineno, funcname = file_match.groups()
|
|
451
|
-
if not _should_skip_frame(filepath):
|
|
452
|
-
# Ray doesn't preserve locals, so use empty dict
|
|
453
|
-
frames.append((filepath, int(lineno), funcname, {}))
|
|
454
|
-
|
|
455
|
-
return frames
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
def _reraise_worker_error(error: Exception, pbar=None, caller_info=None, backend: str = 'unknown') -> None:
|
|
459
|
-
"""
|
|
460
|
-
Re-raise the original exception from a worker error with clean traceback.
|
|
461
|
-
Works for multiprocessing, threadpool, and other backends with real tracebacks.
|
|
462
|
-
"""
|
|
463
|
-
frames = _extract_frames_from_traceback(error)
|
|
464
|
-
_display_formatted_error(
|
|
465
|
-
exc_type_name=type(error).__name__,
|
|
466
|
-
exc_msg=str(error),
|
|
467
|
-
frames=frames,
|
|
468
|
-
caller_info=caller_info,
|
|
469
|
-
backend=backend,
|
|
470
|
-
pbar=pbar,
|
|
471
|
-
)
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
def _reraise_ray_error(ray_task_error: Exception, pbar=None, caller_info=None) -> None:
|
|
475
|
-
"""
|
|
476
|
-
Re-raise the original exception from a RayTaskError with clean traceback.
|
|
477
|
-
Parses Ray's string traceback and displays with full context.
|
|
478
|
-
"""
|
|
479
|
-
# Get the exception info
|
|
480
|
-
cause = ray_task_error.cause if hasattr(ray_task_error, 'cause') else None
|
|
481
|
-
if cause is None:
|
|
482
|
-
cause = ray_task_error.__cause__
|
|
483
|
-
|
|
484
|
-
exc_type_name = type(cause).__name__ if cause else 'Error'
|
|
485
|
-
exc_msg = str(cause) if cause else str(ray_task_error)
|
|
486
|
-
|
|
487
|
-
frames = _extract_frames_from_ray_error(ray_task_error)
|
|
488
|
-
_display_formatted_error(
|
|
489
|
-
exc_type_name=exc_type_name,
|
|
490
|
-
exc_msg=exc_msg,
|
|
491
|
-
frames=frames,
|
|
492
|
-
caller_info=caller_info,
|
|
493
|
-
backend='ray',
|
|
494
|
-
pbar=pbar,
|
|
495
|
-
)
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
SPEEDY_RUNNING_PROCESSES: list[psutil.Process] = []
|
|
499
|
-
_SPEEDY_PROCESSES_LOCK = threading.Lock()
|
|
500
|
-
|
|
501
|
-
def _prune_dead_processes() -> None:
|
|
502
|
-
"""Remove dead processes from tracking list."""
|
|
503
|
-
with _SPEEDY_PROCESSES_LOCK:
|
|
504
|
-
SPEEDY_RUNNING_PROCESSES[:] = [
|
|
505
|
-
p for p in SPEEDY_RUNNING_PROCESSES if p.is_running()
|
|
506
|
-
]
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
def _track_processes(processes: list[psutil.Process]) -> None:
|
|
510
|
-
"""Add processes to global tracking list."""
|
|
511
|
-
if not processes:
|
|
512
|
-
return
|
|
513
|
-
with _SPEEDY_PROCESSES_LOCK:
|
|
514
|
-
living = [p for p in SPEEDY_RUNNING_PROCESSES if p.is_running()]
|
|
515
|
-
for candidate in processes:
|
|
516
|
-
if not candidate.is_running():
|
|
517
|
-
continue
|
|
518
|
-
if any(existing.pid == candidate.pid for existing in living):
|
|
519
|
-
continue
|
|
520
|
-
living.append(candidate)
|
|
521
|
-
SPEEDY_RUNNING_PROCESSES[:] = living
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
def _track_ray_processes() -> None:
|
|
525
|
-
"""Track Ray worker processes when Ray is initialized."""
|
|
526
|
-
|
|
527
|
-
try:
|
|
528
|
-
# Get Ray worker processes
|
|
529
|
-
current_pid = os.getpid()
|
|
530
|
-
parent = psutil.Process(current_pid)
|
|
531
|
-
ray_processes = []
|
|
532
|
-
for child in parent.children(recursive=True):
|
|
533
|
-
try:
|
|
534
|
-
if 'ray' in child.name().lower() or 'worker' in child.name().lower():
|
|
535
|
-
ray_processes.append(child)
|
|
536
|
-
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
537
|
-
continue
|
|
538
|
-
_track_processes(ray_processes)
|
|
539
|
-
except Exception:
|
|
540
|
-
# Don't fail if process tracking fails
|
|
541
|
-
pass
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
def _track_multiprocessing_processes() -> None:
|
|
545
|
-
"""Track multiprocessing worker processes."""
|
|
546
|
-
try:
|
|
547
|
-
# Find recently created child processes that might be multiprocessing workers
|
|
548
|
-
current_pid = os.getpid()
|
|
549
|
-
parent = psutil.Process(current_pid)
|
|
550
|
-
new_processes = []
|
|
551
|
-
for child in parent.children(recursive=False): # Only direct children
|
|
552
|
-
try:
|
|
553
|
-
# Basic heuristic: if it's a recent child process, it might be a worker
|
|
554
|
-
if (
|
|
555
|
-
time.time() - child.create_time() < 5
|
|
556
|
-
): # Created within last 5 seconds
|
|
557
|
-
new_processes.append(child)
|
|
558
|
-
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
559
|
-
continue
|
|
560
|
-
_track_processes(new_processes)
|
|
561
|
-
except Exception:
|
|
562
|
-
# Don't fail if process tracking fails
|
|
563
|
-
pass
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
# ─── cache helpers ──────────────────────────────────────────
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
def _build_cache_dir(func: Callable, items: list[Any]) -> Path:
|
|
570
|
-
"""Build cache dir with function name + timestamp."""
|
|
571
|
-
import datetime
|
|
572
|
-
func_name = getattr(func, '__name__', 'func')
|
|
573
|
-
now = datetime.datetime.now()
|
|
574
|
-
stamp = now.strftime('%m%d_%Hh%Mm%Ss')
|
|
575
|
-
run_id = f'{func_name}_{stamp}_{uuid.uuid4().hex[:6]}'
|
|
576
|
-
path = Path('.cache') / run_id
|
|
577
|
-
path.mkdir(parents=True, exist_ok=True)
|
|
578
|
-
return path
|
|
579
|
-
_DUMP_INTERMEDIATE_THREADS = []
|
|
580
|
-
def wrap_dump(func: Callable, cache_dir: Path | None, dump_in_thread: bool = True):
|
|
581
|
-
"""Wrap a function so results are dumped to .pkl when cache_dir is set."""
|
|
582
|
-
if cache_dir is None:
|
|
583
|
-
return func
|
|
584
|
-
|
|
585
|
-
def wrapped(x, *args, **kwargs):
|
|
586
|
-
res = func(x, *args, **kwargs)
|
|
587
|
-
p = cache_dir / f'{uuid.uuid4().hex}.pkl'
|
|
588
|
-
|
|
589
|
-
def save():
|
|
590
|
-
with open(p, 'wb') as fh:
|
|
591
|
-
pickle.dump(res, fh)
|
|
592
|
-
# Clean trash to avoid bloating memory
|
|
593
|
-
# print(f'Thread count: {threading.active_count()}')
|
|
594
|
-
# print(f'Saved result to {p}')
|
|
595
|
-
|
|
596
|
-
if dump_in_thread:
|
|
597
|
-
thread = threading.Thread(target=save)
|
|
598
|
-
_DUMP_INTERMEDIATE_THREADS.append(thread)
|
|
599
|
-
# count thread
|
|
600
|
-
# print(f'Thread count: {threading.active_count()}')
|
|
601
|
-
while threading.active_count() > 16:
|
|
602
|
-
time.sleep(0.1)
|
|
603
|
-
thread.start()
|
|
604
|
-
else:
|
|
605
|
-
save()
|
|
606
|
-
return str(p)
|
|
607
|
-
|
|
608
|
-
return wrapped
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
_LOG_GATE_CACHE: dict[str, bool] = {}
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
def _should_allow_worker_logs(mode: Literal['all', 'zero', 'first'], gate_path: Path | None) -> bool:
|
|
615
|
-
"""Determine if current worker should emit logs for the given mode."""
|
|
616
|
-
if mode == 'all':
|
|
617
|
-
return True
|
|
618
|
-
if mode == 'zero':
|
|
619
|
-
return False
|
|
620
|
-
if mode == 'first':
|
|
621
|
-
if gate_path is None:
|
|
622
|
-
return True
|
|
623
|
-
key = str(gate_path)
|
|
624
|
-
cached = _LOG_GATE_CACHE.get(key)
|
|
625
|
-
if cached is not None:
|
|
626
|
-
return cached
|
|
627
|
-
gate_path.parent.mkdir(parents=True, exist_ok=True)
|
|
628
|
-
try:
|
|
629
|
-
fd = os.open(key, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
|
|
630
|
-
except FileExistsError:
|
|
631
|
-
allowed = False
|
|
632
|
-
else:
|
|
633
|
-
os.close(fd)
|
|
634
|
-
allowed = True
|
|
635
|
-
_LOG_GATE_CACHE[key] = allowed
|
|
636
|
-
return allowed
|
|
637
|
-
raise ValueError(f'Unsupported log mode: {mode!r}')
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
def _cleanup_log_gate(gate_path: Path | None):
|
|
641
|
-
if gate_path is None:
|
|
642
|
-
return
|
|
643
|
-
try:
|
|
644
|
-
gate_path.unlink(missing_ok=True)
|
|
645
|
-
except OSError:
|
|
646
|
-
pass
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
@contextlib.contextmanager
|
|
650
|
-
def _patch_fastcore_progress_bar(*, leave: bool = True):
|
|
651
|
-
"""Temporarily force fastcore.progress_bar to keep the bar on screen."""
|
|
652
|
-
try:
|
|
653
|
-
import fastcore.parallel as _fp
|
|
654
|
-
except ImportError:
|
|
655
|
-
yield False
|
|
656
|
-
return
|
|
657
|
-
|
|
658
|
-
orig = getattr(_fp, 'progress_bar', None)
|
|
659
|
-
if orig is None:
|
|
660
|
-
yield False
|
|
661
|
-
return
|
|
662
|
-
|
|
663
|
-
def _wrapped(*args, **kwargs):
|
|
664
|
-
kwargs.setdefault('leave', leave)
|
|
665
|
-
return orig(*args, **kwargs)
|
|
666
|
-
|
|
667
|
-
_fp.progress_bar = _wrapped
|
|
668
|
-
try:
|
|
669
|
-
yield True
|
|
670
|
-
finally:
|
|
671
|
-
_fp.progress_bar = orig
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
class _PrefixedWriter:
|
|
675
|
-
"""Stream wrapper that prefixes each line with worker id."""
|
|
676
|
-
|
|
677
|
-
def __init__(self, stream, prefix: str):
|
|
678
|
-
self._stream = stream
|
|
679
|
-
self._prefix = prefix
|
|
680
|
-
self._at_line_start = True
|
|
681
|
-
|
|
682
|
-
def write(self, s):
|
|
683
|
-
if not s:
|
|
684
|
-
return 0
|
|
685
|
-
total = 0
|
|
686
|
-
for chunk in s.splitlines(True):
|
|
687
|
-
if self._at_line_start:
|
|
688
|
-
self._stream.write(self._prefix)
|
|
689
|
-
total += len(self._prefix)
|
|
690
|
-
self._stream.write(chunk)
|
|
691
|
-
total += len(chunk)
|
|
692
|
-
self._at_line_start = chunk.endswith('\n')
|
|
693
|
-
return total
|
|
694
|
-
|
|
695
|
-
def flush(self):
|
|
696
|
-
self._stream.flush()
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
def _call_with_log_control(
|
|
700
|
-
func: Callable,
|
|
701
|
-
x: Any,
|
|
702
|
-
func_kwargs: dict[str, Any],
|
|
703
|
-
log_mode: Literal['all', 'zero', 'first'],
|
|
704
|
-
gate_path: Path | None,
|
|
705
|
-
):
|
|
706
|
-
"""Call a function, silencing stdout/stderr based on log mode."""
|
|
707
|
-
allow_logs = _should_allow_worker_logs(log_mode, gate_path)
|
|
708
|
-
if allow_logs:
|
|
709
|
-
prefix = f"[worker-{os.getpid()}] "
|
|
710
|
-
# Route worker logs to stderr to reduce clobbering tqdm/progress output on stdout
|
|
711
|
-
out = _PrefixedWriter(sys.stderr, prefix)
|
|
712
|
-
err = out
|
|
713
|
-
with contextlib.redirect_stdout(out), contextlib.redirect_stderr(err):
|
|
714
|
-
return func(x, **func_kwargs)
|
|
715
|
-
with open(os.devnull, 'w') as devnull, contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
|
|
716
|
-
return func(x, **func_kwargs)
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
# ─── ray management ─────────────────────────────────────────
|
|
720
|
-
|
|
721
|
-
RAY_WORKER = None
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
def ensure_ray(workers: int | None, pbar: tqdm | None = None, ray_metrics_port: int | None = None):
|
|
725
|
-
"""
|
|
726
|
-
Initialize or reinitialize Ray safely for both local and cluster environments.
|
|
727
|
-
|
|
728
|
-
1. Tries to connect to an existing cluster (address='auto') first.
|
|
729
|
-
2. If no cluster is found, starts a local Ray instance with 'workers' CPUs.
|
|
730
|
-
"""
|
|
731
|
-
import ray as _ray_module
|
|
732
|
-
import logging
|
|
733
|
-
|
|
734
|
-
global RAY_WORKER
|
|
735
|
-
requested_workers = workers
|
|
736
|
-
if workers is None:
|
|
737
|
-
workers = os.cpu_count() or 1
|
|
738
|
-
|
|
739
|
-
if ray_metrics_port is not None:
|
|
740
|
-
os.environ['RAY_metrics_export_port'] = str(ray_metrics_port)
|
|
741
|
-
|
|
742
|
-
allow_restart = os.environ.get("RESTART_RAY", "0").lower() in ("1", "true")
|
|
743
|
-
is_cluster_env = "RAY_ADDRESS" in os.environ or os.environ.get("RAY_CLUSTER") == "1"
|
|
744
|
-
|
|
745
|
-
# 1. Handle existing session
|
|
746
|
-
if _ray_module.is_initialized():
|
|
747
|
-
if not allow_restart:
|
|
748
|
-
if pbar:
|
|
749
|
-
pbar.set_postfix_str("Using existing Ray session")
|
|
750
|
-
return
|
|
751
|
-
|
|
752
|
-
# Avoid shutting down shared cluster sessions.
|
|
753
|
-
if is_cluster_env:
|
|
754
|
-
if pbar:
|
|
755
|
-
pbar.set_postfix_str("Cluster active: skipping restart to protect connection")
|
|
756
|
-
return
|
|
757
|
-
|
|
758
|
-
# Local restart: only if worker count changed
|
|
759
|
-
if workers != RAY_WORKER:
|
|
760
|
-
if pbar:
|
|
761
|
-
pbar.set_postfix_str(f'Restarting local Ray with {workers} workers')
|
|
762
|
-
_ray_module.shutdown()
|
|
763
|
-
else:
|
|
764
|
-
return
|
|
765
|
-
|
|
766
|
-
# 2. Initialization logic
|
|
767
|
-
t0 = time.time()
|
|
768
|
-
|
|
769
|
-
# Try to connect to existing cluster FIRST (address="auto")
|
|
770
|
-
try:
|
|
771
|
-
if pbar:
|
|
772
|
-
pbar.set_postfix_str("Searching for Ray cluster...")
|
|
773
|
-
|
|
774
|
-
# MUST NOT pass num_cpus/num_gpus here to avoid ValueError on existing clusters
|
|
775
|
-
_ray_module.init(
|
|
776
|
-
address="auto",
|
|
777
|
-
ignore_reinit_error=True,
|
|
778
|
-
logging_level=logging.ERROR,
|
|
779
|
-
log_to_driver=False
|
|
780
|
-
)
|
|
781
|
-
|
|
782
|
-
if pbar:
|
|
783
|
-
resources = _ray_module.cluster_resources()
|
|
784
|
-
cpus = resources.get("CPU", 0)
|
|
785
|
-
pbar.set_postfix_str(f"Connected to Ray Cluster ({int(cpus)} CPUs)")
|
|
786
|
-
|
|
787
|
-
except Exception:
|
|
788
|
-
# 3. Fallback: Start a local Ray session
|
|
789
|
-
if pbar:
|
|
790
|
-
pbar.set_postfix_str(f"No cluster found. Starting local Ray ({workers} CPUs)...")
|
|
791
|
-
|
|
792
|
-
_ray_module.init(
|
|
793
|
-
num_cpus=workers,
|
|
794
|
-
ignore_reinit_error=True,
|
|
795
|
-
logging_level=logging.ERROR,
|
|
796
|
-
log_to_driver=False,
|
|
797
|
-
)
|
|
798
|
-
|
|
799
|
-
if pbar:
|
|
800
|
-
took = time.time() - t0
|
|
801
|
-
pbar.set_postfix_str(f'ray.init local {workers} took {took:.2f}s')
|
|
802
|
-
|
|
803
|
-
_track_ray_processes()
|
|
804
|
-
|
|
805
|
-
if requested_workers is None:
|
|
806
|
-
try:
|
|
807
|
-
resources = _ray_module.cluster_resources()
|
|
808
|
-
total_cpus = int(resources.get("CPU", 0))
|
|
809
|
-
if total_cpus > 0:
|
|
810
|
-
workers = total_cpus
|
|
811
|
-
except Exception:
|
|
812
|
-
pass
|
|
813
|
-
|
|
814
|
-
RAY_WORKER = workers
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
# TODO: make smarter backend selection, when shared_kwargs is used, and backend != 'ray', do not raise error but change to ray and warning user about this
|
|
818
|
-
def multi_process(
|
|
819
|
-
func: Callable[[Any], Any],
|
|
820
|
-
items: Iterable[Any] | None = None,
|
|
821
|
-
*,
|
|
822
|
-
inputs: Iterable[Any] | None = None,
|
|
823
|
-
workers: int | None = None,
|
|
824
|
-
lazy_output: bool = False,
|
|
825
|
-
progress: bool = True,
|
|
826
|
-
backend: Literal['seq', 'ray', 'mp', 'safe'] = 'mp',
|
|
827
|
-
desc: str | None = None,
|
|
828
|
-
shared_kwargs: list[str] | None = None,
|
|
829
|
-
dump_in_thread: bool = True,
|
|
830
|
-
ray_metrics_port: int | None = None,
|
|
831
|
-
log_worker: Literal['zero', 'first', 'all'] = 'first',
|
|
832
|
-
total_items: int | None = None,
|
|
833
|
-
poll_interval: float = 0.3,
|
|
834
|
-
error_handler: ErrorHandlerType = 'log',
|
|
835
|
-
max_error_files: int = 100,
|
|
836
|
-
**func_kwargs: Any,
|
|
837
|
-
) -> list[Any]:
|
|
838
|
-
"""
|
|
839
|
-
Multi-process map with selectable backend.
|
|
840
|
-
|
|
841
|
-
backend:
|
|
842
|
-
- "seq": run sequentially
|
|
843
|
-
- "ray": run in parallel with Ray
|
|
844
|
-
- "mp": run in parallel with thread pool (uses ThreadPoolExecutor)
|
|
845
|
-
- "safe": run in parallel with thread pool (explicitly safe for tests)
|
|
43
|
+
ensure_ray = None # type: ignore[assignment,misc]
|
|
44
|
+
RAY_WORKER = None # type: ignore[assignment,misc]
|
|
846
45
|
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
zero-copy object store
|
|
850
|
-
- Only works with Ray backend
|
|
851
|
-
- Useful for large objects (e.g., models, datasets)
|
|
852
|
-
- Example: shared_kwargs=['model', 'tokenizer']
|
|
46
|
+
# Re-export progress utilities
|
|
47
|
+
from .progress import create_progress_tracker, get_ray_progress_actor
|
|
853
48
|
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
- If False, dumping is done synchronously
|
|
49
|
+
# Re-export tqdm for backward compatibility
|
|
50
|
+
from tqdm import tqdm
|
|
857
51
|
|
|
858
|
-
ray_metrics_port:
|
|
859
|
-
- Optional port for Ray metrics export (Ray backend only)
|
|
860
52
|
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
error_handler:
|
|
874
|
-
- 'raise': raise exception on first error (default)
|
|
875
|
-
- 'ignore': continue processing, return None for failed items
|
|
876
|
-
- 'log': same as ignore, but logs errors to files
|
|
877
|
-
|
|
878
|
-
max_error_files:
|
|
879
|
-
- Maximum number of error log files to write (default: 100)
|
|
880
|
-
- Error logs are written to .cache/speedy_utils/error_logs/{idx}.log
|
|
881
|
-
- First error is always printed to screen with the log file path
|
|
882
|
-
|
|
883
|
-
If lazy_output=True, every result is saved to .pkl and
|
|
884
|
-
the returned list contains file paths.
|
|
885
|
-
"""
|
|
886
|
-
|
|
887
|
-
# default backend selection
|
|
888
|
-
if backend is None:
|
|
889
|
-
try:
|
|
890
|
-
import ray as _ray_module
|
|
891
|
-
backend = 'ray'
|
|
892
|
-
except ImportError:
|
|
893
|
-
backend = 'mp'
|
|
894
|
-
|
|
895
|
-
# Validate shared_kwargs
|
|
896
|
-
if shared_kwargs:
|
|
897
|
-
# Validate that all shared_kwargs are valid kwargs for the function
|
|
898
|
-
sig = inspect.signature(func)
|
|
899
|
-
valid_params = set(sig.parameters.keys())
|
|
900
|
-
|
|
901
|
-
for kw in shared_kwargs:
|
|
902
|
-
if kw not in func_kwargs:
|
|
903
|
-
raise ValueError(
|
|
904
|
-
f"shared_kwargs key '{kw}' not found in provided func_kwargs"
|
|
905
|
-
)
|
|
906
|
-
# Check if parameter exists in function signature or if function accepts **kwargs
|
|
907
|
-
has_var_keyword = any(
|
|
908
|
-
p.kind == inspect.Parameter.VAR_KEYWORD
|
|
909
|
-
for p in sig.parameters.values()
|
|
910
|
-
)
|
|
911
|
-
if kw not in valid_params and not has_var_keyword:
|
|
912
|
-
raise ValueError(
|
|
913
|
-
f"shared_kwargs key '{kw}' is not a valid parameter for function '{func.__name__}'. "
|
|
914
|
-
f"Valid parameters: {valid_params}"
|
|
915
|
-
)
|
|
916
|
-
|
|
917
|
-
# Prefer Ray backend when shared kwargs are requested
|
|
918
|
-
if shared_kwargs and backend != 'ray':
|
|
919
|
-
warnings.warn(
|
|
920
|
-
"shared_kwargs only supported with 'ray' backend, switching backend to 'ray'",
|
|
921
|
-
UserWarning,
|
|
922
|
-
)
|
|
923
|
-
backend = 'ray'
|
|
924
|
-
|
|
925
|
-
# unify items
|
|
926
|
-
# unify items and coerce to concrete list so we can use len() and
|
|
927
|
-
# iterate multiple times. This accepts ranges and other iterables.
|
|
928
|
-
if items is None and inputs is not None:
|
|
929
|
-
items = list(inputs)
|
|
930
|
-
if items is not None and not isinstance(items, list):
|
|
931
|
-
items = list(items)
|
|
932
|
-
if items is None:
|
|
933
|
-
raise ValueError("'items' or 'inputs' must be provided")
|
|
934
|
-
|
|
935
|
-
if workers is None and backend != 'ray':
|
|
936
|
-
workers = os.cpu_count() or 1
|
|
937
|
-
|
|
938
|
-
# build cache dir + wrap func
|
|
939
|
-
cache_dir = _build_cache_dir(func, items) if lazy_output else None
|
|
940
|
-
f_wrapped = wrap_dump(func, cache_dir, dump_in_thread)
|
|
941
|
-
|
|
942
|
-
log_gate_path: Path | None = None
|
|
943
|
-
if log_worker == 'first':
|
|
944
|
-
log_gate_path = Path(tempfile.gettempdir()) / f'speedy_utils_log_gate_{os.getpid()}_{uuid.uuid4().hex}.gate'
|
|
945
|
-
elif log_worker not in ('zero', 'all'):
|
|
946
|
-
raise ValueError(f'Unsupported log_worker: {log_worker!r}')
|
|
947
|
-
|
|
948
|
-
total = len(items)
|
|
949
|
-
if desc:
|
|
950
|
-
desc = desc.strip() + f'[{backend}]'
|
|
951
|
-
else:
|
|
952
|
-
desc = f'Multi-process [{backend}]'
|
|
953
|
-
|
|
954
|
-
# Initialize error stats for error handling
|
|
955
|
-
func_name = getattr(func, '__name__', repr(func))
|
|
956
|
-
error_stats = ErrorStats(
|
|
957
|
-
func_name=func_name,
|
|
958
|
-
max_error_files=max_error_files,
|
|
959
|
-
write_logs=error_handler == 'log'
|
|
960
|
-
)
|
|
961
|
-
|
|
962
|
-
def _update_pbar_postfix(pbar: tqdm) -> None:
|
|
963
|
-
"""Update pbar with success/error counts."""
|
|
964
|
-
postfix = error_stats.get_postfix_dict()
|
|
965
|
-
pbar.set_postfix(postfix)
|
|
966
|
-
|
|
967
|
-
def _wrap_with_error_handler(
|
|
968
|
-
f: Callable,
|
|
969
|
-
idx: int,
|
|
970
|
-
input_value: Any,
|
|
971
|
-
error_stats_ref: ErrorStats,
|
|
972
|
-
handler: ErrorHandlerType,
|
|
973
|
-
) -> Callable:
|
|
974
|
-
"""Wrap function to handle errors based on error_handler setting."""
|
|
975
|
-
@functools.wraps(f)
|
|
976
|
-
def wrapper(*args, **kwargs):
|
|
977
|
-
try:
|
|
978
|
-
result = f(*args, **kwargs)
|
|
979
|
-
error_stats_ref.record_success()
|
|
980
|
-
return result
|
|
981
|
-
except Exception as e:
|
|
982
|
-
if handler == 'raise':
|
|
983
|
-
raise
|
|
984
|
-
error_stats_ref.record_error(idx, e, input_value, func_name)
|
|
985
|
-
return None
|
|
986
|
-
return wrapper
|
|
987
|
-
|
|
988
|
-
# ---- sequential backend ----
|
|
989
|
-
if backend == 'seq':
|
|
990
|
-
results: list[Any] = []
|
|
991
|
-
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
992
|
-
for idx, x in enumerate(items):
|
|
993
|
-
try:
|
|
994
|
-
result = _call_with_log_control(
|
|
995
|
-
f_wrapped,
|
|
996
|
-
x,
|
|
997
|
-
func_kwargs,
|
|
998
|
-
log_worker,
|
|
999
|
-
log_gate_path,
|
|
1000
|
-
)
|
|
1001
|
-
error_stats.record_success()
|
|
1002
|
-
results.append(result)
|
|
1003
|
-
except Exception as e:
|
|
1004
|
-
if error_handler == 'raise':
|
|
1005
|
-
raise
|
|
1006
|
-
error_stats.record_error(idx, e, x, func_name)
|
|
1007
|
-
results.append(None)
|
|
1008
|
-
pbar.update(1)
|
|
1009
|
-
_update_pbar_postfix(pbar)
|
|
1010
|
-
_cleanup_log_gate(log_gate_path)
|
|
1011
|
-
return results
|
|
1012
|
-
|
|
1013
|
-
# ---- ray backend ----
|
|
1014
|
-
if backend == 'ray':
|
|
1015
|
-
import ray as _ray_module
|
|
1016
|
-
|
|
1017
|
-
# Capture caller frame for better error reporting
|
|
1018
|
-
caller_frame = inspect.currentframe()
|
|
1019
|
-
caller_info = None
|
|
1020
|
-
if caller_frame and caller_frame.f_back:
|
|
1021
|
-
caller_info = {
|
|
1022
|
-
'filename': caller_frame.f_back.f_code.co_filename,
|
|
1023
|
-
'lineno': caller_frame.f_back.f_lineno,
|
|
1024
|
-
'function': caller_frame.f_back.f_code.co_name,
|
|
1025
|
-
}
|
|
1026
|
-
|
|
1027
|
-
results = []
|
|
1028
|
-
gate_path_str = str(log_gate_path) if log_gate_path else None
|
|
1029
|
-
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
1030
|
-
ensure_ray(workers, pbar, ray_metrics_port)
|
|
1031
|
-
shared_refs = {}
|
|
1032
|
-
regular_kwargs = {}
|
|
1033
|
-
|
|
1034
|
-
# Create progress actor for item-level tracking if total_items specified
|
|
1035
|
-
progress_actor = None
|
|
1036
|
-
progress_poller = None
|
|
1037
|
-
if total_items is not None:
|
|
1038
|
-
progress_actor = create_progress_tracker(total_items, desc or "Items")
|
|
1039
|
-
shared_refs['progress_actor'] = progress_actor
|
|
1040
|
-
|
|
1041
|
-
if shared_kwargs:
|
|
1042
|
-
for kw in shared_kwargs:
|
|
1043
|
-
# Put large objects in Ray's object store (zero-copy)
|
|
1044
|
-
shared_refs[kw] = _ray_module.put(func_kwargs[kw])
|
|
1045
|
-
pbar.set_postfix_str(f'ray: shared `{kw}` via object store')
|
|
1046
|
-
|
|
1047
|
-
# Remaining kwargs are regular
|
|
1048
|
-
regular_kwargs = {
|
|
1049
|
-
k: v for k, v in func_kwargs.items()
|
|
1050
|
-
if k not in shared_kwargs
|
|
1051
|
-
}
|
|
1052
|
-
else:
|
|
1053
|
-
regular_kwargs = func_kwargs
|
|
1054
|
-
|
|
1055
|
-
@_ray_module.remote
|
|
1056
|
-
def _task(x, shared_refs_dict, regular_kwargs_dict):
|
|
1057
|
-
# Dereference shared objects (zero-copy for numpy arrays)
|
|
1058
|
-
import ray as _ray_in_task
|
|
1059
|
-
gate = Path(gate_path_str) if gate_path_str else None
|
|
1060
|
-
dereferenced = {}
|
|
1061
|
-
for k, v in shared_refs_dict.items():
|
|
1062
|
-
if k == 'progress_actor':
|
|
1063
|
-
dereferenced[k] = v
|
|
1064
|
-
else:
|
|
1065
|
-
dereferenced[k] = _ray_in_task.get(v)
|
|
1066
|
-
all_kwargs = {**dereferenced, **regular_kwargs_dict}
|
|
1067
|
-
return _call_with_log_control(
|
|
1068
|
-
f_wrapped,
|
|
1069
|
-
x,
|
|
1070
|
-
all_kwargs,
|
|
1071
|
-
log_worker,
|
|
1072
|
-
gate,
|
|
1073
|
-
)
|
|
1074
|
-
|
|
1075
|
-
refs = [
|
|
1076
|
-
_task.remote(x, shared_refs, regular_kwargs) for x in items
|
|
1077
|
-
]
|
|
1078
|
-
|
|
1079
|
-
t_start = time.time()
|
|
1080
|
-
|
|
1081
|
-
if progress_actor is not None:
|
|
1082
|
-
pbar.total = total_items
|
|
1083
|
-
pbar.refresh()
|
|
1084
|
-
progress_poller = ProgressPoller(progress_actor, pbar, poll_interval)
|
|
1085
|
-
progress_poller.start()
|
|
1086
|
-
|
|
1087
|
-
for idx, r in enumerate(refs):
|
|
1088
|
-
try:
|
|
1089
|
-
result = _ray_module.get(r)
|
|
1090
|
-
error_stats.record_success()
|
|
1091
|
-
results.append(result)
|
|
1092
|
-
except _ray_module.exceptions.RayTaskError as e:
|
|
1093
|
-
if error_handler == 'raise':
|
|
1094
|
-
if progress_poller is not None:
|
|
1095
|
-
progress_poller.stop()
|
|
1096
|
-
_reraise_ray_error(e, pbar, caller_info)
|
|
1097
|
-
# Extract original error from RayTaskError
|
|
1098
|
-
cause = e.cause if hasattr(e, 'cause') else e.__cause__
|
|
1099
|
-
original_error = cause if cause else e
|
|
1100
|
-
error_stats.record_error(idx, original_error, items[idx], func_name)
|
|
1101
|
-
results.append(None)
|
|
1102
|
-
|
|
1103
|
-
if progress_actor is None:
|
|
1104
|
-
pbar.update(1)
|
|
1105
|
-
_update_pbar_postfix(pbar)
|
|
1106
|
-
|
|
1107
|
-
if progress_poller is not None:
|
|
1108
|
-
progress_poller.stop()
|
|
1109
|
-
|
|
1110
|
-
t_end = time.time()
|
|
1111
|
-
item_desc = f"{total_items:,} items" if total_items else f"{total} tasks"
|
|
1112
|
-
print(f"Ray processing took {t_end - t_start:.2f}s for {item_desc}")
|
|
1113
|
-
_cleanup_log_gate(log_gate_path)
|
|
1114
|
-
return results
|
|
1115
|
-
|
|
1116
|
-
# ---- fastcore/thread backend (mp) ----
|
|
1117
|
-
if backend == 'mp':
|
|
1118
|
-
import concurrent.futures
|
|
1119
|
-
|
|
1120
|
-
# Capture caller frame for better error reporting
|
|
1121
|
-
caller_frame = inspect.currentframe()
|
|
1122
|
-
caller_info = None
|
|
1123
|
-
if caller_frame and caller_frame.f_back:
|
|
1124
|
-
caller_info = {
|
|
1125
|
-
'filename': caller_frame.f_back.f_code.co_filename,
|
|
1126
|
-
'lineno': caller_frame.f_back.f_lineno,
|
|
1127
|
-
'function': caller_frame.f_back.f_code.co_name,
|
|
1128
|
-
}
|
|
1129
|
-
|
|
1130
|
-
def worker_func(x):
|
|
1131
|
-
return _call_with_log_control(
|
|
1132
|
-
f_wrapped,
|
|
1133
|
-
x,
|
|
1134
|
-
func_kwargs,
|
|
1135
|
-
log_worker,
|
|
1136
|
-
log_gate_path,
|
|
1137
|
-
)
|
|
1138
|
-
|
|
1139
|
-
results: list[Any] = [None] * total
|
|
1140
|
-
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
1141
|
-
try:
|
|
1142
|
-
from .thread import _prune_dead_threads, _track_executor_threads
|
|
1143
|
-
has_thread_tracking = True
|
|
1144
|
-
except ImportError:
|
|
1145
|
-
has_thread_tracking = False
|
|
1146
|
-
|
|
1147
|
-
with concurrent.futures.ThreadPoolExecutor(
|
|
1148
|
-
max_workers=workers
|
|
1149
|
-
) as executor:
|
|
1150
|
-
if has_thread_tracking:
|
|
1151
|
-
_track_executor_threads(executor)
|
|
1152
|
-
|
|
1153
|
-
# Submit all tasks
|
|
1154
|
-
future_to_idx = {
|
|
1155
|
-
executor.submit(worker_func, x): idx
|
|
1156
|
-
for idx, x in enumerate(items)
|
|
1157
|
-
}
|
|
1158
|
-
|
|
1159
|
-
# Process results as they complete
|
|
1160
|
-
for future in concurrent.futures.as_completed(future_to_idx):
|
|
1161
|
-
idx = future_to_idx[future]
|
|
1162
|
-
try:
|
|
1163
|
-
result = future.result()
|
|
1164
|
-
error_stats.record_success()
|
|
1165
|
-
results[idx] = result
|
|
1166
|
-
except Exception as e:
|
|
1167
|
-
if error_handler == 'raise':
|
|
1168
|
-
# Cancel remaining futures
|
|
1169
|
-
for f in future_to_idx:
|
|
1170
|
-
f.cancel()
|
|
1171
|
-
_reraise_worker_error(e, pbar, caller_info, backend='mp')
|
|
1172
|
-
error_stats.record_error(idx, e, items[idx], func_name)
|
|
1173
|
-
results[idx] = None
|
|
1174
|
-
pbar.update(1)
|
|
1175
|
-
_update_pbar_postfix(pbar)
|
|
1176
|
-
|
|
1177
|
-
if _prune_dead_threads is not None:
|
|
1178
|
-
_prune_dead_threads()
|
|
1179
|
-
|
|
1180
|
-
_track_multiprocessing_processes()
|
|
1181
|
-
_prune_dead_processes()
|
|
1182
|
-
_cleanup_log_gate(log_gate_path)
|
|
1183
|
-
return results
|
|
1184
|
-
|
|
1185
|
-
if backend == 'safe':
|
|
1186
|
-
# Completely safe backend for tests - no multiprocessing
|
|
1187
|
-
import concurrent.futures
|
|
1188
|
-
|
|
1189
|
-
# Capture caller frame for better error reporting
|
|
1190
|
-
caller_frame = inspect.currentframe()
|
|
1191
|
-
caller_info = None
|
|
1192
|
-
if caller_frame and caller_frame.f_back:
|
|
1193
|
-
caller_info = {
|
|
1194
|
-
'filename': caller_frame.f_back.f_code.co_filename,
|
|
1195
|
-
'lineno': caller_frame.f_back.f_lineno,
|
|
1196
|
-
'function': caller_frame.f_back.f_code.co_name,
|
|
1197
|
-
}
|
|
1198
|
-
|
|
1199
|
-
def worker_func(x):
|
|
1200
|
-
return _call_with_log_control(
|
|
1201
|
-
f_wrapped,
|
|
1202
|
-
x,
|
|
1203
|
-
func_kwargs,
|
|
1204
|
-
log_worker,
|
|
1205
|
-
log_gate_path,
|
|
1206
|
-
)
|
|
1207
|
-
|
|
1208
|
-
results: list[Any] = [None] * total
|
|
1209
|
-
with tqdm(total=total, desc=desc, disable=not progress, file=sys.stdout) as pbar:
|
|
1210
|
-
with concurrent.futures.ThreadPoolExecutor(
|
|
1211
|
-
max_workers=workers
|
|
1212
|
-
) as executor:
|
|
1213
|
-
if _track_executor_threads is not None:
|
|
1214
|
-
_track_executor_threads(executor)
|
|
1215
|
-
|
|
1216
|
-
# Submit all tasks
|
|
1217
|
-
future_to_idx = {
|
|
1218
|
-
executor.submit(worker_func, x): idx
|
|
1219
|
-
for idx, x in enumerate(items)
|
|
1220
|
-
}
|
|
1221
|
-
|
|
1222
|
-
# Process results as they complete
|
|
1223
|
-
for future in concurrent.futures.as_completed(future_to_idx):
|
|
1224
|
-
idx = future_to_idx[future]
|
|
1225
|
-
try:
|
|
1226
|
-
result = future.result()
|
|
1227
|
-
error_stats.record_success()
|
|
1228
|
-
results[idx] = result
|
|
1229
|
-
except Exception as e:
|
|
1230
|
-
if error_handler == 'raise':
|
|
1231
|
-
for f in future_to_idx:
|
|
1232
|
-
f.cancel()
|
|
1233
|
-
_reraise_worker_error(e, pbar, caller_info, backend='safe')
|
|
1234
|
-
error_stats.record_error(idx, e, items[idx], func_name)
|
|
1235
|
-
results[idx] = None
|
|
1236
|
-
pbar.update(1)
|
|
1237
|
-
_update_pbar_postfix(pbar)
|
|
1238
|
-
|
|
1239
|
-
if _prune_dead_threads is not None:
|
|
1240
|
-
_prune_dead_threads()
|
|
1241
|
-
|
|
1242
|
-
_cleanup_log_gate(log_gate_path)
|
|
1243
|
-
return results
|
|
1244
|
-
|
|
1245
|
-
raise ValueError(f'Unsupported backend: {backend!r}')
|
|
53
|
+
__all__ = [
|
|
54
|
+
'SPEEDY_RUNNING_PROCESSES',
|
|
55
|
+
'ErrorStats',
|
|
56
|
+
'ErrorHandlerType',
|
|
57
|
+
'multi_process',
|
|
58
|
+
'cleanup_phantom_workers',
|
|
59
|
+
'create_progress_tracker',
|
|
60
|
+
'get_ray_progress_actor',
|
|
61
|
+
'ensure_ray',
|
|
62
|
+
'RAY_WORKER',
|
|
63
|
+
'tqdm',
|
|
64
|
+
]
|
|
1246
65
|
|
|
1247
66
|
|
|
1248
67
|
def cleanup_phantom_workers():
|