fal 1.45.2__py3-none-any.whl → 1.46.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fal might be problematic. Click here for more details.

@@ -0,0 +1,776 @@
1
+ import asyncio
2
+ import inspect
3
+ import os
4
+ import pickle
5
+ import queue
6
+ import threading
7
+ import time
8
+ import traceback
9
+ import warnings
10
+ from collections.abc import AsyncIterator, Callable, Coroutine
11
+ from concurrent.futures import Future
12
+ from pathlib import Path
13
+ from typing import TYPE_CHECKING, Any, Optional, Union
14
+
15
+ from fal.distributed.utils import (
16
+ KeepAliveTimer,
17
+ distributed_deserialize,
18
+ distributed_serialize,
19
+ encode_text_event,
20
+ launch_distributed_processes,
21
+ )
22
+
23
+ if TYPE_CHECKING:
24
+ import torch
25
+ import torch.multiprocessing as mp
26
+ from zmq.sugar.socket import Socket
27
+
28
+
29
+ class DistributedWorker:
30
+ """
31
+ A base class for distributed workers.
32
+ """
33
+
34
+ queue: queue.Queue[bytes]
35
+ loop: asyncio.AbstractEventLoop
36
+ thread: threading.Thread
37
+
38
+ def __init__(
39
+ self,
40
+ rank: int = 0,
41
+ world_size: int = 1,
42
+ ) -> None:
43
+ self.rank = rank
44
+ self.world_size = world_size
45
+ self.queue = queue.Queue()
46
+
47
+ try:
48
+ import uvloop
49
+
50
+ self.loop = uvloop.new_event_loop()
51
+ except ImportError:
52
+ self.loop = asyncio.new_event_loop()
53
+
54
+ self._start_thread()
55
+
56
+ def _start_thread(self) -> None:
57
+ """
58
+ Start the thread.
59
+ """
60
+ self.thread = threading.Thread(target=self._run_forever, daemon=True)
61
+ self.thread.start()
62
+
63
+ def _run_forever(self) -> None:
64
+ """
65
+ Run the event loop forever.
66
+ """
67
+ asyncio.set_event_loop(self.loop)
68
+ self.loop.run_forever()
69
+
70
+ # Public API
71
+
72
+ @property
73
+ def device(self) -> "torch.device":
74
+ """
75
+ :return: The device for the current worker.
76
+ """
77
+ import torch
78
+
79
+ return torch.device(f"cuda:{self.rank}")
80
+
81
+ @property
82
+ def running(self) -> bool:
83
+ """
84
+ :return: Whether the event loop is running.
85
+ """
86
+ return self.thread.is_alive()
87
+
88
+ def initialize(self, **kwargs: Any) -> None:
89
+ """
90
+ Initialize the worker.
91
+ """
92
+ import torch
93
+
94
+ torch.cuda.set_device(self.device)
95
+ self.rank_print(f"Initializing worker on device {self.device}")
96
+
97
+ setup_start = time.time()
98
+ future = self.run_in_worker(self.setup, **kwargs)
99
+ future.result()
100
+ setup_duration = time.time() - setup_start
101
+ self.rank_print(f"Setup took {setup_duration:.2f} seconds")
102
+
103
+ def add_streaming_result(
104
+ self,
105
+ result: Any,
106
+ image_format: str = "jpeg",
107
+ as_text_event: bool = False,
108
+ ) -> None:
109
+ """
110
+ Add a streaming result to the queue.
111
+ :param result: The result to add to the queue.
112
+ """
113
+ if as_text_event:
114
+ serialized = encode_text_event(
115
+ result, is_final=False, image_format=image_format
116
+ )
117
+ else:
118
+ serialized = distributed_serialize(
119
+ result, is_final=False, image_format=image_format
120
+ )
121
+
122
+ self.queue.put_nowait(serialized)
123
+
124
+ def add_streaming_error(self, error: Exception) -> None:
125
+ """
126
+ Add an error to the queue.
127
+ :param error: The error to add to the queue.
128
+ """
129
+ self.queue.put_nowait(
130
+ distributed_serialize({"error": str(error)}, is_final=False)
131
+ )
132
+
133
+ def rank_print(self, message: str, debug: bool = False) -> None:
134
+ """
135
+ Print a message with the rank of the current worker.
136
+ :param message: The message to print.
137
+ :param debug: Whether to print the message as a debug message.
138
+ """
139
+ prefix = "[debug] " if debug else ""
140
+ print(f"{prefix}[rank {self.rank}] {message}")
141
+
142
+ def submit(self, coro: Coroutine[Any, Any, Any]) -> Future[Any]:
143
+ """
144
+ Submit a coroutine to the event loop.
145
+ :param coro: The coroutine to submit to the event loop.
146
+ :return: A future that will resolve to the result of the coroutine.
147
+ """
148
+ if not self.running:
149
+ raise RuntimeError("Event loop is not running.")
150
+ return asyncio.run_coroutine_threadsafe(coro, self.loop)
151
+
152
+ def shutdown(self, timeout: Optional[Union[int, float]] = None) -> None:
153
+ """
154
+ Shutdown the event loop.
155
+ :param timeout: The timeout for the shutdown.
156
+ """
157
+ try:
158
+ self.run_in_worker(self.teardown).result()
159
+ except Exception as e:
160
+ self.rank_print(f"Error during teardown: {e}\n{traceback.format_exc()}")
161
+
162
+ self.loop.call_soon_threadsafe(self.loop.stop)
163
+ self.thread.join(timeout=timeout)
164
+
165
+ def run_in_worker(
166
+ self,
167
+ func: Callable[..., Any],
168
+ *args: Any,
169
+ **kwargs: Any,
170
+ ) -> Future[Any]:
171
+ """
172
+ Run a function in the worker.
173
+ """
174
+ if inspect.iscoroutinefunction(func):
175
+ coro = func(*args, **kwargs)
176
+ else:
177
+ coro = asyncio.to_thread(func, *args, **kwargs)
178
+
179
+ return self.submit(coro)
180
+
181
+ # Overrideables
182
+
183
+ def setup(self, **kwargs: Any) -> None:
184
+ """
185
+ Override this method to set up the worker.
186
+ This method is called once per worker.
187
+ """
188
+ return
189
+
190
+ def teardown(self) -> None:
191
+ """
192
+ Override this method to tear down the worker.
193
+ This method is called once per worker.
194
+ """
195
+ return
196
+
197
+ def __call__(self, streaming: bool = False, **kwargs: Any) -> Any:
198
+ """
199
+ Override this method to run the worker.
200
+ """
201
+ return {}
202
+
203
+
204
+ class DistributedRunner:
205
+ """
206
+ A class to launch and manage distributed workers.
207
+ """
208
+
209
+ zmq_socket: Optional["Socket[Any]"]
210
+ context: Optional["mp.ProcessContext"]
211
+ keepalive_timer: Optional[KeepAliveTimer]
212
+
213
+ def __init__(
214
+ self,
215
+ worker_cls: type[DistributedWorker] = DistributedWorker,
216
+ world_size: int = 1,
217
+ master_addr: str = "127.0.0.1",
218
+ master_port: int = 29500,
219
+ worker_addr: str = "127.0.0.1",
220
+ worker_port: int = 54923,
221
+ timeout: int = 86400, # 24 hours
222
+ keepalive_payload: dict[str, Any] = {},
223
+ keepalive_interval: Optional[Union[int, float]] = None,
224
+ cwd: Optional[Union[str, Path]] = None,
225
+ set_device: Optional[bool] = None, # deprecated
226
+ ) -> None:
227
+ self.worker_cls = worker_cls
228
+ self.world_size = world_size
229
+ self.master_addr = master_addr
230
+ self.master_port = master_port
231
+ self.worker_addr = worker_addr
232
+ self.worker_port = worker_port
233
+ self.timeout = timeout
234
+ self.cwd = cwd
235
+ self.zmq_socket = None
236
+ self.context = None
237
+ self.keepalive_payload = keepalive_payload
238
+ self.keepalive_interval = keepalive_interval
239
+ self.keepalive_timer = None
240
+
241
+ if set_device is not None:
242
+ warnings.warn("set_device is deprecated and will be removed in the future.")
243
+
244
+ def is_alive(self) -> bool:
245
+ """
246
+ Check if the distributed worker processes are alive.
247
+ :return: True if the distributed processes are alive, False otherwise.
248
+ """
249
+ if self.context is None:
250
+ return False
251
+ for process in self.context.processes:
252
+ if not process.is_alive():
253
+ return False
254
+ return True
255
+
256
+ def terminate(self, timeout: Union[int, float] = 10) -> None:
257
+ """
258
+ Terminates the distributed worker processes.
259
+ This method should be called to clean up the worker processes.
260
+ """
261
+ if self.context is not None:
262
+ for process in self.context.processes:
263
+ if process.is_alive():
264
+ process.terminate()
265
+ process.join(timeout=timeout)
266
+
267
+ def gather_errors(self) -> list[Exception]:
268
+ """
269
+ Gathers errors from the distributed worker processes.
270
+
271
+ This method should be called to collect any errors that occurred
272
+ during execution.
273
+
274
+ :return: A list of exceptions raised by the worker processes.
275
+ """
276
+ errors = []
277
+
278
+ if self.context is not None:
279
+ for error_file in self.context.error_files:
280
+ if os.path.exists(error_file):
281
+ with open(error_file, "rb") as f:
282
+ error = pickle.loads(f.read())
283
+ errors.append(error)
284
+
285
+ os.remove(error_file)
286
+
287
+ return errors
288
+
289
+ def ensure_alive(self) -> None:
290
+ """
291
+ Ensures that the distributed worker processes are alive.
292
+ If the processes are not alive, it raises an error.
293
+ """
294
+ if not self.is_alive():
295
+ raise RuntimeError(
296
+ f"Distributed processes are not running. Errors: {self.gather_errors()}"
297
+ )
298
+
299
+ def get_zmq_socket(self) -> "Socket[Any]":
300
+ """
301
+ Returns a ZeroMQ socket of the specified type.
302
+ :param socket_type: The type of the ZeroMQ socket.
303
+ :return: A ZeroMQ socket.
304
+ """
305
+ if self.zmq_socket is not None:
306
+ return self.zmq_socket
307
+
308
+ import zmq
309
+ import zmq.asyncio
310
+
311
+ context = zmq.asyncio.Context()
312
+ socket = context.socket(zmq.ROUTER)
313
+ socket.bind(f"tcp://{self.worker_addr}:{self.worker_port}")
314
+ self.zmq_socket = socket
315
+ return socket
316
+
317
+ def close_zmq_socket(self) -> None:
318
+ """
319
+ Closes the ZeroMQ socket.
320
+ """
321
+ if self.zmq_socket is not None:
322
+ try:
323
+ self.zmq_socket.close()
324
+ except Exception as e:
325
+ print(
326
+ f"[debug] Error closing ZeroMQ socket: {e}\n"
327
+ f"{traceback.format_exc()}"
328
+ )
329
+ self.zmq_socket = None
330
+
331
+ def run(self, **kwargs: Any) -> None:
332
+ """
333
+ The main function to run the distributed worker.
334
+
335
+ This function is called by each worker process spawned by
336
+ `torch.multiprocessing.spawn`. This method must be synchronous.
337
+
338
+ :param kwargs: The arguments to pass to the worker.
339
+ """
340
+ import torch.distributed as dist
341
+ import zmq
342
+
343
+ # Set up communication
344
+ rank = int(os.environ["RANK"])
345
+ context = zmq.Context()
346
+ socket = context.socket(zmq.DEALER)
347
+ socket.setsockopt(zmq.IDENTITY, str(rank).encode("utf-8"))
348
+ socket.connect(f"tcp://{self.worker_addr}:{self.worker_port}")
349
+
350
+ # Create and setup the worker
351
+ worker = self.worker_cls(rank, self.world_size)
352
+ try:
353
+ worker.initialize(**kwargs)
354
+ except Exception as e:
355
+ worker.rank_print(
356
+ f"Error during initialization: {e}\n{traceback.format_exc()}"
357
+ )
358
+ socket.send(b"EXIT")
359
+ socket.close()
360
+ return
361
+
362
+ # Wait until all workers are ready
363
+ socket.send(b"READY")
364
+ dist.barrier()
365
+
366
+ # Define execution methods to invoke from workers
367
+ def execute(payload: bytes) -> Any:
368
+ """
369
+ Execute the worker function with the given payload synchronously.
370
+ :param payload: The payload to send to the worker.
371
+ :return: The result from the worker.
372
+ """
373
+ payload_dict = distributed_deserialize(payload)
374
+ assert isinstance(payload_dict, dict)
375
+ payload_dict["streaming"] = False
376
+
377
+ try:
378
+ future = worker.run_in_worker(worker.__call__, **payload_dict)
379
+ result = future.result()
380
+ except Exception as e:
381
+ error_output = {"error": str(e)}
382
+ worker.rank_print(
383
+ f"Error in execution: {error_output}\n{traceback.format_exc()}"
384
+ )
385
+ result = error_output
386
+
387
+ dist.barrier()
388
+ if worker.rank != 0:
389
+ return
390
+
391
+ socket.send(distributed_serialize(result, is_final=True))
392
+
393
+ def stream(payload: bytes, as_text_events: bool) -> None:
394
+ """
395
+ Stream the result from the worker function with the given payload.
396
+ :param payload: The payload to send to the worker.
397
+ :return: An async iterator that yields the result from the worker.
398
+ """
399
+ payload_dict = distributed_deserialize(payload)
400
+ assert isinstance(payload_dict, dict)
401
+ payload_dict["streaming"] = True
402
+ image_format = payload_dict.get("image_format", "jpeg")
403
+ encoded_response: Optional[bytes] = None
404
+
405
+ try:
406
+ future = worker.run_in_worker(worker.__call__, **payload_dict)
407
+ while not future.done():
408
+ try:
409
+ intermediate = worker.queue.get(timeout=0.1)
410
+ if intermediate is not None and worker.rank == 0:
411
+ socket.send(intermediate) # already serialized
412
+ except queue.Empty:
413
+ pass
414
+ result = future.result()
415
+ except Exception as e:
416
+ error_output = {"error": str(e)}
417
+ worker.rank_print(
418
+ f"Error in streaming: {error_output}\n{traceback.format_exc()}"
419
+ )
420
+ if worker.rank == 0:
421
+ if as_text_events:
422
+ encoded_response = encode_text_event(error_output)
423
+ else:
424
+ encoded_response = distributed_serialize(error_output)
425
+ else:
426
+ if worker.rank == 0:
427
+ if as_text_events:
428
+ encoded_response = encode_text_event(
429
+ result, is_final=True, image_format=image_format
430
+ )
431
+ else:
432
+ encoded_response = distributed_serialize(
433
+ result, is_final=True, image_format=image_format
434
+ )
435
+
436
+ dist.barrier()
437
+ if worker.rank != 0:
438
+ return
439
+
440
+ if encoded_response is not None:
441
+ socket.send(encoded_response)
442
+ socket.send(b"DONE")
443
+
444
+ # Runtime code
445
+ if rank == 0:
446
+ worker.rank_print("Master worker is ready to receive tasks.")
447
+ while True:
448
+ serialized_data = socket.recv()
449
+ streaming = serialized_data[0] == ord("1")
450
+ as_text_events = serialized_data[1] == ord("1")
451
+ serialized_data = serialized_data[2:]
452
+ params = [serialized_data, streaming, as_text_events]
453
+ dist.broadcast_object_list(params, src=0)
454
+
455
+ if serialized_data == b"EXIT":
456
+ worker.rank_print("Received exit payload, exiting.")
457
+ break
458
+
459
+ if streaming:
460
+ stream(serialized_data, as_text_events)
461
+ else:
462
+ execute(serialized_data)
463
+ else:
464
+ worker.rank_print("Worker waiting for tasks.")
465
+ while True:
466
+ try:
467
+ params = [None, None, None]
468
+ dist.broadcast_object_list(params, src=0)
469
+ payload, streaming, as_text_events = params # type: ignore[assignment]
470
+ if payload == b"EXIT":
471
+ worker.rank_print("Received exit payload, exiting.")
472
+ break
473
+
474
+ if streaming:
475
+ stream(payload, as_text_events) # type: ignore[arg-type]
476
+ else:
477
+ execute(payload) # type: ignore[arg-type]
478
+ except Exception as e:
479
+ worker.rank_print(f"Error in worker: {e}\n{traceback.format_exc()}")
480
+
481
+ # Teardown
482
+ worker.rank_print("Worker is tearing down.")
483
+ try:
484
+ worker.shutdown()
485
+ worker.rank_print("Worker torn down successfully.")
486
+ except Exception as e:
487
+ worker.rank_print(f"Error during teardown: {e}\n{traceback.format_exc()}")
488
+
489
+ socket.send(b"EXIT")
490
+ socket.close()
491
+
492
+ async def start(self, timeout: int = 1800, **kwargs: Any) -> None:
493
+ """
494
+ Starts the distributed worker processes.
495
+ :param timeout: The timeout for the distributed processes.
496
+ """
497
+ import zmq
498
+
499
+ if self.is_alive():
500
+ raise RuntimeError("Distributed processes are already running.")
501
+
502
+ self.context = launch_distributed_processes(
503
+ self.run,
504
+ world_size=self.world_size,
505
+ master_addr=self.master_addr,
506
+ master_port=self.master_port,
507
+ timeout=self.timeout,
508
+ cwd=self.cwd,
509
+ **kwargs,
510
+ )
511
+
512
+ try:
513
+ ready_workers: set[int] = set()
514
+ socket = self.get_zmq_socket()
515
+ start_time = time.perf_counter()
516
+
517
+ while len(ready_workers) < self.world_size:
518
+ try:
519
+ ident, msg = await socket.recv_multipart(flags=zmq.NOBLOCK) # type: ignore[misc]
520
+
521
+ if msg != b"READY":
522
+ worker_id = ident.decode("utf-8")
523
+ worker_msg = msg.decode("utf-8")
524
+ raise RuntimeError(
525
+ f"Unexpected message from worker {worker_id}: {worker_msg}"
526
+ )
527
+
528
+ print(f"[debug] Worker {ident.decode('utf-8')} is ready.")
529
+ ready_workers.add(ident)
530
+ except zmq.Again:
531
+ total_wait_time = time.perf_counter() - start_time
532
+ if total_wait_time > timeout:
533
+ raise TimeoutError(
534
+ f"Timeout reached after {timeout} seconds while "
535
+ f"waiting for workers to be ready."
536
+ )
537
+ await asyncio.sleep(0.5)
538
+ self.ensure_alive()
539
+ except Exception as e:
540
+ print(f"[debug] Error during startup: {e}\n{traceback.format_exc()}")
541
+ self.terminate(timeout=timeout)
542
+ raise RuntimeError("Failed to start distributed processes.") from e
543
+
544
+ print("[debug] All workers are ready and running.")
545
+
546
+ # Start the keepalive timer
547
+ self.maybe_start_keepalive()
548
+
549
+ def keepalive(self, timeout: Optional[Union[int, float]] = 60.0) -> None:
550
+ """
551
+ Sends the keepalive payload to the worker.
552
+ """
553
+ # Cancel the keepalive timer
554
+ self.maybe_cancel_keepalive()
555
+ loop_thread = None
556
+ try:
557
+ loop = asyncio.get_running_loop()
558
+ except RuntimeError:
559
+ loop = asyncio.new_event_loop()
560
+ loop_thread = threading.Thread(target=loop.run_forever, daemon=True)
561
+ loop_thread.start()
562
+
563
+ future = asyncio.run_coroutine_threadsafe(
564
+ self.invoke(self.keepalive_payload), loop
565
+ )
566
+
567
+ try:
568
+ future.result(timeout=timeout)
569
+ except Exception as e:
570
+ print(f"[debug] Error during keepalive: {e}\n{traceback.format_exc()}")
571
+ raise RuntimeError("Failed to run keepalive.") from e
572
+ finally:
573
+ if loop_thread is not None:
574
+ loop.call_soon_threadsafe(loop.stop)
575
+ loop_thread.join(timeout=timeout)
576
+ # Restart the keepalive timer
577
+ self.maybe_start_keepalive()
578
+
579
+ def maybe_start_keepalive(self) -> None:
580
+ """
581
+ Starts the keepalive timer if it is set.
582
+ """
583
+ if self.keepalive_timer is None and self.keepalive_interval is not None:
584
+ self.keepalive_timer = KeepAliveTimer(
585
+ self.keepalive, self.keepalive_interval, start=True
586
+ )
587
+
588
+ def maybe_reset_keepalive(self) -> None:
589
+ """
590
+ Resets the keepalive timer if it is set.
591
+ """
592
+ if self.keepalive_timer is not None:
593
+ self.keepalive_timer.reset()
594
+
595
+ def maybe_cancel_keepalive(self) -> None:
596
+ """
597
+ Cancels the keepalive timer if it is set.
598
+ """
599
+ if self.keepalive_timer is not None:
600
+ self.keepalive_timer.cancel()
601
+ self.keepalive_timer = None
602
+
603
+ async def stop(self, timeout: int = 10) -> None:
604
+ """
605
+ Stops the distributed worker processes.
606
+ :param timeout: The timeout for the distributed processes to stop.
607
+ """
608
+ import zmq
609
+
610
+ if not self.is_alive():
611
+ raise RuntimeError("Distributed processes are not running.")
612
+
613
+ self.maybe_cancel_keepalive()
614
+ worker_exits: set[int] = set()
615
+ socket = self.get_zmq_socket()
616
+ await socket.send_multipart([b"0", b"00EXIT"])
617
+
618
+ wait_start = time.perf_counter()
619
+ while len(worker_exits) < self.world_size:
620
+ try:
621
+ ident, msg = await socket.recv_multipart(flags=zmq.NOBLOCK) # type: ignore[misc]
622
+ if msg == b"EXIT":
623
+ worker_exits.add(ident)
624
+ if len(worker_exits) == self.world_size:
625
+ print("[debug] All workers have exited.")
626
+ await asyncio.sleep(1) # Allow time for cleanup
627
+ break
628
+
629
+ except zmq.Again:
630
+ # No messages available, continue waiting
631
+ await asyncio.sleep(0.1)
632
+
633
+ if time.perf_counter() - wait_start > timeout:
634
+ print(
635
+ f"[debug] Timeout reached after {timeout} seconds, "
636
+ f"stopping waiting."
637
+ )
638
+ break
639
+
640
+ if not self.is_alive():
641
+ print("[debug] All workers have exited prematurely.")
642
+ break
643
+
644
+ if self.is_alive():
645
+ print("[debug] Some workers did not exit cleanly, terminating them.")
646
+ self.terminate(timeout=timeout)
647
+
648
+ self.close_zmq_socket()
649
+
650
+ async def stream(
651
+ self,
652
+ payload: dict[str, Any] = {},
653
+ timeout: Optional[int] = None,
654
+ streaming_timeout: Optional[int] = None,
655
+ as_text_events: bool = False,
656
+ ) -> AsyncIterator[Any]:
657
+ """
658
+ Streams the result from the distributed worker.
659
+ :param payload: The payload to send to the worker.
660
+ :param timeout: The timeout for the overall operation.
661
+ :param streaming_timeout: The timeout in-between streamed results.
662
+ :param as_text_events: Whether to yield results as text events.
663
+ :return: An async iterator that yields the result from the worker.
664
+ """
665
+ import zmq
666
+
667
+ self.ensure_alive()
668
+ self.maybe_cancel_keepalive() # Cancel until the streaming is done
669
+ socket = self.get_zmq_socket()
670
+ payload_serialized = distributed_serialize(payload, is_final=True)
671
+ await socket.send_multipart(
672
+ [b"0", b"1" + (b"1" if as_text_events else b"0") + payload_serialized]
673
+ )
674
+
675
+ start_time = time.perf_counter()
676
+ last_yield_time = start_time
677
+ yielded_once = False
678
+
679
+ while True:
680
+ iter_start_time = time.perf_counter()
681
+ try:
682
+ rank, response = await socket.recv_multipart(flags=zmq.NOBLOCK) # type: ignore[misc]
683
+ except zmq.Again:
684
+ if timeout is not None and iter_start_time - start_time > timeout:
685
+ raise TimeoutError(f"Streaming timed out after {timeout} seconds.")
686
+ if (
687
+ streaming_timeout is not None
688
+ and iter_start_time - last_yield_time > streaming_timeout
689
+ ):
690
+ raise TimeoutError(
691
+ f"Streaming timed out after {streaming_timeout} "
692
+ f"seconds of inactivity."
693
+ )
694
+
695
+ await asyncio.sleep(0.1)
696
+ continue
697
+
698
+ assert rank == b"0", "Expected response from worker with rank 0"
699
+
700
+ if response == b"DONE":
701
+ if not yielded_once:
702
+ raise RuntimeError("No data was yielded from the worker.")
703
+ break
704
+
705
+ if as_text_events:
706
+ yield response
707
+ else:
708
+ yield distributed_deserialize(response)
709
+
710
+ yielded_once = True
711
+ last_yield_time = iter_start_time
712
+
713
+ self.maybe_start_keepalive() # Restart the keepalive timer
714
+
715
+ async def invoke(
716
+ self,
717
+ payload: dict[str, Any] = {},
718
+ timeout: Optional[int] = None,
719
+ ) -> Any:
720
+ """
721
+ Invokes the distributed worker with the given payload.
722
+ :param payload: The payload to send to the worker.
723
+ :param timeout: The timeout for the overall operation.
724
+ :return: The result from the worker.
725
+ """
726
+ import zmq
727
+
728
+ self.ensure_alive()
729
+ self.maybe_cancel_keepalive() # Cancel until the invocation is done
730
+ socket = self.get_zmq_socket()
731
+ payload_serialized = distributed_serialize(payload, is_final=True)
732
+
733
+ await socket.send_multipart([b"0", b"00" + payload_serialized])
734
+
735
+ # Wait for the response from the worker
736
+ start_time = time.perf_counter()
737
+ while True:
738
+ try:
739
+ rank, response = await socket.recv_multipart(flags=zmq.NOBLOCK) # type: ignore[misc]
740
+ break # Exit the loop if we received a response
741
+ except zmq.Again:
742
+ elapsed = time.perf_counter() - start_time
743
+ if timeout is not None and elapsed > timeout:
744
+ raise TimeoutError(f"Invocation timed out after {timeout} seconds.")
745
+
746
+ await asyncio.sleep(0.1)
747
+ self.ensure_alive()
748
+
749
+ self.maybe_start_keepalive() # Restart the keepalive timer
750
+ assert rank == b"0", "Expected response from worker with rank 0"
751
+ return distributed_deserialize(response)
752
+
753
+ async def __aenter__(self) -> "DistributedRunner":
754
+ """
755
+ Enter the context manager.
756
+ :return: The DistributedRunner instance.
757
+ """
758
+ await self.start()
759
+ return self
760
+
761
+ async def __aexit__(
762
+ self,
763
+ exc_type: Optional[type[BaseException]],
764
+ exc_value: Optional[BaseException],
765
+ exc_traceback: Optional[traceback.StackSummary],
766
+ ) -> None:
767
+ """
768
+ Exit the context manager.
769
+ :param exc_type: The type of the exception raised, if any.
770
+ :param exc_value: The value of the exception raised, if any.
771
+ :param traceback: The traceback of the exception raised, if any.
772
+ """
773
+ try:
774
+ await self.stop()
775
+ except Exception as e:
776
+ print(f"[debug] Error during cleanup: {e}\n{traceback.format_exc()}")