prefect-client 3.0.0rc17__py3-none-any.whl → 3.0.0rc19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefect/_internal/concurrency/services.py +14 -0
- prefect/_internal/schemas/bases.py +1 -0
- prefect/blocks/core.py +36 -29
- prefect/client/orchestration.py +97 -2
- prefect/client/schemas/actions.py +14 -4
- prefect/client/schemas/filters.py +20 -0
- prefect/client/schemas/objects.py +3 -0
- prefect/client/schemas/responses.py +3 -0
- prefect/client/schemas/sorting.py +2 -0
- prefect/concurrency/v1/__init__.py +0 -0
- prefect/concurrency/v1/asyncio.py +143 -0
- prefect/concurrency/v1/context.py +27 -0
- prefect/concurrency/v1/events.py +61 -0
- prefect/concurrency/v1/services.py +116 -0
- prefect/concurrency/v1/sync.py +92 -0
- prefect/context.py +2 -2
- prefect/deployments/flow_runs.py +0 -7
- prefect/deployments/runner.py +11 -0
- prefect/events/clients.py +41 -0
- prefect/events/related.py +72 -73
- prefect/events/utilities.py +2 -0
- prefect/events/worker.py +12 -3
- prefect/flow_engine.py +2 -0
- prefect/flows.py +7 -0
- prefect/records/__init__.py +1 -1
- prefect/records/base.py +223 -0
- prefect/records/filesystem.py +207 -0
- prefect/records/memory.py +178 -0
- prefect/records/result_store.py +19 -14
- prefect/results.py +11 -0
- prefect/runner/runner.py +7 -4
- prefect/settings.py +0 -8
- prefect/task_engine.py +98 -209
- prefect/task_worker.py +7 -39
- prefect/tasks.py +2 -9
- prefect/transactions.py +67 -19
- prefect/utilities/asyncutils.py +3 -3
- prefect/utilities/callables.py +1 -3
- prefect/utilities/engine.py +7 -6
- {prefect_client-3.0.0rc17.dist-info → prefect_client-3.0.0rc19.dist-info}/METADATA +3 -4
- {prefect_client-3.0.0rc17.dist-info → prefect_client-3.0.0rc19.dist-info}/RECORD +44 -36
- prefect/records/store.py +0 -9
- {prefect_client-3.0.0rc17.dist-info → prefect_client-3.0.0rc19.dist-info}/LICENSE +0 -0
- {prefect_client-3.0.0rc17.dist-info → prefect_client-3.0.0rc19.dist-info}/WHEEL +0 -0
- {prefect_client-3.0.0rc17.dist-info → prefect_client-3.0.0rc19.dist-info}/top_level.txt +0 -0
@@ -39,6 +39,7 @@ class QueueService(abc.ABC, Generic[T]):
|
|
39
39
|
daemon=True,
|
40
40
|
name=f"{type(self).__name__}Thread",
|
41
41
|
)
|
42
|
+
self._logger = logging.getLogger(f"{type(self).__name__}")
|
42
43
|
|
43
44
|
def start(self):
|
44
45
|
logger.debug("Starting service %r", self)
|
@@ -144,11 +145,24 @@ class QueueService(abc.ABC, Generic[T]):
|
|
144
145
|
self._done_event.set()
|
145
146
|
|
146
147
|
async def _main_loop(self):
|
148
|
+
last_log_time = 0
|
149
|
+
log_interval = 4 # log every 4 seconds
|
150
|
+
|
147
151
|
while True:
|
148
152
|
item: T = await self._queue_get_thread.submit(
|
149
153
|
create_call(self._queue.get)
|
150
154
|
).aresult()
|
151
155
|
|
156
|
+
if self._stopped:
|
157
|
+
current_time = asyncio.get_event_loop().time()
|
158
|
+
queue_size = self._queue.qsize()
|
159
|
+
|
160
|
+
if current_time - last_log_time >= log_interval and queue_size > 0:
|
161
|
+
self._logger.warning(
|
162
|
+
f"Still processing items: {queue_size} items remaining..."
|
163
|
+
)
|
164
|
+
last_log_time = current_time
|
165
|
+
|
152
166
|
if item is None:
|
153
167
|
logger.debug("Exiting service %r", self)
|
154
168
|
self._queue.task_done()
|
prefect/blocks/core.py
CHANGED
@@ -24,9 +24,7 @@ from typing import (
|
|
24
24
|
)
|
25
25
|
from uuid import UUID, uuid4
|
26
26
|
|
27
|
-
from griffe
|
28
|
-
from griffe.docstrings.dataclasses import DocstringSection, DocstringSectionKind
|
29
|
-
from griffe.docstrings.parsers import Parser, parse
|
27
|
+
from griffe import Docstring, DocstringSection, DocstringSectionKind, Parser, parse
|
30
28
|
from packaging.version import InvalidVersion, Version
|
31
29
|
from pydantic import (
|
32
30
|
BaseModel,
|
@@ -130,7 +128,9 @@ def _is_subclass(cls, parent_cls) -> bool:
|
|
130
128
|
Checks if a given class is a subclass of another class. Unlike issubclass,
|
131
129
|
this will not throw an exception if cls is an instance instead of a type.
|
132
130
|
"""
|
133
|
-
|
131
|
+
# For python<=3.11 inspect.isclass() will return True for parametrized types (e.g. list[str])
|
132
|
+
# so we need to check for get_origin() to avoid TypeError for issubclass.
|
133
|
+
return inspect.isclass(cls) and not get_origin(cls) and issubclass(cls, parent_cls)
|
134
134
|
|
135
135
|
|
136
136
|
def _collect_secret_fields(
|
@@ -138,12 +138,12 @@ def _collect_secret_fields(
|
|
138
138
|
) -> None:
|
139
139
|
"""
|
140
140
|
Recursively collects all secret fields from a given type and adds them to the
|
141
|
-
secrets list, supporting nested Union /
|
142
|
-
mutates the input secrets list, thus does not return anything.
|
141
|
+
secrets list, supporting nested Union / Dict / Tuple / List / BaseModel fields.
|
142
|
+
Also, note, this function mutates the input secrets list, thus does not return anything.
|
143
143
|
"""
|
144
|
-
if get_origin(type_)
|
145
|
-
for
|
146
|
-
_collect_secret_fields(name,
|
144
|
+
if get_origin(type_) in (Union, dict, list, tuple):
|
145
|
+
for nested_type in get_args(type_):
|
146
|
+
_collect_secret_fields(name, nested_type, secrets)
|
147
147
|
return
|
148
148
|
elif _is_subclass(type_, BaseModel):
|
149
149
|
for field_name, field in type_.model_fields.items():
|
@@ -232,21 +232,25 @@ def schema_extra(schema: Dict[str, Any], model: Type["Block"]):
|
|
232
232
|
|
233
233
|
# create block schema references
|
234
234
|
refs = schema["block_schema_references"] = {}
|
235
|
+
|
236
|
+
def collect_block_schema_references(field_name: str, annotation: type) -> None:
|
237
|
+
"""Walk through the annotation and collect block schemas for any nested blocks."""
|
238
|
+
if Block.is_block_class(annotation):
|
239
|
+
if isinstance(refs.get(field_name), list):
|
240
|
+
refs[field_name].append(annotation._to_block_schema_reference_dict())
|
241
|
+
elif isinstance(refs.get(field_name), dict):
|
242
|
+
refs[field_name] = [
|
243
|
+
refs[field_name],
|
244
|
+
annotation._to_block_schema_reference_dict(),
|
245
|
+
]
|
246
|
+
else:
|
247
|
+
refs[field_name] = annotation._to_block_schema_reference_dict()
|
248
|
+
if get_origin(annotation) in (Union, list, tuple, dict):
|
249
|
+
for type_ in get_args(annotation):
|
250
|
+
collect_block_schema_references(field_name, type_)
|
251
|
+
|
235
252
|
for name, field in model.model_fields.items():
|
236
|
-
|
237
|
-
refs[name] = field.annotation._to_block_schema_reference_dict()
|
238
|
-
if get_origin(field.annotation) in [Union, list]:
|
239
|
-
for type_ in get_args(field.annotation):
|
240
|
-
if Block.is_block_class(type_):
|
241
|
-
if isinstance(refs.get(name), list):
|
242
|
-
refs[name].append(type_._to_block_schema_reference_dict())
|
243
|
-
elif isinstance(refs.get(name), dict):
|
244
|
-
refs[name] = [
|
245
|
-
refs[name],
|
246
|
-
type_._to_block_schema_reference_dict(),
|
247
|
-
]
|
248
|
-
else:
|
249
|
-
refs[name] = type_._to_block_schema_reference_dict()
|
253
|
+
collect_block_schema_references(name, field.annotation)
|
250
254
|
|
251
255
|
|
252
256
|
@register_base_type
|
@@ -1067,13 +1071,16 @@ class Block(BaseModel, ABC):
|
|
1067
1071
|
"subclass and not on a Block interface class directly."
|
1068
1072
|
)
|
1069
1073
|
|
1074
|
+
async def register_blocks_in_annotation(annotation: type) -> None:
|
1075
|
+
"""Walk through the annotation and register any nested blocks."""
|
1076
|
+
if Block.is_block_class(annotation):
|
1077
|
+
await annotation.register_type_and_schema(client=client)
|
1078
|
+
elif get_origin(annotation) in (Union, tuple, list, dict):
|
1079
|
+
for inner_annotation in get_args(annotation):
|
1080
|
+
await register_blocks_in_annotation(inner_annotation)
|
1081
|
+
|
1070
1082
|
for field in cls.model_fields.values():
|
1071
|
-
|
1072
|
-
await field.annotation.register_type_and_schema(client=client)
|
1073
|
-
if get_origin(field.annotation) is Union:
|
1074
|
-
for annotation in get_args(field.annotation):
|
1075
|
-
if Block.is_block_class(annotation):
|
1076
|
-
await annotation.register_type_and_schema(client=client)
|
1083
|
+
await register_blocks_in_annotation(field.annotation)
|
1077
1084
|
|
1078
1085
|
try:
|
1079
1086
|
block_type = await client.read_block_type_by_slug(
|
prefect/client/orchestration.py
CHANGED
@@ -939,6 +939,57 @@ class PrefectClient:
|
|
939
939
|
else:
|
940
940
|
raise
|
941
941
|
|
942
|
+
async def increment_v1_concurrency_slots(
|
943
|
+
self,
|
944
|
+
names: List[str],
|
945
|
+
task_run_id: UUID,
|
946
|
+
) -> httpx.Response:
|
947
|
+
"""
|
948
|
+
Increment concurrency limit slots for the specified limits.
|
949
|
+
|
950
|
+
Args:
|
951
|
+
names (List[str]): A list of limit names for which to increment limits.
|
952
|
+
task_run_id (UUID): The task run ID incrementing the limits.
|
953
|
+
"""
|
954
|
+
data = {
|
955
|
+
"names": names,
|
956
|
+
"task_run_id": str(task_run_id),
|
957
|
+
}
|
958
|
+
|
959
|
+
return await self._client.post(
|
960
|
+
"/concurrency_limits/increment",
|
961
|
+
json=data,
|
962
|
+
)
|
963
|
+
|
964
|
+
async def decrement_v1_concurrency_slots(
|
965
|
+
self,
|
966
|
+
names: List[str],
|
967
|
+
task_run_id: UUID,
|
968
|
+
occupancy_seconds: float,
|
969
|
+
) -> httpx.Response:
|
970
|
+
"""
|
971
|
+
Decrement concurrency limit slots for the specified limits.
|
972
|
+
|
973
|
+
Args:
|
974
|
+
names (List[str]): A list of limit names to decrement.
|
975
|
+
task_run_id (UUID): The task run ID that incremented the limits.
|
976
|
+
occupancy_seconds (float): The duration in seconds that the limits
|
977
|
+
were held.
|
978
|
+
|
979
|
+
Returns:
|
980
|
+
httpx.Response: The HTTP response from the server.
|
981
|
+
"""
|
982
|
+
data = {
|
983
|
+
"names": names,
|
984
|
+
"task_run_id": str(task_run_id),
|
985
|
+
"occupancy_seconds": occupancy_seconds,
|
986
|
+
}
|
987
|
+
|
988
|
+
return await self._client.post(
|
989
|
+
"/concurrency_limits/decrement",
|
990
|
+
json=data,
|
991
|
+
)
|
992
|
+
|
942
993
|
async def create_work_queue(
|
943
994
|
self,
|
944
995
|
name: str,
|
@@ -1599,6 +1650,7 @@ class PrefectClient:
|
|
1599
1650
|
name: str,
|
1600
1651
|
version: Optional[str] = None,
|
1601
1652
|
schedules: Optional[List[DeploymentScheduleCreate]] = None,
|
1653
|
+
concurrency_limit: Optional[int] = None,
|
1602
1654
|
parameters: Optional[Dict[str, Any]] = None,
|
1603
1655
|
description: Optional[str] = None,
|
1604
1656
|
work_queue_name: Optional[str] = None,
|
@@ -1656,6 +1708,7 @@ class PrefectClient:
|
|
1656
1708
|
parameter_openapi_schema=parameter_openapi_schema,
|
1657
1709
|
paused=paused,
|
1658
1710
|
schedules=schedules or [],
|
1711
|
+
concurrency_limit=concurrency_limit,
|
1659
1712
|
pull_steps=pull_steps,
|
1660
1713
|
enforce_parameter_schema=enforce_parameter_schema,
|
1661
1714
|
)
|
@@ -2612,6 +2665,7 @@ class PrefectClient:
|
|
2612
2665
|
async def create_work_pool(
|
2613
2666
|
self,
|
2614
2667
|
work_pool: WorkPoolCreate,
|
2668
|
+
overwrite: bool = False,
|
2615
2669
|
) -> WorkPool:
|
2616
2670
|
"""
|
2617
2671
|
Creates a work pool with the provided configuration.
|
@@ -2629,7 +2683,24 @@ class PrefectClient:
|
|
2629
2683
|
)
|
2630
2684
|
except httpx.HTTPStatusError as e:
|
2631
2685
|
if e.response.status_code == status.HTTP_409_CONFLICT:
|
2632
|
-
|
2686
|
+
if overwrite:
|
2687
|
+
existing_work_pool = await self.read_work_pool(
|
2688
|
+
work_pool_name=work_pool.name
|
2689
|
+
)
|
2690
|
+
if existing_work_pool.type != work_pool.type:
|
2691
|
+
warnings.warn(
|
2692
|
+
"Overwriting work pool type is not supported. Ignoring provided type.",
|
2693
|
+
category=UserWarning,
|
2694
|
+
)
|
2695
|
+
await self.update_work_pool(
|
2696
|
+
work_pool_name=work_pool.name,
|
2697
|
+
work_pool=WorkPoolUpdate.model_validate(
|
2698
|
+
work_pool.model_dump(exclude={"name", "type"})
|
2699
|
+
),
|
2700
|
+
)
|
2701
|
+
response = await self._client.get(f"/work_pools/{work_pool.name}")
|
2702
|
+
else:
|
2703
|
+
raise prefect.exceptions.ObjectAlreadyExists(http_exc=e) from e
|
2633
2704
|
else:
|
2634
2705
|
raise
|
2635
2706
|
|
@@ -3156,7 +3227,7 @@ class PrefectClient:
|
|
3156
3227
|
return pydantic.TypeAdapter(List[Automation]).validate_python(response.json())
|
3157
3228
|
|
3158
3229
|
async def find_automation(
|
3159
|
-
self, id_or_name: Union[str, UUID]
|
3230
|
+
self, id_or_name: Union[str, UUID]
|
3160
3231
|
) -> Optional[Automation]:
|
3161
3232
|
if isinstance(id_or_name, str):
|
3162
3233
|
try:
|
@@ -4096,3 +4167,27 @@ class SyncPrefectClient:
|
|
4096
4167
|
"occupancy_seconds": occupancy_seconds,
|
4097
4168
|
},
|
4098
4169
|
)
|
4170
|
+
|
4171
|
+
def decrement_v1_concurrency_slots(
|
4172
|
+
self, names: List[str], occupancy_seconds: float, task_run_id: UUID
|
4173
|
+
) -> httpx.Response:
|
4174
|
+
"""
|
4175
|
+
Release the specified concurrency limits.
|
4176
|
+
|
4177
|
+
Args:
|
4178
|
+
names (List[str]): A list of limit names to decrement.
|
4179
|
+
occupancy_seconds (float): The duration in seconds that the slots
|
4180
|
+
were held.
|
4181
|
+
task_run_id (UUID): The task run ID that incremented the limits.
|
4182
|
+
|
4183
|
+
Returns:
|
4184
|
+
httpx.Response: The HTTP response from the server.
|
4185
|
+
"""
|
4186
|
+
return self._client.post(
|
4187
|
+
"/concurrency_limits/decrement",
|
4188
|
+
json={
|
4189
|
+
"names": names,
|
4190
|
+
"occupancy_seconds": occupancy_seconds,
|
4191
|
+
"task_run_id": str(task_run_id),
|
4192
|
+
},
|
4193
|
+
)
|
@@ -157,6 +157,10 @@ class DeploymentCreate(ActionBaseModel):
|
|
157
157
|
default_factory=list,
|
158
158
|
description="A list of schedules for the deployment.",
|
159
159
|
)
|
160
|
+
concurrency_limit: Optional[int] = Field(
|
161
|
+
default=None,
|
162
|
+
description="The concurrency limit for the deployment.",
|
163
|
+
)
|
160
164
|
enforce_parameter_schema: Optional[bool] = Field(
|
161
165
|
default=None,
|
162
166
|
description=(
|
@@ -229,6 +233,10 @@ class DeploymentUpdate(ActionBaseModel):
|
|
229
233
|
default=None,
|
230
234
|
description="A list of schedules for the deployment.",
|
231
235
|
)
|
236
|
+
concurrency_limit: Optional[int] = Field(
|
237
|
+
default=None,
|
238
|
+
description="The concurrency limit for the deployment.",
|
239
|
+
)
|
232
240
|
tags: List[str] = Field(default_factory=list)
|
233
241
|
work_queue_name: Optional[str] = Field(None)
|
234
242
|
work_pool_name: Optional[str] = Field(
|
@@ -607,11 +615,11 @@ class WorkQueueCreate(ActionBaseModel):
|
|
607
615
|
default=False,
|
608
616
|
description="Whether the work queue is paused.",
|
609
617
|
)
|
610
|
-
concurrency_limit: Optional[
|
618
|
+
concurrency_limit: Optional[NonNegativeInteger] = Field(
|
611
619
|
default=None,
|
612
620
|
description="A concurrency limit for the work queue.",
|
613
621
|
)
|
614
|
-
priority: Optional[
|
622
|
+
priority: Optional[PositiveInteger] = Field(
|
615
623
|
default=None,
|
616
624
|
description=(
|
617
625
|
"The queue's priority. Lower values are higher priority (1 is the highest)."
|
@@ -635,8 +643,10 @@ class WorkQueueUpdate(ActionBaseModel):
|
|
635
643
|
is_paused: bool = Field(
|
636
644
|
default=False, description="Whether or not the work queue is paused."
|
637
645
|
)
|
638
|
-
concurrency_limit: Optional[
|
639
|
-
priority: Optional[
|
646
|
+
concurrency_limit: Optional[NonNegativeInteger] = Field(None)
|
647
|
+
priority: Optional[PositiveInteger] = Field(
|
648
|
+
None, description="The queue's priority."
|
649
|
+
)
|
640
650
|
last_polled: Optional[DateTime] = Field(None)
|
641
651
|
|
642
652
|
# DEPRECATED
|
@@ -505,6 +505,23 @@ class DeploymentFilterTags(PrefectBaseModel, OperatorMixin):
|
|
505
505
|
)
|
506
506
|
|
507
507
|
|
508
|
+
class DeploymentFilterConcurrencyLimit(PrefectBaseModel):
|
509
|
+
"""Filter by `Deployment.concurrency_limit`."""
|
510
|
+
|
511
|
+
ge_: Optional[int] = Field(
|
512
|
+
default=None,
|
513
|
+
description="Only include deployments with a concurrency limit greater than or equal to this value",
|
514
|
+
)
|
515
|
+
le_: Optional[int] = Field(
|
516
|
+
default=None,
|
517
|
+
description="Only include deployments with a concurrency limit less than or equal to this value",
|
518
|
+
)
|
519
|
+
is_null_: Optional[bool] = Field(
|
520
|
+
default=None,
|
521
|
+
description="If true, only include deployments without a concurrency limit",
|
522
|
+
)
|
523
|
+
|
524
|
+
|
508
525
|
class DeploymentFilter(PrefectBaseModel, OperatorMixin):
|
509
526
|
"""Filter for deployments. Only deployments matching all criteria will be returned."""
|
510
527
|
|
@@ -520,6 +537,9 @@ class DeploymentFilter(PrefectBaseModel, OperatorMixin):
|
|
520
537
|
work_queue_name: Optional[DeploymentFilterWorkQueueName] = Field(
|
521
538
|
default=None, description="Filter criteria for `Deployment.work_queue_name`"
|
522
539
|
)
|
540
|
+
concurrency_limit: Optional[DeploymentFilterConcurrencyLimit] = Field(
|
541
|
+
default=None, description="Filter criteria for `Deployment.concurrency_limit`"
|
542
|
+
)
|
523
543
|
|
524
544
|
|
525
545
|
class LogFilterName(PrefectBaseModel):
|
@@ -996,6 +996,9 @@ class Deployment(ObjectBaseModel):
|
|
996
996
|
paused: bool = Field(
|
997
997
|
default=False, description="Whether or not the deployment is paused."
|
998
998
|
)
|
999
|
+
concurrency_limit: Optional[int] = Field(
|
1000
|
+
default=None, description="The concurrency limit for the deployment."
|
1001
|
+
)
|
999
1002
|
schedules: List[DeploymentSchedule] = Field(
|
1000
1003
|
default_factory=list, description="A list of schedules for the deployment."
|
1001
1004
|
)
|
@@ -313,6 +313,9 @@ class DeploymentResponse(ObjectBaseModel):
|
|
313
313
|
flow_id: UUID = Field(
|
314
314
|
default=..., description="The flow id associated with the deployment."
|
315
315
|
)
|
316
|
+
concurrency_limit: Optional[int] = Field(
|
317
|
+
default=None, description="The concurrency limit for the deployment."
|
318
|
+
)
|
316
319
|
paused: bool = Field(
|
317
320
|
default=False, description="Whether or not the deployment is paused."
|
318
321
|
)
|
File without changes
|
@@ -0,0 +1,143 @@
|
|
1
|
+
import asyncio
|
2
|
+
from contextlib import asynccontextmanager
|
3
|
+
from typing import AsyncGenerator, List, Optional, Union, cast
|
4
|
+
from uuid import UUID
|
5
|
+
|
6
|
+
import anyio
|
7
|
+
import httpx
|
8
|
+
import pendulum
|
9
|
+
|
10
|
+
from ...client.schemas.responses import MinimalConcurrencyLimitResponse
|
11
|
+
|
12
|
+
try:
|
13
|
+
from pendulum import Interval
|
14
|
+
except ImportError:
|
15
|
+
# pendulum < 3
|
16
|
+
from pendulum.period import Period as Interval # type: ignore
|
17
|
+
|
18
|
+
from prefect.client.orchestration import get_client
|
19
|
+
|
20
|
+
from .context import ConcurrencyContext
|
21
|
+
from .events import (
|
22
|
+
_emit_concurrency_acquisition_events,
|
23
|
+
_emit_concurrency_release_events,
|
24
|
+
)
|
25
|
+
from .services import ConcurrencySlotAcquisitionService
|
26
|
+
|
27
|
+
|
28
|
+
class ConcurrencySlotAcquisitionError(Exception):
|
29
|
+
"""Raised when an unhandlable occurs while acquiring concurrency slots."""
|
30
|
+
|
31
|
+
|
32
|
+
class AcquireConcurrencySlotTimeoutError(TimeoutError):
|
33
|
+
"""Raised when acquiring a concurrency slot times out."""
|
34
|
+
|
35
|
+
|
36
|
+
@asynccontextmanager
|
37
|
+
async def concurrency(
|
38
|
+
names: Union[str, List[str]],
|
39
|
+
task_run_id: UUID,
|
40
|
+
timeout_seconds: Optional[float] = None,
|
41
|
+
) -> AsyncGenerator[None, None]:
|
42
|
+
"""A context manager that acquires and releases concurrency slots from the
|
43
|
+
given concurrency limits.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
names: The names of the concurrency limits to acquire slots from.
|
47
|
+
task_run_id: The name of the task_run_id that is incrementing the slots.
|
48
|
+
timeout_seconds: The number of seconds to wait for the slots to be acquired before
|
49
|
+
raising a `TimeoutError`. A timeout of `None` will wait indefinitely.
|
50
|
+
|
51
|
+
Raises:
|
52
|
+
TimeoutError: If the slots are not acquired within the given timeout.
|
53
|
+
|
54
|
+
Example:
|
55
|
+
A simple example of using the async `concurrency` context manager:
|
56
|
+
```python
|
57
|
+
from prefect.concurrency.v1.asyncio import concurrency
|
58
|
+
|
59
|
+
async def resource_heavy():
|
60
|
+
async with concurrency("test", task_run_id):
|
61
|
+
print("Resource heavy task")
|
62
|
+
|
63
|
+
async def main():
|
64
|
+
await resource_heavy()
|
65
|
+
```
|
66
|
+
"""
|
67
|
+
if not names:
|
68
|
+
yield
|
69
|
+
return
|
70
|
+
|
71
|
+
names_normalized: List[str] = names if isinstance(names, list) else [names]
|
72
|
+
|
73
|
+
limits = await _acquire_concurrency_slots(
|
74
|
+
names_normalized,
|
75
|
+
task_run_id=task_run_id,
|
76
|
+
timeout_seconds=timeout_seconds,
|
77
|
+
)
|
78
|
+
acquisition_time = pendulum.now("UTC")
|
79
|
+
emitted_events = _emit_concurrency_acquisition_events(limits, task_run_id)
|
80
|
+
|
81
|
+
try:
|
82
|
+
yield
|
83
|
+
finally:
|
84
|
+
occupancy_period = cast(Interval, (pendulum.now("UTC") - acquisition_time))
|
85
|
+
try:
|
86
|
+
await _release_concurrency_slots(
|
87
|
+
names_normalized, task_run_id, occupancy_period.total_seconds()
|
88
|
+
)
|
89
|
+
except anyio.get_cancelled_exc_class():
|
90
|
+
# The task was cancelled before it could release the slots. Add the
|
91
|
+
# slots to the cleanup list so they can be released when the
|
92
|
+
# concurrency context is exited.
|
93
|
+
if ctx := ConcurrencyContext.get():
|
94
|
+
ctx.cleanup_slots.append(
|
95
|
+
(names_normalized, occupancy_period.total_seconds(), task_run_id)
|
96
|
+
)
|
97
|
+
|
98
|
+
_emit_concurrency_release_events(limits, emitted_events, task_run_id)
|
99
|
+
|
100
|
+
|
101
|
+
async def _acquire_concurrency_slots(
|
102
|
+
names: List[str],
|
103
|
+
task_run_id: UUID,
|
104
|
+
timeout_seconds: Optional[float] = None,
|
105
|
+
) -> List[MinimalConcurrencyLimitResponse]:
|
106
|
+
service = ConcurrencySlotAcquisitionService.instance(frozenset(names))
|
107
|
+
future = service.send((task_run_id, timeout_seconds))
|
108
|
+
response_or_exception = await asyncio.wrap_future(future)
|
109
|
+
|
110
|
+
if isinstance(response_or_exception, Exception):
|
111
|
+
if isinstance(response_or_exception, TimeoutError):
|
112
|
+
raise AcquireConcurrencySlotTimeoutError(
|
113
|
+
f"Attempt to acquire concurrency limits timed out after {timeout_seconds} second(s)"
|
114
|
+
) from response_or_exception
|
115
|
+
|
116
|
+
raise ConcurrencySlotAcquisitionError(
|
117
|
+
f"Unable to acquire concurrency limits {names!r}"
|
118
|
+
) from response_or_exception
|
119
|
+
|
120
|
+
return _response_to_concurrency_limit_response(response_or_exception)
|
121
|
+
|
122
|
+
|
123
|
+
async def _release_concurrency_slots(
|
124
|
+
names: List[str],
|
125
|
+
task_run_id: UUID,
|
126
|
+
occupancy_seconds: float,
|
127
|
+
) -> List[MinimalConcurrencyLimitResponse]:
|
128
|
+
async with get_client() as client:
|
129
|
+
response = await client.decrement_v1_concurrency_slots(
|
130
|
+
names=names,
|
131
|
+
task_run_id=task_run_id,
|
132
|
+
occupancy_seconds=occupancy_seconds,
|
133
|
+
)
|
134
|
+
return _response_to_concurrency_limit_response(response)
|
135
|
+
|
136
|
+
|
137
|
+
def _response_to_concurrency_limit_response(
|
138
|
+
response: httpx.Response,
|
139
|
+
) -> List[MinimalConcurrencyLimitResponse]:
|
140
|
+
data = response.json() or []
|
141
|
+
return [
|
142
|
+
MinimalConcurrencyLimitResponse.model_validate(limit) for limit in data if data
|
143
|
+
]
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from contextvars import ContextVar
|
2
|
+
from typing import List, Tuple
|
3
|
+
from uuid import UUID
|
4
|
+
|
5
|
+
from prefect.client.orchestration import get_client
|
6
|
+
from prefect.context import ContextModel, Field
|
7
|
+
|
8
|
+
|
9
|
+
class ConcurrencyContext(ContextModel):
|
10
|
+
__var__: ContextVar = ContextVar("concurrency_v1")
|
11
|
+
|
12
|
+
# Track the limits that have been acquired but were not able to be released
|
13
|
+
# due to cancellation or some other error. These limits are released when
|
14
|
+
# the context manager exits.
|
15
|
+
cleanup_slots: List[Tuple[List[str], float, UUID]] = Field(default_factory=list)
|
16
|
+
|
17
|
+
def __exit__(self, *exc_info):
|
18
|
+
if self.cleanup_slots:
|
19
|
+
with get_client(sync_client=True) as client:
|
20
|
+
for names, occupancy_seconds, task_run_id in self.cleanup_slots:
|
21
|
+
client.decrement_v1_concurrency_slots(
|
22
|
+
names=names,
|
23
|
+
occupancy_seconds=occupancy_seconds,
|
24
|
+
task_run_id=task_run_id,
|
25
|
+
)
|
26
|
+
|
27
|
+
return super().__exit__(*exc_info)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from typing import Dict, List, Literal, Optional, Union
|
2
|
+
from uuid import UUID
|
3
|
+
|
4
|
+
from prefect.client.schemas.responses import MinimalConcurrencyLimitResponse
|
5
|
+
from prefect.events import Event, RelatedResource, emit_event
|
6
|
+
|
7
|
+
|
8
|
+
def _emit_concurrency_event(
|
9
|
+
phase: Union[Literal["acquired"], Literal["released"]],
|
10
|
+
primary_limit: MinimalConcurrencyLimitResponse,
|
11
|
+
related_limits: List[MinimalConcurrencyLimitResponse],
|
12
|
+
task_run_id: UUID,
|
13
|
+
follows: Union[Event, None] = None,
|
14
|
+
) -> Union[Event, None]:
|
15
|
+
resource: Dict[str, str] = {
|
16
|
+
"prefect.resource.id": f"prefect.concurrency-limit.v1.{primary_limit.id}",
|
17
|
+
"prefect.resource.name": primary_limit.name,
|
18
|
+
"limit": str(primary_limit.limit),
|
19
|
+
"task_run_id": str(task_run_id),
|
20
|
+
}
|
21
|
+
|
22
|
+
related = [
|
23
|
+
RelatedResource.model_validate(
|
24
|
+
{
|
25
|
+
"prefect.resource.id": f"prefect.concurrency-limit.v1.{limit.id}",
|
26
|
+
"prefect.resource.role": "concurrency-limit",
|
27
|
+
}
|
28
|
+
)
|
29
|
+
for limit in related_limits
|
30
|
+
if limit.id != primary_limit.id
|
31
|
+
]
|
32
|
+
|
33
|
+
return emit_event(
|
34
|
+
f"prefect.concurrency-limit.v1.{phase}",
|
35
|
+
resource=resource,
|
36
|
+
related=related,
|
37
|
+
follows=follows,
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
def _emit_concurrency_acquisition_events(
|
42
|
+
limits: List[MinimalConcurrencyLimitResponse],
|
43
|
+
task_run_id: UUID,
|
44
|
+
) -> Dict[UUID, Optional[Event]]:
|
45
|
+
events = {}
|
46
|
+
for limit in limits:
|
47
|
+
event = _emit_concurrency_event("acquired", limit, limits, task_run_id)
|
48
|
+
events[limit.id] = event
|
49
|
+
|
50
|
+
return events
|
51
|
+
|
52
|
+
|
53
|
+
def _emit_concurrency_release_events(
|
54
|
+
limits: List[MinimalConcurrencyLimitResponse],
|
55
|
+
events: Dict[UUID, Optional[Event]],
|
56
|
+
task_run_id: UUID,
|
57
|
+
) -> None:
|
58
|
+
for limit in limits:
|
59
|
+
_emit_concurrency_event(
|
60
|
+
"released", limit, limits, task_run_id, events[limit.id]
|
61
|
+
)
|