wool 0.1rc20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wool/__init__.py +122 -0
- wool/_context.py +29 -0
- wool/_protobuf/worker.py +26 -0
- wool/_resource_pool.py +376 -0
- wool/_typing.py +7 -0
- wool/_undefined.py +11 -0
- wool/_work.py +554 -0
- wool/core/__init__.py +0 -0
- wool/core/discovery/__init__.py +0 -0
- wool/core/discovery/base.py +249 -0
- wool/core/discovery/lan.py +534 -0
- wool/core/discovery/local.py +822 -0
- wool/core/loadbalancer/__init__.py +0 -0
- wool/core/loadbalancer/base.py +125 -0
- wool/core/loadbalancer/roundrobin.py +101 -0
- wool/core/protobuf/__init__.py +18 -0
- wool/core/protobuf/exception.py +3 -0
- wool/core/protobuf/task.py +11 -0
- wool/core/protobuf/task_pb2.py +42 -0
- wool/core/protobuf/task_pb2.pyi +43 -0
- wool/core/protobuf/task_pb2_grpc.py +24 -0
- wool/core/protobuf/worker.py +26 -0
- wool/core/protobuf/worker_pb2.py +53 -0
- wool/core/protobuf/worker_pb2.pyi +65 -0
- wool/core/protobuf/worker_pb2_grpc.py +141 -0
- wool/core/typing.py +22 -0
- wool/core/worker/__init__.py +0 -0
- wool/core/worker/base.py +300 -0
- wool/core/worker/connection.py +250 -0
- wool/core/worker/local.py +148 -0
- wool/core/worker/pool.py +386 -0
- wool/core/worker/process.py +249 -0
- wool/core/worker/proxy.py +427 -0
- wool/core/worker/service.py +231 -0
- wool-0.1rc20.dist-info/METADATA +463 -0
- wool-0.1rc20.dist-info/RECORD +38 -0
- wool-0.1rc20.dist-info/WHEEL +4 -0
- wool-0.1rc20.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import signal
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from functools import partial
|
|
7
|
+
from multiprocessing import Pipe
|
|
8
|
+
from multiprocessing import Process
|
|
9
|
+
from multiprocessing.connection import Connection
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
import grpc.aio
|
|
13
|
+
|
|
14
|
+
import wool
|
|
15
|
+
from wool._resource_pool import ResourcePool
|
|
16
|
+
from wool.core import protobuf as pb
|
|
17
|
+
from wool.core.worker.service import WorkerService
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from wool.core.worker.proxy import WorkerProxy
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class WorkerProcess(Process):
|
|
24
|
+
"""Subprocess hosting a gRPC worker server.
|
|
25
|
+
|
|
26
|
+
Isolated Python process running a gRPC server for task execution.
|
|
27
|
+
Maintains its own event loop and serves as an independent worker node.
|
|
28
|
+
|
|
29
|
+
Communicates the bound port back to the parent process via pipe after
|
|
30
|
+
startup. Handles SIGTERM and SIGINT for graceful shutdown.
|
|
31
|
+
|
|
32
|
+
:param host:
|
|
33
|
+
Host address to bind.
|
|
34
|
+
:param port:
|
|
35
|
+
Port to bind. 0 for random available port.
|
|
36
|
+
:param shutdown_grace_period:
|
|
37
|
+
Graceful shutdown timeout in seconds.
|
|
38
|
+
:param proxy_pool_ttl:
|
|
39
|
+
Proxy pool TTL in seconds.
|
|
40
|
+
:param args:
|
|
41
|
+
Additional args for :class:`multiprocessing.Process`.
|
|
42
|
+
:param kwargs:
|
|
43
|
+
Additional kwargs for :class:`multiprocessing.Process`.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
_port: int | None
|
|
47
|
+
_get_port: Connection
|
|
48
|
+
_set_port: Connection
|
|
49
|
+
_shutdown_grace_period: float
|
|
50
|
+
_proxy_pool_ttl: float
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
*args,
|
|
55
|
+
host: str = "127.0.0.1",
|
|
56
|
+
port: int = 0,
|
|
57
|
+
shutdown_grace_period: float = 60.0,
|
|
58
|
+
proxy_pool_ttl: float = 60.0,
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
super().__init__(*args, **kwargs)
|
|
62
|
+
if not host:
|
|
63
|
+
raise ValueError("Host must be a non-blank string")
|
|
64
|
+
self._host = host
|
|
65
|
+
if port < 0 or port > 65535:
|
|
66
|
+
raise ValueError("Port must be a positive integer")
|
|
67
|
+
self._port = port
|
|
68
|
+
if shutdown_grace_period <= 0:
|
|
69
|
+
raise ValueError("Shutdown grace period must be positive")
|
|
70
|
+
self._shutdown_grace_period = shutdown_grace_period
|
|
71
|
+
if proxy_pool_ttl <= 0:
|
|
72
|
+
raise ValueError("Proxy pool TTL must be positive")
|
|
73
|
+
self._proxy_pool_ttl = proxy_pool_ttl
|
|
74
|
+
self._get_port, self._set_port = Pipe(duplex=False)
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def address(self) -> str | None:
|
|
78
|
+
"""The network address where the gRPC server is listening.
|
|
79
|
+
|
|
80
|
+
:returns:
|
|
81
|
+
The address in "host:port" format, or None if not started.
|
|
82
|
+
"""
|
|
83
|
+
return self._address(self._host, self._port) if self._port else None
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def host(self) -> str | None:
|
|
87
|
+
"""The host where the gRPC server is listening.
|
|
88
|
+
|
|
89
|
+
:returns:
|
|
90
|
+
The host address, or None if not started.
|
|
91
|
+
"""
|
|
92
|
+
return self._host
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def port(self) -> int | None:
|
|
96
|
+
"""The port where the gRPC server is listening.
|
|
97
|
+
|
|
98
|
+
:returns:
|
|
99
|
+
The port number, or None if not started.
|
|
100
|
+
"""
|
|
101
|
+
return self._port or None
|
|
102
|
+
|
|
103
|
+
def start(self, *, timeout: float | None = None):
|
|
104
|
+
"""Start the worker process.
|
|
105
|
+
|
|
106
|
+
Launches the worker process and waits until it has started
|
|
107
|
+
listening on a port. After starting, the :attr:`address`
|
|
108
|
+
property will contain the actual network address.
|
|
109
|
+
|
|
110
|
+
:param timeout:
|
|
111
|
+
Maximum time in seconds to wait for worker process startup.
|
|
112
|
+
:raises RuntimeError:
|
|
113
|
+
If the worker process fails to start within the timeout.
|
|
114
|
+
:raises ValueError:
|
|
115
|
+
If the timeout is not positive.
|
|
116
|
+
"""
|
|
117
|
+
if timeout is not None and timeout <= 0:
|
|
118
|
+
raise ValueError("Timeout must be positive")
|
|
119
|
+
super().start()
|
|
120
|
+
if self._get_port.poll(timeout=timeout):
|
|
121
|
+
self._port = self._get_port.recv()
|
|
122
|
+
else:
|
|
123
|
+
self.terminate()
|
|
124
|
+
self.join()
|
|
125
|
+
raise RuntimeError(
|
|
126
|
+
f"Worker process failed to start within {timeout} seconds"
|
|
127
|
+
)
|
|
128
|
+
self._get_port.close()
|
|
129
|
+
|
|
130
|
+
def run(self) -> None:
|
|
131
|
+
"""Run the worker process.
|
|
132
|
+
|
|
133
|
+
Sets the event loop for this process and starts the gRPC server,
|
|
134
|
+
blocking until the server is stopped.
|
|
135
|
+
"""
|
|
136
|
+
wool.__proxy_pool__.set(
|
|
137
|
+
ResourcePool(
|
|
138
|
+
factory=_proxy_factory,
|
|
139
|
+
finalizer=_proxy_finalizer,
|
|
140
|
+
ttl=self._proxy_pool_ttl,
|
|
141
|
+
)
|
|
142
|
+
)
|
|
143
|
+
asyncio.run(self._serve())
|
|
144
|
+
|
|
145
|
+
async def _serve(self):
|
|
146
|
+
"""Start the gRPC server in this worker process.
|
|
147
|
+
|
|
148
|
+
This method is called by the event loop to start serving
|
|
149
|
+
requests. It creates a gRPC server, adds the worker service, and
|
|
150
|
+
starts listening for incoming connections.
|
|
151
|
+
"""
|
|
152
|
+
server = grpc.aio.server()
|
|
153
|
+
port = server.add_insecure_port(self._address(self._host, self._port))
|
|
154
|
+
service = WorkerService()
|
|
155
|
+
pb.add_to_server[pb.worker.WorkerServicer](service, server)
|
|
156
|
+
|
|
157
|
+
with _signal_handlers(service):
|
|
158
|
+
try:
|
|
159
|
+
await server.start()
|
|
160
|
+
try:
|
|
161
|
+
self._set_port.send(port)
|
|
162
|
+
finally:
|
|
163
|
+
self._set_port.close()
|
|
164
|
+
await service.stopped.wait()
|
|
165
|
+
finally:
|
|
166
|
+
await server.stop(grace=self._shutdown_grace_period)
|
|
167
|
+
|
|
168
|
+
def _address(self, host, port) -> str:
|
|
169
|
+
"""Format network address for the given port.
|
|
170
|
+
|
|
171
|
+
:param port:
|
|
172
|
+
Port number to include in the address.
|
|
173
|
+
:returns:
|
|
174
|
+
Address string in "host:port" format.
|
|
175
|
+
"""
|
|
176
|
+
return f"{host}:{port}"
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@contextmanager
|
|
180
|
+
def _signal_handlers(service: WorkerService):
|
|
181
|
+
"""Context manager for setting up signal handlers for graceful shutdown.
|
|
182
|
+
|
|
183
|
+
Installs SIGTERM and SIGINT handlers that gracefully shut down the worker
|
|
184
|
+
service when the process receives termination signals.
|
|
185
|
+
|
|
186
|
+
:param service:
|
|
187
|
+
The :class:`WorkerService` instance to shut down on signal receipt.
|
|
188
|
+
:yields:
|
|
189
|
+
Control to the calling context with signal handlers installed.
|
|
190
|
+
"""
|
|
191
|
+
loop = asyncio.get_running_loop()
|
|
192
|
+
|
|
193
|
+
old_sigterm = signal.signal(signal.SIGTERM, partial(_sigterm_handler, loop, service))
|
|
194
|
+
old_sigint = signal.signal(signal.SIGINT, partial(_sigint_handler, loop, service))
|
|
195
|
+
try:
|
|
196
|
+
yield
|
|
197
|
+
finally:
|
|
198
|
+
signal.signal(signal.SIGTERM, old_sigterm)
|
|
199
|
+
signal.signal(signal.SIGINT, old_sigint)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _sigterm_handler(loop, service, signum, frame):
|
|
203
|
+
if loop.is_running():
|
|
204
|
+
loop.call_soon_threadsafe(
|
|
205
|
+
lambda: asyncio.create_task(
|
|
206
|
+
service.stop(pb.worker.StopRequest(timeout=0), None)
|
|
207
|
+
)
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _sigint_handler(loop, service, signum, frame):
|
|
212
|
+
if loop.is_running():
|
|
213
|
+
loop.call_soon_threadsafe(
|
|
214
|
+
lambda: asyncio.create_task(
|
|
215
|
+
service.stop(pb.worker.StopRequest(timeout=None), None)
|
|
216
|
+
)
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
async def _proxy_factory(proxy: WorkerProxy):
|
|
221
|
+
"""Factory function for WorkerProxy instances in ResourcePool.
|
|
222
|
+
|
|
223
|
+
Starts the proxy if not already started and returns it.
|
|
224
|
+
The proxy object itself is used as the cache key.
|
|
225
|
+
|
|
226
|
+
:param proxy:
|
|
227
|
+
The WorkerProxy instance to start (passed as key from
|
|
228
|
+
ResourcePool).
|
|
229
|
+
:returns:
|
|
230
|
+
The started WorkerProxy instance.
|
|
231
|
+
"""
|
|
232
|
+
if not proxy.started:
|
|
233
|
+
await proxy.start()
|
|
234
|
+
return proxy
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
async def _proxy_finalizer(proxy: WorkerProxy):
|
|
238
|
+
"""Finalizer function for WorkerProxy instances in ResourcePool.
|
|
239
|
+
|
|
240
|
+
Stops the proxy when it's being cleaned up from the resource pool.
|
|
241
|
+
Based on the cleanup logic from WorkerProxyCache._delayed_cleanup.
|
|
242
|
+
|
|
243
|
+
:param proxy:
|
|
244
|
+
The WorkerProxy instance to clean up.
|
|
245
|
+
"""
|
|
246
|
+
try:
|
|
247
|
+
await proxy.stop()
|
|
248
|
+
except Exception:
|
|
249
|
+
pass
|
|
@@ -0,0 +1,427 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
from typing import AsyncContextManager
|
|
7
|
+
from typing import AsyncIterator
|
|
8
|
+
from typing import Awaitable
|
|
9
|
+
from typing import ContextManager
|
|
10
|
+
from typing import Generic
|
|
11
|
+
from typing import Sequence
|
|
12
|
+
from typing import TypeAlias
|
|
13
|
+
from typing import TypeVar
|
|
14
|
+
from typing import overload
|
|
15
|
+
|
|
16
|
+
import wool
|
|
17
|
+
from wool._resource_pool import ResourcePool
|
|
18
|
+
from wool.core.discovery.base import DiscoveryEvent
|
|
19
|
+
from wool.core.discovery.base import DiscoverySubscriberLike
|
|
20
|
+
from wool.core.discovery.base import WorkerInfo
|
|
21
|
+
from wool.core.discovery.local import LocalDiscovery
|
|
22
|
+
from wool.core.loadbalancer.base import LoadBalancerContext
|
|
23
|
+
from wool.core.loadbalancer.base import LoadBalancerLike
|
|
24
|
+
from wool.core.loadbalancer.roundrobin import RoundRobinLoadBalancer
|
|
25
|
+
from wool.core.typing import Factory
|
|
26
|
+
from wool.core.worker.connection import WorkerConnection
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from wool._work import WoolTask
|
|
30
|
+
|
|
31
|
+
T = TypeVar("T")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ReducibleAsyncIterator(Generic[T]):
|
|
35
|
+
"""An async iterator that can be pickled via __reduce__.
|
|
36
|
+
|
|
37
|
+
Converts a sequence into an async iterator while maintaining
|
|
38
|
+
picklability for distributed task execution contexts.
|
|
39
|
+
|
|
40
|
+
:param items:
|
|
41
|
+
Sequence of items to convert to async iterator.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, items: Sequence[T]):
|
|
45
|
+
self._items = items
|
|
46
|
+
self._index = 0
|
|
47
|
+
|
|
48
|
+
def __aiter__(self) -> AsyncIterator[T]:
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
async def __anext__(self) -> T:
|
|
52
|
+
if self._index >= len(self._items):
|
|
53
|
+
raise StopAsyncIteration
|
|
54
|
+
item = self._items[self._index]
|
|
55
|
+
self._index += 1
|
|
56
|
+
return item
|
|
57
|
+
|
|
58
|
+
def __reduce__(self) -> tuple:
|
|
59
|
+
"""Return constructor args for unpickling."""
|
|
60
|
+
return (self.__class__, (self._items,))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
async def connection_factory(target: str) -> WorkerConnection:
|
|
64
|
+
"""Factory function for creating worker connections.
|
|
65
|
+
|
|
66
|
+
Creates a connection to the specified worker target.
|
|
67
|
+
The target is passed as the key from ResourcePool.
|
|
68
|
+
|
|
69
|
+
:param target:
|
|
70
|
+
The network target (host:port) to create a channel for.
|
|
71
|
+
:returns:
|
|
72
|
+
A new connection to the target.
|
|
73
|
+
"""
|
|
74
|
+
return WorkerConnection(target)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
async def connection_finalizer(connection: WorkerConnection) -> None:
|
|
78
|
+
"""Finalizer function for gRPC channels.
|
|
79
|
+
|
|
80
|
+
Closes the gRPC connection when it's being cleaned up from the resource pool.
|
|
81
|
+
|
|
82
|
+
:param connection:
|
|
83
|
+
The gRPC connection to close.
|
|
84
|
+
"""
|
|
85
|
+
try:
|
|
86
|
+
await connection.close()
|
|
87
|
+
except Exception:
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
WorkerUri: TypeAlias = str
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# public
|
|
95
|
+
class WorkerProxy:
|
|
96
|
+
"""Client-side proxy for dispatching tasks to distributed workers.
|
|
97
|
+
|
|
98
|
+
Manages worker discovery, connection pooling, and load-balanced task
|
|
99
|
+
routing. The bridge between :func:`@wool.work <wool.work>` decorated
|
|
100
|
+
functions and the worker pool.
|
|
101
|
+
|
|
102
|
+
Connects to workers through discovery services, pool URIs, or static
|
|
103
|
+
worker lists. Handles connection lifecycle and fault tolerance
|
|
104
|
+
automatically.
|
|
105
|
+
|
|
106
|
+
**Connect via pool URI:**
|
|
107
|
+
|
|
108
|
+
.. code-block:: python
|
|
109
|
+
|
|
110
|
+
async with WorkerProxy("pool-abc123") as proxy:
|
|
111
|
+
result = await task()
|
|
112
|
+
|
|
113
|
+
**Connect via discovery:**
|
|
114
|
+
|
|
115
|
+
.. code-block:: python
|
|
116
|
+
|
|
117
|
+
from wool.core.discovery.lan import LanDiscovery
|
|
118
|
+
|
|
119
|
+
discovery = LanDiscovery().subscribe()
|
|
120
|
+
async with WorkerProxy(discovery=discovery) as proxy:
|
|
121
|
+
result = await task()
|
|
122
|
+
|
|
123
|
+
**Connect to static workers:**
|
|
124
|
+
|
|
125
|
+
.. code-block:: python
|
|
126
|
+
|
|
127
|
+
workers = [
|
|
128
|
+
WorkerInfo(host="10.0.0.1", port=50051, ...),
|
|
129
|
+
WorkerInfo(host="10.0.0.2", port=50051, ...),
|
|
130
|
+
]
|
|
131
|
+
async with WorkerProxy(workers=workers) as proxy:
|
|
132
|
+
result = await task()
|
|
133
|
+
|
|
134
|
+
**Custom load balancer:**
|
|
135
|
+
|
|
136
|
+
.. code-block:: python
|
|
137
|
+
|
|
138
|
+
from wool.core.loadbalancer.roundrobin import RoundRobinLoadBalancer
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class CustomBalancer(RoundRobinLoadBalancer):
|
|
142
|
+
async def dispatch(self, task, context, timeout=None):
|
|
143
|
+
# Custom routing strategy
|
|
144
|
+
...
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
async with WorkerProxy(
|
|
148
|
+
discovery=discovery,
|
|
149
|
+
loadbalancer=CustomBalancer(),
|
|
150
|
+
) as proxy:
|
|
151
|
+
result = await task()
|
|
152
|
+
|
|
153
|
+
:param pool_uri:
|
|
154
|
+
Pool identifier for discovery-based connection.
|
|
155
|
+
:param tags:
|
|
156
|
+
Additional tags for filtering discovered workers.
|
|
157
|
+
:param discovery:
|
|
158
|
+
Discovery service or event stream.
|
|
159
|
+
:param workers:
|
|
160
|
+
Static worker list for direct connection.
|
|
161
|
+
:param loadbalancer:
|
|
162
|
+
Load balancer instance, factory, or context manager.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
_discovery: DiscoverySubscriberLike | Factory[DiscoverySubscriberLike]
|
|
166
|
+
_discovery_manager: (
|
|
167
|
+
AsyncContextManager[DiscoverySubscriberLike]
|
|
168
|
+
| ContextManager[DiscoverySubscriberLike]
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
_loadbalancer = LoadBalancerLike | Factory[LoadBalancerLike]
|
|
172
|
+
_loadbalancer_manager: (
|
|
173
|
+
AsyncContextManager[LoadBalancerLike] | ContextManager[LoadBalancerLike]
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
@overload
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
*,
|
|
180
|
+
discovery: DiscoverySubscriberLike | Factory[DiscoverySubscriberLike],
|
|
181
|
+
loadbalancer: (
|
|
182
|
+
LoadBalancerLike | Factory[LoadBalancerLike]
|
|
183
|
+
) = RoundRobinLoadBalancer,
|
|
184
|
+
): ...
|
|
185
|
+
|
|
186
|
+
@overload
|
|
187
|
+
def __init__(
|
|
188
|
+
self,
|
|
189
|
+
*,
|
|
190
|
+
workers: Sequence[WorkerInfo],
|
|
191
|
+
loadbalancer: LoadBalancerLike
|
|
192
|
+
| Factory[LoadBalancerLike] = RoundRobinLoadBalancer,
|
|
193
|
+
): ...
|
|
194
|
+
|
|
195
|
+
@overload
|
|
196
|
+
def __init__(
|
|
197
|
+
self,
|
|
198
|
+
pool_uri: str,
|
|
199
|
+
*tags: str,
|
|
200
|
+
loadbalancer: LoadBalancerLike
|
|
201
|
+
| Factory[LoadBalancerLike] = RoundRobinLoadBalancer,
|
|
202
|
+
): ...
|
|
203
|
+
|
|
204
|
+
def __init__(
|
|
205
|
+
self,
|
|
206
|
+
pool_uri: str | None = None,
|
|
207
|
+
*tags: str,
|
|
208
|
+
discovery: (
|
|
209
|
+
DiscoverySubscriberLike | Factory[DiscoverySubscriberLike] | None
|
|
210
|
+
) = None,
|
|
211
|
+
workers: Sequence[WorkerInfo] | None = None,
|
|
212
|
+
loadbalancer: LoadBalancerLike
|
|
213
|
+
| Factory[LoadBalancerLike] = RoundRobinLoadBalancer,
|
|
214
|
+
):
|
|
215
|
+
if not (pool_uri or discovery or workers):
|
|
216
|
+
raise ValueError(
|
|
217
|
+
"Must specify either a workerpool URI, discovery event stream, or a "
|
|
218
|
+
"sequence of workers"
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
self._id: uuid.UUID = uuid.uuid4()
|
|
222
|
+
self._started = False
|
|
223
|
+
self._loadbalancer = loadbalancer
|
|
224
|
+
|
|
225
|
+
match (pool_uri, discovery, workers):
|
|
226
|
+
case (pool_uri, None, None) if pool_uri is not None:
|
|
227
|
+
self._discovery = LocalDiscovery(pool_uri).subscribe(
|
|
228
|
+
filter=lambda w: bool({pool_uri, *tags} & w.tags)
|
|
229
|
+
)
|
|
230
|
+
case (None, discovery, None) if discovery is not None:
|
|
231
|
+
self._discovery = discovery
|
|
232
|
+
case (None, None, workers) if workers is not None:
|
|
233
|
+
self._discovery = ReducibleAsyncIterator(
|
|
234
|
+
[DiscoveryEvent(type="worker-added", worker_info=w) for w in workers]
|
|
235
|
+
)
|
|
236
|
+
case _:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
"Must specify exactly one of: "
|
|
239
|
+
"pool_uri, discovery_event_stream, or workers"
|
|
240
|
+
)
|
|
241
|
+
self._sentinel_task: asyncio.Task[None] | None = None
|
|
242
|
+
self._loadbalancer_context: LoadBalancerContext | None = None
|
|
243
|
+
|
|
244
|
+
async def __aenter__(self):
|
|
245
|
+
"""Starts the proxy and sets it as the active context."""
|
|
246
|
+
await self.start()
|
|
247
|
+
return self
|
|
248
|
+
|
|
249
|
+
async def __aexit__(self, *args):
|
|
250
|
+
"""Stops the proxy and resets the active context."""
|
|
251
|
+
await self.stop(*args)
|
|
252
|
+
|
|
253
|
+
def __hash__(self) -> int:
|
|
254
|
+
return hash(str(self.id))
|
|
255
|
+
|
|
256
|
+
def __eq__(self, value: object) -> bool:
|
|
257
|
+
return isinstance(value, WorkerProxy) and hash(self) == hash(value)
|
|
258
|
+
|
|
259
|
+
def __reduce__(self) -> tuple:
|
|
260
|
+
"""Return constructor args for unpickling with proxy ID preserved.
|
|
261
|
+
|
|
262
|
+
Creates a new WorkerProxy instance with the same discovery stream and
|
|
263
|
+
load balancer type, then sets the preserved proxy ID on the new object.
|
|
264
|
+
Workers will be re-discovered on the new instance.
|
|
265
|
+
|
|
266
|
+
:returns:
|
|
267
|
+
Tuple of (callable, args, state) for unpickling.
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
def _restore_proxy(discovery, loadbalancer, proxy_id):
|
|
271
|
+
proxy = WorkerProxy(discovery=discovery, loadbalancer=loadbalancer)
|
|
272
|
+
proxy._id = proxy_id
|
|
273
|
+
return proxy
|
|
274
|
+
|
|
275
|
+
return (
|
|
276
|
+
_restore_proxy,
|
|
277
|
+
(self._discovery, self._loadbalancer, self._id),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def id(self) -> uuid.UUID:
|
|
282
|
+
return self._id
|
|
283
|
+
|
|
284
|
+
@property
|
|
285
|
+
def started(self) -> bool:
|
|
286
|
+
return self._started
|
|
287
|
+
|
|
288
|
+
@property
|
|
289
|
+
def workers(self) -> list[WorkerInfo]:
|
|
290
|
+
"""A list of the currently discovered worker gRPC stubs."""
|
|
291
|
+
if self._loadbalancer_context:
|
|
292
|
+
return list(self._loadbalancer_context.workers.keys())
|
|
293
|
+
else:
|
|
294
|
+
return []
|
|
295
|
+
|
|
296
|
+
async def start(self) -> None:
|
|
297
|
+
"""Starts the proxy by initiating the worker discovery process.
|
|
298
|
+
|
|
299
|
+
:raises RuntimeError:
|
|
300
|
+
If the proxy has already been started.
|
|
301
|
+
"""
|
|
302
|
+
if self._started:
|
|
303
|
+
raise RuntimeError("Proxy already started")
|
|
304
|
+
|
|
305
|
+
(
|
|
306
|
+
self._loadbalancer_service,
|
|
307
|
+
self._loadbalancer_context_manager,
|
|
308
|
+
) = await self._enter_context(self._loadbalancer)
|
|
309
|
+
if not isinstance(self._loadbalancer_service, LoadBalancerLike):
|
|
310
|
+
raise ValueError
|
|
311
|
+
|
|
312
|
+
(
|
|
313
|
+
self._discovery_stream,
|
|
314
|
+
self._discovery_context_manager,
|
|
315
|
+
) = await self._enter_context(self._discovery)
|
|
316
|
+
if not isinstance(self._discovery_stream, DiscoverySubscriberLike):
|
|
317
|
+
raise ValueError
|
|
318
|
+
|
|
319
|
+
self._proxy_token = wool.__proxy__.set(self)
|
|
320
|
+
self._connection_pool = ResourcePool(
|
|
321
|
+
factory=connection_factory, finalizer=connection_finalizer, ttl=60
|
|
322
|
+
)
|
|
323
|
+
self._loadbalancer_context = LoadBalancerContext()
|
|
324
|
+
self._sentinel_task = asyncio.create_task(self._worker_sentinel())
|
|
325
|
+
self._started = True
|
|
326
|
+
|
|
327
|
+
async def stop(self, *args) -> None:
|
|
328
|
+
"""Stops the proxy, terminating discovery and clearing connections.
|
|
329
|
+
|
|
330
|
+
:raises RuntimeError:
|
|
331
|
+
If the proxy was not started first.
|
|
332
|
+
"""
|
|
333
|
+
if not self._started:
|
|
334
|
+
raise RuntimeError("Proxy not started - call start() first")
|
|
335
|
+
|
|
336
|
+
await self._exit_context(self._discovery_context_manager, *args)
|
|
337
|
+
await self._exit_context(self._loadbalancer_context_manager, *args)
|
|
338
|
+
|
|
339
|
+
wool.__proxy__.reset(self._proxy_token)
|
|
340
|
+
if self._sentinel_task:
|
|
341
|
+
self._sentinel_task.cancel()
|
|
342
|
+
try:
|
|
343
|
+
await self._sentinel_task
|
|
344
|
+
except asyncio.CancelledError:
|
|
345
|
+
pass
|
|
346
|
+
self._sentinel_task = None
|
|
347
|
+
await self._connection_pool.clear()
|
|
348
|
+
self._loadbalancer_context = None
|
|
349
|
+
self._started = False
|
|
350
|
+
|
|
351
|
+
async def dispatch(self, task: WoolTask, *, timeout: float | None = None):
|
|
352
|
+
"""Dispatches a task to an available worker in the pool.
|
|
353
|
+
|
|
354
|
+
This method selects a worker using a round-robin strategy. If no
|
|
355
|
+
workers are available within the timeout period, it raises an
|
|
356
|
+
exception.
|
|
357
|
+
|
|
358
|
+
:param task:
|
|
359
|
+
The :class:`WoolTask` object to be dispatched.
|
|
360
|
+
:param timeout:
|
|
361
|
+
Timeout in seconds for getting a worker.
|
|
362
|
+
:returns:
|
|
363
|
+
A protobuf result object from the worker.
|
|
364
|
+
:raises RuntimeError:
|
|
365
|
+
If the proxy is not started.
|
|
366
|
+
:raises asyncio.TimeoutError:
|
|
367
|
+
If no worker is available within the timeout period.
|
|
368
|
+
"""
|
|
369
|
+
if not self._started:
|
|
370
|
+
raise RuntimeError("Proxy not started - call start() first")
|
|
371
|
+
|
|
372
|
+
await asyncio.wait_for(self._await_workers(), 60)
|
|
373
|
+
|
|
374
|
+
assert isinstance(self._loadbalancer_service, LoadBalancerLike)
|
|
375
|
+
assert self._loadbalancer_context
|
|
376
|
+
return await self._loadbalancer_service.dispatch(
|
|
377
|
+
task, context=self._loadbalancer_context, timeout=timeout
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
async def _enter_context(self, factory):
|
|
381
|
+
ctx = None
|
|
382
|
+
if isinstance(factory, ContextManager):
|
|
383
|
+
ctx = factory
|
|
384
|
+
obj = ctx.__enter__()
|
|
385
|
+
elif isinstance(factory, AsyncContextManager):
|
|
386
|
+
ctx = factory
|
|
387
|
+
obj = await ctx.__aenter__()
|
|
388
|
+
elif callable(factory):
|
|
389
|
+
return await self._enter_context(factory())
|
|
390
|
+
elif isinstance(factory, Awaitable):
|
|
391
|
+
obj = await factory
|
|
392
|
+
else:
|
|
393
|
+
obj = factory
|
|
394
|
+
return obj, ctx
|
|
395
|
+
|
|
396
|
+
async def _exit_context(
|
|
397
|
+
self, ctx: AsyncContextManager | ContextManager | None, *args
|
|
398
|
+
):
|
|
399
|
+
if isinstance(ctx, AsyncContextManager):
|
|
400
|
+
await ctx.__aexit__(*args)
|
|
401
|
+
elif isinstance(ctx, ContextManager):
|
|
402
|
+
ctx.__exit__(*args)
|
|
403
|
+
|
|
404
|
+
async def _await_workers(self):
|
|
405
|
+
while not self._loadbalancer_context or not self._loadbalancer_context.workers:
|
|
406
|
+
await asyncio.sleep(0)
|
|
407
|
+
|
|
408
|
+
async def _worker_sentinel(self):
|
|
409
|
+
assert self._loadbalancer_context
|
|
410
|
+
async for event in self._discovery_stream:
|
|
411
|
+
match event.type:
|
|
412
|
+
case "worker-added":
|
|
413
|
+
self._loadbalancer_context.add_worker(
|
|
414
|
+
event.worker_info,
|
|
415
|
+
lambda: self._connection_pool.get(
|
|
416
|
+
f"{event.worker_info.host}:{event.worker_info.port}",
|
|
417
|
+
),
|
|
418
|
+
)
|
|
419
|
+
case "worker-updated":
|
|
420
|
+
self._loadbalancer_context.update_worker(
|
|
421
|
+
event.worker_info,
|
|
422
|
+
lambda: self._connection_pool.get(
|
|
423
|
+
f"{event.worker_info.host}:{event.worker_info.port}",
|
|
424
|
+
),
|
|
425
|
+
)
|
|
426
|
+
case "worker-dropped":
|
|
427
|
+
self._loadbalancer_context.remove_worker(event.worker_info)
|