wool 0.1rc13__py3-none-any.whl → 0.1rc15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wool might be problematic. Click here for more details.
- wool/__init__.py +1 -26
- wool/_connection.py +247 -0
- wool/_context.py +29 -0
- wool/_loadbalancer.py +213 -0
- wool/_protobuf/task_pb2_grpc.py +2 -2
- wool/_protobuf/worker_pb2.py +6 -6
- wool/_protobuf/worker_pb2.pyi +4 -4
- wool/_protobuf/worker_pb2_grpc.py +2 -2
- wool/_resource_pool.py +3 -3
- wool/_undefined.py +11 -0
- wool/_work.py +5 -4
- wool/_worker.py +115 -426
- wool/_worker_discovery.py +24 -36
- wool/_worker_pool.py +46 -31
- wool/_worker_proxy.py +54 -186
- wool/_worker_service.py +243 -0
- {wool-0.1rc13.dist-info → wool-0.1rc15.dist-info}/METADATA +153 -41
- wool-0.1rc15.dist-info/RECORD +27 -0
- wool-0.1rc13.dist-info/RECORD +0 -22
- {wool-0.1rc13.dist-info → wool-0.1rc15.dist-info}/WHEEL +0 -0
- {wool-0.1rc13.dist-info → wool-0.1rc15.dist-info}/entry_points.txt +0 -0
wool/_worker.py
CHANGED
|
@@ -1,75 +1,66 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
import os
|
|
5
4
|
import signal
|
|
6
5
|
import uuid
|
|
7
6
|
from abc import ABC
|
|
8
7
|
from abc import abstractmethod
|
|
9
|
-
from contextlib import asynccontextmanager
|
|
10
8
|
from contextlib import contextmanager
|
|
11
9
|
from multiprocessing import Pipe
|
|
12
10
|
from multiprocessing import Process
|
|
13
11
|
from multiprocessing.connection import Connection
|
|
12
|
+
from types import MappingProxyType
|
|
14
13
|
from typing import TYPE_CHECKING
|
|
15
14
|
from typing import Any
|
|
16
15
|
from typing import AsyncContextManager
|
|
17
|
-
from typing import AsyncIterator
|
|
18
16
|
from typing import Awaitable
|
|
19
17
|
from typing import ContextManager
|
|
20
18
|
from typing import Final
|
|
21
|
-
from typing import Generic
|
|
22
19
|
from typing import Protocol
|
|
23
|
-
from typing import TypeAlias
|
|
24
|
-
from typing import TypeVar
|
|
25
20
|
from typing import final
|
|
26
21
|
|
|
27
|
-
import cloudpickle
|
|
28
22
|
import grpc.aio
|
|
29
|
-
from grpc import StatusCode
|
|
30
|
-
from grpc.aio import ServicerContext
|
|
31
23
|
|
|
32
24
|
import wool
|
|
33
25
|
from wool import _protobuf as pb
|
|
34
26
|
from wool._resource_pool import ResourcePool
|
|
35
|
-
from wool._work import WoolTask
|
|
36
|
-
from wool._work import WoolTaskEvent
|
|
37
27
|
from wool._worker_discovery import Factory
|
|
38
28
|
from wool._worker_discovery import RegistrarLike
|
|
39
29
|
from wool._worker_discovery import WorkerInfo
|
|
30
|
+
from wool._worker_service import WorkerService
|
|
40
31
|
|
|
41
32
|
if TYPE_CHECKING:
|
|
42
33
|
from wool._worker_proxy import WorkerProxy
|
|
43
34
|
|
|
44
35
|
|
|
45
36
|
@contextmanager
|
|
46
|
-
def _signal_handlers(service:
|
|
37
|
+
def _signal_handlers(service: WorkerService):
|
|
47
38
|
"""Context manager for setting up signal handlers for graceful shutdown.
|
|
48
39
|
|
|
49
40
|
Installs SIGTERM and SIGINT handlers that gracefully shut down the worker
|
|
50
41
|
service when the process receives termination signals.
|
|
51
42
|
|
|
52
43
|
:param service:
|
|
53
|
-
The :
|
|
44
|
+
The :class:`WorkerService` instance to shut down on signal receipt.
|
|
54
45
|
:yields:
|
|
55
46
|
Control to the calling context with signal handlers installed.
|
|
56
47
|
"""
|
|
57
|
-
|
|
58
|
-
loop = asyncio.get_running_loop()
|
|
59
|
-
except RuntimeError:
|
|
60
|
-
loop = asyncio.new_event_loop()
|
|
61
|
-
asyncio.set_event_loop(loop)
|
|
48
|
+
loop = asyncio.get_running_loop()
|
|
62
49
|
|
|
63
50
|
def sigterm_handler(signum, frame):
|
|
64
51
|
if loop.is_running():
|
|
65
52
|
loop.call_soon_threadsafe(
|
|
66
|
-
lambda: asyncio.create_task(
|
|
53
|
+
lambda: asyncio.create_task(
|
|
54
|
+
service.stop(pb.worker.StopRequest(timeout=0), None)
|
|
55
|
+
)
|
|
67
56
|
)
|
|
68
57
|
|
|
69
58
|
def sigint_handler(signum, frame):
|
|
70
59
|
if loop.is_running():
|
|
71
60
|
loop.call_soon_threadsafe(
|
|
72
|
-
lambda: asyncio.create_task(
|
|
61
|
+
lambda: asyncio.create_task(
|
|
62
|
+
service.stop(pb.worker.StopRequest(timeout=None), None)
|
|
63
|
+
)
|
|
73
64
|
)
|
|
74
65
|
|
|
75
66
|
old_sigterm = signal.signal(signal.SIGTERM, sigterm_handler)
|
|
@@ -162,13 +153,24 @@ class Worker(ABC):
|
|
|
162
153
|
def port(self) -> int | None: ...
|
|
163
154
|
|
|
164
155
|
@final
|
|
165
|
-
async def start(self):
|
|
156
|
+
async def start(self, *, timeout: float | None = None):
|
|
166
157
|
"""Start the worker and register it with the pool.
|
|
167
158
|
|
|
168
159
|
This method is a final implementation that calls the abstract
|
|
169
160
|
`_start` method to initialize the worker process and register
|
|
170
161
|
it with the registrar service.
|
|
162
|
+
|
|
163
|
+
:param timeout:
|
|
164
|
+
Maximum time in seconds to wait for worker startup.
|
|
165
|
+
:raises TimeoutError:
|
|
166
|
+
If startup takes longer than the specified timeout.
|
|
167
|
+
:raises RuntimeError:
|
|
168
|
+
If the worker has already been started.
|
|
169
|
+
:raises ValueError:
|
|
170
|
+
If the timeout is not positive.
|
|
171
171
|
"""
|
|
172
|
+
if timeout is not None and timeout <= 0:
|
|
173
|
+
raise ValueError("Timeout must be positive")
|
|
172
174
|
if self._started:
|
|
173
175
|
raise RuntimeError("Worker has already been started")
|
|
174
176
|
|
|
@@ -178,13 +180,13 @@ class Worker(ABC):
|
|
|
178
180
|
if not isinstance(self._registrar_service, RegistrarLike):
|
|
179
181
|
raise ValueError("Registrar factory must return a RegistrarLike instance")
|
|
180
182
|
|
|
181
|
-
await self._start()
|
|
183
|
+
await self._start(timeout=timeout)
|
|
182
184
|
self._started = True
|
|
183
185
|
assert self._info
|
|
184
186
|
await self._registrar_service.register(self._info)
|
|
185
187
|
|
|
186
188
|
@final
|
|
187
|
-
async def stop(self):
|
|
189
|
+
async def stop(self, *, timeout: float | None = None):
|
|
188
190
|
"""Stop the worker and unregister it from the pool.
|
|
189
191
|
|
|
190
192
|
This method is a final implementation that calls the abstract
|
|
@@ -200,7 +202,7 @@ class Worker(ABC):
|
|
|
200
202
|
await self._registrar_service.unregister(self._info)
|
|
201
203
|
finally:
|
|
202
204
|
try:
|
|
203
|
-
await self._stop()
|
|
205
|
+
await self._stop(timeout)
|
|
204
206
|
finally:
|
|
205
207
|
await self._exit_context(self._registrar_context)
|
|
206
208
|
self._registrar_service = None
|
|
@@ -208,16 +210,19 @@ class Worker(ABC):
|
|
|
208
210
|
self._started = False
|
|
209
211
|
|
|
210
212
|
@abstractmethod
|
|
211
|
-
async def _start(self):
|
|
213
|
+
async def _start(self, timeout: float | None):
|
|
212
214
|
"""Implementation-specific worker startup logic.
|
|
213
215
|
|
|
214
216
|
Subclasses must implement this method to handle the actual
|
|
215
217
|
startup of their worker process and gRPC server.
|
|
218
|
+
|
|
219
|
+
:param timeout:
|
|
220
|
+
Maximum time in seconds to wait for worker startup.
|
|
216
221
|
"""
|
|
217
222
|
...
|
|
218
223
|
|
|
219
224
|
@abstractmethod
|
|
220
|
-
async def _stop(self):
|
|
225
|
+
async def _stop(self, timeout: float | None):
|
|
221
226
|
"""Implementation-specific worker shutdown logic.
|
|
222
227
|
|
|
223
228
|
Subclasses must implement this method to handle the graceful
|
|
@@ -246,6 +251,8 @@ class Worker(ABC):
|
|
|
246
251
|
self, ctx: AsyncContextManager | ContextManager | None, *args
|
|
247
252
|
):
|
|
248
253
|
"""Exit context for context managers."""
|
|
254
|
+
if not args:
|
|
255
|
+
args = (None, None, None)
|
|
249
256
|
if isinstance(ctx, AsyncContextManager):
|
|
250
257
|
await ctx.__aexit__(*args)
|
|
251
258
|
elif isinstance(ctx, ContextManager):
|
|
@@ -257,10 +264,10 @@ class WorkerFactory(Protocol):
|
|
|
257
264
|
"""Protocol for creating worker instances with registrar integration.
|
|
258
265
|
|
|
259
266
|
Defines the callable interface for worker factory implementations
|
|
260
|
-
that can create :
|
|
267
|
+
that can create :class:`Worker` instances configured with specific
|
|
261
268
|
capability tags and metadata.
|
|
262
269
|
|
|
263
|
-
Worker factories are used by :
|
|
270
|
+
Worker factories are used by :class:`WorkerPool` to spawn multiple
|
|
264
271
|
worker processes with consistent configuration.
|
|
265
272
|
"""
|
|
266
273
|
|
|
@@ -271,7 +278,7 @@ class WorkerFactory(Protocol):
|
|
|
271
278
|
Additional tags to associate with this worker for discovery
|
|
272
279
|
and filtering purposes.
|
|
273
280
|
:returns:
|
|
274
|
-
A new :
|
|
281
|
+
A new :class:`Worker` instance configured with the
|
|
275
282
|
specified tags and metadata.
|
|
276
283
|
"""
|
|
277
284
|
...
|
|
@@ -281,7 +288,7 @@ class WorkerFactory(Protocol):
|
|
|
281
288
|
class LocalWorker(Worker):
|
|
282
289
|
"""Local worker implementation that runs tasks in a separate process.
|
|
283
290
|
|
|
284
|
-
:
|
|
291
|
+
:class:`LocalWorker` creates and manages a dedicated worker process
|
|
285
292
|
that hosts a gRPC server for executing distributed wool tasks. Each
|
|
286
293
|
worker automatically registers itself with the provided registrar service
|
|
287
294
|
for discovery by client sessions.
|
|
@@ -293,8 +300,17 @@ class LocalWorker(Worker):
|
|
|
293
300
|
:param tags:
|
|
294
301
|
Capability tags to associate with this worker for filtering
|
|
295
302
|
and selection by client sessions.
|
|
303
|
+
:param host:
|
|
304
|
+
Host address where the worker will listen.
|
|
305
|
+
:param port:
|
|
306
|
+
Port number where the worker will listen. If 0, a random
|
|
307
|
+
available port will be selected.
|
|
296
308
|
:param registrar:
|
|
297
309
|
Service instance or factory for worker registration and discovery.
|
|
310
|
+
:param shutdown_grace_period:
|
|
311
|
+
Graceful shutdown timeout for the gRPC server in seconds.
|
|
312
|
+
:param proxy_pool_ttl:
|
|
313
|
+
Time-to-live for the proxy resource pool in seconds.
|
|
298
314
|
:param extra:
|
|
299
315
|
Additional arbitrary metadata as key-value pairs.
|
|
300
316
|
"""
|
|
@@ -307,10 +323,17 @@ class LocalWorker(Worker):
|
|
|
307
323
|
host: str = "127.0.0.1",
|
|
308
324
|
port: int = 0,
|
|
309
325
|
registrar: RegistrarLike | Factory[RegistrarLike],
|
|
326
|
+
shutdown_grace_period: float = 60.0,
|
|
327
|
+
proxy_pool_ttl: float = 60.0,
|
|
310
328
|
**extra: Any,
|
|
311
329
|
):
|
|
312
330
|
super().__init__(*tags, registrar=registrar, **extra)
|
|
313
|
-
self._worker_process = WorkerProcess(
|
|
331
|
+
self._worker_process = WorkerProcess(
|
|
332
|
+
host=host,
|
|
333
|
+
port=port,
|
|
334
|
+
shutdown_grace_period=shutdown_grace_period,
|
|
335
|
+
proxy_pool_ttl=proxy_pool_ttl,
|
|
336
|
+
)
|
|
314
337
|
|
|
315
338
|
@property
|
|
316
339
|
def address(self) -> str | None:
|
|
@@ -339,36 +362,39 @@ class LocalWorker(Worker):
|
|
|
339
362
|
"""
|
|
340
363
|
return self._info.port if self._info else None
|
|
341
364
|
|
|
342
|
-
async def _start(self):
|
|
365
|
+
async def _start(self, timeout: float | None):
|
|
343
366
|
"""Start the worker process and register it with the pool.
|
|
344
367
|
|
|
345
368
|
Initializes the registrar service, starts the worker process
|
|
346
369
|
with its gRPC server, and registers the worker's network
|
|
347
370
|
address with the registrar for discovery by client sessions.
|
|
371
|
+
|
|
372
|
+
:param timeout:
|
|
373
|
+
Maximum time in seconds to wait for worker process startup.
|
|
348
374
|
"""
|
|
349
375
|
loop = asyncio.get_running_loop()
|
|
350
|
-
await loop.run_in_executor(
|
|
376
|
+
await loop.run_in_executor(
|
|
377
|
+
None, lambda t: self._worker_process.start(timeout=t), timeout
|
|
378
|
+
)
|
|
351
379
|
if not self._worker_process.address:
|
|
352
380
|
raise RuntimeError("Worker process failed to start - no address")
|
|
353
381
|
if not self._worker_process.pid:
|
|
354
382
|
raise RuntimeError("Worker process failed to start - no PID")
|
|
355
383
|
|
|
356
|
-
# Parse host and port from address
|
|
357
384
|
host, port_str = self._worker_process.address.split(":")
|
|
358
385
|
port = int(port_str)
|
|
359
386
|
|
|
360
|
-
# Create the WorkerInfo with the actual host, port, and pid
|
|
361
387
|
self._info = WorkerInfo(
|
|
362
388
|
uid=self._uid,
|
|
363
389
|
host=host,
|
|
364
390
|
port=port,
|
|
365
391
|
pid=self._worker_process.pid,
|
|
366
392
|
version=wool.__version__,
|
|
367
|
-
tags=self._tags,
|
|
368
|
-
extra=self._extra,
|
|
393
|
+
tags=frozenset(self._tags),
|
|
394
|
+
extra=MappingProxyType(self._extra),
|
|
369
395
|
)
|
|
370
396
|
|
|
371
|
-
async def _stop(self):
|
|
397
|
+
async def _stop(self, timeout: float | None):
|
|
372
398
|
"""Stop the worker process and unregister it from the pool.
|
|
373
399
|
|
|
374
400
|
Unregisters the worker from the registrar service, gracefully
|
|
@@ -376,35 +402,37 @@ class LocalWorker(Worker):
|
|
|
376
402
|
the registrar service. If graceful shutdown fails, the process
|
|
377
403
|
is forcefully terminated.
|
|
378
404
|
"""
|
|
379
|
-
if
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
self._worker_process.join()
|
|
385
|
-
except OSError:
|
|
386
|
-
if self._worker_process.is_alive():
|
|
387
|
-
self._worker_process.kill()
|
|
405
|
+
if self._worker_process.is_alive():
|
|
406
|
+
assert self.address
|
|
407
|
+
channel = grpc.aio.insecure_channel(self.address)
|
|
408
|
+
stub = pb.worker.WorkerStub(channel)
|
|
409
|
+
await stub.stop(pb.worker.StopRequest(timeout=timeout))
|
|
388
410
|
|
|
389
411
|
|
|
390
412
|
class WorkerProcess(Process):
|
|
391
|
-
"""A :
|
|
413
|
+
"""A :class:`multiprocessing.Process` that runs a gRPC worker
|
|
392
414
|
server.
|
|
393
415
|
|
|
394
|
-
:
|
|
416
|
+
:class:`WorkerProcess` creates an isolated Python process that hosts a
|
|
395
417
|
gRPC server for executing distributed tasks. Each process maintains
|
|
396
418
|
its own event loop and serves as an independent worker node in the
|
|
397
419
|
wool distributed runtime.
|
|
398
420
|
|
|
421
|
+
:param host:
|
|
422
|
+
Host address where the gRPC server will listen.
|
|
399
423
|
:param port:
|
|
400
|
-
|
|
401
|
-
|
|
424
|
+
Port number where the gRPC server will listen. If 0, a random
|
|
425
|
+
available port will be selected.
|
|
426
|
+
:param shutdown_grace_period:
|
|
427
|
+
Graceful shutdown timeout for the gRPC server in seconds.
|
|
428
|
+
:param proxy_pool_ttl:
|
|
429
|
+
Time-to-live for the proxy resource pool in seconds.
|
|
402
430
|
:param args:
|
|
403
431
|
Additional positional arguments passed to the parent
|
|
404
|
-
:
|
|
432
|
+
:class:`multiprocessing.Process` class.
|
|
405
433
|
:param kwargs:
|
|
406
434
|
Additional keyword arguments passed to the parent
|
|
407
|
-
:
|
|
435
|
+
:class:`multiprocessing.Process` class.
|
|
408
436
|
|
|
409
437
|
.. attribute:: address
|
|
410
438
|
The network address where the gRPC server is listening.
|
|
@@ -413,8 +441,18 @@ class WorkerProcess(Process):
|
|
|
413
441
|
_port: int | None
|
|
414
442
|
_get_port: Connection
|
|
415
443
|
_set_port: Connection
|
|
444
|
+
_shutdown_grace_period: float
|
|
445
|
+
_proxy_pool_ttl: float
|
|
416
446
|
|
|
417
|
-
def __init__(
|
|
447
|
+
def __init__(
|
|
448
|
+
self,
|
|
449
|
+
*args,
|
|
450
|
+
host: str = "127.0.0.1",
|
|
451
|
+
port: int = 0,
|
|
452
|
+
shutdown_grace_period: float = 60.0,
|
|
453
|
+
proxy_pool_ttl: float = 60.0,
|
|
454
|
+
**kwargs,
|
|
455
|
+
):
|
|
418
456
|
super().__init__(*args, **kwargs)
|
|
419
457
|
if not host:
|
|
420
458
|
raise ValueError("Host must be a non-blank string")
|
|
@@ -422,6 +460,12 @@ class WorkerProcess(Process):
|
|
|
422
460
|
if port < 0:
|
|
423
461
|
raise ValueError("Port must be a positive integer")
|
|
424
462
|
self._port = port
|
|
463
|
+
if shutdown_grace_period <= 0:
|
|
464
|
+
raise ValueError("Shutdown grace period must be positive")
|
|
465
|
+
self._shutdown_grace_period = shutdown_grace_period
|
|
466
|
+
if proxy_pool_ttl <= 0:
|
|
467
|
+
raise ValueError("Proxy pool TTL must be positive")
|
|
468
|
+
self._proxy_pool_ttl = proxy_pool_ttl
|
|
425
469
|
self._get_port, self._set_port = Pipe(duplex=False)
|
|
426
470
|
|
|
427
471
|
@property
|
|
@@ -451,26 +495,31 @@ class WorkerProcess(Process):
|
|
|
451
495
|
"""
|
|
452
496
|
return self._port or None
|
|
453
497
|
|
|
454
|
-
def start(self):
|
|
498
|
+
def start(self, *, timeout: float | None = None):
|
|
455
499
|
"""Start the worker process.
|
|
456
500
|
|
|
457
501
|
Launches the worker process and waits until it has started
|
|
458
502
|
listening on a port. After starting, the :attr:`address`
|
|
459
503
|
property will contain the actual network address.
|
|
460
504
|
|
|
505
|
+
:param timeout:
|
|
506
|
+
Maximum time in seconds to wait for worker process startup.
|
|
461
507
|
:raises RuntimeError:
|
|
462
|
-
If the worker process fails to start within
|
|
508
|
+
If the worker process fails to start within the timeout.
|
|
463
509
|
:raises ValueError:
|
|
464
|
-
If the
|
|
510
|
+
If the timeout is not positive.
|
|
465
511
|
"""
|
|
512
|
+
if timeout is not None and timeout <= 0:
|
|
513
|
+
raise ValueError("Timeout must be positive")
|
|
466
514
|
super().start()
|
|
467
|
-
|
|
468
|
-
if self._get_port.poll(timeout=10): # 10 second timeout
|
|
515
|
+
if self._get_port.poll(timeout=timeout):
|
|
469
516
|
self._port = self._get_port.recv()
|
|
470
517
|
else:
|
|
471
518
|
self.terminate()
|
|
472
519
|
self.join()
|
|
473
|
-
raise RuntimeError(
|
|
520
|
+
raise RuntimeError(
|
|
521
|
+
f"Worker process failed to start within {timeout} seconds"
|
|
522
|
+
)
|
|
474
523
|
self._get_port.close()
|
|
475
524
|
|
|
476
525
|
def run(self) -> None:
|
|
@@ -510,7 +559,11 @@ class WorkerProcess(Process):
|
|
|
510
559
|
pass
|
|
511
560
|
|
|
512
561
|
wool.__proxy_pool__.set(
|
|
513
|
-
ResourcePool(
|
|
562
|
+
ResourcePool(
|
|
563
|
+
factory=proxy_factory,
|
|
564
|
+
finalizer=proxy_finalizer,
|
|
565
|
+
ttl=self._proxy_pool_ttl,
|
|
566
|
+
)
|
|
514
567
|
)
|
|
515
568
|
asyncio.run(self._serve())
|
|
516
569
|
|
|
@@ -535,7 +588,7 @@ class WorkerProcess(Process):
|
|
|
535
588
|
self._set_port.close()
|
|
536
589
|
await service.stopped.wait()
|
|
537
590
|
finally:
|
|
538
|
-
await server.stop(grace=
|
|
591
|
+
await server.stop(grace=self._shutdown_grace_period)
|
|
539
592
|
|
|
540
593
|
def _address(self, host, port) -> str:
|
|
541
594
|
"""Format network address for the given port.
|
|
@@ -546,367 +599,3 @@ class WorkerProcess(Process):
|
|
|
546
599
|
Address string in "host:port" format.
|
|
547
600
|
"""
|
|
548
601
|
return f"{host}:{port}"
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
class WorkerService(pb.worker.WorkerServicer):
|
|
552
|
-
"""gRPC service implementation for executing distributed wool tasks.
|
|
553
|
-
|
|
554
|
-
:py:class:`WorkerService` implements the gRPC WorkerServicer
|
|
555
|
-
interface, providing remote procedure calls for task scheduling
|
|
556
|
-
and worker lifecycle management. Tasks are executed in the same
|
|
557
|
-
asyncio event loop as the gRPC server.
|
|
558
|
-
|
|
559
|
-
.. note::
|
|
560
|
-
Tasks are executed asynchronously in the current event loop
|
|
561
|
-
and results are serialized for transport back to the client.
|
|
562
|
-
The service maintains a set of running tasks for proper
|
|
563
|
-
lifecycle management during shutdown.
|
|
564
|
-
|
|
565
|
-
During shutdown, the service stops accepting new requests
|
|
566
|
-
immediately when the :meth:`stop` RPC is called, returning
|
|
567
|
-
UNAVAILABLE errors to new :meth:`dispatch` requests while
|
|
568
|
-
allowing existing tasks to complete gracefully.
|
|
569
|
-
|
|
570
|
-
The service provides :attr:`stopping` and
|
|
571
|
-
:attr:`stopped` properties to access the internal shutdown
|
|
572
|
-
state events.
|
|
573
|
-
"""
|
|
574
|
-
|
|
575
|
-
_tasks: set[asyncio.Task]
|
|
576
|
-
_stopped: asyncio.Event
|
|
577
|
-
_stopping: asyncio.Event
|
|
578
|
-
_task_completed: asyncio.Event
|
|
579
|
-
|
|
580
|
-
def __init__(self):
|
|
581
|
-
self._stopped = asyncio.Event()
|
|
582
|
-
self._stopping = asyncio.Event()
|
|
583
|
-
self._task_completed = asyncio.Event()
|
|
584
|
-
self._tasks = set()
|
|
585
|
-
|
|
586
|
-
@property
|
|
587
|
-
def stopping(self) -> asyncio.Event:
|
|
588
|
-
"""Event signaling that the service is stopping.
|
|
589
|
-
|
|
590
|
-
:returns:
|
|
591
|
-
An :py:class:`asyncio.Event` that is set when the service
|
|
592
|
-
begins shutdown.
|
|
593
|
-
"""
|
|
594
|
-
return self._stopping
|
|
595
|
-
|
|
596
|
-
@property
|
|
597
|
-
def stopped(self) -> asyncio.Event:
|
|
598
|
-
"""Event signaling that the service has stopped.
|
|
599
|
-
|
|
600
|
-
:returns:
|
|
601
|
-
An :py:class:`asyncio.Event` that is set when the service
|
|
602
|
-
has completed shutdown.
|
|
603
|
-
"""
|
|
604
|
-
return self._stopped
|
|
605
|
-
|
|
606
|
-
@contextmanager
|
|
607
|
-
def _running(self, wool_task: WoolTask):
|
|
608
|
-
"""Context manager for tracking running tasks.
|
|
609
|
-
|
|
610
|
-
Manages the lifecycle of a task execution, adding it to the
|
|
611
|
-
active tasks set and emitting appropriate events. Ensures
|
|
612
|
-
proper cleanup when the task completes or fails.
|
|
613
|
-
|
|
614
|
-
:param wool_task:
|
|
615
|
-
The :py:class:`WoolTask` instance to execute and track.
|
|
616
|
-
:yields:
|
|
617
|
-
The :py:class:`asyncio.Task` created for the wool task.
|
|
618
|
-
|
|
619
|
-
.. note::
|
|
620
|
-
Emits a :py:class:`WoolTaskEvent` with type "task-scheduled"
|
|
621
|
-
when the task begins execution.
|
|
622
|
-
"""
|
|
623
|
-
WoolTaskEvent("task-scheduled", task=wool_task).emit()
|
|
624
|
-
task = asyncio.create_task(wool_task.run())
|
|
625
|
-
self._tasks.add(task)
|
|
626
|
-
try:
|
|
627
|
-
yield task
|
|
628
|
-
finally:
|
|
629
|
-
self._tasks.remove(task)
|
|
630
|
-
|
|
631
|
-
async def dispatch(
|
|
632
|
-
self, request: pb.task.Task, context: ServicerContext
|
|
633
|
-
) -> AsyncIterator[pb.worker.Response]:
|
|
634
|
-
"""Execute a task in the current event loop.
|
|
635
|
-
|
|
636
|
-
Deserializes the incoming task into a :py:class:`WoolTask`
|
|
637
|
-
instance, schedules it for execution in the current asyncio
|
|
638
|
-
event loop, and yields responses for acknowledgment and result.
|
|
639
|
-
|
|
640
|
-
:param request:
|
|
641
|
-
The protobuf task message containing the serialized task
|
|
642
|
-
data.
|
|
643
|
-
:param context:
|
|
644
|
-
The :py:class:`grpc.aio.ServicerContext` for this request.
|
|
645
|
-
:yields:
|
|
646
|
-
First yields an Ack Response when task processing begins,
|
|
647
|
-
then yields a Response containing the task result.
|
|
648
|
-
|
|
649
|
-
.. note::
|
|
650
|
-
Emits a :py:class:`WoolTaskEvent` when the task is
|
|
651
|
-
scheduled for execution.
|
|
652
|
-
"""
|
|
653
|
-
if self._stopping.is_set():
|
|
654
|
-
await context.abort(
|
|
655
|
-
StatusCode.UNAVAILABLE, "Worker service is shutting down"
|
|
656
|
-
)
|
|
657
|
-
|
|
658
|
-
with self._running(WoolTask.from_protobuf(request)) as task:
|
|
659
|
-
# Yield acknowledgment that task was received and processing is starting
|
|
660
|
-
yield pb.worker.Response(ack=pb.worker.Ack())
|
|
661
|
-
|
|
662
|
-
try:
|
|
663
|
-
result = pb.task.Result(dump=cloudpickle.dumps(await task))
|
|
664
|
-
yield pb.worker.Response(result=result)
|
|
665
|
-
except Exception as e:
|
|
666
|
-
exception = pb.task.Exception(dump=cloudpickle.dumps(e))
|
|
667
|
-
yield pb.worker.Response(exception=exception)
|
|
668
|
-
|
|
669
|
-
async def stop(
|
|
670
|
-
self, request: pb.worker.StopRequest, context: ServicerContext
|
|
671
|
-
) -> pb.worker.Void:
|
|
672
|
-
"""Stop the worker service and its thread.
|
|
673
|
-
|
|
674
|
-
Gracefully shuts down the worker thread and signals the server
|
|
675
|
-
to stop accepting new requests. This method is idempotent and
|
|
676
|
-
can be called multiple times safely.
|
|
677
|
-
|
|
678
|
-
:param request:
|
|
679
|
-
The protobuf stop request containing the wait timeout.
|
|
680
|
-
:param context:
|
|
681
|
-
The :py:class:`grpc.aio.ServicerContext` for this request.
|
|
682
|
-
:returns:
|
|
683
|
-
An empty protobuf response indicating completion.
|
|
684
|
-
"""
|
|
685
|
-
if self._stopping.is_set():
|
|
686
|
-
return pb.worker.Void()
|
|
687
|
-
await self._stop(timeout=request.wait)
|
|
688
|
-
return pb.worker.Void()
|
|
689
|
-
|
|
690
|
-
async def _stop(self, *, timeout: float | None = 0) -> None:
|
|
691
|
-
self._stopping.set()
|
|
692
|
-
await self._await_or_cancel_tasks(timeout=timeout)
|
|
693
|
-
|
|
694
|
-
# Clean up the session cache to prevent issues during shutdown
|
|
695
|
-
try:
|
|
696
|
-
proxy_pool = wool.__proxy_pool__.get()
|
|
697
|
-
assert proxy_pool
|
|
698
|
-
await proxy_pool.clear()
|
|
699
|
-
finally:
|
|
700
|
-
self._stopped.set()
|
|
701
|
-
|
|
702
|
-
async def _await_or_cancel_tasks(self, *, timeout: float | None = 0) -> None:
|
|
703
|
-
"""Stop the worker service gracefully.
|
|
704
|
-
|
|
705
|
-
Gracefully shuts down the worker service by canceling or waiting
|
|
706
|
-
for running tasks. This method is idempotent and can be called
|
|
707
|
-
multiple times safely.
|
|
708
|
-
|
|
709
|
-
:param timeout:
|
|
710
|
-
Maximum time to wait for tasks to complete. If 0 (default),
|
|
711
|
-
tasks are canceled immediately. If None, waits indefinitely.
|
|
712
|
-
If a positive number, waits for that many seconds before
|
|
713
|
-
canceling tasks.
|
|
714
|
-
|
|
715
|
-
.. note::
|
|
716
|
-
If a timeout occurs while waiting for tasks to complete,
|
|
717
|
-
the method recursively calls itself with a timeout of 0
|
|
718
|
-
to cancel all remaining tasks immediately.
|
|
719
|
-
"""
|
|
720
|
-
if self._tasks and timeout == 0:
|
|
721
|
-
await self._cancel(*self._tasks)
|
|
722
|
-
elif self._tasks:
|
|
723
|
-
try:
|
|
724
|
-
await asyncio.wait_for(
|
|
725
|
-
asyncio.gather(*self._tasks, return_exceptions=True),
|
|
726
|
-
timeout=timeout,
|
|
727
|
-
)
|
|
728
|
-
except asyncio.TimeoutError:
|
|
729
|
-
return await self._await_or_cancel_tasks(timeout=0)
|
|
730
|
-
|
|
731
|
-
async def _cancel(self, *tasks: asyncio.Task):
|
|
732
|
-
"""Cancel multiple tasks safely.
|
|
733
|
-
|
|
734
|
-
Cancels the provided tasks while performing safety checks to
|
|
735
|
-
avoid canceling the current task or already completed tasks.
|
|
736
|
-
Waits for all cancelled tasks to complete in parallel and handles
|
|
737
|
-
cancellation exceptions.
|
|
738
|
-
|
|
739
|
-
:param tasks:
|
|
740
|
-
The :py:class:`asyncio.Task` instances to cancel.
|
|
741
|
-
|
|
742
|
-
.. note::
|
|
743
|
-
This method performs the following safety checks:
|
|
744
|
-
- Avoids canceling the current task (would cause deadlock)
|
|
745
|
-
- Only cancels tasks that are not already done
|
|
746
|
-
- Properly handles :py:exc:`asyncio.CancelledError`
|
|
747
|
-
exceptions.
|
|
748
|
-
"""
|
|
749
|
-
current = asyncio.current_task()
|
|
750
|
-
to_cancel = [task for task in tasks if not task.done() and task != current]
|
|
751
|
-
|
|
752
|
-
# Cancel all tasks first
|
|
753
|
-
for task in to_cancel:
|
|
754
|
-
task.cancel()
|
|
755
|
-
|
|
756
|
-
# Wait for all cancelled tasks in parallel
|
|
757
|
-
if to_cancel:
|
|
758
|
-
await asyncio.gather(*to_cancel, return_exceptions=True)
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
DispatchCall: TypeAlias = grpc.aio.UnaryStreamCall[pb.task.Task, pb.worker.Response]
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
@asynccontextmanager
|
|
765
|
-
async def with_timeout(context, timeout):
|
|
766
|
-
"""Async context manager wrapper that adds timeout to context entry.
|
|
767
|
-
|
|
768
|
-
:param context:
|
|
769
|
-
The async context manager to wrap.
|
|
770
|
-
:param timeout:
|
|
771
|
-
Timeout in seconds for context entry.
|
|
772
|
-
:yields:
|
|
773
|
-
Control to the calling context.
|
|
774
|
-
:raises asyncio.TimeoutError:
|
|
775
|
-
If context entry exceeds the timeout.
|
|
776
|
-
"""
|
|
777
|
-
await asyncio.wait_for(context.__aenter__(), timeout=timeout)
|
|
778
|
-
exception_type = exception_value = exception_traceback = None
|
|
779
|
-
try:
|
|
780
|
-
yield
|
|
781
|
-
except BaseException as exception:
|
|
782
|
-
exception_type = type(exception)
|
|
783
|
-
exception_value = exception
|
|
784
|
-
exception_traceback = exception.__traceback__
|
|
785
|
-
raise
|
|
786
|
-
finally:
|
|
787
|
-
await context.__aexit__(exception_type, exception_value, exception_traceback)
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
T = TypeVar("T")
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
class DispatchStream(Generic[T]):
|
|
794
|
-
"""Async iterator wrapper for streaming dispatch results.
|
|
795
|
-
|
|
796
|
-
Simplified wrapper that focuses solely on stream iteration and response handling.
|
|
797
|
-
Channel management is now handled by the WorkerClient.
|
|
798
|
-
"""
|
|
799
|
-
|
|
800
|
-
def __init__(self, stream: DispatchCall):
|
|
801
|
-
"""Initialize the streaming dispatch result wrapper.
|
|
802
|
-
|
|
803
|
-
:param stream:
|
|
804
|
-
The underlying gRPC response stream.
|
|
805
|
-
"""
|
|
806
|
-
self._stream = stream
|
|
807
|
-
self._iter = aiter(stream)
|
|
808
|
-
|
|
809
|
-
def __aiter__(self) -> AsyncIterator[T]:
|
|
810
|
-
"""Return self as the async iterator."""
|
|
811
|
-
return self
|
|
812
|
-
|
|
813
|
-
async def __anext__(self) -> T:
|
|
814
|
-
"""Get the next response from the stream.
|
|
815
|
-
|
|
816
|
-
:returns:
|
|
817
|
-
The next task result from the worker.
|
|
818
|
-
:raises StopAsyncIteration:
|
|
819
|
-
When the stream is exhausted.
|
|
820
|
-
"""
|
|
821
|
-
try:
|
|
822
|
-
response = await anext(self._iter)
|
|
823
|
-
if response.HasField("result"):
|
|
824
|
-
return cloudpickle.loads(response.result.dump)
|
|
825
|
-
elif response.HasField("exception"):
|
|
826
|
-
raise cloudpickle.loads(response.exception.dump)
|
|
827
|
-
else:
|
|
828
|
-
raise RuntimeError(f"Received unexpected response: {response}")
|
|
829
|
-
except Exception as exception:
|
|
830
|
-
await self._handle_exception(exception)
|
|
831
|
-
|
|
832
|
-
async def _handle_exception(self, exception):
|
|
833
|
-
try:
|
|
834
|
-
self._stream.cancel()
|
|
835
|
-
except Exception as cancel_exception:
|
|
836
|
-
raise cancel_exception from exception
|
|
837
|
-
else:
|
|
838
|
-
raise exception
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
class WorkerClient:
|
|
842
|
-
"""Client for dispatching tasks to a specific worker.
|
|
843
|
-
|
|
844
|
-
Simplified client that maintains a persistent gRPC channel to a single
|
|
845
|
-
worker. The client manages the channel lifecycle and provides task
|
|
846
|
-
dispatch functionality with proper error handling.
|
|
847
|
-
|
|
848
|
-
:param address:
|
|
849
|
-
The network address of the target worker in "host:port" format.
|
|
850
|
-
"""
|
|
851
|
-
|
|
852
|
-
def __init__(self, address: str):
|
|
853
|
-
self._channel = grpc.aio.insecure_channel(
|
|
854
|
-
address,
|
|
855
|
-
# options=[
|
|
856
|
-
# ("grpc.keepalive_time_ms", 10000),
|
|
857
|
-
# ("grpc.keepalive_timeout_ms", 5000),
|
|
858
|
-
# ("grpc.http2.max_pings_without_data", 0),
|
|
859
|
-
# ("grpc.http2.min_time_between_pings_ms", 10000),
|
|
860
|
-
# ("grpc.max_receive_message_length", 100 * 1024 * 1024),
|
|
861
|
-
# ("grpc.max_send_message_length", 100 * 1024 * 1024),
|
|
862
|
-
# ],
|
|
863
|
-
)
|
|
864
|
-
self._stub = pb.worker.WorkerStub(self._channel)
|
|
865
|
-
self._semaphore = asyncio.Semaphore(100)
|
|
866
|
-
|
|
867
|
-
async def dispatch(self, task: WoolTask) -> AsyncIterator[pb.task.Result]:
|
|
868
|
-
"""Dispatch task to worker with on-demand channel acquisition.
|
|
869
|
-
|
|
870
|
-
Acquires a channel from the global channel pool, creates a WorkerStub,
|
|
871
|
-
dispatches the task, and verifies the first response is an Ack.
|
|
872
|
-
The channel is automatically managed by the underlying infrastructure.
|
|
873
|
-
|
|
874
|
-
:param task:
|
|
875
|
-
The WoolTask to dispatch to the worker.
|
|
876
|
-
:returns:
|
|
877
|
-
A DispatchStream for reading task results.
|
|
878
|
-
:raises RuntimeError:
|
|
879
|
-
If the worker doesn't acknowledge the task.
|
|
880
|
-
"""
|
|
881
|
-
async with with_timeout(self._semaphore, timeout=60):
|
|
882
|
-
call: DispatchCall = self._stub.dispatch(task.to_protobuf())
|
|
883
|
-
|
|
884
|
-
try:
|
|
885
|
-
first_response = await asyncio.wait_for(anext(aiter(call)), timeout=60)
|
|
886
|
-
if not first_response.HasField("ack"):
|
|
887
|
-
raise UnexpectedResponse("Expected Ack response")
|
|
888
|
-
except (
|
|
889
|
-
asyncio.CancelledError,
|
|
890
|
-
asyncio.TimeoutError,
|
|
891
|
-
grpc.aio.AioRpcError,
|
|
892
|
-
UnexpectedResponse,
|
|
893
|
-
):
|
|
894
|
-
try:
|
|
895
|
-
call.cancel()
|
|
896
|
-
except Exception:
|
|
897
|
-
pass
|
|
898
|
-
raise
|
|
899
|
-
|
|
900
|
-
async for result in DispatchStream(call):
|
|
901
|
-
yield result
|
|
902
|
-
|
|
903
|
-
async def stop(self):
|
|
904
|
-
"""Stop the client and close the gRPC channel.
|
|
905
|
-
|
|
906
|
-
Gracefully closes the underlying gRPC channel and cleans up
|
|
907
|
-
any resources associated with this client.
|
|
908
|
-
"""
|
|
909
|
-
await self._channel.close()
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
class UnexpectedResponse(Exception): ...
|