prefect-client 3.1.12__py3-none-any.whl → 3.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefect/_experimental/sla/client.py +53 -27
- prefect/_experimental/sla/objects.py +10 -2
- prefect/_internal/concurrency/services.py +2 -2
- prefect/_internal/concurrency/threads.py +6 -0
- prefect/_internal/retries.py +6 -3
- prefect/_internal/schemas/validators.py +6 -4
- prefect/_version.py +3 -3
- prefect/artifacts.py +4 -1
- prefect/automations.py +1 -1
- prefect/blocks/abstract.py +5 -2
- prefect/blocks/notifications.py +1 -0
- prefect/cache_policies.py +20 -20
- prefect/client/utilities.py +3 -3
- prefect/deployments/base.py +7 -4
- prefect/deployments/flow_runs.py +5 -1
- prefect/deployments/runner.py +6 -11
- prefect/deployments/steps/core.py +1 -1
- prefect/deployments/steps/pull.py +8 -3
- prefect/deployments/steps/utility.py +2 -2
- prefect/docker/docker_image.py +13 -9
- prefect/engine.py +19 -10
- prefect/events/cli/automations.py +4 -4
- prefect/events/clients.py +17 -14
- prefect/events/schemas/automations.py +12 -8
- prefect/events/schemas/events.py +5 -1
- prefect/events/worker.py +1 -1
- prefect/filesystems.py +1 -1
- prefect/flow_engine.py +17 -9
- prefect/flows.py +118 -73
- prefect/futures.py +14 -7
- prefect/infrastructure/provisioners/__init__.py +2 -0
- prefect/infrastructure/provisioners/cloud_run.py +4 -4
- prefect/infrastructure/provisioners/coiled.py +249 -0
- prefect/infrastructure/provisioners/container_instance.py +4 -3
- prefect/infrastructure/provisioners/ecs.py +55 -43
- prefect/infrastructure/provisioners/modal.py +5 -4
- prefect/input/actions.py +5 -1
- prefect/input/run_input.py +157 -43
- prefect/logging/configuration.py +3 -3
- prefect/logging/filters.py +2 -2
- prefect/logging/formatters.py +15 -11
- prefect/logging/handlers.py +24 -14
- prefect/logging/highlighters.py +5 -5
- prefect/logging/loggers.py +28 -18
- prefect/main.py +3 -1
- prefect/results.py +166 -86
- prefect/runner/runner.py +34 -27
- prefect/runner/server.py +3 -1
- prefect/runner/storage.py +18 -18
- prefect/runner/submit.py +19 -12
- prefect/runtime/deployment.py +15 -8
- prefect/runtime/flow_run.py +19 -6
- prefect/runtime/task_run.py +7 -3
- prefect/settings/base.py +17 -7
- prefect/settings/legacy.py +4 -4
- prefect/settings/models/api.py +4 -3
- prefect/settings/models/cli.py +4 -3
- prefect/settings/models/client.py +7 -4
- prefect/settings/models/cloud.py +4 -3
- prefect/settings/models/deployments.py +4 -3
- prefect/settings/models/experiments.py +4 -3
- prefect/settings/models/flows.py +4 -3
- prefect/settings/models/internal.py +4 -3
- prefect/settings/models/logging.py +8 -6
- prefect/settings/models/results.py +4 -3
- prefect/settings/models/root.py +11 -16
- prefect/settings/models/runner.py +8 -5
- prefect/settings/models/server/api.py +6 -3
- prefect/settings/models/server/database.py +120 -25
- prefect/settings/models/server/deployments.py +4 -3
- prefect/settings/models/server/ephemeral.py +7 -4
- prefect/settings/models/server/events.py +6 -3
- prefect/settings/models/server/flow_run_graph.py +4 -3
- prefect/settings/models/server/root.py +4 -3
- prefect/settings/models/server/services.py +15 -12
- prefect/settings/models/server/tasks.py +7 -4
- prefect/settings/models/server/ui.py +4 -3
- prefect/settings/models/tasks.py +10 -5
- prefect/settings/models/testing.py +4 -3
- prefect/settings/models/worker.py +7 -4
- prefect/settings/profiles.py +13 -12
- prefect/settings/sources.py +20 -19
- prefect/states.py +17 -13
- prefect/task_engine.py +43 -33
- prefect/task_runners.py +35 -23
- prefect/task_runs.py +20 -11
- prefect/task_worker.py +12 -7
- prefect/tasks.py +30 -24
- prefect/telemetry/bootstrap.py +4 -1
- prefect/telemetry/run_telemetry.py +15 -13
- prefect/transactions.py +3 -3
- prefect/types/__init__.py +3 -1
- prefect/utilities/_deprecated.py +38 -0
- prefect/utilities/engine.py +11 -4
- prefect/utilities/filesystem.py +2 -2
- prefect/utilities/generics.py +1 -1
- prefect/utilities/pydantic.py +21 -36
- prefect/workers/base.py +52 -30
- prefect/workers/process.py +20 -15
- prefect/workers/server.py +4 -5
- {prefect_client-3.1.12.dist-info → prefect_client-3.1.13.dist-info}/METADATA +2 -2
- {prefect_client-3.1.12.dist-info → prefect_client-3.1.13.dist-info}/RECORD +105 -103
- {prefect_client-3.1.12.dist-info → prefect_client-3.1.13.dist-info}/LICENSE +0 -0
- {prefect_client-3.1.12.dist-info → prefect_client-3.1.13.dist-info}/WHEEL +0 -0
- {prefect_client-3.1.12.dist-info → prefect_client-3.1.13.dist-info}/top_level.txt +0 -0
prefect/task_engine.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import inspect
|
3
5
|
import logging
|
@@ -28,8 +30,9 @@ from uuid import UUID
|
|
28
30
|
import anyio
|
29
31
|
import pendulum
|
30
32
|
from opentelemetry import trace
|
31
|
-
from typing_extensions import ParamSpec
|
33
|
+
from typing_extensions import ParamSpec, Self
|
32
34
|
|
35
|
+
from prefect.cache_policies import CachePolicy
|
33
36
|
from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client
|
34
37
|
from prefect.client.schemas import TaskRun
|
35
38
|
from prefect.client.schemas.objects import State, TaskRunInput
|
@@ -55,7 +58,7 @@ from prefect.exceptions import (
|
|
55
58
|
from prefect.logging.loggers import get_logger, patch_print, task_run_logger
|
56
59
|
from prefect.results import (
|
57
60
|
ResultRecord,
|
58
|
-
_format_user_supplied_storage_key,
|
61
|
+
_format_user_supplied_storage_key, # type: ignore[reportPrivateUsage]
|
59
62
|
get_result_store,
|
60
63
|
should_persist_result,
|
61
64
|
)
|
@@ -115,20 +118,20 @@ class BaseTaskRunEngine(Generic[P, R]):
|
|
115
118
|
# holds the return value from the user code
|
116
119
|
_return_value: Union[R, Type[NotSet]] = NotSet
|
117
120
|
# holds the exception raised by the user code, if any
|
118
|
-
_raised: Union[Exception, Type[NotSet]] = NotSet
|
121
|
+
_raised: Union[Exception, BaseException, Type[NotSet]] = NotSet
|
119
122
|
_initial_run_context: Optional[TaskRunContext] = None
|
120
123
|
_is_started: bool = False
|
121
124
|
_task_name_set: bool = False
|
122
125
|
_last_event: Optional[PrefectEvent] = None
|
123
126
|
_telemetry: RunTelemetry = field(default_factory=RunTelemetry)
|
124
127
|
|
125
|
-
def __post_init__(self):
|
128
|
+
def __post_init__(self) -> None:
|
126
129
|
if self.parameters is None:
|
127
130
|
self.parameters = {}
|
128
131
|
|
129
132
|
@property
|
130
133
|
def state(self) -> State:
|
131
|
-
if not self.task_run:
|
134
|
+
if not self.task_run or not self.task_run.state:
|
132
135
|
raise ValueError("Task run is not set")
|
133
136
|
return self.task_run.state
|
134
137
|
|
@@ -142,8 +145,8 @@ class BaseTaskRunEngine(Generic[P, R]):
|
|
142
145
|
return False
|
143
146
|
|
144
147
|
def compute_transaction_key(self) -> Optional[str]:
|
145
|
-
key = None
|
146
|
-
if self.task.cache_policy:
|
148
|
+
key: Optional[str] = None
|
149
|
+
if self.task.cache_policy and isinstance(self.task.cache_policy, CachePolicy):
|
147
150
|
flow_run_context = FlowRunContext.get()
|
148
151
|
task_run_context = TaskRunContext.get()
|
149
152
|
|
@@ -153,10 +156,12 @@ class BaseTaskRunEngine(Generic[P, R]):
|
|
153
156
|
parameters = None
|
154
157
|
|
155
158
|
try:
|
159
|
+
if not task_run_context:
|
160
|
+
raise ValueError("Task run context is not set")
|
156
161
|
key = self.task.cache_policy.compute_key(
|
157
162
|
task_ctx=task_run_context,
|
158
|
-
inputs=self.parameters,
|
159
|
-
flow_parameters=parameters,
|
163
|
+
inputs=self.parameters or {},
|
164
|
+
flow_parameters=parameters or {},
|
160
165
|
)
|
161
166
|
except Exception:
|
162
167
|
self.logger.exception(
|
@@ -169,7 +174,7 @@ class BaseTaskRunEngine(Generic[P, R]):
|
|
169
174
|
|
170
175
|
def _resolve_parameters(self):
|
171
176
|
if not self.parameters:
|
172
|
-
return
|
177
|
+
return None
|
173
178
|
|
174
179
|
resolved_parameters = {}
|
175
180
|
for parameter, value in self.parameters.items():
|
@@ -227,10 +232,8 @@ class BaseTaskRunEngine(Generic[P, R]):
|
|
227
232
|
if self.task_run and self.task_run.start_time and not self.task_run.end_time:
|
228
233
|
self.task_run.end_time = state.timestamp
|
229
234
|
|
230
|
-
if self.
|
231
|
-
self.task_run.total_run_time +=
|
232
|
-
state.timestamp - self.task_run.state.timestamp
|
233
|
-
)
|
235
|
+
if self.state.is_running():
|
236
|
+
self.task_run.total_run_time += state.timestamp - self.state.timestamp
|
234
237
|
|
235
238
|
def is_running(self) -> bool:
|
236
239
|
"""Whether or not the engine is currently running a task."""
|
@@ -238,7 +241,7 @@ class BaseTaskRunEngine(Generic[P, R]):
|
|
238
241
|
return False
|
239
242
|
return task_run.state.is_running() or task_run.state.is_scheduled()
|
240
243
|
|
241
|
-
def log_finished_message(self):
|
244
|
+
def log_finished_message(self) -> None:
|
242
245
|
if not self.task_run:
|
243
246
|
return
|
244
247
|
|
@@ -294,6 +297,7 @@ class BaseTaskRunEngine(Generic[P, R]):
|
|
294
297
|
|
295
298
|
@dataclass
|
296
299
|
class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
300
|
+
task_run: Optional[TaskRun] = None
|
297
301
|
_client: Optional[SyncPrefectClient] = None
|
298
302
|
|
299
303
|
@property
|
@@ -336,7 +340,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
336
340
|
)
|
337
341
|
return False
|
338
342
|
|
339
|
-
def call_hooks(self, state: Optional[State] = None):
|
343
|
+
def call_hooks(self, state: Optional[State] = None) -> None:
|
340
344
|
if state is None:
|
341
345
|
state = self.state
|
342
346
|
task = self.task
|
@@ -371,7 +375,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
371
375
|
else:
|
372
376
|
self.logger.info(f"Hook {hook_name!r} finished running successfully")
|
373
377
|
|
374
|
-
def begin_run(self):
|
378
|
+
def begin_run(self) -> None:
|
375
379
|
try:
|
376
380
|
self._resolve_parameters()
|
377
381
|
self._set_custom_task_run_name()
|
@@ -390,6 +394,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
390
394
|
|
391
395
|
new_state = Running()
|
392
396
|
|
397
|
+
assert self.task_run is not None, "Task run is not set"
|
393
398
|
self.task_run.start_time = new_state.timestamp
|
394
399
|
|
395
400
|
flow_run_context = FlowRunContext.get()
|
@@ -406,7 +411,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
406
411
|
# result reference that no longer exists
|
407
412
|
if state.is_completed():
|
408
413
|
try:
|
409
|
-
state.result(retry_result_failure=False, _sync=True)
|
414
|
+
state.result(retry_result_failure=False, _sync=True) # type: ignore[reportCallIssue]
|
410
415
|
except Exception:
|
411
416
|
state = self.set_state(new_state, force=True)
|
412
417
|
|
@@ -422,7 +427,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
422
427
|
time.sleep(interval)
|
423
428
|
state = self.set_state(new_state)
|
424
429
|
|
425
|
-
def set_state(self, state: State, force: bool = False) -> State:
|
430
|
+
def set_state(self, state: State[R], force: bool = False) -> State[R]:
|
426
431
|
last_state = self.state
|
427
432
|
if not self.task_run:
|
428
433
|
raise ValueError("Task run is not set")
|
@@ -537,7 +542,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
537
542
|
new_state = Retrying()
|
538
543
|
|
539
544
|
self.logger.info(
|
540
|
-
"Task run failed with exception: %r -
|
545
|
+
"Task run failed with exception: %r - Retry %s/%s will start %s",
|
541
546
|
exc,
|
542
547
|
self.retries + 1,
|
543
548
|
self.task.retries,
|
@@ -545,7 +550,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
545
550
|
)
|
546
551
|
|
547
552
|
self.set_state(new_state, force=True)
|
548
|
-
self.retries = self.retries + 1
|
553
|
+
self.retries: int = self.retries + 1
|
549
554
|
return True
|
550
555
|
elif self.retries >= self.task.retries:
|
551
556
|
self.logger.error(
|
@@ -639,7 +644,9 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
639
644
|
stack.enter_context(ConcurrencyContextV1())
|
640
645
|
stack.enter_context(ConcurrencyContext())
|
641
646
|
|
642
|
-
self.logger = task_run_logger(
|
647
|
+
self.logger: "logging.Logger" = task_run_logger(
|
648
|
+
task_run=self.task_run, task=self.task
|
649
|
+
) # type: ignore
|
643
650
|
|
644
651
|
yield
|
645
652
|
|
@@ -648,7 +655,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
648
655
|
self,
|
649
656
|
task_run_id: Optional[UUID] = None,
|
650
657
|
dependencies: Optional[dict[str, set[TaskRunInput]]] = None,
|
651
|
-
) -> Generator[
|
658
|
+
) -> Generator[Self, Any, Any]:
|
652
659
|
"""
|
653
660
|
Enters a client context and creates a task run if needed.
|
654
661
|
"""
|
@@ -718,7 +725,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
718
725
|
self._is_started = False
|
719
726
|
self._client = None
|
720
727
|
|
721
|
-
async def wait_until_ready(self):
|
728
|
+
async def wait_until_ready(self) -> None:
|
722
729
|
"""Waits until the scheduled time (if its the future), then enters Running."""
|
723
730
|
if scheduled_time := self.state.state_details.scheduled_time:
|
724
731
|
sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
|
@@ -825,6 +832,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
825
832
|
|
826
833
|
@dataclass
|
827
834
|
class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
835
|
+
task_run: TaskRun | None = None
|
828
836
|
_client: Optional[PrefectClient] = None
|
829
837
|
|
830
838
|
@property
|
@@ -866,7 +874,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
866
874
|
)
|
867
875
|
return False
|
868
876
|
|
869
|
-
async def call_hooks(self, state: Optional[State] = None):
|
877
|
+
async def call_hooks(self, state: Optional[State] = None) -> None:
|
870
878
|
if state is None:
|
871
879
|
state = self.state
|
872
880
|
task = self.task
|
@@ -901,7 +909,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
901
909
|
else:
|
902
910
|
self.logger.info(f"Hook {hook_name!r} finished running successfully")
|
903
911
|
|
904
|
-
async def begin_run(self):
|
912
|
+
async def begin_run(self) -> None:
|
905
913
|
try:
|
906
914
|
self._resolve_parameters()
|
907
915
|
self._set_custom_task_run_name()
|
@@ -1067,7 +1075,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
1067
1075
|
new_state = Retrying()
|
1068
1076
|
|
1069
1077
|
self.logger.info(
|
1070
|
-
"Task run failed with exception: %r -
|
1078
|
+
"Task run failed with exception: %r - Retry %s/%s will start %s",
|
1071
1079
|
exc,
|
1072
1080
|
self.retries + 1,
|
1073
1081
|
self.task.retries,
|
@@ -1075,7 +1083,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
1075
1083
|
)
|
1076
1084
|
|
1077
1085
|
await self.set_state(new_state, force=True)
|
1078
|
-
self.retries = self.retries + 1
|
1086
|
+
self.retries: int = self.retries + 1
|
1079
1087
|
return True
|
1080
1088
|
elif self.retries >= self.task.retries:
|
1081
1089
|
self.logger.error(
|
@@ -1169,7 +1177,9 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
1169
1177
|
)
|
1170
1178
|
stack.enter_context(ConcurrencyContext())
|
1171
1179
|
|
1172
|
-
self.logger = task_run_logger(
|
1180
|
+
self.logger: "logging.Logger" = task_run_logger(
|
1181
|
+
task_run=self.task_run, task=self.task
|
1182
|
+
) # type: ignore
|
1173
1183
|
|
1174
1184
|
yield
|
1175
1185
|
|
@@ -1178,7 +1188,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
1178
1188
|
self,
|
1179
1189
|
task_run_id: Optional[UUID] = None,
|
1180
1190
|
dependencies: Optional[dict[str, set[TaskRunInput]]] = None,
|
1181
|
-
) -> AsyncGenerator[
|
1191
|
+
) -> AsyncGenerator[Self, Any]:
|
1182
1192
|
"""
|
1183
1193
|
Enters a client context and creates a task run if needed.
|
1184
1194
|
"""
|
@@ -1246,7 +1256,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
1246
1256
|
self._is_started = False
|
1247
1257
|
self._client = None
|
1248
1258
|
|
1249
|
-
async def wait_until_ready(self):
|
1259
|
+
async def wait_until_ready(self) -> None:
|
1250
1260
|
"""Waits until the scheduled time (if its the future), then enters Running."""
|
1251
1261
|
if scheduled_time := self.state.state_details.scheduled_time:
|
1252
1262
|
sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
|
@@ -1341,7 +1351,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
|
1341
1351
|
if transaction.is_committed():
|
1342
1352
|
result = transaction.read()
|
1343
1353
|
else:
|
1344
|
-
if self.task_run.tags:
|
1354
|
+
if self.task_run and self.task_run.tags:
|
1345
1355
|
# Acquire a concurrency slot for each tag, but only if a limit
|
1346
1356
|
# matching the tag already exists.
|
1347
1357
|
async with aconcurrency(list(self.task_run.tags), self.task_run.id):
|
@@ -1546,7 +1556,7 @@ def run_task(
|
|
1546
1556
|
Returns:
|
1547
1557
|
The result of the task run
|
1548
1558
|
"""
|
1549
|
-
kwargs = dict(
|
1559
|
+
kwargs: dict[str, Any] = dict(
|
1550
1560
|
task=task,
|
1551
1561
|
task_run_id=task_run_id,
|
1552
1562
|
task_run=task_run,
|
prefect/task_runners.py
CHANGED
@@ -40,6 +40,8 @@ from prefect.utilities.callables import (
|
|
40
40
|
from prefect.utilities.collections import isiterable
|
41
41
|
|
42
42
|
if TYPE_CHECKING:
|
43
|
+
import logging
|
44
|
+
|
43
45
|
from prefect.tasks import Task
|
44
46
|
|
45
47
|
P = ParamSpec("P")
|
@@ -61,11 +63,11 @@ class TaskRunner(abc.ABC, Generic[F]):
|
|
61
63
|
"""
|
62
64
|
|
63
65
|
def __init__(self):
|
64
|
-
self.logger = get_logger(f"task_runner.{self.name}")
|
66
|
+
self.logger: "logging.Logger" = get_logger(f"task_runner.{self.name}")
|
65
67
|
self._started = False
|
66
68
|
|
67
69
|
@property
|
68
|
-
def name(self):
|
70
|
+
def name(self) -> str:
|
69
71
|
"""The name of this task runner"""
|
70
72
|
return type(self).__name__.lower().replace("taskrunner", "")
|
71
73
|
|
@@ -74,32 +76,42 @@ class TaskRunner(abc.ABC, Generic[F]):
|
|
74
76
|
"""Return a new instance of this task runner with the same configuration."""
|
75
77
|
...
|
76
78
|
|
79
|
+
@overload
|
77
80
|
@abc.abstractmethod
|
78
81
|
def submit(
|
79
82
|
self,
|
80
|
-
task: "Task[P, R]",
|
83
|
+
task: "Task[P, Coroutine[Any, Any, R]]",
|
81
84
|
parameters: dict[str, Any],
|
82
85
|
wait_for: Iterable[PrefectFuture[Any]] | None = None,
|
83
86
|
dependencies: dict[str, set[TaskRunInput]] | None = None,
|
84
87
|
) -> F:
|
85
|
-
|
86
|
-
Submit a task to the task run engine.
|
88
|
+
...
|
87
89
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
90
|
+
@overload
|
91
|
+
@abc.abstractmethod
|
92
|
+
def submit(
|
93
|
+
self,
|
94
|
+
task: "Task[Any, R]",
|
95
|
+
parameters: dict[str, Any],
|
96
|
+
wait_for: Iterable[PrefectFuture[Any]] | None = None,
|
97
|
+
dependencies: dict[str, set[TaskRunInput]] | None = None,
|
98
|
+
) -> F:
|
99
|
+
...
|
92
100
|
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
""
|
101
|
+
@abc.abstractmethod
|
102
|
+
def submit(
|
103
|
+
self,
|
104
|
+
task: "Task[P, R]",
|
105
|
+
parameters: dict[str, Any],
|
106
|
+
wait_for: Iterable[PrefectFuture[Any]] | None = None,
|
107
|
+
dependencies: dict[str, set[TaskRunInput]] | None = None,
|
108
|
+
) -> F:
|
97
109
|
...
|
98
110
|
|
99
111
|
def map(
|
100
112
|
self,
|
101
113
|
task: "Task[P, R]",
|
102
|
-
parameters: dict[str, Any],
|
114
|
+
parameters: dict[str, Any | unmapped[Any] | allow_failure[Any]],
|
103
115
|
wait_for: Optional[Iterable[PrefectFuture[R]]] = None,
|
104
116
|
) -> PrefectFutureList[F]:
|
105
117
|
"""
|
@@ -205,7 +217,7 @@ class TaskRunner(abc.ABC, Generic[F]):
|
|
205
217
|
|
206
218
|
return PrefectFutureList(futures)
|
207
219
|
|
208
|
-
def __enter__(self):
|
220
|
+
def __enter__(self) -> Self:
|
209
221
|
if self._started:
|
210
222
|
raise RuntimeError("This task runner is already started")
|
211
223
|
|
@@ -213,12 +225,12 @@ class TaskRunner(abc.ABC, Generic[F]):
|
|
213
225
|
self._started = True
|
214
226
|
return self
|
215
227
|
|
216
|
-
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
|
228
|
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
217
229
|
self.logger.debug("Stopping task runner")
|
218
230
|
self._started = False
|
219
231
|
|
220
232
|
|
221
|
-
class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[
|
233
|
+
class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[R]]):
|
222
234
|
def __init__(self, max_workers: Optional[int] = None):
|
223
235
|
super().__init__()
|
224
236
|
self._executor: Optional[ThreadPoolExecutor] = None
|
@@ -229,7 +241,7 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[Any]]):
|
|
229
241
|
)
|
230
242
|
self._cancel_events: Dict[uuid.UUID, threading.Event] = {}
|
231
243
|
|
232
|
-
def duplicate(self) -> "ThreadPoolTaskRunner":
|
244
|
+
def duplicate(self) -> "ThreadPoolTaskRunner[R]":
|
233
245
|
return type(self)(max_workers=self._max_workers)
|
234
246
|
|
235
247
|
@overload
|
@@ -254,7 +266,7 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[Any]]):
|
|
254
266
|
|
255
267
|
def submit(
|
256
268
|
self,
|
257
|
-
task: "Task[P, R]",
|
269
|
+
task: "Task[P, R | Coroutine[Any, Any, R]]",
|
258
270
|
parameters: dict[str, Any],
|
259
271
|
wait_for: Iterable[PrefectFuture[Any]] | None = None,
|
260
272
|
dependencies: dict[str, set[TaskRunInput]] | None = None,
|
@@ -345,7 +357,7 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[Any]]):
|
|
345
357
|
) -> PrefectFutureList[PrefectConcurrentFuture[R]]:
|
346
358
|
return super().map(task, parameters, wait_for)
|
347
359
|
|
348
|
-
def cancel_all(self):
|
360
|
+
def cancel_all(self) -> None:
|
349
361
|
for event in self._cancel_events.values():
|
350
362
|
event.set()
|
351
363
|
self.logger.debug("Set cancel event")
|
@@ -354,12 +366,12 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[Any]]):
|
|
354
366
|
self._executor.shutdown(cancel_futures=True)
|
355
367
|
self._executor = None
|
356
368
|
|
357
|
-
def __enter__(self):
|
369
|
+
def __enter__(self) -> Self:
|
358
370
|
super().__enter__()
|
359
371
|
self._executor = ThreadPoolExecutor(max_workers=self._max_workers)
|
360
372
|
return self
|
361
373
|
|
362
|
-
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
|
374
|
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
363
375
|
self.cancel_all()
|
364
376
|
if self._executor is not None:
|
365
377
|
self._executor.shutdown(cancel_futures=True)
|
@@ -380,7 +392,7 @@ class PrefectTaskRunner(TaskRunner[PrefectDistributedFuture[R]]):
|
|
380
392
|
def __init__(self):
|
381
393
|
super().__init__()
|
382
394
|
|
383
|
-
def duplicate(self) -> "PrefectTaskRunner":
|
395
|
+
def duplicate(self) -> "PrefectTaskRunner[R]":
|
384
396
|
return type(self)()
|
385
397
|
|
386
398
|
@overload
|
prefect/task_runs.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import atexit
|
3
5
|
import threading
|
4
6
|
import uuid
|
5
|
-
from typing import Callable, Dict, Optional
|
7
|
+
from typing import TYPE_CHECKING, Callable, Dict, Optional
|
6
8
|
|
7
9
|
import anyio
|
8
10
|
from cachetools import TTLCache
|
@@ -15,6 +17,9 @@ from prefect.events.clients import get_events_subscriber
|
|
15
17
|
from prefect.events.filters import EventFilter, EventNameFilter
|
16
18
|
from prefect.logging.loggers import get_logger
|
17
19
|
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
import logging
|
22
|
+
|
18
23
|
|
19
24
|
class TaskRunWaiter:
|
20
25
|
"""
|
@@ -68,19 +73,19 @@ class TaskRunWaiter:
|
|
68
73
|
_instance_lock = threading.Lock()
|
69
74
|
|
70
75
|
def __init__(self):
|
71
|
-
self.logger = get_logger("TaskRunWaiter")
|
72
|
-
self._consumer_task:
|
76
|
+
self.logger: "logging.Logger" = get_logger("TaskRunWaiter")
|
77
|
+
self._consumer_task: "asyncio.Task[None] | None" = None
|
73
78
|
self._observed_completed_task_runs: TTLCache[uuid.UUID, bool] = TTLCache(
|
74
79
|
maxsize=10000, ttl=600
|
75
80
|
)
|
76
81
|
self._completion_events: Dict[uuid.UUID, asyncio.Event] = {}
|
77
|
-
self._completion_callbacks: Dict[uuid.UUID, Callable] = {}
|
82
|
+
self._completion_callbacks: Dict[uuid.UUID, Callable[[], None]] = {}
|
78
83
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
79
84
|
self._observed_completed_task_runs_lock = threading.Lock()
|
80
85
|
self._completion_events_lock = threading.Lock()
|
81
86
|
self._started = False
|
82
87
|
|
83
|
-
def start(self):
|
88
|
+
def start(self) -> None:
|
84
89
|
"""
|
85
90
|
Start the TaskRunWaiter service.
|
86
91
|
"""
|
@@ -89,10 +94,12 @@ class TaskRunWaiter:
|
|
89
94
|
self.logger.debug("Starting TaskRunWaiter")
|
90
95
|
loop_thread = get_global_loop()
|
91
96
|
|
92
|
-
if not asyncio.get_running_loop() == loop_thread.
|
97
|
+
if not asyncio.get_running_loop() == loop_thread.loop:
|
93
98
|
raise RuntimeError("TaskRunWaiter must run on the global loop thread.")
|
94
99
|
|
95
|
-
self._loop = loop_thread.
|
100
|
+
self._loop = loop_thread.loop
|
101
|
+
if TYPE_CHECKING:
|
102
|
+
assert self._loop is not None
|
96
103
|
|
97
104
|
consumer_started = asyncio.Event()
|
98
105
|
self._consumer_task = self._loop.create_task(
|
@@ -141,7 +148,7 @@ class TaskRunWaiter:
|
|
141
148
|
except Exception as exc:
|
142
149
|
self.logger.error(f"Error processing event: {exc}")
|
143
150
|
|
144
|
-
def stop(self):
|
151
|
+
def stop(self) -> None:
|
145
152
|
"""
|
146
153
|
Stop the TaskRunWaiter service.
|
147
154
|
"""
|
@@ -155,7 +162,7 @@ class TaskRunWaiter:
|
|
155
162
|
@classmethod
|
156
163
|
async def wait_for_task_run(
|
157
164
|
cls, task_run_id: uuid.UUID, timeout: Optional[float] = None
|
158
|
-
):
|
165
|
+
) -> None:
|
159
166
|
"""
|
160
167
|
Wait for a task run to finish.
|
161
168
|
|
@@ -199,7 +206,9 @@ class TaskRunWaiter:
|
|
199
206
|
instance._completion_events.pop(task_run_id, None)
|
200
207
|
|
201
208
|
@classmethod
|
202
|
-
def add_done_callback(
|
209
|
+
def add_done_callback(
|
210
|
+
cls, task_run_id: uuid.UUID, callback: Callable[[], None]
|
211
|
+
) -> None:
|
203
212
|
"""
|
204
213
|
Add a callback to be called when a task run finishes.
|
205
214
|
|
@@ -219,7 +228,7 @@ class TaskRunWaiter:
|
|
219
228
|
instance._completion_callbacks[task_run_id] = callback
|
220
229
|
|
221
230
|
@classmethod
|
222
|
-
def instance(cls):
|
231
|
+
def instance(cls) -> Self:
|
223
232
|
"""
|
224
233
|
Get the singleton instance of TaskRunWaiter.
|
225
234
|
"""
|
prefect/task_worker.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
import inspect
|
3
5
|
import os
|
@@ -16,7 +18,7 @@ import pendulum
|
|
16
18
|
import uvicorn
|
17
19
|
from exceptiongroup import BaseExceptionGroup # novermin
|
18
20
|
from fastapi import FastAPI
|
19
|
-
from typing_extensions import ParamSpec, TypeVar
|
21
|
+
from typing_extensions import ParamSpec, Self, TypeVar
|
20
22
|
from websockets.exceptions import InvalidStatusCode
|
21
23
|
|
22
24
|
from prefect import Task
|
@@ -42,7 +44,10 @@ from prefect.utilities.processutils import (
|
|
42
44
|
from prefect.utilities.services import start_client_metrics_server
|
43
45
|
from prefect.utilities.urls import url_for
|
44
46
|
|
45
|
-
|
47
|
+
if TYPE_CHECKING:
|
48
|
+
import logging
|
49
|
+
|
50
|
+
logger: "logging.Logger" = get_logger("task_worker")
|
46
51
|
|
47
52
|
P = ParamSpec("P")
|
48
53
|
R = TypeVar("R", infer_variance=True)
|
@@ -85,7 +90,7 @@ class TaskWorker:
|
|
85
90
|
def __init__(
|
86
91
|
self,
|
87
92
|
*tasks: Task[P, R],
|
88
|
-
limit:
|
93
|
+
limit: int | None = 10,
|
89
94
|
):
|
90
95
|
self.tasks: list["Task[..., Any]"] = []
|
91
96
|
for t in tasks:
|
@@ -100,7 +105,7 @@ class TaskWorker:
|
|
100
105
|
else:
|
101
106
|
self.tasks.append(t.with_options(persist_result=True))
|
102
107
|
|
103
|
-
self.task_keys = set(t.task_key for t in tasks if isinstance(t, Task)) # pyright: ignore[reportUnnecessaryIsInstance]
|
108
|
+
self.task_keys: set[str] = set(t.task_key for t in tasks if isinstance(t, Task)) # pyright: ignore[reportUnnecessaryIsInstance]
|
104
109
|
|
105
110
|
self._started_at: Optional[pendulum.DateTime] = None
|
106
111
|
self.stopping: bool = False
|
@@ -154,7 +159,7 @@ class TaskWorker:
|
|
154
159
|
def available_tasks(self) -> Optional[int]:
|
155
160
|
return int(self._limiter.available_tokens) if self._limiter else None
|
156
161
|
|
157
|
-
def handle_sigterm(self, signum: int, frame: object):
|
162
|
+
def handle_sigterm(self, signum: int, frame: object) -> None:
|
158
163
|
"""
|
159
164
|
Shuts down the task worker when a SIGTERM is received.
|
160
165
|
"""
|
@@ -355,14 +360,14 @@ class TaskWorker:
|
|
355
360
|
)
|
356
361
|
await asyncio.wrap_future(future)
|
357
362
|
|
358
|
-
async def execute_task_run(self, task_run: TaskRun):
|
363
|
+
async def execute_task_run(self, task_run: TaskRun) -> None:
|
359
364
|
"""Execute a task run in the task worker."""
|
360
365
|
async with self if not self.started else asyncnullcontext():
|
361
366
|
token_acquired = await self._acquire_token(task_run.id)
|
362
367
|
if token_acquired:
|
363
368
|
await self._safe_submit_scheduled_task_run(task_run)
|
364
369
|
|
365
|
-
async def __aenter__(self):
|
370
|
+
async def __aenter__(self) -> Self:
|
366
371
|
logger.debug("Starting task worker...")
|
367
372
|
|
368
373
|
if self._client._closed: # pyright: ignore[reportPrivateUsage]
|