pydocket 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydocket might be problematic. Click here for more details.

docket/__init__.py CHANGED
@@ -10,6 +10,7 @@ __version__ = version("pydocket")
10
10
 
11
11
  from .annotations import Logged
12
12
  from .dependencies import (
13
+ ConcurrencyLimit,
13
14
  CurrentDocket,
14
15
  CurrentExecution,
15
16
  CurrentWorker,
@@ -28,6 +29,7 @@ from .worker import Worker
28
29
 
29
30
  __all__ = [
30
31
  "__version__",
32
+ "ConcurrencyLimit",
31
33
  "CurrentDocket",
32
34
  "CurrentExecution",
33
35
  "CurrentWorker",
docket/annotations.py CHANGED
@@ -2,6 +2,8 @@ import abc
2
2
  import inspect
3
3
  from typing import Any, Iterable, Mapping, Self
4
4
 
5
+ from .instrumentation import CACHE_SIZE
6
+
5
7
 
6
8
  class Annotation(abc.ABC):
7
9
  _cache: dict[tuple[type[Self], inspect.Signature], Mapping[str, Self]] = {}
@@ -10,6 +12,7 @@ class Annotation(abc.ABC):
10
12
  def annotated_parameters(cls, signature: inspect.Signature) -> Mapping[str, Self]:
11
13
  key = (cls, signature)
12
14
  if key in cls._cache:
15
+ CACHE_SIZE.set(len(cls._cache), {"cache": "annotation"})
13
16
  return cls._cache[key]
14
17
 
15
18
  annotated: dict[str, Self] = {}
@@ -30,6 +33,7 @@ class Annotation(abc.ABC):
30
33
  annotated[param_name] = arg_type()
31
34
 
32
35
  cls._cache[key] = annotated
36
+ CACHE_SIZE.set(len(cls._cache), {"cache": "annotation"})
33
37
  return annotated
34
38
 
35
39
 
docket/cli.py CHANGED
@@ -358,6 +358,32 @@ def strike(
358
358
  asyncio.run(run())
359
359
 
360
360
 
361
+ @app.command(help="Clear all pending and scheduled tasks from the docket")
362
+ def clear(
363
+ docket_: Annotated[
364
+ str,
365
+ typer.Option(
366
+ "--docket",
367
+ help="The name of the docket",
368
+ envvar="DOCKET_NAME",
369
+ ),
370
+ ] = "docket",
371
+ url: Annotated[
372
+ str,
373
+ typer.Option(
374
+ help="The URL of the Redis server",
375
+ envvar="DOCKET_URL",
376
+ ),
377
+ ] = "redis://localhost:6379/0",
378
+ ) -> None:
379
+ async def run() -> None:
380
+ async with Docket(name=docket_, url=url) as docket:
381
+ cleared_count = await docket.clear()
382
+ print(f"Cleared {cleared_count} tasks from docket '{docket_}'")
383
+
384
+ asyncio.run(run())
385
+
386
+
361
387
  @app.command(help="Restores a task or parameters to the Docket")
362
388
  def restore(
363
389
  function: Annotated[
@@ -568,6 +594,73 @@ def relative_time(now: datetime, when: datetime) -> str:
568
594
  return f"at {local_time(when)}"
569
595
 
570
596
 
597
+ def get_task_stats(
598
+ snapshot: DocketSnapshot,
599
+ ) -> dict[str, dict[str, int | datetime | None]]:
600
+ """Get task count statistics by function name with timestamp data."""
601
+ stats: dict[str, dict[str, int | datetime | None]] = {}
602
+
603
+ # Count running tasks by function
604
+ for execution in snapshot.running:
605
+ func_name = execution.function.__name__
606
+ if func_name not in stats:
607
+ stats[func_name] = {
608
+ "running": 0,
609
+ "queued": 0,
610
+ "total": 0,
611
+ "oldest_queued": None,
612
+ "latest_queued": None,
613
+ "oldest_started": None,
614
+ "latest_started": None,
615
+ }
616
+ stats[func_name]["running"] += 1
617
+ stats[func_name]["total"] += 1
618
+
619
+ # Track oldest/latest started times for running tasks
620
+ started = execution.started
621
+ if (
622
+ stats[func_name]["oldest_started"] is None
623
+ or started < stats[func_name]["oldest_started"]
624
+ ):
625
+ stats[func_name]["oldest_started"] = started
626
+ if (
627
+ stats[func_name]["latest_started"] is None
628
+ or started > stats[func_name]["latest_started"]
629
+ ):
630
+ stats[func_name]["latest_started"] = started
631
+
632
+ # Count future tasks by function
633
+ for execution in snapshot.future:
634
+ func_name = execution.function.__name__
635
+ if func_name not in stats:
636
+ stats[func_name] = {
637
+ "running": 0,
638
+ "queued": 0,
639
+ "total": 0,
640
+ "oldest_queued": None,
641
+ "latest_queued": None,
642
+ "oldest_started": None,
643
+ "latest_started": None,
644
+ }
645
+ stats[func_name]["queued"] += 1
646
+ stats[func_name]["total"] += 1
647
+
648
+ # Track oldest/latest queued times for future tasks
649
+ when = execution.when
650
+ if (
651
+ stats[func_name]["oldest_queued"] is None
652
+ or when < stats[func_name]["oldest_queued"]
653
+ ):
654
+ stats[func_name]["oldest_queued"] = when
655
+ if (
656
+ stats[func_name]["latest_queued"] is None
657
+ or when > stats[func_name]["latest_queued"]
658
+ ):
659
+ stats[func_name]["latest_queued"] = when
660
+
661
+ return stats
662
+
663
+
571
664
  @app.command(help="Shows a snapshot of what's on the docket right now")
572
665
  def snapshot(
573
666
  tasks: Annotated[
@@ -597,6 +690,13 @@ def snapshot(
597
690
  envvar="DOCKET_URL",
598
691
  ),
599
692
  ] = "redis://localhost:6379/0",
693
+ stats: Annotated[
694
+ bool,
695
+ typer.Option(
696
+ "--stats",
697
+ help="Show task count statistics by function name",
698
+ ),
699
+ ] = False,
600
700
  ) -> None:
601
701
  async def run() -> DocketSnapshot:
602
702
  async with Docket(name=docket_, url=url) as docket:
@@ -646,6 +746,44 @@ def snapshot(
646
746
 
647
747
  console.print(table)
648
748
 
749
+ # Display task statistics if requested
750
+ if stats:
751
+ task_stats = get_task_stats(snapshot)
752
+ if task_stats:
753
+ console.print() # Add spacing between tables
754
+ stats_table = Table(title="Task Count Statistics by Function")
755
+ stats_table.add_column("Function", style="cyan")
756
+ stats_table.add_column("Total", style="bold magenta", justify="right")
757
+ stats_table.add_column("Running", style="green", justify="right")
758
+ stats_table.add_column("Queued", style="yellow", justify="right")
759
+ stats_table.add_column("Oldest Queued", style="dim yellow", justify="right")
760
+ stats_table.add_column("Latest Queued", style="dim yellow", justify="right")
761
+
762
+ # Sort by total count descending to highlight potential runaway tasks
763
+ for func_name in sorted(
764
+ task_stats.keys(), key=lambda x: task_stats[x]["total"], reverse=True
765
+ ):
766
+ counts = task_stats[func_name]
767
+
768
+ # Format timestamp columns
769
+ oldest_queued = ""
770
+ latest_queued = ""
771
+ if counts["oldest_queued"] is not None:
772
+ oldest_queued = relative(counts["oldest_queued"])
773
+ if counts["latest_queued"] is not None:
774
+ latest_queued = relative(counts["latest_queued"])
775
+
776
+ stats_table.add_row(
777
+ func_name,
778
+ str(counts["total"]),
779
+ str(counts["running"]),
780
+ str(counts["queued"]),
781
+ oldest_queued,
782
+ latest_queued,
783
+ )
784
+
785
+ console.print(stats_table)
786
+
649
787
 
650
788
  workers_app: typer.Typer = typer.Typer(
651
789
  help="Look at the workers on a docket", no_args_is_help=True
docket/dependencies.py CHANGED
@@ -21,6 +21,7 @@ from typing import (
21
21
 
22
22
  from .docket import Docket
23
23
  from .execution import Execution, TaskFunction, get_signature
24
+ from .instrumentation import CACHE_SIZE
24
25
 
25
26
  if TYPE_CHECKING: # pragma: no cover
26
27
  from .worker import Worker
@@ -38,9 +39,9 @@ class Dependency(abc.ABC):
38
39
 
39
40
  async def __aexit__(
40
41
  self,
41
- exc_type: type[BaseException] | None,
42
- exc_value: BaseException | None,
43
- traceback: TracebackType | None,
42
+ _exc_type: type[BaseException] | None,
43
+ _exc_value: BaseException | None,
44
+ _traceback: TracebackType | None,
44
45
  ) -> bool: ... # pragma: no cover
45
46
 
46
47
 
@@ -415,6 +416,7 @@ def get_dependency_parameters(
415
416
  function: TaskFunction | DependencyFunction[Any],
416
417
  ) -> dict[str, Dependency]:
417
418
  if function in _parameter_cache:
419
+ CACHE_SIZE.set(len(_parameter_cache), {"cache": "parameter"})
418
420
  return _parameter_cache[function]
419
421
 
420
422
  dependencies: dict[str, Dependency] = {}
@@ -428,6 +430,7 @@ def get_dependency_parameters(
428
430
  dependencies[parameter] = param.default
429
431
 
430
432
  _parameter_cache[function] = dependencies
433
+ CACHE_SIZE.set(len(_parameter_cache), {"cache": "parameter"})
431
434
  return dependencies
432
435
 
433
436
 
@@ -502,6 +505,92 @@ def Depends(dependency: DependencyFunction[R]) -> R:
502
505
  return cast(R, _Depends(dependency))
503
506
 
504
507
 
508
+ class ConcurrencyLimit(Dependency):
509
+ """Configures concurrency limits for a task based on specific argument values.
510
+
511
+ This allows fine-grained control over task execution by limiting concurrent
512
+ tasks based on the value of specific arguments.
513
+
514
+ Example:
515
+
516
+ ```python
517
+ async def process_customer(
518
+ customer_id: int,
519
+ concurrency: ConcurrencyLimit = ConcurrencyLimit("customer_id", max_concurrent=1)
520
+ ) -> None:
521
+ # Only one task per customer_id will run at a time
522
+ ...
523
+
524
+ async def backup_db(
525
+ db_name: str,
526
+ concurrency: ConcurrencyLimit = ConcurrencyLimit("db_name", max_concurrent=3)
527
+ ) -> None:
528
+ # Only 3 backup tasks per database name will run at a time
529
+ ...
530
+ ```
531
+ """
532
+
533
+ single: bool = True
534
+
535
+ def __init__(
536
+ self, argument_name: str, max_concurrent: int = 1, scope: str | None = None
537
+ ) -> None:
538
+ """
539
+ Args:
540
+ argument_name: The name of the task argument to use for concurrency grouping
541
+ max_concurrent: Maximum number of concurrent tasks per unique argument value
542
+ scope: Optional scope prefix for Redis keys (defaults to docket name)
543
+ """
544
+ self.argument_name = argument_name
545
+ self.max_concurrent = max_concurrent
546
+ self.scope = scope
547
+ self._concurrency_key: str | None = None
548
+ self._initialized: bool = False
549
+
550
+ async def __aenter__(self) -> "ConcurrencyLimit":
551
+ execution = self.execution.get()
552
+ docket = self.docket.get()
553
+
554
+ # Get the argument value to group by
555
+ try:
556
+ argument_value = execution.get_argument(self.argument_name)
557
+ except KeyError:
558
+ # If argument not found, create a bypass limit that doesn't apply concurrency control
559
+ limit = ConcurrencyLimit(
560
+ self.argument_name, self.max_concurrent, self.scope
561
+ )
562
+ limit._concurrency_key = None # Special marker for bypassed concurrency
563
+ limit._initialized = True # Mark as initialized but bypassed
564
+ return limit
565
+
566
+ # Create a concurrency key for this specific argument value
567
+ scope = self.scope or docket.name
568
+ self._concurrency_key = (
569
+ f"{scope}:concurrency:{self.argument_name}:{argument_value}"
570
+ )
571
+
572
+ limit = ConcurrencyLimit(self.argument_name, self.max_concurrent, self.scope)
573
+ limit._concurrency_key = self._concurrency_key
574
+ limit._initialized = True # Mark as initialized
575
+ return limit
576
+
577
+ @property
578
+ def concurrency_key(self) -> str | None:
579
+ """Redis key used for tracking concurrency for this specific argument value.
580
+ Returns None when concurrency control is bypassed due to missing arguments.
581
+ Raises RuntimeError if accessed before initialization."""
582
+ if not self._initialized:
583
+ raise RuntimeError(
584
+ "ConcurrencyLimit not initialized - use within task context"
585
+ )
586
+ return self._concurrency_key
587
+
588
+ @property
589
+ def is_bypassed(self) -> bool:
590
+ """Returns True if concurrency control is bypassed due to missing arguments."""
591
+ return self._initialized and self._concurrency_key is None
592
+
593
+
505
594
  D = TypeVar("D", bound=Dependency)
506
595
 
507
596
 
docket/docket.py CHANGED
@@ -743,3 +743,46 @@ class Docket:
743
743
  workers.append(WorkerInfo(worker_name, last_seen, task_names))
744
744
 
745
745
  return workers
746
+
747
+ async def clear(self) -> int:
748
+ """Clear all pending and scheduled tasks from the docket.
749
+
750
+ This removes all tasks from the stream (immediate tasks) and queue
751
+ (scheduled tasks), along with their associated parked data. Running
752
+ tasks are not affected.
753
+
754
+ Returns:
755
+ The total number of tasks that were cleared.
756
+ """
757
+ with tracer.start_as_current_span(
758
+ "docket.clear",
759
+ attributes=self.labels(),
760
+ ):
761
+ async with self.redis() as redis:
762
+ async with redis.pipeline() as pipeline:
763
+ # Get counts before clearing
764
+ pipeline.xlen(self.stream_key)
765
+ pipeline.zcard(self.queue_key)
766
+ pipeline.zrange(self.queue_key, 0, -1)
767
+
768
+ stream_count: int
769
+ queue_count: int
770
+ scheduled_keys: list[bytes]
771
+ stream_count, queue_count, scheduled_keys = await pipeline.execute()
772
+
773
+ # Clear all data
774
+ # Trim stream to 0 messages instead of deleting it to preserve consumer group
775
+ if stream_count > 0:
776
+ pipeline.xtrim(self.stream_key, maxlen=0, approximate=False)
777
+ pipeline.delete(self.queue_key)
778
+
779
+ # Clear parked task data and known task keys
780
+ for key_bytes in scheduled_keys:
781
+ key = key_bytes.decode()
782
+ pipeline.delete(self.parked_task_key(key))
783
+ pipeline.delete(self.known_task_key(key))
784
+
785
+ await pipeline.execute()
786
+
787
+ total_cleared = stream_count + queue_count
788
+ return total_cleared
docket/execution.py CHANGED
@@ -19,7 +19,7 @@ import opentelemetry.context
19
19
  from opentelemetry import propagate, trace
20
20
 
21
21
  from .annotations import Logged
22
- from .instrumentation import message_getter
22
+ from .instrumentation import CACHE_SIZE, message_getter
23
23
 
24
24
  logger: logging.Logger = logging.getLogger(__name__)
25
25
 
@@ -32,10 +32,12 @@ _signature_cache: dict[Callable[..., Any], inspect.Signature] = {}
32
32
 
33
33
  def get_signature(function: Callable[..., Any]) -> inspect.Signature:
34
34
  if function in _signature_cache:
35
+ CACHE_SIZE.set(len(_signature_cache), {"cache": "signature"})
35
36
  return _signature_cache[function]
36
37
 
37
38
  signature = inspect.signature(function)
38
39
  _signature_cache[function] = signature
40
+ CACHE_SIZE.set(len(_signature_cache), {"cache": "signature"})
39
41
  return signature
40
42
 
41
43
 
docket/instrumentation.py CHANGED
@@ -117,6 +117,12 @@ SCHEDULE_DEPTH = meter.create_gauge(
117
117
  unit="1",
118
118
  )
119
119
 
120
+ CACHE_SIZE = meter.create_gauge(
121
+ "docket_cache_size",
122
+ description="Size of internal docket caches",
123
+ unit="1",
124
+ )
125
+
120
126
  Message = dict[bytes, bytes]
121
127
 
122
128
 
docket/worker.py CHANGED
@@ -15,11 +15,12 @@ from typing import (
15
15
  )
16
16
 
17
17
  from opentelemetry import trace
18
- from opentelemetry.trace import Tracer
18
+ from opentelemetry.trace import Status, StatusCode, Tracer
19
19
  from redis.asyncio import Redis
20
20
  from redis.exceptions import ConnectionError, LockError
21
21
 
22
22
  from .dependencies import (
23
+ ConcurrencyLimit,
23
24
  Dependency,
24
25
  FailedDependency,
25
26
  Perpetual,
@@ -248,6 +249,7 @@ class Worker:
248
249
  )
249
250
 
250
251
  active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
252
+ task_executions: dict[asyncio.Task[None], Execution] = {}
251
253
  available_slots = self.concurrency
252
254
 
253
255
  log_context = self._log_context()
@@ -284,7 +286,7 @@ class Worker:
284
286
  count=available_slots,
285
287
  )
286
288
 
287
- def start_task(message_id: RedisMessageID, message: RedisMessage) -> bool:
289
+ async def start_task(message_id: RedisMessageID, message: RedisMessage) -> bool:
288
290
  function_name = message[b"function"].decode()
289
291
  if not (function := self.docket.tasks.get(function_name)):
290
292
  logger.warning(
@@ -298,6 +300,7 @@ class Worker:
298
300
 
299
301
  task = asyncio.create_task(self._execute(execution), name=execution.key)
300
302
  active_tasks[task] = message_id
303
+ task_executions[task] = execution
301
304
 
302
305
  nonlocal available_slots
303
306
  available_slots -= 1
@@ -308,6 +311,7 @@ class Worker:
308
311
  completed_tasks = {task for task in active_tasks if task.done()}
309
312
  for task in completed_tasks:
310
313
  message_id = active_tasks.pop(task)
314
+ task_executions.pop(task)
311
315
  await task
312
316
  await ack_message(redis, message_id)
313
317
 
@@ -343,7 +347,9 @@ class Worker:
343
347
  if not message: # pragma: no cover
344
348
  continue
345
349
 
346
- if not start_task(message_id, message):
350
+ task_started = await start_task(message_id, message)
351
+ if not task_started:
352
+ # Other errors - delete and ack
347
353
  await self._delete_known_task(redis, message)
348
354
  await ack_message(redis, message_id)
349
355
 
@@ -531,9 +537,35 @@ class Worker:
531
537
  "code.function.name": execution.function.__name__,
532
538
  },
533
539
  links=execution.incoming_span_links(),
534
- ):
540
+ ) as span:
535
541
  try:
536
542
  async with resolved_dependencies(self, execution) as dependencies:
543
+ # Check concurrency limits after dependency resolution
544
+ concurrency_limit = get_single_dependency_of_type(
545
+ dependencies, ConcurrencyLimit
546
+ )
547
+ if concurrency_limit and not concurrency_limit.is_bypassed:
548
+ async with self.docket.redis() as redis:
549
+ # Check if we can acquire a concurrency slot
550
+ if not await self._can_start_task(redis, execution):
551
+ # Task cannot start due to concurrency limits - reschedule
552
+ logger.debug(
553
+ "🔒 Task %s blocked by concurrency limit, rescheduling",
554
+ execution.key,
555
+ extra=log_context,
556
+ )
557
+ # Reschedule for a few milliseconds in the future
558
+ when = datetime.now(timezone.utc) + timedelta(
559
+ milliseconds=50
560
+ )
561
+ await self.docket.add(execution.function, when=when)(
562
+ *execution.args, **execution.kwargs
563
+ )
564
+ return
565
+ else:
566
+ # Successfully acquired slot
567
+ pass
568
+
537
569
  # Preemptively reschedule the perpetual task for the future, or clear
538
570
  # the known task key for this task
539
571
  rescheduled = await self._perpetuate_if_requested(
@@ -560,22 +592,41 @@ class Worker:
560
592
  ],
561
593
  )
562
594
 
563
- if timeout := get_single_dependency_of_type(dependencies, Timeout):
564
- await self._run_function_with_timeout(
565
- execution, dependencies, timeout
566
- )
595
+ # Apply timeout logic - either user's timeout or redelivery timeout
596
+ user_timeout = get_single_dependency_of_type(dependencies, Timeout)
597
+ if user_timeout:
598
+ # If user timeout is longer than redelivery timeout, limit it
599
+ if user_timeout.base > self.redelivery_timeout:
600
+ # Create a new timeout limited by redelivery timeout
601
+ # Remove the user timeout from dependencies to avoid conflicts
602
+ limited_dependencies = {
603
+ k: v
604
+ for k, v in dependencies.items()
605
+ if not isinstance(v, Timeout)
606
+ }
607
+ limited_timeout = Timeout(self.redelivery_timeout)
608
+ limited_timeout.start()
609
+ await self._run_function_with_timeout(
610
+ execution, limited_dependencies, limited_timeout
611
+ )
612
+ else:
613
+ # User timeout is within redelivery timeout, use as-is
614
+ await self._run_function_with_timeout(
615
+ execution, dependencies, user_timeout
616
+ )
567
617
  else:
568
- await execution.function(
569
- *execution.args,
570
- **{
571
- **execution.kwargs,
572
- **dependencies,
573
- },
618
+ # No user timeout - apply redelivery timeout as hard limit
619
+ redelivery_timeout = Timeout(self.redelivery_timeout)
620
+ redelivery_timeout.start()
621
+ await self._run_function_with_timeout(
622
+ execution, dependencies, redelivery_timeout
574
623
  )
575
624
 
576
625
  duration = log_context["duration"] = time.time() - start
577
626
  TASKS_SUCCEEDED.add(1, counter_labels)
578
627
 
628
+ span.set_status(Status(StatusCode.OK))
629
+
579
630
  rescheduled = await self._perpetuate_if_requested(
580
631
  execution, dependencies, timedelta(seconds=duration)
581
632
  )
@@ -584,10 +635,13 @@ class Worker:
584
635
  logger.info(
585
636
  "%s [%s] %s", arrow, ms(duration), call, extra=log_context
586
637
  )
587
- except Exception:
638
+ except Exception as e:
588
639
  duration = log_context["duration"] = time.time() - start
589
640
  TASKS_FAILED.add(1, counter_labels)
590
641
 
642
+ span.record_exception(e)
643
+ span.set_status(Status(StatusCode.ERROR, str(e)))
644
+
591
645
  retried = await self._retry_if_requested(execution, dependencies)
592
646
  if not retried:
593
647
  retried = await self._perpetuate_if_requested(
@@ -599,6 +653,15 @@ class Worker:
599
653
  "%s [%s] %s", arrow, ms(duration), call, extra=log_context
600
654
  )
601
655
  finally:
656
+ # Release concurrency slot if we acquired one
657
+ if dependencies:
658
+ concurrency_limit = get_single_dependency_of_type(
659
+ dependencies, ConcurrencyLimit
660
+ )
661
+ if concurrency_limit and not concurrency_limit.is_bypassed:
662
+ async with self.docket.redis() as redis:
663
+ await self._release_concurrency_slot(redis, execution)
664
+
602
665
  TASKS_RUNNING.add(-1, counter_labels)
603
666
  TASKS_COMPLETED.add(1, counter_labels)
604
667
  TASK_DURATION.record(duration, counter_labels)
@@ -611,7 +674,13 @@ class Worker:
611
674
  ) -> None:
612
675
  task_coro = cast(
613
676
  Coroutine[None, None, None],
614
- execution.function(*execution.args, **execution.kwargs, **dependencies),
677
+ execution.function(
678
+ *execution.args,
679
+ **{
680
+ **execution.kwargs,
681
+ **dependencies,
682
+ },
683
+ ),
615
684
  )
616
685
  task = asyncio.create_task(task_coro)
617
686
  try:
@@ -757,6 +826,95 @@ class Worker:
757
826
  extra=self._log_context(),
758
827
  )
759
828
 
829
+ async def _can_start_task(self, redis: Redis, execution: Execution) -> bool:
830
+ """Check if a task can start based on concurrency limits."""
831
+ # Check if task has a concurrency limit dependency
832
+ concurrency_limit = get_single_dependency_parameter_of_type(
833
+ execution.function, ConcurrencyLimit
834
+ )
835
+
836
+ if not concurrency_limit:
837
+ return True # No concurrency limit, can always start
838
+
839
+ # Get the concurrency key for this task
840
+ try:
841
+ argument_value = execution.get_argument(concurrency_limit.argument_name)
842
+ except KeyError:
843
+ # If argument not found, let the task fail naturally in execution
844
+ return True
845
+
846
+ scope = concurrency_limit.scope or self.docket.name
847
+ concurrency_key = (
848
+ f"{scope}:concurrency:{concurrency_limit.argument_name}:{argument_value}"
849
+ )
850
+
851
+ # Use Redis sorted set with timestamps to track concurrency and handle expiration
852
+ lua_script = """
853
+ local key = KEYS[1]
854
+ local max_concurrent = tonumber(ARGV[1])
855
+ local worker_id = ARGV[2]
856
+ local task_key = ARGV[3]
857
+ local current_time = tonumber(ARGV[4])
858
+ local expiration_time = tonumber(ARGV[5])
859
+
860
+ -- Remove expired entries
861
+ local expired_cutoff = current_time - expiration_time
862
+ redis.call('ZREMRANGEBYSCORE', key, 0, expired_cutoff)
863
+
864
+ -- Get current count
865
+ local current = redis.call('ZCARD', key)
866
+
867
+ if current < max_concurrent then
868
+ -- Add this worker's task to the sorted set with current timestamp
869
+ redis.call('ZADD', key, current_time, worker_id .. ':' .. task_key)
870
+ return 1
871
+ else
872
+ return 0
873
+ end
874
+ """
875
+
876
+ current_time = datetime.now(timezone.utc).timestamp()
877
+ expiration_seconds = self.redelivery_timeout.total_seconds()
878
+
879
+ result = await redis.eval( # type: ignore
880
+ lua_script,
881
+ 1,
882
+ concurrency_key,
883
+ str(concurrency_limit.max_concurrent),
884
+ self.name,
885
+ execution.key,
886
+ current_time,
887
+ expiration_seconds,
888
+ )
889
+
890
+ return bool(result)
891
+
892
+ async def _release_concurrency_slot(
893
+ self, redis: Redis, execution: Execution
894
+ ) -> None:
895
+ """Release a concurrency slot when task completes."""
896
+ # Check if task has a concurrency limit dependency
897
+ concurrency_limit = get_single_dependency_parameter_of_type(
898
+ execution.function, ConcurrencyLimit
899
+ )
900
+
901
+ if not concurrency_limit:
902
+ return # No concurrency limit to release
903
+
904
+ # Get the concurrency key for this task
905
+ try:
906
+ argument_value = execution.get_argument(concurrency_limit.argument_name)
907
+ except KeyError:
908
+ return # If argument not found, nothing to release
909
+
910
+ scope = concurrency_limit.scope or self.docket.name
911
+ concurrency_key = (
912
+ f"{scope}:concurrency:{concurrency_limit.argument_name}:{argument_value}"
913
+ )
914
+
915
+ # Remove this worker's task from the sorted set
916
+ await redis.zrem(concurrency_key, f"{self.name}:{execution.key}") # type: ignore
917
+
760
918
 
761
919
  def ms(seconds: float) -> str:
762
920
  if seconds < 100:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydocket
3
- Version: 0.7.0
3
+ Version: 0.8.0
4
4
  Summary: A distributed background task system for Python functions
5
5
  Project-URL: Homepage, https://github.com/chrisguidry/docket
6
6
  Project-URL: Bug Tracker, https://github.com/chrisguidry/docket/issues
@@ -93,6 +93,8 @@ reference](https://chrisguidry.github.io/docket/api-reference/).
93
93
 
94
94
  🧩 Fully type-complete and type-aware for your background task functions
95
95
 
96
+ 💉 Dependency injection like FastAPI, Typer, and FastMCP for reusable resources
97
+
96
98
  ## Installing `docket`
97
99
 
98
100
  Docket is [available on PyPI](https://pypi.org/project/pydocket/) under the package name
@@ -0,0 +1,16 @@
1
+ docket/__init__.py,sha256=onwZzh73tESWoFBukbcW-7gjxoXb-yI7dutRD7tPN6g,915
2
+ docket/__main__.py,sha256=wcCrL4PjG51r5wVKqJhcoJPTLfHW0wNbD31DrUN0MWI,28
3
+ docket/annotations.py,sha256=wttix9UOeMFMAWXAIJUfUw5GjESJZsACb4YXJCozP7Q,2348
4
+ docket/cli.py,sha256=rTfri2--u4Q5PlXyh7Ub_F5uh3-TtZOWLUp9WY_TvAE,25750
5
+ docket/dependencies.py,sha256=BC0bnt10cr9_S1p5JAP_bnC9RwZkTr9ulPBrxC7eZnA,20247
6
+ docket/docket.py,sha256=Cw7QB1d0eDwSgwn0Rj26WjFsXSe7MJtfsUBBHGalL7A,26262
7
+ docket/execution.py,sha256=r_2RGC1qhtAcBUg7E6wewLEgftrf3hIxNbH0HnYPbek,14961
8
+ docket/instrumentation.py,sha256=ogvzrfKbWsdPGfdg4hByH3_r5d3b5AwwQkSrmXw0hRg,5492
9
+ docket/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ docket/tasks.py,sha256=RIlSM2omh-YDwVnCz6M5MtmK8T_m_s1w2OlRRxDUs6A,1437
11
+ docket/worker.py,sha256=RNfwvNC1zpGfTYq8-HOjicociuVvOJWMgj8w8DqmN3Y,34940
12
+ pydocket-0.8.0.dist-info/METADATA,sha256=Y2tUVZmlDqPWmBCmmjRUTs56Rd6thT3_wSd9JMKb79s,5418
13
+ pydocket-0.8.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
+ pydocket-0.8.0.dist-info/entry_points.txt,sha256=4WOk1nUlBsUT5O3RyMci2ImuC5XFswuopElYcLHtD5k,47
15
+ pydocket-0.8.0.dist-info/licenses/LICENSE,sha256=YuVWU_ZXO0K_k2FG8xWKe5RGxV24AhJKTvQmKfqXuyk,1087
16
+ pydocket-0.8.0.dist-info/RECORD,,
@@ -1,16 +0,0 @@
1
- docket/__init__.py,sha256=sY1T_NVsXQNOmOhOnfYmZ95dcE_52Ov6DSIVIMZp-1w,869
2
- docket/__main__.py,sha256=wcCrL4PjG51r5wVKqJhcoJPTLfHW0wNbD31DrUN0MWI,28
3
- docket/annotations.py,sha256=SFBrOMbpAh7P67u8fRTH-u3MVvJQxe0qYi92WAShAsw,2173
4
- docket/cli.py,sha256=WPm_URZ54h8gHjrsHKP8SXpRzdeepmyH_FhQHai-Qus,20899
5
- docket/dependencies.py,sha256=fX4vafGjQf7s4x0YROaw7fzQPlYW7TZtCqNhu7Kxj40,16831
6
- docket/docket.py,sha256=5e101CGLZ2tWNcADo4cdewapmXab47ieMCeQr6d92YQ,24478
7
- docket/execution.py,sha256=6KozjnS96byvyCMTQ2-IkcIrPsqaPIVu2HZU0U4Be9E,14813
8
- docket/instrumentation.py,sha256=f-GG5VS6EdS2It30qxjVpzWUBOZQcTnat-3KzPwwDgQ,5367
9
- docket/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- docket/tasks.py,sha256=RIlSM2omh-YDwVnCz6M5MtmK8T_m_s1w2OlRRxDUs6A,1437
11
- docket/worker.py,sha256=tJfk2rlHODzHaWBzpBXT8h-Lo7RDQ6gb6HU8b3T9gFA,27878
12
- pydocket-0.7.0.dist-info/METADATA,sha256=soXf7ybhgvSykxRDH56pMJX2DaXf3SJfDFUFLbebAvM,5335
13
- pydocket-0.7.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- pydocket-0.7.0.dist-info/entry_points.txt,sha256=4WOk1nUlBsUT5O3RyMci2ImuC5XFswuopElYcLHtD5k,47
15
- pydocket-0.7.0.dist-info/licenses/LICENSE,sha256=YuVWU_ZXO0K_k2FG8xWKe5RGxV24AhJKTvQmKfqXuyk,1087
16
- pydocket-0.7.0.dist-info/RECORD,,