pydocket 0.15.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docket/__init__.py +55 -0
- docket/__main__.py +3 -0
- docket/_uuid7.py +99 -0
- docket/agenda.py +202 -0
- docket/annotations.py +81 -0
- docket/cli.py +1185 -0
- docket/dependencies.py +808 -0
- docket/docket.py +1062 -0
- docket/execution.py +1370 -0
- docket/instrumentation.py +225 -0
- docket/py.typed +0 -0
- docket/tasks.py +59 -0
- docket/testing.py +235 -0
- docket/worker.py +1071 -0
- pydocket-0.15.3.dist-info/METADATA +160 -0
- pydocket-0.15.3.dist-info/RECORD +19 -0
- pydocket-0.15.3.dist-info/WHEEL +4 -0
- pydocket-0.15.3.dist-info/entry_points.txt +2 -0
- pydocket-0.15.3.dist-info/licenses/LICENSE +9 -0
docket/dependencies.py
ADDED
|
@@ -0,0 +1,808 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import inspect
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
|
6
|
+
from contextvars import ContextVar
|
|
7
|
+
from datetime import datetime, timedelta, timezone
|
|
8
|
+
from types import TracebackType
|
|
9
|
+
from typing import (
|
|
10
|
+
TYPE_CHECKING,
|
|
11
|
+
Any,
|
|
12
|
+
AsyncContextManager,
|
|
13
|
+
AsyncGenerator,
|
|
14
|
+
Awaitable,
|
|
15
|
+
Callable,
|
|
16
|
+
ContextManager,
|
|
17
|
+
Counter,
|
|
18
|
+
Generic,
|
|
19
|
+
NoReturn,
|
|
20
|
+
TypeVar,
|
|
21
|
+
cast,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from .docket import Docket
|
|
25
|
+
from .execution import Execution, ExecutionProgress, TaskFunction, get_signature
|
|
26
|
+
from .instrumentation import CACHE_SIZE
|
|
27
|
+
# Run and RunProgress have been consolidated into Execution
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
30
|
+
from .worker import Worker
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Dependency(abc.ABC):
|
|
34
|
+
single: bool = False
|
|
35
|
+
|
|
36
|
+
docket: ContextVar[Docket] = ContextVar("docket")
|
|
37
|
+
worker: ContextVar["Worker"] = ContextVar("worker")
|
|
38
|
+
execution: ContextVar[Execution] = ContextVar("execution")
|
|
39
|
+
|
|
40
|
+
@abc.abstractmethod
|
|
41
|
+
async def __aenter__(self) -> Any: ... # pragma: no cover
|
|
42
|
+
|
|
43
|
+
async def __aexit__(
|
|
44
|
+
self,
|
|
45
|
+
_exc_type: type[BaseException] | None,
|
|
46
|
+
_exc_value: BaseException | None,
|
|
47
|
+
_traceback: TracebackType | None,
|
|
48
|
+
) -> bool: ... # pragma: no cover
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class _CurrentWorker(Dependency):
|
|
52
|
+
async def __aenter__(self) -> "Worker":
|
|
53
|
+
return self.worker.get()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def CurrentWorker() -> "Worker":
|
|
57
|
+
"""A dependency to access the current Worker.
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
@task
|
|
63
|
+
async def my_task(worker: Worker = CurrentWorker()) -> None:
|
|
64
|
+
assert isinstance(worker, Worker)
|
|
65
|
+
```
|
|
66
|
+
"""
|
|
67
|
+
return cast("Worker", _CurrentWorker())
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class _CurrentDocket(Dependency):
|
|
71
|
+
async def __aenter__(self) -> Docket:
|
|
72
|
+
return self.docket.get()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def CurrentDocket() -> Docket:
|
|
76
|
+
"""A dependency to access the current Docket.
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
|
|
80
|
+
```python
|
|
81
|
+
@task
|
|
82
|
+
async def my_task(docket: Docket = CurrentDocket()) -> None:
|
|
83
|
+
assert isinstance(docket, Docket)
|
|
84
|
+
```
|
|
85
|
+
"""
|
|
86
|
+
return cast(Docket, _CurrentDocket())
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class _CurrentExecution(Dependency):
|
|
90
|
+
async def __aenter__(self) -> Execution:
|
|
91
|
+
return self.execution.get()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def CurrentExecution() -> Execution:
|
|
95
|
+
"""A dependency to access the current Execution.
|
|
96
|
+
|
|
97
|
+
Example:
|
|
98
|
+
|
|
99
|
+
```python
|
|
100
|
+
@task
|
|
101
|
+
async def my_task(execution: Execution = CurrentExecution()) -> None:
|
|
102
|
+
assert isinstance(execution, Execution)
|
|
103
|
+
```
|
|
104
|
+
"""
|
|
105
|
+
return cast(Execution, _CurrentExecution())
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class _TaskKey(Dependency):
|
|
109
|
+
async def __aenter__(self) -> str:
|
|
110
|
+
return self.execution.get().key
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def TaskKey() -> str:
|
|
114
|
+
"""A dependency to access the key of the currently executing task.
|
|
115
|
+
|
|
116
|
+
Example:
|
|
117
|
+
|
|
118
|
+
```python
|
|
119
|
+
@task
|
|
120
|
+
async def my_task(key: str = TaskKey()) -> None:
|
|
121
|
+
assert isinstance(key, str)
|
|
122
|
+
```
|
|
123
|
+
"""
|
|
124
|
+
return cast(str, _TaskKey())
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class _TaskArgument(Dependency):
|
|
128
|
+
parameter: str | None
|
|
129
|
+
optional: bool
|
|
130
|
+
|
|
131
|
+
def __init__(self, parameter: str | None = None, optional: bool = False) -> None:
|
|
132
|
+
self.parameter = parameter
|
|
133
|
+
self.optional = optional
|
|
134
|
+
|
|
135
|
+
async def __aenter__(self) -> Any:
|
|
136
|
+
assert self.parameter is not None
|
|
137
|
+
execution = self.execution.get()
|
|
138
|
+
try:
|
|
139
|
+
return execution.get_argument(self.parameter)
|
|
140
|
+
except KeyError:
|
|
141
|
+
if self.optional:
|
|
142
|
+
return None
|
|
143
|
+
raise
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def TaskArgument(parameter: str | None = None, optional: bool = False) -> Any:
|
|
147
|
+
"""A dependency to access a argument of the currently executing task. This is
|
|
148
|
+
often useful in dependency functions so they can access the arguments of the
|
|
149
|
+
task they are injected into.
|
|
150
|
+
|
|
151
|
+
Example:
|
|
152
|
+
|
|
153
|
+
```python
|
|
154
|
+
async def customer_name(customer_id: int = TaskArgument()) -> str:
|
|
155
|
+
...look up the customer's name by ID...
|
|
156
|
+
return "John Doe"
|
|
157
|
+
|
|
158
|
+
@task
|
|
159
|
+
async def greet_customer(customer_id: int, name: str = Depends(customer_name)) -> None:
|
|
160
|
+
print(f"Hello, {name}!")
|
|
161
|
+
```
|
|
162
|
+
"""
|
|
163
|
+
return cast(Any, _TaskArgument(parameter, optional))
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class _TaskLogger(Dependency):
|
|
167
|
+
async def __aenter__(self) -> "logging.LoggerAdapter[logging.Logger]":
|
|
168
|
+
execution = self.execution.get()
|
|
169
|
+
logger = logging.getLogger(f"docket.task.{execution.function.__name__}")
|
|
170
|
+
return logging.LoggerAdapter(
|
|
171
|
+
logger,
|
|
172
|
+
{
|
|
173
|
+
**self.docket.get().labels(),
|
|
174
|
+
**self.worker.get().labels(),
|
|
175
|
+
**execution.specific_labels(),
|
|
176
|
+
},
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def TaskLogger() -> "logging.LoggerAdapter[logging.Logger]":
|
|
181
|
+
"""A dependency to access a logger for the currently executing task. The logger
|
|
182
|
+
will automatically inject contextual information such as the worker and docket
|
|
183
|
+
name, the task key, and the current execution attempt number.
|
|
184
|
+
|
|
185
|
+
Example:
|
|
186
|
+
|
|
187
|
+
```python
|
|
188
|
+
@task
|
|
189
|
+
async def my_task(logger: "LoggerAdapter[Logger]" = TaskLogger()) -> None:
|
|
190
|
+
logger.info("Hello, world!")
|
|
191
|
+
```
|
|
192
|
+
"""
|
|
193
|
+
return cast("logging.LoggerAdapter[logging.Logger]", _TaskLogger())
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class Progress(Dependency):
|
|
197
|
+
"""A dependency to report progress updates for the currently executing task.
|
|
198
|
+
|
|
199
|
+
Tasks can use this to report their current progress (current/total values) and
|
|
200
|
+
status messages to external observers.
|
|
201
|
+
|
|
202
|
+
Example:
|
|
203
|
+
|
|
204
|
+
```python
|
|
205
|
+
@task
|
|
206
|
+
async def process_records(records: list, progress: Progress = Progress()) -> None:
|
|
207
|
+
await progress.set_total(len(records))
|
|
208
|
+
for i, record in enumerate(records):
|
|
209
|
+
await process(record)
|
|
210
|
+
await progress.increment()
|
|
211
|
+
await progress.set_message(f"Processed {record.id}")
|
|
212
|
+
```
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(self) -> None:
|
|
216
|
+
self._progress: ExecutionProgress | None = None
|
|
217
|
+
|
|
218
|
+
async def __aenter__(self) -> "Progress":
|
|
219
|
+
execution = self.execution.get()
|
|
220
|
+
self._progress = execution.progress
|
|
221
|
+
return self
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def current(self) -> int | None:
|
|
225
|
+
"""Current progress value."""
|
|
226
|
+
assert self._progress is not None, "Progress must be used as a dependency"
|
|
227
|
+
return self._progress.current
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def total(self) -> int:
|
|
231
|
+
"""Total/target value for progress tracking."""
|
|
232
|
+
assert self._progress is not None, "Progress must be used as a dependency"
|
|
233
|
+
return self._progress.total
|
|
234
|
+
|
|
235
|
+
@property
|
|
236
|
+
def message(self) -> str | None:
|
|
237
|
+
"""User-provided status message."""
|
|
238
|
+
assert self._progress is not None, "Progress must be used as a dependency"
|
|
239
|
+
return self._progress.message
|
|
240
|
+
|
|
241
|
+
async def set_total(self, total: int) -> None:
|
|
242
|
+
"""Set the total/target value for progress tracking."""
|
|
243
|
+
assert self._progress is not None, "Progress must be used as a dependency"
|
|
244
|
+
await self._progress.set_total(total)
|
|
245
|
+
|
|
246
|
+
async def increment(self, amount: int = 1) -> None:
|
|
247
|
+
"""Atomically increment the current progress value."""
|
|
248
|
+
assert self._progress is not None, "Progress must be used as a dependency"
|
|
249
|
+
await self._progress.increment(amount)
|
|
250
|
+
|
|
251
|
+
async def set_message(self, message: str | None) -> None:
|
|
252
|
+
"""Update the progress status message."""
|
|
253
|
+
assert self._progress is not None, "Progress must be used as a dependency"
|
|
254
|
+
await self._progress.set_message(message)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class ForcedRetry(Exception):
|
|
258
|
+
"""Raised when a task requests a retry via `in_` or `at`"""
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class Retry(Dependency):
|
|
262
|
+
"""Configures linear retries for a task. You can specify the total number of
|
|
263
|
+
attempts (or `None` to retry indefinitely), and the delay between attempts.
|
|
264
|
+
|
|
265
|
+
Example:
|
|
266
|
+
|
|
267
|
+
```python
|
|
268
|
+
@task
|
|
269
|
+
async def my_task(retry: Retry = Retry(attempts=3)) -> None:
|
|
270
|
+
...
|
|
271
|
+
```
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
single: bool = True
|
|
275
|
+
|
|
276
|
+
def __init__(
|
|
277
|
+
self, attempts: int | None = 1, delay: timedelta = timedelta(0)
|
|
278
|
+
) -> None:
|
|
279
|
+
"""
|
|
280
|
+
Args:
|
|
281
|
+
attempts: The total number of attempts to make. If `None`, the task will
|
|
282
|
+
be retried indefinitely.
|
|
283
|
+
delay: The delay between attempts.
|
|
284
|
+
"""
|
|
285
|
+
self.attempts = attempts
|
|
286
|
+
self.delay = delay
|
|
287
|
+
self.attempt = 1
|
|
288
|
+
|
|
289
|
+
async def __aenter__(self) -> "Retry":
|
|
290
|
+
execution = self.execution.get()
|
|
291
|
+
retry = Retry(attempts=self.attempts, delay=self.delay)
|
|
292
|
+
retry.attempt = execution.attempt
|
|
293
|
+
return retry
|
|
294
|
+
|
|
295
|
+
def at(self, when: datetime) -> NoReturn:
|
|
296
|
+
now = datetime.now(timezone.utc)
|
|
297
|
+
diff = when - now
|
|
298
|
+
diff = diff if diff.total_seconds() >= 0 else timedelta(0)
|
|
299
|
+
|
|
300
|
+
self.in_(diff)
|
|
301
|
+
|
|
302
|
+
def in_(self, when: timedelta) -> NoReturn:
|
|
303
|
+
self.delay: timedelta = when
|
|
304
|
+
raise ForcedRetry()
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class ExponentialRetry(Retry):
|
|
308
|
+
"""Configures exponential retries for a task. You can specify the total number
|
|
309
|
+
of attempts (or `None` to retry indefinitely), and the minimum and maximum delays
|
|
310
|
+
between attempts.
|
|
311
|
+
|
|
312
|
+
Example:
|
|
313
|
+
|
|
314
|
+
```python
|
|
315
|
+
@task
|
|
316
|
+
async def my_task(retry: ExponentialRetry = ExponentialRetry(attempts=3)) -> None:
|
|
317
|
+
...
|
|
318
|
+
```
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
def __init__(
|
|
322
|
+
self,
|
|
323
|
+
attempts: int | None = 1,
|
|
324
|
+
minimum_delay: timedelta = timedelta(seconds=1),
|
|
325
|
+
maximum_delay: timedelta = timedelta(seconds=64),
|
|
326
|
+
) -> None:
|
|
327
|
+
"""
|
|
328
|
+
Args:
|
|
329
|
+
attempts: The total number of attempts to make. If `None`, the task will
|
|
330
|
+
be retried indefinitely.
|
|
331
|
+
minimum_delay: The minimum delay between attempts.
|
|
332
|
+
maximum_delay: The maximum delay between attempts.
|
|
333
|
+
"""
|
|
334
|
+
super().__init__(attempts=attempts, delay=minimum_delay)
|
|
335
|
+
self.maximum_delay = maximum_delay
|
|
336
|
+
|
|
337
|
+
async def __aenter__(self) -> "ExponentialRetry":
|
|
338
|
+
execution = self.execution.get()
|
|
339
|
+
|
|
340
|
+
retry = ExponentialRetry(
|
|
341
|
+
attempts=self.attempts,
|
|
342
|
+
minimum_delay=self.delay,
|
|
343
|
+
maximum_delay=self.maximum_delay,
|
|
344
|
+
)
|
|
345
|
+
retry.attempt = execution.attempt
|
|
346
|
+
|
|
347
|
+
if execution.attempt > 1:
|
|
348
|
+
backoff_factor = 2 ** (execution.attempt - 1)
|
|
349
|
+
calculated_delay = self.delay * backoff_factor
|
|
350
|
+
|
|
351
|
+
if calculated_delay > self.maximum_delay:
|
|
352
|
+
retry.delay = self.maximum_delay
|
|
353
|
+
else:
|
|
354
|
+
retry.delay = calculated_delay
|
|
355
|
+
|
|
356
|
+
return retry
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
class Perpetual(Dependency):
|
|
360
|
+
"""Declare a task that should be run perpetually. Perpetual tasks are automatically
|
|
361
|
+
rescheduled for the future after they finish (whether they succeed or fail). A
|
|
362
|
+
perpetual task can be scheduled at worker startup with the `automatic=True`.
|
|
363
|
+
|
|
364
|
+
Example:
|
|
365
|
+
|
|
366
|
+
```python
|
|
367
|
+
@task
|
|
368
|
+
async def my_task(perpetual: Perpetual = Perpetual()) -> None:
|
|
369
|
+
...
|
|
370
|
+
```
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
single = True
|
|
374
|
+
|
|
375
|
+
every: timedelta
|
|
376
|
+
automatic: bool
|
|
377
|
+
|
|
378
|
+
args: tuple[Any, ...]
|
|
379
|
+
kwargs: dict[str, Any]
|
|
380
|
+
|
|
381
|
+
cancelled: bool
|
|
382
|
+
|
|
383
|
+
def __init__(
|
|
384
|
+
self,
|
|
385
|
+
every: timedelta = timedelta(0),
|
|
386
|
+
automatic: bool = False,
|
|
387
|
+
) -> None:
|
|
388
|
+
"""
|
|
389
|
+
Args:
|
|
390
|
+
every: The target interval between task executions.
|
|
391
|
+
automatic: If set, this task will be automatically scheduled during worker
|
|
392
|
+
startup and continually through the worker's lifespan. This ensures
|
|
393
|
+
that the task will always be scheduled despite crashes and other
|
|
394
|
+
adverse conditions. Automatic tasks must not require any arguments.
|
|
395
|
+
"""
|
|
396
|
+
self.every = every
|
|
397
|
+
self.automatic = automatic
|
|
398
|
+
self.cancelled = False
|
|
399
|
+
|
|
400
|
+
async def __aenter__(self) -> "Perpetual":
|
|
401
|
+
execution = self.execution.get()
|
|
402
|
+
perpetual = Perpetual(every=self.every)
|
|
403
|
+
perpetual.args = execution.args
|
|
404
|
+
perpetual.kwargs = execution.kwargs
|
|
405
|
+
return perpetual
|
|
406
|
+
|
|
407
|
+
def cancel(self) -> None:
|
|
408
|
+
self.cancelled = True
|
|
409
|
+
|
|
410
|
+
def perpetuate(self, *args: Any, **kwargs: Any) -> None:
|
|
411
|
+
self.args = args
|
|
412
|
+
self.kwargs = kwargs
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class Timeout(Dependency):
|
|
416
|
+
"""Configures a timeout for a task. You can specify the base timeout, and the
|
|
417
|
+
task will be cancelled if it exceeds this duration. The timeout may be extended
|
|
418
|
+
within the context of a single running task.
|
|
419
|
+
|
|
420
|
+
Example:
|
|
421
|
+
|
|
422
|
+
```python
|
|
423
|
+
@task
|
|
424
|
+
async def my_task(timeout: Timeout = Timeout(timedelta(seconds=10))) -> None:
|
|
425
|
+
...
|
|
426
|
+
```
|
|
427
|
+
"""
|
|
428
|
+
|
|
429
|
+
single: bool = True
|
|
430
|
+
|
|
431
|
+
base: timedelta
|
|
432
|
+
_deadline: float
|
|
433
|
+
|
|
434
|
+
def __init__(self, base: timedelta) -> None:
|
|
435
|
+
"""
|
|
436
|
+
Args:
|
|
437
|
+
base: The base timeout duration.
|
|
438
|
+
"""
|
|
439
|
+
self.base = base
|
|
440
|
+
|
|
441
|
+
async def __aenter__(self) -> "Timeout":
|
|
442
|
+
timeout = Timeout(base=self.base)
|
|
443
|
+
timeout.start()
|
|
444
|
+
return timeout
|
|
445
|
+
|
|
446
|
+
def start(self) -> None:
|
|
447
|
+
self._deadline = time.monotonic() + self.base.total_seconds()
|
|
448
|
+
|
|
449
|
+
def expired(self) -> bool:
|
|
450
|
+
return time.monotonic() >= self._deadline
|
|
451
|
+
|
|
452
|
+
def remaining(self) -> timedelta:
|
|
453
|
+
"""Get the remaining time until the timeout expires."""
|
|
454
|
+
return timedelta(seconds=self._deadline - time.monotonic())
|
|
455
|
+
|
|
456
|
+
def extend(self, by: timedelta | None = None) -> None:
|
|
457
|
+
"""Extend the timeout by a given duration. If no duration is provided, the
|
|
458
|
+
base timeout will be used.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
by: The duration to extend the timeout by.
|
|
462
|
+
"""
|
|
463
|
+
if by is None:
|
|
464
|
+
by = self.base
|
|
465
|
+
self._deadline += by.total_seconds()
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
R = TypeVar("R")
|
|
469
|
+
|
|
470
|
+
DependencyFunction = Callable[
|
|
471
|
+
..., R | Awaitable[R] | ContextManager[R] | AsyncContextManager[R]
|
|
472
|
+
]
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
_parameter_cache: dict[
|
|
476
|
+
TaskFunction | DependencyFunction[Any],
|
|
477
|
+
dict[str, Dependency],
|
|
478
|
+
] = {}
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def get_dependency_parameters(
|
|
482
|
+
function: TaskFunction | DependencyFunction[Any],
|
|
483
|
+
) -> dict[str, Dependency]:
|
|
484
|
+
if function in _parameter_cache:
|
|
485
|
+
CACHE_SIZE.set(len(_parameter_cache), {"cache": "parameter"})
|
|
486
|
+
return _parameter_cache[function]
|
|
487
|
+
|
|
488
|
+
dependencies: dict[str, Dependency] = {}
|
|
489
|
+
|
|
490
|
+
signature = get_signature(function)
|
|
491
|
+
|
|
492
|
+
for parameter, param in signature.parameters.items():
|
|
493
|
+
if not isinstance(param.default, Dependency):
|
|
494
|
+
continue
|
|
495
|
+
|
|
496
|
+
dependencies[parameter] = param.default
|
|
497
|
+
|
|
498
|
+
_parameter_cache[function] = dependencies
|
|
499
|
+
CACHE_SIZE.set(len(_parameter_cache), {"cache": "parameter"})
|
|
500
|
+
return dependencies
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
class _Depends(Dependency, Generic[R]):
|
|
504
|
+
dependency: DependencyFunction[R]
|
|
505
|
+
|
|
506
|
+
cache: ContextVar[dict[DependencyFunction[Any], Any]] = ContextVar("cache")
|
|
507
|
+
stack: ContextVar[AsyncExitStack] = ContextVar("stack")
|
|
508
|
+
|
|
509
|
+
def __init__(
|
|
510
|
+
self,
|
|
511
|
+
dependency: Callable[
|
|
512
|
+
[], R | Awaitable[R] | ContextManager[R] | AsyncContextManager[R]
|
|
513
|
+
],
|
|
514
|
+
) -> None:
|
|
515
|
+
self.dependency = dependency
|
|
516
|
+
|
|
517
|
+
async def _resolve_parameters(
|
|
518
|
+
self,
|
|
519
|
+
function: TaskFunction | DependencyFunction[Any],
|
|
520
|
+
) -> dict[str, Any]:
|
|
521
|
+
stack = self.stack.get()
|
|
522
|
+
|
|
523
|
+
arguments: dict[str, Any] = {}
|
|
524
|
+
parameters = get_dependency_parameters(function)
|
|
525
|
+
|
|
526
|
+
for parameter, dependency in parameters.items():
|
|
527
|
+
# Special case for TaskArguments, they are "magical" and infer the parameter
|
|
528
|
+
# they refer to from the parameter name (unless otherwise specified)
|
|
529
|
+
if isinstance(dependency, _TaskArgument) and not dependency.parameter:
|
|
530
|
+
dependency.parameter = parameter
|
|
531
|
+
|
|
532
|
+
arguments[parameter] = await stack.enter_async_context(dependency)
|
|
533
|
+
|
|
534
|
+
return arguments
|
|
535
|
+
|
|
536
|
+
async def __aenter__(self) -> R:
|
|
537
|
+
cache = self.cache.get()
|
|
538
|
+
|
|
539
|
+
if self.dependency in cache:
|
|
540
|
+
return cache[self.dependency]
|
|
541
|
+
|
|
542
|
+
stack = self.stack.get()
|
|
543
|
+
arguments = await self._resolve_parameters(self.dependency)
|
|
544
|
+
|
|
545
|
+
raw_value: R | Awaitable[R] | ContextManager[R] | AsyncContextManager[R] = (
|
|
546
|
+
self.dependency(**arguments)
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
# Handle different return types from the dependency function
|
|
550
|
+
resolved_value: R
|
|
551
|
+
if isinstance(raw_value, AsyncContextManager):
|
|
552
|
+
# Async context manager: await enter_async_context
|
|
553
|
+
resolved_value = await stack.enter_async_context(raw_value)
|
|
554
|
+
elif isinstance(raw_value, ContextManager):
|
|
555
|
+
# Sync context manager: use enter_context (no await needed)
|
|
556
|
+
resolved_value = stack.enter_context(raw_value)
|
|
557
|
+
elif inspect.iscoroutine(raw_value) or isinstance(raw_value, Awaitable):
|
|
558
|
+
# Async function returning awaitable: await it
|
|
559
|
+
resolved_value = await cast(Awaitable[R], raw_value)
|
|
560
|
+
else:
|
|
561
|
+
# Sync function returning a value directly, use as-is
|
|
562
|
+
resolved_value = cast(R, raw_value)
|
|
563
|
+
|
|
564
|
+
cache[self.dependency] = resolved_value
|
|
565
|
+
return resolved_value
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def Depends(dependency: DependencyFunction[R]) -> R:
|
|
569
|
+
"""Include a user-defined function as a dependency. Dependencies may be:
|
|
570
|
+
- Synchronous functions returning a value
|
|
571
|
+
- Asynchronous functions returning a value (awaitable)
|
|
572
|
+
- Synchronous context managers (using @contextmanager)
|
|
573
|
+
- Asynchronous context managers (using @asynccontextmanager)
|
|
574
|
+
|
|
575
|
+
If a dependency returns a context manager, it will be entered and exited around
|
|
576
|
+
the task, giving an opportunity to control the lifetime of a resource.
|
|
577
|
+
|
|
578
|
+
**Important**: Synchronous dependencies should NOT include blocking I/O operations
|
|
579
|
+
(file access, network calls, database queries, etc.). Use async dependencies for
|
|
580
|
+
any I/O. Sync dependencies are best for:
|
|
581
|
+
- Pure computations
|
|
582
|
+
- In-memory data structure access
|
|
583
|
+
- Configuration lookups from memory
|
|
584
|
+
- Non-blocking transformations
|
|
585
|
+
|
|
586
|
+
Examples:
|
|
587
|
+
|
|
588
|
+
```python
|
|
589
|
+
# Sync dependency - pure computation, no I/O
|
|
590
|
+
def get_config() -> dict:
|
|
591
|
+
# Access in-memory config, no I/O
|
|
592
|
+
return {"api_url": "https://api.example.com", "timeout": 30}
|
|
593
|
+
|
|
594
|
+
# Sync dependency - compute value from arguments
|
|
595
|
+
def build_query_params(
|
|
596
|
+
user_id: int = TaskArgument(),
|
|
597
|
+
config: dict = Depends(get_config)
|
|
598
|
+
) -> dict:
|
|
599
|
+
# Pure computation, no I/O
|
|
600
|
+
return {"user_id": user_id, "timeout": config["timeout"]}
|
|
601
|
+
|
|
602
|
+
# Async dependency - I/O operations
|
|
603
|
+
async def get_user(user_id: int = TaskArgument()) -> User:
|
|
604
|
+
# Network I/O - must be async
|
|
605
|
+
return await fetch_user_from_api(user_id)
|
|
606
|
+
|
|
607
|
+
# Async context manager - I/O resource management
|
|
608
|
+
from contextlib import asynccontextmanager
|
|
609
|
+
|
|
610
|
+
@asynccontextmanager
|
|
611
|
+
async def get_db_connection():
|
|
612
|
+
# I/O operations - must be async
|
|
613
|
+
conn = await db.connect()
|
|
614
|
+
try:
|
|
615
|
+
yield conn
|
|
616
|
+
finally:
|
|
617
|
+
await conn.close()
|
|
618
|
+
|
|
619
|
+
@task
|
|
620
|
+
async def my_task(
|
|
621
|
+
params: dict = Depends(build_query_params),
|
|
622
|
+
user: User = Depends(get_user),
|
|
623
|
+
db: Connection = Depends(get_db_connection),
|
|
624
|
+
) -> None:
|
|
625
|
+
await db.execute("UPDATE users SET ...", params)
|
|
626
|
+
```
|
|
627
|
+
"""
|
|
628
|
+
return cast(R, _Depends(dependency))
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
class ConcurrencyLimit(Dependency):
|
|
632
|
+
"""Configures concurrency limits for a task based on specific argument values.
|
|
633
|
+
|
|
634
|
+
This allows fine-grained control over task execution by limiting concurrent
|
|
635
|
+
tasks based on the value of specific arguments.
|
|
636
|
+
|
|
637
|
+
Example:
|
|
638
|
+
|
|
639
|
+
```python
|
|
640
|
+
async def process_customer(
|
|
641
|
+
customer_id: int,
|
|
642
|
+
concurrency: ConcurrencyLimit = ConcurrencyLimit("customer_id", max_concurrent=1)
|
|
643
|
+
) -> None:
|
|
644
|
+
# Only one task per customer_id will run at a time
|
|
645
|
+
...
|
|
646
|
+
|
|
647
|
+
async def backup_db(
|
|
648
|
+
db_name: str,
|
|
649
|
+
concurrency: ConcurrencyLimit = ConcurrencyLimit("db_name", max_concurrent=3)
|
|
650
|
+
) -> None:
|
|
651
|
+
# Only 3 backup tasks per database name will run at a time
|
|
652
|
+
...
|
|
653
|
+
```
|
|
654
|
+
"""
|
|
655
|
+
|
|
656
|
+
single: bool = True
|
|
657
|
+
|
|
658
|
+
def __init__(
|
|
659
|
+
self, argument_name: str, max_concurrent: int = 1, scope: str | None = None
|
|
660
|
+
) -> None:
|
|
661
|
+
"""
|
|
662
|
+
Args:
|
|
663
|
+
argument_name: The name of the task argument to use for concurrency grouping
|
|
664
|
+
max_concurrent: Maximum number of concurrent tasks per unique argument value
|
|
665
|
+
scope: Optional scope prefix for Redis keys (defaults to docket name)
|
|
666
|
+
"""
|
|
667
|
+
self.argument_name = argument_name
|
|
668
|
+
self.max_concurrent = max_concurrent
|
|
669
|
+
self.scope = scope
|
|
670
|
+
self._concurrency_key: str | None = None
|
|
671
|
+
self._initialized: bool = False
|
|
672
|
+
|
|
673
|
+
async def __aenter__(self) -> "ConcurrencyLimit":
|
|
674
|
+
execution = self.execution.get()
|
|
675
|
+
docket = self.docket.get()
|
|
676
|
+
|
|
677
|
+
# Get the argument value to group by
|
|
678
|
+
try:
|
|
679
|
+
argument_value = execution.get_argument(self.argument_name)
|
|
680
|
+
except KeyError:
|
|
681
|
+
# If argument not found, create a bypass limit that doesn't apply concurrency control
|
|
682
|
+
limit = ConcurrencyLimit(
|
|
683
|
+
self.argument_name, self.max_concurrent, self.scope
|
|
684
|
+
)
|
|
685
|
+
limit._concurrency_key = None # Special marker for bypassed concurrency
|
|
686
|
+
limit._initialized = True # Mark as initialized but bypassed
|
|
687
|
+
return limit
|
|
688
|
+
|
|
689
|
+
# Create a concurrency key for this specific argument value
|
|
690
|
+
scope = self.scope or docket.name
|
|
691
|
+
self._concurrency_key = (
|
|
692
|
+
f"{scope}:concurrency:{self.argument_name}:{argument_value}"
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
limit = ConcurrencyLimit(self.argument_name, self.max_concurrent, self.scope)
|
|
696
|
+
limit._concurrency_key = self._concurrency_key
|
|
697
|
+
limit._initialized = True # Mark as initialized
|
|
698
|
+
return limit
|
|
699
|
+
|
|
700
|
+
@property
|
|
701
|
+
def concurrency_key(self) -> str | None:
|
|
702
|
+
"""Redis key used for tracking concurrency for this specific argument value.
|
|
703
|
+
Returns None when concurrency control is bypassed due to missing arguments.
|
|
704
|
+
Raises RuntimeError if accessed before initialization."""
|
|
705
|
+
if not self._initialized:
|
|
706
|
+
raise RuntimeError(
|
|
707
|
+
"ConcurrencyLimit not initialized - use within task context"
|
|
708
|
+
)
|
|
709
|
+
return self._concurrency_key
|
|
710
|
+
|
|
711
|
+
@property
|
|
712
|
+
def is_bypassed(self) -> bool:
|
|
713
|
+
"""Returns True if concurrency control is bypassed due to missing arguments."""
|
|
714
|
+
return self._initialized and self._concurrency_key is None
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
D = TypeVar("D", bound=Dependency)
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
def get_single_dependency_parameter_of_type(
|
|
721
|
+
function: TaskFunction, dependency_type: type[D]
|
|
722
|
+
) -> D | None:
|
|
723
|
+
assert dependency_type.single, "Dependency must be single"
|
|
724
|
+
for _, dependency in get_dependency_parameters(function).items():
|
|
725
|
+
if isinstance(dependency, dependency_type):
|
|
726
|
+
return dependency
|
|
727
|
+
return None
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
def get_single_dependency_of_type(
|
|
731
|
+
dependencies: dict[str, Dependency], dependency_type: type[D]
|
|
732
|
+
) -> D | None:
|
|
733
|
+
assert dependency_type.single, "Dependency must be single"
|
|
734
|
+
for _, dependency in dependencies.items():
|
|
735
|
+
if isinstance(dependency, dependency_type):
|
|
736
|
+
return dependency
|
|
737
|
+
return None
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def validate_dependencies(function: TaskFunction) -> None:
|
|
741
|
+
parameters = get_dependency_parameters(function)
|
|
742
|
+
|
|
743
|
+
counts = Counter(type(dependency) for dependency in parameters.values())
|
|
744
|
+
|
|
745
|
+
for dependency_type, count in counts.items():
|
|
746
|
+
if dependency_type.single and count > 1:
|
|
747
|
+
raise ValueError(
|
|
748
|
+
f"Only one {dependency_type.__name__} dependency is allowed per task"
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
class FailedDependency:
|
|
753
|
+
def __init__(self, parameter: str, error: Exception) -> None:
|
|
754
|
+
self.parameter = parameter
|
|
755
|
+
self.error = error
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
@asynccontextmanager
|
|
759
|
+
async def resolved_dependencies(
|
|
760
|
+
worker: "Worker", execution: Execution
|
|
761
|
+
) -> AsyncGenerator[dict[str, Any], None]:
|
|
762
|
+
# Capture tokens for all contextvar sets to ensure proper cleanup
|
|
763
|
+
docket_token = Dependency.docket.set(worker.docket)
|
|
764
|
+
worker_token = Dependency.worker.set(worker)
|
|
765
|
+
execution_token = Dependency.execution.set(execution)
|
|
766
|
+
cache_token = _Depends.cache.set({})
|
|
767
|
+
|
|
768
|
+
try:
|
|
769
|
+
async with AsyncExitStack() as stack:
|
|
770
|
+
stack_token = _Depends.stack.set(stack)
|
|
771
|
+
try:
|
|
772
|
+
arguments: dict[str, Any] = {}
|
|
773
|
+
|
|
774
|
+
parameters = get_dependency_parameters(execution.function)
|
|
775
|
+
for parameter, dependency in parameters.items():
|
|
776
|
+
kwargs = execution.kwargs
|
|
777
|
+
if parameter in kwargs:
|
|
778
|
+
arguments[parameter] = kwargs[parameter]
|
|
779
|
+
continue
|
|
780
|
+
|
|
781
|
+
# Special case for TaskArguments, they are "magical" and infer the parameter
|
|
782
|
+
# they refer to from the parameter name (unless otherwise specified). At
|
|
783
|
+
# the top-level task function call, it doesn't make sense to specify one
|
|
784
|
+
# _without_ a parameter name, so we'll call that a failed dependency.
|
|
785
|
+
if (
|
|
786
|
+
isinstance(dependency, _TaskArgument)
|
|
787
|
+
and not dependency.parameter
|
|
788
|
+
):
|
|
789
|
+
arguments[parameter] = FailedDependency(
|
|
790
|
+
parameter, ValueError("No parameter name specified")
|
|
791
|
+
)
|
|
792
|
+
continue
|
|
793
|
+
|
|
794
|
+
try:
|
|
795
|
+
arguments[parameter] = await stack.enter_async_context(
|
|
796
|
+
dependency
|
|
797
|
+
)
|
|
798
|
+
except Exception as error:
|
|
799
|
+
arguments[parameter] = FailedDependency(parameter, error)
|
|
800
|
+
|
|
801
|
+
yield arguments
|
|
802
|
+
finally:
|
|
803
|
+
_Depends.stack.reset(stack_token)
|
|
804
|
+
finally:
|
|
805
|
+
_Depends.cache.reset(cache_token)
|
|
806
|
+
Dependency.execution.reset(execution_token)
|
|
807
|
+
Dependency.worker.reset(worker_token)
|
|
808
|
+
Dependency.docket.reset(docket_token)
|