pydocket 0.15.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
docket/docket.py ADDED
@@ -0,0 +1,1062 @@
1
+ import asyncio
2
+ import importlib
3
+ import logging
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass
6
+ from datetime import datetime, timedelta, timezone
7
+ from types import TracebackType
8
+ from typing import (
9
+ AsyncGenerator,
10
+ Awaitable,
11
+ Callable,
12
+ Collection,
13
+ Hashable,
14
+ Iterable,
15
+ Mapping,
16
+ NoReturn,
17
+ ParamSpec,
18
+ Protocol,
19
+ Sequence,
20
+ TypedDict,
21
+ TypeVar,
22
+ cast,
23
+ overload,
24
+ )
25
+
26
+ from key_value.aio.stores.base import BaseContextManagerStore
27
+ from typing_extensions import Self
28
+
29
+ import redis.exceptions
30
+ from opentelemetry import trace
31
+ from redis.asyncio import ConnectionPool, Redis
32
+ from ._uuid7 import uuid7
33
+
34
+ from .execution import (
35
+ Execution,
36
+ ExecutionState,
37
+ LiteralOperator,
38
+ Operator,
39
+ Restore,
40
+ Strike,
41
+ StrikeInstruction,
42
+ StrikeList,
43
+ TaskFunction,
44
+ )
45
+ from key_value.aio.protocols.key_value import AsyncKeyValue
46
+ from key_value.aio.stores.redis import RedisStore
47
+ from key_value.aio.stores.memory import MemoryStore
48
+
49
+ from .instrumentation import (
50
+ REDIS_DISRUPTIONS,
51
+ STRIKES_IN_EFFECT,
52
+ TASKS_ADDED,
53
+ TASKS_CANCELLED,
54
+ TASKS_REPLACED,
55
+ TASKS_SCHEDULED,
56
+ TASKS_STRICKEN,
57
+ )
58
+
59
+ logger: logging.Logger = logging.getLogger(__name__)
60
+ tracer: trace.Tracer = trace.get_tracer(__name__)
61
+
62
+
63
+ class _cancel_task(Protocol):
64
+ async def __call__(
65
+ self, keys: list[str], args: list[str]
66
+ ) -> str: ... # pragma: no cover
67
+
68
+
69
+ P = ParamSpec("P")
70
+ R = TypeVar("R")
71
+
72
+ TaskCollection = Iterable[TaskFunction]
73
+
74
+ RedisStreamID = bytes
75
+ RedisMessageID = bytes
76
+ RedisMessage = dict[bytes, bytes]
77
+ RedisMessages = Sequence[tuple[RedisMessageID, RedisMessage]]
78
+ RedisStream = tuple[RedisStreamID, RedisMessages]
79
+ RedisReadGroupResponse = Sequence[RedisStream]
80
+
81
+
82
+ class RedisStreamPendingMessage(TypedDict):
83
+ message_id: bytes
84
+ consumer: bytes
85
+ time_since_delivered: int
86
+ times_delivered: int
87
+
88
+
89
+ @dataclass
90
+ class WorkerInfo:
91
+ name: str
92
+ last_seen: datetime
93
+ tasks: set[str]
94
+
95
+
96
+ class RunningExecution(Execution):
97
+ worker: str
98
+ started: datetime
99
+
100
+ def __init__(
101
+ self,
102
+ execution: Execution,
103
+ worker: str,
104
+ started: datetime,
105
+ ) -> None:
106
+ # Call parent constructor to properly initialize immutable fields
107
+ super().__init__(
108
+ docket=execution.docket,
109
+ function=execution.function,
110
+ args=execution.args,
111
+ kwargs=execution.kwargs,
112
+ key=execution.key,
113
+ when=execution.when,
114
+ attempt=execution.attempt,
115
+ trace_context=execution.trace_context,
116
+ redelivered=execution.redelivered,
117
+ )
118
+ # Copy over mutable state fields
119
+ self.state: ExecutionState = execution.state
120
+ self.started_at: datetime | None = execution.started_at
121
+ self.completed_at: datetime | None = execution.completed_at
122
+ self.error: str | None = execution.error
123
+ self.result_key: str | None = execution.result_key
124
+ # Set RunningExecution-specific fields
125
+ self.worker = worker
126
+ self.started = started
127
+
128
+
129
+ @dataclass
130
+ class DocketSnapshot:
131
+ taken: datetime
132
+ total_tasks: int
133
+ future: Sequence[Execution]
134
+ running: Sequence[RunningExecution]
135
+ workers: Collection[WorkerInfo]
136
+
137
+
138
+ class Docket:
139
+ """A Docket represents a collection of tasks that may be scheduled for later
140
+ execution. With a Docket, you can add, replace, and cancel tasks.
141
+ Example:
142
+
143
+ ```python
144
+ @task
145
+ async def my_task(greeting: str, recipient: str) -> None:
146
+ print(f"{greeting}, {recipient}!")
147
+
148
+ async with Docket() as docket:
149
+ docket.add(my_task)("Hello", recipient="world")
150
+ ```
151
+ """
152
+
153
+ tasks: dict[str, TaskFunction]
154
+ strike_list: StrikeList
155
+
156
+ _monitor_strikes_task: asyncio.Task[None]
157
+ _connection_pool: ConnectionPool
158
+ _cancel_task_script: _cancel_task | None
159
+
160
+ def __init__(
161
+ self,
162
+ name: str = "docket",
163
+ url: str = "redis://localhost:6379/0",
164
+ heartbeat_interval: timedelta = timedelta(seconds=2),
165
+ missed_heartbeats: int = 5,
166
+ execution_ttl: timedelta = timedelta(minutes=15),
167
+ result_storage: AsyncKeyValue | None = None,
168
+ ) -> None:
169
+ """
170
+ Args:
171
+ name: The name of the docket.
172
+ url: The URL of the Redis server or in-memory backend. For example:
173
+ - "redis://localhost:6379/0"
174
+ - "redis://user:password@localhost:6379/0"
175
+ - "redis://user:password@localhost:6379/0?ssl=true"
176
+ - "rediss://localhost:6379/0"
177
+ - "unix:///path/to/redis.sock"
178
+ - "memory://" (in-memory backend for testing)
179
+ heartbeat_interval: How often workers send heartbeat messages to the docket.
180
+ missed_heartbeats: How many heartbeats a worker can miss before it is
181
+ considered dead.
182
+ execution_ttl: How long to keep completed or failed execution state records
183
+ in Redis before they expire. Defaults to 15 minutes.
184
+ """
185
+ self.name = name
186
+ self.url = url
187
+ self.heartbeat_interval = heartbeat_interval
188
+ self.missed_heartbeats = missed_heartbeats
189
+ self.execution_ttl = execution_ttl
190
+ self._cancel_task_script = None
191
+
192
+ self.result_storage: AsyncKeyValue
193
+ if url.startswith("memory://"):
194
+ self.result_storage = MemoryStore()
195
+ else:
196
+ self.result_storage = RedisStore(
197
+ url=url, default_collection=f"{name}:results"
198
+ )
199
+
200
+ from .tasks import standard_tasks
201
+
202
+ self.tasks: dict[str, TaskFunction] = {fn.__name__: fn for fn in standard_tasks}
203
+
204
+ @property
205
+ def worker_group_name(self) -> str:
206
+ return "docket-workers"
207
+
208
+ async def __aenter__(self) -> Self:
209
+ self.strike_list = StrikeList()
210
+
211
+ # Check if we should use in-memory backend (fakeredis)
212
+ # Support memory:// URLs for in-memory dockets
213
+ if self.url.startswith("memory://"):
214
+ try:
215
+ from fakeredis.aioredis import FakeConnection, FakeServer
216
+
217
+ # All memory:// URLs share a single FakeServer instance
218
+ # Multiple dockets with different names are isolated by Redis key prefixes
219
+ # (e.g., docket1:stream vs docket2:stream)
220
+ if not hasattr(Docket, "_memory_server"):
221
+ Docket._memory_server = FakeServer() # type: ignore
222
+
223
+ server = Docket._memory_server # type: ignore
224
+ self._connection_pool = ConnectionPool(
225
+ connection_class=FakeConnection, server=server
226
+ )
227
+ except ImportError as e:
228
+ raise ImportError(
229
+ "fakeredis is required for memory:// URLs. "
230
+ "Install with: pip install pydocket[memory]"
231
+ ) from e
232
+ else:
233
+ self._connection_pool = ConnectionPool.from_url(self.url) # type: ignore
234
+
235
+ self._monitor_strikes_task = asyncio.create_task(self._monitor_strikes())
236
+
237
+ if isinstance(self.result_storage, BaseContextManagerStore):
238
+ await self.result_storage.__aenter__()
239
+ else:
240
+ await self.result_storage.setup()
241
+ return self
242
+
243
+ async def __aexit__(
244
+ self,
245
+ exc_type: type[BaseException] | None,
246
+ exc_value: BaseException | None,
247
+ traceback: TracebackType | None,
248
+ ) -> None:
249
+ if isinstance(self.result_storage, BaseContextManagerStore):
250
+ await self.result_storage.__aexit__(exc_type, exc_value, traceback)
251
+
252
+ del self.strike_list
253
+
254
+ self._monitor_strikes_task.cancel()
255
+ try:
256
+ await self._monitor_strikes_task
257
+ except asyncio.CancelledError:
258
+ pass
259
+
260
+ await asyncio.shield(self._connection_pool.disconnect())
261
+ del self._connection_pool
262
+
263
+ @asynccontextmanager
264
+ async def redis(self) -> AsyncGenerator[Redis, None]:
265
+ r = Redis(connection_pool=self._connection_pool)
266
+ await r.__aenter__()
267
+ try:
268
+ yield r
269
+ finally:
270
+ await asyncio.shield(r.__aexit__(None, None, None))
271
+
272
+ def register(self, function: TaskFunction) -> None:
273
+ """Register a task with the Docket.
274
+
275
+ Args:
276
+ function: The task to register.
277
+ """
278
+ from .dependencies import validate_dependencies
279
+
280
+ validate_dependencies(function)
281
+
282
+ self.tasks[function.__name__] = function
283
+
284
+ def register_collection(self, collection_path: str) -> None:
285
+ """
286
+ Register a collection of tasks.
287
+
288
+ Args:
289
+ collection_path: A path in the format "module:collection".
290
+ """
291
+ module_name, _, member_name = collection_path.rpartition(":")
292
+ module = importlib.import_module(module_name)
293
+ collection = getattr(module, member_name)
294
+ for function in collection:
295
+ self.register(function)
296
+
297
+ def labels(self) -> Mapping[str, str]:
298
+ return {
299
+ "docket.name": self.name,
300
+ }
301
+
302
+ @overload
303
+ def add(
304
+ self,
305
+ function: Callable[P, Awaitable[R]],
306
+ when: datetime | None = None,
307
+ key: str | None = None,
308
+ ) -> Callable[P, Awaitable[Execution]]:
309
+ """Add a task to the Docket.
310
+
311
+ Args:
312
+ function: The task function to add.
313
+ when: The time to schedule the task.
314
+ key: The key to schedule the task under.
315
+ """
316
+
317
+ @overload
318
+ def add(
319
+ self,
320
+ function: str,
321
+ when: datetime | None = None,
322
+ key: str | None = None,
323
+ ) -> Callable[..., Awaitable[Execution]]:
324
+ """Add a task to the Docket.
325
+
326
+ Args:
327
+ function: The name of a task to add.
328
+ when: The time to schedule the task.
329
+ key: The key to schedule the task under.
330
+ """
331
+
332
+ def add(
333
+ self,
334
+ function: Callable[P, Awaitable[R]] | str,
335
+ when: datetime | None = None,
336
+ key: str | None = None,
337
+ ) -> Callable[..., Awaitable[Execution]]:
338
+ """Add a task to the Docket.
339
+
340
+ Args:
341
+ function: The task to add.
342
+ when: The time to schedule the task.
343
+ key: The key to schedule the task under.
344
+ """
345
+ if isinstance(function, str):
346
+ function = self.tasks[function]
347
+ else:
348
+ self.register(function)
349
+
350
+ if when is None:
351
+ when = datetime.now(timezone.utc)
352
+
353
+ if key is None:
354
+ key = str(uuid7())
355
+
356
+ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
357
+ execution = Execution(self, function, args, kwargs, key, when, attempt=1)
358
+
359
+ # Check if task is stricken before scheduling
360
+ if self.strike_list.is_stricken(execution):
361
+ logger.warning(
362
+ "%r is stricken, skipping schedule of %r",
363
+ execution.function.__name__,
364
+ execution.key,
365
+ )
366
+ TASKS_STRICKEN.add(
367
+ 1,
368
+ {
369
+ **self.labels(),
370
+ **execution.general_labels(),
371
+ "docket.where": "docket",
372
+ },
373
+ )
374
+ return execution
375
+
376
+ # Schedule atomically (includes state record write)
377
+ await execution.schedule(replace=False)
378
+
379
+ TASKS_ADDED.add(1, {**self.labels(), **execution.general_labels()})
380
+ TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
381
+
382
+ return execution
383
+
384
+ return scheduler
385
+
386
+ @overload
387
+ def replace(
388
+ self,
389
+ function: Callable[P, Awaitable[R]],
390
+ when: datetime,
391
+ key: str,
392
+ ) -> Callable[P, Awaitable[Execution]]:
393
+ """Replace a previously scheduled task on the Docket.
394
+
395
+ Args:
396
+ function: The task function to replace.
397
+ when: The time to schedule the task.
398
+ key: The key to schedule the task under.
399
+ """
400
+
401
+ @overload
402
+ def replace(
403
+ self,
404
+ function: str,
405
+ when: datetime,
406
+ key: str,
407
+ ) -> Callable[..., Awaitable[Execution]]:
408
+ """Replace a previously scheduled task on the Docket.
409
+
410
+ Args:
411
+ function: The name of a task to replace.
412
+ when: The time to schedule the task.
413
+ key: The key to schedule the task under.
414
+ """
415
+
416
+ def replace(
417
+ self,
418
+ function: Callable[P, Awaitable[R]] | str,
419
+ when: datetime,
420
+ key: str,
421
+ ) -> Callable[..., Awaitable[Execution]]:
422
+ """Replace a previously scheduled task on the Docket.
423
+
424
+ Args:
425
+ function: The task to replace.
426
+ when: The time to schedule the task.
427
+ key: The key to schedule the task under.
428
+ """
429
+ if isinstance(function, str):
430
+ function = self.tasks[function]
431
+ else:
432
+ self.register(function)
433
+
434
+ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
435
+ execution = Execution(self, function, args, kwargs, key, when, attempt=1)
436
+
437
+ # Check if task is stricken before scheduling
438
+ if self.strike_list.is_stricken(execution):
439
+ logger.warning(
440
+ "%r is stricken, skipping schedule of %r",
441
+ execution.function.__name__,
442
+ execution.key,
443
+ )
444
+ TASKS_STRICKEN.add(
445
+ 1,
446
+ {
447
+ **self.labels(),
448
+ **execution.general_labels(),
449
+ "docket.where": "docket",
450
+ },
451
+ )
452
+ return execution
453
+
454
+ # Schedule atomically (includes state record write)
455
+ await execution.schedule(replace=True)
456
+
457
+ TASKS_REPLACED.add(1, {**self.labels(), **execution.general_labels()})
458
+ TASKS_CANCELLED.add(1, {**self.labels(), **execution.general_labels()})
459
+ TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
460
+
461
+ return execution
462
+
463
+ return scheduler
464
+
465
+ async def schedule(self, execution: Execution) -> None:
466
+ with tracer.start_as_current_span(
467
+ "docket.schedule",
468
+ attributes={
469
+ **self.labels(),
470
+ **execution.specific_labels(),
471
+ "code.function.name": execution.function.__name__,
472
+ },
473
+ ):
474
+ # Check if task is stricken before scheduling
475
+ if self.strike_list.is_stricken(execution):
476
+ logger.warning(
477
+ "%r is stricken, skipping schedule of %r",
478
+ execution.function.__name__,
479
+ execution.key,
480
+ )
481
+ TASKS_STRICKEN.add(
482
+ 1,
483
+ {
484
+ **self.labels(),
485
+ **execution.general_labels(),
486
+ "docket.where": "docket",
487
+ },
488
+ )
489
+ return
490
+
491
+ # Schedule atomically (includes state record write)
492
+ await execution.schedule(replace=False)
493
+
494
+ TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
495
+
496
+ async def cancel(self, key: str) -> None:
497
+ """Cancel a previously scheduled task on the Docket.
498
+
499
+ Args:
500
+ key: The key of the task to cancel.
501
+ """
502
+ with tracer.start_as_current_span(
503
+ "docket.cancel",
504
+ attributes={**self.labels(), "docket.key": key},
505
+ ):
506
+ async with self.redis() as redis:
507
+ await self._cancel(redis, key)
508
+
509
+ TASKS_CANCELLED.add(1, self.labels())
510
+
511
+ async def get_execution(self, key: str) -> Execution | None:
512
+ """Get a task Execution from the Docket by its key.
513
+
514
+ Args:
515
+ key: The task key.
516
+
517
+ Returns:
518
+ The Execution if found, None if the key doesn't exist.
519
+
520
+ Example:
521
+ # Claim check pattern: schedule a task, save the key,
522
+ # then retrieve the execution later to check status or get results
523
+ execution = await docket.add(my_task, key="important-task")(args)
524
+ task_key = execution.key
525
+
526
+ # Later, retrieve the execution by key
527
+ execution = await docket.get_execution(task_key)
528
+ if execution:
529
+ await execution.get_result()
530
+ """
531
+ import cloudpickle
532
+
533
+ async with self.redis() as redis:
534
+ runs_key = f"{self.name}:runs:{key}"
535
+ data = await redis.hgetall(runs_key)
536
+
537
+ if not data:
538
+ return None
539
+
540
+ # Extract task definition from runs hash
541
+ function_name = data.get(b"function")
542
+ args_data = data.get(b"args")
543
+ kwargs_data = data.get(b"kwargs")
544
+
545
+ # TODO: Remove in next breaking release (v0.14.0) - fallback for 0.13.0 compatibility
546
+ # Check parked hash if runs hash incomplete (0.13.0 didn't store task data in runs hash)
547
+ if not function_name or not args_data or not kwargs_data:
548
+ parked_key = self.parked_task_key(key)
549
+ parked_data = await redis.hgetall(parked_key)
550
+ if parked_data:
551
+ function_name = parked_data.get(b"function")
552
+ args_data = parked_data.get(b"args")
553
+ kwargs_data = parked_data.get(b"kwargs")
554
+
555
+ if not function_name or not args_data or not kwargs_data:
556
+ return None
557
+
558
+ # Look up function in registry, or create a placeholder if not found
559
+ function_name_str = function_name.decode()
560
+ function = self.tasks.get(function_name_str)
561
+ if not function:
562
+ # Create a placeholder function for display purposes (e.g., CLI watch)
563
+ # This allows viewing task state even if function isn't registered
564
+ async def placeholder() -> None:
565
+ pass # pragma: no cover
566
+
567
+ placeholder.__name__ = function_name_str
568
+ function = placeholder
569
+
570
+ # Deserialize args and kwargs
571
+ args = cloudpickle.loads(args_data)
572
+ kwargs = cloudpickle.loads(kwargs_data)
573
+
574
+ # Extract scheduling metadata
575
+ when_str = data.get(b"when")
576
+ if not when_str:
577
+ return None
578
+ when = datetime.fromtimestamp(float(when_str.decode()), tz=timezone.utc)
579
+
580
+ # Build execution (attempt defaults to 1 for initial scheduling)
581
+ from docket.execution import Execution
582
+
583
+ execution = Execution(
584
+ docket=self,
585
+ function=function,
586
+ args=args,
587
+ kwargs=kwargs,
588
+ key=key,
589
+ when=when,
590
+ attempt=1,
591
+ )
592
+
593
+ # Sync with current state from Redis
594
+ await execution.sync()
595
+
596
+ return execution
597
+
598
+ @property
599
+ def queue_key(self) -> str:
600
+ return f"{self.name}:queue"
601
+
602
+ @property
603
+ def stream_key(self) -> str:
604
+ return f"{self.name}:stream"
605
+
606
+ def known_task_key(self, key: str) -> str:
607
+ return f"{self.name}:known:{key}"
608
+
609
+ def parked_task_key(self, key: str) -> str:
610
+ return f"{self.name}:{key}"
611
+
612
+ def stream_id_key(self, key: str) -> str:
613
+ return f"{self.name}:stream-id:{key}"
614
+
615
+ async def _ensure_stream_and_group(self) -> None:
616
+ """Create stream and consumer group if they don't exist (idempotent).
617
+
618
+ This is safe to call from multiple workers racing to initialize - the
619
+ BUSYGROUP error is silently ignored since it just means another worker
620
+ created the group first.
621
+ """
622
+ try:
623
+ async with self.redis() as r:
624
+ await r.xgroup_create(
625
+ groupname=self.worker_group_name,
626
+ name=self.stream_key,
627
+ id="0-0",
628
+ mkstream=True,
629
+ )
630
+ except redis.exceptions.ResponseError as e:
631
+ if "BUSYGROUP" not in str(e):
632
+ raise # pragma: no cover
633
+
634
+ async def _cancel(self, redis: Redis, key: str) -> None:
635
+ """Cancel a task atomically.
636
+
637
+ Handles cancellation regardless of task location:
638
+ - From the stream (using stored message ID)
639
+ - From the queue (scheduled tasks)
640
+ - Cleans up all associated metadata keys
641
+ """
642
+ if self._cancel_task_script is None:
643
+ self._cancel_task_script = cast(
644
+ _cancel_task,
645
+ redis.register_script(
646
+ # KEYS: stream_key, known_key, parked_key, queue_key, stream_id_key, runs_key
647
+ # ARGV: task_key, completed_at
648
+ """
649
+ local stream_key = KEYS[1]
650
+ -- TODO: Remove in next breaking release (v0.14.0) - legacy key locations
651
+ local known_key = KEYS[2]
652
+ local parked_key = KEYS[3]
653
+ local queue_key = KEYS[4]
654
+ local stream_id_key = KEYS[5]
655
+ local runs_key = KEYS[6]
656
+ local task_key = ARGV[1]
657
+ local completed_at = ARGV[2]
658
+
659
+ -- Get stream ID (check new location first, then legacy)
660
+ local message_id = redis.call('HGET', runs_key, 'stream_id')
661
+
662
+ -- TODO: Remove in next breaking release (v0.14.0) - check legacy location
663
+ if not message_id then
664
+ message_id = redis.call('GET', stream_id_key)
665
+ end
666
+
667
+ -- Delete from stream if message ID exists
668
+ if message_id then
669
+ redis.call('XDEL', stream_key, message_id)
670
+ end
671
+
672
+ -- Clean up legacy keys and parked data
673
+ redis.call('DEL', known_key, parked_key, stream_id_key)
674
+ redis.call('ZREM', queue_key, task_key)
675
+
676
+ -- Create tombstone: set CANCELLED state with completed_at timestamp
677
+ redis.call('HSET', runs_key, 'state', 'cancelled', 'completed_at', completed_at)
678
+
679
+ return 'OK'
680
+ """
681
+ ),
682
+ )
683
+ cancel_task = self._cancel_task_script
684
+
685
+ # Create tombstone with CANCELLED state
686
+ completed_at = datetime.now(timezone.utc).isoformat()
687
+ runs_key = f"{self.name}:runs:{key}"
688
+
689
+ # Execute the cancellation script
690
+ await cancel_task(
691
+ keys=[
692
+ self.stream_key,
693
+ self.known_task_key(key),
694
+ self.parked_task_key(key),
695
+ self.queue_key,
696
+ self.stream_id_key(key),
697
+ runs_key,
698
+ ],
699
+ args=[key, completed_at],
700
+ )
701
+
702
+ # Apply TTL or delete tombstone based on execution_ttl
703
+ if self.execution_ttl:
704
+ ttl_seconds = int(self.execution_ttl.total_seconds())
705
+ await redis.expire(runs_key, ttl_seconds)
706
+ else:
707
+ # execution_ttl=0 means no observability - delete tombstone immediately
708
+ await redis.delete(runs_key)
709
+
710
+ @property
711
+ def strike_key(self) -> str:
712
+ return f"{self.name}:strikes"
713
+
714
+ async def strike(
715
+ self,
716
+ function: Callable[P, Awaitable[R]] | str | None = None,
717
+ parameter: str | None = None,
718
+ operator: Operator | LiteralOperator = "==",
719
+ value: Hashable | None = None,
720
+ ) -> None:
721
+ """Strike a task from the Docket.
722
+
723
+ Args:
724
+ function: The task to strike.
725
+ parameter: The parameter to strike on.
726
+ operator: The operator to use.
727
+ value: The value to strike on.
728
+ """
729
+ if not isinstance(function, (str, type(None))):
730
+ function = function.__name__
731
+
732
+ operator = Operator(operator)
733
+
734
+ strike = Strike(function, parameter, operator, value)
735
+ return await self._send_strike_instruction(strike)
736
+
737
+ async def restore(
738
+ self,
739
+ function: Callable[P, Awaitable[R]] | str | None = None,
740
+ parameter: str | None = None,
741
+ operator: Operator | LiteralOperator = "==",
742
+ value: Hashable | None = None,
743
+ ) -> None:
744
+ """Restore a previously stricken task to the Docket.
745
+
746
+ Args:
747
+ function: The task to restore.
748
+ parameter: The parameter to restore on.
749
+ operator: The operator to use.
750
+ value: The value to restore on.
751
+ """
752
+ if not isinstance(function, (str, type(None))):
753
+ function = function.__name__
754
+
755
+ operator = Operator(operator)
756
+
757
+ restore = Restore(function, parameter, operator, value)
758
+ return await self._send_strike_instruction(restore)
759
+
760
+ async def _send_strike_instruction(self, instruction: StrikeInstruction) -> None:
761
+ with tracer.start_as_current_span(
762
+ f"docket.{instruction.direction}",
763
+ attributes={
764
+ **self.labels(),
765
+ **instruction.labels(),
766
+ },
767
+ ):
768
+ async with self.redis() as redis:
769
+ message = instruction.as_message()
770
+ await redis.xadd(self.strike_key, message) # type: ignore[arg-type]
771
+ self.strike_list.update(instruction)
772
+
773
+ async def _monitor_strikes(self) -> NoReturn:
774
+ last_id = "0-0"
775
+ while True:
776
+ try:
777
+ async with self.redis() as r:
778
+ while True:
779
+ streams: RedisReadGroupResponse = await r.xread(
780
+ {self.strike_key: last_id},
781
+ count=100,
782
+ block=60_000,
783
+ )
784
+ for _, messages in streams:
785
+ for message_id, message in messages:
786
+ last_id = message_id
787
+ instruction = StrikeInstruction.from_message(message)
788
+ self.strike_list.update(instruction)
789
+ logger.info(
790
+ "%s %r",
791
+ (
792
+ "Striking"
793
+ if instruction.direction == "strike"
794
+ else "Restoring"
795
+ ),
796
+ instruction.call_repr(),
797
+ extra=self.labels(),
798
+ )
799
+
800
+ STRIKES_IN_EFFECT.add(
801
+ 1 if instruction.direction == "strike" else -1,
802
+ {
803
+ **self.labels(),
804
+ **instruction.labels(),
805
+ },
806
+ )
807
+
808
+ except redis.exceptions.ConnectionError: # pragma: no cover
809
+ REDIS_DISRUPTIONS.add(1, {"docket": self.name})
810
+ logger.warning("Connection error, sleeping for 1 second...")
811
+ await asyncio.sleep(1)
812
+ except Exception: # pragma: no cover
813
+ logger.exception("Error monitoring strikes")
814
+ await asyncio.sleep(1)
815
+
816
+ async def snapshot(self) -> DocketSnapshot:
817
+ """Get a snapshot of the Docket, including which tasks are scheduled or currently
818
+ running, as well as which workers are active.
819
+
820
+ Returns:
821
+ A snapshot of the Docket.
822
+ """
823
+ # For memory:// URLs (fakeredis), ensure the group exists upfront. This
824
+ # avoids a fakeredis bug where xpending_range raises TypeError instead
825
+ # of NOGROUP when the consumer group doesn't exist.
826
+ if self.url.startswith("memory://"):
827
+ await self._ensure_stream_and_group()
828
+
829
+ running: list[RunningExecution] = []
830
+ future: list[Execution] = []
831
+
832
+ async with self.redis() as r:
833
+ async with r.pipeline() as pipeline:
834
+ pipeline.xlen(self.stream_key)
835
+
836
+ pipeline.zcard(self.queue_key)
837
+
838
+ pipeline.xpending_range(
839
+ self.stream_key,
840
+ self.worker_group_name,
841
+ min="-",
842
+ max="+",
843
+ count=1000,
844
+ )
845
+
846
+ pipeline.xrange(self.stream_key, "-", "+", count=1000)
847
+
848
+ pipeline.zrange(self.queue_key, 0, -1)
849
+
850
+ total_stream_messages: int
851
+ total_schedule_messages: int
852
+ pending_messages: list[RedisStreamPendingMessage]
853
+ stream_messages: list[tuple[RedisMessageID, RedisMessage]]
854
+ scheduled_task_keys: list[bytes]
855
+
856
+ now = datetime.now(timezone.utc)
857
+ try:
858
+ (
859
+ total_stream_messages,
860
+ total_schedule_messages,
861
+ pending_messages,
862
+ stream_messages,
863
+ scheduled_task_keys,
864
+ ) = await pipeline.execute()
865
+ except redis.exceptions.ResponseError as e:
866
+ # Check for NOGROUP error. Also check for XPENDING because
867
+ # redis-py 7.0 has a bug where pipeline errors lose the
868
+ # original NOGROUP message (shows "{exception.args}" instead).
869
+ error_str = str(e)
870
+ if "NOGROUP" in error_str or "XPENDING" in error_str:
871
+ await self._ensure_stream_and_group()
872
+ return await self.snapshot()
873
+ raise # pragma: no cover
874
+
875
+ for task_key in scheduled_task_keys:
876
+ pipeline.hgetall(self.parked_task_key(task_key.decode()))
877
+
878
+ # Because these are two separate pipeline commands, it's possible that
879
+ # a message has been moved from the schedule to the stream in the
880
+ # meantime, which would end up being an empty `{}` message
881
+ queued_messages: list[RedisMessage] = [
882
+ m for m in await pipeline.execute() if m
883
+ ]
884
+
885
+ total_tasks = total_stream_messages + total_schedule_messages
886
+
887
+ pending_lookup: dict[RedisMessageID, RedisStreamPendingMessage] = {
888
+ pending["message_id"]: pending for pending in pending_messages
889
+ }
890
+
891
+ for message_id, message in stream_messages:
892
+ execution = await Execution.from_message(self, message)
893
+ if message_id in pending_lookup:
894
+ worker_name = pending_lookup[message_id]["consumer"].decode()
895
+ started = now - timedelta(
896
+ milliseconds=pending_lookup[message_id]["time_since_delivered"]
897
+ )
898
+ running.append(RunningExecution(execution, worker_name, started))
899
+ else:
900
+ future.append(execution) # pragma: no cover
901
+
902
+ for message in queued_messages:
903
+ execution = await Execution.from_message(self, message)
904
+ future.append(execution)
905
+
906
+ workers = await self.workers()
907
+
908
+ return DocketSnapshot(now, total_tasks, future, running, workers)
909
+
910
+ @property
911
+ def workers_set(self) -> str:
912
+ return f"{self.name}:workers"
913
+
914
+ def worker_tasks_set(self, worker_name: str) -> str:
915
+ return f"{self.name}:worker-tasks:{worker_name}"
916
+
917
+ def task_workers_set(self, task_name: str) -> str:
918
+ return f"{self.name}:task-workers:{task_name}"
919
+
920
+ async def workers(self) -> Collection[WorkerInfo]:
921
+ """Get a list of all workers that have sent heartbeats to the Docket.
922
+
923
+ Returns:
924
+ A list of all workers that have sent heartbeats to the Docket.
925
+ """
926
+ workers: list[WorkerInfo] = []
927
+
928
+ oldest = datetime.now(timezone.utc).timestamp() - (
929
+ self.heartbeat_interval.total_seconds() * self.missed_heartbeats
930
+ )
931
+
932
+ async with self.redis() as r:
933
+ await r.zremrangebyscore(self.workers_set, 0, oldest)
934
+
935
+ worker_name_bytes: bytes
936
+ last_seen_timestamp: float
937
+
938
+ for worker_name_bytes, last_seen_timestamp in await r.zrange(
939
+ self.workers_set, 0, -1, withscores=True
940
+ ):
941
+ worker_name = worker_name_bytes.decode()
942
+ last_seen = datetime.fromtimestamp(last_seen_timestamp, timezone.utc)
943
+
944
+ task_names: set[str] = {
945
+ task_name_bytes.decode()
946
+ for task_name_bytes in cast(
947
+ set[bytes], await r.smembers(self.worker_tasks_set(worker_name))
948
+ )
949
+ }
950
+
951
+ workers.append(WorkerInfo(worker_name, last_seen, task_names))
952
+
953
+ return workers
954
+
955
+ async def task_workers(self, task_name: str) -> Collection[WorkerInfo]:
956
+ """Get a list of all workers that are able to execute a given task.
957
+
958
+ Args:
959
+ task_name: The name of the task.
960
+
961
+ Returns:
962
+ A list of all workers that are able to execute the given task.
963
+ """
964
+ workers: list[WorkerInfo] = []
965
+ oldest = datetime.now(timezone.utc).timestamp() - (
966
+ self.heartbeat_interval.total_seconds() * self.missed_heartbeats
967
+ )
968
+
969
+ async with self.redis() as r:
970
+ await r.zremrangebyscore(self.task_workers_set(task_name), 0, oldest)
971
+
972
+ worker_name_bytes: bytes
973
+ last_seen_timestamp: float
974
+
975
+ for worker_name_bytes, last_seen_timestamp in await r.zrange(
976
+ self.task_workers_set(task_name), 0, -1, withscores=True
977
+ ):
978
+ worker_name = worker_name_bytes.decode()
979
+ last_seen = datetime.fromtimestamp(last_seen_timestamp, timezone.utc)
980
+
981
+ task_names: set[str] = {
982
+ task_name_bytes.decode()
983
+ for task_name_bytes in cast(
984
+ set[bytes], await r.smembers(self.worker_tasks_set(worker_name))
985
+ )
986
+ }
987
+
988
+ workers.append(WorkerInfo(worker_name, last_seen, task_names))
989
+
990
+ return workers
991
+
992
+ async def clear(self) -> int:
993
+ """Clear all queued and scheduled tasks from the docket.
994
+
995
+ This removes all tasks from the stream (immediate tasks) and queue
996
+ (scheduled tasks), along with their associated parked data. Running
997
+ tasks are not affected.
998
+
999
+ Returns:
1000
+ The total number of tasks that were cleared.
1001
+ """
1002
+ with tracer.start_as_current_span(
1003
+ "docket.clear",
1004
+ attributes=self.labels(),
1005
+ ):
1006
+ async with self.redis() as redis:
1007
+ async with redis.pipeline() as pipeline:
1008
+ # Get counts before clearing
1009
+ pipeline.xlen(self.stream_key)
1010
+ pipeline.zcard(self.queue_key)
1011
+ pipeline.zrange(self.queue_key, 0, -1)
1012
+
1013
+ stream_count: int
1014
+ queue_count: int
1015
+ scheduled_keys: list[bytes]
1016
+ stream_count, queue_count, scheduled_keys = await pipeline.execute()
1017
+
1018
+ # Get keys from stream messages before trimming
1019
+ stream_keys: list[str] = []
1020
+ if stream_count > 0:
1021
+ # Read all messages from the stream
1022
+ messages = await redis.xrange(self.stream_key, "-", "+")
1023
+ for message_id, fields in messages:
1024
+ # Extract the key field from the message
1025
+ if b"key" in fields: # pragma: no branch
1026
+ stream_keys.append(fields[b"key"].decode())
1027
+
1028
+ async with redis.pipeline() as pipeline:
1029
+ # Clear all data
1030
+ # Trim stream to 0 messages instead of deleting it to preserve consumer group
1031
+ if stream_count > 0:
1032
+ pipeline.xtrim(self.stream_key, maxlen=0, approximate=False)
1033
+ pipeline.delete(self.queue_key)
1034
+
1035
+ # Clear parked task data and known task keys for scheduled tasks
1036
+ for key_bytes in scheduled_keys:
1037
+ key = key_bytes.decode()
1038
+ pipeline.delete(self.parked_task_key(key))
1039
+ pipeline.delete(self.known_task_key(key))
1040
+ pipeline.delete(self.stream_id_key(key))
1041
+
1042
+ # Handle runs hash: set TTL or delete based on execution_ttl
1043
+ runs_key = f"{self.name}:runs:{key}"
1044
+ if self.execution_ttl:
1045
+ ttl_seconds = int(self.execution_ttl.total_seconds())
1046
+ pipeline.expire(runs_key, ttl_seconds)
1047
+ else:
1048
+ pipeline.delete(runs_key)
1049
+
1050
+ # Handle runs hash for immediate tasks from stream
1051
+ for key in stream_keys:
1052
+ runs_key = f"{self.name}:runs:{key}"
1053
+ if self.execution_ttl:
1054
+ ttl_seconds = int(self.execution_ttl.total_seconds())
1055
+ pipeline.expire(runs_key, ttl_seconds)
1056
+ else:
1057
+ pipeline.delete(runs_key)
1058
+
1059
+ await pipeline.execute()
1060
+
1061
+ total_cleared = stream_count + queue_count
1062
+ return total_cleared