speedy-utils 1.1.45__py3-none-any.whl → 1.1.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,308 @@
1
+ """
2
+ Ray-specific multi_process backend implementation.
3
+
4
+ Contains:
5
+ - ensure_ray(): Ray initialization/lifecycle management
6
+ - run_ray_backend(): Ray-based parallel execution
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import inspect
11
+ import logging
12
+ import os
13
+ import sys
14
+ import time
15
+ from pathlib import Path
16
+ from typing import TYPE_CHECKING, Any, Callable, Literal
17
+
18
+ from tqdm import tqdm
19
+
20
+ from .common import (
21
+ ErrorStats,
22
+ _call_with_log_control,
23
+ _cleanup_log_gate,
24
+ _exit_on_ray_error,
25
+ _track_ray_processes,
26
+ )
27
+ from .progress import ProgressPoller, create_progress_tracker
28
+
29
+ if TYPE_CHECKING:
30
+ from .common import ErrorHandlerType
31
+
32
+ # ─── Ray management ─────────────────────────────────────────────
33
+
34
+ RAY_WORKER = None
35
+
36
+
37
+ def ensure_ray(
38
+ workers: int | None,
39
+ pbar: tqdm | None = None,
40
+ ray_metrics_port: int | None = None,
41
+ ) -> None:
42
+ """
43
+ Initialize or reinitialize Ray safely for both local and cluster envs.
44
+
45
+ 1. Tries to connect to an existing cluster (address='auto') first.
46
+ 2. If no cluster is found, starts a local Ray instance with 'workers' CPUs.
47
+ """
48
+ import ray as _ray_module
49
+
50
+ global RAY_WORKER
51
+ requested_workers = workers
52
+ if workers is None:
53
+ workers = os.cpu_count() or 1
54
+
55
+ if ray_metrics_port is not None:
56
+ os.environ['RAY_metrics_export_port'] = str(ray_metrics_port)
57
+
58
+ allow_restart = os.environ.get('RESTART_RAY', '0').lower() in ('1', 'true')
59
+ is_cluster_env = (
60
+ 'RAY_ADDRESS' in os.environ
61
+ or os.environ.get('RAY_CLUSTER') == '1'
62
+ )
63
+
64
+ # 1. Handle existing session
65
+ if _ray_module.is_initialized():
66
+ if not allow_restart:
67
+ if pbar:
68
+ pbar.set_postfix_str('Using existing Ray session')
69
+ return
70
+
71
+ # Avoid shutting down shared cluster sessions.
72
+ if is_cluster_env:
73
+ if pbar:
74
+ pbar.set_postfix_str(
75
+ 'Cluster active: skipping restart to protect connection'
76
+ )
77
+ return
78
+
79
+ # Local restart: only if worker count changed
80
+ if workers != RAY_WORKER:
81
+ if pbar:
82
+ pbar.set_postfix_str(
83
+ f'Restarting local Ray with {workers} workers'
84
+ )
85
+ _ray_module.shutdown()
86
+ else:
87
+ return
88
+
89
+ # 2. Initialization logic
90
+ t0 = time.time()
91
+
92
+ # Try to connect to existing cluster FIRST (address="auto")
93
+ try:
94
+ if pbar:
95
+ pbar.set_postfix_str('Searching for Ray cluster...')
96
+
97
+ # MUST NOT pass num_cpus/num_gpus here to avoid ValueError
98
+ _ray_module.init(
99
+ address='auto',
100
+ ignore_reinit_error=True,
101
+ logging_level=logging.ERROR,
102
+ log_to_driver=False
103
+ )
104
+
105
+ if pbar:
106
+ resources = _ray_module.cluster_resources()
107
+ cpus = resources.get('CPU', 0)
108
+ pbar.set_postfix_str(f'Connected to Ray Cluster ({int(cpus)} CPUs)')
109
+
110
+ except Exception:
111
+ # 3. Fallback: Start a local Ray session
112
+ if pbar:
113
+ pbar.set_postfix_str(
114
+ f'No cluster found. Starting local Ray ({workers} CPUs)...'
115
+ )
116
+
117
+ _ray_module.init(
118
+ num_cpus=workers,
119
+ ignore_reinit_error=True,
120
+ logging_level=logging.ERROR,
121
+ log_to_driver=False,
122
+ )
123
+
124
+ if pbar:
125
+ took = time.time() - t0
126
+ pbar.set_postfix_str(f'ray.init local {workers} took {took:.2f}s')
127
+
128
+ _track_ray_processes()
129
+
130
+ if requested_workers is None:
131
+ try:
132
+ resources = _ray_module.cluster_resources()
133
+ total_cpus = int(resources.get('CPU', 0))
134
+ if total_cpus > 0:
135
+ workers = total_cpus
136
+ except Exception:
137
+ pass
138
+
139
+ RAY_WORKER = workers
140
+
141
+
142
+ def run_ray_backend(
143
+ *,
144
+ f_wrapped: Callable,
145
+ items: list[Any],
146
+ total: int,
147
+ workers: int | None,
148
+ progress: bool,
149
+ desc: str,
150
+ func_kwargs: dict[str, Any],
151
+ shared_kwargs: list[str] | None,
152
+ log_worker: Literal['zero', 'first', 'all'],
153
+ log_gate_path: Path | None,
154
+ total_items: int | None,
155
+ poll_interval: float,
156
+ ray_metrics_port: int | None,
157
+ error_handler: ErrorHandlerType,
158
+ error_stats: ErrorStats,
159
+ func_name: str,
160
+ ) -> list[Any]:
161
+ """
162
+ Run the Ray backend for multi_process.
163
+
164
+ Returns a list of results in the same order as items.
165
+ """
166
+ import ray as _ray_module
167
+
168
+ # Capture caller frame for better error reporting
169
+ # Go back to multi_process -> user code
170
+ caller_frame = inspect.currentframe()
171
+ caller_info = None
172
+ if (
173
+ caller_frame
174
+ and caller_frame.f_back
175
+ and caller_frame.f_back.f_back
176
+ ):
177
+ outer = caller_frame.f_back.f_back
178
+ caller_info = {
179
+ 'filename': outer.f_code.co_filename,
180
+ 'lineno': outer.f_lineno,
181
+ 'function': outer.f_code.co_name,
182
+ }
183
+
184
+ results = []
185
+ gate_path_str = str(log_gate_path) if log_gate_path else None
186
+
187
+ # Determine if we're doing item-level or task-level tracking
188
+ use_item_tracking = total_items is not None
189
+ pbar_total = total_items if use_item_tracking else total
190
+ pbar_desc = desc if use_item_tracking else desc
191
+
192
+ with tqdm(
193
+ total=pbar_total,
194
+ desc=pbar_desc,
195
+ disable=not progress,
196
+ file=sys.stdout,
197
+ unit='items' if use_item_tracking else 'tasks',
198
+ ) as pbar:
199
+ ensure_ray(workers, pbar, ray_metrics_port)
200
+
201
+ shared_refs: dict[str, Any] = {}
202
+ regular_kwargs: dict[str, Any] = {}
203
+
204
+ # Create progress actor for item-level tracking if total_items specified
205
+ progress_actor = None
206
+ progress_poller = None
207
+ if use_item_tracking:
208
+ progress_actor = create_progress_tracker(total_items, pbar_desc or 'Items')
209
+ shared_refs['progress_actor'] = progress_actor
210
+
211
+ if shared_kwargs:
212
+ for kw in shared_kwargs:
213
+ # Put large objects in Ray's object store (zero-copy)
214
+ shared_refs[kw] = _ray_module.put(func_kwargs[kw])
215
+ pbar.set_postfix_str(f'ray: shared `{kw}` via object store')
216
+
217
+ # Remaining kwargs are regular
218
+ regular_kwargs = {
219
+ k: v for k, v in func_kwargs.items()
220
+ if k not in shared_kwargs
221
+ }
222
+ else:
223
+ regular_kwargs = func_kwargs
224
+
225
+ @_ray_module.remote
226
+ def _task(x, shared_refs_dict, regular_kwargs_dict):
227
+ # Dereference shared objects (zero-copy for numpy arrays)
228
+ import ray as _ray_in_task
229
+ from .progress import set_progress_context
230
+
231
+ gate = Path(gate_path_str) if gate_path_str else None
232
+ dereferenced = {}
233
+ progress_actor_ref = None
234
+
235
+ for k, v in shared_refs_dict.items():
236
+ if k == 'progress_actor':
237
+ progress_actor_ref = v
238
+ # Don't add progress_actor to kwargs - it's for context only
239
+ else:
240
+ dereferenced[k] = _ray_in_task.get(v)
241
+
242
+ # Set progress context for this worker thread
243
+ # This allows user code to call report_progress() directly
244
+ if progress_actor_ref is not None:
245
+ set_progress_context(progress_actor_ref)
246
+
247
+ all_kwargs = {**dereferenced, **regular_kwargs_dict}
248
+ return _call_with_log_control(
249
+ f_wrapped,
250
+ x,
251
+ all_kwargs,
252
+ log_worker,
253
+ gate,
254
+ )
255
+
256
+ refs = [
257
+ _task.remote(x, shared_refs, regular_kwargs) for x in items
258
+ ]
259
+
260
+ t_start = time.time()
261
+
262
+ # Start progress poller if item-level tracking enabled
263
+ if use_item_tracking and progress_actor is not None:
264
+ progress_poller = ProgressPoller(progress_actor, pbar, poll_interval)
265
+ progress_poller.start()
266
+
267
+ for idx, r in enumerate(refs):
268
+ try:
269
+ result = _ray_module.get(r)
270
+ error_stats.record_success()
271
+ results.append(result)
272
+ except _ray_module.exceptions.RayTaskError as e:
273
+ if error_handler == 'raise':
274
+ if progress_poller is not None:
275
+ progress_poller.stop()
276
+ _exit_on_ray_error(e, pbar, caller_info)
277
+ # Extract original error from RayTaskError
278
+ cause = e.cause if hasattr(e, 'cause') else e.__cause__
279
+ original_error = cause if cause else e
280
+ # Pass full RayTaskError for fallback frame extraction
281
+ error_stats.record_error(
282
+ idx, original_error, items[idx], func_name,
283
+ ray_task_error=e
284
+ )
285
+ results.append(None)
286
+
287
+ # Only update progress bar for task-level tracking
288
+ # Item-level tracking is handled by progress_poller
289
+ if not use_item_tracking:
290
+ pbar.update(1)
291
+ # Update pbar with success/error counts
292
+ postfix = error_stats.get_postfix_dict()
293
+ pbar.set_postfix(postfix)
294
+
295
+ if progress_poller is not None:
296
+ progress_poller.stop()
297
+
298
+ t_end = time.time()
299
+ item_desc = (
300
+ f'{total_items:,} items' if total_items else f'{total} tasks'
301
+ )
302
+ print(f'Ray processing took {t_end - t_start:.2f}s for {item_desc}')
303
+
304
+ _cleanup_log_gate(log_gate_path)
305
+ return results
306
+
307
+
308
+ __all__ = ['ensure_ray', 'run_ray_backend', 'RAY_WORKER']