prefect-client 3.0.0rc1__py3-none-any.whl → 3.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. prefect/_internal/compatibility/migration.py +124 -0
  2. prefect/_internal/concurrency/__init__.py +2 -2
  3. prefect/_internal/concurrency/primitives.py +1 -0
  4. prefect/_internal/pydantic/annotations/pendulum.py +2 -2
  5. prefect/_internal/pytz.py +1 -1
  6. prefect/blocks/core.py +1 -1
  7. prefect/blocks/redis.py +168 -0
  8. prefect/client/orchestration.py +113 -23
  9. prefect/client/schemas/actions.py +1 -1
  10. prefect/client/schemas/filters.py +6 -0
  11. prefect/client/schemas/objects.py +22 -11
  12. prefect/client/subscriptions.py +3 -2
  13. prefect/concurrency/asyncio.py +1 -1
  14. prefect/concurrency/services.py +1 -1
  15. prefect/context.py +1 -27
  16. prefect/deployments/__init__.py +3 -0
  17. prefect/deployments/base.py +11 -3
  18. prefect/deployments/deployments.py +3 -0
  19. prefect/deployments/steps/pull.py +1 -0
  20. prefect/deployments/steps/utility.py +2 -1
  21. prefect/engine.py +3 -0
  22. prefect/events/cli/automations.py +1 -1
  23. prefect/events/clients.py +7 -1
  24. prefect/events/schemas/events.py +2 -0
  25. prefect/exceptions.py +9 -0
  26. prefect/filesystems.py +22 -11
  27. prefect/flow_engine.py +118 -156
  28. prefect/flow_runs.py +2 -2
  29. prefect/flows.py +91 -35
  30. prefect/futures.py +44 -43
  31. prefect/infrastructure/provisioners/container_instance.py +1 -0
  32. prefect/infrastructure/provisioners/ecs.py +2 -2
  33. prefect/input/__init__.py +4 -0
  34. prefect/input/run_input.py +4 -2
  35. prefect/logging/formatters.py +2 -2
  36. prefect/logging/handlers.py +2 -2
  37. prefect/logging/loggers.py +1 -1
  38. prefect/plugins.py +1 -0
  39. prefect/records/cache_policies.py +179 -0
  40. prefect/records/result_store.py +10 -3
  41. prefect/results.py +27 -55
  42. prefect/runner/runner.py +1 -1
  43. prefect/runner/server.py +1 -1
  44. prefect/runtime/__init__.py +1 -0
  45. prefect/runtime/deployment.py +1 -0
  46. prefect/runtime/flow_run.py +1 -0
  47. prefect/runtime/task_run.py +1 -0
  48. prefect/settings.py +21 -5
  49. prefect/states.py +17 -4
  50. prefect/task_engine.py +337 -209
  51. prefect/task_runners.py +15 -5
  52. prefect/task_runs.py +203 -0
  53. prefect/{task_server.py → task_worker.py} +66 -36
  54. prefect/tasks.py +180 -77
  55. prefect/transactions.py +92 -16
  56. prefect/types/__init__.py +1 -1
  57. prefect/utilities/asyncutils.py +3 -3
  58. prefect/utilities/callables.py +90 -7
  59. prefect/utilities/dockerutils.py +5 -3
  60. prefect/utilities/engine.py +11 -0
  61. prefect/utilities/filesystem.py +4 -5
  62. prefect/utilities/importtools.py +34 -5
  63. prefect/utilities/services.py +2 -2
  64. prefect/utilities/urls.py +195 -0
  65. prefect/utilities/visualization.py +1 -0
  66. prefect/variables.py +19 -10
  67. prefect/workers/base.py +46 -1
  68. {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc3.dist-info}/METADATA +3 -2
  69. {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc3.dist-info}/RECORD +72 -66
  70. {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc3.dist-info}/LICENSE +0 -0
  71. {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc3.dist-info}/WHEEL +0 -0
  72. {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc3.dist-info}/top_level.txt +0 -0
prefect/tasks.py CHANGED
@@ -22,6 +22,7 @@ from typing import (
22
22
  Optional,
23
23
  Set,
24
24
  Tuple,
25
+ Type,
25
26
  TypeVar,
26
27
  Union,
27
28
  cast,
@@ -43,6 +44,7 @@ from prefect.context import (
43
44
  )
44
45
  from prefect.futures import PrefectDistributedFuture, PrefectFuture
45
46
  from prefect.logging.loggers import get_logger
47
+ from prefect.records.cache_policies import DEFAULT, CachePolicy
46
48
  from prefect.results import ResultFactory, ResultSerializer, ResultStorage
47
49
  from prefect.settings import (
48
50
  PREFECT_TASK_DEFAULT_RETRIES,
@@ -62,7 +64,6 @@ from prefect.utilities.importtools import to_qualified_name
62
64
  if TYPE_CHECKING:
63
65
  from prefect.client.orchestration import PrefectClient
64
66
  from prefect.context import TaskRunContext
65
- from prefect.task_runners import BaseTaskRunner
66
67
  from prefect.transactions import Transaction
67
68
 
68
69
  T = TypeVar("T") # Generic type var for capturing the inner return type of async funcs
@@ -122,6 +123,57 @@ def exponential_backoff(backoff_factor: float) -> Callable[[int], List[float]]:
122
123
  return retry_backoff_callable
123
124
 
124
125
 
126
+ def _infer_parent_task_runs(
127
+ flow_run_context: Optional[FlowRunContext],
128
+ task_run_context: Optional[TaskRunContext],
129
+ parameters: Dict[str, Any],
130
+ ):
131
+ """
132
+ Attempt to infer the parent task runs for this task run based on the
133
+ provided flow run and task run contexts, as well as any parameters. It is
134
+ assumed that the task run is running within those contexts.
135
+ If any parameter comes from a running task run, that task run is considered
136
+ a parent. This is expected to happen when task inputs are yielded from
137
+ generator tasks.
138
+ """
139
+ parents = []
140
+
141
+ # check if this task has a parent task run based on running in another
142
+ # task run's existing context. A task run is only considered a parent if
143
+ # it is in the same flow run (because otherwise presumably the child is
144
+ # in a subflow, so the subflow serves as the parent) or if there is no
145
+ # flow run
146
+ if task_run_context:
147
+ # there is no flow run
148
+ if not flow_run_context:
149
+ parents.append(TaskRunResult(id=task_run_context.task_run.id))
150
+ # there is a flow run and the task run is in the same flow run
151
+ elif flow_run_context and task_run_context.task_run.flow_run_id == getattr(
152
+ flow_run_context.flow_run, "id", None
153
+ ):
154
+ parents.append(TaskRunResult(id=task_run_context.task_run.id))
155
+
156
+ # parent dependency tracking: for every provided parameter value, try to
157
+ # load the corresponding task run state. If the task run state is still
158
+ # running, we consider it a parent task run. Note this is only done if
159
+ # there is an active flow run context because dependencies are only
160
+ # tracked within the same flow run.
161
+ if flow_run_context:
162
+ for v in parameters.values():
163
+ if isinstance(v, State):
164
+ upstream_state = v
165
+ elif isinstance(v, PrefectFuture):
166
+ upstream_state = v.state
167
+ else:
168
+ upstream_state = flow_run_context.task_run_results.get(id(v))
169
+ if upstream_state and upstream_state.is_running():
170
+ parents.append(
171
+ TaskRunResult(id=upstream_state.state_details.task_run_id)
172
+ )
173
+
174
+ return parents
175
+
176
+
125
177
  @PrefectObjectRegistry.register_instances
126
178
  class Task(Generic[P, R]):
127
179
  """
@@ -145,6 +197,7 @@ class Task(Generic[P, R]):
145
197
  tags are combined with any tags defined by a `prefect.tags` context at
146
198
  task runtime.
147
199
  version: An optional string specifying the version of this task definition
200
+ cache_policy: A cache policy that determines the level of caching for this task
148
201
  cache_key_fn: An optional callable that, given the task run context and call
149
202
  parameters, generates a string key; if the key matches a previous completed
150
203
  state, that state result will be restored instead of running the task again.
@@ -204,6 +257,7 @@ class Task(Generic[P, R]):
204
257
  description: Optional[str] = None,
205
258
  tags: Optional[Iterable[str]] = None,
206
259
  version: Optional[str] = None,
260
+ cache_policy: Optional[CachePolicy] = NotSet,
207
261
  cache_key_fn: Optional[
208
262
  Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
209
263
  ] = None,
@@ -266,7 +320,18 @@ class Task(Generic[P, R]):
266
320
  self.description = description or inspect.getdoc(fn)
267
321
  update_wrapper(self, fn)
268
322
  self.fn = fn
269
- self.isasync = inspect.iscoroutinefunction(self.fn)
323
+
324
+ # the task is considered async if its function is async or an async
325
+ # generator
326
+ self.isasync = inspect.iscoroutinefunction(
327
+ self.fn
328
+ ) or inspect.isasyncgenfunction(self.fn)
329
+
330
+ # the task is considered a generator if its function is a generator or
331
+ # an async generator
332
+ self.isgenerator = inspect.isgeneratorfunction(
333
+ self.fn
334
+ ) or inspect.isasyncgenfunction(self.fn)
270
335
 
271
336
  if not name:
272
337
  if not hasattr(self.fn, "__name__"):
@@ -303,10 +368,23 @@ class Task(Generic[P, R]):
303
368
 
304
369
  self.task_key = f"{self.fn.__qualname__}-{task_origin_hash}"
305
370
 
371
+ # TODO: warn of precedence of cache policies and cache key fn if both provided?
372
+ if cache_key_fn:
373
+ cache_policy = CachePolicy.from_cache_key_fn(cache_key_fn)
374
+
375
+ # TODO: manage expiration and cache refresh
306
376
  self.cache_key_fn = cache_key_fn
307
377
  self.cache_expiration = cache_expiration
308
378
  self.refresh_cache = refresh_cache
309
379
 
380
+ if cache_policy is NotSet and result_storage_key is None:
381
+ self.cache_policy = DEFAULT
382
+ elif result_storage_key:
383
+ # TODO: handle this situation with double storage
384
+ self.cache_policy = None
385
+ else:
386
+ self.cache_policy = cache_policy
387
+
310
388
  # TaskRunPolicy settings
311
389
  # TODO: We can instantiate a `TaskRunPolicy` and add Pydantic bound checks to
312
390
  # validate that the user passes positive numbers here
@@ -352,33 +430,57 @@ class Task(Generic[P, R]):
352
430
  self.retry_condition_fn = retry_condition_fn
353
431
  self.viz_return_value = viz_return_value
354
432
 
433
+ @property
434
+ def ismethod(self) -> bool:
435
+ return hasattr(self.fn, "__prefect_self__")
436
+
437
+ def __get__(self, instance, owner):
438
+ """
439
+ Implement the descriptor protocol so that the task can be used as an instance method.
440
+ When an instance method is loaded, this method is called with the "self" instance as
441
+ an argument. We return a copy of the task with that instance bound to the task's function.
442
+ """
443
+
444
+ # if no instance is provided, it's being accessed on the class
445
+ if instance is None:
446
+ return self
447
+
448
+ # if the task is being accessed on an instance, bind the instance to the __prefect_self__ attribute
449
+ # of the task's function. This will allow it to be automatically added to the task's parameters
450
+ else:
451
+ bound_task = copy(self)
452
+ bound_task.fn.__prefect_self__ = instance
453
+ return bound_task
454
+
355
455
  def with_options(
356
456
  self,
357
457
  *,
358
- name: str = None,
359
- description: str = None,
360
- tags: Iterable[str] = None,
361
- cache_key_fn: Callable[
362
- ["TaskRunContext", Dict[str, Any]], Optional[str]
458
+ name: Optional[str] = None,
459
+ description: Optional[str] = None,
460
+ tags: Optional[Iterable[str]] = None,
461
+ cache_policy: Union[CachePolicy, Type[NotSet]] = NotSet,
462
+ cache_key_fn: Optional[
463
+ Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
363
464
  ] = None,
364
465
  task_run_name: Optional[Union[Callable[[], str], str]] = None,
365
- cache_expiration: datetime.timedelta = None,
366
- retries: Optional[int] = NotSet,
466
+ cache_expiration: Optional[datetime.timedelta] = None,
467
+ retries: Union[int, Type[NotSet]] = NotSet,
367
468
  retry_delay_seconds: Union[
368
469
  float,
369
470
  int,
370
471
  List[float],
371
472
  Callable[[int], List[float]],
473
+ Type[NotSet],
372
474
  ] = NotSet,
373
- retry_jitter_factor: Optional[float] = NotSet,
374
- persist_result: Optional[bool] = NotSet,
375
- result_storage: Optional[ResultStorage] = NotSet,
376
- result_serializer: Optional[ResultSerializer] = NotSet,
377
- result_storage_key: Optional[str] = NotSet,
475
+ retry_jitter_factor: Union[float, Type[NotSet]] = NotSet,
476
+ persist_result: Union[bool, Type[NotSet]] = NotSet,
477
+ result_storage: Union[ResultStorage, Type[NotSet]] = NotSet,
478
+ result_serializer: Union[ResultSerializer, Type[NotSet]] = NotSet,
479
+ result_storage_key: Union[str, Type[NotSet]] = NotSet,
378
480
  cache_result_in_memory: Optional[bool] = None,
379
- timeout_seconds: Union[int, float] = None,
380
- log_prints: Optional[bool] = NotSet,
381
- refresh_cache: Optional[bool] = NotSet,
481
+ timeout_seconds: Union[int, float, None] = None,
482
+ log_prints: Union[bool, Type[NotSet]] = NotSet,
483
+ refresh_cache: Union[bool, Type[NotSet]] = NotSet,
382
484
  on_completion: Optional[
383
485
  List[Callable[["Task", TaskRun, State], Union[Awaitable[None], None]]]
384
486
  ] = None,
@@ -469,6 +571,9 @@ class Task(Generic[P, R]):
469
571
  name=name or self.name,
470
572
  description=description or self.description,
471
573
  tags=tags or copy(self.tags),
574
+ cache_policy=cache_policy
575
+ if cache_policy is not NotSet
576
+ else self.cache_policy,
472
577
  cache_key_fn=cache_key_fn or self.cache_key_fn,
473
578
  cache_expiration=cache_expiration or self.cache_expiration,
474
579
  task_run_name=task_run_name,
@@ -569,7 +674,7 @@ class Task(Generic[P, R]):
569
674
  async with client:
570
675
  if not flow_run_context:
571
676
  dynamic_key = f"{self.task_key}-{str(uuid4().hex)}"
572
- task_run_name = f"{self.name}-{dynamic_key[:NUM_CHARS_DYNAMIC_KEY]}"
677
+ task_run_name = self.name
573
678
  else:
574
679
  dynamic_key = _dynamic_key_for_task_run(
575
680
  context=flow_run_context, task=self
@@ -582,7 +687,7 @@ class Task(Generic[P, R]):
582
687
  else:
583
688
  state = Pending()
584
689
 
585
- # store parameters for background tasks so that task servers
690
+ # store parameters for background tasks so that task worker
586
691
  # can retrieve them at runtime
587
692
  if deferred and (parameters or wait_for):
588
693
  parameters_id = uuid4()
@@ -605,27 +710,15 @@ class Task(Generic[P, R]):
605
710
  k: collect_task_run_inputs_sync(v) for k, v in parameters.items()
606
711
  }
607
712
 
608
- # check if this task has a parent task run based on running in another
609
- # task run's existing context. A task run is only considered a parent if
610
- # it is in the same flow run (because otherwise presumably the child is
611
- # in a subflow, so the subflow serves as the parent) or if there is no
612
- # flow run
613
- if parent_task_run_context:
614
- # there is no flow run
615
- if not flow_run_context:
616
- task_inputs["__parents__"] = [
617
- TaskRunResult(id=parent_task_run_context.task_run.id)
618
- ]
619
- # there is a flow run and the task run is in the same flow run
620
- elif (
621
- flow_run_context
622
- and parent_task_run_context.task_run.flow_run_id
623
- == getattr(flow_run_context.flow_run, "id", None)
624
- ):
625
- task_inputs["__parents__"] = [
626
- TaskRunResult(id=parent_task_run_context.task_run.id)
627
- ]
713
+ # collect all parent dependencies
714
+ if task_parents := _infer_parent_task_runs(
715
+ flow_run_context=flow_run_context,
716
+ task_run_context=parent_task_run_context,
717
+ parameters=parameters,
718
+ ):
719
+ task_inputs["__parents__"] = task_parents
628
720
 
721
+ # check wait for dependencies
629
722
  if wait_for:
630
723
  task_inputs["wait_for"] = collect_task_run_inputs_sync(wait_for)
631
724
 
@@ -755,8 +848,6 @@ class Task(Generic[P, R]):
755
848
  """
756
849
  Submit a run of the task to the engine.
757
850
 
758
- If writing an async task, this call must be awaited.
759
-
760
851
  Will create a new task run in the backing API and submit the task to the flow's
761
852
  task runner. This call only blocks execution while the task is being submitted,
762
853
  once it is submitted, the flow function will continue executing.
@@ -849,7 +940,11 @@ class Task(Generic[P, R]):
849
940
  flow_run_context = FlowRunContext.get()
850
941
 
851
942
  if not flow_run_context:
852
- raise ValueError("Task.submit() must be called within a flow")
943
+ raise RuntimeError(
944
+ "Unable to determine task runner to use for submission. If you are"
945
+ " submitting a task outside of a flow, please use `.delay`"
946
+ " to submit the task run for deferred execution."
947
+ )
853
948
 
854
949
  task_viz_tracker = get_task_viz_tracker()
855
950
  if task_viz_tracker:
@@ -897,6 +992,7 @@ class Task(Generic[P, R]):
897
992
  *args: Any,
898
993
  return_state: bool = False,
899
994
  wait_for: Optional[Iterable[PrefectFuture]] = None,
995
+ deferred: bool = False,
900
996
  **kwargs: Any,
901
997
  ):
902
998
  """
@@ -1010,6 +1106,7 @@ class Task(Generic[P, R]):
1010
1106
  [[11, 21], [12, 22], [13, 23]]
1011
1107
  """
1012
1108
 
1109
+ from prefect.task_runners import TaskRunner
1013
1110
  from prefect.utilities.visualization import (
1014
1111
  VisualizationUnsupportedError,
1015
1112
  get_task_viz_tracker,
@@ -1026,22 +1123,22 @@ class Task(Generic[P, R]):
1026
1123
  "`task.map()` is not currently supported by `flow.visualize()`"
1027
1124
  )
1028
1125
 
1029
- if not flow_run_context:
1030
- # TODO: Should we split out background task mapping into a separate method
1031
- # like we do for the `submit`/`apply_async` split?
1126
+ if deferred:
1032
1127
  parameters_list = expand_mapping_parameters(self.fn, parameters)
1033
- # TODO: Make this non-blocking once we can return a list of futures
1034
- # instead of a list of task runs
1035
- return [
1036
- run_coro_as_sync(self.create_run(parameters=parameters, deferred=True))
1128
+ futures = [
1129
+ self.apply_async(kwargs=parameters, wait_for=wait_for)
1037
1130
  for parameters in parameters_list
1038
1131
  ]
1039
-
1040
- from prefect.task_runners import TaskRunner
1041
-
1042
- task_runner = flow_run_context.task_runner
1043
- assert isinstance(task_runner, TaskRunner)
1044
- futures = task_runner.map(self, parameters, wait_for)
1132
+ elif task_runner := getattr(flow_run_context, "task_runner", None):
1133
+ assert isinstance(task_runner, TaskRunner)
1134
+ futures = task_runner.map(self, parameters, wait_for)
1135
+ else:
1136
+ raise RuntimeError(
1137
+ "Unable to determine task runner to use for mapped task runs. If"
1138
+ " you are mapping a task outside of a flow, please provide"
1139
+ " `deferred=True` to submit the mapped task runs for deferred"
1140
+ " execution."
1141
+ )
1045
1142
  if return_state:
1046
1143
  states = []
1047
1144
  for future in futures:
@@ -1059,7 +1156,7 @@ class Task(Generic[P, R]):
1059
1156
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
1060
1157
  ) -> PrefectDistributedFuture:
1061
1158
  """
1062
- Create a pending task run for a task server to execute.
1159
+ Create a pending task run for a task worker to execute.
1063
1160
 
1064
1161
  Args:
1065
1162
  args: Arguments to run the task with
@@ -1181,7 +1278,7 @@ class Task(Generic[P, R]):
1181
1278
  """
1182
1279
  return self.apply_async(args=args, kwargs=kwargs)
1183
1280
 
1184
- def serve(self, task_runner: Optional["BaseTaskRunner"] = None) -> "Task":
1281
+ def serve(self) -> "Task":
1185
1282
  """Serve the task using the provided task runner. This method is used to
1186
1283
  establish a websocket connection with the Prefect server and listen for
1187
1284
  submitted task runs to execute.
@@ -1198,9 +1295,9 @@ class Task(Generic[P, R]):
1198
1295
 
1199
1296
  >>> my_task.serve()
1200
1297
  """
1201
- from prefect.task_server import serve
1298
+ from prefect.task_worker import serve
1202
1299
 
1203
- serve(self, task_runner=task_runner)
1300
+ serve(self)
1204
1301
 
1205
1302
 
1206
1303
  @overload
@@ -1211,12 +1308,15 @@ def task(__fn: Callable[P, R]) -> Task[P, R]:
1211
1308
  @overload
1212
1309
  def task(
1213
1310
  *,
1214
- name: str = None,
1215
- description: str = None,
1216
- tags: Iterable[str] = None,
1217
- version: str = None,
1218
- cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] = None,
1219
- cache_expiration: datetime.timedelta = None,
1311
+ name: Optional[str] = None,
1312
+ description: Optional[str] = None,
1313
+ tags: Optional[Iterable[str]] = None,
1314
+ version: Optional[str] = None,
1315
+ cache_policy: CachePolicy = NotSet,
1316
+ cache_key_fn: Optional[
1317
+ Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
1318
+ ] = None,
1319
+ cache_expiration: Optional[datetime.timedelta] = None,
1220
1320
  task_run_name: Optional[Union[Callable[[], str], str]] = None,
1221
1321
  retries: int = 0,
1222
1322
  retry_delay_seconds: Union[
@@ -1231,7 +1331,7 @@ def task(
1231
1331
  result_storage_key: Optional[str] = None,
1232
1332
  result_serializer: Optional[ResultSerializer] = None,
1233
1333
  cache_result_in_memory: bool = True,
1234
- timeout_seconds: Union[int, float] = None,
1334
+ timeout_seconds: Union[int, float, None] = None,
1235
1335
  log_prints: Optional[bool] = None,
1236
1336
  refresh_cache: Optional[bool] = None,
1237
1337
  on_completion: Optional[List[Callable[["Task", TaskRun, State], None]]] = None,
@@ -1245,19 +1345,17 @@ def task(
1245
1345
  def task(
1246
1346
  __fn=None,
1247
1347
  *,
1248
- name: str = None,
1249
- description: str = None,
1250
- tags: Iterable[str] = None,
1251
- version: str = None,
1348
+ name: Optional[str] = None,
1349
+ description: Optional[str] = None,
1350
+ tags: Optional[Iterable[str]] = None,
1351
+ version: Optional[str] = None,
1352
+ cache_policy: Union[CachePolicy, Type[NotSet]] = NotSet,
1252
1353
  cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] = None,
1253
- cache_expiration: datetime.timedelta = None,
1354
+ cache_expiration: Optional[datetime.timedelta] = None,
1254
1355
  task_run_name: Optional[Union[Callable[[], str], str]] = None,
1255
- retries: int = None,
1356
+ retries: Optional[int] = None,
1256
1357
  retry_delay_seconds: Union[
1257
- float,
1258
- int,
1259
- List[float],
1260
- Callable[[int], List[float]],
1358
+ float, int, List[float], Callable[[int], List[float]], None
1261
1359
  ] = None,
1262
1360
  retry_jitter_factor: Optional[float] = None,
1263
1361
  persist_result: Optional[bool] = None,
@@ -1265,7 +1363,7 @@ def task(
1265
1363
  result_storage_key: Optional[str] = None,
1266
1364
  result_serializer: Optional[ResultSerializer] = None,
1267
1365
  cache_result_in_memory: bool = True,
1268
- timeout_seconds: Union[int, float] = None,
1366
+ timeout_seconds: Union[int, float, None] = None,
1269
1367
  log_prints: Optional[bool] = None,
1270
1368
  refresh_cache: Optional[bool] = None,
1271
1369
  on_completion: Optional[List[Callable[["Task", TaskRun, State], None]]] = None,
@@ -1383,6 +1481,9 @@ def task(
1383
1481
  """
1384
1482
 
1385
1483
  if __fn:
1484
+ if isinstance(__fn, (classmethod, staticmethod)):
1485
+ method_decorator = type(__fn).__name__
1486
+ raise TypeError(f"@{method_decorator} should be applied on top of @task")
1386
1487
  return cast(
1387
1488
  Task[P, R],
1388
1489
  Task(
@@ -1391,6 +1492,7 @@ def task(
1391
1492
  description=description,
1392
1493
  tags=tags,
1393
1494
  version=version,
1495
+ cache_policy=cache_policy,
1394
1496
  cache_key_fn=cache_key_fn,
1395
1497
  cache_expiration=cache_expiration,
1396
1498
  task_run_name=task_run_name,
@@ -1420,6 +1522,7 @@ def task(
1420
1522
  description=description,
1421
1523
  tags=tags,
1422
1524
  version=version,
1525
+ cache_policy=cache_policy,
1423
1526
  cache_key_fn=cache_key_fn,
1424
1527
  cache_expiration=cache_expiration,
1425
1528
  task_run_name=task_run_name,
prefect/transactions.py CHANGED
@@ -7,17 +7,19 @@ from typing import (
7
7
  List,
8
8
  Optional,
9
9
  Type,
10
- TypeVar,
11
10
  )
12
11
 
13
12
  from pydantic import Field
13
+ from typing_extensions import Self
14
14
 
15
- from prefect.context import ContextModel
15
+ from prefect.context import ContextModel, FlowRunContext, TaskRunContext
16
16
  from prefect.records import RecordStore
17
+ from prefect.records.result_store import ResultFactoryStore
18
+ from prefect.results import BaseResult, ResultFactory, get_default_result_storage
19
+ from prefect.settings import PREFECT_DEFAULT_RESULT_STORAGE_BLOCK
20
+ from prefect.utilities.asyncutils import run_coro_as_sync
17
21
  from prefect.utilities.collections import AutoEnum
18
22
 
19
- T = TypeVar("T")
20
-
21
23
 
22
24
  class IsolationLevel(AutoEnum):
23
25
  READ_COMMITTED = AutoEnum.auto()
@@ -52,8 +54,9 @@ class Transaction(ContextModel):
52
54
  on_rollback_hooks: List[Callable[["Transaction"], None]] = Field(
53
55
  default_factory=list
54
56
  )
57
+ overwrite: bool = False
55
58
  _staged_value: Any = None
56
- __var__ = ContextVar("transaction")
59
+ __var__: ContextVar = ContextVar("transaction")
57
60
 
58
61
  def is_committed(self) -> bool:
59
62
  return self.state == TransactionState.COMMITTED
@@ -91,7 +94,8 @@ class Transaction(ContextModel):
91
94
  self._token = self.__var__.set(self)
92
95
  return self
93
96
 
94
- def __exit__(self, exc_type, exc_val, exc_tb):
97
+ def __exit__(self, *exc_info):
98
+ exc_type, exc_val, _ = exc_info
95
99
  if not self._token:
96
100
  raise RuntimeError(
97
101
  "Asymmetric use of context. Context exit called without an enter."
@@ -122,11 +126,19 @@ class Transaction(ContextModel):
122
126
  def begin(self):
123
127
  # currently we only support READ_COMMITTED isolation
124
128
  # i.e., no locking behavior
125
- if self.store and self.store.exists(key=self.key):
129
+ if (
130
+ not self.overwrite
131
+ and self.store
132
+ and self.key
133
+ and self.store.exists(key=self.key)
134
+ ):
126
135
  self.state = TransactionState.COMMITTED
127
136
 
128
- def read(self) -> dict:
129
- return self.store.read(key=self.key)
137
+ def read(self) -> BaseResult:
138
+ if self.store and self.key:
139
+ return self.store.read(key=self.key)
140
+ else:
141
+ return {} # TODO: Determine what this should be
130
142
 
131
143
  def reset(self) -> None:
132
144
  parent = self.get_parent()
@@ -135,8 +147,9 @@ class Transaction(ContextModel):
135
147
  # parent takes responsibility
136
148
  parent.add_child(self)
137
149
 
138
- self.__var__.reset(self._token)
139
- self._token = None
150
+ if self._token:
151
+ self.__var__.reset(self._token)
152
+ self._token = None
140
153
 
141
154
  # do this below reset so that get_transaction() returns the relevant txn
142
155
  if parent and self.state == TransactionState.ROLLED_BACK:
@@ -164,7 +177,7 @@ class Transaction(ContextModel):
164
177
  for hook in self.on_commit_hooks:
165
178
  hook(self)
166
179
 
167
- if self.store:
180
+ if self.store and self.key:
168
181
  self.store.write(key=self.key, value=self._staged_value)
169
182
  self.state = TransactionState.COMMITTED
170
183
  return True
@@ -173,11 +186,17 @@ class Transaction(ContextModel):
173
186
  return False
174
187
 
175
188
  def stage(
176
- self, value: dict, on_rollback_hooks: list, on_commit_hooks: list
189
+ self,
190
+ value: BaseResult,
191
+ on_rollback_hooks: Optional[List] = None,
192
+ on_commit_hooks: Optional[List] = None,
177
193
  ) -> None:
178
194
  """
179
195
  Stage a value to be committed later.
180
196
  """
197
+ on_commit_hooks = on_commit_hooks or []
198
+ on_rollback_hooks = on_rollback_hooks or []
199
+
181
200
  if self.state != TransactionState.COMMITTED:
182
201
  self._staged_value = value
183
202
  self.on_rollback_hooks += on_rollback_hooks
@@ -202,11 +221,11 @@ class Transaction(ContextModel):
202
221
  return False
203
222
 
204
223
  @classmethod
205
- def get_active(cls: Type[T]) -> Optional[T]:
224
+ def get_active(cls: Type[Self]) -> Optional[Self]:
206
225
  return cls.__var__.get(None)
207
226
 
208
227
 
209
- def get_transaction() -> Transaction:
228
+ def get_transaction() -> Optional[Transaction]:
210
229
  return Transaction.get_active()
211
230
 
212
231
 
@@ -215,6 +234,63 @@ def transaction(
215
234
  key: Optional[str] = None,
216
235
  store: Optional[RecordStore] = None,
217
236
  commit_mode: CommitMode = CommitMode.LAZY,
237
+ overwrite: bool = False,
218
238
  ) -> Generator[Transaction, None, None]:
219
- with Transaction(key=key, store=store, commit_mode=commit_mode) as txn:
239
+ """
240
+ A context manager for opening and managing a transaction.
241
+
242
+ Args:
243
+ - key: An identifier to use for the transaction
244
+ - store: The store to use for persisting the transaction result. If not provided,
245
+ a default store will be used based on the current run context.
246
+ - commit_mode: The commit mode controlling when the transaction and
247
+ child transactions are committed
248
+ - overwrite: Whether to overwrite an existing transaction record in the store
249
+
250
+ Yields:
251
+ - Transaction: An object representing the transaction state
252
+ """
253
+ # if there is no key, we won't persist a record
254
+ if key and not store:
255
+ flow_run_context = FlowRunContext.get()
256
+ task_run_context = TaskRunContext.get()
257
+ existing_factory = getattr(task_run_context, "result_factory", None) or getattr(
258
+ flow_run_context, "result_factory", None
259
+ )
260
+
261
+ if existing_factory and existing_factory.storage_block_id:
262
+ new_factory = existing_factory.model_copy(
263
+ update={
264
+ "persist_result": True,
265
+ }
266
+ )
267
+ else:
268
+ default_storage = get_default_result_storage(_sync=True)
269
+ if not default_storage._block_document_id:
270
+ default_name = PREFECT_DEFAULT_RESULT_STORAGE_BLOCK.value().split("/")[
271
+ -1
272
+ ]
273
+ default_storage.save(default_name, overwrite=True, _sync=True)
274
+ if existing_factory:
275
+ new_factory = existing_factory.model_copy(
276
+ update={
277
+ "persist_result": True,
278
+ "storage_block": default_storage,
279
+ "storage_block_id": default_storage._block_document_id,
280
+ }
281
+ )
282
+ else:
283
+ new_factory = run_coro_as_sync(
284
+ ResultFactory.default_factory(
285
+ persist_result=True,
286
+ result_storage=default_storage,
287
+ )
288
+ )
289
+ store = ResultFactoryStore(
290
+ result_factory=new_factory,
291
+ )
292
+
293
+ with Transaction(
294
+ key=key, store=store, commit_mode=commit_mode, overwrite=overwrite
295
+ ) as txn:
220
296
  yield txn
prefect/types/__init__.py CHANGED
@@ -20,7 +20,7 @@ timezone_set = available_timezones()
20
20
  NonNegativeInteger = Annotated[int, Field(ge=0)]
21
21
  PositiveInteger = Annotated[int, Field(gt=0)]
22
22
  NonNegativeFloat = Annotated[float, Field(ge=0.0)]
23
- TimeZone = Annotated[str, Field(default="UTC", pattern="|".join(timezone_set))]
23
+ TimeZone = Annotated[str, Field(default="UTC", pattern="|".join(sorted(timezone_set)))]
24
24
 
25
25
 
26
26
  BANNED_CHARACTERS = ["/", "%", "&", ">", "<"]
@@ -314,7 +314,7 @@ def sync_compatible(async_fn: T, force_sync: bool = False) -> T:
314
314
  """
315
315
 
316
316
  @wraps(async_fn)
317
- def coroutine_wrapper(*args, _sync: bool = None, **kwargs):
317
+ def coroutine_wrapper(*args, _sync: Optional[bool] = None, **kwargs):
318
318
  from prefect.context import MissingContextError, get_run_context
319
319
  from prefect.settings import (
320
320
  PREFECT_EXPERIMENTAL_DISABLE_SYNC_COMPAT,
@@ -376,8 +376,8 @@ def sync_compatible(async_fn: T, force_sync: bool = False) -> T:
376
376
 
377
377
 
378
378
  @asynccontextmanager
379
- async def asyncnullcontext():
380
- yield
379
+ async def asyncnullcontext(value=None):
380
+ yield value
381
381
 
382
382
 
383
383
  def sync(__async_fn: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: