pydocket 0.7.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydocket might be problematic. Click here for more details.

docket/__init__.py CHANGED
@@ -10,6 +10,7 @@ __version__ = version("pydocket")
10
10
 
11
11
  from .annotations import Logged
12
12
  from .dependencies import (
13
+ ConcurrencyLimit,
13
14
  CurrentDocket,
14
15
  CurrentExecution,
15
16
  CurrentWorker,
@@ -28,6 +29,7 @@ from .worker import Worker
28
29
 
29
30
  __all__ = [
30
31
  "__version__",
32
+ "ConcurrencyLimit",
31
33
  "CurrentDocket",
32
34
  "CurrentExecution",
33
35
  "CurrentWorker",
docket/cli.py CHANGED
@@ -594,6 +594,73 @@ def relative_time(now: datetime, when: datetime) -> str:
594
594
  return f"at {local_time(when)}"
595
595
 
596
596
 
597
+ def get_task_stats(
598
+ snapshot: DocketSnapshot,
599
+ ) -> dict[str, dict[str, int | datetime | None]]:
600
+ """Get task count statistics by function name with timestamp data."""
601
+ stats: dict[str, dict[str, int | datetime | None]] = {}
602
+
603
+ # Count running tasks by function
604
+ for execution in snapshot.running:
605
+ func_name = execution.function.__name__
606
+ if func_name not in stats:
607
+ stats[func_name] = {
608
+ "running": 0,
609
+ "queued": 0,
610
+ "total": 0,
611
+ "oldest_queued": None,
612
+ "latest_queued": None,
613
+ "oldest_started": None,
614
+ "latest_started": None,
615
+ }
616
+ stats[func_name]["running"] += 1
617
+ stats[func_name]["total"] += 1
618
+
619
+ # Track oldest/latest started times for running tasks
620
+ started = execution.started
621
+ if (
622
+ stats[func_name]["oldest_started"] is None
623
+ or started < stats[func_name]["oldest_started"]
624
+ ):
625
+ stats[func_name]["oldest_started"] = started
626
+ if (
627
+ stats[func_name]["latest_started"] is None
628
+ or started > stats[func_name]["latest_started"]
629
+ ):
630
+ stats[func_name]["latest_started"] = started
631
+
632
+ # Count future tasks by function
633
+ for execution in snapshot.future:
634
+ func_name = execution.function.__name__
635
+ if func_name not in stats:
636
+ stats[func_name] = {
637
+ "running": 0,
638
+ "queued": 0,
639
+ "total": 0,
640
+ "oldest_queued": None,
641
+ "latest_queued": None,
642
+ "oldest_started": None,
643
+ "latest_started": None,
644
+ }
645
+ stats[func_name]["queued"] += 1
646
+ stats[func_name]["total"] += 1
647
+
648
+ # Track oldest/latest queued times for future tasks
649
+ when = execution.when
650
+ if (
651
+ stats[func_name]["oldest_queued"] is None
652
+ or when < stats[func_name]["oldest_queued"]
653
+ ):
654
+ stats[func_name]["oldest_queued"] = when
655
+ if (
656
+ stats[func_name]["latest_queued"] is None
657
+ or when > stats[func_name]["latest_queued"]
658
+ ):
659
+ stats[func_name]["latest_queued"] = when
660
+
661
+ return stats
662
+
663
+
597
664
  @app.command(help="Shows a snapshot of what's on the docket right now")
598
665
  def snapshot(
599
666
  tasks: Annotated[
@@ -623,6 +690,13 @@ def snapshot(
623
690
  envvar="DOCKET_URL",
624
691
  ),
625
692
  ] = "redis://localhost:6379/0",
693
+ stats: Annotated[
694
+ bool,
695
+ typer.Option(
696
+ "--stats",
697
+ help="Show task count statistics by function name",
698
+ ),
699
+ ] = False,
626
700
  ) -> None:
627
701
  async def run() -> DocketSnapshot:
628
702
  async with Docket(name=docket_, url=url) as docket:
@@ -672,6 +746,44 @@ def snapshot(
672
746
 
673
747
  console.print(table)
674
748
 
749
+ # Display task statistics if requested
750
+ if stats:
751
+ task_stats = get_task_stats(snapshot)
752
+ if task_stats:
753
+ console.print() # Add spacing between tables
754
+ stats_table = Table(title="Task Count Statistics by Function")
755
+ stats_table.add_column("Function", style="cyan")
756
+ stats_table.add_column("Total", style="bold magenta", justify="right")
757
+ stats_table.add_column("Running", style="green", justify="right")
758
+ stats_table.add_column("Queued", style="yellow", justify="right")
759
+ stats_table.add_column("Oldest Queued", style="dim yellow", justify="right")
760
+ stats_table.add_column("Latest Queued", style="dim yellow", justify="right")
761
+
762
+ # Sort by total count descending to highlight potential runaway tasks
763
+ for func_name in sorted(
764
+ task_stats.keys(), key=lambda x: task_stats[x]["total"], reverse=True
765
+ ):
766
+ counts = task_stats[func_name]
767
+
768
+ # Format timestamp columns
769
+ oldest_queued = ""
770
+ latest_queued = ""
771
+ if counts["oldest_queued"] is not None:
772
+ oldest_queued = relative(counts["oldest_queued"])
773
+ if counts["latest_queued"] is not None:
774
+ latest_queued = relative(counts["latest_queued"])
775
+
776
+ stats_table.add_row(
777
+ func_name,
778
+ str(counts["total"]),
779
+ str(counts["running"]),
780
+ str(counts["queued"]),
781
+ oldest_queued,
782
+ latest_queued,
783
+ )
784
+
785
+ console.print(stats_table)
786
+
675
787
 
676
788
  workers_app: typer.Typer = typer.Typer(
677
789
  help="Look at the workers on a docket", no_args_is_help=True
docket/dependencies.py CHANGED
@@ -39,9 +39,9 @@ class Dependency(abc.ABC):
39
39
 
40
40
  async def __aexit__(
41
41
  self,
42
- exc_type: type[BaseException] | None,
43
- exc_value: BaseException | None,
44
- traceback: TracebackType | None,
42
+ _exc_type: type[BaseException] | None,
43
+ _exc_value: BaseException | None,
44
+ _traceback: TracebackType | None,
45
45
  ) -> bool: ... # pragma: no cover
46
46
 
47
47
 
@@ -505,6 +505,92 @@ def Depends(dependency: DependencyFunction[R]) -> R:
505
505
  return cast(R, _Depends(dependency))
506
506
 
507
507
 
508
+ class ConcurrencyLimit(Dependency):
509
+ """Configures concurrency limits for a task based on specific argument values.
510
+
511
+ This allows fine-grained control over task execution by limiting concurrent
512
+ tasks based on the value of specific arguments.
513
+
514
+ Example:
515
+
516
+ ```python
517
+ async def process_customer(
518
+ customer_id: int,
519
+ concurrency: ConcurrencyLimit = ConcurrencyLimit("customer_id", max_concurrent=1)
520
+ ) -> None:
521
+ # Only one task per customer_id will run at a time
522
+ ...
523
+
524
+ async def backup_db(
525
+ db_name: str,
526
+ concurrency: ConcurrencyLimit = ConcurrencyLimit("db_name", max_concurrent=3)
527
+ ) -> None:
528
+ # Only 3 backup tasks per database name will run at a time
529
+ ...
530
+ ```
531
+ """
532
+
533
+ single: bool = True
534
+
535
+ def __init__(
536
+ self, argument_name: str, max_concurrent: int = 1, scope: str | None = None
537
+ ) -> None:
538
+ """
539
+ Args:
540
+ argument_name: The name of the task argument to use for concurrency grouping
541
+ max_concurrent: Maximum number of concurrent tasks per unique argument value
542
+ scope: Optional scope prefix for Redis keys (defaults to docket name)
543
+ """
544
+ self.argument_name = argument_name
545
+ self.max_concurrent = max_concurrent
546
+ self.scope = scope
547
+ self._concurrency_key: str | None = None
548
+ self._initialized: bool = False
549
+
550
+ async def __aenter__(self) -> "ConcurrencyLimit":
551
+ execution = self.execution.get()
552
+ docket = self.docket.get()
553
+
554
+ # Get the argument value to group by
555
+ try:
556
+ argument_value = execution.get_argument(self.argument_name)
557
+ except KeyError:
558
+ # If argument not found, create a bypass limit that doesn't apply concurrency control
559
+ limit = ConcurrencyLimit(
560
+ self.argument_name, self.max_concurrent, self.scope
561
+ )
562
+ limit._concurrency_key = None # Special marker for bypassed concurrency
563
+ limit._initialized = True # Mark as initialized but bypassed
564
+ return limit
565
+
566
+ # Create a concurrency key for this specific argument value
567
+ scope = self.scope or docket.name
568
+ self._concurrency_key = (
569
+ f"{scope}:concurrency:{self.argument_name}:{argument_value}"
570
+ )
571
+
572
+ limit = ConcurrencyLimit(self.argument_name, self.max_concurrent, self.scope)
573
+ limit._concurrency_key = self._concurrency_key
574
+ limit._initialized = True # Mark as initialized
575
+ return limit
576
+
577
+ @property
578
+ def concurrency_key(self) -> str | None:
579
+ """Redis key used for tracking concurrency for this specific argument value.
580
+ Returns None when concurrency control is bypassed due to missing arguments.
581
+ Raises RuntimeError if accessed before initialization."""
582
+ if not self._initialized:
583
+ raise RuntimeError(
584
+ "ConcurrencyLimit not initialized - use within task context"
585
+ )
586
+ return self._concurrency_key
587
+
588
+ @property
589
+ def is_bypassed(self) -> bool:
590
+ """Returns True if concurrency control is bypassed due to missing arguments."""
591
+ return self._initialized and self._concurrency_key is None
592
+
593
+
508
594
  D = TypeVar("D", bound=Dependency)
509
595
 
510
596
 
docket/docket.py CHANGED
@@ -16,6 +16,7 @@ from typing import (
16
16
  Mapping,
17
17
  NoReturn,
18
18
  ParamSpec,
19
+ Protocol,
19
20
  Self,
20
21
  Sequence,
21
22
  TypedDict,
@@ -27,7 +28,6 @@ from typing import (
27
28
  import redis.exceptions
28
29
  from opentelemetry import propagate, trace
29
30
  from redis.asyncio import ConnectionPool, Redis
30
- from redis.asyncio.client import Pipeline
31
31
  from uuid_extensions import uuid7
32
32
 
33
33
  from .execution import (
@@ -55,6 +55,18 @@ logger: logging.Logger = logging.getLogger(__name__)
55
55
  tracer: trace.Tracer = trace.get_tracer(__name__)
56
56
 
57
57
 
58
+ class _schedule_task(Protocol):
59
+ async def __call__(
60
+ self, keys: list[str], args: list[str | float | bytes]
61
+ ) -> str: ... # pragma: no cover
62
+
63
+
64
+ class _cancel_task(Protocol):
65
+ async def __call__(
66
+ self, keys: list[str], args: list[str]
67
+ ) -> str: ... # pragma: no cover
68
+
69
+
58
70
  P = ParamSpec("P")
59
71
  R = TypeVar("R")
60
72
 
@@ -131,6 +143,8 @@ class Docket:
131
143
 
132
144
  _monitor_strikes_task: asyncio.Task[None]
133
145
  _connection_pool: ConnectionPool
146
+ _schedule_task_script: _schedule_task | None
147
+ _cancel_task_script: _cancel_task | None
134
148
 
135
149
  def __init__(
136
150
  self,
@@ -156,6 +170,8 @@ class Docket:
156
170
  self.url = url
157
171
  self.heartbeat_interval = heartbeat_interval
158
172
  self.missed_heartbeats = missed_heartbeats
173
+ self._schedule_task_script = None
174
+ self._cancel_task_script = None
159
175
 
160
176
  @property
161
177
  def worker_group_name(self) -> str:
@@ -300,9 +316,7 @@ class Docket:
300
316
  execution = Execution(function, args, kwargs, when, key, attempt=1)
301
317
 
302
318
  async with self.redis() as redis:
303
- async with redis.pipeline() as pipeline:
304
- await self._schedule(redis, pipeline, execution, replace=False)
305
- await pipeline.execute()
319
+ await self._schedule(redis, execution, replace=False)
306
320
 
307
321
  TASKS_ADDED.add(1, {**self.labels(), **execution.general_labels()})
308
322
  TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
@@ -361,9 +375,7 @@ class Docket:
361
375
  execution = Execution(function, args, kwargs, when, key, attempt=1)
362
376
 
363
377
  async with self.redis() as redis:
364
- async with redis.pipeline() as pipeline:
365
- await self._schedule(redis, pipeline, execution, replace=True)
366
- await pipeline.execute()
378
+ await self._schedule(redis, execution, replace=True)
367
379
 
368
380
  TASKS_REPLACED.add(1, {**self.labels(), **execution.general_labels()})
369
381
  TASKS_CANCELLED.add(1, {**self.labels(), **execution.general_labels()})
@@ -383,9 +395,7 @@ class Docket:
383
395
  },
384
396
  ):
385
397
  async with self.redis() as redis:
386
- async with redis.pipeline() as pipeline:
387
- await self._schedule(redis, pipeline, execution, replace=False)
388
- await pipeline.execute()
398
+ await self._schedule(redis, execution, replace=False)
389
399
 
390
400
  TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
391
401
 
@@ -400,9 +410,7 @@ class Docket:
400
410
  attributes={**self.labels(), "docket.key": key},
401
411
  ):
402
412
  async with self.redis() as redis:
403
- async with redis.pipeline() as pipeline:
404
- await self._cancel(pipeline, key)
405
- await pipeline.execute()
413
+ await self._cancel(redis, key)
406
414
 
407
415
  TASKS_CANCELLED.add(1, self.labels())
408
416
 
@@ -423,10 +431,17 @@ class Docket:
423
431
  async def _schedule(
424
432
  self,
425
433
  redis: Redis,
426
- pipeline: Pipeline,
427
434
  execution: Execution,
428
435
  replace: bool = False,
429
436
  ) -> None:
437
+ """Schedule a task atomically.
438
+
439
+ Handles:
440
+ - Checking for task existence
441
+ - Cancelling existing tasks when replacing
442
+ - Adding tasks to stream (immediate) or queue (future)
443
+ - Tracking stream message IDs for later cancellation
444
+ """
430
445
  if self.strike_list.is_stricken(execution):
431
446
  logger.warning(
432
447
  "%r is stricken, skipping schedule of %r",
@@ -449,32 +464,133 @@ class Docket:
449
464
  key = execution.key
450
465
  when = execution.when
451
466
  known_task_key = self.known_task_key(key)
467
+ is_immediate = when <= datetime.now(timezone.utc)
452
468
 
469
+ # Lock per task key to prevent race conditions between concurrent operations
453
470
  async with redis.lock(f"{known_task_key}:lock", timeout=10):
454
- if replace:
455
- await self._cancel(pipeline, key)
456
- else:
457
- # if the task is already in the queue or stream, retain it
458
- if await redis.exists(known_task_key):
459
- logger.debug(
460
- "Task %r is already in the queue or stream, not scheduling",
461
- key,
462
- extra=self.labels(),
463
- )
464
- return
471
+ if self._schedule_task_script is None:
472
+ self._schedule_task_script = cast(
473
+ _schedule_task,
474
+ redis.register_script(
475
+ # KEYS: stream_key, known_key, parked_key, queue_key
476
+ # ARGV: task_key, when_timestamp, is_immediate, replace, ...message_fields
477
+ """
478
+ local stream_key = KEYS[1]
479
+ local known_key = KEYS[2]
480
+ local parked_key = KEYS[3]
481
+ local queue_key = KEYS[4]
482
+
483
+ local task_key = ARGV[1]
484
+ local when_timestamp = ARGV[2]
485
+ local is_immediate = ARGV[3] == '1'
486
+ local replace = ARGV[4] == '1'
487
+
488
+ -- Extract message fields from ARGV[5] onwards
489
+ local message = {}
490
+ for i = 5, #ARGV, 2 do
491
+ message[#message + 1] = ARGV[i] -- field name
492
+ message[#message + 1] = ARGV[i + 1] -- field value
493
+ end
494
+
495
+ -- Handle replacement: cancel existing task if needed
496
+ if replace then
497
+ local existing_message_id = redis.call('HGET', known_key, 'stream_message_id')
498
+ if existing_message_id then
499
+ redis.call('XDEL', stream_key, existing_message_id)
500
+ end
501
+ redis.call('DEL', known_key, parked_key)
502
+ redis.call('ZREM', queue_key, task_key)
503
+ else
504
+ -- Check if task already exists
505
+ if redis.call('EXISTS', known_key) == 1 then
506
+ return 'EXISTS'
507
+ end
508
+ end
509
+
510
+ if is_immediate then
511
+ -- Add to stream and store message ID for later cancellation
512
+ local message_id = redis.call('XADD', stream_key, '*', unpack(message))
513
+ redis.call('HSET', known_key, 'when', when_timestamp, 'stream_message_id', message_id)
514
+ return message_id
515
+ else
516
+ -- Add to queue with task data in parked hash
517
+ redis.call('HSET', known_key, 'when', when_timestamp)
518
+ redis.call('HSET', parked_key, unpack(message))
519
+ redis.call('ZADD', queue_key, when_timestamp, task_key)
520
+ return 'QUEUED'
521
+ end
522
+ """
523
+ ),
524
+ )
525
+ schedule_task = self._schedule_task_script
465
526
 
466
- pipeline.set(known_task_key, when.timestamp())
527
+ await schedule_task(
528
+ keys=[
529
+ self.stream_key,
530
+ known_task_key,
531
+ self.parked_task_key(key),
532
+ self.queue_key,
533
+ ],
534
+ args=[
535
+ key,
536
+ str(when.timestamp()),
537
+ "1" if is_immediate else "0",
538
+ "1" if replace else "0",
539
+ *[
540
+ item
541
+ for field, value in message.items()
542
+ for item in (field, value)
543
+ ],
544
+ ],
545
+ )
467
546
 
468
- if when <= datetime.now(timezone.utc):
469
- pipeline.xadd(self.stream_key, message) # type: ignore[arg-type]
470
- else:
471
- pipeline.hset(self.parked_task_key(key), mapping=message) # type: ignore[arg-type]
472
- pipeline.zadd(self.queue_key, {key: when.timestamp()})
547
+ async def _cancel(self, redis: Redis, key: str) -> None:
548
+ """Cancel a task atomically.
473
549
 
474
- async def _cancel(self, pipeline: Pipeline, key: str) -> None:
475
- pipeline.delete(self.known_task_key(key))
476
- pipeline.delete(self.parked_task_key(key))
477
- pipeline.zrem(self.queue_key, key)
550
+ Handles cancellation regardless of task location:
551
+ - From the stream (using stored message ID)
552
+ - From the queue (scheduled tasks)
553
+ - Cleans up all associated metadata keys
554
+ """
555
+ if self._cancel_task_script is None:
556
+ self._cancel_task_script = cast(
557
+ _cancel_task,
558
+ redis.register_script(
559
+ # KEYS: stream_key, known_key, parked_key, queue_key
560
+ # ARGV: task_key
561
+ """
562
+ local stream_key = KEYS[1]
563
+ local known_key = KEYS[2]
564
+ local parked_key = KEYS[3]
565
+ local queue_key = KEYS[4]
566
+ local task_key = ARGV[1]
567
+
568
+ -- Delete from stream if message ID exists
569
+ local message_id = redis.call('HGET', known_key, 'stream_message_id')
570
+ if message_id then
571
+ redis.call('XDEL', stream_key, message_id)
572
+ end
573
+
574
+ -- Clean up all task-related keys
575
+ redis.call('DEL', known_key, parked_key)
576
+ redis.call('ZREM', queue_key, task_key)
577
+
578
+ return 'OK'
579
+ """
580
+ ),
581
+ )
582
+ cancel_task = self._cancel_task_script
583
+
584
+ # Execute the cancellation script
585
+ await cancel_task(
586
+ keys=[
587
+ self.stream_key,
588
+ self.known_task_key(key),
589
+ self.parked_task_key(key),
590
+ self.queue_key,
591
+ ],
592
+ args=[key],
593
+ )
478
594
 
479
595
  @property
480
596
  def strike_key(self) -> str:
docket/worker.py CHANGED
@@ -20,6 +20,7 @@ from redis.asyncio import Redis
20
20
  from redis.exceptions import ConnectionError, LockError
21
21
 
22
22
  from .dependencies import (
23
+ ConcurrencyLimit,
23
24
  Dependency,
24
25
  FailedDependency,
25
26
  Perpetual,
@@ -248,6 +249,7 @@ class Worker:
248
249
  )
249
250
 
250
251
  active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
252
+ task_executions: dict[asyncio.Task[None], Execution] = {}
251
253
  available_slots = self.concurrency
252
254
 
253
255
  log_context = self._log_context()
@@ -298,6 +300,7 @@ class Worker:
298
300
 
299
301
  task = asyncio.create_task(self._execute(execution), name=execution.key)
300
302
  active_tasks[task] = message_id
303
+ task_executions[task] = execution
301
304
 
302
305
  nonlocal available_slots
303
306
  available_slots -= 1
@@ -308,6 +311,7 @@ class Worker:
308
311
  completed_tasks = {task for task in active_tasks if task.done()}
309
312
  for task in completed_tasks:
310
313
  message_id = active_tasks.pop(task)
314
+ task_executions.pop(task)
311
315
  await task
312
316
  await ack_message(redis, message_id)
313
317
 
@@ -343,7 +347,9 @@ class Worker:
343
347
  if not message: # pragma: no cover
344
348
  continue
345
349
 
346
- if not start_task(message_id, message):
350
+ task_started = start_task(message_id, message)
351
+ if not task_started:
352
+ # Other errors - delete and ack
347
353
  await self._delete_known_task(redis, message)
348
354
  await ack_message(redis, message_id)
349
355
 
@@ -400,7 +406,7 @@ class Worker:
400
406
  task[task_data[j]] = task_data[j+1]
401
407
  end
402
408
 
403
- redis.call('XADD', KEYS[2], '*',
409
+ local message_id = redis.call('XADD', KEYS[2], '*',
404
410
  'key', task['key'],
405
411
  'when', task['when'],
406
412
  'function', task['function'],
@@ -408,6 +414,9 @@ class Worker:
408
414
  'kwargs', task['kwargs'],
409
415
  'attempt', task['attempt']
410
416
  )
417
+ -- Store the message ID in the known task key
418
+ local known_key = ARGV[2] .. ":known:" .. key
419
+ redis.call('HSET', known_key, 'stream_message_id', message_id)
411
420
  redis.call('DEL', hash_key)
412
421
  due_work = due_work + 1
413
422
  end
@@ -534,6 +543,32 @@ class Worker:
534
543
  ) as span:
535
544
  try:
536
545
  async with resolved_dependencies(self, execution) as dependencies:
546
+ # Check concurrency limits after dependency resolution
547
+ concurrency_limit = get_single_dependency_of_type(
548
+ dependencies, ConcurrencyLimit
549
+ )
550
+ if concurrency_limit and not concurrency_limit.is_bypassed:
551
+ async with self.docket.redis() as redis:
552
+ # Check if we can acquire a concurrency slot
553
+ if not await self._can_start_task(redis, execution):
554
+ # Task cannot start due to concurrency limits - reschedule
555
+ logger.debug(
556
+ "🔒 Task %s blocked by concurrency limit, rescheduling",
557
+ execution.key,
558
+ extra=log_context,
559
+ )
560
+ # Reschedule for a few milliseconds in the future
561
+ when = datetime.now(timezone.utc) + timedelta(
562
+ milliseconds=50
563
+ )
564
+ await self.docket.add(execution.function, when=when)(
565
+ *execution.args, **execution.kwargs
566
+ )
567
+ return
568
+ else:
569
+ # Successfully acquired slot
570
+ pass
571
+
537
572
  # Preemptively reschedule the perpetual task for the future, or clear
538
573
  # the known task key for this task
539
574
  rescheduled = await self._perpetuate_if_requested(
@@ -560,17 +595,34 @@ class Worker:
560
595
  ],
561
596
  )
562
597
 
563
- if timeout := get_single_dependency_of_type(dependencies, Timeout):
564
- await self._run_function_with_timeout(
565
- execution, dependencies, timeout
566
- )
598
+ # Apply timeout logic - either user's timeout or redelivery timeout
599
+ user_timeout = get_single_dependency_of_type(dependencies, Timeout)
600
+ if user_timeout:
601
+ # If user timeout is longer than redelivery timeout, limit it
602
+ if user_timeout.base > self.redelivery_timeout:
603
+ # Create a new timeout limited by redelivery timeout
604
+ # Remove the user timeout from dependencies to avoid conflicts
605
+ limited_dependencies = {
606
+ k: v
607
+ for k, v in dependencies.items()
608
+ if not isinstance(v, Timeout)
609
+ }
610
+ limited_timeout = Timeout(self.redelivery_timeout)
611
+ limited_timeout.start()
612
+ await self._run_function_with_timeout(
613
+ execution, limited_dependencies, limited_timeout
614
+ )
615
+ else:
616
+ # User timeout is within redelivery timeout, use as-is
617
+ await self._run_function_with_timeout(
618
+ execution, dependencies, user_timeout
619
+ )
567
620
  else:
568
- await execution.function(
569
- *execution.args,
570
- **{
571
- **execution.kwargs,
572
- **dependencies,
573
- },
621
+ # No user timeout - apply redelivery timeout as hard limit
622
+ redelivery_timeout = Timeout(self.redelivery_timeout)
623
+ redelivery_timeout.start()
624
+ await self._run_function_with_timeout(
625
+ execution, dependencies, redelivery_timeout
574
626
  )
575
627
 
576
628
  duration = log_context["duration"] = time.time() - start
@@ -604,6 +656,15 @@ class Worker:
604
656
  "%s [%s] %s", arrow, ms(duration), call, extra=log_context
605
657
  )
606
658
  finally:
659
+ # Release concurrency slot if we acquired one
660
+ if dependencies:
661
+ concurrency_limit = get_single_dependency_of_type(
662
+ dependencies, ConcurrencyLimit
663
+ )
664
+ if concurrency_limit and not concurrency_limit.is_bypassed:
665
+ async with self.docket.redis() as redis:
666
+ await self._release_concurrency_slot(redis, execution)
667
+
607
668
  TASKS_RUNNING.add(-1, counter_labels)
608
669
  TASKS_COMPLETED.add(1, counter_labels)
609
670
  TASK_DURATION.record(duration, counter_labels)
@@ -616,7 +677,13 @@ class Worker:
616
677
  ) -> None:
617
678
  task_coro = cast(
618
679
  Coroutine[None, None, None],
619
- execution.function(*execution.args, **execution.kwargs, **dependencies),
680
+ execution.function(
681
+ *execution.args,
682
+ **{
683
+ **execution.kwargs,
684
+ **dependencies,
685
+ },
686
+ ),
620
687
  )
621
688
  task = asyncio.create_task(task_coro)
622
689
  try:
@@ -762,6 +829,95 @@ class Worker:
762
829
  extra=self._log_context(),
763
830
  )
764
831
 
832
+ async def _can_start_task(self, redis: Redis, execution: Execution) -> bool:
833
+ """Check if a task can start based on concurrency limits."""
834
+ # Check if task has a concurrency limit dependency
835
+ concurrency_limit = get_single_dependency_parameter_of_type(
836
+ execution.function, ConcurrencyLimit
837
+ )
838
+
839
+ if not concurrency_limit:
840
+ return True # No concurrency limit, can always start
841
+
842
+ # Get the concurrency key for this task
843
+ try:
844
+ argument_value = execution.get_argument(concurrency_limit.argument_name)
845
+ except KeyError:
846
+ # If argument not found, let the task fail naturally in execution
847
+ return True
848
+
849
+ scope = concurrency_limit.scope or self.docket.name
850
+ concurrency_key = (
851
+ f"{scope}:concurrency:{concurrency_limit.argument_name}:{argument_value}"
852
+ )
853
+
854
+ # Use Redis sorted set with timestamps to track concurrency and handle expiration
855
+ lua_script = """
856
+ local key = KEYS[1]
857
+ local max_concurrent = tonumber(ARGV[1])
858
+ local worker_id = ARGV[2]
859
+ local task_key = ARGV[3]
860
+ local current_time = tonumber(ARGV[4])
861
+ local expiration_time = tonumber(ARGV[5])
862
+
863
+ -- Remove expired entries
864
+ local expired_cutoff = current_time - expiration_time
865
+ redis.call('ZREMRANGEBYSCORE', key, 0, expired_cutoff)
866
+
867
+ -- Get current count
868
+ local current = redis.call('ZCARD', key)
869
+
870
+ if current < max_concurrent then
871
+ -- Add this worker's task to the sorted set with current timestamp
872
+ redis.call('ZADD', key, current_time, worker_id .. ':' .. task_key)
873
+ return 1
874
+ else
875
+ return 0
876
+ end
877
+ """
878
+
879
+ current_time = datetime.now(timezone.utc).timestamp()
880
+ expiration_seconds = self.redelivery_timeout.total_seconds()
881
+
882
+ result = await redis.eval( # type: ignore
883
+ lua_script,
884
+ 1,
885
+ concurrency_key,
886
+ str(concurrency_limit.max_concurrent),
887
+ self.name,
888
+ execution.key,
889
+ current_time,
890
+ expiration_seconds,
891
+ )
892
+
893
+ return bool(result)
894
+
895
+ async def _release_concurrency_slot(
896
+ self, redis: Redis, execution: Execution
897
+ ) -> None:
898
+ """Release a concurrency slot when task completes."""
899
+ # Check if task has a concurrency limit dependency
900
+ concurrency_limit = get_single_dependency_parameter_of_type(
901
+ execution.function, ConcurrencyLimit
902
+ )
903
+
904
+ if not concurrency_limit:
905
+ return # No concurrency limit to release
906
+
907
+ # Get the concurrency key for this task
908
+ try:
909
+ argument_value = execution.get_argument(concurrency_limit.argument_name)
910
+ except KeyError:
911
+ return # If argument not found, nothing to release
912
+
913
+ scope = concurrency_limit.scope or self.docket.name
914
+ concurrency_key = (
915
+ f"{scope}:concurrency:{concurrency_limit.argument_name}:{argument_value}"
916
+ )
917
+
918
+ # Remove this worker's task from the sorted set
919
+ await redis.zrem(concurrency_key, f"{self.name}:{execution.key}") # type: ignore
920
+
765
921
 
766
922
  def ms(seconds: float) -> str:
767
923
  if seconds < 100:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydocket
3
- Version: 0.7.1
3
+ Version: 0.9.0
4
4
  Summary: A distributed background task system for Python functions
5
5
  Project-URL: Homepage, https://github.com/chrisguidry/docket
6
6
  Project-URL: Bug Tracker, https://github.com/chrisguidry/docket/issues
@@ -93,6 +93,8 @@ reference](https://chrisguidry.github.io/docket/api-reference/).
93
93
 
94
94
  🧩 Fully type-complete and type-aware for your background task functions
95
95
 
96
+ 💉 Dependency injection like FastAPI, Typer, and FastMCP for reusable resources
97
+
96
98
  ## Installing `docket`
97
99
 
98
100
  Docket is [available on PyPI](https://pypi.org/project/pydocket/) under the package name
@@ -0,0 +1,16 @@
1
+ docket/__init__.py,sha256=onwZzh73tESWoFBukbcW-7gjxoXb-yI7dutRD7tPN6g,915
2
+ docket/__main__.py,sha256=wcCrL4PjG51r5wVKqJhcoJPTLfHW0wNbD31DrUN0MWI,28
3
+ docket/annotations.py,sha256=wttix9UOeMFMAWXAIJUfUw5GjESJZsACb4YXJCozP7Q,2348
4
+ docket/cli.py,sha256=rTfri2--u4Q5PlXyh7Ub_F5uh3-TtZOWLUp9WY_TvAE,25750
5
+ docket/dependencies.py,sha256=BC0bnt10cr9_S1p5JAP_bnC9RwZkTr9ulPBrxC7eZnA,20247
6
+ docket/docket.py,sha256=0nQCHDDHy7trv2a0eYygGgIKiA7fWq5GcOXye3_CPWM,30847
7
+ docket/execution.py,sha256=r_2RGC1qhtAcBUg7E6wewLEgftrf3hIxNbH0HnYPbek,14961
8
+ docket/instrumentation.py,sha256=ogvzrfKbWsdPGfdg4hByH3_r5d3b5AwwQkSrmXw0hRg,5492
9
+ docket/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ docket/tasks.py,sha256=RIlSM2omh-YDwVnCz6M5MtmK8T_m_s1w2OlRRxDUs6A,1437
11
+ docket/worker.py,sha256=jqVYqtQyxbk-BIy3shY8haX-amVT9Np97VhJuaQTfpM,35174
12
+ pydocket-0.9.0.dist-info/METADATA,sha256=kymp9PKG7UwMj0i0qGSSCHKu-g-tS__qydr6mYuMLtg,5418
13
+ pydocket-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
+ pydocket-0.9.0.dist-info/entry_points.txt,sha256=4WOk1nUlBsUT5O3RyMci2ImuC5XFswuopElYcLHtD5k,47
15
+ pydocket-0.9.0.dist-info/licenses/LICENSE,sha256=YuVWU_ZXO0K_k2FG8xWKe5RGxV24AhJKTvQmKfqXuyk,1087
16
+ pydocket-0.9.0.dist-info/RECORD,,
@@ -1,16 +0,0 @@
1
- docket/__init__.py,sha256=sY1T_NVsXQNOmOhOnfYmZ95dcE_52Ov6DSIVIMZp-1w,869
2
- docket/__main__.py,sha256=wcCrL4PjG51r5wVKqJhcoJPTLfHW0wNbD31DrUN0MWI,28
3
- docket/annotations.py,sha256=wttix9UOeMFMAWXAIJUfUw5GjESJZsACb4YXJCozP7Q,2348
4
- docket/cli.py,sha256=XG_mbjcqNRO0F0hh6l3AwH9bIZv9xJofZaeaAj9nChc,21608
5
- docket/dependencies.py,sha256=GBwyEY198JFrfm7z5GkLbd84hv7sJktKBMJXv4veWig,17007
6
- docket/docket.py,sha256=Cw7QB1d0eDwSgwn0Rj26WjFsXSe7MJtfsUBBHGalL7A,26262
7
- docket/execution.py,sha256=r_2RGC1qhtAcBUg7E6wewLEgftrf3hIxNbH0HnYPbek,14961
8
- docket/instrumentation.py,sha256=ogvzrfKbWsdPGfdg4hByH3_r5d3b5AwwQkSrmXw0hRg,5492
9
- docket/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- docket/tasks.py,sha256=RIlSM2omh-YDwVnCz6M5MtmK8T_m_s1w2OlRRxDUs6A,1437
11
- docket/worker.py,sha256=CY5Z9p8FZw-6WUwp7Ws4A0V7IFTmonSnBmYP-Cp8Fdw,28079
12
- pydocket-0.7.1.dist-info/METADATA,sha256=00KHm5Er2R6dmjHLTYBUF13kKAeCRPHmDTdAcv5oRcQ,5335
13
- pydocket-0.7.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- pydocket-0.7.1.dist-info/entry_points.txt,sha256=4WOk1nUlBsUT5O3RyMci2ImuC5XFswuopElYcLHtD5k,47
15
- pydocket-0.7.1.dist-info/licenses/LICENSE,sha256=YuVWU_ZXO0K_k2FG8xWKe5RGxV24AhJKTvQmKfqXuyk,1087
16
- pydocket-0.7.1.dist-info/RECORD,,