pydocket 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydocket might be problematic. Click here for more details.

docket/dependencies.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import abc
2
2
  import inspect
3
+ import logging
3
4
  from datetime import timedelta
4
5
  from typing import Any, Awaitable, Callable, Counter, cast
5
6
 
@@ -35,10 +36,52 @@ def CurrentDocket() -> Docket:
35
36
  return cast(Docket, _CurrentDocket())
36
37
 
37
38
 
39
+ class _CurrentExecution(Dependency):
40
+ def __call__(
41
+ self, docket: Docket, worker: Worker, execution: Execution
42
+ ) -> Execution:
43
+ return execution
44
+
45
+
46
+ def CurrentExecution() -> Execution:
47
+ return cast(Execution, _CurrentExecution())
48
+
49
+
50
+ class _TaskKey(Dependency):
51
+ def __call__(self, docket: Docket, worker: Worker, execution: Execution) -> str:
52
+ return execution.key
53
+
54
+
55
+ def TaskKey() -> str:
56
+ return cast(str, _TaskKey())
57
+
58
+
59
+ class _TaskLogger(Dependency):
60
+ def __call__(
61
+ self, docket: Docket, worker: Worker, execution: Execution
62
+ ) -> logging.LoggerAdapter[logging.Logger]:
63
+ logger = logging.getLogger(f"docket.task.{execution.function.__name__}")
64
+
65
+ extra = {
66
+ "execution.key": execution.key,
67
+ "execution.attempt": execution.attempt,
68
+ "worker.name": worker.name,
69
+ "docket.name": docket.name,
70
+ }
71
+
72
+ return logging.LoggerAdapter(logger, extra)
73
+
74
+
75
+ def TaskLogger() -> logging.LoggerAdapter[logging.Logger]:
76
+ return cast(logging.LoggerAdapter[logging.Logger], _TaskLogger())
77
+
78
+
38
79
  class Retry(Dependency):
39
80
  single: bool = True
40
81
 
41
- def __init__(self, attempts: int = 1, delay: timedelta = timedelta(0)) -> None:
82
+ def __init__(
83
+ self, attempts: int | None = 1, delay: timedelta = timedelta(0)
84
+ ) -> None:
42
85
  self.attempts = attempts
43
86
  self.delay = delay
44
87
  self.attempt = 1
@@ -49,6 +92,41 @@ class Retry(Dependency):
49
92
  return retry
50
93
 
51
94
 
95
+ class ExponentialRetry(Retry):
96
+ attempts: int
97
+
98
+ def __init__(
99
+ self,
100
+ attempts: int = 1,
101
+ minimum_delay: timedelta = timedelta(seconds=1),
102
+ maximum_delay: timedelta = timedelta(seconds=64),
103
+ ) -> None:
104
+ super().__init__(attempts=attempts, delay=minimum_delay)
105
+ self.minimum_delay = minimum_delay
106
+ self.maximum_delay = maximum_delay
107
+
108
+ def __call__(
109
+ self, docket: Docket, worker: Worker, execution: Execution
110
+ ) -> "ExponentialRetry":
111
+ retry = ExponentialRetry(
112
+ attempts=self.attempts,
113
+ minimum_delay=self.minimum_delay,
114
+ maximum_delay=self.maximum_delay,
115
+ )
116
+ retry.attempt = execution.attempt
117
+
118
+ if execution.attempt > 1:
119
+ backoff_factor = 2 ** (execution.attempt - 1)
120
+ calculated_delay = self.minimum_delay * backoff_factor
121
+
122
+ if calculated_delay > self.maximum_delay:
123
+ retry.delay = self.maximum_delay
124
+ else:
125
+ retry.delay = calculated_delay
126
+
127
+ return retry
128
+
129
+
52
130
  def get_dependency_parameters(
53
131
  function: Callable[..., Awaitable[Any]],
54
132
  ) -> dict[str, Dependency]:
docket/docket.py CHANGED
@@ -1,45 +1,164 @@
1
+ import asyncio
2
+ import importlib
3
+ import logging
1
4
  from contextlib import asynccontextmanager
2
- from datetime import datetime, timezone
5
+ from dataclasses import dataclass
6
+ from datetime import datetime, timedelta, timezone
3
7
  from types import TracebackType
4
8
  from typing import (
5
9
  Any,
6
10
  AsyncGenerator,
7
11
  Awaitable,
8
12
  Callable,
13
+ Collection,
14
+ Hashable,
15
+ Iterable,
16
+ NoReturn,
9
17
  ParamSpec,
10
18
  Self,
19
+ Sequence,
20
+ TypedDict,
11
21
  TypeVar,
22
+ cast,
12
23
  overload,
13
24
  )
14
25
  from uuid import uuid4
15
26
 
27
+ import redis.exceptions
28
+ from opentelemetry import propagate, trace
16
29
  from redis.asyncio import Redis
17
30
 
18
- from .execution import Execution
31
+ from .execution import (
32
+ Execution,
33
+ LiteralOperator,
34
+ Operator,
35
+ Restore,
36
+ Strike,
37
+ StrikeInstruction,
38
+ StrikeList,
39
+ )
40
+ from .instrumentation import (
41
+ REDIS_DISRUPTIONS,
42
+ STRIKES_IN_EFFECT,
43
+ TASKS_ADDED,
44
+ TASKS_CANCELLED,
45
+ TASKS_REPLACED,
46
+ TASKS_SCHEDULED,
47
+ TASKS_STRICKEN,
48
+ message_setter,
49
+ )
50
+
51
+ logger: logging.Logger = logging.getLogger(__name__)
52
+ tracer: trace.Tracer = trace.get_tracer(__name__)
53
+
19
54
 
20
55
  P = ParamSpec("P")
21
56
  R = TypeVar("R")
22
57
 
58
+ TaskCollection = Iterable[Callable[..., Awaitable[Any]]]
59
+
60
+ RedisStreamID = bytes
61
+ RedisMessageID = bytes
62
+ RedisMessage = dict[bytes, bytes]
63
+ RedisMessages = Sequence[tuple[RedisMessageID, RedisMessage]]
64
+ RedisStream = tuple[RedisStreamID, RedisMessages]
65
+ RedisReadGroupResponse = Sequence[RedisStream]
66
+
67
+
68
+ class RedisStreamPendingMessage(TypedDict):
69
+ message_id: bytes
70
+ consumer: bytes
71
+ time_since_delivered: int
72
+ times_delivered: int
73
+
74
+
75
+ @dataclass
76
+ class WorkerInfo:
77
+ name: str
78
+ last_seen: datetime
79
+ tasks: set[str]
80
+
81
+
82
+ class RunningExecution(Execution):
83
+ worker: str
84
+ started: datetime
85
+
86
+ def __init__(
87
+ self,
88
+ execution: Execution,
89
+ worker: str,
90
+ started: datetime,
91
+ ) -> None:
92
+ self.function: Callable[..., Awaitable[Any]] = execution.function
93
+ self.args: tuple[Any, ...] = execution.args
94
+ self.kwargs: dict[str, Any] = execution.kwargs
95
+ self.when: datetime = execution.when
96
+ self.key: str = execution.key
97
+ self.attempt: int = execution.attempt
98
+ self.worker = worker
99
+ self.started = started
100
+
101
+
102
+ @dataclass
103
+ class DocketSnapshot:
104
+ taken: datetime
105
+ total_tasks: int
106
+ future: Sequence[Execution]
107
+ running: Sequence[RunningExecution]
108
+ workers: Collection[WorkerInfo]
109
+
23
110
 
24
111
  class Docket:
25
112
  tasks: dict[str, Callable[..., Awaitable[Any]]]
113
+ strike_list: StrikeList
26
114
 
27
115
  def __init__(
28
116
  self,
29
117
  name: str = "docket",
30
- host: str = "localhost",
31
- port: int = 6379,
32
- db: int = 0,
33
- password: str | None = None,
118
+ url: str = "redis://localhost:6379/0",
119
+ heartbeat_interval: timedelta = timedelta(seconds=1),
120
+ missed_heartbeats: int = 5,
34
121
  ) -> None:
122
+ """
123
+ Args:
124
+ name: The name of the docket.
125
+ url: The URL of the Redis server. For example:
126
+ - "redis://localhost:6379/0"
127
+ - "redis://user:password@localhost:6379/0"
128
+ - "redis://user:password@localhost:6379/0?ssl=true"
129
+ - "rediss://localhost:6379/0"
130
+ - "unix:///path/to/redis.sock"
131
+ """
35
132
  self.name = name
36
- self.host = host
37
- self.port = port
38
- self.db = db
39
- self.password = password
133
+ self.url = url
134
+ self.heartbeat_interval = heartbeat_interval
135
+ self.missed_heartbeats = missed_heartbeats
136
+
137
+ @property
138
+ def worker_group_name(self) -> str:
139
+ return "docket-workers"
40
140
 
41
141
  async def __aenter__(self) -> Self:
42
- self.tasks = {}
142
+ from .tasks import standard_tasks
143
+
144
+ self.tasks = {fn.__name__: fn for fn in standard_tasks}
145
+ self.strike_list = StrikeList()
146
+
147
+ self._monitor_strikes_task = asyncio.create_task(self._monitor_strikes())
148
+
149
+ # Ensure that the stream and worker group exist
150
+ async with self.redis() as r:
151
+ try:
152
+ await r.xgroup_create(
153
+ groupname=self.worker_group_name,
154
+ name=self.stream_key,
155
+ id="0-0",
156
+ mkstream=True,
157
+ )
158
+ except redis.exceptions.RedisError as e:
159
+ if "BUSYGROUP" not in repr(e):
160
+ raise
161
+
43
162
  return self
44
163
 
45
164
  async def __aexit__(
@@ -48,17 +167,18 @@ class Docket:
48
167
  exc_value: BaseException | None,
49
168
  traceback: TracebackType | None,
50
169
  ) -> None:
51
- pass
170
+ del self.tasks
171
+ del self.strike_list
172
+
173
+ self._monitor_strikes_task.cancel()
174
+ try:
175
+ await self._monitor_strikes_task
176
+ except asyncio.CancelledError:
177
+ pass
52
178
 
53
179
  @asynccontextmanager
54
180
  async def redis(self) -> AsyncGenerator[Redis, None]:
55
- async with Redis(
56
- host=self.host,
57
- port=self.port,
58
- db=self.db,
59
- password=self.password,
60
- single_connection_client=True,
61
- ) as redis:
181
+ async with Redis.from_url(self.url) as redis: # type: ignore
62
182
  yield redis
63
183
 
64
184
  def register(self, function: Callable[..., Awaitable[Any]]) -> None:
@@ -68,6 +188,19 @@ class Docket:
68
188
 
69
189
  self.tasks[function.__name__] = function
70
190
 
191
+ def register_collection(self, collection_path: str) -> None:
192
+ """
193
+ Register a collection of tasks.
194
+
195
+ Args:
196
+ collection_path: A path in the format "module:collection".
197
+ """
198
+ module_name, _, member_name = collection_path.rpartition(":")
199
+ module = importlib.import_module(module_name)
200
+ collection = getattr(module, member_name)
201
+ for function in collection:
202
+ self.register(function)
203
+
71
204
  @overload
72
205
  def add(
73
206
  self,
@@ -104,6 +237,9 @@ class Docket:
104
237
  async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
105
238
  execution = Execution(function, args, kwargs, when, key, attempt=1)
106
239
  await self.schedule(execution)
240
+
241
+ TASKS_ADDED.add(1, {"docket": self.name, "task": function.__name__})
242
+
107
243
  return execution
108
244
 
109
245
  return scheduler
@@ -137,6 +273,9 @@ class Docket:
137
273
  execution = Execution(function, args, kwargs, when, key, attempt=1)
138
274
  await self.cancel(key)
139
275
  await self.schedule(execution)
276
+
277
+ TASKS_REPLACED.add(1, {"docket": self.name, "task": function.__name__})
278
+
140
279
  return execution
141
280
 
142
281
  return scheduler
@@ -153,26 +292,304 @@ class Docket:
153
292
  return f"{self.name}:{key}"
154
293
 
155
294
  async def schedule(self, execution: Execution) -> None:
295
+ if self.strike_list.is_stricken(execution):
296
+ logger.warning(
297
+ "%r is stricken, skipping schedule of %r",
298
+ execution.function.__name__,
299
+ execution.key,
300
+ )
301
+ TASKS_STRICKEN.add(
302
+ 1,
303
+ {
304
+ "docket": self.name,
305
+ "task": execution.function.__name__,
306
+ "where": "docket",
307
+ },
308
+ )
309
+ return
310
+
156
311
  message: dict[bytes, bytes] = execution.as_message()
157
- key = execution.key
158
- when = execution.when
312
+ propagate.inject(message, setter=message_setter)
159
313
 
160
- async with self.redis() as redis:
161
- # if the task is already in the queue, retain it
162
- if await redis.zscore(self.queue_key, key) is not None:
163
- return
314
+ with tracer.start_as_current_span(
315
+ "docket.schedule",
316
+ attributes={
317
+ "docket.name": self.name,
318
+ "docket.execution.when": execution.when.isoformat(),
319
+ "docket.execution.key": execution.key,
320
+ "docket.execution.attempt": execution.attempt,
321
+ "code.function.name": execution.function.__name__,
322
+ },
323
+ ):
324
+ key = execution.key
325
+ when = execution.when
164
326
 
165
- if when <= datetime.now(timezone.utc):
166
- await redis.xadd(self.stream_key, message)
167
- else:
327
+ async with self.redis() as redis:
328
+ # if the task is already in the queue, retain it
329
+ if await redis.zscore(self.queue_key, key) is not None:
330
+ return
331
+
332
+ if when <= datetime.now(timezone.utc):
333
+ await redis.xadd(self.stream_key, message) # type: ignore[arg-type]
334
+ else:
335
+ async with redis.pipeline() as pipe:
336
+ pipe.hset(self.parked_task_key(key), mapping=message) # type: ignore[arg-type]
337
+ pipe.zadd(self.queue_key, {key: when.timestamp()})
338
+ await pipe.execute()
339
+
340
+ TASKS_SCHEDULED.add(
341
+ 1, {"docket": self.name, "task": execution.function.__name__}
342
+ )
343
+
344
+ async def cancel(self, key: str) -> None:
345
+ with tracer.start_as_current_span(
346
+ "docket.cancel",
347
+ attributes={
348
+ "docket.name": self.name,
349
+ "docket.execution.key": key,
350
+ },
351
+ ):
352
+ async with self.redis() as redis:
168
353
  async with redis.pipeline() as pipe:
169
- pipe.hset(self.parked_task_key(key), mapping=message)
170
- pipe.zadd(self.queue_key, {key: when.timestamp()})
354
+ pipe.delete(self.parked_task_key(key))
355
+ pipe.zrem(self.queue_key, key)
171
356
  await pipe.execute()
172
357
 
173
- async def cancel(self, key: str) -> None:
174
- async with self.redis() as redis:
175
- async with redis.pipeline() as pipe:
176
- pipe.delete(self.parked_task_key(key))
177
- pipe.zrem(self.queue_key, key)
178
- await pipe.execute()
358
+ TASKS_CANCELLED.add(1, {"docket": self.name})
359
+
360
+ @property
361
+ def strike_key(self) -> str:
362
+ return f"{self.name}:strikes"
363
+
364
+ async def strike(
365
+ self,
366
+ function: Callable[P, Awaitable[R]] | str | None = None,
367
+ parameter: str | None = None,
368
+ operator: Operator | LiteralOperator = "==",
369
+ value: Hashable | None = None,
370
+ ) -> None:
371
+ if not isinstance(function, (str, type(None))):
372
+ function = function.__name__
373
+
374
+ operator = Operator(operator)
375
+
376
+ strike = Strike(function, parameter, operator, value)
377
+ return await self._send_strike_instruction(strike)
378
+
379
+ async def restore(
380
+ self,
381
+ function: Callable[P, Awaitable[R]] | str | None = None,
382
+ parameter: str | None = None,
383
+ operator: Operator | LiteralOperator = "==",
384
+ value: Hashable | None = None,
385
+ ) -> None:
386
+ if not isinstance(function, (str, type(None))):
387
+ function = function.__name__
388
+
389
+ operator = Operator(operator)
390
+
391
+ restore = Restore(function, parameter, operator, value)
392
+ return await self._send_strike_instruction(restore)
393
+
394
+ async def _send_strike_instruction(self, instruction: StrikeInstruction) -> None:
395
+ with tracer.start_as_current_span(
396
+ f"docket.{instruction.direction}",
397
+ attributes={
398
+ "docket.name": self.name,
399
+ **instruction.as_span_attributes(),
400
+ },
401
+ ):
402
+ async with self.redis() as redis:
403
+ message = instruction.as_message()
404
+ await redis.xadd(self.strike_key, message) # type: ignore[arg-type]
405
+ self.strike_list.update(instruction)
406
+
407
+ async def _monitor_strikes(self) -> NoReturn:
408
+ last_id = "0-0"
409
+ while True:
410
+ try:
411
+ async with self.redis() as r:
412
+ while True:
413
+ streams: RedisReadGroupResponse = await r.xread(
414
+ {self.strike_key: last_id},
415
+ count=100,
416
+ block=60_000,
417
+ )
418
+ for _, messages in streams:
419
+ for message_id, message in messages:
420
+ last_id = message_id
421
+ instruction = StrikeInstruction.from_message(message)
422
+ self.strike_list.update(instruction)
423
+ logger.info(
424
+ "%s %r",
425
+ (
426
+ "Striking"
427
+ if instruction.direction == "strike"
428
+ else "Restoring"
429
+ ),
430
+ instruction.call_repr(),
431
+ extra={"docket": self.name},
432
+ )
433
+
434
+ counter_labels = {"docket": self.name}
435
+ if instruction.function:
436
+ counter_labels["task"] = instruction.function
437
+ if instruction.parameter:
438
+ counter_labels["parameter"] = instruction.parameter
439
+
440
+ STRIKES_IN_EFFECT.add(
441
+ 1 if instruction.direction == "strike" else -1,
442
+ counter_labels,
443
+ )
444
+
445
+ except redis.exceptions.ConnectionError: # pragma: no cover
446
+ REDIS_DISRUPTIONS.add(1, {"docket": self.name})
447
+ logger.warning("Connection error, sleeping for 1 second...")
448
+ await asyncio.sleep(1)
449
+ except Exception: # pragma: no cover
450
+ logger.exception("Error monitoring strikes")
451
+ await asyncio.sleep(1)
452
+
453
+ async def snapshot(self) -> DocketSnapshot:
454
+ running: list[RunningExecution] = []
455
+ future: list[Execution] = []
456
+
457
+ async with self.redis() as r:
458
+ async with r.pipeline() as pipeline:
459
+ pipeline.xlen(self.stream_key)
460
+
461
+ pipeline.zcard(self.queue_key)
462
+
463
+ pipeline.xpending_range(
464
+ self.stream_key,
465
+ self.worker_group_name,
466
+ min="-",
467
+ max="+",
468
+ count=1000,
469
+ )
470
+
471
+ pipeline.xrange(self.stream_key, "-", "+", count=1000)
472
+
473
+ pipeline.zrange(self.queue_key, 0, -1)
474
+
475
+ total_stream_messages: int
476
+ total_schedule_messages: int
477
+ pending_messages: list[RedisStreamPendingMessage]
478
+ stream_messages: list[tuple[RedisMessageID, RedisMessage]]
479
+ scheduled_task_keys: list[bytes]
480
+
481
+ now = datetime.now(timezone.utc)
482
+ (
483
+ total_stream_messages,
484
+ total_schedule_messages,
485
+ pending_messages,
486
+ stream_messages,
487
+ scheduled_task_keys,
488
+ ) = await pipeline.execute()
489
+
490
+ for task_key in scheduled_task_keys:
491
+ pipeline.hgetall(self.parked_task_key(task_key.decode()))
492
+
493
+ # Because these are two separate pipeline commands, it's possible that
494
+ # a message has been moved from the schedule to the stream in the
495
+ # meantime, which would end up being an empty `{}` message
496
+ queued_messages: list[RedisMessage] = [
497
+ m for m in await pipeline.execute() if m
498
+ ]
499
+
500
+ total_tasks = total_stream_messages + total_schedule_messages
501
+
502
+ pending_lookup: dict[RedisMessageID, RedisStreamPendingMessage] = {
503
+ pending["message_id"]: pending for pending in pending_messages
504
+ }
505
+
506
+ for message_id, message in stream_messages:
507
+ function = self.tasks[message[b"function"].decode()]
508
+ execution = Execution.from_message(function, message)
509
+ if message_id in pending_lookup:
510
+ worker_name = pending_lookup[message_id]["consumer"].decode()
511
+ started = now - timedelta(
512
+ milliseconds=pending_lookup[message_id]["time_since_delivered"]
513
+ )
514
+ running.append(RunningExecution(execution, worker_name, started))
515
+ else:
516
+ future.append(execution)
517
+
518
+ for message in queued_messages:
519
+ function = self.tasks[message[b"function"].decode()]
520
+ execution = Execution.from_message(function, message)
521
+ future.append(execution)
522
+
523
+ workers = await self.workers()
524
+
525
+ return DocketSnapshot(now, total_tasks, future, running, workers)
526
+
527
+ @property
528
+ def workers_set(self) -> str:
529
+ return f"{self.name}:workers"
530
+
531
+ def worker_tasks_set(self, worker_name: str) -> str:
532
+ return f"{self.name}:worker-tasks:{worker_name}"
533
+
534
+ def task_workers_set(self, task_name: str) -> str:
535
+ return f"{self.name}:task-workers:{task_name}"
536
+
537
+ async def workers(self) -> Collection[WorkerInfo]:
538
+ workers: list[WorkerInfo] = []
539
+
540
+ oldest = datetime.now(timezone.utc).timestamp() - (
541
+ self.heartbeat_interval.total_seconds() * self.missed_heartbeats
542
+ )
543
+
544
+ async with self.redis() as r:
545
+ await r.zremrangebyscore(self.workers_set, 0, oldest)
546
+
547
+ worker_name_bytes: bytes
548
+ last_seen_timestamp: float
549
+
550
+ for worker_name_bytes, last_seen_timestamp in await r.zrange(
551
+ self.workers_set, 0, -1, withscores=True
552
+ ):
553
+ worker_name = worker_name_bytes.decode()
554
+ last_seen = datetime.fromtimestamp(last_seen_timestamp, timezone.utc)
555
+
556
+ task_names: set[str] = {
557
+ task_name_bytes.decode()
558
+ for task_name_bytes in cast(
559
+ set[bytes], await r.smembers(self.worker_tasks_set(worker_name))
560
+ )
561
+ }
562
+
563
+ workers.append(WorkerInfo(worker_name, last_seen, task_names))
564
+
565
+ return workers
566
+
567
+ async def task_workers(self, task_name: str) -> Collection[WorkerInfo]:
568
+ workers: list[WorkerInfo] = []
569
+
570
+ oldest = datetime.now(timezone.utc).timestamp() - (
571
+ self.heartbeat_interval.total_seconds() * self.missed_heartbeats
572
+ )
573
+
574
+ async with self.redis() as r:
575
+ await r.zremrangebyscore(self.task_workers_set(task_name), 0, oldest)
576
+
577
+ worker_name_bytes: bytes
578
+ last_seen_timestamp: float
579
+
580
+ for worker_name_bytes, last_seen_timestamp in await r.zrange(
581
+ self.task_workers_set(task_name), 0, -1, withscores=True
582
+ ):
583
+ worker_name = worker_name_bytes.decode()
584
+ last_seen = datetime.fromtimestamp(last_seen_timestamp, timezone.utc)
585
+
586
+ task_names: set[str] = {
587
+ task_name_bytes.decode()
588
+ for task_name_bytes in cast(
589
+ set[bytes], await r.smembers(self.worker_tasks_set(worker_name))
590
+ )
591
+ }
592
+
593
+ workers.append(WorkerInfo(worker_name, last_seen, task_names))
594
+
595
+ return workers