pydocket 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydocket might be problematic. Click here for more details.

docket/docket.py CHANGED
@@ -1,32 +1,54 @@
1
+ import asyncio
1
2
  import importlib
3
+ import logging
2
4
  from contextlib import asynccontextmanager
3
- from datetime import datetime, timezone
5
+ from dataclasses import dataclass
6
+ from datetime import datetime, timedelta, timezone
4
7
  from types import TracebackType
5
8
  from typing import (
6
9
  Any,
7
10
  AsyncGenerator,
8
11
  Awaitable,
9
12
  Callable,
13
+ Collection,
14
+ Hashable,
10
15
  Iterable,
16
+ NoReturn,
11
17
  ParamSpec,
12
18
  Self,
19
+ Sequence,
20
+ TypedDict,
13
21
  TypeVar,
22
+ cast,
14
23
  overload,
15
24
  )
16
25
  from uuid import uuid4
17
26
 
27
+ import redis.exceptions
18
28
  from opentelemetry import propagate, trace
19
29
  from redis.asyncio import Redis
20
30
 
21
- from .execution import Execution
31
+ from .execution import (
32
+ Execution,
33
+ LiteralOperator,
34
+ Operator,
35
+ Restore,
36
+ Strike,
37
+ StrikeInstruction,
38
+ StrikeList,
39
+ )
22
40
  from .instrumentation import (
41
+ REDIS_DISRUPTIONS,
42
+ STRIKES_IN_EFFECT,
23
43
  TASKS_ADDED,
24
44
  TASKS_CANCELLED,
25
45
  TASKS_REPLACED,
26
46
  TASKS_SCHEDULED,
47
+ TASKS_STRICKEN,
27
48
  message_setter,
28
49
  )
29
50
 
51
+ logger: logging.Logger = logging.getLogger(__name__)
30
52
  tracer: trace.Tracer = trace.get_tracer(__name__)
31
53
 
32
54
 
@@ -35,14 +57,67 @@ R = TypeVar("R")
35
57
 
36
58
  TaskCollection = Iterable[Callable[..., Awaitable[Any]]]
37
59
 
60
+ RedisStreamID = bytes
61
+ RedisMessageID = bytes
62
+ RedisMessage = dict[bytes, bytes]
63
+ RedisMessages = Sequence[tuple[RedisMessageID, RedisMessage]]
64
+ RedisStream = tuple[RedisStreamID, RedisMessages]
65
+ RedisReadGroupResponse = Sequence[RedisStream]
66
+
67
+
68
+ class RedisStreamPendingMessage(TypedDict):
69
+ message_id: bytes
70
+ consumer: bytes
71
+ time_since_delivered: int
72
+ times_delivered: int
73
+
74
+
75
+ @dataclass
76
+ class WorkerInfo:
77
+ name: str
78
+ last_seen: datetime
79
+ tasks: set[str]
80
+
81
+
82
+ class RunningExecution(Execution):
83
+ worker: str
84
+ started: datetime
85
+
86
+ def __init__(
87
+ self,
88
+ execution: Execution,
89
+ worker: str,
90
+ started: datetime,
91
+ ) -> None:
92
+ self.function: Callable[..., Awaitable[Any]] = execution.function
93
+ self.args: tuple[Any, ...] = execution.args
94
+ self.kwargs: dict[str, Any] = execution.kwargs
95
+ self.when: datetime = execution.when
96
+ self.key: str = execution.key
97
+ self.attempt: int = execution.attempt
98
+ self.worker = worker
99
+ self.started = started
100
+
101
+
102
+ @dataclass
103
+ class DocketSnapshot:
104
+ taken: datetime
105
+ total_tasks: int
106
+ future: Sequence[Execution]
107
+ running: Sequence[RunningExecution]
108
+ workers: Collection[WorkerInfo]
109
+
38
110
 
39
111
  class Docket:
40
112
  tasks: dict[str, Callable[..., Awaitable[Any]]]
113
+ strike_list: StrikeList
41
114
 
42
115
  def __init__(
43
116
  self,
44
117
  name: str = "docket",
45
118
  url: str = "redis://localhost:6379/0",
119
+ heartbeat_interval: timedelta = timedelta(seconds=1),
120
+ missed_heartbeats: int = 5,
46
121
  ) -> None:
47
122
  """
48
123
  Args:
@@ -56,11 +131,33 @@ class Docket:
56
131
  """
57
132
  self.name = name
58
133
  self.url = url
134
+ self.heartbeat_interval = heartbeat_interval
135
+ self.missed_heartbeats = missed_heartbeats
136
+
137
+ @property
138
+ def worker_group_name(self) -> str:
139
+ return "docket-workers"
59
140
 
60
141
  async def __aenter__(self) -> Self:
61
142
  from .tasks import standard_tasks
62
143
 
63
144
  self.tasks = {fn.__name__: fn for fn in standard_tasks}
145
+ self.strike_list = StrikeList()
146
+
147
+ self._monitor_strikes_task = asyncio.create_task(self._monitor_strikes())
148
+
149
+ # Ensure that the stream and worker group exist
150
+ async with self.redis() as r:
151
+ try:
152
+ await r.xgroup_create(
153
+ groupname=self.worker_group_name,
154
+ name=self.stream_key,
155
+ id="0-0",
156
+ mkstream=True,
157
+ )
158
+ except redis.exceptions.RedisError as e:
159
+ if "BUSYGROUP" not in repr(e):
160
+ raise
64
161
 
65
162
  return self
66
163
 
@@ -70,11 +167,18 @@ class Docket:
70
167
  exc_value: BaseException | None,
71
168
  traceback: TracebackType | None,
72
169
  ) -> None:
73
- pass
170
+ del self.tasks
171
+ del self.strike_list
172
+
173
+ self._monitor_strikes_task.cancel()
174
+ try:
175
+ await self._monitor_strikes_task
176
+ except asyncio.CancelledError:
177
+ pass
74
178
 
75
179
  @asynccontextmanager
76
180
  async def redis(self) -> AsyncGenerator[Redis, None]:
77
- async with Redis.from_url(self.url) as redis:
181
+ async with Redis.from_url(self.url) as redis: # type: ignore
78
182
  yield redis
79
183
 
80
184
  def register(self, function: Callable[..., Awaitable[Any]]) -> None:
@@ -188,6 +292,22 @@ class Docket:
188
292
  return f"{self.name}:{key}"
189
293
 
190
294
  async def schedule(self, execution: Execution) -> None:
295
+ if self.strike_list.is_stricken(execution):
296
+ logger.warning(
297
+ "%r is stricken, skipping schedule of %r",
298
+ execution.function.__name__,
299
+ execution.key,
300
+ )
301
+ TASKS_STRICKEN.add(
302
+ 1,
303
+ {
304
+ "docket": self.name,
305
+ "task": execution.function.__name__,
306
+ "where": "docket",
307
+ },
308
+ )
309
+ return
310
+
191
311
  message: dict[bytes, bytes] = execution.as_message()
192
312
  propagate.inject(message, setter=message_setter)
193
313
 
@@ -210,10 +330,10 @@ class Docket:
210
330
  return
211
331
 
212
332
  if when <= datetime.now(timezone.utc):
213
- await redis.xadd(self.stream_key, message)
333
+ await redis.xadd(self.stream_key, message) # type: ignore[arg-type]
214
334
  else:
215
335
  async with redis.pipeline() as pipe:
216
- pipe.hset(self.parked_task_key(key), mapping=message)
336
+ pipe.hset(self.parked_task_key(key), mapping=message) # type: ignore[arg-type]
217
337
  pipe.zadd(self.queue_key, {key: when.timestamp()})
218
338
  await pipe.execute()
219
339
 
@@ -236,3 +356,240 @@ class Docket:
236
356
  await pipe.execute()
237
357
 
238
358
  TASKS_CANCELLED.add(1, {"docket": self.name})
359
+
360
+ @property
361
+ def strike_key(self) -> str:
362
+ return f"{self.name}:strikes"
363
+
364
+ async def strike(
365
+ self,
366
+ function: Callable[P, Awaitable[R]] | str | None = None,
367
+ parameter: str | None = None,
368
+ operator: Operator | LiteralOperator = "==",
369
+ value: Hashable | None = None,
370
+ ) -> None:
371
+ if not isinstance(function, (str, type(None))):
372
+ function = function.__name__
373
+
374
+ operator = Operator(operator)
375
+
376
+ strike = Strike(function, parameter, operator, value)
377
+ return await self._send_strike_instruction(strike)
378
+
379
+ async def restore(
380
+ self,
381
+ function: Callable[P, Awaitable[R]] | str | None = None,
382
+ parameter: str | None = None,
383
+ operator: Operator | LiteralOperator = "==",
384
+ value: Hashable | None = None,
385
+ ) -> None:
386
+ if not isinstance(function, (str, type(None))):
387
+ function = function.__name__
388
+
389
+ operator = Operator(operator)
390
+
391
+ restore = Restore(function, parameter, operator, value)
392
+ return await self._send_strike_instruction(restore)
393
+
394
+ async def _send_strike_instruction(self, instruction: StrikeInstruction) -> None:
395
+ with tracer.start_as_current_span(
396
+ f"docket.{instruction.direction}",
397
+ attributes={
398
+ "docket.name": self.name,
399
+ **instruction.as_span_attributes(),
400
+ },
401
+ ):
402
+ async with self.redis() as redis:
403
+ message = instruction.as_message()
404
+ await redis.xadd(self.strike_key, message) # type: ignore[arg-type]
405
+ self.strike_list.update(instruction)
406
+
407
+ async def _monitor_strikes(self) -> NoReturn:
408
+ last_id = "0-0"
409
+ while True:
410
+ try:
411
+ async with self.redis() as r:
412
+ while True:
413
+ streams: RedisReadGroupResponse = await r.xread(
414
+ {self.strike_key: last_id},
415
+ count=100,
416
+ block=60_000,
417
+ )
418
+ for _, messages in streams:
419
+ for message_id, message in messages:
420
+ last_id = message_id
421
+ instruction = StrikeInstruction.from_message(message)
422
+ self.strike_list.update(instruction)
423
+ logger.info(
424
+ "%s %r",
425
+ (
426
+ "Striking"
427
+ if instruction.direction == "strike"
428
+ else "Restoring"
429
+ ),
430
+ instruction.call_repr(),
431
+ extra={"docket": self.name},
432
+ )
433
+
434
+ counter_labels = {"docket": self.name}
435
+ if instruction.function:
436
+ counter_labels["task"] = instruction.function
437
+ if instruction.parameter:
438
+ counter_labels["parameter"] = instruction.parameter
439
+
440
+ STRIKES_IN_EFFECT.add(
441
+ 1 if instruction.direction == "strike" else -1,
442
+ counter_labels,
443
+ )
444
+
445
+ except redis.exceptions.ConnectionError: # pragma: no cover
446
+ REDIS_DISRUPTIONS.add(1, {"docket": self.name})
447
+ logger.warning("Connection error, sleeping for 1 second...")
448
+ await asyncio.sleep(1)
449
+ except Exception: # pragma: no cover
450
+ logger.exception("Error monitoring strikes")
451
+ await asyncio.sleep(1)
452
+
453
+ async def snapshot(self) -> DocketSnapshot:
454
+ running: list[RunningExecution] = []
455
+ future: list[Execution] = []
456
+
457
+ async with self.redis() as r:
458
+ async with r.pipeline() as pipeline:
459
+ pipeline.xlen(self.stream_key)
460
+
461
+ pipeline.zcard(self.queue_key)
462
+
463
+ pipeline.xpending_range(
464
+ self.stream_key,
465
+ self.worker_group_name,
466
+ min="-",
467
+ max="+",
468
+ count=1000,
469
+ )
470
+
471
+ pipeline.xrange(self.stream_key, "-", "+", count=1000)
472
+
473
+ pipeline.zrange(self.queue_key, 0, -1)
474
+
475
+ total_stream_messages: int
476
+ total_schedule_messages: int
477
+ pending_messages: list[RedisStreamPendingMessage]
478
+ stream_messages: list[tuple[RedisMessageID, RedisMessage]]
479
+ scheduled_task_keys: list[bytes]
480
+
481
+ now = datetime.now(timezone.utc)
482
+ (
483
+ total_stream_messages,
484
+ total_schedule_messages,
485
+ pending_messages,
486
+ stream_messages,
487
+ scheduled_task_keys,
488
+ ) = await pipeline.execute()
489
+
490
+ for task_key in scheduled_task_keys:
491
+ pipeline.hgetall(self.parked_task_key(task_key.decode()))
492
+
493
+ # Because these are two separate pipeline commands, it's possible that
494
+ # a message has been moved from the schedule to the stream in the
495
+ # meantime, which would end up being an empty `{}` message
496
+ queued_messages: list[RedisMessage] = [
497
+ m for m in await pipeline.execute() if m
498
+ ]
499
+
500
+ total_tasks = total_stream_messages + total_schedule_messages
501
+
502
+ pending_lookup: dict[RedisMessageID, RedisStreamPendingMessage] = {
503
+ pending["message_id"]: pending for pending in pending_messages
504
+ }
505
+
506
+ for message_id, message in stream_messages:
507
+ function = self.tasks[message[b"function"].decode()]
508
+ execution = Execution.from_message(function, message)
509
+ if message_id in pending_lookup:
510
+ worker_name = pending_lookup[message_id]["consumer"].decode()
511
+ started = now - timedelta(
512
+ milliseconds=pending_lookup[message_id]["time_since_delivered"]
513
+ )
514
+ running.append(RunningExecution(execution, worker_name, started))
515
+ else:
516
+ future.append(execution)
517
+
518
+ for message in queued_messages:
519
+ function = self.tasks[message[b"function"].decode()]
520
+ execution = Execution.from_message(function, message)
521
+ future.append(execution)
522
+
523
+ workers = await self.workers()
524
+
525
+ return DocketSnapshot(now, total_tasks, future, running, workers)
526
+
527
+ @property
528
+ def workers_set(self) -> str:
529
+ return f"{self.name}:workers"
530
+
531
+ def worker_tasks_set(self, worker_name: str) -> str:
532
+ return f"{self.name}:worker-tasks:{worker_name}"
533
+
534
+ def task_workers_set(self, task_name: str) -> str:
535
+ return f"{self.name}:task-workers:{task_name}"
536
+
537
+ async def workers(self) -> Collection[WorkerInfo]:
538
+ workers: list[WorkerInfo] = []
539
+
540
+ oldest = datetime.now(timezone.utc).timestamp() - (
541
+ self.heartbeat_interval.total_seconds() * self.missed_heartbeats
542
+ )
543
+
544
+ async with self.redis() as r:
545
+ await r.zremrangebyscore(self.workers_set, 0, oldest)
546
+
547
+ worker_name_bytes: bytes
548
+ last_seen_timestamp: float
549
+
550
+ for worker_name_bytes, last_seen_timestamp in await r.zrange(
551
+ self.workers_set, 0, -1, withscores=True
552
+ ):
553
+ worker_name = worker_name_bytes.decode()
554
+ last_seen = datetime.fromtimestamp(last_seen_timestamp, timezone.utc)
555
+
556
+ task_names: set[str] = {
557
+ task_name_bytes.decode()
558
+ for task_name_bytes in cast(
559
+ set[bytes], await r.smembers(self.worker_tasks_set(worker_name))
560
+ )
561
+ }
562
+
563
+ workers.append(WorkerInfo(worker_name, last_seen, task_names))
564
+
565
+ return workers
566
+
567
+ async def task_workers(self, task_name: str) -> Collection[WorkerInfo]:
568
+ workers: list[WorkerInfo] = []
569
+
570
+ oldest = datetime.now(timezone.utc).timestamp() - (
571
+ self.heartbeat_interval.total_seconds() * self.missed_heartbeats
572
+ )
573
+
574
+ async with self.redis() as r:
575
+ await r.zremrangebyscore(self.task_workers_set(task_name), 0, oldest)
576
+
577
+ worker_name_bytes: bytes
578
+ last_seen_timestamp: float
579
+
580
+ for worker_name_bytes, last_seen_timestamp in await r.zrange(
581
+ self.task_workers_set(task_name), 0, -1, withscores=True
582
+ ):
583
+ worker_name = worker_name_bytes.decode()
584
+ last_seen = datetime.fromtimestamp(last_seen_timestamp, timezone.utc)
585
+
586
+ task_names: set[str] = {
587
+ task_name_bytes.decode()
588
+ for task_name_bytes in cast(
589
+ set[bytes], await r.smembers(self.worker_tasks_set(worker_name))
590
+ )
591
+ }
592
+
593
+ workers.append(WorkerInfo(worker_name, last_seen, task_names))
594
+
595
+ return workers