pydocket 0.5.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydocket might be problematic. Click here for more details.

docket/worker.py CHANGED
@@ -1,11 +1,10 @@
1
1
  import asyncio
2
- import inspect
3
2
  import logging
4
3
  import sys
5
4
  from datetime import datetime, timedelta, timezone
6
5
  from types import TracebackType
7
6
  from typing import (
8
- TYPE_CHECKING,
7
+ Coroutine,
9
8
  Mapping,
10
9
  Protocol,
11
10
  Self,
@@ -13,17 +12,27 @@ from typing import (
13
12
  )
14
13
  from uuid import uuid4
15
14
 
16
- import redis.exceptions
17
- from opentelemetry import propagate, trace
15
+ from opentelemetry import trace
18
16
  from opentelemetry.trace import Tracer
19
17
  from redis.asyncio import Redis
20
-
18
+ from redis.exceptions import ConnectionError, LockError
19
+
20
+ from docket.execution import get_signature
21
+
22
+ from .dependencies import (
23
+ Dependency,
24
+ Perpetual,
25
+ Retry,
26
+ Timeout,
27
+ get_single_dependency_of_type,
28
+ get_single_dependency_parameter_of_type,
29
+ resolved_dependencies,
30
+ )
21
31
  from .docket import (
22
32
  Docket,
23
33
  Execution,
24
34
  RedisMessage,
25
35
  RedisMessageID,
26
- RedisMessages,
27
36
  RedisReadGroupResponse,
28
37
  )
29
38
  from .instrumentation import (
@@ -40,7 +49,6 @@ from .instrumentation import (
40
49
  TASKS_STARTED,
41
50
  TASKS_STRICKEN,
42
51
  TASKS_SUCCEEDED,
43
- message_getter,
44
52
  metrics_server,
45
53
  )
46
54
 
@@ -48,10 +56,6 @@ logger: logging.Logger = logging.getLogger(__name__)
48
56
  tracer: Tracer = trace.get_tracer(__name__)
49
57
 
50
58
 
51
- if TYPE_CHECKING: # pragma: no cover
52
- from .dependencies import Dependency
53
-
54
-
55
59
  class _stream_due_tasks(Protocol):
56
60
  async def __call__(
57
61
  self, keys: list[str], args: list[str | float]
@@ -74,7 +78,7 @@ class Worker:
74
78
  concurrency: int = 10,
75
79
  redelivery_timeout: timedelta = timedelta(minutes=5),
76
80
  reconnection_delay: timedelta = timedelta(seconds=5),
77
- minimum_check_interval: timedelta = timedelta(milliseconds=100),
81
+ minimum_check_interval: timedelta = timedelta(milliseconds=250),
78
82
  scheduling_resolution: timedelta = timedelta(milliseconds=250),
79
83
  ) -> None:
80
84
  self.docket = docket
@@ -196,13 +200,14 @@ class Worker:
196
200
  async def _run(self, forever: bool = False) -> None:
197
201
  logger.info("Starting worker %r with the following tasks:", self.name)
198
202
  for task_name, task in self.docket.tasks.items():
199
- signature = inspect.signature(task)
203
+ signature = get_signature(task)
200
204
  logger.info("* %s%s", task_name, signature)
201
205
 
202
206
  while True:
203
207
  try:
204
- return await self._worker_loop(forever=forever)
205
- except redis.exceptions.ConnectionError:
208
+ async with self.docket.redis() as redis:
209
+ return await self._worker_loop(redis, forever=forever)
210
+ except ConnectionError:
206
211
  REDIS_DISRUPTIONS.add(1, self.labels())
207
212
  logger.warning(
208
213
  "Error connecting to redis, retrying in %s...",
@@ -211,127 +216,133 @@ class Worker:
211
216
  )
212
217
  await asyncio.sleep(self.reconnection_delay.total_seconds())
213
218
 
214
- async def _worker_loop(self, forever: bool = False):
219
+ async def _worker_loop(self, redis: Redis, forever: bool = False):
215
220
  worker_stopping = asyncio.Event()
216
221
 
217
222
  await self._schedule_all_automatic_perpetual_tasks()
218
- perpetual_scheduling_task = asyncio.create_task(
219
- self._perpetual_scheduling_loop(worker_stopping)
223
+
224
+ scheduler_task = asyncio.create_task(
225
+ self._scheduler_loop(redis, worker_stopping)
220
226
  )
221
227
 
222
- async with self.docket.redis() as redis:
223
- scheduler_task = asyncio.create_task(
224
- self._scheduler_loop(redis, worker_stopping)
228
+ active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
229
+ available_slots = self.concurrency
230
+
231
+ async def check_for_work() -> bool:
232
+ logger.debug("Checking for work", extra=self._log_context())
233
+ async with redis.pipeline() as pipeline:
234
+ pipeline.xlen(self.docket.stream_key)
235
+ pipeline.zcard(self.docket.queue_key)
236
+ results: list[int] = await pipeline.execute()
237
+ stream_len = results[0]
238
+ queue_len = results[1]
239
+ return stream_len > 0 or queue_len > 0
240
+
241
+ async def get_redeliveries(redis: Redis) -> RedisReadGroupResponse:
242
+ logger.debug("Getting redeliveries", extra=self._log_context())
243
+ _, redeliveries, *_ = await redis.xautoclaim(
244
+ name=self.docket.stream_key,
245
+ groupname=self.docket.worker_group_name,
246
+ consumername=self.name,
247
+ min_idle_time=int(self.redelivery_timeout.total_seconds() * 1000),
248
+ start_id="0-0",
249
+ count=available_slots,
250
+ )
251
+ return [(b"__redelivery__", redeliveries)]
252
+
253
+ async def get_new_deliveries(redis: Redis) -> RedisReadGroupResponse:
254
+ logger.debug("Getting new deliveries", extra=self._log_context())
255
+ return await redis.xreadgroup(
256
+ groupname=self.docket.worker_group_name,
257
+ consumername=self.name,
258
+ streams={self.docket.stream_key: ">"},
259
+ block=int(self.minimum_check_interval.total_seconds() * 1000),
260
+ count=available_slots,
225
261
  )
226
- active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
227
-
228
- async def check_for_work() -> bool:
229
- async with redis.pipeline() as pipeline:
230
- pipeline.xlen(self.docket.stream_key)
231
- pipeline.zcard(self.docket.queue_key)
232
- results: list[int] = await pipeline.execute()
233
- stream_len = results[0]
234
- queue_len = results[1]
235
- return stream_len > 0 or queue_len > 0
236
-
237
- async def process_completed_tasks() -> None:
238
- completed_tasks = {task for task in active_tasks if task.done()}
239
- for task in completed_tasks:
240
- message_id = active_tasks.pop(task)
241
-
242
- await task
243
-
244
- async with redis.pipeline() as pipeline:
245
- pipeline.xack(
246
- self.docket.stream_key,
247
- self.docket.worker_group_name,
248
- message_id,
249
- )
250
- pipeline.xdel(
251
- self.docket.stream_key,
252
- message_id,
253
- )
254
- await pipeline.execute()
255
262
 
256
- has_work: bool = True
263
+ def start_task(message_id: RedisMessageID, message: RedisMessage) -> bool:
264
+ if not message: # pragma: no cover
265
+ return False
257
266
 
258
- if not forever: # pragma: no branch
259
- has_work = await check_for_work()
267
+ function_name = message[b"function"].decode()
268
+ if not (function := self.docket.tasks.get(function_name)):
269
+ logger.warning(
270
+ "Task function %r not found",
271
+ function_name,
272
+ extra=self._log_context(),
273
+ )
274
+ return False
260
275
 
261
- try:
262
- while forever or has_work or active_tasks:
263
- await process_completed_tasks()
276
+ execution = Execution.from_message(function, message)
264
277
 
265
- available_slots = self.concurrency - len(active_tasks)
278
+ task = asyncio.create_task(self._execute(execution))
279
+ active_tasks[task] = message_id
266
280
 
267
- def start_task(
268
- message_id: RedisMessageID, message: RedisMessage
269
- ) -> None:
270
- if not message: # pragma: no cover
271
- return
281
+ nonlocal available_slots
282
+ available_slots -= 1
272
283
 
273
- task = asyncio.create_task(self._execute(message))
274
- active_tasks[task] = message_id
284
+ return True
275
285
 
276
- nonlocal available_slots
277
- available_slots -= 1
286
+ async def ack_message(redis: Redis, message_id: RedisMessageID) -> None:
287
+ logger.debug("Acknowledging message", extra=self._log_context())
288
+ async with redis.pipeline() as pipeline:
289
+ pipeline.xack(
290
+ self.docket.stream_key,
291
+ self.docket.worker_group_name,
292
+ message_id,
293
+ )
294
+ pipeline.xdel(
295
+ self.docket.stream_key,
296
+ message_id,
297
+ )
298
+ await pipeline.execute()
278
299
 
279
- if available_slots <= 0:
280
- await asyncio.sleep(self.minimum_check_interval.total_seconds())
281
- continue
282
-
283
- redeliveries: RedisMessages
284
- _, redeliveries, *_ = await redis.xautoclaim(
285
- name=self.docket.stream_key,
286
- groupname=self.docket.worker_group_name,
287
- consumername=self.name,
288
- min_idle_time=int(
289
- self.redelivery_timeout.total_seconds() * 1000
290
- ),
291
- start_id="0-0",
292
- count=available_slots,
293
- )
300
+ async def process_completed_tasks() -> None:
301
+ completed_tasks = {task for task in active_tasks if task.done()}
302
+ for task in completed_tasks:
303
+ message_id = active_tasks.pop(task)
304
+ await task
305
+ await ack_message(redis, message_id)
294
306
 
295
- for message_id, message in redeliveries:
296
- start_task(message_id, message)
307
+ has_work: bool = True
297
308
 
298
- if available_slots <= 0:
299
- continue
300
-
301
- new_deliveries: RedisReadGroupResponse = await redis.xreadgroup(
302
- groupname=self.docket.worker_group_name,
303
- consumername=self.name,
304
- streams={self.docket.stream_key: ">"},
305
- block=(
306
- int(self.minimum_check_interval.total_seconds() * 1000)
307
- if forever or active_tasks
308
- else None
309
- ),
310
- count=available_slots,
311
- )
309
+ try:
310
+ while forever or has_work or active_tasks:
311
+ await process_completed_tasks()
312
+
313
+ available_slots = self.concurrency - len(active_tasks)
314
+
315
+ if available_slots <= 0:
316
+ await asyncio.sleep(self.minimum_check_interval.total_seconds())
317
+ continue
312
318
 
313
- for _, messages in new_deliveries:
319
+ for source in [get_redeliveries, get_new_deliveries]:
320
+ for _, messages in await source(redis):
314
321
  for message_id, message in messages:
315
- start_task(message_id, message)
322
+ if not start_task(message_id, message):
323
+ await self._delete_known_task(redis, message)
324
+ await ack_message(redis, message_id)
316
325
 
317
- if not forever and not active_tasks and not new_deliveries:
318
- has_work = await check_for_work()
326
+ if available_slots <= 0:
327
+ break
319
328
 
320
- except asyncio.CancelledError:
321
- if active_tasks: # pragma: no cover
322
- logger.info(
323
- "Shutdown requested, finishing %d active tasks...",
324
- len(active_tasks),
325
- extra=self._log_context(),
326
- )
327
- finally:
328
- if active_tasks:
329
- await asyncio.gather(*active_tasks, return_exceptions=True)
330
- await process_completed_tasks()
329
+ if not forever and not active_tasks:
330
+ has_work = await check_for_work()
331
331
 
332
- worker_stopping.set()
333
- await scheduler_task
334
- await perpetual_scheduling_task
332
+ except asyncio.CancelledError:
333
+ if active_tasks: # pragma: no cover
334
+ logger.info(
335
+ "Shutdown requested, finishing %d active tasks...",
336
+ len(active_tasks),
337
+ extra=self._log_context(),
338
+ )
339
+ finally:
340
+ if active_tasks:
341
+ await asyncio.gather(*active_tasks, return_exceptions=True)
342
+ await process_completed_tasks()
343
+
344
+ worker_stopping.set()
345
+ await scheduler_task
335
346
 
336
347
  async def _scheduler_loop(
337
348
  self,
@@ -392,6 +403,7 @@ class Worker:
392
403
 
393
404
  while not worker_stopping.is_set() or total_work:
394
405
  try:
406
+ logger.debug("Scheduling due tasks", extra=self._log_context())
395
407
  total_work, due_work = await stream_due_tasks(
396
408
  keys=[self.docket.queue_key, self.docket.stream_key],
397
409
  args=[datetime.now(timezone.utc).timestamp(), self.docket.name],
@@ -417,73 +429,53 @@ class Worker:
417
429
 
418
430
  logger.debug("Scheduler loop finished", extra=self._log_context())
419
431
 
420
- async def _perpetual_scheduling_loop(self, worker_stopping: asyncio.Event) -> None:
421
- """Loop that ensures that automatic perpetual tasks are always scheduled."""
422
-
423
- while not worker_stopping.is_set():
424
- minimum_interval = self.scheduling_resolution
432
+ async def _schedule_all_automatic_perpetual_tasks(self) -> None:
433
+ async with self.docket.redis() as redis:
425
434
  try:
426
- minimum_interval = await self._schedule_all_automatic_perpetual_tasks()
427
- except Exception: # pragma: no cover
428
- logger.exception(
429
- "Error in perpetual scheduling loop",
430
- exc_info=True,
431
- extra=self._log_context(),
432
- )
433
- finally:
434
- # Wait until just before the next time any task would need to be
435
- # scheduled (one scheduling_resolution before the lowest interval)
436
- interval = max(
437
- minimum_interval - self.scheduling_resolution,
438
- self.scheduling_resolution,
439
- )
440
- assert interval <= self.scheduling_resolution
441
- await asyncio.sleep(interval.total_seconds())
442
-
443
- async def _schedule_all_automatic_perpetual_tasks(self) -> timedelta:
444
- from .dependencies import Perpetual, get_single_dependency_parameter_of_type
445
-
446
- minimum_interval = self.scheduling_resolution
447
- for task_function in self.docket.tasks.values():
448
- perpetual = get_single_dependency_parameter_of_type(
449
- task_function, Perpetual
450
- )
451
- if perpetual is None:
452
- continue
453
-
454
- if not perpetual.automatic:
455
- continue
456
-
457
- key = task_function.__name__
458
- await self.docket.add(task_function, key=key)()
459
- minimum_interval = min(minimum_interval, perpetual.every)
435
+ async with redis.lock(
436
+ f"{self.docket.name}:perpetual:lock", timeout=10, blocking=False
437
+ ):
438
+ for task_function in self.docket.tasks.values():
439
+ perpetual = get_single_dependency_parameter_of_type(
440
+ task_function, Perpetual
441
+ )
442
+ if perpetual is None:
443
+ continue
460
444
 
461
- return minimum_interval
445
+ if not perpetual.automatic:
446
+ continue
462
447
 
463
- async def _execute(self, message: RedisMessage) -> None:
464
- key = message[b"key"].decode()
465
- async with self.docket.redis() as redis:
466
- await redis.delete(self.docket.known_task_key(key))
448
+ key = task_function.__name__
467
449
 
468
- log_context: Mapping[str, str | float] = self._log_context()
450
+ await self.docket.add(task_function, key=key)()
451
+ except LockError: # pragma: no cover
452
+ return
469
453
 
470
- function_name = message[b"function"].decode()
471
- function = self.docket.tasks.get(function_name)
472
- if function is None:
473
- logger.warning(
474
- "Task function %r not found", function_name, extra=log_context
475
- )
454
+ async def _delete_known_task(
455
+ self, redis: Redis, execution_or_message: Execution | RedisMessage
456
+ ) -> None:
457
+ if isinstance(execution_or_message, Execution):
458
+ key = execution_or_message.key
459
+ elif bytes_key := execution_or_message.get(b"key"):
460
+ key = bytes_key.decode()
461
+ else: # pragma: no cover
476
462
  return
477
463
 
478
- execution = Execution.from_message(function, message)
464
+ logger.debug("Deleting known task", extra=self._log_context())
465
+ known_task_key = self.docket.known_task_key(key)
466
+ await redis.delete(known_task_key)
479
467
 
480
- log_context = {**log_context, **execution.specific_labels()}
468
+ async def _execute(self, execution: Execution) -> None:
469
+ log_context = {**self._log_context(), **execution.specific_labels()}
481
470
  counter_labels = {**self.labels(), **execution.general_labels()}
482
471
 
483
472
  arrow = "↬" if execution.attempt > 1 else "↪"
484
473
  call = execution.call_repr()
485
474
 
486
475
  if self.docket.strike_list.is_stricken(execution):
476
+ async with self.docket.redis() as redis:
477
+ await self._delete_known_task(redis, execution)
478
+
487
479
  arrow = "🗙"
488
480
  logger.warning("%s %s", arrow, call, extra=log_context)
489
481
  TASKS_STRICKEN.add(1, counter_labels | {"docket.where": "worker"})
@@ -492,15 +484,16 @@ class Worker:
492
484
  if execution.key in self._execution_counts:
493
485
  self._execution_counts[execution.key] += 1
494
486
 
495
- dependencies = self._get_dependencies(execution)
496
-
497
- context = propagate.extract(message, getter=message_getter)
498
- initiating_context = trace.get_current_span(context).get_span_context()
487
+ initiating_span = trace.get_current_span(execution.trace_context)
488
+ initiating_context = initiating_span.get_span_context()
499
489
  links = [trace.Link(initiating_context)] if initiating_context.is_valid else []
500
490
 
501
491
  start = datetime.now(timezone.utc)
502
492
  punctuality = start - execution.when
503
- log_context = {**log_context, "punctuality": punctuality.total_seconds()}
493
+ log_context = {
494
+ **log_context,
495
+ "punctuality": punctuality.total_seconds(),
496
+ }
504
497
  duration = timedelta(0)
505
498
 
506
499
  TASKS_STARTED.add(1, counter_labels)
@@ -509,77 +502,103 @@ class Worker:
509
502
 
510
503
  logger.info("%s [%s] %s", arrow, punctuality, call, extra=log_context)
511
504
 
512
- try:
513
- with tracer.start_as_current_span(
514
- execution.function.__name__,
515
- kind=trace.SpanKind.CONSUMER,
516
- attributes={
517
- **self.labels(),
518
- **execution.specific_labels(),
519
- "code.function.name": execution.function.__name__,
520
- },
521
- links=links,
522
- ):
523
- await execution.function(
524
- *execution.args,
525
- **{
526
- **execution.kwargs,
527
- **dependencies,
528
- },
505
+ with tracer.start_as_current_span(
506
+ execution.function.__name__,
507
+ kind=trace.SpanKind.CONSUMER,
508
+ attributes={
509
+ **self.labels(),
510
+ **execution.specific_labels(),
511
+ "code.function.name": execution.function.__name__,
512
+ },
513
+ links=links,
514
+ ):
515
+ async with resolved_dependencies(self, execution) as dependencies:
516
+ # Preemptively reschedule the perpetual task for the future, or clear
517
+ # the known task key for this task
518
+ rescheduled = await self._perpetuate_if_requested(
519
+ execution, dependencies
529
520
  )
521
+ if not rescheduled:
522
+ async with self.docket.redis() as redis:
523
+ await self._delete_known_task(redis, execution)
524
+
525
+ try:
526
+ if timeout := get_single_dependency_of_type(dependencies, Timeout):
527
+ await self._run_function_with_timeout(
528
+ execution, dependencies, timeout
529
+ )
530
+ else:
531
+ await execution.function(
532
+ *execution.args,
533
+ **{
534
+ **execution.kwargs,
535
+ **dependencies,
536
+ },
537
+ )
530
538
 
531
- TASKS_SUCCEEDED.add(1, counter_labels)
532
- duration = datetime.now(timezone.utc) - start
533
- log_context["duration"] = duration.total_seconds()
534
- rescheduled = await self._perpetuate_if_requested(
535
- execution, dependencies, duration
536
- )
537
- arrow = "↫" if rescheduled else "↩"
538
- logger.info("%s [%s] %s", arrow, duration, call, extra=log_context)
539
- except Exception:
540
- TASKS_FAILED.add(1, counter_labels)
541
- duration = datetime.now(timezone.utc) - start
542
- log_context["duration"] = duration.total_seconds()
543
- retried = await self._retry_if_requested(execution, dependencies)
544
- if not retried:
545
- retried = await self._perpetuate_if_requested(
546
- execution, dependencies, duration
547
- )
548
- arrow = "↫" if retried else "↩"
549
- logger.exception("%s [%s] %s", arrow, duration, call, extra=log_context)
550
- finally:
551
- TASKS_RUNNING.add(-1, counter_labels)
552
- TASKS_COMPLETED.add(1, counter_labels)
553
- TASK_DURATION.record(duration.total_seconds(), counter_labels)
539
+ TASKS_SUCCEEDED.add(1, counter_labels)
540
+ duration = datetime.now(timezone.utc) - start
541
+ log_context["duration"] = duration.total_seconds()
542
+ rescheduled = await self._perpetuate_if_requested(
543
+ execution, dependencies, duration
544
+ )
545
+ arrow = "↫" if rescheduled else "↩"
546
+ logger.info("%s [%s] %s", arrow, duration, call, extra=log_context)
547
+ except Exception:
548
+ TASKS_FAILED.add(1, counter_labels)
549
+ duration = datetime.now(timezone.utc) - start
550
+ log_context["duration"] = duration.total_seconds()
551
+ retried = await self._retry_if_requested(execution, dependencies)
552
+ if not retried:
553
+ retried = await self._perpetuate_if_requested(
554
+ execution, dependencies, duration
555
+ )
556
+ arrow = "↫" if retried else "↩"
557
+ logger.exception(
558
+ "%s [%s] %s", arrow, duration, call, extra=log_context
559
+ )
560
+ finally:
561
+ TASKS_RUNNING.add(-1, counter_labels)
562
+ TASKS_COMPLETED.add(1, counter_labels)
563
+ TASK_DURATION.record(duration.total_seconds(), counter_labels)
554
564
 
555
- def _get_dependencies(
565
+ async def _run_function_with_timeout(
556
566
  self,
557
567
  execution: Execution,
558
- ) -> dict[str, "Dependency"]:
559
- from .dependencies import get_dependency_parameters
560
-
561
- parameters = get_dependency_parameters(execution.function)
562
-
563
- dependencies: dict[str, "Dependency"] = {}
564
-
565
- for parameter_name, dependency in parameters.items():
566
- # If the argument is already provided, skip it, which allows users to call
567
- # the function directly with the arguments they want.
568
- if parameter_name in execution.kwargs:
569
- dependencies[parameter_name] = execution.kwargs[parameter_name]
570
- continue
571
-
572
- dependencies[parameter_name] = dependency(self.docket, self, execution)
568
+ dependencies: dict[str, Dependency],
569
+ timeout: Timeout,
570
+ ) -> None:
571
+ task_coro = cast(
572
+ Coroutine[None, None, None],
573
+ execution.function(*execution.args, **execution.kwargs, **dependencies),
574
+ )
575
+ task = asyncio.create_task(task_coro)
576
+ try:
577
+ while not task.done(): # pragma: no branch
578
+ remaining = timeout.remaining().total_seconds()
579
+ if timeout.expired():
580
+ task.cancel()
581
+ break
582
+
583
+ try:
584
+ await asyncio.wait_for(asyncio.shield(task), timeout=remaining)
585
+ return
586
+ except asyncio.TimeoutError:
587
+ continue
588
+ finally:
589
+ if not task.done():
590
+ task.cancel()
573
591
 
574
- return dependencies
592
+ try:
593
+ await task
594
+ except asyncio.CancelledError:
595
+ raise asyncio.TimeoutError
575
596
 
576
597
  async def _retry_if_requested(
577
598
  self,
578
599
  execution: Execution,
579
- dependencies: dict[str, "Dependency"],
600
+ dependencies: dict[str, Dependency],
580
601
  ) -> bool:
581
- from .dependencies import Retry, get_single_dependency_of_type
582
-
583
602
  retry = get_single_dependency_of_type(dependencies, Retry)
584
603
  if not retry:
585
604
  return False
@@ -597,26 +616,28 @@ class Worker:
597
616
  async def _perpetuate_if_requested(
598
617
  self,
599
618
  execution: Execution,
600
- dependencies: dict[str, "Dependency"],
601
- duration: timedelta,
619
+ dependencies: dict[str, Dependency],
620
+ duration: timedelta | None = None,
602
621
  ) -> bool:
603
- from .dependencies import Perpetual, get_single_dependency_of_type
604
-
605
622
  perpetual = get_single_dependency_of_type(dependencies, Perpetual)
606
623
  if not perpetual:
607
624
  return False
608
625
 
609
626
  if perpetual.cancelled:
627
+ await self.docket.cancel(execution.key)
610
628
  return False
611
629
 
612
630
  now = datetime.now(timezone.utc)
613
- execution.when = max(now, now + perpetual.every - duration)
614
- execution.args = perpetual.args
615
- execution.kwargs = perpetual.kwargs
631
+ when = max(now, now + perpetual.every - (duration or timedelta(0)))
632
+
633
+ await self.docket.replace(execution.function, when, execution.key)(
634
+ *perpetual.args,
635
+ **perpetual.kwargs,
636
+ )
616
637
 
617
- await self.docket.schedule(execution)
638
+ if duration is not None:
639
+ TASKS_PERPETUATED.add(1, {**self.labels(), **execution.specific_labels()})
618
640
 
619
- TASKS_PERPETUATED.add(1, {**self.labels(), **execution.specific_labels()})
620
641
  return True
621
642
 
622
643
  @property
@@ -676,7 +697,7 @@ class Worker:
676
697
 
677
698
  except asyncio.CancelledError: # pragma: no cover
678
699
  return
679
- except redis.exceptions.ConnectionError:
700
+ except ConnectionError:
680
701
  REDIS_DISRUPTIONS.add(1, self.labels())
681
702
  logger.exception(
682
703
  "Error sending worker heartbeat",