pydocket 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydocket might be problematic. Click here for more details.

docket/__init__.py CHANGED
@@ -8,7 +8,16 @@ from importlib.metadata import version
8
8
 
9
9
  __version__ = version("pydocket")
10
10
 
11
- from .dependencies import CurrentDocket, CurrentWorker, Retry
11
+ from .annotations import Logged
12
+ from .dependencies import (
13
+ CurrentDocket,
14
+ CurrentExecution,
15
+ CurrentWorker,
16
+ ExponentialRetry,
17
+ Retry,
18
+ TaskKey,
19
+ TaskLogger,
20
+ )
12
21
  from .docket import Docket
13
22
  from .execution import Execution
14
23
  from .worker import Worker
@@ -19,6 +28,11 @@ __all__ = [
19
28
  "Execution",
20
29
  "CurrentDocket",
21
30
  "CurrentWorker",
31
+ "CurrentExecution",
32
+ "TaskKey",
33
+ "TaskLogger",
22
34
  "Retry",
35
+ "ExponentialRetry",
36
+ "Logged",
23
37
  "__version__",
24
38
  ]
docket/annotations.py ADDED
@@ -0,0 +1,30 @@
1
+ import abc
2
+ import inspect
3
+ from typing import Any, Iterable, Mapping, Self
4
+
5
+
6
+ class Annotation(abc.ABC):
7
+ @classmethod
8
+ def annotated_parameters(cls, signature: inspect.Signature) -> Mapping[str, Self]:
9
+ annotated: dict[str, Self] = {}
10
+
11
+ for param_name, param in signature.parameters.items():
12
+ if param.annotation == inspect.Parameter.empty:
13
+ continue
14
+
15
+ try:
16
+ metadata: Iterable[Any] = param.annotation.__metadata__
17
+ except AttributeError:
18
+ continue
19
+
20
+ for arg_type in metadata:
21
+ if isinstance(arg_type, cls):
22
+ annotated[param_name] = arg_type
23
+ elif isinstance(arg_type, type) and issubclass(arg_type, cls):
24
+ annotated[param_name] = arg_type()
25
+
26
+ return annotated
27
+
28
+
29
+ class Logged(Annotation):
30
+ """Instructs docket to include arguments to this parameter in the log."""
docket/cli.py CHANGED
@@ -1,6 +1,16 @@
1
+ import asyncio
2
+ import enum
3
+ import logging
4
+ import socket
5
+ import sys
6
+ from datetime import timedelta
7
+ from typing import Annotated
8
+
1
9
  import typer
2
10
 
3
- from docket import __version__
11
+ from . import __version__, tasks
12
+ from .docket import Docket
13
+ from .worker import Worker
4
14
 
5
15
  app: typer.Typer = typer.Typer(
6
16
  help="Docket - A distributed background task system for Python functions",
@@ -9,11 +19,232 @@ app: typer.Typer = typer.Typer(
9
19
  )
10
20
 
11
21
 
22
+ class LogLevel(enum.StrEnum):
23
+ DEBUG = "DEBUG"
24
+ INFO = "INFO"
25
+ WARNING = "WARNING"
26
+ ERROR = "ERROR"
27
+ CRITICAL = "CRITICAL"
28
+
29
+
30
+ class LogFormat(enum.StrEnum):
31
+ RICH = "rich"
32
+ PLAIN = "plain"
33
+ JSON = "json"
34
+
35
+
36
+ def duration(duration_str: str | timedelta) -> timedelta:
37
+ """
38
+ Parse a duration string into a timedelta.
39
+
40
+ Supported formats:
41
+ - 123 = 123 seconds
42
+ - 123s = 123 seconds
43
+ - 123m = 123 minutes
44
+ - 123h = 123 hours
45
+ - 00:00 = mm:ss
46
+ - 00:00:00 = hh:mm:ss
47
+ """
48
+ if isinstance(duration_str, timedelta):
49
+ return duration_str
50
+
51
+ if ":" in duration_str:
52
+ parts = duration_str.split(":")
53
+ if len(parts) == 2: # mm:ss
54
+ minutes, seconds = map(int, parts)
55
+ return timedelta(minutes=minutes, seconds=seconds)
56
+ elif len(parts) == 3: # hh:mm:ss
57
+ hours, minutes, seconds = map(int, parts)
58
+ return timedelta(hours=hours, minutes=minutes, seconds=seconds)
59
+ else:
60
+ raise ValueError(f"Invalid duration string: {duration_str}")
61
+ elif duration_str.endswith("s"):
62
+ return timedelta(seconds=int(duration_str[:-1]))
63
+ elif duration_str.endswith("m"):
64
+ return timedelta(minutes=int(duration_str[:-1]))
65
+ elif duration_str.endswith("h"):
66
+ return timedelta(hours=int(duration_str[:-1]))
67
+ else:
68
+ return timedelta(seconds=int(duration_str))
69
+
70
+
71
+ def set_logging_format(format: LogFormat) -> None:
72
+ root_logger = logging.getLogger()
73
+ if format == LogFormat.JSON:
74
+ from pythonjsonlogger.json import JsonFormatter
75
+
76
+ formatter = JsonFormatter(
77
+ "{name}{asctime}{levelname}{message}{exc_info}", style="{"
78
+ )
79
+ handler = logging.StreamHandler(stream=sys.stdout)
80
+ handler.setFormatter(formatter)
81
+ root_logger.addHandler(handler)
82
+ elif format == LogFormat.PLAIN:
83
+ handler = logging.StreamHandler(stream=sys.stdout)
84
+ formatter = logging.Formatter(
85
+ "[%(asctime)s] %(levelname)s - %(name)s - %(message)s",
86
+ datefmt="%Y-%m-%d %H:%M:%S",
87
+ )
88
+ handler.setFormatter(formatter)
89
+ root_logger.addHandler(handler)
90
+ else:
91
+ from rich.logging import RichHandler
92
+
93
+ handler = RichHandler()
94
+ formatter = logging.Formatter("%(message)s", datefmt="[%X]")
95
+ handler.setFormatter(formatter)
96
+ root_logger.addHandler(handler)
97
+
98
+
99
+ def set_logging_level(level: LogLevel) -> None:
100
+ logging.getLogger().setLevel(level)
101
+
102
+
12
103
  @app.command(
13
104
  help="Start a worker to process tasks",
14
105
  )
15
- def worker() -> None:
16
- print("TODO: start the worker")
106
+ def worker(
107
+ tasks: Annotated[
108
+ list[str],
109
+ typer.Option(
110
+ "--tasks",
111
+ help=(
112
+ "The dotted path of a task collection to register with the docket. "
113
+ "This can be specified multiple times. A task collection is any "
114
+ "iterable of async functions."
115
+ ),
116
+ ),
117
+ ] = ["docket.tasks:standard_tasks"],
118
+ docket_: Annotated[
119
+ str,
120
+ typer.Option(
121
+ "--docket",
122
+ help="The name of the docket",
123
+ envvar="DOCKET_NAME",
124
+ ),
125
+ ] = "docket",
126
+ url: Annotated[
127
+ str,
128
+ typer.Option(
129
+ help="The URL of the Redis server",
130
+ envvar="DOCKET_URL",
131
+ ),
132
+ ] = "redis://localhost:6379/0",
133
+ name: Annotated[
134
+ str | None,
135
+ typer.Option(
136
+ help="The name of the worker",
137
+ envvar="DOCKET_WORKER_NAME",
138
+ ),
139
+ ] = socket.gethostname(),
140
+ logging_level: Annotated[
141
+ LogLevel,
142
+ typer.Option(
143
+ help="The logging level",
144
+ envvar="DOCKET_LOGGING_LEVEL",
145
+ callback=set_logging_level,
146
+ ),
147
+ ] = LogLevel.INFO,
148
+ logging_format: Annotated[
149
+ LogFormat,
150
+ typer.Option(
151
+ help="The logging format",
152
+ envvar="DOCKET_LOGGING_FORMAT",
153
+ callback=set_logging_format,
154
+ ),
155
+ ] = LogFormat.RICH if sys.stdout.isatty() else LogFormat.PLAIN,
156
+ prefetch_count: Annotated[
157
+ int,
158
+ typer.Option(
159
+ help="The number of tasks to request from the docket at a time",
160
+ envvar="DOCKET_WORKER_PREFETCH_COUNT",
161
+ ),
162
+ ] = 10,
163
+ redelivery_timeout: Annotated[
164
+ timedelta,
165
+ typer.Option(
166
+ parser=duration,
167
+ help="How long to wait before redelivering a task to another worker",
168
+ envvar="DOCKET_WORKER_REDELIVERY_TIMEOUT",
169
+ ),
170
+ ] = timedelta(minutes=5),
171
+ reconnection_delay: Annotated[
172
+ timedelta,
173
+ typer.Option(
174
+ parser=duration,
175
+ help=(
176
+ "How long to wait before reconnecting to the Redis server after "
177
+ "a connection error"
178
+ ),
179
+ envvar="DOCKET_WORKER_RECONNECTION_DELAY",
180
+ ),
181
+ ] = timedelta(seconds=5),
182
+ until_finished: Annotated[
183
+ bool,
184
+ typer.Option(
185
+ "--until-finished",
186
+ help="Exit after the current docket is finished",
187
+ ),
188
+ ] = False,
189
+ ) -> None:
190
+ asyncio.run(
191
+ Worker.run(
192
+ docket_name=docket_,
193
+ url=url,
194
+ name=name,
195
+ prefetch_count=prefetch_count,
196
+ redelivery_timeout=redelivery_timeout,
197
+ reconnection_delay=reconnection_delay,
198
+ until_finished=until_finished,
199
+ tasks=tasks,
200
+ )
201
+ )
202
+
203
+
204
+ @app.command(help="Adds a trace task to the Docket")
205
+ def trace(
206
+ docket_: Annotated[
207
+ str,
208
+ typer.Option(
209
+ "--docket",
210
+ help="The name of the docket",
211
+ envvar="DOCKET_NAME",
212
+ ),
213
+ ] = "docket",
214
+ url: Annotated[
215
+ str,
216
+ typer.Option(
217
+ help="The URL of the Redis server",
218
+ envvar="DOCKET_URL",
219
+ ),
220
+ ] = "redis://localhost:6379/0",
221
+ message: Annotated[
222
+ str,
223
+ typer.Argument(
224
+ help="The message to print",
225
+ ),
226
+ ] = "Howdy!",
227
+ error: Annotated[
228
+ bool,
229
+ typer.Option(
230
+ "--error",
231
+ help="Intentionally raise an error",
232
+ ),
233
+ ] = False,
234
+ ) -> None:
235
+ async def run() -> None:
236
+ async with Docket(name=docket_, url=url) as docket:
237
+ if error:
238
+ execution = await docket.add(tasks.fail)(message)
239
+ else:
240
+ execution = await docket.add(tasks.trace)(message)
241
+
242
+ print(
243
+ f"Added {execution.function.__name__} task {execution.key!r} to "
244
+ f"the docket {docket.name!r}"
245
+ )
246
+
247
+ asyncio.run(run())
17
248
 
18
249
 
19
250
  @app.command(
docket/dependencies.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import abc
2
2
  import inspect
3
+ import logging
3
4
  from datetime import timedelta
4
5
  from typing import Any, Awaitable, Callable, Counter, cast
5
6
 
@@ -35,10 +36,52 @@ def CurrentDocket() -> Docket:
35
36
  return cast(Docket, _CurrentDocket())
36
37
 
37
38
 
39
+ class _CurrentExecution(Dependency):
40
+ def __call__(
41
+ self, docket: Docket, worker: Worker, execution: Execution
42
+ ) -> Execution:
43
+ return execution
44
+
45
+
46
+ def CurrentExecution() -> Execution:
47
+ return cast(Execution, _CurrentExecution())
48
+
49
+
50
+ class _TaskKey(Dependency):
51
+ def __call__(self, docket: Docket, worker: Worker, execution: Execution) -> str:
52
+ return execution.key
53
+
54
+
55
+ def TaskKey() -> str:
56
+ return cast(str, _TaskKey())
57
+
58
+
59
+ class _TaskLogger(Dependency):
60
+ def __call__(
61
+ self, docket: Docket, worker: Worker, execution: Execution
62
+ ) -> logging.LoggerAdapter:
63
+ logger = logging.getLogger(f"docket.task.{execution.function.__name__}")
64
+
65
+ extra = {
66
+ "execution.key": execution.key,
67
+ "execution.attempt": execution.attempt,
68
+ "worker.name": worker.name,
69
+ "docket.name": docket.name,
70
+ }
71
+
72
+ return logging.LoggerAdapter(logger, extra)
73
+
74
+
75
+ def TaskLogger() -> logging.LoggerAdapter[logging.Logger]:
76
+ return cast(logging.LoggerAdapter[logging.Logger], _TaskLogger())
77
+
78
+
38
79
  class Retry(Dependency):
39
80
  single: bool = True
40
81
 
41
- def __init__(self, attempts: int = 1, delay: timedelta = timedelta(0)) -> None:
82
+ def __init__(
83
+ self, attempts: int | None = 1, delay: timedelta = timedelta(0)
84
+ ) -> None:
42
85
  self.attempts = attempts
43
86
  self.delay = delay
44
87
  self.attempt = 1
@@ -49,6 +92,41 @@ class Retry(Dependency):
49
92
  return retry
50
93
 
51
94
 
95
+ class ExponentialRetry(Retry):
96
+ attempts: int
97
+
98
+ def __init__(
99
+ self,
100
+ attempts: int = 1,
101
+ minimum_delay: timedelta = timedelta(seconds=1),
102
+ maximum_delay: timedelta = timedelta(seconds=64),
103
+ ) -> None:
104
+ super().__init__(attempts=attempts, delay=minimum_delay)
105
+ self.minimum_delay = minimum_delay
106
+ self.maximum_delay = maximum_delay
107
+
108
+ def __call__(
109
+ self, docket: Docket, worker: Worker, execution: Execution
110
+ ) -> "ExponentialRetry":
111
+ retry = ExponentialRetry(
112
+ attempts=self.attempts,
113
+ minimum_delay=self.minimum_delay,
114
+ maximum_delay=self.maximum_delay,
115
+ )
116
+ retry.attempt = execution.attempt
117
+
118
+ if execution.attempt > 1:
119
+ backoff_factor = 2 ** (execution.attempt - 1)
120
+ calculated_delay = self.minimum_delay * backoff_factor
121
+
122
+ if calculated_delay > self.maximum_delay:
123
+ retry.delay = self.maximum_delay
124
+ else:
125
+ retry.delay = calculated_delay
126
+
127
+ return retry
128
+
129
+
52
130
  def get_dependency_parameters(
53
131
  function: Callable[..., Awaitable[Any]],
54
132
  ) -> dict[str, Dependency]:
docket/docket.py CHANGED
@@ -1,3 +1,4 @@
1
+ import importlib
1
2
  from contextlib import asynccontextmanager
2
3
  from datetime import datetime, timezone
3
4
  from types import TracebackType
@@ -6,6 +7,7 @@ from typing import (
6
7
  AsyncGenerator,
7
8
  Awaitable,
8
9
  Callable,
10
+ Iterable,
9
11
  ParamSpec,
10
12
  Self,
11
13
  TypeVar,
@@ -13,13 +15,26 @@ from typing import (
13
15
  )
14
16
  from uuid import uuid4
15
17
 
18
+ from opentelemetry import propagate, trace
16
19
  from redis.asyncio import Redis
17
20
 
18
21
  from .execution import Execution
22
+ from .instrumentation import (
23
+ TASKS_ADDED,
24
+ TASKS_CANCELLED,
25
+ TASKS_REPLACED,
26
+ TASKS_SCHEDULED,
27
+ message_setter,
28
+ )
29
+
30
+ tracer: trace.Tracer = trace.get_tracer(__name__)
31
+
19
32
 
20
33
  P = ParamSpec("P")
21
34
  R = TypeVar("R")
22
35
 
36
+ TaskCollection = Iterable[Callable[..., Awaitable[Any]]]
37
+
23
38
 
24
39
  class Docket:
25
40
  tasks: dict[str, Callable[..., Awaitable[Any]]]
@@ -27,19 +42,26 @@ class Docket:
27
42
  def __init__(
28
43
  self,
29
44
  name: str = "docket",
30
- host: str = "localhost",
31
- port: int = 6379,
32
- db: int = 0,
33
- password: str | None = None,
45
+ url: str = "redis://localhost:6379/0",
34
46
  ) -> None:
47
+ """
48
+ Args:
49
+ name: The name of the docket.
50
+ url: The URL of the Redis server. For example:
51
+ - "redis://localhost:6379/0"
52
+ - "redis://user:password@localhost:6379/0"
53
+ - "redis://user:password@localhost:6379/0?ssl=true"
54
+ - "rediss://localhost:6379/0"
55
+ - "unix:///path/to/redis.sock"
56
+ """
35
57
  self.name = name
36
- self.host = host
37
- self.port = port
38
- self.db = db
39
- self.password = password
58
+ self.url = url
40
59
 
41
60
  async def __aenter__(self) -> Self:
42
- self.tasks = {}
61
+ from .tasks import standard_tasks
62
+
63
+ self.tasks = {fn.__name__: fn for fn in standard_tasks}
64
+
43
65
  return self
44
66
 
45
67
  async def __aexit__(
@@ -52,13 +74,7 @@ class Docket:
52
74
 
53
75
  @asynccontextmanager
54
76
  async def redis(self) -> AsyncGenerator[Redis, None]:
55
- async with Redis(
56
- host=self.host,
57
- port=self.port,
58
- db=self.db,
59
- password=self.password,
60
- single_connection_client=True,
61
- ) as redis:
77
+ async with Redis.from_url(self.url) as redis:
62
78
  yield redis
63
79
 
64
80
  def register(self, function: Callable[..., Awaitable[Any]]) -> None:
@@ -68,6 +84,19 @@ class Docket:
68
84
 
69
85
  self.tasks[function.__name__] = function
70
86
 
87
+ def register_collection(self, collection_path: str) -> None:
88
+ """
89
+ Register a collection of tasks.
90
+
91
+ Args:
92
+ collection_path: A path in the format "module:collection".
93
+ """
94
+ module_name, _, member_name = collection_path.rpartition(":")
95
+ module = importlib.import_module(module_name)
96
+ collection = getattr(module, member_name)
97
+ for function in collection:
98
+ self.register(function)
99
+
71
100
  @overload
72
101
  def add(
73
102
  self,
@@ -104,6 +133,9 @@ class Docket:
104
133
  async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
105
134
  execution = Execution(function, args, kwargs, when, key, attempt=1)
106
135
  await self.schedule(execution)
136
+
137
+ TASKS_ADDED.add(1, {"docket": self.name, "task": function.__name__})
138
+
107
139
  return execution
108
140
 
109
141
  return scheduler
@@ -137,6 +169,9 @@ class Docket:
137
169
  execution = Execution(function, args, kwargs, when, key, attempt=1)
138
170
  await self.cancel(key)
139
171
  await self.schedule(execution)
172
+
173
+ TASKS_REPLACED.add(1, {"docket": self.name, "task": function.__name__})
174
+
140
175
  return execution
141
176
 
142
177
  return scheduler
@@ -154,25 +189,50 @@ class Docket:
154
189
 
155
190
  async def schedule(self, execution: Execution) -> None:
156
191
  message: dict[bytes, bytes] = execution.as_message()
157
- key = execution.key
158
- when = execution.when
192
+ propagate.inject(message, setter=message_setter)
193
+
194
+ with tracer.start_as_current_span(
195
+ "docket.schedule",
196
+ attributes={
197
+ "docket.name": self.name,
198
+ "docket.execution.when": execution.when.isoformat(),
199
+ "docket.execution.key": execution.key,
200
+ "docket.execution.attempt": execution.attempt,
201
+ "code.function.name": execution.function.__name__,
202
+ },
203
+ ):
204
+ key = execution.key
205
+ when = execution.when
206
+
207
+ async with self.redis() as redis:
208
+ # if the task is already in the queue, retain it
209
+ if await redis.zscore(self.queue_key, key) is not None:
210
+ return
211
+
212
+ if when <= datetime.now(timezone.utc):
213
+ await redis.xadd(self.stream_key, message)
214
+ else:
215
+ async with redis.pipeline() as pipe:
216
+ pipe.hset(self.parked_task_key(key), mapping=message)
217
+ pipe.zadd(self.queue_key, {key: when.timestamp()})
218
+ await pipe.execute()
219
+
220
+ TASKS_SCHEDULED.add(
221
+ 1, {"docket": self.name, "task": execution.function.__name__}
222
+ )
159
223
 
160
- async with self.redis() as redis:
161
- # if the task is already in the queue, retain it
162
- if await redis.zscore(self.queue_key, key) is not None:
163
- return
164
-
165
- if when <= datetime.now(timezone.utc):
166
- await redis.xadd(self.stream_key, message)
167
- else:
224
+ async def cancel(self, key: str) -> None:
225
+ with tracer.start_as_current_span(
226
+ "docket.cancel",
227
+ attributes={
228
+ "docket.name": self.name,
229
+ "docket.execution.key": key,
230
+ },
231
+ ):
232
+ async with self.redis() as redis:
168
233
  async with redis.pipeline() as pipe:
169
- pipe.hset(self.parked_task_key(key), mapping=message)
170
- pipe.zadd(self.queue_key, {key: when.timestamp()})
234
+ pipe.delete(self.parked_task_key(key))
235
+ pipe.zrem(self.queue_key, key)
171
236
  await pipe.execute()
172
237
 
173
- async def cancel(self, key: str) -> None:
174
- async with self.redis() as redis:
175
- async with redis.pipeline() as pipe:
176
- pipe.delete(self.parked_task_key(key))
177
- pipe.zrem(self.queue_key, key)
178
- await pipe.execute()
238
+ TASKS_CANCELLED.add(1, {"docket": self.name})
docket/execution.py CHANGED
@@ -1,8 +1,11 @@
1
+ import inspect
1
2
  from datetime import datetime
2
3
  from typing import Any, Awaitable, Callable, Self
3
4
 
4
5
  import cloudpickle
5
6
 
7
+ from docket.annotations import Logged
8
+
6
9
  Message = dict[bytes, bytes]
7
10
 
8
11
 
@@ -45,3 +48,27 @@ class Execution:
45
48
  key=message[b"key"].decode(),
46
49
  attempt=int(message[b"attempt"].decode()),
47
50
  )
51
+
52
+ def call_repr(self) -> str:
53
+ arguments: list[str] = []
54
+ signature = inspect.signature(self.function)
55
+ function_name = self.function.__name__
56
+
57
+ logged_parameters = Logged.annotated_parameters(signature)
58
+
59
+ parameter_names = list(signature.parameters.keys())
60
+
61
+ for i, argument in enumerate(self.args[: len(parameter_names)]):
62
+ parameter_name = parameter_names[i]
63
+ if parameter_name in logged_parameters:
64
+ arguments.append(repr(argument))
65
+ else:
66
+ arguments.append("...")
67
+
68
+ for parameter_name, argument in self.kwargs.items():
69
+ if parameter_name in logged_parameters:
70
+ arguments.append(f"{parameter_name}={repr(argument)}")
71
+ else:
72
+ arguments.append(f"{parameter_name}=...")
73
+
74
+ return f"{function_name}({', '.join(arguments)}){{{self.key}}}"
@@ -0,0 +1,103 @@
1
+ from opentelemetry import metrics
2
+ from opentelemetry.propagators.textmap import Getter, Setter
3
+
4
+ meter: metrics.Meter = metrics.get_meter("docket")
5
+
6
+ TASKS_ADDED = meter.create_counter(
7
+ "docket_tasks_added",
8
+ description="How many tasks added to the docket",
9
+ unit="1",
10
+ )
11
+
12
+ TASKS_REPLACED = meter.create_counter(
13
+ "docket_tasks_replaced",
14
+ description="How many tasks replaced on the docket",
15
+ unit="1",
16
+ )
17
+
18
+ TASKS_SCHEDULED = meter.create_counter(
19
+ "docket_tasks_scheduled",
20
+ description="How many tasks added or replaced on the docket",
21
+ unit="1",
22
+ )
23
+
24
+ TASKS_CANCELLED = meter.create_counter(
25
+ "docket_tasks_cancelled",
26
+ description="How many tasks cancelled from the docket",
27
+ unit="1",
28
+ )
29
+
30
+ TASKS_STARTED = meter.create_counter(
31
+ "docket_tasks_started",
32
+ description="How many tasks started",
33
+ unit="1",
34
+ )
35
+
36
+ TASKS_COMPLETED = meter.create_counter(
37
+ "docket_tasks_completed",
38
+ description="How many tasks that have completed in any state",
39
+ unit="1",
40
+ )
41
+
42
+ TASKS_FAILED = meter.create_counter(
43
+ "docket_tasks_failed",
44
+ description="How many tasks that have failed",
45
+ unit="1",
46
+ )
47
+
48
+ TASKS_SUCCEEDED = meter.create_counter(
49
+ "docket_tasks_succeeded",
50
+ description="How many tasks that have succeeded",
51
+ unit="1",
52
+ )
53
+
54
+ TASKS_RETRIED = meter.create_counter(
55
+ "docket_tasks_retried",
56
+ description="How many tasks that have been retried",
57
+ unit="1",
58
+ )
59
+
60
+ TASK_DURATION = meter.create_histogram(
61
+ "docket_task_duration",
62
+ description="How long tasks take to complete",
63
+ unit="s",
64
+ )
65
+
66
+ TASK_PUNCTUALITY = meter.create_histogram(
67
+ "docket_task_punctuality",
68
+ description="How close a task was to its scheduled time",
69
+ unit="s",
70
+ )
71
+
72
+ TASKS_RUNNING = meter.create_up_down_counter(
73
+ "docket_tasks_running",
74
+ description="How many tasks that are currently running",
75
+ unit="1",
76
+ )
77
+
78
+ Message = dict[bytes, bytes]
79
+
80
+
81
+ class MessageGetter(Getter[Message]):
82
+ def get(self, carrier: Message, key: str) -> list[str] | None:
83
+ val = carrier.get(key.encode(), None)
84
+ if val is None:
85
+ return None
86
+ return [val.decode()]
87
+
88
+ def keys(self, carrier: Message) -> list[str]:
89
+ return [key.decode() for key in carrier.keys()]
90
+
91
+
92
+ class MessageSetter(Setter[Message]):
93
+ def set(
94
+ self,
95
+ carrier: Message,
96
+ key: str,
97
+ value: str,
98
+ ) -> None:
99
+ carrier[key.encode()] = value.encode()
100
+
101
+
102
+ message_getter: MessageGetter = MessageGetter()
103
+ message_setter: MessageSetter = MessageSetter()
docket/tasks.py ADDED
@@ -0,0 +1,50 @@
1
+ import logging
2
+ from datetime import datetime, timezone
3
+
4
+ from .dependencies import CurrentDocket, CurrentExecution, CurrentWorker, Retry
5
+ from .docket import Docket, TaskCollection
6
+ from .execution import Execution
7
+ from .worker import Worker
8
+
9
+ logger: logging.Logger = logging.getLogger(__name__)
10
+
11
+
12
+ async def trace(
13
+ message: str,
14
+ docket: Docket = CurrentDocket(),
15
+ worker: Worker = CurrentWorker(),
16
+ execution: Execution = CurrentExecution(),
17
+ ) -> None:
18
+ logger.info(
19
+ "%s: %r added to docket %r %s ago now running on worker %r",
20
+ message,
21
+ execution.key,
22
+ docket.name,
23
+ (datetime.now(timezone.utc) - execution.when),
24
+ worker.name,
25
+ extra={
26
+ "docket.name": docket.name,
27
+ "worker.name": worker.name,
28
+ "execution.key": execution.key,
29
+ },
30
+ )
31
+
32
+
33
+ async def fail(
34
+ message: str,
35
+ docket: Docket = CurrentDocket(),
36
+ worker: Worker = CurrentWorker(),
37
+ execution: Execution = CurrentExecution(),
38
+ retry: Retry = Retry(attempts=2),
39
+ ) -> None:
40
+ raise Exception(
41
+ f"{message}: {execution.key} added to docket "
42
+ f"{docket.name} {datetime.now(timezone.utc) - execution.when} "
43
+ f"ago now running on worker {worker.name}"
44
+ )
45
+
46
+
47
+ standard_tasks: TaskCollection = [
48
+ trace,
49
+ fail,
50
+ ]
docket/worker.py CHANGED
@@ -1,15 +1,41 @@
1
+ import asyncio
2
+ import inspect
1
3
  import logging
2
4
  import sys
3
- from datetime import datetime, timezone
5
+ from datetime import datetime, timedelta, timezone
4
6
  from types import TracebackType
5
- from typing import TYPE_CHECKING, Any, Protocol, Self, Sequence, TypeVar, cast
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ Any,
10
+ Protocol,
11
+ Self,
12
+ Sequence,
13
+ TypeVar,
14
+ cast,
15
+ )
6
16
  from uuid import uuid4
7
17
 
18
+ import redis.exceptions
19
+ from opentelemetry import propagate, trace
20
+ from opentelemetry.trace import Tracer
8
21
  from redis import RedisError
9
22
 
10
23
  from .docket import Docket, Execution
24
+ from .instrumentation import (
25
+ TASK_DURATION,
26
+ TASK_PUNCTUALITY,
27
+ TASKS_COMPLETED,
28
+ TASKS_FAILED,
29
+ TASKS_RETRIED,
30
+ TASKS_RUNNING,
31
+ TASKS_STARTED,
32
+ TASKS_SUCCEEDED,
33
+ message_getter,
34
+ )
11
35
 
12
36
  logger: logging.Logger = logging.getLogger(__name__)
37
+ tracer: Tracer = trace.get_tracer(__name__)
38
+
13
39
 
14
40
  RedisStreamID = bytes
15
41
  RedisMessageID = bytes
@@ -30,14 +56,22 @@ class _stream_due_tasks(Protocol):
30
56
 
31
57
 
32
58
  class Worker:
33
- name: str
34
59
  docket: Docket
60
+ name: str
35
61
 
36
- prefetch_count: int = 10
37
-
38
- def __init__(self, docket: Docket) -> None:
39
- self.name = f"worker:{uuid4()}"
62
+ def __init__(
63
+ self,
64
+ docket: Docket,
65
+ name: str | None = None,
66
+ prefetch_count: int = 10,
67
+ redelivery_timeout: timedelta = timedelta(minutes=5),
68
+ reconnection_delay: timedelta = timedelta(seconds=5),
69
+ ) -> None:
40
70
  self.docket = docket
71
+ self.name = name or f"worker:{uuid4()}"
72
+ self.prefetch_count = prefetch_count
73
+ self.redelivery_timeout = redelivery_timeout
74
+ self.reconnection_delay = reconnection_delay
41
75
 
42
76
  async def __aenter__(self) -> Self:
43
77
  async with self.docket.redis() as redis:
@@ -49,7 +83,8 @@ class Worker:
49
83
  mkstream=True,
50
84
  )
51
85
  except RedisError as e:
52
- assert "BUSYGROUP" in repr(e)
86
+ if "BUSYGROUP" not in repr(e):
87
+ raise
53
88
 
54
89
  return self
55
90
 
@@ -72,7 +107,60 @@ class Worker:
72
107
  "stream_key": self.docket.stream_key,
73
108
  }
74
109
 
75
- async def run_until_current(self) -> None:
110
+ @classmethod
111
+ async def run(
112
+ cls,
113
+ docket_name: str = "docket",
114
+ url: str = "redis://localhost:6379/0",
115
+ name: str | None = None,
116
+ prefetch_count: int = 10,
117
+ redelivery_timeout: timedelta = timedelta(minutes=5),
118
+ reconnection_delay: timedelta = timedelta(seconds=5),
119
+ until_finished: bool = False,
120
+ tasks: list[str] = ["docket.tasks:standard_tasks"],
121
+ ) -> None:
122
+ async with Docket(name=docket_name, url=url) as docket:
123
+ for task_path in tasks:
124
+ docket.register_collection(task_path)
125
+
126
+ async with Worker(
127
+ docket=docket,
128
+ name=name,
129
+ prefetch_count=prefetch_count,
130
+ redelivery_timeout=redelivery_timeout,
131
+ reconnection_delay=reconnection_delay,
132
+ ) as worker:
133
+ if until_finished:
134
+ await worker.run_until_finished()
135
+ else:
136
+ await worker.run_forever() # pragma: no cover
137
+
138
+ async def run_until_finished(self) -> None:
139
+ """Run the worker until there are no more tasks to process."""
140
+ return await self._run(forever=False)
141
+
142
+ async def run_forever(self) -> None:
143
+ """Run the worker indefinitely."""
144
+ return await self._run(forever=True) # pragma: no cover
145
+
146
+ async def _run(self, forever: bool = False) -> None:
147
+ logger.info("Starting worker %r with the following tasks:", self.name)
148
+ for task_name, task in self.docket.tasks.items():
149
+ signature = inspect.signature(task)
150
+ logger.info("* %s%s", task_name, signature)
151
+
152
+ while True:
153
+ try:
154
+ return await self._worker_loop(forever=forever)
155
+ except redis.exceptions.ConnectionError:
156
+ logger.warning(
157
+ "Error connecting to redis, retrying in %s...",
158
+ self.reconnection_delay,
159
+ exc_info=True,
160
+ )
161
+ await asyncio.sleep(self.reconnection_delay.total_seconds())
162
+
163
+ async def _worker_loop(self, forever: bool = False):
76
164
  async with self.docket.redis() as redis:
77
165
  stream_due_tasks: _stream_due_tasks = cast(
78
166
  _stream_due_tasks,
@@ -120,83 +208,136 @@ class Worker:
120
208
  )
121
209
 
122
210
  total_work, due_work = sys.maxsize, 0
123
- while total_work:
211
+ while forever or total_work:
124
212
  now = datetime.now(timezone.utc)
125
213
  total_work, due_work = await stream_due_tasks(
126
214
  keys=[self.docket.queue_key, self.docket.stream_key],
127
215
  args=[now.timestamp(), self.docket.name],
128
216
  )
129
- logger.info(
130
- "Moved %d/%d due tasks from %s to %s",
131
- due_work,
132
- total_work,
133
- self.docket.queue_key,
134
- self.docket.stream_key,
135
- extra=self._log_context,
217
+ if due_work > 0:
218
+ logger.debug(
219
+ "Moved %d/%d due tasks from %s to %s",
220
+ due_work,
221
+ total_work,
222
+ self.docket.queue_key,
223
+ self.docket.stream_key,
224
+ extra=self._log_context,
225
+ )
226
+
227
+ _, redeliveries, _ = await redis.xautoclaim(
228
+ name=self.docket.stream_key,
229
+ groupname=self.consumer_group_name,
230
+ consumername=self.name,
231
+ min_idle_time=int(self.redelivery_timeout.total_seconds() * 1000),
232
+ start_id="0-0",
233
+ count=self.prefetch_count,
136
234
  )
137
235
 
138
- response: RedisReadGroupResponse = await redis.xreadgroup(
236
+ new_deliveries: RedisReadGroupResponse = await redis.xreadgroup(
139
237
  groupname=self.consumer_group_name,
140
238
  consumername=self.name,
141
239
  streams={self.docket.stream_key: ">"},
142
240
  count=self.prefetch_count,
143
241
  block=10,
144
242
  )
145
- for _, messages in response:
146
- for message_id, message in messages:
147
- await self._execute(message)
148
-
149
- # When executing a task, there's always a chance that it was
150
- # either retried or it scheduled another task, so let's give
151
- # ourselves one more iteration of the loop to handle that.
152
- total_work += 1
153
-
154
- async with redis.pipeline() as pipe:
155
- pipe.xack(
156
- self.docket.stream_key,
157
- self.consumer_group_name,
158
- message_id,
159
- )
160
- pipe.xdel(
161
- self.docket.stream_key,
162
- message_id,
163
- )
164
- await pipe.execute()
243
+
244
+ for source in [[(b"redeliveries", redeliveries)], new_deliveries]:
245
+ for _, messages in source:
246
+ for message_id, message in messages:
247
+ await self._execute(message)
248
+
249
+ async with redis.pipeline() as pipeline:
250
+ pipeline.xack(
251
+ self.docket.stream_key,
252
+ self.consumer_group_name,
253
+ message_id,
254
+ )
255
+ pipeline.xdel(
256
+ self.docket.stream_key,
257
+ message_id,
258
+ )
259
+ await pipeline.execute()
260
+
261
+ # When executing a task, there's always a chance that it was
262
+ # either retried or it scheduled another task, so let's give
263
+ # ourselves one more iteration of the loop to handle that.
264
+ total_work += 1
165
265
 
166
266
  async def _execute(self, message: RedisMessage) -> None:
167
267
  execution = Execution.from_message(
168
268
  self.docket.tasks[message[b"function"].decode()],
169
269
  message,
170
270
  )
271
+ name = execution.function.__name__
272
+ key = execution.key
171
273
 
172
- logger.info(
173
- "Executing task %s with args %s and kwargs %s",
174
- execution.key,
175
- execution.args,
176
- execution.kwargs,
177
- extra={
178
- **self._log_context,
179
- "function": execution.function.__name__,
180
- },
181
- )
274
+ log_context: dict[str, str | float] = {
275
+ **self._log_context,
276
+ "task": name,
277
+ "key": key,
278
+ }
279
+ counter_labels = {
280
+ "docket": self.docket.name,
281
+ "worker": self.name,
282
+ "task": name,
283
+ }
182
284
 
183
285
  dependencies = self._get_dependencies(execution)
184
286
 
287
+ context = propagate.extract(message, getter=message_getter)
288
+ initiating_context = trace.get_current_span(context).get_span_context()
289
+ links = [trace.Link(initiating_context)] if initiating_context.is_valid else []
290
+
291
+ start = datetime.now(timezone.utc)
292
+ punctuality = start - execution.when
293
+ log_context["punctuality"] = punctuality.total_seconds()
294
+ duration = timedelta(0)
295
+
296
+ TASKS_STARTED.add(1, counter_labels)
297
+ TASKS_RUNNING.add(1, counter_labels)
298
+ TASK_PUNCTUALITY.record(punctuality.total_seconds(), counter_labels)
299
+
300
+ arrow = "↬" if execution.attempt > 1 else "↪"
301
+ call = execution.call_repr()
302
+ logger.info("%s [%s] %s", arrow, punctuality, call, extra=log_context)
303
+
185
304
  try:
186
- await execution.function(
187
- *execution.args,
188
- **{
189
- **execution.kwargs,
190
- **dependencies,
305
+ with tracer.start_as_current_span(
306
+ execution.function.__name__,
307
+ kind=trace.SpanKind.CONSUMER,
308
+ attributes={
309
+ "docket.name": self.docket.name,
310
+ "docket.execution.when": execution.when.isoformat(),
311
+ "docket.execution.key": execution.key,
312
+ "docket.execution.attempt": execution.attempt,
313
+ "docket.execution.punctuality": punctuality.total_seconds(),
314
+ "code.function.name": execution.function.__name__,
191
315
  },
192
- )
316
+ links=links,
317
+ ):
318
+ await execution.function(
319
+ *execution.args,
320
+ **{
321
+ **execution.kwargs,
322
+ **dependencies,
323
+ },
324
+ )
325
+
326
+ TASKS_SUCCEEDED.add(1, counter_labels)
327
+ duration = datetime.now(timezone.utc) - start
328
+ log_context["duration"] = duration.total_seconds()
329
+ logger.info("%s [%s] %s", "↩", duration, call, extra=log_context)
193
330
  except Exception:
194
- logger.exception(
195
- "Error executing task %s",
196
- execution.key,
197
- extra=self._log_context,
198
- )
199
- await self._retry_if_requested(execution, dependencies)
331
+ TASKS_FAILED.add(1, counter_labels)
332
+ duration = datetime.now(timezone.utc) - start
333
+ log_context["duration"] = duration.total_seconds()
334
+ retried = await self._retry_if_requested(execution, dependencies)
335
+ arrow = "↫" if retried else "↩"
336
+ logger.exception("%s [%s] %s", arrow, duration, call, extra=log_context)
337
+ finally:
338
+ TASKS_RUNNING.add(-1, counter_labels)
339
+ TASKS_COMPLETED.add(1, counter_labels)
340
+ TASK_DURATION.record(duration.total_seconds(), counter_labels)
200
341
 
201
342
  def _get_dependencies(
202
343
  self,
@@ -208,14 +349,14 @@ class Worker:
208
349
 
209
350
  dependencies: dict[str, Any] = {}
210
351
 
211
- for param_name, dependency in parameters.items():
352
+ for parameter_name, dependency in parameters.items():
212
353
  # If the argument is already provided, skip it, which allows users to call
213
354
  # the function directly with the arguments they want.
214
- if param_name in execution.kwargs:
215
- dependencies[param_name] = execution.kwargs[param_name]
355
+ if parameter_name in execution.kwargs:
356
+ dependencies[parameter_name] = execution.kwargs[parameter_name]
216
357
  continue
217
358
 
218
- dependencies[param_name] = dependency(self.docket, self, execution)
359
+ dependencies[parameter_name] = dependency(self.docket, self, execution)
219
360
 
220
361
  return dependencies
221
362
 
@@ -223,22 +364,26 @@ class Worker:
223
364
  self,
224
365
  execution: Execution,
225
366
  dependencies: dict[str, Any],
226
- ) -> None:
367
+ ) -> bool:
227
368
  from .dependencies import Retry
228
369
 
229
370
  retries = [retry for retry in dependencies.values() if isinstance(retry, Retry)]
230
371
  if not retries:
231
- return
372
+ return False
232
373
 
233
374
  retry = retries[0]
234
375
 
235
- if execution.attempt < retry.attempts:
376
+ if retry.attempts is None or execution.attempt < retry.attempts:
236
377
  execution.when = datetime.now(timezone.utc) + retry.delay
237
378
  execution.attempt += 1
238
379
  await self.docket.schedule(execution)
239
- else:
240
- logger.error(
241
- "Task %s failed after %d attempts",
242
- execution.key,
243
- retry.attempts,
244
- )
380
+
381
+ counter_labels = {
382
+ "docket": self.docket.name,
383
+ "worker": self.name,
384
+ "task": execution.function.__name__,
385
+ }
386
+ TASKS_RETRIED.add(1, counter_labels)
387
+ return True
388
+
389
+ return False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydocket
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: A distributed background task system for Python functions
5
5
  Project-URL: Homepage, https://github.com/chrisguidry/docket
6
6
  Project-URL: Bug Tracker, https://github.com/chrisguidry/docket/issues
@@ -23,7 +23,12 @@ Classifier: Programming Language :: Python :: 3.12
23
23
  Classifier: Programming Language :: Python :: 3.13
24
24
  Requires-Python: >=3.12
25
25
  Requires-Dist: cloudpickle>=3.1.1
26
+ Requires-Dist: opentelemetry-api>=1.30.0
27
+ Requires-Dist: opentelemetry-exporter-prometheus>=0.51b0
28
+ Requires-Dist: prometheus-client>=0.21.1
29
+ Requires-Dist: python-json-logger>=3.2.1
26
30
  Requires-Dist: redis>=5.2.1
31
+ Requires-Dist: rich>=13.9.4
27
32
  Requires-Dist: typer>=0.15.1
28
33
  Description-Content-Type: text/markdown
29
34
 
@@ -0,0 +1,16 @@
1
+ docket/__init__.py,sha256=GoJYpyuO6QFeBB8GNaxGGvMMuai55Eaw_8u-o1PM3hk,743
2
+ docket/__main__.py,sha256=Vkuh7aJ-Bl7QVpVbbkUksAd_hn05FiLmWbc-8kbhZQ4,34
3
+ docket/annotations.py,sha256=GZwOPtPXyeIhnsLh3TQMBnXrjtTtSmF4Ratv4vjPx8U,950
4
+ docket/cli.py,sha256=Qj2wzc3WjPDFsRiYJDVaqXKdTCUrFHDESwT8eebhvUk,7101
5
+ docket/dependencies.py,sha256=iJ9RdrdgB7jVclxFUd2mR9OXh7GjdpyR6k8fIHn4Sz4,4309
6
+ docket/docket.py,sha256=Z9gnS-vVFMqIThnUNEDvukN3InFz4T80SmPYxhbF9HQ,7006
7
+ docket/execution.py,sha256=AEfNwmpBFJEf9ZFnbA37dqjWexGzmMcmvdxGnDsUQdY,2340
8
+ docket/instrumentation.py,sha256=ZgXqjjPPspeQnfGakcBUS29FcRa-SKeqC0INTBzvIow,2515
9
+ docket/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ docket/tasks.py,sha256=rFAO9z0ESpH1F-P02DJ3yj5E54TGMtNc2pUkpaiTLVY,1317
11
+ docket/worker.py,sha256=DGvaSyiFKkV0k2YG9jhmRhStFm1gDhWZaxDXDDky2yo,13987
12
+ pydocket-0.0.2.dist-info/METADATA,sha256=WWNNg9YXTL8zyPUzHFUvkS0ARvFw_Wl2w_nBf-KqFNE,2292
13
+ pydocket-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
+ pydocket-0.0.2.dist-info/entry_points.txt,sha256=4WOk1nUlBsUT5O3RyMci2ImuC5XFswuopElYcLHtD5k,47
15
+ pydocket-0.0.2.dist-info/licenses/LICENSE,sha256=YuVWU_ZXO0K_k2FG8xWKe5RGxV24AhJKTvQmKfqXuyk,1087
16
+ pydocket-0.0.2.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- docket/__init__.py,sha256=0gzwWxJcDqU7nEZ4rHrWYk7WqEgW4Mz77E7rbjpyCHw,526
2
- docket/__main__.py,sha256=Vkuh7aJ-Bl7QVpVbbkUksAd_hn05FiLmWbc-8kbhZQ4,34
3
- docket/cli.py,sha256=ty8CirvLDvvOuOs7MHW0SYMjmbQq9rqUgPCq-HqW_6g,434
4
- docket/dependencies.py,sha256=yd_sIv3y69Czo_DyHh3aNtLJGFOMjg8jjJ701QN04-Q,2124
5
- docket/docket.py,sha256=sIgSvPUX4HG4EN_OxPSN7xCGu7DZkk5oZsuBVDp5XcU,4942
6
- docket/execution.py,sha256=3HT1GOeg76RMILiq06bpDZITHrqjVQ_j9FU6fuY4jqw,1375
7
- docket/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- docket/worker.py,sha256=N-nftCveJELKnICp1lgmzG5-r1QCYkGRxi8FcTjrLSc,8204
9
- pydocket-0.0.1.dist-info/METADATA,sha256=IYrB8lUbawu5I9Buh6uVJ68kJh47GUfph3F1wxxN78w,2084
10
- pydocket-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
11
- pydocket-0.0.1.dist-info/entry_points.txt,sha256=4WOk1nUlBsUT5O3RyMci2ImuC5XFswuopElYcLHtD5k,47
12
- pydocket-0.0.1.dist-info/licenses/LICENSE,sha256=YuVWU_ZXO0K_k2FG8xWKe5RGxV24AhJKTvQmKfqXuyk,1087
13
- pydocket-0.0.1.dist-info/RECORD,,