speedy-utils 1.1.40__py3-none-any.whl → 1.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +2 -0
- llm_utils/llm_ray.py +370 -0
- llm_utils/lm/llm.py +36 -29
- speedy_utils/__init__.py +10 -0
- speedy_utils/common/utils_io.py +3 -1
- speedy_utils/multi_worker/__init__.py +12 -0
- speedy_utils/multi_worker/dataset_ray.py +303 -0
- speedy_utils/multi_worker/parallel_gpu_pool.py +178 -0
- speedy_utils/multi_worker/process.py +989 -86
- speedy_utils/multi_worker/progress.py +140 -0
- speedy_utils/multi_worker/thread.py +202 -42
- speedy_utils/scripts/mpython.py +49 -4
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.43.dist-info}/METADATA +5 -3
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.43.dist-info}/RECORD +16 -12
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.43.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.40.dist-info → speedy_utils-1.1.43.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Real-time progress tracking for distributed Ray tasks.
|
|
3
|
+
|
|
4
|
+
This module provides a ProgressActor that allows workers to report item-level
|
|
5
|
+
progress in real-time, giving users visibility into actual items processed
|
|
6
|
+
rather than just task completion.
|
|
7
|
+
"""
|
|
8
|
+
import time
|
|
9
|
+
import threading
|
|
10
|
+
from typing import Optional, Callable
|
|
11
|
+
|
|
12
|
+
__all__ = ['ProgressActor', 'create_progress_tracker', 'get_ray_progress_actor']
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_ray_progress_actor():
|
|
16
|
+
"""Get the Ray-decorated ProgressActor class (lazy import to avoid Ray at module load)."""
|
|
17
|
+
import ray
|
|
18
|
+
|
|
19
|
+
@ray.remote
|
|
20
|
+
class ProgressActor:
|
|
21
|
+
"""
|
|
22
|
+
A Ray actor for tracking real-time progress across distributed workers.
|
|
23
|
+
|
|
24
|
+
Workers call `update(n)` to report items processed, and the main process
|
|
25
|
+
can poll `get_progress()` to update a tqdm bar in real-time.
|
|
26
|
+
"""
|
|
27
|
+
def __init__(self, total: int, desc: str = "Items"):
|
|
28
|
+
self.total = total
|
|
29
|
+
self.processed = 0
|
|
30
|
+
self.desc = desc
|
|
31
|
+
self.start_time = time.time()
|
|
32
|
+
self._lock = threading.Lock()
|
|
33
|
+
|
|
34
|
+
def update(self, n: int = 1) -> int:
|
|
35
|
+
"""Increment processed count by n. Returns new total."""
|
|
36
|
+
with self._lock:
|
|
37
|
+
self.processed += n
|
|
38
|
+
return self.processed
|
|
39
|
+
|
|
40
|
+
def get_progress(self) -> dict:
|
|
41
|
+
"""Get current progress stats."""
|
|
42
|
+
with self._lock:
|
|
43
|
+
elapsed = time.time() - self.start_time
|
|
44
|
+
rate = self.processed / elapsed if elapsed > 0 else 0
|
|
45
|
+
return {
|
|
46
|
+
"processed": self.processed,
|
|
47
|
+
"total": self.total,
|
|
48
|
+
"elapsed": elapsed,
|
|
49
|
+
"rate": rate,
|
|
50
|
+
"desc": self.desc,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
def set_total(self, total: int):
|
|
54
|
+
"""Update total (useful if exact count unknown at start)."""
|
|
55
|
+
with self._lock:
|
|
56
|
+
self.total = total
|
|
57
|
+
|
|
58
|
+
def reset(self):
|
|
59
|
+
"""Reset progress counter."""
|
|
60
|
+
with self._lock:
|
|
61
|
+
self.processed = 0
|
|
62
|
+
self.start_time = time.time()
|
|
63
|
+
|
|
64
|
+
return ProgressActor
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def create_progress_tracker(total: int, desc: str = "Items"):
|
|
68
|
+
"""
|
|
69
|
+
Create a progress tracker actor for use with Ray distributed tasks.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
total: Total number of items to process
|
|
73
|
+
desc: Description for the progress bar
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
A Ray actor handle that workers can use to report progress
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
progress_actor = create_progress_tracker(1000000, "Processing items")
|
|
80
|
+
|
|
81
|
+
@ray.remote
|
|
82
|
+
def worker(items, progress_actor):
|
|
83
|
+
for item in items:
|
|
84
|
+
process(item)
|
|
85
|
+
ray.get(progress_actor.update.remote(1))
|
|
86
|
+
|
|
87
|
+
# In main process, poll progress:
|
|
88
|
+
while not done:
|
|
89
|
+
stats = ray.get(progress_actor.get_progress.remote())
|
|
90
|
+
pbar.n = stats["processed"]
|
|
91
|
+
pbar.refresh()
|
|
92
|
+
"""
|
|
93
|
+
import ray
|
|
94
|
+
ProgressActor = get_ray_progress_actor()
|
|
95
|
+
return ProgressActor.remote(total, desc)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ProgressPoller:
|
|
99
|
+
"""
|
|
100
|
+
Background thread that polls a Ray progress actor and updates a tqdm bar.
|
|
101
|
+
"""
|
|
102
|
+
def __init__(self, progress_actor, pbar, poll_interval: float = 0.5):
|
|
103
|
+
import ray
|
|
104
|
+
self._ray = ray
|
|
105
|
+
self.progress_actor = progress_actor
|
|
106
|
+
self.pbar = pbar
|
|
107
|
+
self.poll_interval = poll_interval
|
|
108
|
+
self._stop_event = threading.Event()
|
|
109
|
+
self._thread: Optional[threading.Thread] = None
|
|
110
|
+
|
|
111
|
+
def start(self):
|
|
112
|
+
"""Start the polling thread."""
|
|
113
|
+
self._thread = threading.Thread(target=self._poll_loop, daemon=True)
|
|
114
|
+
self._thread.start()
|
|
115
|
+
|
|
116
|
+
def stop(self):
|
|
117
|
+
"""Stop the polling thread."""
|
|
118
|
+
self._stop_event.set()
|
|
119
|
+
if self._thread:
|
|
120
|
+
self._thread.join(timeout=2.0)
|
|
121
|
+
|
|
122
|
+
def _poll_loop(self):
|
|
123
|
+
"""Poll the progress actor and update tqdm."""
|
|
124
|
+
while not self._stop_event.is_set():
|
|
125
|
+
try:
|
|
126
|
+
stats = self._ray.get(self.progress_actor.get_progress.remote())
|
|
127
|
+
self.pbar.n = stats["processed"]
|
|
128
|
+
self.pbar.set_postfix_str(f'{stats["rate"]:.1f} items/s')
|
|
129
|
+
self.pbar.refresh()
|
|
130
|
+
except Exception:
|
|
131
|
+
pass # Ignore errors during polling
|
|
132
|
+
self._stop_event.wait(self.poll_interval)
|
|
133
|
+
|
|
134
|
+
# Final update
|
|
135
|
+
try:
|
|
136
|
+
stats = self._ray.get(self.progress_actor.get_progress.remote())
|
|
137
|
+
self.pbar.n = stats["processed"]
|
|
138
|
+
self.pbar.refresh()
|
|
139
|
+
except Exception:
|
|
140
|
+
pass
|
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
from ..__imports import *
|
|
2
|
+
import linecache
|
|
3
|
+
|
|
4
|
+
from .process import ErrorStats, ErrorHandlerType
|
|
2
5
|
|
|
3
6
|
|
|
4
7
|
try:
|
|
@@ -6,6 +9,17 @@ try:
|
|
|
6
9
|
except ImportError: # pragma: no cover
|
|
7
10
|
tqdm = None # type: ignore[assignment]
|
|
8
11
|
|
|
12
|
+
try:
|
|
13
|
+
from rich.console import Console
|
|
14
|
+
from rich.panel import Panel
|
|
15
|
+
from rich.syntax import Syntax
|
|
16
|
+
from rich.text import Text
|
|
17
|
+
except ImportError: # pragma: no cover
|
|
18
|
+
Console = None # type: ignore[assignment, misc]
|
|
19
|
+
Panel = None # type: ignore[assignment, misc]
|
|
20
|
+
Syntax = None # type: ignore[assignment, misc]
|
|
21
|
+
Text = None # type: ignore[assignment, misc]
|
|
22
|
+
|
|
9
23
|
# Sensible defaults
|
|
10
24
|
DEFAULT_WORKERS = (os.cpu_count() or 4) * 2
|
|
11
25
|
|
|
@@ -25,11 +39,13 @@ class UserFunctionError(Exception):
|
|
|
25
39
|
func_name: str,
|
|
26
40
|
input_value: Any,
|
|
27
41
|
user_traceback: list[traceback.FrameSummary],
|
|
42
|
+
caller_frame: traceback.FrameSummary | None = None,
|
|
28
43
|
) -> None:
|
|
29
44
|
self.original_exception = original_exception
|
|
30
45
|
self.func_name = func_name
|
|
31
46
|
self.input_value = input_value
|
|
32
47
|
self.user_traceback = user_traceback
|
|
48
|
+
self.caller_frame = caller_frame
|
|
33
49
|
|
|
34
50
|
# Create a focused error message
|
|
35
51
|
tb_str = ''.join(traceback.format_list(user_traceback))
|
|
@@ -44,6 +60,95 @@ class UserFunctionError(Exception):
|
|
|
44
60
|
# Return focused error without infrastructure frames
|
|
45
61
|
return super().__str__()
|
|
46
62
|
|
|
63
|
+
def format_rich(self) -> None:
|
|
64
|
+
"""Format and print error with rich panels and code context."""
|
|
65
|
+
if Console is None or Panel is None or Text is None:
|
|
66
|
+
# Fallback to plain text
|
|
67
|
+
print(str(self), file=sys.stderr)
|
|
68
|
+
return
|
|
69
|
+
|
|
70
|
+
console = Console(stderr=True, force_terminal=True)
|
|
71
|
+
|
|
72
|
+
# Build traceback display with code context
|
|
73
|
+
tb_parts: list[str] = []
|
|
74
|
+
|
|
75
|
+
# Show caller frame first if available
|
|
76
|
+
if self.caller_frame and self.caller_frame.lineno is not None:
|
|
77
|
+
tb_parts.append(
|
|
78
|
+
f'[cyan]{self.caller_frame.filename}[/cyan]:[yellow]{self.caller_frame.lineno}[/yellow] '
|
|
79
|
+
f'in [green]{self.caller_frame.name}[/green]'
|
|
80
|
+
)
|
|
81
|
+
tb_parts.append('')
|
|
82
|
+
context = _get_code_context_rich(self.caller_frame.filename, self.caller_frame.lineno, 3)
|
|
83
|
+
tb_parts.extend(context)
|
|
84
|
+
tb_parts.append('')
|
|
85
|
+
|
|
86
|
+
# Show user code frames with context
|
|
87
|
+
for frame in self.user_traceback:
|
|
88
|
+
if frame.lineno is not None:
|
|
89
|
+
tb_parts.append(
|
|
90
|
+
f'[cyan]{frame.filename}[/cyan]:[yellow]{frame.lineno}[/yellow] '
|
|
91
|
+
f'in [green]{frame.name}[/green]'
|
|
92
|
+
)
|
|
93
|
+
tb_parts.append('')
|
|
94
|
+
context = _get_code_context_rich(frame.filename, frame.lineno, 3)
|
|
95
|
+
tb_parts.extend(context)
|
|
96
|
+
tb_parts.append('')
|
|
97
|
+
|
|
98
|
+
# Print with rich Panel
|
|
99
|
+
console.print()
|
|
100
|
+
console.print(
|
|
101
|
+
Panel(
|
|
102
|
+
'\n'.join(tb_parts),
|
|
103
|
+
title='[bold red]Traceback (most recent call last)[/bold red]',
|
|
104
|
+
border_style='red',
|
|
105
|
+
expand=False,
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
console.print(
|
|
109
|
+
f'[bold red]{type(self.original_exception).__name__}[/bold red]: '
|
|
110
|
+
f'{self.original_exception}'
|
|
111
|
+
)
|
|
112
|
+
console.print()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _get_code_context(filename: str, lineno: int, context_lines: int = 3) -> list[str]:
|
|
116
|
+
"""Get code context around a line with line numbers and highlighting."""
|
|
117
|
+
lines: list[str] = []
|
|
118
|
+
start = max(1, lineno - context_lines)
|
|
119
|
+
end = lineno + context_lines
|
|
120
|
+
|
|
121
|
+
for i in range(start, end + 1):
|
|
122
|
+
line = linecache.getline(filename, i)
|
|
123
|
+
if not line:
|
|
124
|
+
continue
|
|
125
|
+
line = line.rstrip()
|
|
126
|
+
marker = '❱' if i == lineno else ' '
|
|
127
|
+
lines.append(f' {i:4d} {marker} {line}')
|
|
128
|
+
|
|
129
|
+
return lines
|
|
130
|
+
|
|
131
|
+
def _get_code_context_rich(filename: str, lineno: int, context_lines: int = 3) -> list[str]:
|
|
132
|
+
"""Get code context with rich formatting (colors)."""
|
|
133
|
+
lines: list[str] = []
|
|
134
|
+
start = max(1, lineno - context_lines)
|
|
135
|
+
end = lineno + context_lines
|
|
136
|
+
|
|
137
|
+
for i in range(start, end + 1):
|
|
138
|
+
line = linecache.getline(filename, i)
|
|
139
|
+
if not line:
|
|
140
|
+
continue
|
|
141
|
+
line = line.rstrip()
|
|
142
|
+
num_str = f'{i:4d}'
|
|
143
|
+
|
|
144
|
+
if i == lineno:
|
|
145
|
+
# Highlight error line
|
|
146
|
+
lines.append(f'[dim]{num_str}[/dim] [red]❱[/red] {line}')
|
|
147
|
+
else:
|
|
148
|
+
# Normal context line
|
|
149
|
+
lines.append(f'[dim]{num_str} │[/dim] {line}')
|
|
150
|
+
|
|
151
|
+
return lines
|
|
47
152
|
|
|
48
153
|
_PY_SET_ASYNC_EXC = ctypes.pythonapi.PyThreadState_SetAsyncExc
|
|
49
154
|
try:
|
|
@@ -90,6 +195,7 @@ def _worker(
|
|
|
90
195
|
item: T,
|
|
91
196
|
func: Callable[[T], R],
|
|
92
197
|
fixed_kwargs: Mapping[str, Any],
|
|
198
|
+
caller_frame: traceback.FrameSummary | None = None,
|
|
93
199
|
) -> R:
|
|
94
200
|
"""Execute the function with an item and fixed kwargs."""
|
|
95
201
|
# Validate func is callable before attempting to call it
|
|
@@ -102,7 +208,7 @@ def _worker(
|
|
|
102
208
|
)
|
|
103
209
|
|
|
104
210
|
try:
|
|
105
|
-
return func(item
|
|
211
|
+
return func(item)
|
|
106
212
|
except Exception as exc:
|
|
107
213
|
# Extract user code traceback (filter out infrastructure)
|
|
108
214
|
exc_tb = sys.exc_info()[2]
|
|
@@ -114,8 +220,11 @@ def _worker(
|
|
|
114
220
|
user_frames = []
|
|
115
221
|
skip_patterns = [
|
|
116
222
|
'multi_worker/thread.py',
|
|
223
|
+
'multi_worker/process.py',
|
|
117
224
|
'concurrent/futures/',
|
|
118
225
|
'threading.py',
|
|
226
|
+
'multiprocessing/',
|
|
227
|
+
'site-packages/ray/',
|
|
119
228
|
]
|
|
120
229
|
|
|
121
230
|
for frame in tb_list:
|
|
@@ -130,6 +239,7 @@ def _worker(
|
|
|
130
239
|
func_name,
|
|
131
240
|
item,
|
|
132
241
|
user_frames,
|
|
242
|
+
caller_frame,
|
|
133
243
|
) from exc
|
|
134
244
|
|
|
135
245
|
# Fallback: re-raise original if we couldn't extract frames
|
|
@@ -140,8 +250,9 @@ def _run_batch(
|
|
|
140
250
|
items: Sequence[T],
|
|
141
251
|
func: Callable[[T], R],
|
|
142
252
|
fixed_kwargs: Mapping[str, Any],
|
|
253
|
+
caller_frame: traceback.FrameSummary | None = None,
|
|
143
254
|
) -> list[R]:
|
|
144
|
-
return [_worker(item, func, fixed_kwargs) for item in items]
|
|
255
|
+
return [_worker(item, func, fixed_kwargs, caller_frame) for item in items]
|
|
145
256
|
|
|
146
257
|
|
|
147
258
|
def _attach_metadata(fut: Future[Any], idx: int, logical_size: int) -> None:
|
|
@@ -242,7 +353,9 @@ def multi_thread(
|
|
|
242
353
|
progress_update: int = 10,
|
|
243
354
|
prefetch_factor: int = 4,
|
|
244
355
|
timeout: float | None = None,
|
|
245
|
-
stop_on_error: bool =
|
|
356
|
+
stop_on_error: bool | None = None,
|
|
357
|
+
error_handler: ErrorHandlerType = 'raise',
|
|
358
|
+
max_error_files: int = 100,
|
|
246
359
|
n_proc: int = 0,
|
|
247
360
|
store_output_pkl_file: str | None = None,
|
|
248
361
|
**fixed_kwargs: Any,
|
|
@@ -272,8 +385,16 @@ def multi_thread(
|
|
|
272
385
|
Multiplier controlling in-flight items (``workers * prefetch_factor``).
|
|
273
386
|
timeout : float | None, optional
|
|
274
387
|
Overall wall-clock timeout in seconds.
|
|
275
|
-
stop_on_error : bool, optional
|
|
276
|
-
|
|
388
|
+
stop_on_error : bool | None, optional
|
|
389
|
+
Deprecated. Use error_handler instead.
|
|
390
|
+
When True -> error_handler='raise', when False -> error_handler='log'.
|
|
391
|
+
error_handler : 'raise' | 'ignore' | 'log', optional
|
|
392
|
+
- 'raise': raise exception on first error (default)
|
|
393
|
+
- 'ignore': continue, return None for failed items
|
|
394
|
+
- 'log': same as ignore, but logs errors to files
|
|
395
|
+
max_error_files : int, optional
|
|
396
|
+
Maximum number of error log files to write (default: 100).
|
|
397
|
+
Error logs are written to .cache/speedy_utils/error_logs/{idx}.log
|
|
277
398
|
n_proc : int, optional
|
|
278
399
|
Optional process-level fan-out; ``>1`` shards work across processes.
|
|
279
400
|
store_output_pkl_file : str | None, optional
|
|
@@ -285,10 +406,20 @@ def multi_thread(
|
|
|
285
406
|
-------
|
|
286
407
|
list[R | None]
|
|
287
408
|
Collected results, preserving order when requested. Failed tasks yield
|
|
288
|
-
``None`` entries if ``
|
|
409
|
+
``None`` entries if ``error_handler`` is not 'raise'.
|
|
289
410
|
"""
|
|
290
411
|
from speedy_utils import dump_json_or_pickle, load_by_ext
|
|
291
412
|
|
|
413
|
+
# Handle deprecated stop_on_error parameter
|
|
414
|
+
if stop_on_error is not None:
|
|
415
|
+
import warnings
|
|
416
|
+
warnings.warn(
|
|
417
|
+
"stop_on_error is deprecated, use error_handler instead",
|
|
418
|
+
DeprecationWarning,
|
|
419
|
+
stacklevel=2
|
|
420
|
+
)
|
|
421
|
+
error_handler = 'raise' if stop_on_error else 'log'
|
|
422
|
+
|
|
292
423
|
if n_proc > 1:
|
|
293
424
|
import tempfile
|
|
294
425
|
|
|
@@ -319,7 +450,8 @@ def multi_thread(
|
|
|
319
450
|
progress_update=progress_update,
|
|
320
451
|
prefetch_factor=prefetch_factor,
|
|
321
452
|
timeout=timeout,
|
|
322
|
-
|
|
453
|
+
error_handler=error_handler,
|
|
454
|
+
max_error_files=max_error_files,
|
|
323
455
|
n_proc=0,
|
|
324
456
|
store_output_pkl_file=file_pkl,
|
|
325
457
|
**fixed_kwargs,
|
|
@@ -363,12 +495,30 @@ def multi_thread(
|
|
|
363
495
|
if batch == 1 and logical_total and logical_total / max(workers_val, 1) > 20_000:
|
|
364
496
|
batch = 32
|
|
365
497
|
|
|
366
|
-
src_iter:
|
|
498
|
+
src_iter: Iterator[Any] = iter(inputs)
|
|
367
499
|
if batch > 1:
|
|
368
|
-
src_iter = _group_iter(src_iter, batch)
|
|
369
|
-
src_iter = iter(src_iter)
|
|
500
|
+
src_iter = iter(_group_iter(src_iter, batch))
|
|
370
501
|
collector: _ResultCollector[Any] = _ResultCollector(ordered, logical_total)
|
|
371
502
|
|
|
503
|
+
# Initialize error stats for error handling
|
|
504
|
+
func_name = getattr(func, '__name__', repr(func))
|
|
505
|
+
error_stats = ErrorStats(
|
|
506
|
+
func_name=func_name,
|
|
507
|
+
max_error_files=max_error_files,
|
|
508
|
+
write_logs=error_handler == 'log'
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
# Convert inputs to list for index access in error logging
|
|
512
|
+
items_list: list[Any] | None = None
|
|
513
|
+
if error_handler != 'raise':
|
|
514
|
+
try:
|
|
515
|
+
items_list = list(inputs)
|
|
516
|
+
src_iter = iter(items_list)
|
|
517
|
+
if batch > 1:
|
|
518
|
+
src_iter = iter(_group_iter(src_iter, batch))
|
|
519
|
+
except Exception:
|
|
520
|
+
items_list = None
|
|
521
|
+
|
|
372
522
|
bar = None
|
|
373
523
|
last_bar_update = 0
|
|
374
524
|
if (
|
|
@@ -382,10 +532,22 @@ def multi_thread(
|
|
|
382
532
|
ncols=128,
|
|
383
533
|
colour='green',
|
|
384
534
|
bar_format=(
|
|
385
|
-
'{l_bar}{bar}| {n_fmt}/{total_fmt}
|
|
535
|
+
'{l_bar}{bar}| {n_fmt}/{total_fmt} '
|
|
536
|
+
'[{elapsed}<{remaining}, {rate_fmt}{postfix}]'
|
|
386
537
|
),
|
|
387
538
|
)
|
|
388
539
|
|
|
540
|
+
# Capture caller context for error reporting
|
|
541
|
+
caller_frame_obj = inspect.currentframe()
|
|
542
|
+
caller_context: traceback.FrameSummary | None = None
|
|
543
|
+
if caller_frame_obj and caller_frame_obj.f_back:
|
|
544
|
+
caller_info = inspect.getframeinfo(caller_frame_obj.f_back)
|
|
545
|
+
caller_context = traceback.FrameSummary(
|
|
546
|
+
caller_info.filename,
|
|
547
|
+
caller_info.lineno,
|
|
548
|
+
caller_info.function,
|
|
549
|
+
)
|
|
550
|
+
|
|
389
551
|
deadline = time.monotonic() + timeout if timeout is not None else None
|
|
390
552
|
max_inflight = max(workers_val * prefetch_factor, 1)
|
|
391
553
|
completed_items = 0
|
|
@@ -409,10 +571,10 @@ def multi_thread(
|
|
|
409
571
|
batch_items = list(arg)
|
|
410
572
|
if not batch_items:
|
|
411
573
|
return
|
|
412
|
-
fut = pool.submit(_run_batch, batch_items, func, fixed_kwargs_map)
|
|
574
|
+
fut = pool.submit(_run_batch, batch_items, func, fixed_kwargs_map, caller_context)
|
|
413
575
|
logical_size = len(batch_items)
|
|
414
576
|
else:
|
|
415
|
-
fut = pool.submit(_worker, arg, func, fixed_kwargs_map)
|
|
577
|
+
fut = pool.submit(_worker, arg, func, fixed_kwargs_map, caller_context)
|
|
416
578
|
logical_size = 1
|
|
417
579
|
_attach_metadata(fut, next_logical_idx, logical_size)
|
|
418
580
|
next_logical_idx += logical_size
|
|
@@ -453,37 +615,37 @@ def multi_thread(
|
|
|
453
615
|
idx, logical_size = _future_meta(fut)
|
|
454
616
|
try:
|
|
455
617
|
result = fut.result()
|
|
618
|
+
# Record success for each item in the batch
|
|
619
|
+
for _ in range(logical_size):
|
|
620
|
+
error_stats.record_success()
|
|
456
621
|
except UserFunctionError as exc:
|
|
457
|
-
# User function error
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
622
|
+
# User function error
|
|
623
|
+
if error_handler == 'raise':
|
|
624
|
+
sys.stderr.flush()
|
|
625
|
+
sys.stdout.flush()
|
|
626
|
+
exc.format_rich()
|
|
627
|
+
sys.stderr.flush()
|
|
461
628
|
_cancel_futures(inflight)
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
f'{type(orig_exc).__name__}: {orig_exc}'
|
|
472
|
-
)
|
|
473
|
-
|
|
474
|
-
# Raise a new instance of the original exception type
|
|
475
|
-
# with our clean message
|
|
476
|
-
new_exc = type(orig_exc)(clean_msg)
|
|
477
|
-
# Suppress the "from" chain to avoid showing infrastructure
|
|
478
|
-
raise new_exc from None
|
|
479
|
-
|
|
629
|
+
sys.exit(1)
|
|
630
|
+
|
|
631
|
+
# Log error with ErrorStats
|
|
632
|
+
input_val = None
|
|
633
|
+
if items_list is not None and idx < len(items_list):
|
|
634
|
+
input_val = items_list[idx]
|
|
635
|
+
error_stats.record_error(
|
|
636
|
+
idx, exc.original_exception, input_val, func_name
|
|
637
|
+
)
|
|
480
638
|
out_items = [None] * logical_size
|
|
481
639
|
except Exception as exc:
|
|
482
640
|
# Other errors (infrastructure, batching, etc.)
|
|
483
|
-
if
|
|
641
|
+
if error_handler == 'raise':
|
|
484
642
|
_cancel_futures(inflight)
|
|
485
643
|
raise
|
|
486
|
-
|
|
644
|
+
|
|
645
|
+
input_val = None
|
|
646
|
+
if items_list is not None and idx < len(items_list):
|
|
647
|
+
input_val = items_list[idx]
|
|
648
|
+
error_stats.record_error(idx, exc, input_val, func_name)
|
|
487
649
|
out_items = [None] * logical_size
|
|
488
650
|
else:
|
|
489
651
|
try:
|
|
@@ -503,15 +665,13 @@ def multi_thread(
|
|
|
503
665
|
bar.update(delta)
|
|
504
666
|
last_bar_update = completed_items
|
|
505
667
|
submitted = next_logical_idx
|
|
506
|
-
pending = (
|
|
668
|
+
pending: int | str = (
|
|
507
669
|
max(logical_total - submitted, 0)
|
|
508
670
|
if logical_total is not None
|
|
509
671
|
else '-'
|
|
510
672
|
)
|
|
511
|
-
postfix =
|
|
512
|
-
|
|
513
|
-
'pending': pending,
|
|
514
|
-
}
|
|
673
|
+
postfix: dict[str, Any] = error_stats.get_postfix_dict()
|
|
674
|
+
postfix['pending'] = pending
|
|
515
675
|
bar.set_postfix(postfix)
|
|
516
676
|
|
|
517
677
|
try:
|
speedy_utils/scripts/mpython.py
CHANGED
|
@@ -3,13 +3,59 @@ import argparse
|
|
|
3
3
|
import itertools
|
|
4
4
|
import multiprocessing # Import multiprocessing module
|
|
5
5
|
import os
|
|
6
|
+
import re
|
|
6
7
|
import shlex # To properly escape command line arguments
|
|
7
8
|
import shutil
|
|
9
|
+
import subprocess
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
taskset_path = shutil.which('taskset')
|
|
11
13
|
|
|
12
14
|
|
|
15
|
+
def get_existing_tmux_sessions():
|
|
16
|
+
"""Get list of existing tmux session names."""
|
|
17
|
+
try:
|
|
18
|
+
result = subprocess.run(
|
|
19
|
+
['tmux', 'list-sessions', '-F', '#{session_name}'],
|
|
20
|
+
capture_output=True,
|
|
21
|
+
text=True,
|
|
22
|
+
)
|
|
23
|
+
if result.returncode == 0:
|
|
24
|
+
return result.stdout.strip().split('\n')
|
|
25
|
+
return []
|
|
26
|
+
except FileNotFoundError:
|
|
27
|
+
# tmux not installed
|
|
28
|
+
return []
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_next_session_name(base_name='mpython'):
|
|
32
|
+
"""Get next available session name.
|
|
33
|
+
|
|
34
|
+
If 'mpython' doesn't exist, return 'mpython'.
|
|
35
|
+
If 'mpython' exists, return 'mpython-1', 'mpython-2', etc.
|
|
36
|
+
"""
|
|
37
|
+
existing_sessions = get_existing_tmux_sessions()
|
|
38
|
+
|
|
39
|
+
if base_name not in existing_sessions:
|
|
40
|
+
return base_name
|
|
41
|
+
|
|
42
|
+
# Find all existing mpython-N sessions
|
|
43
|
+
pattern = re.compile(rf'^{re.escape(base_name)}-(\d+)$')
|
|
44
|
+
existing_numbers = []
|
|
45
|
+
|
|
46
|
+
for session in existing_sessions:
|
|
47
|
+
match = pattern.match(session)
|
|
48
|
+
if match:
|
|
49
|
+
existing_numbers.append(int(match.group(1)))
|
|
50
|
+
|
|
51
|
+
# Find the next available number
|
|
52
|
+
next_num = 1
|
|
53
|
+
if existing_numbers:
|
|
54
|
+
next_num = max(existing_numbers) + 1
|
|
55
|
+
|
|
56
|
+
return f'{base_name}-{next_num}'
|
|
57
|
+
|
|
58
|
+
|
|
13
59
|
def assert_script(python_path):
|
|
14
60
|
with open(python_path) as f:
|
|
15
61
|
code_str = f.read()
|
|
@@ -30,10 +76,7 @@ def assert_script(python_path):
|
|
|
30
76
|
|
|
31
77
|
def run_in_tmux(commands_to_run, tmux_name, num_windows):
|
|
32
78
|
with open('/tmp/start_multirun_tmux.sh', 'w') as script_file:
|
|
33
|
-
# first cmd is to kill the session if it exists
|
|
34
|
-
|
|
35
79
|
script_file.write('#!/bin/bash\n\n')
|
|
36
|
-
script_file.write(f'tmux kill-session -t {tmux_name}\nsleep .1\n')
|
|
37
80
|
script_file.write(f'tmux new-session -d -s {tmux_name}\n')
|
|
38
81
|
for i, cmd in enumerate(itertools.cycle(commands_to_run)):
|
|
39
82
|
if i >= num_windows:
|
|
@@ -99,9 +142,11 @@ def main():
|
|
|
99
142
|
|
|
100
143
|
cmds.append(fold_cmd)
|
|
101
144
|
|
|
102
|
-
|
|
145
|
+
session_name = get_next_session_name('mpython')
|
|
146
|
+
run_in_tmux(cmds, session_name, args.total_fold)
|
|
103
147
|
os.chmod('/tmp/start_multirun_tmux.sh', 0o755) # Make the script executable
|
|
104
148
|
os.system('/tmp/start_multirun_tmux.sh')
|
|
149
|
+
print(f'Started tmux session: {session_name}')
|
|
105
150
|
|
|
106
151
|
|
|
107
152
|
if __name__ == '__main__':
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: speedy-utils
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.43
|
|
4
4
|
Summary: Fast and easy-to-use package for data science
|
|
5
5
|
Project-URL: Homepage, https://github.com/anhvth/speedy
|
|
6
6
|
Project-URL: Repository, https://github.com/anhvth/speedy
|
|
@@ -17,7 +17,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.12
|
|
18
18
|
Classifier: Programming Language :: Python :: 3.13
|
|
19
19
|
Classifier: Programming Language :: Python :: 3.14
|
|
20
|
-
Requires-Python: >=3.
|
|
20
|
+
Requires-Python: >=3.9
|
|
21
21
|
Requires-Dist: aiohttp
|
|
22
22
|
Requires-Dist: bump2version
|
|
23
23
|
Requires-Dist: cachetools
|
|
@@ -39,13 +39,15 @@ Requires-Dist: pydantic
|
|
|
39
39
|
Requires-Dist: pytest
|
|
40
40
|
Requires-Dist: ray
|
|
41
41
|
Requires-Dist: requests
|
|
42
|
+
Requires-Dist: rich>=14.3.1
|
|
42
43
|
Requires-Dist: ruff
|
|
43
44
|
Requires-Dist: scikit-learn
|
|
44
45
|
Requires-Dist: tabulate
|
|
45
46
|
Requires-Dist: tqdm
|
|
46
47
|
Requires-Dist: xxhash
|
|
47
48
|
Provides-Extra: ray
|
|
48
|
-
Requires-Dist: ray>=2.
|
|
49
|
+
Requires-Dist: ray[data,llm]>=2.40.0; extra == 'ray'
|
|
50
|
+
Requires-Dist: vllm>=0.6.3; extra == 'ray'
|
|
49
51
|
Description-Content-Type: text/markdown
|
|
50
52
|
|
|
51
53
|
# Speedy Utils
|