prefect-client 3.4.5.dev4__py3-none-any.whl → 3.4.6.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prefect/task_engine.py CHANGED
@@ -43,6 +43,7 @@ from prefect.concurrency.v1.asyncio import concurrency as aconcurrency
43
43
  from prefect.concurrency.v1.context import ConcurrencyContext as ConcurrencyContextV1
44
44
  from prefect.concurrency.v1.sync import concurrency
45
45
  from prefect.context import (
46
+ AssetContext,
46
47
  AsyncClientContext,
47
48
  FlowRunContext,
48
49
  SyncClientContext,
@@ -314,10 +315,13 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
314
315
  raise RuntimeError("Engine has not started.")
315
316
  return self._client
316
317
 
317
- def can_retry(self, exc: Exception) -> bool:
318
+ def can_retry(self, exc_or_state: Exception | State[R]) -> bool:
318
319
  retry_condition: Optional[
319
- Callable[["Task[P, Coroutine[Any, Any, R]]", TaskRun, State], bool]
320
+ Callable[["Task[P, Coroutine[Any, Any, R]]", TaskRun, State[R]], bool]
320
321
  ] = self.task.retry_condition_fn
322
+
323
+ failure_type = "exception" if isinstance(exc_or_state, Exception) else "state"
324
+
321
325
  if not self.task_run:
322
326
  raise ValueError("Task run is not set")
323
327
  try:
@@ -326,8 +330,8 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
326
330
  f" {self.task.name!r}"
327
331
  )
328
332
  state = Failed(
329
- data=exc,
330
- message=f"Task run encountered unexpected exception: {repr(exc)}",
333
+ data=exc_or_state,
334
+ message=f"Task run encountered unexpected {failure_type}: {repr(exc_or_state)}",
331
335
  )
332
336
  if asyncio.iscoroutinefunction(retry_condition):
333
337
  should_retry = run_coro_as_sync(
@@ -449,7 +453,9 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
449
453
  else:
450
454
  result = state.data
451
455
 
452
- link_state_to_result(state, result)
456
+ link_state_to_result(new_state, result)
457
+ if asset_context := AssetContext.get():
458
+ asset_context.emit_events(new_state)
453
459
 
454
460
  # emit a state change event
455
461
  self._last_event = emit_task_run_state_change_event(
@@ -476,7 +482,15 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
476
482
  # otherwise, return the exception
477
483
  return self._raised
478
484
 
479
- def handle_success(self, result: R, transaction: Transaction) -> R:
485
+ def handle_success(
486
+ self, result: R, transaction: Transaction
487
+ ) -> Union[ResultRecord[R], None, Coroutine[Any, Any, R], R]:
488
+ # Handle the case where the task explicitly returns a failed state, in
489
+ # which case we should retry the task if it has retries left.
490
+ if isinstance(result, State) and result.is_failed():
491
+ if self.handle_retry(result):
492
+ return None
493
+
480
494
  if self.task.cache_expiration is not None:
481
495
  expiration = prefect.types._datetime.now("UTC") + self.task.cache_expiration
482
496
  else:
@@ -508,16 +522,16 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
508
522
  self._return_value = result
509
523
 
510
524
  self._telemetry.end_span_on_success()
511
- return result
512
525
 
513
- def handle_retry(self, exc: Exception) -> bool:
526
+ def handle_retry(self, exc_or_state: Exception | State[R]) -> bool:
514
527
  """Handle any task run retries.
515
528
 
516
529
  - If the task has retries left, and the retry condition is met, set the task to retrying and return True.
517
530
  - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
518
531
  - If the task has no retries left, or the retry condition is not met, return False.
519
532
  """
520
- if self.retries < self.task.retries and self.can_retry(exc):
533
+ failure_type = "exception" if isinstance(exc_or_state, Exception) else "state"
534
+ if self.retries < self.task.retries and self.can_retry(exc_or_state):
521
535
  if self.task.retry_delay_seconds:
522
536
  delay = (
523
537
  self.task.retry_delay_seconds[
@@ -535,8 +549,9 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
535
549
  new_state = Retrying()
536
550
 
537
551
  self.logger.info(
538
- "Task run failed with exception: %r - Retry %s/%s will start %s",
539
- exc,
552
+ "Task run failed with %s: %r - Retry %s/%s will start %s",
553
+ failure_type,
554
+ exc_or_state,
540
555
  self.retries + 1,
541
556
  self.task.retries,
542
557
  str(delay) + " second(s) from now" if delay else "immediately",
@@ -552,7 +567,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
552
567
  else "No retries configured for this task."
553
568
  )
554
569
  self.logger.error(
555
- f"Task run failed with exception: {exc!r} - {retry_message_suffix}",
570
+ f"Task run failed with {failure_type}: {exc_or_state!r} - {retry_message_suffix}",
556
571
  exc_info=True,
557
572
  )
558
573
  return False
@@ -625,6 +640,16 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
625
640
  persist_result = settings.tasks.default_persist_result
626
641
  else:
627
642
  persist_result = should_persist_result()
643
+
644
+ asset_context = AssetContext.get()
645
+ if not asset_context:
646
+ asset_context = AssetContext.from_task_and_inputs(
647
+ task=self.task,
648
+ task_run_id=self.task_run.id,
649
+ task_inputs=self.task_run.task_inputs,
650
+ )
651
+ stack.enter_context(asset_context)
652
+
628
653
  stack.enter_context(
629
654
  TaskRunContext(
630
655
  task=self.task,
@@ -830,7 +855,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
830
855
 
831
856
  def call_task_fn(
832
857
  self, transaction: Transaction
833
- ) -> Union[R, Coroutine[Any, Any, R]]:
858
+ ) -> Union[ResultRecord[Any], None, Coroutine[Any, Any, R], R]:
834
859
  """
835
860
  Convenience method to call the task function. Returns a coroutine if the
836
861
  task is async.
@@ -855,10 +880,13 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
855
880
  raise RuntimeError("Engine has not started.")
856
881
  return self._client
857
882
 
858
- async def can_retry(self, exc: Exception) -> bool:
883
+ async def can_retry(self, exc_or_state: Exception | State[R]) -> bool:
859
884
  retry_condition: Optional[
860
- Callable[["Task[P, Coroutine[Any, Any, R]]", TaskRun, State], bool]
885
+ Callable[["Task[P, Coroutine[Any, Any, R]]", TaskRun, State[R]], bool]
861
886
  ] = self.task.retry_condition_fn
887
+
888
+ failure_type = "exception" if isinstance(exc_or_state, Exception) else "state"
889
+
862
890
  if not self.task_run:
863
891
  raise ValueError("Task run is not set")
864
892
  try:
@@ -867,8 +895,8 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
867
895
  f" {self.task.name!r}"
868
896
  )
869
897
  state = Failed(
870
- data=exc,
871
- message=f"Task run encountered unexpected exception: {repr(exc)}",
898
+ data=exc_or_state,
899
+ message=f"Task run encountered unexpected {failure_type}: {repr(exc_or_state)}",
872
900
  )
873
901
  if asyncio.iscoroutinefunction(retry_condition):
874
902
  should_retry = await retry_condition(self.task, self.task_run, state)
@@ -1004,6 +1032,8 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1004
1032
  result = new_state.data
1005
1033
 
1006
1034
  link_state_to_result(new_state, result)
1035
+ if asset_context := AssetContext.get():
1036
+ asset_context.emit_events(new_state)
1007
1037
 
1008
1038
  # emit a state change event
1009
1039
  self._last_event = emit_task_run_state_change_event(
@@ -1031,7 +1061,13 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1031
1061
  # otherwise, return the exception
1032
1062
  return self._raised
1033
1063
 
1034
- async def handle_success(self, result: R, transaction: AsyncTransaction) -> R:
1064
+ async def handle_success(
1065
+ self, result: R, transaction: AsyncTransaction
1066
+ ) -> Union[ResultRecord[R], None, Coroutine[Any, Any, R], R]:
1067
+ if isinstance(result, State) and result.is_failed():
1068
+ if await self.handle_retry(result):
1069
+ return None
1070
+
1035
1071
  if self.task.cache_expiration is not None:
1036
1072
  expiration = prefect.types._datetime.now("UTC") + self.task.cache_expiration
1037
1073
  else:
@@ -1059,19 +1095,20 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1059
1095
  self.record_terminal_state_timing(terminal_state)
1060
1096
  await self.set_state(terminal_state)
1061
1097
  self._return_value = result
1062
-
1063
1098
  self._telemetry.end_span_on_success()
1064
1099
 
1065
1100
  return result
1066
1101
 
1067
- async def handle_retry(self, exc: Exception) -> bool:
1102
+ async def handle_retry(self, exc_or_state: Exception | State[R]) -> bool:
1068
1103
  """Handle any task run retries.
1069
1104
 
1070
1105
  - If the task has retries left, and the retry condition is met, set the task to retrying and return True.
1071
1106
  - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
1072
1107
  - If the task has no retries left, or the retry condition is not met, return False.
1073
1108
  """
1074
- if self.retries < self.task.retries and await self.can_retry(exc):
1109
+ failure_type = "exception" if isinstance(exc_or_state, Exception) else "state"
1110
+
1111
+ if self.retries < self.task.retries and await self.can_retry(exc_or_state):
1075
1112
  if self.task.retry_delay_seconds:
1076
1113
  delay = (
1077
1114
  self.task.retry_delay_seconds[
@@ -1089,8 +1126,9 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1089
1126
  new_state = Retrying()
1090
1127
 
1091
1128
  self.logger.info(
1092
- "Task run failed with exception: %r - Retry %s/%s will start %s",
1093
- exc,
1129
+ "Task run failed with %s: %r - Retry %s/%s will start %s",
1130
+ failure_type,
1131
+ exc_or_state,
1094
1132
  self.retries + 1,
1095
1133
  self.task.retries,
1096
1134
  str(delay) + " second(s) from now" if delay else "immediately",
@@ -1106,7 +1144,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1106
1144
  else "No retries configured for this task."
1107
1145
  )
1108
1146
  self.logger.error(
1109
- f"Task run failed with exception: {exc!r} - {retry_message_suffix}",
1147
+ f"Task run failed with {failure_type}: {exc_or_state!r} - {retry_message_suffix}",
1110
1148
  exc_info=True,
1111
1149
  )
1112
1150
  return False
@@ -1180,6 +1218,16 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1180
1218
  persist_result = settings.tasks.default_persist_result
1181
1219
  else:
1182
1220
  persist_result = should_persist_result()
1221
+
1222
+ asset_context = AssetContext.get()
1223
+ if not asset_context:
1224
+ asset_context = AssetContext.from_task_and_inputs(
1225
+ task=self.task,
1226
+ task_run_id=self.task_run.id,
1227
+ task_inputs=self.task_run.task_inputs,
1228
+ )
1229
+ stack.enter_context(asset_context)
1230
+
1183
1231
  stack.enter_context(
1184
1232
  TaskRunContext(
1185
1233
  task=self.task,
@@ -1382,7 +1430,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1382
1430
 
1383
1431
  async def call_task_fn(
1384
1432
  self, transaction: AsyncTransaction
1385
- ) -> Union[R, Coroutine[Any, Any, R]]:
1433
+ ) -> Union[ResultRecord[Any], None, Coroutine[Any, Any, R], R]:
1386
1434
  """
1387
1435
  Convenience method to call the task function. Returns a coroutine if the
1388
1436
  task is async.
prefect/tasks.py CHANGED
@@ -29,10 +29,20 @@ from typing import (
29
29
  )
30
30
  from uuid import UUID, uuid4
31
31
 
32
- from typing_extensions import Literal, ParamSpec, Self, TypeAlias, TypeIs
32
+ from typing_extensions import (
33
+ Literal,
34
+ ParamSpec,
35
+ Self,
36
+ Sequence,
37
+ TypeAlias,
38
+ TypedDict,
39
+ TypeIs,
40
+ Unpack,
41
+ )
33
42
 
34
43
  import prefect.states
35
44
  from prefect._internal.uuid7 import uuid7
45
+ from prefect.assets import Asset
36
46
  from prefect.cache_policies import DEFAULT, NO_CACHE, CachePolicy
37
47
  from prefect.client.orchestration import get_client
38
48
  from prefect.client.schemas import TaskRun
@@ -90,6 +100,65 @@ OneOrManyFutureOrResult: TypeAlias = Union[
90
100
  ]
91
101
 
92
102
 
103
+ class TaskRunNameCallbackWithParameters(Protocol):
104
+ @classmethod
105
+ def is_callback_with_parameters(cls, callable: Callable[..., str]) -> TypeIs[Self]:
106
+ sig = inspect.signature(callable)
107
+ return "parameters" in sig.parameters
108
+
109
+ def __call__(self, parameters: dict[str, Any]) -> str: ...
110
+
111
+
112
+ StateHookCallable: TypeAlias = Callable[
113
+ ["Task[..., Any]", TaskRun, State], Union[Awaitable[None], None]
114
+ ]
115
+ RetryConditionCallable: TypeAlias = Callable[
116
+ ["Task[..., Any]", TaskRun, State], Union[Awaitable[bool], bool]
117
+ ]
118
+ TaskRunNameValueOrCallable: TypeAlias = Union[
119
+ Callable[[], str], TaskRunNameCallbackWithParameters, str
120
+ ]
121
+
122
+
123
+ class TaskOptions(TypedDict, total=False):
124
+ """
125
+ A TypedDict representing all available task configuration options.
126
+
127
+ This can be used with `Unpack` to provide type hints for **kwargs.
128
+ """
129
+
130
+ name: Optional[str]
131
+ description: Optional[str]
132
+ tags: Optional[Iterable[str]]
133
+ version: Optional[str]
134
+ cache_policy: Union[CachePolicy, type[NotSet]]
135
+ cache_key_fn: Union[
136
+ Callable[["TaskRunContext", dict[str, Any]], Optional[str]], None
137
+ ]
138
+ cache_expiration: Optional[datetime.timedelta]
139
+ task_run_name: Optional[TaskRunNameValueOrCallable]
140
+ retries: Optional[int]
141
+ retry_delay_seconds: Union[
142
+ float, int, list[float], Callable[[int], list[float]], None
143
+ ]
144
+ retry_jitter_factor: Optional[float]
145
+ persist_result: Optional[bool]
146
+ result_storage: Optional[ResultStorage]
147
+ result_serializer: Optional[ResultSerializer]
148
+ result_storage_key: Optional[str]
149
+ cache_result_in_memory: bool
150
+ timeout_seconds: Union[int, float, None]
151
+ log_prints: Optional[bool]
152
+ refresh_cache: Optional[bool]
153
+ on_completion: Optional[list[StateHookCallable]]
154
+ on_failure: Optional[list[StateHookCallable]]
155
+ on_rollback: Optional[list[Callable[["Transaction"], None]]]
156
+ on_commit: Optional[list[Callable[["Transaction"], None]]]
157
+ retry_condition_fn: Optional[RetryConditionCallable]
158
+ viz_return_value: Any
159
+ asset_deps: Optional[list[Union[Asset, str]]]
160
+
161
+
93
162
  def task_input_hash(
94
163
  context: "TaskRunContext", arguments: dict[str, Any]
95
164
  ) -> Optional[str]:
@@ -223,23 +292,6 @@ def _generate_task_key(fn: Callable[..., Any]) -> str:
223
292
  return f"{qualname}-{code_hash}"
224
293
 
225
294
 
226
- class TaskRunNameCallbackWithParameters(Protocol):
227
- @classmethod
228
- def is_callback_with_parameters(cls, callable: Callable[..., str]) -> TypeIs[Self]:
229
- sig = inspect.signature(callable)
230
- return "parameters" in sig.parameters
231
-
232
- def __call__(self, parameters: dict[str, Any]) -> str: ...
233
-
234
-
235
- StateHookCallable: TypeAlias = Callable[
236
- ["Task[..., Any]", TaskRun, State], Union[Awaitable[None], None]
237
- ]
238
- TaskRunNameValueOrCallable: TypeAlias = Union[
239
- Callable[[], str], TaskRunNameCallbackWithParameters, str
240
- ]
241
-
242
-
243
295
  class Task(Generic[P, R]):
244
296
  """
245
297
  A Prefect task definition.
@@ -311,6 +363,7 @@ class Task(Generic[P, R]):
311
363
  should end as failed. Defaults to `None`, indicating the task should always continue
312
364
  to its retry policy.
313
365
  viz_return_value: An optional value to return when the task dependency tree is visualized.
366
+ asset_deps: An optional list of upstream assets that this task depends on.
314
367
  """
315
368
 
316
369
  # NOTE: These parameters (types, defaults, and docstrings) should be duplicated
@@ -350,10 +403,9 @@ class Task(Generic[P, R]):
350
403
  on_failure: Optional[list[StateHookCallable]] = None,
351
404
  on_rollback: Optional[list[Callable[["Transaction"], None]]] = None,
352
405
  on_commit: Optional[list[Callable[["Transaction"], None]]] = None,
353
- retry_condition_fn: Optional[
354
- Callable[["Task[..., Any]", TaskRun, State], bool]
355
- ] = None,
406
+ retry_condition_fn: Optional[RetryConditionCallable] = None,
356
407
  viz_return_value: Optional[Any] = None,
408
+ asset_deps: Optional[list[Union[str, Asset]]] = None,
357
409
  ):
358
410
  # Validate if hook passed is list and contains callables
359
411
  hook_categories = [on_completion, on_failure]
@@ -547,6 +599,14 @@ class Task(Generic[P, R]):
547
599
  self.retry_condition_fn = retry_condition_fn
548
600
  self.viz_return_value = viz_return_value
549
601
 
602
+ from prefect.assets import Asset
603
+
604
+ self.asset_deps: list[Asset] = (
605
+ [Asset(key=a) if isinstance(a, str) else a for a in asset_deps]
606
+ if asset_deps
607
+ else []
608
+ )
609
+
550
610
  @property
551
611
  def ismethod(self) -> bool:
552
612
  return hasattr(self.fn, "__prefect_self__")
@@ -613,10 +673,9 @@ class Task(Generic[P, R]):
613
673
  refresh_cache: Union[bool, type[NotSet]] = NotSet,
614
674
  on_completion: Optional[list[StateHookCallable]] = None,
615
675
  on_failure: Optional[list[StateHookCallable]] = None,
616
- retry_condition_fn: Optional[
617
- Callable[["Task[..., Any]", TaskRun, State], bool]
618
- ] = None,
676
+ retry_condition_fn: Optional[RetryConditionCallable] = None,
619
677
  viz_return_value: Optional[Any] = None,
678
+ asset_deps: Optional[list[Union[str, Asset]]] = None,
620
679
  ) -> "Task[P, R]":
621
680
  """
622
681
  Create a new task from the current object, updating provided options.
@@ -750,6 +809,7 @@ class Task(Generic[P, R]):
750
809
  on_failure=on_failure or self.on_failure_hooks,
751
810
  retry_condition_fn=retry_condition_fn or self.retry_condition_fn,
752
811
  viz_return_value=viz_return_value or self.viz_return_value,
812
+ asset_deps=asset_deps or self.asset_deps,
753
813
  )
754
814
 
755
815
  def on_completion(self, fn: StateHookCallable) -> StateHookCallable:
@@ -887,7 +947,9 @@ class Task(Generic[P, R]):
887
947
  deferred: bool = False,
888
948
  ) -> TaskRun:
889
949
  from prefect.utilities._engine import dynamic_key_for_task_run
890
- from prefect.utilities.engine import collect_task_run_inputs_sync
950
+ from prefect.utilities.engine import (
951
+ collect_task_run_inputs_sync,
952
+ )
891
953
 
892
954
  if flow_run_context is None:
893
955
  flow_run_context = FlowRunContext.get()
@@ -927,7 +989,7 @@ class Task(Generic[P, R]):
927
989
 
928
990
  store = await ResultStore(
929
991
  result_storage=await get_or_create_default_task_scheduling_storage()
930
- ).update_for_task(task)
992
+ ).update_for_task(self)
931
993
  context = serialize_context()
932
994
  data: dict[str, Any] = {"context": context}
933
995
  if parameters:
@@ -963,6 +1025,7 @@ class Task(Generic[P, R]):
963
1025
  else None
964
1026
  )
965
1027
  task_run_id = id or uuid7()
1028
+
966
1029
  state = prefect.states.Pending(
967
1030
  state_details=StateDetails(
968
1031
  task_run_id=task_run_id,
@@ -1664,8 +1727,9 @@ def task(
1664
1727
  refresh_cache: Optional[bool] = None,
1665
1728
  on_completion: Optional[list[StateHookCallable]] = None,
1666
1729
  on_failure: Optional[list[StateHookCallable]] = None,
1667
- retry_condition_fn: Literal[None] = None,
1730
+ retry_condition_fn: Optional[RetryConditionCallable] = None,
1668
1731
  viz_return_value: Any = None,
1732
+ asset_deps: Optional[list[Union[str, Asset]]] = None,
1669
1733
  ) -> Callable[[Callable[P, R]], Task[P, R]]: ...
1670
1734
 
1671
1735
 
@@ -1699,8 +1763,9 @@ def task(
1699
1763
  refresh_cache: Optional[bool] = None,
1700
1764
  on_completion: Optional[list[StateHookCallable]] = None,
1701
1765
  on_failure: Optional[list[StateHookCallable]] = None,
1702
- retry_condition_fn: Optional[Callable[[Task[P, R], TaskRun, State], bool]] = None,
1766
+ retry_condition_fn: Optional[RetryConditionCallable] = None,
1703
1767
  viz_return_value: Any = None,
1768
+ asset_deps: Optional[list[Union[str, Asset]]] = None,
1704
1769
  ) -> Callable[[Callable[P, R]], Task[P, R]]: ...
1705
1770
 
1706
1771
 
@@ -1735,8 +1800,9 @@ def task(
1735
1800
  refresh_cache: Optional[bool] = None,
1736
1801
  on_completion: Optional[list[StateHookCallable]] = None,
1737
1802
  on_failure: Optional[list[StateHookCallable]] = None,
1738
- retry_condition_fn: Optional[Callable[[Task[P, Any], TaskRun, State], bool]] = None,
1803
+ retry_condition_fn: Optional[RetryConditionCallable] = None,
1739
1804
  viz_return_value: Any = None,
1805
+ asset_deps: Optional[list[Union[str, Asset]]] = None,
1740
1806
  ) -> Callable[[Callable[P, R]], Task[P, R]]: ...
1741
1807
 
1742
1808
 
@@ -1768,8 +1834,9 @@ def task(
1768
1834
  refresh_cache: Optional[bool] = None,
1769
1835
  on_completion: Optional[list[StateHookCallable]] = None,
1770
1836
  on_failure: Optional[list[StateHookCallable]] = None,
1771
- retry_condition_fn: Optional[Callable[[Task[P, Any], TaskRun, State], bool]] = None,
1837
+ retry_condition_fn: Optional[RetryConditionCallable] = None,
1772
1838
  viz_return_value: Any = None,
1839
+ asset_deps: Optional[list[Union[str, Asset]]] = None,
1773
1840
  ):
1774
1841
  """
1775
1842
  Decorator to designate a function as a task in a Prefect workflow.
@@ -1830,6 +1897,7 @@ def task(
1830
1897
  should end as failed. Defaults to `None`, indicating the task should always continue
1831
1898
  to its retry policy.
1832
1899
  viz_return_value: An optional value to return when the task dependency tree is visualized.
1900
+ asset_deps: An optional list of upstream assets that this task depends on.
1833
1901
 
1834
1902
  Returns:
1835
1903
  A callable `Task` object which, when called, will submit the task for execution.
@@ -1906,6 +1974,7 @@ def task(
1906
1974
  on_failure=on_failure,
1907
1975
  retry_condition_fn=retry_condition_fn,
1908
1976
  viz_return_value=viz_return_value,
1977
+ asset_deps=asset_deps,
1909
1978
  )
1910
1979
  else:
1911
1980
  return cast(
@@ -1935,5 +2004,32 @@ def task(
1935
2004
  on_failure=on_failure,
1936
2005
  retry_condition_fn=retry_condition_fn,
1937
2006
  viz_return_value=viz_return_value,
2007
+ asset_deps=asset_deps,
1938
2008
  ),
1939
2009
  )
2010
+
2011
+
2012
+ class MaterializingTask(Task[P, R]):
2013
+ """
2014
+ A task that materializes Assets.
2015
+
2016
+ Args:
2017
+ assets: List of Assets that this task materializes (can be str or Asset)
2018
+ materialized_by: An optional tool that materialized the asset e.g. "dbt" or "spark"
2019
+ **task_kwargs: All other Task arguments
2020
+ """
2021
+
2022
+ def __init__(
2023
+ self,
2024
+ fn: Callable[P, R],
2025
+ *,
2026
+ assets: Sequence[Union[str, Asset]],
2027
+ materialized_by: str | None = None,
2028
+ **task_kwargs: Unpack[TaskOptions],
2029
+ ):
2030
+ super().__init__(fn=fn, **task_kwargs)
2031
+
2032
+ self.assets: list[Asset] = [
2033
+ Asset(key=a) if isinstance(a, str) else a for a in assets
2034
+ ]
2035
+ self.materialized_by = materialized_by
prefect/types/__init__.py CHANGED
@@ -14,6 +14,7 @@ from .names import (
14
14
  BANNED_CHARACTERS,
15
15
  WITHOUT_BANNED_CHARACTERS,
16
16
  MAX_VARIABLE_NAME_LENGTH,
17
+ URILike,
17
18
  )
18
19
  from pydantic import (
19
20
  BeforeValidator,
@@ -219,4 +220,5 @@ __all__ = [
219
220
  "StatusCode",
220
221
  "StrictVariableValue",
221
222
  "TaskRetryDelaySeconds",
223
+ "URILike",
222
224
  ]
@@ -223,10 +223,10 @@ def travel_to(dt: Any):
223
223
 
224
224
  def in_local_tz(dt: datetime.datetime) -> datetime.datetime:
225
225
  if sys.version_info >= (3, 13):
226
- from whenever import LocalDateTime, ZonedDateTime
226
+ from whenever import PlainDateTime, ZonedDateTime
227
227
 
228
228
  if dt.tzinfo is None:
229
- wdt = LocalDateTime.from_py_datetime(dt)
229
+ wdt = PlainDateTime.from_py_datetime(dt)
230
230
  else:
231
231
  if not isinstance(dt.tzinfo, ZoneInfo):
232
232
  if key := getattr(dt.tzinfo, "key", None):
prefect/types/names.py CHANGED
@@ -137,3 +137,26 @@ VariableName = Annotated[
137
137
  examples=["my_variable"],
138
138
  ),
139
139
  ]
140
+
141
+
142
+ # URI validation
143
+ URI_REGEX = re.compile(r"^[a-z0-9]+://")
144
+
145
+
146
+ def validate_uri(value: str) -> str:
147
+ """Validate that a string is a valid URI with lowercase protocol."""
148
+ if not URI_REGEX.match(value):
149
+ raise ValueError(
150
+ "Key must be a valid URI, e.g. storage://bucket/folder/asset.csv"
151
+ )
152
+ return value
153
+
154
+
155
+ URILike = Annotated[
156
+ str,
157
+ AfterValidator(validate_uri),
158
+ Field(
159
+ description="A URI-like string with a lowercase protocol",
160
+ examples=["s3://bucket/folder/data.csv", "postgres://dbtable"],
161
+ ),
162
+ ]
@@ -286,8 +286,10 @@ def process_v1_params(
286
286
  docstrings: dict[str, str],
287
287
  aliases: dict[str, str],
288
288
  ) -> tuple[str, Any, Any]:
289
+ import pydantic.v1 as pydantic_v1
290
+
289
291
  # Pydantic model creation will fail if names collide with the BaseModel type
290
- if hasattr(pydantic.BaseModel, param.name):
292
+ if hasattr(pydantic_v1.BaseModel, param.name):
291
293
  name = param.name + "__"
292
294
  aliases[name] = param.name
293
295
  else:
@@ -296,10 +298,9 @@ def process_v1_params(
296
298
  type_ = Any if param.annotation is inspect.Parameter.empty else param.annotation
297
299
 
298
300
  with warnings.catch_warnings():
299
- warnings.filterwarnings(
300
- "ignore", category=pydantic.warnings.PydanticDeprecatedSince20
301
- )
302
- field: Any = pydantic.Field( # type: ignore # this uses the v1 signature, not v2
301
+ # Note: pydantic.v1 doesn't have the warnings module, so we can't suppress them
302
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
303
+ field: Any = pydantic_v1.Field(
303
304
  default=... if param.default is param.empty else param.default,
304
305
  title=param.name,
305
306
  description=docstrings.get(param.name, None),
@@ -312,18 +313,19 @@ def process_v1_params(
312
313
  def create_v1_schema(
313
314
  name_: str, model_cfg: type[Any], model_fields: Optional[dict[str, Any]] = None
314
315
  ) -> dict[str, Any]:
316
+ import pydantic.v1 as pydantic_v1
317
+
315
318
  with warnings.catch_warnings():
316
- warnings.filterwarnings(
317
- "ignore", category=pydantic.warnings.PydanticDeprecatedSince20
318
- )
319
+ # Note: pydantic.v1 doesn't have the warnings module, so we can't suppress them
320
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
319
321
 
320
322
  model_fields = model_fields or {}
321
- model: type[pydantic.BaseModel] = pydantic.create_model( # type: ignore # this uses the v1 signature, not v2
323
+ model: type[pydantic_v1.BaseModel] = pydantic_v1.create_model(
322
324
  name_,
323
- __config__=model_cfg, # type: ignore # this uses the v1 signature, not v2
325
+ __config__=model_cfg,
324
326
  **model_fields,
325
327
  )
326
- return model.schema(by_alias=True) # type: ignore # this uses the v1 signature, not v2
328
+ return model.schema(by_alias=True)
327
329
 
328
330
 
329
331
  def parameter_schema(fn: Callable[..., Any]) -> ParameterSchema:
@@ -80,7 +80,11 @@ async def collect_task_run_inputs(expr: Any, max_depth: int = -1) -> set[TaskRun
80
80
  inputs.add(TaskRunResult(id=obj.task_run_id))
81
81
  elif isinstance(obj, State):
82
82
  if obj.state_details.task_run_id:
83
- inputs.add(TaskRunResult(id=obj.state_details.task_run_id))
83
+ inputs.add(
84
+ TaskRunResult(
85
+ id=obj.state_details.task_run_id,
86
+ )
87
+ )
84
88
  # Expressions inside quotes should not be traversed
85
89
  elif isinstance(obj, quote):
86
90
  raise StopVisiting
@@ -118,10 +122,18 @@ def collect_task_run_inputs_sync(
118
122
 
119
123
  def add_futures_and_states_to_inputs(obj: Any) -> None:
120
124
  if isinstance(obj, future_cls) and hasattr(obj, "task_run_id"):
121
- inputs.add(TaskRunResult(id=obj.task_run_id))
125
+ inputs.add(
126
+ TaskRunResult(
127
+ id=obj.task_run_id,
128
+ )
129
+ )
122
130
  elif isinstance(obj, State):
123
131
  if obj.state_details.task_run_id:
124
- inputs.add(TaskRunResult(id=obj.state_details.task_run_id))
132
+ inputs.add(
133
+ TaskRunResult(
134
+ id=obj.state_details.task_run_id,
135
+ )
136
+ )
125
137
  # Expressions inside quotes should not be traversed
126
138
  elif isinstance(obj, quote):
127
139
  raise StopVisiting