pydocket 0.5.2__tar.gz → 0.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydocket might be problematic. Click here for more details.

Files changed (55) hide show
  1. {pydocket-0.5.2 → pydocket-0.6.1}/PKG-INFO +1 -1
  2. {pydocket-0.5.2 → pydocket-0.6.1}/chaos/tasks.py +11 -1
  3. {pydocket-0.5.2 → pydocket-0.6.1}/examples/find_and_flood.py +1 -1
  4. {pydocket-0.5.2 → pydocket-0.6.1}/pyproject.toml +1 -1
  5. {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/__init__.py +4 -0
  6. {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/annotations.py +23 -0
  7. pydocket-0.6.1/src/docket/dependencies.py +366 -0
  8. {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/docket.py +5 -4
  9. {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/execution.py +35 -12
  10. {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/worker.py +264 -226
  11. {pydocket-0.5.2 → pydocket-0.6.1}/tests/conftest.py +2 -2
  12. {pydocket-0.5.2 → pydocket-0.6.1}/tests/test_fundamentals.py +320 -2
  13. {pydocket-0.5.2 → pydocket-0.6.1}/tests/test_worker.py +69 -15
  14. {pydocket-0.5.2 → pydocket-0.6.1}/uv.lock +4 -4
  15. pydocket-0.5.2/src/docket/dependencies.py +0 -222
  16. {pydocket-0.5.2 → pydocket-0.6.1}/.cursor/rules/general.mdc +0 -0
  17. {pydocket-0.5.2 → pydocket-0.6.1}/.cursor/rules/python-style.mdc +0 -0
  18. {pydocket-0.5.2 → pydocket-0.6.1}/.github/codecov.yml +0 -0
  19. {pydocket-0.5.2 → pydocket-0.6.1}/.github/workflows/chaos.yml +0 -0
  20. {pydocket-0.5.2 → pydocket-0.6.1}/.github/workflows/ci.yml +0 -0
  21. {pydocket-0.5.2 → pydocket-0.6.1}/.github/workflows/publish.yml +0 -0
  22. {pydocket-0.5.2 → pydocket-0.6.1}/.gitignore +0 -0
  23. {pydocket-0.5.2 → pydocket-0.6.1}/.pre-commit-config.yaml +0 -0
  24. {pydocket-0.5.2 → pydocket-0.6.1}/LICENSE +0 -0
  25. {pydocket-0.5.2 → pydocket-0.6.1}/README.md +0 -0
  26. {pydocket-0.5.2 → pydocket-0.6.1}/chaos/README.md +0 -0
  27. {pydocket-0.5.2 → pydocket-0.6.1}/chaos/__init__.py +0 -0
  28. {pydocket-0.5.2 → pydocket-0.6.1}/chaos/driver.py +0 -0
  29. {pydocket-0.5.2 → pydocket-0.6.1}/chaos/producer.py +0 -0
  30. {pydocket-0.5.2 → pydocket-0.6.1}/chaos/run +0 -0
  31. {pydocket-0.5.2 → pydocket-0.6.1}/examples/__init__.py +0 -0
  32. {pydocket-0.5.2 → pydocket-0.6.1}/examples/common.py +0 -0
  33. {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/__main__.py +0 -0
  34. {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/cli.py +0 -0
  35. {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/instrumentation.py +0 -0
  36. {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/py.typed +0 -0
  37. {pydocket-0.5.2 → pydocket-0.6.1}/src/docket/tasks.py +0 -0
  38. {pydocket-0.5.2 → pydocket-0.6.1}/telemetry/.gitignore +0 -0
  39. {pydocket-0.5.2 → pydocket-0.6.1}/telemetry/start +0 -0
  40. {pydocket-0.5.2 → pydocket-0.6.1}/telemetry/stop +0 -0
  41. {pydocket-0.5.2 → pydocket-0.6.1}/tests/__init__.py +0 -0
  42. {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/__init__.py +0 -0
  43. {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/conftest.py +0 -0
  44. {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_module.py +0 -0
  45. {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_parsing.py +0 -0
  46. {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_snapshot.py +0 -0
  47. {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_striking.py +0 -0
  48. {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_tasks.py +0 -0
  49. {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_version.py +0 -0
  50. {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_worker.py +0 -0
  51. {pydocket-0.5.2 → pydocket-0.6.1}/tests/cli/test_workers.py +0 -0
  52. {pydocket-0.5.2 → pydocket-0.6.1}/tests/test_dependencies.py +0 -0
  53. {pydocket-0.5.2 → pydocket-0.6.1}/tests/test_docket.py +0 -0
  54. {pydocket-0.5.2 → pydocket-0.6.1}/tests/test_instrumentation.py +0 -0
  55. {pydocket-0.5.2 → pydocket-0.6.1}/tests/test_striking.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydocket
3
- Version: 0.5.2
3
+ Version: 0.6.1
4
4
  Summary: A distributed background task system for Python functions
5
5
  Project-URL: Homepage, https://github.com/chrisguidry/docket
6
6
  Project-URL: Bug Tracker, https://github.com/chrisguidry/docket/issues
@@ -2,17 +2,27 @@ import logging
2
2
  import sys
3
3
  import time
4
4
 
5
- from docket import CurrentDocket, Docket, Retry, TaskKey
5
+ from docket import CurrentDocket, Depends, Docket, Retry, TaskKey
6
6
 
7
7
  logger = logging.getLogger(__name__)
8
8
 
9
9
 
10
+ async def greeting() -> str:
11
+ return "Hello, world"
12
+
13
+
14
+ async def emphatic_greeting(greeting: str = Depends(greeting)) -> str:
15
+ return greeting + "!"
16
+
17
+
10
18
  async def hello(
19
+ greeting: str = Depends(emphatic_greeting),
11
20
  key: str = TaskKey(),
12
21
  docket: Docket = CurrentDocket(),
13
22
  retry: Retry = Retry(attempts=sys.maxsize),
14
23
  ):
15
24
  logger.info("Starting task %s", key)
25
+ logger.info("Greeting: %s", greeting)
16
26
  async with docket.redis() as redis:
17
27
  await redis.zadd("hello:received", {key: time.time()})
18
28
  logger.info("Finished task %s", key)
@@ -16,7 +16,7 @@ async def find(
16
16
  perpetual: Perpetual = Perpetual(every=timedelta(seconds=3), automatic=True),
17
17
  ) -> None:
18
18
  for i in range(1, 10 + 1):
19
- await docket.add(flood, key=str(i))(i)
19
+ await docket.add(flood, key=f"item-{i}")(i)
20
20
 
21
21
 
22
22
  async def flood(
@@ -43,7 +43,7 @@ dev = [
43
43
  "opentelemetry-instrumentation-redis>=0.51b0",
44
44
  "opentelemetry-sdk>=1.30.0",
45
45
  "pre-commit>=4.1.0",
46
- "pyright>=1.1.396",
46
+ "pyright>=1.1.398",
47
47
  "pytest>=8.3.4",
48
48
  "pytest-aio>=1.9.0",
49
49
  "pytest-cov>=6.0.0",
@@ -13,11 +13,13 @@ from .dependencies import (
13
13
  CurrentDocket,
14
14
  CurrentExecution,
15
15
  CurrentWorker,
16
+ Depends,
16
17
  ExponentialRetry,
17
18
  Perpetual,
18
19
  Retry,
19
20
  TaskKey,
20
21
  TaskLogger,
22
+ Timeout,
21
23
  )
22
24
  from .docket import Docket
23
25
  from .execution import Execution
@@ -36,5 +38,7 @@ __all__ = [
36
38
  "ExponentialRetry",
37
39
  "Logged",
38
40
  "Perpetual",
41
+ "Timeout",
42
+ "Depends",
39
43
  "__version__",
40
44
  ]
@@ -4,8 +4,14 @@ from typing import Any, Iterable, Mapping, Self
4
4
 
5
5
 
6
6
  class Annotation(abc.ABC):
7
+ _cache: dict[tuple[type[Self], inspect.Signature], Mapping[str, Self]] = {}
8
+
7
9
  @classmethod
8
10
  def annotated_parameters(cls, signature: inspect.Signature) -> Mapping[str, Self]:
11
+ key = (cls, signature)
12
+ if key in cls._cache:
13
+ return cls._cache[key]
14
+
9
15
  annotated: dict[str, Self] = {}
10
16
 
11
17
  for param_name, param in signature.parameters.items():
@@ -23,8 +29,25 @@ class Annotation(abc.ABC):
23
29
  elif isinstance(arg_type, type) and issubclass(arg_type, cls):
24
30
  annotated[param_name] = arg_type()
25
31
 
32
+ cls._cache[key] = annotated
26
33
  return annotated
27
34
 
28
35
 
29
36
  class Logged(Annotation):
30
37
  """Instructs docket to include arguments to this parameter in the log."""
38
+
39
+ length_only: bool = False
40
+
41
+ def __init__(self, length_only: bool = False) -> None:
42
+ self.length_only = length_only
43
+
44
+ def format(self, argument: Any) -> str:
45
+ if self.length_only:
46
+ if isinstance(argument, (dict, set)):
47
+ return f"{{len {len(argument)}}}"
48
+ elif isinstance(argument, tuple):
49
+ return f"(len {len(argument)})"
50
+ elif hasattr(argument, "__len__"):
51
+ return f"[len {len(argument)}]"
52
+
53
+ return repr(argument)
@@ -0,0 +1,366 @@
1
+ import abc
2
+ import logging
3
+ import time
4
+ from contextlib import AsyncExitStack, asynccontextmanager
5
+ from contextvars import ContextVar
6
+ from datetime import timedelta
7
+ from types import TracebackType
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ AsyncContextManager,
12
+ AsyncGenerator,
13
+ Awaitable,
14
+ Callable,
15
+ Counter,
16
+ Generic,
17
+ TypeVar,
18
+ cast,
19
+ )
20
+
21
+ from .docket import Docket
22
+ from .execution import Execution, TaskFunction, get_signature
23
+
24
+ if TYPE_CHECKING: # pragma: no cover
25
+ from .worker import Worker
26
+
27
+
28
+ class Dependency(abc.ABC):
29
+ single: bool = False
30
+
31
+ docket: ContextVar[Docket] = ContextVar("docket")
32
+ worker: ContextVar["Worker"] = ContextVar("worker")
33
+ execution: ContextVar[Execution] = ContextVar("execution")
34
+
35
+ @abc.abstractmethod
36
+ async def __aenter__(self) -> Any: ... # pragma: no cover
37
+
38
+ async def __aexit__(
39
+ self,
40
+ exc_type: type[BaseException] | None,
41
+ exc_value: BaseException | None,
42
+ traceback: TracebackType | None,
43
+ ) -> bool: ... # pragma: no cover
44
+
45
+
46
+ class _CurrentWorker(Dependency):
47
+ async def __aenter__(self) -> "Worker":
48
+ return self.worker.get()
49
+
50
+
51
+ def CurrentWorker() -> "Worker":
52
+ return cast("Worker", _CurrentWorker())
53
+
54
+
55
+ class _CurrentDocket(Dependency):
56
+ async def __aenter__(self) -> Docket:
57
+ return self.docket.get()
58
+
59
+
60
+ def CurrentDocket() -> Docket:
61
+ return cast(Docket, _CurrentDocket())
62
+
63
+
64
+ class _CurrentExecution(Dependency):
65
+ async def __aenter__(self) -> Execution:
66
+ return self.execution.get()
67
+
68
+
69
+ def CurrentExecution() -> Execution:
70
+ return cast(Execution, _CurrentExecution())
71
+
72
+
73
+ class _TaskKey(Dependency):
74
+ async def __aenter__(self) -> str:
75
+ return self.execution.get().key
76
+
77
+
78
+ def TaskKey() -> str:
79
+ return cast(str, _TaskKey())
80
+
81
+
82
+ class _TaskLogger(Dependency):
83
+ async def __aenter__(self) -> logging.LoggerAdapter[logging.Logger]:
84
+ execution = self.execution.get()
85
+ logger = logging.getLogger(f"docket.task.{execution.function.__name__}")
86
+ return logging.LoggerAdapter(
87
+ logger,
88
+ {
89
+ **self.docket.get().labels(),
90
+ **self.worker.get().labels(),
91
+ **execution.specific_labels(),
92
+ },
93
+ )
94
+
95
+
96
+ def TaskLogger() -> logging.LoggerAdapter[logging.Logger]:
97
+ return cast(logging.LoggerAdapter[logging.Logger], _TaskLogger())
98
+
99
+
100
+ class Retry(Dependency):
101
+ single: bool = True
102
+
103
+ def __init__(
104
+ self, attempts: int | None = 1, delay: timedelta = timedelta(0)
105
+ ) -> None:
106
+ self.attempts = attempts
107
+ self.delay = delay
108
+ self.attempt = 1
109
+
110
+ async def __aenter__(self) -> "Retry":
111
+ execution = self.execution.get()
112
+ retry = Retry(attempts=self.attempts, delay=self.delay)
113
+ retry.attempt = execution.attempt
114
+ return retry
115
+
116
+
117
+ class ExponentialRetry(Retry):
118
+ attempts: int
119
+
120
+ def __init__(
121
+ self,
122
+ attempts: int = 1,
123
+ minimum_delay: timedelta = timedelta(seconds=1),
124
+ maximum_delay: timedelta = timedelta(seconds=64),
125
+ ) -> None:
126
+ super().__init__(attempts=attempts, delay=minimum_delay)
127
+ self.minimum_delay = minimum_delay
128
+ self.maximum_delay = maximum_delay
129
+
130
+ async def __aenter__(self) -> "ExponentialRetry":
131
+ execution = self.execution.get()
132
+
133
+ retry = ExponentialRetry(
134
+ attempts=self.attempts,
135
+ minimum_delay=self.minimum_delay,
136
+ maximum_delay=self.maximum_delay,
137
+ )
138
+ retry.attempt = execution.attempt
139
+
140
+ if execution.attempt > 1:
141
+ backoff_factor = 2 ** (execution.attempt - 1)
142
+ calculated_delay = self.minimum_delay * backoff_factor
143
+
144
+ if calculated_delay > self.maximum_delay:
145
+ retry.delay = self.maximum_delay
146
+ else:
147
+ retry.delay = calculated_delay
148
+
149
+ return retry
150
+
151
+
152
+ class Perpetual(Dependency):
153
+ single = True
154
+
155
+ every: timedelta
156
+ automatic: bool
157
+
158
+ args: tuple[Any, ...]
159
+ kwargs: dict[str, Any]
160
+
161
+ cancelled: bool
162
+
163
+ def __init__(
164
+ self,
165
+ every: timedelta = timedelta(0),
166
+ automatic: bool = False,
167
+ ) -> None:
168
+ """Declare a task that should be run perpetually.
169
+
170
+ Args:
171
+ every: The target interval between task executions.
172
+ automatic: If set, this task will be automatically scheduled during worker
173
+ startup and continually through the worker's lifespan. This ensures
174
+ that the task will always be scheduled despite crashes and other
175
+ adverse conditions. Automatic tasks must not require any arguments.
176
+ """
177
+ self.every = every
178
+ self.automatic = automatic
179
+ self.cancelled = False
180
+
181
+ async def __aenter__(self) -> "Perpetual":
182
+ execution = self.execution.get()
183
+ perpetual = Perpetual(every=self.every)
184
+ perpetual.args = execution.args
185
+ perpetual.kwargs = execution.kwargs
186
+ return perpetual
187
+
188
+ def cancel(self) -> None:
189
+ self.cancelled = True
190
+
191
+ def perpetuate(self, *args: Any, **kwargs: Any) -> None:
192
+ self.args = args
193
+ self.kwargs = kwargs
194
+
195
+
196
+ class Timeout(Dependency):
197
+ single = True
198
+
199
+ base: timedelta
200
+
201
+ _deadline: float
202
+
203
+ def __init__(self, base: timedelta) -> None:
204
+ self.base = base
205
+
206
+ async def __aenter__(self) -> "Timeout":
207
+ timeout = Timeout(base=self.base)
208
+ timeout.start()
209
+ return timeout
210
+
211
+ def start(self) -> None:
212
+ self._deadline = time.monotonic() + self.base.total_seconds()
213
+
214
+ def expired(self) -> bool:
215
+ return time.monotonic() >= self._deadline
216
+
217
+ def remaining(self) -> timedelta:
218
+ return timedelta(seconds=self._deadline - time.monotonic())
219
+
220
+ def extend(self, by: timedelta | None = None) -> None:
221
+ if by is None:
222
+ by = self.base
223
+ self._deadline += by.total_seconds()
224
+
225
+
226
+ R = TypeVar("R")
227
+
228
+ DependencyFunction = Callable[..., Awaitable[R] | AsyncContextManager[R]]
229
+
230
+
231
+ _parameter_cache: dict[
232
+ TaskFunction | DependencyFunction[Any],
233
+ dict[str, Dependency],
234
+ ] = {}
235
+
236
+
237
+ def get_dependency_parameters(
238
+ function: TaskFunction | DependencyFunction[Any],
239
+ ) -> dict[str, Dependency]:
240
+ if function in _parameter_cache:
241
+ return _parameter_cache[function]
242
+
243
+ dependencies: dict[str, Dependency] = {}
244
+
245
+ signature = get_signature(function)
246
+
247
+ for parameter, param in signature.parameters.items():
248
+ if not isinstance(param.default, Dependency):
249
+ continue
250
+
251
+ dependencies[parameter] = param.default
252
+
253
+ _parameter_cache[function] = dependencies
254
+ return dependencies
255
+
256
+
257
+ class _Depends(Dependency, Generic[R]):
258
+ dependency: DependencyFunction[R]
259
+
260
+ cache: ContextVar[dict[DependencyFunction[Any], Any]] = ContextVar("cache")
261
+ stack: ContextVar[AsyncExitStack] = ContextVar("stack")
262
+
263
+ def __init__(
264
+ self, dependency: Callable[[], Awaitable[R] | AsyncContextManager[R]]
265
+ ) -> None:
266
+ self.dependency = dependency
267
+
268
+ async def _resolve_parameters(
269
+ self,
270
+ function: TaskFunction | DependencyFunction[Any],
271
+ ) -> dict[str, Any]:
272
+ stack = self.stack.get()
273
+
274
+ arguments: dict[str, Any] = {}
275
+ parameters = get_dependency_parameters(function)
276
+
277
+ for parameter, dependency in parameters.items():
278
+ arguments[parameter] = await stack.enter_async_context(dependency)
279
+
280
+ return arguments
281
+
282
+ async def __aenter__(self) -> R:
283
+ cache = self.cache.get()
284
+
285
+ if self.dependency in cache:
286
+ return cache[self.dependency]
287
+
288
+ stack = self.stack.get()
289
+ arguments = await self._resolve_parameters(self.dependency)
290
+
291
+ value = self.dependency(**arguments)
292
+
293
+ if isinstance(value, AsyncContextManager):
294
+ value = await stack.enter_async_context(value)
295
+ else:
296
+ value = await value
297
+
298
+ cache[self.dependency] = value
299
+ return value
300
+
301
+
302
+ def Depends(dependency: DependencyFunction[R]) -> R:
303
+ return cast(R, _Depends(dependency))
304
+
305
+
306
+ D = TypeVar("D", bound=Dependency)
307
+
308
+
309
+ def get_single_dependency_parameter_of_type(
310
+ function: TaskFunction, dependency_type: type[D]
311
+ ) -> D | None:
312
+ assert dependency_type.single, "Dependency must be single"
313
+ for _, dependency in get_dependency_parameters(function).items():
314
+ if isinstance(dependency, dependency_type):
315
+ return dependency
316
+ return None
317
+
318
+
319
+ def get_single_dependency_of_type(
320
+ dependencies: dict[str, Dependency], dependency_type: type[D]
321
+ ) -> D | None:
322
+ assert dependency_type.single, "Dependency must be single"
323
+ for _, dependency in dependencies.items():
324
+ if isinstance(dependency, dependency_type):
325
+ return dependency
326
+ return None
327
+
328
+
329
+ def validate_dependencies(function: TaskFunction) -> None:
330
+ parameters = get_dependency_parameters(function)
331
+
332
+ counts = Counter(type(dependency) for dependency in parameters.values())
333
+
334
+ for dependency_type, count in counts.items():
335
+ if dependency_type.single and count > 1:
336
+ raise ValueError(
337
+ f"Only one {dependency_type.__name__} dependency is allowed per task"
338
+ )
339
+
340
+
341
+ @asynccontextmanager
342
+ async def resolved_dependencies(
343
+ worker: "Worker", execution: Execution
344
+ ) -> AsyncGenerator[dict[str, Any], None]:
345
+ # Set context variables once at the beginning
346
+ Dependency.docket.set(worker.docket)
347
+ Dependency.worker.set(worker)
348
+ Dependency.execution.set(execution)
349
+
350
+ _Depends.cache.set({})
351
+
352
+ async with AsyncExitStack() as stack:
353
+ _Depends.stack.set(stack)
354
+
355
+ arguments: dict[str, Any] = {}
356
+
357
+ parameters = get_dependency_parameters(execution.function)
358
+ for parameter, dependency in parameters.items():
359
+ kwargs = execution.kwargs
360
+ if parameter in kwargs:
361
+ arguments[parameter] = kwargs[parameter]
362
+ continue
363
+
364
+ arguments[parameter] = await stack.enter_async_context(dependency)
365
+
366
+ yield arguments
@@ -38,6 +38,7 @@ from .execution import (
38
38
  Strike,
39
39
  StrikeInstruction,
40
40
  StrikeList,
41
+ TaskFunction,
41
42
  )
42
43
  from .instrumentation import (
43
44
  REDIS_DISRUPTIONS,
@@ -57,7 +58,7 @@ tracer: trace.Tracer = trace.get_tracer(__name__)
57
58
  P = ParamSpec("P")
58
59
  R = TypeVar("R")
59
60
 
60
- TaskCollection = Iterable[Callable[..., Awaitable[Any]]]
61
+ TaskCollection = Iterable[TaskFunction]
61
62
 
62
63
  RedisStreamID = bytes
63
64
  RedisMessageID = bytes
@@ -91,7 +92,7 @@ class RunningExecution(Execution):
91
92
  worker: str,
92
93
  started: datetime,
93
94
  ) -> None:
94
- self.function: Callable[..., Awaitable[Any]] = execution.function
95
+ self.function: TaskFunction = execution.function
95
96
  self.args: tuple[Any, ...] = execution.args
96
97
  self.kwargs: dict[str, Any] = execution.kwargs
97
98
  self.when: datetime = execution.when
@@ -111,7 +112,7 @@ class DocketSnapshot:
111
112
 
112
113
 
113
114
  class Docket:
114
- tasks: dict[str, Callable[..., Awaitable[Any]]]
115
+ tasks: dict[str, TaskFunction]
115
116
  strike_list: StrikeList
116
117
 
117
118
  _monitor_strikes_task: asyncio.Task[None]
@@ -197,7 +198,7 @@ class Docket:
197
198
  finally:
198
199
  await asyncio.shield(r.__aexit__(None, None, None))
199
200
 
200
- def register(self, function: Callable[..., Awaitable[Any]]) -> None:
201
+ def register(self, function: TaskFunction) -> None:
201
202
  from .dependencies import validate_dependencies
202
203
 
203
204
  validate_dependencies(function)
@@ -7,23 +7,40 @@ from typing import Any, Awaitable, Callable, Hashable, Literal, Mapping, Self, c
7
7
 
8
8
  import cloudpickle # type: ignore[import]
9
9
 
10
+ from opentelemetry import trace, propagate
11
+ import opentelemetry.context
10
12
 
11
13
  from .annotations import Logged
14
+ from docket.instrumentation import message_getter
12
15
 
13
16
  logger: logging.Logger = logging.getLogger(__name__)
14
17
 
18
+ TaskFunction = Callable[..., Awaitable[Any]]
15
19
  Message = dict[bytes, bytes]
16
20
 
17
21
 
22
+ _signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
23
+
24
+
25
+ def get_signature(function: Callable[..., Any]) -> inspect.Signature:
26
+ if function in _signature_cache:
27
+ return _signature_cache[function]
28
+
29
+ signature = inspect.signature(function)
30
+ _signature_cache[function] = signature
31
+ return signature
32
+
33
+
18
34
  class Execution:
19
35
  def __init__(
20
36
  self,
21
- function: Callable[..., Awaitable[Any]],
37
+ function: TaskFunction,
22
38
  args: tuple[Any, ...],
23
39
  kwargs: dict[str, Any],
24
40
  when: datetime,
25
41
  key: str,
26
42
  attempt: int,
43
+ trace_context: opentelemetry.context.Context | None = None,
27
44
  ) -> None:
28
45
  self.function = function
29
46
  self.args = args
@@ -31,6 +48,7 @@ class Execution:
31
48
  self.when = when
32
49
  self.key = key
33
50
  self.attempt = attempt
51
+ self.trace_context = trace_context
34
52
 
35
53
  def as_message(self) -> Message:
36
54
  return {
@@ -43,9 +61,7 @@ class Execution:
43
61
  }
44
62
 
45
63
  @classmethod
46
- def from_message(
47
- cls, function: Callable[..., Awaitable[Any]], message: Message
48
- ) -> Self:
64
+ def from_message(cls, function: TaskFunction, message: Message) -> Self:
49
65
  return cls(
50
66
  function=function,
51
67
  args=cloudpickle.loads(message[b"args"]),
@@ -53,6 +69,7 @@ class Execution:
53
69
  when=datetime.fromisoformat(message[b"when"].decode()),
54
70
  key=message[b"key"].decode(),
55
71
  attempt=int(message[b"attempt"].decode()),
72
+ trace_context=propagate.extract(message, getter=message_getter),
56
73
  )
57
74
 
58
75
  def general_labels(self) -> Mapping[str, str]:
@@ -68,28 +85,32 @@ class Execution:
68
85
 
69
86
  def call_repr(self) -> str:
70
87
  arguments: list[str] = []
71
- signature = inspect.signature(self.function)
72
88
  function_name = self.function.__name__
73
89
 
90
+ signature = get_signature(self.function)
74
91
  logged_parameters = Logged.annotated_parameters(signature)
75
-
76
92
  parameter_names = list(signature.parameters.keys())
77
93
 
78
94
  for i, argument in enumerate(self.args[: len(parameter_names)]):
79
95
  parameter_name = parameter_names[i]
80
- if parameter_name in logged_parameters:
81
- arguments.append(repr(argument))
96
+ if logged := logged_parameters.get(parameter_name):
97
+ arguments.append(logged.format(argument))
82
98
  else:
83
99
  arguments.append("...")
84
100
 
85
101
  for parameter_name, argument in self.kwargs.items():
86
- if parameter_name in logged_parameters:
87
- arguments.append(f"{parameter_name}={repr(argument)}")
102
+ if logged := logged_parameters.get(parameter_name):
103
+ arguments.append(f"{parameter_name}={logged.format(argument)}")
88
104
  else:
89
105
  arguments.append(f"{parameter_name}=...")
90
106
 
91
107
  return f"{function_name}({', '.join(arguments)}){{{self.key}}}"
92
108
 
109
+ def incoming_span_links(self) -> list[trace.Link]:
110
+ initiating_span = trace.get_current_span(self.trace_context)
111
+ initiating_context = initiating_span.get_span_context()
112
+ return [trace.Link(initiating_context)] if initiating_context.is_valid else []
113
+
93
114
 
94
115
  class Operator(enum.StrEnum):
95
116
  EQUAL = "=="
@@ -217,10 +238,10 @@ class StrikeList:
217
238
  if function_name in self.task_strikes and not task_strikes:
218
239
  return True
219
240
 
220
- sig = inspect.signature(execution.function)
241
+ signature = get_signature(execution.function)
221
242
 
222
243
  try:
223
- bound_args = sig.bind(*execution.args, **execution.kwargs)
244
+ bound_args = signature.bind(*execution.args, **execution.kwargs)
224
245
  bound_args.apply_defaults()
225
246
  except TypeError:
226
247
  # If we can't make sense of the arguments, just assume the task is fine
@@ -265,6 +286,8 @@ class StrikeList:
265
286
  case "between": # pragma: no branch
266
287
  lower, upper = strike_value
267
288
  return lower <= value <= upper
289
+ case _: # pragma: no cover
290
+ raise ValueError(f"Unknown operator: {operator}")
268
291
  except (ValueError, TypeError):
269
292
  # If we can't make the comparison due to incompatible types, just log the
270
293
  # error and assume the task is not stricken