pydocket 0.15.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docket/__init__.py +55 -0
- docket/__main__.py +3 -0
- docket/_uuid7.py +99 -0
- docket/agenda.py +202 -0
- docket/annotations.py +81 -0
- docket/cli.py +1185 -0
- docket/dependencies.py +808 -0
- docket/docket.py +1062 -0
- docket/execution.py +1370 -0
- docket/instrumentation.py +225 -0
- docket/py.typed +0 -0
- docket/tasks.py +59 -0
- docket/testing.py +235 -0
- docket/worker.py +1071 -0
- pydocket-0.15.3.dist-info/METADATA +160 -0
- pydocket-0.15.3.dist-info/RECORD +19 -0
- pydocket-0.15.3.dist-info/WHEEL +4 -0
- pydocket-0.15.3.dist-info/entry_points.txt +2 -0
- pydocket-0.15.3.dist-info/licenses/LICENSE +9 -0
docket/execution.py
ADDED
|
@@ -0,0 +1,1370 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import asyncio
|
|
3
|
+
import base64
|
|
4
|
+
import enum
|
|
5
|
+
import inspect
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from datetime import datetime, timedelta, timezone
|
|
9
|
+
from typing import (
|
|
10
|
+
TYPE_CHECKING,
|
|
11
|
+
Any,
|
|
12
|
+
AsyncGenerator,
|
|
13
|
+
Awaitable,
|
|
14
|
+
Callable,
|
|
15
|
+
Hashable,
|
|
16
|
+
Literal,
|
|
17
|
+
Mapping,
|
|
18
|
+
Protocol,
|
|
19
|
+
TypedDict,
|
|
20
|
+
cast,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
import cloudpickle # type: ignore[import]
|
|
24
|
+
import opentelemetry.context
|
|
25
|
+
from opentelemetry import propagate, trace
|
|
26
|
+
from typing_extensions import Self
|
|
27
|
+
|
|
28
|
+
from .annotations import Logged
|
|
29
|
+
from .instrumentation import CACHE_SIZE, message_getter, message_setter
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from .docket import Docket, RedisMessageID
|
|
33
|
+
|
|
34
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
TaskFunction = Callable[..., Awaitable[Any]]
|
|
37
|
+
Message = dict[bytes, bytes]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class _schedule_task(Protocol):
|
|
41
|
+
async def __call__(
|
|
42
|
+
self, keys: list[str], args: list[str | float | bytes]
|
|
43
|
+
) -> str: ... # pragma: no cover
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
_signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_signature(function: Callable[..., Any]) -> inspect.Signature:
|
|
50
|
+
if function in _signature_cache:
|
|
51
|
+
CACHE_SIZE.set(len(_signature_cache), {"cache": "signature"})
|
|
52
|
+
return _signature_cache[function]
|
|
53
|
+
|
|
54
|
+
signature_attr = getattr(function, "__signature__", None)
|
|
55
|
+
if isinstance(signature_attr, inspect.Signature):
|
|
56
|
+
_signature_cache[function] = signature_attr
|
|
57
|
+
CACHE_SIZE.set(len(_signature_cache), {"cache": "signature"})
|
|
58
|
+
return signature_attr
|
|
59
|
+
|
|
60
|
+
signature = inspect.signature(function)
|
|
61
|
+
_signature_cache[function] = signature
|
|
62
|
+
CACHE_SIZE.set(len(_signature_cache), {"cache": "signature"})
|
|
63
|
+
return signature
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ExecutionState(enum.Enum):
|
|
67
|
+
"""Lifecycle states for task execution."""
|
|
68
|
+
|
|
69
|
+
SCHEDULED = "scheduled"
|
|
70
|
+
"""Task is scheduled and waiting in the queue for its execution time."""
|
|
71
|
+
|
|
72
|
+
QUEUED = "queued"
|
|
73
|
+
"""Task has been moved to the stream and is ready to be claimed by a worker."""
|
|
74
|
+
|
|
75
|
+
RUNNING = "running"
|
|
76
|
+
"""Task is currently being executed by a worker."""
|
|
77
|
+
|
|
78
|
+
COMPLETED = "completed"
|
|
79
|
+
"""Task execution finished successfully."""
|
|
80
|
+
|
|
81
|
+
FAILED = "failed"
|
|
82
|
+
"""Task execution failed."""
|
|
83
|
+
|
|
84
|
+
CANCELLED = "cancelled"
|
|
85
|
+
"""Task was explicitly cancelled before completion."""
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class ProgressEvent(TypedDict):
|
|
89
|
+
type: Literal["progress"]
|
|
90
|
+
key: str
|
|
91
|
+
current: int | None
|
|
92
|
+
total: int
|
|
93
|
+
message: str | None
|
|
94
|
+
updated_at: str | None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class StateEvent(TypedDict):
|
|
98
|
+
type: Literal["state"]
|
|
99
|
+
key: str
|
|
100
|
+
state: ExecutionState
|
|
101
|
+
when: str
|
|
102
|
+
worker: str | None
|
|
103
|
+
started_at: str | None
|
|
104
|
+
completed_at: str | None
|
|
105
|
+
error: str | None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class ExecutionProgress:
|
|
109
|
+
"""Manages user-reported progress for a task execution.
|
|
110
|
+
|
|
111
|
+
Progress data is stored in Redis hash {docket}:progress:{key} and includes:
|
|
112
|
+
- current: Current progress value (integer)
|
|
113
|
+
- total: Total/target value (integer)
|
|
114
|
+
- message: User-provided status message (string)
|
|
115
|
+
- updated_at: Timestamp of last update (ISO 8601 string)
|
|
116
|
+
|
|
117
|
+
This data is ephemeral and deleted when the task completes.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(self, docket: "Docket", key: str) -> None:
|
|
121
|
+
"""Initialize progress tracker for a specific task.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
docket: The docket instance
|
|
125
|
+
key: The task execution key
|
|
126
|
+
"""
|
|
127
|
+
self.docket = docket
|
|
128
|
+
self.key = key
|
|
129
|
+
self._redis_key = f"{docket.name}:progress:{key}"
|
|
130
|
+
self.current: int | None = None
|
|
131
|
+
self.total: int = 1
|
|
132
|
+
self.message: str | None = None
|
|
133
|
+
self.updated_at: datetime | None = None
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
async def create(cls, docket: "Docket", key: str) -> Self:
|
|
137
|
+
"""Create and initialize progress tracker by reading from Redis.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
docket: The docket instance
|
|
141
|
+
key: The task execution key
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
ExecutionProgress instance with attributes populated from Redis
|
|
145
|
+
"""
|
|
146
|
+
instance = cls(docket, key)
|
|
147
|
+
await instance.sync()
|
|
148
|
+
return instance
|
|
149
|
+
|
|
150
|
+
async def set_total(self, total: int) -> None:
|
|
151
|
+
"""Set the total/target value for progress tracking.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
total: The total number of units to complete. Must be at least 1.
|
|
155
|
+
"""
|
|
156
|
+
if total < 1:
|
|
157
|
+
raise ValueError("Total must be at least 1")
|
|
158
|
+
|
|
159
|
+
updated_at_dt = datetime.now(timezone.utc)
|
|
160
|
+
updated_at = updated_at_dt.isoformat()
|
|
161
|
+
async with self.docket.redis() as redis:
|
|
162
|
+
await redis.hset(
|
|
163
|
+
self._redis_key,
|
|
164
|
+
mapping={
|
|
165
|
+
"total": str(total),
|
|
166
|
+
"updated_at": updated_at,
|
|
167
|
+
},
|
|
168
|
+
)
|
|
169
|
+
# Update instance attributes
|
|
170
|
+
self.total = total
|
|
171
|
+
self.updated_at = updated_at_dt
|
|
172
|
+
# Publish update event
|
|
173
|
+
await self._publish({"total": total, "updated_at": updated_at})
|
|
174
|
+
|
|
175
|
+
async def increment(self, amount: int = 1) -> None:
|
|
176
|
+
"""Atomically increment the current progress value.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
amount: Amount to increment by. Must be at least 1.
|
|
180
|
+
"""
|
|
181
|
+
if amount < 1:
|
|
182
|
+
raise ValueError("Amount must be at least 1")
|
|
183
|
+
|
|
184
|
+
updated_at_dt = datetime.now(timezone.utc)
|
|
185
|
+
updated_at = updated_at_dt.isoformat()
|
|
186
|
+
async with self.docket.redis() as redis:
|
|
187
|
+
new_current = await redis.hincrby(self._redis_key, "current", amount)
|
|
188
|
+
await redis.hset(
|
|
189
|
+
self._redis_key,
|
|
190
|
+
"updated_at",
|
|
191
|
+
updated_at,
|
|
192
|
+
)
|
|
193
|
+
# Update instance attributes using Redis return value
|
|
194
|
+
self.current = new_current
|
|
195
|
+
self.updated_at = updated_at_dt
|
|
196
|
+
# Publish update event with new current value
|
|
197
|
+
await self._publish({"current": new_current, "updated_at": updated_at})
|
|
198
|
+
|
|
199
|
+
async def set_message(self, message: str | None) -> None:
|
|
200
|
+
"""Update the progress status message.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
message: Status message describing current progress
|
|
204
|
+
"""
|
|
205
|
+
updated_at_dt = datetime.now(timezone.utc)
|
|
206
|
+
updated_at = updated_at_dt.isoformat()
|
|
207
|
+
async with self.docket.redis() as redis:
|
|
208
|
+
await redis.hset(
|
|
209
|
+
self._redis_key,
|
|
210
|
+
mapping={
|
|
211
|
+
"message": message,
|
|
212
|
+
"updated_at": updated_at,
|
|
213
|
+
},
|
|
214
|
+
)
|
|
215
|
+
# Update instance attributes
|
|
216
|
+
self.message = message
|
|
217
|
+
self.updated_at = updated_at_dt
|
|
218
|
+
# Publish update event
|
|
219
|
+
await self._publish({"message": message, "updated_at": updated_at})
|
|
220
|
+
|
|
221
|
+
async def sync(self) -> None:
|
|
222
|
+
"""Synchronize instance attributes with current progress data from Redis.
|
|
223
|
+
|
|
224
|
+
Updates self.current, self.total, self.message, and self.updated_at
|
|
225
|
+
with values from Redis. Sets attributes to None if no data exists.
|
|
226
|
+
"""
|
|
227
|
+
async with self.docket.redis() as redis:
|
|
228
|
+
data = await redis.hgetall(self._redis_key)
|
|
229
|
+
if data:
|
|
230
|
+
self.current = int(data.get(b"current", b"0"))
|
|
231
|
+
self.total = int(data.get(b"total", b"100"))
|
|
232
|
+
self.message = data[b"message"].decode() if b"message" in data else None
|
|
233
|
+
self.updated_at = (
|
|
234
|
+
datetime.fromisoformat(data[b"updated_at"].decode())
|
|
235
|
+
if b"updated_at" in data
|
|
236
|
+
else None
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
self.current = None
|
|
240
|
+
self.total = 100
|
|
241
|
+
self.message = None
|
|
242
|
+
self.updated_at = None
|
|
243
|
+
|
|
244
|
+
async def delete(self) -> None:
|
|
245
|
+
"""Delete the progress data from Redis.
|
|
246
|
+
|
|
247
|
+
Called internally when task execution completes.
|
|
248
|
+
"""
|
|
249
|
+
async with self.docket.redis() as redis:
|
|
250
|
+
await redis.delete(self._redis_key)
|
|
251
|
+
# Reset instance attributes
|
|
252
|
+
self.current = None
|
|
253
|
+
self.total = 100
|
|
254
|
+
self.message = None
|
|
255
|
+
self.updated_at = None
|
|
256
|
+
|
|
257
|
+
async def _publish(self, data: dict[str, Any]) -> None:
|
|
258
|
+
"""Publish progress update to Redis pub/sub channel.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
data: Progress data to publish (partial update)
|
|
262
|
+
"""
|
|
263
|
+
channel = f"{self.docket.name}:progress:{self.key}"
|
|
264
|
+
# Create ephemeral Redis client for publishing
|
|
265
|
+
async with self.docket.redis() as redis:
|
|
266
|
+
# Use instance attributes for current state
|
|
267
|
+
payload: ProgressEvent = {
|
|
268
|
+
"type": "progress",
|
|
269
|
+
"key": self.key,
|
|
270
|
+
"current": self.current if self.current is not None else 0,
|
|
271
|
+
"total": self.total,
|
|
272
|
+
"message": self.message,
|
|
273
|
+
"updated_at": data.get("updated_at"),
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
# Publish JSON payload
|
|
277
|
+
await redis.publish(channel, json.dumps(payload))
|
|
278
|
+
|
|
279
|
+
async def subscribe(self) -> AsyncGenerator[ProgressEvent, None]:
|
|
280
|
+
"""Subscribe to progress updates for this task.
|
|
281
|
+
|
|
282
|
+
Yields:
|
|
283
|
+
Dict containing progress update events with fields:
|
|
284
|
+
- type: "progress"
|
|
285
|
+
- key: task key
|
|
286
|
+
- current: current progress value
|
|
287
|
+
- total: total/target value (or None)
|
|
288
|
+
- message: status message (or None)
|
|
289
|
+
- updated_at: ISO 8601 timestamp
|
|
290
|
+
"""
|
|
291
|
+
channel = f"{self.docket.name}:progress:{self.key}"
|
|
292
|
+
async with self.docket.redis() as redis:
|
|
293
|
+
async with redis.pubsub() as pubsub:
|
|
294
|
+
await pubsub.subscribe(channel)
|
|
295
|
+
try:
|
|
296
|
+
async for message in pubsub.listen(): # pragma: no cover
|
|
297
|
+
if message["type"] == "message":
|
|
298
|
+
yield json.loads(message["data"])
|
|
299
|
+
finally:
|
|
300
|
+
# Explicitly unsubscribe to ensure clean shutdown
|
|
301
|
+
await pubsub.unsubscribe(channel)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class Execution:
|
|
305
|
+
"""Represents a task execution with state management and progress tracking.
|
|
306
|
+
|
|
307
|
+
Combines task invocation metadata (function, args, when, etc.) with
|
|
308
|
+
Redis-backed lifecycle state tracking and user-reported progress.
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
def __init__(
|
|
312
|
+
self,
|
|
313
|
+
docket: "Docket",
|
|
314
|
+
function: TaskFunction,
|
|
315
|
+
args: tuple[Any, ...],
|
|
316
|
+
kwargs: dict[str, Any],
|
|
317
|
+
key: str,
|
|
318
|
+
when: datetime,
|
|
319
|
+
attempt: int,
|
|
320
|
+
trace_context: opentelemetry.context.Context | None = None,
|
|
321
|
+
redelivered: bool = False,
|
|
322
|
+
) -> None:
|
|
323
|
+
# Task definition (immutable)
|
|
324
|
+
self._docket = docket
|
|
325
|
+
self._function = function
|
|
326
|
+
self._args = args
|
|
327
|
+
self._kwargs = kwargs
|
|
328
|
+
self._key = key
|
|
329
|
+
|
|
330
|
+
# Scheduling metadata
|
|
331
|
+
self.when = when
|
|
332
|
+
self.attempt = attempt
|
|
333
|
+
self._trace_context = trace_context
|
|
334
|
+
self._redelivered = redelivered
|
|
335
|
+
|
|
336
|
+
# Lifecycle state (mutable)
|
|
337
|
+
self.state: ExecutionState = ExecutionState.SCHEDULED
|
|
338
|
+
self.worker: str | None = None
|
|
339
|
+
self.started_at: datetime | None = None
|
|
340
|
+
self.completed_at: datetime | None = None
|
|
341
|
+
self.error: str | None = None
|
|
342
|
+
self.result_key: str | None = None
|
|
343
|
+
|
|
344
|
+
# Progress tracking
|
|
345
|
+
self.progress: ExecutionProgress = ExecutionProgress(docket, key)
|
|
346
|
+
|
|
347
|
+
# Redis key
|
|
348
|
+
self._redis_key = f"{docket.name}:runs:{key}"
|
|
349
|
+
|
|
350
|
+
# Task definition properties (immutable)
|
|
351
|
+
@property
|
|
352
|
+
def docket(self) -> "Docket":
|
|
353
|
+
"""Parent docket instance."""
|
|
354
|
+
return self._docket
|
|
355
|
+
|
|
356
|
+
@property
|
|
357
|
+
def function(self) -> TaskFunction:
|
|
358
|
+
"""Task function to execute."""
|
|
359
|
+
return self._function
|
|
360
|
+
|
|
361
|
+
@property
|
|
362
|
+
def args(self) -> tuple[Any, ...]:
|
|
363
|
+
"""Positional arguments for the task."""
|
|
364
|
+
return self._args
|
|
365
|
+
|
|
366
|
+
@property
|
|
367
|
+
def kwargs(self) -> dict[str, Any]:
|
|
368
|
+
"""Keyword arguments for the task."""
|
|
369
|
+
return self._kwargs
|
|
370
|
+
|
|
371
|
+
@property
|
|
372
|
+
def key(self) -> str:
|
|
373
|
+
"""Unique task identifier."""
|
|
374
|
+
return self._key
|
|
375
|
+
|
|
376
|
+
# Scheduling metadata properties
|
|
377
|
+
@property
|
|
378
|
+
def trace_context(self) -> opentelemetry.context.Context | None:
|
|
379
|
+
"""OpenTelemetry trace context."""
|
|
380
|
+
return self._trace_context
|
|
381
|
+
|
|
382
|
+
@property
|
|
383
|
+
def redelivered(self) -> bool:
|
|
384
|
+
"""Whether this message was redelivered."""
|
|
385
|
+
return self._redelivered
|
|
386
|
+
|
|
387
|
+
def as_message(self) -> Message:
|
|
388
|
+
return {
|
|
389
|
+
b"key": self.key.encode(),
|
|
390
|
+
b"when": self.when.isoformat().encode(),
|
|
391
|
+
b"function": self.function.__name__.encode(),
|
|
392
|
+
b"args": cloudpickle.dumps(self.args), # type: ignore[arg-type]
|
|
393
|
+
b"kwargs": cloudpickle.dumps(self.kwargs), # type: ignore[arg-type]
|
|
394
|
+
b"attempt": str(self.attempt).encode(),
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
@classmethod
|
|
398
|
+
async def from_message(
|
|
399
|
+
cls, docket: "Docket", message: Message, redelivered: bool = False
|
|
400
|
+
) -> Self:
|
|
401
|
+
function_name = message[b"function"].decode()
|
|
402
|
+
if not (function := docket.tasks.get(function_name)):
|
|
403
|
+
raise ValueError(
|
|
404
|
+
f"Task function {function_name!r} is not registered with the current docket"
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
instance = cls(
|
|
408
|
+
docket=docket,
|
|
409
|
+
function=function,
|
|
410
|
+
args=cloudpickle.loads(message[b"args"]),
|
|
411
|
+
kwargs=cloudpickle.loads(message[b"kwargs"]),
|
|
412
|
+
key=message[b"key"].decode(),
|
|
413
|
+
when=datetime.fromisoformat(message[b"when"].decode()),
|
|
414
|
+
attempt=int(message[b"attempt"].decode()),
|
|
415
|
+
trace_context=propagate.extract(message, getter=message_getter),
|
|
416
|
+
redelivered=redelivered,
|
|
417
|
+
)
|
|
418
|
+
await instance.sync()
|
|
419
|
+
return instance
|
|
420
|
+
|
|
421
|
+
def general_labels(self) -> Mapping[str, str]:
|
|
422
|
+
return {"docket.task": self.function.__name__}
|
|
423
|
+
|
|
424
|
+
def specific_labels(self) -> Mapping[str, str | int]:
|
|
425
|
+
return {
|
|
426
|
+
"docket.task": self.function.__name__,
|
|
427
|
+
"docket.key": self.key,
|
|
428
|
+
"docket.when": self.when.isoformat(),
|
|
429
|
+
"docket.attempt": self.attempt,
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
def get_argument(self, parameter: str) -> Any:
|
|
433
|
+
signature = get_signature(self.function)
|
|
434
|
+
bound_args = signature.bind(*self.args, **self.kwargs)
|
|
435
|
+
return bound_args.arguments[parameter]
|
|
436
|
+
|
|
437
|
+
def call_repr(self) -> str:
|
|
438
|
+
arguments: list[str] = []
|
|
439
|
+
function_name = self.function.__name__
|
|
440
|
+
|
|
441
|
+
signature = get_signature(self.function)
|
|
442
|
+
logged_parameters = Logged.annotated_parameters(signature)
|
|
443
|
+
parameter_names = list(signature.parameters.keys())
|
|
444
|
+
|
|
445
|
+
for i, argument in enumerate(self.args[: len(parameter_names)]):
|
|
446
|
+
parameter_name = parameter_names[i]
|
|
447
|
+
if logged := logged_parameters.get(parameter_name):
|
|
448
|
+
arguments.append(logged.format(argument))
|
|
449
|
+
else:
|
|
450
|
+
arguments.append("...")
|
|
451
|
+
|
|
452
|
+
for parameter_name, argument in self.kwargs.items():
|
|
453
|
+
if logged := logged_parameters.get(parameter_name):
|
|
454
|
+
arguments.append(f"{parameter_name}={logged.format(argument)}")
|
|
455
|
+
else:
|
|
456
|
+
arguments.append(f"{parameter_name}=...")
|
|
457
|
+
|
|
458
|
+
return f"{function_name}({', '.join(arguments)}){{{self.key}}}"
|
|
459
|
+
|
|
460
|
+
def incoming_span_links(self) -> list[trace.Link]:
|
|
461
|
+
initiating_span = trace.get_current_span(self.trace_context)
|
|
462
|
+
initiating_context = initiating_span.get_span_context()
|
|
463
|
+
return [trace.Link(initiating_context)] if initiating_context.is_valid else []
|
|
464
|
+
|
|
465
|
+
async def schedule(
|
|
466
|
+
self, replace: bool = False, reschedule_message: "RedisMessageID | None" = None
|
|
467
|
+
) -> None:
|
|
468
|
+
"""Schedule this task atomically in Redis.
|
|
469
|
+
|
|
470
|
+
This performs an atomic operation that:
|
|
471
|
+
- Adds the task to the stream (immediate) or queue (future)
|
|
472
|
+
- Writes the execution state record
|
|
473
|
+
- Tracks metadata for later cancellation
|
|
474
|
+
|
|
475
|
+
Usage patterns:
|
|
476
|
+
- Normal add: schedule(replace=False)
|
|
477
|
+
- Replace existing: schedule(replace=True)
|
|
478
|
+
- Reschedule from stream: schedule(reschedule_message=message_id)
|
|
479
|
+
This atomically acknowledges and deletes the stream message, then
|
|
480
|
+
reschedules the task to the queue. Prevents both task loss and
|
|
481
|
+
duplicate execution when rescheduling tasks (e.g., due to concurrency limits).
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
replace: If True, replaces any existing task with the same key.
|
|
485
|
+
If False, raises an error if the task already exists.
|
|
486
|
+
reschedule_message: If provided, atomically acknowledges and deletes
|
|
487
|
+
this stream message ID before rescheduling the task to the queue.
|
|
488
|
+
Used when a task needs to be rescheduled from an active stream message.
|
|
489
|
+
"""
|
|
490
|
+
message: dict[bytes, bytes] = self.as_message()
|
|
491
|
+
propagate.inject(message, setter=message_setter)
|
|
492
|
+
|
|
493
|
+
key = self.key
|
|
494
|
+
when = self.when
|
|
495
|
+
known_task_key = self.docket.known_task_key(key)
|
|
496
|
+
is_immediate = when <= datetime.now(timezone.utc)
|
|
497
|
+
|
|
498
|
+
async with self.docket.redis() as redis:
|
|
499
|
+
# Lock per task key to prevent race conditions between concurrent operations
|
|
500
|
+
async with redis.lock(f"{known_task_key}:lock", timeout=10):
|
|
501
|
+
# Register script for this connection (not cached to avoid event loop issues)
|
|
502
|
+
schedule_script = cast(
|
|
503
|
+
_schedule_task,
|
|
504
|
+
redis.register_script(
|
|
505
|
+
# KEYS: stream_key, known_key, parked_key, queue_key, stream_id_key, runs_key, worker_group_key
|
|
506
|
+
# ARGV: task_key, when_timestamp, is_immediate, replace, reschedule_message_id, ...message_fields
|
|
507
|
+
"""
|
|
508
|
+
local stream_key = KEYS[1]
|
|
509
|
+
-- TODO: Remove in next breaking release (v0.14.0) - legacy key locations
|
|
510
|
+
local known_key = KEYS[2]
|
|
511
|
+
local parked_key = KEYS[3]
|
|
512
|
+
local queue_key = KEYS[4]
|
|
513
|
+
local stream_id_key = KEYS[5]
|
|
514
|
+
local runs_key = KEYS[6]
|
|
515
|
+
local worker_group_name = KEYS[7]
|
|
516
|
+
|
|
517
|
+
local task_key = ARGV[1]
|
|
518
|
+
local when_timestamp = ARGV[2]
|
|
519
|
+
local is_immediate = ARGV[3] == '1'
|
|
520
|
+
local replace = ARGV[4] == '1'
|
|
521
|
+
local reschedule_message_id = ARGV[5]
|
|
522
|
+
|
|
523
|
+
-- Extract message fields from ARGV[6] onwards
|
|
524
|
+
local message = {}
|
|
525
|
+
local function_name = nil
|
|
526
|
+
local args_data = nil
|
|
527
|
+
local kwargs_data = nil
|
|
528
|
+
|
|
529
|
+
for i = 6, #ARGV, 2 do
|
|
530
|
+
local field_name = ARGV[i]
|
|
531
|
+
local field_value = ARGV[i + 1]
|
|
532
|
+
message[#message + 1] = field_name
|
|
533
|
+
message[#message + 1] = field_value
|
|
534
|
+
|
|
535
|
+
-- Extract task data fields for runs hash
|
|
536
|
+
if field_name == 'function' then
|
|
537
|
+
function_name = field_value
|
|
538
|
+
elseif field_name == 'args' then
|
|
539
|
+
args_data = field_value
|
|
540
|
+
elseif field_name == 'kwargs' then
|
|
541
|
+
kwargs_data = field_value
|
|
542
|
+
end
|
|
543
|
+
end
|
|
544
|
+
|
|
545
|
+
-- Handle rescheduling from stream: atomically ACK message and reschedule to queue
|
|
546
|
+
-- This prevents both task loss (ACK before reschedule) and duplicate execution
|
|
547
|
+
-- (reschedule before ACK with slow reschedule causing redelivery)
|
|
548
|
+
if reschedule_message_id ~= '' then
|
|
549
|
+
-- Acknowledge and delete the message from the stream
|
|
550
|
+
redis.call('XACK', stream_key, worker_group_name, reschedule_message_id)
|
|
551
|
+
redis.call('XDEL', stream_key, reschedule_message_id)
|
|
552
|
+
|
|
553
|
+
-- Park task data for future execution
|
|
554
|
+
redis.call('HSET', parked_key, unpack(message))
|
|
555
|
+
|
|
556
|
+
-- Add to sorted set queue
|
|
557
|
+
redis.call('ZADD', queue_key, when_timestamp, task_key)
|
|
558
|
+
|
|
559
|
+
-- Update state in runs hash (clear stream_id since task is no longer in stream)
|
|
560
|
+
redis.call('HSET', runs_key,
|
|
561
|
+
'state', 'scheduled',
|
|
562
|
+
'when', when_timestamp,
|
|
563
|
+
'function', function_name,
|
|
564
|
+
'args', args_data,
|
|
565
|
+
'kwargs', kwargs_data
|
|
566
|
+
)
|
|
567
|
+
redis.call('HDEL', runs_key, 'stream_id')
|
|
568
|
+
|
|
569
|
+
return 'OK'
|
|
570
|
+
end
|
|
571
|
+
|
|
572
|
+
-- Handle replacement: cancel existing task if needed
|
|
573
|
+
if replace then
|
|
574
|
+
-- Get stream ID from runs hash (check new location first)
|
|
575
|
+
local existing_message_id = redis.call('HGET', runs_key, 'stream_id')
|
|
576
|
+
|
|
577
|
+
-- TODO: Remove in next breaking release (v0.14.0) - check legacy location
|
|
578
|
+
if not existing_message_id then
|
|
579
|
+
existing_message_id = redis.call('GET', stream_id_key)
|
|
580
|
+
end
|
|
581
|
+
|
|
582
|
+
if existing_message_id then
|
|
583
|
+
redis.call('XDEL', stream_key, existing_message_id)
|
|
584
|
+
end
|
|
585
|
+
|
|
586
|
+
redis.call('ZREM', queue_key, task_key)
|
|
587
|
+
redis.call('DEL', parked_key)
|
|
588
|
+
|
|
589
|
+
-- TODO: Remove in next breaking release (v0.14.0) - clean up legacy keys
|
|
590
|
+
redis.call('DEL', known_key, stream_id_key)
|
|
591
|
+
|
|
592
|
+
-- Note: runs_key is updated below, not deleted
|
|
593
|
+
else
|
|
594
|
+
-- Check if task already exists (check new location first, then legacy)
|
|
595
|
+
local known_exists = redis.call('HEXISTS', runs_key, 'known') == 1
|
|
596
|
+
if not known_exists then
|
|
597
|
+
-- Check if task is currently running (known field deleted at claim time)
|
|
598
|
+
local state = redis.call('HGET', runs_key, 'state')
|
|
599
|
+
if state == 'running' then
|
|
600
|
+
return 'EXISTS'
|
|
601
|
+
end
|
|
602
|
+
-- TODO: Remove in next breaking release (v0.14.0) - check legacy location
|
|
603
|
+
known_exists = redis.call('EXISTS', known_key) == 1
|
|
604
|
+
end
|
|
605
|
+
if known_exists then
|
|
606
|
+
return 'EXISTS'
|
|
607
|
+
end
|
|
608
|
+
end
|
|
609
|
+
|
|
610
|
+
if is_immediate then
|
|
611
|
+
-- Add to stream for immediate execution
|
|
612
|
+
local message_id = redis.call('XADD', stream_key, '*', unpack(message))
|
|
613
|
+
|
|
614
|
+
-- Store state and metadata in runs hash
|
|
615
|
+
redis.call('HSET', runs_key,
|
|
616
|
+
'state', 'queued',
|
|
617
|
+
'when', when_timestamp,
|
|
618
|
+
'known', when_timestamp,
|
|
619
|
+
'stream_id', message_id,
|
|
620
|
+
'function', function_name,
|
|
621
|
+
'args', args_data,
|
|
622
|
+
'kwargs', kwargs_data
|
|
623
|
+
)
|
|
624
|
+
else
|
|
625
|
+
-- Park task data for future execution
|
|
626
|
+
redis.call('HSET', parked_key, unpack(message))
|
|
627
|
+
|
|
628
|
+
-- Add to sorted set queue
|
|
629
|
+
redis.call('ZADD', queue_key, when_timestamp, task_key)
|
|
630
|
+
|
|
631
|
+
-- Store state and metadata in runs hash
|
|
632
|
+
redis.call('HSET', runs_key,
|
|
633
|
+
'state', 'scheduled',
|
|
634
|
+
'when', when_timestamp,
|
|
635
|
+
'known', when_timestamp,
|
|
636
|
+
'function', function_name,
|
|
637
|
+
'args', args_data,
|
|
638
|
+
'kwargs', kwargs_data
|
|
639
|
+
)
|
|
640
|
+
end
|
|
641
|
+
|
|
642
|
+
return 'OK'
|
|
643
|
+
"""
|
|
644
|
+
),
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
await schedule_script(
|
|
648
|
+
keys=[
|
|
649
|
+
self.docket.stream_key,
|
|
650
|
+
known_task_key,
|
|
651
|
+
self.docket.parked_task_key(key),
|
|
652
|
+
self.docket.queue_key,
|
|
653
|
+
self.docket.stream_id_key(key),
|
|
654
|
+
self._redis_key,
|
|
655
|
+
self.docket.worker_group_name,
|
|
656
|
+
],
|
|
657
|
+
args=[
|
|
658
|
+
key,
|
|
659
|
+
str(when.timestamp()),
|
|
660
|
+
"1" if is_immediate else "0",
|
|
661
|
+
"1" if replace else "0",
|
|
662
|
+
reschedule_message or b"",
|
|
663
|
+
*[
|
|
664
|
+
item
|
|
665
|
+
for field, value in message.items()
|
|
666
|
+
for item in (field, value)
|
|
667
|
+
],
|
|
668
|
+
],
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
# Update local state based on whether task is immediate, scheduled, or being rescheduled
|
|
672
|
+
if reschedule_message:
|
|
673
|
+
# When rescheduling from stream, task is always parked and queued (never immediate)
|
|
674
|
+
self.state = ExecutionState.SCHEDULED
|
|
675
|
+
await self._publish_state(
|
|
676
|
+
{"state": ExecutionState.SCHEDULED.value, "when": when.isoformat()}
|
|
677
|
+
)
|
|
678
|
+
elif is_immediate:
|
|
679
|
+
self.state = ExecutionState.QUEUED
|
|
680
|
+
await self._publish_state(
|
|
681
|
+
{"state": ExecutionState.QUEUED.value, "when": when.isoformat()}
|
|
682
|
+
)
|
|
683
|
+
else:
|
|
684
|
+
self.state = ExecutionState.SCHEDULED
|
|
685
|
+
await self._publish_state(
|
|
686
|
+
{"state": ExecutionState.SCHEDULED.value, "when": when.isoformat()}
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
async def claim(self, worker: str) -> None:
|
|
690
|
+
"""Atomically claim task and transition to RUNNING state.
|
|
691
|
+
|
|
692
|
+
This consolidates worker operations when claiming a task into a single
|
|
693
|
+
atomic Lua script that:
|
|
694
|
+
- Sets state to RUNNING with worker name and timestamp
|
|
695
|
+
- Initializes progress tracking (current=0, total=100)
|
|
696
|
+
- Deletes known/stream_id fields to allow task rescheduling
|
|
697
|
+
- Cleans up legacy keys for backwards compatibility
|
|
698
|
+
|
|
699
|
+
Args:
|
|
700
|
+
worker: Name of the worker claiming the task
|
|
701
|
+
"""
|
|
702
|
+
started_at = datetime.now(timezone.utc)
|
|
703
|
+
started_at_iso = started_at.isoformat()
|
|
704
|
+
|
|
705
|
+
async with self.docket.redis() as redis:
|
|
706
|
+
claim_script = redis.register_script(
|
|
707
|
+
# KEYS: runs_key, progress_key, known_key, stream_id_key
|
|
708
|
+
# ARGV: worker, started_at_iso
|
|
709
|
+
"""
|
|
710
|
+
local runs_key = KEYS[1]
|
|
711
|
+
local progress_key = KEYS[2]
|
|
712
|
+
-- TODO: Remove in next breaking release (v0.14.0) - legacy key locations
|
|
713
|
+
local known_key = KEYS[3]
|
|
714
|
+
local stream_id_key = KEYS[4]
|
|
715
|
+
|
|
716
|
+
local worker = ARGV[1]
|
|
717
|
+
local started_at = ARGV[2]
|
|
718
|
+
|
|
719
|
+
-- Update execution state to running
|
|
720
|
+
redis.call('HSET', runs_key,
|
|
721
|
+
'state', 'running',
|
|
722
|
+
'worker', worker,
|
|
723
|
+
'started_at', started_at
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
-- Initialize progress tracking
|
|
727
|
+
redis.call('HSET', progress_key,
|
|
728
|
+
'current', '0',
|
|
729
|
+
'total', '100'
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
-- Delete known/stream_id fields to allow task rescheduling
|
|
733
|
+
redis.call('HDEL', runs_key, 'known', 'stream_id')
|
|
734
|
+
|
|
735
|
+
-- TODO: Remove in next breaking release (v0.14.0) - legacy key cleanup
|
|
736
|
+
redis.call('DEL', known_key, stream_id_key)
|
|
737
|
+
|
|
738
|
+
return 'OK'
|
|
739
|
+
"""
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
await claim_script(
|
|
743
|
+
keys=[
|
|
744
|
+
self._redis_key, # runs_key
|
|
745
|
+
self.progress._redis_key, # progress_key
|
|
746
|
+
f"{self.docket.name}:known:{self.key}", # legacy known_key
|
|
747
|
+
f"{self.docket.name}:stream-id:{self.key}", # legacy stream_id_key
|
|
748
|
+
],
|
|
749
|
+
args=[worker, started_at_iso],
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
# Update local state
|
|
753
|
+
self.state = ExecutionState.RUNNING
|
|
754
|
+
self.worker = worker
|
|
755
|
+
self.started_at = started_at
|
|
756
|
+
self.progress.current = 0
|
|
757
|
+
self.progress.total = 100
|
|
758
|
+
|
|
759
|
+
# Publish state change event
|
|
760
|
+
await self._publish_state(
|
|
761
|
+
{
|
|
762
|
+
"state": ExecutionState.RUNNING.value,
|
|
763
|
+
"worker": worker,
|
|
764
|
+
"started_at": started_at_iso,
|
|
765
|
+
}
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
async def mark_as_completed(self, result_key: str | None = None) -> None:
|
|
769
|
+
"""Mark task as completed successfully.
|
|
770
|
+
|
|
771
|
+
Args:
|
|
772
|
+
result_key: Optional key where the task result is stored
|
|
773
|
+
|
|
774
|
+
Sets TTL on state data (from docket.execution_ttl), or deletes state
|
|
775
|
+
immediately if execution_ttl is 0. Also deletes progress data.
|
|
776
|
+
"""
|
|
777
|
+
completed_at = datetime.now(timezone.utc).isoformat()
|
|
778
|
+
async with self.docket.redis() as redis:
|
|
779
|
+
mapping: dict[str, str] = {
|
|
780
|
+
"state": ExecutionState.COMPLETED.value,
|
|
781
|
+
"completed_at": completed_at,
|
|
782
|
+
}
|
|
783
|
+
if result_key is not None:
|
|
784
|
+
mapping["result_key"] = result_key
|
|
785
|
+
await redis.hset(
|
|
786
|
+
self._redis_key,
|
|
787
|
+
mapping=mapping,
|
|
788
|
+
)
|
|
789
|
+
# Set TTL from docket configuration, or delete if TTL=0
|
|
790
|
+
if self.docket.execution_ttl:
|
|
791
|
+
ttl_seconds = int(self.docket.execution_ttl.total_seconds())
|
|
792
|
+
await redis.expire(self._redis_key, ttl_seconds)
|
|
793
|
+
else:
|
|
794
|
+
await redis.delete(self._redis_key)
|
|
795
|
+
self.state = ExecutionState.COMPLETED
|
|
796
|
+
self.result_key = result_key
|
|
797
|
+
# Delete progress data
|
|
798
|
+
await self.progress.delete()
|
|
799
|
+
# Publish state change event
|
|
800
|
+
await self._publish_state(
|
|
801
|
+
{"state": ExecutionState.COMPLETED.value, "completed_at": completed_at}
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
async def mark_as_failed(
|
|
805
|
+
self, error: str | None = None, result_key: str | None = None
|
|
806
|
+
) -> None:
|
|
807
|
+
"""Mark task as failed.
|
|
808
|
+
|
|
809
|
+
Args:
|
|
810
|
+
error: Optional error message describing the failure
|
|
811
|
+
result_key: Optional key where the exception is stored
|
|
812
|
+
|
|
813
|
+
Sets TTL on state data (from docket.execution_ttl), or deletes state
|
|
814
|
+
immediately if execution_ttl is 0. Also deletes progress data.
|
|
815
|
+
"""
|
|
816
|
+
completed_at = datetime.now(timezone.utc).isoformat()
|
|
817
|
+
async with self.docket.redis() as redis:
|
|
818
|
+
mapping = {
|
|
819
|
+
"state": ExecutionState.FAILED.value,
|
|
820
|
+
"completed_at": completed_at,
|
|
821
|
+
}
|
|
822
|
+
if error:
|
|
823
|
+
mapping["error"] = error
|
|
824
|
+
if result_key is not None:
|
|
825
|
+
mapping["result_key"] = result_key
|
|
826
|
+
await redis.hset(self._redis_key, mapping=mapping)
|
|
827
|
+
# Set TTL from docket configuration, or delete if TTL=0
|
|
828
|
+
if self.docket.execution_ttl:
|
|
829
|
+
ttl_seconds = int(self.docket.execution_ttl.total_seconds())
|
|
830
|
+
await redis.expire(self._redis_key, ttl_seconds)
|
|
831
|
+
else:
|
|
832
|
+
await redis.delete(self._redis_key)
|
|
833
|
+
self.state = ExecutionState.FAILED
|
|
834
|
+
self.result_key = result_key
|
|
835
|
+
# Delete progress data
|
|
836
|
+
await self.progress.delete()
|
|
837
|
+
# Publish state change event
|
|
838
|
+
state_data = {
|
|
839
|
+
"state": ExecutionState.FAILED.value,
|
|
840
|
+
"completed_at": completed_at,
|
|
841
|
+
}
|
|
842
|
+
if error:
|
|
843
|
+
state_data["error"] = error
|
|
844
|
+
await self._publish_state(state_data)
|
|
845
|
+
|
|
846
|
+
async def get_result(
|
|
847
|
+
self,
|
|
848
|
+
*,
|
|
849
|
+
timeout: timedelta | None = None,
|
|
850
|
+
deadline: datetime | None = None,
|
|
851
|
+
) -> Any:
|
|
852
|
+
"""Retrieve the result of this task execution.
|
|
853
|
+
|
|
854
|
+
If the execution is not yet complete, this method will wait using
|
|
855
|
+
pub/sub for state updates until completion.
|
|
856
|
+
|
|
857
|
+
Args:
|
|
858
|
+
timeout: Optional duration to wait before giving up.
|
|
859
|
+
If None and deadline is None, waits indefinitely.
|
|
860
|
+
deadline: Optional absolute datetime when to stop waiting.
|
|
861
|
+
If None and timeout is None, waits indefinitely.
|
|
862
|
+
|
|
863
|
+
Returns:
|
|
864
|
+
The result of the task execution, or None if the task returned None.
|
|
865
|
+
|
|
866
|
+
Raises:
|
|
867
|
+
ValueError: If both timeout and deadline are provided
|
|
868
|
+
Exception: If the task failed, raises the stored exception
|
|
869
|
+
TimeoutError: If timeout/deadline is reached before execution completes
|
|
870
|
+
"""
|
|
871
|
+
# Validate that only one time limit is provided
|
|
872
|
+
if timeout is not None and deadline is not None:
|
|
873
|
+
raise ValueError("Cannot specify both timeout and deadline")
|
|
874
|
+
|
|
875
|
+
# Convert timeout to deadline if provided
|
|
876
|
+
if timeout is not None:
|
|
877
|
+
deadline = datetime.now(timezone.utc) + timeout
|
|
878
|
+
|
|
879
|
+
# Wait for execution to complete if not already done
|
|
880
|
+
if self.state not in (ExecutionState.COMPLETED, ExecutionState.FAILED):
|
|
881
|
+
# Calculate timeout duration if absolute deadline provided
|
|
882
|
+
timeout_seconds = None
|
|
883
|
+
if deadline is not None:
|
|
884
|
+
timeout_seconds = (
|
|
885
|
+
deadline - datetime.now(timezone.utc)
|
|
886
|
+
).total_seconds()
|
|
887
|
+
if timeout_seconds <= 0:
|
|
888
|
+
raise TimeoutError(
|
|
889
|
+
f"Timeout waiting for execution {self.key} to complete"
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
try:
|
|
893
|
+
|
|
894
|
+
async def wait_for_completion():
|
|
895
|
+
async for event in self.subscribe(): # pragma: no branch
|
|
896
|
+
if event["type"] == "state":
|
|
897
|
+
state = ExecutionState(event["state"])
|
|
898
|
+
if state in (
|
|
899
|
+
ExecutionState.COMPLETED,
|
|
900
|
+
ExecutionState.FAILED,
|
|
901
|
+
):
|
|
902
|
+
# Sync to get latest data including result key
|
|
903
|
+
await self.sync()
|
|
904
|
+
break
|
|
905
|
+
|
|
906
|
+
# Use asyncio.wait_for to enforce timeout
|
|
907
|
+
await asyncio.wait_for(wait_for_completion(), timeout=timeout_seconds)
|
|
908
|
+
except asyncio.TimeoutError:
|
|
909
|
+
raise TimeoutError(
|
|
910
|
+
f"Timeout waiting for execution {self.key} to complete"
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
# If failed, retrieve and raise the exception
|
|
914
|
+
if self.state == ExecutionState.FAILED:
|
|
915
|
+
if self.result_key:
|
|
916
|
+
# Retrieve serialized exception from result_storage
|
|
917
|
+
result_data = await self.docket.result_storage.get(self.result_key)
|
|
918
|
+
if result_data and "data" in result_data:
|
|
919
|
+
# Base64-decode and unpickle
|
|
920
|
+
pickled_exception = base64.b64decode(result_data["data"])
|
|
921
|
+
exception = cloudpickle.loads(pickled_exception) # type: ignore[arg-type]
|
|
922
|
+
raise exception
|
|
923
|
+
# If no stored exception, raise a generic error with the error message
|
|
924
|
+
error_msg = self.error or "Task execution failed"
|
|
925
|
+
raise Exception(error_msg)
|
|
926
|
+
|
|
927
|
+
# If completed successfully, retrieve result if available
|
|
928
|
+
if self.result_key:
|
|
929
|
+
result_data = await self.docket.result_storage.get(self.result_key)
|
|
930
|
+
if result_data is not None and "data" in result_data:
|
|
931
|
+
# Base64-decode and unpickle
|
|
932
|
+
pickled_result = base64.b64decode(result_data["data"])
|
|
933
|
+
return cloudpickle.loads(pickled_result) # type: ignore[arg-type]
|
|
934
|
+
|
|
935
|
+
# No result stored - task returned None
|
|
936
|
+
return None
|
|
937
|
+
|
|
938
|
+
async def sync(self) -> None:
|
|
939
|
+
"""Synchronize instance attributes with current execution data from Redis.
|
|
940
|
+
|
|
941
|
+
Updates self.state, execution metadata, and progress data from Redis.
|
|
942
|
+
Sets attributes to None if no data exists.
|
|
943
|
+
"""
|
|
944
|
+
async with self.docket.redis() as redis:
|
|
945
|
+
data = await redis.hgetall(self._redis_key)
|
|
946
|
+
if data:
|
|
947
|
+
# Update state
|
|
948
|
+
state_value = data.get(b"state")
|
|
949
|
+
if state_value:
|
|
950
|
+
if isinstance(state_value, bytes):
|
|
951
|
+
state_value = state_value.decode()
|
|
952
|
+
self.state = ExecutionState(state_value)
|
|
953
|
+
|
|
954
|
+
# Update metadata
|
|
955
|
+
self.worker = data[b"worker"].decode() if b"worker" in data else None
|
|
956
|
+
self.started_at = (
|
|
957
|
+
datetime.fromisoformat(data[b"started_at"].decode())
|
|
958
|
+
if b"started_at" in data
|
|
959
|
+
else None
|
|
960
|
+
)
|
|
961
|
+
self.completed_at = (
|
|
962
|
+
datetime.fromisoformat(data[b"completed_at"].decode())
|
|
963
|
+
if b"completed_at" in data
|
|
964
|
+
else None
|
|
965
|
+
)
|
|
966
|
+
self.error = data[b"error"].decode() if b"error" in data else None
|
|
967
|
+
self.result_key = (
|
|
968
|
+
data[b"result_key"].decode() if b"result_key" in data else None
|
|
969
|
+
)
|
|
970
|
+
else:
|
|
971
|
+
# No data exists - reset to defaults
|
|
972
|
+
self.state = ExecutionState.SCHEDULED
|
|
973
|
+
self.worker = None
|
|
974
|
+
self.started_at = None
|
|
975
|
+
self.completed_at = None
|
|
976
|
+
self.error = None
|
|
977
|
+
self.result_key = None
|
|
978
|
+
|
|
979
|
+
# Sync progress data
|
|
980
|
+
await self.progress.sync()
|
|
981
|
+
|
|
982
|
+
async def _publish_state(self, data: dict) -> None:
|
|
983
|
+
"""Publish state change to Redis pub/sub channel.
|
|
984
|
+
|
|
985
|
+
Args:
|
|
986
|
+
data: State data to publish
|
|
987
|
+
"""
|
|
988
|
+
channel = f"{self.docket.name}:state:{self.key}"
|
|
989
|
+
# Create ephemeral Redis client for publishing
|
|
990
|
+
async with self.docket.redis() as redis:
|
|
991
|
+
# Build payload with all relevant state information
|
|
992
|
+
payload = {
|
|
993
|
+
"type": "state",
|
|
994
|
+
"key": self.key,
|
|
995
|
+
**data,
|
|
996
|
+
}
|
|
997
|
+
await redis.publish(channel, json.dumps(payload))
|
|
998
|
+
|
|
999
|
+
async def subscribe(self) -> AsyncGenerator[StateEvent | ProgressEvent, None]:
|
|
1000
|
+
"""Subscribe to both state and progress updates for this task.
|
|
1001
|
+
|
|
1002
|
+
Emits the current state as the first event, then subscribes to real-time
|
|
1003
|
+
state and progress updates via Redis pub/sub.
|
|
1004
|
+
|
|
1005
|
+
Yields:
|
|
1006
|
+
Dict containing state or progress update events with a 'type' field:
|
|
1007
|
+
- For state events: type="state", state, worker, timestamps, error
|
|
1008
|
+
- For progress events: type="progress", current, total, message, updated_at
|
|
1009
|
+
"""
|
|
1010
|
+
# First, emit the current state
|
|
1011
|
+
await self.sync()
|
|
1012
|
+
|
|
1013
|
+
# Build initial state event from current attributes
|
|
1014
|
+
initial_state: StateEvent = {
|
|
1015
|
+
"type": "state",
|
|
1016
|
+
"key": self.key,
|
|
1017
|
+
"state": self.state,
|
|
1018
|
+
"when": self.when.isoformat(),
|
|
1019
|
+
"worker": self.worker,
|
|
1020
|
+
"started_at": self.started_at.isoformat() if self.started_at else None,
|
|
1021
|
+
"completed_at": self.completed_at.isoformat()
|
|
1022
|
+
if self.completed_at
|
|
1023
|
+
else None,
|
|
1024
|
+
"error": self.error,
|
|
1025
|
+
}
|
|
1026
|
+
|
|
1027
|
+
yield initial_state
|
|
1028
|
+
|
|
1029
|
+
progress_event: ProgressEvent = {
|
|
1030
|
+
"type": "progress",
|
|
1031
|
+
"key": self.key,
|
|
1032
|
+
"current": self.progress.current,
|
|
1033
|
+
"total": self.progress.total,
|
|
1034
|
+
"message": self.progress.message,
|
|
1035
|
+
"updated_at": self.progress.updated_at.isoformat()
|
|
1036
|
+
if self.progress.updated_at
|
|
1037
|
+
else None,
|
|
1038
|
+
}
|
|
1039
|
+
|
|
1040
|
+
yield progress_event
|
|
1041
|
+
|
|
1042
|
+
# Then subscribe to real-time updates
|
|
1043
|
+
state_channel = f"{self.docket.name}:state:{self.key}"
|
|
1044
|
+
progress_channel = f"{self.docket.name}:progress:{self.key}"
|
|
1045
|
+
async with self.docket.redis() as redis:
|
|
1046
|
+
async with redis.pubsub() as pubsub:
|
|
1047
|
+
await pubsub.subscribe(state_channel, progress_channel)
|
|
1048
|
+
try:
|
|
1049
|
+
async for message in pubsub.listen(): # pragma: no cover
|
|
1050
|
+
if message["type"] == "message":
|
|
1051
|
+
message_data = json.loads(message["data"])
|
|
1052
|
+
if message_data["type"] == "state":
|
|
1053
|
+
message_data["state"] = ExecutionState(
|
|
1054
|
+
message_data["state"]
|
|
1055
|
+
)
|
|
1056
|
+
yield message_data
|
|
1057
|
+
finally:
|
|
1058
|
+
# Explicitly unsubscribe to ensure clean shutdown
|
|
1059
|
+
await pubsub.unsubscribe(state_channel, progress_channel)
|
|
1060
|
+
|
|
1061
|
+
|
|
1062
|
+
def compact_signature(signature: inspect.Signature) -> str:
|
|
1063
|
+
from .dependencies import Dependency
|
|
1064
|
+
|
|
1065
|
+
parameters: list[str] = []
|
|
1066
|
+
dependencies: int = 0
|
|
1067
|
+
|
|
1068
|
+
for parameter in signature.parameters.values():
|
|
1069
|
+
if isinstance(parameter.default, Dependency):
|
|
1070
|
+
dependencies += 1
|
|
1071
|
+
continue
|
|
1072
|
+
|
|
1073
|
+
parameter_definition = parameter.name
|
|
1074
|
+
if parameter.annotation is not parameter.empty:
|
|
1075
|
+
annotation = parameter.annotation
|
|
1076
|
+
if hasattr(annotation, "__origin__"):
|
|
1077
|
+
annotation = annotation.__args__[0]
|
|
1078
|
+
|
|
1079
|
+
type_name = getattr(annotation, "__name__", str(annotation))
|
|
1080
|
+
parameter_definition = f"{parameter.name}: {type_name}"
|
|
1081
|
+
|
|
1082
|
+
if parameter.default is not parameter.empty:
|
|
1083
|
+
parameter_definition = f"{parameter_definition} = {parameter.default!r}"
|
|
1084
|
+
|
|
1085
|
+
parameters.append(parameter_definition)
|
|
1086
|
+
|
|
1087
|
+
if dependencies > 0:
|
|
1088
|
+
parameters.append("...")
|
|
1089
|
+
|
|
1090
|
+
return ", ".join(parameters)
|
|
1091
|
+
|
|
1092
|
+
|
|
1093
|
+
class Operator(str, enum.Enum):
|
|
1094
|
+
EQUAL = "=="
|
|
1095
|
+
NOT_EQUAL = "!="
|
|
1096
|
+
GREATER_THAN = ">"
|
|
1097
|
+
GREATER_THAN_OR_EQUAL = ">="
|
|
1098
|
+
LESS_THAN = "<"
|
|
1099
|
+
LESS_THAN_OR_EQUAL = "<="
|
|
1100
|
+
BETWEEN = "between"
|
|
1101
|
+
|
|
1102
|
+
|
|
1103
|
+
LiteralOperator = Literal["==", "!=", ">", ">=", "<", "<=", "between"]
|
|
1104
|
+
|
|
1105
|
+
|
|
1106
|
+
class StrikeInstruction(abc.ABC):
|
|
1107
|
+
direction: Literal["strike", "restore"]
|
|
1108
|
+
operator: Operator
|
|
1109
|
+
|
|
1110
|
+
def __init__(
|
|
1111
|
+
self,
|
|
1112
|
+
function: str | None,
|
|
1113
|
+
parameter: str | None,
|
|
1114
|
+
operator: Operator,
|
|
1115
|
+
value: Hashable,
|
|
1116
|
+
) -> None:
|
|
1117
|
+
self.function = function
|
|
1118
|
+
self.parameter = parameter
|
|
1119
|
+
self.operator = operator
|
|
1120
|
+
self.value = value
|
|
1121
|
+
|
|
1122
|
+
def as_message(self) -> Message:
|
|
1123
|
+
message: dict[bytes, bytes] = {b"direction": self.direction.encode()}
|
|
1124
|
+
if self.function:
|
|
1125
|
+
message[b"function"] = self.function.encode()
|
|
1126
|
+
if self.parameter:
|
|
1127
|
+
message[b"parameter"] = self.parameter.encode()
|
|
1128
|
+
message[b"operator"] = self.operator.encode()
|
|
1129
|
+
message[b"value"] = cloudpickle.dumps(self.value) # type: ignore[arg-type]
|
|
1130
|
+
return message
|
|
1131
|
+
|
|
1132
|
+
@classmethod
|
|
1133
|
+
def from_message(cls, message: Message) -> "StrikeInstruction":
|
|
1134
|
+
direction = cast(Literal["strike", "restore"], message[b"direction"].decode())
|
|
1135
|
+
function = message[b"function"].decode() if b"function" in message else None
|
|
1136
|
+
parameter = message[b"parameter"].decode() if b"parameter" in message else None
|
|
1137
|
+
operator = cast(Operator, message[b"operator"].decode())
|
|
1138
|
+
value = cloudpickle.loads(message[b"value"])
|
|
1139
|
+
if direction == "strike":
|
|
1140
|
+
return Strike(function, parameter, operator, value)
|
|
1141
|
+
else:
|
|
1142
|
+
return Restore(function, parameter, operator, value)
|
|
1143
|
+
|
|
1144
|
+
def labels(self) -> Mapping[str, str]:
|
|
1145
|
+
labels: dict[str, str] = {}
|
|
1146
|
+
if self.function:
|
|
1147
|
+
labels["docket.task"] = self.function
|
|
1148
|
+
|
|
1149
|
+
if self.parameter:
|
|
1150
|
+
labels["docket.parameter"] = self.parameter
|
|
1151
|
+
labels["docket.operator"] = self.operator
|
|
1152
|
+
labels["docket.value"] = repr(self.value)
|
|
1153
|
+
|
|
1154
|
+
return labels
|
|
1155
|
+
|
|
1156
|
+
def call_repr(self) -> str:
|
|
1157
|
+
return (
|
|
1158
|
+
f"{self.function or '*'}"
|
|
1159
|
+
"("
|
|
1160
|
+
f"{self.parameter or '*'}"
|
|
1161
|
+
" "
|
|
1162
|
+
f"{self.operator}"
|
|
1163
|
+
" "
|
|
1164
|
+
f"{repr(self.value) if self.parameter else '*'}"
|
|
1165
|
+
")"
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
|
+
|
|
1169
|
+
class Strike(StrikeInstruction):
|
|
1170
|
+
direction: Literal["strike", "restore"] = "strike"
|
|
1171
|
+
|
|
1172
|
+
|
|
1173
|
+
class Restore(StrikeInstruction):
|
|
1174
|
+
direction: Literal["strike", "restore"] = "restore"
|
|
1175
|
+
|
|
1176
|
+
|
|
1177
|
+
MinimalStrike = tuple[Operator, Hashable]
|
|
1178
|
+
ParameterStrikes = dict[str, set[MinimalStrike]]
|
|
1179
|
+
TaskStrikes = dict[str, ParameterStrikes]
|
|
1180
|
+
|
|
1181
|
+
|
|
1182
|
+
class StrikeList:
|
|
1183
|
+
task_strikes: TaskStrikes
|
|
1184
|
+
parameter_strikes: ParameterStrikes
|
|
1185
|
+
_conditions: list[Callable[[Execution], bool]]
|
|
1186
|
+
|
|
1187
|
+
def __init__(self) -> None:
|
|
1188
|
+
self.task_strikes = {}
|
|
1189
|
+
self.parameter_strikes = {}
|
|
1190
|
+
self._conditions = [self._matches_task_or_parameter_strike]
|
|
1191
|
+
|
|
1192
|
+
def add_condition(self, condition: Callable[[Execution], bool]) -> None:
|
|
1193
|
+
"""Adds a temporary condition that indicates an execution is stricken."""
|
|
1194
|
+
self._conditions.insert(0, condition)
|
|
1195
|
+
|
|
1196
|
+
def remove_condition(self, condition: Callable[[Execution], bool]) -> None:
|
|
1197
|
+
"""Adds a temporary condition that indicates an execution is stricken."""
|
|
1198
|
+
assert condition is not self._matches_task_or_parameter_strike
|
|
1199
|
+
self._conditions.remove(condition)
|
|
1200
|
+
|
|
1201
|
+
def is_stricken(self, execution: Execution) -> bool:
|
|
1202
|
+
"""
|
|
1203
|
+
Checks if an execution is stricken based on task, parameter, or temporary
|
|
1204
|
+
conditions.
|
|
1205
|
+
|
|
1206
|
+
Returns:
|
|
1207
|
+
bool: True if the execution is stricken, False otherwise.
|
|
1208
|
+
"""
|
|
1209
|
+
return any(condition(execution) for condition in self._conditions)
|
|
1210
|
+
|
|
1211
|
+
def _matches_task_or_parameter_strike(self, execution: Execution) -> bool:
|
|
1212
|
+
function_name = execution.function.__name__
|
|
1213
|
+
|
|
1214
|
+
# Check if the entire task is stricken (without parameter conditions)
|
|
1215
|
+
task_strikes = self.task_strikes.get(function_name, {})
|
|
1216
|
+
if function_name in self.task_strikes and not task_strikes:
|
|
1217
|
+
return True
|
|
1218
|
+
|
|
1219
|
+
signature = get_signature(execution.function)
|
|
1220
|
+
|
|
1221
|
+
try:
|
|
1222
|
+
bound_args = signature.bind(*execution.args, **execution.kwargs)
|
|
1223
|
+
bound_args.apply_defaults()
|
|
1224
|
+
except TypeError:
|
|
1225
|
+
# If we can't make sense of the arguments, just assume the task is fine
|
|
1226
|
+
return False
|
|
1227
|
+
|
|
1228
|
+
all_arguments = {
|
|
1229
|
+
**bound_args.arguments,
|
|
1230
|
+
**{
|
|
1231
|
+
k: v
|
|
1232
|
+
for k, v in execution.kwargs.items()
|
|
1233
|
+
if k not in bound_args.arguments
|
|
1234
|
+
},
|
|
1235
|
+
}
|
|
1236
|
+
|
|
1237
|
+
for parameter, argument in all_arguments.items():
|
|
1238
|
+
for strike_source in [task_strikes, self.parameter_strikes]:
|
|
1239
|
+
if parameter not in strike_source:
|
|
1240
|
+
continue
|
|
1241
|
+
|
|
1242
|
+
for operator, strike_value in strike_source[parameter]:
|
|
1243
|
+
if self._is_match(argument, operator, strike_value):
|
|
1244
|
+
return True
|
|
1245
|
+
|
|
1246
|
+
return False
|
|
1247
|
+
|
|
1248
|
+
def _is_match(self, value: Any, operator: Operator, strike_value: Any) -> bool:
|
|
1249
|
+
"""Determines if a value matches a strike condition."""
|
|
1250
|
+
try:
|
|
1251
|
+
match operator:
|
|
1252
|
+
case "==":
|
|
1253
|
+
return value == strike_value
|
|
1254
|
+
case "!=":
|
|
1255
|
+
return value != strike_value
|
|
1256
|
+
case ">":
|
|
1257
|
+
return value > strike_value
|
|
1258
|
+
case ">=":
|
|
1259
|
+
return value >= strike_value
|
|
1260
|
+
case "<":
|
|
1261
|
+
return value < strike_value
|
|
1262
|
+
case "<=":
|
|
1263
|
+
return value <= strike_value
|
|
1264
|
+
case "between": # pragma: no branch
|
|
1265
|
+
lower, upper = strike_value
|
|
1266
|
+
return lower <= value <= upper
|
|
1267
|
+
case _: # pragma: no cover
|
|
1268
|
+
raise ValueError(f"Unknown operator: {operator}")
|
|
1269
|
+
except (ValueError, TypeError):
|
|
1270
|
+
# If we can't make the comparison due to incompatible types, just log the
|
|
1271
|
+
# error and assume the task is not stricken
|
|
1272
|
+
logger.warning(
|
|
1273
|
+
"Incompatible type for strike condition: %r %s %r",
|
|
1274
|
+
strike_value,
|
|
1275
|
+
operator,
|
|
1276
|
+
value,
|
|
1277
|
+
exc_info=True,
|
|
1278
|
+
)
|
|
1279
|
+
return False
|
|
1280
|
+
|
|
1281
|
+
def update(self, instruction: StrikeInstruction) -> None:
|
|
1282
|
+
try:
|
|
1283
|
+
hash(instruction.value)
|
|
1284
|
+
except TypeError:
|
|
1285
|
+
logger.warning(
|
|
1286
|
+
"Incompatible type for strike condition: %s %r",
|
|
1287
|
+
instruction.operator,
|
|
1288
|
+
instruction.value,
|
|
1289
|
+
)
|
|
1290
|
+
return
|
|
1291
|
+
|
|
1292
|
+
if isinstance(instruction, Strike):
|
|
1293
|
+
self._strike(instruction)
|
|
1294
|
+
elif isinstance(instruction, Restore): # pragma: no branch
|
|
1295
|
+
self._restore(instruction)
|
|
1296
|
+
|
|
1297
|
+
def _strike(self, strike: Strike) -> None:
|
|
1298
|
+
if strike.function and strike.parameter:
|
|
1299
|
+
try:
|
|
1300
|
+
task_strikes = self.task_strikes[strike.function]
|
|
1301
|
+
except KeyError:
|
|
1302
|
+
task_strikes = self.task_strikes[strike.function] = {}
|
|
1303
|
+
|
|
1304
|
+
try:
|
|
1305
|
+
parameter_strikes = task_strikes[strike.parameter]
|
|
1306
|
+
except KeyError:
|
|
1307
|
+
parameter_strikes = task_strikes[strike.parameter] = set()
|
|
1308
|
+
|
|
1309
|
+
parameter_strikes.add((strike.operator, strike.value))
|
|
1310
|
+
|
|
1311
|
+
elif strike.function:
|
|
1312
|
+
try:
|
|
1313
|
+
task_strikes = self.task_strikes[strike.function]
|
|
1314
|
+
except KeyError:
|
|
1315
|
+
task_strikes = self.task_strikes[strike.function] = {}
|
|
1316
|
+
|
|
1317
|
+
elif strike.parameter: # pragma: no branch
|
|
1318
|
+
try:
|
|
1319
|
+
parameter_strikes = self.parameter_strikes[strike.parameter]
|
|
1320
|
+
except KeyError:
|
|
1321
|
+
parameter_strikes = self.parameter_strikes[strike.parameter] = set()
|
|
1322
|
+
|
|
1323
|
+
parameter_strikes.add((strike.operator, strike.value))
|
|
1324
|
+
|
|
1325
|
+
def _restore(self, restore: Restore) -> None:
|
|
1326
|
+
if restore.function and restore.parameter:
|
|
1327
|
+
try:
|
|
1328
|
+
task_strikes = self.task_strikes[restore.function]
|
|
1329
|
+
except KeyError:
|
|
1330
|
+
return
|
|
1331
|
+
|
|
1332
|
+
try:
|
|
1333
|
+
parameter_strikes = task_strikes[restore.parameter]
|
|
1334
|
+
except KeyError:
|
|
1335
|
+
task_strikes.pop(restore.parameter, None)
|
|
1336
|
+
return
|
|
1337
|
+
|
|
1338
|
+
try:
|
|
1339
|
+
parameter_strikes.remove((restore.operator, restore.value))
|
|
1340
|
+
except KeyError:
|
|
1341
|
+
pass
|
|
1342
|
+
|
|
1343
|
+
if not parameter_strikes:
|
|
1344
|
+
task_strikes.pop(restore.parameter, None)
|
|
1345
|
+
if not task_strikes:
|
|
1346
|
+
self.task_strikes.pop(restore.function, None)
|
|
1347
|
+
|
|
1348
|
+
elif restore.function:
|
|
1349
|
+
try:
|
|
1350
|
+
task_strikes = self.task_strikes[restore.function]
|
|
1351
|
+
except KeyError:
|
|
1352
|
+
return
|
|
1353
|
+
|
|
1354
|
+
# If there are no parameter strikes, this was a full task strike
|
|
1355
|
+
if not task_strikes:
|
|
1356
|
+
self.task_strikes.pop(restore.function, None)
|
|
1357
|
+
|
|
1358
|
+
elif restore.parameter: # pragma: no branch
|
|
1359
|
+
try:
|
|
1360
|
+
parameter_strikes = self.parameter_strikes[restore.parameter]
|
|
1361
|
+
except KeyError:
|
|
1362
|
+
return
|
|
1363
|
+
|
|
1364
|
+
try:
|
|
1365
|
+
parameter_strikes.remove((restore.operator, restore.value))
|
|
1366
|
+
except KeyError:
|
|
1367
|
+
pass
|
|
1368
|
+
|
|
1369
|
+
if not parameter_strikes:
|
|
1370
|
+
self.parameter_strikes.pop(restore.parameter, None)
|