pydocket 0.5.2__tar.gz → 0.6.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydocket might be problematic. Click here for more details.
- {pydocket-0.5.2 → pydocket-0.6.1}/PKG-INFO +1 -1
- {pydocket-0.5.2 → pydocket-0.6.1}/chaos/tasks.py +11 -1
- {pydocket-0.5.2 → pydocket-0.6.1}/examples/find_and_flood.py +1 -1
- {pydocket-0.5.2 → pydocket-0.6.1}/pyproject.toml +1 -1
- {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/__init__.py +4 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/annotations.py +23 -0
- pydocket-0.6.1/src/docket/dependencies.py +366 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/docket.py +5 -4
- {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/execution.py +35 -12
- {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/worker.py +264 -226
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/conftest.py +2 -2
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/test_fundamentals.py +320 -2
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/test_worker.py +69 -15
- {pydocket-0.5.2 → pydocket-0.6.1}/uv.lock +4 -4
- pydocket-0.5.2/src/docket/dependencies.py +0 -222
- {pydocket-0.5.2 → pydocket-0.6.1}/.cursor/rules/general.mdc +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/.cursor/rules/python-style.mdc +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/.github/codecov.yml +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/.github/workflows/chaos.yml +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/.github/workflows/ci.yml +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/.github/workflows/publish.yml +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/.gitignore +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/.pre-commit-config.yaml +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/LICENSE +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/README.md +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/chaos/README.md +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/chaos/__init__.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/chaos/driver.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/chaos/producer.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/chaos/run +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/examples/__init__.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/examples/common.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/__main__.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/cli.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/instrumentation.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/py.typed +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/tasks.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/telemetry/.gitignore +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/telemetry/start +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/telemetry/stop +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/__init__.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/__init__.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/conftest.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_module.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_parsing.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_snapshot.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_striking.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_tasks.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_version.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_worker.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_workers.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/test_dependencies.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/test_docket.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/test_instrumentation.py +0 -0
- {pydocket-0.5.2 → pydocket-0.6.1}/tests/test_striking.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydocket
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.1
|
|
4
4
|
Summary: A distributed background task system for Python functions
|
|
5
5
|
Project-URL: Homepage, https://github.com/chrisguidry/docket
|
|
6
6
|
Project-URL: Bug Tracker, https://github.com/chrisguidry/docket/issues
|
|
@@ -2,17 +2,27 @@ import logging
|
|
|
2
2
|
import sys
|
|
3
3
|
import time
|
|
4
4
|
|
|
5
|
-
from docket import CurrentDocket, Docket, Retry, TaskKey
|
|
5
|
+
from docket import CurrentDocket, Depends, Docket, Retry, TaskKey
|
|
6
6
|
|
|
7
7
|
logger = logging.getLogger(__name__)
|
|
8
8
|
|
|
9
9
|
|
|
10
|
+
async def greeting() -> str:
|
|
11
|
+
return "Hello, world"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def emphatic_greeting(greeting: str = Depends(greeting)) -> str:
|
|
15
|
+
return greeting + "!"
|
|
16
|
+
|
|
17
|
+
|
|
10
18
|
async def hello(
|
|
19
|
+
greeting: str = Depends(emphatic_greeting),
|
|
11
20
|
key: str = TaskKey(),
|
|
12
21
|
docket: Docket = CurrentDocket(),
|
|
13
22
|
retry: Retry = Retry(attempts=sys.maxsize),
|
|
14
23
|
):
|
|
15
24
|
logger.info("Starting task %s", key)
|
|
25
|
+
logger.info("Greeting: %s", greeting)
|
|
16
26
|
async with docket.redis() as redis:
|
|
17
27
|
await redis.zadd("hello:received", {key: time.time()})
|
|
18
28
|
logger.info("Finished task %s", key)
|
|
@@ -13,11 +13,13 @@ from .dependencies import (
|
|
|
13
13
|
CurrentDocket,
|
|
14
14
|
CurrentExecution,
|
|
15
15
|
CurrentWorker,
|
|
16
|
+
Depends,
|
|
16
17
|
ExponentialRetry,
|
|
17
18
|
Perpetual,
|
|
18
19
|
Retry,
|
|
19
20
|
TaskKey,
|
|
20
21
|
TaskLogger,
|
|
22
|
+
Timeout,
|
|
21
23
|
)
|
|
22
24
|
from .docket import Docket
|
|
23
25
|
from .execution import Execution
|
|
@@ -36,5 +38,7 @@ __all__ = [
|
|
|
36
38
|
"ExponentialRetry",
|
|
37
39
|
"Logged",
|
|
38
40
|
"Perpetual",
|
|
41
|
+
"Timeout",
|
|
42
|
+
"Depends",
|
|
39
43
|
"__version__",
|
|
40
44
|
]
|
|
@@ -4,8 +4,14 @@ from typing import Any, Iterable, Mapping, Self
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class Annotation(abc.ABC):
|
|
7
|
+
_cache: dict[tuple[type[Self], inspect.Signature], Mapping[str, Self]] = {}
|
|
8
|
+
|
|
7
9
|
@classmethod
|
|
8
10
|
def annotated_parameters(cls, signature: inspect.Signature) -> Mapping[str, Self]:
|
|
11
|
+
key = (cls, signature)
|
|
12
|
+
if key in cls._cache:
|
|
13
|
+
return cls._cache[key]
|
|
14
|
+
|
|
9
15
|
annotated: dict[str, Self] = {}
|
|
10
16
|
|
|
11
17
|
for param_name, param in signature.parameters.items():
|
|
@@ -23,8 +29,25 @@ class Annotation(abc.ABC):
|
|
|
23
29
|
elif isinstance(arg_type, type) and issubclass(arg_type, cls):
|
|
24
30
|
annotated[param_name] = arg_type()
|
|
25
31
|
|
|
32
|
+
cls._cache[key] = annotated
|
|
26
33
|
return annotated
|
|
27
34
|
|
|
28
35
|
|
|
29
36
|
class Logged(Annotation):
|
|
30
37
|
"""Instructs docket to include arguments to this parameter in the log."""
|
|
38
|
+
|
|
39
|
+
length_only: bool = False
|
|
40
|
+
|
|
41
|
+
def __init__(self, length_only: bool = False) -> None:
|
|
42
|
+
self.length_only = length_only
|
|
43
|
+
|
|
44
|
+
def format(self, argument: Any) -> str:
|
|
45
|
+
if self.length_only:
|
|
46
|
+
if isinstance(argument, (dict, set)):
|
|
47
|
+
return f"{{len {len(argument)}}}"
|
|
48
|
+
elif isinstance(argument, tuple):
|
|
49
|
+
return f"(len {len(argument)})"
|
|
50
|
+
elif hasattr(argument, "__len__"):
|
|
51
|
+
return f"[len {len(argument)}]"
|
|
52
|
+
|
|
53
|
+
return repr(argument)
|
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
|
5
|
+
from contextvars import ContextVar
|
|
6
|
+
from datetime import timedelta
|
|
7
|
+
from types import TracebackType
|
|
8
|
+
from typing import (
|
|
9
|
+
TYPE_CHECKING,
|
|
10
|
+
Any,
|
|
11
|
+
AsyncContextManager,
|
|
12
|
+
AsyncGenerator,
|
|
13
|
+
Awaitable,
|
|
14
|
+
Callable,
|
|
15
|
+
Counter,
|
|
16
|
+
Generic,
|
|
17
|
+
TypeVar,
|
|
18
|
+
cast,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from .docket import Docket
|
|
22
|
+
from .execution import Execution, TaskFunction, get_signature
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING: # pragma: no cover
|
|
25
|
+
from .worker import Worker
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Dependency(abc.ABC):
|
|
29
|
+
single: bool = False
|
|
30
|
+
|
|
31
|
+
docket: ContextVar[Docket] = ContextVar("docket")
|
|
32
|
+
worker: ContextVar["Worker"] = ContextVar("worker")
|
|
33
|
+
execution: ContextVar[Execution] = ContextVar("execution")
|
|
34
|
+
|
|
35
|
+
@abc.abstractmethod
|
|
36
|
+
async def __aenter__(self) -> Any: ... # pragma: no cover
|
|
37
|
+
|
|
38
|
+
async def __aexit__(
|
|
39
|
+
self,
|
|
40
|
+
exc_type: type[BaseException] | None,
|
|
41
|
+
exc_value: BaseException | None,
|
|
42
|
+
traceback: TracebackType | None,
|
|
43
|
+
) -> bool: ... # pragma: no cover
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class _CurrentWorker(Dependency):
|
|
47
|
+
async def __aenter__(self) -> "Worker":
|
|
48
|
+
return self.worker.get()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def CurrentWorker() -> "Worker":
|
|
52
|
+
return cast("Worker", _CurrentWorker())
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class _CurrentDocket(Dependency):
|
|
56
|
+
async def __aenter__(self) -> Docket:
|
|
57
|
+
return self.docket.get()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def CurrentDocket() -> Docket:
|
|
61
|
+
return cast(Docket, _CurrentDocket())
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class _CurrentExecution(Dependency):
|
|
65
|
+
async def __aenter__(self) -> Execution:
|
|
66
|
+
return self.execution.get()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def CurrentExecution() -> Execution:
|
|
70
|
+
return cast(Execution, _CurrentExecution())
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class _TaskKey(Dependency):
|
|
74
|
+
async def __aenter__(self) -> str:
|
|
75
|
+
return self.execution.get().key
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def TaskKey() -> str:
|
|
79
|
+
return cast(str, _TaskKey())
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class _TaskLogger(Dependency):
|
|
83
|
+
async def __aenter__(self) -> logging.LoggerAdapter[logging.Logger]:
|
|
84
|
+
execution = self.execution.get()
|
|
85
|
+
logger = logging.getLogger(f"docket.task.{execution.function.__name__}")
|
|
86
|
+
return logging.LoggerAdapter(
|
|
87
|
+
logger,
|
|
88
|
+
{
|
|
89
|
+
**self.docket.get().labels(),
|
|
90
|
+
**self.worker.get().labels(),
|
|
91
|
+
**execution.specific_labels(),
|
|
92
|
+
},
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def TaskLogger() -> logging.LoggerAdapter[logging.Logger]:
|
|
97
|
+
return cast(logging.LoggerAdapter[logging.Logger], _TaskLogger())
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class Retry(Dependency):
|
|
101
|
+
single: bool = True
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self, attempts: int | None = 1, delay: timedelta = timedelta(0)
|
|
105
|
+
) -> None:
|
|
106
|
+
self.attempts = attempts
|
|
107
|
+
self.delay = delay
|
|
108
|
+
self.attempt = 1
|
|
109
|
+
|
|
110
|
+
async def __aenter__(self) -> "Retry":
|
|
111
|
+
execution = self.execution.get()
|
|
112
|
+
retry = Retry(attempts=self.attempts, delay=self.delay)
|
|
113
|
+
retry.attempt = execution.attempt
|
|
114
|
+
return retry
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class ExponentialRetry(Retry):
|
|
118
|
+
attempts: int
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
attempts: int = 1,
|
|
123
|
+
minimum_delay: timedelta = timedelta(seconds=1),
|
|
124
|
+
maximum_delay: timedelta = timedelta(seconds=64),
|
|
125
|
+
) -> None:
|
|
126
|
+
super().__init__(attempts=attempts, delay=minimum_delay)
|
|
127
|
+
self.minimum_delay = minimum_delay
|
|
128
|
+
self.maximum_delay = maximum_delay
|
|
129
|
+
|
|
130
|
+
async def __aenter__(self) -> "ExponentialRetry":
|
|
131
|
+
execution = self.execution.get()
|
|
132
|
+
|
|
133
|
+
retry = ExponentialRetry(
|
|
134
|
+
attempts=self.attempts,
|
|
135
|
+
minimum_delay=self.minimum_delay,
|
|
136
|
+
maximum_delay=self.maximum_delay,
|
|
137
|
+
)
|
|
138
|
+
retry.attempt = execution.attempt
|
|
139
|
+
|
|
140
|
+
if execution.attempt > 1:
|
|
141
|
+
backoff_factor = 2 ** (execution.attempt - 1)
|
|
142
|
+
calculated_delay = self.minimum_delay * backoff_factor
|
|
143
|
+
|
|
144
|
+
if calculated_delay > self.maximum_delay:
|
|
145
|
+
retry.delay = self.maximum_delay
|
|
146
|
+
else:
|
|
147
|
+
retry.delay = calculated_delay
|
|
148
|
+
|
|
149
|
+
return retry
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class Perpetual(Dependency):
|
|
153
|
+
single = True
|
|
154
|
+
|
|
155
|
+
every: timedelta
|
|
156
|
+
automatic: bool
|
|
157
|
+
|
|
158
|
+
args: tuple[Any, ...]
|
|
159
|
+
kwargs: dict[str, Any]
|
|
160
|
+
|
|
161
|
+
cancelled: bool
|
|
162
|
+
|
|
163
|
+
def __init__(
|
|
164
|
+
self,
|
|
165
|
+
every: timedelta = timedelta(0),
|
|
166
|
+
automatic: bool = False,
|
|
167
|
+
) -> None:
|
|
168
|
+
"""Declare a task that should be run perpetually.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
every: The target interval between task executions.
|
|
172
|
+
automatic: If set, this task will be automatically scheduled during worker
|
|
173
|
+
startup and continually through the worker's lifespan. This ensures
|
|
174
|
+
that the task will always be scheduled despite crashes and other
|
|
175
|
+
adverse conditions. Automatic tasks must not require any arguments.
|
|
176
|
+
"""
|
|
177
|
+
self.every = every
|
|
178
|
+
self.automatic = automatic
|
|
179
|
+
self.cancelled = False
|
|
180
|
+
|
|
181
|
+
async def __aenter__(self) -> "Perpetual":
|
|
182
|
+
execution = self.execution.get()
|
|
183
|
+
perpetual = Perpetual(every=self.every)
|
|
184
|
+
perpetual.args = execution.args
|
|
185
|
+
perpetual.kwargs = execution.kwargs
|
|
186
|
+
return perpetual
|
|
187
|
+
|
|
188
|
+
def cancel(self) -> None:
|
|
189
|
+
self.cancelled = True
|
|
190
|
+
|
|
191
|
+
def perpetuate(self, *args: Any, **kwargs: Any) -> None:
|
|
192
|
+
self.args = args
|
|
193
|
+
self.kwargs = kwargs
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class Timeout(Dependency):
|
|
197
|
+
single = True
|
|
198
|
+
|
|
199
|
+
base: timedelta
|
|
200
|
+
|
|
201
|
+
_deadline: float
|
|
202
|
+
|
|
203
|
+
def __init__(self, base: timedelta) -> None:
|
|
204
|
+
self.base = base
|
|
205
|
+
|
|
206
|
+
async def __aenter__(self) -> "Timeout":
|
|
207
|
+
timeout = Timeout(base=self.base)
|
|
208
|
+
timeout.start()
|
|
209
|
+
return timeout
|
|
210
|
+
|
|
211
|
+
def start(self) -> None:
|
|
212
|
+
self._deadline = time.monotonic() + self.base.total_seconds()
|
|
213
|
+
|
|
214
|
+
def expired(self) -> bool:
|
|
215
|
+
return time.monotonic() >= self._deadline
|
|
216
|
+
|
|
217
|
+
def remaining(self) -> timedelta:
|
|
218
|
+
return timedelta(seconds=self._deadline - time.monotonic())
|
|
219
|
+
|
|
220
|
+
def extend(self, by: timedelta | None = None) -> None:
|
|
221
|
+
if by is None:
|
|
222
|
+
by = self.base
|
|
223
|
+
self._deadline += by.total_seconds()
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
R = TypeVar("R")
|
|
227
|
+
|
|
228
|
+
DependencyFunction = Callable[..., Awaitable[R] | AsyncContextManager[R]]
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
_parameter_cache: dict[
|
|
232
|
+
TaskFunction | DependencyFunction[Any],
|
|
233
|
+
dict[str, Dependency],
|
|
234
|
+
] = {}
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def get_dependency_parameters(
|
|
238
|
+
function: TaskFunction | DependencyFunction[Any],
|
|
239
|
+
) -> dict[str, Dependency]:
|
|
240
|
+
if function in _parameter_cache:
|
|
241
|
+
return _parameter_cache[function]
|
|
242
|
+
|
|
243
|
+
dependencies: dict[str, Dependency] = {}
|
|
244
|
+
|
|
245
|
+
signature = get_signature(function)
|
|
246
|
+
|
|
247
|
+
for parameter, param in signature.parameters.items():
|
|
248
|
+
if not isinstance(param.default, Dependency):
|
|
249
|
+
continue
|
|
250
|
+
|
|
251
|
+
dependencies[parameter] = param.default
|
|
252
|
+
|
|
253
|
+
_parameter_cache[function] = dependencies
|
|
254
|
+
return dependencies
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class _Depends(Dependency, Generic[R]):
|
|
258
|
+
dependency: DependencyFunction[R]
|
|
259
|
+
|
|
260
|
+
cache: ContextVar[dict[DependencyFunction[Any], Any]] = ContextVar("cache")
|
|
261
|
+
stack: ContextVar[AsyncExitStack] = ContextVar("stack")
|
|
262
|
+
|
|
263
|
+
def __init__(
|
|
264
|
+
self, dependency: Callable[[], Awaitable[R] | AsyncContextManager[R]]
|
|
265
|
+
) -> None:
|
|
266
|
+
self.dependency = dependency
|
|
267
|
+
|
|
268
|
+
async def _resolve_parameters(
|
|
269
|
+
self,
|
|
270
|
+
function: TaskFunction | DependencyFunction[Any],
|
|
271
|
+
) -> dict[str, Any]:
|
|
272
|
+
stack = self.stack.get()
|
|
273
|
+
|
|
274
|
+
arguments: dict[str, Any] = {}
|
|
275
|
+
parameters = get_dependency_parameters(function)
|
|
276
|
+
|
|
277
|
+
for parameter, dependency in parameters.items():
|
|
278
|
+
arguments[parameter] = await stack.enter_async_context(dependency)
|
|
279
|
+
|
|
280
|
+
return arguments
|
|
281
|
+
|
|
282
|
+
async def __aenter__(self) -> R:
|
|
283
|
+
cache = self.cache.get()
|
|
284
|
+
|
|
285
|
+
if self.dependency in cache:
|
|
286
|
+
return cache[self.dependency]
|
|
287
|
+
|
|
288
|
+
stack = self.stack.get()
|
|
289
|
+
arguments = await self._resolve_parameters(self.dependency)
|
|
290
|
+
|
|
291
|
+
value = self.dependency(**arguments)
|
|
292
|
+
|
|
293
|
+
if isinstance(value, AsyncContextManager):
|
|
294
|
+
value = await stack.enter_async_context(value)
|
|
295
|
+
else:
|
|
296
|
+
value = await value
|
|
297
|
+
|
|
298
|
+
cache[self.dependency] = value
|
|
299
|
+
return value
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def Depends(dependency: DependencyFunction[R]) -> R:
|
|
303
|
+
return cast(R, _Depends(dependency))
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
D = TypeVar("D", bound=Dependency)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def get_single_dependency_parameter_of_type(
|
|
310
|
+
function: TaskFunction, dependency_type: type[D]
|
|
311
|
+
) -> D | None:
|
|
312
|
+
assert dependency_type.single, "Dependency must be single"
|
|
313
|
+
for _, dependency in get_dependency_parameters(function).items():
|
|
314
|
+
if isinstance(dependency, dependency_type):
|
|
315
|
+
return dependency
|
|
316
|
+
return None
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def get_single_dependency_of_type(
|
|
320
|
+
dependencies: dict[str, Dependency], dependency_type: type[D]
|
|
321
|
+
) -> D | None:
|
|
322
|
+
assert dependency_type.single, "Dependency must be single"
|
|
323
|
+
for _, dependency in dependencies.items():
|
|
324
|
+
if isinstance(dependency, dependency_type):
|
|
325
|
+
return dependency
|
|
326
|
+
return None
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def validate_dependencies(function: TaskFunction) -> None:
|
|
330
|
+
parameters = get_dependency_parameters(function)
|
|
331
|
+
|
|
332
|
+
counts = Counter(type(dependency) for dependency in parameters.values())
|
|
333
|
+
|
|
334
|
+
for dependency_type, count in counts.items():
|
|
335
|
+
if dependency_type.single and count > 1:
|
|
336
|
+
raise ValueError(
|
|
337
|
+
f"Only one {dependency_type.__name__} dependency is allowed per task"
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
@asynccontextmanager
|
|
342
|
+
async def resolved_dependencies(
|
|
343
|
+
worker: "Worker", execution: Execution
|
|
344
|
+
) -> AsyncGenerator[dict[str, Any], None]:
|
|
345
|
+
# Set context variables once at the beginning
|
|
346
|
+
Dependency.docket.set(worker.docket)
|
|
347
|
+
Dependency.worker.set(worker)
|
|
348
|
+
Dependency.execution.set(execution)
|
|
349
|
+
|
|
350
|
+
_Depends.cache.set({})
|
|
351
|
+
|
|
352
|
+
async with AsyncExitStack() as stack:
|
|
353
|
+
_Depends.stack.set(stack)
|
|
354
|
+
|
|
355
|
+
arguments: dict[str, Any] = {}
|
|
356
|
+
|
|
357
|
+
parameters = get_dependency_parameters(execution.function)
|
|
358
|
+
for parameter, dependency in parameters.items():
|
|
359
|
+
kwargs = execution.kwargs
|
|
360
|
+
if parameter in kwargs:
|
|
361
|
+
arguments[parameter] = kwargs[parameter]
|
|
362
|
+
continue
|
|
363
|
+
|
|
364
|
+
arguments[parameter] = await stack.enter_async_context(dependency)
|
|
365
|
+
|
|
366
|
+
yield arguments
|
|
@@ -38,6 +38,7 @@ from .execution import (
|
|
|
38
38
|
Strike,
|
|
39
39
|
StrikeInstruction,
|
|
40
40
|
StrikeList,
|
|
41
|
+
TaskFunction,
|
|
41
42
|
)
|
|
42
43
|
from .instrumentation import (
|
|
43
44
|
REDIS_DISRUPTIONS,
|
|
@@ -57,7 +58,7 @@ tracer: trace.Tracer = trace.get_tracer(__name__)
|
|
|
57
58
|
P = ParamSpec("P")
|
|
58
59
|
R = TypeVar("R")
|
|
59
60
|
|
|
60
|
-
TaskCollection = Iterable[
|
|
61
|
+
TaskCollection = Iterable[TaskFunction]
|
|
61
62
|
|
|
62
63
|
RedisStreamID = bytes
|
|
63
64
|
RedisMessageID = bytes
|
|
@@ -91,7 +92,7 @@ class RunningExecution(Execution):
|
|
|
91
92
|
worker: str,
|
|
92
93
|
started: datetime,
|
|
93
94
|
) -> None:
|
|
94
|
-
self.function:
|
|
95
|
+
self.function: TaskFunction = execution.function
|
|
95
96
|
self.args: tuple[Any, ...] = execution.args
|
|
96
97
|
self.kwargs: dict[str, Any] = execution.kwargs
|
|
97
98
|
self.when: datetime = execution.when
|
|
@@ -111,7 +112,7 @@ class DocketSnapshot:
|
|
|
111
112
|
|
|
112
113
|
|
|
113
114
|
class Docket:
|
|
114
|
-
tasks: dict[str,
|
|
115
|
+
tasks: dict[str, TaskFunction]
|
|
115
116
|
strike_list: StrikeList
|
|
116
117
|
|
|
117
118
|
_monitor_strikes_task: asyncio.Task[None]
|
|
@@ -197,7 +198,7 @@ class Docket:
|
|
|
197
198
|
finally:
|
|
198
199
|
await asyncio.shield(r.__aexit__(None, None, None))
|
|
199
200
|
|
|
200
|
-
def register(self, function:
|
|
201
|
+
def register(self, function: TaskFunction) -> None:
|
|
201
202
|
from .dependencies import validate_dependencies
|
|
202
203
|
|
|
203
204
|
validate_dependencies(function)
|
|
@@ -7,23 +7,40 @@ from typing import Any, Awaitable, Callable, Hashable, Literal, Mapping, Self, c
|
|
|
7
7
|
|
|
8
8
|
import cloudpickle # type: ignore[import]
|
|
9
9
|
|
|
10
|
+
from opentelemetry import trace, propagate
|
|
11
|
+
import opentelemetry.context
|
|
10
12
|
|
|
11
13
|
from .annotations import Logged
|
|
14
|
+
from docket.instrumentation import message_getter
|
|
12
15
|
|
|
13
16
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
14
17
|
|
|
18
|
+
TaskFunction = Callable[..., Awaitable[Any]]
|
|
15
19
|
Message = dict[bytes, bytes]
|
|
16
20
|
|
|
17
21
|
|
|
22
|
+
_signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_signature(function: Callable[..., Any]) -> inspect.Signature:
|
|
26
|
+
if function in _signature_cache:
|
|
27
|
+
return _signature_cache[function]
|
|
28
|
+
|
|
29
|
+
signature = inspect.signature(function)
|
|
30
|
+
_signature_cache[function] = signature
|
|
31
|
+
return signature
|
|
32
|
+
|
|
33
|
+
|
|
18
34
|
class Execution:
|
|
19
35
|
def __init__(
|
|
20
36
|
self,
|
|
21
|
-
function:
|
|
37
|
+
function: TaskFunction,
|
|
22
38
|
args: tuple[Any, ...],
|
|
23
39
|
kwargs: dict[str, Any],
|
|
24
40
|
when: datetime,
|
|
25
41
|
key: str,
|
|
26
42
|
attempt: int,
|
|
43
|
+
trace_context: opentelemetry.context.Context | None = None,
|
|
27
44
|
) -> None:
|
|
28
45
|
self.function = function
|
|
29
46
|
self.args = args
|
|
@@ -31,6 +48,7 @@ class Execution:
|
|
|
31
48
|
self.when = when
|
|
32
49
|
self.key = key
|
|
33
50
|
self.attempt = attempt
|
|
51
|
+
self.trace_context = trace_context
|
|
34
52
|
|
|
35
53
|
def as_message(self) -> Message:
|
|
36
54
|
return {
|
|
@@ -43,9 +61,7 @@ class Execution:
|
|
|
43
61
|
}
|
|
44
62
|
|
|
45
63
|
@classmethod
|
|
46
|
-
def from_message(
|
|
47
|
-
cls, function: Callable[..., Awaitable[Any]], message: Message
|
|
48
|
-
) -> Self:
|
|
64
|
+
def from_message(cls, function: TaskFunction, message: Message) -> Self:
|
|
49
65
|
return cls(
|
|
50
66
|
function=function,
|
|
51
67
|
args=cloudpickle.loads(message[b"args"]),
|
|
@@ -53,6 +69,7 @@ class Execution:
|
|
|
53
69
|
when=datetime.fromisoformat(message[b"when"].decode()),
|
|
54
70
|
key=message[b"key"].decode(),
|
|
55
71
|
attempt=int(message[b"attempt"].decode()),
|
|
72
|
+
trace_context=propagate.extract(message, getter=message_getter),
|
|
56
73
|
)
|
|
57
74
|
|
|
58
75
|
def general_labels(self) -> Mapping[str, str]:
|
|
@@ -68,28 +85,32 @@ class Execution:
|
|
|
68
85
|
|
|
69
86
|
def call_repr(self) -> str:
|
|
70
87
|
arguments: list[str] = []
|
|
71
|
-
signature = inspect.signature(self.function)
|
|
72
88
|
function_name = self.function.__name__
|
|
73
89
|
|
|
90
|
+
signature = get_signature(self.function)
|
|
74
91
|
logged_parameters = Logged.annotated_parameters(signature)
|
|
75
|
-
|
|
76
92
|
parameter_names = list(signature.parameters.keys())
|
|
77
93
|
|
|
78
94
|
for i, argument in enumerate(self.args[: len(parameter_names)]):
|
|
79
95
|
parameter_name = parameter_names[i]
|
|
80
|
-
if
|
|
81
|
-
arguments.append(
|
|
96
|
+
if logged := logged_parameters.get(parameter_name):
|
|
97
|
+
arguments.append(logged.format(argument))
|
|
82
98
|
else:
|
|
83
99
|
arguments.append("...")
|
|
84
100
|
|
|
85
101
|
for parameter_name, argument in self.kwargs.items():
|
|
86
|
-
if
|
|
87
|
-
arguments.append(f"{parameter_name}={
|
|
102
|
+
if logged := logged_parameters.get(parameter_name):
|
|
103
|
+
arguments.append(f"{parameter_name}={logged.format(argument)}")
|
|
88
104
|
else:
|
|
89
105
|
arguments.append(f"{parameter_name}=...")
|
|
90
106
|
|
|
91
107
|
return f"{function_name}({', '.join(arguments)}){{{self.key}}}"
|
|
92
108
|
|
|
109
|
+
def incoming_span_links(self) -> list[trace.Link]:
|
|
110
|
+
initiating_span = trace.get_current_span(self.trace_context)
|
|
111
|
+
initiating_context = initiating_span.get_span_context()
|
|
112
|
+
return [trace.Link(initiating_context)] if initiating_context.is_valid else []
|
|
113
|
+
|
|
93
114
|
|
|
94
115
|
class Operator(enum.StrEnum):
|
|
95
116
|
EQUAL = "=="
|
|
@@ -217,10 +238,10 @@ class StrikeList:
|
|
|
217
238
|
if function_name in self.task_strikes and not task_strikes:
|
|
218
239
|
return True
|
|
219
240
|
|
|
220
|
-
|
|
241
|
+
signature = get_signature(execution.function)
|
|
221
242
|
|
|
222
243
|
try:
|
|
223
|
-
bound_args =
|
|
244
|
+
bound_args = signature.bind(*execution.args, **execution.kwargs)
|
|
224
245
|
bound_args.apply_defaults()
|
|
225
246
|
except TypeError:
|
|
226
247
|
# If we can't make sense of the arguments, just assume the task is fine
|
|
@@ -265,6 +286,8 @@ class StrikeList:
|
|
|
265
286
|
case "between": # pragma: no branch
|
|
266
287
|
lower, upper = strike_value
|
|
267
288
|
return lower <= value <= upper
|
|
289
|
+
case _: # pragma: no cover
|
|
290
|
+
raise ValueError(f"Unknown operator: {operator}")
|
|
268
291
|
except (ValueError, TypeError):
|
|
269
292
|
# If we can't make the comparison due to incompatible types, just log the
|
|
270
293
|
# error and assume the task is not stricken
|