prefect-client 3.1.12__py3-none-any.whl → 3.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. prefect/_experimental/sla/client.py +53 -27
  2. prefect/_experimental/sla/objects.py +10 -2
  3. prefect/_internal/concurrency/services.py +2 -2
  4. prefect/_internal/concurrency/threads.py +6 -0
  5. prefect/_internal/retries.py +6 -3
  6. prefect/_internal/schemas/validators.py +6 -4
  7. prefect/_version.py +3 -3
  8. prefect/artifacts.py +4 -1
  9. prefect/automations.py +1 -1
  10. prefect/blocks/abstract.py +5 -2
  11. prefect/blocks/notifications.py +1 -0
  12. prefect/cache_policies.py +20 -20
  13. prefect/client/utilities.py +3 -3
  14. prefect/deployments/base.py +7 -4
  15. prefect/deployments/flow_runs.py +5 -1
  16. prefect/deployments/runner.py +6 -11
  17. prefect/deployments/steps/core.py +1 -1
  18. prefect/deployments/steps/pull.py +8 -3
  19. prefect/deployments/steps/utility.py +2 -2
  20. prefect/docker/docker_image.py +13 -9
  21. prefect/engine.py +19 -10
  22. prefect/events/cli/automations.py +4 -4
  23. prefect/events/clients.py +17 -14
  24. prefect/events/schemas/automations.py +12 -8
  25. prefect/events/schemas/events.py +5 -1
  26. prefect/events/worker.py +1 -1
  27. prefect/filesystems.py +1 -1
  28. prefect/flow_engine.py +17 -9
  29. prefect/flows.py +118 -73
  30. prefect/futures.py +14 -7
  31. prefect/infrastructure/provisioners/__init__.py +2 -0
  32. prefect/infrastructure/provisioners/cloud_run.py +4 -4
  33. prefect/infrastructure/provisioners/coiled.py +249 -0
  34. prefect/infrastructure/provisioners/container_instance.py +4 -3
  35. prefect/infrastructure/provisioners/ecs.py +55 -43
  36. prefect/infrastructure/provisioners/modal.py +5 -4
  37. prefect/input/actions.py +5 -1
  38. prefect/input/run_input.py +157 -43
  39. prefect/logging/configuration.py +3 -3
  40. prefect/logging/filters.py +2 -2
  41. prefect/logging/formatters.py +15 -11
  42. prefect/logging/handlers.py +24 -14
  43. prefect/logging/highlighters.py +5 -5
  44. prefect/logging/loggers.py +28 -18
  45. prefect/main.py +3 -1
  46. prefect/results.py +166 -86
  47. prefect/runner/runner.py +34 -27
  48. prefect/runner/server.py +3 -1
  49. prefect/runner/storage.py +18 -18
  50. prefect/runner/submit.py +19 -12
  51. prefect/runtime/deployment.py +15 -8
  52. prefect/runtime/flow_run.py +19 -6
  53. prefect/runtime/task_run.py +7 -3
  54. prefect/settings/base.py +17 -7
  55. prefect/settings/legacy.py +4 -4
  56. prefect/settings/models/api.py +4 -3
  57. prefect/settings/models/cli.py +4 -3
  58. prefect/settings/models/client.py +7 -4
  59. prefect/settings/models/cloud.py +4 -3
  60. prefect/settings/models/deployments.py +4 -3
  61. prefect/settings/models/experiments.py +4 -3
  62. prefect/settings/models/flows.py +4 -3
  63. prefect/settings/models/internal.py +4 -3
  64. prefect/settings/models/logging.py +8 -6
  65. prefect/settings/models/results.py +4 -3
  66. prefect/settings/models/root.py +11 -16
  67. prefect/settings/models/runner.py +8 -5
  68. prefect/settings/models/server/api.py +6 -3
  69. prefect/settings/models/server/database.py +120 -25
  70. prefect/settings/models/server/deployments.py +4 -3
  71. prefect/settings/models/server/ephemeral.py +7 -4
  72. prefect/settings/models/server/events.py +6 -3
  73. prefect/settings/models/server/flow_run_graph.py +4 -3
  74. prefect/settings/models/server/root.py +4 -3
  75. prefect/settings/models/server/services.py +15 -12
  76. prefect/settings/models/server/tasks.py +7 -4
  77. prefect/settings/models/server/ui.py +4 -3
  78. prefect/settings/models/tasks.py +10 -5
  79. prefect/settings/models/testing.py +4 -3
  80. prefect/settings/models/worker.py +7 -4
  81. prefect/settings/profiles.py +13 -12
  82. prefect/settings/sources.py +20 -19
  83. prefect/states.py +17 -13
  84. prefect/task_engine.py +43 -33
  85. prefect/task_runners.py +35 -23
  86. prefect/task_runs.py +20 -11
  87. prefect/task_worker.py +12 -7
  88. prefect/tasks.py +30 -24
  89. prefect/telemetry/bootstrap.py +4 -1
  90. prefect/telemetry/run_telemetry.py +15 -13
  91. prefect/transactions.py +3 -3
  92. prefect/types/__init__.py +3 -1
  93. prefect/utilities/_deprecated.py +38 -0
  94. prefect/utilities/engine.py +11 -4
  95. prefect/utilities/filesystem.py +2 -2
  96. prefect/utilities/generics.py +1 -1
  97. prefect/utilities/pydantic.py +21 -36
  98. prefect/workers/base.py +52 -30
  99. prefect/workers/process.py +20 -15
  100. prefect/workers/server.py +4 -5
  101. {prefect_client-3.1.12.dist-info → prefect_client-3.1.13.dist-info}/METADATA +2 -2
  102. {prefect_client-3.1.12.dist-info → prefect_client-3.1.13.dist-info}/RECORD +105 -103
  103. {prefect_client-3.1.12.dist-info → prefect_client-3.1.13.dist-info}/LICENSE +0 -0
  104. {prefect_client-3.1.12.dist-info → prefect_client-3.1.13.dist-info}/WHEEL +0 -0
  105. {prefect_client-3.1.12.dist-info → prefect_client-3.1.13.dist-info}/top_level.txt +0 -0
prefect/task_engine.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import asyncio
2
4
  import inspect
3
5
  import logging
@@ -28,8 +30,9 @@ from uuid import UUID
28
30
  import anyio
29
31
  import pendulum
30
32
  from opentelemetry import trace
31
- from typing_extensions import ParamSpec
33
+ from typing_extensions import ParamSpec, Self
32
34
 
35
+ from prefect.cache_policies import CachePolicy
33
36
  from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client
34
37
  from prefect.client.schemas import TaskRun
35
38
  from prefect.client.schemas.objects import State, TaskRunInput
@@ -55,7 +58,7 @@ from prefect.exceptions import (
55
58
  from prefect.logging.loggers import get_logger, patch_print, task_run_logger
56
59
  from prefect.results import (
57
60
  ResultRecord,
58
- _format_user_supplied_storage_key,
61
+ _format_user_supplied_storage_key, # type: ignore[reportPrivateUsage]
59
62
  get_result_store,
60
63
  should_persist_result,
61
64
  )
@@ -115,20 +118,20 @@ class BaseTaskRunEngine(Generic[P, R]):
115
118
  # holds the return value from the user code
116
119
  _return_value: Union[R, Type[NotSet]] = NotSet
117
120
  # holds the exception raised by the user code, if any
118
- _raised: Union[Exception, Type[NotSet]] = NotSet
121
+ _raised: Union[Exception, BaseException, Type[NotSet]] = NotSet
119
122
  _initial_run_context: Optional[TaskRunContext] = None
120
123
  _is_started: bool = False
121
124
  _task_name_set: bool = False
122
125
  _last_event: Optional[PrefectEvent] = None
123
126
  _telemetry: RunTelemetry = field(default_factory=RunTelemetry)
124
127
 
125
- def __post_init__(self):
128
+ def __post_init__(self) -> None:
126
129
  if self.parameters is None:
127
130
  self.parameters = {}
128
131
 
129
132
  @property
130
133
  def state(self) -> State:
131
- if not self.task_run:
134
+ if not self.task_run or not self.task_run.state:
132
135
  raise ValueError("Task run is not set")
133
136
  return self.task_run.state
134
137
 
@@ -142,8 +145,8 @@ class BaseTaskRunEngine(Generic[P, R]):
142
145
  return False
143
146
 
144
147
  def compute_transaction_key(self) -> Optional[str]:
145
- key = None
146
- if self.task.cache_policy:
148
+ key: Optional[str] = None
149
+ if self.task.cache_policy and isinstance(self.task.cache_policy, CachePolicy):
147
150
  flow_run_context = FlowRunContext.get()
148
151
  task_run_context = TaskRunContext.get()
149
152
 
@@ -153,10 +156,12 @@ class BaseTaskRunEngine(Generic[P, R]):
153
156
  parameters = None
154
157
 
155
158
  try:
159
+ if not task_run_context:
160
+ raise ValueError("Task run context is not set")
156
161
  key = self.task.cache_policy.compute_key(
157
162
  task_ctx=task_run_context,
158
- inputs=self.parameters,
159
- flow_parameters=parameters,
163
+ inputs=self.parameters or {},
164
+ flow_parameters=parameters or {},
160
165
  )
161
166
  except Exception:
162
167
  self.logger.exception(
@@ -169,7 +174,7 @@ class BaseTaskRunEngine(Generic[P, R]):
169
174
 
170
175
  def _resolve_parameters(self):
171
176
  if not self.parameters:
172
- return {}
177
+ return None
173
178
 
174
179
  resolved_parameters = {}
175
180
  for parameter, value in self.parameters.items():
@@ -227,10 +232,8 @@ class BaseTaskRunEngine(Generic[P, R]):
227
232
  if self.task_run and self.task_run.start_time and not self.task_run.end_time:
228
233
  self.task_run.end_time = state.timestamp
229
234
 
230
- if self.task_run.state.is_running():
231
- self.task_run.total_run_time += (
232
- state.timestamp - self.task_run.state.timestamp
233
- )
235
+ if self.state.is_running():
236
+ self.task_run.total_run_time += state.timestamp - self.state.timestamp
234
237
 
235
238
  def is_running(self) -> bool:
236
239
  """Whether or not the engine is currently running a task."""
@@ -238,7 +241,7 @@ class BaseTaskRunEngine(Generic[P, R]):
238
241
  return False
239
242
  return task_run.state.is_running() or task_run.state.is_scheduled()
240
243
 
241
- def log_finished_message(self):
244
+ def log_finished_message(self) -> None:
242
245
  if not self.task_run:
243
246
  return
244
247
 
@@ -294,6 +297,7 @@ class BaseTaskRunEngine(Generic[P, R]):
294
297
 
295
298
  @dataclass
296
299
  class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
300
+ task_run: Optional[TaskRun] = None
297
301
  _client: Optional[SyncPrefectClient] = None
298
302
 
299
303
  @property
@@ -336,7 +340,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
336
340
  )
337
341
  return False
338
342
 
339
- def call_hooks(self, state: Optional[State] = None):
343
+ def call_hooks(self, state: Optional[State] = None) -> None:
340
344
  if state is None:
341
345
  state = self.state
342
346
  task = self.task
@@ -371,7 +375,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
371
375
  else:
372
376
  self.logger.info(f"Hook {hook_name!r} finished running successfully")
373
377
 
374
- def begin_run(self):
378
+ def begin_run(self) -> None:
375
379
  try:
376
380
  self._resolve_parameters()
377
381
  self._set_custom_task_run_name()
@@ -390,6 +394,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
390
394
 
391
395
  new_state = Running()
392
396
 
397
+ assert self.task_run is not None, "Task run is not set"
393
398
  self.task_run.start_time = new_state.timestamp
394
399
 
395
400
  flow_run_context = FlowRunContext.get()
@@ -406,7 +411,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
406
411
  # result reference that no longer exists
407
412
  if state.is_completed():
408
413
  try:
409
- state.result(retry_result_failure=False, _sync=True)
414
+ state.result(retry_result_failure=False, _sync=True) # type: ignore[reportCallIssue]
410
415
  except Exception:
411
416
  state = self.set_state(new_state, force=True)
412
417
 
@@ -422,7 +427,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
422
427
  time.sleep(interval)
423
428
  state = self.set_state(new_state)
424
429
 
425
- def set_state(self, state: State, force: bool = False) -> State:
430
+ def set_state(self, state: State[R], force: bool = False) -> State[R]:
426
431
  last_state = self.state
427
432
  if not self.task_run:
428
433
  raise ValueError("Task run is not set")
@@ -537,7 +542,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
537
542
  new_state = Retrying()
538
543
 
539
544
  self.logger.info(
540
- "Task run failed with exception: %r - " "Retry %s/%s will start %s",
545
+ "Task run failed with exception: %r - Retry %s/%s will start %s",
541
546
  exc,
542
547
  self.retries + 1,
543
548
  self.task.retries,
@@ -545,7 +550,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
545
550
  )
546
551
 
547
552
  self.set_state(new_state, force=True)
548
- self.retries = self.retries + 1
553
+ self.retries: int = self.retries + 1
549
554
  return True
550
555
  elif self.retries >= self.task.retries:
551
556
  self.logger.error(
@@ -639,7 +644,9 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
639
644
  stack.enter_context(ConcurrencyContextV1())
640
645
  stack.enter_context(ConcurrencyContext())
641
646
 
642
- self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
647
+ self.logger: "logging.Logger" = task_run_logger(
648
+ task_run=self.task_run, task=self.task
649
+ ) # type: ignore
643
650
 
644
651
  yield
645
652
 
@@ -648,7 +655,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
648
655
  self,
649
656
  task_run_id: Optional[UUID] = None,
650
657
  dependencies: Optional[dict[str, set[TaskRunInput]]] = None,
651
- ) -> Generator["SyncTaskRunEngine", Any, Any]:
658
+ ) -> Generator[Self, Any, Any]:
652
659
  """
653
660
  Enters a client context and creates a task run if needed.
654
661
  """
@@ -718,7 +725,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
718
725
  self._is_started = False
719
726
  self._client = None
720
727
 
721
- async def wait_until_ready(self):
728
+ async def wait_until_ready(self) -> None:
722
729
  """Waits until the scheduled time (if its the future), then enters Running."""
723
730
  if scheduled_time := self.state.state_details.scheduled_time:
724
731
  sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
@@ -825,6 +832,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
825
832
 
826
833
  @dataclass
827
834
  class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
835
+ task_run: TaskRun | None = None
828
836
  _client: Optional[PrefectClient] = None
829
837
 
830
838
  @property
@@ -866,7 +874,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
866
874
  )
867
875
  return False
868
876
 
869
- async def call_hooks(self, state: Optional[State] = None):
877
+ async def call_hooks(self, state: Optional[State] = None) -> None:
870
878
  if state is None:
871
879
  state = self.state
872
880
  task = self.task
@@ -901,7 +909,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
901
909
  else:
902
910
  self.logger.info(f"Hook {hook_name!r} finished running successfully")
903
911
 
904
- async def begin_run(self):
912
+ async def begin_run(self) -> None:
905
913
  try:
906
914
  self._resolve_parameters()
907
915
  self._set_custom_task_run_name()
@@ -1067,7 +1075,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1067
1075
  new_state = Retrying()
1068
1076
 
1069
1077
  self.logger.info(
1070
- "Task run failed with exception: %r - " "Retry %s/%s will start %s",
1078
+ "Task run failed with exception: %r - Retry %s/%s will start %s",
1071
1079
  exc,
1072
1080
  self.retries + 1,
1073
1081
  self.task.retries,
@@ -1075,7 +1083,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1075
1083
  )
1076
1084
 
1077
1085
  await self.set_state(new_state, force=True)
1078
- self.retries = self.retries + 1
1086
+ self.retries: int = self.retries + 1
1079
1087
  return True
1080
1088
  elif self.retries >= self.task.retries:
1081
1089
  self.logger.error(
@@ -1169,7 +1177,9 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1169
1177
  )
1170
1178
  stack.enter_context(ConcurrencyContext())
1171
1179
 
1172
- self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
1180
+ self.logger: "logging.Logger" = task_run_logger(
1181
+ task_run=self.task_run, task=self.task
1182
+ ) # type: ignore
1173
1183
 
1174
1184
  yield
1175
1185
 
@@ -1178,7 +1188,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1178
1188
  self,
1179
1189
  task_run_id: Optional[UUID] = None,
1180
1190
  dependencies: Optional[dict[str, set[TaskRunInput]]] = None,
1181
- ) -> AsyncGenerator["AsyncTaskRunEngine", Any]:
1191
+ ) -> AsyncGenerator[Self, Any]:
1182
1192
  """
1183
1193
  Enters a client context and creates a task run if needed.
1184
1194
  """
@@ -1246,7 +1256,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1246
1256
  self._is_started = False
1247
1257
  self._client = None
1248
1258
 
1249
- async def wait_until_ready(self):
1259
+ async def wait_until_ready(self) -> None:
1250
1260
  """Waits until the scheduled time (if its the future), then enters Running."""
1251
1261
  if scheduled_time := self.state.state_details.scheduled_time:
1252
1262
  sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
@@ -1341,7 +1351,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1341
1351
  if transaction.is_committed():
1342
1352
  result = transaction.read()
1343
1353
  else:
1344
- if self.task_run.tags:
1354
+ if self.task_run and self.task_run.tags:
1345
1355
  # Acquire a concurrency slot for each tag, but only if a limit
1346
1356
  # matching the tag already exists.
1347
1357
  async with aconcurrency(list(self.task_run.tags), self.task_run.id):
@@ -1546,7 +1556,7 @@ def run_task(
1546
1556
  Returns:
1547
1557
  The result of the task run
1548
1558
  """
1549
- kwargs = dict(
1559
+ kwargs: dict[str, Any] = dict(
1550
1560
  task=task,
1551
1561
  task_run_id=task_run_id,
1552
1562
  task_run=task_run,
prefect/task_runners.py CHANGED
@@ -40,6 +40,8 @@ from prefect.utilities.callables import (
40
40
  from prefect.utilities.collections import isiterable
41
41
 
42
42
  if TYPE_CHECKING:
43
+ import logging
44
+
43
45
  from prefect.tasks import Task
44
46
 
45
47
  P = ParamSpec("P")
@@ -61,11 +63,11 @@ class TaskRunner(abc.ABC, Generic[F]):
61
63
  """
62
64
 
63
65
  def __init__(self):
64
- self.logger = get_logger(f"task_runner.{self.name}")
66
+ self.logger: "logging.Logger" = get_logger(f"task_runner.{self.name}")
65
67
  self._started = False
66
68
 
67
69
  @property
68
- def name(self):
70
+ def name(self) -> str:
69
71
  """The name of this task runner"""
70
72
  return type(self).__name__.lower().replace("taskrunner", "")
71
73
 
@@ -74,32 +76,42 @@ class TaskRunner(abc.ABC, Generic[F]):
74
76
  """Return a new instance of this task runner with the same configuration."""
75
77
  ...
76
78
 
79
+ @overload
77
80
  @abc.abstractmethod
78
81
  def submit(
79
82
  self,
80
- task: "Task[P, R]",
83
+ task: "Task[P, Coroutine[Any, Any, R]]",
81
84
  parameters: dict[str, Any],
82
85
  wait_for: Iterable[PrefectFuture[Any]] | None = None,
83
86
  dependencies: dict[str, set[TaskRunInput]] | None = None,
84
87
  ) -> F:
85
- """
86
- Submit a task to the task run engine.
88
+ ...
87
89
 
88
- Args:
89
- task: The task to submit.
90
- parameters: The parameters to use when running the task.
91
- wait_for: A list of futures that the task depends on.
90
+ @overload
91
+ @abc.abstractmethod
92
+ def submit(
93
+ self,
94
+ task: "Task[Any, R]",
95
+ parameters: dict[str, Any],
96
+ wait_for: Iterable[PrefectFuture[Any]] | None = None,
97
+ dependencies: dict[str, set[TaskRunInput]] | None = None,
98
+ ) -> F:
99
+ ...
92
100
 
93
- Returns:
94
- A future object that can be used to wait for the task to complete and
95
- retrieve the result.
96
- """
101
+ @abc.abstractmethod
102
+ def submit(
103
+ self,
104
+ task: "Task[P, R]",
105
+ parameters: dict[str, Any],
106
+ wait_for: Iterable[PrefectFuture[Any]] | None = None,
107
+ dependencies: dict[str, set[TaskRunInput]] | None = None,
108
+ ) -> F:
97
109
  ...
98
110
 
99
111
  def map(
100
112
  self,
101
113
  task: "Task[P, R]",
102
- parameters: dict[str, Any],
114
+ parameters: dict[str, Any | unmapped[Any] | allow_failure[Any]],
103
115
  wait_for: Optional[Iterable[PrefectFuture[R]]] = None,
104
116
  ) -> PrefectFutureList[F]:
105
117
  """
@@ -205,7 +217,7 @@ class TaskRunner(abc.ABC, Generic[F]):
205
217
 
206
218
  return PrefectFutureList(futures)
207
219
 
208
- def __enter__(self):
220
+ def __enter__(self) -> Self:
209
221
  if self._started:
210
222
  raise RuntimeError("This task runner is already started")
211
223
 
@@ -213,12 +225,12 @@ class TaskRunner(abc.ABC, Generic[F]):
213
225
  self._started = True
214
226
  return self
215
227
 
216
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
228
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
217
229
  self.logger.debug("Stopping task runner")
218
230
  self._started = False
219
231
 
220
232
 
221
- class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[Any]]):
233
+ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[R]]):
222
234
  def __init__(self, max_workers: Optional[int] = None):
223
235
  super().__init__()
224
236
  self._executor: Optional[ThreadPoolExecutor] = None
@@ -229,7 +241,7 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[Any]]):
229
241
  )
230
242
  self._cancel_events: Dict[uuid.UUID, threading.Event] = {}
231
243
 
232
- def duplicate(self) -> "ThreadPoolTaskRunner":
244
+ def duplicate(self) -> "ThreadPoolTaskRunner[R]":
233
245
  return type(self)(max_workers=self._max_workers)
234
246
 
235
247
  @overload
@@ -254,7 +266,7 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[Any]]):
254
266
 
255
267
  def submit(
256
268
  self,
257
- task: "Task[P, R]",
269
+ task: "Task[P, R | Coroutine[Any, Any, R]]",
258
270
  parameters: dict[str, Any],
259
271
  wait_for: Iterable[PrefectFuture[Any]] | None = None,
260
272
  dependencies: dict[str, set[TaskRunInput]] | None = None,
@@ -345,7 +357,7 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[Any]]):
345
357
  ) -> PrefectFutureList[PrefectConcurrentFuture[R]]:
346
358
  return super().map(task, parameters, wait_for)
347
359
 
348
- def cancel_all(self):
360
+ def cancel_all(self) -> None:
349
361
  for event in self._cancel_events.values():
350
362
  event.set()
351
363
  self.logger.debug("Set cancel event")
@@ -354,12 +366,12 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[Any]]):
354
366
  self._executor.shutdown(cancel_futures=True)
355
367
  self._executor = None
356
368
 
357
- def __enter__(self):
369
+ def __enter__(self) -> Self:
358
370
  super().__enter__()
359
371
  self._executor = ThreadPoolExecutor(max_workers=self._max_workers)
360
372
  return self
361
373
 
362
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
374
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
363
375
  self.cancel_all()
364
376
  if self._executor is not None:
365
377
  self._executor.shutdown(cancel_futures=True)
@@ -380,7 +392,7 @@ class PrefectTaskRunner(TaskRunner[PrefectDistributedFuture[R]]):
380
392
  def __init__(self):
381
393
  super().__init__()
382
394
 
383
- def duplicate(self) -> "PrefectTaskRunner":
395
+ def duplicate(self) -> "PrefectTaskRunner[R]":
384
396
  return type(self)()
385
397
 
386
398
  @overload
prefect/task_runs.py CHANGED
@@ -1,8 +1,10 @@
1
+ from __future__ import annotations
2
+
1
3
  import asyncio
2
4
  import atexit
3
5
  import threading
4
6
  import uuid
5
- from typing import Callable, Dict, Optional
7
+ from typing import TYPE_CHECKING, Callable, Dict, Optional
6
8
 
7
9
  import anyio
8
10
  from cachetools import TTLCache
@@ -15,6 +17,9 @@ from prefect.events.clients import get_events_subscriber
15
17
  from prefect.events.filters import EventFilter, EventNameFilter
16
18
  from prefect.logging.loggers import get_logger
17
19
 
20
+ if TYPE_CHECKING:
21
+ import logging
22
+
18
23
 
19
24
  class TaskRunWaiter:
20
25
  """
@@ -68,19 +73,19 @@ class TaskRunWaiter:
68
73
  _instance_lock = threading.Lock()
69
74
 
70
75
  def __init__(self):
71
- self.logger = get_logger("TaskRunWaiter")
72
- self._consumer_task: Optional[asyncio.Task] = None
76
+ self.logger: "logging.Logger" = get_logger("TaskRunWaiter")
77
+ self._consumer_task: "asyncio.Task[None] | None" = None
73
78
  self._observed_completed_task_runs: TTLCache[uuid.UUID, bool] = TTLCache(
74
79
  maxsize=10000, ttl=600
75
80
  )
76
81
  self._completion_events: Dict[uuid.UUID, asyncio.Event] = {}
77
- self._completion_callbacks: Dict[uuid.UUID, Callable] = {}
82
+ self._completion_callbacks: Dict[uuid.UUID, Callable[[], None]] = {}
78
83
  self._loop: Optional[asyncio.AbstractEventLoop] = None
79
84
  self._observed_completed_task_runs_lock = threading.Lock()
80
85
  self._completion_events_lock = threading.Lock()
81
86
  self._started = False
82
87
 
83
- def start(self):
88
+ def start(self) -> None:
84
89
  """
85
90
  Start the TaskRunWaiter service.
86
91
  """
@@ -89,10 +94,12 @@ class TaskRunWaiter:
89
94
  self.logger.debug("Starting TaskRunWaiter")
90
95
  loop_thread = get_global_loop()
91
96
 
92
- if not asyncio.get_running_loop() == loop_thread._loop:
97
+ if not asyncio.get_running_loop() == loop_thread.loop:
93
98
  raise RuntimeError("TaskRunWaiter must run on the global loop thread.")
94
99
 
95
- self._loop = loop_thread._loop
100
+ self._loop = loop_thread.loop
101
+ if TYPE_CHECKING:
102
+ assert self._loop is not None
96
103
 
97
104
  consumer_started = asyncio.Event()
98
105
  self._consumer_task = self._loop.create_task(
@@ -141,7 +148,7 @@ class TaskRunWaiter:
141
148
  except Exception as exc:
142
149
  self.logger.error(f"Error processing event: {exc}")
143
150
 
144
- def stop(self):
151
+ def stop(self) -> None:
145
152
  """
146
153
  Stop the TaskRunWaiter service.
147
154
  """
@@ -155,7 +162,7 @@ class TaskRunWaiter:
155
162
  @classmethod
156
163
  async def wait_for_task_run(
157
164
  cls, task_run_id: uuid.UUID, timeout: Optional[float] = None
158
- ):
165
+ ) -> None:
159
166
  """
160
167
  Wait for a task run to finish.
161
168
 
@@ -199,7 +206,9 @@ class TaskRunWaiter:
199
206
  instance._completion_events.pop(task_run_id, None)
200
207
 
201
208
  @classmethod
202
- def add_done_callback(cls, task_run_id: uuid.UUID, callback):
209
+ def add_done_callback(
210
+ cls, task_run_id: uuid.UUID, callback: Callable[[], None]
211
+ ) -> None:
203
212
  """
204
213
  Add a callback to be called when a task run finishes.
205
214
 
@@ -219,7 +228,7 @@ class TaskRunWaiter:
219
228
  instance._completion_callbacks[task_run_id] = callback
220
229
 
221
230
  @classmethod
222
- def instance(cls):
231
+ def instance(cls) -> Self:
223
232
  """
224
233
  Get the singleton instance of TaskRunWaiter.
225
234
  """
prefect/task_worker.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import asyncio
2
4
  import inspect
3
5
  import os
@@ -16,7 +18,7 @@ import pendulum
16
18
  import uvicorn
17
19
  from exceptiongroup import BaseExceptionGroup # novermin
18
20
  from fastapi import FastAPI
19
- from typing_extensions import ParamSpec, TypeVar
21
+ from typing_extensions import ParamSpec, Self, TypeVar
20
22
  from websockets.exceptions import InvalidStatusCode
21
23
 
22
24
  from prefect import Task
@@ -42,7 +44,10 @@ from prefect.utilities.processutils import (
42
44
  from prefect.utilities.services import start_client_metrics_server
43
45
  from prefect.utilities.urls import url_for
44
46
 
45
- logger = get_logger("task_worker")
47
+ if TYPE_CHECKING:
48
+ import logging
49
+
50
+ logger: "logging.Logger" = get_logger("task_worker")
46
51
 
47
52
  P = ParamSpec("P")
48
53
  R = TypeVar("R", infer_variance=True)
@@ -85,7 +90,7 @@ class TaskWorker:
85
90
  def __init__(
86
91
  self,
87
92
  *tasks: Task[P, R],
88
- limit: Optional[int] = 10,
93
+ limit: int | None = 10,
89
94
  ):
90
95
  self.tasks: list["Task[..., Any]"] = []
91
96
  for t in tasks:
@@ -100,7 +105,7 @@ class TaskWorker:
100
105
  else:
101
106
  self.tasks.append(t.with_options(persist_result=True))
102
107
 
103
- self.task_keys = set(t.task_key for t in tasks if isinstance(t, Task)) # pyright: ignore[reportUnnecessaryIsInstance]
108
+ self.task_keys: set[str] = set(t.task_key for t in tasks if isinstance(t, Task)) # pyright: ignore[reportUnnecessaryIsInstance]
104
109
 
105
110
  self._started_at: Optional[pendulum.DateTime] = None
106
111
  self.stopping: bool = False
@@ -154,7 +159,7 @@ class TaskWorker:
154
159
  def available_tasks(self) -> Optional[int]:
155
160
  return int(self._limiter.available_tokens) if self._limiter else None
156
161
 
157
- def handle_sigterm(self, signum: int, frame: object):
162
+ def handle_sigterm(self, signum: int, frame: object) -> None:
158
163
  """
159
164
  Shuts down the task worker when a SIGTERM is received.
160
165
  """
@@ -355,14 +360,14 @@ class TaskWorker:
355
360
  )
356
361
  await asyncio.wrap_future(future)
357
362
 
358
- async def execute_task_run(self, task_run: TaskRun):
363
+ async def execute_task_run(self, task_run: TaskRun) -> None:
359
364
  """Execute a task run in the task worker."""
360
365
  async with self if not self.started else asyncnullcontext():
361
366
  token_acquired = await self._acquire_token(task_run.id)
362
367
  if token_acquired:
363
368
  await self._safe_submit_scheduled_task_run(task_run)
364
369
 
365
- async def __aenter__(self):
370
+ async def __aenter__(self) -> Self:
366
371
  logger.debug("Starting task worker...")
367
372
 
368
373
  if self._client._closed: # pyright: ignore[reportPrivateUsage]