pydocket 0.15.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
docket/worker.py ADDED
@@ -0,0 +1,1071 @@
1
+ import asyncio
2
+ import base64
3
+ import logging
4
+ import os
5
+ import signal
6
+ import socket
7
+ import sys
8
+ import time
9
+ from datetime import datetime, timedelta, timezone
10
+ from types import TracebackType
11
+ from typing import Any, Coroutine, Mapping, Protocol, cast
12
+
13
+ import cloudpickle # type: ignore[import]
14
+
15
+ if sys.version_info < (3, 11): # pragma: no cover
16
+ from exceptiongroup import ExceptionGroup
17
+
18
+ from opentelemetry import trace
19
+ from opentelemetry.trace import Status, StatusCode, Tracer
20
+ from redis.asyncio import Redis
21
+ from redis.exceptions import ConnectionError, LockError, ResponseError
22
+ from typing_extensions import Self
23
+
24
+ from .dependencies import (
25
+ ConcurrencyLimit,
26
+ Dependency,
27
+ FailedDependency,
28
+ Perpetual,
29
+ Retry,
30
+ Timeout,
31
+ get_single_dependency_of_type,
32
+ get_single_dependency_parameter_of_type,
33
+ resolved_dependencies,
34
+ )
35
+ from .docket import (
36
+ Docket,
37
+ Execution,
38
+ RedisMessage,
39
+ RedisMessageID,
40
+ RedisReadGroupResponse,
41
+ )
42
+ from .execution import compact_signature, get_signature
43
+
44
+ # Run class has been consolidated into Execution
45
+ from .instrumentation import (
46
+ QUEUE_DEPTH,
47
+ REDIS_DISRUPTIONS,
48
+ SCHEDULE_DEPTH,
49
+ TASK_DURATION,
50
+ TASK_PUNCTUALITY,
51
+ TASKS_COMPLETED,
52
+ TASKS_FAILED,
53
+ TASKS_PERPETUATED,
54
+ TASKS_REDELIVERED,
55
+ TASKS_RETRIED,
56
+ TASKS_RUNNING,
57
+ TASKS_STARTED,
58
+ TASKS_STRICKEN,
59
+ TASKS_SUCCEEDED,
60
+ healthcheck_server,
61
+ metrics_server,
62
+ )
63
+
64
+ # Delay before retrying a task blocked by concurrency limits
65
+ # Must be larger than redelivery_timeout to ensure atomic reschedule+ACK completes
66
+ # before Redis would consider redelivering the message
67
+ CONCURRENCY_BLOCKED_RETRY_DELAY = timedelta(milliseconds=100)
68
+
69
+
70
+ class ConcurrencyBlocked(Exception):
71
+ """Raised when a task cannot start due to concurrency limits."""
72
+
73
+ def __init__(self, execution: Execution):
74
+ self.execution = execution
75
+ super().__init__(f"Task {execution.key} blocked by concurrency limits")
76
+
77
+
78
+ logger: logging.Logger = logging.getLogger(__name__)
79
+ tracer: Tracer = trace.get_tracer(__name__)
80
+
81
+
82
+ class _stream_due_tasks(Protocol):
83
+ async def __call__(
84
+ self, keys: list[str], args: list[str | float]
85
+ ) -> tuple[int, int]: ... # pragma: no cover
86
+
87
+
88
+ class Worker:
89
+ """A Worker executes tasks on a Docket. You may run as many workers as you like
90
+ to work a single Docket.
91
+
92
+ Example:
93
+
94
+ ```python
95
+ async with Docket() as docket:
96
+ async with Worker(docket) as worker:
97
+ await worker.run_forever()
98
+ ```
99
+ """
100
+
101
+ docket: Docket
102
+ name: str
103
+ concurrency: int
104
+ redelivery_timeout: timedelta
105
+ reconnection_delay: timedelta
106
+ minimum_check_interval: timedelta
107
+ scheduling_resolution: timedelta
108
+ schedule_automatic_tasks: bool
109
+
110
+ def __init__(
111
+ self,
112
+ docket: Docket,
113
+ name: str | None = None,
114
+ concurrency: int = 10,
115
+ redelivery_timeout: timedelta = timedelta(minutes=5),
116
+ reconnection_delay: timedelta = timedelta(seconds=5),
117
+ minimum_check_interval: timedelta = timedelta(milliseconds=250),
118
+ scheduling_resolution: timedelta = timedelta(milliseconds=250),
119
+ schedule_automatic_tasks: bool = True,
120
+ ) -> None:
121
+ self.docket = docket
122
+ self.name = name or f"{socket.gethostname()}#{os.getpid()}"
123
+ self.concurrency = concurrency
124
+ self.redelivery_timeout = redelivery_timeout
125
+ self.reconnection_delay = reconnection_delay
126
+ self.minimum_check_interval = minimum_check_interval
127
+ self.scheduling_resolution = scheduling_resolution
128
+ self.schedule_automatic_tasks = schedule_automatic_tasks
129
+
130
+ async def __aenter__(self) -> Self:
131
+ self._heartbeat_task = asyncio.create_task(self._heartbeat())
132
+ self._execution_counts = {}
133
+ return self
134
+
135
+ async def __aexit__(
136
+ self,
137
+ exc_type: type[BaseException] | None,
138
+ exc_value: BaseException | None,
139
+ traceback: TracebackType | None,
140
+ ) -> None:
141
+ del self._execution_counts
142
+
143
+ self._heartbeat_task.cancel()
144
+ try:
145
+ await self._heartbeat_task
146
+ except asyncio.CancelledError:
147
+ pass
148
+ del self._heartbeat_task
149
+
150
+ def labels(self) -> Mapping[str, str]:
151
+ return {
152
+ **self.docket.labels(),
153
+ "docket.worker": self.name,
154
+ }
155
+
156
+ def _log_context(self) -> Mapping[str, str]:
157
+ return {
158
+ **self.labels(),
159
+ "docket.queue_key": self.docket.queue_key,
160
+ "docket.stream_key": self.docket.stream_key,
161
+ }
162
+
163
+ @classmethod
164
+ async def run(
165
+ cls,
166
+ docket_name: str = "docket",
167
+ url: str = "redis://localhost:6379/0",
168
+ name: str | None = None,
169
+ concurrency: int = 10,
170
+ redelivery_timeout: timedelta = timedelta(minutes=5),
171
+ reconnection_delay: timedelta = timedelta(seconds=5),
172
+ minimum_check_interval: timedelta = timedelta(milliseconds=100),
173
+ scheduling_resolution: timedelta = timedelta(milliseconds=250),
174
+ schedule_automatic_tasks: bool = True,
175
+ until_finished: bool = False,
176
+ healthcheck_port: int | None = None,
177
+ metrics_port: int | None = None,
178
+ tasks: list[str] = ["docket.tasks:standard_tasks"],
179
+ ) -> None:
180
+ """Run a worker as the main entry point (CLI).
181
+
182
+ This method installs signal handlers for graceful shutdown since it
183
+ assumes ownership of the event loop. When embedding Docket in another
184
+ framework (e.g., FastAPI with uvicorn), use Worker.run_forever() or
185
+ Worker.run_until_finished() directly - those methods do not install
186
+ signal handlers and rely on the framework to handle shutdown signals.
187
+ """
188
+ with (
189
+ healthcheck_server(port=healthcheck_port),
190
+ metrics_server(port=metrics_port),
191
+ ):
192
+ async with Docket(name=docket_name, url=url) as docket:
193
+ for task_path in tasks:
194
+ docket.register_collection(task_path)
195
+
196
+ async with (
197
+ Worker( # pragma: no branch - context manager exit varies across interpreters
198
+ docket=docket,
199
+ name=name,
200
+ concurrency=concurrency,
201
+ redelivery_timeout=redelivery_timeout,
202
+ reconnection_delay=reconnection_delay,
203
+ minimum_check_interval=minimum_check_interval,
204
+ scheduling_resolution=scheduling_resolution,
205
+ schedule_automatic_tasks=schedule_automatic_tasks,
206
+ ) as worker
207
+ ):
208
+ # Install signal handlers for graceful shutdown.
209
+ # This is only appropriate when we own the event loop (CLI entry point).
210
+ # Embedded usage should let the framework handle signals.
211
+ loop = asyncio.get_running_loop()
212
+ run_task: asyncio.Task[None] | None = None
213
+
214
+ def handle_shutdown(sig_name: str) -> None: # pragma: no cover
215
+ logger.info(
216
+ "Received %s, initiating graceful shutdown...", sig_name
217
+ )
218
+ if run_task and not run_task.done():
219
+ run_task.cancel()
220
+
221
+ if hasattr(signal, "SIGTERM"): # pragma: no cover
222
+ loop.add_signal_handler(
223
+ signal.SIGTERM, lambda: handle_shutdown("SIGTERM")
224
+ )
225
+ loop.add_signal_handler(
226
+ signal.SIGINT, lambda: handle_shutdown("SIGINT")
227
+ )
228
+
229
+ try:
230
+ if until_finished:
231
+ run_task = asyncio.create_task(worker.run_until_finished())
232
+ else:
233
+ run_task = asyncio.create_task(
234
+ worker.run_forever()
235
+ ) # pragma: no cover
236
+ await run_task
237
+ except asyncio.CancelledError: # pragma: no cover
238
+ pass
239
+ finally:
240
+ if hasattr(signal, "SIGTERM"): # pragma: no cover
241
+ loop.remove_signal_handler(signal.SIGTERM)
242
+ loop.remove_signal_handler(signal.SIGINT)
243
+
244
+ async def run_until_finished(self) -> None:
245
+ """Run the worker until there are no more tasks to process."""
246
+ return await self._run(forever=False)
247
+
248
+ async def run_forever(self) -> None:
249
+ """Run the worker indefinitely."""
250
+ return await self._run(forever=True) # pragma: no cover
251
+
252
+ _execution_counts: dict[str, int]
253
+
254
+ async def run_at_most(self, iterations_by_key: Mapping[str, int]) -> None:
255
+ """
256
+ Run the worker until there are no more tasks to process, but limit specified
257
+ task keys to a maximum number of iterations.
258
+
259
+ This is particularly useful for testing self-perpetuating tasks that would
260
+ otherwise run indefinitely.
261
+
262
+ Args:
263
+ iterations_by_key: Maps task keys to their maximum allowed executions
264
+ """
265
+ self._execution_counts = {key: 0 for key in iterations_by_key}
266
+
267
+ def has_reached_max_iterations(execution: Execution) -> bool:
268
+ key = execution.key
269
+
270
+ if key not in iterations_by_key:
271
+ return False
272
+
273
+ if self._execution_counts[key] >= iterations_by_key[key]:
274
+ return True
275
+
276
+ return False
277
+
278
+ self.docket.strike_list.add_condition(has_reached_max_iterations)
279
+ try:
280
+ await self.run_until_finished()
281
+ finally:
282
+ self.docket.strike_list.remove_condition(has_reached_max_iterations)
283
+ self._execution_counts = {}
284
+
285
+ async def _run(self, forever: bool = False) -> None:
286
+ self._startup_log()
287
+
288
+ while True:
289
+ try:
290
+ async with self.docket.redis() as redis:
291
+ return await self._worker_loop(redis, forever=forever)
292
+ except ConnectionError:
293
+ REDIS_DISRUPTIONS.add(1, self.labels())
294
+ logger.warning(
295
+ "Error connecting to redis, retrying in %s...",
296
+ self.reconnection_delay,
297
+ exc_info=True,
298
+ )
299
+ await asyncio.sleep(self.reconnection_delay.total_seconds())
300
+
301
+ async def _worker_loop(self, redis: Redis, forever: bool = False):
302
+ worker_stopping = asyncio.Event()
303
+
304
+ if self.schedule_automatic_tasks:
305
+ await self._schedule_all_automatic_perpetual_tasks()
306
+
307
+ scheduler_task = asyncio.create_task(
308
+ self._scheduler_loop(redis, worker_stopping)
309
+ )
310
+
311
+ active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
312
+ task_executions: dict[asyncio.Task[None], Execution] = {}
313
+ available_slots = self.concurrency
314
+
315
+ log_context = self._log_context()
316
+
317
+ async def check_for_work() -> bool:
318
+ logger.debug("Checking for work", extra=log_context)
319
+ async with redis.pipeline() as pipeline:
320
+ pipeline.xlen(self.docket.stream_key)
321
+ pipeline.zcard(self.docket.queue_key)
322
+ results: list[int] = await pipeline.execute()
323
+ stream_len = results[0]
324
+ queue_len = results[1]
325
+ return stream_len > 0 or queue_len > 0
326
+
327
+ async def get_redeliveries(redis: Redis) -> RedisReadGroupResponse:
328
+ logger.debug("Getting redeliveries", extra=log_context)
329
+ try:
330
+ _, redeliveries, *_ = await redis.xautoclaim(
331
+ name=self.docket.stream_key,
332
+ groupname=self.docket.worker_group_name,
333
+ consumername=self.name,
334
+ min_idle_time=int(self.redelivery_timeout.total_seconds() * 1000),
335
+ start_id="0-0",
336
+ count=available_slots,
337
+ )
338
+ except ResponseError as e:
339
+ if "NOGROUP" in str(e):
340
+ await self.docket._ensure_stream_and_group()
341
+ return await get_redeliveries(redis)
342
+ raise # pragma: no cover
343
+ return [(b"__redelivery__", redeliveries)]
344
+
345
+ async def get_new_deliveries(redis: Redis) -> RedisReadGroupResponse:
346
+ logger.debug("Getting new deliveries", extra=log_context)
347
+ # Use non-blocking read with in-memory backend + manual sleep
348
+ # This is necessary because fakeredis's async blocking operations don't
349
+ # properly yield control to the asyncio event loop
350
+ is_memory = self.docket.url.startswith("memory://")
351
+ try:
352
+ result = await redis.xreadgroup(
353
+ groupname=self.docket.worker_group_name,
354
+ consumername=self.name,
355
+ streams={self.docket.stream_key: ">"},
356
+ block=0
357
+ if is_memory
358
+ else int(self.minimum_check_interval.total_seconds() * 1000),
359
+ count=available_slots,
360
+ )
361
+ except ResponseError as e:
362
+ if "NOGROUP" in str(e):
363
+ await self.docket._ensure_stream_and_group()
364
+ return await get_new_deliveries(redis)
365
+ raise # pragma: no cover
366
+ if is_memory and not result:
367
+ await asyncio.sleep(self.minimum_check_interval.total_seconds())
368
+ return result
369
+
370
+ async def start_task(
371
+ message_id: RedisMessageID,
372
+ message: RedisMessage,
373
+ is_redelivery: bool = False,
374
+ ) -> bool:
375
+ try:
376
+ execution = await Execution.from_message(
377
+ self.docket, message, redelivered=is_redelivery
378
+ )
379
+ except ValueError as e:
380
+ logger.error(
381
+ "Unable to start task: %s",
382
+ e,
383
+ extra=log_context,
384
+ )
385
+ return False
386
+
387
+ task = asyncio.create_task(self._execute(execution), name=execution.key)
388
+ active_tasks[task] = message_id
389
+ task_executions[task] = execution
390
+
391
+ nonlocal available_slots
392
+ available_slots -= 1
393
+
394
+ return True
395
+
396
+ async def process_completed_tasks() -> None:
397
+ completed_tasks = {task for task in active_tasks if task.done()}
398
+ for task in completed_tasks:
399
+ message_id = active_tasks.pop(task)
400
+ task_executions.pop(task)
401
+ try:
402
+ await task
403
+ # Task succeeded - acknowledge the message
404
+ await ack_message(redis, message_id)
405
+ except ConcurrencyBlocked as e:
406
+ # Task was blocked by concurrency limits, reschedule atomically
407
+ logger.debug(
408
+ "🔒 Task %s blocked by concurrency limit, rescheduling",
409
+ e.execution.key,
410
+ extra=log_context,
411
+ )
412
+ # Use atomic schedule(reschedule_message=...) to prevent both task loss and duplicate execution
413
+ e.execution.when = (
414
+ datetime.now(timezone.utc) + CONCURRENCY_BLOCKED_RETRY_DELAY
415
+ )
416
+ await e.execution.schedule(reschedule_message=message_id)
417
+
418
+ async def ack_message(redis: Redis, message_id: RedisMessageID) -> None:
419
+ logger.debug("Acknowledging message", extra=log_context)
420
+ async with redis.pipeline() as pipeline:
421
+ pipeline.xack(
422
+ self.docket.stream_key,
423
+ self.docket.worker_group_name,
424
+ message_id,
425
+ )
426
+ pipeline.xdel(
427
+ self.docket.stream_key,
428
+ message_id,
429
+ )
430
+ await pipeline.execute()
431
+
432
+ has_work: bool = True
433
+
434
+ try:
435
+ while forever or has_work or active_tasks:
436
+ await process_completed_tasks()
437
+
438
+ available_slots = self.concurrency - len(active_tasks)
439
+
440
+ if available_slots <= 0:
441
+ await asyncio.sleep(self.minimum_check_interval.total_seconds())
442
+ continue
443
+
444
+ for source in [get_redeliveries, get_new_deliveries]:
445
+ for stream_key, messages in await source(redis):
446
+ is_redelivery = stream_key == b"__redelivery__"
447
+ for message_id, message in messages:
448
+ if not message: # pragma: no cover
449
+ continue
450
+
451
+ task_started = await start_task(
452
+ message_id, message, is_redelivery
453
+ )
454
+ if not task_started:
455
+ await self._delete_known_task(redis, message)
456
+ await ack_message(redis, message_id)
457
+
458
+ if available_slots <= 0:
459
+ break
460
+
461
+ if not forever and not active_tasks:
462
+ has_work = await check_for_work()
463
+
464
+ except asyncio.CancelledError:
465
+ if active_tasks: # pragma: no cover
466
+ logger.info(
467
+ "Shutdown requested, finishing %d active tasks...",
468
+ len(active_tasks),
469
+ extra=log_context,
470
+ )
471
+ finally:
472
+ if active_tasks:
473
+ await asyncio.gather(*active_tasks, return_exceptions=True)
474
+ await process_completed_tasks()
475
+
476
+ worker_stopping.set()
477
+ await scheduler_task
478
+
479
+ async def _scheduler_loop(
480
+ self,
481
+ redis: Redis,
482
+ worker_stopping: asyncio.Event,
483
+ ) -> None:
484
+ """Loop that moves due tasks from the queue to the stream."""
485
+
486
+ stream_due_tasks: _stream_due_tasks = cast(
487
+ _stream_due_tasks,
488
+ redis.register_script(
489
+ # Lua script to atomically move scheduled tasks to the stream
490
+ # KEYS[1]: queue key (sorted set)
491
+ # KEYS[2]: stream key
492
+ # ARGV[1]: current timestamp
493
+ # ARGV[2]: docket name prefix
494
+ """
495
+ local total_work = redis.call('ZCARD', KEYS[1])
496
+ local due_work = 0
497
+
498
+ if total_work > 0 then
499
+ local tasks = redis.call('ZRANGEBYSCORE', KEYS[1], 0, ARGV[1])
500
+
501
+ for i, key in ipairs(tasks) do
502
+ local hash_key = ARGV[2] .. ":" .. key
503
+ local task_data = redis.call('HGETALL', hash_key)
504
+
505
+ if #task_data > 0 then
506
+ local task = {}
507
+ for j = 1, #task_data, 2 do
508
+ task[task_data[j]] = task_data[j+1]
509
+ end
510
+
511
+ redis.call('XADD', KEYS[2], '*',
512
+ 'key', task['key'],
513
+ 'when', task['when'],
514
+ 'function', task['function'],
515
+ 'args', task['args'],
516
+ 'kwargs', task['kwargs'],
517
+ 'attempt', task['attempt']
518
+ )
519
+ redis.call('DEL', hash_key)
520
+
521
+ -- Set run state to queued
522
+ local run_key = ARGV[2] .. ":runs:" .. task['key']
523
+ redis.call('HSET', run_key, 'state', 'queued')
524
+
525
+ -- Publish state change event to pub/sub
526
+ local channel = ARGV[2] .. ":state:" .. task['key']
527
+ local payload = '{"type":"state","key":"' .. task['key'] .. '","state":"queued","when":"' .. task['when'] .. '"}'
528
+ redis.call('PUBLISH', channel, payload)
529
+
530
+ due_work = due_work + 1
531
+ end
532
+ end
533
+ end
534
+
535
+ if due_work > 0 then
536
+ redis.call('ZREMRANGEBYSCORE', KEYS[1], 0, ARGV[1])
537
+ end
538
+
539
+ return {total_work, due_work}
540
+ """
541
+ ),
542
+ )
543
+
544
+ total_work: int = sys.maxsize
545
+
546
+ log_context = self._log_context()
547
+
548
+ while not worker_stopping.is_set() or total_work:
549
+ try:
550
+ logger.debug("Scheduling due tasks", extra=log_context)
551
+ total_work, due_work = await stream_due_tasks(
552
+ keys=[self.docket.queue_key, self.docket.stream_key],
553
+ args=[datetime.now(timezone.utc).timestamp(), self.docket.name],
554
+ )
555
+
556
+ if due_work > 0:
557
+ logger.debug(
558
+ "Moved %d/%d due tasks from %s to %s",
559
+ due_work,
560
+ total_work,
561
+ self.docket.queue_key,
562
+ self.docket.stream_key,
563
+ extra=log_context,
564
+ )
565
+ except Exception: # pragma: no cover
566
+ logger.exception(
567
+ "Error in scheduler loop",
568
+ exc_info=True,
569
+ extra=log_context,
570
+ )
571
+ finally:
572
+ await asyncio.sleep(self.scheduling_resolution.total_seconds())
573
+
574
+ logger.debug("Scheduler loop finished", extra=log_context)
575
+
576
+ async def _schedule_all_automatic_perpetual_tasks(self) -> None:
577
+ async with self.docket.redis() as redis:
578
+ try:
579
+ async with redis.lock(
580
+ f"{self.docket.name}:perpetual:lock", timeout=10, blocking=False
581
+ ):
582
+ for task_function in self.docket.tasks.values():
583
+ perpetual = get_single_dependency_parameter_of_type(
584
+ task_function, Perpetual
585
+ )
586
+ if perpetual is None:
587
+ continue
588
+
589
+ if not perpetual.automatic:
590
+ continue
591
+
592
+ key = task_function.__name__
593
+
594
+ await self.docket.add(task_function, key=key)()
595
+ except LockError: # pragma: no cover
596
+ return
597
+
598
+ async def _delete_known_task(
599
+ self, redis: Redis, execution_or_message: Execution | RedisMessage
600
+ ) -> None:
601
+ if isinstance(execution_or_message, Execution):
602
+ key = execution_or_message.key
603
+ elif bytes_key := execution_or_message.get(b"key"):
604
+ key = bytes_key.decode()
605
+ else: # pragma: no cover
606
+ return
607
+
608
+ logger.debug("Deleting known task", extra=self._log_context())
609
+ # Delete known/stream_id from runs hash to allow task rescheduling
610
+ runs_key = f"{self.docket.name}:runs:{key}"
611
+ await redis.hdel(runs_key, "known", "stream_id")
612
+
613
+ # TODO: Remove in next breaking release (v0.14.0) - legacy key cleanup
614
+ known_task_key = self.docket.known_task_key(key)
615
+ stream_id_key = self.docket.stream_id_key(key)
616
+ await redis.delete(known_task_key, stream_id_key)
617
+
618
+ async def _execute(self, execution: Execution) -> None:
619
+ log_context = {**self._log_context(), **execution.specific_labels()}
620
+ counter_labels = {**self.labels(), **execution.general_labels()}
621
+
622
+ call = execution.call_repr()
623
+
624
+ if self.docket.strike_list.is_stricken(execution):
625
+ async with self.docket.redis() as redis:
626
+ await self._delete_known_task(redis, execution)
627
+
628
+ logger.warning("🗙 %s", call, extra=log_context)
629
+ TASKS_STRICKEN.add(1, counter_labels | {"docket.where": "worker"})
630
+ return
631
+
632
+ if execution.key in self._execution_counts:
633
+ self._execution_counts[execution.key] += 1
634
+
635
+ start = time.time()
636
+ punctuality = start - execution.when.timestamp()
637
+ log_context = {**log_context, "punctuality": punctuality}
638
+ duration = 0.0
639
+
640
+ TASKS_STARTED.add(1, counter_labels)
641
+ if execution.redelivered:
642
+ TASKS_REDELIVERED.add(1, counter_labels)
643
+ TASKS_RUNNING.add(1, counter_labels)
644
+ TASK_PUNCTUALITY.record(punctuality, counter_labels)
645
+
646
+ arrow = "↬" if execution.attempt > 1 else "↪"
647
+ logger.info("%s [%s] %s", arrow, ms(punctuality), call, extra=log_context)
648
+
649
+ # Atomically claim task and transition to running state
650
+ # This also initializes progress and cleans up known/stream_id to allow rescheduling
651
+ await execution.claim(self.name)
652
+
653
+ dependencies: dict[str, Dependency] = {}
654
+
655
+ with tracer.start_as_current_span(
656
+ execution.function.__name__,
657
+ kind=trace.SpanKind.CONSUMER,
658
+ attributes={
659
+ **self.labels(),
660
+ **execution.specific_labels(),
661
+ "code.function.name": execution.function.__name__,
662
+ },
663
+ links=execution.incoming_span_links(),
664
+ ) as span:
665
+ try:
666
+ async with resolved_dependencies(self, execution) as dependencies:
667
+ # Check concurrency limits after dependency resolution
668
+ concurrency_limit = get_single_dependency_of_type(
669
+ dependencies, ConcurrencyLimit
670
+ )
671
+ if (
672
+ concurrency_limit and not concurrency_limit.is_bypassed
673
+ ): # pragma: no branch - coverage.py on Python 3.10 struggles with this
674
+ async with self.docket.redis() as redis:
675
+ # Check if we can acquire a concurrency slot
676
+ can_start = await self._can_start_task(redis, execution)
677
+ if not can_start: # pragma: no branch - 3.10 failure
678
+ # Task cannot start due to concurrency limits
679
+ raise ConcurrencyBlocked(execution)
680
+
681
+ dependency_failures = {
682
+ k: v
683
+ for k, v in dependencies.items()
684
+ if isinstance(v, FailedDependency)
685
+ }
686
+ if dependency_failures:
687
+ raise ExceptionGroup(
688
+ (
689
+ "Failed to resolve dependencies for parameter(s): "
690
+ + ", ".join(dependency_failures.keys())
691
+ ),
692
+ [
693
+ dependency.error
694
+ for dependency in dependency_failures.values()
695
+ ],
696
+ )
697
+
698
+ # Apply timeout logic - either user's timeout or redelivery timeout
699
+ user_timeout = get_single_dependency_of_type(dependencies, Timeout)
700
+ if user_timeout:
701
+ # If user timeout is longer than redelivery timeout, limit it
702
+ if user_timeout.base > self.redelivery_timeout:
703
+ # Create a new timeout limited by redelivery timeout
704
+ # Remove the user timeout from dependencies to avoid conflicts
705
+ limited_dependencies = {
706
+ k: v
707
+ for k, v in dependencies.items()
708
+ if not isinstance(v, Timeout)
709
+ }
710
+ limited_timeout = Timeout(self.redelivery_timeout)
711
+ limited_timeout.start()
712
+ result = await self._run_function_with_timeout(
713
+ execution, limited_dependencies, limited_timeout
714
+ )
715
+ else:
716
+ # User timeout is within redelivery timeout, use as-is
717
+ result = await self._run_function_with_timeout(
718
+ execution, dependencies, user_timeout
719
+ )
720
+ else:
721
+ # No user timeout - apply redelivery timeout as hard limit
722
+ redelivery_timeout = Timeout(self.redelivery_timeout)
723
+ redelivery_timeout.start()
724
+ result = await self._run_function_with_timeout(
725
+ execution, dependencies, redelivery_timeout
726
+ )
727
+
728
+ duration = log_context["duration"] = time.time() - start
729
+ TASKS_SUCCEEDED.add(1, counter_labels)
730
+
731
+ span.set_status(Status(StatusCode.OK))
732
+
733
+ rescheduled = await self._perpetuate_if_requested(
734
+ execution, dependencies, timedelta(seconds=duration)
735
+ )
736
+
737
+ if rescheduled:
738
+ # Task was rescheduled - still mark this execution as completed
739
+ # to set TTL on the runs hash (the new execution has its own entry)
740
+ await execution.mark_as_completed(result_key=None)
741
+ else:
742
+ # Store result if appropriate
743
+ result_key = None
744
+ if result is not None and self.docket.execution_ttl:
745
+ # Serialize and store result
746
+ pickled_result = cloudpickle.dumps(result) # type: ignore[arg-type]
747
+ # Base64-encode for JSON serialization
748
+ encoded_result = base64.b64encode(pickled_result).decode(
749
+ "ascii"
750
+ )
751
+ result_key = execution.key
752
+ ttl_seconds = int(self.docket.execution_ttl.total_seconds())
753
+ await self.docket.result_storage.put(
754
+ result_key, {"data": encoded_result}, ttl=ttl_seconds
755
+ )
756
+ # Mark execution as completed
757
+ await execution.mark_as_completed(result_key=result_key)
758
+
759
+ arrow = "↫" if rescheduled else "↩"
760
+ logger.info(
761
+ "%s [%s] %s", arrow, ms(duration), call, extra=log_context
762
+ )
763
+ except ConcurrencyBlocked:
764
+ # Re-raise to be handled by process_completed_tasks
765
+ raise
766
+ except Exception as e:
767
+ duration = log_context["duration"] = time.time() - start
768
+ TASKS_FAILED.add(1, counter_labels)
769
+
770
+ span.record_exception(e)
771
+ span.set_status(Status(StatusCode.ERROR, str(e)))
772
+
773
+ retried = await self._retry_if_requested(execution, dependencies)
774
+ if not retried:
775
+ retried = await self._perpetuate_if_requested(
776
+ execution, dependencies, timedelta(seconds=duration)
777
+ )
778
+
779
+ # Store exception in result_storage
780
+ result_key = None
781
+ if self.docket.execution_ttl:
782
+ pickled_exception = cloudpickle.dumps(e) # type: ignore[arg-type]
783
+ # Base64-encode for JSON serialization
784
+ encoded_exception = base64.b64encode(pickled_exception).decode(
785
+ "ascii"
786
+ )
787
+ result_key = execution.key
788
+ ttl_seconds = int(self.docket.execution_ttl.total_seconds())
789
+ await self.docket.result_storage.put(
790
+ result_key, {"data": encoded_exception}, ttl=ttl_seconds
791
+ )
792
+
793
+ # Mark execution as failed with error message
794
+ error_msg = f"{type(e).__name__}: {str(e)}"
795
+ await execution.mark_as_failed(error_msg, result_key=result_key)
796
+
797
+ arrow = "↫" if retried else "↩"
798
+ logger.exception(
799
+ "%s [%s] %s", arrow, ms(duration), call, extra=log_context
800
+ )
801
+ finally:
802
+ # Release concurrency slot if we acquired one
803
+ if dependencies:
804
+ concurrency_limit = get_single_dependency_of_type(
805
+ dependencies, ConcurrencyLimit
806
+ )
807
+ if concurrency_limit and not concurrency_limit.is_bypassed:
808
+ async with self.docket.redis() as redis:
809
+ await self._release_concurrency_slot(redis, execution)
810
+
811
+ TASKS_RUNNING.add(-1, counter_labels)
812
+ TASKS_COMPLETED.add(1, counter_labels)
813
+ TASK_DURATION.record(duration, counter_labels)
814
+
815
+ async def _run_function_with_timeout(
816
+ self,
817
+ execution: Execution,
818
+ dependencies: dict[str, Dependency],
819
+ timeout: Timeout,
820
+ ) -> Any:
821
+ task_coro = cast(
822
+ Coroutine[None, None, Any],
823
+ execution.function(
824
+ *execution.args,
825
+ **{
826
+ **execution.kwargs,
827
+ **dependencies,
828
+ },
829
+ ),
830
+ )
831
+ task = asyncio.create_task(task_coro)
832
+ try:
833
+ while not task.done(): # pragma: no branch
834
+ remaining = timeout.remaining().total_seconds()
835
+ if timeout.expired():
836
+ task.cancel()
837
+ break
838
+
839
+ try:
840
+ result = await asyncio.wait_for(
841
+ asyncio.shield(task), timeout=remaining
842
+ )
843
+ return result
844
+ except asyncio.TimeoutError:
845
+ continue
846
+ finally:
847
+ if not task.done(): # pragma: no branch
848
+ task.cancel()
849
+
850
+ try:
851
+ return await task
852
+ except asyncio.CancelledError:
853
+ raise asyncio.TimeoutError
854
+
855
+ async def _retry_if_requested(
856
+ self,
857
+ execution: Execution,
858
+ dependencies: dict[str, Dependency],
859
+ ) -> bool:
860
+ retry = get_single_dependency_of_type(dependencies, Retry)
861
+ if not retry:
862
+ return False
863
+
864
+ if retry.attempts is not None and execution.attempt >= retry.attempts:
865
+ return False
866
+
867
+ execution.when = datetime.now(timezone.utc) + retry.delay
868
+ execution.attempt += 1
869
+ # Use replace=True since the task is being rescheduled after failure
870
+ await execution.schedule(replace=True)
871
+
872
+ TASKS_RETRIED.add(1, {**self.labels(), **execution.general_labels()})
873
+ return True
874
+
875
+ async def _perpetuate_if_requested(
876
+ self,
877
+ execution: Execution,
878
+ dependencies: dict[str, Dependency],
879
+ duration: timedelta,
880
+ ) -> bool:
881
+ perpetual = get_single_dependency_of_type(dependencies, Perpetual)
882
+ if not perpetual:
883
+ return False
884
+
885
+ if perpetual.cancelled:
886
+ await self.docket.cancel(execution.key)
887
+ return False
888
+
889
+ now = datetime.now(timezone.utc)
890
+ when = max(now, now + perpetual.every - duration)
891
+
892
+ await self.docket.replace(execution.function, when, execution.key)(
893
+ *perpetual.args,
894
+ **perpetual.kwargs,
895
+ )
896
+
897
+ TASKS_PERPETUATED.add(1, {**self.labels(), **execution.general_labels()})
898
+
899
+ return True
900
+
901
+ def _startup_log(self) -> None:
902
+ logger.info("Starting worker %r with the following tasks:", self.name)
903
+ for task_name, task in self.docket.tasks.items():
904
+ logger.info("* %s(%s)", task_name, compact_signature(get_signature(task)))
905
+
906
+ @property
907
+ def workers_set(self) -> str:
908
+ return self.docket.workers_set
909
+
910
+ def worker_tasks_set(self, worker_name: str) -> str:
911
+ return self.docket.worker_tasks_set(worker_name)
912
+
913
+ def task_workers_set(self, task_name: str) -> str:
914
+ return self.docket.task_workers_set(task_name)
915
+
916
+ async def _heartbeat(self) -> None:
917
+ while True:
918
+ await asyncio.sleep(self.docket.heartbeat_interval.total_seconds())
919
+ try:
920
+ now = datetime.now(timezone.utc).timestamp()
921
+ maximum_age = (
922
+ self.docket.heartbeat_interval * self.docket.missed_heartbeats
923
+ )
924
+ oldest = now - maximum_age.total_seconds()
925
+
926
+ task_names = list(self.docket.tasks)
927
+
928
+ async with self.docket.redis() as r:
929
+ async with r.pipeline() as pipeline:
930
+ pipeline.zremrangebyscore(self.workers_set, 0, oldest)
931
+ pipeline.zadd(self.workers_set, {self.name: now})
932
+
933
+ for task_name in task_names:
934
+ task_workers_set = self.task_workers_set(task_name)
935
+ pipeline.zremrangebyscore(task_workers_set, 0, oldest)
936
+ pipeline.zadd(task_workers_set, {self.name: now})
937
+
938
+ pipeline.sadd(self.worker_tasks_set(self.name), *task_names)
939
+ pipeline.expire(
940
+ self.worker_tasks_set(self.name),
941
+ max(maximum_age, timedelta(seconds=1)),
942
+ )
943
+
944
+ await pipeline.execute()
945
+
946
+ async with r.pipeline() as pipeline:
947
+ pipeline.xlen(self.docket.stream_key)
948
+ pipeline.zcount(self.docket.queue_key, 0, now)
949
+ pipeline.zcount(self.docket.queue_key, now, "+inf")
950
+
951
+ results: list[int] = await pipeline.execute()
952
+ stream_depth = results[0]
953
+ overdue_depth = results[1]
954
+ schedule_depth = results[2]
955
+
956
+ QUEUE_DEPTH.set(
957
+ stream_depth + overdue_depth, self.docket.labels()
958
+ )
959
+ SCHEDULE_DEPTH.set(schedule_depth, self.docket.labels())
960
+
961
+ except asyncio.CancelledError: # pragma: no cover
962
+ return
963
+ except ConnectionError:
964
+ REDIS_DISRUPTIONS.add(1, self.labels())
965
+ logger.exception(
966
+ "Error sending worker heartbeat",
967
+ exc_info=True,
968
+ extra=self._log_context(),
969
+ )
970
+ except Exception:
971
+ logger.exception(
972
+ "Error sending worker heartbeat",
973
+ exc_info=True,
974
+ extra=self._log_context(),
975
+ )
976
+
977
+ async def _can_start_task(self, redis: Redis, execution: Execution) -> bool:
978
+ """Check if a task can start based on concurrency limits."""
979
+ # Check if task has a concurrency limit dependency
980
+ concurrency_limit = get_single_dependency_parameter_of_type(
981
+ execution.function, ConcurrencyLimit
982
+ )
983
+
984
+ if not concurrency_limit:
985
+ return True # No concurrency limit, can always start
986
+
987
+ # Get the concurrency key for this task
988
+ try:
989
+ argument_value = execution.get_argument(concurrency_limit.argument_name)
990
+ except KeyError:
991
+ # If argument not found, let the task fail naturally in execution
992
+ return True
993
+
994
+ scope = concurrency_limit.scope or self.docket.name
995
+ concurrency_key = (
996
+ f"{scope}:concurrency:{concurrency_limit.argument_name}:{argument_value}"
997
+ )
998
+
999
+ # Use Redis sorted set with timestamps to track concurrency and handle expiration
1000
+ lua_script = """
1001
+ local key = KEYS[1]
1002
+ local max_concurrent = tonumber(ARGV[1])
1003
+ local worker_id = ARGV[2]
1004
+ local task_key = ARGV[3]
1005
+ local current_time = tonumber(ARGV[4])
1006
+ local expiration_time = tonumber(ARGV[5])
1007
+
1008
+ -- Remove expired entries
1009
+ local expired_cutoff = current_time - expiration_time
1010
+ redis.call('ZREMRANGEBYSCORE', key, 0, expired_cutoff)
1011
+
1012
+ -- Get current count
1013
+ local current = redis.call('ZCARD', key)
1014
+
1015
+ if current < max_concurrent then
1016
+ -- Add this worker's task to the sorted set with current timestamp
1017
+ redis.call('ZADD', key, current_time, worker_id .. ':' .. task_key)
1018
+ return 1
1019
+ else
1020
+ return 0
1021
+ end
1022
+ """
1023
+
1024
+ current_time = datetime.now(timezone.utc).timestamp()
1025
+ expiration_seconds = self.redelivery_timeout.total_seconds()
1026
+
1027
+ result = await redis.eval( # type: ignore
1028
+ lua_script,
1029
+ 1,
1030
+ concurrency_key,
1031
+ str(concurrency_limit.max_concurrent),
1032
+ self.name,
1033
+ execution.key,
1034
+ current_time,
1035
+ expiration_seconds,
1036
+ )
1037
+
1038
+ return bool(result)
1039
+
1040
+ async def _release_concurrency_slot(
1041
+ self, redis: Redis, execution: Execution
1042
+ ) -> None:
1043
+ """Release a concurrency slot when task completes."""
1044
+ # Check if task has a concurrency limit dependency
1045
+ concurrency_limit = get_single_dependency_parameter_of_type(
1046
+ execution.function, ConcurrencyLimit
1047
+ )
1048
+
1049
+ if not concurrency_limit:
1050
+ return # No concurrency limit to release
1051
+
1052
+ # Get the concurrency key for this task
1053
+ try:
1054
+ argument_value = execution.get_argument(concurrency_limit.argument_name)
1055
+ except KeyError:
1056
+ return # If argument not found, nothing to release
1057
+
1058
+ scope = concurrency_limit.scope or self.docket.name
1059
+ concurrency_key = (
1060
+ f"{scope}:concurrency:{concurrency_limit.argument_name}:{argument_value}"
1061
+ )
1062
+
1063
+ # Remove this worker's task from the sorted set
1064
+ await redis.zrem(concurrency_key, f"{self.name}:{execution.key}") # type: ignore
1065
+
1066
+
1067
+ def ms(seconds: float) -> str:
1068
+ if seconds < 100:
1069
+ return f"{seconds * 1000:6.0f}ms"
1070
+ else:
1071
+ return f"{seconds:6.0f}s "