pydocket 0.5.2__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydocket might be problematic. Click here for more details.

docket/__init__.py CHANGED
@@ -13,11 +13,13 @@ from .dependencies import (
13
13
  CurrentDocket,
14
14
  CurrentExecution,
15
15
  CurrentWorker,
16
+ Depends,
16
17
  ExponentialRetry,
17
18
  Perpetual,
18
19
  Retry,
19
20
  TaskKey,
20
21
  TaskLogger,
22
+ Timeout,
21
23
  )
22
24
  from .docket import Docket
23
25
  from .execution import Execution
@@ -36,5 +38,7 @@ __all__ = [
36
38
  "ExponentialRetry",
37
39
  "Logged",
38
40
  "Perpetual",
41
+ "Timeout",
42
+ "Depends",
39
43
  "__version__",
40
44
  ]
docket/annotations.py CHANGED
@@ -28,3 +28,19 @@ class Annotation(abc.ABC):
28
28
 
29
29
  class Logged(Annotation):
30
30
  """Instructs docket to include arguments to this parameter in the log."""
31
+
32
+ length_only: bool = False
33
+
34
+ def __init__(self, length_only: bool = False) -> None:
35
+ self.length_only = length_only
36
+
37
+ def format(self, argument: Any) -> str:
38
+ if self.length_only:
39
+ if isinstance(argument, (dict, set)):
40
+ return f"{{len {len(argument)}}}"
41
+ elif isinstance(argument, tuple):
42
+ return f"(len {len(argument)})"
43
+ elif hasattr(argument, "__len__"):
44
+ return f"[len {len(argument)}]"
45
+
46
+ return repr(argument)
docket/dependencies.py CHANGED
@@ -1,35 +1,60 @@
1
1
  import abc
2
- import inspect
3
2
  import logging
3
+ import time
4
+ from contextlib import AsyncExitStack, asynccontextmanager
5
+ from contextvars import ContextVar
4
6
  from datetime import timedelta
5
- from typing import Any, Awaitable, Callable, Counter, TypeVar, cast
7
+ from types import TracebackType
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ AsyncContextManager,
12
+ AsyncGenerator,
13
+ Awaitable,
14
+ Callable,
15
+ Counter,
16
+ Generic,
17
+ TypeVar,
18
+ cast,
19
+ )
6
20
 
7
21
  from .docket import Docket
8
- from .execution import Execution
9
- from .worker import Worker
22
+ from .execution import Execution, TaskFunction, get_signature
23
+
24
+ if TYPE_CHECKING: # pragma: no cover
25
+ from .worker import Worker
10
26
 
11
27
 
12
28
  class Dependency(abc.ABC):
13
29
  single: bool = False
14
30
 
31
+ docket: ContextVar[Docket] = ContextVar("docket")
32
+ worker: ContextVar["Worker"] = ContextVar("worker")
33
+ execution: ContextVar[Execution] = ContextVar("execution")
34
+
15
35
  @abc.abstractmethod
16
- def __call__(
17
- self, docket: Docket, worker: Worker, execution: Execution
18
- ) -> Any: ... # pragma: no cover
36
+ async def __aenter__(self) -> Any: ... # pragma: no cover
37
+
38
+ async def __aexit__(
39
+ self,
40
+ exc_type: type[BaseException] | None,
41
+ exc_value: BaseException | None,
42
+ traceback: TracebackType | None,
43
+ ) -> bool: ... # pragma: no cover
19
44
 
20
45
 
21
46
  class _CurrentWorker(Dependency):
22
- def __call__(self, docket: Docket, worker: Worker, execution: Execution) -> Worker:
23
- return worker
47
+ async def __aenter__(self) -> "Worker":
48
+ return self.worker.get()
24
49
 
25
50
 
26
- def CurrentWorker() -> Worker:
27
- return cast(Worker, _CurrentWorker())
51
+ def CurrentWorker() -> "Worker":
52
+ return cast("Worker", _CurrentWorker())
28
53
 
29
54
 
30
55
  class _CurrentDocket(Dependency):
31
- def __call__(self, docket: Docket, worker: Worker, execution: Execution) -> Docket:
32
- return docket
56
+ async def __aenter__(self) -> Docket:
57
+ return self.docket.get()
33
58
 
34
59
 
35
60
  def CurrentDocket() -> Docket:
@@ -37,10 +62,8 @@ def CurrentDocket() -> Docket:
37
62
 
38
63
 
39
64
  class _CurrentExecution(Dependency):
40
- def __call__(
41
- self, docket: Docket, worker: Worker, execution: Execution
42
- ) -> Execution:
43
- return execution
65
+ async def __aenter__(self) -> Execution:
66
+ return self.execution.get()
44
67
 
45
68
 
46
69
  def CurrentExecution() -> Execution:
@@ -48,8 +71,8 @@ def CurrentExecution() -> Execution:
48
71
 
49
72
 
50
73
  class _TaskKey(Dependency):
51
- def __call__(self, docket: Docket, worker: Worker, execution: Execution) -> str:
52
- return execution.key
74
+ async def __aenter__(self) -> str:
75
+ return self.execution.get().key
53
76
 
54
77
 
55
78
  def TaskKey() -> str:
@@ -57,15 +80,14 @@ def TaskKey() -> str:
57
80
 
58
81
 
59
82
  class _TaskLogger(Dependency):
60
- def __call__(
61
- self, docket: Docket, worker: Worker, execution: Execution
62
- ) -> logging.LoggerAdapter[logging.Logger]:
83
+ async def __aenter__(self) -> logging.LoggerAdapter[logging.Logger]:
84
+ execution = self.execution.get()
63
85
  logger = logging.getLogger(f"docket.task.{execution.function.__name__}")
64
86
  return logging.LoggerAdapter(
65
87
  logger,
66
88
  {
67
- **docket.labels(),
68
- **worker.labels(),
89
+ **self.docket.get().labels(),
90
+ **self.worker.get().labels(),
69
91
  **execution.specific_labels(),
70
92
  },
71
93
  )
@@ -85,7 +107,8 @@ class Retry(Dependency):
85
107
  self.delay = delay
86
108
  self.attempt = 1
87
109
 
88
- def __call__(self, docket: Docket, worker: Worker, execution: Execution) -> "Retry":
110
+ async def __aenter__(self) -> "Retry":
111
+ execution = self.execution.get()
89
112
  retry = Retry(attempts=self.attempts, delay=self.delay)
90
113
  retry.attempt = execution.attempt
91
114
  return retry
@@ -104,9 +127,9 @@ class ExponentialRetry(Retry):
104
127
  self.minimum_delay = minimum_delay
105
128
  self.maximum_delay = maximum_delay
106
129
 
107
- def __call__(
108
- self, docket: Docket, worker: Worker, execution: Execution
109
- ) -> "ExponentialRetry":
130
+ async def __aenter__(self) -> "ExponentialRetry":
131
+ execution = self.execution.get()
132
+
110
133
  retry = ExponentialRetry(
111
134
  attempts=self.attempts,
112
135
  minimum_delay=self.minimum_delay,
@@ -155,9 +178,8 @@ class Perpetual(Dependency):
155
178
  self.automatic = automatic
156
179
  self.cancelled = False
157
180
 
158
- def __call__(
159
- self, docket: Docket, worker: Worker, execution: Execution
160
- ) -> "Perpetual":
181
+ async def __aenter__(self) -> "Perpetual":
182
+ execution = self.execution.get()
161
183
  perpetual = Perpetual(every=self.every)
162
184
  perpetual.args = execution.args
163
185
  perpetual.kwargs = execution.kwargs
@@ -171,27 +193,121 @@ class Perpetual(Dependency):
171
193
  self.kwargs = kwargs
172
194
 
173
195
 
196
+ class Timeout(Dependency):
197
+ single = True
198
+
199
+ base: timedelta
200
+
201
+ _deadline: float
202
+
203
+ def __init__(self, base: timedelta) -> None:
204
+ self.base = base
205
+
206
+ async def __aenter__(self) -> "Timeout":
207
+ timeout = Timeout(base=self.base)
208
+ timeout.start()
209
+ return timeout
210
+
211
+ def start(self) -> None:
212
+ self._deadline = time.monotonic() + self.base.total_seconds()
213
+
214
+ def expired(self) -> bool:
215
+ return time.monotonic() >= self._deadline
216
+
217
+ def remaining(self) -> timedelta:
218
+ return timedelta(seconds=self._deadline - time.monotonic())
219
+
220
+ def extend(self, by: timedelta | None = None) -> None:
221
+ if by is None:
222
+ by = self.base
223
+ self._deadline += by.total_seconds()
224
+
225
+
226
+ R = TypeVar("R")
227
+
228
+ DependencyFunction = Callable[..., Awaitable[R] | AsyncContextManager[R]]
229
+
230
+
231
+ _parameter_cache: dict[
232
+ TaskFunction | DependencyFunction[Any],
233
+ dict[str, Dependency],
234
+ ] = {}
235
+
236
+
174
237
  def get_dependency_parameters(
175
- function: Callable[..., Awaitable[Any]],
238
+ function: TaskFunction | DependencyFunction[Any],
176
239
  ) -> dict[str, Dependency]:
177
- dependencies: dict[str, Any] = {}
240
+ if function in _parameter_cache:
241
+ return _parameter_cache[function]
242
+
243
+ dependencies: dict[str, Dependency] = {}
178
244
 
179
- signature = inspect.signature(function)
245
+ signature = get_signature(function)
180
246
 
181
- for param_name, param in signature.parameters.items():
247
+ for parameter, param in signature.parameters.items():
182
248
  if not isinstance(param.default, Dependency):
183
249
  continue
184
250
 
185
- dependencies[param_name] = param.default
251
+ dependencies[parameter] = param.default
186
252
 
253
+ _parameter_cache[function] = dependencies
187
254
  return dependencies
188
255
 
189
256
 
257
+ class _Depends(Dependency, Generic[R]):
258
+ dependency: DependencyFunction[R]
259
+
260
+ cache: ContextVar[dict[DependencyFunction[Any], Any]] = ContextVar("cache")
261
+ stack: ContextVar[AsyncExitStack] = ContextVar("stack")
262
+
263
+ def __init__(
264
+ self, dependency: Callable[[], Awaitable[R] | AsyncContextManager[R]]
265
+ ) -> None:
266
+ self.dependency = dependency
267
+
268
+ async def _resolve_parameters(
269
+ self,
270
+ function: TaskFunction | DependencyFunction[Any],
271
+ ) -> dict[str, Any]:
272
+ stack = self.stack.get()
273
+
274
+ arguments: dict[str, Any] = {}
275
+ parameters = get_dependency_parameters(function)
276
+
277
+ for parameter, dependency in parameters.items():
278
+ arguments[parameter] = await stack.enter_async_context(dependency)
279
+
280
+ return arguments
281
+
282
+ async def __aenter__(self) -> R:
283
+ cache = self.cache.get()
284
+
285
+ if self.dependency in cache:
286
+ return cache[self.dependency]
287
+
288
+ stack = self.stack.get()
289
+ arguments = await self._resolve_parameters(self.dependency)
290
+
291
+ value = self.dependency(**arguments)
292
+
293
+ if isinstance(value, AsyncContextManager):
294
+ value = await stack.enter_async_context(value)
295
+ else:
296
+ value = await value
297
+
298
+ cache[self.dependency] = value
299
+ return value
300
+
301
+
302
+ def Depends(dependency: DependencyFunction[R]) -> R:
303
+ return cast(R, _Depends(dependency))
304
+
305
+
190
306
  D = TypeVar("D", bound=Dependency)
191
307
 
192
308
 
193
309
  def get_single_dependency_parameter_of_type(
194
- function: Callable[..., Awaitable[Any]], dependency_type: type[D]
310
+ function: TaskFunction, dependency_type: type[D]
195
311
  ) -> D | None:
196
312
  assert dependency_type.single, "Dependency must be single"
197
313
  for _, dependency in get_dependency_parameters(function).items():
@@ -210,7 +326,7 @@ def get_single_dependency_of_type(
210
326
  return None
211
327
 
212
328
 
213
- def validate_dependencies(function: Callable[..., Awaitable[Any]]) -> None:
329
+ def validate_dependencies(function: TaskFunction) -> None:
214
330
  parameters = get_dependency_parameters(function)
215
331
 
216
332
  counts = Counter(type(dependency) for dependency in parameters.values())
@@ -220,3 +336,31 @@ def validate_dependencies(function: Callable[..., Awaitable[Any]]) -> None:
220
336
  raise ValueError(
221
337
  f"Only one {dependency_type.__name__} dependency is allowed per task"
222
338
  )
339
+
340
+
341
+ @asynccontextmanager
342
+ async def resolved_dependencies(
343
+ worker: "Worker", execution: Execution
344
+ ) -> AsyncGenerator[dict[str, Any], None]:
345
+ # Set context variables once at the beginning
346
+ Dependency.docket.set(worker.docket)
347
+ Dependency.worker.set(worker)
348
+ Dependency.execution.set(execution)
349
+
350
+ _Depends.cache.set({})
351
+
352
+ async with AsyncExitStack() as stack:
353
+ _Depends.stack.set(stack)
354
+
355
+ arguments: dict[str, Any] = {}
356
+
357
+ parameters = get_dependency_parameters(execution.function)
358
+ for parameter, dependency in parameters.items():
359
+ kwargs = execution.kwargs
360
+ if parameter in kwargs:
361
+ arguments[parameter] = kwargs[parameter]
362
+ continue
363
+
364
+ arguments[parameter] = await stack.enter_async_context(dependency)
365
+
366
+ yield arguments
docket/docket.py CHANGED
@@ -38,6 +38,7 @@ from .execution import (
38
38
  Strike,
39
39
  StrikeInstruction,
40
40
  StrikeList,
41
+ TaskFunction,
41
42
  )
42
43
  from .instrumentation import (
43
44
  REDIS_DISRUPTIONS,
@@ -57,7 +58,7 @@ tracer: trace.Tracer = trace.get_tracer(__name__)
57
58
  P = ParamSpec("P")
58
59
  R = TypeVar("R")
59
60
 
60
- TaskCollection = Iterable[Callable[..., Awaitable[Any]]]
61
+ TaskCollection = Iterable[TaskFunction]
61
62
 
62
63
  RedisStreamID = bytes
63
64
  RedisMessageID = bytes
@@ -91,7 +92,7 @@ class RunningExecution(Execution):
91
92
  worker: str,
92
93
  started: datetime,
93
94
  ) -> None:
94
- self.function: Callable[..., Awaitable[Any]] = execution.function
95
+ self.function: TaskFunction = execution.function
95
96
  self.args: tuple[Any, ...] = execution.args
96
97
  self.kwargs: dict[str, Any] = execution.kwargs
97
98
  self.when: datetime = execution.when
@@ -111,7 +112,7 @@ class DocketSnapshot:
111
112
 
112
113
 
113
114
  class Docket:
114
- tasks: dict[str, Callable[..., Awaitable[Any]]]
115
+ tasks: dict[str, TaskFunction]
115
116
  strike_list: StrikeList
116
117
 
117
118
  _monitor_strikes_task: asyncio.Task[None]
@@ -197,7 +198,7 @@ class Docket:
197
198
  finally:
198
199
  await asyncio.shield(r.__aexit__(None, None, None))
199
200
 
200
- def register(self, function: Callable[..., Awaitable[Any]]) -> None:
201
+ def register(self, function: TaskFunction) -> None:
201
202
  from .dependencies import validate_dependencies
202
203
 
203
204
  validate_dependencies(function)
docket/execution.py CHANGED
@@ -7,23 +7,40 @@ from typing import Any, Awaitable, Callable, Hashable, Literal, Mapping, Self, c
7
7
 
8
8
  import cloudpickle # type: ignore[import]
9
9
 
10
+ from opentelemetry import propagate
11
+ import opentelemetry.context
10
12
 
11
13
  from .annotations import Logged
14
+ from docket.instrumentation import message_getter
12
15
 
13
16
  logger: logging.Logger = logging.getLogger(__name__)
14
17
 
18
+ TaskFunction = Callable[..., Awaitable[Any]]
15
19
  Message = dict[bytes, bytes]
16
20
 
17
21
 
22
+ _signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
23
+
24
+
25
+ def get_signature(function: Callable[..., Any]) -> inspect.Signature:
26
+ if function in _signature_cache:
27
+ return _signature_cache[function]
28
+
29
+ signature = inspect.signature(function)
30
+ _signature_cache[function] = signature
31
+ return signature
32
+
33
+
18
34
  class Execution:
19
35
  def __init__(
20
36
  self,
21
- function: Callable[..., Awaitable[Any]],
37
+ function: TaskFunction,
22
38
  args: tuple[Any, ...],
23
39
  kwargs: dict[str, Any],
24
40
  when: datetime,
25
41
  key: str,
26
42
  attempt: int,
43
+ trace_context: opentelemetry.context.Context | None = None,
27
44
  ) -> None:
28
45
  self.function = function
29
46
  self.args = args
@@ -31,6 +48,7 @@ class Execution:
31
48
  self.when = when
32
49
  self.key = key
33
50
  self.attempt = attempt
51
+ self.trace_context = trace_context
34
52
 
35
53
  def as_message(self) -> Message:
36
54
  return {
@@ -43,9 +61,7 @@ class Execution:
43
61
  }
44
62
 
45
63
  @classmethod
46
- def from_message(
47
- cls, function: Callable[..., Awaitable[Any]], message: Message
48
- ) -> Self:
64
+ def from_message(cls, function: TaskFunction, message: Message) -> Self:
49
65
  return cls(
50
66
  function=function,
51
67
  args=cloudpickle.loads(message[b"args"]),
@@ -53,6 +69,7 @@ class Execution:
53
69
  when=datetime.fromisoformat(message[b"when"].decode()),
54
70
  key=message[b"key"].decode(),
55
71
  attempt=int(message[b"attempt"].decode()),
72
+ trace_context=propagate.extract(message, getter=message_getter),
56
73
  )
57
74
 
58
75
  def general_labels(self) -> Mapping[str, str]:
@@ -68,7 +85,7 @@ class Execution:
68
85
 
69
86
  def call_repr(self) -> str:
70
87
  arguments: list[str] = []
71
- signature = inspect.signature(self.function)
88
+ signature = get_signature(self.function)
72
89
  function_name = self.function.__name__
73
90
 
74
91
  logged_parameters = Logged.annotated_parameters(signature)
@@ -77,14 +94,14 @@ class Execution:
77
94
 
78
95
  for i, argument in enumerate(self.args[: len(parameter_names)]):
79
96
  parameter_name = parameter_names[i]
80
- if parameter_name in logged_parameters:
81
- arguments.append(repr(argument))
97
+ if logged := logged_parameters.get(parameter_name):
98
+ arguments.append(logged.format(argument))
82
99
  else:
83
100
  arguments.append("...")
84
101
 
85
102
  for parameter_name, argument in self.kwargs.items():
86
- if parameter_name in logged_parameters:
87
- arguments.append(f"{parameter_name}={repr(argument)}")
103
+ if logged := logged_parameters.get(parameter_name):
104
+ arguments.append(f"{parameter_name}={logged.format(argument)}")
88
105
  else:
89
106
  arguments.append(f"{parameter_name}=...")
90
107
 
@@ -217,10 +234,10 @@ class StrikeList:
217
234
  if function_name in self.task_strikes and not task_strikes:
218
235
  return True
219
236
 
220
- sig = inspect.signature(execution.function)
237
+ signature = get_signature(execution.function)
221
238
 
222
239
  try:
223
- bound_args = sig.bind(*execution.args, **execution.kwargs)
240
+ bound_args = signature.bind(*execution.args, **execution.kwargs)
224
241
  bound_args.apply_defaults()
225
242
  except TypeError:
226
243
  # If we can't make sense of the arguments, just assume the task is fine
@@ -265,6 +282,8 @@ class StrikeList:
265
282
  case "between": # pragma: no branch
266
283
  lower, upper = strike_value
267
284
  return lower <= value <= upper
285
+ case _: # pragma: no cover
286
+ raise ValueError(f"Unknown operator: {operator}")
268
287
  except (ValueError, TypeError):
269
288
  # If we can't make the comparison due to incompatible types, just log the
270
289
  # error and assume the task is not stricken
docket/worker.py CHANGED
@@ -1,11 +1,10 @@
1
1
  import asyncio
2
- import inspect
3
2
  import logging
4
3
  import sys
5
4
  from datetime import datetime, timedelta, timezone
6
5
  from types import TracebackType
7
6
  from typing import (
8
- TYPE_CHECKING,
7
+ Coroutine,
9
8
  Mapping,
10
9
  Protocol,
11
10
  Self,
@@ -13,18 +12,27 @@ from typing import (
13
12
  )
14
13
  from uuid import uuid4
15
14
 
16
- import redis.exceptions
17
- from opentelemetry import propagate, trace
15
+ from opentelemetry import trace
18
16
  from opentelemetry.trace import Tracer
19
17
  from redis.asyncio import Redis
20
- from redis.exceptions import LockError
21
-
18
+ from redis.exceptions import ConnectionError, LockError
19
+
20
+ from docket.execution import get_signature
21
+
22
+ from .dependencies import (
23
+ Dependency,
24
+ Perpetual,
25
+ Retry,
26
+ Timeout,
27
+ get_single_dependency_of_type,
28
+ get_single_dependency_parameter_of_type,
29
+ resolved_dependencies,
30
+ )
22
31
  from .docket import (
23
32
  Docket,
24
33
  Execution,
25
34
  RedisMessage,
26
35
  RedisMessageID,
27
- RedisMessages,
28
36
  RedisReadGroupResponse,
29
37
  )
30
38
  from .instrumentation import (
@@ -41,7 +49,6 @@ from .instrumentation import (
41
49
  TASKS_STARTED,
42
50
  TASKS_STRICKEN,
43
51
  TASKS_SUCCEEDED,
44
- message_getter,
45
52
  metrics_server,
46
53
  )
47
54
 
@@ -49,10 +56,6 @@ logger: logging.Logger = logging.getLogger(__name__)
49
56
  tracer: Tracer = trace.get_tracer(__name__)
50
57
 
51
58
 
52
- if TYPE_CHECKING: # pragma: no cover
53
- from .dependencies import Dependency
54
-
55
-
56
59
  class _stream_due_tasks(Protocol):
57
60
  async def __call__(
58
61
  self, keys: list[str], args: list[str | float]
@@ -75,7 +78,7 @@ class Worker:
75
78
  concurrency: int = 10,
76
79
  redelivery_timeout: timedelta = timedelta(minutes=5),
77
80
  reconnection_delay: timedelta = timedelta(seconds=5),
78
- minimum_check_interval: timedelta = timedelta(milliseconds=100),
81
+ minimum_check_interval: timedelta = timedelta(milliseconds=250),
79
82
  scheduling_resolution: timedelta = timedelta(milliseconds=250),
80
83
  ) -> None:
81
84
  self.docket = docket
@@ -197,13 +200,14 @@ class Worker:
197
200
  async def _run(self, forever: bool = False) -> None:
198
201
  logger.info("Starting worker %r with the following tasks:", self.name)
199
202
  for task_name, task in self.docket.tasks.items():
200
- signature = inspect.signature(task)
203
+ signature = get_signature(task)
201
204
  logger.info("* %s%s", task_name, signature)
202
205
 
203
206
  while True:
204
207
  try:
205
- return await self._worker_loop(forever=forever)
206
- except redis.exceptions.ConnectionError:
208
+ async with self.docket.redis() as redis:
209
+ return await self._worker_loop(redis, forever=forever)
210
+ except ConnectionError:
207
211
  REDIS_DISRUPTIONS.add(1, self.labels())
208
212
  logger.warning(
209
213
  "Error connecting to redis, retrying in %s...",
@@ -212,123 +216,133 @@ class Worker:
212
216
  )
213
217
  await asyncio.sleep(self.reconnection_delay.total_seconds())
214
218
 
215
- async def _worker_loop(self, forever: bool = False):
219
+ async def _worker_loop(self, redis: Redis, forever: bool = False):
216
220
  worker_stopping = asyncio.Event()
217
221
 
218
222
  await self._schedule_all_automatic_perpetual_tasks()
219
223
 
220
- async with self.docket.redis() as redis:
221
- scheduler_task = asyncio.create_task(
222
- self._scheduler_loop(redis, worker_stopping)
224
+ scheduler_task = asyncio.create_task(
225
+ self._scheduler_loop(redis, worker_stopping)
226
+ )
227
+
228
+ active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
229
+ available_slots = self.concurrency
230
+
231
+ async def check_for_work() -> bool:
232
+ logger.debug("Checking for work", extra=self._log_context())
233
+ async with redis.pipeline() as pipeline:
234
+ pipeline.xlen(self.docket.stream_key)
235
+ pipeline.zcard(self.docket.queue_key)
236
+ results: list[int] = await pipeline.execute()
237
+ stream_len = results[0]
238
+ queue_len = results[1]
239
+ return stream_len > 0 or queue_len > 0
240
+
241
+ async def get_redeliveries(redis: Redis) -> RedisReadGroupResponse:
242
+ logger.debug("Getting redeliveries", extra=self._log_context())
243
+ _, redeliveries, *_ = await redis.xautoclaim(
244
+ name=self.docket.stream_key,
245
+ groupname=self.docket.worker_group_name,
246
+ consumername=self.name,
247
+ min_idle_time=int(self.redelivery_timeout.total_seconds() * 1000),
248
+ start_id="0-0",
249
+ count=available_slots,
250
+ )
251
+ return [(b"__redelivery__", redeliveries)]
252
+
253
+ async def get_new_deliveries(redis: Redis) -> RedisReadGroupResponse:
254
+ logger.debug("Getting new deliveries", extra=self._log_context())
255
+ return await redis.xreadgroup(
256
+ groupname=self.docket.worker_group_name,
257
+ consumername=self.name,
258
+ streams={self.docket.stream_key: ">"},
259
+ block=int(self.minimum_check_interval.total_seconds() * 1000),
260
+ count=available_slots,
223
261
  )
224
- active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
225
-
226
- async def check_for_work() -> bool:
227
- async with redis.pipeline() as pipeline:
228
- pipeline.xlen(self.docket.stream_key)
229
- pipeline.zcard(self.docket.queue_key)
230
- results: list[int] = await pipeline.execute()
231
- stream_len = results[0]
232
- queue_len = results[1]
233
- return stream_len > 0 or queue_len > 0
234
-
235
- async def process_completed_tasks() -> None:
236
- completed_tasks = {task for task in active_tasks if task.done()}
237
- for task in completed_tasks:
238
- message_id = active_tasks.pop(task)
239
-
240
- await task
241
-
242
- async with redis.pipeline() as pipeline:
243
- pipeline.xack(
244
- self.docket.stream_key,
245
- self.docket.worker_group_name,
246
- message_id,
247
- )
248
- pipeline.xdel(
249
- self.docket.stream_key,
250
- message_id,
251
- )
252
- await pipeline.execute()
253
262
 
254
- has_work: bool = True
263
+ def start_task(message_id: RedisMessageID, message: RedisMessage) -> bool:
264
+ if not message: # pragma: no cover
265
+ return False
255
266
 
256
- if not forever: # pragma: no branch
257
- has_work = await check_for_work()
267
+ function_name = message[b"function"].decode()
268
+ if not (function := self.docket.tasks.get(function_name)):
269
+ logger.warning(
270
+ "Task function %r not found",
271
+ function_name,
272
+ extra=self._log_context(),
273
+ )
274
+ return False
258
275
 
259
- try:
260
- while forever or has_work or active_tasks:
261
- await process_completed_tasks()
276
+ execution = Execution.from_message(function, message)
262
277
 
263
- available_slots = self.concurrency - len(active_tasks)
278
+ task = asyncio.create_task(self._execute(execution))
279
+ active_tasks[task] = message_id
264
280
 
265
- def start_task(
266
- message_id: RedisMessageID, message: RedisMessage
267
- ) -> None:
268
- if not message: # pragma: no cover
269
- return
281
+ nonlocal available_slots
282
+ available_slots -= 1
270
283
 
271
- task = asyncio.create_task(self._execute(message))
272
- active_tasks[task] = message_id
284
+ return True
273
285
 
274
- nonlocal available_slots
275
- available_slots -= 1
286
+ async def ack_message(redis: Redis, message_id: RedisMessageID) -> None:
287
+ logger.debug("Acknowledging message", extra=self._log_context())
288
+ async with redis.pipeline() as pipeline:
289
+ pipeline.xack(
290
+ self.docket.stream_key,
291
+ self.docket.worker_group_name,
292
+ message_id,
293
+ )
294
+ pipeline.xdel(
295
+ self.docket.stream_key,
296
+ message_id,
297
+ )
298
+ await pipeline.execute()
276
299
 
277
- if available_slots <= 0:
278
- await asyncio.sleep(self.minimum_check_interval.total_seconds())
279
- continue
280
-
281
- redeliveries: RedisMessages
282
- _, redeliveries, *_ = await redis.xautoclaim(
283
- name=self.docket.stream_key,
284
- groupname=self.docket.worker_group_name,
285
- consumername=self.name,
286
- min_idle_time=int(
287
- self.redelivery_timeout.total_seconds() * 1000
288
- ),
289
- start_id="0-0",
290
- count=available_slots,
291
- )
300
+ async def process_completed_tasks() -> None:
301
+ completed_tasks = {task for task in active_tasks if task.done()}
302
+ for task in completed_tasks:
303
+ message_id = active_tasks.pop(task)
304
+ await task
305
+ await ack_message(redis, message_id)
292
306
 
293
- for message_id, message in redeliveries:
294
- start_task(message_id, message)
307
+ has_work: bool = True
295
308
 
296
- if available_slots <= 0:
297
- continue
298
-
299
- new_deliveries: RedisReadGroupResponse = await redis.xreadgroup(
300
- groupname=self.docket.worker_group_name,
301
- consumername=self.name,
302
- streams={self.docket.stream_key: ">"},
303
- block=(
304
- int(self.minimum_check_interval.total_seconds() * 1000)
305
- if forever or active_tasks
306
- else None
307
- ),
308
- count=available_slots,
309
- )
309
+ try:
310
+ while forever or has_work or active_tasks:
311
+ await process_completed_tasks()
312
+
313
+ available_slots = self.concurrency - len(active_tasks)
310
314
 
311
- for _, messages in new_deliveries:
315
+ if available_slots <= 0:
316
+ await asyncio.sleep(self.minimum_check_interval.total_seconds())
317
+ continue
318
+
319
+ for source in [get_redeliveries, get_new_deliveries]:
320
+ for _, messages in await source(redis):
312
321
  for message_id, message in messages:
313
- start_task(message_id, message)
322
+ if not start_task(message_id, message):
323
+ await self._delete_known_task(redis, message)
324
+ await ack_message(redis, message_id)
314
325
 
315
- if not forever and not active_tasks and not new_deliveries:
316
- has_work = await check_for_work()
326
+ if available_slots <= 0:
327
+ break
317
328
 
318
- except asyncio.CancelledError:
319
- if active_tasks: # pragma: no cover
320
- logger.info(
321
- "Shutdown requested, finishing %d active tasks...",
322
- len(active_tasks),
323
- extra=self._log_context(),
324
- )
325
- finally:
326
- if active_tasks:
327
- await asyncio.gather(*active_tasks, return_exceptions=True)
328
- await process_completed_tasks()
329
+ if not forever and not active_tasks:
330
+ has_work = await check_for_work()
329
331
 
330
- worker_stopping.set()
331
- await scheduler_task
332
+ except asyncio.CancelledError:
333
+ if active_tasks: # pragma: no cover
334
+ logger.info(
335
+ "Shutdown requested, finishing %d active tasks...",
336
+ len(active_tasks),
337
+ extra=self._log_context(),
338
+ )
339
+ finally:
340
+ if active_tasks:
341
+ await asyncio.gather(*active_tasks, return_exceptions=True)
342
+ await process_completed_tasks()
343
+
344
+ worker_stopping.set()
345
+ await scheduler_task
332
346
 
333
347
  async def _scheduler_loop(
334
348
  self,
@@ -389,6 +403,7 @@ class Worker:
389
403
 
390
404
  while not worker_stopping.is_set() or total_work:
391
405
  try:
406
+ logger.debug("Scheduling due tasks", extra=self._log_context())
392
407
  total_work, due_work = await stream_due_tasks(
393
408
  keys=[self.docket.queue_key, self.docket.stream_key],
394
409
  args=[datetime.now(timezone.utc).timestamp(), self.docket.name],
@@ -415,8 +430,6 @@ class Worker:
415
430
  logger.debug("Scheduler loop finished", extra=self._log_context())
416
431
 
417
432
  async def _schedule_all_automatic_perpetual_tasks(self) -> None:
418
- from .dependencies import Perpetual, get_single_dependency_parameter_of_type
419
-
420
433
  async with self.docket.redis() as redis:
421
434
  try:
422
435
  async with redis.lock(
@@ -438,25 +451,22 @@ class Worker:
438
451
  except LockError: # pragma: no cover
439
452
  return
440
453
 
441
- async def _execute(self, message: RedisMessage) -> None:
442
- key = message[b"key"].decode()
443
-
444
- log_context: Mapping[str, str | float] = self._log_context()
445
-
446
- function_name = message[b"function"].decode()
447
- function = self.docket.tasks.get(function_name)
448
- if function is None:
449
- async with self.docket.redis() as redis:
450
- await redis.delete(self.docket.known_task_key(key))
451
- logger.warning(
452
- "Task function %r not found", function_name, extra=log_context
453
- )
454
+ async def _delete_known_task(
455
+ self, redis: Redis, execution_or_message: Execution | RedisMessage
456
+ ) -> None:
457
+ if isinstance(execution_or_message, Execution):
458
+ key = execution_or_message.key
459
+ elif bytes_key := execution_or_message.get(b"key"):
460
+ key = bytes_key.decode()
461
+ else: # pragma: no cover
454
462
  return
455
463
 
456
- execution = Execution.from_message(function, message)
457
- dependencies = self._get_dependencies(execution)
464
+ logger.debug("Deleting known task", extra=self._log_context())
465
+ known_task_key = self.docket.known_task_key(key)
466
+ await redis.delete(known_task_key)
458
467
 
459
- log_context = {**log_context, **execution.specific_labels()}
468
+ async def _execute(self, execution: Execution) -> None:
469
+ log_context = {**self._log_context(), **execution.specific_labels()}
460
470
  counter_labels = {**self.labels(), **execution.general_labels()}
461
471
 
462
472
  arrow = "↬" if execution.attempt > 1 else "↪"
@@ -464,7 +474,7 @@ class Worker:
464
474
 
465
475
  if self.docket.strike_list.is_stricken(execution):
466
476
  async with self.docket.redis() as redis:
467
- await redis.delete(self.docket.known_task_key(key))
477
+ await self._delete_known_task(redis, execution)
468
478
 
469
479
  arrow = "🗙"
470
480
  logger.warning("%s %s", arrow, call, extra=log_context)
@@ -474,20 +484,16 @@ class Worker:
474
484
  if execution.key in self._execution_counts:
475
485
  self._execution_counts[execution.key] += 1
476
486
 
477
- # Preemptively reschedule the perpetual task for the future, or clear the
478
- # known task key for this task
479
- rescheduled = await self._perpetuate_if_requested(execution, dependencies)
480
- if not rescheduled:
481
- async with self.docket.redis() as redis:
482
- await redis.delete(self.docket.known_task_key(key))
483
-
484
- context = propagate.extract(message, getter=message_getter)
485
- initiating_context = trace.get_current_span(context).get_span_context()
487
+ initiating_span = trace.get_current_span(execution.trace_context)
488
+ initiating_context = initiating_span.get_span_context()
486
489
  links = [trace.Link(initiating_context)] if initiating_context.is_valid else []
487
490
 
488
491
  start = datetime.now(timezone.utc)
489
492
  punctuality = start - execution.when
490
- log_context = {**log_context, "punctuality": punctuality.total_seconds()}
493
+ log_context = {
494
+ **log_context,
495
+ "punctuality": punctuality.total_seconds(),
496
+ }
491
497
  duration = timedelta(0)
492
498
 
493
499
  TASKS_STARTED.add(1, counter_labels)
@@ -496,77 +502,103 @@ class Worker:
496
502
 
497
503
  logger.info("%s [%s] %s", arrow, punctuality, call, extra=log_context)
498
504
 
499
- try:
500
- with tracer.start_as_current_span(
501
- execution.function.__name__,
502
- kind=trace.SpanKind.CONSUMER,
503
- attributes={
504
- **self.labels(),
505
- **execution.specific_labels(),
506
- "code.function.name": execution.function.__name__,
507
- },
508
- links=links,
509
- ):
510
- await execution.function(
511
- *execution.args,
512
- **{
513
- **execution.kwargs,
514
- **dependencies,
515
- },
505
+ with tracer.start_as_current_span(
506
+ execution.function.__name__,
507
+ kind=trace.SpanKind.CONSUMER,
508
+ attributes={
509
+ **self.labels(),
510
+ **execution.specific_labels(),
511
+ "code.function.name": execution.function.__name__,
512
+ },
513
+ links=links,
514
+ ):
515
+ async with resolved_dependencies(self, execution) as dependencies:
516
+ # Preemptively reschedule the perpetual task for the future, or clear
517
+ # the known task key for this task
518
+ rescheduled = await self._perpetuate_if_requested(
519
+ execution, dependencies
516
520
  )
521
+ if not rescheduled:
522
+ async with self.docket.redis() as redis:
523
+ await self._delete_known_task(redis, execution)
524
+
525
+ try:
526
+ if timeout := get_single_dependency_of_type(dependencies, Timeout):
527
+ await self._run_function_with_timeout(
528
+ execution, dependencies, timeout
529
+ )
530
+ else:
531
+ await execution.function(
532
+ *execution.args,
533
+ **{
534
+ **execution.kwargs,
535
+ **dependencies,
536
+ },
537
+ )
517
538
 
518
- TASKS_SUCCEEDED.add(1, counter_labels)
519
- duration = datetime.now(timezone.utc) - start
520
- log_context["duration"] = duration.total_seconds()
521
- rescheduled = await self._perpetuate_if_requested(
522
- execution, dependencies, duration
523
- )
524
- arrow = "↫" if rescheduled else "↩"
525
- logger.info("%s [%s] %s", arrow, duration, call, extra=log_context)
526
- except Exception:
527
- TASKS_FAILED.add(1, counter_labels)
528
- duration = datetime.now(timezone.utc) - start
529
- log_context["duration"] = duration.total_seconds()
530
- retried = await self._retry_if_requested(execution, dependencies)
531
- if not retried:
532
- retried = await self._perpetuate_if_requested(
533
- execution, dependencies, duration
534
- )
535
- arrow = "↫" if retried else "↩"
536
- logger.exception("%s [%s] %s", arrow, duration, call, extra=log_context)
537
- finally:
538
- TASKS_RUNNING.add(-1, counter_labels)
539
- TASKS_COMPLETED.add(1, counter_labels)
540
- TASK_DURATION.record(duration.total_seconds(), counter_labels)
539
+ TASKS_SUCCEEDED.add(1, counter_labels)
540
+ duration = datetime.now(timezone.utc) - start
541
+ log_context["duration"] = duration.total_seconds()
542
+ rescheduled = await self._perpetuate_if_requested(
543
+ execution, dependencies, duration
544
+ )
545
+ arrow = "↫" if rescheduled else "↩"
546
+ logger.info("%s [%s] %s", arrow, duration, call, extra=log_context)
547
+ except Exception:
548
+ TASKS_FAILED.add(1, counter_labels)
549
+ duration = datetime.now(timezone.utc) - start
550
+ log_context["duration"] = duration.total_seconds()
551
+ retried = await self._retry_if_requested(execution, dependencies)
552
+ if not retried:
553
+ retried = await self._perpetuate_if_requested(
554
+ execution, dependencies, duration
555
+ )
556
+ arrow = "↫" if retried else "↩"
557
+ logger.exception(
558
+ "%s [%s] %s", arrow, duration, call, extra=log_context
559
+ )
560
+ finally:
561
+ TASKS_RUNNING.add(-1, counter_labels)
562
+ TASKS_COMPLETED.add(1, counter_labels)
563
+ TASK_DURATION.record(duration.total_seconds(), counter_labels)
541
564
 
542
- def _get_dependencies(
565
+ async def _run_function_with_timeout(
543
566
  self,
544
567
  execution: Execution,
545
- ) -> dict[str, "Dependency"]:
546
- from .dependencies import get_dependency_parameters
547
-
548
- parameters = get_dependency_parameters(execution.function)
549
-
550
- dependencies: dict[str, "Dependency"] = {}
551
-
552
- for parameter_name, dependency in parameters.items():
553
- # If the argument is already provided, skip it, which allows users to call
554
- # the function directly with the arguments they want.
555
- if parameter_name in execution.kwargs:
556
- dependencies[parameter_name] = execution.kwargs[parameter_name]
557
- continue
558
-
559
- dependencies[parameter_name] = dependency(self.docket, self, execution)
568
+ dependencies: dict[str, Dependency],
569
+ timeout: Timeout,
570
+ ) -> None:
571
+ task_coro = cast(
572
+ Coroutine[None, None, None],
573
+ execution.function(*execution.args, **execution.kwargs, **dependencies),
574
+ )
575
+ task = asyncio.create_task(task_coro)
576
+ try:
577
+ while not task.done(): # pragma: no branch
578
+ remaining = timeout.remaining().total_seconds()
579
+ if timeout.expired():
580
+ task.cancel()
581
+ break
582
+
583
+ try:
584
+ await asyncio.wait_for(asyncio.shield(task), timeout=remaining)
585
+ return
586
+ except asyncio.TimeoutError:
587
+ continue
588
+ finally:
589
+ if not task.done():
590
+ task.cancel()
560
591
 
561
- return dependencies
592
+ try:
593
+ await task
594
+ except asyncio.CancelledError:
595
+ raise asyncio.TimeoutError
562
596
 
563
597
  async def _retry_if_requested(
564
598
  self,
565
599
  execution: Execution,
566
- dependencies: dict[str, "Dependency"],
600
+ dependencies: dict[str, Dependency],
567
601
  ) -> bool:
568
- from .dependencies import Retry, get_single_dependency_of_type
569
-
570
602
  retry = get_single_dependency_of_type(dependencies, Retry)
571
603
  if not retry:
572
604
  return False
@@ -584,11 +616,9 @@ class Worker:
584
616
  async def _perpetuate_if_requested(
585
617
  self,
586
618
  execution: Execution,
587
- dependencies: dict[str, "Dependency"],
619
+ dependencies: dict[str, Dependency],
588
620
  duration: timedelta | None = None,
589
621
  ) -> bool:
590
- from .dependencies import Perpetual, get_single_dependency_of_type
591
-
592
622
  perpetual = get_single_dependency_of_type(dependencies, Perpetual)
593
623
  if not perpetual:
594
624
  return False
@@ -667,7 +697,7 @@ class Worker:
667
697
 
668
698
  except asyncio.CancelledError: # pragma: no cover
669
699
  return
670
- except redis.exceptions.ConnectionError:
700
+ except ConnectionError:
671
701
  REDIS_DISRUPTIONS.add(1, self.labels())
672
702
  logger.exception(
673
703
  "Error sending worker heartbeat",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydocket
3
- Version: 0.5.2
3
+ Version: 0.6.0
4
4
  Summary: A distributed background task system for Python functions
5
5
  Project-URL: Homepage, https://github.com/chrisguidry/docket
6
6
  Project-URL: Bug Tracker, https://github.com/chrisguidry/docket/issues
@@ -0,0 +1,16 @@
1
+ docket/__init__.py,sha256=124XWbyQQHO1lhCoLQ-oheZnu4vNDHIaq4Whb7z3ogI,831
2
+ docket/__main__.py,sha256=Vkuh7aJ-Bl7QVpVbbkUksAd_hn05FiLmWbc-8kbhZQ4,34
3
+ docket/annotations.py,sha256=I00zB32BYWOQSNEjjCkc5n5DwTnT277I_BRYUJPS7w4,1474
4
+ docket/cli.py,sha256=OWql6QFthSbvRCGkIg-ufo26F48z0eCmzRXJYOdyAEc,20309
5
+ docket/dependencies.py,sha256=pkjseBZjdSpgW9g2H4cZ_RXIRZ2ZfdngBCXJGUcbmao,10052
6
+ docket/docket.py,sha256=KJxgiyOskEHsRQOmfgLpJCYDNNleHI-vEKK3uBPL_K8,21420
7
+ docket/execution.py,sha256=da1uYxSNAfz5FuNyCzX4I_PglHiMaf1oEv--K5TkjXc,13297
8
+ docket/instrumentation.py,sha256=bZlGA02JoJcY0J1WGm5_qXDfY0AXKr0ZLAYu67wkeKY,4611
9
+ docket/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ docket/tasks.py,sha256=RIlSM2omh-YDwVnCz6M5MtmK8T_m_s1w2OlRRxDUs6A,1437
11
+ docket/worker.py,sha256=3sMcwGfSJ0Q4y5AuaqdgiGniDhJ21nM2PQmroJi_Q-A,26430
12
+ pydocket-0.6.0.dist-info/METADATA,sha256=ktk1hqLmP_VSqYmdRtHFDPbEeRQD1J66ZAHEqaDXejk,13092
13
+ pydocket-0.6.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
+ pydocket-0.6.0.dist-info/entry_points.txt,sha256=4WOk1nUlBsUT5O3RyMci2ImuC5XFswuopElYcLHtD5k,47
15
+ pydocket-0.6.0.dist-info/licenses/LICENSE,sha256=YuVWU_ZXO0K_k2FG8xWKe5RGxV24AhJKTvQmKfqXuyk,1087
16
+ pydocket-0.6.0.dist-info/RECORD,,
@@ -1,16 +0,0 @@
1
- docket/__init__.py,sha256=7oruGALDoU6W_ntF-mMxxv3FFtO970DVzj3lUgoVIiM,775
2
- docket/__main__.py,sha256=Vkuh7aJ-Bl7QVpVbbkUksAd_hn05FiLmWbc-8kbhZQ4,34
3
- docket/annotations.py,sha256=GZwOPtPXyeIhnsLh3TQMBnXrjtTtSmF4Ratv4vjPx8U,950
4
- docket/cli.py,sha256=OWql6QFthSbvRCGkIg-ufo26F48z0eCmzRXJYOdyAEc,20309
5
- docket/dependencies.py,sha256=0P8GJTMWrzm9uZkQejCiRfT6IBisY7Hp1-4HAGTWv6w,6326
6
- docket/docket.py,sha256=p2G7QNn4H0sUhDlAI5BO5C6cRTy1ZWUZmFEuohX3RM8,21470
7
- docket/execution.py,sha256=PDrlAr8VzmB6JvqKO71YhXUcTcGQW7eyXrSKiTcAexE,12508
8
- docket/instrumentation.py,sha256=bZlGA02JoJcY0J1WGm5_qXDfY0AXKr0ZLAYu67wkeKY,4611
9
- docket/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- docket/tasks.py,sha256=RIlSM2omh-YDwVnCz6M5MtmK8T_m_s1w2OlRRxDUs6A,1437
11
- docket/worker.py,sha256=gqY_N7H9Jxh_0YIYQk0mucj_UrZNKItkT1xkuhwYmlY,25301
12
- pydocket-0.5.2.dist-info/METADATA,sha256=VbNbGmDdseQkzH64LFmsPNtw6kwbIc8cL73jlhS0vck,13092
13
- pydocket-0.5.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- pydocket-0.5.2.dist-info/entry_points.txt,sha256=4WOk1nUlBsUT5O3RyMci2ImuC5XFswuopElYcLHtD5k,47
15
- pydocket-0.5.2.dist-info/licenses/LICENSE,sha256=YuVWU_ZXO0K_k2FG8xWKe5RGxV24AhJKTvQmKfqXuyk,1087
16
- pydocket-0.5.2.dist-info/RECORD,,