wool 0.1rc8__py3-none-any.whl → 0.1rc10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wool might be problematic. Click here for more details.
- wool/__init__.py +71 -50
- wool/_protobuf/__init__.py +14 -0
- wool/_protobuf/exception.py +3 -0
- wool/_protobuf/task.py +11 -0
- wool/_protobuf/task_pb2.py +42 -0
- wool/_protobuf/task_pb2.pyi +43 -0
- wool/_protobuf/{mempool/metadata/metadata_pb2_grpc.py → task_pb2_grpc.py} +2 -2
- wool/_protobuf/worker.py +24 -0
- wool/_protobuf/worker_pb2.py +47 -0
- wool/_protobuf/worker_pb2.pyi +39 -0
- wool/_protobuf/worker_pb2_grpc.py +141 -0
- wool/_resource_pool.py +376 -0
- wool/_typing.py +0 -10
- wool/_work.py +553 -0
- wool/_worker.py +843 -169
- wool/_worker_discovery.py +1223 -0
- wool/_worker_pool.py +331 -0
- wool/_worker_proxy.py +515 -0
- {wool-0.1rc8.dist-info → wool-0.1rc10.dist-info}/METADATA +8 -7
- wool-0.1rc10.dist-info/RECORD +22 -0
- wool-0.1rc10.dist-info/entry_points.txt +2 -0
- wool/_cli.py +0 -262
- wool/_event.py +0 -109
- wool/_future.py +0 -171
- wool/_logging.py +0 -44
- wool/_manager.py +0 -181
- wool/_mempool/__init__.py +0 -4
- wool/_mempool/_mempool.py +0 -311
- wool/_mempool/_metadata.py +0 -39
- wool/_mempool/_service.py +0 -225
- wool/_pool.py +0 -524
- wool/_protobuf/mempool/mempool_pb2.py +0 -66
- wool/_protobuf/mempool/mempool_pb2.pyi +0 -108
- wool/_protobuf/mempool/mempool_pb2_grpc.py +0 -312
- wool/_protobuf/mempool/metadata/metadata_pb2.py +0 -36
- wool/_protobuf/mempool/metadata/metadata_pb2.pyi +0 -17
- wool/_queue.py +0 -32
- wool/_session.py +0 -429
- wool/_task.py +0 -366
- wool/_utils.py +0 -63
- wool-0.1rc8.dist-info/RECORD +0 -28
- wool-0.1rc8.dist-info/entry_points.txt +0 -2
- {wool-0.1rc8.dist-info → wool-0.1rc10.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,1223 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import hashlib
|
|
5
|
+
import json
|
|
6
|
+
import multiprocessing.shared_memory
|
|
7
|
+
import socket
|
|
8
|
+
import struct
|
|
9
|
+
from abc import ABC
|
|
10
|
+
from abc import abstractmethod
|
|
11
|
+
from collections import deque
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from dataclasses import field
|
|
14
|
+
from typing import TYPE_CHECKING
|
|
15
|
+
from typing import Any
|
|
16
|
+
from typing import AsyncContextManager
|
|
17
|
+
from typing import AsyncIterator
|
|
18
|
+
from typing import Awaitable
|
|
19
|
+
from typing import Callable
|
|
20
|
+
from typing import ContextManager
|
|
21
|
+
from typing import Deque
|
|
22
|
+
from typing import Dict
|
|
23
|
+
from typing import Generic
|
|
24
|
+
from typing import Literal
|
|
25
|
+
from typing import Protocol
|
|
26
|
+
from typing import Tuple
|
|
27
|
+
from typing import TypeAlias
|
|
28
|
+
from typing import TypeVar
|
|
29
|
+
from typing import final
|
|
30
|
+
from typing import runtime_checkable
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
from zeroconf import IPVersion
|
|
36
|
+
from zeroconf import ServiceInfo
|
|
37
|
+
from zeroconf import ServiceListener
|
|
38
|
+
from zeroconf import Zeroconf
|
|
39
|
+
from zeroconf.asyncio import AsyncServiceBrowser
|
|
40
|
+
from zeroconf.asyncio import AsyncZeroconf
|
|
41
|
+
|
|
42
|
+
if TYPE_CHECKING:
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# public
|
|
47
|
+
@dataclass
|
|
48
|
+
class WorkerInfo:
|
|
49
|
+
"""Properties and metadata for a worker instance.
|
|
50
|
+
|
|
51
|
+
Contains identifying information and capabilities of a worker that
|
|
52
|
+
can be used for discovery, filtering, and routing decisions.
|
|
53
|
+
|
|
54
|
+
:param uid:
|
|
55
|
+
Unique identifier for the worker instance.
|
|
56
|
+
:param host:
|
|
57
|
+
Network host address where the worker is accessible.
|
|
58
|
+
:param port:
|
|
59
|
+
Network port number where the worker is listening.
|
|
60
|
+
:param pid:
|
|
61
|
+
Process ID of the worker.
|
|
62
|
+
:param version:
|
|
63
|
+
Version string of the worker software.
|
|
64
|
+
:param tags:
|
|
65
|
+
Set of capability tags for worker filtering and selection.
|
|
66
|
+
:param extra:
|
|
67
|
+
Additional arbitrary metadata as key-value pairs.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
uid: str
|
|
71
|
+
host: str
|
|
72
|
+
port: int | None
|
|
73
|
+
pid: int
|
|
74
|
+
version: str
|
|
75
|
+
tags: set[str] = field(default_factory=set)
|
|
76
|
+
extra: dict[str, Any] = field(default_factory=dict)
|
|
77
|
+
|
|
78
|
+
def __hash__(self) -> int:
|
|
79
|
+
return hash(self.uid)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# public
|
|
83
|
+
DiscoveryEventType: TypeAlias = Literal[
|
|
84
|
+
"worker_added", "worker_removed", "worker_updated"
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
_T = TypeVar("_T")
|
|
88
|
+
PredicateFunction: TypeAlias = Callable[[_T], bool]
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class PredicatedQueue(Generic[_T]):
|
|
92
|
+
"""An asyncio queue that supports predicated gets.
|
|
93
|
+
|
|
94
|
+
Items can be retrieved only if they match a predicate function.
|
|
95
|
+
Non-matching items remain in the queue for future gets. This allows
|
|
96
|
+
selective consumption from a shared queue based on item properties.
|
|
97
|
+
|
|
98
|
+
:param maxsize:
|
|
99
|
+
Maximum number of items in the queue (0 for unlimited).
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
_maxsize: int
|
|
103
|
+
_queue: Deque[_T]
|
|
104
|
+
_getters: Deque[Tuple[asyncio.Future[_T], PredicateFunction[_T] | None]]
|
|
105
|
+
_putters: Deque[Tuple[asyncio.Future[None], _T]]
|
|
106
|
+
_unfinished_tasks: int
|
|
107
|
+
_finished: asyncio.Event
|
|
108
|
+
|
|
109
|
+
def __init__(self, maxsize: int = 0):
|
|
110
|
+
self._maxsize = maxsize
|
|
111
|
+
self._queue: Deque[_T] = deque()
|
|
112
|
+
self._getters: Deque[Tuple[asyncio.Future[_T], PredicateFunction[_T] | None]] = (
|
|
113
|
+
deque()
|
|
114
|
+
)
|
|
115
|
+
self._putters: Deque[Tuple[asyncio.Future[None], _T]] = deque()
|
|
116
|
+
self._unfinished_tasks = 0
|
|
117
|
+
self._finished = asyncio.Event()
|
|
118
|
+
self._finished.set()
|
|
119
|
+
|
|
120
|
+
def qsize(self) -> int:
|
|
121
|
+
"""Number of items in the queue."""
|
|
122
|
+
return len(self._queue)
|
|
123
|
+
|
|
124
|
+
def empty(self) -> bool:
|
|
125
|
+
"""Return True if the queue is empty, False otherwise."""
|
|
126
|
+
return not self._queue
|
|
127
|
+
|
|
128
|
+
def full(self) -> bool:
|
|
129
|
+
"""Return True if there are maxsize items in the queue."""
|
|
130
|
+
if self._maxsize <= 0:
|
|
131
|
+
return False
|
|
132
|
+
return self.qsize() >= self._maxsize
|
|
133
|
+
|
|
134
|
+
async def put(self, item: _T) -> None:
|
|
135
|
+
"""Put an item into the queue.
|
|
136
|
+
|
|
137
|
+
If the queue is full, wait until a free slot is available.
|
|
138
|
+
"""
|
|
139
|
+
while self.full():
|
|
140
|
+
putter_future: asyncio.Future[None] = asyncio.Future()
|
|
141
|
+
self._putters.append((putter_future, item))
|
|
142
|
+
try:
|
|
143
|
+
await putter_future
|
|
144
|
+
return
|
|
145
|
+
except asyncio.CancelledError:
|
|
146
|
+
putter_future.cancel()
|
|
147
|
+
try:
|
|
148
|
+
self._putters.remove((putter_future, item))
|
|
149
|
+
except ValueError:
|
|
150
|
+
pass
|
|
151
|
+
raise
|
|
152
|
+
|
|
153
|
+
self._put_nowait(item)
|
|
154
|
+
|
|
155
|
+
def put_nowait(self, item: _T) -> None:
|
|
156
|
+
"""Put an item into the queue without blocking.
|
|
157
|
+
|
|
158
|
+
:raises QueueFull:
|
|
159
|
+
If no free slot is immediately available.
|
|
160
|
+
"""
|
|
161
|
+
if self.full():
|
|
162
|
+
raise asyncio.QueueFull
|
|
163
|
+
self._put_nowait(item)
|
|
164
|
+
|
|
165
|
+
def _put_nowait(self, item: _T) -> None:
|
|
166
|
+
"""Internal method to put item without capacity checks."""
|
|
167
|
+
self._queue.append(item)
|
|
168
|
+
self._unfinished_tasks += 1
|
|
169
|
+
self._finished.clear()
|
|
170
|
+
self._wakeup_next_getter(item)
|
|
171
|
+
|
|
172
|
+
async def get(self, predicate: PredicateFunction[_T] | None = None) -> _T:
|
|
173
|
+
"""Remove and return an item from the queue that matches the predicate.
|
|
174
|
+
|
|
175
|
+
If predicate is None, return the first available item. If predicate is
|
|
176
|
+
provided, return the first item that makes predicate(item) True. Items
|
|
177
|
+
that don't match the predicate remain in the queue.
|
|
178
|
+
|
|
179
|
+
If no matching item is available, wait until one becomes available.
|
|
180
|
+
|
|
181
|
+
:param predicate:
|
|
182
|
+
Optional function to filter items.
|
|
183
|
+
:returns:
|
|
184
|
+
Item that matches the predicate.
|
|
185
|
+
"""
|
|
186
|
+
while True:
|
|
187
|
+
# Try to find a matching item in the current queue
|
|
188
|
+
item = self._get_matching_item(predicate)
|
|
189
|
+
if item is not None:
|
|
190
|
+
self._wakeup_next_putter()
|
|
191
|
+
return item
|
|
192
|
+
|
|
193
|
+
# No matching item found, wait for new items
|
|
194
|
+
getter_future: asyncio.Future[_T] = asyncio.Future()
|
|
195
|
+
self._getters.append((getter_future, predicate))
|
|
196
|
+
try:
|
|
197
|
+
return await getter_future
|
|
198
|
+
except asyncio.CancelledError:
|
|
199
|
+
getter_future.cancel()
|
|
200
|
+
try:
|
|
201
|
+
self._getters.remove((getter_future, predicate))
|
|
202
|
+
except ValueError:
|
|
203
|
+
pass
|
|
204
|
+
raise
|
|
205
|
+
|
|
206
|
+
def get_nowait(self, predicate: PredicateFunction[_T] | None = None) -> _T:
|
|
207
|
+
"""Remove and return an item immediately that matches the predicate.
|
|
208
|
+
|
|
209
|
+
:param predicate:
|
|
210
|
+
Optional function to filter items.
|
|
211
|
+
:returns:
|
|
212
|
+
Item that matches the predicate.
|
|
213
|
+
:raises QueueEmpty:
|
|
214
|
+
If no matching item is immediately available.
|
|
215
|
+
"""
|
|
216
|
+
item = self._get_matching_item(predicate)
|
|
217
|
+
if item is None:
|
|
218
|
+
raise asyncio.QueueEmpty
|
|
219
|
+
self._wakeup_next_putter()
|
|
220
|
+
return item
|
|
221
|
+
|
|
222
|
+
def _get_matching_item(self, predicate: PredicateFunction[_T] | None) -> _T | None:
|
|
223
|
+
"""Find and remove the first item that matches the predicate."""
|
|
224
|
+
if not self._queue:
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
if predicate is None:
|
|
228
|
+
# No predicate, return first available item
|
|
229
|
+
return self._queue.popleft()
|
|
230
|
+
|
|
231
|
+
# Search for matching item
|
|
232
|
+
for i, item in enumerate(self._queue):
|
|
233
|
+
if predicate(item):
|
|
234
|
+
# Found matching item, remove it from queue
|
|
235
|
+
del self._queue[i]
|
|
236
|
+
return item
|
|
237
|
+
|
|
238
|
+
return None
|
|
239
|
+
|
|
240
|
+
def _wakeup_next_getter(self, item: _T) -> None:
|
|
241
|
+
"""Try to satisfy waiting getters with the new item."""
|
|
242
|
+
remaining_getters = deque()
|
|
243
|
+
item_consumed = False
|
|
244
|
+
|
|
245
|
+
while self._getters and not item_consumed:
|
|
246
|
+
getter_future, getter_predicate = self._getters.popleft()
|
|
247
|
+
|
|
248
|
+
if getter_future.done():
|
|
249
|
+
continue
|
|
250
|
+
|
|
251
|
+
# Check if this getter's predicate matches the item
|
|
252
|
+
if getter_predicate is None or getter_predicate(item):
|
|
253
|
+
# This getter can take the item
|
|
254
|
+
# Try to remove the item from queue and satisfy the getter
|
|
255
|
+
try:
|
|
256
|
+
self._queue.remove(item)
|
|
257
|
+
getter_future.set_result(item)
|
|
258
|
+
item_consumed = True
|
|
259
|
+
# Item was successfully given to a getter, we're done
|
|
260
|
+
break
|
|
261
|
+
except ValueError:
|
|
262
|
+
# Item was already taken by another operation
|
|
263
|
+
# Continue to next getter, but don't try to give them this item
|
|
264
|
+
remaining_getters.append((getter_future, getter_predicate))
|
|
265
|
+
else:
|
|
266
|
+
# This getter's predicate doesn't match, keep waiting
|
|
267
|
+
remaining_getters.append((getter_future, getter_predicate))
|
|
268
|
+
|
|
269
|
+
# Restore getters that couldn't be satisfied
|
|
270
|
+
self._getters.extendleft(reversed(remaining_getters))
|
|
271
|
+
|
|
272
|
+
def _wakeup_next_putter(self) -> None:
|
|
273
|
+
"""Wake up the next putter if there's space."""
|
|
274
|
+
while self._putters and not self.full():
|
|
275
|
+
putter_future, item = self._putters.popleft()
|
|
276
|
+
if not putter_future.done():
|
|
277
|
+
self._put_nowait(item)
|
|
278
|
+
putter_future.set_result(None)
|
|
279
|
+
break
|
|
280
|
+
|
|
281
|
+
def task_done(self) -> None:
|
|
282
|
+
"""Indicate that a formerly enqueued task is complete."""
|
|
283
|
+
if self._unfinished_tasks <= 0:
|
|
284
|
+
raise ValueError("task_done() called too many times")
|
|
285
|
+
self._unfinished_tasks -= 1
|
|
286
|
+
if self._unfinished_tasks == 0:
|
|
287
|
+
self._finished.set()
|
|
288
|
+
|
|
289
|
+
async def join(self) -> None:
|
|
290
|
+
"""Wait until all items in the queue have been gotten and completed."""
|
|
291
|
+
await self._finished.wait()
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
# public
|
|
295
|
+
@dataclass
|
|
296
|
+
class DiscoveryEvent:
|
|
297
|
+
"""Represents a worker service discovery event.
|
|
298
|
+
|
|
299
|
+
Contains information about worker service lifecycle events (added,
|
|
300
|
+
updated, removed) including both pre- and post-event property states to
|
|
301
|
+
enable comprehensive event handling.
|
|
302
|
+
|
|
303
|
+
:param type:
|
|
304
|
+
Type of discovery event (added, updated, or removed).
|
|
305
|
+
:param worker:
|
|
306
|
+
The :class:`~wool._worker_discovery.WorkerInfo` instance associated with this
|
|
307
|
+
event.
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
type: DiscoveryEventType
|
|
311
|
+
worker_info: WorkerInfo
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
_T_co = TypeVar("_T_co", covariant=True)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
# public
|
|
318
|
+
class Reducible(Protocol):
|
|
319
|
+
"""Protocol for objects that support pickling via __reduce__."""
|
|
320
|
+
|
|
321
|
+
def __reduce__(self) -> tuple: ...
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
# public
|
|
325
|
+
class ReducibleAsyncIteratorLike(Reducible, Protocol, Generic[_T_co]):
|
|
326
|
+
"""Protocol for async iterators that yield discovery events.
|
|
327
|
+
|
|
328
|
+
Implementations must be pickleable via __reduce__ to support
|
|
329
|
+
task-specific session contexts in distributed environments.
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
def __aiter__(self) -> ReducibleAsyncIteratorLike[_T_co]: ...
|
|
333
|
+
|
|
334
|
+
def __anext__(self) -> Awaitable[_T_co]: ...
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
# public
|
|
338
|
+
@runtime_checkable
|
|
339
|
+
class Factory(Protocol, Generic[_T_co]):
|
|
340
|
+
def __call__(
|
|
341
|
+
self,
|
|
342
|
+
) -> (
|
|
343
|
+
_T_co | Awaitable[_T_co] | AsyncContextManager[_T_co] | ContextManager[_T_co]
|
|
344
|
+
): ...
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
# public
|
|
348
|
+
class DiscoveryService(ABC):
|
|
349
|
+
"""Abstract base class for discovering worker services.
|
|
350
|
+
|
|
351
|
+
When started, implementations should discover all existing services that
|
|
352
|
+
satisfy the specified filter and deliver worker-added events for each.
|
|
353
|
+
Subsequently, they should monitor for newly added, updated, or removed
|
|
354
|
+
workers and deliver appropriate events via the :meth:`events` method.
|
|
355
|
+
|
|
356
|
+
Service tracking behavior:
|
|
357
|
+
- Only workers satisfying the filter should be tracked
|
|
358
|
+
- Workers updated to satisfy the filter should trigger worker-added events
|
|
359
|
+
- Workers updated to no longer satisfy the filter should trigger
|
|
360
|
+
worker-removed events
|
|
361
|
+
- Tracked workers removed from the registry entirely should always
|
|
362
|
+
trigger worker-removed
|
|
363
|
+
|
|
364
|
+
:param filter:
|
|
365
|
+
Optional filter function to select which discovery events to yield.
|
|
366
|
+
Only events matching the filter will be delivered.
|
|
367
|
+
|
|
368
|
+
.. warning::
|
|
369
|
+
The discovery procedure should not block continuously, as it will be
|
|
370
|
+
executed in the current event loop.
|
|
371
|
+
|
|
372
|
+
.. note::
|
|
373
|
+
Implementations must be pickleable and provide an unstarted copy
|
|
374
|
+
when unpickled to support task-specific session contexts.
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
_started: bool
|
|
378
|
+
_service_cache: Dict[str, WorkerInfo]
|
|
379
|
+
|
|
380
|
+
def __init__(
|
|
381
|
+
self,
|
|
382
|
+
filter: PredicateFunction[WorkerInfo] | None = None,
|
|
383
|
+
):
|
|
384
|
+
self._filter = filter
|
|
385
|
+
self._started = False
|
|
386
|
+
self._service_cache = {}
|
|
387
|
+
|
|
388
|
+
def __reduce__(self) -> tuple:
|
|
389
|
+
"""Return constructor args for unpickling an unstarted service copy."""
|
|
390
|
+
return (self.__class__, (self._filter,))
|
|
391
|
+
|
|
392
|
+
def __aiter__(self) -> AsyncIterator[DiscoveryEvent]:
|
|
393
|
+
"""Returns self as an async iterator."""
|
|
394
|
+
return self.events()
|
|
395
|
+
|
|
396
|
+
async def __anext__(self) -> DiscoveryEvent:
|
|
397
|
+
"""Delegate to the events async iterator."""
|
|
398
|
+
return await anext(self.events())
|
|
399
|
+
|
|
400
|
+
@final
|
|
401
|
+
@property
|
|
402
|
+
def started(self) -> bool:
|
|
403
|
+
return self._started
|
|
404
|
+
|
|
405
|
+
@final
|
|
406
|
+
async def start(self) -> None:
|
|
407
|
+
"""Starts the worker discovery procedure.
|
|
408
|
+
|
|
409
|
+
This method should initiate the discovery process, which may involve
|
|
410
|
+
network operations or other asynchronous tasks.
|
|
411
|
+
|
|
412
|
+
:raises RuntimeError:
|
|
413
|
+
If the service has already been started.
|
|
414
|
+
"""
|
|
415
|
+
if self._started:
|
|
416
|
+
raise RuntimeError("Discovery service already started")
|
|
417
|
+
await self._start()
|
|
418
|
+
self._started = True
|
|
419
|
+
|
|
420
|
+
@final
|
|
421
|
+
async def stop(self) -> None:
|
|
422
|
+
"""Stops the worker discovery procedure.
|
|
423
|
+
|
|
424
|
+
This method should clean up any resources used for discovery, such as
|
|
425
|
+
network connections or event listeners.
|
|
426
|
+
|
|
427
|
+
:raises RuntimeError:
|
|
428
|
+
If the service has not been started.
|
|
429
|
+
"""
|
|
430
|
+
if not self._started:
|
|
431
|
+
raise RuntimeError("Discovery service not started")
|
|
432
|
+
await self._stop()
|
|
433
|
+
|
|
434
|
+
@abstractmethod
|
|
435
|
+
def events(self) -> AsyncIterator[DiscoveryEvent]:
|
|
436
|
+
"""Yields discovery events as they occur.
|
|
437
|
+
|
|
438
|
+
Returns an asynchronous iterator that yields discovery events for
|
|
439
|
+
workers being added, updated, or removed from the registry. Events
|
|
440
|
+
are filtered according to the filter function provided during
|
|
441
|
+
initialization.
|
|
442
|
+
|
|
443
|
+
:yields:
|
|
444
|
+
Instances of :class:`DiscoveryEvent` representing worker
|
|
445
|
+
additions, updates, and removals.
|
|
446
|
+
"""
|
|
447
|
+
...
|
|
448
|
+
|
|
449
|
+
@abstractmethod
|
|
450
|
+
async def _start(self) -> None:
|
|
451
|
+
"""Starts the worker discovery procedure."""
|
|
452
|
+
...
|
|
453
|
+
|
|
454
|
+
@abstractmethod
|
|
455
|
+
async def _stop(self) -> None:
|
|
456
|
+
"""Stops the worker discovery procedure."""
|
|
457
|
+
...
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
_T_DiscoveryServiceLike = TypeVar("_T_DiscoveryServiceLike", bound=DiscoveryService)
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
# public
|
|
464
|
+
class RegistryServiceLike(Protocol):
|
|
465
|
+
"""Abstract base class for a service where workers can register themselves.
|
|
466
|
+
|
|
467
|
+
Provides the interface for worker registration, unregistration, and
|
|
468
|
+
property updates within a distributed worker pool system.
|
|
469
|
+
"""
|
|
470
|
+
|
|
471
|
+
async def start(self) -> None: ...
|
|
472
|
+
|
|
473
|
+
async def stop(self) -> None: ...
|
|
474
|
+
|
|
475
|
+
async def register(self, worker_info: WorkerInfo) -> None: ...
|
|
476
|
+
|
|
477
|
+
async def unregister(self, worker_info: WorkerInfo) -> None: ...
|
|
478
|
+
|
|
479
|
+
async def update(self, worker_info: WorkerInfo) -> None: ...
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
# public
|
|
483
|
+
class RegistryService(Generic[_T_DiscoveryServiceLike], ABC):
|
|
484
|
+
"""Abstract base class for a service where workers can register themselves.
|
|
485
|
+
|
|
486
|
+
Provides the interface for worker registration, unregistration, and
|
|
487
|
+
property updates within a distributed worker pool system.
|
|
488
|
+
"""
|
|
489
|
+
|
|
490
|
+
_started: bool
|
|
491
|
+
_stopped: bool
|
|
492
|
+
|
|
493
|
+
def __init__(self):
|
|
494
|
+
self._started = False
|
|
495
|
+
self._stopped = False
|
|
496
|
+
|
|
497
|
+
async def start(self) -> None:
|
|
498
|
+
"""Starts the registry service, making it ready to accept registrations.
|
|
499
|
+
|
|
500
|
+
:raises RuntimeError:
|
|
501
|
+
If the service has already been started.
|
|
502
|
+
"""
|
|
503
|
+
if self._started:
|
|
504
|
+
raise RuntimeError("Registry service already started")
|
|
505
|
+
await asyncio.wait_for(self._start(), timeout=60)
|
|
506
|
+
self._started = True
|
|
507
|
+
|
|
508
|
+
async def stop(self) -> None:
|
|
509
|
+
"""Stops the registry service and cleans up any resources.
|
|
510
|
+
|
|
511
|
+
:raises RuntimeError:
|
|
512
|
+
If the service has not been started.
|
|
513
|
+
"""
|
|
514
|
+
if self._stopped:
|
|
515
|
+
return
|
|
516
|
+
if not self._started:
|
|
517
|
+
raise RuntimeError("Registry service not started")
|
|
518
|
+
await self._stop()
|
|
519
|
+
self._stopped = True
|
|
520
|
+
|
|
521
|
+
@abstractmethod
|
|
522
|
+
async def _start(self) -> None:
|
|
523
|
+
"""Starts the registry service, making it ready to accept registrations."""
|
|
524
|
+
...
|
|
525
|
+
|
|
526
|
+
@abstractmethod
|
|
527
|
+
async def _stop(self) -> None:
|
|
528
|
+
"""Stops the registry service and cleans up any resources."""
|
|
529
|
+
...
|
|
530
|
+
|
|
531
|
+
async def register(
|
|
532
|
+
self,
|
|
533
|
+
worker_info: WorkerInfo,
|
|
534
|
+
) -> None:
|
|
535
|
+
"""Registers a worker by publishing its service information.
|
|
536
|
+
|
|
537
|
+
:param worker_info:
|
|
538
|
+
The :class:`~wool._worker_discovery.WorkerInfo` instance containing all
|
|
539
|
+
worker details.
|
|
540
|
+
:raises RuntimeError:
|
|
541
|
+
If the registry service is not running.
|
|
542
|
+
"""
|
|
543
|
+
if not self._started:
|
|
544
|
+
raise RuntimeError("Registry service not started - call start() first")
|
|
545
|
+
if self._stopped:
|
|
546
|
+
raise RuntimeError("Registry service already stopped")
|
|
547
|
+
await self._register(worker_info)
|
|
548
|
+
|
|
549
|
+
@abstractmethod
|
|
550
|
+
async def _register(
|
|
551
|
+
self,
|
|
552
|
+
worker_info: WorkerInfo,
|
|
553
|
+
) -> None:
|
|
554
|
+
"""Implementation-specific worker registration.
|
|
555
|
+
|
|
556
|
+
:param worker_info:
|
|
557
|
+
The :class:`~wool._worker_discovery.WorkerInfo` instance containing all
|
|
558
|
+
worker details.
|
|
559
|
+
"""
|
|
560
|
+
...
|
|
561
|
+
|
|
562
|
+
async def unregister(self, worker_info: WorkerInfo) -> None:
|
|
563
|
+
"""Unregisters a worker by removing its service record.
|
|
564
|
+
|
|
565
|
+
:param worker_info:
|
|
566
|
+
The :class:`~wool._worker_discovery.WorkerInfo` instance of the worker to
|
|
567
|
+
unregister.
|
|
568
|
+
:raises RuntimeError:
|
|
569
|
+
If the registry service is not running.
|
|
570
|
+
"""
|
|
571
|
+
if not self._started:
|
|
572
|
+
raise RuntimeError("Registry service not started - call start() first")
|
|
573
|
+
if self._stopped:
|
|
574
|
+
raise RuntimeError("Registry service already stopped")
|
|
575
|
+
await self._unregister(worker_info)
|
|
576
|
+
|
|
577
|
+
@abstractmethod
|
|
578
|
+
async def _unregister(self, worker_info: WorkerInfo) -> None:
|
|
579
|
+
"""Implementation-specific worker unregistration.
|
|
580
|
+
|
|
581
|
+
:param worker_info:
|
|
582
|
+
The :class:`~wool._worker_discovery.WorkerInfo` instance of the worker to
|
|
583
|
+
unregister.
|
|
584
|
+
"""
|
|
585
|
+
...
|
|
586
|
+
|
|
587
|
+
async def update(self, worker_info: WorkerInfo) -> None:
|
|
588
|
+
"""Updates a worker's properties if they have changed.
|
|
589
|
+
|
|
590
|
+
:param worker_info:
|
|
591
|
+
The updated :class:`~wool._worker_discovery.WorkerInfo` instance.
|
|
592
|
+
:raises RuntimeError:
|
|
593
|
+
If the registry service is not running.
|
|
594
|
+
"""
|
|
595
|
+
if not self._started:
|
|
596
|
+
raise RuntimeError("Registry service not started - call start() first")
|
|
597
|
+
if self._stopped:
|
|
598
|
+
raise RuntimeError("Registry service already stopped")
|
|
599
|
+
await self._update(worker_info)
|
|
600
|
+
|
|
601
|
+
@abstractmethod
|
|
602
|
+
async def _update(self, worker_info: WorkerInfo) -> None:
|
|
603
|
+
"""Implementation-specific worker property updates.
|
|
604
|
+
|
|
605
|
+
:param worker_info:
|
|
606
|
+
The updated :class:`~wool._worker_discovery.WorkerInfo` instance.
|
|
607
|
+
"""
|
|
608
|
+
...
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
# public
|
|
612
|
+
class LanDiscoveryService(DiscoveryService):
|
|
613
|
+
"""Implements worker discovery on the local network using Zeroconf.
|
|
614
|
+
|
|
615
|
+
This service browses the local network for DNS-SD services and delivers
|
|
616
|
+
all worker service events to clients. Uses Zeroconf/Bonjour protocol
|
|
617
|
+
for automatic service discovery without requiring central coordination.
|
|
618
|
+
|
|
619
|
+
:param filter:
|
|
620
|
+
Optional predicate function to filter discovered workers.
|
|
621
|
+
"""
|
|
622
|
+
|
|
623
|
+
aiozc: AsyncZeroconf
|
|
624
|
+
browser: AsyncServiceBrowser
|
|
625
|
+
service_type: Literal["_wool._tcp.local."] = "_wool._tcp.local."
|
|
626
|
+
_event_queue: PredicatedQueue[DiscoveryEvent]
|
|
627
|
+
|
|
628
|
+
def __init__(
|
|
629
|
+
self,
|
|
630
|
+
filter: PredicateFunction[WorkerInfo] | None = None,
|
|
631
|
+
) -> None:
|
|
632
|
+
super().__init__(filter) # type: ignore[arg-type]
|
|
633
|
+
self._event_queue = PredicatedQueue()
|
|
634
|
+
|
|
635
|
+
async def events(self) -> AsyncIterator[DiscoveryEvent]:
|
|
636
|
+
"""Returns an async iterator over discovery events."""
|
|
637
|
+
await self.start()
|
|
638
|
+
try:
|
|
639
|
+
while True:
|
|
640
|
+
yield await self._event_queue.get()
|
|
641
|
+
finally:
|
|
642
|
+
await self.stop()
|
|
643
|
+
|
|
644
|
+
async def _start(self) -> None:
|
|
645
|
+
"""Starts the Zeroconf service browser.
|
|
646
|
+
|
|
647
|
+
:raises RuntimeError:
|
|
648
|
+
If the service has already been started.
|
|
649
|
+
"""
|
|
650
|
+
# Configure zeroconf to use localhost only to avoid network warnings
|
|
651
|
+
self.aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
|
|
652
|
+
self.browser = AsyncServiceBrowser(
|
|
653
|
+
self.aiozc.zeroconf,
|
|
654
|
+
self.service_type,
|
|
655
|
+
listener=self._Listener(
|
|
656
|
+
aiozc=self.aiozc,
|
|
657
|
+
event_queue=self._event_queue,
|
|
658
|
+
service_cache=self._service_cache,
|
|
659
|
+
predicate=self._filter or (lambda _: True),
|
|
660
|
+
),
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
async def _stop(self) -> None:
|
|
664
|
+
"""Stops the Zeroconf service browser and closes the connection.
|
|
665
|
+
|
|
666
|
+
:raises RuntimeError:
|
|
667
|
+
If the service has not been started.
|
|
668
|
+
"""
|
|
669
|
+
if self.browser:
|
|
670
|
+
await self.browser.async_cancel()
|
|
671
|
+
if self.aiozc:
|
|
672
|
+
await self.aiozc.async_close()
|
|
673
|
+
|
|
674
|
+
class _Listener(ServiceListener):
|
|
675
|
+
"""A Zeroconf listener that delivers all worker service events.
|
|
676
|
+
|
|
677
|
+
:param aiozc:
|
|
678
|
+
The :class:`~zeroconf.asyncio.AsyncZeroconf` instance to use
|
|
679
|
+
for async service info retrieval.
|
|
680
|
+
:param event_queue:
|
|
681
|
+
Queue to deliver discovery events to.
|
|
682
|
+
:param service_cache:
|
|
683
|
+
Cache to track service properties for pre/post event states.
|
|
684
|
+
:param predicate:
|
|
685
|
+
Function to filter which workers to track.
|
|
686
|
+
"""
|
|
687
|
+
|
|
688
|
+
aiozc: AsyncZeroconf
|
|
689
|
+
_event_queue: PredicatedQueue[DiscoveryEvent]
|
|
690
|
+
_service_addresses: Dict[str, str]
|
|
691
|
+
_service_cache: Dict[str, WorkerInfo]
|
|
692
|
+
|
|
693
|
+
def __init__(
|
|
694
|
+
self,
|
|
695
|
+
aiozc: AsyncZeroconf,
|
|
696
|
+
event_queue: PredicatedQueue[DiscoveryEvent],
|
|
697
|
+
predicate: PredicateFunction[WorkerInfo],
|
|
698
|
+
service_cache: Dict[str, WorkerInfo],
|
|
699
|
+
) -> None:
|
|
700
|
+
self.aiozc = aiozc
|
|
701
|
+
self._event_queue = event_queue
|
|
702
|
+
self._predicate = predicate
|
|
703
|
+
self._service_addresses = {}
|
|
704
|
+
self._service_cache = service_cache
|
|
705
|
+
|
|
706
|
+
def add_service(self, zc: Zeroconf, type_: str, name: str):
|
|
707
|
+
"""Called by Zeroconf when a service is added."""
|
|
708
|
+
if type_ == LanRegistryService.service_type:
|
|
709
|
+
asyncio.create_task(self._handle_add_service(type_, name))
|
|
710
|
+
|
|
711
|
+
def remove_service(self, zc: Zeroconf, type_: str, name: str):
|
|
712
|
+
"""Called by Zeroconf when a service is removed."""
|
|
713
|
+
if type_ == LanRegistryService.service_type:
|
|
714
|
+
if worker := self._service_cache.pop(name, None):
|
|
715
|
+
asyncio.create_task(
|
|
716
|
+
self._event_queue.put(
|
|
717
|
+
DiscoveryEvent(type="worker_removed", worker_info=worker)
|
|
718
|
+
)
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
def update_service(self, zc: Zeroconf, type_, name):
|
|
722
|
+
"""Called by Zeroconf when a service is updated."""
|
|
723
|
+
if type_ == LanRegistryService.service_type:
|
|
724
|
+
asyncio.create_task(self._handle_update_service(type_, name))
|
|
725
|
+
|
|
726
|
+
async def _handle_add_service(self, type_: str, name: str):
|
|
727
|
+
"""Async handler for service addition."""
|
|
728
|
+
try:
|
|
729
|
+
if not (
|
|
730
|
+
service_info := await self.aiozc.async_get_service_info(type_, name)
|
|
731
|
+
):
|
|
732
|
+
return
|
|
733
|
+
|
|
734
|
+
try:
|
|
735
|
+
worker_info = _deserialize_worker_info(service_info)
|
|
736
|
+
except ValueError:
|
|
737
|
+
return
|
|
738
|
+
|
|
739
|
+
if self._predicate(worker_info):
|
|
740
|
+
self._service_cache[name] = worker_info
|
|
741
|
+
event = DiscoveryEvent(type="worker_added", worker_info=worker_info)
|
|
742
|
+
await self._event_queue.put(event)
|
|
743
|
+
except Exception:
|
|
744
|
+
pass # Service may have disappeared before we could query it
|
|
745
|
+
|
|
746
|
+
async def _handle_update_service(self, type_: str, name: str):
|
|
747
|
+
"""Async handler for service update."""
|
|
748
|
+
try:
|
|
749
|
+
if not (
|
|
750
|
+
service_info := await self.aiozc.async_get_service_info(type_, name)
|
|
751
|
+
):
|
|
752
|
+
return
|
|
753
|
+
|
|
754
|
+
try:
|
|
755
|
+
worker_info = _deserialize_worker_info(service_info)
|
|
756
|
+
except ValueError:
|
|
757
|
+
return
|
|
758
|
+
|
|
759
|
+
if name not in self._service_cache:
|
|
760
|
+
# New worker that wasn't tracked before
|
|
761
|
+
if self._predicate(worker_info):
|
|
762
|
+
self._service_cache[name] = worker_info
|
|
763
|
+
event = DiscoveryEvent(
|
|
764
|
+
type="worker_added", worker_info=worker_info
|
|
765
|
+
)
|
|
766
|
+
await self._event_queue.put(event)
|
|
767
|
+
else:
|
|
768
|
+
# Existing tracked worker
|
|
769
|
+
old_worker = self._service_cache[name]
|
|
770
|
+
if self._predicate(worker_info):
|
|
771
|
+
# Still satisfies filter, update cache and emit update
|
|
772
|
+
self._service_cache[name] = worker_info
|
|
773
|
+
event = DiscoveryEvent(
|
|
774
|
+
type="worker_updated", worker_info=worker_info
|
|
775
|
+
)
|
|
776
|
+
await self._event_queue.put(event)
|
|
777
|
+
else:
|
|
778
|
+
# No longer satisfies filter, remove and emit removal
|
|
779
|
+
del self._service_cache[name]
|
|
780
|
+
removal_event = DiscoveryEvent(
|
|
781
|
+
type="worker_removed", worker_info=old_worker
|
|
782
|
+
)
|
|
783
|
+
await self._event_queue.put(removal_event)
|
|
784
|
+
|
|
785
|
+
except Exception:
|
|
786
|
+
pass
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
# public
|
|
790
|
+
class LanRegistryService(RegistryService[LanDiscoveryService]):
|
|
791
|
+
"""Implements a worker registry using Zeroconf to advertise on the LAN.
|
|
792
|
+
|
|
793
|
+
This service registers workers by publishing a DNS-SD service record on
|
|
794
|
+
the local network, allowing :class:`LanDiscoveryService` to find them.
|
|
795
|
+
"""
|
|
796
|
+
|
|
797
|
+
aiozc: AsyncZeroconf | None
|
|
798
|
+
services: Dict[str, ServiceInfo]
|
|
799
|
+
service_type: Literal["_wool._tcp.local."] = "_wool._tcp.local."
|
|
800
|
+
|
|
801
|
+
def __init__(self):
|
|
802
|
+
super().__init__()
|
|
803
|
+
self.aiozc = None
|
|
804
|
+
self.services = {}
|
|
805
|
+
|
|
806
|
+
async def _start(self) -> None:
|
|
807
|
+
"""Initializes and starts the Zeroconf instance for advertising."""
|
|
808
|
+
# Configure zeroconf to use localhost only to avoid network warnings
|
|
809
|
+
self.aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
|
|
810
|
+
|
|
811
|
+
async def _stop(self) -> None:
|
|
812
|
+
"""Stops the Zeroconf instance and cleans up all registered services."""
|
|
813
|
+
if self.aiozc:
|
|
814
|
+
await self.aiozc.async_close()
|
|
815
|
+
self.aiozc = None
|
|
816
|
+
|
|
817
|
+
async def _register(
|
|
818
|
+
self,
|
|
819
|
+
worker_info: WorkerInfo,
|
|
820
|
+
) -> None:
|
|
821
|
+
"""Registers a worker by publishing its service information via Zeroconf.
|
|
822
|
+
|
|
823
|
+
:param worker_info:
|
|
824
|
+
The :class:`~wool._worker_discovery.WorkerInfo` instance containing all
|
|
825
|
+
worker details.
|
|
826
|
+
:raises RuntimeError:
|
|
827
|
+
If the registry service is not properly initialized.
|
|
828
|
+
"""
|
|
829
|
+
if self.aiozc is None:
|
|
830
|
+
raise RuntimeError("Registry service not properly initialized")
|
|
831
|
+
address = f"{worker_info.host}:{worker_info.port}"
|
|
832
|
+
ip_address, port = self._resolve_address(address)
|
|
833
|
+
service_name = f"{worker_info.uid}.{self.service_type}"
|
|
834
|
+
service_info = ServiceInfo(
|
|
835
|
+
self.service_type,
|
|
836
|
+
service_name,
|
|
837
|
+
addresses=[ip_address],
|
|
838
|
+
port=port,
|
|
839
|
+
properties=_serialize_worker_info(worker_info),
|
|
840
|
+
)
|
|
841
|
+
self.services[worker_info.uid] = service_info
|
|
842
|
+
await self.aiozc.async_register_service(service_info)
|
|
843
|
+
|
|
844
|
+
async def _unregister(self, worker_info: WorkerInfo) -> None:
|
|
845
|
+
"""Unregisters a worker by removing its Zeroconf service record.
|
|
846
|
+
|
|
847
|
+
:param worker_info:
|
|
848
|
+
The :class:`~wool._worker_discovery.WorkerInfo` instance of the worker to
|
|
849
|
+
unregister.
|
|
850
|
+
:raises RuntimeError:
|
|
851
|
+
If the registry service is not properly initialized.
|
|
852
|
+
"""
|
|
853
|
+
if self.aiozc is None:
|
|
854
|
+
raise RuntimeError("Registry service not properly initialized")
|
|
855
|
+
service = self.services[worker_info.uid]
|
|
856
|
+
await self.aiozc.async_unregister_service(service)
|
|
857
|
+
del self.services[worker_info.uid]
|
|
858
|
+
|
|
859
|
+
async def _update(self, worker_info: WorkerInfo) -> None:
|
|
860
|
+
"""Updates a worker's properties if they have changed.
|
|
861
|
+
|
|
862
|
+
Updates both the Zeroconf service and local cache atomically.
|
|
863
|
+
If the Zeroconf update fails, the local cache remains unchanged
|
|
864
|
+
to maintain consistency.
|
|
865
|
+
|
|
866
|
+
:param worker_info:
|
|
867
|
+
The updated :class:`~wool._worker_discovery.WorkerInfo` instance.
|
|
868
|
+
:raises RuntimeError:
|
|
869
|
+
If the registry service is not properly initialized.
|
|
870
|
+
:raises Exception:
|
|
871
|
+
If the Zeroconf service update fails.
|
|
872
|
+
"""
|
|
873
|
+
if self.aiozc is None:
|
|
874
|
+
raise RuntimeError("Registry service not properly initialized")
|
|
875
|
+
|
|
876
|
+
service = self.services[worker_info.uid]
|
|
877
|
+
new_properties = _serialize_worker_info(worker_info)
|
|
878
|
+
|
|
879
|
+
if service.decoded_properties != new_properties:
|
|
880
|
+
updated_service = ServiceInfo(
|
|
881
|
+
service.type,
|
|
882
|
+
service.name,
|
|
883
|
+
addresses=service.addresses,
|
|
884
|
+
port=service.port,
|
|
885
|
+
properties=new_properties,
|
|
886
|
+
server=service.server,
|
|
887
|
+
)
|
|
888
|
+
await self.aiozc.async_update_service(updated_service)
|
|
889
|
+
self.services[worker_info.uid] = updated_service
|
|
890
|
+
|
|
891
|
+
def _resolve_address(self, address: str) -> Tuple[bytes, int]:
|
|
892
|
+
"""Resolve an address string to bytes and validate port.
|
|
893
|
+
|
|
894
|
+
:param address:
|
|
895
|
+
Address in format "host:port".
|
|
896
|
+
:returns:
|
|
897
|
+
Tuple of (IPv4/IPv6 address as bytes, port as int).
|
|
898
|
+
:raises ValueError:
|
|
899
|
+
If address format is invalid or port is out of range.
|
|
900
|
+
"""
|
|
901
|
+
host, port = address.split(":")
|
|
902
|
+
port = int(port)
|
|
903
|
+
|
|
904
|
+
try:
|
|
905
|
+
return socket.inet_pton(socket.AF_INET, host), port
|
|
906
|
+
except OSError:
|
|
907
|
+
pass
|
|
908
|
+
|
|
909
|
+
try:
|
|
910
|
+
return socket.inet_pton(socket.AF_INET6, host), port
|
|
911
|
+
except OSError:
|
|
912
|
+
pass
|
|
913
|
+
|
|
914
|
+
return socket.inet_aton(socket.gethostbyname(host)), port
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
def _serialize_worker_info(
|
|
918
|
+
info: WorkerInfo,
|
|
919
|
+
) -> dict[str, str | None]:
|
|
920
|
+
"""Serialize WorkerInfo to a flat dict for ServiceInfo.properties.
|
|
921
|
+
|
|
922
|
+
:param info:
|
|
923
|
+
:class:`~wool._worker_discovery.WorkerInfo` instance to serialize.
|
|
924
|
+
:returns:
|
|
925
|
+
Flat dict with pid, version, tags (JSON), extra (JSON).
|
|
926
|
+
"""
|
|
927
|
+
properties = {
|
|
928
|
+
"pid": str(info.pid),
|
|
929
|
+
"version": info.version,
|
|
930
|
+
"tags": (json.dumps(list(info.tags)) if info.tags else None),
|
|
931
|
+
"extra": (json.dumps(info.extra) if info.extra else None),
|
|
932
|
+
}
|
|
933
|
+
return properties
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
def _deserialize_worker_info(info: ServiceInfo) -> WorkerInfo:
|
|
937
|
+
"""Deserialize ServiceInfo.decoded_properties to WorkerProperties.
|
|
938
|
+
|
|
939
|
+
:param info:
|
|
940
|
+
ServiceInfo with decoded properties dict (str keys/values).
|
|
941
|
+
:returns:
|
|
942
|
+
:class:`~wool._worker_discovery.WorkerInfo` instance.
|
|
943
|
+
:raises ValueError:
|
|
944
|
+
If required fields are missing or invalid JSON.
|
|
945
|
+
"""
|
|
946
|
+
properties = info.decoded_properties
|
|
947
|
+
if missing := {"pid", "version"} - set(k for k, v in properties.items() if v):
|
|
948
|
+
missing = ", ".join(missing)
|
|
949
|
+
raise ValueError(f"Missing required properties: {missing}")
|
|
950
|
+
assert "pid" in properties and properties["pid"]
|
|
951
|
+
assert "version" in properties and properties["version"]
|
|
952
|
+
pid = int(properties["pid"])
|
|
953
|
+
version = properties["version"]
|
|
954
|
+
if "tags" in properties and properties["tags"]:
|
|
955
|
+
tags = set(json.loads(properties["tags"]))
|
|
956
|
+
else:
|
|
957
|
+
tags = set()
|
|
958
|
+
if "extra" in properties and properties["extra"]:
|
|
959
|
+
extra = json.loads(properties["extra"])
|
|
960
|
+
else:
|
|
961
|
+
extra = {}
|
|
962
|
+
return WorkerInfo(
|
|
963
|
+
uid=info.name,
|
|
964
|
+
pid=pid,
|
|
965
|
+
host=str(info.ip_addresses_by_version(IPVersion.V4Only)[0]),
|
|
966
|
+
port=info.port,
|
|
967
|
+
version=version,
|
|
968
|
+
tags=tags,
|
|
969
|
+
extra=extra,
|
|
970
|
+
)
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
# public
|
|
974
|
+
class LocalRegistryService(RegistryService):
|
|
975
|
+
"""Implements a worker registry using shared memory for local pools.
|
|
976
|
+
|
|
977
|
+
This service registers workers by writing their information to a shared memory
|
|
978
|
+
block, allowing LocalDiscoveryService instances to find them efficiently.
|
|
979
|
+
The registry stores worker ports as integers in a simple array format,
|
|
980
|
+
providing fast local discovery without network overhead.
|
|
981
|
+
|
|
982
|
+
:param uri:
|
|
983
|
+
Unique identifier for the shared memory segment.
|
|
984
|
+
"""
|
|
985
|
+
|
|
986
|
+
_shared_memory: multiprocessing.shared_memory.SharedMemory | None = None
|
|
987
|
+
_uri: str
|
|
988
|
+
_created_shared_memory: bool = False
|
|
989
|
+
|
|
990
|
+
def __init__(self, uri: str):
|
|
991
|
+
super().__init__()
|
|
992
|
+
self._uri = uri
|
|
993
|
+
self._created_shared_memory = False
|
|
994
|
+
|
|
995
|
+
async def _start(self) -> None:
|
|
996
|
+
"""Initialize shared memory for worker registration."""
|
|
997
|
+
if self._shared_memory is None:
|
|
998
|
+
# Try to connect to existing shared memory first, create if it doesn't exist
|
|
999
|
+
shared_memory_name = hashlib.sha256(self._uri.encode()).hexdigest()[:12]
|
|
1000
|
+
try:
|
|
1001
|
+
self._shared_memory = multiprocessing.shared_memory.SharedMemory(
|
|
1002
|
+
name=shared_memory_name
|
|
1003
|
+
)
|
|
1004
|
+
except FileNotFoundError:
|
|
1005
|
+
# Create new shared memory if it doesn't exist
|
|
1006
|
+
self._shared_memory = multiprocessing.shared_memory.SharedMemory(
|
|
1007
|
+
name=shared_memory_name,
|
|
1008
|
+
create=True,
|
|
1009
|
+
size=1024, # 1024 bytes = 256 worker slots (4 bytes per port)
|
|
1010
|
+
)
|
|
1011
|
+
self._created_shared_memory = True
|
|
1012
|
+
# Initialize all slots to 0 (empty)
|
|
1013
|
+
for i in range(len(self._shared_memory.buf)):
|
|
1014
|
+
self._shared_memory.buf[i] = 0
|
|
1015
|
+
|
|
1016
|
+
async def _stop(self) -> None:
|
|
1017
|
+
"""Clean up shared memory resources."""
|
|
1018
|
+
if self._shared_memory:
|
|
1019
|
+
try:
|
|
1020
|
+
self._shared_memory.close()
|
|
1021
|
+
# Unlink the shared memory if this registry created it
|
|
1022
|
+
if self._created_shared_memory:
|
|
1023
|
+
self._shared_memory.unlink()
|
|
1024
|
+
except Exception:
|
|
1025
|
+
pass
|
|
1026
|
+
self._shared_memory = None
|
|
1027
|
+
self._created_shared_memory = False
|
|
1028
|
+
|
|
1029
|
+
async def _register(self, worker_info: WorkerInfo) -> None:
|
|
1030
|
+
"""Register a worker by writing its port to shared memory.
|
|
1031
|
+
|
|
1032
|
+
:param worker_info:
|
|
1033
|
+
The :class:`~wool._worker_discovery.WorkerInfo` instance containing all
|
|
1034
|
+
worker details. Only the port is stored in shared memory.
|
|
1035
|
+
:raises RuntimeError:
|
|
1036
|
+
If the registry service is not properly initialized.
|
|
1037
|
+
"""
|
|
1038
|
+
if self._shared_memory is None:
|
|
1039
|
+
raise RuntimeError("Registry service not properly initialized")
|
|
1040
|
+
|
|
1041
|
+
if worker_info.port is None:
|
|
1042
|
+
raise ValueError("Worker port must be specified")
|
|
1043
|
+
|
|
1044
|
+
# Find first available slot and write port
|
|
1045
|
+
for i in range(0, len(self._shared_memory.buf), 4):
|
|
1046
|
+
existing_port = struct.unpack("I", self._shared_memory.buf[i : i + 4])[0]
|
|
1047
|
+
if existing_port == 0: # Empty slot
|
|
1048
|
+
struct.pack_into("I", self._shared_memory.buf, i, worker_info.port)
|
|
1049
|
+
break
|
|
1050
|
+
else:
|
|
1051
|
+
raise RuntimeError("No available slots in shared memory registry")
|
|
1052
|
+
|
|
1053
|
+
async def _unregister(self, worker_info: WorkerInfo) -> None:
|
|
1054
|
+
"""Unregister a worker by removing its port from shared memory.
|
|
1055
|
+
|
|
1056
|
+
:param worker_info:
|
|
1057
|
+
The :class:`~wool._worker_discovery.WorkerInfo` instance of the worker to
|
|
1058
|
+
unregister.
|
|
1059
|
+
:raises RuntimeError:
|
|
1060
|
+
If the registry service is not properly initialized.
|
|
1061
|
+
"""
|
|
1062
|
+
if self._shared_memory is None:
|
|
1063
|
+
raise RuntimeError("Registry service not properly initialized")
|
|
1064
|
+
|
|
1065
|
+
if worker_info.port is None:
|
|
1066
|
+
return
|
|
1067
|
+
|
|
1068
|
+
# Find and clear the port
|
|
1069
|
+
for i in range(0, len(self._shared_memory.buf), 4):
|
|
1070
|
+
existing_port = struct.unpack("I", self._shared_memory.buf[i : i + 4])[0]
|
|
1071
|
+
if existing_port == worker_info.port:
|
|
1072
|
+
struct.pack_into("I", self._shared_memory.buf, i, 0) # Clear slot
|
|
1073
|
+
break
|
|
1074
|
+
|
|
1075
|
+
async def _update(self, worker_info: WorkerInfo) -> None:
|
|
1076
|
+
"""Update a worker's properties in shared memory.
|
|
1077
|
+
|
|
1078
|
+
For the simple port-based registry, update is the same as register.
|
|
1079
|
+
|
|
1080
|
+
:param worker_info:
|
|
1081
|
+
The updated :class:`~wool._worker_discovery.WorkerInfo` instance.
|
|
1082
|
+
"""
|
|
1083
|
+
await self._register(worker_info)
|
|
1084
|
+
|
|
1085
|
+
|
|
1086
|
+
# public
|
|
1087
|
+
class LocalDiscoveryService(DiscoveryService):
|
|
1088
|
+
"""Implements worker discovery using shared memory for local pools.
|
|
1089
|
+
|
|
1090
|
+
This service reads worker ports from a shared memory block and
|
|
1091
|
+
constructs WorkerInfo instances with localhost as the implied host.
|
|
1092
|
+
Provides efficient local discovery for single-machine worker pools.
|
|
1093
|
+
|
|
1094
|
+
:param uri:
|
|
1095
|
+
Unique identifier for the shared memory segment.
|
|
1096
|
+
:param filter:
|
|
1097
|
+
Optional predicate function to filter discovered workers.
|
|
1098
|
+
"""
|
|
1099
|
+
|
|
1100
|
+
_shared_memory: multiprocessing.shared_memory.SharedMemory | None = None
|
|
1101
|
+
_uri: str
|
|
1102
|
+
_event_queue: asyncio.Queue[DiscoveryEvent]
|
|
1103
|
+
_monitor_task: asyncio.Task | None
|
|
1104
|
+
_stop_event: asyncio.Event
|
|
1105
|
+
|
|
1106
|
+
def __init__(
|
|
1107
|
+
self,
|
|
1108
|
+
uri: str,
|
|
1109
|
+
filter: PredicateFunction[WorkerInfo] | None = None,
|
|
1110
|
+
) -> None:
|
|
1111
|
+
super().__init__(filter) # type: ignore[arg-type]
|
|
1112
|
+
self._uri = uri
|
|
1113
|
+
self._event_queue = asyncio.Queue()
|
|
1114
|
+
self._monitor_task = None
|
|
1115
|
+
self._stop_event = asyncio.Event()
|
|
1116
|
+
|
|
1117
|
+
def __reduce__(self) -> tuple:
|
|
1118
|
+
"""Return constructor args for unpickling an unstarted service copy."""
|
|
1119
|
+
func, args = super().__reduce__()
|
|
1120
|
+
return (func, (self._uri, *args))
|
|
1121
|
+
|
|
1122
|
+
async def events(self) -> AsyncIterator[DiscoveryEvent]:
|
|
1123
|
+
"""Returns an async iterator over discovery events."""
|
|
1124
|
+
await self.start()
|
|
1125
|
+
try:
|
|
1126
|
+
while True:
|
|
1127
|
+
event = await self._event_queue.get()
|
|
1128
|
+
yield event
|
|
1129
|
+
finally:
|
|
1130
|
+
await self.stop()
|
|
1131
|
+
|
|
1132
|
+
async def _start(self) -> None:
|
|
1133
|
+
"""Starts monitoring shared memory for worker registrations."""
|
|
1134
|
+
if self._shared_memory is None:
|
|
1135
|
+
# Try to connect to existing shared memory first
|
|
1136
|
+
self._shared_memory = multiprocessing.shared_memory.SharedMemory(
|
|
1137
|
+
name=hashlib.sha256(self._uri.encode()).hexdigest()[:12]
|
|
1138
|
+
)
|
|
1139
|
+
|
|
1140
|
+
# Start monitoring task
|
|
1141
|
+
self._monitor_task = asyncio.create_task(self._monitor_shared_memory())
|
|
1142
|
+
|
|
1143
|
+
async def _stop(self) -> None:
|
|
1144
|
+
"""Stops monitoring shared memory."""
|
|
1145
|
+
self._stop_event.set()
|
|
1146
|
+
if self._monitor_task:
|
|
1147
|
+
try:
|
|
1148
|
+
await self._monitor_task
|
|
1149
|
+
except asyncio.CancelledError:
|
|
1150
|
+
pass
|
|
1151
|
+
if self._shared_memory:
|
|
1152
|
+
try:
|
|
1153
|
+
self._shared_memory.close()
|
|
1154
|
+
except Exception:
|
|
1155
|
+
pass
|
|
1156
|
+
self._shared_memory = None
|
|
1157
|
+
|
|
1158
|
+
async def _monitor_shared_memory(self) -> None:
|
|
1159
|
+
"""Monitor shared memory for changes and emit events."""
|
|
1160
|
+
poll_interval = 0.1
|
|
1161
|
+
|
|
1162
|
+
while not self._stop_event.is_set():
|
|
1163
|
+
try:
|
|
1164
|
+
current_workers = {}
|
|
1165
|
+
|
|
1166
|
+
# Read current state from shared memory
|
|
1167
|
+
if self._shared_memory:
|
|
1168
|
+
for i in range(0, len(self._shared_memory.buf), 4):
|
|
1169
|
+
port = struct.unpack("I", self._shared_memory.buf[i : i + 4])[0]
|
|
1170
|
+
if port > 0: # Active worker
|
|
1171
|
+
worker_info = WorkerInfo(
|
|
1172
|
+
uid=f"worker-{port}",
|
|
1173
|
+
host="localhost",
|
|
1174
|
+
port=port,
|
|
1175
|
+
pid=0, # Not available in simple registry
|
|
1176
|
+
version="unknown", # Not available in simple registry
|
|
1177
|
+
tags=set(),
|
|
1178
|
+
extra={},
|
|
1179
|
+
)
|
|
1180
|
+
if self._filter is None or self._filter(worker_info):
|
|
1181
|
+
current_workers[worker_info.uid] = worker_info
|
|
1182
|
+
|
|
1183
|
+
# Detect changes
|
|
1184
|
+
await self._detect_changes(current_workers)
|
|
1185
|
+
|
|
1186
|
+
# Wait before next poll
|
|
1187
|
+
try:
|
|
1188
|
+
await asyncio.wait_for(
|
|
1189
|
+
self._stop_event.wait(), timeout=poll_interval
|
|
1190
|
+
)
|
|
1191
|
+
break # Stop event was set
|
|
1192
|
+
except asyncio.TimeoutError:
|
|
1193
|
+
continue # Continue polling
|
|
1194
|
+
|
|
1195
|
+
except Exception:
|
|
1196
|
+
continue
|
|
1197
|
+
|
|
1198
|
+
async def _detect_changes(self, current_workers: Dict[str, WorkerInfo]) -> None:
|
|
1199
|
+
"""Detect and emit events for worker changes."""
|
|
1200
|
+
# Find added workers
|
|
1201
|
+
for uid, worker_info in current_workers.items():
|
|
1202
|
+
if uid not in self._service_cache:
|
|
1203
|
+
self._service_cache[uid] = worker_info
|
|
1204
|
+
event = DiscoveryEvent(type="worker_added", worker_info=worker_info)
|
|
1205
|
+
await self._event_queue.put(event)
|
|
1206
|
+
|
|
1207
|
+
# Find removed workers
|
|
1208
|
+
for uid in list(self._service_cache.keys()):
|
|
1209
|
+
if uid not in current_workers:
|
|
1210
|
+
worker_info = self._service_cache.pop(uid)
|
|
1211
|
+
event = DiscoveryEvent(type="worker_removed", worker_info=worker_info)
|
|
1212
|
+
await self._event_queue.put(event)
|
|
1213
|
+
|
|
1214
|
+
# Find updated workers (minimal for port-only registry)
|
|
1215
|
+
for uid, worker_info in current_workers.items():
|
|
1216
|
+
if uid in self._service_cache:
|
|
1217
|
+
old_worker = self._service_cache[uid]
|
|
1218
|
+
if worker_info.port != old_worker.port:
|
|
1219
|
+
self._service_cache[uid] = worker_info
|
|
1220
|
+
event = DiscoveryEvent(
|
|
1221
|
+
type="worker_updated", worker_info=worker_info
|
|
1222
|
+
)
|
|
1223
|
+
await self._event_queue.put(event)
|