pydocket 0.7.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydocket might be problematic. Click here for more details.
- docket/__init__.py +2 -0
- docket/cli.py +112 -0
- docket/dependencies.py +89 -3
- docket/docket.py +151 -35
- docket/worker.py +169 -13
- {pydocket-0.7.1.dist-info → pydocket-0.9.0.dist-info}/METADATA +3 -1
- pydocket-0.9.0.dist-info/RECORD +16 -0
- pydocket-0.7.1.dist-info/RECORD +0 -16
- {pydocket-0.7.1.dist-info → pydocket-0.9.0.dist-info}/WHEEL +0 -0
- {pydocket-0.7.1.dist-info → pydocket-0.9.0.dist-info}/entry_points.txt +0 -0
- {pydocket-0.7.1.dist-info → pydocket-0.9.0.dist-info}/licenses/LICENSE +0 -0
docket/__init__.py
CHANGED
|
@@ -10,6 +10,7 @@ __version__ = version("pydocket")
|
|
|
10
10
|
|
|
11
11
|
from .annotations import Logged
|
|
12
12
|
from .dependencies import (
|
|
13
|
+
ConcurrencyLimit,
|
|
13
14
|
CurrentDocket,
|
|
14
15
|
CurrentExecution,
|
|
15
16
|
CurrentWorker,
|
|
@@ -28,6 +29,7 @@ from .worker import Worker
|
|
|
28
29
|
|
|
29
30
|
__all__ = [
|
|
30
31
|
"__version__",
|
|
32
|
+
"ConcurrencyLimit",
|
|
31
33
|
"CurrentDocket",
|
|
32
34
|
"CurrentExecution",
|
|
33
35
|
"CurrentWorker",
|
docket/cli.py
CHANGED
|
@@ -594,6 +594,73 @@ def relative_time(now: datetime, when: datetime) -> str:
|
|
|
594
594
|
return f"at {local_time(when)}"
|
|
595
595
|
|
|
596
596
|
|
|
597
|
+
def get_task_stats(
|
|
598
|
+
snapshot: DocketSnapshot,
|
|
599
|
+
) -> dict[str, dict[str, int | datetime | None]]:
|
|
600
|
+
"""Get task count statistics by function name with timestamp data."""
|
|
601
|
+
stats: dict[str, dict[str, int | datetime | None]] = {}
|
|
602
|
+
|
|
603
|
+
# Count running tasks by function
|
|
604
|
+
for execution in snapshot.running:
|
|
605
|
+
func_name = execution.function.__name__
|
|
606
|
+
if func_name not in stats:
|
|
607
|
+
stats[func_name] = {
|
|
608
|
+
"running": 0,
|
|
609
|
+
"queued": 0,
|
|
610
|
+
"total": 0,
|
|
611
|
+
"oldest_queued": None,
|
|
612
|
+
"latest_queued": None,
|
|
613
|
+
"oldest_started": None,
|
|
614
|
+
"latest_started": None,
|
|
615
|
+
}
|
|
616
|
+
stats[func_name]["running"] += 1
|
|
617
|
+
stats[func_name]["total"] += 1
|
|
618
|
+
|
|
619
|
+
# Track oldest/latest started times for running tasks
|
|
620
|
+
started = execution.started
|
|
621
|
+
if (
|
|
622
|
+
stats[func_name]["oldest_started"] is None
|
|
623
|
+
or started < stats[func_name]["oldest_started"]
|
|
624
|
+
):
|
|
625
|
+
stats[func_name]["oldest_started"] = started
|
|
626
|
+
if (
|
|
627
|
+
stats[func_name]["latest_started"] is None
|
|
628
|
+
or started > stats[func_name]["latest_started"]
|
|
629
|
+
):
|
|
630
|
+
stats[func_name]["latest_started"] = started
|
|
631
|
+
|
|
632
|
+
# Count future tasks by function
|
|
633
|
+
for execution in snapshot.future:
|
|
634
|
+
func_name = execution.function.__name__
|
|
635
|
+
if func_name not in stats:
|
|
636
|
+
stats[func_name] = {
|
|
637
|
+
"running": 0,
|
|
638
|
+
"queued": 0,
|
|
639
|
+
"total": 0,
|
|
640
|
+
"oldest_queued": None,
|
|
641
|
+
"latest_queued": None,
|
|
642
|
+
"oldest_started": None,
|
|
643
|
+
"latest_started": None,
|
|
644
|
+
}
|
|
645
|
+
stats[func_name]["queued"] += 1
|
|
646
|
+
stats[func_name]["total"] += 1
|
|
647
|
+
|
|
648
|
+
# Track oldest/latest queued times for future tasks
|
|
649
|
+
when = execution.when
|
|
650
|
+
if (
|
|
651
|
+
stats[func_name]["oldest_queued"] is None
|
|
652
|
+
or when < stats[func_name]["oldest_queued"]
|
|
653
|
+
):
|
|
654
|
+
stats[func_name]["oldest_queued"] = when
|
|
655
|
+
if (
|
|
656
|
+
stats[func_name]["latest_queued"] is None
|
|
657
|
+
or when > stats[func_name]["latest_queued"]
|
|
658
|
+
):
|
|
659
|
+
stats[func_name]["latest_queued"] = when
|
|
660
|
+
|
|
661
|
+
return stats
|
|
662
|
+
|
|
663
|
+
|
|
597
664
|
@app.command(help="Shows a snapshot of what's on the docket right now")
|
|
598
665
|
def snapshot(
|
|
599
666
|
tasks: Annotated[
|
|
@@ -623,6 +690,13 @@ def snapshot(
|
|
|
623
690
|
envvar="DOCKET_URL",
|
|
624
691
|
),
|
|
625
692
|
] = "redis://localhost:6379/0",
|
|
693
|
+
stats: Annotated[
|
|
694
|
+
bool,
|
|
695
|
+
typer.Option(
|
|
696
|
+
"--stats",
|
|
697
|
+
help="Show task count statistics by function name",
|
|
698
|
+
),
|
|
699
|
+
] = False,
|
|
626
700
|
) -> None:
|
|
627
701
|
async def run() -> DocketSnapshot:
|
|
628
702
|
async with Docket(name=docket_, url=url) as docket:
|
|
@@ -672,6 +746,44 @@ def snapshot(
|
|
|
672
746
|
|
|
673
747
|
console.print(table)
|
|
674
748
|
|
|
749
|
+
# Display task statistics if requested
|
|
750
|
+
if stats:
|
|
751
|
+
task_stats = get_task_stats(snapshot)
|
|
752
|
+
if task_stats:
|
|
753
|
+
console.print() # Add spacing between tables
|
|
754
|
+
stats_table = Table(title="Task Count Statistics by Function")
|
|
755
|
+
stats_table.add_column("Function", style="cyan")
|
|
756
|
+
stats_table.add_column("Total", style="bold magenta", justify="right")
|
|
757
|
+
stats_table.add_column("Running", style="green", justify="right")
|
|
758
|
+
stats_table.add_column("Queued", style="yellow", justify="right")
|
|
759
|
+
stats_table.add_column("Oldest Queued", style="dim yellow", justify="right")
|
|
760
|
+
stats_table.add_column("Latest Queued", style="dim yellow", justify="right")
|
|
761
|
+
|
|
762
|
+
# Sort by total count descending to highlight potential runaway tasks
|
|
763
|
+
for func_name in sorted(
|
|
764
|
+
task_stats.keys(), key=lambda x: task_stats[x]["total"], reverse=True
|
|
765
|
+
):
|
|
766
|
+
counts = task_stats[func_name]
|
|
767
|
+
|
|
768
|
+
# Format timestamp columns
|
|
769
|
+
oldest_queued = ""
|
|
770
|
+
latest_queued = ""
|
|
771
|
+
if counts["oldest_queued"] is not None:
|
|
772
|
+
oldest_queued = relative(counts["oldest_queued"])
|
|
773
|
+
if counts["latest_queued"] is not None:
|
|
774
|
+
latest_queued = relative(counts["latest_queued"])
|
|
775
|
+
|
|
776
|
+
stats_table.add_row(
|
|
777
|
+
func_name,
|
|
778
|
+
str(counts["total"]),
|
|
779
|
+
str(counts["running"]),
|
|
780
|
+
str(counts["queued"]),
|
|
781
|
+
oldest_queued,
|
|
782
|
+
latest_queued,
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
console.print(stats_table)
|
|
786
|
+
|
|
675
787
|
|
|
676
788
|
workers_app: typer.Typer = typer.Typer(
|
|
677
789
|
help="Look at the workers on a docket", no_args_is_help=True
|
docket/dependencies.py
CHANGED
|
@@ -39,9 +39,9 @@ class Dependency(abc.ABC):
|
|
|
39
39
|
|
|
40
40
|
async def __aexit__(
|
|
41
41
|
self,
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
42
|
+
_exc_type: type[BaseException] | None,
|
|
43
|
+
_exc_value: BaseException | None,
|
|
44
|
+
_traceback: TracebackType | None,
|
|
45
45
|
) -> bool: ... # pragma: no cover
|
|
46
46
|
|
|
47
47
|
|
|
@@ -505,6 +505,92 @@ def Depends(dependency: DependencyFunction[R]) -> R:
|
|
|
505
505
|
return cast(R, _Depends(dependency))
|
|
506
506
|
|
|
507
507
|
|
|
508
|
+
class ConcurrencyLimit(Dependency):
|
|
509
|
+
"""Configures concurrency limits for a task based on specific argument values.
|
|
510
|
+
|
|
511
|
+
This allows fine-grained control over task execution by limiting concurrent
|
|
512
|
+
tasks based on the value of specific arguments.
|
|
513
|
+
|
|
514
|
+
Example:
|
|
515
|
+
|
|
516
|
+
```python
|
|
517
|
+
async def process_customer(
|
|
518
|
+
customer_id: int,
|
|
519
|
+
concurrency: ConcurrencyLimit = ConcurrencyLimit("customer_id", max_concurrent=1)
|
|
520
|
+
) -> None:
|
|
521
|
+
# Only one task per customer_id will run at a time
|
|
522
|
+
...
|
|
523
|
+
|
|
524
|
+
async def backup_db(
|
|
525
|
+
db_name: str,
|
|
526
|
+
concurrency: ConcurrencyLimit = ConcurrencyLimit("db_name", max_concurrent=3)
|
|
527
|
+
) -> None:
|
|
528
|
+
# Only 3 backup tasks per database name will run at a time
|
|
529
|
+
...
|
|
530
|
+
```
|
|
531
|
+
"""
|
|
532
|
+
|
|
533
|
+
single: bool = True
|
|
534
|
+
|
|
535
|
+
def __init__(
|
|
536
|
+
self, argument_name: str, max_concurrent: int = 1, scope: str | None = None
|
|
537
|
+
) -> None:
|
|
538
|
+
"""
|
|
539
|
+
Args:
|
|
540
|
+
argument_name: The name of the task argument to use for concurrency grouping
|
|
541
|
+
max_concurrent: Maximum number of concurrent tasks per unique argument value
|
|
542
|
+
scope: Optional scope prefix for Redis keys (defaults to docket name)
|
|
543
|
+
"""
|
|
544
|
+
self.argument_name = argument_name
|
|
545
|
+
self.max_concurrent = max_concurrent
|
|
546
|
+
self.scope = scope
|
|
547
|
+
self._concurrency_key: str | None = None
|
|
548
|
+
self._initialized: bool = False
|
|
549
|
+
|
|
550
|
+
async def __aenter__(self) -> "ConcurrencyLimit":
|
|
551
|
+
execution = self.execution.get()
|
|
552
|
+
docket = self.docket.get()
|
|
553
|
+
|
|
554
|
+
# Get the argument value to group by
|
|
555
|
+
try:
|
|
556
|
+
argument_value = execution.get_argument(self.argument_name)
|
|
557
|
+
except KeyError:
|
|
558
|
+
# If argument not found, create a bypass limit that doesn't apply concurrency control
|
|
559
|
+
limit = ConcurrencyLimit(
|
|
560
|
+
self.argument_name, self.max_concurrent, self.scope
|
|
561
|
+
)
|
|
562
|
+
limit._concurrency_key = None # Special marker for bypassed concurrency
|
|
563
|
+
limit._initialized = True # Mark as initialized but bypassed
|
|
564
|
+
return limit
|
|
565
|
+
|
|
566
|
+
# Create a concurrency key for this specific argument value
|
|
567
|
+
scope = self.scope or docket.name
|
|
568
|
+
self._concurrency_key = (
|
|
569
|
+
f"{scope}:concurrency:{self.argument_name}:{argument_value}"
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
limit = ConcurrencyLimit(self.argument_name, self.max_concurrent, self.scope)
|
|
573
|
+
limit._concurrency_key = self._concurrency_key
|
|
574
|
+
limit._initialized = True # Mark as initialized
|
|
575
|
+
return limit
|
|
576
|
+
|
|
577
|
+
@property
|
|
578
|
+
def concurrency_key(self) -> str | None:
|
|
579
|
+
"""Redis key used for tracking concurrency for this specific argument value.
|
|
580
|
+
Returns None when concurrency control is bypassed due to missing arguments.
|
|
581
|
+
Raises RuntimeError if accessed before initialization."""
|
|
582
|
+
if not self._initialized:
|
|
583
|
+
raise RuntimeError(
|
|
584
|
+
"ConcurrencyLimit not initialized - use within task context"
|
|
585
|
+
)
|
|
586
|
+
return self._concurrency_key
|
|
587
|
+
|
|
588
|
+
@property
|
|
589
|
+
def is_bypassed(self) -> bool:
|
|
590
|
+
"""Returns True if concurrency control is bypassed due to missing arguments."""
|
|
591
|
+
return self._initialized and self._concurrency_key is None
|
|
592
|
+
|
|
593
|
+
|
|
508
594
|
D = TypeVar("D", bound=Dependency)
|
|
509
595
|
|
|
510
596
|
|
docket/docket.py
CHANGED
|
@@ -16,6 +16,7 @@ from typing import (
|
|
|
16
16
|
Mapping,
|
|
17
17
|
NoReturn,
|
|
18
18
|
ParamSpec,
|
|
19
|
+
Protocol,
|
|
19
20
|
Self,
|
|
20
21
|
Sequence,
|
|
21
22
|
TypedDict,
|
|
@@ -27,7 +28,6 @@ from typing import (
|
|
|
27
28
|
import redis.exceptions
|
|
28
29
|
from opentelemetry import propagate, trace
|
|
29
30
|
from redis.asyncio import ConnectionPool, Redis
|
|
30
|
-
from redis.asyncio.client import Pipeline
|
|
31
31
|
from uuid_extensions import uuid7
|
|
32
32
|
|
|
33
33
|
from .execution import (
|
|
@@ -55,6 +55,18 @@ logger: logging.Logger = logging.getLogger(__name__)
|
|
|
55
55
|
tracer: trace.Tracer = trace.get_tracer(__name__)
|
|
56
56
|
|
|
57
57
|
|
|
58
|
+
class _schedule_task(Protocol):
|
|
59
|
+
async def __call__(
|
|
60
|
+
self, keys: list[str], args: list[str | float | bytes]
|
|
61
|
+
) -> str: ... # pragma: no cover
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class _cancel_task(Protocol):
|
|
65
|
+
async def __call__(
|
|
66
|
+
self, keys: list[str], args: list[str]
|
|
67
|
+
) -> str: ... # pragma: no cover
|
|
68
|
+
|
|
69
|
+
|
|
58
70
|
P = ParamSpec("P")
|
|
59
71
|
R = TypeVar("R")
|
|
60
72
|
|
|
@@ -131,6 +143,8 @@ class Docket:
|
|
|
131
143
|
|
|
132
144
|
_monitor_strikes_task: asyncio.Task[None]
|
|
133
145
|
_connection_pool: ConnectionPool
|
|
146
|
+
_schedule_task_script: _schedule_task | None
|
|
147
|
+
_cancel_task_script: _cancel_task | None
|
|
134
148
|
|
|
135
149
|
def __init__(
|
|
136
150
|
self,
|
|
@@ -156,6 +170,8 @@ class Docket:
|
|
|
156
170
|
self.url = url
|
|
157
171
|
self.heartbeat_interval = heartbeat_interval
|
|
158
172
|
self.missed_heartbeats = missed_heartbeats
|
|
173
|
+
self._schedule_task_script = None
|
|
174
|
+
self._cancel_task_script = None
|
|
159
175
|
|
|
160
176
|
@property
|
|
161
177
|
def worker_group_name(self) -> str:
|
|
@@ -300,9 +316,7 @@ class Docket:
|
|
|
300
316
|
execution = Execution(function, args, kwargs, when, key, attempt=1)
|
|
301
317
|
|
|
302
318
|
async with self.redis() as redis:
|
|
303
|
-
|
|
304
|
-
await self._schedule(redis, pipeline, execution, replace=False)
|
|
305
|
-
await pipeline.execute()
|
|
319
|
+
await self._schedule(redis, execution, replace=False)
|
|
306
320
|
|
|
307
321
|
TASKS_ADDED.add(1, {**self.labels(), **execution.general_labels()})
|
|
308
322
|
TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
|
|
@@ -361,9 +375,7 @@ class Docket:
|
|
|
361
375
|
execution = Execution(function, args, kwargs, when, key, attempt=1)
|
|
362
376
|
|
|
363
377
|
async with self.redis() as redis:
|
|
364
|
-
|
|
365
|
-
await self._schedule(redis, pipeline, execution, replace=True)
|
|
366
|
-
await pipeline.execute()
|
|
378
|
+
await self._schedule(redis, execution, replace=True)
|
|
367
379
|
|
|
368
380
|
TASKS_REPLACED.add(1, {**self.labels(), **execution.general_labels()})
|
|
369
381
|
TASKS_CANCELLED.add(1, {**self.labels(), **execution.general_labels()})
|
|
@@ -383,9 +395,7 @@ class Docket:
|
|
|
383
395
|
},
|
|
384
396
|
):
|
|
385
397
|
async with self.redis() as redis:
|
|
386
|
-
|
|
387
|
-
await self._schedule(redis, pipeline, execution, replace=False)
|
|
388
|
-
await pipeline.execute()
|
|
398
|
+
await self._schedule(redis, execution, replace=False)
|
|
389
399
|
|
|
390
400
|
TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
|
|
391
401
|
|
|
@@ -400,9 +410,7 @@ class Docket:
|
|
|
400
410
|
attributes={**self.labels(), "docket.key": key},
|
|
401
411
|
):
|
|
402
412
|
async with self.redis() as redis:
|
|
403
|
-
|
|
404
|
-
await self._cancel(pipeline, key)
|
|
405
|
-
await pipeline.execute()
|
|
413
|
+
await self._cancel(redis, key)
|
|
406
414
|
|
|
407
415
|
TASKS_CANCELLED.add(1, self.labels())
|
|
408
416
|
|
|
@@ -423,10 +431,17 @@ class Docket:
|
|
|
423
431
|
async def _schedule(
|
|
424
432
|
self,
|
|
425
433
|
redis: Redis,
|
|
426
|
-
pipeline: Pipeline,
|
|
427
434
|
execution: Execution,
|
|
428
435
|
replace: bool = False,
|
|
429
436
|
) -> None:
|
|
437
|
+
"""Schedule a task atomically.
|
|
438
|
+
|
|
439
|
+
Handles:
|
|
440
|
+
- Checking for task existence
|
|
441
|
+
- Cancelling existing tasks when replacing
|
|
442
|
+
- Adding tasks to stream (immediate) or queue (future)
|
|
443
|
+
- Tracking stream message IDs for later cancellation
|
|
444
|
+
"""
|
|
430
445
|
if self.strike_list.is_stricken(execution):
|
|
431
446
|
logger.warning(
|
|
432
447
|
"%r is stricken, skipping schedule of %r",
|
|
@@ -449,32 +464,133 @@ class Docket:
|
|
|
449
464
|
key = execution.key
|
|
450
465
|
when = execution.when
|
|
451
466
|
known_task_key = self.known_task_key(key)
|
|
467
|
+
is_immediate = when <= datetime.now(timezone.utc)
|
|
452
468
|
|
|
469
|
+
# Lock per task key to prevent race conditions between concurrent operations
|
|
453
470
|
async with redis.lock(f"{known_task_key}:lock", timeout=10):
|
|
454
|
-
if
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
"
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
471
|
+
if self._schedule_task_script is None:
|
|
472
|
+
self._schedule_task_script = cast(
|
|
473
|
+
_schedule_task,
|
|
474
|
+
redis.register_script(
|
|
475
|
+
# KEYS: stream_key, known_key, parked_key, queue_key
|
|
476
|
+
# ARGV: task_key, when_timestamp, is_immediate, replace, ...message_fields
|
|
477
|
+
"""
|
|
478
|
+
local stream_key = KEYS[1]
|
|
479
|
+
local known_key = KEYS[2]
|
|
480
|
+
local parked_key = KEYS[3]
|
|
481
|
+
local queue_key = KEYS[4]
|
|
482
|
+
|
|
483
|
+
local task_key = ARGV[1]
|
|
484
|
+
local when_timestamp = ARGV[2]
|
|
485
|
+
local is_immediate = ARGV[3] == '1'
|
|
486
|
+
local replace = ARGV[4] == '1'
|
|
487
|
+
|
|
488
|
+
-- Extract message fields from ARGV[5] onwards
|
|
489
|
+
local message = {}
|
|
490
|
+
for i = 5, #ARGV, 2 do
|
|
491
|
+
message[#message + 1] = ARGV[i] -- field name
|
|
492
|
+
message[#message + 1] = ARGV[i + 1] -- field value
|
|
493
|
+
end
|
|
494
|
+
|
|
495
|
+
-- Handle replacement: cancel existing task if needed
|
|
496
|
+
if replace then
|
|
497
|
+
local existing_message_id = redis.call('HGET', known_key, 'stream_message_id')
|
|
498
|
+
if existing_message_id then
|
|
499
|
+
redis.call('XDEL', stream_key, existing_message_id)
|
|
500
|
+
end
|
|
501
|
+
redis.call('DEL', known_key, parked_key)
|
|
502
|
+
redis.call('ZREM', queue_key, task_key)
|
|
503
|
+
else
|
|
504
|
+
-- Check if task already exists
|
|
505
|
+
if redis.call('EXISTS', known_key) == 1 then
|
|
506
|
+
return 'EXISTS'
|
|
507
|
+
end
|
|
508
|
+
end
|
|
509
|
+
|
|
510
|
+
if is_immediate then
|
|
511
|
+
-- Add to stream and store message ID for later cancellation
|
|
512
|
+
local message_id = redis.call('XADD', stream_key, '*', unpack(message))
|
|
513
|
+
redis.call('HSET', known_key, 'when', when_timestamp, 'stream_message_id', message_id)
|
|
514
|
+
return message_id
|
|
515
|
+
else
|
|
516
|
+
-- Add to queue with task data in parked hash
|
|
517
|
+
redis.call('HSET', known_key, 'when', when_timestamp)
|
|
518
|
+
redis.call('HSET', parked_key, unpack(message))
|
|
519
|
+
redis.call('ZADD', queue_key, when_timestamp, task_key)
|
|
520
|
+
return 'QUEUED'
|
|
521
|
+
end
|
|
522
|
+
"""
|
|
523
|
+
),
|
|
524
|
+
)
|
|
525
|
+
schedule_task = self._schedule_task_script
|
|
465
526
|
|
|
466
|
-
|
|
527
|
+
await schedule_task(
|
|
528
|
+
keys=[
|
|
529
|
+
self.stream_key,
|
|
530
|
+
known_task_key,
|
|
531
|
+
self.parked_task_key(key),
|
|
532
|
+
self.queue_key,
|
|
533
|
+
],
|
|
534
|
+
args=[
|
|
535
|
+
key,
|
|
536
|
+
str(when.timestamp()),
|
|
537
|
+
"1" if is_immediate else "0",
|
|
538
|
+
"1" if replace else "0",
|
|
539
|
+
*[
|
|
540
|
+
item
|
|
541
|
+
for field, value in message.items()
|
|
542
|
+
for item in (field, value)
|
|
543
|
+
],
|
|
544
|
+
],
|
|
545
|
+
)
|
|
467
546
|
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
else:
|
|
471
|
-
pipeline.hset(self.parked_task_key(key), mapping=message) # type: ignore[arg-type]
|
|
472
|
-
pipeline.zadd(self.queue_key, {key: when.timestamp()})
|
|
547
|
+
async def _cancel(self, redis: Redis, key: str) -> None:
|
|
548
|
+
"""Cancel a task atomically.
|
|
473
549
|
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
550
|
+
Handles cancellation regardless of task location:
|
|
551
|
+
- From the stream (using stored message ID)
|
|
552
|
+
- From the queue (scheduled tasks)
|
|
553
|
+
- Cleans up all associated metadata keys
|
|
554
|
+
"""
|
|
555
|
+
if self._cancel_task_script is None:
|
|
556
|
+
self._cancel_task_script = cast(
|
|
557
|
+
_cancel_task,
|
|
558
|
+
redis.register_script(
|
|
559
|
+
# KEYS: stream_key, known_key, parked_key, queue_key
|
|
560
|
+
# ARGV: task_key
|
|
561
|
+
"""
|
|
562
|
+
local stream_key = KEYS[1]
|
|
563
|
+
local known_key = KEYS[2]
|
|
564
|
+
local parked_key = KEYS[3]
|
|
565
|
+
local queue_key = KEYS[4]
|
|
566
|
+
local task_key = ARGV[1]
|
|
567
|
+
|
|
568
|
+
-- Delete from stream if message ID exists
|
|
569
|
+
local message_id = redis.call('HGET', known_key, 'stream_message_id')
|
|
570
|
+
if message_id then
|
|
571
|
+
redis.call('XDEL', stream_key, message_id)
|
|
572
|
+
end
|
|
573
|
+
|
|
574
|
+
-- Clean up all task-related keys
|
|
575
|
+
redis.call('DEL', known_key, parked_key)
|
|
576
|
+
redis.call('ZREM', queue_key, task_key)
|
|
577
|
+
|
|
578
|
+
return 'OK'
|
|
579
|
+
"""
|
|
580
|
+
),
|
|
581
|
+
)
|
|
582
|
+
cancel_task = self._cancel_task_script
|
|
583
|
+
|
|
584
|
+
# Execute the cancellation script
|
|
585
|
+
await cancel_task(
|
|
586
|
+
keys=[
|
|
587
|
+
self.stream_key,
|
|
588
|
+
self.known_task_key(key),
|
|
589
|
+
self.parked_task_key(key),
|
|
590
|
+
self.queue_key,
|
|
591
|
+
],
|
|
592
|
+
args=[key],
|
|
593
|
+
)
|
|
478
594
|
|
|
479
595
|
@property
|
|
480
596
|
def strike_key(self) -> str:
|
docket/worker.py
CHANGED
|
@@ -20,6 +20,7 @@ from redis.asyncio import Redis
|
|
|
20
20
|
from redis.exceptions import ConnectionError, LockError
|
|
21
21
|
|
|
22
22
|
from .dependencies import (
|
|
23
|
+
ConcurrencyLimit,
|
|
23
24
|
Dependency,
|
|
24
25
|
FailedDependency,
|
|
25
26
|
Perpetual,
|
|
@@ -248,6 +249,7 @@ class Worker:
|
|
|
248
249
|
)
|
|
249
250
|
|
|
250
251
|
active_tasks: dict[asyncio.Task[None], RedisMessageID] = {}
|
|
252
|
+
task_executions: dict[asyncio.Task[None], Execution] = {}
|
|
251
253
|
available_slots = self.concurrency
|
|
252
254
|
|
|
253
255
|
log_context = self._log_context()
|
|
@@ -298,6 +300,7 @@ class Worker:
|
|
|
298
300
|
|
|
299
301
|
task = asyncio.create_task(self._execute(execution), name=execution.key)
|
|
300
302
|
active_tasks[task] = message_id
|
|
303
|
+
task_executions[task] = execution
|
|
301
304
|
|
|
302
305
|
nonlocal available_slots
|
|
303
306
|
available_slots -= 1
|
|
@@ -308,6 +311,7 @@ class Worker:
|
|
|
308
311
|
completed_tasks = {task for task in active_tasks if task.done()}
|
|
309
312
|
for task in completed_tasks:
|
|
310
313
|
message_id = active_tasks.pop(task)
|
|
314
|
+
task_executions.pop(task)
|
|
311
315
|
await task
|
|
312
316
|
await ack_message(redis, message_id)
|
|
313
317
|
|
|
@@ -343,7 +347,9 @@ class Worker:
|
|
|
343
347
|
if not message: # pragma: no cover
|
|
344
348
|
continue
|
|
345
349
|
|
|
346
|
-
|
|
350
|
+
task_started = start_task(message_id, message)
|
|
351
|
+
if not task_started:
|
|
352
|
+
# Other errors - delete and ack
|
|
347
353
|
await self._delete_known_task(redis, message)
|
|
348
354
|
await ack_message(redis, message_id)
|
|
349
355
|
|
|
@@ -400,7 +406,7 @@ class Worker:
|
|
|
400
406
|
task[task_data[j]] = task_data[j+1]
|
|
401
407
|
end
|
|
402
408
|
|
|
403
|
-
redis.call('XADD', KEYS[2], '*',
|
|
409
|
+
local message_id = redis.call('XADD', KEYS[2], '*',
|
|
404
410
|
'key', task['key'],
|
|
405
411
|
'when', task['when'],
|
|
406
412
|
'function', task['function'],
|
|
@@ -408,6 +414,9 @@ class Worker:
|
|
|
408
414
|
'kwargs', task['kwargs'],
|
|
409
415
|
'attempt', task['attempt']
|
|
410
416
|
)
|
|
417
|
+
-- Store the message ID in the known task key
|
|
418
|
+
local known_key = ARGV[2] .. ":known:" .. key
|
|
419
|
+
redis.call('HSET', known_key, 'stream_message_id', message_id)
|
|
411
420
|
redis.call('DEL', hash_key)
|
|
412
421
|
due_work = due_work + 1
|
|
413
422
|
end
|
|
@@ -534,6 +543,32 @@ class Worker:
|
|
|
534
543
|
) as span:
|
|
535
544
|
try:
|
|
536
545
|
async with resolved_dependencies(self, execution) as dependencies:
|
|
546
|
+
# Check concurrency limits after dependency resolution
|
|
547
|
+
concurrency_limit = get_single_dependency_of_type(
|
|
548
|
+
dependencies, ConcurrencyLimit
|
|
549
|
+
)
|
|
550
|
+
if concurrency_limit and not concurrency_limit.is_bypassed:
|
|
551
|
+
async with self.docket.redis() as redis:
|
|
552
|
+
# Check if we can acquire a concurrency slot
|
|
553
|
+
if not await self._can_start_task(redis, execution):
|
|
554
|
+
# Task cannot start due to concurrency limits - reschedule
|
|
555
|
+
logger.debug(
|
|
556
|
+
"🔒 Task %s blocked by concurrency limit, rescheduling",
|
|
557
|
+
execution.key,
|
|
558
|
+
extra=log_context,
|
|
559
|
+
)
|
|
560
|
+
# Reschedule for a few milliseconds in the future
|
|
561
|
+
when = datetime.now(timezone.utc) + timedelta(
|
|
562
|
+
milliseconds=50
|
|
563
|
+
)
|
|
564
|
+
await self.docket.add(execution.function, when=when)(
|
|
565
|
+
*execution.args, **execution.kwargs
|
|
566
|
+
)
|
|
567
|
+
return
|
|
568
|
+
else:
|
|
569
|
+
# Successfully acquired slot
|
|
570
|
+
pass
|
|
571
|
+
|
|
537
572
|
# Preemptively reschedule the perpetual task for the future, or clear
|
|
538
573
|
# the known task key for this task
|
|
539
574
|
rescheduled = await self._perpetuate_if_requested(
|
|
@@ -560,17 +595,34 @@ class Worker:
|
|
|
560
595
|
],
|
|
561
596
|
)
|
|
562
597
|
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
598
|
+
# Apply timeout logic - either user's timeout or redelivery timeout
|
|
599
|
+
user_timeout = get_single_dependency_of_type(dependencies, Timeout)
|
|
600
|
+
if user_timeout:
|
|
601
|
+
# If user timeout is longer than redelivery timeout, limit it
|
|
602
|
+
if user_timeout.base > self.redelivery_timeout:
|
|
603
|
+
# Create a new timeout limited by redelivery timeout
|
|
604
|
+
# Remove the user timeout from dependencies to avoid conflicts
|
|
605
|
+
limited_dependencies = {
|
|
606
|
+
k: v
|
|
607
|
+
for k, v in dependencies.items()
|
|
608
|
+
if not isinstance(v, Timeout)
|
|
609
|
+
}
|
|
610
|
+
limited_timeout = Timeout(self.redelivery_timeout)
|
|
611
|
+
limited_timeout.start()
|
|
612
|
+
await self._run_function_with_timeout(
|
|
613
|
+
execution, limited_dependencies, limited_timeout
|
|
614
|
+
)
|
|
615
|
+
else:
|
|
616
|
+
# User timeout is within redelivery timeout, use as-is
|
|
617
|
+
await self._run_function_with_timeout(
|
|
618
|
+
execution, dependencies, user_timeout
|
|
619
|
+
)
|
|
567
620
|
else:
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
},
|
|
621
|
+
# No user timeout - apply redelivery timeout as hard limit
|
|
622
|
+
redelivery_timeout = Timeout(self.redelivery_timeout)
|
|
623
|
+
redelivery_timeout.start()
|
|
624
|
+
await self._run_function_with_timeout(
|
|
625
|
+
execution, dependencies, redelivery_timeout
|
|
574
626
|
)
|
|
575
627
|
|
|
576
628
|
duration = log_context["duration"] = time.time() - start
|
|
@@ -604,6 +656,15 @@ class Worker:
|
|
|
604
656
|
"%s [%s] %s", arrow, ms(duration), call, extra=log_context
|
|
605
657
|
)
|
|
606
658
|
finally:
|
|
659
|
+
# Release concurrency slot if we acquired one
|
|
660
|
+
if dependencies:
|
|
661
|
+
concurrency_limit = get_single_dependency_of_type(
|
|
662
|
+
dependencies, ConcurrencyLimit
|
|
663
|
+
)
|
|
664
|
+
if concurrency_limit and not concurrency_limit.is_bypassed:
|
|
665
|
+
async with self.docket.redis() as redis:
|
|
666
|
+
await self._release_concurrency_slot(redis, execution)
|
|
667
|
+
|
|
607
668
|
TASKS_RUNNING.add(-1, counter_labels)
|
|
608
669
|
TASKS_COMPLETED.add(1, counter_labels)
|
|
609
670
|
TASK_DURATION.record(duration, counter_labels)
|
|
@@ -616,7 +677,13 @@ class Worker:
|
|
|
616
677
|
) -> None:
|
|
617
678
|
task_coro = cast(
|
|
618
679
|
Coroutine[None, None, None],
|
|
619
|
-
execution.function(
|
|
680
|
+
execution.function(
|
|
681
|
+
*execution.args,
|
|
682
|
+
**{
|
|
683
|
+
**execution.kwargs,
|
|
684
|
+
**dependencies,
|
|
685
|
+
},
|
|
686
|
+
),
|
|
620
687
|
)
|
|
621
688
|
task = asyncio.create_task(task_coro)
|
|
622
689
|
try:
|
|
@@ -762,6 +829,95 @@ class Worker:
|
|
|
762
829
|
extra=self._log_context(),
|
|
763
830
|
)
|
|
764
831
|
|
|
832
|
+
async def _can_start_task(self, redis: Redis, execution: Execution) -> bool:
|
|
833
|
+
"""Check if a task can start based on concurrency limits."""
|
|
834
|
+
# Check if task has a concurrency limit dependency
|
|
835
|
+
concurrency_limit = get_single_dependency_parameter_of_type(
|
|
836
|
+
execution.function, ConcurrencyLimit
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
if not concurrency_limit:
|
|
840
|
+
return True # No concurrency limit, can always start
|
|
841
|
+
|
|
842
|
+
# Get the concurrency key for this task
|
|
843
|
+
try:
|
|
844
|
+
argument_value = execution.get_argument(concurrency_limit.argument_name)
|
|
845
|
+
except KeyError:
|
|
846
|
+
# If argument not found, let the task fail naturally in execution
|
|
847
|
+
return True
|
|
848
|
+
|
|
849
|
+
scope = concurrency_limit.scope or self.docket.name
|
|
850
|
+
concurrency_key = (
|
|
851
|
+
f"{scope}:concurrency:{concurrency_limit.argument_name}:{argument_value}"
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
# Use Redis sorted set with timestamps to track concurrency and handle expiration
|
|
855
|
+
lua_script = """
|
|
856
|
+
local key = KEYS[1]
|
|
857
|
+
local max_concurrent = tonumber(ARGV[1])
|
|
858
|
+
local worker_id = ARGV[2]
|
|
859
|
+
local task_key = ARGV[3]
|
|
860
|
+
local current_time = tonumber(ARGV[4])
|
|
861
|
+
local expiration_time = tonumber(ARGV[5])
|
|
862
|
+
|
|
863
|
+
-- Remove expired entries
|
|
864
|
+
local expired_cutoff = current_time - expiration_time
|
|
865
|
+
redis.call('ZREMRANGEBYSCORE', key, 0, expired_cutoff)
|
|
866
|
+
|
|
867
|
+
-- Get current count
|
|
868
|
+
local current = redis.call('ZCARD', key)
|
|
869
|
+
|
|
870
|
+
if current < max_concurrent then
|
|
871
|
+
-- Add this worker's task to the sorted set with current timestamp
|
|
872
|
+
redis.call('ZADD', key, current_time, worker_id .. ':' .. task_key)
|
|
873
|
+
return 1
|
|
874
|
+
else
|
|
875
|
+
return 0
|
|
876
|
+
end
|
|
877
|
+
"""
|
|
878
|
+
|
|
879
|
+
current_time = datetime.now(timezone.utc).timestamp()
|
|
880
|
+
expiration_seconds = self.redelivery_timeout.total_seconds()
|
|
881
|
+
|
|
882
|
+
result = await redis.eval( # type: ignore
|
|
883
|
+
lua_script,
|
|
884
|
+
1,
|
|
885
|
+
concurrency_key,
|
|
886
|
+
str(concurrency_limit.max_concurrent),
|
|
887
|
+
self.name,
|
|
888
|
+
execution.key,
|
|
889
|
+
current_time,
|
|
890
|
+
expiration_seconds,
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
return bool(result)
|
|
894
|
+
|
|
895
|
+
async def _release_concurrency_slot(
|
|
896
|
+
self, redis: Redis, execution: Execution
|
|
897
|
+
) -> None:
|
|
898
|
+
"""Release a concurrency slot when task completes."""
|
|
899
|
+
# Check if task has a concurrency limit dependency
|
|
900
|
+
concurrency_limit = get_single_dependency_parameter_of_type(
|
|
901
|
+
execution.function, ConcurrencyLimit
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
if not concurrency_limit:
|
|
905
|
+
return # No concurrency limit to release
|
|
906
|
+
|
|
907
|
+
# Get the concurrency key for this task
|
|
908
|
+
try:
|
|
909
|
+
argument_value = execution.get_argument(concurrency_limit.argument_name)
|
|
910
|
+
except KeyError:
|
|
911
|
+
return # If argument not found, nothing to release
|
|
912
|
+
|
|
913
|
+
scope = concurrency_limit.scope or self.docket.name
|
|
914
|
+
concurrency_key = (
|
|
915
|
+
f"{scope}:concurrency:{concurrency_limit.argument_name}:{argument_value}"
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
# Remove this worker's task from the sorted set
|
|
919
|
+
await redis.zrem(concurrency_key, f"{self.name}:{execution.key}") # type: ignore
|
|
920
|
+
|
|
765
921
|
|
|
766
922
|
def ms(seconds: float) -> str:
|
|
767
923
|
if seconds < 100:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydocket
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.9.0
|
|
4
4
|
Summary: A distributed background task system for Python functions
|
|
5
5
|
Project-URL: Homepage, https://github.com/chrisguidry/docket
|
|
6
6
|
Project-URL: Bug Tracker, https://github.com/chrisguidry/docket/issues
|
|
@@ -93,6 +93,8 @@ reference](https://chrisguidry.github.io/docket/api-reference/).
|
|
|
93
93
|
|
|
94
94
|
🧩 Fully type-complete and type-aware for your background task functions
|
|
95
95
|
|
|
96
|
+
💉 Dependency injection like FastAPI, Typer, and FastMCP for reusable resources
|
|
97
|
+
|
|
96
98
|
## Installing `docket`
|
|
97
99
|
|
|
98
100
|
Docket is [available on PyPI](https://pypi.org/project/pydocket/) under the package name
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
docket/__init__.py,sha256=onwZzh73tESWoFBukbcW-7gjxoXb-yI7dutRD7tPN6g,915
|
|
2
|
+
docket/__main__.py,sha256=wcCrL4PjG51r5wVKqJhcoJPTLfHW0wNbD31DrUN0MWI,28
|
|
3
|
+
docket/annotations.py,sha256=wttix9UOeMFMAWXAIJUfUw5GjESJZsACb4YXJCozP7Q,2348
|
|
4
|
+
docket/cli.py,sha256=rTfri2--u4Q5PlXyh7Ub_F5uh3-TtZOWLUp9WY_TvAE,25750
|
|
5
|
+
docket/dependencies.py,sha256=BC0bnt10cr9_S1p5JAP_bnC9RwZkTr9ulPBrxC7eZnA,20247
|
|
6
|
+
docket/docket.py,sha256=0nQCHDDHy7trv2a0eYygGgIKiA7fWq5GcOXye3_CPWM,30847
|
|
7
|
+
docket/execution.py,sha256=r_2RGC1qhtAcBUg7E6wewLEgftrf3hIxNbH0HnYPbek,14961
|
|
8
|
+
docket/instrumentation.py,sha256=ogvzrfKbWsdPGfdg4hByH3_r5d3b5AwwQkSrmXw0hRg,5492
|
|
9
|
+
docket/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
docket/tasks.py,sha256=RIlSM2omh-YDwVnCz6M5MtmK8T_m_s1w2OlRRxDUs6A,1437
|
|
11
|
+
docket/worker.py,sha256=jqVYqtQyxbk-BIy3shY8haX-amVT9Np97VhJuaQTfpM,35174
|
|
12
|
+
pydocket-0.9.0.dist-info/METADATA,sha256=kymp9PKG7UwMj0i0qGSSCHKu-g-tS__qydr6mYuMLtg,5418
|
|
13
|
+
pydocket-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
14
|
+
pydocket-0.9.0.dist-info/entry_points.txt,sha256=4WOk1nUlBsUT5O3RyMci2ImuC5XFswuopElYcLHtD5k,47
|
|
15
|
+
pydocket-0.9.0.dist-info/licenses/LICENSE,sha256=YuVWU_ZXO0K_k2FG8xWKe5RGxV24AhJKTvQmKfqXuyk,1087
|
|
16
|
+
pydocket-0.9.0.dist-info/RECORD,,
|
pydocket-0.7.1.dist-info/RECORD
DELETED
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
docket/__init__.py,sha256=sY1T_NVsXQNOmOhOnfYmZ95dcE_52Ov6DSIVIMZp-1w,869
|
|
2
|
-
docket/__main__.py,sha256=wcCrL4PjG51r5wVKqJhcoJPTLfHW0wNbD31DrUN0MWI,28
|
|
3
|
-
docket/annotations.py,sha256=wttix9UOeMFMAWXAIJUfUw5GjESJZsACb4YXJCozP7Q,2348
|
|
4
|
-
docket/cli.py,sha256=XG_mbjcqNRO0F0hh6l3AwH9bIZv9xJofZaeaAj9nChc,21608
|
|
5
|
-
docket/dependencies.py,sha256=GBwyEY198JFrfm7z5GkLbd84hv7sJktKBMJXv4veWig,17007
|
|
6
|
-
docket/docket.py,sha256=Cw7QB1d0eDwSgwn0Rj26WjFsXSe7MJtfsUBBHGalL7A,26262
|
|
7
|
-
docket/execution.py,sha256=r_2RGC1qhtAcBUg7E6wewLEgftrf3hIxNbH0HnYPbek,14961
|
|
8
|
-
docket/instrumentation.py,sha256=ogvzrfKbWsdPGfdg4hByH3_r5d3b5AwwQkSrmXw0hRg,5492
|
|
9
|
-
docket/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
-
docket/tasks.py,sha256=RIlSM2omh-YDwVnCz6M5MtmK8T_m_s1w2OlRRxDUs6A,1437
|
|
11
|
-
docket/worker.py,sha256=CY5Z9p8FZw-6WUwp7Ws4A0V7IFTmonSnBmYP-Cp8Fdw,28079
|
|
12
|
-
pydocket-0.7.1.dist-info/METADATA,sha256=00KHm5Er2R6dmjHLTYBUF13kKAeCRPHmDTdAcv5oRcQ,5335
|
|
13
|
-
pydocket-0.7.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
14
|
-
pydocket-0.7.1.dist-info/entry_points.txt,sha256=4WOk1nUlBsUT5O3RyMci2ImuC5XFswuopElYcLHtD5k,47
|
|
15
|
-
pydocket-0.7.1.dist-info/licenses/LICENSE,sha256=YuVWU_ZXO0K_k2FG8xWKe5RGxV24AhJKTvQmKfqXuyk,1087
|
|
16
|
-
pydocket-0.7.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|