prefect-client 3.0.0rc1__py3-none-any.whl → 3.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefect/blocks/redis.py +168 -0
- prefect/client/orchestration.py +17 -1
- prefect/client/schemas/objects.py +12 -8
- prefect/concurrency/asyncio.py +1 -1
- prefect/concurrency/services.py +1 -1
- prefect/deployments/base.py +7 -1
- prefect/events/schemas/events.py +2 -0
- prefect/flow_engine.py +2 -2
- prefect/flow_runs.py +2 -2
- prefect/flows.py +8 -1
- prefect/futures.py +44 -43
- prefect/input/run_input.py +4 -2
- prefect/records/cache_policies.py +179 -0
- prefect/settings.py +6 -3
- prefect/states.py +6 -4
- prefect/task_engine.py +169 -198
- prefect/task_runners.py +6 -2
- prefect/task_runs.py +203 -0
- prefect/{task_server.py → task_worker.py} +37 -27
- prefect/tasks.py +49 -22
- prefect/transactions.py +6 -2
- prefect/utilities/callables.py +74 -3
- prefect/utilities/importtools.py +5 -5
- prefect/variables.py +15 -10
- prefect/workers/base.py +11 -1
- {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc2.dist-info}/METADATA +2 -1
- {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc2.dist-info}/RECORD +30 -27
- {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc2.dist-info}/LICENSE +0 -0
- {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc2.dist-info}/WHEEL +0 -0
- {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc2.dist-info}/top_level.txt +0 -0
prefect/task_runs.py
ADDED
@@ -0,0 +1,203 @@
|
|
1
|
+
import asyncio
|
2
|
+
import atexit
|
3
|
+
import threading
|
4
|
+
import uuid
|
5
|
+
from typing import Dict, Optional
|
6
|
+
|
7
|
+
import anyio
|
8
|
+
from cachetools import TTLCache
|
9
|
+
from typing_extensions import Self
|
10
|
+
|
11
|
+
from prefect._internal.concurrency.api import create_call, from_async, from_sync
|
12
|
+
from prefect._internal.concurrency.threads import get_global_loop
|
13
|
+
from prefect.client.schemas.objects import TERMINAL_STATES
|
14
|
+
from prefect.events.clients import get_events_subscriber
|
15
|
+
from prefect.events.filters import EventFilter, EventNameFilter
|
16
|
+
from prefect.logging.loggers import get_logger
|
17
|
+
|
18
|
+
|
19
|
+
class TaskRunWaiter:
|
20
|
+
"""
|
21
|
+
A service used for waiting for a task run to finish.
|
22
|
+
|
23
|
+
This service listens for task run events and provides a way to wait for a specific
|
24
|
+
task run to finish. This is useful for waiting for a task run to finish before
|
25
|
+
continuing execution.
|
26
|
+
|
27
|
+
The service is a singleton and must be started before use. The service will
|
28
|
+
automatically start when the first instance is created. A single websocket
|
29
|
+
connection is used to listen for task run events.
|
30
|
+
|
31
|
+
The service can be used to wait for a task run to finish by calling
|
32
|
+
`TaskRunWaiter.wait_for_task_run` with the task run ID to wait for. The method
|
33
|
+
will return when the task run has finished or the timeout has elapsed.
|
34
|
+
|
35
|
+
The service will automatically stop when the Python process exits or when the
|
36
|
+
global loop thread is stopped.
|
37
|
+
|
38
|
+
Example:
|
39
|
+
```python
|
40
|
+
import asyncio
|
41
|
+
from uuid import uuid4
|
42
|
+
|
43
|
+
from prefect import task
|
44
|
+
from prefect.task_engine import run_task_async
|
45
|
+
from prefect.task_runs import TaskRunWaiter
|
46
|
+
|
47
|
+
|
48
|
+
@task
|
49
|
+
async def test_task():
|
50
|
+
await asyncio.sleep(5)
|
51
|
+
print("Done!")
|
52
|
+
|
53
|
+
|
54
|
+
async def main():
|
55
|
+
task_run_id = uuid4()
|
56
|
+
asyncio.create_task(run_task_async(task=test_task, task_run_id=task_run_id))
|
57
|
+
|
58
|
+
await TaskRunWaiter.wait_for_task_run(task_run_id)
|
59
|
+
print("Task run finished")
|
60
|
+
|
61
|
+
|
62
|
+
if __name__ == "__main__":
|
63
|
+
asyncio.run(main())
|
64
|
+
```
|
65
|
+
"""
|
66
|
+
|
67
|
+
_instance: Optional[Self] = None
|
68
|
+
_instance_lock = threading.Lock()
|
69
|
+
|
70
|
+
def __init__(self):
|
71
|
+
self.logger = get_logger("TaskRunWaiter")
|
72
|
+
self._consumer_task: Optional[asyncio.Task] = None
|
73
|
+
self._observed_completed_task_runs: TTLCache[uuid.UUID, bool] = TTLCache(
|
74
|
+
maxsize=100, ttl=60
|
75
|
+
)
|
76
|
+
self._completion_events: Dict[uuid.UUID, asyncio.Event] = {}
|
77
|
+
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
78
|
+
self._observed_completed_task_runs_lock = threading.Lock()
|
79
|
+
self._completion_events_lock = threading.Lock()
|
80
|
+
self._started = False
|
81
|
+
|
82
|
+
def start(self):
|
83
|
+
"""
|
84
|
+
Start the TaskRunWaiter service.
|
85
|
+
"""
|
86
|
+
if self._started:
|
87
|
+
return
|
88
|
+
self.logger.info("Starting TaskRunWaiter")
|
89
|
+
loop_thread = get_global_loop()
|
90
|
+
|
91
|
+
if not asyncio.get_running_loop() == loop_thread._loop:
|
92
|
+
raise RuntimeError("TaskRunWaiter must run on the global loop thread.")
|
93
|
+
|
94
|
+
self._loop = loop_thread._loop
|
95
|
+
self._consumer_task = self._loop.create_task(self._consume_events())
|
96
|
+
|
97
|
+
loop_thread.add_shutdown_call(create_call(self.stop))
|
98
|
+
atexit.register(self.stop)
|
99
|
+
self._started = True
|
100
|
+
|
101
|
+
async def _consume_events(self):
|
102
|
+
async with get_events_subscriber(
|
103
|
+
filter=EventFilter(
|
104
|
+
event=EventNameFilter(
|
105
|
+
name=[
|
106
|
+
f"prefect.task-run.{state.name.title()}"
|
107
|
+
for state in TERMINAL_STATES
|
108
|
+
],
|
109
|
+
)
|
110
|
+
)
|
111
|
+
) as subscriber:
|
112
|
+
async for event in subscriber:
|
113
|
+
try:
|
114
|
+
self.logger.info(
|
115
|
+
f"Received event: {event.resource['prefect.resource.id']}"
|
116
|
+
)
|
117
|
+
task_run_id = uuid.UUID(
|
118
|
+
event.resource["prefect.resource.id"].replace(
|
119
|
+
"prefect.task-run.", ""
|
120
|
+
)
|
121
|
+
)
|
122
|
+
with self._observed_completed_task_runs_lock:
|
123
|
+
# Cache the task run ID for a short period of time to avoid
|
124
|
+
# unnecessary waits
|
125
|
+
self._observed_completed_task_runs[task_run_id] = True
|
126
|
+
with self._completion_events_lock:
|
127
|
+
# Set the event for the task run ID if it is in the cache
|
128
|
+
# so the waiter can wake up the waiting coroutine
|
129
|
+
if task_run_id in self._completion_events:
|
130
|
+
self._completion_events[task_run_id].set()
|
131
|
+
except Exception as exc:
|
132
|
+
self.logger.error(f"Error processing event: {exc}")
|
133
|
+
|
134
|
+
def stop(self):
|
135
|
+
"""
|
136
|
+
Stop the TaskRunWaiter service.
|
137
|
+
"""
|
138
|
+
self.logger.debug("Stopping TaskRunWaiter")
|
139
|
+
if self._consumer_task:
|
140
|
+
self._consumer_task.cancel()
|
141
|
+
self._consumer_task = None
|
142
|
+
self.__class__._instance = None
|
143
|
+
self._started = False
|
144
|
+
|
145
|
+
@classmethod
|
146
|
+
async def wait_for_task_run(
|
147
|
+
cls, task_run_id: uuid.UUID, timeout: Optional[float] = None
|
148
|
+
):
|
149
|
+
"""
|
150
|
+
Wait for a task run to finish.
|
151
|
+
|
152
|
+
Note this relies on a websocket connection to receive events from the server
|
153
|
+
and will not work with an ephemeral server.
|
154
|
+
|
155
|
+
Args:
|
156
|
+
task_run_id: The ID of the task run to wait for.
|
157
|
+
timeout: The maximum time to wait for the task run to
|
158
|
+
finish. Defaults to None.
|
159
|
+
"""
|
160
|
+
instance = cls.instance()
|
161
|
+
with instance._observed_completed_task_runs_lock:
|
162
|
+
if task_run_id in instance._observed_completed_task_runs:
|
163
|
+
return
|
164
|
+
|
165
|
+
# Need to create event in loop thread to ensure it can be set
|
166
|
+
# from the loop thread
|
167
|
+
finished_event = await from_async.wait_for_call_in_loop_thread(
|
168
|
+
create_call(asyncio.Event)
|
169
|
+
)
|
170
|
+
with instance._completion_events_lock:
|
171
|
+
# Cache the event for the task run ID so the consumer can set it
|
172
|
+
# when the event is received
|
173
|
+
instance._completion_events[task_run_id] = finished_event
|
174
|
+
|
175
|
+
with anyio.move_on_after(delay=timeout):
|
176
|
+
await from_async.wait_for_call_in_loop_thread(
|
177
|
+
create_call(finished_event.wait)
|
178
|
+
)
|
179
|
+
|
180
|
+
with instance._completion_events_lock:
|
181
|
+
# Remove the event from the cache after it has been waited on
|
182
|
+
instance._completion_events.pop(task_run_id, None)
|
183
|
+
|
184
|
+
@classmethod
|
185
|
+
def instance(cls):
|
186
|
+
"""
|
187
|
+
Get the singleton instance of TaskRunWaiter.
|
188
|
+
"""
|
189
|
+
with cls._instance_lock:
|
190
|
+
if cls._instance is None:
|
191
|
+
cls._instance = cls._new_instance()
|
192
|
+
return cls._instance
|
193
|
+
|
194
|
+
@classmethod
|
195
|
+
def _new_instance(cls):
|
196
|
+
instance = cls()
|
197
|
+
|
198
|
+
if threading.get_ident() == get_global_loop().thread.ident:
|
199
|
+
instance.start()
|
200
|
+
else:
|
201
|
+
from_sync.call_soon_in_loop_thread(create_call(instance.start)).result()
|
202
|
+
|
203
|
+
return instance
|
@@ -11,10 +11,12 @@ from typing import List, Optional
|
|
11
11
|
|
12
12
|
import anyio
|
13
13
|
import anyio.abc
|
14
|
+
from exceptiongroup import BaseExceptionGroup # novermin
|
14
15
|
from websockets.exceptions import InvalidStatusCode
|
15
16
|
|
16
|
-
from prefect import Task
|
17
|
+
from prefect import Task
|
17
18
|
from prefect._internal.concurrency.api import create_call, from_sync
|
19
|
+
from prefect.client.orchestration import get_client
|
18
20
|
from prefect.client.schemas.objects import TaskRun
|
19
21
|
from prefect.client.subscriptions import Subscription
|
20
22
|
from prefect.exceptions import Abort, PrefectHTTPStatusError
|
@@ -30,11 +32,11 @@ from prefect.utilities.asyncutils import asyncnullcontext, sync_compatible
|
|
30
32
|
from prefect.utilities.engine import emit_task_run_state_change_event, propose_state
|
31
33
|
from prefect.utilities.processutils import _register_signal
|
32
34
|
|
33
|
-
logger = get_logger("
|
35
|
+
logger = get_logger("task_worker")
|
34
36
|
|
35
37
|
|
36
|
-
class
|
37
|
-
"""Raised when the task
|
38
|
+
class StopTaskWorker(Exception):
|
39
|
+
"""Raised when the task worker is stopped."""
|
38
40
|
|
39
41
|
pass
|
40
42
|
|
@@ -49,11 +51,11 @@ def should_try_to_read_parameters(task: Task, task_run: TaskRun) -> bool:
|
|
49
51
|
return new_enough_state_details and task_accepts_parameters
|
50
52
|
|
51
53
|
|
52
|
-
class
|
54
|
+
class TaskWorker:
|
53
55
|
"""This class is responsible for serving tasks that may be executed in the background
|
54
56
|
by a task runner via the traditional engine machinery.
|
55
57
|
|
56
|
-
When `start()` is called, the task
|
58
|
+
When `start()` is called, the task worker will open a websocket connection to a
|
57
59
|
server-side queue of scheduled task runs. When a scheduled task run is found, the
|
58
60
|
scheduled task run is submitted to the engine for execution with a minimal `EngineContext`
|
59
61
|
so that the task run can be governed by orchestration rules.
|
@@ -70,7 +72,7 @@ class TaskServer:
|
|
70
72
|
*tasks: Task,
|
71
73
|
limit: Optional[int] = 10,
|
72
74
|
):
|
73
|
-
self.tasks: List[Task] = tasks
|
75
|
+
self.tasks: List[Task] = list(tasks)
|
74
76
|
|
75
77
|
self.started: bool = False
|
76
78
|
self.stopping: bool = False
|
@@ -80,7 +82,7 @@ class TaskServer:
|
|
80
82
|
|
81
83
|
if not asyncio.get_event_loop().is_running():
|
82
84
|
raise RuntimeError(
|
83
|
-
"
|
85
|
+
"TaskWorker must be initialized within an async context."
|
84
86
|
)
|
85
87
|
|
86
88
|
self._runs_task_group: anyio.abc.TaskGroup = anyio.create_task_group()
|
@@ -93,7 +95,7 @@ class TaskServer:
|
|
93
95
|
|
94
96
|
def handle_sigterm(self, signum, frame):
|
95
97
|
"""
|
96
|
-
Shuts down the task
|
98
|
+
Shuts down the task worker when a SIGTERM is received.
|
97
99
|
"""
|
98
100
|
logger.info("SIGTERM received, initiating graceful shutdown...")
|
99
101
|
from_sync.call_in_loop_thread(create_call(self.stop))
|
@@ -103,12 +105,12 @@ class TaskServer:
|
|
103
105
|
@sync_compatible
|
104
106
|
async def start(self) -> None:
|
105
107
|
"""
|
106
|
-
Starts a task
|
108
|
+
Starts a task worker, which runs the tasks provided in the constructor.
|
107
109
|
"""
|
108
110
|
_register_signal(signal.SIGTERM, self.handle_sigterm)
|
109
111
|
|
110
112
|
async with asyncnullcontext() if self.started else self:
|
111
|
-
logger.info("Starting task
|
113
|
+
logger.info("Starting task worker...")
|
112
114
|
try:
|
113
115
|
await self._subscribe_to_task_scheduling()
|
114
116
|
except InvalidStatusCode as exc:
|
@@ -124,17 +126,17 @@ class TaskServer:
|
|
124
126
|
|
125
127
|
@sync_compatible
|
126
128
|
async def stop(self):
|
127
|
-
"""Stops the task
|
129
|
+
"""Stops the task worker's polling cycle."""
|
128
130
|
if not self.started:
|
129
131
|
raise RuntimeError(
|
130
|
-
"Task
|
132
|
+
"Task worker has not yet started. Please start the task worker by"
|
131
133
|
" calling .start()"
|
132
134
|
)
|
133
135
|
|
134
136
|
self.started = False
|
135
137
|
self.stopping = True
|
136
138
|
|
137
|
-
raise
|
139
|
+
raise StopTaskWorker
|
138
140
|
|
139
141
|
async def _subscribe_to_task_scheduling(self):
|
140
142
|
logger.info(
|
@@ -159,11 +161,11 @@ class TaskServer:
|
|
159
161
|
task = next((t for t in self.tasks if t.task_key == task_run.task_key), None)
|
160
162
|
|
161
163
|
if not task:
|
162
|
-
if PREFECT_TASK_SCHEDULING_DELETE_FAILED_SUBMISSIONS
|
164
|
+
if PREFECT_TASK_SCHEDULING_DELETE_FAILED_SUBMISSIONS:
|
163
165
|
logger.warning(
|
164
|
-
f"Task {task_run.name!r} not found in task
|
166
|
+
f"Task {task_run.name!r} not found in task worker registry."
|
165
167
|
)
|
166
|
-
await self._client._client.delete(f"/task_runs/{task_run.id}")
|
168
|
+
await self._client._client.delete(f"/task_runs/{task_run.id}") # type: ignore
|
167
169
|
|
168
170
|
return
|
169
171
|
|
@@ -260,14 +262,14 @@ class TaskServer:
|
|
260
262
|
self._limiter.release_on_behalf_of(task_run.id)
|
261
263
|
|
262
264
|
async def execute_task_run(self, task_run: TaskRun):
|
263
|
-
"""Execute a task run in the task
|
265
|
+
"""Execute a task run in the task worker."""
|
264
266
|
async with self if not self.started else asyncnullcontext():
|
265
267
|
if self._limiter:
|
266
268
|
await self._limiter.acquire_on_behalf_of(task_run.id)
|
267
269
|
await self._submit_scheduled_task_run(task_run)
|
268
270
|
|
269
271
|
async def __aenter__(self):
|
270
|
-
logger.debug("Starting task
|
272
|
+
logger.debug("Starting task worker...")
|
271
273
|
|
272
274
|
if self._client._closed:
|
273
275
|
self._client = get_client()
|
@@ -280,7 +282,7 @@ class TaskServer:
|
|
280
282
|
return self
|
281
283
|
|
282
284
|
async def __aexit__(self, *exc_info):
|
283
|
-
logger.debug("Stopping task
|
285
|
+
logger.debug("Stopping task worker...")
|
284
286
|
self.started = False
|
285
287
|
await self._exit_stack.__aexit__(*exc_info)
|
286
288
|
|
@@ -300,7 +302,7 @@ async def serve(*tasks: Task, limit: Optional[int] = 10):
|
|
300
302
|
Example:
|
301
303
|
```python
|
302
304
|
from prefect import task
|
303
|
-
from prefect.
|
305
|
+
from prefect.task_worker import serve
|
304
306
|
|
305
307
|
@task(log_prints=True)
|
306
308
|
def say(message: str):
|
@@ -315,13 +317,21 @@ async def serve(*tasks: Task, limit: Optional[int] = 10):
|
|
315
317
|
serve(say, yell)
|
316
318
|
```
|
317
319
|
"""
|
318
|
-
|
320
|
+
task_worker = TaskWorker(*tasks, limit=limit)
|
319
321
|
|
320
322
|
try:
|
321
|
-
await
|
323
|
+
await task_worker.start()
|
324
|
+
|
325
|
+
except BaseExceptionGroup as exc: # novermin
|
326
|
+
exceptions = exc.exceptions
|
327
|
+
n_exceptions = len(exceptions)
|
328
|
+
logger.error(
|
329
|
+
f"Task worker stopped with {n_exceptions} exception{'s' if n_exceptions != 1 else ''}:"
|
330
|
+
f"\n" + "\n".join(str(e) for e in exceptions)
|
331
|
+
)
|
322
332
|
|
323
|
-
except
|
324
|
-
logger.info("Task
|
333
|
+
except StopTaskWorker:
|
334
|
+
logger.info("Task worker stopped.")
|
325
335
|
|
326
|
-
except asyncio.CancelledError:
|
327
|
-
logger.info("Task
|
336
|
+
except (asyncio.CancelledError, KeyboardInterrupt):
|
337
|
+
logger.info("Task worker interrupted, stopping...")
|
prefect/tasks.py
CHANGED
@@ -43,6 +43,7 @@ from prefect.context import (
|
|
43
43
|
)
|
44
44
|
from prefect.futures import PrefectDistributedFuture, PrefectFuture
|
45
45
|
from prefect.logging.loggers import get_logger
|
46
|
+
from prefect.records.cache_policies import DEFAULT, CachePolicy
|
46
47
|
from prefect.results import ResultFactory, ResultSerializer, ResultStorage
|
47
48
|
from prefect.settings import (
|
48
49
|
PREFECT_TASK_DEFAULT_RETRIES,
|
@@ -62,7 +63,6 @@ from prefect.utilities.importtools import to_qualified_name
|
|
62
63
|
if TYPE_CHECKING:
|
63
64
|
from prefect.client.orchestration import PrefectClient
|
64
65
|
from prefect.context import TaskRunContext
|
65
|
-
from prefect.task_runners import BaseTaskRunner
|
66
66
|
from prefect.transactions import Transaction
|
67
67
|
|
68
68
|
T = TypeVar("T") # Generic type var for capturing the inner return type of async funcs
|
@@ -145,6 +145,7 @@ class Task(Generic[P, R]):
|
|
145
145
|
tags are combined with any tags defined by a `prefect.tags` context at
|
146
146
|
task runtime.
|
147
147
|
version: An optional string specifying the version of this task definition
|
148
|
+
cache_policy: A cache policy that determines the level of caching for this task
|
148
149
|
cache_key_fn: An optional callable that, given the task run context and call
|
149
150
|
parameters, generates a string key; if the key matches a previous completed
|
150
151
|
state, that state result will be restored instead of running the task again.
|
@@ -204,6 +205,7 @@ class Task(Generic[P, R]):
|
|
204
205
|
description: Optional[str] = None,
|
205
206
|
tags: Optional[Iterable[str]] = None,
|
206
207
|
version: Optional[str] = None,
|
208
|
+
cache_policy: Optional[CachePolicy] = NotSet,
|
207
209
|
cache_key_fn: Optional[
|
208
210
|
Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
|
209
211
|
] = None,
|
@@ -303,10 +305,23 @@ class Task(Generic[P, R]):
|
|
303
305
|
|
304
306
|
self.task_key = f"{self.fn.__qualname__}-{task_origin_hash}"
|
305
307
|
|
308
|
+
# TODO: warn of precedence of cache policies and cache key fn if both provided?
|
309
|
+
if cache_key_fn:
|
310
|
+
cache_policy = CachePolicy.from_cache_key_fn(cache_key_fn)
|
311
|
+
|
312
|
+
# TODO: manage expiration and cache refresh
|
306
313
|
self.cache_key_fn = cache_key_fn
|
307
314
|
self.cache_expiration = cache_expiration
|
308
315
|
self.refresh_cache = refresh_cache
|
309
316
|
|
317
|
+
if cache_policy is NotSet and result_storage_key is None:
|
318
|
+
self.cache_policy = DEFAULT
|
319
|
+
elif result_storage_key:
|
320
|
+
# TODO: handle this situation with double storage
|
321
|
+
self.cache_policy = None
|
322
|
+
else:
|
323
|
+
self.cache_policy = cache_policy
|
324
|
+
|
310
325
|
# TaskRunPolicy settings
|
311
326
|
# TODO: We can instantiate a `TaskRunPolicy` and add Pydantic bound checks to
|
312
327
|
# validate that the user passes positive numbers here
|
@@ -358,6 +373,7 @@ class Task(Generic[P, R]):
|
|
358
373
|
name: str = None,
|
359
374
|
description: str = None,
|
360
375
|
tags: Iterable[str] = None,
|
376
|
+
cache_policy: CachePolicy = NotSet,
|
361
377
|
cache_key_fn: Callable[
|
362
378
|
["TaskRunContext", Dict[str, Any]], Optional[str]
|
363
379
|
] = None,
|
@@ -469,6 +485,9 @@ class Task(Generic[P, R]):
|
|
469
485
|
name=name or self.name,
|
470
486
|
description=description or self.description,
|
471
487
|
tags=tags or copy(self.tags),
|
488
|
+
cache_policy=cache_policy
|
489
|
+
if cache_policy is not NotSet
|
490
|
+
else self.cache_policy,
|
472
491
|
cache_key_fn=cache_key_fn or self.cache_key_fn,
|
473
492
|
cache_expiration=cache_expiration or self.cache_expiration,
|
474
493
|
task_run_name=task_run_name,
|
@@ -582,7 +601,7 @@ class Task(Generic[P, R]):
|
|
582
601
|
else:
|
583
602
|
state = Pending()
|
584
603
|
|
585
|
-
# store parameters for background tasks so that task
|
604
|
+
# store parameters for background tasks so that task worker
|
586
605
|
# can retrieve them at runtime
|
587
606
|
if deferred and (parameters or wait_for):
|
588
607
|
parameters_id = uuid4()
|
@@ -755,8 +774,6 @@ class Task(Generic[P, R]):
|
|
755
774
|
"""
|
756
775
|
Submit a run of the task to the engine.
|
757
776
|
|
758
|
-
If writing an async task, this call must be awaited.
|
759
|
-
|
760
777
|
Will create a new task run in the backing API and submit the task to the flow's
|
761
778
|
task runner. This call only blocks execution while the task is being submitted,
|
762
779
|
once it is submitted, the flow function will continue executing.
|
@@ -849,7 +866,11 @@ class Task(Generic[P, R]):
|
|
849
866
|
flow_run_context = FlowRunContext.get()
|
850
867
|
|
851
868
|
if not flow_run_context:
|
852
|
-
raise
|
869
|
+
raise RuntimeError(
|
870
|
+
"Unable to determine task runner to use for submission. If you are"
|
871
|
+
" submitting a task outside of a flow, please use `.delay`"
|
872
|
+
" to submit the task run for deferred execution."
|
873
|
+
)
|
853
874
|
|
854
875
|
task_viz_tracker = get_task_viz_tracker()
|
855
876
|
if task_viz_tracker:
|
@@ -897,6 +918,7 @@ class Task(Generic[P, R]):
|
|
897
918
|
*args: Any,
|
898
919
|
return_state: bool = False,
|
899
920
|
wait_for: Optional[Iterable[PrefectFuture]] = None,
|
921
|
+
deferred: bool = False,
|
900
922
|
**kwargs: Any,
|
901
923
|
):
|
902
924
|
"""
|
@@ -1010,6 +1032,7 @@ class Task(Generic[P, R]):
|
|
1010
1032
|
[[11, 21], [12, 22], [13, 23]]
|
1011
1033
|
"""
|
1012
1034
|
|
1035
|
+
from prefect.task_runners import TaskRunner
|
1013
1036
|
from prefect.utilities.visualization import (
|
1014
1037
|
VisualizationUnsupportedError,
|
1015
1038
|
get_task_viz_tracker,
|
@@ -1026,22 +1049,22 @@ class Task(Generic[P, R]):
|
|
1026
1049
|
"`task.map()` is not currently supported by `flow.visualize()`"
|
1027
1050
|
)
|
1028
1051
|
|
1029
|
-
if
|
1030
|
-
# TODO: Should we split out background task mapping into a separate method
|
1031
|
-
# like we do for the `submit`/`apply_async` split?
|
1052
|
+
if deferred:
|
1032
1053
|
parameters_list = expand_mapping_parameters(self.fn, parameters)
|
1033
|
-
|
1034
|
-
|
1035
|
-
return [
|
1036
|
-
run_coro_as_sync(self.create_run(parameters=parameters, deferred=True))
|
1054
|
+
futures = [
|
1055
|
+
self.apply_async(kwargs=parameters, wait_for=wait_for)
|
1037
1056
|
for parameters in parameters_list
|
1038
1057
|
]
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1058
|
+
elif task_runner := getattr(flow_run_context, "task_runner", None):
|
1059
|
+
assert isinstance(task_runner, TaskRunner)
|
1060
|
+
futures = task_runner.map(self, parameters, wait_for)
|
1061
|
+
else:
|
1062
|
+
raise RuntimeError(
|
1063
|
+
"Unable to determine task runner to use for mapped task runs. If"
|
1064
|
+
" you are mapping a task outside of a flow, please provide"
|
1065
|
+
" `deferred=True` to submit the mapped task runs for deferred"
|
1066
|
+
" execution."
|
1067
|
+
)
|
1045
1068
|
if return_state:
|
1046
1069
|
states = []
|
1047
1070
|
for future in futures:
|
@@ -1059,7 +1082,7 @@ class Task(Generic[P, R]):
|
|
1059
1082
|
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
1060
1083
|
) -> PrefectDistributedFuture:
|
1061
1084
|
"""
|
1062
|
-
Create a pending task run for a task
|
1085
|
+
Create a pending task run for a task worker to execute.
|
1063
1086
|
|
1064
1087
|
Args:
|
1065
1088
|
args: Arguments to run the task with
|
@@ -1181,7 +1204,7 @@ class Task(Generic[P, R]):
|
|
1181
1204
|
"""
|
1182
1205
|
return self.apply_async(args=args, kwargs=kwargs)
|
1183
1206
|
|
1184
|
-
def serve(self
|
1207
|
+
def serve(self) -> "Task":
|
1185
1208
|
"""Serve the task using the provided task runner. This method is used to
|
1186
1209
|
establish a websocket connection with the Prefect server and listen for
|
1187
1210
|
submitted task runs to execute.
|
@@ -1198,9 +1221,9 @@ class Task(Generic[P, R]):
|
|
1198
1221
|
|
1199
1222
|
>>> my_task.serve()
|
1200
1223
|
"""
|
1201
|
-
from prefect.
|
1224
|
+
from prefect.task_worker import serve
|
1202
1225
|
|
1203
|
-
serve(self
|
1226
|
+
serve(self)
|
1204
1227
|
|
1205
1228
|
|
1206
1229
|
@overload
|
@@ -1215,6 +1238,7 @@ def task(
|
|
1215
1238
|
description: str = None,
|
1216
1239
|
tags: Iterable[str] = None,
|
1217
1240
|
version: str = None,
|
1241
|
+
cache_policy: CachePolicy = NotSet,
|
1218
1242
|
cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] = None,
|
1219
1243
|
cache_expiration: datetime.timedelta = None,
|
1220
1244
|
task_run_name: Optional[Union[Callable[[], str], str]] = None,
|
@@ -1249,6 +1273,7 @@ def task(
|
|
1249
1273
|
description: str = None,
|
1250
1274
|
tags: Iterable[str] = None,
|
1251
1275
|
version: str = None,
|
1276
|
+
cache_policy: CachePolicy = NotSet,
|
1252
1277
|
cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] = None,
|
1253
1278
|
cache_expiration: datetime.timedelta = None,
|
1254
1279
|
task_run_name: Optional[Union[Callable[[], str], str]] = None,
|
@@ -1391,6 +1416,7 @@ def task(
|
|
1391
1416
|
description=description,
|
1392
1417
|
tags=tags,
|
1393
1418
|
version=version,
|
1419
|
+
cache_policy=cache_policy,
|
1394
1420
|
cache_key_fn=cache_key_fn,
|
1395
1421
|
cache_expiration=cache_expiration,
|
1396
1422
|
task_run_name=task_run_name,
|
@@ -1420,6 +1446,7 @@ def task(
|
|
1420
1446
|
description=description,
|
1421
1447
|
tags=tags,
|
1422
1448
|
version=version,
|
1449
|
+
cache_policy=cache_policy,
|
1423
1450
|
cache_key_fn=cache_key_fn,
|
1424
1451
|
cache_expiration=cache_expiration,
|
1425
1452
|
task_run_name=task_run_name,
|
prefect/transactions.py
CHANGED
@@ -52,6 +52,7 @@ class Transaction(ContextModel):
|
|
52
52
|
on_rollback_hooks: List[Callable[["Transaction"], None]] = Field(
|
53
53
|
default_factory=list
|
54
54
|
)
|
55
|
+
overwrite: bool = False
|
55
56
|
_staged_value: Any = None
|
56
57
|
__var__ = ContextVar("transaction")
|
57
58
|
|
@@ -122,7 +123,7 @@ class Transaction(ContextModel):
|
|
122
123
|
def begin(self):
|
123
124
|
# currently we only support READ_COMMITTED isolation
|
124
125
|
# i.e., no locking behavior
|
125
|
-
if self.store and self.store.exists(key=self.key):
|
126
|
+
if not self.overwrite and self.store and self.store.exists(key=self.key):
|
126
127
|
self.state = TransactionState.COMMITTED
|
127
128
|
|
128
129
|
def read(self) -> dict:
|
@@ -215,6 +216,9 @@ def transaction(
|
|
215
216
|
key: Optional[str] = None,
|
216
217
|
store: Optional[RecordStore] = None,
|
217
218
|
commit_mode: CommitMode = CommitMode.LAZY,
|
219
|
+
overwrite: bool = False,
|
218
220
|
) -> Generator[Transaction, None, None]:
|
219
|
-
with Transaction(
|
221
|
+
with Transaction(
|
222
|
+
key=key, store=store, commit_mode=commit_mode, overwrite=overwrite
|
223
|
+
) as txn:
|
220
224
|
yield txn
|