pydocket 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydocket might be problematic. Click here for more details.

docket/worker.py CHANGED
@@ -1,21 +1,48 @@
1
+ import asyncio
2
+ import inspect
1
3
  import logging
2
4
  import sys
3
- from datetime import datetime, timezone
5
+ from datetime import datetime, timedelta, timezone
4
6
  from types import TracebackType
5
- from typing import TYPE_CHECKING, Any, Protocol, Self, Sequence, TypeVar, cast
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ Any,
10
+ Protocol,
11
+ Self,
12
+ TypeVar,
13
+ cast,
14
+ )
6
15
  from uuid import uuid4
7
16
 
8
- from redis import RedisError
9
-
10
- from .docket import Docket, Execution
17
+ import redis.exceptions
18
+ from opentelemetry import propagate, trace
19
+ from opentelemetry.trace import Tracer
20
+
21
+ from .docket import (
22
+ Docket,
23
+ Execution,
24
+ RedisMessage,
25
+ RedisMessageID,
26
+ RedisMessages,
27
+ RedisReadGroupResponse,
28
+ )
29
+ from .instrumentation import (
30
+ REDIS_DISRUPTIONS,
31
+ TASK_DURATION,
32
+ TASK_PUNCTUALITY,
33
+ TASKS_COMPLETED,
34
+ TASKS_FAILED,
35
+ TASKS_RETRIED,
36
+ TASKS_RUNNING,
37
+ TASKS_STARTED,
38
+ TASKS_STRICKEN,
39
+ TASKS_SUCCEEDED,
40
+ message_getter,
41
+ )
11
42
 
12
43
  logger: logging.Logger = logging.getLogger(__name__)
44
+ tracer: Tracer = trace.get_tracer(__name__)
13
45
 
14
- RedisStreamID = bytes
15
- RedisMessageID = bytes
16
- RedisMessage = dict[bytes, bytes]
17
- RedisStream = tuple[RedisStreamID, Sequence[tuple[RedisMessageID, RedisMessage]]]
18
- RedisReadGroupResponse = Sequence[RedisStream]
19
46
 
20
47
  if TYPE_CHECKING: # pragma: no cover
21
48
  from .dependencies import Dependency
@@ -30,26 +57,27 @@ class _stream_due_tasks(Protocol):
30
57
 
31
58
 
32
59
  class Worker:
33
- name: str
34
60
  docket: Docket
61
+ name: str
35
62
 
36
- prefetch_count: int = 10
37
-
38
- def __init__(self, docket: Docket) -> None:
39
- self.name = f"worker:{uuid4()}"
63
+ def __init__(
64
+ self,
65
+ docket: Docket,
66
+ name: str | None = None,
67
+ concurrency: int = 10,
68
+ redelivery_timeout: timedelta = timedelta(minutes=5),
69
+ reconnection_delay: timedelta = timedelta(seconds=5),
70
+ minimum_check_interval: timedelta = timedelta(milliseconds=10),
71
+ ) -> None:
40
72
  self.docket = docket
73
+ self.name = name or f"worker:{uuid4()}"
74
+ self.concurrency = concurrency
75
+ self.redelivery_timeout = redelivery_timeout
76
+ self.reconnection_delay = reconnection_delay
77
+ self.minimum_check_interval = minimum_check_interval
41
78
 
42
79
  async def __aenter__(self) -> Self:
43
- async with self.docket.redis() as redis:
44
- try:
45
- await redis.xgroup_create(
46
- groupname=self.consumer_group_name,
47
- name=self.docket.stream_key,
48
- id="0-0",
49
- mkstream=True,
50
- )
51
- except RedisError as e:
52
- assert "BUSYGROUP" in repr(e)
80
+ self._heartbeat_task = asyncio.create_task(self._heartbeat())
53
81
 
54
82
  return self
55
83
 
@@ -59,11 +87,12 @@ class Worker:
59
87
  exc_value: BaseException | None,
60
88
  traceback: TracebackType | None,
61
89
  ) -> None:
62
- pass
63
-
64
- @property
65
- def consumer_group_name(self) -> str:
66
- return "docket"
90
+ self._heartbeat_task.cancel()
91
+ try:
92
+ await self._heartbeat_task
93
+ except asyncio.CancelledError:
94
+ pass
95
+ del self._heartbeat_task
67
96
 
68
97
  @property
69
98
  def _log_context(self) -> dict[str, str]:
@@ -72,7 +101,63 @@ class Worker:
72
101
  "stream_key": self.docket.stream_key,
73
102
  }
74
103
 
75
- async def run_until_current(self) -> None:
104
+ @classmethod
105
+ async def run(
106
+ cls,
107
+ docket_name: str = "docket",
108
+ url: str = "redis://localhost:6379/0",
109
+ name: str | None = None,
110
+ concurrency: int = 10,
111
+ redelivery_timeout: timedelta = timedelta(minutes=5),
112
+ reconnection_delay: timedelta = timedelta(seconds=5),
113
+ until_finished: bool = False,
114
+ tasks: list[str] = ["docket.tasks:standard_tasks"],
115
+ ) -> None:
116
+ async with Docket(name=docket_name, url=url) as docket:
117
+ for task_path in tasks:
118
+ docket.register_collection(task_path)
119
+
120
+ async with Worker(
121
+ docket=docket,
122
+ name=name,
123
+ concurrency=concurrency,
124
+ redelivery_timeout=redelivery_timeout,
125
+ reconnection_delay=reconnection_delay,
126
+ ) as worker:
127
+ if until_finished:
128
+ await worker.run_until_finished()
129
+ else:
130
+ await worker.run_forever() # pragma: no cover
131
+
132
+ async def run_until_finished(self) -> None:
133
+ """Run the worker until there are no more tasks to process."""
134
+ return await self._run(forever=False)
135
+
136
+ async def run_forever(self) -> None:
137
+ """Run the worker indefinitely."""
138
+ return await self._run(forever=True) # pragma: no cover
139
+
140
+ async def _run(self, forever: bool = False) -> None:
141
+ logger.info("Starting worker %r with the following tasks:", self.name)
142
+ for task_name, task in self.docket.tasks.items():
143
+ signature = inspect.signature(task)
144
+ logger.info("* %s%s", task_name, signature)
145
+
146
+ while True:
147
+ try:
148
+ return await self._worker_loop(forever=forever)
149
+ except redis.exceptions.ConnectionError:
150
+ REDIS_DISRUPTIONS.add(
151
+ 1, {"docket": self.docket.name, "worker": self.name}
152
+ )
153
+ logger.warning(
154
+ "Error connecting to redis, retrying in %s...",
155
+ self.reconnection_delay,
156
+ exc_info=True,
157
+ )
158
+ await asyncio.sleep(self.reconnection_delay.total_seconds())
159
+
160
+ async def _worker_loop(self, forever: bool = False):
76
161
  async with self.docket.redis() as redis:
77
162
  stream_due_tasks: _stream_due_tasks = cast(
78
163
  _stream_due_tasks,
@@ -119,84 +204,198 @@ class Worker:
119
204
  ),
120
205
  )
121
206
 
122
- total_work, due_work = sys.maxsize, 0
123
- while total_work:
124
- now = datetime.now(timezone.utc)
125
- total_work, due_work = await stream_due_tasks(
126
- keys=[self.docket.queue_key, self.docket.stream_key],
127
- args=[now.timestamp(), self.docket.name],
128
- )
129
- logger.info(
130
- "Moved %d/%d due tasks from %s to %s",
131
- due_work,
132
- total_work,
133
- self.docket.queue_key,
134
- self.docket.stream_key,
135
- extra=self._log_context,
136
- )
207
+ active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
137
208
 
138
- response: RedisReadGroupResponse = await redis.xreadgroup(
139
- groupname=self.consumer_group_name,
140
- consumername=self.name,
141
- streams={self.docket.stream_key: ">"},
142
- count=self.prefetch_count,
143
- block=10,
144
- )
145
- for _, messages in response:
146
- for message_id, message in messages:
147
- await self._execute(message)
148
-
149
- # When executing a task, there's always a chance that it was
150
- # either retried or it scheduled another task, so let's give
151
- # ourselves one more iteration of the loop to handle that.
152
- total_work += 1
153
-
154
- async with redis.pipeline() as pipe:
155
- pipe.xack(
156
- self.docket.stream_key,
157
- self.consumer_group_name,
158
- message_id,
159
- )
160
- pipe.xdel(
161
- self.docket.stream_key,
162
- message_id,
163
- )
164
- await pipe.execute()
209
+ async def process_completed_tasks() -> None:
210
+ completed_tasks = {task for task in active_tasks if task.done()}
211
+ for task in completed_tasks:
212
+ message_id = active_tasks.pop(task)
213
+
214
+ await task
215
+
216
+ async with redis.pipeline() as pipeline:
217
+ pipeline.xack(
218
+ self.docket.stream_key,
219
+ self.docket.worker_group_name,
220
+ message_id,
221
+ )
222
+ pipeline.xdel(
223
+ self.docket.stream_key,
224
+ message_id,
225
+ )
226
+ await pipeline.execute()
227
+
228
+ future_work, due_work = sys.maxsize, 0
229
+
230
+ try:
231
+ while forever or future_work or active_tasks:
232
+ await process_completed_tasks()
233
+
234
+ available_slots = self.concurrency - len(active_tasks)
235
+
236
+ def start_task(
237
+ message_id: RedisMessageID, message: RedisMessage
238
+ ) -> None:
239
+ task = asyncio.create_task(self._execute(message))
240
+ active_tasks[task] = message_id
241
+
242
+ nonlocal available_slots, future_work
243
+ available_slots -= 1
244
+ future_work += 1
245
+
246
+ if available_slots <= 0:
247
+ await asyncio.sleep(self.minimum_check_interval.total_seconds())
248
+ continue
249
+
250
+ future_work, due_work = await stream_due_tasks(
251
+ keys=[self.docket.queue_key, self.docket.stream_key],
252
+ args=[datetime.now(timezone.utc).timestamp(), self.docket.name],
253
+ )
254
+ if due_work > 0:
255
+ logger.debug(
256
+ "Moved %d/%d due tasks from %s to %s",
257
+ due_work,
258
+ future_work,
259
+ self.docket.queue_key,
260
+ self.docket.stream_key,
261
+ extra=self._log_context,
262
+ )
263
+
264
+ redeliveries: RedisMessages
265
+ _, redeliveries, _ = await redis.xautoclaim(
266
+ name=self.docket.stream_key,
267
+ groupname=self.docket.worker_group_name,
268
+ consumername=self.name,
269
+ min_idle_time=int(
270
+ self.redelivery_timeout.total_seconds() * 1000
271
+ ),
272
+ start_id="0-0",
273
+ count=available_slots,
274
+ )
275
+
276
+ for message_id, message in redeliveries:
277
+ start_task(message_id, message)
278
+ if available_slots <= 0:
279
+ break
280
+
281
+ if available_slots <= 0:
282
+ continue
283
+
284
+ new_deliveries: RedisReadGroupResponse = await redis.xreadgroup(
285
+ groupname=self.docket.worker_group_name,
286
+ consumername=self.name,
287
+ streams={self.docket.stream_key: ">"},
288
+ block=(
289
+ int(self.minimum_check_interval.total_seconds() * 1000)
290
+ if forever or active_tasks
291
+ else None
292
+ ),
293
+ count=available_slots,
294
+ )
295
+ for _, messages in new_deliveries:
296
+ for message_id, message in messages:
297
+ start_task(message_id, message)
298
+ if available_slots <= 0:
299
+ break
300
+ except asyncio.CancelledError:
301
+ if active_tasks: # pragma: no cover
302
+ logger.info(
303
+ "Shutdown requested, finishing %d active tasks...",
304
+ len(active_tasks),
305
+ extra=self._log_context,
306
+ )
307
+ finally:
308
+ if active_tasks:
309
+ await asyncio.gather(*active_tasks, return_exceptions=True)
310
+ await process_completed_tasks()
165
311
 
166
312
  async def _execute(self, message: RedisMessage) -> None:
167
- execution = Execution.from_message(
168
- self.docket.tasks[message[b"function"].decode()],
169
- message,
170
- )
171
-
172
- logger.info(
173
- "Executing task %s with args %s and kwargs %s",
174
- execution.key,
175
- execution.args,
176
- execution.kwargs,
177
- extra={
178
- **self._log_context,
179
- "function": execution.function.__name__,
180
- },
181
- )
313
+ function_name = message[b"function"].decode()
314
+ function = self.docket.tasks.get(function_name)
315
+ if function is None:
316
+ logger.warning(
317
+ "Task function %r not found", function_name, extra=self._log_context
318
+ )
319
+ return
320
+
321
+ execution = Execution.from_message(function, message)
322
+ name = execution.function.__name__
323
+ key = execution.key
324
+
325
+ log_context: dict[str, str | float] = {
326
+ **self._log_context,
327
+ "task": name,
328
+ "key": key,
329
+ }
330
+ counter_labels = {
331
+ "docket": self.docket.name,
332
+ "worker": self.name,
333
+ "task": name,
334
+ }
335
+
336
+ arrow = "↬" if execution.attempt > 1 else "↪"
337
+ call = execution.call_repr()
338
+
339
+ if self.docket.strike_list.is_stricken(execution):
340
+ arrow = "🗙"
341
+ logger.warning("%s %s", arrow, call, extra=log_context)
342
+ TASKS_STRICKEN.add(1, counter_labels | {"where": "worker"})
343
+ return
182
344
 
183
345
  dependencies = self._get_dependencies(execution)
184
346
 
347
+ context = propagate.extract(message, getter=message_getter)
348
+ initiating_context = trace.get_current_span(context).get_span_context()
349
+ links = [trace.Link(initiating_context)] if initiating_context.is_valid else []
350
+
351
+ start = datetime.now(timezone.utc)
352
+ punctuality = start - execution.when
353
+ log_context["punctuality"] = punctuality.total_seconds()
354
+ duration = timedelta(0)
355
+
356
+ TASKS_STARTED.add(1, counter_labels)
357
+ TASKS_RUNNING.add(1, counter_labels)
358
+ TASK_PUNCTUALITY.record(punctuality.total_seconds(), counter_labels)
359
+
360
+ logger.info("%s [%s] %s", arrow, punctuality, call, extra=log_context)
361
+
185
362
  try:
186
- await execution.function(
187
- *execution.args,
188
- **{
189
- **execution.kwargs,
190
- **dependencies,
363
+ with tracer.start_as_current_span(
364
+ execution.function.__name__,
365
+ kind=trace.SpanKind.CONSUMER,
366
+ attributes={
367
+ "docket.name": self.docket.name,
368
+ "docket.execution.when": execution.when.isoformat(),
369
+ "docket.execution.key": execution.key,
370
+ "docket.execution.attempt": execution.attempt,
371
+ "docket.execution.punctuality": punctuality.total_seconds(),
372
+ "code.function.name": execution.function.__name__,
191
373
  },
192
- )
374
+ links=links,
375
+ ):
376
+ await execution.function(
377
+ *execution.args,
378
+ **{
379
+ **execution.kwargs,
380
+ **dependencies,
381
+ },
382
+ )
383
+
384
+ TASKS_SUCCEEDED.add(1, counter_labels)
385
+ duration = datetime.now(timezone.utc) - start
386
+ log_context["duration"] = duration.total_seconds()
387
+ logger.info("%s [%s] %s", "↩", duration, call, extra=log_context)
193
388
  except Exception:
194
- logger.exception(
195
- "Error executing task %s",
196
- execution.key,
197
- extra=self._log_context,
198
- )
199
- await self._retry_if_requested(execution, dependencies)
389
+ TASKS_FAILED.add(1, counter_labels)
390
+ duration = datetime.now(timezone.utc) - start
391
+ log_context["duration"] = duration.total_seconds()
392
+ retried = await self._retry_if_requested(execution, dependencies)
393
+ arrow = "↫" if retried else "↩"
394
+ logger.exception("%s [%s] %s", arrow, duration, call, extra=log_context)
395
+ finally:
396
+ TASKS_RUNNING.add(-1, counter_labels)
397
+ TASKS_COMPLETED.add(1, counter_labels)
398
+ TASK_DURATION.record(duration.total_seconds(), counter_labels)
200
399
 
201
400
  def _get_dependencies(
202
401
  self,
@@ -208,14 +407,14 @@ class Worker:
208
407
 
209
408
  dependencies: dict[str, Any] = {}
210
409
 
211
- for param_name, dependency in parameters.items():
410
+ for parameter_name, dependency in parameters.items():
212
411
  # If the argument is already provided, skip it, which allows users to call
213
412
  # the function directly with the arguments they want.
214
- if param_name in execution.kwargs:
215
- dependencies[param_name] = execution.kwargs[param_name]
413
+ if parameter_name in execution.kwargs:
414
+ dependencies[parameter_name] = execution.kwargs[parameter_name]
216
415
  continue
217
416
 
218
- dependencies[param_name] = dependency(self.docket, self, execution)
417
+ dependencies[parameter_name] = dependency(self.docket, self, execution)
219
418
 
220
419
  return dependencies
221
420
 
@@ -223,22 +422,75 @@ class Worker:
223
422
  self,
224
423
  execution: Execution,
225
424
  dependencies: dict[str, Any],
226
- ) -> None:
425
+ ) -> bool:
227
426
  from .dependencies import Retry
228
427
 
229
428
  retries = [retry for retry in dependencies.values() if isinstance(retry, Retry)]
230
429
  if not retries:
231
- return
430
+ return False
232
431
 
233
432
  retry = retries[0]
234
433
 
235
- if execution.attempt < retry.attempts:
434
+ if retry.attempts is None or execution.attempt < retry.attempts:
236
435
  execution.when = datetime.now(timezone.utc) + retry.delay
237
436
  execution.attempt += 1
238
437
  await self.docket.schedule(execution)
239
- else:
240
- logger.error(
241
- "Task %s failed after %d attempts",
242
- execution.key,
243
- retry.attempts,
244
- )
438
+
439
+ counter_labels = {
440
+ "docket": self.docket.name,
441
+ "worker": self.name,
442
+ "task": execution.function.__name__,
443
+ }
444
+ TASKS_RETRIED.add(1, counter_labels)
445
+ return True
446
+
447
+ return False
448
+
449
+ @property
450
+ def workers_set(self) -> str:
451
+ return self.docket.workers_set
452
+
453
+ def worker_tasks_set(self, worker_name: str) -> str:
454
+ return self.docket.worker_tasks_set(worker_name)
455
+
456
+ def task_workers_set(self, task_name: str) -> str:
457
+ return self.docket.task_workers_set(task_name)
458
+
459
+ async def _heartbeat(self) -> None:
460
+ while True:
461
+ await asyncio.sleep(self.docket.heartbeat_interval.total_seconds())
462
+ try:
463
+ now = datetime.now(timezone.utc).timestamp()
464
+ maximum_age = (
465
+ self.docket.heartbeat_interval * self.docket.missed_heartbeats
466
+ )
467
+ oldest = now - maximum_age.total_seconds()
468
+
469
+ task_names = list(self.docket.tasks)
470
+
471
+ async with self.docket.redis() as r:
472
+ async with r.pipeline() as pipeline:
473
+ pipeline.zremrangebyscore(self.workers_set, 0, oldest)
474
+ pipeline.zadd(self.workers_set, {self.name: now})
475
+
476
+ for task_name in task_names:
477
+ task_workers_set = self.task_workers_set(task_name)
478
+ pipeline.zremrangebyscore(task_workers_set, 0, oldest)
479
+ pipeline.zadd(task_workers_set, {self.name: now})
480
+
481
+ pipeline.sadd(self.worker_tasks_set(self.name), *task_names)
482
+ pipeline.expire(
483
+ self.worker_tasks_set(self.name),
484
+ max(maximum_age, timedelta(seconds=1)),
485
+ )
486
+
487
+ await pipeline.execute()
488
+ except asyncio.CancelledError: # pragma: no cover
489
+ return
490
+ except redis.exceptions.ConnectionError:
491
+ REDIS_DISRUPTIONS.add(
492
+ 1, {"docket": self.docket.name, "worker": self.name}
493
+ )
494
+ logger.exception("Error sending worker heartbeat", exc_info=True)
495
+ except Exception:
496
+ logger.exception("Error sending worker heartbeat", exc_info=True)