wool 0.1rc20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wool/__init__.py +122 -0
- wool/_context.py +29 -0
- wool/_protobuf/worker.py +26 -0
- wool/_resource_pool.py +376 -0
- wool/_typing.py +7 -0
- wool/_undefined.py +11 -0
- wool/_work.py +554 -0
- wool/core/__init__.py +0 -0
- wool/core/discovery/__init__.py +0 -0
- wool/core/discovery/base.py +249 -0
- wool/core/discovery/lan.py +534 -0
- wool/core/discovery/local.py +822 -0
- wool/core/loadbalancer/__init__.py +0 -0
- wool/core/loadbalancer/base.py +125 -0
- wool/core/loadbalancer/roundrobin.py +101 -0
- wool/core/protobuf/__init__.py +18 -0
- wool/core/protobuf/exception.py +3 -0
- wool/core/protobuf/task.py +11 -0
- wool/core/protobuf/task_pb2.py +42 -0
- wool/core/protobuf/task_pb2.pyi +43 -0
- wool/core/protobuf/task_pb2_grpc.py +24 -0
- wool/core/protobuf/worker.py +26 -0
- wool/core/protobuf/worker_pb2.py +53 -0
- wool/core/protobuf/worker_pb2.pyi +65 -0
- wool/core/protobuf/worker_pb2_grpc.py +141 -0
- wool/core/typing.py +22 -0
- wool/core/worker/__init__.py +0 -0
- wool/core/worker/base.py +300 -0
- wool/core/worker/connection.py +250 -0
- wool/core/worker/local.py +148 -0
- wool/core/worker/pool.py +386 -0
- wool/core/worker/process.py +249 -0
- wool/core/worker/proxy.py +427 -0
- wool/core/worker/service.py +231 -0
- wool-0.1rc20.dist-info/METADATA +463 -0
- wool-0.1rc20.dist-info/RECORD +38 -0
- wool-0.1rc20.dist-info/WHEEL +4 -0
- wool-0.1rc20.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from types import MappingProxyType
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import grpc.aio
|
|
8
|
+
|
|
9
|
+
import wool
|
|
10
|
+
from wool.core import protobuf as pb
|
|
11
|
+
from wool.core.discovery.base import WorkerInfo
|
|
12
|
+
from wool.core.worker.base import Worker
|
|
13
|
+
from wool.core.worker.process import WorkerProcess
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# public
|
|
17
|
+
class LocalWorker(Worker):
|
|
18
|
+
"""Worker running in a local subprocess.
|
|
19
|
+
|
|
20
|
+
Spawns a dedicated process hosting a gRPC server for task execution.
|
|
21
|
+
Handles multiple concurrent tasks in an isolated asyncio event loop.
|
|
22
|
+
|
|
23
|
+
**Basic usage:**
|
|
24
|
+
|
|
25
|
+
.. code-block:: python
|
|
26
|
+
|
|
27
|
+
worker = LocalWorker("gpu-capable")
|
|
28
|
+
await worker.start()
|
|
29
|
+
# Worker is now accepting tasks
|
|
30
|
+
await worker.stop()
|
|
31
|
+
|
|
32
|
+
**Custom configuration:**
|
|
33
|
+
|
|
34
|
+
.. code-block:: python
|
|
35
|
+
|
|
36
|
+
worker = LocalWorker(
|
|
37
|
+
"production",
|
|
38
|
+
"high-memory",
|
|
39
|
+
host="0.0.0.0", # Listen on all interfaces
|
|
40
|
+
port=50051, # Fixed port
|
|
41
|
+
shutdown_grace_period=30.0,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
:param tags:
|
|
45
|
+
Capability tags for filtering and selection.
|
|
46
|
+
:param host:
|
|
47
|
+
Host address to bind. Defaults to localhost.
|
|
48
|
+
:param port:
|
|
49
|
+
Port to bind. 0 for random available port.
|
|
50
|
+
:param shutdown_grace_period:
|
|
51
|
+
Graceful shutdown timeout in seconds.
|
|
52
|
+
:param proxy_pool_ttl:
|
|
53
|
+
Proxy pool TTL in seconds.
|
|
54
|
+
:param extra:
|
|
55
|
+
Additional metadata as key-value pairs.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
_worker_process: WorkerProcess
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
*tags: str,
|
|
63
|
+
host: str = "127.0.0.1",
|
|
64
|
+
port: int = 0,
|
|
65
|
+
shutdown_grace_period: float = 60.0,
|
|
66
|
+
proxy_pool_ttl: float = 60.0,
|
|
67
|
+
**extra: Any,
|
|
68
|
+
):
|
|
69
|
+
super().__init__(*tags, **extra)
|
|
70
|
+
self._worker_process = WorkerProcess(
|
|
71
|
+
host=host,
|
|
72
|
+
port=port,
|
|
73
|
+
shutdown_grace_period=shutdown_grace_period,
|
|
74
|
+
proxy_pool_ttl=proxy_pool_ttl,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def address(self) -> str | None:
|
|
79
|
+
"""The network address where the worker is listening.
|
|
80
|
+
|
|
81
|
+
:returns:
|
|
82
|
+
The address in "host:port" format, or None if not started.
|
|
83
|
+
"""
|
|
84
|
+
return self._worker_process.address
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def host(self) -> str | None:
|
|
88
|
+
"""The host where the worker is listening.
|
|
89
|
+
|
|
90
|
+
:returns:
|
|
91
|
+
The host address, or None if not started.
|
|
92
|
+
"""
|
|
93
|
+
return self._info.host if self._info else None
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def port(self) -> int | None:
|
|
97
|
+
"""The port where the worker is listening.
|
|
98
|
+
|
|
99
|
+
:returns:
|
|
100
|
+
The port number, or None if not started.
|
|
101
|
+
"""
|
|
102
|
+
return self._info.port if self._info else None
|
|
103
|
+
|
|
104
|
+
async def _start(self, timeout: float | None):
|
|
105
|
+
"""Start the worker process and register it with the pool.
|
|
106
|
+
|
|
107
|
+
Initializes the registrar service, starts the worker process
|
|
108
|
+
with its gRPC server, and registers the worker's network
|
|
109
|
+
address with the registrar for discovery by client sessions.
|
|
110
|
+
|
|
111
|
+
:param timeout:
|
|
112
|
+
Maximum time in seconds to wait for worker process startup.
|
|
113
|
+
"""
|
|
114
|
+
loop = asyncio.get_running_loop()
|
|
115
|
+
await loop.run_in_executor(
|
|
116
|
+
None, lambda t: self._worker_process.start(timeout=t), timeout
|
|
117
|
+
)
|
|
118
|
+
if not self._worker_process.address:
|
|
119
|
+
raise RuntimeError("Worker process failed to start - no address")
|
|
120
|
+
if not self._worker_process.pid:
|
|
121
|
+
raise RuntimeError("Worker process failed to start - no PID")
|
|
122
|
+
|
|
123
|
+
host, port_str = self._worker_process.address.split(":")
|
|
124
|
+
port = int(port_str)
|
|
125
|
+
|
|
126
|
+
self._info = WorkerInfo(
|
|
127
|
+
uid=self._uid,
|
|
128
|
+
host=host,
|
|
129
|
+
port=port,
|
|
130
|
+
pid=self._worker_process.pid,
|
|
131
|
+
version=wool.__version__,
|
|
132
|
+
tags=frozenset(self._tags),
|
|
133
|
+
extra=MappingProxyType(self._extra),
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
async def _stop(self, timeout: float | None):
|
|
137
|
+
"""Stop the worker process and unregister it from the pool.
|
|
138
|
+
|
|
139
|
+
Unregisters the worker from the registrar service, gracefully
|
|
140
|
+
shuts down the worker process using SIGINT, and cleans up
|
|
141
|
+
the registrar service. If graceful shutdown fails, the process
|
|
142
|
+
is forcefully terminated.
|
|
143
|
+
"""
|
|
144
|
+
if self._worker_process.is_alive():
|
|
145
|
+
assert self.address
|
|
146
|
+
channel = grpc.aio.insecure_channel(self.address)
|
|
147
|
+
stub = pb.worker.WorkerStub(channel)
|
|
148
|
+
await stub.stop(pb.worker.StopRequest(timeout=timeout))
|
wool/core/worker/pool.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
import uuid
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from typing import AsyncContextManager
|
|
9
|
+
from typing import Awaitable
|
|
10
|
+
from typing import ContextManager
|
|
11
|
+
from typing import Coroutine
|
|
12
|
+
from typing import Final
|
|
13
|
+
from typing import overload
|
|
14
|
+
|
|
15
|
+
from wool.core.discovery.base import DiscoveryLike
|
|
16
|
+
from wool.core.discovery.base import DiscoveryPublisherLike
|
|
17
|
+
from wool.core.discovery.local import LocalDiscovery
|
|
18
|
+
from wool.core.typing import Factory
|
|
19
|
+
from wool.core.worker.base import WorkerFactory
|
|
20
|
+
from wool.core.worker.base import WorkerLike
|
|
21
|
+
from wool.core.worker.local import LocalWorker
|
|
22
|
+
from wool.core.worker.proxy import LoadBalancerLike
|
|
23
|
+
from wool.core.worker.proxy import RoundRobinLoadBalancer
|
|
24
|
+
from wool.core.worker.proxy import WorkerProxy
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# public
|
|
28
|
+
class WorkerPool:
|
|
29
|
+
"""Orchestrates distributed workers for task execution.
|
|
30
|
+
|
|
31
|
+
The core of wool's distributed runtime. Manages worker lifecycle,
|
|
32
|
+
discovery, and load balancing across two modes:
|
|
33
|
+
|
|
34
|
+
**Ephemeral pools** spawn local workers automatically managed within the
|
|
35
|
+
pool's lifecycle. Perfect for development and single-machine deployments.
|
|
36
|
+
|
|
37
|
+
**Durable pools** connect to existing remote workers through discovery
|
|
38
|
+
services. Workers run independently, serving multiple clients across
|
|
39
|
+
distributed deployments.
|
|
40
|
+
|
|
41
|
+
**Basic ephemeral pool:**
|
|
42
|
+
|
|
43
|
+
.. code-block:: python
|
|
44
|
+
|
|
45
|
+
import wool
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@wool.work
|
|
49
|
+
async def fibonacci(n: int) -> int:
|
|
50
|
+
if n <= 1:
|
|
51
|
+
return n
|
|
52
|
+
a = await fibonacci(n - 1)
|
|
53
|
+
b = await fibonacci(n - 2)
|
|
54
|
+
return a + b
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
async with wool.WorkerPool() as pool:
|
|
58
|
+
result = await fibonacci(10)
|
|
59
|
+
|
|
60
|
+
**Ephemeral with tags:**
|
|
61
|
+
|
|
62
|
+
.. code-block:: python
|
|
63
|
+
|
|
64
|
+
async with WorkerPool("gpu-capable", size=4) as pool:
|
|
65
|
+
result = await gpu_task()
|
|
66
|
+
|
|
67
|
+
**Custom worker factory:**
|
|
68
|
+
|
|
69
|
+
.. code-block:: python
|
|
70
|
+
|
|
71
|
+
from functools import partial
|
|
72
|
+
|
|
73
|
+
worker_factory = partial(LocalWorker, host="0.0.0.0")
|
|
74
|
+
|
|
75
|
+
async with WorkerPool(size=8, worker=worker_factory) as pool:
|
|
76
|
+
result = await task()
|
|
77
|
+
|
|
78
|
+
**Durable pool:**
|
|
79
|
+
|
|
80
|
+
.. code-block:: python
|
|
81
|
+
|
|
82
|
+
from wool.core.discovery.lan import LanDiscovery
|
|
83
|
+
|
|
84
|
+
async with WorkerPool(discovery=LanDiscovery()) as pool:
|
|
85
|
+
result = await task()
|
|
86
|
+
|
|
87
|
+
**Filtered discovery:**
|
|
88
|
+
|
|
89
|
+
.. code-block:: python
|
|
90
|
+
|
|
91
|
+
discovery = LanDiscovery().subscribe(filter=lambda w: "production" in w.tags)
|
|
92
|
+
async with WorkerPool(discovery=discovery) as pool:
|
|
93
|
+
result = await task()
|
|
94
|
+
|
|
95
|
+
**Hybrid pool:**
|
|
96
|
+
|
|
97
|
+
.. code-block:: python
|
|
98
|
+
|
|
99
|
+
# Spawn local workers AND discover remote workers
|
|
100
|
+
async with WorkerPool(size=4, discovery=LanDiscovery()) as pool:
|
|
101
|
+
result = await task()
|
|
102
|
+
|
|
103
|
+
**Custom load balancer:**
|
|
104
|
+
|
|
105
|
+
.. code-block:: python
|
|
106
|
+
|
|
107
|
+
from wool.core.loadbalancer.roundrobin import RoundRobinLoadBalancer
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class PriorityBalancer(RoundRobinLoadBalancer):
|
|
111
|
+
async def dispatch(self, task, context, timeout=None):
|
|
112
|
+
# Custom routing logic
|
|
113
|
+
...
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
async with WorkerPool(loadbalancer=PriorityBalancer()) as pool:
|
|
117
|
+
result = await task()
|
|
118
|
+
|
|
119
|
+
**Custom discovery:**
|
|
120
|
+
|
|
121
|
+
.. code-block:: python
|
|
122
|
+
|
|
123
|
+
from contextlib import asynccontextmanager
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@asynccontextmanager
|
|
127
|
+
async def custom_discovery():
|
|
128
|
+
svc = await DatabaseDiscovery.connect()
|
|
129
|
+
try:
|
|
130
|
+
yield svc.subscribe()
|
|
131
|
+
finally:
|
|
132
|
+
await svc.close()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
async with WorkerPool(discovery=custom_discovery) as pool:
|
|
136
|
+
result = await task()
|
|
137
|
+
|
|
138
|
+
:param tags:
|
|
139
|
+
Capability tags for spawned workers.
|
|
140
|
+
:param size:
|
|
141
|
+
Number of workers to spawn (0 = CPU count).
|
|
142
|
+
:param worker:
|
|
143
|
+
Worker factory callable. Defaults to :class:`LocalWorker`.
|
|
144
|
+
:param loadbalancer:
|
|
145
|
+
Load balancer instance, factory, or context manager.
|
|
146
|
+
:param discovery:
|
|
147
|
+
Discovery service instance, factory, or context manager.
|
|
148
|
+
:raises ValueError:
|
|
149
|
+
If configuration is invalid or CPU count unavailable.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
_workers: Final[dict[WorkerLike, Coroutine]]
|
|
153
|
+
|
|
154
|
+
@overload
|
|
155
|
+
def __init__(
|
|
156
|
+
self,
|
|
157
|
+
*tags: str,
|
|
158
|
+
size: int = 0,
|
|
159
|
+
worker: WorkerFactory = LocalWorker,
|
|
160
|
+
discovery: DiscoveryLike | Factory[DiscoveryLike] | None = None,
|
|
161
|
+
loadbalancer: (
|
|
162
|
+
LoadBalancerLike | Factory[LoadBalancerLike]
|
|
163
|
+
) = RoundRobinLoadBalancer,
|
|
164
|
+
):
|
|
165
|
+
"""
|
|
166
|
+
Create an ephemeral pool of workers, spawning the specified quantity of workers
|
|
167
|
+
using the specified worker factory.
|
|
168
|
+
"""
|
|
169
|
+
...
|
|
170
|
+
|
|
171
|
+
@overload
|
|
172
|
+
def __init__(
|
|
173
|
+
self,
|
|
174
|
+
*,
|
|
175
|
+
discovery: DiscoveryLike | Factory[DiscoveryLike],
|
|
176
|
+
loadbalancer: (
|
|
177
|
+
LoadBalancerLike | Factory[LoadBalancerLike]
|
|
178
|
+
) = RoundRobinLoadBalancer,
|
|
179
|
+
):
|
|
180
|
+
"""
|
|
181
|
+
Connect to an existing pool of workers discovered by the specified discovery
|
|
182
|
+
protocol.
|
|
183
|
+
"""
|
|
184
|
+
...
|
|
185
|
+
|
|
186
|
+
def __init__(
|
|
187
|
+
self,
|
|
188
|
+
*tags: str,
|
|
189
|
+
size: int | None = None,
|
|
190
|
+
worker: WorkerFactory | None = None,
|
|
191
|
+
discovery: DiscoveryLike | Factory[DiscoveryLike] | None = None,
|
|
192
|
+
loadbalancer: (
|
|
193
|
+
LoadBalancerLike | Factory[LoadBalancerLike]
|
|
194
|
+
) = RoundRobinLoadBalancer,
|
|
195
|
+
):
|
|
196
|
+
self._workers = {}
|
|
197
|
+
|
|
198
|
+
match (size, discovery):
|
|
199
|
+
case (size, discovery) if size is not None and discovery is not None:
|
|
200
|
+
if size == 0:
|
|
201
|
+
cpu_count = os.cpu_count()
|
|
202
|
+
if cpu_count is None:
|
|
203
|
+
raise ValueError("Unable to determine CPU count")
|
|
204
|
+
size = cpu_count
|
|
205
|
+
elif size < 0:
|
|
206
|
+
raise ValueError("Size must be non-negative")
|
|
207
|
+
|
|
208
|
+
@asynccontextmanager
|
|
209
|
+
async def create_proxy():
|
|
210
|
+
discovery_svc, discovery_ctx = await self._enter_context(discovery)
|
|
211
|
+
if not isinstance(discovery_svc, DiscoveryLike):
|
|
212
|
+
raise TypeError(
|
|
213
|
+
f"Expected DiscoveryLike, got: {type(discovery_svc)}"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
async with self._worker_context(
|
|
218
|
+
*tags,
|
|
219
|
+
size=size,
|
|
220
|
+
factory=worker,
|
|
221
|
+
publisher=discovery_svc.publisher,
|
|
222
|
+
):
|
|
223
|
+
async with WorkerProxy(
|
|
224
|
+
discovery=discovery_svc.subscribe(_predicate(tags)),
|
|
225
|
+
loadbalancer=loadbalancer,
|
|
226
|
+
):
|
|
227
|
+
yield
|
|
228
|
+
finally:
|
|
229
|
+
await self._exit_context(discovery_ctx)
|
|
230
|
+
|
|
231
|
+
case (size, None) if size is not None:
|
|
232
|
+
if size == 0:
|
|
233
|
+
cpu_count = os.cpu_count()
|
|
234
|
+
if cpu_count is None:
|
|
235
|
+
raise ValueError("Unable to determine CPU count")
|
|
236
|
+
size = cpu_count
|
|
237
|
+
elif size < 0:
|
|
238
|
+
raise ValueError("Size must be non-negative")
|
|
239
|
+
|
|
240
|
+
namespace = f"pool-{uuid.uuid4().hex}"
|
|
241
|
+
|
|
242
|
+
@asynccontextmanager
|
|
243
|
+
async def create_proxy():
|
|
244
|
+
discovery = LocalDiscovery(namespace)
|
|
245
|
+
async with self._worker_context(
|
|
246
|
+
*tags,
|
|
247
|
+
size=size,
|
|
248
|
+
factory=worker,
|
|
249
|
+
publisher=discovery.publisher,
|
|
250
|
+
):
|
|
251
|
+
async with WorkerProxy(
|
|
252
|
+
discovery=discovery.subscribe(_predicate(tags)),
|
|
253
|
+
loadbalancer=loadbalancer,
|
|
254
|
+
):
|
|
255
|
+
yield
|
|
256
|
+
|
|
257
|
+
case (None, discovery) if discovery is not None:
|
|
258
|
+
|
|
259
|
+
@asynccontextmanager
|
|
260
|
+
async def create_proxy():
|
|
261
|
+
discovery_svc, discovery_ctx = await self._enter_context(discovery)
|
|
262
|
+
if not isinstance(discovery_svc, DiscoveryLike):
|
|
263
|
+
raise ValueError
|
|
264
|
+
try:
|
|
265
|
+
async with WorkerProxy(
|
|
266
|
+
discovery=discovery_svc.subscriber,
|
|
267
|
+
loadbalancer=loadbalancer,
|
|
268
|
+
):
|
|
269
|
+
yield
|
|
270
|
+
finally:
|
|
271
|
+
await self._exit_context(discovery_ctx)
|
|
272
|
+
|
|
273
|
+
case (None, None):
|
|
274
|
+
cpu_count = os.cpu_count()
|
|
275
|
+
if cpu_count is None:
|
|
276
|
+
raise ValueError("Unable to determine CPU count")
|
|
277
|
+
size = cpu_count
|
|
278
|
+
|
|
279
|
+
namespace = f"pool-{uuid.uuid4().hex}"
|
|
280
|
+
|
|
281
|
+
@asynccontextmanager
|
|
282
|
+
async def create_proxy():
|
|
283
|
+
discovery = LocalDiscovery(namespace)
|
|
284
|
+
async with self._worker_context(
|
|
285
|
+
*tags,
|
|
286
|
+
size=size,
|
|
287
|
+
factory=worker,
|
|
288
|
+
publisher=discovery.publisher,
|
|
289
|
+
):
|
|
290
|
+
async with WorkerProxy(
|
|
291
|
+
discovery=discovery.subscriber,
|
|
292
|
+
loadbalancer=loadbalancer,
|
|
293
|
+
):
|
|
294
|
+
yield
|
|
295
|
+
|
|
296
|
+
case _:
|
|
297
|
+
raise RuntimeError
|
|
298
|
+
|
|
299
|
+
self._proxy_factory = create_proxy
|
|
300
|
+
|
|
301
|
+
async def __aenter__(self) -> WorkerPool:
|
|
302
|
+
"""Starts the worker pool and its services, returning a session.
|
|
303
|
+
|
|
304
|
+
This method starts the worker registrar, creates a connection,
|
|
305
|
+
launches all worker processes, and registers them.
|
|
306
|
+
|
|
307
|
+
:returns:
|
|
308
|
+
The :class:`WorkerPool` instance itself for method chaining.
|
|
309
|
+
"""
|
|
310
|
+
self._proxy_context = self._proxy_factory()
|
|
311
|
+
await self._proxy_context.__aenter__()
|
|
312
|
+
return self
|
|
313
|
+
|
|
314
|
+
async def __aexit__(self, *args):
|
|
315
|
+
"""Stops all workers and tears down the pool and its services."""
|
|
316
|
+
await self._proxy_context.__aexit__(*args)
|
|
317
|
+
|
|
318
|
+
@asynccontextmanager
|
|
319
|
+
async def _worker_context(
|
|
320
|
+
self,
|
|
321
|
+
*tags: str,
|
|
322
|
+
size: int,
|
|
323
|
+
factory: WorkerFactory | None,
|
|
324
|
+
publisher: DiscoveryPublisherLike,
|
|
325
|
+
):
|
|
326
|
+
if factory is None:
|
|
327
|
+
factory = self._default_worker_factory()
|
|
328
|
+
publisher_svc, publisher_ctx = await self._enter_context(publisher)
|
|
329
|
+
if not isinstance(publisher_svc, DiscoveryPublisherLike):
|
|
330
|
+
raise ValueError
|
|
331
|
+
|
|
332
|
+
tasks = []
|
|
333
|
+
for _ in range(size):
|
|
334
|
+
worker = factory(*tags)
|
|
335
|
+
|
|
336
|
+
async def start(worker):
|
|
337
|
+
await worker.start()
|
|
338
|
+
await publisher.publish("worker-added", worker.info)
|
|
339
|
+
|
|
340
|
+
async def stop(worker):
|
|
341
|
+
await publisher.publish("worker-dropped", worker.info)
|
|
342
|
+
await worker.stop()
|
|
343
|
+
|
|
344
|
+
task = asyncio.create_task(start(worker))
|
|
345
|
+
tasks.append(task)
|
|
346
|
+
self._workers[worker] = stop(worker)
|
|
347
|
+
|
|
348
|
+
try:
|
|
349
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
350
|
+
yield [w.info for w in self._workers if w.info]
|
|
351
|
+
finally:
|
|
352
|
+
tasks = [asyncio.create_task(stop) for stop in self._workers.values()]
|
|
353
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
354
|
+
await self._exit_context(publisher_ctx)
|
|
355
|
+
|
|
356
|
+
def _default_worker_factory(self):
|
|
357
|
+
def factory(*tags, **_):
|
|
358
|
+
return LocalWorker(*tags)
|
|
359
|
+
|
|
360
|
+
return factory
|
|
361
|
+
|
|
362
|
+
async def _enter_context(self, factory):
|
|
363
|
+
ctx = None
|
|
364
|
+
if isinstance(factory, ContextManager):
|
|
365
|
+
ctx = factory
|
|
366
|
+
obj = ctx.__enter__()
|
|
367
|
+
elif isinstance(factory, AsyncContextManager):
|
|
368
|
+
ctx = factory
|
|
369
|
+
obj = await ctx.__aenter__()
|
|
370
|
+
elif callable(factory):
|
|
371
|
+
return await self._enter_context(factory())
|
|
372
|
+
elif isinstance(factory, Awaitable):
|
|
373
|
+
obj = await factory
|
|
374
|
+
else:
|
|
375
|
+
obj = factory
|
|
376
|
+
return obj, ctx
|
|
377
|
+
|
|
378
|
+
async def _exit_context(self, ctx: AsyncContextManager | ContextManager | None):
|
|
379
|
+
if isinstance(ctx, AsyncContextManager):
|
|
380
|
+
await ctx.__aexit__(*sys.exc_info())
|
|
381
|
+
elif isinstance(ctx, ContextManager):
|
|
382
|
+
ctx.__exit__(*sys.exc_info())
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def _predicate(tags):
|
|
386
|
+
return lambda w: bool(w.tags & set(tags)) if tags else True
|