wool 0.1rc8__py3-none-any.whl → 0.1rc10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wool might be problematic. Click here for more details.

Files changed (43) hide show
  1. wool/__init__.py +71 -50
  2. wool/_protobuf/__init__.py +14 -0
  3. wool/_protobuf/exception.py +3 -0
  4. wool/_protobuf/task.py +11 -0
  5. wool/_protobuf/task_pb2.py +42 -0
  6. wool/_protobuf/task_pb2.pyi +43 -0
  7. wool/_protobuf/{mempool/metadata/metadata_pb2_grpc.py → task_pb2_grpc.py} +2 -2
  8. wool/_protobuf/worker.py +24 -0
  9. wool/_protobuf/worker_pb2.py +47 -0
  10. wool/_protobuf/worker_pb2.pyi +39 -0
  11. wool/_protobuf/worker_pb2_grpc.py +141 -0
  12. wool/_resource_pool.py +376 -0
  13. wool/_typing.py +0 -10
  14. wool/_work.py +553 -0
  15. wool/_worker.py +843 -169
  16. wool/_worker_discovery.py +1223 -0
  17. wool/_worker_pool.py +331 -0
  18. wool/_worker_proxy.py +515 -0
  19. {wool-0.1rc8.dist-info → wool-0.1rc10.dist-info}/METADATA +8 -7
  20. wool-0.1rc10.dist-info/RECORD +22 -0
  21. wool-0.1rc10.dist-info/entry_points.txt +2 -0
  22. wool/_cli.py +0 -262
  23. wool/_event.py +0 -109
  24. wool/_future.py +0 -171
  25. wool/_logging.py +0 -44
  26. wool/_manager.py +0 -181
  27. wool/_mempool/__init__.py +0 -4
  28. wool/_mempool/_mempool.py +0 -311
  29. wool/_mempool/_metadata.py +0 -39
  30. wool/_mempool/_service.py +0 -225
  31. wool/_pool.py +0 -524
  32. wool/_protobuf/mempool/mempool_pb2.py +0 -66
  33. wool/_protobuf/mempool/mempool_pb2.pyi +0 -108
  34. wool/_protobuf/mempool/mempool_pb2_grpc.py +0 -312
  35. wool/_protobuf/mempool/metadata/metadata_pb2.py +0 -36
  36. wool/_protobuf/mempool/metadata/metadata_pb2.pyi +0 -17
  37. wool/_queue.py +0 -32
  38. wool/_session.py +0 -429
  39. wool/_task.py +0 -366
  40. wool/_utils.py +0 -63
  41. wool-0.1rc8.dist-info/RECORD +0 -28
  42. wool-0.1rc8.dist-info/entry_points.txt +0 -2
  43. {wool-0.1rc8.dist-info → wool-0.1rc10.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1223 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import hashlib
5
+ import json
6
+ import multiprocessing.shared_memory
7
+ import socket
8
+ import struct
9
+ from abc import ABC
10
+ from abc import abstractmethod
11
+ from collections import deque
12
+ from dataclasses import dataclass
13
+ from dataclasses import field
14
+ from typing import TYPE_CHECKING
15
+ from typing import Any
16
+ from typing import AsyncContextManager
17
+ from typing import AsyncIterator
18
+ from typing import Awaitable
19
+ from typing import Callable
20
+ from typing import ContextManager
21
+ from typing import Deque
22
+ from typing import Dict
23
+ from typing import Generic
24
+ from typing import Literal
25
+ from typing import Protocol
26
+ from typing import Tuple
27
+ from typing import TypeAlias
28
+ from typing import TypeVar
29
+ from typing import final
30
+ from typing import runtime_checkable
31
+
32
+ if TYPE_CHECKING:
33
+ pass
34
+
35
+ from zeroconf import IPVersion
36
+ from zeroconf import ServiceInfo
37
+ from zeroconf import ServiceListener
38
+ from zeroconf import Zeroconf
39
+ from zeroconf.asyncio import AsyncServiceBrowser
40
+ from zeroconf.asyncio import AsyncZeroconf
41
+
42
+ if TYPE_CHECKING:
43
+ pass
44
+
45
+
46
+ # public
47
+ @dataclass
48
+ class WorkerInfo:
49
+ """Properties and metadata for a worker instance.
50
+
51
+ Contains identifying information and capabilities of a worker that
52
+ can be used for discovery, filtering, and routing decisions.
53
+
54
+ :param uid:
55
+ Unique identifier for the worker instance.
56
+ :param host:
57
+ Network host address where the worker is accessible.
58
+ :param port:
59
+ Network port number where the worker is listening.
60
+ :param pid:
61
+ Process ID of the worker.
62
+ :param version:
63
+ Version string of the worker software.
64
+ :param tags:
65
+ Set of capability tags for worker filtering and selection.
66
+ :param extra:
67
+ Additional arbitrary metadata as key-value pairs.
68
+ """
69
+
70
+ uid: str
71
+ host: str
72
+ port: int | None
73
+ pid: int
74
+ version: str
75
+ tags: set[str] = field(default_factory=set)
76
+ extra: dict[str, Any] = field(default_factory=dict)
77
+
78
+ def __hash__(self) -> int:
79
+ return hash(self.uid)
80
+
81
+
82
+ # public
83
+ DiscoveryEventType: TypeAlias = Literal[
84
+ "worker_added", "worker_removed", "worker_updated"
85
+ ]
86
+
87
+ _T = TypeVar("_T")
88
+ PredicateFunction: TypeAlias = Callable[[_T], bool]
89
+
90
+
91
+ class PredicatedQueue(Generic[_T]):
92
+ """An asyncio queue that supports predicated gets.
93
+
94
+ Items can be retrieved only if they match a predicate function.
95
+ Non-matching items remain in the queue for future gets. This allows
96
+ selective consumption from a shared queue based on item properties.
97
+
98
+ :param maxsize:
99
+ Maximum number of items in the queue (0 for unlimited).
100
+ """
101
+
102
+ _maxsize: int
103
+ _queue: Deque[_T]
104
+ _getters: Deque[Tuple[asyncio.Future[_T], PredicateFunction[_T] | None]]
105
+ _putters: Deque[Tuple[asyncio.Future[None], _T]]
106
+ _unfinished_tasks: int
107
+ _finished: asyncio.Event
108
+
109
+ def __init__(self, maxsize: int = 0):
110
+ self._maxsize = maxsize
111
+ self._queue: Deque[_T] = deque()
112
+ self._getters: Deque[Tuple[asyncio.Future[_T], PredicateFunction[_T] | None]] = (
113
+ deque()
114
+ )
115
+ self._putters: Deque[Tuple[asyncio.Future[None], _T]] = deque()
116
+ self._unfinished_tasks = 0
117
+ self._finished = asyncio.Event()
118
+ self._finished.set()
119
+
120
+ def qsize(self) -> int:
121
+ """Number of items in the queue."""
122
+ return len(self._queue)
123
+
124
+ def empty(self) -> bool:
125
+ """Return True if the queue is empty, False otherwise."""
126
+ return not self._queue
127
+
128
+ def full(self) -> bool:
129
+ """Return True if there are maxsize items in the queue."""
130
+ if self._maxsize <= 0:
131
+ return False
132
+ return self.qsize() >= self._maxsize
133
+
134
+ async def put(self, item: _T) -> None:
135
+ """Put an item into the queue.
136
+
137
+ If the queue is full, wait until a free slot is available.
138
+ """
139
+ while self.full():
140
+ putter_future: asyncio.Future[None] = asyncio.Future()
141
+ self._putters.append((putter_future, item))
142
+ try:
143
+ await putter_future
144
+ return
145
+ except asyncio.CancelledError:
146
+ putter_future.cancel()
147
+ try:
148
+ self._putters.remove((putter_future, item))
149
+ except ValueError:
150
+ pass
151
+ raise
152
+
153
+ self._put_nowait(item)
154
+
155
+ def put_nowait(self, item: _T) -> None:
156
+ """Put an item into the queue without blocking.
157
+
158
+ :raises QueueFull:
159
+ If no free slot is immediately available.
160
+ """
161
+ if self.full():
162
+ raise asyncio.QueueFull
163
+ self._put_nowait(item)
164
+
165
+ def _put_nowait(self, item: _T) -> None:
166
+ """Internal method to put item without capacity checks."""
167
+ self._queue.append(item)
168
+ self._unfinished_tasks += 1
169
+ self._finished.clear()
170
+ self._wakeup_next_getter(item)
171
+
172
+ async def get(self, predicate: PredicateFunction[_T] | None = None) -> _T:
173
+ """Remove and return an item from the queue that matches the predicate.
174
+
175
+ If predicate is None, return the first available item. If predicate is
176
+ provided, return the first item that makes predicate(item) True. Items
177
+ that don't match the predicate remain in the queue.
178
+
179
+ If no matching item is available, wait until one becomes available.
180
+
181
+ :param predicate:
182
+ Optional function to filter items.
183
+ :returns:
184
+ Item that matches the predicate.
185
+ """
186
+ while True:
187
+ # Try to find a matching item in the current queue
188
+ item = self._get_matching_item(predicate)
189
+ if item is not None:
190
+ self._wakeup_next_putter()
191
+ return item
192
+
193
+ # No matching item found, wait for new items
194
+ getter_future: asyncio.Future[_T] = asyncio.Future()
195
+ self._getters.append((getter_future, predicate))
196
+ try:
197
+ return await getter_future
198
+ except asyncio.CancelledError:
199
+ getter_future.cancel()
200
+ try:
201
+ self._getters.remove((getter_future, predicate))
202
+ except ValueError:
203
+ pass
204
+ raise
205
+
206
+ def get_nowait(self, predicate: PredicateFunction[_T] | None = None) -> _T:
207
+ """Remove and return an item immediately that matches the predicate.
208
+
209
+ :param predicate:
210
+ Optional function to filter items.
211
+ :returns:
212
+ Item that matches the predicate.
213
+ :raises QueueEmpty:
214
+ If no matching item is immediately available.
215
+ """
216
+ item = self._get_matching_item(predicate)
217
+ if item is None:
218
+ raise asyncio.QueueEmpty
219
+ self._wakeup_next_putter()
220
+ return item
221
+
222
+ def _get_matching_item(self, predicate: PredicateFunction[_T] | None) -> _T | None:
223
+ """Find and remove the first item that matches the predicate."""
224
+ if not self._queue:
225
+ return None
226
+
227
+ if predicate is None:
228
+ # No predicate, return first available item
229
+ return self._queue.popleft()
230
+
231
+ # Search for matching item
232
+ for i, item in enumerate(self._queue):
233
+ if predicate(item):
234
+ # Found matching item, remove it from queue
235
+ del self._queue[i]
236
+ return item
237
+
238
+ return None
239
+
240
+ def _wakeup_next_getter(self, item: _T) -> None:
241
+ """Try to satisfy waiting getters with the new item."""
242
+ remaining_getters = deque()
243
+ item_consumed = False
244
+
245
+ while self._getters and not item_consumed:
246
+ getter_future, getter_predicate = self._getters.popleft()
247
+
248
+ if getter_future.done():
249
+ continue
250
+
251
+ # Check if this getter's predicate matches the item
252
+ if getter_predicate is None or getter_predicate(item):
253
+ # This getter can take the item
254
+ # Try to remove the item from queue and satisfy the getter
255
+ try:
256
+ self._queue.remove(item)
257
+ getter_future.set_result(item)
258
+ item_consumed = True
259
+ # Item was successfully given to a getter, we're done
260
+ break
261
+ except ValueError:
262
+ # Item was already taken by another operation
263
+ # Continue to next getter, but don't try to give them this item
264
+ remaining_getters.append((getter_future, getter_predicate))
265
+ else:
266
+ # This getter's predicate doesn't match, keep waiting
267
+ remaining_getters.append((getter_future, getter_predicate))
268
+
269
+ # Restore getters that couldn't be satisfied
270
+ self._getters.extendleft(reversed(remaining_getters))
271
+
272
+ def _wakeup_next_putter(self) -> None:
273
+ """Wake up the next putter if there's space."""
274
+ while self._putters and not self.full():
275
+ putter_future, item = self._putters.popleft()
276
+ if not putter_future.done():
277
+ self._put_nowait(item)
278
+ putter_future.set_result(None)
279
+ break
280
+
281
+ def task_done(self) -> None:
282
+ """Indicate that a formerly enqueued task is complete."""
283
+ if self._unfinished_tasks <= 0:
284
+ raise ValueError("task_done() called too many times")
285
+ self._unfinished_tasks -= 1
286
+ if self._unfinished_tasks == 0:
287
+ self._finished.set()
288
+
289
+ async def join(self) -> None:
290
+ """Wait until all items in the queue have been gotten and completed."""
291
+ await self._finished.wait()
292
+
293
+
294
+ # public
295
+ @dataclass
296
+ class DiscoveryEvent:
297
+ """Represents a worker service discovery event.
298
+
299
+ Contains information about worker service lifecycle events (added,
300
+ updated, removed) including both pre- and post-event property states to
301
+ enable comprehensive event handling.
302
+
303
+ :param type:
304
+ Type of discovery event (added, updated, or removed).
305
+ :param worker:
306
+ The :class:`~wool._worker_discovery.WorkerInfo` instance associated with this
307
+ event.
308
+ """
309
+
310
+ type: DiscoveryEventType
311
+ worker_info: WorkerInfo
312
+
313
+
314
+ _T_co = TypeVar("_T_co", covariant=True)
315
+
316
+
317
+ # public
318
+ class Reducible(Protocol):
319
+ """Protocol for objects that support pickling via __reduce__."""
320
+
321
+ def __reduce__(self) -> tuple: ...
322
+
323
+
324
+ # public
325
+ class ReducibleAsyncIteratorLike(Reducible, Protocol, Generic[_T_co]):
326
+ """Protocol for async iterators that yield discovery events.
327
+
328
+ Implementations must be pickleable via __reduce__ to support
329
+ task-specific session contexts in distributed environments.
330
+ """
331
+
332
+ def __aiter__(self) -> ReducibleAsyncIteratorLike[_T_co]: ...
333
+
334
+ def __anext__(self) -> Awaitable[_T_co]: ...
335
+
336
+
337
+ # public
338
+ @runtime_checkable
339
+ class Factory(Protocol, Generic[_T_co]):
340
+ def __call__(
341
+ self,
342
+ ) -> (
343
+ _T_co | Awaitable[_T_co] | AsyncContextManager[_T_co] | ContextManager[_T_co]
344
+ ): ...
345
+
346
+
347
+ # public
348
+ class DiscoveryService(ABC):
349
+ """Abstract base class for discovering worker services.
350
+
351
+ When started, implementations should discover all existing services that
352
+ satisfy the specified filter and deliver worker-added events for each.
353
+ Subsequently, they should monitor for newly added, updated, or removed
354
+ workers and deliver appropriate events via the :meth:`events` method.
355
+
356
+ Service tracking behavior:
357
+ - Only workers satisfying the filter should be tracked
358
+ - Workers updated to satisfy the filter should trigger worker-added events
359
+ - Workers updated to no longer satisfy the filter should trigger
360
+ worker-removed events
361
+ - Tracked workers removed from the registry entirely should always
362
+ trigger worker-removed
363
+
364
+ :param filter:
365
+ Optional filter function to select which discovery events to yield.
366
+ Only events matching the filter will be delivered.
367
+
368
+ .. warning::
369
+ The discovery procedure should not block continuously, as it will be
370
+ executed in the current event loop.
371
+
372
+ .. note::
373
+ Implementations must be pickleable and provide an unstarted copy
374
+ when unpickled to support task-specific session contexts.
375
+ """
376
+
377
+ _started: bool
378
+ _service_cache: Dict[str, WorkerInfo]
379
+
380
+ def __init__(
381
+ self,
382
+ filter: PredicateFunction[WorkerInfo] | None = None,
383
+ ):
384
+ self._filter = filter
385
+ self._started = False
386
+ self._service_cache = {}
387
+
388
+ def __reduce__(self) -> tuple:
389
+ """Return constructor args for unpickling an unstarted service copy."""
390
+ return (self.__class__, (self._filter,))
391
+
392
+ def __aiter__(self) -> AsyncIterator[DiscoveryEvent]:
393
+ """Returns self as an async iterator."""
394
+ return self.events()
395
+
396
+ async def __anext__(self) -> DiscoveryEvent:
397
+ """Delegate to the events async iterator."""
398
+ return await anext(self.events())
399
+
400
+ @final
401
+ @property
402
+ def started(self) -> bool:
403
+ return self._started
404
+
405
+ @final
406
+ async def start(self) -> None:
407
+ """Starts the worker discovery procedure.
408
+
409
+ This method should initiate the discovery process, which may involve
410
+ network operations or other asynchronous tasks.
411
+
412
+ :raises RuntimeError:
413
+ If the service has already been started.
414
+ """
415
+ if self._started:
416
+ raise RuntimeError("Discovery service already started")
417
+ await self._start()
418
+ self._started = True
419
+
420
+ @final
421
+ async def stop(self) -> None:
422
+ """Stops the worker discovery procedure.
423
+
424
+ This method should clean up any resources used for discovery, such as
425
+ network connections or event listeners.
426
+
427
+ :raises RuntimeError:
428
+ If the service has not been started.
429
+ """
430
+ if not self._started:
431
+ raise RuntimeError("Discovery service not started")
432
+ await self._stop()
433
+
434
+ @abstractmethod
435
+ def events(self) -> AsyncIterator[DiscoveryEvent]:
436
+ """Yields discovery events as they occur.
437
+
438
+ Returns an asynchronous iterator that yields discovery events for
439
+ workers being added, updated, or removed from the registry. Events
440
+ are filtered according to the filter function provided during
441
+ initialization.
442
+
443
+ :yields:
444
+ Instances of :class:`DiscoveryEvent` representing worker
445
+ additions, updates, and removals.
446
+ """
447
+ ...
448
+
449
+ @abstractmethod
450
+ async def _start(self) -> None:
451
+ """Starts the worker discovery procedure."""
452
+ ...
453
+
454
+ @abstractmethod
455
+ async def _stop(self) -> None:
456
+ """Stops the worker discovery procedure."""
457
+ ...
458
+
459
+
460
+ _T_DiscoveryServiceLike = TypeVar("_T_DiscoveryServiceLike", bound=DiscoveryService)
461
+
462
+
463
+ # public
464
+ class RegistryServiceLike(Protocol):
465
+ """Abstract base class for a service where workers can register themselves.
466
+
467
+ Provides the interface for worker registration, unregistration, and
468
+ property updates within a distributed worker pool system.
469
+ """
470
+
471
+ async def start(self) -> None: ...
472
+
473
+ async def stop(self) -> None: ...
474
+
475
+ async def register(self, worker_info: WorkerInfo) -> None: ...
476
+
477
+ async def unregister(self, worker_info: WorkerInfo) -> None: ...
478
+
479
+ async def update(self, worker_info: WorkerInfo) -> None: ...
480
+
481
+
482
+ # public
483
+ class RegistryService(Generic[_T_DiscoveryServiceLike], ABC):
484
+ """Abstract base class for a service where workers can register themselves.
485
+
486
+ Provides the interface for worker registration, unregistration, and
487
+ property updates within a distributed worker pool system.
488
+ """
489
+
490
+ _started: bool
491
+ _stopped: bool
492
+
493
+ def __init__(self):
494
+ self._started = False
495
+ self._stopped = False
496
+
497
+ async def start(self) -> None:
498
+ """Starts the registry service, making it ready to accept registrations.
499
+
500
+ :raises RuntimeError:
501
+ If the service has already been started.
502
+ """
503
+ if self._started:
504
+ raise RuntimeError("Registry service already started")
505
+ await asyncio.wait_for(self._start(), timeout=60)
506
+ self._started = True
507
+
508
+ async def stop(self) -> None:
509
+ """Stops the registry service and cleans up any resources.
510
+
511
+ :raises RuntimeError:
512
+ If the service has not been started.
513
+ """
514
+ if self._stopped:
515
+ return
516
+ if not self._started:
517
+ raise RuntimeError("Registry service not started")
518
+ await self._stop()
519
+ self._stopped = True
520
+
521
+ @abstractmethod
522
+ async def _start(self) -> None:
523
+ """Starts the registry service, making it ready to accept registrations."""
524
+ ...
525
+
526
+ @abstractmethod
527
+ async def _stop(self) -> None:
528
+ """Stops the registry service and cleans up any resources."""
529
+ ...
530
+
531
+ async def register(
532
+ self,
533
+ worker_info: WorkerInfo,
534
+ ) -> None:
535
+ """Registers a worker by publishing its service information.
536
+
537
+ :param worker_info:
538
+ The :class:`~wool._worker_discovery.WorkerInfo` instance containing all
539
+ worker details.
540
+ :raises RuntimeError:
541
+ If the registry service is not running.
542
+ """
543
+ if not self._started:
544
+ raise RuntimeError("Registry service not started - call start() first")
545
+ if self._stopped:
546
+ raise RuntimeError("Registry service already stopped")
547
+ await self._register(worker_info)
548
+
549
+ @abstractmethod
550
+ async def _register(
551
+ self,
552
+ worker_info: WorkerInfo,
553
+ ) -> None:
554
+ """Implementation-specific worker registration.
555
+
556
+ :param worker_info:
557
+ The :class:`~wool._worker_discovery.WorkerInfo` instance containing all
558
+ worker details.
559
+ """
560
+ ...
561
+
562
+ async def unregister(self, worker_info: WorkerInfo) -> None:
563
+ """Unregisters a worker by removing its service record.
564
+
565
+ :param worker_info:
566
+ The :class:`~wool._worker_discovery.WorkerInfo` instance of the worker to
567
+ unregister.
568
+ :raises RuntimeError:
569
+ If the registry service is not running.
570
+ """
571
+ if not self._started:
572
+ raise RuntimeError("Registry service not started - call start() first")
573
+ if self._stopped:
574
+ raise RuntimeError("Registry service already stopped")
575
+ await self._unregister(worker_info)
576
+
577
+ @abstractmethod
578
+ async def _unregister(self, worker_info: WorkerInfo) -> None:
579
+ """Implementation-specific worker unregistration.
580
+
581
+ :param worker_info:
582
+ The :class:`~wool._worker_discovery.WorkerInfo` instance of the worker to
583
+ unregister.
584
+ """
585
+ ...
586
+
587
+ async def update(self, worker_info: WorkerInfo) -> None:
588
+ """Updates a worker's properties if they have changed.
589
+
590
+ :param worker_info:
591
+ The updated :class:`~wool._worker_discovery.WorkerInfo` instance.
592
+ :raises RuntimeError:
593
+ If the registry service is not running.
594
+ """
595
+ if not self._started:
596
+ raise RuntimeError("Registry service not started - call start() first")
597
+ if self._stopped:
598
+ raise RuntimeError("Registry service already stopped")
599
+ await self._update(worker_info)
600
+
601
+ @abstractmethod
602
+ async def _update(self, worker_info: WorkerInfo) -> None:
603
+ """Implementation-specific worker property updates.
604
+
605
+ :param worker_info:
606
+ The updated :class:`~wool._worker_discovery.WorkerInfo` instance.
607
+ """
608
+ ...
609
+
610
+
611
+ # public
612
+ class LanDiscoveryService(DiscoveryService):
613
+ """Implements worker discovery on the local network using Zeroconf.
614
+
615
+ This service browses the local network for DNS-SD services and delivers
616
+ all worker service events to clients. Uses Zeroconf/Bonjour protocol
617
+ for automatic service discovery without requiring central coordination.
618
+
619
+ :param filter:
620
+ Optional predicate function to filter discovered workers.
621
+ """
622
+
623
+ aiozc: AsyncZeroconf
624
+ browser: AsyncServiceBrowser
625
+ service_type: Literal["_wool._tcp.local."] = "_wool._tcp.local."
626
+ _event_queue: PredicatedQueue[DiscoveryEvent]
627
+
628
+ def __init__(
629
+ self,
630
+ filter: PredicateFunction[WorkerInfo] | None = None,
631
+ ) -> None:
632
+ super().__init__(filter) # type: ignore[arg-type]
633
+ self._event_queue = PredicatedQueue()
634
+
635
+ async def events(self) -> AsyncIterator[DiscoveryEvent]:
636
+ """Returns an async iterator over discovery events."""
637
+ await self.start()
638
+ try:
639
+ while True:
640
+ yield await self._event_queue.get()
641
+ finally:
642
+ await self.stop()
643
+
644
+ async def _start(self) -> None:
645
+ """Starts the Zeroconf service browser.
646
+
647
+ :raises RuntimeError:
648
+ If the service has already been started.
649
+ """
650
+ # Configure zeroconf to use localhost only to avoid network warnings
651
+ self.aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
652
+ self.browser = AsyncServiceBrowser(
653
+ self.aiozc.zeroconf,
654
+ self.service_type,
655
+ listener=self._Listener(
656
+ aiozc=self.aiozc,
657
+ event_queue=self._event_queue,
658
+ service_cache=self._service_cache,
659
+ predicate=self._filter or (lambda _: True),
660
+ ),
661
+ )
662
+
663
+ async def _stop(self) -> None:
664
+ """Stops the Zeroconf service browser and closes the connection.
665
+
666
+ :raises RuntimeError:
667
+ If the service has not been started.
668
+ """
669
+ if self.browser:
670
+ await self.browser.async_cancel()
671
+ if self.aiozc:
672
+ await self.aiozc.async_close()
673
+
674
+ class _Listener(ServiceListener):
675
+ """A Zeroconf listener that delivers all worker service events.
676
+
677
+ :param aiozc:
678
+ The :class:`~zeroconf.asyncio.AsyncZeroconf` instance to use
679
+ for async service info retrieval.
680
+ :param event_queue:
681
+ Queue to deliver discovery events to.
682
+ :param service_cache:
683
+ Cache to track service properties for pre/post event states.
684
+ :param predicate:
685
+ Function to filter which workers to track.
686
+ """
687
+
688
+ aiozc: AsyncZeroconf
689
+ _event_queue: PredicatedQueue[DiscoveryEvent]
690
+ _service_addresses: Dict[str, str]
691
+ _service_cache: Dict[str, WorkerInfo]
692
+
693
+ def __init__(
694
+ self,
695
+ aiozc: AsyncZeroconf,
696
+ event_queue: PredicatedQueue[DiscoveryEvent],
697
+ predicate: PredicateFunction[WorkerInfo],
698
+ service_cache: Dict[str, WorkerInfo],
699
+ ) -> None:
700
+ self.aiozc = aiozc
701
+ self._event_queue = event_queue
702
+ self._predicate = predicate
703
+ self._service_addresses = {}
704
+ self._service_cache = service_cache
705
+
706
+ def add_service(self, zc: Zeroconf, type_: str, name: str):
707
+ """Called by Zeroconf when a service is added."""
708
+ if type_ == LanRegistryService.service_type:
709
+ asyncio.create_task(self._handle_add_service(type_, name))
710
+
711
+ def remove_service(self, zc: Zeroconf, type_: str, name: str):
712
+ """Called by Zeroconf when a service is removed."""
713
+ if type_ == LanRegistryService.service_type:
714
+ if worker := self._service_cache.pop(name, None):
715
+ asyncio.create_task(
716
+ self._event_queue.put(
717
+ DiscoveryEvent(type="worker_removed", worker_info=worker)
718
+ )
719
+ )
720
+
721
+ def update_service(self, zc: Zeroconf, type_, name):
722
+ """Called by Zeroconf when a service is updated."""
723
+ if type_ == LanRegistryService.service_type:
724
+ asyncio.create_task(self._handle_update_service(type_, name))
725
+
726
+ async def _handle_add_service(self, type_: str, name: str):
727
+ """Async handler for service addition."""
728
+ try:
729
+ if not (
730
+ service_info := await self.aiozc.async_get_service_info(type_, name)
731
+ ):
732
+ return
733
+
734
+ try:
735
+ worker_info = _deserialize_worker_info(service_info)
736
+ except ValueError:
737
+ return
738
+
739
+ if self._predicate(worker_info):
740
+ self._service_cache[name] = worker_info
741
+ event = DiscoveryEvent(type="worker_added", worker_info=worker_info)
742
+ await self._event_queue.put(event)
743
+ except Exception:
744
+ pass # Service may have disappeared before we could query it
745
+
746
+ async def _handle_update_service(self, type_: str, name: str):
747
+ """Async handler for service update."""
748
+ try:
749
+ if not (
750
+ service_info := await self.aiozc.async_get_service_info(type_, name)
751
+ ):
752
+ return
753
+
754
+ try:
755
+ worker_info = _deserialize_worker_info(service_info)
756
+ except ValueError:
757
+ return
758
+
759
+ if name not in self._service_cache:
760
+ # New worker that wasn't tracked before
761
+ if self._predicate(worker_info):
762
+ self._service_cache[name] = worker_info
763
+ event = DiscoveryEvent(
764
+ type="worker_added", worker_info=worker_info
765
+ )
766
+ await self._event_queue.put(event)
767
+ else:
768
+ # Existing tracked worker
769
+ old_worker = self._service_cache[name]
770
+ if self._predicate(worker_info):
771
+ # Still satisfies filter, update cache and emit update
772
+ self._service_cache[name] = worker_info
773
+ event = DiscoveryEvent(
774
+ type="worker_updated", worker_info=worker_info
775
+ )
776
+ await self._event_queue.put(event)
777
+ else:
778
+ # No longer satisfies filter, remove and emit removal
779
+ del self._service_cache[name]
780
+ removal_event = DiscoveryEvent(
781
+ type="worker_removed", worker_info=old_worker
782
+ )
783
+ await self._event_queue.put(removal_event)
784
+
785
+ except Exception:
786
+ pass
787
+
788
+
789
+ # public
790
+ class LanRegistryService(RegistryService[LanDiscoveryService]):
791
+ """Implements a worker registry using Zeroconf to advertise on the LAN.
792
+
793
+ This service registers workers by publishing a DNS-SD service record on
794
+ the local network, allowing :class:`LanDiscoveryService` to find them.
795
+ """
796
+
797
+ aiozc: AsyncZeroconf | None
798
+ services: Dict[str, ServiceInfo]
799
+ service_type: Literal["_wool._tcp.local."] = "_wool._tcp.local."
800
+
801
+ def __init__(self):
802
+ super().__init__()
803
+ self.aiozc = None
804
+ self.services = {}
805
+
806
+ async def _start(self) -> None:
807
+ """Initializes and starts the Zeroconf instance for advertising."""
808
+ # Configure zeroconf to use localhost only to avoid network warnings
809
+ self.aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
810
+
811
+ async def _stop(self) -> None:
812
+ """Stops the Zeroconf instance and cleans up all registered services."""
813
+ if self.aiozc:
814
+ await self.aiozc.async_close()
815
+ self.aiozc = None
816
+
817
+ async def _register(
818
+ self,
819
+ worker_info: WorkerInfo,
820
+ ) -> None:
821
+ """Registers a worker by publishing its service information via Zeroconf.
822
+
823
+ :param worker_info:
824
+ The :class:`~wool._worker_discovery.WorkerInfo` instance containing all
825
+ worker details.
826
+ :raises RuntimeError:
827
+ If the registry service is not properly initialized.
828
+ """
829
+ if self.aiozc is None:
830
+ raise RuntimeError("Registry service not properly initialized")
831
+ address = f"{worker_info.host}:{worker_info.port}"
832
+ ip_address, port = self._resolve_address(address)
833
+ service_name = f"{worker_info.uid}.{self.service_type}"
834
+ service_info = ServiceInfo(
835
+ self.service_type,
836
+ service_name,
837
+ addresses=[ip_address],
838
+ port=port,
839
+ properties=_serialize_worker_info(worker_info),
840
+ )
841
+ self.services[worker_info.uid] = service_info
842
+ await self.aiozc.async_register_service(service_info)
843
+
844
+ async def _unregister(self, worker_info: WorkerInfo) -> None:
845
+ """Unregisters a worker by removing its Zeroconf service record.
846
+
847
+ :param worker_info:
848
+ The :class:`~wool._worker_discovery.WorkerInfo` instance of the worker to
849
+ unregister.
850
+ :raises RuntimeError:
851
+ If the registry service is not properly initialized.
852
+ """
853
+ if self.aiozc is None:
854
+ raise RuntimeError("Registry service not properly initialized")
855
+ service = self.services[worker_info.uid]
856
+ await self.aiozc.async_unregister_service(service)
857
+ del self.services[worker_info.uid]
858
+
859
+ async def _update(self, worker_info: WorkerInfo) -> None:
860
+ """Updates a worker's properties if they have changed.
861
+
862
+ Updates both the Zeroconf service and local cache atomically.
863
+ If the Zeroconf update fails, the local cache remains unchanged
864
+ to maintain consistency.
865
+
866
+ :param worker_info:
867
+ The updated :class:`~wool._worker_discovery.WorkerInfo` instance.
868
+ :raises RuntimeError:
869
+ If the registry service is not properly initialized.
870
+ :raises Exception:
871
+ If the Zeroconf service update fails.
872
+ """
873
+ if self.aiozc is None:
874
+ raise RuntimeError("Registry service not properly initialized")
875
+
876
+ service = self.services[worker_info.uid]
877
+ new_properties = _serialize_worker_info(worker_info)
878
+
879
+ if service.decoded_properties != new_properties:
880
+ updated_service = ServiceInfo(
881
+ service.type,
882
+ service.name,
883
+ addresses=service.addresses,
884
+ port=service.port,
885
+ properties=new_properties,
886
+ server=service.server,
887
+ )
888
+ await self.aiozc.async_update_service(updated_service)
889
+ self.services[worker_info.uid] = updated_service
890
+
891
+ def _resolve_address(self, address: str) -> Tuple[bytes, int]:
892
+ """Resolve an address string to bytes and validate port.
893
+
894
+ :param address:
895
+ Address in format "host:port".
896
+ :returns:
897
+ Tuple of (IPv4/IPv6 address as bytes, port as int).
898
+ :raises ValueError:
899
+ If address format is invalid or port is out of range.
900
+ """
901
+ host, port = address.split(":")
902
+ port = int(port)
903
+
904
+ try:
905
+ return socket.inet_pton(socket.AF_INET, host), port
906
+ except OSError:
907
+ pass
908
+
909
+ try:
910
+ return socket.inet_pton(socket.AF_INET6, host), port
911
+ except OSError:
912
+ pass
913
+
914
+ return socket.inet_aton(socket.gethostbyname(host)), port
915
+
916
+
917
+ def _serialize_worker_info(
918
+ info: WorkerInfo,
919
+ ) -> dict[str, str | None]:
920
+ """Serialize WorkerInfo to a flat dict for ServiceInfo.properties.
921
+
922
+ :param info:
923
+ :class:`~wool._worker_discovery.WorkerInfo` instance to serialize.
924
+ :returns:
925
+ Flat dict with pid, version, tags (JSON), extra (JSON).
926
+ """
927
+ properties = {
928
+ "pid": str(info.pid),
929
+ "version": info.version,
930
+ "tags": (json.dumps(list(info.tags)) if info.tags else None),
931
+ "extra": (json.dumps(info.extra) if info.extra else None),
932
+ }
933
+ return properties
934
+
935
+
936
+ def _deserialize_worker_info(info: ServiceInfo) -> WorkerInfo:
937
+ """Deserialize ServiceInfo.decoded_properties to WorkerProperties.
938
+
939
+ :param info:
940
+ ServiceInfo with decoded properties dict (str keys/values).
941
+ :returns:
942
+ :class:`~wool._worker_discovery.WorkerInfo` instance.
943
+ :raises ValueError:
944
+ If required fields are missing or invalid JSON.
945
+ """
946
+ properties = info.decoded_properties
947
+ if missing := {"pid", "version"} - set(k for k, v in properties.items() if v):
948
+ missing = ", ".join(missing)
949
+ raise ValueError(f"Missing required properties: {missing}")
950
+ assert "pid" in properties and properties["pid"]
951
+ assert "version" in properties and properties["version"]
952
+ pid = int(properties["pid"])
953
+ version = properties["version"]
954
+ if "tags" in properties and properties["tags"]:
955
+ tags = set(json.loads(properties["tags"]))
956
+ else:
957
+ tags = set()
958
+ if "extra" in properties and properties["extra"]:
959
+ extra = json.loads(properties["extra"])
960
+ else:
961
+ extra = {}
962
+ return WorkerInfo(
963
+ uid=info.name,
964
+ pid=pid,
965
+ host=str(info.ip_addresses_by_version(IPVersion.V4Only)[0]),
966
+ port=info.port,
967
+ version=version,
968
+ tags=tags,
969
+ extra=extra,
970
+ )
971
+
972
+
973
+ # public
974
+ class LocalRegistryService(RegistryService):
975
+ """Implements a worker registry using shared memory for local pools.
976
+
977
+ This service registers workers by writing their information to a shared memory
978
+ block, allowing LocalDiscoveryService instances to find them efficiently.
979
+ The registry stores worker ports as integers in a simple array format,
980
+ providing fast local discovery without network overhead.
981
+
982
+ :param uri:
983
+ Unique identifier for the shared memory segment.
984
+ """
985
+
986
+ _shared_memory: multiprocessing.shared_memory.SharedMemory | None = None
987
+ _uri: str
988
+ _created_shared_memory: bool = False
989
+
990
+ def __init__(self, uri: str):
991
+ super().__init__()
992
+ self._uri = uri
993
+ self._created_shared_memory = False
994
+
995
+ async def _start(self) -> None:
996
+ """Initialize shared memory for worker registration."""
997
+ if self._shared_memory is None:
998
+ # Try to connect to existing shared memory first, create if it doesn't exist
999
+ shared_memory_name = hashlib.sha256(self._uri.encode()).hexdigest()[:12]
1000
+ try:
1001
+ self._shared_memory = multiprocessing.shared_memory.SharedMemory(
1002
+ name=shared_memory_name
1003
+ )
1004
+ except FileNotFoundError:
1005
+ # Create new shared memory if it doesn't exist
1006
+ self._shared_memory = multiprocessing.shared_memory.SharedMemory(
1007
+ name=shared_memory_name,
1008
+ create=True,
1009
+ size=1024, # 1024 bytes = 256 worker slots (4 bytes per port)
1010
+ )
1011
+ self._created_shared_memory = True
1012
+ # Initialize all slots to 0 (empty)
1013
+ for i in range(len(self._shared_memory.buf)):
1014
+ self._shared_memory.buf[i] = 0
1015
+
1016
+ async def _stop(self) -> None:
1017
+ """Clean up shared memory resources."""
1018
+ if self._shared_memory:
1019
+ try:
1020
+ self._shared_memory.close()
1021
+ # Unlink the shared memory if this registry created it
1022
+ if self._created_shared_memory:
1023
+ self._shared_memory.unlink()
1024
+ except Exception:
1025
+ pass
1026
+ self._shared_memory = None
1027
+ self._created_shared_memory = False
1028
+
1029
+ async def _register(self, worker_info: WorkerInfo) -> None:
1030
+ """Register a worker by writing its port to shared memory.
1031
+
1032
+ :param worker_info:
1033
+ The :class:`~wool._worker_discovery.WorkerInfo` instance containing all
1034
+ worker details. Only the port is stored in shared memory.
1035
+ :raises RuntimeError:
1036
+ If the registry service is not properly initialized.
1037
+ """
1038
+ if self._shared_memory is None:
1039
+ raise RuntimeError("Registry service not properly initialized")
1040
+
1041
+ if worker_info.port is None:
1042
+ raise ValueError("Worker port must be specified")
1043
+
1044
+ # Find first available slot and write port
1045
+ for i in range(0, len(self._shared_memory.buf), 4):
1046
+ existing_port = struct.unpack("I", self._shared_memory.buf[i : i + 4])[0]
1047
+ if existing_port == 0: # Empty slot
1048
+ struct.pack_into("I", self._shared_memory.buf, i, worker_info.port)
1049
+ break
1050
+ else:
1051
+ raise RuntimeError("No available slots in shared memory registry")
1052
+
1053
+ async def _unregister(self, worker_info: WorkerInfo) -> None:
1054
+ """Unregister a worker by removing its port from shared memory.
1055
+
1056
+ :param worker_info:
1057
+ The :class:`~wool._worker_discovery.WorkerInfo` instance of the worker to
1058
+ unregister.
1059
+ :raises RuntimeError:
1060
+ If the registry service is not properly initialized.
1061
+ """
1062
+ if self._shared_memory is None:
1063
+ raise RuntimeError("Registry service not properly initialized")
1064
+
1065
+ if worker_info.port is None:
1066
+ return
1067
+
1068
+ # Find and clear the port
1069
+ for i in range(0, len(self._shared_memory.buf), 4):
1070
+ existing_port = struct.unpack("I", self._shared_memory.buf[i : i + 4])[0]
1071
+ if existing_port == worker_info.port:
1072
+ struct.pack_into("I", self._shared_memory.buf, i, 0) # Clear slot
1073
+ break
1074
+
1075
+ async def _update(self, worker_info: WorkerInfo) -> None:
1076
+ """Update a worker's properties in shared memory.
1077
+
1078
+ For the simple port-based registry, update is the same as register.
1079
+
1080
+ :param worker_info:
1081
+ The updated :class:`~wool._worker_discovery.WorkerInfo` instance.
1082
+ """
1083
+ await self._register(worker_info)
1084
+
1085
+
1086
+ # public
1087
+ class LocalDiscoveryService(DiscoveryService):
1088
+ """Implements worker discovery using shared memory for local pools.
1089
+
1090
+ This service reads worker ports from a shared memory block and
1091
+ constructs WorkerInfo instances with localhost as the implied host.
1092
+ Provides efficient local discovery for single-machine worker pools.
1093
+
1094
+ :param uri:
1095
+ Unique identifier for the shared memory segment.
1096
+ :param filter:
1097
+ Optional predicate function to filter discovered workers.
1098
+ """
1099
+
1100
+ _shared_memory: multiprocessing.shared_memory.SharedMemory | None = None
1101
+ _uri: str
1102
+ _event_queue: asyncio.Queue[DiscoveryEvent]
1103
+ _monitor_task: asyncio.Task | None
1104
+ _stop_event: asyncio.Event
1105
+
1106
+ def __init__(
1107
+ self,
1108
+ uri: str,
1109
+ filter: PredicateFunction[WorkerInfo] | None = None,
1110
+ ) -> None:
1111
+ super().__init__(filter) # type: ignore[arg-type]
1112
+ self._uri = uri
1113
+ self._event_queue = asyncio.Queue()
1114
+ self._monitor_task = None
1115
+ self._stop_event = asyncio.Event()
1116
+
1117
+ def __reduce__(self) -> tuple:
1118
+ """Return constructor args for unpickling an unstarted service copy."""
1119
+ func, args = super().__reduce__()
1120
+ return (func, (self._uri, *args))
1121
+
1122
+ async def events(self) -> AsyncIterator[DiscoveryEvent]:
1123
+ """Returns an async iterator over discovery events."""
1124
+ await self.start()
1125
+ try:
1126
+ while True:
1127
+ event = await self._event_queue.get()
1128
+ yield event
1129
+ finally:
1130
+ await self.stop()
1131
+
1132
+ async def _start(self) -> None:
1133
+ """Starts monitoring shared memory for worker registrations."""
1134
+ if self._shared_memory is None:
1135
+ # Try to connect to existing shared memory first
1136
+ self._shared_memory = multiprocessing.shared_memory.SharedMemory(
1137
+ name=hashlib.sha256(self._uri.encode()).hexdigest()[:12]
1138
+ )
1139
+
1140
+ # Start monitoring task
1141
+ self._monitor_task = asyncio.create_task(self._monitor_shared_memory())
1142
+
1143
+ async def _stop(self) -> None:
1144
+ """Stops monitoring shared memory."""
1145
+ self._stop_event.set()
1146
+ if self._monitor_task:
1147
+ try:
1148
+ await self._monitor_task
1149
+ except asyncio.CancelledError:
1150
+ pass
1151
+ if self._shared_memory:
1152
+ try:
1153
+ self._shared_memory.close()
1154
+ except Exception:
1155
+ pass
1156
+ self._shared_memory = None
1157
+
1158
+ async def _monitor_shared_memory(self) -> None:
1159
+ """Monitor shared memory for changes and emit events."""
1160
+ poll_interval = 0.1
1161
+
1162
+ while not self._stop_event.is_set():
1163
+ try:
1164
+ current_workers = {}
1165
+
1166
+ # Read current state from shared memory
1167
+ if self._shared_memory:
1168
+ for i in range(0, len(self._shared_memory.buf), 4):
1169
+ port = struct.unpack("I", self._shared_memory.buf[i : i + 4])[0]
1170
+ if port > 0: # Active worker
1171
+ worker_info = WorkerInfo(
1172
+ uid=f"worker-{port}",
1173
+ host="localhost",
1174
+ port=port,
1175
+ pid=0, # Not available in simple registry
1176
+ version="unknown", # Not available in simple registry
1177
+ tags=set(),
1178
+ extra={},
1179
+ )
1180
+ if self._filter is None or self._filter(worker_info):
1181
+ current_workers[worker_info.uid] = worker_info
1182
+
1183
+ # Detect changes
1184
+ await self._detect_changes(current_workers)
1185
+
1186
+ # Wait before next poll
1187
+ try:
1188
+ await asyncio.wait_for(
1189
+ self._stop_event.wait(), timeout=poll_interval
1190
+ )
1191
+ break # Stop event was set
1192
+ except asyncio.TimeoutError:
1193
+ continue # Continue polling
1194
+
1195
+ except Exception:
1196
+ continue
1197
+
1198
+ async def _detect_changes(self, current_workers: Dict[str, WorkerInfo]) -> None:
1199
+ """Detect and emit events for worker changes."""
1200
+ # Find added workers
1201
+ for uid, worker_info in current_workers.items():
1202
+ if uid not in self._service_cache:
1203
+ self._service_cache[uid] = worker_info
1204
+ event = DiscoveryEvent(type="worker_added", worker_info=worker_info)
1205
+ await self._event_queue.put(event)
1206
+
1207
+ # Find removed workers
1208
+ for uid in list(self._service_cache.keys()):
1209
+ if uid not in current_workers:
1210
+ worker_info = self._service_cache.pop(uid)
1211
+ event = DiscoveryEvent(type="worker_removed", worker_info=worker_info)
1212
+ await self._event_queue.put(event)
1213
+
1214
+ # Find updated workers (minimal for port-only registry)
1215
+ for uid, worker_info in current_workers.items():
1216
+ if uid in self._service_cache:
1217
+ old_worker = self._service_cache[uid]
1218
+ if worker_info.port != old_worker.port:
1219
+ self._service_cache[uid] = worker_info
1220
+ event = DiscoveryEvent(
1221
+ type="worker_updated", worker_info=worker_info
1222
+ )
1223
+ await self._event_queue.put(event)