pydocket 0.15.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docket/__init__.py +55 -0
- docket/__main__.py +3 -0
- docket/_uuid7.py +99 -0
- docket/agenda.py +202 -0
- docket/annotations.py +81 -0
- docket/cli.py +1185 -0
- docket/dependencies.py +808 -0
- docket/docket.py +1062 -0
- docket/execution.py +1370 -0
- docket/instrumentation.py +225 -0
- docket/py.typed +0 -0
- docket/tasks.py +59 -0
- docket/testing.py +235 -0
- docket/worker.py +1071 -0
- pydocket-0.15.3.dist-info/METADATA +160 -0
- pydocket-0.15.3.dist-info/RECORD +19 -0
- pydocket-0.15.3.dist-info/WHEEL +4 -0
- pydocket-0.15.3.dist-info/entry_points.txt +2 -0
- pydocket-0.15.3.dist-info/licenses/LICENSE +9 -0
docket/docket.py
ADDED
|
@@ -0,0 +1,1062 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import importlib
|
|
3
|
+
import logging
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from datetime import datetime, timedelta, timezone
|
|
7
|
+
from types import TracebackType
|
|
8
|
+
from typing import (
|
|
9
|
+
AsyncGenerator,
|
|
10
|
+
Awaitable,
|
|
11
|
+
Callable,
|
|
12
|
+
Collection,
|
|
13
|
+
Hashable,
|
|
14
|
+
Iterable,
|
|
15
|
+
Mapping,
|
|
16
|
+
NoReturn,
|
|
17
|
+
ParamSpec,
|
|
18
|
+
Protocol,
|
|
19
|
+
Sequence,
|
|
20
|
+
TypedDict,
|
|
21
|
+
TypeVar,
|
|
22
|
+
cast,
|
|
23
|
+
overload,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from key_value.aio.stores.base import BaseContextManagerStore
|
|
27
|
+
from typing_extensions import Self
|
|
28
|
+
|
|
29
|
+
import redis.exceptions
|
|
30
|
+
from opentelemetry import trace
|
|
31
|
+
from redis.asyncio import ConnectionPool, Redis
|
|
32
|
+
from ._uuid7 import uuid7
|
|
33
|
+
|
|
34
|
+
from .execution import (
|
|
35
|
+
Execution,
|
|
36
|
+
ExecutionState,
|
|
37
|
+
LiteralOperator,
|
|
38
|
+
Operator,
|
|
39
|
+
Restore,
|
|
40
|
+
Strike,
|
|
41
|
+
StrikeInstruction,
|
|
42
|
+
StrikeList,
|
|
43
|
+
TaskFunction,
|
|
44
|
+
)
|
|
45
|
+
from key_value.aio.protocols.key_value import AsyncKeyValue
|
|
46
|
+
from key_value.aio.stores.redis import RedisStore
|
|
47
|
+
from key_value.aio.stores.memory import MemoryStore
|
|
48
|
+
|
|
49
|
+
from .instrumentation import (
|
|
50
|
+
REDIS_DISRUPTIONS,
|
|
51
|
+
STRIKES_IN_EFFECT,
|
|
52
|
+
TASKS_ADDED,
|
|
53
|
+
TASKS_CANCELLED,
|
|
54
|
+
TASKS_REPLACED,
|
|
55
|
+
TASKS_SCHEDULED,
|
|
56
|
+
TASKS_STRICKEN,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
60
|
+
tracer: trace.Tracer = trace.get_tracer(__name__)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class _cancel_task(Protocol):
|
|
64
|
+
async def __call__(
|
|
65
|
+
self, keys: list[str], args: list[str]
|
|
66
|
+
) -> str: ... # pragma: no cover
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
P = ParamSpec("P")
|
|
70
|
+
R = TypeVar("R")
|
|
71
|
+
|
|
72
|
+
TaskCollection = Iterable[TaskFunction]
|
|
73
|
+
|
|
74
|
+
RedisStreamID = bytes
|
|
75
|
+
RedisMessageID = bytes
|
|
76
|
+
RedisMessage = dict[bytes, bytes]
|
|
77
|
+
RedisMessages = Sequence[tuple[RedisMessageID, RedisMessage]]
|
|
78
|
+
RedisStream = tuple[RedisStreamID, RedisMessages]
|
|
79
|
+
RedisReadGroupResponse = Sequence[RedisStream]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class RedisStreamPendingMessage(TypedDict):
|
|
83
|
+
message_id: bytes
|
|
84
|
+
consumer: bytes
|
|
85
|
+
time_since_delivered: int
|
|
86
|
+
times_delivered: int
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@dataclass
|
|
90
|
+
class WorkerInfo:
|
|
91
|
+
name: str
|
|
92
|
+
last_seen: datetime
|
|
93
|
+
tasks: set[str]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class RunningExecution(Execution):
|
|
97
|
+
worker: str
|
|
98
|
+
started: datetime
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
execution: Execution,
|
|
103
|
+
worker: str,
|
|
104
|
+
started: datetime,
|
|
105
|
+
) -> None:
|
|
106
|
+
# Call parent constructor to properly initialize immutable fields
|
|
107
|
+
super().__init__(
|
|
108
|
+
docket=execution.docket,
|
|
109
|
+
function=execution.function,
|
|
110
|
+
args=execution.args,
|
|
111
|
+
kwargs=execution.kwargs,
|
|
112
|
+
key=execution.key,
|
|
113
|
+
when=execution.when,
|
|
114
|
+
attempt=execution.attempt,
|
|
115
|
+
trace_context=execution.trace_context,
|
|
116
|
+
redelivered=execution.redelivered,
|
|
117
|
+
)
|
|
118
|
+
# Copy over mutable state fields
|
|
119
|
+
self.state: ExecutionState = execution.state
|
|
120
|
+
self.started_at: datetime | None = execution.started_at
|
|
121
|
+
self.completed_at: datetime | None = execution.completed_at
|
|
122
|
+
self.error: str | None = execution.error
|
|
123
|
+
self.result_key: str | None = execution.result_key
|
|
124
|
+
# Set RunningExecution-specific fields
|
|
125
|
+
self.worker = worker
|
|
126
|
+
self.started = started
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclass
|
|
130
|
+
class DocketSnapshot:
|
|
131
|
+
taken: datetime
|
|
132
|
+
total_tasks: int
|
|
133
|
+
future: Sequence[Execution]
|
|
134
|
+
running: Sequence[RunningExecution]
|
|
135
|
+
workers: Collection[WorkerInfo]
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class Docket:
|
|
139
|
+
"""A Docket represents a collection of tasks that may be scheduled for later
|
|
140
|
+
execution. With a Docket, you can add, replace, and cancel tasks.
|
|
141
|
+
Example:
|
|
142
|
+
|
|
143
|
+
```python
|
|
144
|
+
@task
|
|
145
|
+
async def my_task(greeting: str, recipient: str) -> None:
|
|
146
|
+
print(f"{greeting}, {recipient}!")
|
|
147
|
+
|
|
148
|
+
async with Docket() as docket:
|
|
149
|
+
docket.add(my_task)("Hello", recipient="world")
|
|
150
|
+
```
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
tasks: dict[str, TaskFunction]
|
|
154
|
+
strike_list: StrikeList
|
|
155
|
+
|
|
156
|
+
_monitor_strikes_task: asyncio.Task[None]
|
|
157
|
+
_connection_pool: ConnectionPool
|
|
158
|
+
_cancel_task_script: _cancel_task | None
|
|
159
|
+
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
name: str = "docket",
|
|
163
|
+
url: str = "redis://localhost:6379/0",
|
|
164
|
+
heartbeat_interval: timedelta = timedelta(seconds=2),
|
|
165
|
+
missed_heartbeats: int = 5,
|
|
166
|
+
execution_ttl: timedelta = timedelta(minutes=15),
|
|
167
|
+
result_storage: AsyncKeyValue | None = None,
|
|
168
|
+
) -> None:
|
|
169
|
+
"""
|
|
170
|
+
Args:
|
|
171
|
+
name: The name of the docket.
|
|
172
|
+
url: The URL of the Redis server or in-memory backend. For example:
|
|
173
|
+
- "redis://localhost:6379/0"
|
|
174
|
+
- "redis://user:password@localhost:6379/0"
|
|
175
|
+
- "redis://user:password@localhost:6379/0?ssl=true"
|
|
176
|
+
- "rediss://localhost:6379/0"
|
|
177
|
+
- "unix:///path/to/redis.sock"
|
|
178
|
+
- "memory://" (in-memory backend for testing)
|
|
179
|
+
heartbeat_interval: How often workers send heartbeat messages to the docket.
|
|
180
|
+
missed_heartbeats: How many heartbeats a worker can miss before it is
|
|
181
|
+
considered dead.
|
|
182
|
+
execution_ttl: How long to keep completed or failed execution state records
|
|
183
|
+
in Redis before they expire. Defaults to 15 minutes.
|
|
184
|
+
"""
|
|
185
|
+
self.name = name
|
|
186
|
+
self.url = url
|
|
187
|
+
self.heartbeat_interval = heartbeat_interval
|
|
188
|
+
self.missed_heartbeats = missed_heartbeats
|
|
189
|
+
self.execution_ttl = execution_ttl
|
|
190
|
+
self._cancel_task_script = None
|
|
191
|
+
|
|
192
|
+
self.result_storage: AsyncKeyValue
|
|
193
|
+
if url.startswith("memory://"):
|
|
194
|
+
self.result_storage = MemoryStore()
|
|
195
|
+
else:
|
|
196
|
+
self.result_storage = RedisStore(
|
|
197
|
+
url=url, default_collection=f"{name}:results"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
from .tasks import standard_tasks
|
|
201
|
+
|
|
202
|
+
self.tasks: dict[str, TaskFunction] = {fn.__name__: fn for fn in standard_tasks}
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
def worker_group_name(self) -> str:
|
|
206
|
+
return "docket-workers"
|
|
207
|
+
|
|
208
|
+
async def __aenter__(self) -> Self:
|
|
209
|
+
self.strike_list = StrikeList()
|
|
210
|
+
|
|
211
|
+
# Check if we should use in-memory backend (fakeredis)
|
|
212
|
+
# Support memory:// URLs for in-memory dockets
|
|
213
|
+
if self.url.startswith("memory://"):
|
|
214
|
+
try:
|
|
215
|
+
from fakeredis.aioredis import FakeConnection, FakeServer
|
|
216
|
+
|
|
217
|
+
# All memory:// URLs share a single FakeServer instance
|
|
218
|
+
# Multiple dockets with different names are isolated by Redis key prefixes
|
|
219
|
+
# (e.g., docket1:stream vs docket2:stream)
|
|
220
|
+
if not hasattr(Docket, "_memory_server"):
|
|
221
|
+
Docket._memory_server = FakeServer() # type: ignore
|
|
222
|
+
|
|
223
|
+
server = Docket._memory_server # type: ignore
|
|
224
|
+
self._connection_pool = ConnectionPool(
|
|
225
|
+
connection_class=FakeConnection, server=server
|
|
226
|
+
)
|
|
227
|
+
except ImportError as e:
|
|
228
|
+
raise ImportError(
|
|
229
|
+
"fakeredis is required for memory:// URLs. "
|
|
230
|
+
"Install with: pip install pydocket[memory]"
|
|
231
|
+
) from e
|
|
232
|
+
else:
|
|
233
|
+
self._connection_pool = ConnectionPool.from_url(self.url) # type: ignore
|
|
234
|
+
|
|
235
|
+
self._monitor_strikes_task = asyncio.create_task(self._monitor_strikes())
|
|
236
|
+
|
|
237
|
+
if isinstance(self.result_storage, BaseContextManagerStore):
|
|
238
|
+
await self.result_storage.__aenter__()
|
|
239
|
+
else:
|
|
240
|
+
await self.result_storage.setup()
|
|
241
|
+
return self
|
|
242
|
+
|
|
243
|
+
async def __aexit__(
|
|
244
|
+
self,
|
|
245
|
+
exc_type: type[BaseException] | None,
|
|
246
|
+
exc_value: BaseException | None,
|
|
247
|
+
traceback: TracebackType | None,
|
|
248
|
+
) -> None:
|
|
249
|
+
if isinstance(self.result_storage, BaseContextManagerStore):
|
|
250
|
+
await self.result_storage.__aexit__(exc_type, exc_value, traceback)
|
|
251
|
+
|
|
252
|
+
del self.strike_list
|
|
253
|
+
|
|
254
|
+
self._monitor_strikes_task.cancel()
|
|
255
|
+
try:
|
|
256
|
+
await self._monitor_strikes_task
|
|
257
|
+
except asyncio.CancelledError:
|
|
258
|
+
pass
|
|
259
|
+
|
|
260
|
+
await asyncio.shield(self._connection_pool.disconnect())
|
|
261
|
+
del self._connection_pool
|
|
262
|
+
|
|
263
|
+
@asynccontextmanager
|
|
264
|
+
async def redis(self) -> AsyncGenerator[Redis, None]:
|
|
265
|
+
r = Redis(connection_pool=self._connection_pool)
|
|
266
|
+
await r.__aenter__()
|
|
267
|
+
try:
|
|
268
|
+
yield r
|
|
269
|
+
finally:
|
|
270
|
+
await asyncio.shield(r.__aexit__(None, None, None))
|
|
271
|
+
|
|
272
|
+
def register(self, function: TaskFunction) -> None:
|
|
273
|
+
"""Register a task with the Docket.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
function: The task to register.
|
|
277
|
+
"""
|
|
278
|
+
from .dependencies import validate_dependencies
|
|
279
|
+
|
|
280
|
+
validate_dependencies(function)
|
|
281
|
+
|
|
282
|
+
self.tasks[function.__name__] = function
|
|
283
|
+
|
|
284
|
+
def register_collection(self, collection_path: str) -> None:
|
|
285
|
+
"""
|
|
286
|
+
Register a collection of tasks.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
collection_path: A path in the format "module:collection".
|
|
290
|
+
"""
|
|
291
|
+
module_name, _, member_name = collection_path.rpartition(":")
|
|
292
|
+
module = importlib.import_module(module_name)
|
|
293
|
+
collection = getattr(module, member_name)
|
|
294
|
+
for function in collection:
|
|
295
|
+
self.register(function)
|
|
296
|
+
|
|
297
|
+
def labels(self) -> Mapping[str, str]:
|
|
298
|
+
return {
|
|
299
|
+
"docket.name": self.name,
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
@overload
|
|
303
|
+
def add(
|
|
304
|
+
self,
|
|
305
|
+
function: Callable[P, Awaitable[R]],
|
|
306
|
+
when: datetime | None = None,
|
|
307
|
+
key: str | None = None,
|
|
308
|
+
) -> Callable[P, Awaitable[Execution]]:
|
|
309
|
+
"""Add a task to the Docket.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
function: The task function to add.
|
|
313
|
+
when: The time to schedule the task.
|
|
314
|
+
key: The key to schedule the task under.
|
|
315
|
+
"""
|
|
316
|
+
|
|
317
|
+
@overload
|
|
318
|
+
def add(
|
|
319
|
+
self,
|
|
320
|
+
function: str,
|
|
321
|
+
when: datetime | None = None,
|
|
322
|
+
key: str | None = None,
|
|
323
|
+
) -> Callable[..., Awaitable[Execution]]:
|
|
324
|
+
"""Add a task to the Docket.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
function: The name of a task to add.
|
|
328
|
+
when: The time to schedule the task.
|
|
329
|
+
key: The key to schedule the task under.
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
def add(
|
|
333
|
+
self,
|
|
334
|
+
function: Callable[P, Awaitable[R]] | str,
|
|
335
|
+
when: datetime | None = None,
|
|
336
|
+
key: str | None = None,
|
|
337
|
+
) -> Callable[..., Awaitable[Execution]]:
|
|
338
|
+
"""Add a task to the Docket.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
function: The task to add.
|
|
342
|
+
when: The time to schedule the task.
|
|
343
|
+
key: The key to schedule the task under.
|
|
344
|
+
"""
|
|
345
|
+
if isinstance(function, str):
|
|
346
|
+
function = self.tasks[function]
|
|
347
|
+
else:
|
|
348
|
+
self.register(function)
|
|
349
|
+
|
|
350
|
+
if when is None:
|
|
351
|
+
when = datetime.now(timezone.utc)
|
|
352
|
+
|
|
353
|
+
if key is None:
|
|
354
|
+
key = str(uuid7())
|
|
355
|
+
|
|
356
|
+
async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
|
|
357
|
+
execution = Execution(self, function, args, kwargs, key, when, attempt=1)
|
|
358
|
+
|
|
359
|
+
# Check if task is stricken before scheduling
|
|
360
|
+
if self.strike_list.is_stricken(execution):
|
|
361
|
+
logger.warning(
|
|
362
|
+
"%r is stricken, skipping schedule of %r",
|
|
363
|
+
execution.function.__name__,
|
|
364
|
+
execution.key,
|
|
365
|
+
)
|
|
366
|
+
TASKS_STRICKEN.add(
|
|
367
|
+
1,
|
|
368
|
+
{
|
|
369
|
+
**self.labels(),
|
|
370
|
+
**execution.general_labels(),
|
|
371
|
+
"docket.where": "docket",
|
|
372
|
+
},
|
|
373
|
+
)
|
|
374
|
+
return execution
|
|
375
|
+
|
|
376
|
+
# Schedule atomically (includes state record write)
|
|
377
|
+
await execution.schedule(replace=False)
|
|
378
|
+
|
|
379
|
+
TASKS_ADDED.add(1, {**self.labels(), **execution.general_labels()})
|
|
380
|
+
TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
|
|
381
|
+
|
|
382
|
+
return execution
|
|
383
|
+
|
|
384
|
+
return scheduler
|
|
385
|
+
|
|
386
|
+
@overload
|
|
387
|
+
def replace(
|
|
388
|
+
self,
|
|
389
|
+
function: Callable[P, Awaitable[R]],
|
|
390
|
+
when: datetime,
|
|
391
|
+
key: str,
|
|
392
|
+
) -> Callable[P, Awaitable[Execution]]:
|
|
393
|
+
"""Replace a previously scheduled task on the Docket.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
function: The task function to replace.
|
|
397
|
+
when: The time to schedule the task.
|
|
398
|
+
key: The key to schedule the task under.
|
|
399
|
+
"""
|
|
400
|
+
|
|
401
|
+
@overload
|
|
402
|
+
def replace(
|
|
403
|
+
self,
|
|
404
|
+
function: str,
|
|
405
|
+
when: datetime,
|
|
406
|
+
key: str,
|
|
407
|
+
) -> Callable[..., Awaitable[Execution]]:
|
|
408
|
+
"""Replace a previously scheduled task on the Docket.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
function: The name of a task to replace.
|
|
412
|
+
when: The time to schedule the task.
|
|
413
|
+
key: The key to schedule the task under.
|
|
414
|
+
"""
|
|
415
|
+
|
|
416
|
+
def replace(
|
|
417
|
+
self,
|
|
418
|
+
function: Callable[P, Awaitable[R]] | str,
|
|
419
|
+
when: datetime,
|
|
420
|
+
key: str,
|
|
421
|
+
) -> Callable[..., Awaitable[Execution]]:
|
|
422
|
+
"""Replace a previously scheduled task on the Docket.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
function: The task to replace.
|
|
426
|
+
when: The time to schedule the task.
|
|
427
|
+
key: The key to schedule the task under.
|
|
428
|
+
"""
|
|
429
|
+
if isinstance(function, str):
|
|
430
|
+
function = self.tasks[function]
|
|
431
|
+
else:
|
|
432
|
+
self.register(function)
|
|
433
|
+
|
|
434
|
+
async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
|
|
435
|
+
execution = Execution(self, function, args, kwargs, key, when, attempt=1)
|
|
436
|
+
|
|
437
|
+
# Check if task is stricken before scheduling
|
|
438
|
+
if self.strike_list.is_stricken(execution):
|
|
439
|
+
logger.warning(
|
|
440
|
+
"%r is stricken, skipping schedule of %r",
|
|
441
|
+
execution.function.__name__,
|
|
442
|
+
execution.key,
|
|
443
|
+
)
|
|
444
|
+
TASKS_STRICKEN.add(
|
|
445
|
+
1,
|
|
446
|
+
{
|
|
447
|
+
**self.labels(),
|
|
448
|
+
**execution.general_labels(),
|
|
449
|
+
"docket.where": "docket",
|
|
450
|
+
},
|
|
451
|
+
)
|
|
452
|
+
return execution
|
|
453
|
+
|
|
454
|
+
# Schedule atomically (includes state record write)
|
|
455
|
+
await execution.schedule(replace=True)
|
|
456
|
+
|
|
457
|
+
TASKS_REPLACED.add(1, {**self.labels(), **execution.general_labels()})
|
|
458
|
+
TASKS_CANCELLED.add(1, {**self.labels(), **execution.general_labels()})
|
|
459
|
+
TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
|
|
460
|
+
|
|
461
|
+
return execution
|
|
462
|
+
|
|
463
|
+
return scheduler
|
|
464
|
+
|
|
465
|
+
async def schedule(self, execution: Execution) -> None:
|
|
466
|
+
with tracer.start_as_current_span(
|
|
467
|
+
"docket.schedule",
|
|
468
|
+
attributes={
|
|
469
|
+
**self.labels(),
|
|
470
|
+
**execution.specific_labels(),
|
|
471
|
+
"code.function.name": execution.function.__name__,
|
|
472
|
+
},
|
|
473
|
+
):
|
|
474
|
+
# Check if task is stricken before scheduling
|
|
475
|
+
if self.strike_list.is_stricken(execution):
|
|
476
|
+
logger.warning(
|
|
477
|
+
"%r is stricken, skipping schedule of %r",
|
|
478
|
+
execution.function.__name__,
|
|
479
|
+
execution.key,
|
|
480
|
+
)
|
|
481
|
+
TASKS_STRICKEN.add(
|
|
482
|
+
1,
|
|
483
|
+
{
|
|
484
|
+
**self.labels(),
|
|
485
|
+
**execution.general_labels(),
|
|
486
|
+
"docket.where": "docket",
|
|
487
|
+
},
|
|
488
|
+
)
|
|
489
|
+
return
|
|
490
|
+
|
|
491
|
+
# Schedule atomically (includes state record write)
|
|
492
|
+
await execution.schedule(replace=False)
|
|
493
|
+
|
|
494
|
+
TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
|
|
495
|
+
|
|
496
|
+
async def cancel(self, key: str) -> None:
|
|
497
|
+
"""Cancel a previously scheduled task on the Docket.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
key: The key of the task to cancel.
|
|
501
|
+
"""
|
|
502
|
+
with tracer.start_as_current_span(
|
|
503
|
+
"docket.cancel",
|
|
504
|
+
attributes={**self.labels(), "docket.key": key},
|
|
505
|
+
):
|
|
506
|
+
async with self.redis() as redis:
|
|
507
|
+
await self._cancel(redis, key)
|
|
508
|
+
|
|
509
|
+
TASKS_CANCELLED.add(1, self.labels())
|
|
510
|
+
|
|
511
|
+
async def get_execution(self, key: str) -> Execution | None:
|
|
512
|
+
"""Get a task Execution from the Docket by its key.
|
|
513
|
+
|
|
514
|
+
Args:
|
|
515
|
+
key: The task key.
|
|
516
|
+
|
|
517
|
+
Returns:
|
|
518
|
+
The Execution if found, None if the key doesn't exist.
|
|
519
|
+
|
|
520
|
+
Example:
|
|
521
|
+
# Claim check pattern: schedule a task, save the key,
|
|
522
|
+
# then retrieve the execution later to check status or get results
|
|
523
|
+
execution = await docket.add(my_task, key="important-task")(args)
|
|
524
|
+
task_key = execution.key
|
|
525
|
+
|
|
526
|
+
# Later, retrieve the execution by key
|
|
527
|
+
execution = await docket.get_execution(task_key)
|
|
528
|
+
if execution:
|
|
529
|
+
await execution.get_result()
|
|
530
|
+
"""
|
|
531
|
+
import cloudpickle
|
|
532
|
+
|
|
533
|
+
async with self.redis() as redis:
|
|
534
|
+
runs_key = f"{self.name}:runs:{key}"
|
|
535
|
+
data = await redis.hgetall(runs_key)
|
|
536
|
+
|
|
537
|
+
if not data:
|
|
538
|
+
return None
|
|
539
|
+
|
|
540
|
+
# Extract task definition from runs hash
|
|
541
|
+
function_name = data.get(b"function")
|
|
542
|
+
args_data = data.get(b"args")
|
|
543
|
+
kwargs_data = data.get(b"kwargs")
|
|
544
|
+
|
|
545
|
+
# TODO: Remove in next breaking release (v0.14.0) - fallback for 0.13.0 compatibility
|
|
546
|
+
# Check parked hash if runs hash incomplete (0.13.0 didn't store task data in runs hash)
|
|
547
|
+
if not function_name or not args_data or not kwargs_data:
|
|
548
|
+
parked_key = self.parked_task_key(key)
|
|
549
|
+
parked_data = await redis.hgetall(parked_key)
|
|
550
|
+
if parked_data:
|
|
551
|
+
function_name = parked_data.get(b"function")
|
|
552
|
+
args_data = parked_data.get(b"args")
|
|
553
|
+
kwargs_data = parked_data.get(b"kwargs")
|
|
554
|
+
|
|
555
|
+
if not function_name or not args_data or not kwargs_data:
|
|
556
|
+
return None
|
|
557
|
+
|
|
558
|
+
# Look up function in registry, or create a placeholder if not found
|
|
559
|
+
function_name_str = function_name.decode()
|
|
560
|
+
function = self.tasks.get(function_name_str)
|
|
561
|
+
if not function:
|
|
562
|
+
# Create a placeholder function for display purposes (e.g., CLI watch)
|
|
563
|
+
# This allows viewing task state even if function isn't registered
|
|
564
|
+
async def placeholder() -> None:
|
|
565
|
+
pass # pragma: no cover
|
|
566
|
+
|
|
567
|
+
placeholder.__name__ = function_name_str
|
|
568
|
+
function = placeholder
|
|
569
|
+
|
|
570
|
+
# Deserialize args and kwargs
|
|
571
|
+
args = cloudpickle.loads(args_data)
|
|
572
|
+
kwargs = cloudpickle.loads(kwargs_data)
|
|
573
|
+
|
|
574
|
+
# Extract scheduling metadata
|
|
575
|
+
when_str = data.get(b"when")
|
|
576
|
+
if not when_str:
|
|
577
|
+
return None
|
|
578
|
+
when = datetime.fromtimestamp(float(when_str.decode()), tz=timezone.utc)
|
|
579
|
+
|
|
580
|
+
# Build execution (attempt defaults to 1 for initial scheduling)
|
|
581
|
+
from docket.execution import Execution
|
|
582
|
+
|
|
583
|
+
execution = Execution(
|
|
584
|
+
docket=self,
|
|
585
|
+
function=function,
|
|
586
|
+
args=args,
|
|
587
|
+
kwargs=kwargs,
|
|
588
|
+
key=key,
|
|
589
|
+
when=when,
|
|
590
|
+
attempt=1,
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
# Sync with current state from Redis
|
|
594
|
+
await execution.sync()
|
|
595
|
+
|
|
596
|
+
return execution
|
|
597
|
+
|
|
598
|
+
@property
|
|
599
|
+
def queue_key(self) -> str:
|
|
600
|
+
return f"{self.name}:queue"
|
|
601
|
+
|
|
602
|
+
@property
|
|
603
|
+
def stream_key(self) -> str:
|
|
604
|
+
return f"{self.name}:stream"
|
|
605
|
+
|
|
606
|
+
def known_task_key(self, key: str) -> str:
|
|
607
|
+
return f"{self.name}:known:{key}"
|
|
608
|
+
|
|
609
|
+
def parked_task_key(self, key: str) -> str:
|
|
610
|
+
return f"{self.name}:{key}"
|
|
611
|
+
|
|
612
|
+
def stream_id_key(self, key: str) -> str:
|
|
613
|
+
return f"{self.name}:stream-id:{key}"
|
|
614
|
+
|
|
615
|
+
async def _ensure_stream_and_group(self) -> None:
|
|
616
|
+
"""Create stream and consumer group if they don't exist (idempotent).
|
|
617
|
+
|
|
618
|
+
This is safe to call from multiple workers racing to initialize - the
|
|
619
|
+
BUSYGROUP error is silently ignored since it just means another worker
|
|
620
|
+
created the group first.
|
|
621
|
+
"""
|
|
622
|
+
try:
|
|
623
|
+
async with self.redis() as r:
|
|
624
|
+
await r.xgroup_create(
|
|
625
|
+
groupname=self.worker_group_name,
|
|
626
|
+
name=self.stream_key,
|
|
627
|
+
id="0-0",
|
|
628
|
+
mkstream=True,
|
|
629
|
+
)
|
|
630
|
+
except redis.exceptions.ResponseError as e:
|
|
631
|
+
if "BUSYGROUP" not in str(e):
|
|
632
|
+
raise # pragma: no cover
|
|
633
|
+
|
|
634
|
+
async def _cancel(self, redis: Redis, key: str) -> None:
|
|
635
|
+
"""Cancel a task atomically.
|
|
636
|
+
|
|
637
|
+
Handles cancellation regardless of task location:
|
|
638
|
+
- From the stream (using stored message ID)
|
|
639
|
+
- From the queue (scheduled tasks)
|
|
640
|
+
- Cleans up all associated metadata keys
|
|
641
|
+
"""
|
|
642
|
+
if self._cancel_task_script is None:
|
|
643
|
+
self._cancel_task_script = cast(
|
|
644
|
+
_cancel_task,
|
|
645
|
+
redis.register_script(
|
|
646
|
+
# KEYS: stream_key, known_key, parked_key, queue_key, stream_id_key, runs_key
|
|
647
|
+
# ARGV: task_key, completed_at
|
|
648
|
+
"""
|
|
649
|
+
local stream_key = KEYS[1]
|
|
650
|
+
-- TODO: Remove in next breaking release (v0.14.0) - legacy key locations
|
|
651
|
+
local known_key = KEYS[2]
|
|
652
|
+
local parked_key = KEYS[3]
|
|
653
|
+
local queue_key = KEYS[4]
|
|
654
|
+
local stream_id_key = KEYS[5]
|
|
655
|
+
local runs_key = KEYS[6]
|
|
656
|
+
local task_key = ARGV[1]
|
|
657
|
+
local completed_at = ARGV[2]
|
|
658
|
+
|
|
659
|
+
-- Get stream ID (check new location first, then legacy)
|
|
660
|
+
local message_id = redis.call('HGET', runs_key, 'stream_id')
|
|
661
|
+
|
|
662
|
+
-- TODO: Remove in next breaking release (v0.14.0) - check legacy location
|
|
663
|
+
if not message_id then
|
|
664
|
+
message_id = redis.call('GET', stream_id_key)
|
|
665
|
+
end
|
|
666
|
+
|
|
667
|
+
-- Delete from stream if message ID exists
|
|
668
|
+
if message_id then
|
|
669
|
+
redis.call('XDEL', stream_key, message_id)
|
|
670
|
+
end
|
|
671
|
+
|
|
672
|
+
-- Clean up legacy keys and parked data
|
|
673
|
+
redis.call('DEL', known_key, parked_key, stream_id_key)
|
|
674
|
+
redis.call('ZREM', queue_key, task_key)
|
|
675
|
+
|
|
676
|
+
-- Create tombstone: set CANCELLED state with completed_at timestamp
|
|
677
|
+
redis.call('HSET', runs_key, 'state', 'cancelled', 'completed_at', completed_at)
|
|
678
|
+
|
|
679
|
+
return 'OK'
|
|
680
|
+
"""
|
|
681
|
+
),
|
|
682
|
+
)
|
|
683
|
+
cancel_task = self._cancel_task_script
|
|
684
|
+
|
|
685
|
+
# Create tombstone with CANCELLED state
|
|
686
|
+
completed_at = datetime.now(timezone.utc).isoformat()
|
|
687
|
+
runs_key = f"{self.name}:runs:{key}"
|
|
688
|
+
|
|
689
|
+
# Execute the cancellation script
|
|
690
|
+
await cancel_task(
|
|
691
|
+
keys=[
|
|
692
|
+
self.stream_key,
|
|
693
|
+
self.known_task_key(key),
|
|
694
|
+
self.parked_task_key(key),
|
|
695
|
+
self.queue_key,
|
|
696
|
+
self.stream_id_key(key),
|
|
697
|
+
runs_key,
|
|
698
|
+
],
|
|
699
|
+
args=[key, completed_at],
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
# Apply TTL or delete tombstone based on execution_ttl
|
|
703
|
+
if self.execution_ttl:
|
|
704
|
+
ttl_seconds = int(self.execution_ttl.total_seconds())
|
|
705
|
+
await redis.expire(runs_key, ttl_seconds)
|
|
706
|
+
else:
|
|
707
|
+
# execution_ttl=0 means no observability - delete tombstone immediately
|
|
708
|
+
await redis.delete(runs_key)
|
|
709
|
+
|
|
710
|
+
@property
|
|
711
|
+
def strike_key(self) -> str:
|
|
712
|
+
return f"{self.name}:strikes"
|
|
713
|
+
|
|
714
|
+
async def strike(
|
|
715
|
+
self,
|
|
716
|
+
function: Callable[P, Awaitable[R]] | str | None = None,
|
|
717
|
+
parameter: str | None = None,
|
|
718
|
+
operator: Operator | LiteralOperator = "==",
|
|
719
|
+
value: Hashable | None = None,
|
|
720
|
+
) -> None:
|
|
721
|
+
"""Strike a task from the Docket.
|
|
722
|
+
|
|
723
|
+
Args:
|
|
724
|
+
function: The task to strike.
|
|
725
|
+
parameter: The parameter to strike on.
|
|
726
|
+
operator: The operator to use.
|
|
727
|
+
value: The value to strike on.
|
|
728
|
+
"""
|
|
729
|
+
if not isinstance(function, (str, type(None))):
|
|
730
|
+
function = function.__name__
|
|
731
|
+
|
|
732
|
+
operator = Operator(operator)
|
|
733
|
+
|
|
734
|
+
strike = Strike(function, parameter, operator, value)
|
|
735
|
+
return await self._send_strike_instruction(strike)
|
|
736
|
+
|
|
737
|
+
async def restore(
|
|
738
|
+
self,
|
|
739
|
+
function: Callable[P, Awaitable[R]] | str | None = None,
|
|
740
|
+
parameter: str | None = None,
|
|
741
|
+
operator: Operator | LiteralOperator = "==",
|
|
742
|
+
value: Hashable | None = None,
|
|
743
|
+
) -> None:
|
|
744
|
+
"""Restore a previously stricken task to the Docket.
|
|
745
|
+
|
|
746
|
+
Args:
|
|
747
|
+
function: The task to restore.
|
|
748
|
+
parameter: The parameter to restore on.
|
|
749
|
+
operator: The operator to use.
|
|
750
|
+
value: The value to restore on.
|
|
751
|
+
"""
|
|
752
|
+
if not isinstance(function, (str, type(None))):
|
|
753
|
+
function = function.__name__
|
|
754
|
+
|
|
755
|
+
operator = Operator(operator)
|
|
756
|
+
|
|
757
|
+
restore = Restore(function, parameter, operator, value)
|
|
758
|
+
return await self._send_strike_instruction(restore)
|
|
759
|
+
|
|
760
|
+
async def _send_strike_instruction(self, instruction: StrikeInstruction) -> None:
|
|
761
|
+
with tracer.start_as_current_span(
|
|
762
|
+
f"docket.{instruction.direction}",
|
|
763
|
+
attributes={
|
|
764
|
+
**self.labels(),
|
|
765
|
+
**instruction.labels(),
|
|
766
|
+
},
|
|
767
|
+
):
|
|
768
|
+
async with self.redis() as redis:
|
|
769
|
+
message = instruction.as_message()
|
|
770
|
+
await redis.xadd(self.strike_key, message) # type: ignore[arg-type]
|
|
771
|
+
self.strike_list.update(instruction)
|
|
772
|
+
|
|
773
|
+
async def _monitor_strikes(self) -> NoReturn:
|
|
774
|
+
last_id = "0-0"
|
|
775
|
+
while True:
|
|
776
|
+
try:
|
|
777
|
+
async with self.redis() as r:
|
|
778
|
+
while True:
|
|
779
|
+
streams: RedisReadGroupResponse = await r.xread(
|
|
780
|
+
{self.strike_key: last_id},
|
|
781
|
+
count=100,
|
|
782
|
+
block=60_000,
|
|
783
|
+
)
|
|
784
|
+
for _, messages in streams:
|
|
785
|
+
for message_id, message in messages:
|
|
786
|
+
last_id = message_id
|
|
787
|
+
instruction = StrikeInstruction.from_message(message)
|
|
788
|
+
self.strike_list.update(instruction)
|
|
789
|
+
logger.info(
|
|
790
|
+
"%s %r",
|
|
791
|
+
(
|
|
792
|
+
"Striking"
|
|
793
|
+
if instruction.direction == "strike"
|
|
794
|
+
else "Restoring"
|
|
795
|
+
),
|
|
796
|
+
instruction.call_repr(),
|
|
797
|
+
extra=self.labels(),
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
STRIKES_IN_EFFECT.add(
|
|
801
|
+
1 if instruction.direction == "strike" else -1,
|
|
802
|
+
{
|
|
803
|
+
**self.labels(),
|
|
804
|
+
**instruction.labels(),
|
|
805
|
+
},
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
except redis.exceptions.ConnectionError: # pragma: no cover
|
|
809
|
+
REDIS_DISRUPTIONS.add(1, {"docket": self.name})
|
|
810
|
+
logger.warning("Connection error, sleeping for 1 second...")
|
|
811
|
+
await asyncio.sleep(1)
|
|
812
|
+
except Exception: # pragma: no cover
|
|
813
|
+
logger.exception("Error monitoring strikes")
|
|
814
|
+
await asyncio.sleep(1)
|
|
815
|
+
|
|
816
|
+
async def snapshot(self) -> DocketSnapshot:
|
|
817
|
+
"""Get a snapshot of the Docket, including which tasks are scheduled or currently
|
|
818
|
+
running, as well as which workers are active.
|
|
819
|
+
|
|
820
|
+
Returns:
|
|
821
|
+
A snapshot of the Docket.
|
|
822
|
+
"""
|
|
823
|
+
# For memory:// URLs (fakeredis), ensure the group exists upfront. This
|
|
824
|
+
# avoids a fakeredis bug where xpending_range raises TypeError instead
|
|
825
|
+
# of NOGROUP when the consumer group doesn't exist.
|
|
826
|
+
if self.url.startswith("memory://"):
|
|
827
|
+
await self._ensure_stream_and_group()
|
|
828
|
+
|
|
829
|
+
running: list[RunningExecution] = []
|
|
830
|
+
future: list[Execution] = []
|
|
831
|
+
|
|
832
|
+
async with self.redis() as r:
|
|
833
|
+
async with r.pipeline() as pipeline:
|
|
834
|
+
pipeline.xlen(self.stream_key)
|
|
835
|
+
|
|
836
|
+
pipeline.zcard(self.queue_key)
|
|
837
|
+
|
|
838
|
+
pipeline.xpending_range(
|
|
839
|
+
self.stream_key,
|
|
840
|
+
self.worker_group_name,
|
|
841
|
+
min="-",
|
|
842
|
+
max="+",
|
|
843
|
+
count=1000,
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
pipeline.xrange(self.stream_key, "-", "+", count=1000)
|
|
847
|
+
|
|
848
|
+
pipeline.zrange(self.queue_key, 0, -1)
|
|
849
|
+
|
|
850
|
+
total_stream_messages: int
|
|
851
|
+
total_schedule_messages: int
|
|
852
|
+
pending_messages: list[RedisStreamPendingMessage]
|
|
853
|
+
stream_messages: list[tuple[RedisMessageID, RedisMessage]]
|
|
854
|
+
scheduled_task_keys: list[bytes]
|
|
855
|
+
|
|
856
|
+
now = datetime.now(timezone.utc)
|
|
857
|
+
try:
|
|
858
|
+
(
|
|
859
|
+
total_stream_messages,
|
|
860
|
+
total_schedule_messages,
|
|
861
|
+
pending_messages,
|
|
862
|
+
stream_messages,
|
|
863
|
+
scheduled_task_keys,
|
|
864
|
+
) = await pipeline.execute()
|
|
865
|
+
except redis.exceptions.ResponseError as e:
|
|
866
|
+
# Check for NOGROUP error. Also check for XPENDING because
|
|
867
|
+
# redis-py 7.0 has a bug where pipeline errors lose the
|
|
868
|
+
# original NOGROUP message (shows "{exception.args}" instead).
|
|
869
|
+
error_str = str(e)
|
|
870
|
+
if "NOGROUP" in error_str or "XPENDING" in error_str:
|
|
871
|
+
await self._ensure_stream_and_group()
|
|
872
|
+
return await self.snapshot()
|
|
873
|
+
raise # pragma: no cover
|
|
874
|
+
|
|
875
|
+
for task_key in scheduled_task_keys:
|
|
876
|
+
pipeline.hgetall(self.parked_task_key(task_key.decode()))
|
|
877
|
+
|
|
878
|
+
# Because these are two separate pipeline commands, it's possible that
|
|
879
|
+
# a message has been moved from the schedule to the stream in the
|
|
880
|
+
# meantime, which would end up being an empty `{}` message
|
|
881
|
+
queued_messages: list[RedisMessage] = [
|
|
882
|
+
m for m in await pipeline.execute() if m
|
|
883
|
+
]
|
|
884
|
+
|
|
885
|
+
total_tasks = total_stream_messages + total_schedule_messages
|
|
886
|
+
|
|
887
|
+
pending_lookup: dict[RedisMessageID, RedisStreamPendingMessage] = {
|
|
888
|
+
pending["message_id"]: pending for pending in pending_messages
|
|
889
|
+
}
|
|
890
|
+
|
|
891
|
+
for message_id, message in stream_messages:
|
|
892
|
+
execution = await Execution.from_message(self, message)
|
|
893
|
+
if message_id in pending_lookup:
|
|
894
|
+
worker_name = pending_lookup[message_id]["consumer"].decode()
|
|
895
|
+
started = now - timedelta(
|
|
896
|
+
milliseconds=pending_lookup[message_id]["time_since_delivered"]
|
|
897
|
+
)
|
|
898
|
+
running.append(RunningExecution(execution, worker_name, started))
|
|
899
|
+
else:
|
|
900
|
+
future.append(execution) # pragma: no cover
|
|
901
|
+
|
|
902
|
+
for message in queued_messages:
|
|
903
|
+
execution = await Execution.from_message(self, message)
|
|
904
|
+
future.append(execution)
|
|
905
|
+
|
|
906
|
+
workers = await self.workers()
|
|
907
|
+
|
|
908
|
+
return DocketSnapshot(now, total_tasks, future, running, workers)
|
|
909
|
+
|
|
910
|
+
@property
|
|
911
|
+
def workers_set(self) -> str:
|
|
912
|
+
return f"{self.name}:workers"
|
|
913
|
+
|
|
914
|
+
def worker_tasks_set(self, worker_name: str) -> str:
|
|
915
|
+
return f"{self.name}:worker-tasks:{worker_name}"
|
|
916
|
+
|
|
917
|
+
def task_workers_set(self, task_name: str) -> str:
|
|
918
|
+
return f"{self.name}:task-workers:{task_name}"
|
|
919
|
+
|
|
920
|
+
async def workers(self) -> Collection[WorkerInfo]:
|
|
921
|
+
"""Get a list of all workers that have sent heartbeats to the Docket.
|
|
922
|
+
|
|
923
|
+
Returns:
|
|
924
|
+
A list of all workers that have sent heartbeats to the Docket.
|
|
925
|
+
"""
|
|
926
|
+
workers: list[WorkerInfo] = []
|
|
927
|
+
|
|
928
|
+
oldest = datetime.now(timezone.utc).timestamp() - (
|
|
929
|
+
self.heartbeat_interval.total_seconds() * self.missed_heartbeats
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
async with self.redis() as r:
|
|
933
|
+
await r.zremrangebyscore(self.workers_set, 0, oldest)
|
|
934
|
+
|
|
935
|
+
worker_name_bytes: bytes
|
|
936
|
+
last_seen_timestamp: float
|
|
937
|
+
|
|
938
|
+
for worker_name_bytes, last_seen_timestamp in await r.zrange(
|
|
939
|
+
self.workers_set, 0, -1, withscores=True
|
|
940
|
+
):
|
|
941
|
+
worker_name = worker_name_bytes.decode()
|
|
942
|
+
last_seen = datetime.fromtimestamp(last_seen_timestamp, timezone.utc)
|
|
943
|
+
|
|
944
|
+
task_names: set[str] = {
|
|
945
|
+
task_name_bytes.decode()
|
|
946
|
+
for task_name_bytes in cast(
|
|
947
|
+
set[bytes], await r.smembers(self.worker_tasks_set(worker_name))
|
|
948
|
+
)
|
|
949
|
+
}
|
|
950
|
+
|
|
951
|
+
workers.append(WorkerInfo(worker_name, last_seen, task_names))
|
|
952
|
+
|
|
953
|
+
return workers
|
|
954
|
+
|
|
955
|
+
async def task_workers(self, task_name: str) -> Collection[WorkerInfo]:
|
|
956
|
+
"""Get a list of all workers that are able to execute a given task.
|
|
957
|
+
|
|
958
|
+
Args:
|
|
959
|
+
task_name: The name of the task.
|
|
960
|
+
|
|
961
|
+
Returns:
|
|
962
|
+
A list of all workers that are able to execute the given task.
|
|
963
|
+
"""
|
|
964
|
+
workers: list[WorkerInfo] = []
|
|
965
|
+
oldest = datetime.now(timezone.utc).timestamp() - (
|
|
966
|
+
self.heartbeat_interval.total_seconds() * self.missed_heartbeats
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
async with self.redis() as r:
|
|
970
|
+
await r.zremrangebyscore(self.task_workers_set(task_name), 0, oldest)
|
|
971
|
+
|
|
972
|
+
worker_name_bytes: bytes
|
|
973
|
+
last_seen_timestamp: float
|
|
974
|
+
|
|
975
|
+
for worker_name_bytes, last_seen_timestamp in await r.zrange(
|
|
976
|
+
self.task_workers_set(task_name), 0, -1, withscores=True
|
|
977
|
+
):
|
|
978
|
+
worker_name = worker_name_bytes.decode()
|
|
979
|
+
last_seen = datetime.fromtimestamp(last_seen_timestamp, timezone.utc)
|
|
980
|
+
|
|
981
|
+
task_names: set[str] = {
|
|
982
|
+
task_name_bytes.decode()
|
|
983
|
+
for task_name_bytes in cast(
|
|
984
|
+
set[bytes], await r.smembers(self.worker_tasks_set(worker_name))
|
|
985
|
+
)
|
|
986
|
+
}
|
|
987
|
+
|
|
988
|
+
workers.append(WorkerInfo(worker_name, last_seen, task_names))
|
|
989
|
+
|
|
990
|
+
return workers
|
|
991
|
+
|
|
992
|
+
async def clear(self) -> int:
|
|
993
|
+
"""Clear all queued and scheduled tasks from the docket.
|
|
994
|
+
|
|
995
|
+
This removes all tasks from the stream (immediate tasks) and queue
|
|
996
|
+
(scheduled tasks), along with their associated parked data. Running
|
|
997
|
+
tasks are not affected.
|
|
998
|
+
|
|
999
|
+
Returns:
|
|
1000
|
+
The total number of tasks that were cleared.
|
|
1001
|
+
"""
|
|
1002
|
+
with tracer.start_as_current_span(
|
|
1003
|
+
"docket.clear",
|
|
1004
|
+
attributes=self.labels(),
|
|
1005
|
+
):
|
|
1006
|
+
async with self.redis() as redis:
|
|
1007
|
+
async with redis.pipeline() as pipeline:
|
|
1008
|
+
# Get counts before clearing
|
|
1009
|
+
pipeline.xlen(self.stream_key)
|
|
1010
|
+
pipeline.zcard(self.queue_key)
|
|
1011
|
+
pipeline.zrange(self.queue_key, 0, -1)
|
|
1012
|
+
|
|
1013
|
+
stream_count: int
|
|
1014
|
+
queue_count: int
|
|
1015
|
+
scheduled_keys: list[bytes]
|
|
1016
|
+
stream_count, queue_count, scheduled_keys = await pipeline.execute()
|
|
1017
|
+
|
|
1018
|
+
# Get keys from stream messages before trimming
|
|
1019
|
+
stream_keys: list[str] = []
|
|
1020
|
+
if stream_count > 0:
|
|
1021
|
+
# Read all messages from the stream
|
|
1022
|
+
messages = await redis.xrange(self.stream_key, "-", "+")
|
|
1023
|
+
for message_id, fields in messages:
|
|
1024
|
+
# Extract the key field from the message
|
|
1025
|
+
if b"key" in fields: # pragma: no branch
|
|
1026
|
+
stream_keys.append(fields[b"key"].decode())
|
|
1027
|
+
|
|
1028
|
+
async with redis.pipeline() as pipeline:
|
|
1029
|
+
# Clear all data
|
|
1030
|
+
# Trim stream to 0 messages instead of deleting it to preserve consumer group
|
|
1031
|
+
if stream_count > 0:
|
|
1032
|
+
pipeline.xtrim(self.stream_key, maxlen=0, approximate=False)
|
|
1033
|
+
pipeline.delete(self.queue_key)
|
|
1034
|
+
|
|
1035
|
+
# Clear parked task data and known task keys for scheduled tasks
|
|
1036
|
+
for key_bytes in scheduled_keys:
|
|
1037
|
+
key = key_bytes.decode()
|
|
1038
|
+
pipeline.delete(self.parked_task_key(key))
|
|
1039
|
+
pipeline.delete(self.known_task_key(key))
|
|
1040
|
+
pipeline.delete(self.stream_id_key(key))
|
|
1041
|
+
|
|
1042
|
+
# Handle runs hash: set TTL or delete based on execution_ttl
|
|
1043
|
+
runs_key = f"{self.name}:runs:{key}"
|
|
1044
|
+
if self.execution_ttl:
|
|
1045
|
+
ttl_seconds = int(self.execution_ttl.total_seconds())
|
|
1046
|
+
pipeline.expire(runs_key, ttl_seconds)
|
|
1047
|
+
else:
|
|
1048
|
+
pipeline.delete(runs_key)
|
|
1049
|
+
|
|
1050
|
+
# Handle runs hash for immediate tasks from stream
|
|
1051
|
+
for key in stream_keys:
|
|
1052
|
+
runs_key = f"{self.name}:runs:{key}"
|
|
1053
|
+
if self.execution_ttl:
|
|
1054
|
+
ttl_seconds = int(self.execution_ttl.total_seconds())
|
|
1055
|
+
pipeline.expire(runs_key, ttl_seconds)
|
|
1056
|
+
else:
|
|
1057
|
+
pipeline.delete(runs_key)
|
|
1058
|
+
|
|
1059
|
+
await pipeline.execute()
|
|
1060
|
+
|
|
1061
|
+
total_cleared = stream_count + queue_count
|
|
1062
|
+
return total_cleared
|