pydocket 0.4.0__tar.gz → 0.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydocket might be problematic. Click here for more details.

Files changed (51) hide show
  1. {pydocket-0.4.0 → pydocket-0.5.1}/PKG-INFO +1 -1
  2. {pydocket-0.4.0 → pydocket-0.5.1}/chaos/driver.py +7 -11
  3. {pydocket-0.4.0 → pydocket-0.5.1}/chaos/producer.py +10 -1
  4. {pydocket-0.4.0 → pydocket-0.5.1}/src/docket/cli.py +25 -0
  5. {pydocket-0.4.0 → pydocket-0.5.1}/src/docket/dependencies.py +42 -2
  6. {pydocket-0.4.0 → pydocket-0.5.1}/src/docket/worker.py +186 -95
  7. {pydocket-0.4.0 → pydocket-0.5.1}/tests/conftest.py +3 -1
  8. {pydocket-0.4.0 → pydocket-0.5.1}/tests/test_fundamentals.py +28 -0
  9. {pydocket-0.4.0 → pydocket-0.5.1}/.cursor/rules/general.mdc +0 -0
  10. {pydocket-0.4.0 → pydocket-0.5.1}/.cursor/rules/python-style.mdc +0 -0
  11. {pydocket-0.4.0 → pydocket-0.5.1}/.github/codecov.yml +0 -0
  12. {pydocket-0.4.0 → pydocket-0.5.1}/.github/workflows/chaos.yml +0 -0
  13. {pydocket-0.4.0 → pydocket-0.5.1}/.github/workflows/ci.yml +0 -0
  14. {pydocket-0.4.0 → pydocket-0.5.1}/.github/workflows/publish.yml +0 -0
  15. {pydocket-0.4.0 → pydocket-0.5.1}/.gitignore +0 -0
  16. {pydocket-0.4.0 → pydocket-0.5.1}/.pre-commit-config.yaml +0 -0
  17. {pydocket-0.4.0 → pydocket-0.5.1}/LICENSE +0 -0
  18. {pydocket-0.4.0 → pydocket-0.5.1}/README.md +0 -0
  19. {pydocket-0.4.0 → pydocket-0.5.1}/chaos/README.md +0 -0
  20. {pydocket-0.4.0 → pydocket-0.5.1}/chaos/__init__.py +0 -0
  21. {pydocket-0.4.0 → pydocket-0.5.1}/chaos/run +0 -0
  22. {pydocket-0.4.0 → pydocket-0.5.1}/chaos/tasks.py +0 -0
  23. {pydocket-0.4.0 → pydocket-0.5.1}/pyproject.toml +0 -0
  24. {pydocket-0.4.0 → pydocket-0.5.1}/src/docket/__init__.py +0 -0
  25. {pydocket-0.4.0 → pydocket-0.5.1}/src/docket/__main__.py +0 -0
  26. {pydocket-0.4.0 → pydocket-0.5.1}/src/docket/annotations.py +0 -0
  27. {pydocket-0.4.0 → pydocket-0.5.1}/src/docket/docket.py +0 -0
  28. {pydocket-0.4.0 → pydocket-0.5.1}/src/docket/execution.py +0 -0
  29. {pydocket-0.4.0 → pydocket-0.5.1}/src/docket/instrumentation.py +0 -0
  30. {pydocket-0.4.0 → pydocket-0.5.1}/src/docket/py.typed +0 -0
  31. {pydocket-0.4.0 → pydocket-0.5.1}/src/docket/tasks.py +0 -0
  32. {pydocket-0.4.0 → pydocket-0.5.1}/telemetry/.gitignore +0 -0
  33. {pydocket-0.4.0 → pydocket-0.5.1}/telemetry/start +0 -0
  34. {pydocket-0.4.0 → pydocket-0.5.1}/telemetry/stop +0 -0
  35. {pydocket-0.4.0 → pydocket-0.5.1}/tests/__init__.py +0 -0
  36. {pydocket-0.4.0 → pydocket-0.5.1}/tests/cli/__init__.py +0 -0
  37. {pydocket-0.4.0 → pydocket-0.5.1}/tests/cli/conftest.py +0 -0
  38. {pydocket-0.4.0 → pydocket-0.5.1}/tests/cli/test_module.py +0 -0
  39. {pydocket-0.4.0 → pydocket-0.5.1}/tests/cli/test_parsing.py +0 -0
  40. {pydocket-0.4.0 → pydocket-0.5.1}/tests/cli/test_snapshot.py +0 -0
  41. {pydocket-0.4.0 → pydocket-0.5.1}/tests/cli/test_striking.py +0 -0
  42. {pydocket-0.4.0 → pydocket-0.5.1}/tests/cli/test_tasks.py +0 -0
  43. {pydocket-0.4.0 → pydocket-0.5.1}/tests/cli/test_version.py +0 -0
  44. {pydocket-0.4.0 → pydocket-0.5.1}/tests/cli/test_worker.py +0 -0
  45. {pydocket-0.4.0 → pydocket-0.5.1}/tests/cli/test_workers.py +0 -0
  46. {pydocket-0.4.0 → pydocket-0.5.1}/tests/test_dependencies.py +0 -0
  47. {pydocket-0.4.0 → pydocket-0.5.1}/tests/test_docket.py +0 -0
  48. {pydocket-0.4.0 → pydocket-0.5.1}/tests/test_instrumentation.py +0 -0
  49. {pydocket-0.4.0 → pydocket-0.5.1}/tests/test_striking.py +0 -0
  50. {pydocket-0.4.0 → pydocket-0.5.1}/tests/test_worker.py +0 -0
  51. {pydocket-0.4.0 → pydocket-0.5.1}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydocket
3
- Version: 0.4.0
3
+ Version: 0.5.1
4
4
  Summary: A distributed background task system for Python functions
5
5
  Project-URL: Homepage, https://github.com/chrisguidry/docket
6
6
  Project-URL: Bug Tracker, https://github.com/chrisguidry/docket/issues
@@ -76,9 +76,9 @@ async def run_redis(version: str) -> AsyncGenerator[tuple[str, Container], None]
76
76
 
77
77
  async def main(
78
78
  mode: Literal["performance", "chaos"] = "chaos",
79
- tasks: int = 5000,
80
- producers: int = 4,
81
- workers: int = 7,
79
+ tasks: int = 20000,
80
+ producers: int = 5,
81
+ workers: int = 10,
82
82
  ):
83
83
  async with (
84
84
  run_redis("7.4.2") as (redis_url, redis_container),
@@ -97,9 +97,7 @@ async def main(
97
97
  # Add in some random strikes to performance test
98
98
  for _ in range(100):
99
99
  parameter = f"param_{random.randint(1, 100)}"
100
- operator: Operator = random.choice(
101
- ["==", "!=", ">", ">=", "<", "<=", "between"]
102
- )
100
+ operator = random.choice(list(Operator))
103
101
  value = f"val_{random.randint(1, 1000)}"
104
102
  await docket.strike("rando", parameter, operator, value)
105
103
 
@@ -141,11 +139,9 @@ async def main(
141
139
  redis_url,
142
140
  "--tasks",
143
141
  "chaos.tasks:chaos_tasks",
144
- env=environment
145
- | {
146
- "OTEL_SERVICE_NAME": "chaos-worker",
147
- "DOCKET_WORKER_REDELIVERY_TIMEOUT": "5s",
148
- },
142
+ "--redelivery-timeout",
143
+ "5s",
144
+ env=environment | {"OTEL_SERVICE_NAME": "chaos-worker"},
149
145
  stdout=subprocess.DEVNULL,
150
146
  stderr=subprocess.DEVNULL,
151
147
  )
@@ -1,8 +1,11 @@
1
1
  import asyncio
2
+ import datetime
2
3
  import logging
3
4
  import os
5
+ import random
4
6
  import sys
5
7
  import time
8
+ from datetime import timedelta
6
9
 
7
10
  import redis.exceptions
8
11
 
@@ -14,6 +17,10 @@ logging.getLogger().setLevel(logging.INFO)
14
17
  logger = logging.getLogger("chaos.producer")
15
18
 
16
19
 
20
+ def now() -> datetime.datetime:
21
+ return datetime.datetime.now(datetime.timezone.utc)
22
+
23
+
17
24
  async def main(tasks_to_produce: int):
18
25
  docket = Docket(
19
26
  name=os.environ["DOCKET_NAME"],
@@ -25,7 +32,9 @@ async def main(tasks_to_produce: int):
25
32
  async with docket:
26
33
  async with docket.redis() as r:
27
34
  for _ in range(tasks_sent, tasks_to_produce):
28
- execution = await docket.add(hello)()
35
+ jitter = 5 * ((random.random() * 2) - 1)
36
+ when = now() + timedelta(seconds=jitter)
37
+ execution = await docket.add(hello, when=when)()
29
38
  await r.zadd("hello:sent", {execution.key: time.time()})
30
39
  logger.info("Added task %s", execution.key)
31
40
  tasks_sent += 1
@@ -162,6 +162,7 @@ def worker(
162
162
  "This can be specified multiple times. A task collection is any "
163
163
  "iterable of async functions."
164
164
  ),
165
+ envvar="DOCKET_TASKS",
165
166
  ),
166
167
  ] = ["docket.tasks:standard_tasks"],
167
168
  docket_: Annotated[
@@ -236,6 +237,14 @@ def worker(
236
237
  envvar="DOCKET_WORKER_MINIMUM_CHECK_INTERVAL",
237
238
  ),
238
239
  ] = timedelta(milliseconds=100),
240
+ scheduling_resolution: Annotated[
241
+ timedelta,
242
+ typer.Option(
243
+ parser=duration,
244
+ help="How frequently to check for future tasks to be scheduled",
245
+ envvar="DOCKET_WORKER_SCHEDULING_RESOLUTION",
246
+ ),
247
+ ] = timedelta(milliseconds=250),
239
248
  until_finished: Annotated[
240
249
  bool,
241
250
  typer.Option(
@@ -260,6 +269,7 @@ def worker(
260
269
  redelivery_timeout=redelivery_timeout,
261
270
  reconnection_delay=reconnection_delay,
262
271
  minimum_check_interval=minimum_check_interval,
272
+ scheduling_resolution=scheduling_resolution,
263
273
  until_finished=until_finished,
264
274
  metrics_port=metrics_port,
265
275
  tasks=tasks,
@@ -542,6 +552,18 @@ def relative_time(now: datetime, when: datetime) -> str:
542
552
 
543
553
  @app.command(help="Shows a snapshot of what's on the docket right now")
544
554
  def snapshot(
555
+ tasks: Annotated[
556
+ list[str],
557
+ typer.Option(
558
+ "--tasks",
559
+ help=(
560
+ "The dotted path of a task collection to register with the docket. "
561
+ "This can be specified multiple times. A task collection is any "
562
+ "iterable of async functions."
563
+ ),
564
+ envvar="DOCKET_TASKS",
565
+ ),
566
+ ] = ["docket.tasks:standard_tasks"],
545
567
  docket_: Annotated[
546
568
  str,
547
569
  typer.Option(
@@ -560,6 +582,9 @@ def snapshot(
560
582
  ) -> None:
561
583
  async def run() -> DocketSnapshot:
562
584
  async with Docket(name=docket_, url=url) as docket:
585
+ for task_path in tasks:
586
+ docket.register_collection(task_path)
587
+
563
588
  return await docket.snapshot()
564
589
 
565
590
  snapshot = asyncio.run(run())
@@ -2,7 +2,7 @@ import abc
2
2
  import inspect
3
3
  import logging
4
4
  from datetime import timedelta
5
- from typing import Any, Awaitable, Callable, Counter, cast
5
+ from typing import Any, Awaitable, Callable, Counter, TypeVar, cast
6
6
 
7
7
  from .docket import Docket
8
8
  from .execution import Execution
@@ -130,12 +130,29 @@ class Perpetual(Dependency):
130
130
  single = True
131
131
 
132
132
  every: timedelta
133
+ automatic: bool
134
+
133
135
  args: tuple[Any, ...]
134
136
  kwargs: dict[str, Any]
137
+
135
138
  cancelled: bool
136
139
 
137
- def __init__(self, every: timedelta = timedelta(0)) -> None:
140
+ def __init__(
141
+ self,
142
+ every: timedelta = timedelta(0),
143
+ automatic: bool = False,
144
+ ) -> None:
145
+ """Declare a task that should be run perpetually.
146
+
147
+ Args:
148
+ every: The target interval between task executions.
149
+ automatic: If set, this task will be automatically scheduled during worker
150
+ startup and continually through the worker's lifespan. This ensures
151
+ that the task will always be scheduled despite crashes and other
152
+ adverse conditions. Automatic tasks must not require any arguments.
153
+ """
138
154
  self.every = every
155
+ self.automatic = automatic
139
156
  self.cancelled = False
140
157
 
141
158
  def __call__(
@@ -170,6 +187,29 @@ def get_dependency_parameters(
170
187
  return dependencies
171
188
 
172
189
 
190
+ D = TypeVar("D", bound=Dependency)
191
+
192
+
193
+ def get_single_dependency_parameter_of_type(
194
+ function: Callable[..., Awaitable[Any]], dependency_type: type[D]
195
+ ) -> D | None:
196
+ assert dependency_type.single, "Dependency must be single"
197
+ for _, dependency in get_dependency_parameters(function).items():
198
+ if isinstance(dependency, dependency_type):
199
+ return dependency
200
+ return None
201
+
202
+
203
+ def get_single_dependency_of_type(
204
+ dependencies: dict[str, Dependency], dependency_type: type[D]
205
+ ) -> D | None:
206
+ assert dependency_type.single, "Dependency must be single"
207
+ for _, dependency in dependencies.items():
208
+ if isinstance(dependency, dependency_type):
209
+ return dependency
210
+ return None
211
+
212
+
173
213
  def validate_dependencies(function: Callable[..., Awaitable[Any]]) -> None:
174
214
  parameters = get_dependency_parameters(function)
175
215
 
@@ -6,11 +6,9 @@ from datetime import datetime, timedelta, timezone
6
6
  from types import TracebackType
7
7
  from typing import (
8
8
  TYPE_CHECKING,
9
- Any,
10
9
  Mapping,
11
10
  Protocol,
12
11
  Self,
13
- TypeVar,
14
12
  cast,
15
13
  )
16
14
  from uuid import uuid4
@@ -18,6 +16,7 @@ from uuid import uuid4
18
16
  import redis.exceptions
19
17
  from opentelemetry import propagate, trace
20
18
  from opentelemetry.trace import Tracer
19
+ from redis.asyncio import Redis
21
20
 
22
21
  from .docket import (
23
22
  Docket,
@@ -52,8 +51,6 @@ tracer: Tracer = trace.get_tracer(__name__)
52
51
  if TYPE_CHECKING: # pragma: no cover
53
52
  from .dependencies import Dependency
54
53
 
55
- D = TypeVar("D", bound="Dependency")
56
-
57
54
 
58
55
  class _stream_due_tasks(Protocol):
59
56
  async def __call__(
@@ -68,6 +65,7 @@ class Worker:
68
65
  redelivery_timeout: timedelta
69
66
  reconnection_delay: timedelta
70
67
  minimum_check_interval: timedelta
68
+ scheduling_resolution: timedelta
71
69
 
72
70
  def __init__(
73
71
  self,
@@ -77,6 +75,7 @@ class Worker:
77
75
  redelivery_timeout: timedelta = timedelta(minutes=5),
78
76
  reconnection_delay: timedelta = timedelta(seconds=5),
79
77
  minimum_check_interval: timedelta = timedelta(milliseconds=100),
78
+ scheduling_resolution: timedelta = timedelta(milliseconds=250),
80
79
  ) -> None:
81
80
  self.docket = docket
82
81
  self.name = name or f"worker:{uuid4()}"
@@ -84,6 +83,7 @@ class Worker:
84
83
  self.redelivery_timeout = redelivery_timeout
85
84
  self.reconnection_delay = reconnection_delay
86
85
  self.minimum_check_interval = minimum_check_interval
86
+ self.scheduling_resolution = scheduling_resolution
87
87
 
88
88
  async def __aenter__(self) -> Self:
89
89
  self._heartbeat_task = asyncio.create_task(self._heartbeat())
@@ -128,6 +128,7 @@ class Worker:
128
128
  redelivery_timeout: timedelta = timedelta(minutes=5),
129
129
  reconnection_delay: timedelta = timedelta(seconds=5),
130
130
  minimum_check_interval: timedelta = timedelta(milliseconds=100),
131
+ scheduling_resolution: timedelta = timedelta(milliseconds=250),
131
132
  until_finished: bool = False,
132
133
  metrics_port: int | None = None,
133
134
  tasks: list[str] = ["docket.tasks:standard_tasks"],
@@ -144,6 +145,7 @@ class Worker:
144
145
  redelivery_timeout=redelivery_timeout,
145
146
  reconnection_delay=reconnection_delay,
146
147
  minimum_check_interval=minimum_check_interval,
148
+ scheduling_resolution=scheduling_resolution,
147
149
  ) as worker:
148
150
  if until_finished:
149
151
  await worker.run_until_finished()
@@ -210,57 +212,28 @@ class Worker:
210
212
  await asyncio.sleep(self.reconnection_delay.total_seconds())
211
213
 
212
214
  async def _worker_loop(self, forever: bool = False):
213
- async with self.docket.redis() as redis:
214
- stream_due_tasks: _stream_due_tasks = cast(
215
- _stream_due_tasks,
216
- redis.register_script(
217
- # Lua script to atomically move scheduled tasks to the stream
218
- # KEYS[1]: queue key (sorted set)
219
- # KEYS[2]: stream key
220
- # ARGV[1]: current timestamp
221
- # ARGV[2]: docket name prefix
222
- """
223
- local total_work = redis.call('ZCARD', KEYS[1])
224
- local due_work = 0
225
-
226
- if total_work > 0 then
227
- local tasks = redis.call('ZRANGEBYSCORE', KEYS[1], 0, ARGV[1])
228
-
229
- for i, key in ipairs(tasks) do
230
- local hash_key = ARGV[2] .. ":" .. key
231
- local task_data = redis.call('HGETALL', hash_key)
232
-
233
- if #task_data > 0 then
234
- local task = {}
235
- for j = 1, #task_data, 2 do
236
- task[task_data[j]] = task_data[j+1]
237
- end
238
-
239
- redis.call('XADD', KEYS[2], '*',
240
- 'key', task['key'],
241
- 'when', task['when'],
242
- 'function', task['function'],
243
- 'args', task['args'],
244
- 'kwargs', task['kwargs'],
245
- 'attempt', task['attempt']
246
- )
247
- redis.call('DEL', hash_key)
248
- due_work = due_work + 1
249
- end
250
- end
251
- end
215
+ worker_stopping = asyncio.Event()
252
216
 
253
- if due_work > 0 then
254
- redis.call('ZREMRANGEBYSCORE', KEYS[1], 0, ARGV[1])
255
- end
217
+ await self._schedule_all_automatic_perpetual_tasks()
218
+ perpetual_scheduling_task = asyncio.create_task(
219
+ self._perpetual_scheduling_loop(worker_stopping)
220
+ )
256
221
 
257
- return {total_work, due_work}
258
- """
259
- ),
222
+ async with self.docket.redis() as redis:
223
+ scheduler_task = asyncio.create_task(
224
+ self._scheduler_loop(redis, worker_stopping)
260
225
  )
261
-
262
226
  active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
263
227
 
228
+ async def check_for_work() -> bool:
229
+ async with redis.pipeline() as pipeline:
230
+ pipeline.xlen(self.docket.stream_key)
231
+ pipeline.zcard(self.docket.queue_key)
232
+ results: list[int] = await pipeline.execute()
233
+ stream_len = results[0]
234
+ queue_len = results[1]
235
+ return stream_len > 0 or queue_len > 0
236
+
264
237
  async def process_completed_tasks() -> None:
265
238
  completed_tasks = {task for task in active_tasks if task.done()}
266
239
  for task in completed_tasks:
@@ -280,10 +253,13 @@ class Worker:
280
253
  )
281
254
  await pipeline.execute()
282
255
 
283
- future_work, due_work = sys.maxsize, 0
256
+ has_work: bool = True
257
+
258
+ if not forever: # pragma: no branch
259
+ has_work = await check_for_work()
284
260
 
285
261
  try:
286
- while forever or future_work or active_tasks:
262
+ while forever or has_work or active_tasks:
287
263
  await process_completed_tasks()
288
264
 
289
265
  available_slots = self.concurrency - len(active_tasks)
@@ -297,28 +273,13 @@ class Worker:
297
273
  task = asyncio.create_task(self._execute(message))
298
274
  active_tasks[task] = message_id
299
275
 
300
- nonlocal available_slots, future_work
276
+ nonlocal available_slots
301
277
  available_slots -= 1
302
- future_work += 1
303
278
 
304
279
  if available_slots <= 0:
305
280
  await asyncio.sleep(self.minimum_check_interval.total_seconds())
306
281
  continue
307
282
 
308
- future_work, due_work = await stream_due_tasks(
309
- keys=[self.docket.queue_key, self.docket.stream_key],
310
- args=[datetime.now(timezone.utc).timestamp(), self.docket.name],
311
- )
312
- if due_work > 0:
313
- logger.debug(
314
- "Moved %d/%d due tasks from %s to %s",
315
- due_work,
316
- future_work,
317
- self.docket.queue_key,
318
- self.docket.stream_key,
319
- extra=self._log_context(),
320
- )
321
-
322
283
  redeliveries: RedisMessages
323
284
  _, redeliveries, *_ = await redis.xautoclaim(
324
285
  name=self.docket.stream_key,
@@ -348,10 +309,14 @@ class Worker:
348
309
  ),
349
310
  count=available_slots,
350
311
  )
312
+
351
313
  for _, messages in new_deliveries:
352
314
  for message_id, message in messages:
353
315
  start_task(message_id, message)
354
316
 
317
+ if not forever and not active_tasks and not new_deliveries:
318
+ has_work = await check_for_work()
319
+
355
320
  except asyncio.CancelledError:
356
321
  if active_tasks: # pragma: no cover
357
322
  logger.info(
@@ -364,7 +329,142 @@ class Worker:
364
329
  await asyncio.gather(*active_tasks, return_exceptions=True)
365
330
  await process_completed_tasks()
366
331
 
332
+ worker_stopping.set()
333
+ await scheduler_task
334
+ await perpetual_scheduling_task
335
+
336
+ async def _scheduler_loop(
337
+ self,
338
+ redis: Redis,
339
+ worker_stopping: asyncio.Event,
340
+ ) -> None:
341
+ """Loop that moves due tasks from the queue to the stream."""
342
+
343
+ stream_due_tasks: _stream_due_tasks = cast(
344
+ _stream_due_tasks,
345
+ redis.register_script(
346
+ # Lua script to atomically move scheduled tasks to the stream
347
+ # KEYS[1]: queue key (sorted set)
348
+ # KEYS[2]: stream key
349
+ # ARGV[1]: current timestamp
350
+ # ARGV[2]: docket name prefix
351
+ """
352
+ local total_work = redis.call('ZCARD', KEYS[1])
353
+ local due_work = 0
354
+
355
+ if total_work > 0 then
356
+ local tasks = redis.call('ZRANGEBYSCORE', KEYS[1], 0, ARGV[1])
357
+
358
+ for i, key in ipairs(tasks) do
359
+ local hash_key = ARGV[2] .. ":" .. key
360
+ local task_data = redis.call('HGETALL', hash_key)
361
+
362
+ if #task_data > 0 then
363
+ local task = {}
364
+ for j = 1, #task_data, 2 do
365
+ task[task_data[j]] = task_data[j+1]
366
+ end
367
+
368
+ redis.call('XADD', KEYS[2], '*',
369
+ 'key', task['key'],
370
+ 'when', task['when'],
371
+ 'function', task['function'],
372
+ 'args', task['args'],
373
+ 'kwargs', task['kwargs'],
374
+ 'attempt', task['attempt']
375
+ )
376
+ redis.call('DEL', hash_key)
377
+ due_work = due_work + 1
378
+ end
379
+ end
380
+ end
381
+
382
+ if due_work > 0 then
383
+ redis.call('ZREMRANGEBYSCORE', KEYS[1], 0, ARGV[1])
384
+ end
385
+
386
+ return {total_work, due_work}
387
+ """
388
+ ),
389
+ )
390
+
391
+ total_work: int = sys.maxsize
392
+
393
+ while not worker_stopping.is_set() or total_work:
394
+ try:
395
+ total_work, due_work = await stream_due_tasks(
396
+ keys=[self.docket.queue_key, self.docket.stream_key],
397
+ args=[datetime.now(timezone.utc).timestamp(), self.docket.name],
398
+ )
399
+
400
+ if due_work > 0:
401
+ logger.debug(
402
+ "Moved %d/%d due tasks from %s to %s",
403
+ due_work,
404
+ total_work,
405
+ self.docket.queue_key,
406
+ self.docket.stream_key,
407
+ extra=self._log_context(),
408
+ )
409
+ except Exception: # pragma: no cover
410
+ logger.exception(
411
+ "Error in scheduler loop",
412
+ exc_info=True,
413
+ extra=self._log_context(),
414
+ )
415
+ finally:
416
+ await asyncio.sleep(self.scheduling_resolution.total_seconds())
417
+
418
+ logger.debug("Scheduler loop finished", extra=self._log_context())
419
+
420
+ async def _perpetual_scheduling_loop(self, worker_stopping: asyncio.Event) -> None:
421
+ """Loop that ensures that automatic perpetual tasks are always scheduled."""
422
+
423
+ while not worker_stopping.is_set():
424
+ minimum_interval = self.scheduling_resolution
425
+ try:
426
+ minimum_interval = await self._schedule_all_automatic_perpetual_tasks()
427
+ except Exception: # pragma: no cover
428
+ logger.exception(
429
+ "Error in perpetual scheduling loop",
430
+ exc_info=True,
431
+ extra=self._log_context(),
432
+ )
433
+ finally:
434
+ # Wait until just before the next time any task would need to be
435
+ # scheduled (one scheduling_resolution before the lowest interval)
436
+ interval = max(
437
+ minimum_interval - self.scheduling_resolution,
438
+ self.scheduling_resolution,
439
+ )
440
+ assert interval <= self.scheduling_resolution
441
+ await asyncio.sleep(interval.total_seconds())
442
+
443
+ async def _schedule_all_automatic_perpetual_tasks(self) -> timedelta:
444
+ from .dependencies import Perpetual, get_single_dependency_parameter_of_type
445
+
446
+ minimum_interval = self.scheduling_resolution
447
+ for task_function in self.docket.tasks.values():
448
+ perpetual = get_single_dependency_parameter_of_type(
449
+ task_function, Perpetual
450
+ )
451
+ if perpetual is None:
452
+ continue
453
+
454
+ if not perpetual.automatic:
455
+ continue
456
+
457
+ key = task_function.__name__
458
+ await self.docket.add(task_function, key=key)()
459
+ minimum_interval = min(minimum_interval, perpetual.every)
460
+
461
+ return minimum_interval
462
+
367
463
  async def _execute(self, message: RedisMessage) -> None:
464
+ key = message[b"key"].decode()
465
+ async with self.docket.redis() as redis:
466
+ await redis.delete(self.docket.known_task_key(key))
467
+
368
468
  log_context: Mapping[str, str | float] = self._log_context()
369
469
 
370
470
  function_name = message[b"function"].decode()
@@ -377,9 +477,6 @@ class Worker:
377
477
 
378
478
  execution = Execution.from_message(function, message)
379
479
 
380
- async with self.docket.redis() as redis:
381
- await redis.delete(self.docket.known_task_key(execution.key))
382
-
383
480
  log_context = {**log_context, **execution.specific_labels()}
384
481
  counter_labels = {**self.labels(), **execution.general_labels()}
385
482
 
@@ -458,12 +555,12 @@ class Worker:
458
555
  def _get_dependencies(
459
556
  self,
460
557
  execution: Execution,
461
- ) -> dict[str, Any]:
558
+ ) -> dict[str, "Dependency"]:
462
559
  from .dependencies import get_dependency_parameters
463
560
 
464
561
  parameters = get_dependency_parameters(execution.function)
465
562
 
466
- dependencies: dict[str, Any] = {}
563
+ dependencies: dict[str, "Dependency"] = {}
467
564
 
468
565
  for parameter_name, dependency in parameters.items():
469
566
  # If the argument is already provided, skip it, which allows users to call
@@ -479,16 +576,14 @@ class Worker:
479
576
  async def _retry_if_requested(
480
577
  self,
481
578
  execution: Execution,
482
- dependencies: dict[str, Any],
579
+ dependencies: dict[str, "Dependency"],
483
580
  ) -> bool:
484
- from .dependencies import Retry
581
+ from .dependencies import Retry, get_single_dependency_of_type
485
582
 
486
- retries = [retry for retry in dependencies.values() if isinstance(retry, Retry)]
487
- if not retries:
583
+ retry = get_single_dependency_of_type(dependencies, Retry)
584
+ if not retry:
488
585
  return False
489
586
 
490
- retry = retries[0]
491
-
492
587
  if retry.attempts is None or execution.attempt < retry.attempts:
493
588
  execution.when = datetime.now(timezone.utc) + retry.delay
494
589
  execution.attempt += 1
@@ -500,19 +595,16 @@ class Worker:
500
595
  return False
501
596
 
502
597
  async def _perpetuate_if_requested(
503
- self, execution: Execution, dependencies: dict[str, Any], duration: timedelta
598
+ self,
599
+ execution: Execution,
600
+ dependencies: dict[str, "Dependency"],
601
+ duration: timedelta,
504
602
  ) -> bool:
505
- from .dependencies import Perpetual
506
-
507
- perpetuals = [
508
- perpetual
509
- for perpetual in dependencies.values()
510
- if isinstance(perpetual, Perpetual)
511
- ]
512
- if not perpetuals:
513
- return False
603
+ from .dependencies import Perpetual, get_single_dependency_of_type
514
604
 
515
- perpetual = perpetuals[0]
605
+ perpetual = get_single_dependency_of_type(dependencies, Perpetual)
606
+ if not perpetual:
607
+ return False
516
608
 
517
609
  if perpetual.cancelled:
518
610
  return False
@@ -572,11 +664,10 @@ class Worker:
572
664
  pipeline.zcount(self.docket.queue_key, 0, now)
573
665
  pipeline.zcount(self.docket.queue_key, now, "+inf")
574
666
 
575
- (
576
- stream_depth,
577
- overdue_depth,
578
- schedule_depth,
579
- ) = await pipeline.execute()
667
+ results: list[int] = await pipeline.execute()
668
+ stream_depth = results[0]
669
+ overdue_depth = results[1]
670
+ schedule_depth = results[2]
580
671
 
581
672
  QUEUE_DEPTH.set(
582
673
  stream_depth + overdue_depth, self.docket.labels()
@@ -155,7 +155,9 @@ async def docket(redis_url: str, aiolib: str) -> AsyncGenerator[Docket, None]:
155
155
  @pytest.fixture
156
156
  async def worker(docket: Docket) -> AsyncGenerator[Worker, None]:
157
157
  async with Worker(
158
- docket, minimum_check_interval=timedelta(milliseconds=10)
158
+ docket,
159
+ minimum_check_interval=timedelta(milliseconds=10),
160
+ scheduling_resolution=timedelta(milliseconds=10),
159
161
  ) as worker:
160
162
  yield worker
161
163
 
@@ -1037,3 +1037,31 @@ async def test_perpetual_tasks_perpetuate_even_after_errors(
1037
1037
  await worker.run_at_most({execution.key: 3})
1038
1038
 
1039
1039
  assert calls == 3
1040
+
1041
+
1042
+ async def test_perpetual_tasks_can_be_automatically_scheduled(
1043
+ docket: Docket, worker: Worker
1044
+ ):
1045
+ """Perpetual tasks can be automatically scheduled"""
1046
+
1047
+ calls = 0
1048
+
1049
+ async def my_automatic_task(
1050
+ perpetual: Perpetual = Perpetual(
1051
+ every=timedelta(milliseconds=50), automatic=True
1052
+ ),
1053
+ ):
1054
+ assert isinstance(perpetual, Perpetual)
1055
+
1056
+ assert perpetual.every == timedelta(milliseconds=50)
1057
+
1058
+ nonlocal calls
1059
+ calls += 1
1060
+
1061
+ # Note we never add this task to the docket, we just register it.
1062
+ docket.register(my_automatic_task)
1063
+
1064
+ # The automatic key will be the task function's name
1065
+ await worker.run_at_most({"my_automatic_task": 3})
1066
+
1067
+ assert calls == 3
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes