speedy-utils 1.1.46__py3-none-any.whl → 1.1.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +1 -3
- llm_utils/chat_format/__init__.py +0 -2
- llm_utils/chat_format/display.py +283 -364
- llm_utils/lm/llm.py +62 -22
- speedy_utils/__init__.py +4 -0
- speedy_utils/multi_worker/__init__.py +4 -0
- speedy_utils/multi_worker/_multi_process.py +425 -0
- speedy_utils/multi_worker/_multi_process_ray.py +308 -0
- speedy_utils/multi_worker/common.py +879 -0
- speedy_utils/multi_worker/dataset_sharding.py +203 -0
- speedy_utils/multi_worker/process.py +53 -1234
- speedy_utils/multi_worker/progress.py +71 -1
- speedy_utils/multi_worker/thread.py +45 -0
- speedy_utils/scripts/mpython.py +19 -12
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.48.dist-info}/METADATA +1 -1
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.48.dist-info}/RECORD +18 -14
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.48.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.46.dist-info → speedy_utils-1.1.48.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ray-specific multi_process backend implementation.
|
|
3
|
+
|
|
4
|
+
Contains:
|
|
5
|
+
- ensure_ray(): Ray initialization/lifecycle management
|
|
6
|
+
- run_ray_backend(): Ray-based parallel execution
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import inspect
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
import sys
|
|
14
|
+
import time
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal
|
|
17
|
+
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
|
|
20
|
+
from .common import (
|
|
21
|
+
ErrorStats,
|
|
22
|
+
_call_with_log_control,
|
|
23
|
+
_cleanup_log_gate,
|
|
24
|
+
_exit_on_ray_error,
|
|
25
|
+
_track_ray_processes,
|
|
26
|
+
)
|
|
27
|
+
from .progress import ProgressPoller, create_progress_tracker
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from .common import ErrorHandlerType
|
|
31
|
+
|
|
32
|
+
# ─── Ray management ─────────────────────────────────────────────
|
|
33
|
+
|
|
34
|
+
RAY_WORKER = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def ensure_ray(
|
|
38
|
+
workers: int | None,
|
|
39
|
+
pbar: tqdm | None = None,
|
|
40
|
+
ray_metrics_port: int | None = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""
|
|
43
|
+
Initialize or reinitialize Ray safely for both local and cluster envs.
|
|
44
|
+
|
|
45
|
+
1. Tries to connect to an existing cluster (address='auto') first.
|
|
46
|
+
2. If no cluster is found, starts a local Ray instance with 'workers' CPUs.
|
|
47
|
+
"""
|
|
48
|
+
import ray as _ray_module
|
|
49
|
+
|
|
50
|
+
global RAY_WORKER
|
|
51
|
+
requested_workers = workers
|
|
52
|
+
if workers is None:
|
|
53
|
+
workers = os.cpu_count() or 1
|
|
54
|
+
|
|
55
|
+
if ray_metrics_port is not None:
|
|
56
|
+
os.environ['RAY_metrics_export_port'] = str(ray_metrics_port)
|
|
57
|
+
|
|
58
|
+
allow_restart = os.environ.get('RESTART_RAY', '0').lower() in ('1', 'true')
|
|
59
|
+
is_cluster_env = (
|
|
60
|
+
'RAY_ADDRESS' in os.environ
|
|
61
|
+
or os.environ.get('RAY_CLUSTER') == '1'
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# 1. Handle existing session
|
|
65
|
+
if _ray_module.is_initialized():
|
|
66
|
+
if not allow_restart:
|
|
67
|
+
if pbar:
|
|
68
|
+
pbar.set_postfix_str('Using existing Ray session')
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
# Avoid shutting down shared cluster sessions.
|
|
72
|
+
if is_cluster_env:
|
|
73
|
+
if pbar:
|
|
74
|
+
pbar.set_postfix_str(
|
|
75
|
+
'Cluster active: skipping restart to protect connection'
|
|
76
|
+
)
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
# Local restart: only if worker count changed
|
|
80
|
+
if workers != RAY_WORKER:
|
|
81
|
+
if pbar:
|
|
82
|
+
pbar.set_postfix_str(
|
|
83
|
+
f'Restarting local Ray with {workers} workers'
|
|
84
|
+
)
|
|
85
|
+
_ray_module.shutdown()
|
|
86
|
+
else:
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
# 2. Initialization logic
|
|
90
|
+
t0 = time.time()
|
|
91
|
+
|
|
92
|
+
# Try to connect to existing cluster FIRST (address="auto")
|
|
93
|
+
try:
|
|
94
|
+
if pbar:
|
|
95
|
+
pbar.set_postfix_str('Searching for Ray cluster...')
|
|
96
|
+
|
|
97
|
+
# MUST NOT pass num_cpus/num_gpus here to avoid ValueError
|
|
98
|
+
_ray_module.init(
|
|
99
|
+
address='auto',
|
|
100
|
+
ignore_reinit_error=True,
|
|
101
|
+
logging_level=logging.ERROR,
|
|
102
|
+
log_to_driver=False
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if pbar:
|
|
106
|
+
resources = _ray_module.cluster_resources()
|
|
107
|
+
cpus = resources.get('CPU', 0)
|
|
108
|
+
pbar.set_postfix_str(f'Connected to Ray Cluster ({int(cpus)} CPUs)')
|
|
109
|
+
|
|
110
|
+
except Exception:
|
|
111
|
+
# 3. Fallback: Start a local Ray session
|
|
112
|
+
if pbar:
|
|
113
|
+
pbar.set_postfix_str(
|
|
114
|
+
f'No cluster found. Starting local Ray ({workers} CPUs)...'
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
_ray_module.init(
|
|
118
|
+
num_cpus=workers,
|
|
119
|
+
ignore_reinit_error=True,
|
|
120
|
+
logging_level=logging.ERROR,
|
|
121
|
+
log_to_driver=False,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if pbar:
|
|
125
|
+
took = time.time() - t0
|
|
126
|
+
pbar.set_postfix_str(f'ray.init local {workers} took {took:.2f}s')
|
|
127
|
+
|
|
128
|
+
_track_ray_processes()
|
|
129
|
+
|
|
130
|
+
if requested_workers is None:
|
|
131
|
+
try:
|
|
132
|
+
resources = _ray_module.cluster_resources()
|
|
133
|
+
total_cpus = int(resources.get('CPU', 0))
|
|
134
|
+
if total_cpus > 0:
|
|
135
|
+
workers = total_cpus
|
|
136
|
+
except Exception:
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
RAY_WORKER = workers
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def run_ray_backend(
|
|
143
|
+
*,
|
|
144
|
+
f_wrapped: Callable,
|
|
145
|
+
items: list[Any],
|
|
146
|
+
total: int,
|
|
147
|
+
workers: int | None,
|
|
148
|
+
progress: bool,
|
|
149
|
+
desc: str,
|
|
150
|
+
func_kwargs: dict[str, Any],
|
|
151
|
+
shared_kwargs: list[str] | None,
|
|
152
|
+
log_worker: Literal['zero', 'first', 'all'],
|
|
153
|
+
log_gate_path: Path | None,
|
|
154
|
+
total_items: int | None,
|
|
155
|
+
poll_interval: float,
|
|
156
|
+
ray_metrics_port: int | None,
|
|
157
|
+
error_handler: ErrorHandlerType,
|
|
158
|
+
error_stats: ErrorStats,
|
|
159
|
+
func_name: str,
|
|
160
|
+
) -> list[Any]:
|
|
161
|
+
"""
|
|
162
|
+
Run the Ray backend for multi_process.
|
|
163
|
+
|
|
164
|
+
Returns a list of results in the same order as items.
|
|
165
|
+
"""
|
|
166
|
+
import ray as _ray_module
|
|
167
|
+
|
|
168
|
+
# Capture caller frame for better error reporting
|
|
169
|
+
# Go back to multi_process -> user code
|
|
170
|
+
caller_frame = inspect.currentframe()
|
|
171
|
+
caller_info = None
|
|
172
|
+
if (
|
|
173
|
+
caller_frame
|
|
174
|
+
and caller_frame.f_back
|
|
175
|
+
and caller_frame.f_back.f_back
|
|
176
|
+
):
|
|
177
|
+
outer = caller_frame.f_back.f_back
|
|
178
|
+
caller_info = {
|
|
179
|
+
'filename': outer.f_code.co_filename,
|
|
180
|
+
'lineno': outer.f_lineno,
|
|
181
|
+
'function': outer.f_code.co_name,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
results = []
|
|
185
|
+
gate_path_str = str(log_gate_path) if log_gate_path else None
|
|
186
|
+
|
|
187
|
+
# Determine if we're doing item-level or task-level tracking
|
|
188
|
+
use_item_tracking = total_items is not None
|
|
189
|
+
pbar_total = total_items if use_item_tracking else total
|
|
190
|
+
pbar_desc = desc if use_item_tracking else desc
|
|
191
|
+
|
|
192
|
+
with tqdm(
|
|
193
|
+
total=pbar_total,
|
|
194
|
+
desc=pbar_desc,
|
|
195
|
+
disable=not progress,
|
|
196
|
+
file=sys.stdout,
|
|
197
|
+
unit='items' if use_item_tracking else 'tasks',
|
|
198
|
+
) as pbar:
|
|
199
|
+
ensure_ray(workers, pbar, ray_metrics_port)
|
|
200
|
+
|
|
201
|
+
shared_refs: dict[str, Any] = {}
|
|
202
|
+
regular_kwargs: dict[str, Any] = {}
|
|
203
|
+
|
|
204
|
+
# Create progress actor for item-level tracking if total_items specified
|
|
205
|
+
progress_actor = None
|
|
206
|
+
progress_poller = None
|
|
207
|
+
if use_item_tracking:
|
|
208
|
+
progress_actor = create_progress_tracker(total_items, pbar_desc or 'Items')
|
|
209
|
+
shared_refs['progress_actor'] = progress_actor
|
|
210
|
+
|
|
211
|
+
if shared_kwargs:
|
|
212
|
+
for kw in shared_kwargs:
|
|
213
|
+
# Put large objects in Ray's object store (zero-copy)
|
|
214
|
+
shared_refs[kw] = _ray_module.put(func_kwargs[kw])
|
|
215
|
+
pbar.set_postfix_str(f'ray: shared `{kw}` via object store')
|
|
216
|
+
|
|
217
|
+
# Remaining kwargs are regular
|
|
218
|
+
regular_kwargs = {
|
|
219
|
+
k: v for k, v in func_kwargs.items()
|
|
220
|
+
if k not in shared_kwargs
|
|
221
|
+
}
|
|
222
|
+
else:
|
|
223
|
+
regular_kwargs = func_kwargs
|
|
224
|
+
|
|
225
|
+
@_ray_module.remote
|
|
226
|
+
def _task(x, shared_refs_dict, regular_kwargs_dict):
|
|
227
|
+
# Dereference shared objects (zero-copy for numpy arrays)
|
|
228
|
+
import ray as _ray_in_task
|
|
229
|
+
from .progress import set_progress_context
|
|
230
|
+
|
|
231
|
+
gate = Path(gate_path_str) if gate_path_str else None
|
|
232
|
+
dereferenced = {}
|
|
233
|
+
progress_actor_ref = None
|
|
234
|
+
|
|
235
|
+
for k, v in shared_refs_dict.items():
|
|
236
|
+
if k == 'progress_actor':
|
|
237
|
+
progress_actor_ref = v
|
|
238
|
+
# Don't add progress_actor to kwargs - it's for context only
|
|
239
|
+
else:
|
|
240
|
+
dereferenced[k] = _ray_in_task.get(v)
|
|
241
|
+
|
|
242
|
+
# Set progress context for this worker thread
|
|
243
|
+
# This allows user code to call report_progress() directly
|
|
244
|
+
if progress_actor_ref is not None:
|
|
245
|
+
set_progress_context(progress_actor_ref)
|
|
246
|
+
|
|
247
|
+
all_kwargs = {**dereferenced, **regular_kwargs_dict}
|
|
248
|
+
return _call_with_log_control(
|
|
249
|
+
f_wrapped,
|
|
250
|
+
x,
|
|
251
|
+
all_kwargs,
|
|
252
|
+
log_worker,
|
|
253
|
+
gate,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
refs = [
|
|
257
|
+
_task.remote(x, shared_refs, regular_kwargs) for x in items
|
|
258
|
+
]
|
|
259
|
+
|
|
260
|
+
t_start = time.time()
|
|
261
|
+
|
|
262
|
+
# Start progress poller if item-level tracking enabled
|
|
263
|
+
if use_item_tracking and progress_actor is not None:
|
|
264
|
+
progress_poller = ProgressPoller(progress_actor, pbar, poll_interval)
|
|
265
|
+
progress_poller.start()
|
|
266
|
+
|
|
267
|
+
for idx, r in enumerate(refs):
|
|
268
|
+
try:
|
|
269
|
+
result = _ray_module.get(r)
|
|
270
|
+
error_stats.record_success()
|
|
271
|
+
results.append(result)
|
|
272
|
+
except _ray_module.exceptions.RayTaskError as e:
|
|
273
|
+
if error_handler == 'raise':
|
|
274
|
+
if progress_poller is not None:
|
|
275
|
+
progress_poller.stop()
|
|
276
|
+
_exit_on_ray_error(e, pbar, caller_info)
|
|
277
|
+
# Extract original error from RayTaskError
|
|
278
|
+
cause = e.cause if hasattr(e, 'cause') else e.__cause__
|
|
279
|
+
original_error = cause if cause else e
|
|
280
|
+
# Pass full RayTaskError for fallback frame extraction
|
|
281
|
+
error_stats.record_error(
|
|
282
|
+
idx, original_error, items[idx], func_name,
|
|
283
|
+
ray_task_error=e
|
|
284
|
+
)
|
|
285
|
+
results.append(None)
|
|
286
|
+
|
|
287
|
+
# Only update progress bar for task-level tracking
|
|
288
|
+
# Item-level tracking is handled by progress_poller
|
|
289
|
+
if not use_item_tracking:
|
|
290
|
+
pbar.update(1)
|
|
291
|
+
# Update pbar with success/error counts
|
|
292
|
+
postfix = error_stats.get_postfix_dict()
|
|
293
|
+
pbar.set_postfix(postfix)
|
|
294
|
+
|
|
295
|
+
if progress_poller is not None:
|
|
296
|
+
progress_poller.stop()
|
|
297
|
+
|
|
298
|
+
t_end = time.time()
|
|
299
|
+
item_desc = (
|
|
300
|
+
f'{total_items:,} items' if total_items else f'{total} tasks'
|
|
301
|
+
)
|
|
302
|
+
print(f'Ray processing took {t_end - t_start:.2f}s for {item_desc}')
|
|
303
|
+
|
|
304
|
+
_cleanup_log_gate(log_gate_path)
|
|
305
|
+
return results
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
__all__ = ['ensure_ray', 'run_ray_backend', 'RAY_WORKER']
|