fal 1.45.2__py3-none-any.whl → 1.46.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fal might be problematic. Click here for more details.
- fal/_fal_version.py +2 -2
- fal/api/__init__.py +1 -0
- fal/api/apps.py +69 -0
- fal/api/client.py +116 -0
- fal/api/deploy.py +211 -0
- fal/api/runners.py +16 -0
- fal/cli/apps.py +51 -60
- fal/cli/deploy.py +29 -181
- fal/cli/queue.py +2 -2
- fal/cli/runners.py +45 -47
- fal/distributed/__init__.py +3 -0
- fal/distributed/utils.py +420 -0
- fal/distributed/worker.py +776 -0
- {fal-1.45.2.dist-info → fal-1.46.0.dist-info}/METADATA +1 -1
- {fal-1.45.2.dist-info → fal-1.46.0.dist-info}/RECORD +19 -11
- /fal/{api.py → api/api.py} +0 -0
- {fal-1.45.2.dist-info → fal-1.46.0.dist-info}/WHEEL +0 -0
- {fal-1.45.2.dist-info → fal-1.46.0.dist-info}/entry_points.txt +0 -0
- {fal-1.45.2.dist-info → fal-1.46.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,776 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import inspect
|
|
3
|
+
import os
|
|
4
|
+
import pickle
|
|
5
|
+
import queue
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
import traceback
|
|
9
|
+
import warnings
|
|
10
|
+
from collections.abc import AsyncIterator, Callable, Coroutine
|
|
11
|
+
from concurrent.futures import Future
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
14
|
+
|
|
15
|
+
from fal.distributed.utils import (
|
|
16
|
+
KeepAliveTimer,
|
|
17
|
+
distributed_deserialize,
|
|
18
|
+
distributed_serialize,
|
|
19
|
+
encode_text_event,
|
|
20
|
+
launch_distributed_processes,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
import torch
|
|
25
|
+
import torch.multiprocessing as mp
|
|
26
|
+
from zmq.sugar.socket import Socket
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DistributedWorker:
|
|
30
|
+
"""
|
|
31
|
+
A base class for distributed workers.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
queue: queue.Queue[bytes]
|
|
35
|
+
loop: asyncio.AbstractEventLoop
|
|
36
|
+
thread: threading.Thread
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
rank: int = 0,
|
|
41
|
+
world_size: int = 1,
|
|
42
|
+
) -> None:
|
|
43
|
+
self.rank = rank
|
|
44
|
+
self.world_size = world_size
|
|
45
|
+
self.queue = queue.Queue()
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
import uvloop
|
|
49
|
+
|
|
50
|
+
self.loop = uvloop.new_event_loop()
|
|
51
|
+
except ImportError:
|
|
52
|
+
self.loop = asyncio.new_event_loop()
|
|
53
|
+
|
|
54
|
+
self._start_thread()
|
|
55
|
+
|
|
56
|
+
def _start_thread(self) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Start the thread.
|
|
59
|
+
"""
|
|
60
|
+
self.thread = threading.Thread(target=self._run_forever, daemon=True)
|
|
61
|
+
self.thread.start()
|
|
62
|
+
|
|
63
|
+
def _run_forever(self) -> None:
|
|
64
|
+
"""
|
|
65
|
+
Run the event loop forever.
|
|
66
|
+
"""
|
|
67
|
+
asyncio.set_event_loop(self.loop)
|
|
68
|
+
self.loop.run_forever()
|
|
69
|
+
|
|
70
|
+
# Public API
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def device(self) -> "torch.device":
|
|
74
|
+
"""
|
|
75
|
+
:return: The device for the current worker.
|
|
76
|
+
"""
|
|
77
|
+
import torch
|
|
78
|
+
|
|
79
|
+
return torch.device(f"cuda:{self.rank}")
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def running(self) -> bool:
|
|
83
|
+
"""
|
|
84
|
+
:return: Whether the event loop is running.
|
|
85
|
+
"""
|
|
86
|
+
return self.thread.is_alive()
|
|
87
|
+
|
|
88
|
+
def initialize(self, **kwargs: Any) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Initialize the worker.
|
|
91
|
+
"""
|
|
92
|
+
import torch
|
|
93
|
+
|
|
94
|
+
torch.cuda.set_device(self.device)
|
|
95
|
+
self.rank_print(f"Initializing worker on device {self.device}")
|
|
96
|
+
|
|
97
|
+
setup_start = time.time()
|
|
98
|
+
future = self.run_in_worker(self.setup, **kwargs)
|
|
99
|
+
future.result()
|
|
100
|
+
setup_duration = time.time() - setup_start
|
|
101
|
+
self.rank_print(f"Setup took {setup_duration:.2f} seconds")
|
|
102
|
+
|
|
103
|
+
def add_streaming_result(
|
|
104
|
+
self,
|
|
105
|
+
result: Any,
|
|
106
|
+
image_format: str = "jpeg",
|
|
107
|
+
as_text_event: bool = False,
|
|
108
|
+
) -> None:
|
|
109
|
+
"""
|
|
110
|
+
Add a streaming result to the queue.
|
|
111
|
+
:param result: The result to add to the queue.
|
|
112
|
+
"""
|
|
113
|
+
if as_text_event:
|
|
114
|
+
serialized = encode_text_event(
|
|
115
|
+
result, is_final=False, image_format=image_format
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
serialized = distributed_serialize(
|
|
119
|
+
result, is_final=False, image_format=image_format
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
self.queue.put_nowait(serialized)
|
|
123
|
+
|
|
124
|
+
def add_streaming_error(self, error: Exception) -> None:
|
|
125
|
+
"""
|
|
126
|
+
Add an error to the queue.
|
|
127
|
+
:param error: The error to add to the queue.
|
|
128
|
+
"""
|
|
129
|
+
self.queue.put_nowait(
|
|
130
|
+
distributed_serialize({"error": str(error)}, is_final=False)
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def rank_print(self, message: str, debug: bool = False) -> None:
|
|
134
|
+
"""
|
|
135
|
+
Print a message with the rank of the current worker.
|
|
136
|
+
:param message: The message to print.
|
|
137
|
+
:param debug: Whether to print the message as a debug message.
|
|
138
|
+
"""
|
|
139
|
+
prefix = "[debug] " if debug else ""
|
|
140
|
+
print(f"{prefix}[rank {self.rank}] {message}")
|
|
141
|
+
|
|
142
|
+
def submit(self, coro: Coroutine[Any, Any, Any]) -> Future[Any]:
|
|
143
|
+
"""
|
|
144
|
+
Submit a coroutine to the event loop.
|
|
145
|
+
:param coro: The coroutine to submit to the event loop.
|
|
146
|
+
:return: A future that will resolve to the result of the coroutine.
|
|
147
|
+
"""
|
|
148
|
+
if not self.running:
|
|
149
|
+
raise RuntimeError("Event loop is not running.")
|
|
150
|
+
return asyncio.run_coroutine_threadsafe(coro, self.loop)
|
|
151
|
+
|
|
152
|
+
def shutdown(self, timeout: Optional[Union[int, float]] = None) -> None:
|
|
153
|
+
"""
|
|
154
|
+
Shutdown the event loop.
|
|
155
|
+
:param timeout: The timeout for the shutdown.
|
|
156
|
+
"""
|
|
157
|
+
try:
|
|
158
|
+
self.run_in_worker(self.teardown).result()
|
|
159
|
+
except Exception as e:
|
|
160
|
+
self.rank_print(f"Error during teardown: {e}\n{traceback.format_exc()}")
|
|
161
|
+
|
|
162
|
+
self.loop.call_soon_threadsafe(self.loop.stop)
|
|
163
|
+
self.thread.join(timeout=timeout)
|
|
164
|
+
|
|
165
|
+
def run_in_worker(
|
|
166
|
+
self,
|
|
167
|
+
func: Callable[..., Any],
|
|
168
|
+
*args: Any,
|
|
169
|
+
**kwargs: Any,
|
|
170
|
+
) -> Future[Any]:
|
|
171
|
+
"""
|
|
172
|
+
Run a function in the worker.
|
|
173
|
+
"""
|
|
174
|
+
if inspect.iscoroutinefunction(func):
|
|
175
|
+
coro = func(*args, **kwargs)
|
|
176
|
+
else:
|
|
177
|
+
coro = asyncio.to_thread(func, *args, **kwargs)
|
|
178
|
+
|
|
179
|
+
return self.submit(coro)
|
|
180
|
+
|
|
181
|
+
# Overrideables
|
|
182
|
+
|
|
183
|
+
def setup(self, **kwargs: Any) -> None:
|
|
184
|
+
"""
|
|
185
|
+
Override this method to set up the worker.
|
|
186
|
+
This method is called once per worker.
|
|
187
|
+
"""
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
def teardown(self) -> None:
|
|
191
|
+
"""
|
|
192
|
+
Override this method to tear down the worker.
|
|
193
|
+
This method is called once per worker.
|
|
194
|
+
"""
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
def __call__(self, streaming: bool = False, **kwargs: Any) -> Any:
|
|
198
|
+
"""
|
|
199
|
+
Override this method to run the worker.
|
|
200
|
+
"""
|
|
201
|
+
return {}
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class DistributedRunner:
|
|
205
|
+
"""
|
|
206
|
+
A class to launch and manage distributed workers.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
zmq_socket: Optional["Socket[Any]"]
|
|
210
|
+
context: Optional["mp.ProcessContext"]
|
|
211
|
+
keepalive_timer: Optional[KeepAliveTimer]
|
|
212
|
+
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
worker_cls: type[DistributedWorker] = DistributedWorker,
|
|
216
|
+
world_size: int = 1,
|
|
217
|
+
master_addr: str = "127.0.0.1",
|
|
218
|
+
master_port: int = 29500,
|
|
219
|
+
worker_addr: str = "127.0.0.1",
|
|
220
|
+
worker_port: int = 54923,
|
|
221
|
+
timeout: int = 86400, # 24 hours
|
|
222
|
+
keepalive_payload: dict[str, Any] = {},
|
|
223
|
+
keepalive_interval: Optional[Union[int, float]] = None,
|
|
224
|
+
cwd: Optional[Union[str, Path]] = None,
|
|
225
|
+
set_device: Optional[bool] = None, # deprecated
|
|
226
|
+
) -> None:
|
|
227
|
+
self.worker_cls = worker_cls
|
|
228
|
+
self.world_size = world_size
|
|
229
|
+
self.master_addr = master_addr
|
|
230
|
+
self.master_port = master_port
|
|
231
|
+
self.worker_addr = worker_addr
|
|
232
|
+
self.worker_port = worker_port
|
|
233
|
+
self.timeout = timeout
|
|
234
|
+
self.cwd = cwd
|
|
235
|
+
self.zmq_socket = None
|
|
236
|
+
self.context = None
|
|
237
|
+
self.keepalive_payload = keepalive_payload
|
|
238
|
+
self.keepalive_interval = keepalive_interval
|
|
239
|
+
self.keepalive_timer = None
|
|
240
|
+
|
|
241
|
+
if set_device is not None:
|
|
242
|
+
warnings.warn("set_device is deprecated and will be removed in the future.")
|
|
243
|
+
|
|
244
|
+
def is_alive(self) -> bool:
|
|
245
|
+
"""
|
|
246
|
+
Check if the distributed worker processes are alive.
|
|
247
|
+
:return: True if the distributed processes are alive, False otherwise.
|
|
248
|
+
"""
|
|
249
|
+
if self.context is None:
|
|
250
|
+
return False
|
|
251
|
+
for process in self.context.processes:
|
|
252
|
+
if not process.is_alive():
|
|
253
|
+
return False
|
|
254
|
+
return True
|
|
255
|
+
|
|
256
|
+
def terminate(self, timeout: Union[int, float] = 10) -> None:
|
|
257
|
+
"""
|
|
258
|
+
Terminates the distributed worker processes.
|
|
259
|
+
This method should be called to clean up the worker processes.
|
|
260
|
+
"""
|
|
261
|
+
if self.context is not None:
|
|
262
|
+
for process in self.context.processes:
|
|
263
|
+
if process.is_alive():
|
|
264
|
+
process.terminate()
|
|
265
|
+
process.join(timeout=timeout)
|
|
266
|
+
|
|
267
|
+
def gather_errors(self) -> list[Exception]:
|
|
268
|
+
"""
|
|
269
|
+
Gathers errors from the distributed worker processes.
|
|
270
|
+
|
|
271
|
+
This method should be called to collect any errors that occurred
|
|
272
|
+
during execution.
|
|
273
|
+
|
|
274
|
+
:return: A list of exceptions raised by the worker processes.
|
|
275
|
+
"""
|
|
276
|
+
errors = []
|
|
277
|
+
|
|
278
|
+
if self.context is not None:
|
|
279
|
+
for error_file in self.context.error_files:
|
|
280
|
+
if os.path.exists(error_file):
|
|
281
|
+
with open(error_file, "rb") as f:
|
|
282
|
+
error = pickle.loads(f.read())
|
|
283
|
+
errors.append(error)
|
|
284
|
+
|
|
285
|
+
os.remove(error_file)
|
|
286
|
+
|
|
287
|
+
return errors
|
|
288
|
+
|
|
289
|
+
def ensure_alive(self) -> None:
|
|
290
|
+
"""
|
|
291
|
+
Ensures that the distributed worker processes are alive.
|
|
292
|
+
If the processes are not alive, it raises an error.
|
|
293
|
+
"""
|
|
294
|
+
if not self.is_alive():
|
|
295
|
+
raise RuntimeError(
|
|
296
|
+
f"Distributed processes are not running. Errors: {self.gather_errors()}"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
def get_zmq_socket(self) -> "Socket[Any]":
|
|
300
|
+
"""
|
|
301
|
+
Returns a ZeroMQ socket of the specified type.
|
|
302
|
+
:param socket_type: The type of the ZeroMQ socket.
|
|
303
|
+
:return: A ZeroMQ socket.
|
|
304
|
+
"""
|
|
305
|
+
if self.zmq_socket is not None:
|
|
306
|
+
return self.zmq_socket
|
|
307
|
+
|
|
308
|
+
import zmq
|
|
309
|
+
import zmq.asyncio
|
|
310
|
+
|
|
311
|
+
context = zmq.asyncio.Context()
|
|
312
|
+
socket = context.socket(zmq.ROUTER)
|
|
313
|
+
socket.bind(f"tcp://{self.worker_addr}:{self.worker_port}")
|
|
314
|
+
self.zmq_socket = socket
|
|
315
|
+
return socket
|
|
316
|
+
|
|
317
|
+
def close_zmq_socket(self) -> None:
|
|
318
|
+
"""
|
|
319
|
+
Closes the ZeroMQ socket.
|
|
320
|
+
"""
|
|
321
|
+
if self.zmq_socket is not None:
|
|
322
|
+
try:
|
|
323
|
+
self.zmq_socket.close()
|
|
324
|
+
except Exception as e:
|
|
325
|
+
print(
|
|
326
|
+
f"[debug] Error closing ZeroMQ socket: {e}\n"
|
|
327
|
+
f"{traceback.format_exc()}"
|
|
328
|
+
)
|
|
329
|
+
self.zmq_socket = None
|
|
330
|
+
|
|
331
|
+
def run(self, **kwargs: Any) -> None:
|
|
332
|
+
"""
|
|
333
|
+
The main function to run the distributed worker.
|
|
334
|
+
|
|
335
|
+
This function is called by each worker process spawned by
|
|
336
|
+
`torch.multiprocessing.spawn`. This method must be synchronous.
|
|
337
|
+
|
|
338
|
+
:param kwargs: The arguments to pass to the worker.
|
|
339
|
+
"""
|
|
340
|
+
import torch.distributed as dist
|
|
341
|
+
import zmq
|
|
342
|
+
|
|
343
|
+
# Set up communication
|
|
344
|
+
rank = int(os.environ["RANK"])
|
|
345
|
+
context = zmq.Context()
|
|
346
|
+
socket = context.socket(zmq.DEALER)
|
|
347
|
+
socket.setsockopt(zmq.IDENTITY, str(rank).encode("utf-8"))
|
|
348
|
+
socket.connect(f"tcp://{self.worker_addr}:{self.worker_port}")
|
|
349
|
+
|
|
350
|
+
# Create and setup the worker
|
|
351
|
+
worker = self.worker_cls(rank, self.world_size)
|
|
352
|
+
try:
|
|
353
|
+
worker.initialize(**kwargs)
|
|
354
|
+
except Exception as e:
|
|
355
|
+
worker.rank_print(
|
|
356
|
+
f"Error during initialization: {e}\n{traceback.format_exc()}"
|
|
357
|
+
)
|
|
358
|
+
socket.send(b"EXIT")
|
|
359
|
+
socket.close()
|
|
360
|
+
return
|
|
361
|
+
|
|
362
|
+
# Wait until all workers are ready
|
|
363
|
+
socket.send(b"READY")
|
|
364
|
+
dist.barrier()
|
|
365
|
+
|
|
366
|
+
# Define execution methods to invoke from workers
|
|
367
|
+
def execute(payload: bytes) -> Any:
|
|
368
|
+
"""
|
|
369
|
+
Execute the worker function with the given payload synchronously.
|
|
370
|
+
:param payload: The payload to send to the worker.
|
|
371
|
+
:return: The result from the worker.
|
|
372
|
+
"""
|
|
373
|
+
payload_dict = distributed_deserialize(payload)
|
|
374
|
+
assert isinstance(payload_dict, dict)
|
|
375
|
+
payload_dict["streaming"] = False
|
|
376
|
+
|
|
377
|
+
try:
|
|
378
|
+
future = worker.run_in_worker(worker.__call__, **payload_dict)
|
|
379
|
+
result = future.result()
|
|
380
|
+
except Exception as e:
|
|
381
|
+
error_output = {"error": str(e)}
|
|
382
|
+
worker.rank_print(
|
|
383
|
+
f"Error in execution: {error_output}\n{traceback.format_exc()}"
|
|
384
|
+
)
|
|
385
|
+
result = error_output
|
|
386
|
+
|
|
387
|
+
dist.barrier()
|
|
388
|
+
if worker.rank != 0:
|
|
389
|
+
return
|
|
390
|
+
|
|
391
|
+
socket.send(distributed_serialize(result, is_final=True))
|
|
392
|
+
|
|
393
|
+
def stream(payload: bytes, as_text_events: bool) -> None:
|
|
394
|
+
"""
|
|
395
|
+
Stream the result from the worker function with the given payload.
|
|
396
|
+
:param payload: The payload to send to the worker.
|
|
397
|
+
:return: An async iterator that yields the result from the worker.
|
|
398
|
+
"""
|
|
399
|
+
payload_dict = distributed_deserialize(payload)
|
|
400
|
+
assert isinstance(payload_dict, dict)
|
|
401
|
+
payload_dict["streaming"] = True
|
|
402
|
+
image_format = payload_dict.get("image_format", "jpeg")
|
|
403
|
+
encoded_response: Optional[bytes] = None
|
|
404
|
+
|
|
405
|
+
try:
|
|
406
|
+
future = worker.run_in_worker(worker.__call__, **payload_dict)
|
|
407
|
+
while not future.done():
|
|
408
|
+
try:
|
|
409
|
+
intermediate = worker.queue.get(timeout=0.1)
|
|
410
|
+
if intermediate is not None and worker.rank == 0:
|
|
411
|
+
socket.send(intermediate) # already serialized
|
|
412
|
+
except queue.Empty:
|
|
413
|
+
pass
|
|
414
|
+
result = future.result()
|
|
415
|
+
except Exception as e:
|
|
416
|
+
error_output = {"error": str(e)}
|
|
417
|
+
worker.rank_print(
|
|
418
|
+
f"Error in streaming: {error_output}\n{traceback.format_exc()}"
|
|
419
|
+
)
|
|
420
|
+
if worker.rank == 0:
|
|
421
|
+
if as_text_events:
|
|
422
|
+
encoded_response = encode_text_event(error_output)
|
|
423
|
+
else:
|
|
424
|
+
encoded_response = distributed_serialize(error_output)
|
|
425
|
+
else:
|
|
426
|
+
if worker.rank == 0:
|
|
427
|
+
if as_text_events:
|
|
428
|
+
encoded_response = encode_text_event(
|
|
429
|
+
result, is_final=True, image_format=image_format
|
|
430
|
+
)
|
|
431
|
+
else:
|
|
432
|
+
encoded_response = distributed_serialize(
|
|
433
|
+
result, is_final=True, image_format=image_format
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
dist.barrier()
|
|
437
|
+
if worker.rank != 0:
|
|
438
|
+
return
|
|
439
|
+
|
|
440
|
+
if encoded_response is not None:
|
|
441
|
+
socket.send(encoded_response)
|
|
442
|
+
socket.send(b"DONE")
|
|
443
|
+
|
|
444
|
+
# Runtime code
|
|
445
|
+
if rank == 0:
|
|
446
|
+
worker.rank_print("Master worker is ready to receive tasks.")
|
|
447
|
+
while True:
|
|
448
|
+
serialized_data = socket.recv()
|
|
449
|
+
streaming = serialized_data[0] == ord("1")
|
|
450
|
+
as_text_events = serialized_data[1] == ord("1")
|
|
451
|
+
serialized_data = serialized_data[2:]
|
|
452
|
+
params = [serialized_data, streaming, as_text_events]
|
|
453
|
+
dist.broadcast_object_list(params, src=0)
|
|
454
|
+
|
|
455
|
+
if serialized_data == b"EXIT":
|
|
456
|
+
worker.rank_print("Received exit payload, exiting.")
|
|
457
|
+
break
|
|
458
|
+
|
|
459
|
+
if streaming:
|
|
460
|
+
stream(serialized_data, as_text_events)
|
|
461
|
+
else:
|
|
462
|
+
execute(serialized_data)
|
|
463
|
+
else:
|
|
464
|
+
worker.rank_print("Worker waiting for tasks.")
|
|
465
|
+
while True:
|
|
466
|
+
try:
|
|
467
|
+
params = [None, None, None]
|
|
468
|
+
dist.broadcast_object_list(params, src=0)
|
|
469
|
+
payload, streaming, as_text_events = params # type: ignore[assignment]
|
|
470
|
+
if payload == b"EXIT":
|
|
471
|
+
worker.rank_print("Received exit payload, exiting.")
|
|
472
|
+
break
|
|
473
|
+
|
|
474
|
+
if streaming:
|
|
475
|
+
stream(payload, as_text_events) # type: ignore[arg-type]
|
|
476
|
+
else:
|
|
477
|
+
execute(payload) # type: ignore[arg-type]
|
|
478
|
+
except Exception as e:
|
|
479
|
+
worker.rank_print(f"Error in worker: {e}\n{traceback.format_exc()}")
|
|
480
|
+
|
|
481
|
+
# Teardown
|
|
482
|
+
worker.rank_print("Worker is tearing down.")
|
|
483
|
+
try:
|
|
484
|
+
worker.shutdown()
|
|
485
|
+
worker.rank_print("Worker torn down successfully.")
|
|
486
|
+
except Exception as e:
|
|
487
|
+
worker.rank_print(f"Error during teardown: {e}\n{traceback.format_exc()}")
|
|
488
|
+
|
|
489
|
+
socket.send(b"EXIT")
|
|
490
|
+
socket.close()
|
|
491
|
+
|
|
492
|
+
async def start(self, timeout: int = 1800, **kwargs: Any) -> None:
|
|
493
|
+
"""
|
|
494
|
+
Starts the distributed worker processes.
|
|
495
|
+
:param timeout: The timeout for the distributed processes.
|
|
496
|
+
"""
|
|
497
|
+
import zmq
|
|
498
|
+
|
|
499
|
+
if self.is_alive():
|
|
500
|
+
raise RuntimeError("Distributed processes are already running.")
|
|
501
|
+
|
|
502
|
+
self.context = launch_distributed_processes(
|
|
503
|
+
self.run,
|
|
504
|
+
world_size=self.world_size,
|
|
505
|
+
master_addr=self.master_addr,
|
|
506
|
+
master_port=self.master_port,
|
|
507
|
+
timeout=self.timeout,
|
|
508
|
+
cwd=self.cwd,
|
|
509
|
+
**kwargs,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
try:
|
|
513
|
+
ready_workers: set[int] = set()
|
|
514
|
+
socket = self.get_zmq_socket()
|
|
515
|
+
start_time = time.perf_counter()
|
|
516
|
+
|
|
517
|
+
while len(ready_workers) < self.world_size:
|
|
518
|
+
try:
|
|
519
|
+
ident, msg = await socket.recv_multipart(flags=zmq.NOBLOCK) # type: ignore[misc]
|
|
520
|
+
|
|
521
|
+
if msg != b"READY":
|
|
522
|
+
worker_id = ident.decode("utf-8")
|
|
523
|
+
worker_msg = msg.decode("utf-8")
|
|
524
|
+
raise RuntimeError(
|
|
525
|
+
f"Unexpected message from worker {worker_id}: {worker_msg}"
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
print(f"[debug] Worker {ident.decode('utf-8')} is ready.")
|
|
529
|
+
ready_workers.add(ident)
|
|
530
|
+
except zmq.Again:
|
|
531
|
+
total_wait_time = time.perf_counter() - start_time
|
|
532
|
+
if total_wait_time > timeout:
|
|
533
|
+
raise TimeoutError(
|
|
534
|
+
f"Timeout reached after {timeout} seconds while "
|
|
535
|
+
f"waiting for workers to be ready."
|
|
536
|
+
)
|
|
537
|
+
await asyncio.sleep(0.5)
|
|
538
|
+
self.ensure_alive()
|
|
539
|
+
except Exception as e:
|
|
540
|
+
print(f"[debug] Error during startup: {e}\n{traceback.format_exc()}")
|
|
541
|
+
self.terminate(timeout=timeout)
|
|
542
|
+
raise RuntimeError("Failed to start distributed processes.") from e
|
|
543
|
+
|
|
544
|
+
print("[debug] All workers are ready and running.")
|
|
545
|
+
|
|
546
|
+
# Start the keepalive timer
|
|
547
|
+
self.maybe_start_keepalive()
|
|
548
|
+
|
|
549
|
+
def keepalive(self, timeout: Optional[Union[int, float]] = 60.0) -> None:
|
|
550
|
+
"""
|
|
551
|
+
Sends the keepalive payload to the worker.
|
|
552
|
+
"""
|
|
553
|
+
# Cancel the keepalive timer
|
|
554
|
+
self.maybe_cancel_keepalive()
|
|
555
|
+
loop_thread = None
|
|
556
|
+
try:
|
|
557
|
+
loop = asyncio.get_running_loop()
|
|
558
|
+
except RuntimeError:
|
|
559
|
+
loop = asyncio.new_event_loop()
|
|
560
|
+
loop_thread = threading.Thread(target=loop.run_forever, daemon=True)
|
|
561
|
+
loop_thread.start()
|
|
562
|
+
|
|
563
|
+
future = asyncio.run_coroutine_threadsafe(
|
|
564
|
+
self.invoke(self.keepalive_payload), loop
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
try:
|
|
568
|
+
future.result(timeout=timeout)
|
|
569
|
+
except Exception as e:
|
|
570
|
+
print(f"[debug] Error during keepalive: {e}\n{traceback.format_exc()}")
|
|
571
|
+
raise RuntimeError("Failed to run keepalive.") from e
|
|
572
|
+
finally:
|
|
573
|
+
if loop_thread is not None:
|
|
574
|
+
loop.call_soon_threadsafe(loop.stop)
|
|
575
|
+
loop_thread.join(timeout=timeout)
|
|
576
|
+
# Restart the keepalive timer
|
|
577
|
+
self.maybe_start_keepalive()
|
|
578
|
+
|
|
579
|
+
def maybe_start_keepalive(self) -> None:
|
|
580
|
+
"""
|
|
581
|
+
Starts the keepalive timer if it is set.
|
|
582
|
+
"""
|
|
583
|
+
if self.keepalive_timer is None and self.keepalive_interval is not None:
|
|
584
|
+
self.keepalive_timer = KeepAliveTimer(
|
|
585
|
+
self.keepalive, self.keepalive_interval, start=True
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
def maybe_reset_keepalive(self) -> None:
|
|
589
|
+
"""
|
|
590
|
+
Resets the keepalive timer if it is set.
|
|
591
|
+
"""
|
|
592
|
+
if self.keepalive_timer is not None:
|
|
593
|
+
self.keepalive_timer.reset()
|
|
594
|
+
|
|
595
|
+
def maybe_cancel_keepalive(self) -> None:
|
|
596
|
+
"""
|
|
597
|
+
Cancels the keepalive timer if it is set.
|
|
598
|
+
"""
|
|
599
|
+
if self.keepalive_timer is not None:
|
|
600
|
+
self.keepalive_timer.cancel()
|
|
601
|
+
self.keepalive_timer = None
|
|
602
|
+
|
|
603
|
+
async def stop(self, timeout: int = 10) -> None:
|
|
604
|
+
"""
|
|
605
|
+
Stops the distributed worker processes.
|
|
606
|
+
:param timeout: The timeout for the distributed processes to stop.
|
|
607
|
+
"""
|
|
608
|
+
import zmq
|
|
609
|
+
|
|
610
|
+
if not self.is_alive():
|
|
611
|
+
raise RuntimeError("Distributed processes are not running.")
|
|
612
|
+
|
|
613
|
+
self.maybe_cancel_keepalive()
|
|
614
|
+
worker_exits: set[int] = set()
|
|
615
|
+
socket = self.get_zmq_socket()
|
|
616
|
+
await socket.send_multipart([b"0", b"00EXIT"])
|
|
617
|
+
|
|
618
|
+
wait_start = time.perf_counter()
|
|
619
|
+
while len(worker_exits) < self.world_size:
|
|
620
|
+
try:
|
|
621
|
+
ident, msg = await socket.recv_multipart(flags=zmq.NOBLOCK) # type: ignore[misc]
|
|
622
|
+
if msg == b"EXIT":
|
|
623
|
+
worker_exits.add(ident)
|
|
624
|
+
if len(worker_exits) == self.world_size:
|
|
625
|
+
print("[debug] All workers have exited.")
|
|
626
|
+
await asyncio.sleep(1) # Allow time for cleanup
|
|
627
|
+
break
|
|
628
|
+
|
|
629
|
+
except zmq.Again:
|
|
630
|
+
# No messages available, continue waiting
|
|
631
|
+
await asyncio.sleep(0.1)
|
|
632
|
+
|
|
633
|
+
if time.perf_counter() - wait_start > timeout:
|
|
634
|
+
print(
|
|
635
|
+
f"[debug] Timeout reached after {timeout} seconds, "
|
|
636
|
+
f"stopping waiting."
|
|
637
|
+
)
|
|
638
|
+
break
|
|
639
|
+
|
|
640
|
+
if not self.is_alive():
|
|
641
|
+
print("[debug] All workers have exited prematurely.")
|
|
642
|
+
break
|
|
643
|
+
|
|
644
|
+
if self.is_alive():
|
|
645
|
+
print("[debug] Some workers did not exit cleanly, terminating them.")
|
|
646
|
+
self.terminate(timeout=timeout)
|
|
647
|
+
|
|
648
|
+
self.close_zmq_socket()
|
|
649
|
+
|
|
650
|
+
async def stream(
|
|
651
|
+
self,
|
|
652
|
+
payload: dict[str, Any] = {},
|
|
653
|
+
timeout: Optional[int] = None,
|
|
654
|
+
streaming_timeout: Optional[int] = None,
|
|
655
|
+
as_text_events: bool = False,
|
|
656
|
+
) -> AsyncIterator[Any]:
|
|
657
|
+
"""
|
|
658
|
+
Streams the result from the distributed worker.
|
|
659
|
+
:param payload: The payload to send to the worker.
|
|
660
|
+
:param timeout: The timeout for the overall operation.
|
|
661
|
+
:param streaming_timeout: The timeout in-between streamed results.
|
|
662
|
+
:param as_text_events: Whether to yield results as text events.
|
|
663
|
+
:return: An async iterator that yields the result from the worker.
|
|
664
|
+
"""
|
|
665
|
+
import zmq
|
|
666
|
+
|
|
667
|
+
self.ensure_alive()
|
|
668
|
+
self.maybe_cancel_keepalive() # Cancel until the streaming is done
|
|
669
|
+
socket = self.get_zmq_socket()
|
|
670
|
+
payload_serialized = distributed_serialize(payload, is_final=True)
|
|
671
|
+
await socket.send_multipart(
|
|
672
|
+
[b"0", b"1" + (b"1" if as_text_events else b"0") + payload_serialized]
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
start_time = time.perf_counter()
|
|
676
|
+
last_yield_time = start_time
|
|
677
|
+
yielded_once = False
|
|
678
|
+
|
|
679
|
+
while True:
|
|
680
|
+
iter_start_time = time.perf_counter()
|
|
681
|
+
try:
|
|
682
|
+
rank, response = await socket.recv_multipart(flags=zmq.NOBLOCK) # type: ignore[misc]
|
|
683
|
+
except zmq.Again:
|
|
684
|
+
if timeout is not None and iter_start_time - start_time > timeout:
|
|
685
|
+
raise TimeoutError(f"Streaming timed out after {timeout} seconds.")
|
|
686
|
+
if (
|
|
687
|
+
streaming_timeout is not None
|
|
688
|
+
and iter_start_time - last_yield_time > streaming_timeout
|
|
689
|
+
):
|
|
690
|
+
raise TimeoutError(
|
|
691
|
+
f"Streaming timed out after {streaming_timeout} "
|
|
692
|
+
f"seconds of inactivity."
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
await asyncio.sleep(0.1)
|
|
696
|
+
continue
|
|
697
|
+
|
|
698
|
+
assert rank == b"0", "Expected response from worker with rank 0"
|
|
699
|
+
|
|
700
|
+
if response == b"DONE":
|
|
701
|
+
if not yielded_once:
|
|
702
|
+
raise RuntimeError("No data was yielded from the worker.")
|
|
703
|
+
break
|
|
704
|
+
|
|
705
|
+
if as_text_events:
|
|
706
|
+
yield response
|
|
707
|
+
else:
|
|
708
|
+
yield distributed_deserialize(response)
|
|
709
|
+
|
|
710
|
+
yielded_once = True
|
|
711
|
+
last_yield_time = iter_start_time
|
|
712
|
+
|
|
713
|
+
self.maybe_start_keepalive() # Restart the keepalive timer
|
|
714
|
+
|
|
715
|
+
async def invoke(
|
|
716
|
+
self,
|
|
717
|
+
payload: dict[str, Any] = {},
|
|
718
|
+
timeout: Optional[int] = None,
|
|
719
|
+
) -> Any:
|
|
720
|
+
"""
|
|
721
|
+
Invokes the distributed worker with the given payload.
|
|
722
|
+
:param payload: The payload to send to the worker.
|
|
723
|
+
:param timeout: The timeout for the overall operation.
|
|
724
|
+
:return: The result from the worker.
|
|
725
|
+
"""
|
|
726
|
+
import zmq
|
|
727
|
+
|
|
728
|
+
self.ensure_alive()
|
|
729
|
+
self.maybe_cancel_keepalive() # Cancel until the invocation is done
|
|
730
|
+
socket = self.get_zmq_socket()
|
|
731
|
+
payload_serialized = distributed_serialize(payload, is_final=True)
|
|
732
|
+
|
|
733
|
+
await socket.send_multipart([b"0", b"00" + payload_serialized])
|
|
734
|
+
|
|
735
|
+
# Wait for the response from the worker
|
|
736
|
+
start_time = time.perf_counter()
|
|
737
|
+
while True:
|
|
738
|
+
try:
|
|
739
|
+
rank, response = await socket.recv_multipart(flags=zmq.NOBLOCK) # type: ignore[misc]
|
|
740
|
+
break # Exit the loop if we received a response
|
|
741
|
+
except zmq.Again:
|
|
742
|
+
elapsed = time.perf_counter() - start_time
|
|
743
|
+
if timeout is not None and elapsed > timeout:
|
|
744
|
+
raise TimeoutError(f"Invocation timed out after {timeout} seconds.")
|
|
745
|
+
|
|
746
|
+
await asyncio.sleep(0.1)
|
|
747
|
+
self.ensure_alive()
|
|
748
|
+
|
|
749
|
+
self.maybe_start_keepalive() # Restart the keepalive timer
|
|
750
|
+
assert rank == b"0", "Expected response from worker with rank 0"
|
|
751
|
+
return distributed_deserialize(response)
|
|
752
|
+
|
|
753
|
+
async def __aenter__(self) -> "DistributedRunner":
|
|
754
|
+
"""
|
|
755
|
+
Enter the context manager.
|
|
756
|
+
:return: The DistributedRunner instance.
|
|
757
|
+
"""
|
|
758
|
+
await self.start()
|
|
759
|
+
return self
|
|
760
|
+
|
|
761
|
+
async def __aexit__(
|
|
762
|
+
self,
|
|
763
|
+
exc_type: Optional[type[BaseException]],
|
|
764
|
+
exc_value: Optional[BaseException],
|
|
765
|
+
exc_traceback: Optional[traceback.StackSummary],
|
|
766
|
+
) -> None:
|
|
767
|
+
"""
|
|
768
|
+
Exit the context manager.
|
|
769
|
+
:param exc_type: The type of the exception raised, if any.
|
|
770
|
+
:param exc_value: The value of the exception raised, if any.
|
|
771
|
+
:param traceback: The traceback of the exception raised, if any.
|
|
772
|
+
"""
|
|
773
|
+
try:
|
|
774
|
+
await self.stop()
|
|
775
|
+
except Exception as e:
|
|
776
|
+
print(f"[debug] Error during cleanup: {e}\n{traceback.format_exc()}")
|