wool 0.1rc13__py3-none-any.whl → 0.1rc15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wool might be problematic. Click here for more details.

wool/_worker.py CHANGED
@@ -1,75 +1,66 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
- import os
5
4
  import signal
6
5
  import uuid
7
6
  from abc import ABC
8
7
  from abc import abstractmethod
9
- from contextlib import asynccontextmanager
10
8
  from contextlib import contextmanager
11
9
  from multiprocessing import Pipe
12
10
  from multiprocessing import Process
13
11
  from multiprocessing.connection import Connection
12
+ from types import MappingProxyType
14
13
  from typing import TYPE_CHECKING
15
14
  from typing import Any
16
15
  from typing import AsyncContextManager
17
- from typing import AsyncIterator
18
16
  from typing import Awaitable
19
17
  from typing import ContextManager
20
18
  from typing import Final
21
- from typing import Generic
22
19
  from typing import Protocol
23
- from typing import TypeAlias
24
- from typing import TypeVar
25
20
  from typing import final
26
21
 
27
- import cloudpickle
28
22
  import grpc.aio
29
- from grpc import StatusCode
30
- from grpc.aio import ServicerContext
31
23
 
32
24
  import wool
33
25
  from wool import _protobuf as pb
34
26
  from wool._resource_pool import ResourcePool
35
- from wool._work import WoolTask
36
- from wool._work import WoolTaskEvent
37
27
  from wool._worker_discovery import Factory
38
28
  from wool._worker_discovery import RegistrarLike
39
29
  from wool._worker_discovery import WorkerInfo
30
+ from wool._worker_service import WorkerService
40
31
 
41
32
  if TYPE_CHECKING:
42
33
  from wool._worker_proxy import WorkerProxy
43
34
 
44
35
 
45
36
  @contextmanager
46
- def _signal_handlers(service: "WorkerService"):
37
+ def _signal_handlers(service: WorkerService):
47
38
  """Context manager for setting up signal handlers for graceful shutdown.
48
39
 
49
40
  Installs SIGTERM and SIGINT handlers that gracefully shut down the worker
50
41
  service when the process receives termination signals.
51
42
 
52
43
  :param service:
53
- The :py:class:`WorkerService` instance to shut down on signal receipt.
44
+ The :class:`WorkerService` instance to shut down on signal receipt.
54
45
  :yields:
55
46
  Control to the calling context with signal handlers installed.
56
47
  """
57
- try:
58
- loop = asyncio.get_running_loop()
59
- except RuntimeError:
60
- loop = asyncio.new_event_loop()
61
- asyncio.set_event_loop(loop)
48
+ loop = asyncio.get_running_loop()
62
49
 
63
50
  def sigterm_handler(signum, frame):
64
51
  if loop.is_running():
65
52
  loop.call_soon_threadsafe(
66
- lambda: asyncio.create_task(service._stop(timeout=0))
53
+ lambda: asyncio.create_task(
54
+ service.stop(pb.worker.StopRequest(timeout=0), None)
55
+ )
67
56
  )
68
57
 
69
58
  def sigint_handler(signum, frame):
70
59
  if loop.is_running():
71
60
  loop.call_soon_threadsafe(
72
- lambda: asyncio.create_task(service._stop(timeout=None))
61
+ lambda: asyncio.create_task(
62
+ service.stop(pb.worker.StopRequest(timeout=None), None)
63
+ )
73
64
  )
74
65
 
75
66
  old_sigterm = signal.signal(signal.SIGTERM, sigterm_handler)
@@ -162,13 +153,24 @@ class Worker(ABC):
162
153
  def port(self) -> int | None: ...
163
154
 
164
155
  @final
165
- async def start(self):
156
+ async def start(self, *, timeout: float | None = None):
166
157
  """Start the worker and register it with the pool.
167
158
 
168
159
  This method is a final implementation that calls the abstract
169
160
  `_start` method to initialize the worker process and register
170
161
  it with the registrar service.
162
+
163
+ :param timeout:
164
+ Maximum time in seconds to wait for worker startup.
165
+ :raises TimeoutError:
166
+ If startup takes longer than the specified timeout.
167
+ :raises RuntimeError:
168
+ If the worker has already been started.
169
+ :raises ValueError:
170
+ If the timeout is not positive.
171
171
  """
172
+ if timeout is not None and timeout <= 0:
173
+ raise ValueError("Timeout must be positive")
172
174
  if self._started:
173
175
  raise RuntimeError("Worker has already been started")
174
176
 
@@ -178,13 +180,13 @@ class Worker(ABC):
178
180
  if not isinstance(self._registrar_service, RegistrarLike):
179
181
  raise ValueError("Registrar factory must return a RegistrarLike instance")
180
182
 
181
- await self._start()
183
+ await self._start(timeout=timeout)
182
184
  self._started = True
183
185
  assert self._info
184
186
  await self._registrar_service.register(self._info)
185
187
 
186
188
  @final
187
- async def stop(self):
189
+ async def stop(self, *, timeout: float | None = None):
188
190
  """Stop the worker and unregister it from the pool.
189
191
 
190
192
  This method is a final implementation that calls the abstract
@@ -200,7 +202,7 @@ class Worker(ABC):
200
202
  await self._registrar_service.unregister(self._info)
201
203
  finally:
202
204
  try:
203
- await self._stop()
205
+ await self._stop(timeout)
204
206
  finally:
205
207
  await self._exit_context(self._registrar_context)
206
208
  self._registrar_service = None
@@ -208,16 +210,19 @@ class Worker(ABC):
208
210
  self._started = False
209
211
 
210
212
  @abstractmethod
211
- async def _start(self):
213
+ async def _start(self, timeout: float | None):
212
214
  """Implementation-specific worker startup logic.
213
215
 
214
216
  Subclasses must implement this method to handle the actual
215
217
  startup of their worker process and gRPC server.
218
+
219
+ :param timeout:
220
+ Maximum time in seconds to wait for worker startup.
216
221
  """
217
222
  ...
218
223
 
219
224
  @abstractmethod
220
- async def _stop(self):
225
+ async def _stop(self, timeout: float | None):
221
226
  """Implementation-specific worker shutdown logic.
222
227
 
223
228
  Subclasses must implement this method to handle the graceful
@@ -246,6 +251,8 @@ class Worker(ABC):
246
251
  self, ctx: AsyncContextManager | ContextManager | None, *args
247
252
  ):
248
253
  """Exit context for context managers."""
254
+ if not args:
255
+ args = (None, None, None)
249
256
  if isinstance(ctx, AsyncContextManager):
250
257
  await ctx.__aexit__(*args)
251
258
  elif isinstance(ctx, ContextManager):
@@ -257,10 +264,10 @@ class WorkerFactory(Protocol):
257
264
  """Protocol for creating worker instances with registrar integration.
258
265
 
259
266
  Defines the callable interface for worker factory implementations
260
- that can create :py:class:`Worker` instances configured with specific
267
+ that can create :class:`Worker` instances configured with specific
261
268
  capability tags and metadata.
262
269
 
263
- Worker factories are used by :py:class:`WorkerPool` to spawn multiple
270
+ Worker factories are used by :class:`WorkerPool` to spawn multiple
264
271
  worker processes with consistent configuration.
265
272
  """
266
273
 
@@ -271,7 +278,7 @@ class WorkerFactory(Protocol):
271
278
  Additional tags to associate with this worker for discovery
272
279
  and filtering purposes.
273
280
  :returns:
274
- A new :py:class:`Worker` instance configured with the
281
+ A new :class:`Worker` instance configured with the
275
282
  specified tags and metadata.
276
283
  """
277
284
  ...
@@ -281,7 +288,7 @@ class WorkerFactory(Protocol):
281
288
  class LocalWorker(Worker):
282
289
  """Local worker implementation that runs tasks in a separate process.
283
290
 
284
- :py:class:`LocalWorker` creates and manages a dedicated worker process
291
+ :class:`LocalWorker` creates and manages a dedicated worker process
285
292
  that hosts a gRPC server for executing distributed wool tasks. Each
286
293
  worker automatically registers itself with the provided registrar service
287
294
  for discovery by client sessions.
@@ -293,8 +300,17 @@ class LocalWorker(Worker):
293
300
  :param tags:
294
301
  Capability tags to associate with this worker for filtering
295
302
  and selection by client sessions.
303
+ :param host:
304
+ Host address where the worker will listen.
305
+ :param port:
306
+ Port number where the worker will listen. If 0, a random
307
+ available port will be selected.
296
308
  :param registrar:
297
309
  Service instance or factory for worker registration and discovery.
310
+ :param shutdown_grace_period:
311
+ Graceful shutdown timeout for the gRPC server in seconds.
312
+ :param proxy_pool_ttl:
313
+ Time-to-live for the proxy resource pool in seconds.
298
314
  :param extra:
299
315
  Additional arbitrary metadata as key-value pairs.
300
316
  """
@@ -307,10 +323,17 @@ class LocalWorker(Worker):
307
323
  host: str = "127.0.0.1",
308
324
  port: int = 0,
309
325
  registrar: RegistrarLike | Factory[RegistrarLike],
326
+ shutdown_grace_period: float = 60.0,
327
+ proxy_pool_ttl: float = 60.0,
310
328
  **extra: Any,
311
329
  ):
312
330
  super().__init__(*tags, registrar=registrar, **extra)
313
- self._worker_process = WorkerProcess(host=host, port=port)
331
+ self._worker_process = WorkerProcess(
332
+ host=host,
333
+ port=port,
334
+ shutdown_grace_period=shutdown_grace_period,
335
+ proxy_pool_ttl=proxy_pool_ttl,
336
+ )
314
337
 
315
338
  @property
316
339
  def address(self) -> str | None:
@@ -339,36 +362,39 @@ class LocalWorker(Worker):
339
362
  """
340
363
  return self._info.port if self._info else None
341
364
 
342
- async def _start(self):
365
+ async def _start(self, timeout: float | None):
343
366
  """Start the worker process and register it with the pool.
344
367
 
345
368
  Initializes the registrar service, starts the worker process
346
369
  with its gRPC server, and registers the worker's network
347
370
  address with the registrar for discovery by client sessions.
371
+
372
+ :param timeout:
373
+ Maximum time in seconds to wait for worker process startup.
348
374
  """
349
375
  loop = asyncio.get_running_loop()
350
- await loop.run_in_executor(None, self._worker_process.start)
376
+ await loop.run_in_executor(
377
+ None, lambda t: self._worker_process.start(timeout=t), timeout
378
+ )
351
379
  if not self._worker_process.address:
352
380
  raise RuntimeError("Worker process failed to start - no address")
353
381
  if not self._worker_process.pid:
354
382
  raise RuntimeError("Worker process failed to start - no PID")
355
383
 
356
- # Parse host and port from address
357
384
  host, port_str = self._worker_process.address.split(":")
358
385
  port = int(port_str)
359
386
 
360
- # Create the WorkerInfo with the actual host, port, and pid
361
387
  self._info = WorkerInfo(
362
388
  uid=self._uid,
363
389
  host=host,
364
390
  port=port,
365
391
  pid=self._worker_process.pid,
366
392
  version=wool.__version__,
367
- tags=self._tags,
368
- extra=self._extra,
393
+ tags=frozenset(self._tags),
394
+ extra=MappingProxyType(self._extra),
369
395
  )
370
396
 
371
- async def _stop(self):
397
+ async def _stop(self, timeout: float | None):
372
398
  """Stop the worker process and unregister it from the pool.
373
399
 
374
400
  Unregisters the worker from the registrar service, gracefully
@@ -376,35 +402,37 @@ class LocalWorker(Worker):
376
402
  the registrar service. If graceful shutdown fails, the process
377
403
  is forcefully terminated.
378
404
  """
379
- if not self._worker_process.is_alive():
380
- return
381
- try:
382
- if self._worker_process.pid:
383
- os.kill(self._worker_process.pid, signal.SIGINT)
384
- self._worker_process.join()
385
- except OSError:
386
- if self._worker_process.is_alive():
387
- self._worker_process.kill()
405
+ if self._worker_process.is_alive():
406
+ assert self.address
407
+ channel = grpc.aio.insecure_channel(self.address)
408
+ stub = pb.worker.WorkerStub(channel)
409
+ await stub.stop(pb.worker.StopRequest(timeout=timeout))
388
410
 
389
411
 
390
412
  class WorkerProcess(Process):
391
- """A :py:class:`multiprocessing.Process` that runs a gRPC worker
413
+ """A :class:`multiprocessing.Process` that runs a gRPC worker
392
414
  server.
393
415
 
394
- :py:class:`WorkerProcess` creates an isolated Python process that hosts a
416
+ :class:`WorkerProcess` creates an isolated Python process that hosts a
395
417
  gRPC server for executing distributed tasks. Each process maintains
396
418
  its own event loop and serves as an independent worker node in the
397
419
  wool distributed runtime.
398
420
 
421
+ :param host:
422
+ Host address where the gRPC server will listen.
399
423
  :param port:
400
- Optional port number where the gRPC server will listen.
401
- If None, a random available port will be selected.
424
+ Port number where the gRPC server will listen. If 0, a random
425
+ available port will be selected.
426
+ :param shutdown_grace_period:
427
+ Graceful shutdown timeout for the gRPC server in seconds.
428
+ :param proxy_pool_ttl:
429
+ Time-to-live for the proxy resource pool in seconds.
402
430
  :param args:
403
431
  Additional positional arguments passed to the parent
404
- :py:class:`multiprocessing.Process` class.
432
+ :class:`multiprocessing.Process` class.
405
433
  :param kwargs:
406
434
  Additional keyword arguments passed to the parent
407
- :py:class:`multiprocessing.Process` class.
435
+ :class:`multiprocessing.Process` class.
408
436
 
409
437
  .. attribute:: address
410
438
  The network address where the gRPC server is listening.
@@ -413,8 +441,18 @@ class WorkerProcess(Process):
413
441
  _port: int | None
414
442
  _get_port: Connection
415
443
  _set_port: Connection
444
+ _shutdown_grace_period: float
445
+ _proxy_pool_ttl: float
416
446
 
417
- def __init__(self, *args, host: str = "127.0.0.1", port: int = 0, **kwargs):
447
+ def __init__(
448
+ self,
449
+ *args,
450
+ host: str = "127.0.0.1",
451
+ port: int = 0,
452
+ shutdown_grace_period: float = 60.0,
453
+ proxy_pool_ttl: float = 60.0,
454
+ **kwargs,
455
+ ):
418
456
  super().__init__(*args, **kwargs)
419
457
  if not host:
420
458
  raise ValueError("Host must be a non-blank string")
@@ -422,6 +460,12 @@ class WorkerProcess(Process):
422
460
  if port < 0:
423
461
  raise ValueError("Port must be a positive integer")
424
462
  self._port = port
463
+ if shutdown_grace_period <= 0:
464
+ raise ValueError("Shutdown grace period must be positive")
465
+ self._shutdown_grace_period = shutdown_grace_period
466
+ if proxy_pool_ttl <= 0:
467
+ raise ValueError("Proxy pool TTL must be positive")
468
+ self._proxy_pool_ttl = proxy_pool_ttl
425
469
  self._get_port, self._set_port = Pipe(duplex=False)
426
470
 
427
471
  @property
@@ -451,26 +495,31 @@ class WorkerProcess(Process):
451
495
  """
452
496
  return self._port or None
453
497
 
454
- def start(self):
498
+ def start(self, *, timeout: float | None = None):
455
499
  """Start the worker process.
456
500
 
457
501
  Launches the worker process and waits until it has started
458
502
  listening on a port. After starting, the :attr:`address`
459
503
  property will contain the actual network address.
460
504
 
505
+ :param timeout:
506
+ Maximum time in seconds to wait for worker process startup.
461
507
  :raises RuntimeError:
462
- If the worker process fails to start within 10 seconds.
508
+ If the worker process fails to start within the timeout.
463
509
  :raises ValueError:
464
- If the port is negative.
510
+ If the timeout is not positive.
465
511
  """
512
+ if timeout is not None and timeout <= 0:
513
+ raise ValueError("Timeout must be positive")
466
514
  super().start()
467
- # Add timeout to prevent hanging if child process fails to start
468
- if self._get_port.poll(timeout=10): # 10 second timeout
515
+ if self._get_port.poll(timeout=timeout):
469
516
  self._port = self._get_port.recv()
470
517
  else:
471
518
  self.terminate()
472
519
  self.join()
473
- raise RuntimeError("Worker process failed to start within 10 seconds")
520
+ raise RuntimeError(
521
+ f"Worker process failed to start within {timeout} seconds"
522
+ )
474
523
  self._get_port.close()
475
524
 
476
525
  def run(self) -> None:
@@ -510,7 +559,11 @@ class WorkerProcess(Process):
510
559
  pass
511
560
 
512
561
  wool.__proxy_pool__.set(
513
- ResourcePool(factory=proxy_factory, finalizer=proxy_finalizer, ttl=60)
562
+ ResourcePool(
563
+ factory=proxy_factory,
564
+ finalizer=proxy_finalizer,
565
+ ttl=self._proxy_pool_ttl,
566
+ )
514
567
  )
515
568
  asyncio.run(self._serve())
516
569
 
@@ -535,7 +588,7 @@ class WorkerProcess(Process):
535
588
  self._set_port.close()
536
589
  await service.stopped.wait()
537
590
  finally:
538
- await server.stop(grace=60)
591
+ await server.stop(grace=self._shutdown_grace_period)
539
592
 
540
593
  def _address(self, host, port) -> str:
541
594
  """Format network address for the given port.
@@ -546,367 +599,3 @@ class WorkerProcess(Process):
546
599
  Address string in "host:port" format.
547
600
  """
548
601
  return f"{host}:{port}"
549
-
550
-
551
- class WorkerService(pb.worker.WorkerServicer):
552
- """gRPC service implementation for executing distributed wool tasks.
553
-
554
- :py:class:`WorkerService` implements the gRPC WorkerServicer
555
- interface, providing remote procedure calls for task scheduling
556
- and worker lifecycle management. Tasks are executed in the same
557
- asyncio event loop as the gRPC server.
558
-
559
- .. note::
560
- Tasks are executed asynchronously in the current event loop
561
- and results are serialized for transport back to the client.
562
- The service maintains a set of running tasks for proper
563
- lifecycle management during shutdown.
564
-
565
- During shutdown, the service stops accepting new requests
566
- immediately when the :meth:`stop` RPC is called, returning
567
- UNAVAILABLE errors to new :meth:`dispatch` requests while
568
- allowing existing tasks to complete gracefully.
569
-
570
- The service provides :attr:`stopping` and
571
- :attr:`stopped` properties to access the internal shutdown
572
- state events.
573
- """
574
-
575
- _tasks: set[asyncio.Task]
576
- _stopped: asyncio.Event
577
- _stopping: asyncio.Event
578
- _task_completed: asyncio.Event
579
-
580
- def __init__(self):
581
- self._stopped = asyncio.Event()
582
- self._stopping = asyncio.Event()
583
- self._task_completed = asyncio.Event()
584
- self._tasks = set()
585
-
586
- @property
587
- def stopping(self) -> asyncio.Event:
588
- """Event signaling that the service is stopping.
589
-
590
- :returns:
591
- An :py:class:`asyncio.Event` that is set when the service
592
- begins shutdown.
593
- """
594
- return self._stopping
595
-
596
- @property
597
- def stopped(self) -> asyncio.Event:
598
- """Event signaling that the service has stopped.
599
-
600
- :returns:
601
- An :py:class:`asyncio.Event` that is set when the service
602
- has completed shutdown.
603
- """
604
- return self._stopped
605
-
606
- @contextmanager
607
- def _running(self, wool_task: WoolTask):
608
- """Context manager for tracking running tasks.
609
-
610
- Manages the lifecycle of a task execution, adding it to the
611
- active tasks set and emitting appropriate events. Ensures
612
- proper cleanup when the task completes or fails.
613
-
614
- :param wool_task:
615
- The :py:class:`WoolTask` instance to execute and track.
616
- :yields:
617
- The :py:class:`asyncio.Task` created for the wool task.
618
-
619
- .. note::
620
- Emits a :py:class:`WoolTaskEvent` with type "task-scheduled"
621
- when the task begins execution.
622
- """
623
- WoolTaskEvent("task-scheduled", task=wool_task).emit()
624
- task = asyncio.create_task(wool_task.run())
625
- self._tasks.add(task)
626
- try:
627
- yield task
628
- finally:
629
- self._tasks.remove(task)
630
-
631
- async def dispatch(
632
- self, request: pb.task.Task, context: ServicerContext
633
- ) -> AsyncIterator[pb.worker.Response]:
634
- """Execute a task in the current event loop.
635
-
636
- Deserializes the incoming task into a :py:class:`WoolTask`
637
- instance, schedules it for execution in the current asyncio
638
- event loop, and yields responses for acknowledgment and result.
639
-
640
- :param request:
641
- The protobuf task message containing the serialized task
642
- data.
643
- :param context:
644
- The :py:class:`grpc.aio.ServicerContext` for this request.
645
- :yields:
646
- First yields an Ack Response when task processing begins,
647
- then yields a Response containing the task result.
648
-
649
- .. note::
650
- Emits a :py:class:`WoolTaskEvent` when the task is
651
- scheduled for execution.
652
- """
653
- if self._stopping.is_set():
654
- await context.abort(
655
- StatusCode.UNAVAILABLE, "Worker service is shutting down"
656
- )
657
-
658
- with self._running(WoolTask.from_protobuf(request)) as task:
659
- # Yield acknowledgment that task was received and processing is starting
660
- yield pb.worker.Response(ack=pb.worker.Ack())
661
-
662
- try:
663
- result = pb.task.Result(dump=cloudpickle.dumps(await task))
664
- yield pb.worker.Response(result=result)
665
- except Exception as e:
666
- exception = pb.task.Exception(dump=cloudpickle.dumps(e))
667
- yield pb.worker.Response(exception=exception)
668
-
669
- async def stop(
670
- self, request: pb.worker.StopRequest, context: ServicerContext
671
- ) -> pb.worker.Void:
672
- """Stop the worker service and its thread.
673
-
674
- Gracefully shuts down the worker thread and signals the server
675
- to stop accepting new requests. This method is idempotent and
676
- can be called multiple times safely.
677
-
678
- :param request:
679
- The protobuf stop request containing the wait timeout.
680
- :param context:
681
- The :py:class:`grpc.aio.ServicerContext` for this request.
682
- :returns:
683
- An empty protobuf response indicating completion.
684
- """
685
- if self._stopping.is_set():
686
- return pb.worker.Void()
687
- await self._stop(timeout=request.wait)
688
- return pb.worker.Void()
689
-
690
- async def _stop(self, *, timeout: float | None = 0) -> None:
691
- self._stopping.set()
692
- await self._await_or_cancel_tasks(timeout=timeout)
693
-
694
- # Clean up the session cache to prevent issues during shutdown
695
- try:
696
- proxy_pool = wool.__proxy_pool__.get()
697
- assert proxy_pool
698
- await proxy_pool.clear()
699
- finally:
700
- self._stopped.set()
701
-
702
- async def _await_or_cancel_tasks(self, *, timeout: float | None = 0) -> None:
703
- """Stop the worker service gracefully.
704
-
705
- Gracefully shuts down the worker service by canceling or waiting
706
- for running tasks. This method is idempotent and can be called
707
- multiple times safely.
708
-
709
- :param timeout:
710
- Maximum time to wait for tasks to complete. If 0 (default),
711
- tasks are canceled immediately. If None, waits indefinitely.
712
- If a positive number, waits for that many seconds before
713
- canceling tasks.
714
-
715
- .. note::
716
- If a timeout occurs while waiting for tasks to complete,
717
- the method recursively calls itself with a timeout of 0
718
- to cancel all remaining tasks immediately.
719
- """
720
- if self._tasks and timeout == 0:
721
- await self._cancel(*self._tasks)
722
- elif self._tasks:
723
- try:
724
- await asyncio.wait_for(
725
- asyncio.gather(*self._tasks, return_exceptions=True),
726
- timeout=timeout,
727
- )
728
- except asyncio.TimeoutError:
729
- return await self._await_or_cancel_tasks(timeout=0)
730
-
731
- async def _cancel(self, *tasks: asyncio.Task):
732
- """Cancel multiple tasks safely.
733
-
734
- Cancels the provided tasks while performing safety checks to
735
- avoid canceling the current task or already completed tasks.
736
- Waits for all cancelled tasks to complete in parallel and handles
737
- cancellation exceptions.
738
-
739
- :param tasks:
740
- The :py:class:`asyncio.Task` instances to cancel.
741
-
742
- .. note::
743
- This method performs the following safety checks:
744
- - Avoids canceling the current task (would cause deadlock)
745
- - Only cancels tasks that are not already done
746
- - Properly handles :py:exc:`asyncio.CancelledError`
747
- exceptions.
748
- """
749
- current = asyncio.current_task()
750
- to_cancel = [task for task in tasks if not task.done() and task != current]
751
-
752
- # Cancel all tasks first
753
- for task in to_cancel:
754
- task.cancel()
755
-
756
- # Wait for all cancelled tasks in parallel
757
- if to_cancel:
758
- await asyncio.gather(*to_cancel, return_exceptions=True)
759
-
760
-
761
- DispatchCall: TypeAlias = grpc.aio.UnaryStreamCall[pb.task.Task, pb.worker.Response]
762
-
763
-
764
- @asynccontextmanager
765
- async def with_timeout(context, timeout):
766
- """Async context manager wrapper that adds timeout to context entry.
767
-
768
- :param context:
769
- The async context manager to wrap.
770
- :param timeout:
771
- Timeout in seconds for context entry.
772
- :yields:
773
- Control to the calling context.
774
- :raises asyncio.TimeoutError:
775
- If context entry exceeds the timeout.
776
- """
777
- await asyncio.wait_for(context.__aenter__(), timeout=timeout)
778
- exception_type = exception_value = exception_traceback = None
779
- try:
780
- yield
781
- except BaseException as exception:
782
- exception_type = type(exception)
783
- exception_value = exception
784
- exception_traceback = exception.__traceback__
785
- raise
786
- finally:
787
- await context.__aexit__(exception_type, exception_value, exception_traceback)
788
-
789
-
790
- T = TypeVar("T")
791
-
792
-
793
- class DispatchStream(Generic[T]):
794
- """Async iterator wrapper for streaming dispatch results.
795
-
796
- Simplified wrapper that focuses solely on stream iteration and response handling.
797
- Channel management is now handled by the WorkerClient.
798
- """
799
-
800
- def __init__(self, stream: DispatchCall):
801
- """Initialize the streaming dispatch result wrapper.
802
-
803
- :param stream:
804
- The underlying gRPC response stream.
805
- """
806
- self._stream = stream
807
- self._iter = aiter(stream)
808
-
809
- def __aiter__(self) -> AsyncIterator[T]:
810
- """Return self as the async iterator."""
811
- return self
812
-
813
- async def __anext__(self) -> T:
814
- """Get the next response from the stream.
815
-
816
- :returns:
817
- The next task result from the worker.
818
- :raises StopAsyncIteration:
819
- When the stream is exhausted.
820
- """
821
- try:
822
- response = await anext(self._iter)
823
- if response.HasField("result"):
824
- return cloudpickle.loads(response.result.dump)
825
- elif response.HasField("exception"):
826
- raise cloudpickle.loads(response.exception.dump)
827
- else:
828
- raise RuntimeError(f"Received unexpected response: {response}")
829
- except Exception as exception:
830
- await self._handle_exception(exception)
831
-
832
- async def _handle_exception(self, exception):
833
- try:
834
- self._stream.cancel()
835
- except Exception as cancel_exception:
836
- raise cancel_exception from exception
837
- else:
838
- raise exception
839
-
840
-
841
- class WorkerClient:
842
- """Client for dispatching tasks to a specific worker.
843
-
844
- Simplified client that maintains a persistent gRPC channel to a single
845
- worker. The client manages the channel lifecycle and provides task
846
- dispatch functionality with proper error handling.
847
-
848
- :param address:
849
- The network address of the target worker in "host:port" format.
850
- """
851
-
852
- def __init__(self, address: str):
853
- self._channel = grpc.aio.insecure_channel(
854
- address,
855
- # options=[
856
- # ("grpc.keepalive_time_ms", 10000),
857
- # ("grpc.keepalive_timeout_ms", 5000),
858
- # ("grpc.http2.max_pings_without_data", 0),
859
- # ("grpc.http2.min_time_between_pings_ms", 10000),
860
- # ("grpc.max_receive_message_length", 100 * 1024 * 1024),
861
- # ("grpc.max_send_message_length", 100 * 1024 * 1024),
862
- # ],
863
- )
864
- self._stub = pb.worker.WorkerStub(self._channel)
865
- self._semaphore = asyncio.Semaphore(100)
866
-
867
- async def dispatch(self, task: WoolTask) -> AsyncIterator[pb.task.Result]:
868
- """Dispatch task to worker with on-demand channel acquisition.
869
-
870
- Acquires a channel from the global channel pool, creates a WorkerStub,
871
- dispatches the task, and verifies the first response is an Ack.
872
- The channel is automatically managed by the underlying infrastructure.
873
-
874
- :param task:
875
- The WoolTask to dispatch to the worker.
876
- :returns:
877
- A DispatchStream for reading task results.
878
- :raises RuntimeError:
879
- If the worker doesn't acknowledge the task.
880
- """
881
- async with with_timeout(self._semaphore, timeout=60):
882
- call: DispatchCall = self._stub.dispatch(task.to_protobuf())
883
-
884
- try:
885
- first_response = await asyncio.wait_for(anext(aiter(call)), timeout=60)
886
- if not first_response.HasField("ack"):
887
- raise UnexpectedResponse("Expected Ack response")
888
- except (
889
- asyncio.CancelledError,
890
- asyncio.TimeoutError,
891
- grpc.aio.AioRpcError,
892
- UnexpectedResponse,
893
- ):
894
- try:
895
- call.cancel()
896
- except Exception:
897
- pass
898
- raise
899
-
900
- async for result in DispatchStream(call):
901
- yield result
902
-
903
- async def stop(self):
904
- """Stop the client and close the gRPC channel.
905
-
906
- Gracefully closes the underlying gRPC channel and cleans up
907
- any resources associated with this client.
908
- """
909
- await self._channel.close()
910
-
911
-
912
- class UnexpectedResponse(Exception): ...