wool 0.1rc20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,534 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import socket
6
+ from asyncio import Queue
7
+ from types import MappingProxyType
8
+ from typing import AsyncIterator
9
+ from typing import Dict
10
+ from typing import Final
11
+ from typing import Literal
12
+ from typing import Tuple
13
+ from uuid import UUID
14
+
15
+ from zeroconf import IPVersion
16
+ from zeroconf import ServiceInfo
17
+ from zeroconf import ServiceListener
18
+ from zeroconf import Zeroconf
19
+ from zeroconf.asyncio import AsyncServiceBrowser
20
+ from zeroconf.asyncio import AsyncZeroconf
21
+
22
+ from wool.core.discovery.base import Discovery
23
+ from wool.core.discovery.base import DiscoveryEvent
24
+ from wool.core.discovery.base import DiscoveryEventType
25
+ from wool.core.discovery.base import DiscoveryPublisherLike
26
+ from wool.core.discovery.base import DiscoverySubscriberLike
27
+ from wool.core.discovery.base import PredicateFunction
28
+ from wool.core.discovery.base import WorkerInfo
29
+
30
+
31
+ # public
32
+ class LanDiscovery(Discovery):
33
+ """Worker discovery on the local network using Zeroconf/Bonjour.
34
+
35
+ Provides network-wide worker discovery using DNS Service Discovery
36
+ (DNS-SD) via the Zeroconf protocol. Workers are automatically
37
+ discovered as they join or leave the network without requiring
38
+ central coordination.
39
+
40
+ The service type "_wool._tcp.local." is used for all Wool worker
41
+ services on the LAN. Publishers advertise workers by registering
42
+ DNS-SD service records, and subscribers browse for these services.
43
+
44
+ Example usage:
45
+
46
+ Publish workers
47
+ .. code-block:: python
48
+ publisher = LanDiscovery.Publisher()
49
+ async with publisher:
50
+ await publisher.publish("worker-added", worker_info)
51
+
52
+ Subscribe to workers
53
+ .. code-block:: python
54
+ discovery = LanDiscovery()
55
+ async for event in discovery.subscriber:
56
+ print(f"Discovered worker: {event.worker_info}")
57
+ """
58
+
59
+ service_type: Literal["_wool._tcp.local."] = "_wool._tcp.local."
60
+
61
+ @property
62
+ def publisher(self) -> DiscoveryPublisherLike:
63
+ """A new publisher instance for this discovery service.
64
+
65
+ :returns:
66
+ A publisher instance for broadcasting worker events.
67
+ """
68
+ return self.Publisher()
69
+
70
+ @property
71
+ def subscriber(self) -> DiscoverySubscriberLike:
72
+ """The default subscriber that receives all worker events.
73
+
74
+ :returns:
75
+ A subscriber instance that receives all worker discovery
76
+ events.
77
+ """
78
+ return self.subscribe()
79
+
80
+ def subscribe(
81
+ self, filter: PredicateFunction | None = None
82
+ ) -> DiscoverySubscriberLike:
83
+ """Create a new subscriber with optional filtering.
84
+
85
+ :param filter:
86
+ Optional predicate function to filter workers. Only workers
87
+ for which the predicate returns True will be included in
88
+ events.
89
+ :returns:
90
+ A subscriber instance that receives filtered worker
91
+ discovery events.
92
+ """
93
+ return self.Subscriber(filter)
94
+
95
+ class Publisher:
96
+ """Publisher for broadcasting worker discovery events.
97
+
98
+ Publishes worker :class:`discovery events <~wool.DiscoveryEvent>`
99
+ by registering and managing DNS-SD service records on the local
100
+ network. Multiple publishers can safely operate on the same
101
+ network, each advertising their own set of workers.
102
+
103
+ Uses AsyncZeroconf for non-blocking service registration and
104
+ management. Services are advertised on localhost (127.0.0.1) to
105
+ avoid network warnings during development.
106
+ """
107
+
108
+ aiozc: AsyncZeroconf | None
109
+ services: Dict[str, ServiceInfo]
110
+ service_type: Literal["_wool._tcp.local."] = "_wool._tcp.local."
111
+
112
+ def __init__(self):
113
+ self.aiozc = None
114
+ self.services = {}
115
+
116
+ async def __aenter__(self):
117
+ """Initialize and start the Zeroconf instance.
118
+
119
+ Configures AsyncZeroconf to use localhost only to avoid
120
+ network warnings during development.
121
+
122
+ :returns:
123
+ Self, for context manager usage.
124
+ """
125
+ # Configure zeroconf to use localhost only
126
+ self.aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
127
+ return self
128
+
129
+ async def __aexit__(self, *_args):
130
+ """Stop Zeroconf and clean up registered services.
131
+
132
+ Closes the AsyncZeroconf instance and releases all registered
133
+ service records.
134
+ """
135
+ if self.aiozc:
136
+ await self.aiozc.async_close()
137
+ self.aiozc = None
138
+
139
+ async def publish(self, type: DiscoveryEventType, worker_info: WorkerInfo):
140
+ """Publish a worker discovery event.
141
+
142
+ Manages Zeroconf service records based on the event type:
143
+
144
+ - worker-added: Registers a new service record
145
+ - worker-dropped: Unregisters an existing service record
146
+ - worker-updated: Updates an existing service record
147
+
148
+ :param type:
149
+ The type of discovery event.
150
+ :param worker_info:
151
+ Worker information to publish.
152
+ :raises RuntimeError:
153
+ If the publisher is not properly initialized or if an
154
+ unexpected event type is provided.
155
+ """
156
+ if self.aiozc is None:
157
+ raise RuntimeError("Publisher not properly initialized")
158
+
159
+ match type:
160
+ case "worker-added":
161
+ await self._add(worker_info)
162
+ case "worker-dropped":
163
+ await self._drop(worker_info)
164
+ case "worker-updated":
165
+ await self._update(worker_info)
166
+ case _:
167
+ raise RuntimeError(f"Unexpected discovery event type: {type}")
168
+
169
+ async def _add(self, worker_info: WorkerInfo) -> None:
170
+ """Register a worker by publishing its service info.
171
+
172
+ :param worker_info:
173
+ The worker details to publish.
174
+ :raises RuntimeError:
175
+ If the publisher is not properly initialized.
176
+ :raises ValueError:
177
+ If worker port is not specified.
178
+ """
179
+ assert self.aiozc
180
+
181
+ if worker_info.port is None:
182
+ raise ValueError("Worker port must be specified for LAN discovery")
183
+
184
+ address = f"{worker_info.host}:{worker_info.port}"
185
+ ip_address, port = self._resolve_address(address)
186
+ service_name = f"{worker_info.uid}.{self.service_type}"
187
+ service_info = ServiceInfo(
188
+ self.service_type,
189
+ service_name,
190
+ addresses=[ip_address],
191
+ port=port,
192
+ properties=_serialize_worker_info(worker_info),
193
+ )
194
+ self.services[str(worker_info.uid)] = service_info
195
+ await self.aiozc.async_register_service(service_info)
196
+
197
+ async def _drop(self, worker_info: WorkerInfo) -> None:
198
+ """Unregister a worker by removing its service record.
199
+
200
+ :param worker_info:
201
+ The worker to unregister.
202
+ :raises RuntimeError:
203
+ If the publisher is not properly initialized.
204
+ """
205
+ assert self.aiozc
206
+
207
+ uid_str = str(worker_info.uid)
208
+ if uid_str in self.services:
209
+ service = self.services[uid_str]
210
+ await self.aiozc.async_unregister_service(service)
211
+ del self.services[uid_str]
212
+
213
+ async def _update(self, worker_info: WorkerInfo) -> None:
214
+ """Update a worker's properties if they have changed.
215
+
216
+ Updates both the Zeroconf service and local cache
217
+ atomically. If the Zeroconf update fails, the local cache
218
+ remains unchanged to maintain consistency.
219
+
220
+ :param worker_info:
221
+ The updated worker information.
222
+ :raises RuntimeError:
223
+ If the publisher is not properly initialized.
224
+ :raises Exception:
225
+ If the Zeroconf service update fails.
226
+ """
227
+ assert self.aiozc
228
+
229
+ uid_str = str(worker_info.uid)
230
+ if uid_str not in self.services:
231
+ # Worker not found, treat as registration
232
+ await self._add(worker_info)
233
+ return
234
+
235
+ service = self.services[uid_str]
236
+ new_properties = _serialize_worker_info(worker_info)
237
+
238
+ if service.decoded_properties != new_properties:
239
+ updated_service = ServiceInfo(
240
+ service.type,
241
+ service.name,
242
+ addresses=service.addresses,
243
+ port=service.port,
244
+ properties=new_properties,
245
+ server=service.server,
246
+ )
247
+ await self.aiozc.async_update_service(updated_service)
248
+ self.services[uid_str] = updated_service
249
+
250
+ def _resolve_address(self, address: str) -> Tuple[bytes, int]:
251
+ """Resolve an address string to bytes and validate port.
252
+
253
+ :param address:
254
+ Address in format "host:port".
255
+ :returns:
256
+ Tuple of (IPv4/IPv6 address as bytes, port as int).
257
+ :raises ValueError:
258
+ If address format is invalid or port is out of range.
259
+ :raises OSError:
260
+ If hostname cannot be resolved.
261
+ """
262
+ host, port_str = address.split(":")
263
+ port = int(port_str)
264
+
265
+ try:
266
+ return socket.inet_pton(socket.AF_INET, host), port
267
+ except OSError:
268
+ pass
269
+
270
+ try:
271
+ return socket.inet_pton(socket.AF_INET6, host), port
272
+ except OSError:
273
+ pass
274
+
275
+ return socket.inet_aton(socket.gethostbyname(host)), port
276
+
277
+ class Subscriber:
278
+ """Subscriber for receiving worker discovery events.
279
+
280
+ Subscribes to worker :class:`discovery events
281
+ <~wool.DiscoveryEvent>` by browsing for DNS-SD services on the
282
+ local network. As workers register and unregister their
283
+ services, the subscriber yields corresponding events.
284
+
285
+ Each call to ``__aiter__`` creates an isolated iterator with its
286
+ own state. Multiple concurrent iterations from the same
287
+ subscriber instance are fully independent.
288
+
289
+ Uses AsyncZeroconf's service browser to monitor for service
290
+ changes and converts Zeroconf events into Wool discovery
291
+ events.
292
+
293
+ :param filter:
294
+ Optional predicate function to filter workers. Only workers
295
+ for which the predicate returns True will be included in
296
+ events.
297
+ """
298
+
299
+ _filter: Final[PredicateFunction[WorkerInfo] | None]
300
+ service_type: Literal["_wool._tcp.local."] = "_wool._tcp.local."
301
+
302
+ def __init__(
303
+ self,
304
+ filter: PredicateFunction[WorkerInfo] | None = None,
305
+ ) -> None:
306
+ self._filter = filter
307
+
308
+ def __aiter__(self) -> AsyncIterator[DiscoveryEvent]:
309
+ return self._event_stream()
310
+
311
+ async def _event_stream(self) -> AsyncIterator[DiscoveryEvent]:
312
+ """Stream discovery events from the network.
313
+
314
+ Creates isolated state for this iteration including its own
315
+ Zeroconf instance, service browser, event queue, and service
316
+ cache. Automatically cleans up all resources when iteration
317
+ completes or is interrupted.
318
+
319
+ :yields:
320
+ Discovery events as workers are added, updated, or removed.
321
+ """
322
+ # Create isolated state for this iterator
323
+ event_queue: Queue[DiscoveryEvent] = Queue()
324
+ service_cache: Dict[str, WorkerInfo] = {}
325
+
326
+ # Configure zeroconf to use localhost only to avoid network warnings
327
+ aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
328
+
329
+ try:
330
+ browser = AsyncServiceBrowser(
331
+ aiozc.zeroconf,
332
+ self.service_type,
333
+ listener=self._Listener(
334
+ aiozc=aiozc,
335
+ event_queue=event_queue,
336
+ service_cache=service_cache,
337
+ predicate=self._filter or (lambda _: True),
338
+ ),
339
+ )
340
+
341
+ try:
342
+ while True:
343
+ event = await event_queue.get()
344
+ yield event
345
+ finally:
346
+ await browser.async_cancel()
347
+ finally:
348
+ await aiozc.async_close()
349
+
350
+ class _Listener(ServiceListener):
351
+ """Zeroconf listener that delivers worker service events.
352
+
353
+ :param aiozc:
354
+ The AsyncZeroconf instance to use for async service
355
+ info retrieval.
356
+ :param event_queue:
357
+ Queue to deliver discovery events to.
358
+ :param service_cache:
359
+ Cache to track service properties for pre/post event
360
+ states.
361
+ :param predicate:
362
+ Function to filter which workers to track.
363
+ """
364
+
365
+ aiozc: AsyncZeroconf
366
+ _event_queue: Queue[DiscoveryEvent]
367
+ _service_addresses: Dict[str, str]
368
+ _service_cache: Dict[str, WorkerInfo]
369
+
370
+ def __init__(
371
+ self,
372
+ aiozc: AsyncZeroconf,
373
+ event_queue: Queue[DiscoveryEvent],
374
+ predicate: PredicateFunction[WorkerInfo],
375
+ service_cache: Dict[str, WorkerInfo],
376
+ ) -> None:
377
+ self.aiozc = aiozc
378
+ self._event_queue = event_queue
379
+ self._predicate = predicate
380
+ self._service_addresses = {}
381
+ self._service_cache = service_cache
382
+
383
+ def add_service(self, zc: Zeroconf, type_: str, name: str): # noqa: ARG002
384
+ """Called by Zeroconf when a service is added."""
385
+ if type_ == LanDiscovery.service_type:
386
+ asyncio.create_task(self._handle_add_service(type_, name))
387
+
388
+ def remove_service(self, zc: Zeroconf, type_: str, name: str): # noqa: ARG002
389
+ """Called by Zeroconf when a service is removed."""
390
+ if type_ == LanDiscovery.service_type:
391
+ if worker := self._service_cache.pop(name, None):
392
+ asyncio.create_task(
393
+ self._event_queue.put(
394
+ DiscoveryEvent(type="worker-dropped", worker_info=worker)
395
+ )
396
+ )
397
+
398
+ def update_service(self, zc: Zeroconf, type_, name): # noqa: ARG002
399
+ """Called by Zeroconf when a service is updated."""
400
+ if type_ == LanDiscovery.service_type:
401
+ asyncio.create_task(self._handle_update_service(type_, name))
402
+
403
+ async def _handle_add_service(self, type_: str, name: str):
404
+ """Async handler for service addition."""
405
+ try:
406
+ if not (
407
+ service_info := await self.aiozc.async_get_service_info(
408
+ type_, name
409
+ )
410
+ ):
411
+ return
412
+
413
+ try:
414
+ worker_info = _deserialize_worker_info(service_info)
415
+ except ValueError:
416
+ return
417
+
418
+ if self._predicate(worker_info):
419
+ self._service_cache[name] = worker_info
420
+ event = DiscoveryEvent(
421
+ type="worker-added", worker_info=worker_info
422
+ )
423
+ await self._event_queue.put(event)
424
+ except Exception: # pragma: no cover
425
+ pass
426
+
427
+ async def _handle_update_service(self, type_: str, name: str):
428
+ """Async handler for service update."""
429
+ try:
430
+ if not (
431
+ service_info := await self.aiozc.async_get_service_info(
432
+ type_, name
433
+ )
434
+ ):
435
+ return
436
+
437
+ try:
438
+ worker_info = _deserialize_worker_info(service_info)
439
+ except ValueError:
440
+ return
441
+
442
+ if name not in self._service_cache:
443
+ # New worker that wasn't tracked before
444
+ if self._predicate(worker_info):
445
+ self._service_cache[name] = worker_info
446
+ event = DiscoveryEvent(
447
+ type="worker-added", worker_info=worker_info
448
+ )
449
+ await self._event_queue.put(event)
450
+ else:
451
+ # Existing tracked worker
452
+ old_worker = self._service_cache[name]
453
+ if self._predicate(worker_info):
454
+ # Still satisfies filter, update cache and emit update
455
+ self._service_cache[name] = worker_info
456
+ event = DiscoveryEvent(
457
+ type="worker-updated", worker_info=worker_info
458
+ )
459
+ await self._event_queue.put(event)
460
+ else:
461
+ # No longer satisfies filter, remove and emit removal
462
+ del self._service_cache[name]
463
+ removal_event = DiscoveryEvent(
464
+ type="worker-dropped", worker_info=old_worker
465
+ )
466
+ await self._event_queue.put(removal_event)
467
+
468
+ except Exception: # pragma: no cover
469
+ pass
470
+
471
+
472
+ def _serialize_worker_info(
473
+ info: WorkerInfo,
474
+ ) -> dict[str, str | None]:
475
+ """Serialize WorkerInfo to a flat dict for service properties.
476
+
477
+ :param info:
478
+ WorkerInfo instance to serialize.
479
+ :returns:
480
+ Flat dict with pid, version, tags (JSON), extra (JSON).
481
+ """
482
+ properties = {
483
+ "pid": str(info.pid),
484
+ "version": info.version,
485
+ "tags": (json.dumps(list(info.tags)) if info.tags else None),
486
+ "extra": (json.dumps(dict(info.extra)) if info.extra else None),
487
+ }
488
+ return properties
489
+
490
+
491
+ def _deserialize_worker_info(info: ServiceInfo) -> WorkerInfo:
492
+ """Deserialize ServiceInfo.decoded_properties to WorkerInfo.
493
+
494
+ :param info:
495
+ ServiceInfo with decoded properties dict (str keys/values).
496
+ :returns:
497
+ WorkerInfo instance.
498
+ :raises ValueError:
499
+ If required fields are missing or invalid JSON.
500
+ """
501
+ properties = info.decoded_properties
502
+ if missing := {"pid", "version"} - set(k for k, v in properties.items() if v):
503
+ missing_fields = ", ".join(missing)
504
+ raise ValueError(f"Missing required properties: {missing_fields}")
505
+
506
+ assert "pid" in properties and properties["pid"]
507
+ assert "version" in properties and properties["version"]
508
+
509
+ pid = int(properties["pid"])
510
+ version = properties["version"]
511
+
512
+ if "tags" in properties and properties["tags"]:
513
+ tags = frozenset(json.loads(properties["tags"]))
514
+ else:
515
+ tags = frozenset()
516
+
517
+ if "extra" in properties and properties["extra"]:
518
+ extra = json.loads(properties["extra"])
519
+ else:
520
+ extra = {}
521
+
522
+ # Extract UID from service name (format: "<uuid>._wool._tcp.local.")
523
+ service_name = info.name
524
+ uid_str = service_name.split(".")[0]
525
+
526
+ return WorkerInfo(
527
+ uid=UUID(uid_str),
528
+ pid=pid,
529
+ host=str(info.ip_addresses_by_version(IPVersion.V4Only)[0]),
530
+ port=info.port,
531
+ version=version,
532
+ tags=tags,
533
+ extra=MappingProxyType(extra),
534
+ )