pydocket 0.15.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docket/__init__.py +55 -0
- docket/__main__.py +3 -0
- docket/_uuid7.py +99 -0
- docket/agenda.py +202 -0
- docket/annotations.py +81 -0
- docket/cli.py +1185 -0
- docket/dependencies.py +808 -0
- docket/docket.py +1062 -0
- docket/execution.py +1370 -0
- docket/instrumentation.py +225 -0
- docket/py.typed +0 -0
- docket/tasks.py +59 -0
- docket/testing.py +235 -0
- docket/worker.py +1071 -0
- pydocket-0.15.3.dist-info/METADATA +160 -0
- pydocket-0.15.3.dist-info/RECORD +19 -0
- pydocket-0.15.3.dist-info/WHEEL +4 -0
- pydocket-0.15.3.dist-info/entry_points.txt +2 -0
- pydocket-0.15.3.dist-info/licenses/LICENSE +9 -0
docket/worker.py
ADDED
|
@@ -0,0 +1,1071 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import base64
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import signal
|
|
6
|
+
import socket
|
|
7
|
+
import sys
|
|
8
|
+
import time
|
|
9
|
+
from datetime import datetime, timedelta, timezone
|
|
10
|
+
from types import TracebackType
|
|
11
|
+
from typing import Any, Coroutine, Mapping, Protocol, cast
|
|
12
|
+
|
|
13
|
+
import cloudpickle # type: ignore[import]
|
|
14
|
+
|
|
15
|
+
if sys.version_info < (3, 11): # pragma: no cover
|
|
16
|
+
from exceptiongroup import ExceptionGroup
|
|
17
|
+
|
|
18
|
+
from opentelemetry import trace
|
|
19
|
+
from opentelemetry.trace import Status, StatusCode, Tracer
|
|
20
|
+
from redis.asyncio import Redis
|
|
21
|
+
from redis.exceptions import ConnectionError, LockError, ResponseError
|
|
22
|
+
from typing_extensions import Self
|
|
23
|
+
|
|
24
|
+
from .dependencies import (
|
|
25
|
+
ConcurrencyLimit,
|
|
26
|
+
Dependency,
|
|
27
|
+
FailedDependency,
|
|
28
|
+
Perpetual,
|
|
29
|
+
Retry,
|
|
30
|
+
Timeout,
|
|
31
|
+
get_single_dependency_of_type,
|
|
32
|
+
get_single_dependency_parameter_of_type,
|
|
33
|
+
resolved_dependencies,
|
|
34
|
+
)
|
|
35
|
+
from .docket import (
|
|
36
|
+
Docket,
|
|
37
|
+
Execution,
|
|
38
|
+
RedisMessage,
|
|
39
|
+
RedisMessageID,
|
|
40
|
+
RedisReadGroupResponse,
|
|
41
|
+
)
|
|
42
|
+
from .execution import compact_signature, get_signature
|
|
43
|
+
|
|
44
|
+
# Run class has been consolidated into Execution
|
|
45
|
+
from .instrumentation import (
|
|
46
|
+
QUEUE_DEPTH,
|
|
47
|
+
REDIS_DISRUPTIONS,
|
|
48
|
+
SCHEDULE_DEPTH,
|
|
49
|
+
TASK_DURATION,
|
|
50
|
+
TASK_PUNCTUALITY,
|
|
51
|
+
TASKS_COMPLETED,
|
|
52
|
+
TASKS_FAILED,
|
|
53
|
+
TASKS_PERPETUATED,
|
|
54
|
+
TASKS_REDELIVERED,
|
|
55
|
+
TASKS_RETRIED,
|
|
56
|
+
TASKS_RUNNING,
|
|
57
|
+
TASKS_STARTED,
|
|
58
|
+
TASKS_STRICKEN,
|
|
59
|
+
TASKS_SUCCEEDED,
|
|
60
|
+
healthcheck_server,
|
|
61
|
+
metrics_server,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Delay before retrying a task blocked by concurrency limits
|
|
65
|
+
# Must be larger than redelivery_timeout to ensure atomic reschedule+ACK completes
|
|
66
|
+
# before Redis would consider redelivering the message
|
|
67
|
+
CONCURRENCY_BLOCKED_RETRY_DELAY = timedelta(milliseconds=100)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ConcurrencyBlocked(Exception):
|
|
71
|
+
"""Raised when a task cannot start due to concurrency limits."""
|
|
72
|
+
|
|
73
|
+
def __init__(self, execution: Execution):
|
|
74
|
+
self.execution = execution
|
|
75
|
+
super().__init__(f"Task {execution.key} blocked by concurrency limits")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
79
|
+
tracer: Tracer = trace.get_tracer(__name__)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class _stream_due_tasks(Protocol):
|
|
83
|
+
async def __call__(
|
|
84
|
+
self, keys: list[str], args: list[str | float]
|
|
85
|
+
) -> tuple[int, int]: ... # pragma: no cover
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class Worker:
|
|
89
|
+
"""A Worker executes tasks on a Docket. You may run as many workers as you like
|
|
90
|
+
to work a single Docket.
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
|
|
94
|
+
```python
|
|
95
|
+
async with Docket() as docket:
|
|
96
|
+
async with Worker(docket) as worker:
|
|
97
|
+
await worker.run_forever()
|
|
98
|
+
```
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
docket: Docket
|
|
102
|
+
name: str
|
|
103
|
+
concurrency: int
|
|
104
|
+
redelivery_timeout: timedelta
|
|
105
|
+
reconnection_delay: timedelta
|
|
106
|
+
minimum_check_interval: timedelta
|
|
107
|
+
scheduling_resolution: timedelta
|
|
108
|
+
schedule_automatic_tasks: bool
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
docket: Docket,
|
|
113
|
+
name: str | None = None,
|
|
114
|
+
concurrency: int = 10,
|
|
115
|
+
redelivery_timeout: timedelta = timedelta(minutes=5),
|
|
116
|
+
reconnection_delay: timedelta = timedelta(seconds=5),
|
|
117
|
+
minimum_check_interval: timedelta = timedelta(milliseconds=250),
|
|
118
|
+
scheduling_resolution: timedelta = timedelta(milliseconds=250),
|
|
119
|
+
schedule_automatic_tasks: bool = True,
|
|
120
|
+
) -> None:
|
|
121
|
+
self.docket = docket
|
|
122
|
+
self.name = name or f"{socket.gethostname()}#{os.getpid()}"
|
|
123
|
+
self.concurrency = concurrency
|
|
124
|
+
self.redelivery_timeout = redelivery_timeout
|
|
125
|
+
self.reconnection_delay = reconnection_delay
|
|
126
|
+
self.minimum_check_interval = minimum_check_interval
|
|
127
|
+
self.scheduling_resolution = scheduling_resolution
|
|
128
|
+
self.schedule_automatic_tasks = schedule_automatic_tasks
|
|
129
|
+
|
|
130
|
+
async def __aenter__(self) -> Self:
|
|
131
|
+
self._heartbeat_task = asyncio.create_task(self._heartbeat())
|
|
132
|
+
self._execution_counts = {}
|
|
133
|
+
return self
|
|
134
|
+
|
|
135
|
+
async def __aexit__(
|
|
136
|
+
self,
|
|
137
|
+
exc_type: type[BaseException] | None,
|
|
138
|
+
exc_value: BaseException | None,
|
|
139
|
+
traceback: TracebackType | None,
|
|
140
|
+
) -> None:
|
|
141
|
+
del self._execution_counts
|
|
142
|
+
|
|
143
|
+
self._heartbeat_task.cancel()
|
|
144
|
+
try:
|
|
145
|
+
await self._heartbeat_task
|
|
146
|
+
except asyncio.CancelledError:
|
|
147
|
+
pass
|
|
148
|
+
del self._heartbeat_task
|
|
149
|
+
|
|
150
|
+
def labels(self) -> Mapping[str, str]:
|
|
151
|
+
return {
|
|
152
|
+
**self.docket.labels(),
|
|
153
|
+
"docket.worker": self.name,
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
def _log_context(self) -> Mapping[str, str]:
|
|
157
|
+
return {
|
|
158
|
+
**self.labels(),
|
|
159
|
+
"docket.queue_key": self.docket.queue_key,
|
|
160
|
+
"docket.stream_key": self.docket.stream_key,
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
@classmethod
|
|
164
|
+
async def run(
|
|
165
|
+
cls,
|
|
166
|
+
docket_name: str = "docket",
|
|
167
|
+
url: str = "redis://localhost:6379/0",
|
|
168
|
+
name: str | None = None,
|
|
169
|
+
concurrency: int = 10,
|
|
170
|
+
redelivery_timeout: timedelta = timedelta(minutes=5),
|
|
171
|
+
reconnection_delay: timedelta = timedelta(seconds=5),
|
|
172
|
+
minimum_check_interval: timedelta = timedelta(milliseconds=100),
|
|
173
|
+
scheduling_resolution: timedelta = timedelta(milliseconds=250),
|
|
174
|
+
schedule_automatic_tasks: bool = True,
|
|
175
|
+
until_finished: bool = False,
|
|
176
|
+
healthcheck_port: int | None = None,
|
|
177
|
+
metrics_port: int | None = None,
|
|
178
|
+
tasks: list[str] = ["docket.tasks:standard_tasks"],
|
|
179
|
+
) -> None:
|
|
180
|
+
"""Run a worker as the main entry point (CLI).
|
|
181
|
+
|
|
182
|
+
This method installs signal handlers for graceful shutdown since it
|
|
183
|
+
assumes ownership of the event loop. When embedding Docket in another
|
|
184
|
+
framework (e.g., FastAPI with uvicorn), use Worker.run_forever() or
|
|
185
|
+
Worker.run_until_finished() directly - those methods do not install
|
|
186
|
+
signal handlers and rely on the framework to handle shutdown signals.
|
|
187
|
+
"""
|
|
188
|
+
with (
|
|
189
|
+
healthcheck_server(port=healthcheck_port),
|
|
190
|
+
metrics_server(port=metrics_port),
|
|
191
|
+
):
|
|
192
|
+
async with Docket(name=docket_name, url=url) as docket:
|
|
193
|
+
for task_path in tasks:
|
|
194
|
+
docket.register_collection(task_path)
|
|
195
|
+
|
|
196
|
+
async with (
|
|
197
|
+
Worker( # pragma: no branch - context manager exit varies across interpreters
|
|
198
|
+
docket=docket,
|
|
199
|
+
name=name,
|
|
200
|
+
concurrency=concurrency,
|
|
201
|
+
redelivery_timeout=redelivery_timeout,
|
|
202
|
+
reconnection_delay=reconnection_delay,
|
|
203
|
+
minimum_check_interval=minimum_check_interval,
|
|
204
|
+
scheduling_resolution=scheduling_resolution,
|
|
205
|
+
schedule_automatic_tasks=schedule_automatic_tasks,
|
|
206
|
+
) as worker
|
|
207
|
+
):
|
|
208
|
+
# Install signal handlers for graceful shutdown.
|
|
209
|
+
# This is only appropriate when we own the event loop (CLI entry point).
|
|
210
|
+
# Embedded usage should let the framework handle signals.
|
|
211
|
+
loop = asyncio.get_running_loop()
|
|
212
|
+
run_task: asyncio.Task[None] | None = None
|
|
213
|
+
|
|
214
|
+
def handle_shutdown(sig_name: str) -> None: # pragma: no cover
|
|
215
|
+
logger.info(
|
|
216
|
+
"Received %s, initiating graceful shutdown...", sig_name
|
|
217
|
+
)
|
|
218
|
+
if run_task and not run_task.done():
|
|
219
|
+
run_task.cancel()
|
|
220
|
+
|
|
221
|
+
if hasattr(signal, "SIGTERM"): # pragma: no cover
|
|
222
|
+
loop.add_signal_handler(
|
|
223
|
+
signal.SIGTERM, lambda: handle_shutdown("SIGTERM")
|
|
224
|
+
)
|
|
225
|
+
loop.add_signal_handler(
|
|
226
|
+
signal.SIGINT, lambda: handle_shutdown("SIGINT")
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
try:
|
|
230
|
+
if until_finished:
|
|
231
|
+
run_task = asyncio.create_task(worker.run_until_finished())
|
|
232
|
+
else:
|
|
233
|
+
run_task = asyncio.create_task(
|
|
234
|
+
worker.run_forever()
|
|
235
|
+
) # pragma: no cover
|
|
236
|
+
await run_task
|
|
237
|
+
except asyncio.CancelledError: # pragma: no cover
|
|
238
|
+
pass
|
|
239
|
+
finally:
|
|
240
|
+
if hasattr(signal, "SIGTERM"): # pragma: no cover
|
|
241
|
+
loop.remove_signal_handler(signal.SIGTERM)
|
|
242
|
+
loop.remove_signal_handler(signal.SIGINT)
|
|
243
|
+
|
|
244
|
+
async def run_until_finished(self) -> None:
|
|
245
|
+
"""Run the worker until there are no more tasks to process."""
|
|
246
|
+
return await self._run(forever=False)
|
|
247
|
+
|
|
248
|
+
async def run_forever(self) -> None:
|
|
249
|
+
"""Run the worker indefinitely."""
|
|
250
|
+
return await self._run(forever=True) # pragma: no cover
|
|
251
|
+
|
|
252
|
+
_execution_counts: dict[str, int]
|
|
253
|
+
|
|
254
|
+
async def run_at_most(self, iterations_by_key: Mapping[str, int]) -> None:
|
|
255
|
+
"""
|
|
256
|
+
Run the worker until there are no more tasks to process, but limit specified
|
|
257
|
+
task keys to a maximum number of iterations.
|
|
258
|
+
|
|
259
|
+
This is particularly useful for testing self-perpetuating tasks that would
|
|
260
|
+
otherwise run indefinitely.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
iterations_by_key: Maps task keys to their maximum allowed executions
|
|
264
|
+
"""
|
|
265
|
+
self._execution_counts = {key: 0 for key in iterations_by_key}
|
|
266
|
+
|
|
267
|
+
def has_reached_max_iterations(execution: Execution) -> bool:
|
|
268
|
+
key = execution.key
|
|
269
|
+
|
|
270
|
+
if key not in iterations_by_key:
|
|
271
|
+
return False
|
|
272
|
+
|
|
273
|
+
if self._execution_counts[key] >= iterations_by_key[key]:
|
|
274
|
+
return True
|
|
275
|
+
|
|
276
|
+
return False
|
|
277
|
+
|
|
278
|
+
self.docket.strike_list.add_condition(has_reached_max_iterations)
|
|
279
|
+
try:
|
|
280
|
+
await self.run_until_finished()
|
|
281
|
+
finally:
|
|
282
|
+
self.docket.strike_list.remove_condition(has_reached_max_iterations)
|
|
283
|
+
self._execution_counts = {}
|
|
284
|
+
|
|
285
|
+
async def _run(self, forever: bool = False) -> None:
|
|
286
|
+
self._startup_log()
|
|
287
|
+
|
|
288
|
+
while True:
|
|
289
|
+
try:
|
|
290
|
+
async with self.docket.redis() as redis:
|
|
291
|
+
return await self._worker_loop(redis, forever=forever)
|
|
292
|
+
except ConnectionError:
|
|
293
|
+
REDIS_DISRUPTIONS.add(1, self.labels())
|
|
294
|
+
logger.warning(
|
|
295
|
+
"Error connecting to redis, retrying in %s...",
|
|
296
|
+
self.reconnection_delay,
|
|
297
|
+
exc_info=True,
|
|
298
|
+
)
|
|
299
|
+
await asyncio.sleep(self.reconnection_delay.total_seconds())
|
|
300
|
+
|
|
301
|
+
async def _worker_loop(self, redis: Redis, forever: bool = False):
|
|
302
|
+
worker_stopping = asyncio.Event()
|
|
303
|
+
|
|
304
|
+
if self.schedule_automatic_tasks:
|
|
305
|
+
await self._schedule_all_automatic_perpetual_tasks()
|
|
306
|
+
|
|
307
|
+
scheduler_task = asyncio.create_task(
|
|
308
|
+
self._scheduler_loop(redis, worker_stopping)
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
|
|
312
|
+
task_executions: dict[asyncio.Task[None], Execution] = {}
|
|
313
|
+
available_slots = self.concurrency
|
|
314
|
+
|
|
315
|
+
log_context = self._log_context()
|
|
316
|
+
|
|
317
|
+
async def check_for_work() -> bool:
|
|
318
|
+
logger.debug("Checking for work", extra=log_context)
|
|
319
|
+
async with redis.pipeline() as pipeline:
|
|
320
|
+
pipeline.xlen(self.docket.stream_key)
|
|
321
|
+
pipeline.zcard(self.docket.queue_key)
|
|
322
|
+
results: list[int] = await pipeline.execute()
|
|
323
|
+
stream_len = results[0]
|
|
324
|
+
queue_len = results[1]
|
|
325
|
+
return stream_len > 0 or queue_len > 0
|
|
326
|
+
|
|
327
|
+
async def get_redeliveries(redis: Redis) -> RedisReadGroupResponse:
|
|
328
|
+
logger.debug("Getting redeliveries", extra=log_context)
|
|
329
|
+
try:
|
|
330
|
+
_, redeliveries, *_ = await redis.xautoclaim(
|
|
331
|
+
name=self.docket.stream_key,
|
|
332
|
+
groupname=self.docket.worker_group_name,
|
|
333
|
+
consumername=self.name,
|
|
334
|
+
min_idle_time=int(self.redelivery_timeout.total_seconds() * 1000),
|
|
335
|
+
start_id="0-0",
|
|
336
|
+
count=available_slots,
|
|
337
|
+
)
|
|
338
|
+
except ResponseError as e:
|
|
339
|
+
if "NOGROUP" in str(e):
|
|
340
|
+
await self.docket._ensure_stream_and_group()
|
|
341
|
+
return await get_redeliveries(redis)
|
|
342
|
+
raise # pragma: no cover
|
|
343
|
+
return [(b"__redelivery__", redeliveries)]
|
|
344
|
+
|
|
345
|
+
async def get_new_deliveries(redis: Redis) -> RedisReadGroupResponse:
|
|
346
|
+
logger.debug("Getting new deliveries", extra=log_context)
|
|
347
|
+
# Use non-blocking read with in-memory backend + manual sleep
|
|
348
|
+
# This is necessary because fakeredis's async blocking operations don't
|
|
349
|
+
# properly yield control to the asyncio event loop
|
|
350
|
+
is_memory = self.docket.url.startswith("memory://")
|
|
351
|
+
try:
|
|
352
|
+
result = await redis.xreadgroup(
|
|
353
|
+
groupname=self.docket.worker_group_name,
|
|
354
|
+
consumername=self.name,
|
|
355
|
+
streams={self.docket.stream_key: ">"},
|
|
356
|
+
block=0
|
|
357
|
+
if is_memory
|
|
358
|
+
else int(self.minimum_check_interval.total_seconds() * 1000),
|
|
359
|
+
count=available_slots,
|
|
360
|
+
)
|
|
361
|
+
except ResponseError as e:
|
|
362
|
+
if "NOGROUP" in str(e):
|
|
363
|
+
await self.docket._ensure_stream_and_group()
|
|
364
|
+
return await get_new_deliveries(redis)
|
|
365
|
+
raise # pragma: no cover
|
|
366
|
+
if is_memory and not result:
|
|
367
|
+
await asyncio.sleep(self.minimum_check_interval.total_seconds())
|
|
368
|
+
return result
|
|
369
|
+
|
|
370
|
+
async def start_task(
|
|
371
|
+
message_id: RedisMessageID,
|
|
372
|
+
message: RedisMessage,
|
|
373
|
+
is_redelivery: bool = False,
|
|
374
|
+
) -> bool:
|
|
375
|
+
try:
|
|
376
|
+
execution = await Execution.from_message(
|
|
377
|
+
self.docket, message, redelivered=is_redelivery
|
|
378
|
+
)
|
|
379
|
+
except ValueError as e:
|
|
380
|
+
logger.error(
|
|
381
|
+
"Unable to start task: %s",
|
|
382
|
+
e,
|
|
383
|
+
extra=log_context,
|
|
384
|
+
)
|
|
385
|
+
return False
|
|
386
|
+
|
|
387
|
+
task = asyncio.create_task(self._execute(execution), name=execution.key)
|
|
388
|
+
active_tasks[task] = message_id
|
|
389
|
+
task_executions[task] = execution
|
|
390
|
+
|
|
391
|
+
nonlocal available_slots
|
|
392
|
+
available_slots -= 1
|
|
393
|
+
|
|
394
|
+
return True
|
|
395
|
+
|
|
396
|
+
async def process_completed_tasks() -> None:
|
|
397
|
+
completed_tasks = {task for task in active_tasks if task.done()}
|
|
398
|
+
for task in completed_tasks:
|
|
399
|
+
message_id = active_tasks.pop(task)
|
|
400
|
+
task_executions.pop(task)
|
|
401
|
+
try:
|
|
402
|
+
await task
|
|
403
|
+
# Task succeeded - acknowledge the message
|
|
404
|
+
await ack_message(redis, message_id)
|
|
405
|
+
except ConcurrencyBlocked as e:
|
|
406
|
+
# Task was blocked by concurrency limits, reschedule atomically
|
|
407
|
+
logger.debug(
|
|
408
|
+
"🔒 Task %s blocked by concurrency limit, rescheduling",
|
|
409
|
+
e.execution.key,
|
|
410
|
+
extra=log_context,
|
|
411
|
+
)
|
|
412
|
+
# Use atomic schedule(reschedule_message=...) to prevent both task loss and duplicate execution
|
|
413
|
+
e.execution.when = (
|
|
414
|
+
datetime.now(timezone.utc) + CONCURRENCY_BLOCKED_RETRY_DELAY
|
|
415
|
+
)
|
|
416
|
+
await e.execution.schedule(reschedule_message=message_id)
|
|
417
|
+
|
|
418
|
+
async def ack_message(redis: Redis, message_id: RedisMessageID) -> None:
|
|
419
|
+
logger.debug("Acknowledging message", extra=log_context)
|
|
420
|
+
async with redis.pipeline() as pipeline:
|
|
421
|
+
pipeline.xack(
|
|
422
|
+
self.docket.stream_key,
|
|
423
|
+
self.docket.worker_group_name,
|
|
424
|
+
message_id,
|
|
425
|
+
)
|
|
426
|
+
pipeline.xdel(
|
|
427
|
+
self.docket.stream_key,
|
|
428
|
+
message_id,
|
|
429
|
+
)
|
|
430
|
+
await pipeline.execute()
|
|
431
|
+
|
|
432
|
+
has_work: bool = True
|
|
433
|
+
|
|
434
|
+
try:
|
|
435
|
+
while forever or has_work or active_tasks:
|
|
436
|
+
await process_completed_tasks()
|
|
437
|
+
|
|
438
|
+
available_slots = self.concurrency - len(active_tasks)
|
|
439
|
+
|
|
440
|
+
if available_slots <= 0:
|
|
441
|
+
await asyncio.sleep(self.minimum_check_interval.total_seconds())
|
|
442
|
+
continue
|
|
443
|
+
|
|
444
|
+
for source in [get_redeliveries, get_new_deliveries]:
|
|
445
|
+
for stream_key, messages in await source(redis):
|
|
446
|
+
is_redelivery = stream_key == b"__redelivery__"
|
|
447
|
+
for message_id, message in messages:
|
|
448
|
+
if not message: # pragma: no cover
|
|
449
|
+
continue
|
|
450
|
+
|
|
451
|
+
task_started = await start_task(
|
|
452
|
+
message_id, message, is_redelivery
|
|
453
|
+
)
|
|
454
|
+
if not task_started:
|
|
455
|
+
await self._delete_known_task(redis, message)
|
|
456
|
+
await ack_message(redis, message_id)
|
|
457
|
+
|
|
458
|
+
if available_slots <= 0:
|
|
459
|
+
break
|
|
460
|
+
|
|
461
|
+
if not forever and not active_tasks:
|
|
462
|
+
has_work = await check_for_work()
|
|
463
|
+
|
|
464
|
+
except asyncio.CancelledError:
|
|
465
|
+
if active_tasks: # pragma: no cover
|
|
466
|
+
logger.info(
|
|
467
|
+
"Shutdown requested, finishing %d active tasks...",
|
|
468
|
+
len(active_tasks),
|
|
469
|
+
extra=log_context,
|
|
470
|
+
)
|
|
471
|
+
finally:
|
|
472
|
+
if active_tasks:
|
|
473
|
+
await asyncio.gather(*active_tasks, return_exceptions=True)
|
|
474
|
+
await process_completed_tasks()
|
|
475
|
+
|
|
476
|
+
worker_stopping.set()
|
|
477
|
+
await scheduler_task
|
|
478
|
+
|
|
479
|
+
async def _scheduler_loop(
|
|
480
|
+
self,
|
|
481
|
+
redis: Redis,
|
|
482
|
+
worker_stopping: asyncio.Event,
|
|
483
|
+
) -> None:
|
|
484
|
+
"""Loop that moves due tasks from the queue to the stream."""
|
|
485
|
+
|
|
486
|
+
stream_due_tasks: _stream_due_tasks = cast(
|
|
487
|
+
_stream_due_tasks,
|
|
488
|
+
redis.register_script(
|
|
489
|
+
# Lua script to atomically move scheduled tasks to the stream
|
|
490
|
+
# KEYS[1]: queue key (sorted set)
|
|
491
|
+
# KEYS[2]: stream key
|
|
492
|
+
# ARGV[1]: current timestamp
|
|
493
|
+
# ARGV[2]: docket name prefix
|
|
494
|
+
"""
|
|
495
|
+
local total_work = redis.call('ZCARD', KEYS[1])
|
|
496
|
+
local due_work = 0
|
|
497
|
+
|
|
498
|
+
if total_work > 0 then
|
|
499
|
+
local tasks = redis.call('ZRANGEBYSCORE', KEYS[1], 0, ARGV[1])
|
|
500
|
+
|
|
501
|
+
for i, key in ipairs(tasks) do
|
|
502
|
+
local hash_key = ARGV[2] .. ":" .. key
|
|
503
|
+
local task_data = redis.call('HGETALL', hash_key)
|
|
504
|
+
|
|
505
|
+
if #task_data > 0 then
|
|
506
|
+
local task = {}
|
|
507
|
+
for j = 1, #task_data, 2 do
|
|
508
|
+
task[task_data[j]] = task_data[j+1]
|
|
509
|
+
end
|
|
510
|
+
|
|
511
|
+
redis.call('XADD', KEYS[2], '*',
|
|
512
|
+
'key', task['key'],
|
|
513
|
+
'when', task['when'],
|
|
514
|
+
'function', task['function'],
|
|
515
|
+
'args', task['args'],
|
|
516
|
+
'kwargs', task['kwargs'],
|
|
517
|
+
'attempt', task['attempt']
|
|
518
|
+
)
|
|
519
|
+
redis.call('DEL', hash_key)
|
|
520
|
+
|
|
521
|
+
-- Set run state to queued
|
|
522
|
+
local run_key = ARGV[2] .. ":runs:" .. task['key']
|
|
523
|
+
redis.call('HSET', run_key, 'state', 'queued')
|
|
524
|
+
|
|
525
|
+
-- Publish state change event to pub/sub
|
|
526
|
+
local channel = ARGV[2] .. ":state:" .. task['key']
|
|
527
|
+
local payload = '{"type":"state","key":"' .. task['key'] .. '","state":"queued","when":"' .. task['when'] .. '"}'
|
|
528
|
+
redis.call('PUBLISH', channel, payload)
|
|
529
|
+
|
|
530
|
+
due_work = due_work + 1
|
|
531
|
+
end
|
|
532
|
+
end
|
|
533
|
+
end
|
|
534
|
+
|
|
535
|
+
if due_work > 0 then
|
|
536
|
+
redis.call('ZREMRANGEBYSCORE', KEYS[1], 0, ARGV[1])
|
|
537
|
+
end
|
|
538
|
+
|
|
539
|
+
return {total_work, due_work}
|
|
540
|
+
"""
|
|
541
|
+
),
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
total_work: int = sys.maxsize
|
|
545
|
+
|
|
546
|
+
log_context = self._log_context()
|
|
547
|
+
|
|
548
|
+
while not worker_stopping.is_set() or total_work:
|
|
549
|
+
try:
|
|
550
|
+
logger.debug("Scheduling due tasks", extra=log_context)
|
|
551
|
+
total_work, due_work = await stream_due_tasks(
|
|
552
|
+
keys=[self.docket.queue_key, self.docket.stream_key],
|
|
553
|
+
args=[datetime.now(timezone.utc).timestamp(), self.docket.name],
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
if due_work > 0:
|
|
557
|
+
logger.debug(
|
|
558
|
+
"Moved %d/%d due tasks from %s to %s",
|
|
559
|
+
due_work,
|
|
560
|
+
total_work,
|
|
561
|
+
self.docket.queue_key,
|
|
562
|
+
self.docket.stream_key,
|
|
563
|
+
extra=log_context,
|
|
564
|
+
)
|
|
565
|
+
except Exception: # pragma: no cover
|
|
566
|
+
logger.exception(
|
|
567
|
+
"Error in scheduler loop",
|
|
568
|
+
exc_info=True,
|
|
569
|
+
extra=log_context,
|
|
570
|
+
)
|
|
571
|
+
finally:
|
|
572
|
+
await asyncio.sleep(self.scheduling_resolution.total_seconds())
|
|
573
|
+
|
|
574
|
+
logger.debug("Scheduler loop finished", extra=log_context)
|
|
575
|
+
|
|
576
|
+
async def _schedule_all_automatic_perpetual_tasks(self) -> None:
|
|
577
|
+
async with self.docket.redis() as redis:
|
|
578
|
+
try:
|
|
579
|
+
async with redis.lock(
|
|
580
|
+
f"{self.docket.name}:perpetual:lock", timeout=10, blocking=False
|
|
581
|
+
):
|
|
582
|
+
for task_function in self.docket.tasks.values():
|
|
583
|
+
perpetual = get_single_dependency_parameter_of_type(
|
|
584
|
+
task_function, Perpetual
|
|
585
|
+
)
|
|
586
|
+
if perpetual is None:
|
|
587
|
+
continue
|
|
588
|
+
|
|
589
|
+
if not perpetual.automatic:
|
|
590
|
+
continue
|
|
591
|
+
|
|
592
|
+
key = task_function.__name__
|
|
593
|
+
|
|
594
|
+
await self.docket.add(task_function, key=key)()
|
|
595
|
+
except LockError: # pragma: no cover
|
|
596
|
+
return
|
|
597
|
+
|
|
598
|
+
async def _delete_known_task(
|
|
599
|
+
self, redis: Redis, execution_or_message: Execution | RedisMessage
|
|
600
|
+
) -> None:
|
|
601
|
+
if isinstance(execution_or_message, Execution):
|
|
602
|
+
key = execution_or_message.key
|
|
603
|
+
elif bytes_key := execution_or_message.get(b"key"):
|
|
604
|
+
key = bytes_key.decode()
|
|
605
|
+
else: # pragma: no cover
|
|
606
|
+
return
|
|
607
|
+
|
|
608
|
+
logger.debug("Deleting known task", extra=self._log_context())
|
|
609
|
+
# Delete known/stream_id from runs hash to allow task rescheduling
|
|
610
|
+
runs_key = f"{self.docket.name}:runs:{key}"
|
|
611
|
+
await redis.hdel(runs_key, "known", "stream_id")
|
|
612
|
+
|
|
613
|
+
# TODO: Remove in next breaking release (v0.14.0) - legacy key cleanup
|
|
614
|
+
known_task_key = self.docket.known_task_key(key)
|
|
615
|
+
stream_id_key = self.docket.stream_id_key(key)
|
|
616
|
+
await redis.delete(known_task_key, stream_id_key)
|
|
617
|
+
|
|
618
|
+
async def _execute(self, execution: Execution) -> None:
|
|
619
|
+
log_context = {**self._log_context(), **execution.specific_labels()}
|
|
620
|
+
counter_labels = {**self.labels(), **execution.general_labels()}
|
|
621
|
+
|
|
622
|
+
call = execution.call_repr()
|
|
623
|
+
|
|
624
|
+
if self.docket.strike_list.is_stricken(execution):
|
|
625
|
+
async with self.docket.redis() as redis:
|
|
626
|
+
await self._delete_known_task(redis, execution)
|
|
627
|
+
|
|
628
|
+
logger.warning("🗙 %s", call, extra=log_context)
|
|
629
|
+
TASKS_STRICKEN.add(1, counter_labels | {"docket.where": "worker"})
|
|
630
|
+
return
|
|
631
|
+
|
|
632
|
+
if execution.key in self._execution_counts:
|
|
633
|
+
self._execution_counts[execution.key] += 1
|
|
634
|
+
|
|
635
|
+
start = time.time()
|
|
636
|
+
punctuality = start - execution.when.timestamp()
|
|
637
|
+
log_context = {**log_context, "punctuality": punctuality}
|
|
638
|
+
duration = 0.0
|
|
639
|
+
|
|
640
|
+
TASKS_STARTED.add(1, counter_labels)
|
|
641
|
+
if execution.redelivered:
|
|
642
|
+
TASKS_REDELIVERED.add(1, counter_labels)
|
|
643
|
+
TASKS_RUNNING.add(1, counter_labels)
|
|
644
|
+
TASK_PUNCTUALITY.record(punctuality, counter_labels)
|
|
645
|
+
|
|
646
|
+
arrow = "↬" if execution.attempt > 1 else "↪"
|
|
647
|
+
logger.info("%s [%s] %s", arrow, ms(punctuality), call, extra=log_context)
|
|
648
|
+
|
|
649
|
+
# Atomically claim task and transition to running state
|
|
650
|
+
# This also initializes progress and cleans up known/stream_id to allow rescheduling
|
|
651
|
+
await execution.claim(self.name)
|
|
652
|
+
|
|
653
|
+
dependencies: dict[str, Dependency] = {}
|
|
654
|
+
|
|
655
|
+
with tracer.start_as_current_span(
|
|
656
|
+
execution.function.__name__,
|
|
657
|
+
kind=trace.SpanKind.CONSUMER,
|
|
658
|
+
attributes={
|
|
659
|
+
**self.labels(),
|
|
660
|
+
**execution.specific_labels(),
|
|
661
|
+
"code.function.name": execution.function.__name__,
|
|
662
|
+
},
|
|
663
|
+
links=execution.incoming_span_links(),
|
|
664
|
+
) as span:
|
|
665
|
+
try:
|
|
666
|
+
async with resolved_dependencies(self, execution) as dependencies:
|
|
667
|
+
# Check concurrency limits after dependency resolution
|
|
668
|
+
concurrency_limit = get_single_dependency_of_type(
|
|
669
|
+
dependencies, ConcurrencyLimit
|
|
670
|
+
)
|
|
671
|
+
if (
|
|
672
|
+
concurrency_limit and not concurrency_limit.is_bypassed
|
|
673
|
+
): # pragma: no branch - coverage.py on Python 3.10 struggles with this
|
|
674
|
+
async with self.docket.redis() as redis:
|
|
675
|
+
# Check if we can acquire a concurrency slot
|
|
676
|
+
can_start = await self._can_start_task(redis, execution)
|
|
677
|
+
if not can_start: # pragma: no branch - 3.10 failure
|
|
678
|
+
# Task cannot start due to concurrency limits
|
|
679
|
+
raise ConcurrencyBlocked(execution)
|
|
680
|
+
|
|
681
|
+
dependency_failures = {
|
|
682
|
+
k: v
|
|
683
|
+
for k, v in dependencies.items()
|
|
684
|
+
if isinstance(v, FailedDependency)
|
|
685
|
+
}
|
|
686
|
+
if dependency_failures:
|
|
687
|
+
raise ExceptionGroup(
|
|
688
|
+
(
|
|
689
|
+
"Failed to resolve dependencies for parameter(s): "
|
|
690
|
+
+ ", ".join(dependency_failures.keys())
|
|
691
|
+
),
|
|
692
|
+
[
|
|
693
|
+
dependency.error
|
|
694
|
+
for dependency in dependency_failures.values()
|
|
695
|
+
],
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
# Apply timeout logic - either user's timeout or redelivery timeout
|
|
699
|
+
user_timeout = get_single_dependency_of_type(dependencies, Timeout)
|
|
700
|
+
if user_timeout:
|
|
701
|
+
# If user timeout is longer than redelivery timeout, limit it
|
|
702
|
+
if user_timeout.base > self.redelivery_timeout:
|
|
703
|
+
# Create a new timeout limited by redelivery timeout
|
|
704
|
+
# Remove the user timeout from dependencies to avoid conflicts
|
|
705
|
+
limited_dependencies = {
|
|
706
|
+
k: v
|
|
707
|
+
for k, v in dependencies.items()
|
|
708
|
+
if not isinstance(v, Timeout)
|
|
709
|
+
}
|
|
710
|
+
limited_timeout = Timeout(self.redelivery_timeout)
|
|
711
|
+
limited_timeout.start()
|
|
712
|
+
result = await self._run_function_with_timeout(
|
|
713
|
+
execution, limited_dependencies, limited_timeout
|
|
714
|
+
)
|
|
715
|
+
else:
|
|
716
|
+
# User timeout is within redelivery timeout, use as-is
|
|
717
|
+
result = await self._run_function_with_timeout(
|
|
718
|
+
execution, dependencies, user_timeout
|
|
719
|
+
)
|
|
720
|
+
else:
|
|
721
|
+
# No user timeout - apply redelivery timeout as hard limit
|
|
722
|
+
redelivery_timeout = Timeout(self.redelivery_timeout)
|
|
723
|
+
redelivery_timeout.start()
|
|
724
|
+
result = await self._run_function_with_timeout(
|
|
725
|
+
execution, dependencies, redelivery_timeout
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
duration = log_context["duration"] = time.time() - start
|
|
729
|
+
TASKS_SUCCEEDED.add(1, counter_labels)
|
|
730
|
+
|
|
731
|
+
span.set_status(Status(StatusCode.OK))
|
|
732
|
+
|
|
733
|
+
rescheduled = await self._perpetuate_if_requested(
|
|
734
|
+
execution, dependencies, timedelta(seconds=duration)
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
if rescheduled:
|
|
738
|
+
# Task was rescheduled - still mark this execution as completed
|
|
739
|
+
# to set TTL on the runs hash (the new execution has its own entry)
|
|
740
|
+
await execution.mark_as_completed(result_key=None)
|
|
741
|
+
else:
|
|
742
|
+
# Store result if appropriate
|
|
743
|
+
result_key = None
|
|
744
|
+
if result is not None and self.docket.execution_ttl:
|
|
745
|
+
# Serialize and store result
|
|
746
|
+
pickled_result = cloudpickle.dumps(result) # type: ignore[arg-type]
|
|
747
|
+
# Base64-encode for JSON serialization
|
|
748
|
+
encoded_result = base64.b64encode(pickled_result).decode(
|
|
749
|
+
"ascii"
|
|
750
|
+
)
|
|
751
|
+
result_key = execution.key
|
|
752
|
+
ttl_seconds = int(self.docket.execution_ttl.total_seconds())
|
|
753
|
+
await self.docket.result_storage.put(
|
|
754
|
+
result_key, {"data": encoded_result}, ttl=ttl_seconds
|
|
755
|
+
)
|
|
756
|
+
# Mark execution as completed
|
|
757
|
+
await execution.mark_as_completed(result_key=result_key)
|
|
758
|
+
|
|
759
|
+
arrow = "↫" if rescheduled else "↩"
|
|
760
|
+
logger.info(
|
|
761
|
+
"%s [%s] %s", arrow, ms(duration), call, extra=log_context
|
|
762
|
+
)
|
|
763
|
+
except ConcurrencyBlocked:
|
|
764
|
+
# Re-raise to be handled by process_completed_tasks
|
|
765
|
+
raise
|
|
766
|
+
except Exception as e:
|
|
767
|
+
duration = log_context["duration"] = time.time() - start
|
|
768
|
+
TASKS_FAILED.add(1, counter_labels)
|
|
769
|
+
|
|
770
|
+
span.record_exception(e)
|
|
771
|
+
span.set_status(Status(StatusCode.ERROR, str(e)))
|
|
772
|
+
|
|
773
|
+
retried = await self._retry_if_requested(execution, dependencies)
|
|
774
|
+
if not retried:
|
|
775
|
+
retried = await self._perpetuate_if_requested(
|
|
776
|
+
execution, dependencies, timedelta(seconds=duration)
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
# Store exception in result_storage
|
|
780
|
+
result_key = None
|
|
781
|
+
if self.docket.execution_ttl:
|
|
782
|
+
pickled_exception = cloudpickle.dumps(e) # type: ignore[arg-type]
|
|
783
|
+
# Base64-encode for JSON serialization
|
|
784
|
+
encoded_exception = base64.b64encode(pickled_exception).decode(
|
|
785
|
+
"ascii"
|
|
786
|
+
)
|
|
787
|
+
result_key = execution.key
|
|
788
|
+
ttl_seconds = int(self.docket.execution_ttl.total_seconds())
|
|
789
|
+
await self.docket.result_storage.put(
|
|
790
|
+
result_key, {"data": encoded_exception}, ttl=ttl_seconds
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
# Mark execution as failed with error message
|
|
794
|
+
error_msg = f"{type(e).__name__}: {str(e)}"
|
|
795
|
+
await execution.mark_as_failed(error_msg, result_key=result_key)
|
|
796
|
+
|
|
797
|
+
arrow = "↫" if retried else "↩"
|
|
798
|
+
logger.exception(
|
|
799
|
+
"%s [%s] %s", arrow, ms(duration), call, extra=log_context
|
|
800
|
+
)
|
|
801
|
+
finally:
|
|
802
|
+
# Release concurrency slot if we acquired one
|
|
803
|
+
if dependencies:
|
|
804
|
+
concurrency_limit = get_single_dependency_of_type(
|
|
805
|
+
dependencies, ConcurrencyLimit
|
|
806
|
+
)
|
|
807
|
+
if concurrency_limit and not concurrency_limit.is_bypassed:
|
|
808
|
+
async with self.docket.redis() as redis:
|
|
809
|
+
await self._release_concurrency_slot(redis, execution)
|
|
810
|
+
|
|
811
|
+
TASKS_RUNNING.add(-1, counter_labels)
|
|
812
|
+
TASKS_COMPLETED.add(1, counter_labels)
|
|
813
|
+
TASK_DURATION.record(duration, counter_labels)
|
|
814
|
+
|
|
815
|
+
async def _run_function_with_timeout(
|
|
816
|
+
self,
|
|
817
|
+
execution: Execution,
|
|
818
|
+
dependencies: dict[str, Dependency],
|
|
819
|
+
timeout: Timeout,
|
|
820
|
+
) -> Any:
|
|
821
|
+
task_coro = cast(
|
|
822
|
+
Coroutine[None, None, Any],
|
|
823
|
+
execution.function(
|
|
824
|
+
*execution.args,
|
|
825
|
+
**{
|
|
826
|
+
**execution.kwargs,
|
|
827
|
+
**dependencies,
|
|
828
|
+
},
|
|
829
|
+
),
|
|
830
|
+
)
|
|
831
|
+
task = asyncio.create_task(task_coro)
|
|
832
|
+
try:
|
|
833
|
+
while not task.done(): # pragma: no branch
|
|
834
|
+
remaining = timeout.remaining().total_seconds()
|
|
835
|
+
if timeout.expired():
|
|
836
|
+
task.cancel()
|
|
837
|
+
break
|
|
838
|
+
|
|
839
|
+
try:
|
|
840
|
+
result = await asyncio.wait_for(
|
|
841
|
+
asyncio.shield(task), timeout=remaining
|
|
842
|
+
)
|
|
843
|
+
return result
|
|
844
|
+
except asyncio.TimeoutError:
|
|
845
|
+
continue
|
|
846
|
+
finally:
|
|
847
|
+
if not task.done(): # pragma: no branch
|
|
848
|
+
task.cancel()
|
|
849
|
+
|
|
850
|
+
try:
|
|
851
|
+
return await task
|
|
852
|
+
except asyncio.CancelledError:
|
|
853
|
+
raise asyncio.TimeoutError
|
|
854
|
+
|
|
855
|
+
async def _retry_if_requested(
|
|
856
|
+
self,
|
|
857
|
+
execution: Execution,
|
|
858
|
+
dependencies: dict[str, Dependency],
|
|
859
|
+
) -> bool:
|
|
860
|
+
retry = get_single_dependency_of_type(dependencies, Retry)
|
|
861
|
+
if not retry:
|
|
862
|
+
return False
|
|
863
|
+
|
|
864
|
+
if retry.attempts is not None and execution.attempt >= retry.attempts:
|
|
865
|
+
return False
|
|
866
|
+
|
|
867
|
+
execution.when = datetime.now(timezone.utc) + retry.delay
|
|
868
|
+
execution.attempt += 1
|
|
869
|
+
# Use replace=True since the task is being rescheduled after failure
|
|
870
|
+
await execution.schedule(replace=True)
|
|
871
|
+
|
|
872
|
+
TASKS_RETRIED.add(1, {**self.labels(), **execution.general_labels()})
|
|
873
|
+
return True
|
|
874
|
+
|
|
875
|
+
async def _perpetuate_if_requested(
|
|
876
|
+
self,
|
|
877
|
+
execution: Execution,
|
|
878
|
+
dependencies: dict[str, Dependency],
|
|
879
|
+
duration: timedelta,
|
|
880
|
+
) -> bool:
|
|
881
|
+
perpetual = get_single_dependency_of_type(dependencies, Perpetual)
|
|
882
|
+
if not perpetual:
|
|
883
|
+
return False
|
|
884
|
+
|
|
885
|
+
if perpetual.cancelled:
|
|
886
|
+
await self.docket.cancel(execution.key)
|
|
887
|
+
return False
|
|
888
|
+
|
|
889
|
+
now = datetime.now(timezone.utc)
|
|
890
|
+
when = max(now, now + perpetual.every - duration)
|
|
891
|
+
|
|
892
|
+
await self.docket.replace(execution.function, when, execution.key)(
|
|
893
|
+
*perpetual.args,
|
|
894
|
+
**perpetual.kwargs,
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
TASKS_PERPETUATED.add(1, {**self.labels(), **execution.general_labels()})
|
|
898
|
+
|
|
899
|
+
return True
|
|
900
|
+
|
|
901
|
+
def _startup_log(self) -> None:
|
|
902
|
+
logger.info("Starting worker %r with the following tasks:", self.name)
|
|
903
|
+
for task_name, task in self.docket.tasks.items():
|
|
904
|
+
logger.info("* %s(%s)", task_name, compact_signature(get_signature(task)))
|
|
905
|
+
|
|
906
|
+
@property
|
|
907
|
+
def workers_set(self) -> str:
|
|
908
|
+
return self.docket.workers_set
|
|
909
|
+
|
|
910
|
+
def worker_tasks_set(self, worker_name: str) -> str:
|
|
911
|
+
return self.docket.worker_tasks_set(worker_name)
|
|
912
|
+
|
|
913
|
+
def task_workers_set(self, task_name: str) -> str:
|
|
914
|
+
return self.docket.task_workers_set(task_name)
|
|
915
|
+
|
|
916
|
+
async def _heartbeat(self) -> None:
|
|
917
|
+
while True:
|
|
918
|
+
await asyncio.sleep(self.docket.heartbeat_interval.total_seconds())
|
|
919
|
+
try:
|
|
920
|
+
now = datetime.now(timezone.utc).timestamp()
|
|
921
|
+
maximum_age = (
|
|
922
|
+
self.docket.heartbeat_interval * self.docket.missed_heartbeats
|
|
923
|
+
)
|
|
924
|
+
oldest = now - maximum_age.total_seconds()
|
|
925
|
+
|
|
926
|
+
task_names = list(self.docket.tasks)
|
|
927
|
+
|
|
928
|
+
async with self.docket.redis() as r:
|
|
929
|
+
async with r.pipeline() as pipeline:
|
|
930
|
+
pipeline.zremrangebyscore(self.workers_set, 0, oldest)
|
|
931
|
+
pipeline.zadd(self.workers_set, {self.name: now})
|
|
932
|
+
|
|
933
|
+
for task_name in task_names:
|
|
934
|
+
task_workers_set = self.task_workers_set(task_name)
|
|
935
|
+
pipeline.zremrangebyscore(task_workers_set, 0, oldest)
|
|
936
|
+
pipeline.zadd(task_workers_set, {self.name: now})
|
|
937
|
+
|
|
938
|
+
pipeline.sadd(self.worker_tasks_set(self.name), *task_names)
|
|
939
|
+
pipeline.expire(
|
|
940
|
+
self.worker_tasks_set(self.name),
|
|
941
|
+
max(maximum_age, timedelta(seconds=1)),
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
await pipeline.execute()
|
|
945
|
+
|
|
946
|
+
async with r.pipeline() as pipeline:
|
|
947
|
+
pipeline.xlen(self.docket.stream_key)
|
|
948
|
+
pipeline.zcount(self.docket.queue_key, 0, now)
|
|
949
|
+
pipeline.zcount(self.docket.queue_key, now, "+inf")
|
|
950
|
+
|
|
951
|
+
results: list[int] = await pipeline.execute()
|
|
952
|
+
stream_depth = results[0]
|
|
953
|
+
overdue_depth = results[1]
|
|
954
|
+
schedule_depth = results[2]
|
|
955
|
+
|
|
956
|
+
QUEUE_DEPTH.set(
|
|
957
|
+
stream_depth + overdue_depth, self.docket.labels()
|
|
958
|
+
)
|
|
959
|
+
SCHEDULE_DEPTH.set(schedule_depth, self.docket.labels())
|
|
960
|
+
|
|
961
|
+
except asyncio.CancelledError: # pragma: no cover
|
|
962
|
+
return
|
|
963
|
+
except ConnectionError:
|
|
964
|
+
REDIS_DISRUPTIONS.add(1, self.labels())
|
|
965
|
+
logger.exception(
|
|
966
|
+
"Error sending worker heartbeat",
|
|
967
|
+
exc_info=True,
|
|
968
|
+
extra=self._log_context(),
|
|
969
|
+
)
|
|
970
|
+
except Exception:
|
|
971
|
+
logger.exception(
|
|
972
|
+
"Error sending worker heartbeat",
|
|
973
|
+
exc_info=True,
|
|
974
|
+
extra=self._log_context(),
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
async def _can_start_task(self, redis: Redis, execution: Execution) -> bool:
|
|
978
|
+
"""Check if a task can start based on concurrency limits."""
|
|
979
|
+
# Check if task has a concurrency limit dependency
|
|
980
|
+
concurrency_limit = get_single_dependency_parameter_of_type(
|
|
981
|
+
execution.function, ConcurrencyLimit
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
if not concurrency_limit:
|
|
985
|
+
return True # No concurrency limit, can always start
|
|
986
|
+
|
|
987
|
+
# Get the concurrency key for this task
|
|
988
|
+
try:
|
|
989
|
+
argument_value = execution.get_argument(concurrency_limit.argument_name)
|
|
990
|
+
except KeyError:
|
|
991
|
+
# If argument not found, let the task fail naturally in execution
|
|
992
|
+
return True
|
|
993
|
+
|
|
994
|
+
scope = concurrency_limit.scope or self.docket.name
|
|
995
|
+
concurrency_key = (
|
|
996
|
+
f"{scope}:concurrency:{concurrency_limit.argument_name}:{argument_value}"
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
# Use Redis sorted set with timestamps to track concurrency and handle expiration
|
|
1000
|
+
lua_script = """
|
|
1001
|
+
local key = KEYS[1]
|
|
1002
|
+
local max_concurrent = tonumber(ARGV[1])
|
|
1003
|
+
local worker_id = ARGV[2]
|
|
1004
|
+
local task_key = ARGV[3]
|
|
1005
|
+
local current_time = tonumber(ARGV[4])
|
|
1006
|
+
local expiration_time = tonumber(ARGV[5])
|
|
1007
|
+
|
|
1008
|
+
-- Remove expired entries
|
|
1009
|
+
local expired_cutoff = current_time - expiration_time
|
|
1010
|
+
redis.call('ZREMRANGEBYSCORE', key, 0, expired_cutoff)
|
|
1011
|
+
|
|
1012
|
+
-- Get current count
|
|
1013
|
+
local current = redis.call('ZCARD', key)
|
|
1014
|
+
|
|
1015
|
+
if current < max_concurrent then
|
|
1016
|
+
-- Add this worker's task to the sorted set with current timestamp
|
|
1017
|
+
redis.call('ZADD', key, current_time, worker_id .. ':' .. task_key)
|
|
1018
|
+
return 1
|
|
1019
|
+
else
|
|
1020
|
+
return 0
|
|
1021
|
+
end
|
|
1022
|
+
"""
|
|
1023
|
+
|
|
1024
|
+
current_time = datetime.now(timezone.utc).timestamp()
|
|
1025
|
+
expiration_seconds = self.redelivery_timeout.total_seconds()
|
|
1026
|
+
|
|
1027
|
+
result = await redis.eval( # type: ignore
|
|
1028
|
+
lua_script,
|
|
1029
|
+
1,
|
|
1030
|
+
concurrency_key,
|
|
1031
|
+
str(concurrency_limit.max_concurrent),
|
|
1032
|
+
self.name,
|
|
1033
|
+
execution.key,
|
|
1034
|
+
current_time,
|
|
1035
|
+
expiration_seconds,
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
return bool(result)
|
|
1039
|
+
|
|
1040
|
+
async def _release_concurrency_slot(
|
|
1041
|
+
self, redis: Redis, execution: Execution
|
|
1042
|
+
) -> None:
|
|
1043
|
+
"""Release a concurrency slot when task completes."""
|
|
1044
|
+
# Check if task has a concurrency limit dependency
|
|
1045
|
+
concurrency_limit = get_single_dependency_parameter_of_type(
|
|
1046
|
+
execution.function, ConcurrencyLimit
|
|
1047
|
+
)
|
|
1048
|
+
|
|
1049
|
+
if not concurrency_limit:
|
|
1050
|
+
return # No concurrency limit to release
|
|
1051
|
+
|
|
1052
|
+
# Get the concurrency key for this task
|
|
1053
|
+
try:
|
|
1054
|
+
argument_value = execution.get_argument(concurrency_limit.argument_name)
|
|
1055
|
+
except KeyError:
|
|
1056
|
+
return # If argument not found, nothing to release
|
|
1057
|
+
|
|
1058
|
+
scope = concurrency_limit.scope or self.docket.name
|
|
1059
|
+
concurrency_key = (
|
|
1060
|
+
f"{scope}:concurrency:{concurrency_limit.argument_name}:{argument_value}"
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
# Remove this worker's task from the sorted set
|
|
1064
|
+
await redis.zrem(concurrency_key, f"{self.name}:{execution.key}") # type: ignore
|
|
1065
|
+
|
|
1066
|
+
|
|
1067
|
+
def ms(seconds: float) -> str:
|
|
1068
|
+
if seconds < 100:
|
|
1069
|
+
return f"{seconds * 1000:6.0f}ms"
|
|
1070
|
+
else:
|
|
1071
|
+
return f"{seconds:6.0f}s "
|