pydocket 0.15.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
docket/execution.py ADDED
@@ -0,0 +1,1370 @@
1
+ import abc
2
+ import asyncio
3
+ import base64
4
+ import enum
5
+ import inspect
6
+ import json
7
+ import logging
8
+ from datetime import datetime, timedelta, timezone
9
+ from typing import (
10
+ TYPE_CHECKING,
11
+ Any,
12
+ AsyncGenerator,
13
+ Awaitable,
14
+ Callable,
15
+ Hashable,
16
+ Literal,
17
+ Mapping,
18
+ Protocol,
19
+ TypedDict,
20
+ cast,
21
+ )
22
+
23
+ import cloudpickle # type: ignore[import]
24
+ import opentelemetry.context
25
+ from opentelemetry import propagate, trace
26
+ from typing_extensions import Self
27
+
28
+ from .annotations import Logged
29
+ from .instrumentation import CACHE_SIZE, message_getter, message_setter
30
+
31
+ if TYPE_CHECKING:
32
+ from .docket import Docket, RedisMessageID
33
+
34
+ logger: logging.Logger = logging.getLogger(__name__)
35
+
36
+ TaskFunction = Callable[..., Awaitable[Any]]
37
+ Message = dict[bytes, bytes]
38
+
39
+
40
+ class _schedule_task(Protocol):
41
+ async def __call__(
42
+ self, keys: list[str], args: list[str | float | bytes]
43
+ ) -> str: ... # pragma: no cover
44
+
45
+
46
+ _signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
47
+
48
+
49
+ def get_signature(function: Callable[..., Any]) -> inspect.Signature:
50
+ if function in _signature_cache:
51
+ CACHE_SIZE.set(len(_signature_cache), {"cache": "signature"})
52
+ return _signature_cache[function]
53
+
54
+ signature_attr = getattr(function, "__signature__", None)
55
+ if isinstance(signature_attr, inspect.Signature):
56
+ _signature_cache[function] = signature_attr
57
+ CACHE_SIZE.set(len(_signature_cache), {"cache": "signature"})
58
+ return signature_attr
59
+
60
+ signature = inspect.signature(function)
61
+ _signature_cache[function] = signature
62
+ CACHE_SIZE.set(len(_signature_cache), {"cache": "signature"})
63
+ return signature
64
+
65
+
66
+ class ExecutionState(enum.Enum):
67
+ """Lifecycle states for task execution."""
68
+
69
+ SCHEDULED = "scheduled"
70
+ """Task is scheduled and waiting in the queue for its execution time."""
71
+
72
+ QUEUED = "queued"
73
+ """Task has been moved to the stream and is ready to be claimed by a worker."""
74
+
75
+ RUNNING = "running"
76
+ """Task is currently being executed by a worker."""
77
+
78
+ COMPLETED = "completed"
79
+ """Task execution finished successfully."""
80
+
81
+ FAILED = "failed"
82
+ """Task execution failed."""
83
+
84
+ CANCELLED = "cancelled"
85
+ """Task was explicitly cancelled before completion."""
86
+
87
+
88
+ class ProgressEvent(TypedDict):
89
+ type: Literal["progress"]
90
+ key: str
91
+ current: int | None
92
+ total: int
93
+ message: str | None
94
+ updated_at: str | None
95
+
96
+
97
+ class StateEvent(TypedDict):
98
+ type: Literal["state"]
99
+ key: str
100
+ state: ExecutionState
101
+ when: str
102
+ worker: str | None
103
+ started_at: str | None
104
+ completed_at: str | None
105
+ error: str | None
106
+
107
+
108
+ class ExecutionProgress:
109
+ """Manages user-reported progress for a task execution.
110
+
111
+ Progress data is stored in Redis hash {docket}:progress:{key} and includes:
112
+ - current: Current progress value (integer)
113
+ - total: Total/target value (integer)
114
+ - message: User-provided status message (string)
115
+ - updated_at: Timestamp of last update (ISO 8601 string)
116
+
117
+ This data is ephemeral and deleted when the task completes.
118
+ """
119
+
120
+ def __init__(self, docket: "Docket", key: str) -> None:
121
+ """Initialize progress tracker for a specific task.
122
+
123
+ Args:
124
+ docket: The docket instance
125
+ key: The task execution key
126
+ """
127
+ self.docket = docket
128
+ self.key = key
129
+ self._redis_key = f"{docket.name}:progress:{key}"
130
+ self.current: int | None = None
131
+ self.total: int = 1
132
+ self.message: str | None = None
133
+ self.updated_at: datetime | None = None
134
+
135
+ @classmethod
136
+ async def create(cls, docket: "Docket", key: str) -> Self:
137
+ """Create and initialize progress tracker by reading from Redis.
138
+
139
+ Args:
140
+ docket: The docket instance
141
+ key: The task execution key
142
+
143
+ Returns:
144
+ ExecutionProgress instance with attributes populated from Redis
145
+ """
146
+ instance = cls(docket, key)
147
+ await instance.sync()
148
+ return instance
149
+
150
+ async def set_total(self, total: int) -> None:
151
+ """Set the total/target value for progress tracking.
152
+
153
+ Args:
154
+ total: The total number of units to complete. Must be at least 1.
155
+ """
156
+ if total < 1:
157
+ raise ValueError("Total must be at least 1")
158
+
159
+ updated_at_dt = datetime.now(timezone.utc)
160
+ updated_at = updated_at_dt.isoformat()
161
+ async with self.docket.redis() as redis:
162
+ await redis.hset(
163
+ self._redis_key,
164
+ mapping={
165
+ "total": str(total),
166
+ "updated_at": updated_at,
167
+ },
168
+ )
169
+ # Update instance attributes
170
+ self.total = total
171
+ self.updated_at = updated_at_dt
172
+ # Publish update event
173
+ await self._publish({"total": total, "updated_at": updated_at})
174
+
175
+ async def increment(self, amount: int = 1) -> None:
176
+ """Atomically increment the current progress value.
177
+
178
+ Args:
179
+ amount: Amount to increment by. Must be at least 1.
180
+ """
181
+ if amount < 1:
182
+ raise ValueError("Amount must be at least 1")
183
+
184
+ updated_at_dt = datetime.now(timezone.utc)
185
+ updated_at = updated_at_dt.isoformat()
186
+ async with self.docket.redis() as redis:
187
+ new_current = await redis.hincrby(self._redis_key, "current", amount)
188
+ await redis.hset(
189
+ self._redis_key,
190
+ "updated_at",
191
+ updated_at,
192
+ )
193
+ # Update instance attributes using Redis return value
194
+ self.current = new_current
195
+ self.updated_at = updated_at_dt
196
+ # Publish update event with new current value
197
+ await self._publish({"current": new_current, "updated_at": updated_at})
198
+
199
+ async def set_message(self, message: str | None) -> None:
200
+ """Update the progress status message.
201
+
202
+ Args:
203
+ message: Status message describing current progress
204
+ """
205
+ updated_at_dt = datetime.now(timezone.utc)
206
+ updated_at = updated_at_dt.isoformat()
207
+ async with self.docket.redis() as redis:
208
+ await redis.hset(
209
+ self._redis_key,
210
+ mapping={
211
+ "message": message,
212
+ "updated_at": updated_at,
213
+ },
214
+ )
215
+ # Update instance attributes
216
+ self.message = message
217
+ self.updated_at = updated_at_dt
218
+ # Publish update event
219
+ await self._publish({"message": message, "updated_at": updated_at})
220
+
221
+ async def sync(self) -> None:
222
+ """Synchronize instance attributes with current progress data from Redis.
223
+
224
+ Updates self.current, self.total, self.message, and self.updated_at
225
+ with values from Redis. Sets attributes to None if no data exists.
226
+ """
227
+ async with self.docket.redis() as redis:
228
+ data = await redis.hgetall(self._redis_key)
229
+ if data:
230
+ self.current = int(data.get(b"current", b"0"))
231
+ self.total = int(data.get(b"total", b"100"))
232
+ self.message = data[b"message"].decode() if b"message" in data else None
233
+ self.updated_at = (
234
+ datetime.fromisoformat(data[b"updated_at"].decode())
235
+ if b"updated_at" in data
236
+ else None
237
+ )
238
+ else:
239
+ self.current = None
240
+ self.total = 100
241
+ self.message = None
242
+ self.updated_at = None
243
+
244
+ async def delete(self) -> None:
245
+ """Delete the progress data from Redis.
246
+
247
+ Called internally when task execution completes.
248
+ """
249
+ async with self.docket.redis() as redis:
250
+ await redis.delete(self._redis_key)
251
+ # Reset instance attributes
252
+ self.current = None
253
+ self.total = 100
254
+ self.message = None
255
+ self.updated_at = None
256
+
257
+ async def _publish(self, data: dict[str, Any]) -> None:
258
+ """Publish progress update to Redis pub/sub channel.
259
+
260
+ Args:
261
+ data: Progress data to publish (partial update)
262
+ """
263
+ channel = f"{self.docket.name}:progress:{self.key}"
264
+ # Create ephemeral Redis client for publishing
265
+ async with self.docket.redis() as redis:
266
+ # Use instance attributes for current state
267
+ payload: ProgressEvent = {
268
+ "type": "progress",
269
+ "key": self.key,
270
+ "current": self.current if self.current is not None else 0,
271
+ "total": self.total,
272
+ "message": self.message,
273
+ "updated_at": data.get("updated_at"),
274
+ }
275
+
276
+ # Publish JSON payload
277
+ await redis.publish(channel, json.dumps(payload))
278
+
279
+ async def subscribe(self) -> AsyncGenerator[ProgressEvent, None]:
280
+ """Subscribe to progress updates for this task.
281
+
282
+ Yields:
283
+ Dict containing progress update events with fields:
284
+ - type: "progress"
285
+ - key: task key
286
+ - current: current progress value
287
+ - total: total/target value (or None)
288
+ - message: status message (or None)
289
+ - updated_at: ISO 8601 timestamp
290
+ """
291
+ channel = f"{self.docket.name}:progress:{self.key}"
292
+ async with self.docket.redis() as redis:
293
+ async with redis.pubsub() as pubsub:
294
+ await pubsub.subscribe(channel)
295
+ try:
296
+ async for message in pubsub.listen(): # pragma: no cover
297
+ if message["type"] == "message":
298
+ yield json.loads(message["data"])
299
+ finally:
300
+ # Explicitly unsubscribe to ensure clean shutdown
301
+ await pubsub.unsubscribe(channel)
302
+
303
+
304
+ class Execution:
305
+ """Represents a task execution with state management and progress tracking.
306
+
307
+ Combines task invocation metadata (function, args, when, etc.) with
308
+ Redis-backed lifecycle state tracking and user-reported progress.
309
+ """
310
+
311
+ def __init__(
312
+ self,
313
+ docket: "Docket",
314
+ function: TaskFunction,
315
+ args: tuple[Any, ...],
316
+ kwargs: dict[str, Any],
317
+ key: str,
318
+ when: datetime,
319
+ attempt: int,
320
+ trace_context: opentelemetry.context.Context | None = None,
321
+ redelivered: bool = False,
322
+ ) -> None:
323
+ # Task definition (immutable)
324
+ self._docket = docket
325
+ self._function = function
326
+ self._args = args
327
+ self._kwargs = kwargs
328
+ self._key = key
329
+
330
+ # Scheduling metadata
331
+ self.when = when
332
+ self.attempt = attempt
333
+ self._trace_context = trace_context
334
+ self._redelivered = redelivered
335
+
336
+ # Lifecycle state (mutable)
337
+ self.state: ExecutionState = ExecutionState.SCHEDULED
338
+ self.worker: str | None = None
339
+ self.started_at: datetime | None = None
340
+ self.completed_at: datetime | None = None
341
+ self.error: str | None = None
342
+ self.result_key: str | None = None
343
+
344
+ # Progress tracking
345
+ self.progress: ExecutionProgress = ExecutionProgress(docket, key)
346
+
347
+ # Redis key
348
+ self._redis_key = f"{docket.name}:runs:{key}"
349
+
350
+ # Task definition properties (immutable)
351
+ @property
352
+ def docket(self) -> "Docket":
353
+ """Parent docket instance."""
354
+ return self._docket
355
+
356
+ @property
357
+ def function(self) -> TaskFunction:
358
+ """Task function to execute."""
359
+ return self._function
360
+
361
+ @property
362
+ def args(self) -> tuple[Any, ...]:
363
+ """Positional arguments for the task."""
364
+ return self._args
365
+
366
+ @property
367
+ def kwargs(self) -> dict[str, Any]:
368
+ """Keyword arguments for the task."""
369
+ return self._kwargs
370
+
371
+ @property
372
+ def key(self) -> str:
373
+ """Unique task identifier."""
374
+ return self._key
375
+
376
+ # Scheduling metadata properties
377
+ @property
378
+ def trace_context(self) -> opentelemetry.context.Context | None:
379
+ """OpenTelemetry trace context."""
380
+ return self._trace_context
381
+
382
+ @property
383
+ def redelivered(self) -> bool:
384
+ """Whether this message was redelivered."""
385
+ return self._redelivered
386
+
387
+ def as_message(self) -> Message:
388
+ return {
389
+ b"key": self.key.encode(),
390
+ b"when": self.when.isoformat().encode(),
391
+ b"function": self.function.__name__.encode(),
392
+ b"args": cloudpickle.dumps(self.args), # type: ignore[arg-type]
393
+ b"kwargs": cloudpickle.dumps(self.kwargs), # type: ignore[arg-type]
394
+ b"attempt": str(self.attempt).encode(),
395
+ }
396
+
397
+ @classmethod
398
+ async def from_message(
399
+ cls, docket: "Docket", message: Message, redelivered: bool = False
400
+ ) -> Self:
401
+ function_name = message[b"function"].decode()
402
+ if not (function := docket.tasks.get(function_name)):
403
+ raise ValueError(
404
+ f"Task function {function_name!r} is not registered with the current docket"
405
+ )
406
+
407
+ instance = cls(
408
+ docket=docket,
409
+ function=function,
410
+ args=cloudpickle.loads(message[b"args"]),
411
+ kwargs=cloudpickle.loads(message[b"kwargs"]),
412
+ key=message[b"key"].decode(),
413
+ when=datetime.fromisoformat(message[b"when"].decode()),
414
+ attempt=int(message[b"attempt"].decode()),
415
+ trace_context=propagate.extract(message, getter=message_getter),
416
+ redelivered=redelivered,
417
+ )
418
+ await instance.sync()
419
+ return instance
420
+
421
+ def general_labels(self) -> Mapping[str, str]:
422
+ return {"docket.task": self.function.__name__}
423
+
424
+ def specific_labels(self) -> Mapping[str, str | int]:
425
+ return {
426
+ "docket.task": self.function.__name__,
427
+ "docket.key": self.key,
428
+ "docket.when": self.when.isoformat(),
429
+ "docket.attempt": self.attempt,
430
+ }
431
+
432
+ def get_argument(self, parameter: str) -> Any:
433
+ signature = get_signature(self.function)
434
+ bound_args = signature.bind(*self.args, **self.kwargs)
435
+ return bound_args.arguments[parameter]
436
+
437
+ def call_repr(self) -> str:
438
+ arguments: list[str] = []
439
+ function_name = self.function.__name__
440
+
441
+ signature = get_signature(self.function)
442
+ logged_parameters = Logged.annotated_parameters(signature)
443
+ parameter_names = list(signature.parameters.keys())
444
+
445
+ for i, argument in enumerate(self.args[: len(parameter_names)]):
446
+ parameter_name = parameter_names[i]
447
+ if logged := logged_parameters.get(parameter_name):
448
+ arguments.append(logged.format(argument))
449
+ else:
450
+ arguments.append("...")
451
+
452
+ for parameter_name, argument in self.kwargs.items():
453
+ if logged := logged_parameters.get(parameter_name):
454
+ arguments.append(f"{parameter_name}={logged.format(argument)}")
455
+ else:
456
+ arguments.append(f"{parameter_name}=...")
457
+
458
+ return f"{function_name}({', '.join(arguments)}){{{self.key}}}"
459
+
460
+ def incoming_span_links(self) -> list[trace.Link]:
461
+ initiating_span = trace.get_current_span(self.trace_context)
462
+ initiating_context = initiating_span.get_span_context()
463
+ return [trace.Link(initiating_context)] if initiating_context.is_valid else []
464
+
465
+ async def schedule(
466
+ self, replace: bool = False, reschedule_message: "RedisMessageID | None" = None
467
+ ) -> None:
468
+ """Schedule this task atomically in Redis.
469
+
470
+ This performs an atomic operation that:
471
+ - Adds the task to the stream (immediate) or queue (future)
472
+ - Writes the execution state record
473
+ - Tracks metadata for later cancellation
474
+
475
+ Usage patterns:
476
+ - Normal add: schedule(replace=False)
477
+ - Replace existing: schedule(replace=True)
478
+ - Reschedule from stream: schedule(reschedule_message=message_id)
479
+ This atomically acknowledges and deletes the stream message, then
480
+ reschedules the task to the queue. Prevents both task loss and
481
+ duplicate execution when rescheduling tasks (e.g., due to concurrency limits).
482
+
483
+ Args:
484
+ replace: If True, replaces any existing task with the same key.
485
+ If False, raises an error if the task already exists.
486
+ reschedule_message: If provided, atomically acknowledges and deletes
487
+ this stream message ID before rescheduling the task to the queue.
488
+ Used when a task needs to be rescheduled from an active stream message.
489
+ """
490
+ message: dict[bytes, bytes] = self.as_message()
491
+ propagate.inject(message, setter=message_setter)
492
+
493
+ key = self.key
494
+ when = self.when
495
+ known_task_key = self.docket.known_task_key(key)
496
+ is_immediate = when <= datetime.now(timezone.utc)
497
+
498
+ async with self.docket.redis() as redis:
499
+ # Lock per task key to prevent race conditions between concurrent operations
500
+ async with redis.lock(f"{known_task_key}:lock", timeout=10):
501
+ # Register script for this connection (not cached to avoid event loop issues)
502
+ schedule_script = cast(
503
+ _schedule_task,
504
+ redis.register_script(
505
+ # KEYS: stream_key, known_key, parked_key, queue_key, stream_id_key, runs_key, worker_group_key
506
+ # ARGV: task_key, when_timestamp, is_immediate, replace, reschedule_message_id, ...message_fields
507
+ """
508
+ local stream_key = KEYS[1]
509
+ -- TODO: Remove in next breaking release (v0.14.0) - legacy key locations
510
+ local known_key = KEYS[2]
511
+ local parked_key = KEYS[3]
512
+ local queue_key = KEYS[4]
513
+ local stream_id_key = KEYS[5]
514
+ local runs_key = KEYS[6]
515
+ local worker_group_name = KEYS[7]
516
+
517
+ local task_key = ARGV[1]
518
+ local when_timestamp = ARGV[2]
519
+ local is_immediate = ARGV[3] == '1'
520
+ local replace = ARGV[4] == '1'
521
+ local reschedule_message_id = ARGV[5]
522
+
523
+ -- Extract message fields from ARGV[6] onwards
524
+ local message = {}
525
+ local function_name = nil
526
+ local args_data = nil
527
+ local kwargs_data = nil
528
+
529
+ for i = 6, #ARGV, 2 do
530
+ local field_name = ARGV[i]
531
+ local field_value = ARGV[i + 1]
532
+ message[#message + 1] = field_name
533
+ message[#message + 1] = field_value
534
+
535
+ -- Extract task data fields for runs hash
536
+ if field_name == 'function' then
537
+ function_name = field_value
538
+ elseif field_name == 'args' then
539
+ args_data = field_value
540
+ elseif field_name == 'kwargs' then
541
+ kwargs_data = field_value
542
+ end
543
+ end
544
+
545
+ -- Handle rescheduling from stream: atomically ACK message and reschedule to queue
546
+ -- This prevents both task loss (ACK before reschedule) and duplicate execution
547
+ -- (reschedule before ACK with slow reschedule causing redelivery)
548
+ if reschedule_message_id ~= '' then
549
+ -- Acknowledge and delete the message from the stream
550
+ redis.call('XACK', stream_key, worker_group_name, reschedule_message_id)
551
+ redis.call('XDEL', stream_key, reschedule_message_id)
552
+
553
+ -- Park task data for future execution
554
+ redis.call('HSET', parked_key, unpack(message))
555
+
556
+ -- Add to sorted set queue
557
+ redis.call('ZADD', queue_key, when_timestamp, task_key)
558
+
559
+ -- Update state in runs hash (clear stream_id since task is no longer in stream)
560
+ redis.call('HSET', runs_key,
561
+ 'state', 'scheduled',
562
+ 'when', when_timestamp,
563
+ 'function', function_name,
564
+ 'args', args_data,
565
+ 'kwargs', kwargs_data
566
+ )
567
+ redis.call('HDEL', runs_key, 'stream_id')
568
+
569
+ return 'OK'
570
+ end
571
+
572
+ -- Handle replacement: cancel existing task if needed
573
+ if replace then
574
+ -- Get stream ID from runs hash (check new location first)
575
+ local existing_message_id = redis.call('HGET', runs_key, 'stream_id')
576
+
577
+ -- TODO: Remove in next breaking release (v0.14.0) - check legacy location
578
+ if not existing_message_id then
579
+ existing_message_id = redis.call('GET', stream_id_key)
580
+ end
581
+
582
+ if existing_message_id then
583
+ redis.call('XDEL', stream_key, existing_message_id)
584
+ end
585
+
586
+ redis.call('ZREM', queue_key, task_key)
587
+ redis.call('DEL', parked_key)
588
+
589
+ -- TODO: Remove in next breaking release (v0.14.0) - clean up legacy keys
590
+ redis.call('DEL', known_key, stream_id_key)
591
+
592
+ -- Note: runs_key is updated below, not deleted
593
+ else
594
+ -- Check if task already exists (check new location first, then legacy)
595
+ local known_exists = redis.call('HEXISTS', runs_key, 'known') == 1
596
+ if not known_exists then
597
+ -- Check if task is currently running (known field deleted at claim time)
598
+ local state = redis.call('HGET', runs_key, 'state')
599
+ if state == 'running' then
600
+ return 'EXISTS'
601
+ end
602
+ -- TODO: Remove in next breaking release (v0.14.0) - check legacy location
603
+ known_exists = redis.call('EXISTS', known_key) == 1
604
+ end
605
+ if known_exists then
606
+ return 'EXISTS'
607
+ end
608
+ end
609
+
610
+ if is_immediate then
611
+ -- Add to stream for immediate execution
612
+ local message_id = redis.call('XADD', stream_key, '*', unpack(message))
613
+
614
+ -- Store state and metadata in runs hash
615
+ redis.call('HSET', runs_key,
616
+ 'state', 'queued',
617
+ 'when', when_timestamp,
618
+ 'known', when_timestamp,
619
+ 'stream_id', message_id,
620
+ 'function', function_name,
621
+ 'args', args_data,
622
+ 'kwargs', kwargs_data
623
+ )
624
+ else
625
+ -- Park task data for future execution
626
+ redis.call('HSET', parked_key, unpack(message))
627
+
628
+ -- Add to sorted set queue
629
+ redis.call('ZADD', queue_key, when_timestamp, task_key)
630
+
631
+ -- Store state and metadata in runs hash
632
+ redis.call('HSET', runs_key,
633
+ 'state', 'scheduled',
634
+ 'when', when_timestamp,
635
+ 'known', when_timestamp,
636
+ 'function', function_name,
637
+ 'args', args_data,
638
+ 'kwargs', kwargs_data
639
+ )
640
+ end
641
+
642
+ return 'OK'
643
+ """
644
+ ),
645
+ )
646
+
647
+ await schedule_script(
648
+ keys=[
649
+ self.docket.stream_key,
650
+ known_task_key,
651
+ self.docket.parked_task_key(key),
652
+ self.docket.queue_key,
653
+ self.docket.stream_id_key(key),
654
+ self._redis_key,
655
+ self.docket.worker_group_name,
656
+ ],
657
+ args=[
658
+ key,
659
+ str(when.timestamp()),
660
+ "1" if is_immediate else "0",
661
+ "1" if replace else "0",
662
+ reschedule_message or b"",
663
+ *[
664
+ item
665
+ for field, value in message.items()
666
+ for item in (field, value)
667
+ ],
668
+ ],
669
+ )
670
+
671
+ # Update local state based on whether task is immediate, scheduled, or being rescheduled
672
+ if reschedule_message:
673
+ # When rescheduling from stream, task is always parked and queued (never immediate)
674
+ self.state = ExecutionState.SCHEDULED
675
+ await self._publish_state(
676
+ {"state": ExecutionState.SCHEDULED.value, "when": when.isoformat()}
677
+ )
678
+ elif is_immediate:
679
+ self.state = ExecutionState.QUEUED
680
+ await self._publish_state(
681
+ {"state": ExecutionState.QUEUED.value, "when": when.isoformat()}
682
+ )
683
+ else:
684
+ self.state = ExecutionState.SCHEDULED
685
+ await self._publish_state(
686
+ {"state": ExecutionState.SCHEDULED.value, "when": when.isoformat()}
687
+ )
688
+
689
+ async def claim(self, worker: str) -> None:
690
+ """Atomically claim task and transition to RUNNING state.
691
+
692
+ This consolidates worker operations when claiming a task into a single
693
+ atomic Lua script that:
694
+ - Sets state to RUNNING with worker name and timestamp
695
+ - Initializes progress tracking (current=0, total=100)
696
+ - Deletes known/stream_id fields to allow task rescheduling
697
+ - Cleans up legacy keys for backwards compatibility
698
+
699
+ Args:
700
+ worker: Name of the worker claiming the task
701
+ """
702
+ started_at = datetime.now(timezone.utc)
703
+ started_at_iso = started_at.isoformat()
704
+
705
+ async with self.docket.redis() as redis:
706
+ claim_script = redis.register_script(
707
+ # KEYS: runs_key, progress_key, known_key, stream_id_key
708
+ # ARGV: worker, started_at_iso
709
+ """
710
+ local runs_key = KEYS[1]
711
+ local progress_key = KEYS[2]
712
+ -- TODO: Remove in next breaking release (v0.14.0) - legacy key locations
713
+ local known_key = KEYS[3]
714
+ local stream_id_key = KEYS[4]
715
+
716
+ local worker = ARGV[1]
717
+ local started_at = ARGV[2]
718
+
719
+ -- Update execution state to running
720
+ redis.call('HSET', runs_key,
721
+ 'state', 'running',
722
+ 'worker', worker,
723
+ 'started_at', started_at
724
+ )
725
+
726
+ -- Initialize progress tracking
727
+ redis.call('HSET', progress_key,
728
+ 'current', '0',
729
+ 'total', '100'
730
+ )
731
+
732
+ -- Delete known/stream_id fields to allow task rescheduling
733
+ redis.call('HDEL', runs_key, 'known', 'stream_id')
734
+
735
+ -- TODO: Remove in next breaking release (v0.14.0) - legacy key cleanup
736
+ redis.call('DEL', known_key, stream_id_key)
737
+
738
+ return 'OK'
739
+ """
740
+ )
741
+
742
+ await claim_script(
743
+ keys=[
744
+ self._redis_key, # runs_key
745
+ self.progress._redis_key, # progress_key
746
+ f"{self.docket.name}:known:{self.key}", # legacy known_key
747
+ f"{self.docket.name}:stream-id:{self.key}", # legacy stream_id_key
748
+ ],
749
+ args=[worker, started_at_iso],
750
+ )
751
+
752
+ # Update local state
753
+ self.state = ExecutionState.RUNNING
754
+ self.worker = worker
755
+ self.started_at = started_at
756
+ self.progress.current = 0
757
+ self.progress.total = 100
758
+
759
+ # Publish state change event
760
+ await self._publish_state(
761
+ {
762
+ "state": ExecutionState.RUNNING.value,
763
+ "worker": worker,
764
+ "started_at": started_at_iso,
765
+ }
766
+ )
767
+
768
+ async def mark_as_completed(self, result_key: str | None = None) -> None:
769
+ """Mark task as completed successfully.
770
+
771
+ Args:
772
+ result_key: Optional key where the task result is stored
773
+
774
+ Sets TTL on state data (from docket.execution_ttl), or deletes state
775
+ immediately if execution_ttl is 0. Also deletes progress data.
776
+ """
777
+ completed_at = datetime.now(timezone.utc).isoformat()
778
+ async with self.docket.redis() as redis:
779
+ mapping: dict[str, str] = {
780
+ "state": ExecutionState.COMPLETED.value,
781
+ "completed_at": completed_at,
782
+ }
783
+ if result_key is not None:
784
+ mapping["result_key"] = result_key
785
+ await redis.hset(
786
+ self._redis_key,
787
+ mapping=mapping,
788
+ )
789
+ # Set TTL from docket configuration, or delete if TTL=0
790
+ if self.docket.execution_ttl:
791
+ ttl_seconds = int(self.docket.execution_ttl.total_seconds())
792
+ await redis.expire(self._redis_key, ttl_seconds)
793
+ else:
794
+ await redis.delete(self._redis_key)
795
+ self.state = ExecutionState.COMPLETED
796
+ self.result_key = result_key
797
+ # Delete progress data
798
+ await self.progress.delete()
799
+ # Publish state change event
800
+ await self._publish_state(
801
+ {"state": ExecutionState.COMPLETED.value, "completed_at": completed_at}
802
+ )
803
+
804
+ async def mark_as_failed(
805
+ self, error: str | None = None, result_key: str | None = None
806
+ ) -> None:
807
+ """Mark task as failed.
808
+
809
+ Args:
810
+ error: Optional error message describing the failure
811
+ result_key: Optional key where the exception is stored
812
+
813
+ Sets TTL on state data (from docket.execution_ttl), or deletes state
814
+ immediately if execution_ttl is 0. Also deletes progress data.
815
+ """
816
+ completed_at = datetime.now(timezone.utc).isoformat()
817
+ async with self.docket.redis() as redis:
818
+ mapping = {
819
+ "state": ExecutionState.FAILED.value,
820
+ "completed_at": completed_at,
821
+ }
822
+ if error:
823
+ mapping["error"] = error
824
+ if result_key is not None:
825
+ mapping["result_key"] = result_key
826
+ await redis.hset(self._redis_key, mapping=mapping)
827
+ # Set TTL from docket configuration, or delete if TTL=0
828
+ if self.docket.execution_ttl:
829
+ ttl_seconds = int(self.docket.execution_ttl.total_seconds())
830
+ await redis.expire(self._redis_key, ttl_seconds)
831
+ else:
832
+ await redis.delete(self._redis_key)
833
+ self.state = ExecutionState.FAILED
834
+ self.result_key = result_key
835
+ # Delete progress data
836
+ await self.progress.delete()
837
+ # Publish state change event
838
+ state_data = {
839
+ "state": ExecutionState.FAILED.value,
840
+ "completed_at": completed_at,
841
+ }
842
+ if error:
843
+ state_data["error"] = error
844
+ await self._publish_state(state_data)
845
+
846
+ async def get_result(
847
+ self,
848
+ *,
849
+ timeout: timedelta | None = None,
850
+ deadline: datetime | None = None,
851
+ ) -> Any:
852
+ """Retrieve the result of this task execution.
853
+
854
+ If the execution is not yet complete, this method will wait using
855
+ pub/sub for state updates until completion.
856
+
857
+ Args:
858
+ timeout: Optional duration to wait before giving up.
859
+ If None and deadline is None, waits indefinitely.
860
+ deadline: Optional absolute datetime when to stop waiting.
861
+ If None and timeout is None, waits indefinitely.
862
+
863
+ Returns:
864
+ The result of the task execution, or None if the task returned None.
865
+
866
+ Raises:
867
+ ValueError: If both timeout and deadline are provided
868
+ Exception: If the task failed, raises the stored exception
869
+ TimeoutError: If timeout/deadline is reached before execution completes
870
+ """
871
+ # Validate that only one time limit is provided
872
+ if timeout is not None and deadline is not None:
873
+ raise ValueError("Cannot specify both timeout and deadline")
874
+
875
+ # Convert timeout to deadline if provided
876
+ if timeout is not None:
877
+ deadline = datetime.now(timezone.utc) + timeout
878
+
879
+ # Wait for execution to complete if not already done
880
+ if self.state not in (ExecutionState.COMPLETED, ExecutionState.FAILED):
881
+ # Calculate timeout duration if absolute deadline provided
882
+ timeout_seconds = None
883
+ if deadline is not None:
884
+ timeout_seconds = (
885
+ deadline - datetime.now(timezone.utc)
886
+ ).total_seconds()
887
+ if timeout_seconds <= 0:
888
+ raise TimeoutError(
889
+ f"Timeout waiting for execution {self.key} to complete"
890
+ )
891
+
892
+ try:
893
+
894
+ async def wait_for_completion():
895
+ async for event in self.subscribe(): # pragma: no branch
896
+ if event["type"] == "state":
897
+ state = ExecutionState(event["state"])
898
+ if state in (
899
+ ExecutionState.COMPLETED,
900
+ ExecutionState.FAILED,
901
+ ):
902
+ # Sync to get latest data including result key
903
+ await self.sync()
904
+ break
905
+
906
+ # Use asyncio.wait_for to enforce timeout
907
+ await asyncio.wait_for(wait_for_completion(), timeout=timeout_seconds)
908
+ except asyncio.TimeoutError:
909
+ raise TimeoutError(
910
+ f"Timeout waiting for execution {self.key} to complete"
911
+ )
912
+
913
+ # If failed, retrieve and raise the exception
914
+ if self.state == ExecutionState.FAILED:
915
+ if self.result_key:
916
+ # Retrieve serialized exception from result_storage
917
+ result_data = await self.docket.result_storage.get(self.result_key)
918
+ if result_data and "data" in result_data:
919
+ # Base64-decode and unpickle
920
+ pickled_exception = base64.b64decode(result_data["data"])
921
+ exception = cloudpickle.loads(pickled_exception) # type: ignore[arg-type]
922
+ raise exception
923
+ # If no stored exception, raise a generic error with the error message
924
+ error_msg = self.error or "Task execution failed"
925
+ raise Exception(error_msg)
926
+
927
+ # If completed successfully, retrieve result if available
928
+ if self.result_key:
929
+ result_data = await self.docket.result_storage.get(self.result_key)
930
+ if result_data is not None and "data" in result_data:
931
+ # Base64-decode and unpickle
932
+ pickled_result = base64.b64decode(result_data["data"])
933
+ return cloudpickle.loads(pickled_result) # type: ignore[arg-type]
934
+
935
+ # No result stored - task returned None
936
+ return None
937
+
938
+ async def sync(self) -> None:
939
+ """Synchronize instance attributes with current execution data from Redis.
940
+
941
+ Updates self.state, execution metadata, and progress data from Redis.
942
+ Sets attributes to None if no data exists.
943
+ """
944
+ async with self.docket.redis() as redis:
945
+ data = await redis.hgetall(self._redis_key)
946
+ if data:
947
+ # Update state
948
+ state_value = data.get(b"state")
949
+ if state_value:
950
+ if isinstance(state_value, bytes):
951
+ state_value = state_value.decode()
952
+ self.state = ExecutionState(state_value)
953
+
954
+ # Update metadata
955
+ self.worker = data[b"worker"].decode() if b"worker" in data else None
956
+ self.started_at = (
957
+ datetime.fromisoformat(data[b"started_at"].decode())
958
+ if b"started_at" in data
959
+ else None
960
+ )
961
+ self.completed_at = (
962
+ datetime.fromisoformat(data[b"completed_at"].decode())
963
+ if b"completed_at" in data
964
+ else None
965
+ )
966
+ self.error = data[b"error"].decode() if b"error" in data else None
967
+ self.result_key = (
968
+ data[b"result_key"].decode() if b"result_key" in data else None
969
+ )
970
+ else:
971
+ # No data exists - reset to defaults
972
+ self.state = ExecutionState.SCHEDULED
973
+ self.worker = None
974
+ self.started_at = None
975
+ self.completed_at = None
976
+ self.error = None
977
+ self.result_key = None
978
+
979
+ # Sync progress data
980
+ await self.progress.sync()
981
+
982
+ async def _publish_state(self, data: dict) -> None:
983
+ """Publish state change to Redis pub/sub channel.
984
+
985
+ Args:
986
+ data: State data to publish
987
+ """
988
+ channel = f"{self.docket.name}:state:{self.key}"
989
+ # Create ephemeral Redis client for publishing
990
+ async with self.docket.redis() as redis:
991
+ # Build payload with all relevant state information
992
+ payload = {
993
+ "type": "state",
994
+ "key": self.key,
995
+ **data,
996
+ }
997
+ await redis.publish(channel, json.dumps(payload))
998
+
999
+ async def subscribe(self) -> AsyncGenerator[StateEvent | ProgressEvent, None]:
1000
+ """Subscribe to both state and progress updates for this task.
1001
+
1002
+ Emits the current state as the first event, then subscribes to real-time
1003
+ state and progress updates via Redis pub/sub.
1004
+
1005
+ Yields:
1006
+ Dict containing state or progress update events with a 'type' field:
1007
+ - For state events: type="state", state, worker, timestamps, error
1008
+ - For progress events: type="progress", current, total, message, updated_at
1009
+ """
1010
+ # First, emit the current state
1011
+ await self.sync()
1012
+
1013
+ # Build initial state event from current attributes
1014
+ initial_state: StateEvent = {
1015
+ "type": "state",
1016
+ "key": self.key,
1017
+ "state": self.state,
1018
+ "when": self.when.isoformat(),
1019
+ "worker": self.worker,
1020
+ "started_at": self.started_at.isoformat() if self.started_at else None,
1021
+ "completed_at": self.completed_at.isoformat()
1022
+ if self.completed_at
1023
+ else None,
1024
+ "error": self.error,
1025
+ }
1026
+
1027
+ yield initial_state
1028
+
1029
+ progress_event: ProgressEvent = {
1030
+ "type": "progress",
1031
+ "key": self.key,
1032
+ "current": self.progress.current,
1033
+ "total": self.progress.total,
1034
+ "message": self.progress.message,
1035
+ "updated_at": self.progress.updated_at.isoformat()
1036
+ if self.progress.updated_at
1037
+ else None,
1038
+ }
1039
+
1040
+ yield progress_event
1041
+
1042
+ # Then subscribe to real-time updates
1043
+ state_channel = f"{self.docket.name}:state:{self.key}"
1044
+ progress_channel = f"{self.docket.name}:progress:{self.key}"
1045
+ async with self.docket.redis() as redis:
1046
+ async with redis.pubsub() as pubsub:
1047
+ await pubsub.subscribe(state_channel, progress_channel)
1048
+ try:
1049
+ async for message in pubsub.listen(): # pragma: no cover
1050
+ if message["type"] == "message":
1051
+ message_data = json.loads(message["data"])
1052
+ if message_data["type"] == "state":
1053
+ message_data["state"] = ExecutionState(
1054
+ message_data["state"]
1055
+ )
1056
+ yield message_data
1057
+ finally:
1058
+ # Explicitly unsubscribe to ensure clean shutdown
1059
+ await pubsub.unsubscribe(state_channel, progress_channel)
1060
+
1061
+
1062
+ def compact_signature(signature: inspect.Signature) -> str:
1063
+ from .dependencies import Dependency
1064
+
1065
+ parameters: list[str] = []
1066
+ dependencies: int = 0
1067
+
1068
+ for parameter in signature.parameters.values():
1069
+ if isinstance(parameter.default, Dependency):
1070
+ dependencies += 1
1071
+ continue
1072
+
1073
+ parameter_definition = parameter.name
1074
+ if parameter.annotation is not parameter.empty:
1075
+ annotation = parameter.annotation
1076
+ if hasattr(annotation, "__origin__"):
1077
+ annotation = annotation.__args__[0]
1078
+
1079
+ type_name = getattr(annotation, "__name__", str(annotation))
1080
+ parameter_definition = f"{parameter.name}: {type_name}"
1081
+
1082
+ if parameter.default is not parameter.empty:
1083
+ parameter_definition = f"{parameter_definition} = {parameter.default!r}"
1084
+
1085
+ parameters.append(parameter_definition)
1086
+
1087
+ if dependencies > 0:
1088
+ parameters.append("...")
1089
+
1090
+ return ", ".join(parameters)
1091
+
1092
+
1093
+ class Operator(str, enum.Enum):
1094
+ EQUAL = "=="
1095
+ NOT_EQUAL = "!="
1096
+ GREATER_THAN = ">"
1097
+ GREATER_THAN_OR_EQUAL = ">="
1098
+ LESS_THAN = "<"
1099
+ LESS_THAN_OR_EQUAL = "<="
1100
+ BETWEEN = "between"
1101
+
1102
+
1103
+ LiteralOperator = Literal["==", "!=", ">", ">=", "<", "<=", "between"]
1104
+
1105
+
1106
+ class StrikeInstruction(abc.ABC):
1107
+ direction: Literal["strike", "restore"]
1108
+ operator: Operator
1109
+
1110
+ def __init__(
1111
+ self,
1112
+ function: str | None,
1113
+ parameter: str | None,
1114
+ operator: Operator,
1115
+ value: Hashable,
1116
+ ) -> None:
1117
+ self.function = function
1118
+ self.parameter = parameter
1119
+ self.operator = operator
1120
+ self.value = value
1121
+
1122
+ def as_message(self) -> Message:
1123
+ message: dict[bytes, bytes] = {b"direction": self.direction.encode()}
1124
+ if self.function:
1125
+ message[b"function"] = self.function.encode()
1126
+ if self.parameter:
1127
+ message[b"parameter"] = self.parameter.encode()
1128
+ message[b"operator"] = self.operator.encode()
1129
+ message[b"value"] = cloudpickle.dumps(self.value) # type: ignore[arg-type]
1130
+ return message
1131
+
1132
+ @classmethod
1133
+ def from_message(cls, message: Message) -> "StrikeInstruction":
1134
+ direction = cast(Literal["strike", "restore"], message[b"direction"].decode())
1135
+ function = message[b"function"].decode() if b"function" in message else None
1136
+ parameter = message[b"parameter"].decode() if b"parameter" in message else None
1137
+ operator = cast(Operator, message[b"operator"].decode())
1138
+ value = cloudpickle.loads(message[b"value"])
1139
+ if direction == "strike":
1140
+ return Strike(function, parameter, operator, value)
1141
+ else:
1142
+ return Restore(function, parameter, operator, value)
1143
+
1144
+ def labels(self) -> Mapping[str, str]:
1145
+ labels: dict[str, str] = {}
1146
+ if self.function:
1147
+ labels["docket.task"] = self.function
1148
+
1149
+ if self.parameter:
1150
+ labels["docket.parameter"] = self.parameter
1151
+ labels["docket.operator"] = self.operator
1152
+ labels["docket.value"] = repr(self.value)
1153
+
1154
+ return labels
1155
+
1156
+ def call_repr(self) -> str:
1157
+ return (
1158
+ f"{self.function or '*'}"
1159
+ "("
1160
+ f"{self.parameter or '*'}"
1161
+ " "
1162
+ f"{self.operator}"
1163
+ " "
1164
+ f"{repr(self.value) if self.parameter else '*'}"
1165
+ ")"
1166
+ )
1167
+
1168
+
1169
+ class Strike(StrikeInstruction):
1170
+ direction: Literal["strike", "restore"] = "strike"
1171
+
1172
+
1173
+ class Restore(StrikeInstruction):
1174
+ direction: Literal["strike", "restore"] = "restore"
1175
+
1176
+
1177
+ MinimalStrike = tuple[Operator, Hashable]
1178
+ ParameterStrikes = dict[str, set[MinimalStrike]]
1179
+ TaskStrikes = dict[str, ParameterStrikes]
1180
+
1181
+
1182
+ class StrikeList:
1183
+ task_strikes: TaskStrikes
1184
+ parameter_strikes: ParameterStrikes
1185
+ _conditions: list[Callable[[Execution], bool]]
1186
+
1187
+ def __init__(self) -> None:
1188
+ self.task_strikes = {}
1189
+ self.parameter_strikes = {}
1190
+ self._conditions = [self._matches_task_or_parameter_strike]
1191
+
1192
+ def add_condition(self, condition: Callable[[Execution], bool]) -> None:
1193
+ """Adds a temporary condition that indicates an execution is stricken."""
1194
+ self._conditions.insert(0, condition)
1195
+
1196
+ def remove_condition(self, condition: Callable[[Execution], bool]) -> None:
1197
+ """Adds a temporary condition that indicates an execution is stricken."""
1198
+ assert condition is not self._matches_task_or_parameter_strike
1199
+ self._conditions.remove(condition)
1200
+
1201
+ def is_stricken(self, execution: Execution) -> bool:
1202
+ """
1203
+ Checks if an execution is stricken based on task, parameter, or temporary
1204
+ conditions.
1205
+
1206
+ Returns:
1207
+ bool: True if the execution is stricken, False otherwise.
1208
+ """
1209
+ return any(condition(execution) for condition in self._conditions)
1210
+
1211
+ def _matches_task_or_parameter_strike(self, execution: Execution) -> bool:
1212
+ function_name = execution.function.__name__
1213
+
1214
+ # Check if the entire task is stricken (without parameter conditions)
1215
+ task_strikes = self.task_strikes.get(function_name, {})
1216
+ if function_name in self.task_strikes and not task_strikes:
1217
+ return True
1218
+
1219
+ signature = get_signature(execution.function)
1220
+
1221
+ try:
1222
+ bound_args = signature.bind(*execution.args, **execution.kwargs)
1223
+ bound_args.apply_defaults()
1224
+ except TypeError:
1225
+ # If we can't make sense of the arguments, just assume the task is fine
1226
+ return False
1227
+
1228
+ all_arguments = {
1229
+ **bound_args.arguments,
1230
+ **{
1231
+ k: v
1232
+ for k, v in execution.kwargs.items()
1233
+ if k not in bound_args.arguments
1234
+ },
1235
+ }
1236
+
1237
+ for parameter, argument in all_arguments.items():
1238
+ for strike_source in [task_strikes, self.parameter_strikes]:
1239
+ if parameter not in strike_source:
1240
+ continue
1241
+
1242
+ for operator, strike_value in strike_source[parameter]:
1243
+ if self._is_match(argument, operator, strike_value):
1244
+ return True
1245
+
1246
+ return False
1247
+
1248
+ def _is_match(self, value: Any, operator: Operator, strike_value: Any) -> bool:
1249
+ """Determines if a value matches a strike condition."""
1250
+ try:
1251
+ match operator:
1252
+ case "==":
1253
+ return value == strike_value
1254
+ case "!=":
1255
+ return value != strike_value
1256
+ case ">":
1257
+ return value > strike_value
1258
+ case ">=":
1259
+ return value >= strike_value
1260
+ case "<":
1261
+ return value < strike_value
1262
+ case "<=":
1263
+ return value <= strike_value
1264
+ case "between": # pragma: no branch
1265
+ lower, upper = strike_value
1266
+ return lower <= value <= upper
1267
+ case _: # pragma: no cover
1268
+ raise ValueError(f"Unknown operator: {operator}")
1269
+ except (ValueError, TypeError):
1270
+ # If we can't make the comparison due to incompatible types, just log the
1271
+ # error and assume the task is not stricken
1272
+ logger.warning(
1273
+ "Incompatible type for strike condition: %r %s %r",
1274
+ strike_value,
1275
+ operator,
1276
+ value,
1277
+ exc_info=True,
1278
+ )
1279
+ return False
1280
+
1281
+ def update(self, instruction: StrikeInstruction) -> None:
1282
+ try:
1283
+ hash(instruction.value)
1284
+ except TypeError:
1285
+ logger.warning(
1286
+ "Incompatible type for strike condition: %s %r",
1287
+ instruction.operator,
1288
+ instruction.value,
1289
+ )
1290
+ return
1291
+
1292
+ if isinstance(instruction, Strike):
1293
+ self._strike(instruction)
1294
+ elif isinstance(instruction, Restore): # pragma: no branch
1295
+ self._restore(instruction)
1296
+
1297
+ def _strike(self, strike: Strike) -> None:
1298
+ if strike.function and strike.parameter:
1299
+ try:
1300
+ task_strikes = self.task_strikes[strike.function]
1301
+ except KeyError:
1302
+ task_strikes = self.task_strikes[strike.function] = {}
1303
+
1304
+ try:
1305
+ parameter_strikes = task_strikes[strike.parameter]
1306
+ except KeyError:
1307
+ parameter_strikes = task_strikes[strike.parameter] = set()
1308
+
1309
+ parameter_strikes.add((strike.operator, strike.value))
1310
+
1311
+ elif strike.function:
1312
+ try:
1313
+ task_strikes = self.task_strikes[strike.function]
1314
+ except KeyError:
1315
+ task_strikes = self.task_strikes[strike.function] = {}
1316
+
1317
+ elif strike.parameter: # pragma: no branch
1318
+ try:
1319
+ parameter_strikes = self.parameter_strikes[strike.parameter]
1320
+ except KeyError:
1321
+ parameter_strikes = self.parameter_strikes[strike.parameter] = set()
1322
+
1323
+ parameter_strikes.add((strike.operator, strike.value))
1324
+
1325
+ def _restore(self, restore: Restore) -> None:
1326
+ if restore.function and restore.parameter:
1327
+ try:
1328
+ task_strikes = self.task_strikes[restore.function]
1329
+ except KeyError:
1330
+ return
1331
+
1332
+ try:
1333
+ parameter_strikes = task_strikes[restore.parameter]
1334
+ except KeyError:
1335
+ task_strikes.pop(restore.parameter, None)
1336
+ return
1337
+
1338
+ try:
1339
+ parameter_strikes.remove((restore.operator, restore.value))
1340
+ except KeyError:
1341
+ pass
1342
+
1343
+ if not parameter_strikes:
1344
+ task_strikes.pop(restore.parameter, None)
1345
+ if not task_strikes:
1346
+ self.task_strikes.pop(restore.function, None)
1347
+
1348
+ elif restore.function:
1349
+ try:
1350
+ task_strikes = self.task_strikes[restore.function]
1351
+ except KeyError:
1352
+ return
1353
+
1354
+ # If there are no parameter strikes, this was a full task strike
1355
+ if not task_strikes:
1356
+ self.task_strikes.pop(restore.function, None)
1357
+
1358
+ elif restore.parameter: # pragma: no branch
1359
+ try:
1360
+ parameter_strikes = self.parameter_strikes[restore.parameter]
1361
+ except KeyError:
1362
+ return
1363
+
1364
+ try:
1365
+ parameter_strikes.remove((restore.operator, restore.value))
1366
+ except KeyError:
1367
+ pass
1368
+
1369
+ if not parameter_strikes:
1370
+ self.parameter_strikes.pop(restore.parameter, None)