prefect-client 3.2.2__py3-none-any.whl → 3.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. prefect/__init__.py +15 -8
  2. prefect/_build_info.py +5 -0
  3. prefect/client/orchestration/__init__.py +16 -5
  4. prefect/main.py +0 -2
  5. prefect/server/api/__init__.py +34 -0
  6. prefect/server/api/admin.py +85 -0
  7. prefect/server/api/artifacts.py +224 -0
  8. prefect/server/api/automations.py +239 -0
  9. prefect/server/api/block_capabilities.py +25 -0
  10. prefect/server/api/block_documents.py +164 -0
  11. prefect/server/api/block_schemas.py +153 -0
  12. prefect/server/api/block_types.py +211 -0
  13. prefect/server/api/clients.py +246 -0
  14. prefect/server/api/collections.py +75 -0
  15. prefect/server/api/concurrency_limits.py +286 -0
  16. prefect/server/api/concurrency_limits_v2.py +269 -0
  17. prefect/server/api/csrf_token.py +38 -0
  18. prefect/server/api/dependencies.py +196 -0
  19. prefect/server/api/deployments.py +941 -0
  20. prefect/server/api/events.py +300 -0
  21. prefect/server/api/flow_run_notification_policies.py +120 -0
  22. prefect/server/api/flow_run_states.py +52 -0
  23. prefect/server/api/flow_runs.py +867 -0
  24. prefect/server/api/flows.py +210 -0
  25. prefect/server/api/logs.py +43 -0
  26. prefect/server/api/middleware.py +73 -0
  27. prefect/server/api/root.py +35 -0
  28. prefect/server/api/run_history.py +170 -0
  29. prefect/server/api/saved_searches.py +99 -0
  30. prefect/server/api/server.py +891 -0
  31. prefect/server/api/task_run_states.py +52 -0
  32. prefect/server/api/task_runs.py +342 -0
  33. prefect/server/api/task_workers.py +31 -0
  34. prefect/server/api/templates.py +35 -0
  35. prefect/server/api/ui/__init__.py +3 -0
  36. prefect/server/api/ui/flow_runs.py +128 -0
  37. prefect/server/api/ui/flows.py +173 -0
  38. prefect/server/api/ui/schemas.py +63 -0
  39. prefect/server/api/ui/task_runs.py +175 -0
  40. prefect/server/api/validation.py +382 -0
  41. prefect/server/api/variables.py +181 -0
  42. prefect/server/api/work_queues.py +230 -0
  43. prefect/server/api/workers.py +656 -0
  44. prefect/settings/sources.py +18 -5
  45. {prefect_client-3.2.2.dist-info → prefect_client-3.2.4.dist-info}/METADATA +10 -15
  46. {prefect_client-3.2.2.dist-info → prefect_client-3.2.4.dist-info}/RECORD +48 -10
  47. {prefect_client-3.2.2.dist-info → prefect_client-3.2.4.dist-info}/WHEEL +1 -2
  48. prefect/_version.py +0 -21
  49. prefect_client-3.2.2.dist-info/top_level.txt +0 -1
  50. {prefect_client-3.2.2.dist-info → prefect_client-3.2.4.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,52 @@
1
+ """
2
+ Routes for interacting with task run state objects.
3
+ """
4
+
5
+ from typing import List
6
+ from uuid import UUID
7
+
8
+ from fastapi import Depends, HTTPException, Path, status
9
+
10
+ import prefect.server.models as models
11
+ import prefect.server.schemas as schemas
12
+ from prefect.server.database import PrefectDBInterface, provide_database_interface
13
+ from prefect.server.utilities.server import PrefectRouter
14
+
15
+ router: PrefectRouter = PrefectRouter(
16
+ prefix="/task_run_states", tags=["Task Run States"]
17
+ )
18
+
19
+
20
+ @router.get("/{id}")
21
+ async def read_task_run_state(
22
+ task_run_state_id: UUID = Path(
23
+ ..., description="The task run state id", alias="id"
24
+ ),
25
+ db: PrefectDBInterface = Depends(provide_database_interface),
26
+ ) -> schemas.states.State:
27
+ """
28
+ Get a task run state by id.
29
+ """
30
+ async with db.session_context() as session:
31
+ task_run_state = await models.task_run_states.read_task_run_state(
32
+ session=session, task_run_state_id=task_run_state_id
33
+ )
34
+ if not task_run_state:
35
+ raise HTTPException(
36
+ status_code=status.HTTP_404_NOT_FOUND, detail="Flow run state not found"
37
+ )
38
+ return task_run_state
39
+
40
+
41
+ @router.get("/")
42
+ async def read_task_run_states(
43
+ task_run_id: UUID,
44
+ db: PrefectDBInterface = Depends(provide_database_interface),
45
+ ) -> List[schemas.states.State]:
46
+ """
47
+ Get states associated with a task run.
48
+ """
49
+ async with db.session_context() as session:
50
+ return await models.task_run_states.read_task_run_states(
51
+ session=session, task_run_id=task_run_id
52
+ )
@@ -0,0 +1,342 @@
1
+ """
2
+ Routes for interacting with task run objects.
3
+ """
4
+
5
+ import asyncio
6
+ import datetime
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
8
+ from uuid import UUID
9
+
10
+ from fastapi import (
11
+ Body,
12
+ Depends,
13
+ HTTPException,
14
+ Path,
15
+ Response,
16
+ WebSocket,
17
+ status,
18
+ )
19
+ from starlette.websockets import WebSocketDisconnect
20
+
21
+ import prefect.server.api.dependencies as dependencies
22
+ import prefect.server.models as models
23
+ import prefect.server.schemas as schemas
24
+ from prefect.logging import get_logger
25
+ from prefect.server.api.run_history import run_history
26
+ from prefect.server.database import PrefectDBInterface, provide_database_interface
27
+ from prefect.server.orchestration import dependencies as orchestration_dependencies
28
+ from prefect.server.orchestration.core_policy import CoreTaskPolicy
29
+ from prefect.server.orchestration.policies import TaskRunOrchestrationPolicy
30
+ from prefect.server.schemas.responses import OrchestrationResult
31
+ from prefect.server.task_queue import MultiQueue, TaskQueue
32
+ from prefect.server.utilities import subscriptions
33
+ from prefect.server.utilities.server import PrefectRouter
34
+ from prefect.types import DateTime
35
+ from prefect.types._datetime import now
36
+
37
+ if TYPE_CHECKING:
38
+ import logging
39
+
40
+ logger: "logging.Logger" = get_logger("server.api")
41
+
42
+ router: PrefectRouter = PrefectRouter(prefix="/task_runs", tags=["Task Runs"])
43
+
44
+
45
+ @router.post("/")
46
+ async def create_task_run(
47
+ task_run: schemas.actions.TaskRunCreate,
48
+ response: Response,
49
+ db: PrefectDBInterface = Depends(provide_database_interface),
50
+ orchestration_parameters: Dict[str, Any] = Depends(
51
+ orchestration_dependencies.provide_task_orchestration_parameters
52
+ ),
53
+ ) -> schemas.core.TaskRun:
54
+ """
55
+ Create a task run. If a task run with the same flow_run_id,
56
+ task_key, and dynamic_key already exists, the existing task
57
+ run will be returned.
58
+
59
+ If no state is provided, the task run will be created in a PENDING state.
60
+ """
61
+ # hydrate the input model into a full task run / state model
62
+ task_run_dict = task_run.model_dump()
63
+ if not task_run_dict.get("id"):
64
+ task_run_dict.pop("id", None)
65
+ task_run = schemas.core.TaskRun(**task_run_dict)
66
+
67
+ if not task_run.state:
68
+ task_run.state = schemas.states.Pending()
69
+
70
+ right_now = now("UTC")
71
+
72
+ async with db.session_context(begin_transaction=True) as session:
73
+ model = await models.task_runs.create_task_run(
74
+ session=session,
75
+ task_run=task_run,
76
+ orchestration_parameters=orchestration_parameters,
77
+ )
78
+
79
+ if model.created >= right_now:
80
+ response.status_code = status.HTTP_201_CREATED
81
+
82
+ new_task_run: schemas.core.TaskRun = schemas.core.TaskRun.model_validate(model)
83
+
84
+ return new_task_run
85
+
86
+
87
+ @router.patch("/{id}", status_code=status.HTTP_204_NO_CONTENT)
88
+ async def update_task_run(
89
+ task_run: schemas.actions.TaskRunUpdate,
90
+ task_run_id: UUID = Path(..., description="The task run id", alias="id"),
91
+ db: PrefectDBInterface = Depends(provide_database_interface),
92
+ ) -> None:
93
+ """
94
+ Updates a task run.
95
+ """
96
+ async with db.session_context(begin_transaction=True) as session:
97
+ result = await models.task_runs.update_task_run(
98
+ session=session, task_run=task_run, task_run_id=task_run_id
99
+ )
100
+ if not result:
101
+ raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Task run not found")
102
+
103
+
104
+ @router.post("/count")
105
+ async def count_task_runs(
106
+ db: PrefectDBInterface = Depends(provide_database_interface),
107
+ flows: schemas.filters.FlowFilter = None,
108
+ flow_runs: schemas.filters.FlowRunFilter = None,
109
+ task_runs: schemas.filters.TaskRunFilter = None,
110
+ deployments: schemas.filters.DeploymentFilter = None,
111
+ ) -> int:
112
+ """
113
+ Count task runs.
114
+ """
115
+ async with db.session_context() as session:
116
+ return await models.task_runs.count_task_runs(
117
+ session=session,
118
+ flow_filter=flows,
119
+ flow_run_filter=flow_runs,
120
+ task_run_filter=task_runs,
121
+ deployment_filter=deployments,
122
+ )
123
+
124
+
125
+ @router.post("/history")
126
+ async def task_run_history(
127
+ history_start: DateTime = Body(..., description="The history's start time."),
128
+ history_end: DateTime = Body(..., description="The history's end time."),
129
+ # Workaround for the fact that FastAPI does not let us configure ser_json_timedelta
130
+ # to represent timedeltas as floats in JSON.
131
+ history_interval: float = Body(
132
+ ...,
133
+ description=(
134
+ "The size of each history interval, in seconds. Must be at least 1 second."
135
+ ),
136
+ json_schema_extra={"format": "time-delta"},
137
+ alias="history_interval_seconds",
138
+ ),
139
+ flows: schemas.filters.FlowFilter = None,
140
+ flow_runs: schemas.filters.FlowRunFilter = None,
141
+ task_runs: schemas.filters.TaskRunFilter = None,
142
+ deployments: schemas.filters.DeploymentFilter = None,
143
+ db: PrefectDBInterface = Depends(provide_database_interface),
144
+ ) -> List[schemas.responses.HistoryResponse]:
145
+ """
146
+ Query for task run history data across a given range and interval.
147
+ """
148
+ if isinstance(history_interval, float):
149
+ history_interval = datetime.timedelta(seconds=history_interval)
150
+
151
+ if history_interval < datetime.timedelta(seconds=1):
152
+ raise HTTPException(
153
+ status.HTTP_422_UNPROCESSABLE_ENTITY,
154
+ detail="History interval must not be less than 1 second.",
155
+ )
156
+
157
+ async with db.session_context() as session:
158
+ return await run_history(
159
+ session=session,
160
+ run_type="task_run",
161
+ history_start=history_start,
162
+ history_end=history_end,
163
+ history_interval=history_interval,
164
+ flows=flows,
165
+ flow_runs=flow_runs,
166
+ task_runs=task_runs,
167
+ deployments=deployments,
168
+ )
169
+
170
+
171
+ @router.get("/{id}")
172
+ async def read_task_run(
173
+ task_run_id: UUID = Path(..., description="The task run id", alias="id"),
174
+ db: PrefectDBInterface = Depends(provide_database_interface),
175
+ ) -> schemas.core.TaskRun:
176
+ """
177
+ Get a task run by id.
178
+ """
179
+ async with db.session_context() as session:
180
+ task_run = await models.task_runs.read_task_run(
181
+ session=session, task_run_id=task_run_id
182
+ )
183
+ if not task_run:
184
+ raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Task not found")
185
+ return task_run
186
+
187
+
188
+ @router.post("/filter")
189
+ async def read_task_runs(
190
+ sort: schemas.sorting.TaskRunSort = Body(schemas.sorting.TaskRunSort.ID_DESC),
191
+ limit: int = dependencies.LimitBody(),
192
+ offset: int = Body(0, ge=0),
193
+ flows: Optional[schemas.filters.FlowFilter] = None,
194
+ flow_runs: Optional[schemas.filters.FlowRunFilter] = None,
195
+ task_runs: Optional[schemas.filters.TaskRunFilter] = None,
196
+ deployments: Optional[schemas.filters.DeploymentFilter] = None,
197
+ db: PrefectDBInterface = Depends(provide_database_interface),
198
+ ) -> List[schemas.core.TaskRun]:
199
+ """
200
+ Query for task runs.
201
+ """
202
+ async with db.session_context() as session:
203
+ return await models.task_runs.read_task_runs(
204
+ session=session,
205
+ flow_filter=flows,
206
+ flow_run_filter=flow_runs,
207
+ task_run_filter=task_runs,
208
+ deployment_filter=deployments,
209
+ offset=offset,
210
+ limit=limit,
211
+ sort=sort,
212
+ )
213
+
214
+
215
+ @router.delete("/{id}", status_code=status.HTTP_204_NO_CONTENT)
216
+ async def delete_task_run(
217
+ task_run_id: UUID = Path(..., description="The task run id", alias="id"),
218
+ db: PrefectDBInterface = Depends(provide_database_interface),
219
+ ) -> None:
220
+ """
221
+ Delete a task run by id.
222
+ """
223
+ async with db.session_context(begin_transaction=True) as session:
224
+ result = await models.task_runs.delete_task_run(
225
+ session=session, task_run_id=task_run_id
226
+ )
227
+ if not result:
228
+ raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Task not found")
229
+
230
+
231
+ @router.post("/{id}/set_state")
232
+ async def set_task_run_state(
233
+ task_run_id: UUID = Path(..., description="The task run id", alias="id"),
234
+ state: schemas.actions.StateCreate = Body(..., description="The intended state."),
235
+ force: bool = Body(
236
+ False,
237
+ description=(
238
+ "If false, orchestration rules will be applied that may alter or prevent"
239
+ " the state transition. If True, orchestration rules are not applied."
240
+ ),
241
+ ),
242
+ db: PrefectDBInterface = Depends(provide_database_interface),
243
+ response: Response = None,
244
+ task_policy: TaskRunOrchestrationPolicy = Depends(
245
+ orchestration_dependencies.provide_task_policy
246
+ ),
247
+ orchestration_parameters: Dict[str, Any] = Depends(
248
+ orchestration_dependencies.provide_task_orchestration_parameters
249
+ ),
250
+ ) -> OrchestrationResult:
251
+ """Set a task run state, invoking any orchestration rules."""
252
+
253
+ right_now = now("UTC")
254
+
255
+ # create the state
256
+ async with db.session_context(
257
+ begin_transaction=True, with_for_update=True
258
+ ) as session:
259
+ orchestration_result = await models.task_runs.set_task_run_state(
260
+ session=session,
261
+ task_run_id=task_run_id,
262
+ state=schemas.states.State.model_validate(
263
+ state
264
+ ), # convert to a full State object
265
+ force=force,
266
+ task_policy=CoreTaskPolicy,
267
+ orchestration_parameters=orchestration_parameters,
268
+ )
269
+
270
+ # set the 201 if a new state was created
271
+ if orchestration_result.state and orchestration_result.state.timestamp >= right_now:
272
+ response.status_code = status.HTTP_201_CREATED
273
+ else:
274
+ response.status_code = status.HTTP_200_OK
275
+
276
+ return orchestration_result
277
+
278
+
279
+ @router.websocket("/subscriptions/scheduled")
280
+ async def scheduled_task_subscription(websocket: WebSocket) -> None:
281
+ websocket = await subscriptions.accept_prefect_socket(websocket)
282
+ if not websocket:
283
+ return
284
+
285
+ try:
286
+ subscription = await websocket.receive_json()
287
+ except subscriptions.NORMAL_DISCONNECT_EXCEPTIONS:
288
+ return
289
+
290
+ if subscription.get("type") != "subscribe":
291
+ return await websocket.close(
292
+ code=4001, reason="Protocol violation: expected 'subscribe' message"
293
+ )
294
+
295
+ task_keys = subscription.get("keys", [])
296
+ if not task_keys:
297
+ return await websocket.close(
298
+ code=4001, reason="Protocol violation: expected 'keys' in subscribe message"
299
+ )
300
+
301
+ if not (client_id := subscription.get("client_id")):
302
+ return await websocket.close(
303
+ code=4001,
304
+ reason="Protocol violation: expected 'client_id' in subscribe message",
305
+ )
306
+
307
+ subscribed_queue = MultiQueue(task_keys)
308
+
309
+ logger.info(f"Task worker {client_id!r} subscribed to task keys {task_keys!r}")
310
+
311
+ while True:
312
+ try:
313
+ # observe here so that all workers with active websockets are tracked
314
+ await models.task_workers.observe_worker(task_keys, client_id)
315
+ task_run = await asyncio.wait_for(subscribed_queue.get(), timeout=1)
316
+ except asyncio.TimeoutError:
317
+ if not await subscriptions.still_connected(websocket):
318
+ await models.task_workers.forget_worker(client_id)
319
+ return
320
+ continue
321
+
322
+ try:
323
+ await websocket.send_json(task_run.model_dump(mode="json"))
324
+
325
+ acknowledgement = await websocket.receive_json()
326
+ ack_type = acknowledgement.get("type")
327
+ if ack_type != "ack":
328
+ if ack_type == "quit":
329
+ return await websocket.close()
330
+
331
+ raise WebSocketDisconnect(
332
+ code=4001, reason="Protocol violation: expected 'ack' message"
333
+ )
334
+
335
+ await models.task_workers.observe_worker([task_run.task_key], client_id)
336
+
337
+ except subscriptions.NORMAL_DISCONNECT_EXCEPTIONS:
338
+ # If sending fails or pong fails, put the task back into the retry queue
339
+ await asyncio.shield(TaskQueue.for_key(task_run.task_key).retry(task_run))
340
+ return
341
+ finally:
342
+ await models.task_workers.forget_worker(client_id)
@@ -0,0 +1,31 @@
1
+ from typing import List, Optional
2
+
3
+ from fastapi import Body
4
+ from pydantic import BaseModel
5
+
6
+ from prefect.server import models
7
+ from prefect.server.models.task_workers import TaskWorkerResponse
8
+ from prefect.server.utilities.server import PrefectRouter
9
+
10
+ router: PrefectRouter = PrefectRouter(prefix="/task_workers", tags=["Task Workers"])
11
+
12
+
13
+ class TaskWorkerFilter(BaseModel):
14
+ task_keys: List[str]
15
+
16
+
17
+ @router.post("/filter")
18
+ async def read_task_workers(
19
+ task_worker_filter: Optional[TaskWorkerFilter] = Body(
20
+ default=None, description="The task worker filter", embed=True
21
+ ),
22
+ ) -> List[TaskWorkerResponse]:
23
+ """Read active task workers. Optionally filter by task keys."""
24
+
25
+ if task_worker_filter and task_worker_filter.task_keys:
26
+ return await models.task_workers.get_workers_for_task_keys(
27
+ task_keys=task_worker_filter.task_keys,
28
+ )
29
+
30
+ else:
31
+ return await models.task_workers.get_all_workers()
@@ -0,0 +1,35 @@
1
+ import orjson
2
+ from fastapi import Body, Response
3
+ from jinja2.exceptions import TemplateSyntaxError
4
+
5
+ from prefect.server.utilities.server import PrefectRouter
6
+ from prefect.server.utilities.user_templates import (
7
+ TemplateSecurityError,
8
+ validate_user_template,
9
+ )
10
+
11
+ router: PrefectRouter = PrefectRouter(prefix="/templates", tags=["Automations"])
12
+
13
+
14
+ @router.post(
15
+ "/validate",
16
+ response_class=Response,
17
+ )
18
+ def validate_template(template: str = Body(default="")) -> Response:
19
+ try:
20
+ validate_user_template(template)
21
+ return Response(content="", status_code=204)
22
+ except (TemplateSyntaxError, TemplateSecurityError) as e:
23
+ return Response(
24
+ status_code=422,
25
+ media_type="application/json",
26
+ content=orjson.dumps(
27
+ {
28
+ "error": {
29
+ "line": e.lineno,
30
+ "message": e.message,
31
+ "source": template,
32
+ },
33
+ }
34
+ ),
35
+ )
@@ -0,0 +1,3 @@
1
+ """Routes primarily for use by the UI"""
2
+
3
+ from . import flows, flow_runs, schemas, task_runs
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ from typing import TYPE_CHECKING, List
5
+ from uuid import UUID
6
+
7
+ import sqlalchemy as sa
8
+ from fastapi import Body, Depends
9
+ from pydantic import Field
10
+
11
+ import prefect.server.schemas as schemas
12
+ from prefect.logging import get_logger
13
+ from prefect.server import models
14
+ from prefect.server.database import PrefectDBInterface, provide_database_interface
15
+ from prefect.server.utilities.schemas.bases import PrefectBaseModel
16
+ from prefect.server.utilities.server import PrefectRouter
17
+ from prefect.types import DateTime
18
+
19
+ if TYPE_CHECKING:
20
+ import logging
21
+
22
+ logger: "logging.Logger" = get_logger("server.api.ui.flow_runs")
23
+
24
+ router: PrefectRouter = PrefectRouter(prefix="/ui/flow_runs", tags=["Flow Runs", "UI"])
25
+
26
+
27
+ class SimpleFlowRun(PrefectBaseModel):
28
+ id: UUID = Field(default=..., description="The flow run id.")
29
+ state_type: schemas.states.StateType = Field(
30
+ default=..., description="The state type."
31
+ )
32
+ timestamp: DateTime = Field(
33
+ default=...,
34
+ description=(
35
+ "The start time of the run, or the expected start time "
36
+ "if it hasn't run yet."
37
+ ),
38
+ )
39
+ duration: datetime.timedelta = Field(
40
+ default=..., description="The total run time of the run."
41
+ )
42
+ lateness: datetime.timedelta = Field(
43
+ default=..., description="The delay between the expected and actual start time."
44
+ )
45
+
46
+
47
+ @router.post("/history")
48
+ async def read_flow_run_history(
49
+ sort: schemas.sorting.FlowRunSort = Body(
50
+ schemas.sorting.FlowRunSort.EXPECTED_START_TIME_DESC
51
+ ),
52
+ limit: int = Body(1000, le=1000),
53
+ offset: int = Body(0, ge=0),
54
+ flows: schemas.filters.FlowFilter = None,
55
+ flow_runs: schemas.filters.FlowRunFilter = None,
56
+ task_runs: schemas.filters.TaskRunFilter = None,
57
+ deployments: schemas.filters.DeploymentFilter = None,
58
+ work_pools: schemas.filters.WorkPoolFilter = None,
59
+ db: PrefectDBInterface = Depends(provide_database_interface),
60
+ ) -> List[SimpleFlowRun]:
61
+ columns = [
62
+ db.FlowRun.id,
63
+ db.FlowRun.state_type,
64
+ db.FlowRun.start_time,
65
+ db.FlowRun.expected_start_time,
66
+ db.FlowRun.total_run_time,
67
+ # Although it isn't returned, we need to select
68
+ # this field in order to compute `estimated_run_time`
69
+ db.FlowRun.state_timestamp,
70
+ ]
71
+ async with db.session_context() as session:
72
+ result = await models.flow_runs.read_flow_runs(
73
+ columns=columns,
74
+ flow_filter=flows,
75
+ flow_run_filter=flow_runs,
76
+ task_run_filter=task_runs,
77
+ deployment_filter=deployments,
78
+ work_pool_filter=work_pools,
79
+ sort=sort,
80
+ limit=limit,
81
+ offset=offset,
82
+ session=session,
83
+ )
84
+ return [
85
+ SimpleFlowRun(
86
+ id=r.id,
87
+ state_type=r.state_type,
88
+ timestamp=r.start_time or r.expected_start_time,
89
+ duration=r.estimated_run_time,
90
+ lateness=r.estimated_start_time_delta,
91
+ )
92
+ for r in result
93
+ ]
94
+
95
+
96
+ @router.post("/count-task-runs")
97
+ async def count_task_runs_by_flow_run(
98
+ flow_run_ids: list[UUID] = Body(default=..., embed=True, max_items=200),
99
+ db: PrefectDBInterface = Depends(provide_database_interface),
100
+ ) -> dict[UUID, int]:
101
+ """
102
+ Get task run counts by flow run id.
103
+ """
104
+ async with db.session_context() as session:
105
+ query = (
106
+ sa.select(
107
+ db.TaskRun.flow_run_id,
108
+ sa.func.count(db.TaskRun.id).label("task_run_count"),
109
+ )
110
+ .where(
111
+ sa.and_(
112
+ db.TaskRun.flow_run_id.in_(flow_run_ids),
113
+ sa.not_(db.TaskRun.subflow_run.has()),
114
+ )
115
+ )
116
+ .group_by(db.TaskRun.flow_run_id)
117
+ )
118
+
119
+ results = await session.execute(query)
120
+
121
+ task_run_counts_by_flow_run = {
122
+ flow_run_id: task_run_count for flow_run_id, task_run_count in results.t
123
+ }
124
+
125
+ return {
126
+ flow_run_id: task_run_counts_by_flow_run.get(flow_run_id, 0)
127
+ for flow_run_id in flow_run_ids
128
+ }