fal 1.45.2__py3-none-any.whl → 1.46.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

@@ -0,0 +1,791 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import inspect
5
+ import os
6
+ import pickle
7
+ import queue
8
+ import threading
9
+ import time
10
+ import traceback
11
+ import warnings
12
+ from collections.abc import AsyncIterator, Callable, Coroutine
13
+ from concurrent.futures import Future
14
+ from functools import partial
15
+ from pathlib import Path
16
+ from typing import TYPE_CHECKING, Any, Optional, Union
17
+
18
+ from fal.distributed.utils import (
19
+ KeepAliveTimer,
20
+ distributed_deserialize,
21
+ distributed_serialize,
22
+ encode_text_event,
23
+ launch_distributed_processes,
24
+ )
25
+
26
+ if TYPE_CHECKING:
27
+ import torch
28
+ import torch.multiprocessing as mp
29
+ from zmq.sugar.socket import Socket
30
+
31
+
32
+ class DistributedWorker:
33
+ """
34
+ A base class for distributed workers.
35
+ """
36
+
37
+ queue: queue.Queue[bytes]
38
+ loop: asyncio.AbstractEventLoop
39
+ thread: threading.Thread
40
+
41
+ def __init__(
42
+ self,
43
+ rank: int = 0,
44
+ world_size: int = 1,
45
+ ) -> None:
46
+ self.rank = rank
47
+ self.world_size = world_size
48
+ self.queue = queue.Queue()
49
+
50
+ try:
51
+ import uvloop
52
+
53
+ self.loop = uvloop.new_event_loop()
54
+ except ImportError:
55
+ self.loop = asyncio.new_event_loop()
56
+
57
+ self._start_thread()
58
+
59
+ def _start_thread(self) -> None:
60
+ """
61
+ Start the thread.
62
+ """
63
+ self.thread = threading.Thread(target=self._run_forever, daemon=True)
64
+ self.thread.start()
65
+
66
+ def _run_forever(self) -> None:
67
+ """
68
+ Run the event loop forever.
69
+ """
70
+ asyncio.set_event_loop(self.loop)
71
+ self.loop.run_forever()
72
+
73
+ # Public API
74
+
75
+ @property
76
+ def device(self) -> torch.device:
77
+ """
78
+ :return: The device for the current worker.
79
+ """
80
+ import torch
81
+
82
+ return torch.device(f"cuda:{self.rank}")
83
+
84
+ @property
85
+ def running(self) -> bool:
86
+ """
87
+ :return: Whether the event loop is running.
88
+ """
89
+ return self.thread.is_alive()
90
+
91
+ def initialize(self, **kwargs: Any) -> None:
92
+ """
93
+ Initialize the worker.
94
+ """
95
+ import torch
96
+
97
+ torch.cuda.set_device(self.device)
98
+ self.rank_print(f"Initializing worker on device {self.device}")
99
+
100
+ setup_start = time.time()
101
+ future = self.run_in_worker(self.setup, **kwargs)
102
+ future.result()
103
+ setup_duration = time.time() - setup_start
104
+ self.rank_print(f"Setup took {setup_duration:.2f} seconds")
105
+
106
+ def add_streaming_result(
107
+ self,
108
+ result: Any,
109
+ image_format: str = "jpeg",
110
+ as_text_event: bool = False,
111
+ ) -> None:
112
+ """
113
+ Add a streaming result to the queue.
114
+ :param result: The result to add to the queue.
115
+ """
116
+ if as_text_event:
117
+ serialized = encode_text_event(
118
+ result, is_final=False, image_format=image_format
119
+ )
120
+ else:
121
+ serialized = distributed_serialize(
122
+ result, is_final=False, image_format=image_format
123
+ )
124
+
125
+ self.queue.put_nowait(serialized)
126
+
127
+ def add_streaming_error(self, error: Exception) -> None:
128
+ """
129
+ Add an error to the queue.
130
+ :param error: The error to add to the queue.
131
+ """
132
+ self.queue.put_nowait(
133
+ distributed_serialize({"error": str(error)}, is_final=False)
134
+ )
135
+
136
+ def rank_print(self, message: str, debug: bool = False) -> None:
137
+ """
138
+ Print a message with the rank of the current worker.
139
+ :param message: The message to print.
140
+ :param debug: Whether to print the message as a debug message.
141
+ """
142
+ prefix = "[debug] " if debug else ""
143
+ print(f"{prefix}[rank {self.rank}] {message}")
144
+
145
+ def submit(self, coro: Coroutine[Any, Any, Any]) -> Future[Any]:
146
+ """
147
+ Submit a coroutine to the event loop.
148
+ :param coro: The coroutine to submit to the event loop.
149
+ :return: A future that will resolve to the result of the coroutine.
150
+ """
151
+ if not self.running:
152
+ raise RuntimeError("Event loop is not running.")
153
+ return asyncio.run_coroutine_threadsafe(coro, self.loop)
154
+
155
+ def shutdown(self, timeout: Optional[Union[int, float]] = None) -> None:
156
+ """
157
+ Shutdown the event loop.
158
+ :param timeout: The timeout for the shutdown.
159
+ """
160
+ try:
161
+ self.run_in_worker(self.teardown).result()
162
+ except Exception as e:
163
+ self.rank_print(f"Error during teardown: {e}\n{traceback.format_exc()}")
164
+
165
+ self.loop.call_soon_threadsafe(self.loop.stop)
166
+ self.thread.join(timeout=timeout)
167
+
168
+ async def _run_sync_in_executor(
169
+ self,
170
+ func: Callable[..., Any],
171
+ *args: Any,
172
+ **kwargs: Any,
173
+ ) -> Any:
174
+ """Run a synchronous function in the executor."""
175
+ loop = asyncio.get_running_loop()
176
+ return await loop.run_in_executor(None, partial(func, *args, **kwargs))
177
+
178
+ def run_in_worker(
179
+ self,
180
+ func: Callable[..., Any],
181
+ *args: Any,
182
+ **kwargs: Any,
183
+ ) -> Future[Any]:
184
+ """
185
+ Run a function in the worker.
186
+ """
187
+ if inspect.iscoroutinefunction(func):
188
+ coro = func(*args, **kwargs)
189
+ else:
190
+ # Using in place of asyncio.to_thread
191
+ # since it's not available in Python 3.8
192
+ coro = self._run_sync_in_executor(func, *args, **kwargs)
193
+
194
+ return self.submit(coro)
195
+
196
+ # Overrideables
197
+
198
+ def setup(self, **kwargs: Any) -> None:
199
+ """
200
+ Override this method to set up the worker.
201
+ This method is called once per worker.
202
+ """
203
+ return
204
+
205
+ def teardown(self) -> None:
206
+ """
207
+ Override this method to tear down the worker.
208
+ This method is called once per worker.
209
+ """
210
+ return
211
+
212
+ def __call__(self, streaming: bool = False, **kwargs: Any) -> Any:
213
+ """
214
+ Override this method to run the worker.
215
+ """
216
+ return {}
217
+
218
+
219
+ class DistributedRunner:
220
+ """
221
+ A class to launch and manage distributed workers.
222
+ """
223
+
224
+ zmq_socket: Optional[Socket[Any]]
225
+ context: Optional[mp.ProcessContext]
226
+ keepalive_timer: Optional[KeepAliveTimer]
227
+
228
+ def __init__(
229
+ self,
230
+ worker_cls: type[DistributedWorker] = DistributedWorker,
231
+ world_size: int = 1,
232
+ master_addr: str = "127.0.0.1",
233
+ master_port: int = 29500,
234
+ worker_addr: str = "127.0.0.1",
235
+ worker_port: int = 54923,
236
+ timeout: int = 86400, # 24 hours
237
+ keepalive_payload: dict[str, Any] = {},
238
+ keepalive_interval: Optional[Union[int, float]] = None,
239
+ cwd: Optional[Union[str, Path]] = None,
240
+ set_device: Optional[bool] = None, # deprecated
241
+ ) -> None:
242
+ self.worker_cls = worker_cls
243
+ self.world_size = world_size
244
+ self.master_addr = master_addr
245
+ self.master_port = master_port
246
+ self.worker_addr = worker_addr
247
+ self.worker_port = worker_port
248
+ self.timeout = timeout
249
+ self.cwd = cwd
250
+ self.zmq_socket = None
251
+ self.context = None
252
+ self.keepalive_payload = keepalive_payload
253
+ self.keepalive_interval = keepalive_interval
254
+ self.keepalive_timer = None
255
+
256
+ if set_device is not None:
257
+ warnings.warn("set_device is deprecated and will be removed in the future.")
258
+
259
+ def is_alive(self) -> bool:
260
+ """
261
+ Check if the distributed worker processes are alive.
262
+ :return: True if the distributed processes are alive, False otherwise.
263
+ """
264
+ if self.context is None:
265
+ return False
266
+ for process in self.context.processes:
267
+ if not process.is_alive():
268
+ return False
269
+ return True
270
+
271
+ def terminate(self, timeout: Union[int, float] = 10) -> None:
272
+ """
273
+ Terminates the distributed worker processes.
274
+ This method should be called to clean up the worker processes.
275
+ """
276
+ if self.context is not None:
277
+ for process in self.context.processes:
278
+ if process.is_alive():
279
+ process.terminate()
280
+ process.join(timeout=timeout)
281
+
282
+ def gather_errors(self) -> list[Exception]:
283
+ """
284
+ Gathers errors from the distributed worker processes.
285
+
286
+ This method should be called to collect any errors that occurred
287
+ during execution.
288
+
289
+ :return: A list of exceptions raised by the worker processes.
290
+ """
291
+ errors = []
292
+
293
+ if self.context is not None:
294
+ for error_file in self.context.error_files:
295
+ if os.path.exists(error_file):
296
+ with open(error_file, "rb") as f:
297
+ error = pickle.loads(f.read())
298
+ errors.append(error)
299
+
300
+ os.remove(error_file)
301
+
302
+ return errors
303
+
304
+ def ensure_alive(self) -> None:
305
+ """
306
+ Ensures that the distributed worker processes are alive.
307
+ If the processes are not alive, it raises an error.
308
+ """
309
+ if not self.is_alive():
310
+ raise RuntimeError(
311
+ f"Distributed processes are not running. Errors: {self.gather_errors()}"
312
+ )
313
+
314
+ def get_zmq_socket(self) -> Socket[Any]:
315
+ """
316
+ Returns a ZeroMQ socket of the specified type.
317
+ :param socket_type: The type of the ZeroMQ socket.
318
+ :return: A ZeroMQ socket.
319
+ """
320
+ if self.zmq_socket is not None:
321
+ return self.zmq_socket
322
+
323
+ import zmq
324
+ import zmq.asyncio
325
+
326
+ context = zmq.asyncio.Context()
327
+ socket = context.socket(zmq.ROUTER)
328
+ socket.bind(f"tcp://{self.worker_addr}:{self.worker_port}")
329
+ self.zmq_socket = socket
330
+ return socket
331
+
332
+ def close_zmq_socket(self) -> None:
333
+ """
334
+ Closes the ZeroMQ socket.
335
+ """
336
+ if self.zmq_socket is not None:
337
+ try:
338
+ self.zmq_socket.close()
339
+ except Exception as e:
340
+ print(
341
+ f"[debug] Error closing ZeroMQ socket: {e}\n"
342
+ f"{traceback.format_exc()}"
343
+ )
344
+ self.zmq_socket = None
345
+
346
+ def run(self, **kwargs: Any) -> None:
347
+ """
348
+ The main function to run the distributed worker.
349
+
350
+ This function is called by each worker process spawned by
351
+ `torch.multiprocessing.spawn`. This method must be synchronous.
352
+
353
+ :param kwargs: The arguments to pass to the worker.
354
+ """
355
+ import torch.distributed as dist
356
+ import zmq
357
+
358
+ # Set up communication
359
+ rank = int(os.environ["RANK"])
360
+ context = zmq.Context()
361
+ socket = context.socket(zmq.DEALER)
362
+ socket.setsockopt(zmq.IDENTITY, str(rank).encode("utf-8"))
363
+ socket.connect(f"tcp://{self.worker_addr}:{self.worker_port}")
364
+
365
+ # Create and setup the worker
366
+ worker = self.worker_cls(rank, self.world_size)
367
+ try:
368
+ worker.initialize(**kwargs)
369
+ except Exception as e:
370
+ worker.rank_print(
371
+ f"Error during initialization: {e}\n{traceback.format_exc()}"
372
+ )
373
+ socket.send(b"EXIT")
374
+ socket.close()
375
+ return
376
+
377
+ # Wait until all workers are ready
378
+ socket.send(b"READY")
379
+ dist.barrier()
380
+
381
+ # Define execution methods to invoke from workers
382
+ def execute(payload: bytes) -> Any:
383
+ """
384
+ Execute the worker function with the given payload synchronously.
385
+ :param payload: The payload to send to the worker.
386
+ :return: The result from the worker.
387
+ """
388
+ payload_dict = distributed_deserialize(payload)
389
+ assert isinstance(payload_dict, dict)
390
+ payload_dict["streaming"] = False
391
+
392
+ try:
393
+ future = worker.run_in_worker(worker.__call__, **payload_dict)
394
+ result = future.result()
395
+ except Exception as e:
396
+ error_output = {"error": str(e)}
397
+ worker.rank_print(
398
+ f"Error in execution: {error_output}\n{traceback.format_exc()}"
399
+ )
400
+ result = error_output
401
+
402
+ dist.barrier()
403
+ if worker.rank != 0:
404
+ return
405
+
406
+ socket.send(distributed_serialize(result, is_final=True))
407
+
408
+ def stream(payload: bytes, as_text_events: bool) -> None:
409
+ """
410
+ Stream the result from the worker function with the given payload.
411
+ :param payload: The payload to send to the worker.
412
+ :return: An async iterator that yields the result from the worker.
413
+ """
414
+ payload_dict = distributed_deserialize(payload)
415
+ assert isinstance(payload_dict, dict)
416
+ payload_dict["streaming"] = True
417
+ image_format = payload_dict.get("image_format", "jpeg")
418
+ encoded_response: Optional[bytes] = None
419
+
420
+ try:
421
+ future = worker.run_in_worker(worker.__call__, **payload_dict)
422
+ while not future.done():
423
+ try:
424
+ intermediate = worker.queue.get(timeout=0.1)
425
+ if intermediate is not None and worker.rank == 0:
426
+ socket.send(intermediate) # already serialized
427
+ except queue.Empty:
428
+ pass
429
+ result = future.result()
430
+ except Exception as e:
431
+ error_output = {"error": str(e)}
432
+ worker.rank_print(
433
+ f"Error in streaming: {error_output}\n{traceback.format_exc()}"
434
+ )
435
+ if worker.rank == 0:
436
+ if as_text_events:
437
+ encoded_response = encode_text_event(error_output)
438
+ else:
439
+ encoded_response = distributed_serialize(error_output)
440
+ else:
441
+ if worker.rank == 0:
442
+ if as_text_events:
443
+ encoded_response = encode_text_event(
444
+ result, is_final=True, image_format=image_format
445
+ )
446
+ else:
447
+ encoded_response = distributed_serialize(
448
+ result, is_final=True, image_format=image_format
449
+ )
450
+
451
+ dist.barrier()
452
+ if worker.rank != 0:
453
+ return
454
+
455
+ if encoded_response is not None:
456
+ socket.send(encoded_response)
457
+ socket.send(b"DONE")
458
+
459
+ # Runtime code
460
+ if rank == 0:
461
+ worker.rank_print("Master worker is ready to receive tasks.")
462
+ while True:
463
+ serialized_data = socket.recv()
464
+ streaming = serialized_data[0] == ord("1")
465
+ as_text_events = serialized_data[1] == ord("1")
466
+ serialized_data = serialized_data[2:]
467
+ params = [serialized_data, streaming, as_text_events]
468
+ dist.broadcast_object_list(params, src=0)
469
+
470
+ if serialized_data == b"EXIT":
471
+ worker.rank_print("Received exit payload, exiting.")
472
+ break
473
+
474
+ if streaming:
475
+ stream(serialized_data, as_text_events)
476
+ else:
477
+ execute(serialized_data)
478
+ else:
479
+ worker.rank_print("Worker waiting for tasks.")
480
+ while True:
481
+ try:
482
+ params = [None, None, None]
483
+ dist.broadcast_object_list(params, src=0)
484
+ payload, streaming, as_text_events = params # type: ignore[assignment]
485
+ if payload == b"EXIT":
486
+ worker.rank_print("Received exit payload, exiting.")
487
+ break
488
+
489
+ if streaming:
490
+ stream(payload, as_text_events) # type: ignore[arg-type]
491
+ else:
492
+ execute(payload) # type: ignore[arg-type]
493
+ except Exception as e:
494
+ worker.rank_print(f"Error in worker: {e}\n{traceback.format_exc()}")
495
+
496
+ # Teardown
497
+ worker.rank_print("Worker is tearing down.")
498
+ try:
499
+ worker.shutdown()
500
+ worker.rank_print("Worker torn down successfully.")
501
+ except Exception as e:
502
+ worker.rank_print(f"Error during teardown: {e}\n{traceback.format_exc()}")
503
+
504
+ socket.send(b"EXIT")
505
+ socket.close()
506
+
507
+ async def start(self, timeout: int = 1800, **kwargs: Any) -> None:
508
+ """
509
+ Starts the distributed worker processes.
510
+ :param timeout: The timeout for the distributed processes.
511
+ """
512
+ import zmq
513
+
514
+ if self.is_alive():
515
+ raise RuntimeError("Distributed processes are already running.")
516
+
517
+ self.context = launch_distributed_processes(
518
+ self.run,
519
+ world_size=self.world_size,
520
+ master_addr=self.master_addr,
521
+ master_port=self.master_port,
522
+ timeout=self.timeout,
523
+ cwd=self.cwd,
524
+ **kwargs,
525
+ )
526
+
527
+ try:
528
+ ready_workers: set[int] = set()
529
+ socket = self.get_zmq_socket()
530
+ start_time = time.perf_counter()
531
+
532
+ while len(ready_workers) < self.world_size:
533
+ try:
534
+ ident, msg = await socket.recv_multipart(flags=zmq.NOBLOCK) # type: ignore[misc]
535
+
536
+ if msg != b"READY":
537
+ worker_id = ident.decode("utf-8")
538
+ worker_msg = msg.decode("utf-8")
539
+ raise RuntimeError(
540
+ f"Unexpected message from worker {worker_id}: {worker_msg}"
541
+ )
542
+
543
+ print(f"[debug] Worker {ident.decode('utf-8')} is ready.")
544
+ ready_workers.add(ident)
545
+ except zmq.Again:
546
+ total_wait_time = time.perf_counter() - start_time
547
+ if total_wait_time > timeout:
548
+ raise TimeoutError(
549
+ f"Timeout reached after {timeout} seconds while "
550
+ f"waiting for workers to be ready."
551
+ )
552
+ await asyncio.sleep(0.5)
553
+ self.ensure_alive()
554
+ except Exception as e:
555
+ print(f"[debug] Error during startup: {e}\n{traceback.format_exc()}")
556
+ self.terminate(timeout=timeout)
557
+ raise RuntimeError("Failed to start distributed processes.") from e
558
+
559
+ print("[debug] All workers are ready and running.")
560
+
561
+ # Start the keepalive timer
562
+ self.maybe_start_keepalive()
563
+
564
+ def keepalive(self, timeout: Optional[Union[int, float]] = 60.0) -> None:
565
+ """
566
+ Sends the keepalive payload to the worker.
567
+ """
568
+ # Cancel the keepalive timer
569
+ self.maybe_cancel_keepalive()
570
+ loop_thread = None
571
+ try:
572
+ loop = asyncio.get_running_loop()
573
+ except RuntimeError:
574
+ loop = asyncio.new_event_loop()
575
+ loop_thread = threading.Thread(target=loop.run_forever, daemon=True)
576
+ loop_thread.start()
577
+
578
+ future = asyncio.run_coroutine_threadsafe(
579
+ self.invoke(self.keepalive_payload), loop
580
+ )
581
+
582
+ try:
583
+ future.result(timeout=timeout)
584
+ except Exception as e:
585
+ print(f"[debug] Error during keepalive: {e}\n{traceback.format_exc()}")
586
+ raise RuntimeError("Failed to run keepalive.") from e
587
+ finally:
588
+ if loop_thread is not None:
589
+ loop.call_soon_threadsafe(loop.stop)
590
+ loop_thread.join(timeout=timeout)
591
+ # Restart the keepalive timer
592
+ self.maybe_start_keepalive()
593
+
594
+ def maybe_start_keepalive(self) -> None:
595
+ """
596
+ Starts the keepalive timer if it is set.
597
+ """
598
+ if self.keepalive_timer is None and self.keepalive_interval is not None:
599
+ self.keepalive_timer = KeepAliveTimer(
600
+ self.keepalive, self.keepalive_interval, start=True
601
+ )
602
+
603
+ def maybe_reset_keepalive(self) -> None:
604
+ """
605
+ Resets the keepalive timer if it is set.
606
+ """
607
+ if self.keepalive_timer is not None:
608
+ self.keepalive_timer.reset()
609
+
610
+ def maybe_cancel_keepalive(self) -> None:
611
+ """
612
+ Cancels the keepalive timer if it is set.
613
+ """
614
+ if self.keepalive_timer is not None:
615
+ self.keepalive_timer.cancel()
616
+ self.keepalive_timer = None
617
+
618
+ async def stop(self, timeout: int = 10) -> None:
619
+ """
620
+ Stops the distributed worker processes.
621
+ :param timeout: The timeout for the distributed processes to stop.
622
+ """
623
+ import zmq
624
+
625
+ if not self.is_alive():
626
+ raise RuntimeError("Distributed processes are not running.")
627
+
628
+ self.maybe_cancel_keepalive()
629
+ worker_exits: set[int] = set()
630
+ socket = self.get_zmq_socket()
631
+ await socket.send_multipart([b"0", b"00EXIT"])
632
+
633
+ wait_start = time.perf_counter()
634
+ while len(worker_exits) < self.world_size:
635
+ try:
636
+ ident, msg = await socket.recv_multipart(flags=zmq.NOBLOCK) # type: ignore[misc]
637
+ if msg == b"EXIT":
638
+ worker_exits.add(ident)
639
+ if len(worker_exits) == self.world_size:
640
+ print("[debug] All workers have exited.")
641
+ await asyncio.sleep(1) # Allow time for cleanup
642
+ break
643
+
644
+ except zmq.Again:
645
+ # No messages available, continue waiting
646
+ await asyncio.sleep(0.1)
647
+
648
+ if time.perf_counter() - wait_start > timeout:
649
+ print(
650
+ f"[debug] Timeout reached after {timeout} seconds, "
651
+ f"stopping waiting."
652
+ )
653
+ break
654
+
655
+ if not self.is_alive():
656
+ print("[debug] All workers have exited prematurely.")
657
+ break
658
+
659
+ if self.is_alive():
660
+ print("[debug] Some workers did not exit cleanly, terminating them.")
661
+ self.terminate(timeout=timeout)
662
+
663
+ self.close_zmq_socket()
664
+
665
+ async def stream(
666
+ self,
667
+ payload: dict[str, Any] = {},
668
+ timeout: Optional[int] = None,
669
+ streaming_timeout: Optional[int] = None,
670
+ as_text_events: bool = False,
671
+ ) -> AsyncIterator[Any]:
672
+ """
673
+ Streams the result from the distributed worker.
674
+ :param payload: The payload to send to the worker.
675
+ :param timeout: The timeout for the overall operation.
676
+ :param streaming_timeout: The timeout in-between streamed results.
677
+ :param as_text_events: Whether to yield results as text events.
678
+ :return: An async iterator that yields the result from the worker.
679
+ """
680
+ import zmq
681
+
682
+ self.ensure_alive()
683
+ self.maybe_cancel_keepalive() # Cancel until the streaming is done
684
+ socket = self.get_zmq_socket()
685
+ payload_serialized = distributed_serialize(payload, is_final=True)
686
+ await socket.send_multipart(
687
+ [b"0", b"1" + (b"1" if as_text_events else b"0") + payload_serialized]
688
+ )
689
+
690
+ start_time = time.perf_counter()
691
+ last_yield_time = start_time
692
+ yielded_once = False
693
+
694
+ while True:
695
+ iter_start_time = time.perf_counter()
696
+ try:
697
+ rank, response = await socket.recv_multipart(flags=zmq.NOBLOCK) # type: ignore[misc]
698
+ except zmq.Again:
699
+ if timeout is not None and iter_start_time - start_time > timeout:
700
+ raise TimeoutError(f"Streaming timed out after {timeout} seconds.")
701
+ if (
702
+ streaming_timeout is not None
703
+ and iter_start_time - last_yield_time > streaming_timeout
704
+ ):
705
+ raise TimeoutError(
706
+ f"Streaming timed out after {streaming_timeout} "
707
+ f"seconds of inactivity."
708
+ )
709
+
710
+ await asyncio.sleep(0.1)
711
+ continue
712
+
713
+ assert rank == b"0", "Expected response from worker with rank 0"
714
+
715
+ if response == b"DONE":
716
+ if not yielded_once:
717
+ raise RuntimeError("No data was yielded from the worker.")
718
+ break
719
+
720
+ if as_text_events:
721
+ yield response
722
+ else:
723
+ yield distributed_deserialize(response)
724
+
725
+ yielded_once = True
726
+ last_yield_time = iter_start_time
727
+
728
+ self.maybe_start_keepalive() # Restart the keepalive timer
729
+
730
+ async def invoke(
731
+ self,
732
+ payload: dict[str, Any] = {},
733
+ timeout: Optional[int] = None,
734
+ ) -> Any:
735
+ """
736
+ Invokes the distributed worker with the given payload.
737
+ :param payload: The payload to send to the worker.
738
+ :param timeout: The timeout for the overall operation.
739
+ :return: The result from the worker.
740
+ """
741
+ import zmq
742
+
743
+ self.ensure_alive()
744
+ self.maybe_cancel_keepalive() # Cancel until the invocation is done
745
+ socket = self.get_zmq_socket()
746
+ payload_serialized = distributed_serialize(payload, is_final=True)
747
+
748
+ await socket.send_multipart([b"0", b"00" + payload_serialized])
749
+
750
+ # Wait for the response from the worker
751
+ start_time = time.perf_counter()
752
+ while True:
753
+ try:
754
+ rank, response = await socket.recv_multipart(flags=zmq.NOBLOCK) # type: ignore[misc]
755
+ break # Exit the loop if we received a response
756
+ except zmq.Again:
757
+ elapsed = time.perf_counter() - start_time
758
+ if timeout is not None and elapsed > timeout:
759
+ raise TimeoutError(f"Invocation timed out after {timeout} seconds.")
760
+
761
+ await asyncio.sleep(0.1)
762
+ self.ensure_alive()
763
+
764
+ self.maybe_start_keepalive() # Restart the keepalive timer
765
+ assert rank == b"0", "Expected response from worker with rank 0"
766
+ return distributed_deserialize(response)
767
+
768
+ async def __aenter__(self) -> DistributedRunner:
769
+ """
770
+ Enter the context manager.
771
+ :return: The DistributedRunner instance.
772
+ """
773
+ await self.start()
774
+ return self
775
+
776
+ async def __aexit__(
777
+ self,
778
+ exc_type: Optional[type[BaseException]],
779
+ exc_value: Optional[BaseException],
780
+ exc_traceback: Optional[traceback.StackSummary],
781
+ ) -> None:
782
+ """
783
+ Exit the context manager.
784
+ :param exc_type: The type of the exception raised, if any.
785
+ :param exc_value: The value of the exception raised, if any.
786
+ :param traceback: The traceback of the exception raised, if any.
787
+ """
788
+ try:
789
+ await self.stop()
790
+ except Exception as e:
791
+ print(f"[debug] Error during cleanup: {e}\n{traceback.format_exc()}")