wool 0.1rc20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,249 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import signal
5
+ from contextlib import contextmanager
6
+ from functools import partial
7
+ from multiprocessing import Pipe
8
+ from multiprocessing import Process
9
+ from multiprocessing.connection import Connection
10
+ from typing import TYPE_CHECKING
11
+
12
+ import grpc.aio
13
+
14
+ import wool
15
+ from wool._resource_pool import ResourcePool
16
+ from wool.core import protobuf as pb
17
+ from wool.core.worker.service import WorkerService
18
+
19
+ if TYPE_CHECKING:
20
+ from wool.core.worker.proxy import WorkerProxy
21
+
22
+
23
+ class WorkerProcess(Process):
24
+ """Subprocess hosting a gRPC worker server.
25
+
26
+ Isolated Python process running a gRPC server for task execution.
27
+ Maintains its own event loop and serves as an independent worker node.
28
+
29
+ Communicates the bound port back to the parent process via pipe after
30
+ startup. Handles SIGTERM and SIGINT for graceful shutdown.
31
+
32
+ :param host:
33
+ Host address to bind.
34
+ :param port:
35
+ Port to bind. 0 for random available port.
36
+ :param shutdown_grace_period:
37
+ Graceful shutdown timeout in seconds.
38
+ :param proxy_pool_ttl:
39
+ Proxy pool TTL in seconds.
40
+ :param args:
41
+ Additional args for :class:`multiprocessing.Process`.
42
+ :param kwargs:
43
+ Additional kwargs for :class:`multiprocessing.Process`.
44
+ """
45
+
46
+ _port: int | None
47
+ _get_port: Connection
48
+ _set_port: Connection
49
+ _shutdown_grace_period: float
50
+ _proxy_pool_ttl: float
51
+
52
+ def __init__(
53
+ self,
54
+ *args,
55
+ host: str = "127.0.0.1",
56
+ port: int = 0,
57
+ shutdown_grace_period: float = 60.0,
58
+ proxy_pool_ttl: float = 60.0,
59
+ **kwargs,
60
+ ):
61
+ super().__init__(*args, **kwargs)
62
+ if not host:
63
+ raise ValueError("Host must be a non-blank string")
64
+ self._host = host
65
+ if port < 0 or port > 65535:
66
+ raise ValueError("Port must be a positive integer")
67
+ self._port = port
68
+ if shutdown_grace_period <= 0:
69
+ raise ValueError("Shutdown grace period must be positive")
70
+ self._shutdown_grace_period = shutdown_grace_period
71
+ if proxy_pool_ttl <= 0:
72
+ raise ValueError("Proxy pool TTL must be positive")
73
+ self._proxy_pool_ttl = proxy_pool_ttl
74
+ self._get_port, self._set_port = Pipe(duplex=False)
75
+
76
+ @property
77
+ def address(self) -> str | None:
78
+ """The network address where the gRPC server is listening.
79
+
80
+ :returns:
81
+ The address in "host:port" format, or None if not started.
82
+ """
83
+ return self._address(self._host, self._port) if self._port else None
84
+
85
+ @property
86
+ def host(self) -> str | None:
87
+ """The host where the gRPC server is listening.
88
+
89
+ :returns:
90
+ The host address, or None if not started.
91
+ """
92
+ return self._host
93
+
94
+ @property
95
+ def port(self) -> int | None:
96
+ """The port where the gRPC server is listening.
97
+
98
+ :returns:
99
+ The port number, or None if not started.
100
+ """
101
+ return self._port or None
102
+
103
+ def start(self, *, timeout: float | None = None):
104
+ """Start the worker process.
105
+
106
+ Launches the worker process and waits until it has started
107
+ listening on a port. After starting, the :attr:`address`
108
+ property will contain the actual network address.
109
+
110
+ :param timeout:
111
+ Maximum time in seconds to wait for worker process startup.
112
+ :raises RuntimeError:
113
+ If the worker process fails to start within the timeout.
114
+ :raises ValueError:
115
+ If the timeout is not positive.
116
+ """
117
+ if timeout is not None and timeout <= 0:
118
+ raise ValueError("Timeout must be positive")
119
+ super().start()
120
+ if self._get_port.poll(timeout=timeout):
121
+ self._port = self._get_port.recv()
122
+ else:
123
+ self.terminate()
124
+ self.join()
125
+ raise RuntimeError(
126
+ f"Worker process failed to start within {timeout} seconds"
127
+ )
128
+ self._get_port.close()
129
+
130
+ def run(self) -> None:
131
+ """Run the worker process.
132
+
133
+ Sets the event loop for this process and starts the gRPC server,
134
+ blocking until the server is stopped.
135
+ """
136
+ wool.__proxy_pool__.set(
137
+ ResourcePool(
138
+ factory=_proxy_factory,
139
+ finalizer=_proxy_finalizer,
140
+ ttl=self._proxy_pool_ttl,
141
+ )
142
+ )
143
+ asyncio.run(self._serve())
144
+
145
+ async def _serve(self):
146
+ """Start the gRPC server in this worker process.
147
+
148
+ This method is called by the event loop to start serving
149
+ requests. It creates a gRPC server, adds the worker service, and
150
+ starts listening for incoming connections.
151
+ """
152
+ server = grpc.aio.server()
153
+ port = server.add_insecure_port(self._address(self._host, self._port))
154
+ service = WorkerService()
155
+ pb.add_to_server[pb.worker.WorkerServicer](service, server)
156
+
157
+ with _signal_handlers(service):
158
+ try:
159
+ await server.start()
160
+ try:
161
+ self._set_port.send(port)
162
+ finally:
163
+ self._set_port.close()
164
+ await service.stopped.wait()
165
+ finally:
166
+ await server.stop(grace=self._shutdown_grace_period)
167
+
168
+ def _address(self, host, port) -> str:
169
+ """Format network address for the given port.
170
+
171
+ :param port:
172
+ Port number to include in the address.
173
+ :returns:
174
+ Address string in "host:port" format.
175
+ """
176
+ return f"{host}:{port}"
177
+
178
+
179
+ @contextmanager
180
+ def _signal_handlers(service: WorkerService):
181
+ """Context manager for setting up signal handlers for graceful shutdown.
182
+
183
+ Installs SIGTERM and SIGINT handlers that gracefully shut down the worker
184
+ service when the process receives termination signals.
185
+
186
+ :param service:
187
+ The :class:`WorkerService` instance to shut down on signal receipt.
188
+ :yields:
189
+ Control to the calling context with signal handlers installed.
190
+ """
191
+ loop = asyncio.get_running_loop()
192
+
193
+ old_sigterm = signal.signal(signal.SIGTERM, partial(_sigterm_handler, loop, service))
194
+ old_sigint = signal.signal(signal.SIGINT, partial(_sigint_handler, loop, service))
195
+ try:
196
+ yield
197
+ finally:
198
+ signal.signal(signal.SIGTERM, old_sigterm)
199
+ signal.signal(signal.SIGINT, old_sigint)
200
+
201
+
202
+ def _sigterm_handler(loop, service, signum, frame):
203
+ if loop.is_running():
204
+ loop.call_soon_threadsafe(
205
+ lambda: asyncio.create_task(
206
+ service.stop(pb.worker.StopRequest(timeout=0), None)
207
+ )
208
+ )
209
+
210
+
211
+ def _sigint_handler(loop, service, signum, frame):
212
+ if loop.is_running():
213
+ loop.call_soon_threadsafe(
214
+ lambda: asyncio.create_task(
215
+ service.stop(pb.worker.StopRequest(timeout=None), None)
216
+ )
217
+ )
218
+
219
+
220
+ async def _proxy_factory(proxy: WorkerProxy):
221
+ """Factory function for WorkerProxy instances in ResourcePool.
222
+
223
+ Starts the proxy if not already started and returns it.
224
+ The proxy object itself is used as the cache key.
225
+
226
+ :param proxy:
227
+ The WorkerProxy instance to start (passed as key from
228
+ ResourcePool).
229
+ :returns:
230
+ The started WorkerProxy instance.
231
+ """
232
+ if not proxy.started:
233
+ await proxy.start()
234
+ return proxy
235
+
236
+
237
+ async def _proxy_finalizer(proxy: WorkerProxy):
238
+ """Finalizer function for WorkerProxy instances in ResourcePool.
239
+
240
+ Stops the proxy when it's being cleaned up from the resource pool.
241
+ Based on the cleanup logic from WorkerProxyCache._delayed_cleanup.
242
+
243
+ :param proxy:
244
+ The WorkerProxy instance to clean up.
245
+ """
246
+ try:
247
+ await proxy.stop()
248
+ except Exception:
249
+ pass
@@ -0,0 +1,427 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import uuid
5
+ from typing import TYPE_CHECKING
6
+ from typing import AsyncContextManager
7
+ from typing import AsyncIterator
8
+ from typing import Awaitable
9
+ from typing import ContextManager
10
+ from typing import Generic
11
+ from typing import Sequence
12
+ from typing import TypeAlias
13
+ from typing import TypeVar
14
+ from typing import overload
15
+
16
+ import wool
17
+ from wool._resource_pool import ResourcePool
18
+ from wool.core.discovery.base import DiscoveryEvent
19
+ from wool.core.discovery.base import DiscoverySubscriberLike
20
+ from wool.core.discovery.base import WorkerInfo
21
+ from wool.core.discovery.local import LocalDiscovery
22
+ from wool.core.loadbalancer.base import LoadBalancerContext
23
+ from wool.core.loadbalancer.base import LoadBalancerLike
24
+ from wool.core.loadbalancer.roundrobin import RoundRobinLoadBalancer
25
+ from wool.core.typing import Factory
26
+ from wool.core.worker.connection import WorkerConnection
27
+
28
+ if TYPE_CHECKING:
29
+ from wool._work import WoolTask
30
+
31
+ T = TypeVar("T")
32
+
33
+
34
+ class ReducibleAsyncIterator(Generic[T]):
35
+ """An async iterator that can be pickled via __reduce__.
36
+
37
+ Converts a sequence into an async iterator while maintaining
38
+ picklability for distributed task execution contexts.
39
+
40
+ :param items:
41
+ Sequence of items to convert to async iterator.
42
+ """
43
+
44
+ def __init__(self, items: Sequence[T]):
45
+ self._items = items
46
+ self._index = 0
47
+
48
+ def __aiter__(self) -> AsyncIterator[T]:
49
+ return self
50
+
51
+ async def __anext__(self) -> T:
52
+ if self._index >= len(self._items):
53
+ raise StopAsyncIteration
54
+ item = self._items[self._index]
55
+ self._index += 1
56
+ return item
57
+
58
+ def __reduce__(self) -> tuple:
59
+ """Return constructor args for unpickling."""
60
+ return (self.__class__, (self._items,))
61
+
62
+
63
+ async def connection_factory(target: str) -> WorkerConnection:
64
+ """Factory function for creating worker connections.
65
+
66
+ Creates a connection to the specified worker target.
67
+ The target is passed as the key from ResourcePool.
68
+
69
+ :param target:
70
+ The network target (host:port) to create a channel for.
71
+ :returns:
72
+ A new connection to the target.
73
+ """
74
+ return WorkerConnection(target)
75
+
76
+
77
+ async def connection_finalizer(connection: WorkerConnection) -> None:
78
+ """Finalizer function for gRPC channels.
79
+
80
+ Closes the gRPC connection when it's being cleaned up from the resource pool.
81
+
82
+ :param connection:
83
+ The gRPC connection to close.
84
+ """
85
+ try:
86
+ await connection.close()
87
+ except Exception:
88
+ pass
89
+
90
+
91
+ WorkerUri: TypeAlias = str
92
+
93
+
94
+ # public
95
+ class WorkerProxy:
96
+ """Client-side proxy for dispatching tasks to distributed workers.
97
+
98
+ Manages worker discovery, connection pooling, and load-balanced task
99
+ routing. The bridge between :func:`@wool.work <wool.work>` decorated
100
+ functions and the worker pool.
101
+
102
+ Connects to workers through discovery services, pool URIs, or static
103
+ worker lists. Handles connection lifecycle and fault tolerance
104
+ automatically.
105
+
106
+ **Connect via pool URI:**
107
+
108
+ .. code-block:: python
109
+
110
+ async with WorkerProxy("pool-abc123") as proxy:
111
+ result = await task()
112
+
113
+ **Connect via discovery:**
114
+
115
+ .. code-block:: python
116
+
117
+ from wool.core.discovery.lan import LanDiscovery
118
+
119
+ discovery = LanDiscovery().subscribe()
120
+ async with WorkerProxy(discovery=discovery) as proxy:
121
+ result = await task()
122
+
123
+ **Connect to static workers:**
124
+
125
+ .. code-block:: python
126
+
127
+ workers = [
128
+ WorkerInfo(host="10.0.0.1", port=50051, ...),
129
+ WorkerInfo(host="10.0.0.2", port=50051, ...),
130
+ ]
131
+ async with WorkerProxy(workers=workers) as proxy:
132
+ result = await task()
133
+
134
+ **Custom load balancer:**
135
+
136
+ .. code-block:: python
137
+
138
+ from wool.core.loadbalancer.roundrobin import RoundRobinLoadBalancer
139
+
140
+
141
+ class CustomBalancer(RoundRobinLoadBalancer):
142
+ async def dispatch(self, task, context, timeout=None):
143
+ # Custom routing strategy
144
+ ...
145
+
146
+
147
+ async with WorkerProxy(
148
+ discovery=discovery,
149
+ loadbalancer=CustomBalancer(),
150
+ ) as proxy:
151
+ result = await task()
152
+
153
+ :param pool_uri:
154
+ Pool identifier for discovery-based connection.
155
+ :param tags:
156
+ Additional tags for filtering discovered workers.
157
+ :param discovery:
158
+ Discovery service or event stream.
159
+ :param workers:
160
+ Static worker list for direct connection.
161
+ :param loadbalancer:
162
+ Load balancer instance, factory, or context manager.
163
+ """
164
+
165
+ _discovery: DiscoverySubscriberLike | Factory[DiscoverySubscriberLike]
166
+ _discovery_manager: (
167
+ AsyncContextManager[DiscoverySubscriberLike]
168
+ | ContextManager[DiscoverySubscriberLike]
169
+ )
170
+
171
+ _loadbalancer = LoadBalancerLike | Factory[LoadBalancerLike]
172
+ _loadbalancer_manager: (
173
+ AsyncContextManager[LoadBalancerLike] | ContextManager[LoadBalancerLike]
174
+ )
175
+
176
+ @overload
177
+ def __init__(
178
+ self,
179
+ *,
180
+ discovery: DiscoverySubscriberLike | Factory[DiscoverySubscriberLike],
181
+ loadbalancer: (
182
+ LoadBalancerLike | Factory[LoadBalancerLike]
183
+ ) = RoundRobinLoadBalancer,
184
+ ): ...
185
+
186
+ @overload
187
+ def __init__(
188
+ self,
189
+ *,
190
+ workers: Sequence[WorkerInfo],
191
+ loadbalancer: LoadBalancerLike
192
+ | Factory[LoadBalancerLike] = RoundRobinLoadBalancer,
193
+ ): ...
194
+
195
+ @overload
196
+ def __init__(
197
+ self,
198
+ pool_uri: str,
199
+ *tags: str,
200
+ loadbalancer: LoadBalancerLike
201
+ | Factory[LoadBalancerLike] = RoundRobinLoadBalancer,
202
+ ): ...
203
+
204
+ def __init__(
205
+ self,
206
+ pool_uri: str | None = None,
207
+ *tags: str,
208
+ discovery: (
209
+ DiscoverySubscriberLike | Factory[DiscoverySubscriberLike] | None
210
+ ) = None,
211
+ workers: Sequence[WorkerInfo] | None = None,
212
+ loadbalancer: LoadBalancerLike
213
+ | Factory[LoadBalancerLike] = RoundRobinLoadBalancer,
214
+ ):
215
+ if not (pool_uri or discovery or workers):
216
+ raise ValueError(
217
+ "Must specify either a workerpool URI, discovery event stream, or a "
218
+ "sequence of workers"
219
+ )
220
+
221
+ self._id: uuid.UUID = uuid.uuid4()
222
+ self._started = False
223
+ self._loadbalancer = loadbalancer
224
+
225
+ match (pool_uri, discovery, workers):
226
+ case (pool_uri, None, None) if pool_uri is not None:
227
+ self._discovery = LocalDiscovery(pool_uri).subscribe(
228
+ filter=lambda w: bool({pool_uri, *tags} & w.tags)
229
+ )
230
+ case (None, discovery, None) if discovery is not None:
231
+ self._discovery = discovery
232
+ case (None, None, workers) if workers is not None:
233
+ self._discovery = ReducibleAsyncIterator(
234
+ [DiscoveryEvent(type="worker-added", worker_info=w) for w in workers]
235
+ )
236
+ case _:
237
+ raise ValueError(
238
+ "Must specify exactly one of: "
239
+ "pool_uri, discovery_event_stream, or workers"
240
+ )
241
+ self._sentinel_task: asyncio.Task[None] | None = None
242
+ self._loadbalancer_context: LoadBalancerContext | None = None
243
+
244
+ async def __aenter__(self):
245
+ """Starts the proxy and sets it as the active context."""
246
+ await self.start()
247
+ return self
248
+
249
+ async def __aexit__(self, *args):
250
+ """Stops the proxy and resets the active context."""
251
+ await self.stop(*args)
252
+
253
+ def __hash__(self) -> int:
254
+ return hash(str(self.id))
255
+
256
+ def __eq__(self, value: object) -> bool:
257
+ return isinstance(value, WorkerProxy) and hash(self) == hash(value)
258
+
259
+ def __reduce__(self) -> tuple:
260
+ """Return constructor args for unpickling with proxy ID preserved.
261
+
262
+ Creates a new WorkerProxy instance with the same discovery stream and
263
+ load balancer type, then sets the preserved proxy ID on the new object.
264
+ Workers will be re-discovered on the new instance.
265
+
266
+ :returns:
267
+ Tuple of (callable, args, state) for unpickling.
268
+ """
269
+
270
+ def _restore_proxy(discovery, loadbalancer, proxy_id):
271
+ proxy = WorkerProxy(discovery=discovery, loadbalancer=loadbalancer)
272
+ proxy._id = proxy_id
273
+ return proxy
274
+
275
+ return (
276
+ _restore_proxy,
277
+ (self._discovery, self._loadbalancer, self._id),
278
+ )
279
+
280
+ @property
281
+ def id(self) -> uuid.UUID:
282
+ return self._id
283
+
284
+ @property
285
+ def started(self) -> bool:
286
+ return self._started
287
+
288
+ @property
289
+ def workers(self) -> list[WorkerInfo]:
290
+ """A list of the currently discovered worker gRPC stubs."""
291
+ if self._loadbalancer_context:
292
+ return list(self._loadbalancer_context.workers.keys())
293
+ else:
294
+ return []
295
+
296
+ async def start(self) -> None:
297
+ """Starts the proxy by initiating the worker discovery process.
298
+
299
+ :raises RuntimeError:
300
+ If the proxy has already been started.
301
+ """
302
+ if self._started:
303
+ raise RuntimeError("Proxy already started")
304
+
305
+ (
306
+ self._loadbalancer_service,
307
+ self._loadbalancer_context_manager,
308
+ ) = await self._enter_context(self._loadbalancer)
309
+ if not isinstance(self._loadbalancer_service, LoadBalancerLike):
310
+ raise ValueError
311
+
312
+ (
313
+ self._discovery_stream,
314
+ self._discovery_context_manager,
315
+ ) = await self._enter_context(self._discovery)
316
+ if not isinstance(self._discovery_stream, DiscoverySubscriberLike):
317
+ raise ValueError
318
+
319
+ self._proxy_token = wool.__proxy__.set(self)
320
+ self._connection_pool = ResourcePool(
321
+ factory=connection_factory, finalizer=connection_finalizer, ttl=60
322
+ )
323
+ self._loadbalancer_context = LoadBalancerContext()
324
+ self._sentinel_task = asyncio.create_task(self._worker_sentinel())
325
+ self._started = True
326
+
327
+ async def stop(self, *args) -> None:
328
+ """Stops the proxy, terminating discovery and clearing connections.
329
+
330
+ :raises RuntimeError:
331
+ If the proxy was not started first.
332
+ """
333
+ if not self._started:
334
+ raise RuntimeError("Proxy not started - call start() first")
335
+
336
+ await self._exit_context(self._discovery_context_manager, *args)
337
+ await self._exit_context(self._loadbalancer_context_manager, *args)
338
+
339
+ wool.__proxy__.reset(self._proxy_token)
340
+ if self._sentinel_task:
341
+ self._sentinel_task.cancel()
342
+ try:
343
+ await self._sentinel_task
344
+ except asyncio.CancelledError:
345
+ pass
346
+ self._sentinel_task = None
347
+ await self._connection_pool.clear()
348
+ self._loadbalancer_context = None
349
+ self._started = False
350
+
351
+ async def dispatch(self, task: WoolTask, *, timeout: float | None = None):
352
+ """Dispatches a task to an available worker in the pool.
353
+
354
+ This method selects a worker using a round-robin strategy. If no
355
+ workers are available within the timeout period, it raises an
356
+ exception.
357
+
358
+ :param task:
359
+ The :class:`WoolTask` object to be dispatched.
360
+ :param timeout:
361
+ Timeout in seconds for getting a worker.
362
+ :returns:
363
+ A protobuf result object from the worker.
364
+ :raises RuntimeError:
365
+ If the proxy is not started.
366
+ :raises asyncio.TimeoutError:
367
+ If no worker is available within the timeout period.
368
+ """
369
+ if not self._started:
370
+ raise RuntimeError("Proxy not started - call start() first")
371
+
372
+ await asyncio.wait_for(self._await_workers(), 60)
373
+
374
+ assert isinstance(self._loadbalancer_service, LoadBalancerLike)
375
+ assert self._loadbalancer_context
376
+ return await self._loadbalancer_service.dispatch(
377
+ task, context=self._loadbalancer_context, timeout=timeout
378
+ )
379
+
380
+ async def _enter_context(self, factory):
381
+ ctx = None
382
+ if isinstance(factory, ContextManager):
383
+ ctx = factory
384
+ obj = ctx.__enter__()
385
+ elif isinstance(factory, AsyncContextManager):
386
+ ctx = factory
387
+ obj = await ctx.__aenter__()
388
+ elif callable(factory):
389
+ return await self._enter_context(factory())
390
+ elif isinstance(factory, Awaitable):
391
+ obj = await factory
392
+ else:
393
+ obj = factory
394
+ return obj, ctx
395
+
396
+ async def _exit_context(
397
+ self, ctx: AsyncContextManager | ContextManager | None, *args
398
+ ):
399
+ if isinstance(ctx, AsyncContextManager):
400
+ await ctx.__aexit__(*args)
401
+ elif isinstance(ctx, ContextManager):
402
+ ctx.__exit__(*args)
403
+
404
+ async def _await_workers(self):
405
+ while not self._loadbalancer_context or not self._loadbalancer_context.workers:
406
+ await asyncio.sleep(0)
407
+
408
+ async def _worker_sentinel(self):
409
+ assert self._loadbalancer_context
410
+ async for event in self._discovery_stream:
411
+ match event.type:
412
+ case "worker-added":
413
+ self._loadbalancer_context.add_worker(
414
+ event.worker_info,
415
+ lambda: self._connection_pool.get(
416
+ f"{event.worker_info.host}:{event.worker_info.port}",
417
+ ),
418
+ )
419
+ case "worker-updated":
420
+ self._loadbalancer_context.update_worker(
421
+ event.worker_info,
422
+ lambda: self._connection_pool.get(
423
+ f"{event.worker_info.host}:{event.worker_info.port}",
424
+ ),
425
+ )
426
+ case "worker-dropped":
427
+ self._loadbalancer_context.remove_worker(event.worker_info)