prefect-client 3.4.5.dev4__py3-none-any.whl → 3.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
- import inspect
4
3
  import warnings
5
4
  from pathlib import Path
6
5
  from typing import (
7
6
  Annotated,
8
7
  Any,
9
8
  ClassVar,
10
- Dict,
11
9
  Iterable,
12
10
  Iterator,
13
11
  Optional,
14
- Union,
15
12
  )
16
13
 
17
14
  import toml
@@ -20,16 +17,15 @@ from pydantic import (
20
17
  BeforeValidator,
21
18
  ConfigDict,
22
19
  Field,
23
- TypeAdapter,
24
20
  ValidationError,
25
21
  )
26
- from pydantic_settings import BaseSettings
27
22
 
28
23
  from prefect.exceptions import ProfileSettingsValidationError
29
24
  from prefect.settings.constants import DEFAULT_PROFILES_PATH
30
25
  from prefect.settings.context import get_current_settings
31
26
  from prefect.settings.legacy import Setting, _get_settings_fields
32
27
  from prefect.settings.models.root import Settings
28
+ from prefect.utilities.collections import set_in_dict
33
29
 
34
30
 
35
31
  def _cast_settings(
@@ -69,7 +65,7 @@ class Profile(BaseModel):
69
65
  )
70
66
  source: Optional[Path] = None
71
67
 
72
- def to_environment_variables(self) -> Dict[str, str]:
68
+ def to_environment_variables(self) -> dict[str, str]:
73
69
  """Convert the profile settings to a dictionary of environment variables."""
74
70
  return {
75
71
  setting.name: str(value)
@@ -78,23 +74,40 @@ class Profile(BaseModel):
78
74
  }
79
75
 
80
76
  def validate_settings(self) -> None:
81
- errors: list[tuple[Setting, ValidationError]] = []
77
+ """
78
+ Validate all settings in this profile by creating a partial Settings object
79
+ with the nested structure properly constructed using accessor paths.
80
+ """
81
+ if not self.settings:
82
+ return
83
+
84
+ nested_settings: dict[str, Any] = {}
85
+
82
86
  for setting, value in self.settings.items():
83
- try:
84
- model_fields = Settings.model_fields
85
- annotation = None
86
- for section in setting.accessor.split("."):
87
- annotation = model_fields[section].annotation
88
- if inspect.isclass(annotation) and issubclass(
89
- annotation, BaseSettings
90
- ):
91
- model_fields = annotation.model_fields
92
-
93
- TypeAdapter(annotation).validate_python(value)
94
- except ValidationError as e:
95
- errors.append((setting, e))
96
- if errors:
97
- raise ProfileSettingsValidationError(errors)
87
+ set_in_dict(nested_settings, setting.accessor, value)
88
+
89
+ try:
90
+ Settings.model_validate(nested_settings)
91
+ except ValidationError as e:
92
+ errors: list[tuple[Setting, ValidationError]] = []
93
+
94
+ for error in e.errors():
95
+ error_path = ".".join(str(loc) for loc in error["loc"])
96
+
97
+ for setting in self.settings.keys():
98
+ if setting.accessor == error_path:
99
+ errors.append(
100
+ (
101
+ setting,
102
+ ValidationError.from_exception_data(
103
+ "ValidationError", [error]
104
+ ),
105
+ )
106
+ )
107
+ break
108
+
109
+ if errors:
110
+ raise ProfileSettingsValidationError(errors)
98
111
 
99
112
 
100
113
  class ProfilesCollection:
@@ -106,9 +119,7 @@ class ProfilesCollection:
106
119
  The collection may store the name of the active profile.
107
120
  """
108
121
 
109
- def __init__(
110
- self, profiles: Iterable[Profile], active: Optional[str] = None
111
- ) -> None:
122
+ def __init__(self, profiles: Iterable[Profile], active: str | None = None) -> None:
112
123
  self.profiles_by_name: dict[str, Profile] = {
113
124
  profile.name: profile for profile in profiles
114
125
  }
@@ -122,7 +133,7 @@ class ProfilesCollection:
122
133
  return set(self.profiles_by_name.keys())
123
134
 
124
135
  @property
125
- def active_profile(self) -> Optional[Profile]:
136
+ def active_profile(self) -> Profile | None:
126
137
  """
127
138
  Retrieve the active profile in this collection.
128
139
  """
@@ -130,7 +141,7 @@ class ProfilesCollection:
130
141
  return None
131
142
  return self[self.active_name]
132
143
 
133
- def set_active(self, name: Optional[str], check: bool = True) -> None:
144
+ def set_active(self, name: str | None, check: bool = True) -> None:
134
145
  """
135
146
  Set the active profile name in the collection.
136
147
 
@@ -145,7 +156,7 @@ class ProfilesCollection:
145
156
  self,
146
157
  name: str,
147
158
  settings: dict[Setting, Any],
148
- source: Optional[Path] = None,
159
+ source: Path | None = None,
149
160
  ) -> Profile:
150
161
  """
151
162
  Add a profile to the collection or update the existing on if the name is already
@@ -201,7 +212,7 @@ class ProfilesCollection:
201
212
  """
202
213
  self.profiles_by_name.pop(name)
203
214
 
204
- def without_profile_source(self, path: Optional[Path]) -> "ProfilesCollection":
215
+ def without_profile_source(self, path: Path | None) -> "ProfilesCollection":
205
216
  """
206
217
  Remove profiles that were loaded from a given path.
207
218
 
@@ -367,7 +378,7 @@ def load_profile(name: str) -> Profile:
367
378
 
368
379
 
369
380
  def update_current_profile(
370
- settings: Dict[Union[str, Setting], Any],
381
+ settings: dict[str | Setting, Any],
371
382
  ) -> Profile:
372
383
  """
373
384
  Update the persisted data for the profile currently in-use.
prefect/task_engine.py CHANGED
@@ -43,6 +43,7 @@ from prefect.concurrency.v1.asyncio import concurrency as aconcurrency
43
43
  from prefect.concurrency.v1.context import ConcurrencyContext as ConcurrencyContextV1
44
44
  from prefect.concurrency.v1.sync import concurrency
45
45
  from prefect.context import (
46
+ AssetContext,
46
47
  AsyncClientContext,
47
48
  FlowRunContext,
48
49
  SyncClientContext,
@@ -314,10 +315,13 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
314
315
  raise RuntimeError("Engine has not started.")
315
316
  return self._client
316
317
 
317
- def can_retry(self, exc: Exception) -> bool:
318
+ def can_retry(self, exc_or_state: Exception | State[R]) -> bool:
318
319
  retry_condition: Optional[
319
- Callable[["Task[P, Coroutine[Any, Any, R]]", TaskRun, State], bool]
320
+ Callable[["Task[P, Coroutine[Any, Any, R]]", TaskRun, State[R]], bool]
320
321
  ] = self.task.retry_condition_fn
322
+
323
+ failure_type = "exception" if isinstance(exc_or_state, Exception) else "state"
324
+
321
325
  if not self.task_run:
322
326
  raise ValueError("Task run is not set")
323
327
  try:
@@ -326,8 +330,8 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
326
330
  f" {self.task.name!r}"
327
331
  )
328
332
  state = Failed(
329
- data=exc,
330
- message=f"Task run encountered unexpected exception: {repr(exc)}",
333
+ data=exc_or_state,
334
+ message=f"Task run encountered unexpected {failure_type}: {repr(exc_or_state)}",
331
335
  )
332
336
  if asyncio.iscoroutinefunction(retry_condition):
333
337
  should_retry = run_coro_as_sync(
@@ -449,7 +453,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
449
453
  else:
450
454
  result = state.data
451
455
 
452
- link_state_to_result(state, result)
456
+ link_state_to_result(new_state, result)
453
457
 
454
458
  # emit a state change event
455
459
  self._last_event = emit_task_run_state_change_event(
@@ -476,7 +480,15 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
476
480
  # otherwise, return the exception
477
481
  return self._raised
478
482
 
479
- def handle_success(self, result: R, transaction: Transaction) -> R:
483
+ def handle_success(
484
+ self, result: R, transaction: Transaction
485
+ ) -> Union[ResultRecord[R], None, Coroutine[Any, Any, R], R]:
486
+ # Handle the case where the task explicitly returns a failed state, in
487
+ # which case we should retry the task if it has retries left.
488
+ if isinstance(result, State) and result.is_failed():
489
+ if self.handle_retry(result):
490
+ return None
491
+
480
492
  if self.task.cache_expiration is not None:
481
493
  expiration = prefect.types._datetime.now("UTC") + self.task.cache_expiration
482
494
  else:
@@ -508,16 +520,16 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
508
520
  self._return_value = result
509
521
 
510
522
  self._telemetry.end_span_on_success()
511
- return result
512
523
 
513
- def handle_retry(self, exc: Exception) -> bool:
524
+ def handle_retry(self, exc_or_state: Exception | State[R]) -> bool:
514
525
  """Handle any task run retries.
515
526
 
516
527
  - If the task has retries left, and the retry condition is met, set the task to retrying and return True.
517
528
  - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
518
529
  - If the task has no retries left, or the retry condition is not met, return False.
519
530
  """
520
- if self.retries < self.task.retries and self.can_retry(exc):
531
+ failure_type = "exception" if isinstance(exc_or_state, Exception) else "state"
532
+ if self.retries < self.task.retries and self.can_retry(exc_or_state):
521
533
  if self.task.retry_delay_seconds:
522
534
  delay = (
523
535
  self.task.retry_delay_seconds[
@@ -535,8 +547,9 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
535
547
  new_state = Retrying()
536
548
 
537
549
  self.logger.info(
538
- "Task run failed with exception: %r - Retry %s/%s will start %s",
539
- exc,
550
+ "Task run failed with %s: %r - Retry %s/%s will start %s",
551
+ failure_type,
552
+ exc_or_state,
540
553
  self.retries + 1,
541
554
  self.task.retries,
542
555
  str(delay) + " second(s) from now" if delay else "immediately",
@@ -552,7 +565,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
552
565
  else "No retries configured for this task."
553
566
  )
554
567
  self.logger.error(
555
- f"Task run failed with exception: {exc!r} - {retry_message_suffix}",
568
+ f"Task run failed with {failure_type}: {exc_or_state!r} - {retry_message_suffix}",
556
569
  exc_info=True,
557
570
  )
558
571
  return False
@@ -625,6 +638,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
625
638
  persist_result = settings.tasks.default_persist_result
626
639
  else:
627
640
  persist_result = should_persist_result()
641
+
628
642
  stack.enter_context(
629
643
  TaskRunContext(
630
644
  task=self.task,
@@ -647,6 +661,24 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
647
661
 
648
662
  yield
649
663
 
664
+ @contextmanager
665
+ def asset_context(self):
666
+ parent_asset_ctx = AssetContext.get()
667
+
668
+ if parent_asset_ctx and parent_asset_ctx.copy_to_child_ctx:
669
+ asset_ctx = parent_asset_ctx.model_copy()
670
+ asset_ctx.copy_to_child_ctx = False
671
+ else:
672
+ asset_ctx = AssetContext.from_task_and_inputs(
673
+ self.task, self.task_run.id, self.task_run.task_inputs
674
+ )
675
+
676
+ with asset_ctx as ctx:
677
+ try:
678
+ yield
679
+ finally:
680
+ ctx.emit_events(self.state)
681
+
650
682
  @contextmanager
651
683
  def initialize_run(
652
684
  self,
@@ -830,7 +862,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
830
862
 
831
863
  def call_task_fn(
832
864
  self, transaction: Transaction
833
- ) -> Union[R, Coroutine[Any, Any, R]]:
865
+ ) -> Union[ResultRecord[Any], None, Coroutine[Any, Any, R], R]:
834
866
  """
835
867
  Convenience method to call the task function. Returns a coroutine if the
836
868
  task is async.
@@ -855,10 +887,13 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
855
887
  raise RuntimeError("Engine has not started.")
856
888
  return self._client
857
889
 
858
- async def can_retry(self, exc: Exception) -> bool:
890
+ async def can_retry(self, exc_or_state: Exception | State[R]) -> bool:
859
891
  retry_condition: Optional[
860
- Callable[["Task[P, Coroutine[Any, Any, R]]", TaskRun, State], bool]
892
+ Callable[["Task[P, Coroutine[Any, Any, R]]", TaskRun, State[R]], bool]
861
893
  ] = self.task.retry_condition_fn
894
+
895
+ failure_type = "exception" if isinstance(exc_or_state, Exception) else "state"
896
+
862
897
  if not self.task_run:
863
898
  raise ValueError("Task run is not set")
864
899
  try:
@@ -867,8 +902,8 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
867
902
  f" {self.task.name!r}"
868
903
  )
869
904
  state = Failed(
870
- data=exc,
871
- message=f"Task run encountered unexpected exception: {repr(exc)}",
905
+ data=exc_or_state,
906
+ message=f"Task run encountered unexpected {failure_type}: {repr(exc_or_state)}",
872
907
  )
873
908
  if asyncio.iscoroutinefunction(retry_condition):
874
909
  should_retry = await retry_condition(self.task, self.task_run, state)
@@ -1031,7 +1066,13 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1031
1066
  # otherwise, return the exception
1032
1067
  return self._raised
1033
1068
 
1034
- async def handle_success(self, result: R, transaction: AsyncTransaction) -> R:
1069
+ async def handle_success(
1070
+ self, result: R, transaction: AsyncTransaction
1071
+ ) -> Union[ResultRecord[R], None, Coroutine[Any, Any, R], R]:
1072
+ if isinstance(result, State) and result.is_failed():
1073
+ if await self.handle_retry(result):
1074
+ return None
1075
+
1035
1076
  if self.task.cache_expiration is not None:
1036
1077
  expiration = prefect.types._datetime.now("UTC") + self.task.cache_expiration
1037
1078
  else:
@@ -1059,19 +1100,20 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1059
1100
  self.record_terminal_state_timing(terminal_state)
1060
1101
  await self.set_state(terminal_state)
1061
1102
  self._return_value = result
1062
-
1063
1103
  self._telemetry.end_span_on_success()
1064
1104
 
1065
1105
  return result
1066
1106
 
1067
- async def handle_retry(self, exc: Exception) -> bool:
1107
+ async def handle_retry(self, exc_or_state: Exception | State[R]) -> bool:
1068
1108
  """Handle any task run retries.
1069
1109
 
1070
1110
  - If the task has retries left, and the retry condition is met, set the task to retrying and return True.
1071
1111
  - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
1072
1112
  - If the task has no retries left, or the retry condition is not met, return False.
1073
1113
  """
1074
- if self.retries < self.task.retries and await self.can_retry(exc):
1114
+ failure_type = "exception" if isinstance(exc_or_state, Exception) else "state"
1115
+
1116
+ if self.retries < self.task.retries and await self.can_retry(exc_or_state):
1075
1117
  if self.task.retry_delay_seconds:
1076
1118
  delay = (
1077
1119
  self.task.retry_delay_seconds[
@@ -1089,8 +1131,9 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1089
1131
  new_state = Retrying()
1090
1132
 
1091
1133
  self.logger.info(
1092
- "Task run failed with exception: %r - Retry %s/%s will start %s",
1093
- exc,
1134
+ "Task run failed with %s: %r - Retry %s/%s will start %s",
1135
+ failure_type,
1136
+ exc_or_state,
1094
1137
  self.retries + 1,
1095
1138
  self.task.retries,
1096
1139
  str(delay) + " second(s) from now" if delay else "immediately",
@@ -1106,7 +1149,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1106
1149
  else "No retries configured for this task."
1107
1150
  )
1108
1151
  self.logger.error(
1109
- f"Task run failed with exception: {exc!r} - {retry_message_suffix}",
1152
+ f"Task run failed with {failure_type}: {exc_or_state!r} - {retry_message_suffix}",
1110
1153
  exc_info=True,
1111
1154
  )
1112
1155
  return False
@@ -1180,6 +1223,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1180
1223
  persist_result = settings.tasks.default_persist_result
1181
1224
  else:
1182
1225
  persist_result = should_persist_result()
1226
+
1183
1227
  stack.enter_context(
1184
1228
  TaskRunContext(
1185
1229
  task=self.task,
@@ -1201,6 +1245,24 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1201
1245
 
1202
1246
  yield
1203
1247
 
1248
+ @asynccontextmanager
1249
+ async def asset_context(self):
1250
+ parent_asset_ctx = AssetContext.get()
1251
+
1252
+ if parent_asset_ctx and parent_asset_ctx.copy_to_child_ctx:
1253
+ asset_ctx = parent_asset_ctx.model_copy()
1254
+ asset_ctx.copy_to_child_ctx = False
1255
+ else:
1256
+ asset_ctx = AssetContext.from_task_and_inputs(
1257
+ self.task, self.task_run.id, self.task_run.task_inputs
1258
+ )
1259
+
1260
+ with asset_ctx as ctx:
1261
+ try:
1262
+ yield
1263
+ finally:
1264
+ ctx.emit_events(self.state)
1265
+
1204
1266
  @asynccontextmanager
1205
1267
  async def initialize_run(
1206
1268
  self,
@@ -1382,7 +1444,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1382
1444
 
1383
1445
  async def call_task_fn(
1384
1446
  self, transaction: AsyncTransaction
1385
- ) -> Union[R, Coroutine[Any, Any, R]]:
1447
+ ) -> Union[ResultRecord[Any], None, Coroutine[Any, Any, R], R]:
1386
1448
  """
1387
1449
  Convenience method to call the task function. Returns a coroutine if the
1388
1450
  task is async.
@@ -1417,7 +1479,11 @@ def run_task_sync(
1417
1479
  with engine.start(task_run_id=task_run_id, dependencies=dependencies):
1418
1480
  while engine.is_running():
1419
1481
  run_coro_as_sync(engine.wait_until_ready())
1420
- with engine.run_context(), engine.transaction_context() as txn:
1482
+ with (
1483
+ engine.asset_context(),
1484
+ engine.run_context(),
1485
+ engine.transaction_context() as txn,
1486
+ ):
1421
1487
  engine.call_task_fn(txn)
1422
1488
 
1423
1489
  return engine.state if return_type == "state" else engine.result()
@@ -1444,7 +1510,11 @@ async def run_task_async(
1444
1510
  async with engine.start(task_run_id=task_run_id, dependencies=dependencies):
1445
1511
  while engine.is_running():
1446
1512
  await engine.wait_until_ready()
1447
- async with engine.run_context(), engine.transaction_context() as txn:
1513
+ async with (
1514
+ engine.asset_context(),
1515
+ engine.run_context(),
1516
+ engine.transaction_context() as txn,
1517
+ ):
1448
1518
  await engine.call_task_fn(txn)
1449
1519
 
1450
1520
  return engine.state if return_type == "state" else await engine.result()
@@ -1474,7 +1544,11 @@ def run_generator_task_sync(
1474
1544
  with engine.start(task_run_id=task_run_id, dependencies=dependencies):
1475
1545
  while engine.is_running():
1476
1546
  run_coro_as_sync(engine.wait_until_ready())
1477
- with engine.run_context(), engine.transaction_context() as txn:
1547
+ with (
1548
+ engine.asset_context(),
1549
+ engine.run_context(),
1550
+ engine.transaction_context() as txn,
1551
+ ):
1478
1552
  # TODO: generators should default to commit_mode=OFF
1479
1553
  # because they are dynamic by definition
1480
1554
  # for now we just prevent this branch explicitly
@@ -1528,7 +1602,11 @@ async def run_generator_task_async(
1528
1602
  async with engine.start(task_run_id=task_run_id, dependencies=dependencies):
1529
1603
  while engine.is_running():
1530
1604
  await engine.wait_until_ready()
1531
- async with engine.run_context(), engine.transaction_context() as txn:
1605
+ async with (
1606
+ engine.asset_context(),
1607
+ engine.run_context(),
1608
+ engine.transaction_context() as txn,
1609
+ ):
1532
1610
  # TODO: generators should default to commit_mode=OFF
1533
1611
  # because they are dynamic by definition
1534
1612
  # for now we just prevent this branch explicitly