wool 0.1rc13__py3-none-any.whl → 0.1rc15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wool might be problematic. Click here for more details.
- wool/__init__.py +1 -26
- wool/_connection.py +247 -0
- wool/_context.py +29 -0
- wool/_loadbalancer.py +213 -0
- wool/_protobuf/task_pb2_grpc.py +2 -2
- wool/_protobuf/worker_pb2.py +6 -6
- wool/_protobuf/worker_pb2.pyi +4 -4
- wool/_protobuf/worker_pb2_grpc.py +2 -2
- wool/_resource_pool.py +3 -3
- wool/_undefined.py +11 -0
- wool/_work.py +5 -4
- wool/_worker.py +115 -426
- wool/_worker_discovery.py +24 -36
- wool/_worker_pool.py +46 -31
- wool/_worker_proxy.py +54 -186
- wool/_worker_service.py +243 -0
- {wool-0.1rc13.dist-info → wool-0.1rc15.dist-info}/METADATA +153 -41
- wool-0.1rc15.dist-info/RECORD +27 -0
- wool-0.1rc13.dist-info/RECORD +0 -22
- {wool-0.1rc13.dist-info → wool-0.1rc15.dist-info}/WHEEL +0 -0
- {wool-0.1rc13.dist-info → wool-0.1rc15.dist-info}/entry_points.txt +0 -0
wool/_worker_discovery.py
CHANGED
|
@@ -11,7 +11,7 @@ from abc import abstractmethod
|
|
|
11
11
|
from collections import deque
|
|
12
12
|
from dataclasses import dataclass
|
|
13
13
|
from dataclasses import field
|
|
14
|
-
from
|
|
14
|
+
from types import MappingProxyType
|
|
15
15
|
from typing import AsyncContextManager
|
|
16
16
|
from typing import AsyncIterator
|
|
17
17
|
from typing import Awaitable
|
|
@@ -38,7 +38,7 @@ from zeroconf.asyncio import AsyncZeroconf
|
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
# public
|
|
41
|
-
@dataclass
|
|
41
|
+
@dataclass(frozen=True)
|
|
42
42
|
class WorkerInfo:
|
|
43
43
|
"""Properties and metadata for a worker instance.
|
|
44
44
|
|
|
@@ -56,21 +56,20 @@ class WorkerInfo:
|
|
|
56
56
|
:param version:
|
|
57
57
|
Version string of the worker software.
|
|
58
58
|
:param tags:
|
|
59
|
-
|
|
59
|
+
Frozenset of capability tags for worker filtering and selection.
|
|
60
60
|
:param extra:
|
|
61
|
-
Additional arbitrary metadata as key-value pairs.
|
|
61
|
+
Additional arbitrary metadata as immutable key-value pairs.
|
|
62
62
|
"""
|
|
63
63
|
|
|
64
64
|
uid: str
|
|
65
|
-
host: str
|
|
66
|
-
port: int | None
|
|
67
|
-
pid: int
|
|
68
|
-
version: str
|
|
69
|
-
tags:
|
|
70
|
-
extra:
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
return hash(self.uid)
|
|
65
|
+
host: str = field(hash=False)
|
|
66
|
+
port: int | None = field(hash=False)
|
|
67
|
+
pid: int = field(hash=False)
|
|
68
|
+
version: str = field(hash=False)
|
|
69
|
+
tags: frozenset[str] = field(default_factory=frozenset, hash=False)
|
|
70
|
+
extra: MappingProxyType = field(
|
|
71
|
+
default_factory=lambda: MappingProxyType({}), hash=False
|
|
72
|
+
)
|
|
74
73
|
|
|
75
74
|
|
|
76
75
|
# public
|
|
@@ -376,6 +375,7 @@ class Discovery(ABC):
|
|
|
376
375
|
when unpickled to support task-specific session contexts.
|
|
377
376
|
"""
|
|
378
377
|
|
|
378
|
+
_filter: PredicateFunction[WorkerInfo] | None
|
|
379
379
|
_started: bool
|
|
380
380
|
_service_cache: Dict[str, WorkerInfo]
|
|
381
381
|
|
|
@@ -399,6 +399,11 @@ class Discovery(ABC):
|
|
|
399
399
|
"""Delegate to the events async iterator."""
|
|
400
400
|
return await anext(self.events())
|
|
401
401
|
|
|
402
|
+
@final
|
|
403
|
+
@property
|
|
404
|
+
def filter(self) -> PredicateFunction[WorkerInfo] | None:
|
|
405
|
+
return self._filter
|
|
406
|
+
|
|
402
407
|
@final
|
|
403
408
|
@property
|
|
404
409
|
def started(self) -> bool:
|
|
@@ -991,46 +996,28 @@ class LocalRegistrar(Registrar):
|
|
|
991
996
|
|
|
992
997
|
_shared_memory: multiprocessing.shared_memory.SharedMemory | None = None
|
|
993
998
|
_uri: str
|
|
994
|
-
_created_shared_memory: bool = False
|
|
995
999
|
|
|
996
1000
|
def __init__(self, uri: str):
|
|
997
1001
|
super().__init__()
|
|
998
1002
|
self._uri = uri
|
|
999
|
-
self._created_shared_memory = False
|
|
1000
1003
|
|
|
1001
1004
|
async def _start(self) -> None:
|
|
1002
1005
|
"""Initialize shared memory for worker registration."""
|
|
1003
1006
|
if self._shared_memory is None:
|
|
1004
1007
|
# Try to connect to existing shared memory first, create if it doesn't exist
|
|
1005
1008
|
shared_memory_name = hashlib.sha256(self._uri.encode()).hexdigest()[:12]
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
)
|
|
1010
|
-
except FileNotFoundError:
|
|
1011
|
-
# Create new shared memory if it doesn't exist
|
|
1012
|
-
self._shared_memory = multiprocessing.shared_memory.SharedMemory(
|
|
1013
|
-
name=shared_memory_name,
|
|
1014
|
-
create=True,
|
|
1015
|
-
size=1024, # 1024 bytes = 256 worker slots (4 bytes per port)
|
|
1016
|
-
)
|
|
1017
|
-
self._created_shared_memory = True
|
|
1018
|
-
# Initialize all slots to 0 (empty)
|
|
1019
|
-
for i in range(len(self._shared_memory.buf)):
|
|
1020
|
-
self._shared_memory.buf[i] = 0
|
|
1009
|
+
self._shared_memory = multiprocessing.shared_memory.SharedMemory(
|
|
1010
|
+
name=shared_memory_name
|
|
1011
|
+
)
|
|
1021
1012
|
|
|
1022
1013
|
async def _stop(self) -> None:
|
|
1023
1014
|
"""Clean up shared memory resources."""
|
|
1024
1015
|
if self._shared_memory:
|
|
1025
1016
|
try:
|
|
1026
1017
|
self._shared_memory.close()
|
|
1027
|
-
# Unlink the shared memory if this registrar created it
|
|
1028
|
-
if self._created_shared_memory:
|
|
1029
|
-
self._shared_memory.unlink()
|
|
1030
1018
|
except Exception:
|
|
1031
1019
|
pass
|
|
1032
1020
|
self._shared_memory = None
|
|
1033
|
-
self._created_shared_memory = False
|
|
1034
1021
|
|
|
1035
1022
|
async def _register(self, worker_info: WorkerInfo) -> None:
|
|
1036
1023
|
"""Register a worker by writing its port to shared memory.
|
|
@@ -1138,7 +1125,6 @@ class LocalDiscovery(Discovery):
|
|
|
1138
1125
|
async def _start(self) -> None:
|
|
1139
1126
|
"""Starts monitoring shared memory for worker registrations."""
|
|
1140
1127
|
if self._shared_memory is None:
|
|
1141
|
-
# Try to connect to existing shared memory first
|
|
1142
1128
|
self._shared_memory = multiprocessing.shared_memory.SharedMemory(
|
|
1143
1129
|
name=hashlib.sha256(self._uri.encode()).hexdigest()[:12]
|
|
1144
1130
|
)
|
|
@@ -1172,7 +1158,9 @@ class LocalDiscovery(Discovery):
|
|
|
1172
1158
|
# Read current state from shared memory
|
|
1173
1159
|
if self._shared_memory:
|
|
1174
1160
|
for i in range(0, len(self._shared_memory.buf), 4):
|
|
1175
|
-
port = struct.unpack(
|
|
1161
|
+
port = struct.unpack(
|
|
1162
|
+
"I", bytes(self._shared_memory.buf[i : i + 4])
|
|
1163
|
+
)[0]
|
|
1176
1164
|
if port > 0: # Active worker
|
|
1177
1165
|
worker_info = WorkerInfo(
|
|
1178
1166
|
uid=f"worker-{port}",
|
wool/_worker_pool.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import atexit
|
|
4
5
|
import hashlib
|
|
5
6
|
import os
|
|
6
7
|
import uuid
|
|
8
|
+
from contextlib import asynccontextmanager
|
|
7
9
|
from multiprocessing.shared_memory import SharedMemory
|
|
8
10
|
from typing import Final
|
|
9
11
|
from typing import overload
|
|
@@ -172,7 +174,6 @@ class WorkerPool:
|
|
|
172
174
|
"""
|
|
173
175
|
|
|
174
176
|
_workers: Final[list[Worker]]
|
|
175
|
-
_shared_memory = None
|
|
176
177
|
|
|
177
178
|
@overload
|
|
178
179
|
def __init__(
|
|
@@ -226,19 +227,27 @@ class WorkerPool:
|
|
|
226
227
|
|
|
227
228
|
uri = f"pool-{uuid.uuid4().hex}"
|
|
228
229
|
|
|
230
|
+
@asynccontextmanager
|
|
229
231
|
async def create_proxy():
|
|
230
|
-
|
|
232
|
+
shared_memory_size = (size + 1) * 4
|
|
233
|
+
shared_memory = SharedMemory(
|
|
231
234
|
name=hashlib.sha256(uri.encode()).hexdigest()[:12],
|
|
232
235
|
create=True,
|
|
233
|
-
size=
|
|
234
|
-
)
|
|
235
|
-
for i in range(1024):
|
|
236
|
-
self._shared_memory.buf[i] = 0
|
|
237
|
-
await self._spawn_workers(uri, *tags, size=size, factory=worker)
|
|
238
|
-
return WorkerProxy(
|
|
239
|
-
discovery=LocalDiscovery(uri),
|
|
240
|
-
loadbalancer=loadbalancer,
|
|
236
|
+
size=shared_memory_size,
|
|
241
237
|
)
|
|
238
|
+
cleanup = atexit.register(lambda: shared_memory.unlink())
|
|
239
|
+
try:
|
|
240
|
+
for i in range(shared_memory_size):
|
|
241
|
+
shared_memory.buf[i] = 0
|
|
242
|
+
await self._spawn_workers(uri, *tags, size=size, factory=worker)
|
|
243
|
+
async with WorkerProxy(
|
|
244
|
+
discovery=LocalDiscovery(uri),
|
|
245
|
+
loadbalancer=loadbalancer,
|
|
246
|
+
):
|
|
247
|
+
yield
|
|
248
|
+
finally:
|
|
249
|
+
shared_memory.unlink()
|
|
250
|
+
atexit.unregister(cleanup)
|
|
242
251
|
|
|
243
252
|
case (size, None) if size is not None:
|
|
244
253
|
if size == 0:
|
|
@@ -251,27 +260,37 @@ class WorkerPool:
|
|
|
251
260
|
|
|
252
261
|
uri = f"pool-{uuid.uuid4().hex}"
|
|
253
262
|
|
|
263
|
+
@asynccontextmanager
|
|
254
264
|
async def create_proxy():
|
|
255
|
-
|
|
265
|
+
shared_memory_size = (size + 1) * 4
|
|
266
|
+
shared_memory = SharedMemory(
|
|
256
267
|
name=hashlib.sha256(uri.encode()).hexdigest()[:12],
|
|
257
268
|
create=True,
|
|
258
|
-
size=
|
|
259
|
-
)
|
|
260
|
-
for i in range(1024):
|
|
261
|
-
self._shared_memory.buf[i] = 0
|
|
262
|
-
await self._spawn_workers(uri, *tags, size=size, factory=worker)
|
|
263
|
-
return WorkerProxy(
|
|
264
|
-
discovery=LocalDiscovery(uri),
|
|
265
|
-
loadbalancer=loadbalancer,
|
|
269
|
+
size=shared_memory_size,
|
|
266
270
|
)
|
|
271
|
+
cleanup = atexit.register(lambda: shared_memory.unlink())
|
|
272
|
+
try:
|
|
273
|
+
for i in range(shared_memory_size):
|
|
274
|
+
shared_memory.buf[i] = 0
|
|
275
|
+
await self._spawn_workers(uri, *tags, size=size, factory=worker)
|
|
276
|
+
async with WorkerProxy(
|
|
277
|
+
discovery=LocalDiscovery(uri),
|
|
278
|
+
loadbalancer=loadbalancer,
|
|
279
|
+
):
|
|
280
|
+
yield
|
|
281
|
+
finally:
|
|
282
|
+
shared_memory.unlink()
|
|
283
|
+
atexit.unregister(cleanup)
|
|
267
284
|
|
|
268
285
|
case (None, discovery) if discovery is not None:
|
|
269
286
|
|
|
287
|
+
@asynccontextmanager
|
|
270
288
|
async def create_proxy():
|
|
271
|
-
|
|
289
|
+
async with WorkerProxy(
|
|
272
290
|
discovery=discovery,
|
|
273
291
|
loadbalancer=loadbalancer,
|
|
274
|
-
)
|
|
292
|
+
):
|
|
293
|
+
yield
|
|
275
294
|
|
|
276
295
|
case _:
|
|
277
296
|
raise RuntimeError
|
|
@@ -281,24 +300,20 @@ class WorkerPool:
|
|
|
281
300
|
async def __aenter__(self) -> WorkerPool:
|
|
282
301
|
"""Starts the worker pool and its services, returning a session.
|
|
283
302
|
|
|
284
|
-
This method starts the worker registrar, creates a
|
|
303
|
+
This method starts the worker registrar, creates a connection,
|
|
285
304
|
launches all worker processes, and registers them.
|
|
286
305
|
|
|
287
306
|
:returns:
|
|
288
|
-
The :
|
|
307
|
+
The :class:`WorkerPool` instance itself for method chaining.
|
|
289
308
|
"""
|
|
290
|
-
self.
|
|
291
|
-
await self.
|
|
309
|
+
self._proxy_context = self._proxy_factory()
|
|
310
|
+
await self._proxy_context.__aenter__()
|
|
292
311
|
return self
|
|
293
312
|
|
|
294
313
|
async def __aexit__(self, *args):
|
|
295
314
|
"""Stops all workers and tears down the pool and its services."""
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
await self._proxy.__aexit__(*args)
|
|
299
|
-
finally:
|
|
300
|
-
if self._shared_memory is not None:
|
|
301
|
-
self._shared_memory.unlink()
|
|
315
|
+
await self._stop_workers()
|
|
316
|
+
await self._proxy_context.__aexit__(*args)
|
|
302
317
|
|
|
303
318
|
async def _spawn_workers(
|
|
304
319
|
self, uri, *tags: str, size: int, factory: WorkerFactory | None
|
wool/_worker_proxy.py
CHANGED
|
@@ -1,31 +1,24 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
import itertools
|
|
5
4
|
import uuid
|
|
6
5
|
from typing import TYPE_CHECKING
|
|
7
6
|
from typing import AsyncContextManager
|
|
8
7
|
from typing import AsyncIterator
|
|
9
8
|
from typing import Awaitable
|
|
10
|
-
from typing import Callable
|
|
11
9
|
from typing import ContextManager
|
|
12
|
-
from typing import Final
|
|
13
10
|
from typing import Generic
|
|
14
|
-
from typing import Protocol
|
|
15
11
|
from typing import Sequence
|
|
16
12
|
from typing import TypeAlias
|
|
17
13
|
from typing import TypeVar
|
|
18
14
|
from typing import overload
|
|
19
|
-
from typing import runtime_checkable
|
|
20
|
-
|
|
21
|
-
import grpc
|
|
22
|
-
import grpc.aio
|
|
23
15
|
|
|
24
16
|
import wool
|
|
25
|
-
from wool import
|
|
26
|
-
from wool.
|
|
17
|
+
from wool._connection import Connection
|
|
18
|
+
from wool._loadbalancer import LoadBalancerContext
|
|
19
|
+
from wool._loadbalancer import LoadBalancerLike
|
|
20
|
+
from wool._loadbalancer import RoundRobinLoadBalancer
|
|
27
21
|
from wool._resource_pool import ResourcePool
|
|
28
|
-
from wool._worker import WorkerClient
|
|
29
22
|
from wool._worker_discovery import DiscoveryEvent
|
|
30
23
|
from wool._worker_discovery import DiscoveryLike
|
|
31
24
|
from wool._worker_discovery import Factory
|
|
@@ -67,30 +60,30 @@ class ReducibleAsyncIterator(Generic[T]):
|
|
|
67
60
|
return (self.__class__, (self._items,))
|
|
68
61
|
|
|
69
62
|
|
|
70
|
-
async def
|
|
71
|
-
"""Factory function for creating
|
|
63
|
+
async def connection_factory(target: str) -> Connection:
|
|
64
|
+
"""Factory function for creating worker connections.
|
|
72
65
|
|
|
73
|
-
Creates
|
|
74
|
-
The
|
|
66
|
+
Creates a connection to the specified worker target.
|
|
67
|
+
The target is passed as the key from ResourcePool.
|
|
75
68
|
|
|
76
|
-
:param
|
|
77
|
-
The network
|
|
69
|
+
:param target:
|
|
70
|
+
The network target (host:port) to create a channel for.
|
|
78
71
|
:returns:
|
|
79
|
-
A new
|
|
72
|
+
A new connection to the target.
|
|
80
73
|
"""
|
|
81
|
-
return
|
|
74
|
+
return Connection(target)
|
|
82
75
|
|
|
83
76
|
|
|
84
|
-
async def
|
|
77
|
+
async def connection_finalizer(connection: Connection) -> None:
|
|
85
78
|
"""Finalizer function for gRPC channels.
|
|
86
79
|
|
|
87
|
-
Closes the gRPC
|
|
80
|
+
Closes the gRPC connection when it's being cleaned up from the resource pool.
|
|
88
81
|
|
|
89
|
-
:param
|
|
90
|
-
The gRPC
|
|
82
|
+
:param connection:
|
|
83
|
+
The gRPC connection to close.
|
|
91
84
|
"""
|
|
92
85
|
try:
|
|
93
|
-
await
|
|
86
|
+
await connection.close()
|
|
94
87
|
except Exception:
|
|
95
88
|
pass
|
|
96
89
|
|
|
@@ -98,134 +91,6 @@ async def client_finalizer(client: WorkerClient) -> None:
|
|
|
98
91
|
WorkerUri: TypeAlias = str
|
|
99
92
|
|
|
100
93
|
|
|
101
|
-
class NoWorkersAvailable(Exception):
|
|
102
|
-
"""Raised when no workers are available for task dispatch.
|
|
103
|
-
|
|
104
|
-
This exception indicates that either no workers exist in the worker pool
|
|
105
|
-
or all available workers have been tried and failed with transient errors.
|
|
106
|
-
"""
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
# public
|
|
110
|
-
@runtime_checkable
|
|
111
|
-
class LoadBalancerLike(Protocol):
|
|
112
|
-
"""Protocol for load balancer v2 that directly dispatches tasks.
|
|
113
|
-
|
|
114
|
-
This simplified protocol does not manage discovery services and instead
|
|
115
|
-
operates on a dynamic list of (worker_uri, WorkerInfo) tuples sorted by
|
|
116
|
-
worker_uri. It only defines a dispatch method that accepts a WoolTask and
|
|
117
|
-
returns a task result.
|
|
118
|
-
|
|
119
|
-
Expected constructor signature (see LoadBalancerV2Factory):
|
|
120
|
-
__init__(self, workers: list[tuple[str, WorkerInfo]])
|
|
121
|
-
"""
|
|
122
|
-
|
|
123
|
-
def dispatch(self, task: WoolTask) -> AsyncIterator: ...
|
|
124
|
-
|
|
125
|
-
def worker_added_callback(
|
|
126
|
-
self, client: Callable[[], Resource[WorkerClient]], info: WorkerInfo
|
|
127
|
-
): ...
|
|
128
|
-
|
|
129
|
-
def worker_updated_callback(
|
|
130
|
-
self, client: Callable[[], Resource[WorkerClient]], info: WorkerInfo
|
|
131
|
-
): ...
|
|
132
|
-
|
|
133
|
-
def worker_removed_callback(self, info: WorkerInfo): ...
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
DispatchCall: TypeAlias = grpc.aio.UnaryStreamCall[pb.task.Task, pb.worker.Response]
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
class RoundRobinLoadBalancer:
|
|
140
|
-
"""Round-robin load balancer for distributing tasks across workers.
|
|
141
|
-
|
|
142
|
-
Distributes tasks evenly across available workers using a simple round-robin
|
|
143
|
-
algorithm. Automatically handles worker failures by trying the next worker
|
|
144
|
-
when transient errors occur. Workers are dynamically managed through
|
|
145
|
-
callback methods for addition, updates, and removal.
|
|
146
|
-
"""
|
|
147
|
-
|
|
148
|
-
TRANSIENT_ERRORS: Final = {
|
|
149
|
-
grpc.StatusCode.UNAVAILABLE,
|
|
150
|
-
grpc.StatusCode.DEADLINE_EXCEEDED,
|
|
151
|
-
grpc.StatusCode.RESOURCE_EXHAUSTED,
|
|
152
|
-
}
|
|
153
|
-
|
|
154
|
-
_current_index: int
|
|
155
|
-
_workers: dict[WorkerInfo, Callable[[], Resource[WorkerClient]]]
|
|
156
|
-
|
|
157
|
-
def __init__(self):
|
|
158
|
-
"""Initialize the round-robin load balancer.
|
|
159
|
-
|
|
160
|
-
Sets up internal state for tracking workers and round-robin index.
|
|
161
|
-
Workers are managed dynamically through callback methods.
|
|
162
|
-
"""
|
|
163
|
-
self._current_index = 0
|
|
164
|
-
self._workers = {}
|
|
165
|
-
|
|
166
|
-
async def dispatch(self, task: WoolTask) -> AsyncIterator:
|
|
167
|
-
"""Dispatch a task to the next available worker using round-robin.
|
|
168
|
-
|
|
169
|
-
Tries all workers in one round-robin cycle. If a worker fails with a
|
|
170
|
-
transient error, continues to the next worker. Returns a streaming
|
|
171
|
-
result that automatically manages channel cleanup.
|
|
172
|
-
|
|
173
|
-
:param task:
|
|
174
|
-
The WoolTask to dispatch.
|
|
175
|
-
:returns:
|
|
176
|
-
A streaming dispatch result that yields worker responses.
|
|
177
|
-
:raises NoWorkersAvailable:
|
|
178
|
-
If no workers are available or all workers fail with transient errors.
|
|
179
|
-
"""
|
|
180
|
-
# Track the first worker URI we try to detect when we've looped back
|
|
181
|
-
checkpoint = None
|
|
182
|
-
|
|
183
|
-
while self._workers:
|
|
184
|
-
self._current_index = self._current_index + 1
|
|
185
|
-
if self._current_index >= len(self._workers):
|
|
186
|
-
# Reset index if it's out of bounds
|
|
187
|
-
self._current_index = 0
|
|
188
|
-
|
|
189
|
-
worker_info, worker_resource = next(
|
|
190
|
-
itertools.islice(
|
|
191
|
-
self._workers.items(), self._current_index, self._current_index + 1
|
|
192
|
-
)
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
# Check if we've looped back to the first worker we tried
|
|
196
|
-
if checkpoint is None:
|
|
197
|
-
checkpoint = worker_info.uid
|
|
198
|
-
elif worker_info.uid == checkpoint:
|
|
199
|
-
# We've tried all workers and looped back around
|
|
200
|
-
break
|
|
201
|
-
|
|
202
|
-
async with worker_resource() as worker:
|
|
203
|
-
async for result in worker.dispatch(task):
|
|
204
|
-
yield result
|
|
205
|
-
return
|
|
206
|
-
else:
|
|
207
|
-
raise NoWorkersAvailable("No workers available for dispatch")
|
|
208
|
-
|
|
209
|
-
# If we get here, all workers failed with transient errors
|
|
210
|
-
raise NoWorkersAvailable(
|
|
211
|
-
f"All {len(self._workers)} workers failed with transient errors"
|
|
212
|
-
)
|
|
213
|
-
|
|
214
|
-
def worker_added_callback(
|
|
215
|
-
self, client: Callable[[], Resource[WorkerClient]], info: WorkerInfo
|
|
216
|
-
):
|
|
217
|
-
self._workers[info] = client
|
|
218
|
-
|
|
219
|
-
def worker_updated_callback(
|
|
220
|
-
self, client: Callable[[], Resource[WorkerClient]], info: WorkerInfo
|
|
221
|
-
):
|
|
222
|
-
self._workers[info] = client
|
|
223
|
-
|
|
224
|
-
def worker_removed_callback(self, info: WorkerInfo):
|
|
225
|
-
if info in self._workers:
|
|
226
|
-
del self._workers[info]
|
|
227
|
-
|
|
228
|
-
|
|
229
94
|
# public
|
|
230
95
|
class WorkerProxy:
|
|
231
96
|
"""Client-side interface for task dispatch to distributed workers.
|
|
@@ -306,9 +171,8 @@ class WorkerProxy:
|
|
|
306
171
|
"sequence of workers"
|
|
307
172
|
)
|
|
308
173
|
|
|
309
|
-
self._id:
|
|
174
|
+
self._id: uuid.UUID = uuid.uuid4()
|
|
310
175
|
self._started = False
|
|
311
|
-
self._workers: dict[WorkerInfo, Resource[WorkerClient]] = {}
|
|
312
176
|
self._loadbalancer = loadbalancer
|
|
313
177
|
|
|
314
178
|
match (pool_uri, discovery, workers):
|
|
@@ -328,6 +192,7 @@ class WorkerProxy:
|
|
|
328
192
|
"pool_uri, discovery_event_stream, or workers"
|
|
329
193
|
)
|
|
330
194
|
self._sentinel_task: asyncio.Task[None] | None = None
|
|
195
|
+
self._loadbalancer_context: LoadBalancerContext | None = None
|
|
331
196
|
|
|
332
197
|
async def __aenter__(self):
|
|
333
198
|
"""Starts the proxy and sets it as the active context."""
|
|
@@ -374,9 +239,12 @@ class WorkerProxy:
|
|
|
374
239
|
return self._started
|
|
375
240
|
|
|
376
241
|
@property
|
|
377
|
-
def workers(self) ->
|
|
242
|
+
def workers(self) -> list[WorkerInfo]:
|
|
378
243
|
"""A list of the currently discovered worker gRPC stubs."""
|
|
379
|
-
|
|
244
|
+
if self._loadbalancer_context:
|
|
245
|
+
return list(self._loadbalancer_context.workers.keys())
|
|
246
|
+
else:
|
|
247
|
+
return []
|
|
380
248
|
|
|
381
249
|
async def start(self) -> None:
|
|
382
250
|
"""Starts the proxy by initiating the worker discovery process.
|
|
@@ -387,22 +255,25 @@ class WorkerProxy:
|
|
|
387
255
|
if self._started:
|
|
388
256
|
raise RuntimeError("Proxy already started")
|
|
389
257
|
|
|
390
|
-
(
|
|
391
|
-
self.
|
|
392
|
-
|
|
258
|
+
(
|
|
259
|
+
self._loadbalancer_service,
|
|
260
|
+
self._loadbalancer_context_manager,
|
|
261
|
+
) = await self._enter_context(self._loadbalancer)
|
|
393
262
|
if not isinstance(self._loadbalancer_service, LoadBalancerLike):
|
|
394
263
|
raise ValueError
|
|
395
264
|
|
|
396
|
-
|
|
397
|
-
self.
|
|
398
|
-
|
|
265
|
+
(
|
|
266
|
+
self._discovery_service,
|
|
267
|
+
self._discovery_context_manager,
|
|
268
|
+
) = await self._enter_context(self._discovery)
|
|
399
269
|
if not isinstance(self._discovery_service, DiscoveryLike):
|
|
400
270
|
raise ValueError
|
|
401
271
|
|
|
402
272
|
self._proxy_token = wool.__proxy__.set(self)
|
|
403
|
-
self.
|
|
404
|
-
factory=
|
|
273
|
+
self._connection_pool = ResourcePool(
|
|
274
|
+
factory=connection_factory, finalizer=connection_finalizer, ttl=60
|
|
405
275
|
)
|
|
276
|
+
self._loadbalancer_context = LoadBalancerContext()
|
|
406
277
|
self._sentinel_task = asyncio.create_task(self._worker_sentinel())
|
|
407
278
|
self._started = True
|
|
408
279
|
|
|
@@ -415,8 +286,8 @@ class WorkerProxy:
|
|
|
415
286
|
if not self._started:
|
|
416
287
|
raise RuntimeError("Proxy not started - call start() first")
|
|
417
288
|
|
|
418
|
-
await self._exit_context(self.
|
|
419
|
-
await self._exit_context(self.
|
|
289
|
+
await self._exit_context(self._discovery_context_manager, *args)
|
|
290
|
+
await self._exit_context(self._loadbalancer_context_manager, *args)
|
|
420
291
|
|
|
421
292
|
wool.__proxy__.reset(self._proxy_token)
|
|
422
293
|
if self._sentinel_task:
|
|
@@ -426,12 +297,11 @@ class WorkerProxy:
|
|
|
426
297
|
except asyncio.CancelledError:
|
|
427
298
|
pass
|
|
428
299
|
self._sentinel_task = None
|
|
429
|
-
await self.
|
|
430
|
-
|
|
431
|
-
self._workers.clear()
|
|
300
|
+
await self._connection_pool.clear()
|
|
301
|
+
self._loadbalancer_context = None
|
|
432
302
|
self._started = False
|
|
433
303
|
|
|
434
|
-
async def dispatch(self, task: WoolTask):
|
|
304
|
+
async def dispatch(self, task: WoolTask, *, timeout: float | None = None):
|
|
435
305
|
"""Dispatches a task to an available worker in the pool.
|
|
436
306
|
|
|
437
307
|
This method selects a worker using a round-robin strategy. If no
|
|
@@ -439,7 +309,7 @@ class WorkerProxy:
|
|
|
439
309
|
exception.
|
|
440
310
|
|
|
441
311
|
:param task:
|
|
442
|
-
The :
|
|
312
|
+
The :class:`WoolTask` object to be dispatched.
|
|
443
313
|
:param timeout:
|
|
444
314
|
Timeout in seconds for getting a worker.
|
|
445
315
|
:returns:
|
|
@@ -455,8 +325,10 @@ class WorkerProxy:
|
|
|
455
325
|
await asyncio.wait_for(self._await_workers(), 60)
|
|
456
326
|
|
|
457
327
|
assert isinstance(self._loadbalancer_service, LoadBalancerLike)
|
|
458
|
-
|
|
459
|
-
|
|
328
|
+
assert self._loadbalancer_context
|
|
329
|
+
return await self._loadbalancer_service.dispatch(
|
|
330
|
+
task, context=self._loadbalancer_context, timeout=timeout
|
|
331
|
+
)
|
|
460
332
|
|
|
461
333
|
async def _enter_context(self, factory):
|
|
462
334
|
ctx = None
|
|
@@ -483,30 +355,26 @@ class WorkerProxy:
|
|
|
483
355
|
ctx.__exit__(*args)
|
|
484
356
|
|
|
485
357
|
async def _await_workers(self):
|
|
486
|
-
while not self.
|
|
358
|
+
while not self._loadbalancer_context or not self._loadbalancer_context.workers:
|
|
487
359
|
await asyncio.sleep(0)
|
|
488
360
|
|
|
489
361
|
async def _worker_sentinel(self):
|
|
490
|
-
assert
|
|
491
|
-
assert isinstance(self._loadbalancer_service, LoadBalancerLike)
|
|
362
|
+
assert self._loadbalancer_context
|
|
492
363
|
async for event in self._discovery_service:
|
|
493
364
|
match event.type:
|
|
494
365
|
case "worker_added":
|
|
495
|
-
self.
|
|
496
|
-
|
|
366
|
+
self._loadbalancer_context.add_worker(
|
|
367
|
+
event.worker_info,
|
|
368
|
+
lambda: self._connection_pool.get(
|
|
497
369
|
f"{event.worker_info.host}:{event.worker_info.port}",
|
|
498
370
|
),
|
|
499
|
-
event.worker_info,
|
|
500
371
|
)
|
|
501
372
|
case "worker_updated":
|
|
502
|
-
self.
|
|
503
|
-
|
|
373
|
+
self._loadbalancer_context.update_worker(
|
|
374
|
+
event.worker_info,
|
|
375
|
+
lambda: self._connection_pool.get(
|
|
504
376
|
f"{event.worker_info.host}:{event.worker_info.port}",
|
|
505
377
|
),
|
|
506
|
-
event.worker_info,
|
|
507
378
|
)
|
|
508
379
|
case "worker_removed":
|
|
509
|
-
|
|
510
|
-
self._loadbalancer_service.worker_removed_callback(
|
|
511
|
-
event.worker_info
|
|
512
|
-
)
|
|
380
|
+
self._loadbalancer_context.remove_worker(event.worker_info)
|