prefect-client 3.2.2__py3-none-any.whl → 3.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefect/__init__.py +15 -8
- prefect/_build_info.py +5 -0
- prefect/client/orchestration/__init__.py +16 -5
- prefect/main.py +0 -2
- prefect/server/api/__init__.py +34 -0
- prefect/server/api/admin.py +85 -0
- prefect/server/api/artifacts.py +224 -0
- prefect/server/api/automations.py +239 -0
- prefect/server/api/block_capabilities.py +25 -0
- prefect/server/api/block_documents.py +164 -0
- prefect/server/api/block_schemas.py +153 -0
- prefect/server/api/block_types.py +211 -0
- prefect/server/api/clients.py +246 -0
- prefect/server/api/collections.py +75 -0
- prefect/server/api/concurrency_limits.py +286 -0
- prefect/server/api/concurrency_limits_v2.py +269 -0
- prefect/server/api/csrf_token.py +38 -0
- prefect/server/api/dependencies.py +196 -0
- prefect/server/api/deployments.py +941 -0
- prefect/server/api/events.py +300 -0
- prefect/server/api/flow_run_notification_policies.py +120 -0
- prefect/server/api/flow_run_states.py +52 -0
- prefect/server/api/flow_runs.py +867 -0
- prefect/server/api/flows.py +210 -0
- prefect/server/api/logs.py +43 -0
- prefect/server/api/middleware.py +73 -0
- prefect/server/api/root.py +35 -0
- prefect/server/api/run_history.py +170 -0
- prefect/server/api/saved_searches.py +99 -0
- prefect/server/api/server.py +891 -0
- prefect/server/api/task_run_states.py +52 -0
- prefect/server/api/task_runs.py +342 -0
- prefect/server/api/task_workers.py +31 -0
- prefect/server/api/templates.py +35 -0
- prefect/server/api/ui/__init__.py +3 -0
- prefect/server/api/ui/flow_runs.py +128 -0
- prefect/server/api/ui/flows.py +173 -0
- prefect/server/api/ui/schemas.py +63 -0
- prefect/server/api/ui/task_runs.py +175 -0
- prefect/server/api/validation.py +382 -0
- prefect/server/api/variables.py +181 -0
- prefect/server/api/work_queues.py +230 -0
- prefect/server/api/workers.py +656 -0
- prefect/settings/sources.py +18 -5
- {prefect_client-3.2.2.dist-info → prefect_client-3.2.4.dist-info}/METADATA +10 -15
- {prefect_client-3.2.2.dist-info → prefect_client-3.2.4.dist-info}/RECORD +48 -10
- {prefect_client-3.2.2.dist-info → prefect_client-3.2.4.dist-info}/WHEEL +1 -2
- prefect/_version.py +0 -21
- prefect_client-3.2.2.dist-info/top_level.txt +0 -1
- {prefect_client-3.2.2.dist-info → prefect_client-3.2.4.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,246 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import base64
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
5
|
+
from urllib.parse import quote
|
6
|
+
from uuid import UUID
|
7
|
+
|
8
|
+
import httpx
|
9
|
+
import pydantic
|
10
|
+
from httpx import Response
|
11
|
+
from starlette import status
|
12
|
+
from typing_extensions import Self
|
13
|
+
|
14
|
+
from prefect.client.base import PrefectHttpxAsyncClient
|
15
|
+
from prefect.exceptions import ObjectNotFound
|
16
|
+
from prefect.logging import get_logger
|
17
|
+
from prefect.server.schemas.actions import DeploymentFlowRunCreate, StateCreate
|
18
|
+
from prefect.server.schemas.core import WorkPool
|
19
|
+
from prefect.server.schemas.filters import VariableFilter, VariableFilterName
|
20
|
+
from prefect.server.schemas.responses import DeploymentResponse, OrchestrationResult
|
21
|
+
from prefect.settings import PREFECT_SERVER_API_AUTH_STRING
|
22
|
+
from prefect.types import StrictVariableValue
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
import logging
|
26
|
+
|
27
|
+
logger: "logging.Logger" = get_logger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
class BaseClient:
|
31
|
+
_http_client: PrefectHttpxAsyncClient
|
32
|
+
|
33
|
+
def __init__(self, additional_headers: dict[str, str] | None = None):
|
34
|
+
from prefect.server.api.server import create_app
|
35
|
+
|
36
|
+
additional_headers = additional_headers or {}
|
37
|
+
|
38
|
+
# create_app caches application instances, and invoking it with no arguments
|
39
|
+
# will point it to the the currently running server instance
|
40
|
+
api_app = create_app()
|
41
|
+
|
42
|
+
# we pull the auth string from _server_ settings because this client is run on the server
|
43
|
+
auth_string = PREFECT_SERVER_API_AUTH_STRING.value()
|
44
|
+
|
45
|
+
if auth_string:
|
46
|
+
token = base64.b64encode(auth_string.encode("utf-8")).decode("utf-8")
|
47
|
+
additional_headers.setdefault("Authorization", f"Basic {token}")
|
48
|
+
|
49
|
+
self._http_client = PrefectHttpxAsyncClient(
|
50
|
+
transport=httpx.ASGITransport(app=api_app, raise_app_exceptions=False),
|
51
|
+
headers={**additional_headers},
|
52
|
+
base_url="http://prefect-in-memory/api",
|
53
|
+
enable_csrf_support=False,
|
54
|
+
raise_on_all_errors=False,
|
55
|
+
)
|
56
|
+
|
57
|
+
async def __aenter__(self) -> Self:
|
58
|
+
await self._http_client.__aenter__()
|
59
|
+
return self
|
60
|
+
|
61
|
+
async def __aexit__(self, *args: Any) -> None:
|
62
|
+
await self._http_client.__aexit__(*args)
|
63
|
+
|
64
|
+
|
65
|
+
class OrchestrationClient(BaseClient):
|
66
|
+
async def read_deployment_raw(self, deployment_id: UUID) -> Response:
|
67
|
+
return await self._http_client.get(f"/deployments/{deployment_id}")
|
68
|
+
|
69
|
+
async def read_deployment(
|
70
|
+
self, deployment_id: UUID
|
71
|
+
) -> Optional[DeploymentResponse]:
|
72
|
+
try:
|
73
|
+
response = await self.read_deployment_raw(deployment_id)
|
74
|
+
response.raise_for_status()
|
75
|
+
except httpx.HTTPStatusError as e:
|
76
|
+
if e.response.status_code == status.HTTP_404_NOT_FOUND:
|
77
|
+
return None
|
78
|
+
raise
|
79
|
+
return DeploymentResponse.model_validate(response.json())
|
80
|
+
|
81
|
+
async def read_flow_raw(self, flow_id: UUID) -> Response:
|
82
|
+
return await self._http_client.get(f"/flows/{flow_id}")
|
83
|
+
|
84
|
+
async def create_flow_run(
|
85
|
+
self, deployment_id: UUID, flow_run_create: DeploymentFlowRunCreate
|
86
|
+
) -> Response:
|
87
|
+
return await self._http_client.post(
|
88
|
+
f"/deployments/{deployment_id}/create_flow_run",
|
89
|
+
json=flow_run_create.model_dump(mode="json"),
|
90
|
+
)
|
91
|
+
|
92
|
+
async def read_flow_run_raw(self, flow_run_id: UUID) -> Response:
|
93
|
+
return await self._http_client.get(f"/flow_runs/{flow_run_id}")
|
94
|
+
|
95
|
+
async def read_task_run_raw(self, task_run_id: UUID) -> Response:
|
96
|
+
return await self._http_client.get(f"/task_runs/{task_run_id}")
|
97
|
+
|
98
|
+
async def resume_flow_run(self, flow_run_id: UUID) -> OrchestrationResult:
|
99
|
+
response = await self._http_client.post(
|
100
|
+
f"/flow_runs/{flow_run_id}/resume",
|
101
|
+
)
|
102
|
+
response.raise_for_status()
|
103
|
+
return OrchestrationResult.model_validate(response.json())
|
104
|
+
|
105
|
+
async def pause_deployment(self, deployment_id: UUID) -> Response:
|
106
|
+
return await self._http_client.post(
|
107
|
+
f"/deployments/{deployment_id}/pause_deployment",
|
108
|
+
)
|
109
|
+
|
110
|
+
async def resume_deployment(self, deployment_id: UUID) -> Response:
|
111
|
+
return await self._http_client.post(
|
112
|
+
f"/deployments/{deployment_id}/resume_deployment",
|
113
|
+
)
|
114
|
+
|
115
|
+
async def set_flow_run_state(
|
116
|
+
self, flow_run_id: UUID, state: StateCreate
|
117
|
+
) -> Response:
|
118
|
+
return await self._http_client.post(
|
119
|
+
f"/flow_runs/{flow_run_id}/set_state",
|
120
|
+
json={
|
121
|
+
"state": state.model_dump(mode="json"),
|
122
|
+
"force": False,
|
123
|
+
},
|
124
|
+
)
|
125
|
+
|
126
|
+
async def pause_work_pool(self, work_pool_name: str) -> Response:
|
127
|
+
return await self._http_client.patch(
|
128
|
+
f"/work_pools/{quote(work_pool_name)}", json={"is_paused": True}
|
129
|
+
)
|
130
|
+
|
131
|
+
async def resume_work_pool(self, work_pool_name: str) -> Response:
|
132
|
+
return await self._http_client.patch(
|
133
|
+
f"/work_pools/{quote(work_pool_name)}", json={"is_paused": False}
|
134
|
+
)
|
135
|
+
|
136
|
+
async def read_work_pool_raw(self, work_pool_id: UUID) -> Response:
|
137
|
+
return await self._http_client.post(
|
138
|
+
"/work_pools/filter",
|
139
|
+
json={"work_pools": {"id": {"any_": [str(work_pool_id)]}}},
|
140
|
+
)
|
141
|
+
|
142
|
+
async def read_work_pool(self, work_pool_id: UUID) -> Optional[WorkPool]:
|
143
|
+
response = await self.read_work_pool_raw(work_pool_id)
|
144
|
+
response.raise_for_status()
|
145
|
+
|
146
|
+
pools = pydantic.TypeAdapter(List[WorkPool]).validate_python(response.json())
|
147
|
+
return pools[0] if pools else None
|
148
|
+
|
149
|
+
async def read_work_queue_raw(self, work_queue_id: UUID) -> Response:
|
150
|
+
return await self._http_client.get(f"/work_queues/{work_queue_id}")
|
151
|
+
|
152
|
+
async def read_work_queue_status_raw(self, work_queue_id: UUID) -> Response:
|
153
|
+
return await self._http_client.get(f"/work_queues/{work_queue_id}/status")
|
154
|
+
|
155
|
+
async def pause_work_queue(self, work_queue_id: UUID) -> Response:
|
156
|
+
return await self._http_client.patch(
|
157
|
+
f"/work_queues/{work_queue_id}",
|
158
|
+
json={"is_paused": True},
|
159
|
+
)
|
160
|
+
|
161
|
+
async def resume_work_queue(self, work_queue_id: UUID) -> Response:
|
162
|
+
return await self._http_client.patch(
|
163
|
+
f"/work_queues/{work_queue_id}",
|
164
|
+
json={"is_paused": False},
|
165
|
+
)
|
166
|
+
|
167
|
+
async def read_block_document_raw(
|
168
|
+
self,
|
169
|
+
block_document_id: UUID,
|
170
|
+
include_secrets: bool = True,
|
171
|
+
) -> Response:
|
172
|
+
return await self._http_client.get(
|
173
|
+
f"/block_documents/{block_document_id}",
|
174
|
+
params=dict(include_secrets=include_secrets),
|
175
|
+
)
|
176
|
+
|
177
|
+
VARIABLE_PAGE_SIZE = 200
|
178
|
+
MAX_VARIABLES_PER_WORKSPACE = 1000
|
179
|
+
|
180
|
+
async def read_workspace_variables(
|
181
|
+
self, names: Optional[List[str]] = None
|
182
|
+
) -> Dict[str, StrictVariableValue]:
|
183
|
+
variables: Dict[str, StrictVariableValue] = {}
|
184
|
+
|
185
|
+
offset = 0
|
186
|
+
|
187
|
+
filter = VariableFilter()
|
188
|
+
|
189
|
+
if names is not None and not names:
|
190
|
+
return variables
|
191
|
+
elif names is not None:
|
192
|
+
filter.name = VariableFilterName(any_=list(set(names)))
|
193
|
+
|
194
|
+
for offset in range(
|
195
|
+
0, self.MAX_VARIABLES_PER_WORKSPACE, self.VARIABLE_PAGE_SIZE
|
196
|
+
):
|
197
|
+
response = await self._http_client.post(
|
198
|
+
"/variables/filter",
|
199
|
+
json={
|
200
|
+
"variables": filter.model_dump(),
|
201
|
+
"limit": self.VARIABLE_PAGE_SIZE,
|
202
|
+
"offset": offset,
|
203
|
+
},
|
204
|
+
)
|
205
|
+
if response.status_code >= 300:
|
206
|
+
response.raise_for_status()
|
207
|
+
|
208
|
+
results = response.json()
|
209
|
+
for variable in results:
|
210
|
+
variables[variable["name"]] = variable["value"]
|
211
|
+
|
212
|
+
if len(results) < self.VARIABLE_PAGE_SIZE:
|
213
|
+
break
|
214
|
+
|
215
|
+
return variables
|
216
|
+
|
217
|
+
async def read_concurrency_limit_v2_raw(
|
218
|
+
self, concurrency_limit_id: UUID
|
219
|
+
) -> Response:
|
220
|
+
return await self._http_client.get(
|
221
|
+
f"/v2/concurrency_limits/{concurrency_limit_id}"
|
222
|
+
)
|
223
|
+
|
224
|
+
|
225
|
+
class WorkPoolsOrchestrationClient(BaseClient):
|
226
|
+
async def __aenter__(self) -> Self:
|
227
|
+
return self
|
228
|
+
|
229
|
+
async def read_work_pool(self, work_pool_name: str) -> WorkPool:
|
230
|
+
"""
|
231
|
+
Reads information for a given work pool
|
232
|
+
Args:
|
233
|
+
work_pool_name: The name of the work pool to for which to get
|
234
|
+
information.
|
235
|
+
Returns:
|
236
|
+
Information about the requested work pool.
|
237
|
+
"""
|
238
|
+
try:
|
239
|
+
response = await self._http_client.get(f"/work_pools/{work_pool_name}")
|
240
|
+
response.raise_for_status()
|
241
|
+
return WorkPool.model_validate(response.json())
|
242
|
+
except httpx.HTTPStatusError as e:
|
243
|
+
if e.response.status_code == status.HTTP_404_NOT_FOUND:
|
244
|
+
raise ObjectNotFound(http_exc=e) from e
|
245
|
+
else:
|
246
|
+
raise
|
@@ -0,0 +1,75 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any, Dict
|
3
|
+
|
4
|
+
import httpx
|
5
|
+
from anyio import Path
|
6
|
+
from cachetools import TTLCache
|
7
|
+
from fastapi import HTTPException, status
|
8
|
+
|
9
|
+
from prefect.server.utilities.server import PrefectRouter
|
10
|
+
|
11
|
+
router: PrefectRouter = PrefectRouter(prefix="/collections", tags=["Collections"])
|
12
|
+
|
13
|
+
GLOBAL_COLLECTIONS_VIEW_CACHE: TTLCache[str, dict[str, Any]] = TTLCache(
|
14
|
+
maxsize=200, ttl=60 * 10
|
15
|
+
)
|
16
|
+
|
17
|
+
REGISTRY_VIEWS = (
|
18
|
+
"https://raw.githubusercontent.com/PrefectHQ/prefect-collection-registry/main/views"
|
19
|
+
)
|
20
|
+
KNOWN_VIEWS = {
|
21
|
+
"aggregate-block-metadata": f"{REGISTRY_VIEWS}/aggregate-block-metadata.json",
|
22
|
+
"aggregate-flow-metadata": f"{REGISTRY_VIEWS}/aggregate-flow-metadata.json",
|
23
|
+
"aggregate-worker-metadata": f"{REGISTRY_VIEWS}/aggregate-worker-metadata.json",
|
24
|
+
"demo-flows": f"{REGISTRY_VIEWS}/demo-flows.json",
|
25
|
+
}
|
26
|
+
|
27
|
+
|
28
|
+
@router.get("/views/{view}")
|
29
|
+
async def read_view_content(view: str) -> Dict[str, Any]:
|
30
|
+
"""Reads the content of a view from the prefect-collection-registry."""
|
31
|
+
try:
|
32
|
+
return await get_collection_view(view)
|
33
|
+
except KeyError:
|
34
|
+
raise HTTPException(
|
35
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
36
|
+
detail=f"View {view} not found in registry",
|
37
|
+
)
|
38
|
+
except httpx.HTTPStatusError as exc:
|
39
|
+
if exc.response.status_code == 404:
|
40
|
+
raise HTTPException(
|
41
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
42
|
+
detail=f"Requested content missing for view {view}",
|
43
|
+
)
|
44
|
+
else:
|
45
|
+
raise
|
46
|
+
|
47
|
+
|
48
|
+
async def get_collection_view(view: str) -> dict[str, Any]:
|
49
|
+
try:
|
50
|
+
return GLOBAL_COLLECTIONS_VIEW_CACHE[view]
|
51
|
+
except KeyError:
|
52
|
+
pass
|
53
|
+
|
54
|
+
try:
|
55
|
+
async with httpx.AsyncClient() as client:
|
56
|
+
resp = await client.get(KNOWN_VIEWS[view])
|
57
|
+
resp.raise_for_status()
|
58
|
+
|
59
|
+
data = resp.json()
|
60
|
+
if view == "aggregate-worker-metadata":
|
61
|
+
data.get("prefect", {}).pop("prefect-agent", None)
|
62
|
+
|
63
|
+
GLOBAL_COLLECTIONS_VIEW_CACHE[view] = data
|
64
|
+
return data
|
65
|
+
except Exception:
|
66
|
+
if view not in KNOWN_VIEWS:
|
67
|
+
raise
|
68
|
+
local_file = Path(__file__).parent / Path(f"collections_data/views/{view}.json")
|
69
|
+
if await local_file.exists():
|
70
|
+
raw_data = await local_file.read_text()
|
71
|
+
data = json.loads(raw_data)
|
72
|
+
GLOBAL_COLLECTIONS_VIEW_CACHE[view] = data
|
73
|
+
return data
|
74
|
+
else:
|
75
|
+
raise
|
@@ -0,0 +1,286 @@
|
|
1
|
+
"""
|
2
|
+
Routes for interacting with concurrency limit objects.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import List, Optional, Sequence
|
6
|
+
from uuid import UUID
|
7
|
+
|
8
|
+
from fastapi import Body, Depends, HTTPException, Path, Response, status
|
9
|
+
|
10
|
+
import prefect.server.api.dependencies as dependencies
|
11
|
+
import prefect.server.models as models
|
12
|
+
import prefect.server.schemas as schemas
|
13
|
+
from prefect.server.api.concurrency_limits_v2 import MinimalConcurrencyLimitResponse
|
14
|
+
from prefect.server.database import PrefectDBInterface, provide_database_interface
|
15
|
+
from prefect.server.models import concurrency_limits
|
16
|
+
from prefect.server.utilities.server import PrefectRouter
|
17
|
+
from prefect.settings import PREFECT_TASK_RUN_TAG_CONCURRENCY_SLOT_WAIT_SECONDS
|
18
|
+
from prefect.types._datetime import now
|
19
|
+
|
20
|
+
router: PrefectRouter = PrefectRouter(
|
21
|
+
prefix="/concurrency_limits", tags=["Concurrency Limits"]
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
@router.post("/")
|
26
|
+
async def create_concurrency_limit(
|
27
|
+
concurrency_limit: schemas.actions.ConcurrencyLimitCreate,
|
28
|
+
response: Response,
|
29
|
+
db: PrefectDBInterface = Depends(provide_database_interface),
|
30
|
+
) -> schemas.core.ConcurrencyLimit:
|
31
|
+
# hydrate the input model into a full model
|
32
|
+
concurrency_limit_model = schemas.core.ConcurrencyLimit(
|
33
|
+
**concurrency_limit.model_dump()
|
34
|
+
)
|
35
|
+
|
36
|
+
async with db.session_context(begin_transaction=True) as session:
|
37
|
+
model = await models.concurrency_limits.create_concurrency_limit(
|
38
|
+
session=session, concurrency_limit=concurrency_limit_model
|
39
|
+
)
|
40
|
+
|
41
|
+
if model.created >= now("UTC"):
|
42
|
+
response.status_code = status.HTTP_201_CREATED
|
43
|
+
|
44
|
+
return model
|
45
|
+
|
46
|
+
|
47
|
+
@router.get("/{id}")
|
48
|
+
async def read_concurrency_limit(
|
49
|
+
concurrency_limit_id: UUID = Path(
|
50
|
+
..., description="The concurrency limit id", alias="id"
|
51
|
+
),
|
52
|
+
db: PrefectDBInterface = Depends(provide_database_interface),
|
53
|
+
) -> schemas.core.ConcurrencyLimit:
|
54
|
+
"""
|
55
|
+
Get a concurrency limit by id.
|
56
|
+
|
57
|
+
The `active slots` field contains a list of TaskRun IDs currently using a
|
58
|
+
concurrency slot for the specified tag.
|
59
|
+
"""
|
60
|
+
async with db.session_context() as session:
|
61
|
+
model = await models.concurrency_limits.read_concurrency_limit(
|
62
|
+
session=session, concurrency_limit_id=concurrency_limit_id
|
63
|
+
)
|
64
|
+
if not model:
|
65
|
+
raise HTTPException(
|
66
|
+
status_code=status.HTTP_404_NOT_FOUND, detail="Concurrency limit not found"
|
67
|
+
)
|
68
|
+
return model
|
69
|
+
|
70
|
+
|
71
|
+
@router.get("/tag/{tag}")
|
72
|
+
async def read_concurrency_limit_by_tag(
|
73
|
+
tag: str = Path(..., description="The tag name", alias="tag"),
|
74
|
+
db: PrefectDBInterface = Depends(provide_database_interface),
|
75
|
+
) -> schemas.core.ConcurrencyLimit:
|
76
|
+
"""
|
77
|
+
Get a concurrency limit by tag.
|
78
|
+
|
79
|
+
The `active slots` field contains a list of TaskRun IDs currently using a
|
80
|
+
concurrency slot for the specified tag.
|
81
|
+
"""
|
82
|
+
|
83
|
+
async with db.session_context() as session:
|
84
|
+
model = await models.concurrency_limits.read_concurrency_limit_by_tag(
|
85
|
+
session=session, tag=tag
|
86
|
+
)
|
87
|
+
|
88
|
+
if not model:
|
89
|
+
raise HTTPException(
|
90
|
+
status.HTTP_404_NOT_FOUND, detail="Concurrency limit not found"
|
91
|
+
)
|
92
|
+
return model
|
93
|
+
|
94
|
+
|
95
|
+
@router.post("/filter")
|
96
|
+
async def read_concurrency_limits(
|
97
|
+
limit: int = dependencies.LimitBody(),
|
98
|
+
offset: int = Body(0, ge=0),
|
99
|
+
db: PrefectDBInterface = Depends(provide_database_interface),
|
100
|
+
) -> Sequence[schemas.core.ConcurrencyLimit]:
|
101
|
+
"""
|
102
|
+
Query for concurrency limits.
|
103
|
+
|
104
|
+
For each concurrency limit the `active slots` field contains a list of TaskRun IDs
|
105
|
+
currently using a concurrency slot for the specified tag.
|
106
|
+
"""
|
107
|
+
async with db.session_context() as session:
|
108
|
+
return await models.concurrency_limits.read_concurrency_limits(
|
109
|
+
session=session,
|
110
|
+
limit=limit,
|
111
|
+
offset=offset,
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
@router.post("/tag/{tag}/reset")
|
116
|
+
async def reset_concurrency_limit_by_tag(
|
117
|
+
tag: str = Path(..., description="The tag name"),
|
118
|
+
slot_override: Optional[List[UUID]] = Body(
|
119
|
+
None,
|
120
|
+
embed=True,
|
121
|
+
description="Manual override for active concurrency limit slots.",
|
122
|
+
),
|
123
|
+
db: PrefectDBInterface = Depends(provide_database_interface),
|
124
|
+
) -> None:
|
125
|
+
async with db.session_context(begin_transaction=True) as session:
|
126
|
+
model = await models.concurrency_limits.reset_concurrency_limit_by_tag(
|
127
|
+
session=session, tag=tag, slot_override=slot_override
|
128
|
+
)
|
129
|
+
if not model:
|
130
|
+
raise HTTPException(
|
131
|
+
status_code=status.HTTP_404_NOT_FOUND, detail="Concurrency limit not found"
|
132
|
+
)
|
133
|
+
|
134
|
+
|
135
|
+
@router.delete("/{id}")
|
136
|
+
async def delete_concurrency_limit(
|
137
|
+
concurrency_limit_id: UUID = Path(
|
138
|
+
..., description="The concurrency limit id", alias="id"
|
139
|
+
),
|
140
|
+
db: PrefectDBInterface = Depends(provide_database_interface),
|
141
|
+
) -> None:
|
142
|
+
async with db.session_context(begin_transaction=True) as session:
|
143
|
+
result = await models.concurrency_limits.delete_concurrency_limit(
|
144
|
+
session=session, concurrency_limit_id=concurrency_limit_id
|
145
|
+
)
|
146
|
+
if not result:
|
147
|
+
raise HTTPException(
|
148
|
+
status_code=status.HTTP_404_NOT_FOUND, detail="Concurrency limit not found"
|
149
|
+
)
|
150
|
+
|
151
|
+
|
152
|
+
@router.delete("/tag/{tag}")
|
153
|
+
async def delete_concurrency_limit_by_tag(
|
154
|
+
tag: str = Path(..., description="The tag name"),
|
155
|
+
db: PrefectDBInterface = Depends(provide_database_interface),
|
156
|
+
) -> None:
|
157
|
+
async with db.session_context(begin_transaction=True) as session:
|
158
|
+
result = await models.concurrency_limits.delete_concurrency_limit_by_tag(
|
159
|
+
session=session, tag=tag
|
160
|
+
)
|
161
|
+
if not result:
|
162
|
+
raise HTTPException(
|
163
|
+
status_code=status.HTTP_404_NOT_FOUND, detail="Concurrency limit not found"
|
164
|
+
)
|
165
|
+
|
166
|
+
|
167
|
+
class Abort(Exception):
|
168
|
+
def __init__(self, reason: str):
|
169
|
+
self.reason = reason
|
170
|
+
|
171
|
+
|
172
|
+
class Delay(Exception):
|
173
|
+
def __init__(self, delay_seconds: float, reason: str):
|
174
|
+
self.delay_seconds = delay_seconds
|
175
|
+
self.reason = reason
|
176
|
+
|
177
|
+
|
178
|
+
@router.post("/increment")
|
179
|
+
async def increment_concurrency_limits_v1(
|
180
|
+
names: List[str] = Body(..., description="The tags to acquire a slot for"),
|
181
|
+
task_run_id: UUID = Body(
|
182
|
+
..., description="The ID of the task run acquiring the slot"
|
183
|
+
),
|
184
|
+
db: PrefectDBInterface = Depends(provide_database_interface),
|
185
|
+
) -> List[MinimalConcurrencyLimitResponse]:
|
186
|
+
applied_limits = {}
|
187
|
+
|
188
|
+
async with db.session_context(begin_transaction=True) as session:
|
189
|
+
try:
|
190
|
+
applied_limits = {}
|
191
|
+
filtered_limits = (
|
192
|
+
await concurrency_limits.filter_concurrency_limits_for_orchestration(
|
193
|
+
session, tags=names
|
194
|
+
)
|
195
|
+
)
|
196
|
+
run_limits = {limit.tag: limit for limit in filtered_limits}
|
197
|
+
for tag, cl in run_limits.items():
|
198
|
+
limit = cl.concurrency_limit
|
199
|
+
if limit == 0:
|
200
|
+
# limits of 0 will deadlock, and the transition needs to abort
|
201
|
+
for stale_tag in applied_limits.keys():
|
202
|
+
stale_limit = run_limits.get(stale_tag, None)
|
203
|
+
active_slots = set(stale_limit.active_slots)
|
204
|
+
active_slots.discard(str(task_run_id))
|
205
|
+
stale_limit.active_slots = list(active_slots)
|
206
|
+
|
207
|
+
raise Abort(
|
208
|
+
reason=(
|
209
|
+
f'The concurrency limit on tag "{tag}" is 0 and will '
|
210
|
+
"deadlock if the task tries to run again."
|
211
|
+
),
|
212
|
+
)
|
213
|
+
elif len(cl.active_slots) >= limit:
|
214
|
+
# if the limit has already been reached, delay the transition
|
215
|
+
for stale_tag in applied_limits.keys():
|
216
|
+
stale_limit = run_limits.get(stale_tag, None)
|
217
|
+
active_slots = set(stale_limit.active_slots)
|
218
|
+
active_slots.discard(str(task_run_id))
|
219
|
+
stale_limit.active_slots = list(active_slots)
|
220
|
+
|
221
|
+
raise Delay(
|
222
|
+
delay_seconds=PREFECT_TASK_RUN_TAG_CONCURRENCY_SLOT_WAIT_SECONDS.value(),
|
223
|
+
reason=f"Concurrency limit for the {tag} tag has been reached",
|
224
|
+
)
|
225
|
+
else:
|
226
|
+
# log the TaskRun ID to active_slots
|
227
|
+
applied_limits[tag] = cl
|
228
|
+
active_slots = set(cl.active_slots)
|
229
|
+
active_slots.add(str(task_run_id))
|
230
|
+
cl.active_slots = list(active_slots)
|
231
|
+
except Exception as e:
|
232
|
+
for tag in applied_limits.keys():
|
233
|
+
cl = await concurrency_limits.read_concurrency_limit_by_tag(
|
234
|
+
session, tag
|
235
|
+
)
|
236
|
+
active_slots = set(cl.active_slots)
|
237
|
+
active_slots.discard(str(task_run_id))
|
238
|
+
cl.active_slots = list(active_slots)
|
239
|
+
|
240
|
+
if isinstance(e, Delay):
|
241
|
+
raise HTTPException(
|
242
|
+
status_code=status.HTTP_423_LOCKED,
|
243
|
+
detail=e.reason,
|
244
|
+
headers={"Retry-After": str(e.delay_seconds)},
|
245
|
+
)
|
246
|
+
elif isinstance(e, Abort):
|
247
|
+
raise HTTPException(
|
248
|
+
status_code=status.HTTP_423_LOCKED,
|
249
|
+
detail=e.reason,
|
250
|
+
)
|
251
|
+
else:
|
252
|
+
raise
|
253
|
+
return [
|
254
|
+
MinimalConcurrencyLimitResponse(
|
255
|
+
name=limit.tag, limit=limit.concurrency_limit, id=limit.id
|
256
|
+
)
|
257
|
+
for limit in applied_limits.values()
|
258
|
+
]
|
259
|
+
|
260
|
+
|
261
|
+
@router.post("/decrement")
|
262
|
+
async def decrement_concurrency_limits_v1(
|
263
|
+
names: List[str] = Body(..., description="The tags to release a slot for"),
|
264
|
+
task_run_id: UUID = Body(
|
265
|
+
..., description="The ID of the task run releasing the slot"
|
266
|
+
),
|
267
|
+
db: PrefectDBInterface = Depends(provide_database_interface),
|
268
|
+
) -> None:
|
269
|
+
async with db.session_context(begin_transaction=True) as session:
|
270
|
+
filtered_limits = (
|
271
|
+
await concurrency_limits.filter_concurrency_limits_for_orchestration(
|
272
|
+
session, tags=names
|
273
|
+
)
|
274
|
+
)
|
275
|
+
run_limits = {limit.tag: limit for limit in filtered_limits}
|
276
|
+
for tag, cl in run_limits.items():
|
277
|
+
active_slots = set(cl.active_slots)
|
278
|
+
active_slots.discard(str(task_run_id))
|
279
|
+
cl.active_slots = list(active_slots)
|
280
|
+
|
281
|
+
return [
|
282
|
+
MinimalConcurrencyLimitResponse(
|
283
|
+
name=limit.tag, limit=limit.concurrency_limit, id=limit.id
|
284
|
+
)
|
285
|
+
for limit in run_limits.values()
|
286
|
+
]
|