wool 0.1rc3__py3-none-any.whl → 0.1rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wool might be problematic. Click here for more details.

wool/_pool.py CHANGED
@@ -3,51 +3,141 @@ from __future__ import annotations
3
3
  import logging
4
4
  import os
5
5
  from contextvars import ContextVar
6
- from copy import copy
7
- from functools import partial, wraps
8
- from multiprocessing import Process, current_process
6
+ from contextvars import Token
7
+ from functools import partial
8
+ from functools import wraps
9
+ from multiprocessing import Pipe
10
+ from multiprocessing import Process
11
+ from multiprocessing import current_process
9
12
  from multiprocessing.managers import Server
10
- from signal import Signals, signal
11
- from threading import Event, Lock, Thread, current_thread
12
- from time import sleep
13
- from typing import TYPE_CHECKING, Callable, Coroutine
14
- from weakref import WeakSet
13
+ from signal import Signals
14
+ from signal import signal
15
+ from threading import Event
16
+ from threading import Semaphore
17
+ from threading import Thread
18
+ from typing import TYPE_CHECKING
19
+ from typing import Coroutine
15
20
 
16
21
  import wool
17
- from wool._client import WoolClient
18
22
  from wool._manager import Manager
23
+ from wool._session import WorkerPoolSession
24
+ from wool._worker import Scheduler
19
25
  from wool._worker import Worker
20
26
 
21
27
  if TYPE_CHECKING:
28
+ from wool._queue import TaskQueue
22
29
  from wool._task import AsyncCallable
23
30
 
24
31
 
32
+ def _stop(pool: WorkerPool, wait: bool, *_):
33
+ pool.stop(wait=wait)
34
+
35
+
25
36
  # PUBLIC
26
37
  def pool(
27
- address: tuple[str, int],
38
+ host: str = "localhost",
39
+ port: int = 48800,
28
40
  *,
29
41
  authkey: bytes | None = None,
30
42
  breadth: int = 0,
31
43
  log_level: int = logging.INFO,
32
- ) -> Callable[[AsyncCallable], AsyncCallable]:
33
- def _pool(fn: AsyncCallable) -> AsyncCallable:
34
- @wraps(fn)
35
- async def wrapper(*args, **kwargs) -> Coroutine:
36
- with WoolPool(
37
- breadth=breadth,
38
- address=address,
39
- authkey=authkey,
40
- log_level=log_level,
44
+ ) -> WorkerPool:
45
+ """
46
+ Convenience function to declare a worker pool context.
47
+
48
+ :param host: The hostname of the worker pool. Defaults to "localhost".
49
+ :param port: The port of the worker pool. Defaults to 48800.
50
+ :param authkey: Optional authentication key for the worker pool.
51
+ :param breadth: Number of worker processes in the pool. Defaults to 0
52
+ (CPU count).
53
+ :param log_level: Logging level for the worker pool. Defaults to
54
+ logging.INFO.
55
+ :return: A decorator that wraps the function to execute within the session.
56
+
57
+ Usage:
58
+
59
+ .. code-block:: python
60
+
61
+ import wool
62
+
63
+
64
+ @wool.pool(
65
+ host="localhost",
66
+ port=48800,
67
+ authkey=b"deadbeef",
68
+ breadth=4,
69
+ )
70
+ async def foo(): ...
71
+
72
+ This is equivalent to:
73
+
74
+ .. code-block:: python
75
+
76
+ import wool
77
+
78
+
79
+ async def foo():
80
+ with wool.pool(
81
+ host="localhost", port=48800, authkey=b"deadbeef", breadth=4
41
82
  ):
42
- return await fn(*args, **kwargs)
83
+ ...
43
84
 
44
- return wrapper
85
+ This decorator can also be combined with the ``@wool.task`` decorator to
86
+ declare a task that is tightly coupled with the specified pool:
87
+
88
+ .. code-block:: python
89
+
90
+ import wool
45
91
 
46
- return _pool
92
+
93
+ @wool.pool(
94
+ host="localhost",
95
+ port=48800,
96
+ authkey=b"deadbeef",
97
+ breadth=4,
98
+ )
99
+ @wool.task
100
+ async def foo(): ...
101
+
102
+ .. note::
103
+
104
+ The order of decorators matters. To ensure that invocations of the
105
+ declared task are dispatched to the pool specified by ``@wool.pool``,
106
+ the ``@wool.task`` decorator must be applied after ``@wool.pool``.
107
+ """
108
+ return WorkerPool(
109
+ address=(host, port),
110
+ authkey=authkey,
111
+ breadth=breadth,
112
+ log_level=log_level,
113
+ )
47
114
 
48
115
 
49
116
  # PUBLIC
50
- class WoolPool(Process):
117
+ class WorkerPool(Process):
118
+ """
119
+ A multiprocessing-based worker pool for executing asynchronous tasks. A
120
+ pool consists of a single manager process and at least a single worker
121
+ process. The manager process orchestrates its workers and serves client
122
+ dispatch requests. The worker process(es) execute(s) dispatched tasks on a
123
+ first-come, first-served basis.
124
+
125
+ The worker pool class is implemented as a context manager and decorator,
126
+ allowing users to easily spawn ephemeral pools that live for the duration
127
+ of a client application's execution and tightly couple functions to a pool.
128
+
129
+ :param address: The address of the worker pool (host, port).
130
+ :param authkey: Optional authentication key for the pool. If not specified,
131
+ the manager will inherit the authkey from the current process.
132
+ :param breadth: Number of worker processes in the pool. Defaults to CPU
133
+ count.
134
+ :param log_level: Logging level for the pool.
135
+ """
136
+
137
+ _wait_event: Event | None = None
138
+ _stop_event: Event | None = None
139
+ _stopped: bool = False
140
+
51
141
  def __init__(
52
142
  self,
53
143
  address: tuple[str, int] = ("localhost", 5050),
@@ -68,41 +158,139 @@ class WoolPool(Process):
68
158
  self._breadth: int = breadth
69
159
  self._address: tuple[str, int] = address
70
160
  self._log_level: int = log_level
71
- self._outer_client: WoolClient | None = None
72
- self._client = WoolClient(address=self._address)
161
+ self._token: Token | None = None
162
+ self._session = self.session_type(
163
+ address=self._address, authkey=self.authkey
164
+ )
165
+ self._get_ready, self._set_ready = Pipe(duplex=False)
166
+
167
+ def __call__(self, fn: AsyncCallable) -> AsyncCallable:
168
+ """
169
+ Decorate a function to be executed within the pool.
170
+
171
+ :param fn: The function to be executed.
172
+ :return: The wrapped function.
173
+ """
174
+
175
+ @wraps(fn)
176
+ async def wrapper(*args, **kwargs) -> Coroutine:
177
+ with self:
178
+ return await fn(*args, **kwargs)
179
+
180
+ return wrapper
73
181
 
74
182
  def __enter__(self):
183
+ """
184
+ Enter the context of the pool, starting the pool and connecting the
185
+ session.
186
+ """
75
187
  self.start()
76
- self._outer_client = self.client_context.get()
77
- self.client_context.set(self._client)
188
+ self._session.connect()
189
+ self._token = self.session_context.set(self._session)
78
190
 
79
191
  def __exit__(self, *_) -> None:
80
- assert self._outer_client
81
- self.client_context.set(self._outer_client)
82
- self._outer_client = None
192
+ """
193
+ Exit the context of the pool, stopping the pool and disconnecting the
194
+ session.
195
+ """
196
+ assert self._token
197
+ self.session_context.reset(self._token)
83
198
  assert self.pid
84
- self.stop()
85
- self.join()
199
+ try:
200
+ self.stop(wait=True)
201
+ except ConnectionRefusedError:
202
+ logging.warning(
203
+ f"Connection to manager at {self._address} refused."
204
+ )
205
+ finally:
206
+ self.join()
207
+
208
+ @property
209
+ def session_type(self) -> type[WorkerPoolSession]:
210
+ """
211
+ Get the session type for the pool.
212
+
213
+ :return: The session type.
214
+ """
215
+ return WorkerPoolSession
216
+
217
+ @property
218
+ def session_context(self) -> ContextVar[WorkerPoolSession]:
219
+ """
220
+ Get the session context variable for the pool.
221
+
222
+ :return: The session context variable.
223
+ """
224
+ return wool.__wool_session__
86
225
 
87
226
  @property
88
- def client_context(self) -> ContextVar[WoolClient]:
89
- return wool.__wool_client__
227
+ def scheduler_type(self) -> type[Scheduler]:
228
+ """
229
+ Get the scheduler type for the pool.
230
+
231
+ :return: The scheduler type.
232
+ """
233
+ return Scheduler
90
234
 
91
235
  @property
92
236
  def log_level(self) -> int:
237
+ """
238
+ Get the logging level for the pool.
239
+
240
+ :return: The logging level.
241
+ """
93
242
  return self._log_level
94
243
 
95
244
  @log_level.setter
96
245
  def log_level(self, value: int) -> None:
246
+ """
247
+ Set the logging level for the pool.
248
+
249
+ :param value: The new logging level.
250
+ """
97
251
  if value < 0:
98
252
  raise ValueError("Log level must be non-negative")
99
253
  self._log_level = value
100
254
 
101
255
  @property
102
256
  def breadth(self) -> int:
257
+ """
258
+ Get the number of worker processes in the pool.
259
+
260
+ :return: The number of worker processes.
261
+ """
103
262
  return self._breadth
104
263
 
264
+ @property
265
+ def waiting(self) -> bool | None:
266
+ """
267
+ Check if the pool is in a waiting state.
268
+
269
+ :return: True if waiting, False otherwise, or None if undefined.
270
+ """
271
+ return self._wait_event and self._wait_event.is_set()
272
+
273
+ @property
274
+ def stopping(self) -> bool | None:
275
+ """
276
+ Check if the pool is in a stopping state.
277
+
278
+ :return: True if stopping, False otherwise, or None if undefined.
279
+ """
280
+ return self._stop_event and self._stop_event.is_set()
281
+
282
+ def start(self) -> None:
283
+ """
284
+ Start the pool process and wait for it to be ready.
285
+ """
286
+ super().start()
287
+ self._get_ready.recv()
288
+ self._get_ready.close()
289
+
105
290
  def run(self) -> None:
291
+ """
292
+ Run the pool process, managing workers and the manager process.
293
+ """
106
294
  if self.log_level:
107
295
  wool.__log_level__ = self.log_level
108
296
  logging.basicConfig(format=wool.__log_format__)
@@ -111,95 +299,150 @@ class WoolPool(Process):
111
299
 
112
300
  logging.debug("Thread started")
113
301
 
114
- signal(Signals.SIGINT, partial(stop, self, True))
115
- signal(Signals.SIGTERM, partial(stop, self, False))
302
+ signal(Signals.SIGINT, partial(_stop, self, False))
303
+ signal(Signals.SIGTERM, partial(_stop, self, True))
116
304
 
117
- self.lock = Lock()
118
- self._worker_sentinels = WeakSet()
119
-
120
- self._manager = Manager(address=self._address, authkey=self.authkey)
121
-
122
- server = self._manager.get_server()
123
- server_thread = Thread(
124
- target=server.serve_forever, name="ServerThread", daemon=True
305
+ self.manager_sentinel = ManagerSentinel(
306
+ address=self._address, authkey=self.authkey
125
307
  )
126
- server_thread.start()
308
+ self.manager_sentinel.start()
127
309
 
128
- self._manager.connect()
310
+ self._wait_event = self.manager_sentinel.waiting
311
+ self._stop_event = self.manager_sentinel.stopping
129
312
 
130
- self._stop_event = Event()
131
- shutdown_sentinel = ShutdownSentinel(
132
- self._stop_event, server, self._worker_sentinels
133
- )
134
- shutdown_sentinel.start()
135
-
136
- with self.lock:
313
+ worker_sentinels = []
314
+ logging.info("Spawning workers...")
315
+ try:
137
316
  for i in range(1, self.breadth + 1):
138
- if self._stop_event.is_set():
139
- break
140
- logging.debug(f"Spawning worker {i}...")
141
- worker_sentinel = WorkerSentinel(
142
- address=self._address,
143
- log_level=self.log_level,
144
- id=i,
145
- lock=self.lock,
146
- )
147
- worker_sentinel.start()
148
- self._worker_sentinels.add(worker_sentinel)
149
-
150
- current_thread().name = "IdleSentinel"
151
- while not self.idle() and not self._stop_event.is_set():
152
- self._stop_event.wait(1)
153
- else:
154
- self.stop()
155
-
156
- server_thread.join()
157
- for worker_sentinel in self._worker_sentinels:
158
- if worker_sentinel.is_alive():
317
+ if not self._stop_event.is_set():
318
+ worker_sentinel = WorkerSentinel(
319
+ address=self._address,
320
+ log_level=self.log_level,
321
+ id=i,
322
+ scheduler=self.scheduler_type,
323
+ )
324
+ worker_sentinel.start()
325
+ worker_sentinels.append(worker_sentinel)
326
+ for worker_sentinel in worker_sentinels:
327
+ worker_sentinel.ready.wait()
328
+ self._set_ready.send(True)
329
+ self._set_ready.close()
330
+ except Exception:
331
+ logging.exception("Error in worker pool")
332
+ raise
333
+ finally:
334
+ while not self.idle and not self.stopping:
335
+ self._stop_event.wait(1)
336
+ else:
337
+ self.stop(wait=bool(self.idle or self.waiting))
338
+
339
+ logging.info("Stopping workers...")
340
+ for worker_sentinel in worker_sentinels:
341
+ if worker_sentinel.is_alive():
342
+ worker_sentinel.stop(wait=self.waiting)
343
+ for worker_sentinel in worker_sentinels:
159
344
  worker_sentinel.join()
160
- if shutdown_sentinel.is_alive():
161
- shutdown_sentinel.join()
162
345
 
346
+ logging.info("Stopping manager...")
347
+ if self.manager_sentinel.is_alive():
348
+ self.manager_sentinel.stop()
349
+ self.manager_sentinel.join()
350
+
351
+ @property
163
352
  def idle(self):
164
- assert self._manager
353
+ """
354
+ Check if the pool is idle.
355
+
356
+ :return: True if idle, False otherwise.
357
+ """
358
+ assert self.manager_sentinel
165
359
  try:
166
- return self._manager.queue().idle()
360
+ return self.manager_sentinel.idle
167
361
  except (ConnectionRefusedError, ConnectionResetError):
168
362
  return True
169
363
 
170
364
  def stop(self, *, wait: bool = True) -> None:
365
+ """
366
+ Stop the pool process.
367
+
368
+ :param wait: Whether to wait for the pool to stop gracefully.
369
+ """
171
370
  if self.pid == current_process().pid:
172
- with self.lock:
173
- if self._stop_event and not self._stop_event.is_set():
174
- self._stop_event.set()
175
- for worker_sentinel in self._worker_sentinels:
176
- if worker_sentinel.is_alive():
177
- worker_sentinel.stop(wait=wait)
371
+ if wait and self.waiting is False and self.stopping is False:
372
+ assert self._wait_event
373
+ self._wait_event.set()
374
+ if self.stopping is False:
375
+ assert self._stop_event
376
+ self._stop_event.set()
178
377
  elif self.pid:
179
- if not self._client.connected:
180
- self._client.connect()
181
- self._client.stop(wait=wait)
378
+ self._session.stop(wait=wait)
379
+
380
+
381
+ class ManagerSentinel(Thread):
382
+ _wait_event: Event | None = None
383
+ _stop_event: Event | None = None
384
+ _queue: TaskQueue | None = None
385
+
386
+ def __init__(
387
+ self, address: tuple[str, int], authkey: bytes, *args, **kwargs
388
+ ) -> None:
389
+ self._manager: Manager = Manager(address=address, authkey=authkey)
390
+ self._server: Server = self._manager.get_server()
391
+ super().__init__(*args, name=self.__class__.__name__, **kwargs)
392
+
393
+ @property
394
+ def waiting(self) -> Event:
395
+ if not self._wait_event:
396
+ self._manager.connect()
397
+ self._wait_event = self._manager.waiting()
398
+ return self._wait_event
399
+
400
+ @property
401
+ def stopping(self) -> Event:
402
+ if not self._stop_event:
403
+ self._manager.connect()
404
+ self._stop_event = self._manager.stopping()
405
+ return self._stop_event
406
+
407
+ @property
408
+ def idle(self) -> bool | None:
409
+ if not self._queue:
410
+ self._manager.connect()
411
+ self._queue = self._manager.queue()
412
+ return self._queue.idle()
413
+
414
+ def run(self) -> None:
415
+ self._server.serve_forever()
416
+
417
+ def stop(self) -> None:
418
+ stop_event = getattr(self._server, "stop_event")
419
+ assert isinstance(stop_event, Event)
420
+ logging.debug("Stopping manager...")
421
+ stop_event.set()
182
422
 
183
423
 
184
424
  class WorkerSentinel(Thread):
185
425
  _worker: Worker | None = None
426
+ _semaphore: Semaphore = Semaphore(8)
186
427
 
187
428
  def __init__(
188
429
  self,
189
430
  address: tuple[str, int],
190
431
  *args,
191
432
  id: int,
192
- lock: Lock,
193
433
  cooldown: float = 1,
194
434
  log_level: int = logging.INFO,
435
+ scheduler: type[Scheduler] = Scheduler,
195
436
  **kwargs,
196
437
  ) -> None:
197
438
  self._address: tuple[str, int] = address
198
439
  self._id: int = id
199
- self._lock: Lock = lock
200
440
  self._cooldown: float = cooldown
201
441
  self._log_level: int = log_level
442
+ self._scheduler_type = scheduler
202
443
  self._stop_event: Event = Event()
444
+ self._wait_event: Event = Event()
445
+ self._ready: Event = Event()
203
446
  super().__init__(
204
447
  *args, name=f"{self.__class__.__name__}-{self.id}", **kwargs
205
448
  )
@@ -213,8 +456,8 @@ class WorkerSentinel(Thread):
213
456
  return self._id
214
457
 
215
458
  @property
216
- def lock(self) -> Lock:
217
- return self._lock
459
+ def ready(self) -> Event:
460
+ return self._ready
218
461
 
219
462
  @property
220
463
  def cooldown(self) -> float:
@@ -230,12 +473,23 @@ class WorkerSentinel(Thread):
230
473
  def log_level(self) -> int:
231
474
  return self._log_level
232
475
 
476
+ @property
477
+ def waiting(self) -> bool:
478
+ return self._wait_event.is_set()
479
+
480
+ @property
481
+ def stopping(self) -> bool:
482
+ return self._stop_event.is_set()
483
+
233
484
  @log_level.setter
234
485
  def log_level(self, value: int) -> None:
235
486
  if value < 0:
236
487
  raise ValueError("Log level must be non-negative")
237
488
  self._log_level = value
238
489
 
490
+ def start(self) -> None:
491
+ super().start()
492
+
239
493
  def run(self) -> None:
240
494
  logging.debug("Thread started")
241
495
  while not self._stop_event.is_set():
@@ -243,15 +497,13 @@ class WorkerSentinel(Thread):
243
497
  address=self._address,
244
498
  name=f"Worker-{self.id}",
245
499
  log_level=self.log_level,
500
+ scheduler=self._scheduler_type,
246
501
  )
247
- with self.lock:
248
- if self._stop_event.is_set():
249
- logging.debug("Worker interrupted before starting")
250
- break
502
+ with self._semaphore:
251
503
  worker.start()
252
- self._worker = worker
253
- logging.info(f"Spawned worker process {worker.pid}")
254
- sleep(0.05)
504
+ self._worker = worker
505
+ logging.info(f"Spawned worker process {worker.pid}")
506
+ self._ready.set()
255
507
  try:
256
508
  worker.join()
257
509
  except Exception as e:
@@ -259,58 +511,14 @@ class WorkerSentinel(Thread):
259
511
  finally:
260
512
  logging.info(f"Terminated worker process {worker.pid}")
261
513
  self._worker = None
262
- self._stop_event.wait(self.cooldown)
514
+ self._stop_event.wait(self.cooldown)
263
515
  logging.debug("Thread stopped")
264
516
 
265
517
  def stop(self, *, wait: bool = True) -> None:
266
- if not self._stop_event.is_set():
518
+ logging.info(f"Stopping thread {self.name}...")
519
+ if wait and not self.waiting:
520
+ self._wait_event.set()
521
+ if not self.stopping:
267
522
  self._stop_event.set()
268
- if self.worker:
269
- self.worker.stop(wait=wait)
270
-
271
-
272
- class ShutdownSentinel(Thread):
273
- def __init__(
274
- self,
275
- stop: Event,
276
- server: Server,
277
- worker_sentinels: WeakSet[WorkerSentinel],
278
- *args,
279
- **kwargs,
280
- ) -> None:
281
- self._stop_event: Event = stop
282
- self._server: Server = server
283
- self._worker_sentinels: WeakSet[WorkerSentinel] = worker_sentinels
284
- super().__init__(*args, name="ShutdownSentinel", **kwargs)
285
-
286
- @property
287
- def stop(self) -> Event:
288
- return self._stop_event
289
-
290
- @property
291
- def server(self) -> Server:
292
- return self._server
293
-
294
- @property
295
- def worker_sentinels(self) -> WeakSet[WorkerSentinel]:
296
- return self._worker_sentinels
297
-
298
- def run(self) -> None:
299
- logging.debug("Thread started")
300
- try:
301
- while not self._stop_event.is_set():
302
- self._stop_event.wait(1)
303
- else:
304
- logging.debug("Stopping workers...")
305
- for worker_sentinel in copy(self._worker_sentinels):
306
- worker_sentinel.join()
307
- stop_event = getattr(self.server, "stop_event")
308
- assert isinstance(stop_event, Event)
309
- logging.debug("Stopping manager...")
310
- stop_event.set()
311
- finally:
312
- logging.debug("Thread stopped")
313
-
314
-
315
- def stop(pool: WoolPool, wait: bool, *_):
316
- pool.stop(wait=wait)
523
+ if self._worker:
524
+ self._worker.stop(wait=self._wait_event.is_set())
File without changes
@@ -0,0 +1,36 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: _mempool/_metadata/_metadata.proto
5
+ # Protobuf Python Version: 6.30.0
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 6,
15
+ 30,
16
+ 0,
17
+ '',
18
+ '_mempool/_metadata/_metadata.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+
26
+
27
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"_mempool/_metadata/_metadata.proto\x12\x17wool._mempool._metadata\"K\n\x10_MetadataMessage\x12\x0b\n\x03ref\x18\x01 \x01(\t\x12\x0f\n\x07mutable\x18\x02 \x01(\x08\x12\x0c\n\x04size\x18\x03 \x01(\x03\x12\x0b\n\x03md5\x18\x04 \x01(\x0c\x62\x06proto3')
28
+
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, '_mempool._metadata._metadata_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ DESCRIPTOR._loaded_options = None
34
+ _globals['__METADATAMESSAGE']._serialized_start=63
35
+ _globals['__METADATAMESSAGE']._serialized_end=138
36
+ # @@protoc_insertion_point(module_scope)