wool 0.1rc9__py3-none-any.whl → 0.1rc10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wool might be problematic. Click here for more details.

Files changed (44) hide show
  1. wool/__init__.py +71 -50
  2. wool/_protobuf/__init__.py +12 -5
  3. wool/_protobuf/exception.py +3 -0
  4. wool/_protobuf/task.py +11 -0
  5. wool/_protobuf/task_pb2.py +42 -0
  6. wool/_protobuf/task_pb2.pyi +43 -0
  7. wool/_protobuf/{mempool/metadata_pb2_grpc.py → task_pb2_grpc.py} +2 -2
  8. wool/_protobuf/worker.py +24 -0
  9. wool/_protobuf/worker_pb2.py +47 -0
  10. wool/_protobuf/worker_pb2.pyi +39 -0
  11. wool/_protobuf/worker_pb2_grpc.py +141 -0
  12. wool/_resource_pool.py +376 -0
  13. wool/_typing.py +0 -10
  14. wool/_work.py +553 -0
  15. wool/_worker.py +843 -169
  16. wool/_worker_discovery.py +1223 -0
  17. wool/_worker_pool.py +331 -0
  18. wool/_worker_proxy.py +515 -0
  19. {wool-0.1rc9.dist-info → wool-0.1rc10.dist-info}/METADATA +7 -7
  20. wool-0.1rc10.dist-info/RECORD +22 -0
  21. wool-0.1rc10.dist-info/entry_points.txt +2 -0
  22. wool/_cli.py +0 -262
  23. wool/_event.py +0 -109
  24. wool/_future.py +0 -171
  25. wool/_logging.py +0 -44
  26. wool/_manager.py +0 -181
  27. wool/_mempool/__init__.py +0 -4
  28. wool/_mempool/_client.py +0 -167
  29. wool/_mempool/_mempool.py +0 -311
  30. wool/_mempool/_metadata.py +0 -35
  31. wool/_mempool/_service.py +0 -227
  32. wool/_pool.py +0 -524
  33. wool/_protobuf/mempool/metadata_pb2.py +0 -36
  34. wool/_protobuf/mempool/metadata_pb2.pyi +0 -17
  35. wool/_protobuf/mempool/service_pb2.py +0 -66
  36. wool/_protobuf/mempool/service_pb2.pyi +0 -108
  37. wool/_protobuf/mempool/service_pb2_grpc.py +0 -355
  38. wool/_queue.py +0 -32
  39. wool/_session.py +0 -429
  40. wool/_task.py +0 -366
  41. wool/_utils.py +0 -63
  42. wool-0.1rc9.dist-info/RECORD +0 -29
  43. wool-0.1rc9.dist-info/entry_points.txt +0 -2
  44. {wool-0.1rc9.dist-info → wool-0.1rc10.dist-info}/WHEEL +0 -0
wool/_worker.py CHANGED
@@ -1,201 +1,875 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
- import logging
5
- from contextvars import ContextVar
4
+ import os
5
+ import signal
6
+ import uuid
7
+ from abc import ABC
8
+ from abc import abstractmethod
9
+ from contextlib import asynccontextmanager
10
+ from contextlib import contextmanager
6
11
  from multiprocessing import Pipe
7
12
  from multiprocessing import Process
8
- from multiprocessing import current_process
9
- from queue import Empty
10
- from signal import Signals
11
- from signal import signal
12
- from threading import Event
13
- from threading import Thread
14
- from time import sleep
13
+ from multiprocessing.connection import Connection
15
14
  from typing import TYPE_CHECKING
15
+ from typing import Any
16
+ from typing import AsyncIterator
17
+ from typing import Final
18
+ from typing import Generic
19
+ from typing import Protocol
20
+ from typing import TypeAlias
21
+ from typing import TypeVar
22
+ from typing import final
23
+
24
+ import cloudpickle
25
+ import grpc.aio
26
+ from grpc import StatusCode
27
+ from grpc.aio import ServicerContext
16
28
 
17
29
  import wool
18
- from wool._event import TaskEvent
19
- from wool._future import fulfill
20
- from wool._future import poll
21
- from wool._session import WorkerPoolSession
22
- from wool._session import WorkerSession
30
+ from wool import _protobuf as pb
31
+ from wool._resource_pool import ResourcePool
32
+ from wool._work import WoolTask
33
+ from wool._work import WoolTaskEvent
34
+ from wool._worker_discovery import RegistryServiceLike
35
+ from wool._worker_discovery import WorkerInfo
23
36
 
24
37
  if TYPE_CHECKING:
25
- from wool._task import Task
38
+ from wool._worker_proxy import WorkerProxy
39
+
40
+ _ip_address: str | None = None
41
+ _EXTERNAL_DNS_SERVER: Final[str] = "8.8.8.8" # Google DNS for IP detection
42
+
43
+
44
+ @contextmanager
45
+ def _signal_handlers(service: "WorkerService"):
46
+ """Context manager for setting up signal handlers for graceful shutdown.
47
+
48
+ Installs SIGTERM and SIGINT handlers that gracefully shut down the worker
49
+ service when the process receives termination signals.
50
+
51
+ :param service:
52
+ The :py:class:`WorkerService` instance to shut down on signal receipt.
53
+ :yields:
54
+ Control to the calling context with signal handlers installed.
55
+ """
56
+ try:
57
+ loop = asyncio.get_running_loop()
58
+ except RuntimeError:
59
+ loop = asyncio.new_event_loop()
60
+ asyncio.set_event_loop(loop)
61
+
62
+ def sigterm_handler(signum, frame):
63
+ if loop.is_running():
64
+ loop.call_soon_threadsafe(
65
+ lambda: asyncio.create_task(service._stop(timeout=0))
66
+ )
67
+
68
+ def sigint_handler(signum, frame):
69
+ if loop.is_running():
70
+ loop.call_soon_threadsafe(
71
+ lambda: asyncio.create_task(service._stop(timeout=None))
72
+ )
73
+
74
+ old_sigterm = signal.signal(signal.SIGTERM, sigterm_handler)
75
+ old_sigint = signal.signal(signal.SIGINT, sigint_handler)
76
+ try:
77
+ yield
78
+ finally:
79
+ signal.signal(signal.SIGTERM, old_sigterm)
80
+ signal.signal(signal.SIGINT, old_sigint)
26
81
 
27
82
 
28
- def _noop(*_):
29
- pass
83
+ _T_RegistryService = TypeVar(
84
+ "_T_RegistryService", bound=RegistryServiceLike, covariant=True
85
+ )
30
86
 
31
87
 
32
- class Scheduler(Thread):
88
+ # public
89
+ class Worker(ABC, Generic[_T_RegistryService]):
90
+ """Abstract base class for worker implementations in the wool framework.
91
+
92
+ Workers are individual processes that execute distributed tasks within
93
+ a worker pool. Each worker runs a gRPC server and registers itself with
94
+ a discovery service to be found by client sessions.
95
+
96
+ This class defines the core interface that all worker implementations
97
+ must provide, including lifecycle management and registry service
98
+ integration for peer-to-peer discovery.
99
+
100
+ :param tags:
101
+ Capability tags associated with this worker for filtering and
102
+ selection by client sessions.
103
+ :param registry_service:
104
+ Service instance for worker registration and discovery within
105
+ the distributed pool.
106
+ :param extra:
107
+ Additional arbitrary metadata as key-value pairs.
108
+ """
109
+
110
+ _info: WorkerInfo | None = None
111
+ _started: bool = False
112
+ _registry_service: RegistryServiceLike
113
+ _uid: Final[str]
114
+ _tags: Final[set[str]]
115
+ _extra: Final[dict[str, Any]]
116
+
33
117
  def __init__(
34
118
  self,
35
- address: tuple[str, int],
36
- loop: asyncio.AbstractEventLoop,
37
- stop_event: Event,
38
- ready: Event,
39
- timeout: float = 1,
40
- *args,
41
- **kwargs,
42
- ) -> None:
43
- super().__init__(*args, name="Scheduler", **kwargs)
44
- self._address: tuple[str, int] = address
45
- self._loop: asyncio.AbstractEventLoop = loop
46
- self._stop_event: Event = stop_event
47
- self._timeout: float = timeout
48
- self._worker_ready: Event = ready
119
+ *tags: str,
120
+ registry_service: _T_RegistryService,
121
+ **extra: Any,
122
+ ):
123
+ self._uid = f"worker-{uuid.uuid4().hex}"
124
+ self._tags = set(tags)
125
+ self._extra = extra
126
+ self._registry_service = registry_service
49
127
 
50
128
  @property
51
- def session_context(self) -> ContextVar[WorkerPoolSession]:
52
- return wool.__wool_session__
129
+ def uid(self) -> str:
130
+ """The worker's unique identifier."""
131
+ return self._uid
132
+
133
+ @property
134
+ def info(self) -> WorkerInfo | None:
135
+ """Worker information including network address and metadata.
136
+
137
+ :returns:
138
+ The worker's complete information or None if not started.
139
+ """
140
+ return self._info
141
+
142
+ @property
143
+ def tags(self) -> set[str]:
144
+ """Capability tags for this worker."""
145
+ return self._tags
146
+
147
+ @property
148
+ def extra(self) -> dict[str, Any]:
149
+ """Additional arbitrary metadata for this worker."""
150
+ return self._extra
151
+
152
+ @property
153
+ @abstractmethod
154
+ def address(self) -> str | None: ...
155
+
156
+ @property
157
+ @abstractmethod
158
+ def host(self) -> str | None: ...
159
+
160
+ @property
161
+ @abstractmethod
162
+ def port(self) -> int | None: ...
163
+
164
+ @final
165
+ async def start(self):
166
+ """Start the worker and register it with the pool.
167
+
168
+ This method is a final implementation that calls the abstract
169
+ `_start` method to initialize the worker process and register
170
+ it with the registry service.
171
+ """
172
+ if self._started:
173
+ raise RuntimeError("Worker has already been started")
174
+ if self._registry_service:
175
+ await self._registry_service.start()
176
+ await self._start()
177
+ self._started = True
178
+
179
+ @final
180
+ async def stop(self):
181
+ """Stop the worker and unregister it from the pool.
182
+
183
+ This method is a final implementation that calls the abstract
184
+ `_stop` method to gracefully shut down the worker process and
185
+ unregister it from the registry service.
186
+ """
187
+ if not self._started:
188
+ raise RuntimeError("Worker has not been started")
189
+ await self._stop()
190
+ if self._registry_service:
191
+ await self._registry_service.stop()
192
+
193
+ @abstractmethod
194
+ async def _start(self):
195
+ """Implementation-specific worker startup logic.
196
+
197
+ Subclasses must implement this method to handle the actual
198
+ startup of their worker process and gRPC server.
199
+ """
200
+ ...
201
+
202
+ @abstractmethod
203
+ async def _stop(self):
204
+ """Implementation-specific worker shutdown logic.
205
+
206
+ Subclasses must implement this method to handle the graceful
207
+ shutdown of their worker process and cleanup of resources.
208
+ """
209
+ ...
210
+
211
+
212
+ # public
213
+ class WorkerFactory(Generic[_T_RegistryService], Protocol):
214
+ """Protocol for creating worker instances with registry integration.
215
+
216
+ Defines the callable interface for worker factory implementations
217
+ that can create :py:class:`Worker` instances configured with specific
218
+ capability tags and metadata.
219
+
220
+ Worker factories are used by :py:class:`WorkerPool` to spawn multiple
221
+ worker processes with consistent configuration.
222
+ """
223
+
224
+ def __call__(self, *tags: str, **_) -> Worker[_T_RegistryService]:
225
+ """Create a new worker instance.
226
+
227
+ :param tags:
228
+ Additional tags to associate with this worker for discovery
229
+ and filtering purposes.
230
+ :returns:
231
+ A new :py:class:`Worker` instance configured with the
232
+ specified tags and metadata.
233
+ """
234
+ ...
235
+
236
+
237
+ # public
238
+ class LocalWorker(Worker[_T_RegistryService]):
239
+ """Local worker implementation that runs tasks in a separate process.
240
+
241
+ :py:class:`LocalWorker` creates and manages a dedicated worker process
242
+ that hosts a gRPC server for executing distributed wool tasks. Each
243
+ worker automatically registers itself with the provided registry service
244
+ for discovery by client sessions.
245
+
246
+ The worker process runs independently and can handle multiple concurrent
247
+ tasks within its own asyncio event loop, providing process-level
248
+ isolation for task execution.
249
+
250
+ :param tags:
251
+ Capability tags to associate with this worker for filtering
252
+ and selection by client sessions.
253
+ :param registry_service:
254
+ Service instance for worker registration and discovery.
255
+ :param extra:
256
+ Additional arbitrary metadata as key-value pairs.
257
+ """
258
+
259
+ _worker_process: WorkerProcess
53
260
 
54
- def run(self) -> None:
55
- logging.debug("Thread started")
56
- self._worker_ready.wait()
57
- sleep(0.1)
58
- with WorkerSession(address=self._address) as self.session:
59
- self.session_context.set(
60
- WorkerPoolSession(address=self._address).connect()
61
- )
62
- while not self._stop_event.is_set():
63
- try:
64
- task: Task = self.session.get(timeout=self._timeout)
65
- except Empty:
66
- continue
67
- else:
68
- self._schedule_task(task, self._loop)
69
- logging.debug("Thread stopped")
70
-
71
- def _schedule_task(
72
- self, wool_task: Task, loop: asyncio.AbstractEventLoop
73
- ) -> None:
74
- future = self.session.futures().setdefault(wool_task.id, wool.Future())
75
- task = asyncio.run_coroutine_threadsafe(wool_task.run(), loop)
76
- task.add_done_callback(fulfill(future))
77
- asyncio.run_coroutine_threadsafe(poll(future, task), loop)
78
- TaskEvent("task-scheduled", task=wool_task).emit()
79
-
80
-
81
- class Worker(Process):
82
261
  def __init__(
83
262
  self,
84
- address: tuple[str, int],
85
- *args,
86
- log_level: int = logging.INFO,
87
- scheduler: type[Scheduler] = Scheduler,
88
- **kwargs,
89
- ) -> None:
263
+ *tags: str,
264
+ host: str = "127.0.0.1",
265
+ port: int = 0,
266
+ registry_service: _T_RegistryService,
267
+ **extra: Any,
268
+ ):
269
+ super().__init__(*tags, registry_service=registry_service, **extra)
270
+ self._worker_process = WorkerProcess(host=host, port=port)
271
+
272
+ @property
273
+ def address(self) -> str | None:
274
+ """The network address where the worker is listening.
275
+
276
+ :returns:
277
+ The address in "host:port" format, or None if not started.
278
+ """
279
+ return self._worker_process.address
280
+
281
+ @property
282
+ def host(self) -> str | None:
283
+ """The host where the worker is listening.
284
+
285
+ :returns:
286
+ The host address, or None if not started.
287
+ """
288
+ return self._info.host if self._info else None
289
+
290
+ @property
291
+ def port(self) -> int | None:
292
+ """The port where the worker is listening.
293
+
294
+ :returns:
295
+ The port number, or None if not started.
296
+ """
297
+ return self._info.port if self._info else None
298
+
299
+ async def _start(self):
300
+ """Start the worker process and register it with the pool.
301
+
302
+ Initializes the registry service, starts the worker process
303
+ with its gRPC server, and registers the worker's network
304
+ address with the registry for discovery by client sessions.
305
+ """
306
+ loop = asyncio.get_running_loop()
307
+ await loop.run_in_executor(None, self._worker_process.start)
308
+ if not self._worker_process.address:
309
+ raise RuntimeError("Worker process failed to start - no address")
310
+ if not self._worker_process.pid:
311
+ raise RuntimeError("Worker process failed to start - no PID")
312
+
313
+ # Parse host and port from address
314
+ host, port_str = self._worker_process.address.split(":")
315
+ port = int(port_str)
316
+
317
+ # Create the WorkerInfo with the actual host, port, and pid
318
+ self._info = WorkerInfo(
319
+ uid=self._uid,
320
+ host=host,
321
+ port=port,
322
+ pid=self._worker_process.pid,
323
+ version=wool.__version__,
324
+ tags=self._tags,
325
+ extra=self._extra,
326
+ )
327
+ await self._registry_service.register(self._info)
328
+
329
+ async def _stop(self):
330
+ """Stop the worker process and unregister it from the pool.
331
+
332
+ Unregisters the worker from the registry service, gracefully
333
+ shuts down the worker process using SIGINT, and cleans up
334
+ the registry service. If graceful shutdown fails, the process
335
+ is forcefully terminated.
336
+ """
337
+ if not self._info:
338
+ raise RuntimeError("Cannot unregister - worker has no info")
339
+
340
+ await self._registry_service.unregister(self._info)
341
+
342
+ if not self._worker_process.is_alive():
343
+ return
344
+ try:
345
+ if self._worker_process.pid:
346
+ os.kill(self._worker_process.pid, signal.SIGINT)
347
+ self._worker_process.join()
348
+ except OSError:
349
+ if self._worker_process.is_alive():
350
+ self._worker_process.kill()
351
+
352
+
353
+ class WorkerProcess(Process):
354
+ """A :py:class:`multiprocessing.Process` that runs a gRPC worker
355
+ server.
356
+
357
+ :py:class:`WorkerProcess` creates an isolated Python process that hosts a
358
+ gRPC server for executing distributed tasks. Each process maintains
359
+ its own event loop and serves as an independent worker node in the
360
+ wool distributed runtime.
361
+
362
+ :param port:
363
+ Optional port number where the gRPC server will listen.
364
+ If None, a random available port will be selected.
365
+ :param args:
366
+ Additional positional arguments passed to the parent
367
+ :py:class:`multiprocessing.Process` class.
368
+ :param kwargs:
369
+ Additional keyword arguments passed to the parent
370
+ :py:class:`multiprocessing.Process` class.
371
+
372
+ .. attribute:: address
373
+ The network address where the gRPC server is listening.
374
+ """
375
+
376
+ _port: int | None
377
+ _get_port: Connection
378
+ _set_port: Connection
379
+
380
+ def __init__(self, *args, host: str = "127.0.0.1", port: int = 0, **kwargs):
90
381
  super().__init__(*args, **kwargs)
91
- self._address: tuple[str, int] = address
92
- self.log_level: int = log_level
93
- self._scheduler_type = scheduler
94
- self._get_stop, self._set_stop = Pipe(duplex=False)
95
- self._get_ready, self._set_ready = Pipe(duplex=False)
382
+ if not host:
383
+ raise ValueError("Host must be a non-blank string")
384
+ self._host = host
385
+ if port < 0:
386
+ raise ValueError("Port must be a positive integer")
387
+ self._port = port
388
+ self._get_port, self._set_port = Pipe(duplex=False)
96
389
 
97
390
  @property
98
- def loop(self) -> asyncio.AbstractEventLoop:
99
- return asyncio.get_event_loop()
391
+ def address(self) -> str | None:
392
+ """The network address where the gRPC server is listening.
393
+
394
+ :returns:
395
+ The address in "host:port" format, or None if not started.
396
+ """
397
+ return self._address(self._host, self._port) if self._port else None
398
+
399
+ @property
400
+ def host(self) -> str | None:
401
+ """The host where the gRPC server is listening.
402
+
403
+ :returns:
404
+ The host address, or None if not started.
405
+ """
406
+ return self._host
407
+
408
+ @property
409
+ def port(self) -> int | None:
410
+ """The port where the gRPC server is listening.
411
+
412
+ :returns:
413
+ The port number, or None if not started.
414
+ """
415
+ return self._port or None
100
416
 
101
417
  def start(self):
418
+ """Start the worker process.
419
+
420
+ Launches the worker process and waits until it has started
421
+ listening on a port. After starting, the :attr:`address`
422
+ property will contain the actual network address.
423
+
424
+ :raises RuntimeError:
425
+ If the worker process fails to start within 10 seconds.
426
+ :raises ValueError:
427
+ If the port is negative.
428
+ """
102
429
  super().start()
103
- self._get_ready.recv()
104
- self._get_ready.close()
430
+ # Add timeout to prevent hanging if child process fails to start
431
+ if self._get_port.poll(timeout=10): # 10 second timeout
432
+ self._port = self._get_port.recv()
433
+ else:
434
+ self.terminate()
435
+ self.join()
436
+ raise RuntimeError("Worker process failed to start within 10 seconds")
437
+ self._get_port.close()
105
438
 
106
439
  def run(self) -> None:
107
- signal(Signals.SIGINT, _noop)
108
- wool.__wool_worker__ = self
109
- self._set_stop.close()
110
- self._stop_event = Event()
111
- self._wait_event = Event()
112
-
113
- if self.log_level:
114
- wool.__log_level__ = self.log_level
115
- logging.basicConfig(format=wool.__log_format__)
116
- logging.getLogger().setLevel(self.log_level)
117
- logging.info(f"Set log level to {self.log_level}")
118
-
119
- logging.debug("Thread started")
120
-
121
- self.shutdown_sentinel = ShutdownSentinel(
122
- stop_event=self._stop_event,
123
- wait_event=self._wait_event,
124
- loop=self.loop,
125
- )
126
- self.shutdown_sentinel.start()
127
-
128
- logging.debug("Spawning scheduler thread...")
129
- self.scheduler = self._scheduler_type(
130
- address=self._address,
131
- loop=self.loop,
132
- stop_event=self._stop_event,
133
- ready=(_ready_event := Event()),
440
+ """Run the worker process.
441
+
442
+ Sets the event loop for this process and starts the gRPC server,
443
+ blocking until the server is stopped.
444
+ """
445
+
446
+ async def proxy_factory(proxy: WorkerProxy):
447
+ """Factory function for WorkerProxy instances in ResourcePool.
448
+
449
+ Starts the proxy if not already started and returns it.
450
+ The proxy object itself is used as the cache key.
451
+
452
+ :param proxy:
453
+ The WorkerProxy instance to start (passed as key from ResourcePool).
454
+ :returns:
455
+ The started WorkerProxy instance.
456
+ """
457
+ if not proxy.started:
458
+ await proxy.start()
459
+ return proxy
460
+
461
+ async def proxy_finalizer(proxy: WorkerProxy):
462
+ """Finalizer function for WorkerProxy instances in ResourcePool.
463
+
464
+ Stops the proxy when it's being cleaned up from the resource pool.
465
+ Based on the cleanup logic from WorkerProxyCache._delayed_cleanup.
466
+
467
+ :param proxy:
468
+ The WorkerProxy instance to clean up.
469
+ """
470
+ try:
471
+ await proxy.stop()
472
+ except Exception:
473
+ pass
474
+
475
+ wool.__proxy_pool__.set(
476
+ ResourcePool(factory=proxy_factory, finalizer=proxy_finalizer, ttl=60)
134
477
  )
135
- self.scheduler.start()
136
-
137
- loop = Thread(target=self.loop.run_forever, name="EventLoop")
138
- loop.start()
139
-
140
- self._set_ready.send(True)
141
- self._set_ready.close()
142
- _ready_event.set()
143
- self.stop(self._get_stop.recv())
144
- self._get_stop.close()
145
- self.scheduler.join()
146
- loop.join()
147
- self.shutdown_sentinel.join()
148
- logging.info("Thread stopped")
149
-
150
- def stop(self, wait: bool = True) -> None:
151
- if self.pid == current_process().pid:
152
- if wait and not self._wait_event.is_set():
153
- self._wait_event.set()
154
- if not self._stop_event.is_set():
155
- self._stop_event.set()
156
- elif self.pid:
157
- self._set_stop.send(wait)
158
-
159
-
160
- class ShutdownSentinel(Thread):
161
- def __init__(
162
- self,
163
- stop_event: Event,
164
- wait_event: Event,
165
- loop: asyncio.AbstractEventLoop,
166
- *args,
167
- **kwargs,
168
- ) -> None:
169
- super().__init__(*args, name="ShutdownSentinel", **kwargs)
170
- self._stop_event: Event = stop_event
171
- self._wait_event: Event = wait_event
172
- self._loop: asyncio.AbstractEventLoop = loop
478
+ asyncio.run(self._serve())
173
479
 
174
- def run(self) -> None:
175
- logging.debug("Thread started")
176
- self._stop_event.wait()
177
- logging.debug("Shutdown signal received")
178
- if not self._wait_event.is_set():
179
- logging.warning("Cancelling tasks...")
180
- asyncio.run_coroutine_threadsafe(self._cancel_tasks(), self._loop)
181
- if tasks := asyncio.all_tasks(self._loop):
182
- logging.info("Gathering tasks...")
183
- future = asyncio.run_coroutine_threadsafe(
184
- self._gather(*tasks), self._loop
480
+ async def _serve(self):
481
+ """Start the gRPC server in this worker process.
482
+
483
+ This method is called by the event loop to start serving
484
+ requests. It creates a gRPC server, adds the worker service, and
485
+ starts listening for incoming connections.
486
+ """
487
+ server = grpc.aio.server()
488
+ port = server.add_insecure_port(self._address(self._host, self._port))
489
+ service = WorkerService()
490
+ pb.add_to_server[pb.worker.WorkerServicer](service, server)
491
+
492
+ with _signal_handlers(service):
493
+ try:
494
+ await server.start()
495
+ try:
496
+ self._set_port.send(port)
497
+ finally:
498
+ self._set_port.close()
499
+ await service.stopped.wait()
500
+ finally:
501
+ await server.stop(grace=60)
502
+
503
+ def _address(self, host, port) -> str:
504
+ """Format network address for the given port.
505
+
506
+ :param port:
507
+ Port number to include in the address.
508
+ :returns:
509
+ Address string in "host:port" format.
510
+ """
511
+ return f"{host}:{port}"
512
+
513
+
514
+ class WorkerService(pb.worker.WorkerServicer):
515
+ """gRPC service implementation for executing distributed wool tasks.
516
+
517
+ :py:class:`WorkerService` implements the gRPC WorkerServicer
518
+ interface, providing remote procedure calls for task scheduling
519
+ and worker lifecycle management. Tasks are executed in the same
520
+ asyncio event loop as the gRPC server.
521
+
522
+ .. note::
523
+ Tasks are executed asynchronously in the current event loop
524
+ and results are serialized for transport back to the client.
525
+ The service maintains a set of running tasks for proper
526
+ lifecycle management during shutdown.
527
+
528
+ During shutdown, the service stops accepting new requests
529
+ immediately when the :meth:`stop` RPC is called, returning
530
+ UNAVAILABLE errors to new :meth:`dispatch` requests while
531
+ allowing existing tasks to complete gracefully.
532
+
533
+ The service provides :attr:`stopping` and
534
+ :attr:`stopped` properties to access the internal shutdown
535
+ state events.
536
+ """
537
+
538
+ _tasks: set[asyncio.Task]
539
+ _stopped: asyncio.Event
540
+ _stopping: asyncio.Event
541
+ _task_completed: asyncio.Event
542
+
543
+ def __init__(self):
544
+ self._stopped = asyncio.Event()
545
+ self._stopping = asyncio.Event()
546
+ self._task_completed = asyncio.Event()
547
+ self._tasks = set()
548
+
549
+ @property
550
+ def stopping(self) -> asyncio.Event:
551
+ """Event signaling that the service is stopping.
552
+
553
+ :returns:
554
+ An :py:class:`asyncio.Event` that is set when the service
555
+ begins shutdown.
556
+ """
557
+ return self._stopping
558
+
559
+ @property
560
+ def stopped(self) -> asyncio.Event:
561
+ """Event signaling that the service has stopped.
562
+
563
+ :returns:
564
+ An :py:class:`asyncio.Event` that is set when the service
565
+ has completed shutdown.
566
+ """
567
+ return self._stopped
568
+
569
+ @contextmanager
570
+ def _running(self, wool_task: WoolTask):
571
+ """Context manager for tracking running tasks.
572
+
573
+ Manages the lifecycle of a task execution, adding it to the
574
+ active tasks set and emitting appropriate events. Ensures
575
+ proper cleanup when the task completes or fails.
576
+
577
+ :param wool_task:
578
+ The :py:class:`WoolTask` instance to execute and track.
579
+ :yields:
580
+ The :py:class:`asyncio.Task` created for the wool task.
581
+
582
+ .. note::
583
+ Emits a :py:class:`WoolTaskEvent` with type "task-scheduled"
584
+ when the task begins execution.
585
+ """
586
+ WoolTaskEvent("task-scheduled", task=wool_task).emit()
587
+ task = asyncio.create_task(wool_task.run())
588
+ self._tasks.add(task)
589
+ try:
590
+ yield task
591
+ finally:
592
+ self._tasks.remove(task)
593
+
594
+ async def dispatch(
595
+ self, request: pb.task.Task, context: ServicerContext
596
+ ) -> AsyncIterator[pb.worker.Response]:
597
+ """Execute a task in the current event loop.
598
+
599
+ Deserializes the incoming task into a :py:class:`WoolTask`
600
+ instance, schedules it for execution in the current asyncio
601
+ event loop, and yields responses for acknowledgment and result.
602
+
603
+ :param request:
604
+ The protobuf task message containing the serialized task
605
+ data.
606
+ :param context:
607
+ The :py:class:`grpc.aio.ServicerContext` for this request.
608
+ :yields:
609
+ First yields an Ack Response when task processing begins,
610
+ then yields a Response containing the task result.
611
+
612
+ .. note::
613
+ Emits a :py:class:`WoolTaskEvent` when the task is
614
+ scheduled for execution.
615
+ """
616
+ if self._stopping.is_set():
617
+ await context.abort(
618
+ StatusCode.UNAVAILABLE, "Worker service is shutting down"
185
619
  )
186
- while not future.done():
187
- sleep(0.1)
188
- self._loop.call_soon_threadsafe(self._loop.stop)
189
- logging.debug("Thread stopped")
190
-
191
- async def _gather(self, *tasks: asyncio.Task) -> list:
192
- return await asyncio.gather(*tasks, return_exceptions=True)
193
-
194
- async def _cancel_tasks(self):
195
- for task in asyncio.all_tasks(self._loop):
196
- if task == asyncio.current_task(self._loop):
197
- continue
198
- if task.get_coro():
199
- if task.cancel():
200
- logging.debug(f"Cancelled task {task.get_coro()}")
201
- await asyncio.sleep(0)
620
+
621
+ with self._running(WoolTask.from_protobuf(request)) as task:
622
+ # Yield acknowledgment that task was received and processing is starting
623
+ yield pb.worker.Response(ack=pb.worker.Ack())
624
+
625
+ try:
626
+ result = pb.task.Result(dump=cloudpickle.dumps(await task))
627
+ yield pb.worker.Response(result=result)
628
+ except Exception as e:
629
+ exception = pb.task.Exception(dump=cloudpickle.dumps(e))
630
+ yield pb.worker.Response(exception=exception)
631
+
632
+ async def stop(
633
+ self, request: pb.worker.StopRequest, context: ServicerContext
634
+ ) -> pb.worker.Void:
635
+ """Stop the worker service and its thread.
636
+
637
+ Gracefully shuts down the worker thread and signals the server
638
+ to stop accepting new requests. This method is idempotent and
639
+ can be called multiple times safely.
640
+
641
+ :param request:
642
+ The protobuf stop request containing the wait timeout.
643
+ :param context:
644
+ The :py:class:`grpc.aio.ServicerContext` for this request.
645
+ :returns:
646
+ An empty protobuf response indicating completion.
647
+ """
648
+ if self._stopping.is_set():
649
+ return pb.worker.Void()
650
+ await self._stop(timeout=request.wait)
651
+ return pb.worker.Void()
652
+
653
+ async def _stop(self, *, timeout: float | None = 0) -> None:
654
+ self._stopping.set()
655
+ await self._await_or_cancel_tasks(timeout=timeout)
656
+
657
+ # Clean up the session cache to prevent issues during shutdown
658
+ try:
659
+ proxy_pool = wool.__proxy_pool__.get()
660
+ assert proxy_pool
661
+ await proxy_pool.clear()
662
+ finally:
663
+ self._stopped.set()
664
+
665
+ async def _await_or_cancel_tasks(self, *, timeout: float | None = 0) -> None:
666
+ """Stop the worker service gracefully.
667
+
668
+ Gracefully shuts down the worker service by canceling or waiting
669
+ for running tasks. This method is idempotent and can be called
670
+ multiple times safely.
671
+
672
+ :param timeout:
673
+ Maximum time to wait for tasks to complete. If 0 (default),
674
+ tasks are canceled immediately. If None, waits indefinitely.
675
+ If a positive number, waits for that many seconds before
676
+ canceling tasks.
677
+
678
+ .. note::
679
+ If a timeout occurs while waiting for tasks to complete,
680
+ the method recursively calls itself with a timeout of 0
681
+ to cancel all remaining tasks immediately.
682
+ """
683
+ if self._tasks and timeout == 0:
684
+ await self._cancel(*self._tasks)
685
+ elif self._tasks:
686
+ try:
687
+ await asyncio.wait_for(
688
+ asyncio.gather(*self._tasks, return_exceptions=True),
689
+ timeout=timeout,
690
+ )
691
+ except asyncio.TimeoutError:
692
+ return await self._await_or_cancel_tasks(timeout=0)
693
+
694
+ async def _cancel(self, *tasks: asyncio.Task):
695
+ """Cancel multiple tasks safely.
696
+
697
+ Cancels the provided tasks while performing safety checks to
698
+ avoid canceling the current task or already completed tasks.
699
+ Waits for all cancelled tasks to complete in parallel and handles
700
+ cancellation exceptions.
701
+
702
+ :param tasks:
703
+ The :py:class:`asyncio.Task` instances to cancel.
704
+
705
+ .. note::
706
+ This method performs the following safety checks:
707
+ - Avoids canceling the current task (would cause deadlock)
708
+ - Only cancels tasks that are not already done
709
+ - Properly handles :py:exc:`asyncio.CancelledError`
710
+ exceptions.
711
+ """
712
+ current = asyncio.current_task()
713
+ to_cancel = [task for task in tasks if not task.done() and task != current]
714
+
715
+ # Cancel all tasks first
716
+ for task in to_cancel:
717
+ task.cancel()
718
+
719
+ # Wait for all cancelled tasks in parallel
720
+ if to_cancel:
721
+ await asyncio.gather(*to_cancel, return_exceptions=True)
722
+
723
+
724
+ DispatchCall: TypeAlias = grpc.aio.UnaryStreamCall[pb.task.Task, pb.worker.Response]
725
+
726
+
727
+ @asynccontextmanager
728
+ async def with_timeout(context, timeout):
729
+ """Async context manager wrapper that adds timeout to context entry.
730
+
731
+ :param context:
732
+ The async context manager to wrap.
733
+ :param timeout:
734
+ Timeout in seconds for context entry.
735
+ :yields:
736
+ Control to the calling context.
737
+ :raises asyncio.TimeoutError:
738
+ If context entry exceeds the timeout.
739
+ """
740
+ await asyncio.wait_for(context.__aenter__(), timeout=timeout)
741
+ exception_type = exception_value = exception_traceback = None
742
+ try:
743
+ yield
744
+ except BaseException as exception:
745
+ exception_type = type(exception)
746
+ exception_value = exception
747
+ exception_traceback = exception.__traceback__
748
+ raise
749
+ finally:
750
+ await context.__aexit__(exception_type, exception_value, exception_traceback)
751
+
752
+
753
+ T = TypeVar("T")
754
+
755
+
756
+ class DispatchStream(Generic[T]):
757
+ """Async iterator wrapper for streaming dispatch results.
758
+
759
+ Simplified wrapper that focuses solely on stream iteration and response handling.
760
+ Channel management is now handled by the WorkerClient.
761
+ """
762
+
763
+ def __init__(self, stream: DispatchCall):
764
+ """Initialize the streaming dispatch result wrapper.
765
+
766
+ :param stream:
767
+ The underlying gRPC response stream.
768
+ """
769
+ self._stream = stream
770
+ self._iter = aiter(stream)
771
+
772
+ def __aiter__(self) -> AsyncIterator[T]:
773
+ """Return self as the async iterator."""
774
+ return self
775
+
776
+ async def __anext__(self) -> T:
777
+ """Get the next response from the stream.
778
+
779
+ :returns:
780
+ The next task result from the worker.
781
+ :raises StopAsyncIteration:
782
+ When the stream is exhausted.
783
+ """
784
+ try:
785
+ response = await anext(self._iter)
786
+ if response.HasField("result"):
787
+ return cloudpickle.loads(response.result.dump)
788
+ elif response.HasField("exception"):
789
+ raise cloudpickle.loads(response.exception.dump)
790
+ else:
791
+ raise RuntimeError(f"Received unexpected response: {response}")
792
+ except Exception as exception:
793
+ await self._handle_exception(exception)
794
+
795
+ async def _handle_exception(self, exception):
796
+ try:
797
+ self._stream.cancel()
798
+ except Exception as cancel_exception:
799
+ raise cancel_exception from exception
800
+ else:
801
+ raise exception
802
+
803
+
804
+ class WorkerClient:
805
+ """Client for dispatching tasks to a specific worker.
806
+
807
+ Simplified client that maintains a persistent gRPC channel to a single
808
+ worker. The client manages the channel lifecycle and provides task
809
+ dispatch functionality with proper error handling.
810
+
811
+ :param address:
812
+ The network address of the target worker in "host:port" format.
813
+ """
814
+
815
+ def __init__(self, address: str):
816
+ self._channel = grpc.aio.insecure_channel(
817
+ address,
818
+ # options=[
819
+ # ("grpc.keepalive_time_ms", 10000),
820
+ # ("grpc.keepalive_timeout_ms", 5000),
821
+ # ("grpc.http2.max_pings_without_data", 0),
822
+ # ("grpc.http2.min_time_between_pings_ms", 10000),
823
+ # ("grpc.max_receive_message_length", 100 * 1024 * 1024),
824
+ # ("grpc.max_send_message_length", 100 * 1024 * 1024),
825
+ # ],
826
+ )
827
+ self._stub = pb.worker.WorkerStub(self._channel)
828
+ self._semaphore = asyncio.Semaphore(100)
829
+
830
+ async def dispatch(self, task: WoolTask) -> AsyncIterator[pb.task.Result]:
831
+ """Dispatch task to worker with on-demand channel acquisition.
832
+
833
+ Acquires a channel from the global channel pool, creates a WorkerStub,
834
+ dispatches the task, and verifies the first response is an Ack.
835
+ The channel is automatically managed by the underlying infrastructure.
836
+
837
+ :param task:
838
+ The WoolTask to dispatch to the worker.
839
+ :returns:
840
+ A DispatchStream for reading task results.
841
+ :raises RuntimeError:
842
+ If the worker doesn't acknowledge the task.
843
+ """
844
+ async with with_timeout(self._semaphore, timeout=60):
845
+ call: DispatchCall = self._stub.dispatch(task.to_protobuf())
846
+
847
+ try:
848
+ first_response = await asyncio.wait_for(anext(aiter(call)), timeout=60)
849
+ if not first_response.HasField("ack"):
850
+ raise UnexpectedResponse("Expected Ack response")
851
+ except (
852
+ asyncio.CancelledError,
853
+ asyncio.TimeoutError,
854
+ grpc.aio.AioRpcError,
855
+ UnexpectedResponse,
856
+ ):
857
+ try:
858
+ call.cancel()
859
+ except Exception:
860
+ pass
861
+ raise
862
+
863
+ async for result in DispatchStream(call):
864
+ yield result
865
+
866
+ async def stop(self):
867
+ """Stop the client and close the gRPC channel.
868
+
869
+ Gracefully closes the underlying gRPC channel and cleans up
870
+ any resources associated with this client.
871
+ """
872
+ await self._channel.close()
873
+
874
+
875
+ class UnexpectedResponse(Exception): ...