prefect-client 3.1.12__py3-none-any.whl → 3.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. prefect/_experimental/lineage.py +63 -0
  2. prefect/_experimental/sla/client.py +53 -27
  3. prefect/_experimental/sla/objects.py +10 -2
  4. prefect/_internal/concurrency/services.py +2 -2
  5. prefect/_internal/concurrency/threads.py +6 -0
  6. prefect/_internal/retries.py +6 -3
  7. prefect/_internal/schemas/validators.py +6 -4
  8. prefect/_version.py +3 -3
  9. prefect/artifacts.py +4 -1
  10. prefect/automations.py +1 -1
  11. prefect/blocks/abstract.py +5 -2
  12. prefect/blocks/notifications.py +1 -0
  13. prefect/cache_policies.py +70 -22
  14. prefect/client/orchestration/_automations/client.py +4 -0
  15. prefect/client/orchestration/_deployments/client.py +3 -3
  16. prefect/client/utilities.py +3 -3
  17. prefect/context.py +16 -6
  18. prefect/deployments/base.py +7 -4
  19. prefect/deployments/flow_runs.py +5 -1
  20. prefect/deployments/runner.py +6 -11
  21. prefect/deployments/steps/core.py +1 -1
  22. prefect/deployments/steps/pull.py +8 -3
  23. prefect/deployments/steps/utility.py +2 -2
  24. prefect/docker/docker_image.py +13 -9
  25. prefect/engine.py +19 -10
  26. prefect/events/cli/automations.py +4 -4
  27. prefect/events/clients.py +17 -14
  28. prefect/events/filters.py +34 -34
  29. prefect/events/schemas/automations.py +12 -8
  30. prefect/events/schemas/events.py +5 -1
  31. prefect/events/worker.py +1 -1
  32. prefect/filesystems.py +1 -1
  33. prefect/flow_engine.py +172 -123
  34. prefect/flows.py +119 -74
  35. prefect/futures.py +14 -7
  36. prefect/infrastructure/provisioners/__init__.py +2 -0
  37. prefect/infrastructure/provisioners/cloud_run.py +4 -4
  38. prefect/infrastructure/provisioners/coiled.py +249 -0
  39. prefect/infrastructure/provisioners/container_instance.py +4 -3
  40. prefect/infrastructure/provisioners/ecs.py +55 -43
  41. prefect/infrastructure/provisioners/modal.py +5 -4
  42. prefect/input/actions.py +5 -1
  43. prefect/input/run_input.py +157 -43
  44. prefect/logging/configuration.py +5 -8
  45. prefect/logging/filters.py +2 -2
  46. prefect/logging/formatters.py +15 -11
  47. prefect/logging/handlers.py +24 -14
  48. prefect/logging/highlighters.py +5 -5
  49. prefect/logging/loggers.py +29 -20
  50. prefect/main.py +3 -1
  51. prefect/results.py +166 -86
  52. prefect/runner/runner.py +112 -84
  53. prefect/runner/server.py +3 -1
  54. prefect/runner/storage.py +18 -18
  55. prefect/runner/submit.py +19 -12
  56. prefect/runtime/deployment.py +15 -8
  57. prefect/runtime/flow_run.py +19 -6
  58. prefect/runtime/task_run.py +7 -3
  59. prefect/settings/base.py +17 -7
  60. prefect/settings/legacy.py +4 -4
  61. prefect/settings/models/api.py +4 -3
  62. prefect/settings/models/cli.py +4 -3
  63. prefect/settings/models/client.py +7 -4
  64. prefect/settings/models/cloud.py +4 -3
  65. prefect/settings/models/deployments.py +4 -3
  66. prefect/settings/models/experiments.py +4 -3
  67. prefect/settings/models/flows.py +4 -3
  68. prefect/settings/models/internal.py +4 -3
  69. prefect/settings/models/logging.py +8 -6
  70. prefect/settings/models/results.py +4 -3
  71. prefect/settings/models/root.py +11 -16
  72. prefect/settings/models/runner.py +8 -5
  73. prefect/settings/models/server/api.py +6 -3
  74. prefect/settings/models/server/database.py +120 -25
  75. prefect/settings/models/server/deployments.py +4 -3
  76. prefect/settings/models/server/ephemeral.py +7 -4
  77. prefect/settings/models/server/events.py +6 -3
  78. prefect/settings/models/server/flow_run_graph.py +4 -3
  79. prefect/settings/models/server/root.py +4 -3
  80. prefect/settings/models/server/services.py +15 -12
  81. prefect/settings/models/server/tasks.py +7 -4
  82. prefect/settings/models/server/ui.py +4 -3
  83. prefect/settings/models/tasks.py +10 -5
  84. prefect/settings/models/testing.py +4 -3
  85. prefect/settings/models/worker.py +7 -4
  86. prefect/settings/profiles.py +13 -12
  87. prefect/settings/sources.py +20 -19
  88. prefect/states.py +17 -13
  89. prefect/task_engine.py +43 -33
  90. prefect/task_runners.py +35 -23
  91. prefect/task_runs.py +20 -11
  92. prefect/task_worker.py +12 -7
  93. prefect/tasks.py +67 -25
  94. prefect/telemetry/bootstrap.py +4 -1
  95. prefect/telemetry/run_telemetry.py +15 -13
  96. prefect/transactions.py +3 -3
  97. prefect/types/__init__.py +9 -6
  98. prefect/types/_datetime.py +19 -0
  99. prefect/utilities/_deprecated.py +38 -0
  100. prefect/utilities/engine.py +11 -4
  101. prefect/utilities/filesystem.py +2 -2
  102. prefect/utilities/generics.py +1 -1
  103. prefect/utilities/pydantic.py +21 -36
  104. prefect/workers/base.py +52 -30
  105. prefect/workers/process.py +20 -15
  106. prefect/workers/server.py +4 -5
  107. {prefect_client-3.1.12.dist-info → prefect_client-3.1.14.dist-info}/METADATA +2 -2
  108. {prefect_client-3.1.12.dist-info → prefect_client-3.1.14.dist-info}/RECORD +111 -108
  109. {prefect_client-3.1.12.dist-info → prefect_client-3.1.14.dist-info}/LICENSE +0 -0
  110. {prefect_client-3.1.12.dist-info → prefect_client-3.1.14.dist-info}/WHEEL +0 -0
  111. {prefect_client-3.1.12.dist-info → prefect_client-3.1.14.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,13 @@
1
1
  from typing import Any, ClassVar, Optional
2
2
 
3
- from pydantic import AliasChoices, AliasPath, ConfigDict, Field
3
+ from pydantic import AliasChoices, AliasPath, Field
4
+ from pydantic_settings import SettingsConfigDict
4
5
 
5
- from prefect.settings.base import PrefectBaseSettings, _build_settings_config
6
+ from prefect.settings.base import PrefectBaseSettings, build_settings_config
6
7
 
7
8
 
8
9
  class TestingSettings(PrefectBaseSettings):
9
- model_config: ClassVar[ConfigDict] = _build_settings_config(("testing",))
10
+ model_config: ClassVar[SettingsConfigDict] = build_settings_config(("testing",))
10
11
 
11
12
  test_mode: bool = Field(
12
13
  default=False,
@@ -1,12 +1,15 @@
1
1
  from typing import ClassVar
2
2
 
3
- from pydantic import ConfigDict, Field
3
+ from pydantic import Field
4
+ from pydantic_settings import SettingsConfigDict
4
5
 
5
- from prefect.settings.base import PrefectBaseSettings, _build_settings_config
6
+ from prefect.settings.base import PrefectBaseSettings, build_settings_config
6
7
 
7
8
 
8
9
  class WorkerWebserverSettings(PrefectBaseSettings):
9
- model_config: ClassVar[ConfigDict] = _build_settings_config(("worker", "webserver"))
10
+ model_config: ClassVar[SettingsConfigDict] = build_settings_config(
11
+ ("worker", "webserver")
12
+ )
10
13
 
11
14
  host: str = Field(
12
15
  default="0.0.0.0",
@@ -20,7 +23,7 @@ class WorkerWebserverSettings(PrefectBaseSettings):
20
23
 
21
24
 
22
25
  class WorkerSettings(PrefectBaseSettings):
23
- model_config: ClassVar[ConfigDict] = _build_settings_config(("worker",))
26
+ model_config: ClassVar[SettingsConfigDict] = build_settings_config(("worker",))
24
27
 
25
28
  heartbeat_seconds: float = Field(
26
29
  default=30,
@@ -7,10 +7,9 @@ from typing import (
7
7
  ClassVar,
8
8
  Dict,
9
9
  Iterable,
10
- List,
10
+ Iterator,
11
11
  Optional,
12
12
  Set,
13
- Tuple,
14
13
  Union,
15
14
  )
16
15
 
@@ -77,8 +76,8 @@ class Profile(BaseModel):
77
76
  if value is not None
78
77
  }
79
78
 
80
- def validate_settings(self):
81
- errors: List[Tuple[Setting, ValidationError]] = []
79
+ def validate_settings(self) -> None:
80
+ errors: list[tuple[Setting, ValidationError]] = []
82
81
  for setting, value in self.settings.items():
83
82
  try:
84
83
  model_fields = Settings.model_fields
@@ -109,7 +108,9 @@ class ProfilesCollection:
109
108
  def __init__(
110
109
  self, profiles: Iterable[Profile], active: Optional[str] = None
111
110
  ) -> None:
112
- self.profiles_by_name = {profile.name: profile for profile in profiles}
111
+ self.profiles_by_name: dict[str, Profile] = {
112
+ profile.name: profile for profile in profiles
113
+ }
113
114
  self.active_name = active
114
115
 
115
116
  @property
@@ -128,7 +129,7 @@ class ProfilesCollection:
128
129
  return None
129
130
  return self[self.active_name]
130
131
 
131
- def set_active(self, name: Optional[str], check: bool = True):
132
+ def set_active(self, name: Optional[str], check: bool = True) -> None:
132
133
  """
133
134
  Set the active profile name in the collection.
134
135
 
@@ -142,7 +143,7 @@ class ProfilesCollection:
142
143
  def update_profile(
143
144
  self,
144
145
  name: str,
145
- settings: Dict[Setting, Any],
146
+ settings: dict[Setting, Any],
146
147
  source: Optional[Path] = None,
147
148
  ) -> Profile:
148
149
  """
@@ -214,7 +215,7 @@ class ProfilesCollection:
214
215
  active=self.active_name,
215
216
  )
216
217
 
217
- def to_dict(self):
218
+ def to_dict(self) -> dict[str, Any]:
218
219
  """
219
220
  Convert to a dictionary suitable for writing to disk.
220
221
  """
@@ -229,11 +230,11 @@ class ProfilesCollection:
229
230
  def __getitem__(self, name: str) -> Profile:
230
231
  return self.profiles_by_name[name]
231
232
 
232
- def __iter__(self):
233
+ def __iter__(self) -> Iterator[str]:
233
234
  return self.profiles_by_name.__iter__()
234
235
 
235
- def items(self):
236
- return self.profiles_by_name.items()
236
+ def items(self) -> list[tuple[str, Profile]]:
237
+ return list(self.profiles_by_name.items())
237
238
 
238
239
  def __eq__(self, __o: object) -> bool:
239
240
  if not isinstance(__o, ProfilesCollection):
@@ -325,7 +326,7 @@ def load_profiles(include_defaults: bool = True) -> ProfilesCollection:
325
326
  return profiles
326
327
 
327
328
 
328
- def load_current_profile():
329
+ def load_current_profile() -> Profile:
329
330
  """
330
331
  Load the current profile from the default and current profile paths.
331
332
 
@@ -2,7 +2,7 @@ import os
2
2
  import sys
3
3
  import warnings
4
4
  from pathlib import Path
5
- from typing import Any, Dict, List, Optional, Tuple, Type
5
+ from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type
6
6
 
7
7
  import dotenv
8
8
  import toml
@@ -54,6 +54,7 @@ class EnvFilterSettingsSource(EnvSettingsSource):
54
54
  env_parse_none_str,
55
55
  env_parse_enums,
56
56
  )
57
+ self.env_vars: Mapping[str, str | None]
57
58
  if env_filter:
58
59
  if isinstance(self.env_vars, dict):
59
60
  for key in env_filter:
@@ -97,7 +98,7 @@ class FilteredDotEnvSettingsSource(DotEnvSettingsSource):
97
98
  for key in self.env_blacklist:
98
99
  self.env_vars.pop(key, None)
99
100
  else:
100
- self.env_vars = {
101
+ self.env_vars: dict[str, str | None] = {
101
102
  key: value
102
103
  for key, value in self.env_vars.items() # type: ignore
103
104
  if key.lower() not in env_blacklist
@@ -114,8 +115,8 @@ class ProfileSettingsTomlLoader(PydanticBaseSettingsSource):
114
115
  def __init__(self, settings_cls: Type[BaseSettings]):
115
116
  super().__init__(settings_cls)
116
117
  self.settings_cls = settings_cls
117
- self.profiles_path = _get_profiles_path()
118
- self.profile_settings = self._load_profile_settings()
118
+ self.profiles_path: Path = _get_profiles_path()
119
+ self.profile_settings: dict[str, Any] = self._load_profile_settings()
119
120
 
120
121
  def _load_profile_settings(self) -> Dict[str, Any]:
121
122
  """Helper method to load the profile settings from the profiles.toml file"""
@@ -213,14 +214,14 @@ class TomlConfigSettingsSourceBase(PydanticBaseSettingsSource, ConfigFileSourceM
213
214
  def __init__(self, settings_cls: Type[BaseSettings]):
214
215
  super().__init__(settings_cls)
215
216
  self.settings_cls = settings_cls
216
- self.toml_data = {}
217
+ self.toml_data: dict[str, Any] = {}
217
218
 
218
- def _read_file(self, path: Path) -> Dict[str, Any]:
219
+ def _read_file(self, path: Path) -> dict[str, Any]:
219
220
  return toml.load(path)
220
221
 
221
222
  def get_field_value(
222
223
  self, field: FieldInfo, field_name: str
223
- ) -> Tuple[Any, str, bool]:
224
+ ) -> tuple[Any, str, bool]:
224
225
  """Concrete implementation to get the field value from toml data"""
225
226
  value = self.toml_data.get(field_name)
226
227
  if isinstance(value, dict):
@@ -244,9 +245,9 @@ class TomlConfigSettingsSourceBase(PydanticBaseSettingsSource, ConfigFileSourceM
244
245
  break
245
246
  return value, name, self.field_is_complex(field)
246
247
 
247
- def __call__(self) -> Dict[str, Any]:
248
+ def __call__(self) -> dict[str, Any]:
248
249
  """Called by pydantic to get the settings from our custom source"""
249
- toml_setings: Dict[str, Any] = {}
250
+ toml_setings: dict[str, Any] = {}
250
251
  for field_name, field in self.settings_cls.model_fields.items():
251
252
  value, key, is_complex = self.get_field_value(field, field_name)
252
253
  if value is not None:
@@ -265,15 +266,15 @@ class PrefectTomlConfigSettingsSource(TomlConfigSettingsSourceBase):
265
266
  settings_cls: Type[BaseSettings],
266
267
  ):
267
268
  super().__init__(settings_cls)
268
- self.toml_file_path = settings_cls.model_config.get(
269
- "toml_file", DEFAULT_PREFECT_TOML_PATH
270
- )
271
- self.toml_data = self._read_files(self.toml_file_path)
272
- self.toml_table_header = settings_cls.model_config.get(
269
+ self.toml_file_path: Path | str | Sequence[
270
+ Path | str
271
+ ] | None = settings_cls.model_config.get("toml_file", DEFAULT_PREFECT_TOML_PATH)
272
+ self.toml_data: dict[str, Any] = self._read_files(self.toml_file_path)
273
+ self.toml_table_header: tuple[str, ...] = settings_cls.model_config.get(
273
274
  "prefect_toml_table_header", tuple()
274
275
  )
275
276
  for key in self.toml_table_header:
276
- self.toml_data = self.toml_data.get(key, {})
277
+ self.toml_data: dict[str, Any] = self.toml_data.get(key, {})
277
278
 
278
279
 
279
280
  class PyprojectTomlConfigSettingsSource(TomlConfigSettingsSourceBase):
@@ -284,13 +285,13 @@ class PyprojectTomlConfigSettingsSource(TomlConfigSettingsSourceBase):
284
285
  settings_cls: Type[BaseSettings],
285
286
  ):
286
287
  super().__init__(settings_cls)
287
- self.toml_file_path = Path("pyproject.toml")
288
- self.toml_data = self._read_files(self.toml_file_path)
289
- self.toml_table_header = settings_cls.model_config.get(
288
+ self.toml_file_path: Path = Path("pyproject.toml")
289
+ self.toml_data: dict[str, Any] = self._read_files(self.toml_file_path)
290
+ self.toml_table_header: tuple[str, ...] = settings_cls.model_config.get(
290
291
  "pyproject_toml_table_header", ("tool", "prefect")
291
292
  )
292
293
  for key in self.toml_table_header:
293
- self.toml_data = self.toml_data.get(key, {})
294
+ self.toml_data: dict[str, Any] = self.toml_data.get(key, {})
294
295
 
295
296
 
296
297
  def _is_test_mode() -> bool:
prefect/states.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import asyncio
2
4
  import datetime
3
5
  import sys
@@ -33,12 +35,14 @@ from prefect.utilities.asyncutils import in_async_main_thread, sync_compatible
33
35
  from prefect.utilities.collections import ensure_iterable
34
36
 
35
37
  if TYPE_CHECKING:
38
+ import logging
39
+
36
40
  from prefect.results import (
37
41
  R,
38
42
  ResultStore,
39
43
  )
40
44
 
41
- logger = get_logger("states")
45
+ logger: "logging.Logger" = get_logger("states")
42
46
 
43
47
 
44
48
  @deprecated.deprecated_parameter(
@@ -245,8 +249,8 @@ async def exception_to_failed_state(
245
249
  exc: Optional[BaseException] = None,
246
250
  result_store: Optional["ResultStore"] = None,
247
251
  write_result: bool = False,
248
- **kwargs,
249
- ) -> State:
252
+ **kwargs: Any,
253
+ ) -> State[BaseException]:
250
254
  """
251
255
  Convenience function for creating `Failed` states from exceptions
252
256
  """
@@ -553,17 +557,17 @@ def is_state_iterable(obj: Any) -> TypeGuard[Iterable[State]]:
553
557
 
554
558
 
555
559
  class StateGroup:
556
- def __init__(self, states: Iterable[State]) -> None:
557
- self.states = states
558
- self.type_counts = self._get_type_counts(states)
559
- self.total_count = len(states)
560
- self.cancelled_count = self.type_counts[StateType.CANCELLED]
561
- self.final_count = sum(state.is_final() for state in states)
562
- self.not_final_count = self.total_count - self.final_count
563
- self.paused_count = self.type_counts[StateType.PAUSED]
560
+ def __init__(self, states: list[State]) -> None:
561
+ self.states: list[State] = states
562
+ self.type_counts: dict[StateType, int] = self._get_type_counts(states)
563
+ self.total_count: int = len(states)
564
+ self.cancelled_count: int = self.type_counts[StateType.CANCELLED]
565
+ self.final_count: int = sum(state.is_final() for state in states)
566
+ self.not_final_count: int = self.total_count - self.final_count
567
+ self.paused_count: int = self.type_counts[StateType.PAUSED]
564
568
 
565
569
  @property
566
- def fail_count(self):
570
+ def fail_count(self) -> int:
567
571
  return self.type_counts[StateType.FAILED] + self.type_counts[StateType.CRASHED]
568
572
 
569
573
  def all_completed(self) -> bool:
@@ -741,7 +745,7 @@ def Suspended(
741
745
  pause_expiration_time: Optional[datetime.datetime] = None,
742
746
  pause_key: Optional[str] = None,
743
747
  **kwargs: Any,
744
- ):
748
+ ) -> "State[R]":
745
749
  """Convenience function for creating `Suspended` states.
746
750
 
747
751
  Returns:
prefect/task_engine.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import asyncio
2
4
  import inspect
3
5
  import logging
@@ -28,8 +30,9 @@ from uuid import UUID
28
30
  import anyio
29
31
  import pendulum
30
32
  from opentelemetry import trace
31
- from typing_extensions import ParamSpec
33
+ from typing_extensions import ParamSpec, Self
32
34
 
35
+ from prefect.cache_policies import CachePolicy
33
36
  from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client
34
37
  from prefect.client.schemas import TaskRun
35
38
  from prefect.client.schemas.objects import State, TaskRunInput
@@ -55,7 +58,7 @@ from prefect.exceptions import (
55
58
  from prefect.logging.loggers import get_logger, patch_print, task_run_logger
56
59
  from prefect.results import (
57
60
  ResultRecord,
58
- _format_user_supplied_storage_key,
61
+ _format_user_supplied_storage_key, # type: ignore[reportPrivateUsage]
59
62
  get_result_store,
60
63
  should_persist_result,
61
64
  )
@@ -115,20 +118,20 @@ class BaseTaskRunEngine(Generic[P, R]):
115
118
  # holds the return value from the user code
116
119
  _return_value: Union[R, Type[NotSet]] = NotSet
117
120
  # holds the exception raised by the user code, if any
118
- _raised: Union[Exception, Type[NotSet]] = NotSet
121
+ _raised: Union[Exception, BaseException, Type[NotSet]] = NotSet
119
122
  _initial_run_context: Optional[TaskRunContext] = None
120
123
  _is_started: bool = False
121
124
  _task_name_set: bool = False
122
125
  _last_event: Optional[PrefectEvent] = None
123
126
  _telemetry: RunTelemetry = field(default_factory=RunTelemetry)
124
127
 
125
- def __post_init__(self):
128
+ def __post_init__(self) -> None:
126
129
  if self.parameters is None:
127
130
  self.parameters = {}
128
131
 
129
132
  @property
130
133
  def state(self) -> State:
131
- if not self.task_run:
134
+ if not self.task_run or not self.task_run.state:
132
135
  raise ValueError("Task run is not set")
133
136
  return self.task_run.state
134
137
 
@@ -142,8 +145,8 @@ class BaseTaskRunEngine(Generic[P, R]):
142
145
  return False
143
146
 
144
147
  def compute_transaction_key(self) -> Optional[str]:
145
- key = None
146
- if self.task.cache_policy:
148
+ key: Optional[str] = None
149
+ if self.task.cache_policy and isinstance(self.task.cache_policy, CachePolicy):
147
150
  flow_run_context = FlowRunContext.get()
148
151
  task_run_context = TaskRunContext.get()
149
152
 
@@ -153,10 +156,12 @@ class BaseTaskRunEngine(Generic[P, R]):
153
156
  parameters = None
154
157
 
155
158
  try:
159
+ if not task_run_context:
160
+ raise ValueError("Task run context is not set")
156
161
  key = self.task.cache_policy.compute_key(
157
162
  task_ctx=task_run_context,
158
- inputs=self.parameters,
159
- flow_parameters=parameters,
163
+ inputs=self.parameters or {},
164
+ flow_parameters=parameters or {},
160
165
  )
161
166
  except Exception:
162
167
  self.logger.exception(
@@ -169,7 +174,7 @@ class BaseTaskRunEngine(Generic[P, R]):
169
174
 
170
175
  def _resolve_parameters(self):
171
176
  if not self.parameters:
172
- return {}
177
+ return None
173
178
 
174
179
  resolved_parameters = {}
175
180
  for parameter, value in self.parameters.items():
@@ -227,10 +232,8 @@ class BaseTaskRunEngine(Generic[P, R]):
227
232
  if self.task_run and self.task_run.start_time and not self.task_run.end_time:
228
233
  self.task_run.end_time = state.timestamp
229
234
 
230
- if self.task_run.state.is_running():
231
- self.task_run.total_run_time += (
232
- state.timestamp - self.task_run.state.timestamp
233
- )
235
+ if self.state.is_running():
236
+ self.task_run.total_run_time += state.timestamp - self.state.timestamp
234
237
 
235
238
  def is_running(self) -> bool:
236
239
  """Whether or not the engine is currently running a task."""
@@ -238,7 +241,7 @@ class BaseTaskRunEngine(Generic[P, R]):
238
241
  return False
239
242
  return task_run.state.is_running() or task_run.state.is_scheduled()
240
243
 
241
- def log_finished_message(self):
244
+ def log_finished_message(self) -> None:
242
245
  if not self.task_run:
243
246
  return
244
247
 
@@ -294,6 +297,7 @@ class BaseTaskRunEngine(Generic[P, R]):
294
297
 
295
298
  @dataclass
296
299
  class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
300
+ task_run: Optional[TaskRun] = None
297
301
  _client: Optional[SyncPrefectClient] = None
298
302
 
299
303
  @property
@@ -336,7 +340,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
336
340
  )
337
341
  return False
338
342
 
339
- def call_hooks(self, state: Optional[State] = None):
343
+ def call_hooks(self, state: Optional[State] = None) -> None:
340
344
  if state is None:
341
345
  state = self.state
342
346
  task = self.task
@@ -371,7 +375,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
371
375
  else:
372
376
  self.logger.info(f"Hook {hook_name!r} finished running successfully")
373
377
 
374
- def begin_run(self):
378
+ def begin_run(self) -> None:
375
379
  try:
376
380
  self._resolve_parameters()
377
381
  self._set_custom_task_run_name()
@@ -390,6 +394,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
390
394
 
391
395
  new_state = Running()
392
396
 
397
+ assert self.task_run is not None, "Task run is not set"
393
398
  self.task_run.start_time = new_state.timestamp
394
399
 
395
400
  flow_run_context = FlowRunContext.get()
@@ -406,7 +411,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
406
411
  # result reference that no longer exists
407
412
  if state.is_completed():
408
413
  try:
409
- state.result(retry_result_failure=False, _sync=True)
414
+ state.result(retry_result_failure=False, _sync=True) # type: ignore[reportCallIssue]
410
415
  except Exception:
411
416
  state = self.set_state(new_state, force=True)
412
417
 
@@ -422,7 +427,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
422
427
  time.sleep(interval)
423
428
  state = self.set_state(new_state)
424
429
 
425
- def set_state(self, state: State, force: bool = False) -> State:
430
+ def set_state(self, state: State[R], force: bool = False) -> State[R]:
426
431
  last_state = self.state
427
432
  if not self.task_run:
428
433
  raise ValueError("Task run is not set")
@@ -537,7 +542,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
537
542
  new_state = Retrying()
538
543
 
539
544
  self.logger.info(
540
- "Task run failed with exception: %r - " "Retry %s/%s will start %s",
545
+ "Task run failed with exception: %r - Retry %s/%s will start %s",
541
546
  exc,
542
547
  self.retries + 1,
543
548
  self.task.retries,
@@ -545,7 +550,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
545
550
  )
546
551
 
547
552
  self.set_state(new_state, force=True)
548
- self.retries = self.retries + 1
553
+ self.retries: int = self.retries + 1
549
554
  return True
550
555
  elif self.retries >= self.task.retries:
551
556
  self.logger.error(
@@ -639,7 +644,9 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
639
644
  stack.enter_context(ConcurrencyContextV1())
640
645
  stack.enter_context(ConcurrencyContext())
641
646
 
642
- self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
647
+ self.logger: "logging.Logger" = task_run_logger(
648
+ task_run=self.task_run, task=self.task
649
+ ) # type: ignore
643
650
 
644
651
  yield
645
652
 
@@ -648,7 +655,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
648
655
  self,
649
656
  task_run_id: Optional[UUID] = None,
650
657
  dependencies: Optional[dict[str, set[TaskRunInput]]] = None,
651
- ) -> Generator["SyncTaskRunEngine", Any, Any]:
658
+ ) -> Generator[Self, Any, Any]:
652
659
  """
653
660
  Enters a client context and creates a task run if needed.
654
661
  """
@@ -718,7 +725,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
718
725
  self._is_started = False
719
726
  self._client = None
720
727
 
721
- async def wait_until_ready(self):
728
+ async def wait_until_ready(self) -> None:
722
729
  """Waits until the scheduled time (if its the future), then enters Running."""
723
730
  if scheduled_time := self.state.state_details.scheduled_time:
724
731
  sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
@@ -825,6 +832,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
825
832
 
826
833
  @dataclass
827
834
  class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
835
+ task_run: TaskRun | None = None
828
836
  _client: Optional[PrefectClient] = None
829
837
 
830
838
  @property
@@ -866,7 +874,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
866
874
  )
867
875
  return False
868
876
 
869
- async def call_hooks(self, state: Optional[State] = None):
877
+ async def call_hooks(self, state: Optional[State] = None) -> None:
870
878
  if state is None:
871
879
  state = self.state
872
880
  task = self.task
@@ -901,7 +909,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
901
909
  else:
902
910
  self.logger.info(f"Hook {hook_name!r} finished running successfully")
903
911
 
904
- async def begin_run(self):
912
+ async def begin_run(self) -> None:
905
913
  try:
906
914
  self._resolve_parameters()
907
915
  self._set_custom_task_run_name()
@@ -1067,7 +1075,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1067
1075
  new_state = Retrying()
1068
1076
 
1069
1077
  self.logger.info(
1070
- "Task run failed with exception: %r - " "Retry %s/%s will start %s",
1078
+ "Task run failed with exception: %r - Retry %s/%s will start %s",
1071
1079
  exc,
1072
1080
  self.retries + 1,
1073
1081
  self.task.retries,
@@ -1075,7 +1083,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1075
1083
  )
1076
1084
 
1077
1085
  await self.set_state(new_state, force=True)
1078
- self.retries = self.retries + 1
1086
+ self.retries: int = self.retries + 1
1079
1087
  return True
1080
1088
  elif self.retries >= self.task.retries:
1081
1089
  self.logger.error(
@@ -1169,7 +1177,9 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1169
1177
  )
1170
1178
  stack.enter_context(ConcurrencyContext())
1171
1179
 
1172
- self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
1180
+ self.logger: "logging.Logger" = task_run_logger(
1181
+ task_run=self.task_run, task=self.task
1182
+ ) # type: ignore
1173
1183
 
1174
1184
  yield
1175
1185
 
@@ -1178,7 +1188,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1178
1188
  self,
1179
1189
  task_run_id: Optional[UUID] = None,
1180
1190
  dependencies: Optional[dict[str, set[TaskRunInput]]] = None,
1181
- ) -> AsyncGenerator["AsyncTaskRunEngine", Any]:
1191
+ ) -> AsyncGenerator[Self, Any]:
1182
1192
  """
1183
1193
  Enters a client context and creates a task run if needed.
1184
1194
  """
@@ -1246,7 +1256,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1246
1256
  self._is_started = False
1247
1257
  self._client = None
1248
1258
 
1249
- async def wait_until_ready(self):
1259
+ async def wait_until_ready(self) -> None:
1250
1260
  """Waits until the scheduled time (if its the future), then enters Running."""
1251
1261
  if scheduled_time := self.state.state_details.scheduled_time:
1252
1262
  sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
@@ -1341,7 +1351,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1341
1351
  if transaction.is_committed():
1342
1352
  result = transaction.read()
1343
1353
  else:
1344
- if self.task_run.tags:
1354
+ if self.task_run and self.task_run.tags:
1345
1355
  # Acquire a concurrency slot for each tag, but only if a limit
1346
1356
  # matching the tag already exists.
1347
1357
  async with aconcurrency(list(self.task_run.tags), self.task_run.id):
@@ -1546,7 +1556,7 @@ def run_task(
1546
1556
  Returns:
1547
1557
  The result of the task run
1548
1558
  """
1549
- kwargs = dict(
1559
+ kwargs: dict[str, Any] = dict(
1550
1560
  task=task,
1551
1561
  task_run_id=task_run_id,
1552
1562
  task_run=task_run,