prefect-client 3.0.0rc8__py3-none-any.whl → 3.0.0rc10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefect/_internal/compatibility/deprecated.py +53 -0
- prefect/_internal/compatibility/migration.py +53 -11
- prefect/_internal/integrations.py +7 -0
- prefect/agent.py +6 -0
- prefect/blocks/core.py +1 -1
- prefect/client/__init__.py +4 -0
- prefect/client/schemas/objects.py +6 -3
- prefect/client/utilities.py +4 -4
- prefect/context.py +6 -0
- prefect/deployments/schedules.py +5 -2
- prefect/deployments/steps/core.py +6 -0
- prefect/engine.py +4 -4
- prefect/events/schemas/automations.py +3 -3
- prefect/exceptions.py +4 -1
- prefect/filesystems.py +4 -3
- prefect/flow_engine.py +102 -15
- prefect/flow_runs.py +1 -1
- prefect/flows.py +65 -15
- prefect/futures.py +5 -0
- prefect/infrastructure/__init__.py +6 -0
- prefect/infrastructure/base.py +6 -0
- prefect/logging/loggers.py +1 -1
- prefect/results.py +85 -68
- prefect/serializers.py +3 -3
- prefect/settings.py +7 -33
- prefect/task_engine.py +78 -21
- prefect/task_runners.py +28 -16
- prefect/task_worker.py +19 -6
- prefect/tasks.py +39 -7
- prefect/transactions.py +41 -3
- prefect/utilities/asyncutils.py +37 -8
- prefect/utilities/collections.py +1 -1
- prefect/utilities/importtools.py +1 -1
- prefect/utilities/timeout.py +20 -5
- prefect/workers/block.py +6 -0
- prefect/workers/cloud.py +6 -0
- {prefect_client-3.0.0rc8.dist-info → prefect_client-3.0.0rc10.dist-info}/METADATA +3 -2
- {prefect_client-3.0.0rc8.dist-info → prefect_client-3.0.0rc10.dist-info}/RECORD +41 -36
- {prefect_client-3.0.0rc8.dist-info → prefect_client-3.0.0rc10.dist-info}/LICENSE +0 -0
- {prefect_client-3.0.0rc8.dist-info → prefect_client-3.0.0rc10.dist-info}/WHEEL +0 -0
- {prefect_client-3.0.0rc8.dist-info → prefect_client-3.0.0rc10.dist-info}/top_level.txt +0 -0
prefect/task_engine.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
+
import threading
|
3
4
|
import time
|
5
|
+
from asyncio import CancelledError
|
4
6
|
from contextlib import ExitStack, contextmanager
|
5
7
|
from dataclasses import dataclass, field
|
6
8
|
from textwrap import dedent
|
@@ -17,6 +19,7 @@ from typing import (
|
|
17
19
|
Optional,
|
18
20
|
Sequence,
|
19
21
|
Set,
|
22
|
+
Type,
|
20
23
|
TypeVar,
|
21
24
|
Union,
|
22
25
|
)
|
@@ -36,17 +39,18 @@ from prefect.context import (
|
|
36
39
|
TaskRunContext,
|
37
40
|
hydrated_context,
|
38
41
|
)
|
39
|
-
from prefect.events.schemas.events import Event
|
42
|
+
from prefect.events.schemas.events import Event as PrefectEvent
|
40
43
|
from prefect.exceptions import (
|
41
44
|
Abort,
|
42
45
|
Pause,
|
43
46
|
PrefectException,
|
47
|
+
TerminationSignal,
|
44
48
|
UpstreamTaskError,
|
45
49
|
)
|
46
50
|
from prefect.futures import PrefectFuture
|
47
51
|
from prefect.logging.loggers import get_logger, patch_print, task_run_logger
|
48
52
|
from prefect.records.result_store import ResultFactoryStore
|
49
|
-
from prefect.results import ResultFactory, _format_user_supplied_storage_key
|
53
|
+
from prefect.results import BaseResult, ResultFactory, _format_user_supplied_storage_key
|
50
54
|
from prefect.settings import (
|
51
55
|
PREFECT_DEBUG_MODE,
|
52
56
|
PREFECT_TASKS_REFRESH_CACHE,
|
@@ -63,6 +67,7 @@ from prefect.states import (
|
|
63
67
|
return_value_to_state,
|
64
68
|
)
|
65
69
|
from prefect.transactions import Transaction, transaction
|
70
|
+
from prefect.utilities.annotations import NotSet
|
66
71
|
from prefect.utilities.asyncutils import run_coro_as_sync
|
67
72
|
from prefect.utilities.callables import call_with_parameters, parameters_to_args_kwargs
|
68
73
|
from prefect.utilities.collections import visit_collection
|
@@ -80,6 +85,10 @@ P = ParamSpec("P")
|
|
80
85
|
R = TypeVar("R")
|
81
86
|
|
82
87
|
|
88
|
+
class TaskRunTimeoutError(TimeoutError):
|
89
|
+
"""Raised when a task run exceeds its timeout."""
|
90
|
+
|
91
|
+
|
83
92
|
@dataclass
|
84
93
|
class TaskRunEngine(Generic[P, R]):
|
85
94
|
task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]
|
@@ -89,11 +98,15 @@ class TaskRunEngine(Generic[P, R]):
|
|
89
98
|
retries: int = 0
|
90
99
|
wait_for: Optional[Iterable[PrefectFuture]] = None
|
91
100
|
context: Optional[Dict[str, Any]] = None
|
101
|
+
# holds the return value from the user code
|
102
|
+
_return_value: Union[R, Type[NotSet]] = NotSet
|
103
|
+
# holds the exception raised by the user code, if any
|
104
|
+
_raised: Union[Exception, Type[NotSet]] = NotSet
|
92
105
|
_initial_run_context: Optional[TaskRunContext] = None
|
93
106
|
_is_started: bool = False
|
94
107
|
_client: Optional[SyncPrefectClient] = None
|
95
108
|
_task_name_set: bool = False
|
96
|
-
_last_event: Optional[
|
109
|
+
_last_event: Optional[PrefectEvent] = None
|
97
110
|
|
98
111
|
def __post_init__(self):
|
99
112
|
if self.parameters is None:
|
@@ -136,7 +149,16 @@ class TaskRunEngine(Generic[P, R]):
|
|
136
149
|
)
|
137
150
|
return False
|
138
151
|
|
139
|
-
def
|
152
|
+
def is_cancelled(self) -> bool:
|
153
|
+
if (
|
154
|
+
self.context
|
155
|
+
and "cancel_event" in self.context
|
156
|
+
and isinstance(self.context["cancel_event"], threading.Event)
|
157
|
+
):
|
158
|
+
return self.context["cancel_event"].is_set()
|
159
|
+
return False
|
160
|
+
|
161
|
+
def call_hooks(self, state: Optional[State] = None):
|
140
162
|
if state is None:
|
141
163
|
state = self.state
|
142
164
|
task = self.task
|
@@ -171,7 +193,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
171
193
|
else:
|
172
194
|
self.logger.info(f"Hook {hook_name!r} finished running successfully")
|
173
195
|
|
174
|
-
def compute_transaction_key(self) -> str:
|
196
|
+
def compute_transaction_key(self) -> Optional[str]:
|
175
197
|
key = None
|
176
198
|
if self.task.cache_policy:
|
177
199
|
flow_run_context = FlowRunContext.get()
|
@@ -304,12 +326,24 @@ class TaskRunEngine(Generic[P, R]):
|
|
304
326
|
return new_state
|
305
327
|
|
306
328
|
def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
329
|
+
if self._return_value is not NotSet:
|
330
|
+
# if the return value is a BaseResult, we need to fetch it
|
331
|
+
if isinstance(self._return_value, BaseResult):
|
332
|
+
_result = self._return_value.get()
|
333
|
+
if inspect.isawaitable(_result):
|
334
|
+
_result = run_coro_as_sync(_result)
|
335
|
+
return _result
|
336
|
+
|
337
|
+
# otherwise, return the value as is
|
338
|
+
return self._return_value
|
339
|
+
|
340
|
+
if self._raised is not NotSet:
|
341
|
+
# if the task raised an exception, raise it
|
342
|
+
if raise_on_failure:
|
343
|
+
raise self._raised
|
344
|
+
|
345
|
+
# otherwise, return the exception
|
346
|
+
return self._raised
|
313
347
|
|
314
348
|
def handle_success(self, result: R, transaction: Transaction) -> R:
|
315
349
|
result_factory = getattr(TaskRunContext.get(), "result_factory", None)
|
@@ -339,6 +373,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
339
373
|
if transaction.is_committed():
|
340
374
|
terminal_state.name = "Cached"
|
341
375
|
self.set_state(terminal_state)
|
376
|
+
self._return_value = result
|
342
377
|
return result
|
343
378
|
|
344
379
|
def handle_retry(self, exc: Exception) -> bool:
|
@@ -365,9 +400,11 @@ class TaskRunEngine(Generic[P, R]):
|
|
365
400
|
new_state = Retrying()
|
366
401
|
|
367
402
|
self.logger.info(
|
368
|
-
|
369
|
-
|
370
|
-
|
403
|
+
"Task run failed with exception: %r - " "Retry %s/%s will start %s",
|
404
|
+
exc,
|
405
|
+
self.retries + 1,
|
406
|
+
self.task.retries,
|
407
|
+
str(delay) + " second(s) from now" if delay else "immediately",
|
371
408
|
)
|
372
409
|
|
373
410
|
self.set_state(new_state, force=True)
|
@@ -375,7 +412,9 @@ class TaskRunEngine(Generic[P, R]):
|
|
375
412
|
return True
|
376
413
|
elif self.retries >= self.task.retries:
|
377
414
|
self.logger.error(
|
378
|
-
|
415
|
+
"Task run failed with exception: %r - Retries are exhausted",
|
416
|
+
exc,
|
417
|
+
exc_info=True,
|
379
418
|
)
|
380
419
|
return False
|
381
420
|
|
@@ -394,12 +433,14 @@ class TaskRunEngine(Generic[P, R]):
|
|
394
433
|
)
|
395
434
|
)
|
396
435
|
self.set_state(state)
|
436
|
+
self._raised = exc
|
397
437
|
|
398
438
|
def handle_timeout(self, exc: TimeoutError) -> None:
|
399
439
|
if not self.handle_retry(exc):
|
400
|
-
|
401
|
-
f"Task run exceeded timeout of {self.task.timeout_seconds}
|
402
|
-
|
440
|
+
if isinstance(exc, TaskRunTimeoutError):
|
441
|
+
message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
|
442
|
+
else:
|
443
|
+
message = f"Task run failed due to timeout: {exc!r}"
|
403
444
|
self.logger.error(message)
|
404
445
|
state = Failed(
|
405
446
|
data=exc,
|
@@ -407,12 +448,14 @@ class TaskRunEngine(Generic[P, R]):
|
|
407
448
|
name="TimedOut",
|
408
449
|
)
|
409
450
|
self.set_state(state)
|
451
|
+
self._raised = exc
|
410
452
|
|
411
453
|
def handle_crash(self, exc: BaseException) -> None:
|
412
454
|
state = run_coro_as_sync(exception_to_crashed_state(exc))
|
413
455
|
self.logger.error(f"Crash detected! {state.message}")
|
414
456
|
self.logger.debug("Crash details:", exc_info=exc)
|
415
457
|
self.set_state(state, force=True)
|
458
|
+
self._raised = exc
|
416
459
|
|
417
460
|
@contextmanager
|
418
461
|
def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
|
@@ -498,6 +541,11 @@ class TaskRunEngine(Generic[P, R]):
|
|
498
541
|
)
|
499
542
|
yield self
|
500
543
|
|
544
|
+
except TerminationSignal as exc:
|
545
|
+
# TerminationSignals are caught and handled as crashes
|
546
|
+
self.handle_crash(exc)
|
547
|
+
raise exc
|
548
|
+
|
501
549
|
except Exception:
|
502
550
|
# regular exceptions are caught and re-raised to the user
|
503
551
|
raise
|
@@ -539,8 +587,8 @@ class TaskRunEngine(Generic[P, R]):
|
|
539
587
|
|
540
588
|
@flow
|
541
589
|
def example_flow():
|
542
|
-
say_hello.submit(name="Marvin)
|
543
|
-
|
590
|
+
future = say_hello.submit(name="Marvin)
|
591
|
+
future.wait()
|
544
592
|
|
545
593
|
example_flow()
|
546
594
|
"""
|
@@ -602,6 +650,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
602
650
|
key=self.compute_transaction_key(),
|
603
651
|
store=ResultFactoryStore(result_factory=result_factory),
|
604
652
|
overwrite=overwrite,
|
653
|
+
logger=self.logger,
|
605
654
|
) as txn:
|
606
655
|
yield txn
|
607
656
|
|
@@ -611,10 +660,16 @@ class TaskRunEngine(Generic[P, R]):
|
|
611
660
|
# reenter the run context to ensure it is up to date for every run
|
612
661
|
with self.setup_run_context():
|
613
662
|
try:
|
614
|
-
with timeout_context(
|
663
|
+
with timeout_context(
|
664
|
+
seconds=self.task.timeout_seconds,
|
665
|
+
timeout_exc_type=TaskRunTimeoutError,
|
666
|
+
):
|
615
667
|
self.logger.debug(
|
616
668
|
f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..."
|
617
669
|
)
|
670
|
+
if self.is_cancelled():
|
671
|
+
raise CancelledError("Task run cancelled by the task runner")
|
672
|
+
|
618
673
|
yield self
|
619
674
|
except TimeoutError as exc:
|
620
675
|
self.handle_timeout(exc)
|
@@ -637,6 +692,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
637
692
|
else:
|
638
693
|
result = await call_with_parameters(self.task.fn, parameters)
|
639
694
|
self.handle_success(result, transaction=transaction)
|
695
|
+
return result
|
640
696
|
|
641
697
|
return _call_task_fn()
|
642
698
|
else:
|
@@ -645,6 +701,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
645
701
|
else:
|
646
702
|
result = call_with_parameters(self.task.fn, parameters)
|
647
703
|
self.handle_success(result, transaction=transaction)
|
704
|
+
return result
|
648
705
|
|
649
706
|
|
650
707
|
def run_task_sync(
|
prefect/task_runners.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import abc
|
2
2
|
import asyncio
|
3
3
|
import sys
|
4
|
+
import threading
|
4
5
|
import uuid
|
5
6
|
from concurrent.futures import ThreadPoolExecutor
|
6
7
|
from contextvars import copy_context
|
@@ -41,8 +42,8 @@ if TYPE_CHECKING:
|
|
41
42
|
|
42
43
|
P = ParamSpec("P")
|
43
44
|
T = TypeVar("T")
|
44
|
-
F = TypeVar("F", bound=PrefectFuture)
|
45
45
|
R = TypeVar("R")
|
46
|
+
F = TypeVar("F", bound=PrefectFuture, default=PrefectConcurrentFuture)
|
46
47
|
|
47
48
|
|
48
49
|
class TaskRunner(abc.ABC, Generic[F]):
|
@@ -220,6 +221,7 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture]):
|
|
220
221
|
super().__init__()
|
221
222
|
self._executor: Optional[ThreadPoolExecutor] = None
|
222
223
|
self._max_workers = sys.maxsize if max_workers is None else max_workers
|
224
|
+
self._cancel_events: Dict[uuid.UUID, threading.Event] = {}
|
223
225
|
|
224
226
|
def duplicate(self) -> "ThreadPoolTaskRunner":
|
225
227
|
return type(self)(max_workers=self._max_workers)
|
@@ -270,6 +272,8 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture]):
|
|
270
272
|
from prefect.task_engine import run_task_async, run_task_sync
|
271
273
|
|
272
274
|
task_run_id = uuid.uuid4()
|
275
|
+
cancel_event = threading.Event()
|
276
|
+
self._cancel_events[task_run_id] = cancel_event
|
273
277
|
context = copy_context()
|
274
278
|
|
275
279
|
flow_run_ctx = FlowRunContext.get()
|
@@ -280,31 +284,29 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture]):
|
|
280
284
|
else:
|
281
285
|
self.logger.info(f"Submitting task {task.name} to thread pool executor...")
|
282
286
|
|
287
|
+
submit_kwargs = dict(
|
288
|
+
task=task,
|
289
|
+
task_run_id=task_run_id,
|
290
|
+
parameters=parameters,
|
291
|
+
wait_for=wait_for,
|
292
|
+
return_type="state",
|
293
|
+
dependencies=dependencies,
|
294
|
+
context=dict(cancel_event=cancel_event),
|
295
|
+
)
|
296
|
+
|
283
297
|
if task.isasync:
|
284
298
|
# TODO: Explore possibly using a long-lived thread with an event loop
|
285
299
|
# for better performance
|
286
300
|
future = self._executor.submit(
|
287
301
|
context.run,
|
288
302
|
asyncio.run,
|
289
|
-
run_task_async(
|
290
|
-
task=task,
|
291
|
-
task_run_id=task_run_id,
|
292
|
-
parameters=parameters,
|
293
|
-
wait_for=wait_for,
|
294
|
-
return_type="state",
|
295
|
-
dependencies=dependencies,
|
296
|
-
),
|
303
|
+
run_task_async(**submit_kwargs),
|
297
304
|
)
|
298
305
|
else:
|
299
306
|
future = self._executor.submit(
|
300
307
|
context.run,
|
301
308
|
run_task_sync,
|
302
|
-
|
303
|
-
task_run_id=task_run_id,
|
304
|
-
parameters=parameters,
|
305
|
-
wait_for=wait_for,
|
306
|
-
return_type="state",
|
307
|
-
dependencies=dependencies,
|
309
|
+
**submit_kwargs,
|
308
310
|
)
|
309
311
|
prefect_future = PrefectConcurrentFuture(
|
310
312
|
task_run_id=task_run_id, wrapped_future=future
|
@@ -337,14 +339,24 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture]):
|
|
337
339
|
):
|
338
340
|
return super().map(task, parameters, wait_for)
|
339
341
|
|
342
|
+
def cancel_all(self):
|
343
|
+
for event in self._cancel_events.values():
|
344
|
+
event.set()
|
345
|
+
self.logger.debug("Set cancel event")
|
346
|
+
|
347
|
+
if self._executor is not None:
|
348
|
+
self._executor.shutdown(cancel_futures=True)
|
349
|
+
self._executor = None
|
350
|
+
|
340
351
|
def __enter__(self):
|
341
352
|
super().__enter__()
|
342
353
|
self._executor = ThreadPoolExecutor(max_workers=self._max_workers)
|
343
354
|
return self
|
344
355
|
|
345
356
|
def __exit__(self, exc_type, exc_value, traceback):
|
357
|
+
self.cancel_all()
|
346
358
|
if self._executor is not None:
|
347
|
-
self._executor.shutdown()
|
359
|
+
self._executor.shutdown(cancel_futures=True)
|
348
360
|
self._executor = None
|
349
361
|
super().__exit__(exc_type, exc_value, traceback)
|
350
362
|
|
prefect/task_worker.py
CHANGED
@@ -7,7 +7,7 @@ import sys
|
|
7
7
|
from concurrent.futures import ThreadPoolExecutor
|
8
8
|
from contextlib import AsyncExitStack
|
9
9
|
from contextvars import copy_context
|
10
|
-
from typing import
|
10
|
+
from typing import Optional
|
11
11
|
from uuid import UUID
|
12
12
|
|
13
13
|
import anyio
|
@@ -20,6 +20,7 @@ from websockets.exceptions import InvalidStatusCode
|
|
20
20
|
|
21
21
|
from prefect import Task
|
22
22
|
from prefect._internal.concurrency.api import create_call, from_sync
|
23
|
+
from prefect.cache_policies import DEFAULT, NONE
|
23
24
|
from prefect.client.orchestration import get_client
|
24
25
|
from prefect.client.schemas.objects import TaskRun
|
25
26
|
from prefect.client.subscriptions import Subscription
|
@@ -32,9 +33,11 @@ from prefect.settings import (
|
|
32
33
|
)
|
33
34
|
from prefect.states import Pending
|
34
35
|
from prefect.task_engine import run_task_async, run_task_sync
|
36
|
+
from prefect.utilities.annotations import NotSet
|
35
37
|
from prefect.utilities.asyncutils import asyncnullcontext, sync_compatible
|
36
38
|
from prefect.utilities.engine import emit_task_run_state_change_event, propose_state
|
37
39
|
from prefect.utilities.processutils import _register_signal
|
40
|
+
from prefect.utilities.urls import url_for
|
38
41
|
|
39
42
|
logger = get_logger("task_worker")
|
40
43
|
|
@@ -76,7 +79,16 @@ class TaskWorker:
|
|
76
79
|
*tasks: Task,
|
77
80
|
limit: Optional[int] = 10,
|
78
81
|
):
|
79
|
-
self.tasks
|
82
|
+
self.tasks = []
|
83
|
+
for t in tasks:
|
84
|
+
if isinstance(t, Task):
|
85
|
+
if t.cache_policy in [None, NONE, NotSet]:
|
86
|
+
self.tasks.append(
|
87
|
+
t.with_options(persist_result=True, cache_policy=DEFAULT)
|
88
|
+
)
|
89
|
+
else:
|
90
|
+
self.tasks.append(t.with_options(persist_result=True))
|
91
|
+
|
80
92
|
self.task_keys = set(t.task_key for t in tasks if isinstance(t, Task))
|
81
93
|
|
82
94
|
self._started_at: Optional[pendulum.DateTime] = None
|
@@ -277,10 +289,6 @@ class TaskWorker:
|
|
277
289
|
await self._client._client.delete(f"/task_runs/{task_run.id}")
|
278
290
|
return
|
279
291
|
|
280
|
-
logger.debug(
|
281
|
-
f"Submitting run {task_run.name!r} of task {task.name!r} to engine"
|
282
|
-
)
|
283
|
-
|
284
292
|
try:
|
285
293
|
new_state = Pending()
|
286
294
|
new_state.state_details.deferred = True
|
@@ -315,6 +323,11 @@ class TaskWorker:
|
|
315
323
|
validated_state=state,
|
316
324
|
)
|
317
325
|
|
326
|
+
if task_run_url := url_for(task_run):
|
327
|
+
logger.info(
|
328
|
+
f"Submitting task run {task_run.name!r} to engine. View run in the UI at {task_run_url!r}"
|
329
|
+
)
|
330
|
+
|
318
331
|
if task.isasync:
|
319
332
|
await run_task_async(
|
320
333
|
task=task,
|
prefect/tasks.py
CHANGED
@@ -33,6 +33,9 @@ from uuid import UUID, uuid4
|
|
33
33
|
|
34
34
|
from typing_extensions import Literal, ParamSpec
|
35
35
|
|
36
|
+
from prefect._internal.compatibility.deprecated import (
|
37
|
+
deprecated_async_method,
|
38
|
+
)
|
36
39
|
from prefect.cache_policies import DEFAULT, NONE, CachePolicy
|
37
40
|
from prefect.client.orchestration import get_client
|
38
41
|
from prefect.client.schemas import TaskRun
|
@@ -47,7 +50,6 @@ from prefect.futures import PrefectDistributedFuture, PrefectFuture, PrefectFutu
|
|
47
50
|
from prefect.logging.loggers import get_logger
|
48
51
|
from prefect.results import ResultFactory, ResultSerializer, ResultStorage
|
49
52
|
from prefect.settings import (
|
50
|
-
PREFECT_RESULTS_PERSIST_BY_DEFAULT,
|
51
53
|
PREFECT_TASK_DEFAULT_RETRIES,
|
52
54
|
PREFECT_TASK_DEFAULT_RETRY_DELAY_SECONDS,
|
53
55
|
)
|
@@ -64,6 +66,7 @@ from prefect.utilities.callables import (
|
|
64
66
|
)
|
65
67
|
from prefect.utilities.hashing import hash_objects
|
66
68
|
from prefect.utilities.importtools import to_qualified_name
|
69
|
+
from prefect.utilities.urls import url_for
|
67
70
|
|
68
71
|
if TYPE_CHECKING:
|
69
72
|
from prefect.client.orchestration import PrefectClient
|
@@ -384,9 +387,20 @@ class Task(Generic[P, R]):
|
|
384
387
|
self.cache_expiration = cache_expiration
|
385
388
|
self.refresh_cache = refresh_cache
|
386
389
|
|
390
|
+
# result persistence settings
|
387
391
|
if persist_result is None:
|
388
|
-
|
389
|
-
|
392
|
+
if any(
|
393
|
+
[
|
394
|
+
cache_policy and cache_policy != NONE and cache_policy != NotSet,
|
395
|
+
cache_key_fn is not None,
|
396
|
+
result_storage_key is not None,
|
397
|
+
result_storage is not None,
|
398
|
+
result_serializer is not None,
|
399
|
+
]
|
400
|
+
):
|
401
|
+
persist_result = True
|
402
|
+
|
403
|
+
if persist_result is False:
|
390
404
|
self.cache_policy = None if cache_policy is None else NONE
|
391
405
|
if cache_policy and cache_policy is not NotSet and cache_policy != NONE:
|
392
406
|
logger.warning(
|
@@ -425,6 +439,14 @@ class Task(Generic[P, R]):
|
|
425
439
|
|
426
440
|
self.retry_jitter_factor = retry_jitter_factor
|
427
441
|
self.persist_result = persist_result
|
442
|
+
|
443
|
+
if result_storage and not isinstance(result_storage, str):
|
444
|
+
if getattr(result_storage, "_block_document_id", None) is None:
|
445
|
+
raise TypeError(
|
446
|
+
"Result storage configuration must be persisted server-side."
|
447
|
+
" Please call `.save()` on your block before passing it in."
|
448
|
+
)
|
449
|
+
|
428
450
|
self.result_storage = result_storage
|
429
451
|
self.result_serializer = result_serializer
|
430
452
|
self.result_storage_key = result_storage_key
|
@@ -477,7 +499,7 @@ class Task(Generic[P, R]):
|
|
477
499
|
cache_key_fn: Optional[
|
478
500
|
Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
|
479
501
|
] = None,
|
480
|
-
task_run_name: Optional[Union[Callable[[], str], str]] =
|
502
|
+
task_run_name: Optional[Union[Callable[[], str], str, Type[NotSet]]] = NotSet,
|
481
503
|
cache_expiration: Optional[datetime.timedelta] = None,
|
482
504
|
retries: Union[int, Type[NotSet]] = NotSet,
|
483
505
|
retry_delay_seconds: Union[
|
@@ -591,7 +613,9 @@ class Task(Generic[P, R]):
|
|
591
613
|
else self.cache_policy,
|
592
614
|
cache_key_fn=cache_key_fn or self.cache_key_fn,
|
593
615
|
cache_expiration=cache_expiration or self.cache_expiration,
|
594
|
-
task_run_name=task_run_name
|
616
|
+
task_run_name=task_run_name
|
617
|
+
if task_run_name is not NotSet
|
618
|
+
else self.task_run_name,
|
595
619
|
retries=retries if retries is not NotSet else self.retries,
|
596
620
|
retry_delay_seconds=(
|
597
621
|
retry_delay_seconds
|
@@ -869,6 +893,7 @@ class Task(Generic[P, R]):
|
|
869
893
|
) -> State[T]:
|
870
894
|
...
|
871
895
|
|
896
|
+
@deprecated_async_method
|
872
897
|
def submit(
|
873
898
|
self,
|
874
899
|
*args: Any,
|
@@ -1033,6 +1058,7 @@ class Task(Generic[P, R]):
|
|
1033
1058
|
) -> PrefectFutureList[State[T]]:
|
1034
1059
|
...
|
1035
1060
|
|
1061
|
+
@deprecated_async_method
|
1036
1062
|
def map(
|
1037
1063
|
self,
|
1038
1064
|
*args: Any,
|
@@ -1275,14 +1301,20 @@ class Task(Generic[P, R]):
|
|
1275
1301
|
# Convert the call args/kwargs to a parameter dict
|
1276
1302
|
parameters = get_call_parameters(self.fn, args, kwargs)
|
1277
1303
|
|
1278
|
-
task_run = run_coro_as_sync(
|
1304
|
+
task_run: TaskRun = run_coro_as_sync(
|
1279
1305
|
self.create_run(
|
1280
1306
|
parameters=parameters,
|
1281
1307
|
deferred=True,
|
1282
1308
|
wait_for=wait_for,
|
1283
1309
|
extra_task_inputs=dependencies,
|
1284
1310
|
)
|
1285
|
-
)
|
1311
|
+
) # type: ignore
|
1312
|
+
|
1313
|
+
if task_run_url := url_for(task_run):
|
1314
|
+
logger.info(
|
1315
|
+
f"Created task run {task_run.name!r}. View it in the UI at {task_run_url!r}"
|
1316
|
+
)
|
1317
|
+
|
1286
1318
|
return PrefectDistributedFuture(task_run_id=task_run.id)
|
1287
1319
|
|
1288
1320
|
def delay(self, *args: P.args, **kwargs: P.kwargs) -> PrefectDistributedFuture:
|
prefect/transactions.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import logging
|
1
2
|
from contextlib import contextmanager
|
2
3
|
from contextvars import ContextVar, Token
|
3
4
|
from typing import (
|
@@ -7,21 +8,25 @@ from typing import (
|
|
7
8
|
List,
|
8
9
|
Optional,
|
9
10
|
Type,
|
11
|
+
Union,
|
10
12
|
)
|
11
13
|
|
12
14
|
from pydantic import Field
|
13
15
|
from typing_extensions import Self
|
14
16
|
|
15
17
|
from prefect.context import ContextModel, FlowRunContext, TaskRunContext
|
18
|
+
from prefect.exceptions import MissingContextError
|
19
|
+
from prefect.logging.loggers import PrefectLogAdapter, get_logger, get_run_logger
|
16
20
|
from prefect.records import RecordStore
|
17
21
|
from prefect.records.result_store import ResultFactoryStore
|
18
22
|
from prefect.results import (
|
19
23
|
BaseResult,
|
20
24
|
ResultFactory,
|
21
|
-
|
25
|
+
get_default_result_storage,
|
22
26
|
)
|
23
27
|
from prefect.utilities.asyncutils import run_coro_as_sync
|
24
28
|
from prefect.utilities.collections import AutoEnum
|
29
|
+
from prefect.utilities.engine import _get_hook_name
|
25
30
|
|
26
31
|
|
27
32
|
class IsolationLevel(AutoEnum):
|
@@ -58,6 +63,7 @@ class Transaction(ContextModel):
|
|
58
63
|
default_factory=list
|
59
64
|
)
|
60
65
|
overwrite: bool = False
|
66
|
+
logger: Union[logging.Logger, logging.LoggerAdapter, None] = None
|
61
67
|
_staged_value: Any = None
|
62
68
|
__var__: ContextVar = ContextVar("transaction")
|
63
69
|
|
@@ -174,10 +180,13 @@ class Transaction(ContextModel):
|
|
174
180
|
return False
|
175
181
|
|
176
182
|
try:
|
183
|
+
hook_name = None
|
184
|
+
|
177
185
|
for child in self.children:
|
178
186
|
child.commit()
|
179
187
|
|
180
188
|
for hook in self.on_commit_hooks:
|
189
|
+
hook_name = _get_hook_name(hook)
|
181
190
|
hook(self)
|
182
191
|
|
183
192
|
if self.store and self.key:
|
@@ -185,6 +194,19 @@ class Transaction(ContextModel):
|
|
185
194
|
self.state = TransactionState.COMMITTED
|
186
195
|
return True
|
187
196
|
except Exception:
|
197
|
+
if self.logger:
|
198
|
+
if hook_name:
|
199
|
+
msg = (
|
200
|
+
f"An error was encountered while running commit hook {hook_name!r}",
|
201
|
+
)
|
202
|
+
else:
|
203
|
+
msg = (
|
204
|
+
f"An error was encountered while committing transaction {self.key!r}",
|
205
|
+
)
|
206
|
+
self.logger.exception(
|
207
|
+
msg,
|
208
|
+
exc_info=True,
|
209
|
+
)
|
188
210
|
self.rollback()
|
189
211
|
return False
|
190
212
|
|
@@ -212,6 +234,7 @@ class Transaction(ContextModel):
|
|
212
234
|
|
213
235
|
try:
|
214
236
|
for hook in reversed(self.on_rollback_hooks):
|
237
|
+
hook_name = _get_hook_name(hook)
|
215
238
|
hook(self)
|
216
239
|
|
217
240
|
self.state = TransactionState.ROLLED_BACK
|
@@ -221,6 +244,11 @@ class Transaction(ContextModel):
|
|
221
244
|
|
222
245
|
return True
|
223
246
|
except Exception:
|
247
|
+
if self.logger:
|
248
|
+
self.logger.exception(
|
249
|
+
f"An error was encountered while running rollback hook {hook_name!r}",
|
250
|
+
exc_info=True,
|
251
|
+
)
|
224
252
|
return False
|
225
253
|
|
226
254
|
@classmethod
|
@@ -238,6 +266,7 @@ def transaction(
|
|
238
266
|
store: Optional[RecordStore] = None,
|
239
267
|
commit_mode: Optional[CommitMode] = None,
|
240
268
|
overwrite: bool = False,
|
269
|
+
logger: Optional[PrefectLogAdapter] = None,
|
241
270
|
) -> Generator[Transaction, None, None]:
|
242
271
|
"""
|
243
272
|
A context manager for opening and managing a transaction.
|
@@ -268,7 +297,7 @@ def transaction(
|
|
268
297
|
}
|
269
298
|
)
|
270
299
|
else:
|
271
|
-
default_storage =
|
300
|
+
default_storage = get_default_result_storage(_sync=True)
|
272
301
|
if existing_factory:
|
273
302
|
new_factory = existing_factory.model_copy(
|
274
303
|
update={
|
@@ -288,7 +317,16 @@ def transaction(
|
|
288
317
|
result_factory=new_factory,
|
289
318
|
)
|
290
319
|
|
320
|
+
try:
|
321
|
+
logger = logger or get_run_logger()
|
322
|
+
except MissingContextError:
|
323
|
+
logger = get_logger("transactions")
|
324
|
+
|
291
325
|
with Transaction(
|
292
|
-
key=key,
|
326
|
+
key=key,
|
327
|
+
store=store,
|
328
|
+
commit_mode=commit_mode,
|
329
|
+
overwrite=overwrite,
|
330
|
+
logger=logger,
|
293
331
|
) as txn:
|
294
332
|
yield txn
|