pydocket 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydocket might be problematic. Click here for more details.

docket/worker.py CHANGED
@@ -9,7 +9,6 @@ from typing import (
9
9
  Any,
10
10
  Protocol,
11
11
  Self,
12
- Sequence,
13
12
  TypeVar,
14
13
  cast,
15
14
  )
@@ -18,10 +17,17 @@ from uuid import uuid4
18
17
  import redis.exceptions
19
18
  from opentelemetry import propagate, trace
20
19
  from opentelemetry.trace import Tracer
21
- from redis import RedisError
22
20
 
23
- from .docket import Docket, Execution
21
+ from .docket import (
22
+ Docket,
23
+ Execution,
24
+ RedisMessage,
25
+ RedisMessageID,
26
+ RedisMessages,
27
+ RedisReadGroupResponse,
28
+ )
24
29
  from .instrumentation import (
30
+ REDIS_DISRUPTIONS,
25
31
  TASK_DURATION,
26
32
  TASK_PUNCTUALITY,
27
33
  TASKS_COMPLETED,
@@ -29,6 +35,7 @@ from .instrumentation import (
29
35
  TASKS_RETRIED,
30
36
  TASKS_RUNNING,
31
37
  TASKS_STARTED,
38
+ TASKS_STRICKEN,
32
39
  TASKS_SUCCEEDED,
33
40
  message_getter,
34
41
  )
@@ -37,12 +44,6 @@ logger: logging.Logger = logging.getLogger(__name__)
37
44
  tracer: Tracer = trace.get_tracer(__name__)
38
45
 
39
46
 
40
- RedisStreamID = bytes
41
- RedisMessageID = bytes
42
- RedisMessage = dict[bytes, bytes]
43
- RedisStream = tuple[RedisStreamID, Sequence[tuple[RedisMessageID, RedisMessage]]]
44
- RedisReadGroupResponse = Sequence[RedisStream]
45
-
46
47
  if TYPE_CHECKING: # pragma: no cover
47
48
  from .dependencies import Dependency
48
49
 
@@ -63,28 +64,20 @@ class Worker:
63
64
  self,
64
65
  docket: Docket,
65
66
  name: str | None = None,
66
- prefetch_count: int = 10,
67
+ concurrency: int = 10,
67
68
  redelivery_timeout: timedelta = timedelta(minutes=5),
68
69
  reconnection_delay: timedelta = timedelta(seconds=5),
70
+ minimum_check_interval: timedelta = timedelta(milliseconds=10),
69
71
  ) -> None:
70
72
  self.docket = docket
71
73
  self.name = name or f"worker:{uuid4()}"
72
- self.prefetch_count = prefetch_count
74
+ self.concurrency = concurrency
73
75
  self.redelivery_timeout = redelivery_timeout
74
76
  self.reconnection_delay = reconnection_delay
77
+ self.minimum_check_interval = minimum_check_interval
75
78
 
76
79
  async def __aenter__(self) -> Self:
77
- async with self.docket.redis() as redis:
78
- try:
79
- await redis.xgroup_create(
80
- groupname=self.consumer_group_name,
81
- name=self.docket.stream_key,
82
- id="0-0",
83
- mkstream=True,
84
- )
85
- except RedisError as e:
86
- if "BUSYGROUP" not in repr(e):
87
- raise
80
+ self._heartbeat_task = asyncio.create_task(self._heartbeat())
88
81
 
89
82
  return self
90
83
 
@@ -94,11 +87,12 @@ class Worker:
94
87
  exc_value: BaseException | None,
95
88
  traceback: TracebackType | None,
96
89
  ) -> None:
97
- pass
98
-
99
- @property
100
- def consumer_group_name(self) -> str:
101
- return "docket"
90
+ self._heartbeat_task.cancel()
91
+ try:
92
+ await self._heartbeat_task
93
+ except asyncio.CancelledError:
94
+ pass
95
+ del self._heartbeat_task
102
96
 
103
97
  @property
104
98
  def _log_context(self) -> dict[str, str]:
@@ -113,7 +107,7 @@ class Worker:
113
107
  docket_name: str = "docket",
114
108
  url: str = "redis://localhost:6379/0",
115
109
  name: str | None = None,
116
- prefetch_count: int = 10,
110
+ concurrency: int = 10,
117
111
  redelivery_timeout: timedelta = timedelta(minutes=5),
118
112
  reconnection_delay: timedelta = timedelta(seconds=5),
119
113
  until_finished: bool = False,
@@ -126,7 +120,7 @@ class Worker:
126
120
  async with Worker(
127
121
  docket=docket,
128
122
  name=name,
129
- prefetch_count=prefetch_count,
123
+ concurrency=concurrency,
130
124
  redelivery_timeout=redelivery_timeout,
131
125
  reconnection_delay=reconnection_delay,
132
126
  ) as worker:
@@ -153,6 +147,9 @@ class Worker:
153
147
  try:
154
148
  return await self._worker_loop(forever=forever)
155
149
  except redis.exceptions.ConnectionError:
150
+ REDIS_DISRUPTIONS.add(
151
+ 1, {"docket": self.docket.name, "worker": self.name}
152
+ )
156
153
  logger.warning(
157
154
  "Error connecting to redis, retrying in %s...",
158
155
  self.reconnection_delay,
@@ -207,67 +204,121 @@ class Worker:
207
204
  ),
208
205
  )
209
206
 
210
- total_work, due_work = sys.maxsize, 0
211
- while forever or total_work:
212
- now = datetime.now(timezone.utc)
213
- total_work, due_work = await stream_due_tasks(
214
- keys=[self.docket.queue_key, self.docket.stream_key],
215
- args=[now.timestamp(), self.docket.name],
216
- )
217
- if due_work > 0:
218
- logger.debug(
219
- "Moved %d/%d due tasks from %s to %s",
220
- due_work,
221
- total_work,
222
- self.docket.queue_key,
223
- self.docket.stream_key,
224
- extra=self._log_context,
225
- )
207
+ active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
226
208
 
227
- _, redeliveries, _ = await redis.xautoclaim(
228
- name=self.docket.stream_key,
229
- groupname=self.consumer_group_name,
230
- consumername=self.name,
231
- min_idle_time=int(self.redelivery_timeout.total_seconds() * 1000),
232
- start_id="0-0",
233
- count=self.prefetch_count,
234
- )
209
+ async def process_completed_tasks() -> None:
210
+ completed_tasks = {task for task in active_tasks if task.done()}
211
+ for task in completed_tasks:
212
+ message_id = active_tasks.pop(task)
235
213
 
236
- new_deliveries: RedisReadGroupResponse = await redis.xreadgroup(
237
- groupname=self.consumer_group_name,
238
- consumername=self.name,
239
- streams={self.docket.stream_key: ">"},
240
- count=self.prefetch_count,
241
- block=10,
242
- )
214
+ await task
215
+
216
+ async with redis.pipeline() as pipeline:
217
+ pipeline.xack(
218
+ self.docket.stream_key,
219
+ self.docket.worker_group_name,
220
+ message_id,
221
+ )
222
+ pipeline.xdel(
223
+ self.docket.stream_key,
224
+ message_id,
225
+ )
226
+ await pipeline.execute()
227
+
228
+ future_work, due_work = sys.maxsize, 0
229
+
230
+ try:
231
+ while forever or future_work or active_tasks:
232
+ await process_completed_tasks()
233
+
234
+ available_slots = self.concurrency - len(active_tasks)
235
+
236
+ def start_task(
237
+ message_id: RedisMessageID, message: RedisMessage
238
+ ) -> None:
239
+ task = asyncio.create_task(self._execute(message))
240
+ active_tasks[task] = message_id
241
+
242
+ nonlocal available_slots, future_work
243
+ available_slots -= 1
244
+ future_work += 1
245
+
246
+ if available_slots <= 0:
247
+ await asyncio.sleep(self.minimum_check_interval.total_seconds())
248
+ continue
249
+
250
+ future_work, due_work = await stream_due_tasks(
251
+ keys=[self.docket.queue_key, self.docket.stream_key],
252
+ args=[datetime.now(timezone.utc).timestamp(), self.docket.name],
253
+ )
254
+ if due_work > 0:
255
+ logger.debug(
256
+ "Moved %d/%d due tasks from %s to %s",
257
+ due_work,
258
+ future_work,
259
+ self.docket.queue_key,
260
+ self.docket.stream_key,
261
+ extra=self._log_context,
262
+ )
263
+
264
+ redeliveries: RedisMessages
265
+ _, redeliveries, _ = await redis.xautoclaim(
266
+ name=self.docket.stream_key,
267
+ groupname=self.docket.worker_group_name,
268
+ consumername=self.name,
269
+ min_idle_time=int(
270
+ self.redelivery_timeout.total_seconds() * 1000
271
+ ),
272
+ start_id="0-0",
273
+ count=available_slots,
274
+ )
243
275
 
244
- for source in [[(b"redeliveries", redeliveries)], new_deliveries]:
245
- for _, messages in source:
276
+ for message_id, message in redeliveries:
277
+ start_task(message_id, message)
278
+ if available_slots <= 0:
279
+ break
280
+
281
+ if available_slots <= 0:
282
+ continue
283
+
284
+ new_deliveries: RedisReadGroupResponse = await redis.xreadgroup(
285
+ groupname=self.docket.worker_group_name,
286
+ consumername=self.name,
287
+ streams={self.docket.stream_key: ">"},
288
+ block=(
289
+ int(self.minimum_check_interval.total_seconds() * 1000)
290
+ if forever or active_tasks
291
+ else None
292
+ ),
293
+ count=available_slots,
294
+ )
295
+ for _, messages in new_deliveries:
246
296
  for message_id, message in messages:
247
- await self._execute(message)
248
-
249
- async with redis.pipeline() as pipeline:
250
- pipeline.xack(
251
- self.docket.stream_key,
252
- self.consumer_group_name,
253
- message_id,
254
- )
255
- pipeline.xdel(
256
- self.docket.stream_key,
257
- message_id,
258
- )
259
- await pipeline.execute()
260
-
261
- # When executing a task, there's always a chance that it was
262
- # either retried or it scheduled another task, so let's give
263
- # ourselves one more iteration of the loop to handle that.
264
- total_work += 1
297
+ start_task(message_id, message)
298
+ if available_slots <= 0:
299
+ break
300
+ except asyncio.CancelledError:
301
+ if active_tasks: # pragma: no cover
302
+ logger.info(
303
+ "Shutdown requested, finishing %d active tasks...",
304
+ len(active_tasks),
305
+ extra=self._log_context,
306
+ )
307
+ finally:
308
+ if active_tasks:
309
+ await asyncio.gather(*active_tasks, return_exceptions=True)
310
+ await process_completed_tasks()
265
311
 
266
312
  async def _execute(self, message: RedisMessage) -> None:
267
- execution = Execution.from_message(
268
- self.docket.tasks[message[b"function"].decode()],
269
- message,
270
- )
313
+ function_name = message[b"function"].decode()
314
+ function = self.docket.tasks.get(function_name)
315
+ if function is None:
316
+ logger.warning(
317
+ "Task function %r not found", function_name, extra=self._log_context
318
+ )
319
+ return
320
+
321
+ execution = Execution.from_message(function, message)
271
322
  name = execution.function.__name__
272
323
  key = execution.key
273
324
 
@@ -282,6 +333,15 @@ class Worker:
282
333
  "task": name,
283
334
  }
284
335
 
336
+ arrow = "↬" if execution.attempt > 1 else "↪"
337
+ call = execution.call_repr()
338
+
339
+ if self.docket.strike_list.is_stricken(execution):
340
+ arrow = "🗙"
341
+ logger.warning("%s %s", arrow, call, extra=log_context)
342
+ TASKS_STRICKEN.add(1, counter_labels | {"where": "worker"})
343
+ return
344
+
285
345
  dependencies = self._get_dependencies(execution)
286
346
 
287
347
  context = propagate.extract(message, getter=message_getter)
@@ -297,8 +357,6 @@ class Worker:
297
357
  TASKS_RUNNING.add(1, counter_labels)
298
358
  TASK_PUNCTUALITY.record(punctuality.total_seconds(), counter_labels)
299
359
 
300
- arrow = "↬" if execution.attempt > 1 else "↪"
301
- call = execution.call_repr()
302
360
  logger.info("%s [%s] %s", arrow, punctuality, call, extra=log_context)
303
361
 
304
362
  try:
@@ -387,3 +445,52 @@ class Worker:
387
445
  return True
388
446
 
389
447
  return False
448
+
449
+ @property
450
+ def workers_set(self) -> str:
451
+ return self.docket.workers_set
452
+
453
+ def worker_tasks_set(self, worker_name: str) -> str:
454
+ return self.docket.worker_tasks_set(worker_name)
455
+
456
+ def task_workers_set(self, task_name: str) -> str:
457
+ return self.docket.task_workers_set(task_name)
458
+
459
+ async def _heartbeat(self) -> None:
460
+ while True:
461
+ await asyncio.sleep(self.docket.heartbeat_interval.total_seconds())
462
+ try:
463
+ now = datetime.now(timezone.utc).timestamp()
464
+ maximum_age = (
465
+ self.docket.heartbeat_interval * self.docket.missed_heartbeats
466
+ )
467
+ oldest = now - maximum_age.total_seconds()
468
+
469
+ task_names = list(self.docket.tasks)
470
+
471
+ async with self.docket.redis() as r:
472
+ async with r.pipeline() as pipeline:
473
+ pipeline.zremrangebyscore(self.workers_set, 0, oldest)
474
+ pipeline.zadd(self.workers_set, {self.name: now})
475
+
476
+ for task_name in task_names:
477
+ task_workers_set = self.task_workers_set(task_name)
478
+ pipeline.zremrangebyscore(task_workers_set, 0, oldest)
479
+ pipeline.zadd(task_workers_set, {self.name: now})
480
+
481
+ pipeline.sadd(self.worker_tasks_set(self.name), *task_names)
482
+ pipeline.expire(
483
+ self.worker_tasks_set(self.name),
484
+ max(maximum_age, timedelta(seconds=1)),
485
+ )
486
+
487
+ await pipeline.execute()
488
+ except asyncio.CancelledError: # pragma: no cover
489
+ return
490
+ except redis.exceptions.ConnectionError:
491
+ REDIS_DISRUPTIONS.add(
492
+ 1, {"docket": self.docket.name, "worker": self.name}
493
+ )
494
+ logger.exception("Error sending worker heartbeat", exc_info=True)
495
+ except Exception:
496
+ logger.exception("Error sending worker heartbeat", exc_info=True)