pydocket 0.15.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
docket/dependencies.py ADDED
@@ -0,0 +1,808 @@
1
+ import abc
2
+ import inspect
3
+ import logging
4
+ import time
5
+ from contextlib import AsyncExitStack, asynccontextmanager
6
+ from contextvars import ContextVar
7
+ from datetime import datetime, timedelta, timezone
8
+ from types import TracebackType
9
+ from typing import (
10
+ TYPE_CHECKING,
11
+ Any,
12
+ AsyncContextManager,
13
+ AsyncGenerator,
14
+ Awaitable,
15
+ Callable,
16
+ ContextManager,
17
+ Counter,
18
+ Generic,
19
+ NoReturn,
20
+ TypeVar,
21
+ cast,
22
+ )
23
+
24
+ from .docket import Docket
25
+ from .execution import Execution, ExecutionProgress, TaskFunction, get_signature
26
+ from .instrumentation import CACHE_SIZE
27
+ # Run and RunProgress have been consolidated into Execution
28
+
29
+ if TYPE_CHECKING: # pragma: no cover
30
+ from .worker import Worker
31
+
32
+
33
+ class Dependency(abc.ABC):
34
+ single: bool = False
35
+
36
+ docket: ContextVar[Docket] = ContextVar("docket")
37
+ worker: ContextVar["Worker"] = ContextVar("worker")
38
+ execution: ContextVar[Execution] = ContextVar("execution")
39
+
40
+ @abc.abstractmethod
41
+ async def __aenter__(self) -> Any: ... # pragma: no cover
42
+
43
+ async def __aexit__(
44
+ self,
45
+ _exc_type: type[BaseException] | None,
46
+ _exc_value: BaseException | None,
47
+ _traceback: TracebackType | None,
48
+ ) -> bool: ... # pragma: no cover
49
+
50
+
51
+ class _CurrentWorker(Dependency):
52
+ async def __aenter__(self) -> "Worker":
53
+ return self.worker.get()
54
+
55
+
56
+ def CurrentWorker() -> "Worker":
57
+ """A dependency to access the current Worker.
58
+
59
+ Example:
60
+
61
+ ```python
62
+ @task
63
+ async def my_task(worker: Worker = CurrentWorker()) -> None:
64
+ assert isinstance(worker, Worker)
65
+ ```
66
+ """
67
+ return cast("Worker", _CurrentWorker())
68
+
69
+
70
+ class _CurrentDocket(Dependency):
71
+ async def __aenter__(self) -> Docket:
72
+ return self.docket.get()
73
+
74
+
75
+ def CurrentDocket() -> Docket:
76
+ """A dependency to access the current Docket.
77
+
78
+ Example:
79
+
80
+ ```python
81
+ @task
82
+ async def my_task(docket: Docket = CurrentDocket()) -> None:
83
+ assert isinstance(docket, Docket)
84
+ ```
85
+ """
86
+ return cast(Docket, _CurrentDocket())
87
+
88
+
89
+ class _CurrentExecution(Dependency):
90
+ async def __aenter__(self) -> Execution:
91
+ return self.execution.get()
92
+
93
+
94
+ def CurrentExecution() -> Execution:
95
+ """A dependency to access the current Execution.
96
+
97
+ Example:
98
+
99
+ ```python
100
+ @task
101
+ async def my_task(execution: Execution = CurrentExecution()) -> None:
102
+ assert isinstance(execution, Execution)
103
+ ```
104
+ """
105
+ return cast(Execution, _CurrentExecution())
106
+
107
+
108
+ class _TaskKey(Dependency):
109
+ async def __aenter__(self) -> str:
110
+ return self.execution.get().key
111
+
112
+
113
+ def TaskKey() -> str:
114
+ """A dependency to access the key of the currently executing task.
115
+
116
+ Example:
117
+
118
+ ```python
119
+ @task
120
+ async def my_task(key: str = TaskKey()) -> None:
121
+ assert isinstance(key, str)
122
+ ```
123
+ """
124
+ return cast(str, _TaskKey())
125
+
126
+
127
+ class _TaskArgument(Dependency):
128
+ parameter: str | None
129
+ optional: bool
130
+
131
+ def __init__(self, parameter: str | None = None, optional: bool = False) -> None:
132
+ self.parameter = parameter
133
+ self.optional = optional
134
+
135
+ async def __aenter__(self) -> Any:
136
+ assert self.parameter is not None
137
+ execution = self.execution.get()
138
+ try:
139
+ return execution.get_argument(self.parameter)
140
+ except KeyError:
141
+ if self.optional:
142
+ return None
143
+ raise
144
+
145
+
146
+ def TaskArgument(parameter: str | None = None, optional: bool = False) -> Any:
147
+ """A dependency to access a argument of the currently executing task. This is
148
+ often useful in dependency functions so they can access the arguments of the
149
+ task they are injected into.
150
+
151
+ Example:
152
+
153
+ ```python
154
+ async def customer_name(customer_id: int = TaskArgument()) -> str:
155
+ ...look up the customer's name by ID...
156
+ return "John Doe"
157
+
158
+ @task
159
+ async def greet_customer(customer_id: int, name: str = Depends(customer_name)) -> None:
160
+ print(f"Hello, {name}!")
161
+ ```
162
+ """
163
+ return cast(Any, _TaskArgument(parameter, optional))
164
+
165
+
166
+ class _TaskLogger(Dependency):
167
+ async def __aenter__(self) -> "logging.LoggerAdapter[logging.Logger]":
168
+ execution = self.execution.get()
169
+ logger = logging.getLogger(f"docket.task.{execution.function.__name__}")
170
+ return logging.LoggerAdapter(
171
+ logger,
172
+ {
173
+ **self.docket.get().labels(),
174
+ **self.worker.get().labels(),
175
+ **execution.specific_labels(),
176
+ },
177
+ )
178
+
179
+
180
+ def TaskLogger() -> "logging.LoggerAdapter[logging.Logger]":
181
+ """A dependency to access a logger for the currently executing task. The logger
182
+ will automatically inject contextual information such as the worker and docket
183
+ name, the task key, and the current execution attempt number.
184
+
185
+ Example:
186
+
187
+ ```python
188
+ @task
189
+ async def my_task(logger: "LoggerAdapter[Logger]" = TaskLogger()) -> None:
190
+ logger.info("Hello, world!")
191
+ ```
192
+ """
193
+ return cast("logging.LoggerAdapter[logging.Logger]", _TaskLogger())
194
+
195
+
196
+ class Progress(Dependency):
197
+ """A dependency to report progress updates for the currently executing task.
198
+
199
+ Tasks can use this to report their current progress (current/total values) and
200
+ status messages to external observers.
201
+
202
+ Example:
203
+
204
+ ```python
205
+ @task
206
+ async def process_records(records: list, progress: Progress = Progress()) -> None:
207
+ await progress.set_total(len(records))
208
+ for i, record in enumerate(records):
209
+ await process(record)
210
+ await progress.increment()
211
+ await progress.set_message(f"Processed {record.id}")
212
+ ```
213
+ """
214
+
215
+ def __init__(self) -> None:
216
+ self._progress: ExecutionProgress | None = None
217
+
218
+ async def __aenter__(self) -> "Progress":
219
+ execution = self.execution.get()
220
+ self._progress = execution.progress
221
+ return self
222
+
223
+ @property
224
+ def current(self) -> int | None:
225
+ """Current progress value."""
226
+ assert self._progress is not None, "Progress must be used as a dependency"
227
+ return self._progress.current
228
+
229
+ @property
230
+ def total(self) -> int:
231
+ """Total/target value for progress tracking."""
232
+ assert self._progress is not None, "Progress must be used as a dependency"
233
+ return self._progress.total
234
+
235
+ @property
236
+ def message(self) -> str | None:
237
+ """User-provided status message."""
238
+ assert self._progress is not None, "Progress must be used as a dependency"
239
+ return self._progress.message
240
+
241
+ async def set_total(self, total: int) -> None:
242
+ """Set the total/target value for progress tracking."""
243
+ assert self._progress is not None, "Progress must be used as a dependency"
244
+ await self._progress.set_total(total)
245
+
246
+ async def increment(self, amount: int = 1) -> None:
247
+ """Atomically increment the current progress value."""
248
+ assert self._progress is not None, "Progress must be used as a dependency"
249
+ await self._progress.increment(amount)
250
+
251
+ async def set_message(self, message: str | None) -> None:
252
+ """Update the progress status message."""
253
+ assert self._progress is not None, "Progress must be used as a dependency"
254
+ await self._progress.set_message(message)
255
+
256
+
257
+ class ForcedRetry(Exception):
258
+ """Raised when a task requests a retry via `in_` or `at`"""
259
+
260
+
261
+ class Retry(Dependency):
262
+ """Configures linear retries for a task. You can specify the total number of
263
+ attempts (or `None` to retry indefinitely), and the delay between attempts.
264
+
265
+ Example:
266
+
267
+ ```python
268
+ @task
269
+ async def my_task(retry: Retry = Retry(attempts=3)) -> None:
270
+ ...
271
+ ```
272
+ """
273
+
274
+ single: bool = True
275
+
276
+ def __init__(
277
+ self, attempts: int | None = 1, delay: timedelta = timedelta(0)
278
+ ) -> None:
279
+ """
280
+ Args:
281
+ attempts: The total number of attempts to make. If `None`, the task will
282
+ be retried indefinitely.
283
+ delay: The delay between attempts.
284
+ """
285
+ self.attempts = attempts
286
+ self.delay = delay
287
+ self.attempt = 1
288
+
289
+ async def __aenter__(self) -> "Retry":
290
+ execution = self.execution.get()
291
+ retry = Retry(attempts=self.attempts, delay=self.delay)
292
+ retry.attempt = execution.attempt
293
+ return retry
294
+
295
+ def at(self, when: datetime) -> NoReturn:
296
+ now = datetime.now(timezone.utc)
297
+ diff = when - now
298
+ diff = diff if diff.total_seconds() >= 0 else timedelta(0)
299
+
300
+ self.in_(diff)
301
+
302
+ def in_(self, when: timedelta) -> NoReturn:
303
+ self.delay: timedelta = when
304
+ raise ForcedRetry()
305
+
306
+
307
+ class ExponentialRetry(Retry):
308
+ """Configures exponential retries for a task. You can specify the total number
309
+ of attempts (or `None` to retry indefinitely), and the minimum and maximum delays
310
+ between attempts.
311
+
312
+ Example:
313
+
314
+ ```python
315
+ @task
316
+ async def my_task(retry: ExponentialRetry = ExponentialRetry(attempts=3)) -> None:
317
+ ...
318
+ ```
319
+ """
320
+
321
+ def __init__(
322
+ self,
323
+ attempts: int | None = 1,
324
+ minimum_delay: timedelta = timedelta(seconds=1),
325
+ maximum_delay: timedelta = timedelta(seconds=64),
326
+ ) -> None:
327
+ """
328
+ Args:
329
+ attempts: The total number of attempts to make. If `None`, the task will
330
+ be retried indefinitely.
331
+ minimum_delay: The minimum delay between attempts.
332
+ maximum_delay: The maximum delay between attempts.
333
+ """
334
+ super().__init__(attempts=attempts, delay=minimum_delay)
335
+ self.maximum_delay = maximum_delay
336
+
337
+ async def __aenter__(self) -> "ExponentialRetry":
338
+ execution = self.execution.get()
339
+
340
+ retry = ExponentialRetry(
341
+ attempts=self.attempts,
342
+ minimum_delay=self.delay,
343
+ maximum_delay=self.maximum_delay,
344
+ )
345
+ retry.attempt = execution.attempt
346
+
347
+ if execution.attempt > 1:
348
+ backoff_factor = 2 ** (execution.attempt - 1)
349
+ calculated_delay = self.delay * backoff_factor
350
+
351
+ if calculated_delay > self.maximum_delay:
352
+ retry.delay = self.maximum_delay
353
+ else:
354
+ retry.delay = calculated_delay
355
+
356
+ return retry
357
+
358
+
359
+ class Perpetual(Dependency):
360
+ """Declare a task that should be run perpetually. Perpetual tasks are automatically
361
+ rescheduled for the future after they finish (whether they succeed or fail). A
362
+ perpetual task can be scheduled at worker startup with the `automatic=True`.
363
+
364
+ Example:
365
+
366
+ ```python
367
+ @task
368
+ async def my_task(perpetual: Perpetual = Perpetual()) -> None:
369
+ ...
370
+ ```
371
+ """
372
+
373
+ single = True
374
+
375
+ every: timedelta
376
+ automatic: bool
377
+
378
+ args: tuple[Any, ...]
379
+ kwargs: dict[str, Any]
380
+
381
+ cancelled: bool
382
+
383
+ def __init__(
384
+ self,
385
+ every: timedelta = timedelta(0),
386
+ automatic: bool = False,
387
+ ) -> None:
388
+ """
389
+ Args:
390
+ every: The target interval between task executions.
391
+ automatic: If set, this task will be automatically scheduled during worker
392
+ startup and continually through the worker's lifespan. This ensures
393
+ that the task will always be scheduled despite crashes and other
394
+ adverse conditions. Automatic tasks must not require any arguments.
395
+ """
396
+ self.every = every
397
+ self.automatic = automatic
398
+ self.cancelled = False
399
+
400
+ async def __aenter__(self) -> "Perpetual":
401
+ execution = self.execution.get()
402
+ perpetual = Perpetual(every=self.every)
403
+ perpetual.args = execution.args
404
+ perpetual.kwargs = execution.kwargs
405
+ return perpetual
406
+
407
+ def cancel(self) -> None:
408
+ self.cancelled = True
409
+
410
+ def perpetuate(self, *args: Any, **kwargs: Any) -> None:
411
+ self.args = args
412
+ self.kwargs = kwargs
413
+
414
+
415
+ class Timeout(Dependency):
416
+ """Configures a timeout for a task. You can specify the base timeout, and the
417
+ task will be cancelled if it exceeds this duration. The timeout may be extended
418
+ within the context of a single running task.
419
+
420
+ Example:
421
+
422
+ ```python
423
+ @task
424
+ async def my_task(timeout: Timeout = Timeout(timedelta(seconds=10))) -> None:
425
+ ...
426
+ ```
427
+ """
428
+
429
+ single: bool = True
430
+
431
+ base: timedelta
432
+ _deadline: float
433
+
434
+ def __init__(self, base: timedelta) -> None:
435
+ """
436
+ Args:
437
+ base: The base timeout duration.
438
+ """
439
+ self.base = base
440
+
441
+ async def __aenter__(self) -> "Timeout":
442
+ timeout = Timeout(base=self.base)
443
+ timeout.start()
444
+ return timeout
445
+
446
+ def start(self) -> None:
447
+ self._deadline = time.monotonic() + self.base.total_seconds()
448
+
449
+ def expired(self) -> bool:
450
+ return time.monotonic() >= self._deadline
451
+
452
+ def remaining(self) -> timedelta:
453
+ """Get the remaining time until the timeout expires."""
454
+ return timedelta(seconds=self._deadline - time.monotonic())
455
+
456
+ def extend(self, by: timedelta | None = None) -> None:
457
+ """Extend the timeout by a given duration. If no duration is provided, the
458
+ base timeout will be used.
459
+
460
+ Args:
461
+ by: The duration to extend the timeout by.
462
+ """
463
+ if by is None:
464
+ by = self.base
465
+ self._deadline += by.total_seconds()
466
+
467
+
468
+ R = TypeVar("R")
469
+
470
+ DependencyFunction = Callable[
471
+ ..., R | Awaitable[R] | ContextManager[R] | AsyncContextManager[R]
472
+ ]
473
+
474
+
475
+ _parameter_cache: dict[
476
+ TaskFunction | DependencyFunction[Any],
477
+ dict[str, Dependency],
478
+ ] = {}
479
+
480
+
481
+ def get_dependency_parameters(
482
+ function: TaskFunction | DependencyFunction[Any],
483
+ ) -> dict[str, Dependency]:
484
+ if function in _parameter_cache:
485
+ CACHE_SIZE.set(len(_parameter_cache), {"cache": "parameter"})
486
+ return _parameter_cache[function]
487
+
488
+ dependencies: dict[str, Dependency] = {}
489
+
490
+ signature = get_signature(function)
491
+
492
+ for parameter, param in signature.parameters.items():
493
+ if not isinstance(param.default, Dependency):
494
+ continue
495
+
496
+ dependencies[parameter] = param.default
497
+
498
+ _parameter_cache[function] = dependencies
499
+ CACHE_SIZE.set(len(_parameter_cache), {"cache": "parameter"})
500
+ return dependencies
501
+
502
+
503
+ class _Depends(Dependency, Generic[R]):
504
+ dependency: DependencyFunction[R]
505
+
506
+ cache: ContextVar[dict[DependencyFunction[Any], Any]] = ContextVar("cache")
507
+ stack: ContextVar[AsyncExitStack] = ContextVar("stack")
508
+
509
+ def __init__(
510
+ self,
511
+ dependency: Callable[
512
+ [], R | Awaitable[R] | ContextManager[R] | AsyncContextManager[R]
513
+ ],
514
+ ) -> None:
515
+ self.dependency = dependency
516
+
517
+ async def _resolve_parameters(
518
+ self,
519
+ function: TaskFunction | DependencyFunction[Any],
520
+ ) -> dict[str, Any]:
521
+ stack = self.stack.get()
522
+
523
+ arguments: dict[str, Any] = {}
524
+ parameters = get_dependency_parameters(function)
525
+
526
+ for parameter, dependency in parameters.items():
527
+ # Special case for TaskArguments, they are "magical" and infer the parameter
528
+ # they refer to from the parameter name (unless otherwise specified)
529
+ if isinstance(dependency, _TaskArgument) and not dependency.parameter:
530
+ dependency.parameter = parameter
531
+
532
+ arguments[parameter] = await stack.enter_async_context(dependency)
533
+
534
+ return arguments
535
+
536
+ async def __aenter__(self) -> R:
537
+ cache = self.cache.get()
538
+
539
+ if self.dependency in cache:
540
+ return cache[self.dependency]
541
+
542
+ stack = self.stack.get()
543
+ arguments = await self._resolve_parameters(self.dependency)
544
+
545
+ raw_value: R | Awaitable[R] | ContextManager[R] | AsyncContextManager[R] = (
546
+ self.dependency(**arguments)
547
+ )
548
+
549
+ # Handle different return types from the dependency function
550
+ resolved_value: R
551
+ if isinstance(raw_value, AsyncContextManager):
552
+ # Async context manager: await enter_async_context
553
+ resolved_value = await stack.enter_async_context(raw_value)
554
+ elif isinstance(raw_value, ContextManager):
555
+ # Sync context manager: use enter_context (no await needed)
556
+ resolved_value = stack.enter_context(raw_value)
557
+ elif inspect.iscoroutine(raw_value) or isinstance(raw_value, Awaitable):
558
+ # Async function returning awaitable: await it
559
+ resolved_value = await cast(Awaitable[R], raw_value)
560
+ else:
561
+ # Sync function returning a value directly, use as-is
562
+ resolved_value = cast(R, raw_value)
563
+
564
+ cache[self.dependency] = resolved_value
565
+ return resolved_value
566
+
567
+
568
+ def Depends(dependency: DependencyFunction[R]) -> R:
569
+ """Include a user-defined function as a dependency. Dependencies may be:
570
+ - Synchronous functions returning a value
571
+ - Asynchronous functions returning a value (awaitable)
572
+ - Synchronous context managers (using @contextmanager)
573
+ - Asynchronous context managers (using @asynccontextmanager)
574
+
575
+ If a dependency returns a context manager, it will be entered and exited around
576
+ the task, giving an opportunity to control the lifetime of a resource.
577
+
578
+ **Important**: Synchronous dependencies should NOT include blocking I/O operations
579
+ (file access, network calls, database queries, etc.). Use async dependencies for
580
+ any I/O. Sync dependencies are best for:
581
+ - Pure computations
582
+ - In-memory data structure access
583
+ - Configuration lookups from memory
584
+ - Non-blocking transformations
585
+
586
+ Examples:
587
+
588
+ ```python
589
+ # Sync dependency - pure computation, no I/O
590
+ def get_config() -> dict:
591
+ # Access in-memory config, no I/O
592
+ return {"api_url": "https://api.example.com", "timeout": 30}
593
+
594
+ # Sync dependency - compute value from arguments
595
+ def build_query_params(
596
+ user_id: int = TaskArgument(),
597
+ config: dict = Depends(get_config)
598
+ ) -> dict:
599
+ # Pure computation, no I/O
600
+ return {"user_id": user_id, "timeout": config["timeout"]}
601
+
602
+ # Async dependency - I/O operations
603
+ async def get_user(user_id: int = TaskArgument()) -> User:
604
+ # Network I/O - must be async
605
+ return await fetch_user_from_api(user_id)
606
+
607
+ # Async context manager - I/O resource management
608
+ from contextlib import asynccontextmanager
609
+
610
+ @asynccontextmanager
611
+ async def get_db_connection():
612
+ # I/O operations - must be async
613
+ conn = await db.connect()
614
+ try:
615
+ yield conn
616
+ finally:
617
+ await conn.close()
618
+
619
+ @task
620
+ async def my_task(
621
+ params: dict = Depends(build_query_params),
622
+ user: User = Depends(get_user),
623
+ db: Connection = Depends(get_db_connection),
624
+ ) -> None:
625
+ await db.execute("UPDATE users SET ...", params)
626
+ ```
627
+ """
628
+ return cast(R, _Depends(dependency))
629
+
630
+
631
+ class ConcurrencyLimit(Dependency):
632
+ """Configures concurrency limits for a task based on specific argument values.
633
+
634
+ This allows fine-grained control over task execution by limiting concurrent
635
+ tasks based on the value of specific arguments.
636
+
637
+ Example:
638
+
639
+ ```python
640
+ async def process_customer(
641
+ customer_id: int,
642
+ concurrency: ConcurrencyLimit = ConcurrencyLimit("customer_id", max_concurrent=1)
643
+ ) -> None:
644
+ # Only one task per customer_id will run at a time
645
+ ...
646
+
647
+ async def backup_db(
648
+ db_name: str,
649
+ concurrency: ConcurrencyLimit = ConcurrencyLimit("db_name", max_concurrent=3)
650
+ ) -> None:
651
+ # Only 3 backup tasks per database name will run at a time
652
+ ...
653
+ ```
654
+ """
655
+
656
+ single: bool = True
657
+
658
+ def __init__(
659
+ self, argument_name: str, max_concurrent: int = 1, scope: str | None = None
660
+ ) -> None:
661
+ """
662
+ Args:
663
+ argument_name: The name of the task argument to use for concurrency grouping
664
+ max_concurrent: Maximum number of concurrent tasks per unique argument value
665
+ scope: Optional scope prefix for Redis keys (defaults to docket name)
666
+ """
667
+ self.argument_name = argument_name
668
+ self.max_concurrent = max_concurrent
669
+ self.scope = scope
670
+ self._concurrency_key: str | None = None
671
+ self._initialized: bool = False
672
+
673
+ async def __aenter__(self) -> "ConcurrencyLimit":
674
+ execution = self.execution.get()
675
+ docket = self.docket.get()
676
+
677
+ # Get the argument value to group by
678
+ try:
679
+ argument_value = execution.get_argument(self.argument_name)
680
+ except KeyError:
681
+ # If argument not found, create a bypass limit that doesn't apply concurrency control
682
+ limit = ConcurrencyLimit(
683
+ self.argument_name, self.max_concurrent, self.scope
684
+ )
685
+ limit._concurrency_key = None # Special marker for bypassed concurrency
686
+ limit._initialized = True # Mark as initialized but bypassed
687
+ return limit
688
+
689
+ # Create a concurrency key for this specific argument value
690
+ scope = self.scope or docket.name
691
+ self._concurrency_key = (
692
+ f"{scope}:concurrency:{self.argument_name}:{argument_value}"
693
+ )
694
+
695
+ limit = ConcurrencyLimit(self.argument_name, self.max_concurrent, self.scope)
696
+ limit._concurrency_key = self._concurrency_key
697
+ limit._initialized = True # Mark as initialized
698
+ return limit
699
+
700
+ @property
701
+ def concurrency_key(self) -> str | None:
702
+ """Redis key used for tracking concurrency for this specific argument value.
703
+ Returns None when concurrency control is bypassed due to missing arguments.
704
+ Raises RuntimeError if accessed before initialization."""
705
+ if not self._initialized:
706
+ raise RuntimeError(
707
+ "ConcurrencyLimit not initialized - use within task context"
708
+ )
709
+ return self._concurrency_key
710
+
711
+ @property
712
+ def is_bypassed(self) -> bool:
713
+ """Returns True if concurrency control is bypassed due to missing arguments."""
714
+ return self._initialized and self._concurrency_key is None
715
+
716
+
717
+ D = TypeVar("D", bound=Dependency)
718
+
719
+
720
+ def get_single_dependency_parameter_of_type(
721
+ function: TaskFunction, dependency_type: type[D]
722
+ ) -> D | None:
723
+ assert dependency_type.single, "Dependency must be single"
724
+ for _, dependency in get_dependency_parameters(function).items():
725
+ if isinstance(dependency, dependency_type):
726
+ return dependency
727
+ return None
728
+
729
+
730
+ def get_single_dependency_of_type(
731
+ dependencies: dict[str, Dependency], dependency_type: type[D]
732
+ ) -> D | None:
733
+ assert dependency_type.single, "Dependency must be single"
734
+ for _, dependency in dependencies.items():
735
+ if isinstance(dependency, dependency_type):
736
+ return dependency
737
+ return None
738
+
739
+
740
+ def validate_dependencies(function: TaskFunction) -> None:
741
+ parameters = get_dependency_parameters(function)
742
+
743
+ counts = Counter(type(dependency) for dependency in parameters.values())
744
+
745
+ for dependency_type, count in counts.items():
746
+ if dependency_type.single and count > 1:
747
+ raise ValueError(
748
+ f"Only one {dependency_type.__name__} dependency is allowed per task"
749
+ )
750
+
751
+
752
+ class FailedDependency:
753
+ def __init__(self, parameter: str, error: Exception) -> None:
754
+ self.parameter = parameter
755
+ self.error = error
756
+
757
+
758
+ @asynccontextmanager
759
+ async def resolved_dependencies(
760
+ worker: "Worker", execution: Execution
761
+ ) -> AsyncGenerator[dict[str, Any], None]:
762
+ # Capture tokens for all contextvar sets to ensure proper cleanup
763
+ docket_token = Dependency.docket.set(worker.docket)
764
+ worker_token = Dependency.worker.set(worker)
765
+ execution_token = Dependency.execution.set(execution)
766
+ cache_token = _Depends.cache.set({})
767
+
768
+ try:
769
+ async with AsyncExitStack() as stack:
770
+ stack_token = _Depends.stack.set(stack)
771
+ try:
772
+ arguments: dict[str, Any] = {}
773
+
774
+ parameters = get_dependency_parameters(execution.function)
775
+ for parameter, dependency in parameters.items():
776
+ kwargs = execution.kwargs
777
+ if parameter in kwargs:
778
+ arguments[parameter] = kwargs[parameter]
779
+ continue
780
+
781
+ # Special case for TaskArguments, they are "magical" and infer the parameter
782
+ # they refer to from the parameter name (unless otherwise specified). At
783
+ # the top-level task function call, it doesn't make sense to specify one
784
+ # _without_ a parameter name, so we'll call that a failed dependency.
785
+ if (
786
+ isinstance(dependency, _TaskArgument)
787
+ and not dependency.parameter
788
+ ):
789
+ arguments[parameter] = FailedDependency(
790
+ parameter, ValueError("No parameter name specified")
791
+ )
792
+ continue
793
+
794
+ try:
795
+ arguments[parameter] = await stack.enter_async_context(
796
+ dependency
797
+ )
798
+ except Exception as error:
799
+ arguments[parameter] = FailedDependency(parameter, error)
800
+
801
+ yield arguments
802
+ finally:
803
+ _Depends.stack.reset(stack_token)
804
+ finally:
805
+ _Depends.cache.reset(cache_token)
806
+ Dependency.execution.reset(execution_token)
807
+ Dependency.worker.reset(worker_token)
808
+ Dependency.docket.reset(docket_token)