prefect-client 3.0.0rc1__py3-none-any.whl → 3.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prefect/task_runs.py ADDED
@@ -0,0 +1,203 @@
1
+ import asyncio
2
+ import atexit
3
+ import threading
4
+ import uuid
5
+ from typing import Dict, Optional
6
+
7
+ import anyio
8
+ from cachetools import TTLCache
9
+ from typing_extensions import Self
10
+
11
+ from prefect._internal.concurrency.api import create_call, from_async, from_sync
12
+ from prefect._internal.concurrency.threads import get_global_loop
13
+ from prefect.client.schemas.objects import TERMINAL_STATES
14
+ from prefect.events.clients import get_events_subscriber
15
+ from prefect.events.filters import EventFilter, EventNameFilter
16
+ from prefect.logging.loggers import get_logger
17
+
18
+
19
+ class TaskRunWaiter:
20
+ """
21
+ A service used for waiting for a task run to finish.
22
+
23
+ This service listens for task run events and provides a way to wait for a specific
24
+ task run to finish. This is useful for waiting for a task run to finish before
25
+ continuing execution.
26
+
27
+ The service is a singleton and must be started before use. The service will
28
+ automatically start when the first instance is created. A single websocket
29
+ connection is used to listen for task run events.
30
+
31
+ The service can be used to wait for a task run to finish by calling
32
+ `TaskRunWaiter.wait_for_task_run` with the task run ID to wait for. The method
33
+ will return when the task run has finished or the timeout has elapsed.
34
+
35
+ The service will automatically stop when the Python process exits or when the
36
+ global loop thread is stopped.
37
+
38
+ Example:
39
+ ```python
40
+ import asyncio
41
+ from uuid import uuid4
42
+
43
+ from prefect import task
44
+ from prefect.task_engine import run_task_async
45
+ from prefect.task_runs import TaskRunWaiter
46
+
47
+
48
+ @task
49
+ async def test_task():
50
+ await asyncio.sleep(5)
51
+ print("Done!")
52
+
53
+
54
+ async def main():
55
+ task_run_id = uuid4()
56
+ asyncio.create_task(run_task_async(task=test_task, task_run_id=task_run_id))
57
+
58
+ await TaskRunWaiter.wait_for_task_run(task_run_id)
59
+ print("Task run finished")
60
+
61
+
62
+ if __name__ == "__main__":
63
+ asyncio.run(main())
64
+ ```
65
+ """
66
+
67
+ _instance: Optional[Self] = None
68
+ _instance_lock = threading.Lock()
69
+
70
+ def __init__(self):
71
+ self.logger = get_logger("TaskRunWaiter")
72
+ self._consumer_task: Optional[asyncio.Task] = None
73
+ self._observed_completed_task_runs: TTLCache[uuid.UUID, bool] = TTLCache(
74
+ maxsize=100, ttl=60
75
+ )
76
+ self._completion_events: Dict[uuid.UUID, asyncio.Event] = {}
77
+ self._loop: Optional[asyncio.AbstractEventLoop] = None
78
+ self._observed_completed_task_runs_lock = threading.Lock()
79
+ self._completion_events_lock = threading.Lock()
80
+ self._started = False
81
+
82
+ def start(self):
83
+ """
84
+ Start the TaskRunWaiter service.
85
+ """
86
+ if self._started:
87
+ return
88
+ self.logger.info("Starting TaskRunWaiter")
89
+ loop_thread = get_global_loop()
90
+
91
+ if not asyncio.get_running_loop() == loop_thread._loop:
92
+ raise RuntimeError("TaskRunWaiter must run on the global loop thread.")
93
+
94
+ self._loop = loop_thread._loop
95
+ self._consumer_task = self._loop.create_task(self._consume_events())
96
+
97
+ loop_thread.add_shutdown_call(create_call(self.stop))
98
+ atexit.register(self.stop)
99
+ self._started = True
100
+
101
+ async def _consume_events(self):
102
+ async with get_events_subscriber(
103
+ filter=EventFilter(
104
+ event=EventNameFilter(
105
+ name=[
106
+ f"prefect.task-run.{state.name.title()}"
107
+ for state in TERMINAL_STATES
108
+ ],
109
+ )
110
+ )
111
+ ) as subscriber:
112
+ async for event in subscriber:
113
+ try:
114
+ self.logger.info(
115
+ f"Received event: {event.resource['prefect.resource.id']}"
116
+ )
117
+ task_run_id = uuid.UUID(
118
+ event.resource["prefect.resource.id"].replace(
119
+ "prefect.task-run.", ""
120
+ )
121
+ )
122
+ with self._observed_completed_task_runs_lock:
123
+ # Cache the task run ID for a short period of time to avoid
124
+ # unnecessary waits
125
+ self._observed_completed_task_runs[task_run_id] = True
126
+ with self._completion_events_lock:
127
+ # Set the event for the task run ID if it is in the cache
128
+ # so the waiter can wake up the waiting coroutine
129
+ if task_run_id in self._completion_events:
130
+ self._completion_events[task_run_id].set()
131
+ except Exception as exc:
132
+ self.logger.error(f"Error processing event: {exc}")
133
+
134
+ def stop(self):
135
+ """
136
+ Stop the TaskRunWaiter service.
137
+ """
138
+ self.logger.debug("Stopping TaskRunWaiter")
139
+ if self._consumer_task:
140
+ self._consumer_task.cancel()
141
+ self._consumer_task = None
142
+ self.__class__._instance = None
143
+ self._started = False
144
+
145
+ @classmethod
146
+ async def wait_for_task_run(
147
+ cls, task_run_id: uuid.UUID, timeout: Optional[float] = None
148
+ ):
149
+ """
150
+ Wait for a task run to finish.
151
+
152
+ Note this relies on a websocket connection to receive events from the server
153
+ and will not work with an ephemeral server.
154
+
155
+ Args:
156
+ task_run_id: The ID of the task run to wait for.
157
+ timeout: The maximum time to wait for the task run to
158
+ finish. Defaults to None.
159
+ """
160
+ instance = cls.instance()
161
+ with instance._observed_completed_task_runs_lock:
162
+ if task_run_id in instance._observed_completed_task_runs:
163
+ return
164
+
165
+ # Need to create event in loop thread to ensure it can be set
166
+ # from the loop thread
167
+ finished_event = await from_async.wait_for_call_in_loop_thread(
168
+ create_call(asyncio.Event)
169
+ )
170
+ with instance._completion_events_lock:
171
+ # Cache the event for the task run ID so the consumer can set it
172
+ # when the event is received
173
+ instance._completion_events[task_run_id] = finished_event
174
+
175
+ with anyio.move_on_after(delay=timeout):
176
+ await from_async.wait_for_call_in_loop_thread(
177
+ create_call(finished_event.wait)
178
+ )
179
+
180
+ with instance._completion_events_lock:
181
+ # Remove the event from the cache after it has been waited on
182
+ instance._completion_events.pop(task_run_id, None)
183
+
184
+ @classmethod
185
+ def instance(cls):
186
+ """
187
+ Get the singleton instance of TaskRunWaiter.
188
+ """
189
+ with cls._instance_lock:
190
+ if cls._instance is None:
191
+ cls._instance = cls._new_instance()
192
+ return cls._instance
193
+
194
+ @classmethod
195
+ def _new_instance(cls):
196
+ instance = cls()
197
+
198
+ if threading.get_ident() == get_global_loop().thread.ident:
199
+ instance.start()
200
+ else:
201
+ from_sync.call_soon_in_loop_thread(create_call(instance.start)).result()
202
+
203
+ return instance
@@ -11,10 +11,12 @@ from typing import List, Optional
11
11
 
12
12
  import anyio
13
13
  import anyio.abc
14
+ from exceptiongroup import BaseExceptionGroup # novermin
14
15
  from websockets.exceptions import InvalidStatusCode
15
16
 
16
- from prefect import Task, get_client
17
+ from prefect import Task
17
18
  from prefect._internal.concurrency.api import create_call, from_sync
19
+ from prefect.client.orchestration import get_client
18
20
  from prefect.client.schemas.objects import TaskRun
19
21
  from prefect.client.subscriptions import Subscription
20
22
  from prefect.exceptions import Abort, PrefectHTTPStatusError
@@ -30,11 +32,11 @@ from prefect.utilities.asyncutils import asyncnullcontext, sync_compatible
30
32
  from prefect.utilities.engine import emit_task_run_state_change_event, propose_state
31
33
  from prefect.utilities.processutils import _register_signal
32
34
 
33
- logger = get_logger("task_server")
35
+ logger = get_logger("task_worker")
34
36
 
35
37
 
36
- class StopTaskServer(Exception):
37
- """Raised when the task server is stopped."""
38
+ class StopTaskWorker(Exception):
39
+ """Raised when the task worker is stopped."""
38
40
 
39
41
  pass
40
42
 
@@ -49,11 +51,11 @@ def should_try_to_read_parameters(task: Task, task_run: TaskRun) -> bool:
49
51
  return new_enough_state_details and task_accepts_parameters
50
52
 
51
53
 
52
- class TaskServer:
54
+ class TaskWorker:
53
55
  """This class is responsible for serving tasks that may be executed in the background
54
56
  by a task runner via the traditional engine machinery.
55
57
 
56
- When `start()` is called, the task server will open a websocket connection to a
58
+ When `start()` is called, the task worker will open a websocket connection to a
57
59
  server-side queue of scheduled task runs. When a scheduled task run is found, the
58
60
  scheduled task run is submitted to the engine for execution with a minimal `EngineContext`
59
61
  so that the task run can be governed by orchestration rules.
@@ -70,7 +72,7 @@ class TaskServer:
70
72
  *tasks: Task,
71
73
  limit: Optional[int] = 10,
72
74
  ):
73
- self.tasks: List[Task] = tasks
75
+ self.tasks: List[Task] = list(tasks)
74
76
 
75
77
  self.started: bool = False
76
78
  self.stopping: bool = False
@@ -80,7 +82,7 @@ class TaskServer:
80
82
 
81
83
  if not asyncio.get_event_loop().is_running():
82
84
  raise RuntimeError(
83
- "TaskServer must be initialized within an async context."
85
+ "TaskWorker must be initialized within an async context."
84
86
  )
85
87
 
86
88
  self._runs_task_group: anyio.abc.TaskGroup = anyio.create_task_group()
@@ -93,7 +95,7 @@ class TaskServer:
93
95
 
94
96
  def handle_sigterm(self, signum, frame):
95
97
  """
96
- Shuts down the task server when a SIGTERM is received.
98
+ Shuts down the task worker when a SIGTERM is received.
97
99
  """
98
100
  logger.info("SIGTERM received, initiating graceful shutdown...")
99
101
  from_sync.call_in_loop_thread(create_call(self.stop))
@@ -103,12 +105,12 @@ class TaskServer:
103
105
  @sync_compatible
104
106
  async def start(self) -> None:
105
107
  """
106
- Starts a task server, which runs the tasks provided in the constructor.
108
+ Starts a task worker, which runs the tasks provided in the constructor.
107
109
  """
108
110
  _register_signal(signal.SIGTERM, self.handle_sigterm)
109
111
 
110
112
  async with asyncnullcontext() if self.started else self:
111
- logger.info("Starting task server...")
113
+ logger.info("Starting task worker...")
112
114
  try:
113
115
  await self._subscribe_to_task_scheduling()
114
116
  except InvalidStatusCode as exc:
@@ -124,17 +126,17 @@ class TaskServer:
124
126
 
125
127
  @sync_compatible
126
128
  async def stop(self):
127
- """Stops the task server's polling cycle."""
129
+ """Stops the task worker's polling cycle."""
128
130
  if not self.started:
129
131
  raise RuntimeError(
130
- "Task server has not yet started. Please start the task server by"
132
+ "Task worker has not yet started. Please start the task worker by"
131
133
  " calling .start()"
132
134
  )
133
135
 
134
136
  self.started = False
135
137
  self.stopping = True
136
138
 
137
- raise StopTaskServer
139
+ raise StopTaskWorker
138
140
 
139
141
  async def _subscribe_to_task_scheduling(self):
140
142
  logger.info(
@@ -159,11 +161,11 @@ class TaskServer:
159
161
  task = next((t for t in self.tasks if t.task_key == task_run.task_key), None)
160
162
 
161
163
  if not task:
162
- if PREFECT_TASK_SCHEDULING_DELETE_FAILED_SUBMISSIONS.value():
164
+ if PREFECT_TASK_SCHEDULING_DELETE_FAILED_SUBMISSIONS:
163
165
  logger.warning(
164
- f"Task {task_run.name!r} not found in task server registry."
166
+ f"Task {task_run.name!r} not found in task worker registry."
165
167
  )
166
- await self._client._client.delete(f"/task_runs/{task_run.id}")
168
+ await self._client._client.delete(f"/task_runs/{task_run.id}") # type: ignore
167
169
 
168
170
  return
169
171
 
@@ -260,14 +262,14 @@ class TaskServer:
260
262
  self._limiter.release_on_behalf_of(task_run.id)
261
263
 
262
264
  async def execute_task_run(self, task_run: TaskRun):
263
- """Execute a task run in the task server."""
265
+ """Execute a task run in the task worker."""
264
266
  async with self if not self.started else asyncnullcontext():
265
267
  if self._limiter:
266
268
  await self._limiter.acquire_on_behalf_of(task_run.id)
267
269
  await self._submit_scheduled_task_run(task_run)
268
270
 
269
271
  async def __aenter__(self):
270
- logger.debug("Starting task server...")
272
+ logger.debug("Starting task worker...")
271
273
 
272
274
  if self._client._closed:
273
275
  self._client = get_client()
@@ -280,7 +282,7 @@ class TaskServer:
280
282
  return self
281
283
 
282
284
  async def __aexit__(self, *exc_info):
283
- logger.debug("Stopping task server...")
285
+ logger.debug("Stopping task worker...")
284
286
  self.started = False
285
287
  await self._exit_stack.__aexit__(*exc_info)
286
288
 
@@ -300,7 +302,7 @@ async def serve(*tasks: Task, limit: Optional[int] = 10):
300
302
  Example:
301
303
  ```python
302
304
  from prefect import task
303
- from prefect.task_server import serve
305
+ from prefect.task_worker import serve
304
306
 
305
307
  @task(log_prints=True)
306
308
  def say(message: str):
@@ -315,13 +317,21 @@ async def serve(*tasks: Task, limit: Optional[int] = 10):
315
317
  serve(say, yell)
316
318
  ```
317
319
  """
318
- task_server = TaskServer(*tasks, limit=limit)
320
+ task_worker = TaskWorker(*tasks, limit=limit)
319
321
 
320
322
  try:
321
- await task_server.start()
323
+ await task_worker.start()
324
+
325
+ except BaseExceptionGroup as exc: # novermin
326
+ exceptions = exc.exceptions
327
+ n_exceptions = len(exceptions)
328
+ logger.error(
329
+ f"Task worker stopped with {n_exceptions} exception{'s' if n_exceptions != 1 else ''}:"
330
+ f"\n" + "\n".join(str(e) for e in exceptions)
331
+ )
322
332
 
323
- except StopTaskServer:
324
- logger.info("Task server stopped.")
333
+ except StopTaskWorker:
334
+ logger.info("Task worker stopped.")
325
335
 
326
- except asyncio.CancelledError:
327
- logger.info("Task server interrupted, stopping...")
336
+ except (asyncio.CancelledError, KeyboardInterrupt):
337
+ logger.info("Task worker interrupted, stopping...")
prefect/tasks.py CHANGED
@@ -43,6 +43,7 @@ from prefect.context import (
43
43
  )
44
44
  from prefect.futures import PrefectDistributedFuture, PrefectFuture
45
45
  from prefect.logging.loggers import get_logger
46
+ from prefect.records.cache_policies import DEFAULT, CachePolicy
46
47
  from prefect.results import ResultFactory, ResultSerializer, ResultStorage
47
48
  from prefect.settings import (
48
49
  PREFECT_TASK_DEFAULT_RETRIES,
@@ -62,7 +63,6 @@ from prefect.utilities.importtools import to_qualified_name
62
63
  if TYPE_CHECKING:
63
64
  from prefect.client.orchestration import PrefectClient
64
65
  from prefect.context import TaskRunContext
65
- from prefect.task_runners import BaseTaskRunner
66
66
  from prefect.transactions import Transaction
67
67
 
68
68
  T = TypeVar("T") # Generic type var for capturing the inner return type of async funcs
@@ -145,6 +145,7 @@ class Task(Generic[P, R]):
145
145
  tags are combined with any tags defined by a `prefect.tags` context at
146
146
  task runtime.
147
147
  version: An optional string specifying the version of this task definition
148
+ cache_policy: A cache policy that determines the level of caching for this task
148
149
  cache_key_fn: An optional callable that, given the task run context and call
149
150
  parameters, generates a string key; if the key matches a previous completed
150
151
  state, that state result will be restored instead of running the task again.
@@ -204,6 +205,7 @@ class Task(Generic[P, R]):
204
205
  description: Optional[str] = None,
205
206
  tags: Optional[Iterable[str]] = None,
206
207
  version: Optional[str] = None,
208
+ cache_policy: Optional[CachePolicy] = NotSet,
207
209
  cache_key_fn: Optional[
208
210
  Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
209
211
  ] = None,
@@ -303,10 +305,23 @@ class Task(Generic[P, R]):
303
305
 
304
306
  self.task_key = f"{self.fn.__qualname__}-{task_origin_hash}"
305
307
 
308
+ # TODO: warn of precedence of cache policies and cache key fn if both provided?
309
+ if cache_key_fn:
310
+ cache_policy = CachePolicy.from_cache_key_fn(cache_key_fn)
311
+
312
+ # TODO: manage expiration and cache refresh
306
313
  self.cache_key_fn = cache_key_fn
307
314
  self.cache_expiration = cache_expiration
308
315
  self.refresh_cache = refresh_cache
309
316
 
317
+ if cache_policy is NotSet and result_storage_key is None:
318
+ self.cache_policy = DEFAULT
319
+ elif result_storage_key:
320
+ # TODO: handle this situation with double storage
321
+ self.cache_policy = None
322
+ else:
323
+ self.cache_policy = cache_policy
324
+
310
325
  # TaskRunPolicy settings
311
326
  # TODO: We can instantiate a `TaskRunPolicy` and add Pydantic bound checks to
312
327
  # validate that the user passes positive numbers here
@@ -358,6 +373,7 @@ class Task(Generic[P, R]):
358
373
  name: str = None,
359
374
  description: str = None,
360
375
  tags: Iterable[str] = None,
376
+ cache_policy: CachePolicy = NotSet,
361
377
  cache_key_fn: Callable[
362
378
  ["TaskRunContext", Dict[str, Any]], Optional[str]
363
379
  ] = None,
@@ -469,6 +485,9 @@ class Task(Generic[P, R]):
469
485
  name=name or self.name,
470
486
  description=description or self.description,
471
487
  tags=tags or copy(self.tags),
488
+ cache_policy=cache_policy
489
+ if cache_policy is not NotSet
490
+ else self.cache_policy,
472
491
  cache_key_fn=cache_key_fn or self.cache_key_fn,
473
492
  cache_expiration=cache_expiration or self.cache_expiration,
474
493
  task_run_name=task_run_name,
@@ -582,7 +601,7 @@ class Task(Generic[P, R]):
582
601
  else:
583
602
  state = Pending()
584
603
 
585
- # store parameters for background tasks so that task servers
604
+ # store parameters for background tasks so that task worker
586
605
  # can retrieve them at runtime
587
606
  if deferred and (parameters or wait_for):
588
607
  parameters_id = uuid4()
@@ -755,8 +774,6 @@ class Task(Generic[P, R]):
755
774
  """
756
775
  Submit a run of the task to the engine.
757
776
 
758
- If writing an async task, this call must be awaited.
759
-
760
777
  Will create a new task run in the backing API and submit the task to the flow's
761
778
  task runner. This call only blocks execution while the task is being submitted,
762
779
  once it is submitted, the flow function will continue executing.
@@ -849,7 +866,11 @@ class Task(Generic[P, R]):
849
866
  flow_run_context = FlowRunContext.get()
850
867
 
851
868
  if not flow_run_context:
852
- raise ValueError("Task.submit() must be called within a flow")
869
+ raise RuntimeError(
870
+ "Unable to determine task runner to use for submission. If you are"
871
+ " submitting a task outside of a flow, please use `.delay`"
872
+ " to submit the task run for deferred execution."
873
+ )
853
874
 
854
875
  task_viz_tracker = get_task_viz_tracker()
855
876
  if task_viz_tracker:
@@ -897,6 +918,7 @@ class Task(Generic[P, R]):
897
918
  *args: Any,
898
919
  return_state: bool = False,
899
920
  wait_for: Optional[Iterable[PrefectFuture]] = None,
921
+ deferred: bool = False,
900
922
  **kwargs: Any,
901
923
  ):
902
924
  """
@@ -1010,6 +1032,7 @@ class Task(Generic[P, R]):
1010
1032
  [[11, 21], [12, 22], [13, 23]]
1011
1033
  """
1012
1034
 
1035
+ from prefect.task_runners import TaskRunner
1013
1036
  from prefect.utilities.visualization import (
1014
1037
  VisualizationUnsupportedError,
1015
1038
  get_task_viz_tracker,
@@ -1026,22 +1049,22 @@ class Task(Generic[P, R]):
1026
1049
  "`task.map()` is not currently supported by `flow.visualize()`"
1027
1050
  )
1028
1051
 
1029
- if not flow_run_context:
1030
- # TODO: Should we split out background task mapping into a separate method
1031
- # like we do for the `submit`/`apply_async` split?
1052
+ if deferred:
1032
1053
  parameters_list = expand_mapping_parameters(self.fn, parameters)
1033
- # TODO: Make this non-blocking once we can return a list of futures
1034
- # instead of a list of task runs
1035
- return [
1036
- run_coro_as_sync(self.create_run(parameters=parameters, deferred=True))
1054
+ futures = [
1055
+ self.apply_async(kwargs=parameters, wait_for=wait_for)
1037
1056
  for parameters in parameters_list
1038
1057
  ]
1039
-
1040
- from prefect.task_runners import TaskRunner
1041
-
1042
- task_runner = flow_run_context.task_runner
1043
- assert isinstance(task_runner, TaskRunner)
1044
- futures = task_runner.map(self, parameters, wait_for)
1058
+ elif task_runner := getattr(flow_run_context, "task_runner", None):
1059
+ assert isinstance(task_runner, TaskRunner)
1060
+ futures = task_runner.map(self, parameters, wait_for)
1061
+ else:
1062
+ raise RuntimeError(
1063
+ "Unable to determine task runner to use for mapped task runs. If"
1064
+ " you are mapping a task outside of a flow, please provide"
1065
+ " `deferred=True` to submit the mapped task runs for deferred"
1066
+ " execution."
1067
+ )
1045
1068
  if return_state:
1046
1069
  states = []
1047
1070
  for future in futures:
@@ -1059,7 +1082,7 @@ class Task(Generic[P, R]):
1059
1082
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
1060
1083
  ) -> PrefectDistributedFuture:
1061
1084
  """
1062
- Create a pending task run for a task server to execute.
1085
+ Create a pending task run for a task worker to execute.
1063
1086
 
1064
1087
  Args:
1065
1088
  args: Arguments to run the task with
@@ -1181,7 +1204,7 @@ class Task(Generic[P, R]):
1181
1204
  """
1182
1205
  return self.apply_async(args=args, kwargs=kwargs)
1183
1206
 
1184
- def serve(self, task_runner: Optional["BaseTaskRunner"] = None) -> "Task":
1207
+ def serve(self) -> "Task":
1185
1208
  """Serve the task using the provided task runner. This method is used to
1186
1209
  establish a websocket connection with the Prefect server and listen for
1187
1210
  submitted task runs to execute.
@@ -1198,9 +1221,9 @@ class Task(Generic[P, R]):
1198
1221
 
1199
1222
  >>> my_task.serve()
1200
1223
  """
1201
- from prefect.task_server import serve
1224
+ from prefect.task_worker import serve
1202
1225
 
1203
- serve(self, task_runner=task_runner)
1226
+ serve(self)
1204
1227
 
1205
1228
 
1206
1229
  @overload
@@ -1215,6 +1238,7 @@ def task(
1215
1238
  description: str = None,
1216
1239
  tags: Iterable[str] = None,
1217
1240
  version: str = None,
1241
+ cache_policy: CachePolicy = NotSet,
1218
1242
  cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] = None,
1219
1243
  cache_expiration: datetime.timedelta = None,
1220
1244
  task_run_name: Optional[Union[Callable[[], str], str]] = None,
@@ -1249,6 +1273,7 @@ def task(
1249
1273
  description: str = None,
1250
1274
  tags: Iterable[str] = None,
1251
1275
  version: str = None,
1276
+ cache_policy: CachePolicy = NotSet,
1252
1277
  cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] = None,
1253
1278
  cache_expiration: datetime.timedelta = None,
1254
1279
  task_run_name: Optional[Union[Callable[[], str], str]] = None,
@@ -1391,6 +1416,7 @@ def task(
1391
1416
  description=description,
1392
1417
  tags=tags,
1393
1418
  version=version,
1419
+ cache_policy=cache_policy,
1394
1420
  cache_key_fn=cache_key_fn,
1395
1421
  cache_expiration=cache_expiration,
1396
1422
  task_run_name=task_run_name,
@@ -1420,6 +1446,7 @@ def task(
1420
1446
  description=description,
1421
1447
  tags=tags,
1422
1448
  version=version,
1449
+ cache_policy=cache_policy,
1423
1450
  cache_key_fn=cache_key_fn,
1424
1451
  cache_expiration=cache_expiration,
1425
1452
  task_run_name=task_run_name,
prefect/transactions.py CHANGED
@@ -52,6 +52,7 @@ class Transaction(ContextModel):
52
52
  on_rollback_hooks: List[Callable[["Transaction"], None]] = Field(
53
53
  default_factory=list
54
54
  )
55
+ overwrite: bool = False
55
56
  _staged_value: Any = None
56
57
  __var__ = ContextVar("transaction")
57
58
 
@@ -122,7 +123,7 @@ class Transaction(ContextModel):
122
123
  def begin(self):
123
124
  # currently we only support READ_COMMITTED isolation
124
125
  # i.e., no locking behavior
125
- if self.store and self.store.exists(key=self.key):
126
+ if not self.overwrite and self.store and self.store.exists(key=self.key):
126
127
  self.state = TransactionState.COMMITTED
127
128
 
128
129
  def read(self) -> dict:
@@ -215,6 +216,9 @@ def transaction(
215
216
  key: Optional[str] = None,
216
217
  store: Optional[RecordStore] = None,
217
218
  commit_mode: CommitMode = CommitMode.LAZY,
219
+ overwrite: bool = False,
218
220
  ) -> Generator[Transaction, None, None]:
219
- with Transaction(key=key, store=store, commit_mode=commit_mode) as txn:
221
+ with Transaction(
222
+ key=key, store=store, commit_mode=commit_mode, overwrite=overwrite
223
+ ) as txn:
220
224
  yield txn