prefect-client 3.0.0rc9__py3-none-any.whl → 3.0.0rc10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefect/_internal/compatibility/migration.py +48 -8
- prefect/agent.py +6 -0
- prefect/client/schemas/objects.py +2 -3
- prefect/context.py +6 -0
- prefect/deployments/schedules.py +5 -2
- prefect/events/schemas/automations.py +3 -3
- prefect/exceptions.py +4 -1
- prefect/filesystems.py +4 -3
- prefect/flow_engine.py +72 -8
- prefect/flows.py +48 -4
- prefect/infrastructure/__init__.py +6 -0
- prefect/infrastructure/base.py +6 -0
- prefect/results.py +50 -67
- prefect/serializers.py +3 -3
- prefect/settings.py +6 -32
- prefect/task_engine.py +77 -21
- prefect/task_runners.py +28 -16
- prefect/task_worker.py +6 -4
- prefect/tasks.py +30 -5
- prefect/transactions.py +2 -2
- prefect/utilities/asyncutils.py +8 -3
- prefect/utilities/importtools.py +1 -1
- prefect/utilities/timeout.py +20 -5
- prefect/workers/block.py +6 -0
- prefect/workers/cloud.py +6 -0
- {prefect_client-3.0.0rc9.dist-info → prefect_client-3.0.0rc10.dist-info}/METADATA +2 -2
- {prefect_client-3.0.0rc9.dist-info → prefect_client-3.0.0rc10.dist-info}/RECORD +30 -26
- {prefect_client-3.0.0rc9.dist-info → prefect_client-3.0.0rc10.dist-info}/LICENSE +0 -0
- {prefect_client-3.0.0rc9.dist-info → prefect_client-3.0.0rc10.dist-info}/WHEEL +0 -0
- {prefect_client-3.0.0rc9.dist-info → prefect_client-3.0.0rc10.dist-info}/top_level.txt +0 -0
prefect/results.py
CHANGED
@@ -28,7 +28,6 @@ from prefect.client.utilities import inject_client
|
|
28
28
|
from prefect.exceptions import MissingResult, ObjectAlreadyExists
|
29
29
|
from prefect.filesystems import (
|
30
30
|
LocalFileSystem,
|
31
|
-
ReadableFileSystem,
|
32
31
|
WritableFileSystem,
|
33
32
|
)
|
34
33
|
from prefect.logging import get_logger
|
@@ -111,22 +110,32 @@ async def _get_or_create_default_storage(block_document_slug: str) -> ResultStor
|
|
111
110
|
|
112
111
|
|
113
112
|
@sync_compatible
|
114
|
-
async def
|
113
|
+
async def get_default_result_storage() -> ResultStorage:
|
115
114
|
"""
|
116
115
|
Generate a default file system for result storage.
|
117
116
|
"""
|
118
|
-
|
119
|
-
|
120
|
-
|
117
|
+
default_block = PREFECT_DEFAULT_RESULT_STORAGE_BLOCK.value()
|
118
|
+
|
119
|
+
if default_block is not None:
|
120
|
+
return await Block.load(default_block)
|
121
|
+
|
122
|
+
# otherwise, use the local file system
|
123
|
+
basepath = PREFECT_LOCAL_STORAGE_PATH.value()
|
124
|
+
return LocalFileSystem(basepath=basepath)
|
121
125
|
|
122
126
|
|
123
127
|
async def get_or_create_default_task_scheduling_storage() -> ResultStorage:
|
124
128
|
"""
|
125
129
|
Generate a default file system for background task parameter/result storage.
|
126
130
|
"""
|
127
|
-
|
128
|
-
|
129
|
-
|
131
|
+
default_block = PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK.value()
|
132
|
+
|
133
|
+
if default_block is not None:
|
134
|
+
return await Block.load(default_block)
|
135
|
+
|
136
|
+
# otherwise, use the local file system
|
137
|
+
basepath = PREFECT_LOCAL_STORAGE_PATH.value()
|
138
|
+
return LocalFileSystem(basepath=basepath)
|
130
139
|
|
131
140
|
|
132
141
|
def get_default_result_serializer() -> ResultSerializer:
|
@@ -177,9 +186,7 @@ class ResultFactory(BaseModel):
|
|
177
186
|
kwargs.pop(key)
|
178
187
|
|
179
188
|
# Apply defaults
|
180
|
-
kwargs.setdefault(
|
181
|
-
"result_storage", await get_or_create_default_result_storage()
|
182
|
-
)
|
189
|
+
kwargs.setdefault("result_storage", await get_default_result_storage())
|
183
190
|
kwargs.setdefault("result_serializer", get_default_result_serializer())
|
184
191
|
kwargs.setdefault("persist_result", get_default_persist_setting())
|
185
192
|
kwargs.setdefault("cache_result_in_memory", True)
|
@@ -230,9 +237,7 @@ class ResultFactory(BaseModel):
|
|
230
237
|
"""
|
231
238
|
Create a new result factory for a task.
|
232
239
|
"""
|
233
|
-
return await cls._from_task(
|
234
|
-
task, get_or_create_default_result_storage, client=client
|
235
|
-
)
|
240
|
+
return await cls._from_task(task, get_default_result_storage, client=client)
|
236
241
|
|
237
242
|
@classmethod
|
238
243
|
@inject_client
|
@@ -268,7 +273,14 @@ class ResultFactory(BaseModel):
|
|
268
273
|
if ctx and ctx.result_factory
|
269
274
|
else get_default_result_serializer()
|
270
275
|
)
|
271
|
-
|
276
|
+
if task.persist_result is None:
|
277
|
+
persist_result = (
|
278
|
+
ctx.result_factory.persist_result
|
279
|
+
if ctx and ctx.result_factory
|
280
|
+
else get_default_persist_setting()
|
281
|
+
)
|
282
|
+
else:
|
283
|
+
persist_result = task.persist_result
|
272
284
|
|
273
285
|
cache_result_in_memory = task.cache_result_in_memory
|
274
286
|
|
@@ -330,16 +342,7 @@ class ResultFactory(BaseModel):
|
|
330
342
|
# Avoid saving the block if it already has an identifier assigned
|
331
343
|
storage_block_id = storage_block._block_document_id
|
332
344
|
else:
|
333
|
-
|
334
|
-
# TODO: Overwrite is true to avoid issues where the save collides with
|
335
|
-
# a previously saved document with a matching hash
|
336
|
-
storage_block_id = await storage_block._save(
|
337
|
-
is_anonymous=True, overwrite=True, client=client
|
338
|
-
)
|
339
|
-
else:
|
340
|
-
# a None-type UUID on unpersisted storage should not matter
|
341
|
-
# since the ID is generated on the server
|
342
|
-
storage_block_id = None
|
345
|
+
storage_block_id = None
|
343
346
|
elif isinstance(result_storage, str):
|
344
347
|
storage_block = await Block.load(result_storage, client=client)
|
345
348
|
storage_block_id = storage_block._block_document_id
|
@@ -412,9 +415,6 @@ class ResultFactory(BaseModel):
|
|
412
415
|
|
413
416
|
@sync_compatible
|
414
417
|
async def store_parameters(self, identifier: UUID, parameters: Dict[str, Any]):
|
415
|
-
assert (
|
416
|
-
self.storage_block_id is not None
|
417
|
-
), "Unexpected storage block ID. Was it persisted?"
|
418
418
|
data = self.serializer.dumps(parameters)
|
419
419
|
blob = PersistedResultBlob(serializer=self.serializer, data=data)
|
420
420
|
await self.storage_block.write_path(
|
@@ -423,9 +423,6 @@ class ResultFactory(BaseModel):
|
|
423
423
|
|
424
424
|
@sync_compatible
|
425
425
|
async def read_parameters(self, identifier: UUID) -> Dict[str, Any]:
|
426
|
-
assert (
|
427
|
-
self.storage_block_id is not None
|
428
|
-
), "Unexpected storage block ID. Was it persisted?"
|
429
426
|
blob = PersistedResultBlob.model_validate_json(
|
430
427
|
await self.storage_block.read_path(f"parameters/{identifier}")
|
431
428
|
)
|
@@ -435,10 +432,7 @@ class ResultFactory(BaseModel):
|
|
435
432
|
@register_base_type
|
436
433
|
class BaseResult(BaseModel, abc.ABC, Generic[R]):
|
437
434
|
model_config = ConfigDict(extra="forbid")
|
438
|
-
|
439
435
|
type: str
|
440
|
-
artifact_type: Optional[str] = None
|
441
|
-
artifact_description: Optional[str] = None
|
442
436
|
|
443
437
|
def __init__(self, **data: Any) -> None:
|
444
438
|
type_string = get_dispatch_key(self) if type(self) != BaseResult else "__base__"
|
@@ -504,11 +498,7 @@ class UnpersistedResult(BaseResult):
|
|
504
498
|
obj: R,
|
505
499
|
cache_object: bool = True,
|
506
500
|
) -> "UnpersistedResult[R]":
|
507
|
-
|
508
|
-
result = cls(
|
509
|
-
artifact_type="result",
|
510
|
-
artifact_description=description,
|
511
|
-
)
|
501
|
+
result = cls()
|
512
502
|
# Only store the object in local memory, it will not be sent to the API
|
513
503
|
if cache_object:
|
514
504
|
result._cache_object(obj)
|
@@ -528,8 +518,8 @@ class PersistedResult(BaseResult):
|
|
528
518
|
type: str = "reference"
|
529
519
|
|
530
520
|
serializer_type: str
|
531
|
-
storage_block_id: uuid.UUID
|
532
521
|
storage_key: str
|
522
|
+
storage_block_id: Optional[uuid.UUID] = None
|
533
523
|
expiration: Optional[DateTime] = None
|
534
524
|
|
535
525
|
_should_cache_object: bool = PrivateAttr(default=True)
|
@@ -547,6 +537,17 @@ class PersistedResult(BaseResult):
|
|
547
537
|
self._storage_block = storage_block
|
548
538
|
self._serializer = serializer
|
549
539
|
|
540
|
+
@inject_client
|
541
|
+
async def _get_storage_block(self, client: "PrefectClient") -> WritableFileSystem:
|
542
|
+
if self._storage_block is not None:
|
543
|
+
return self._storage_block
|
544
|
+
elif self.storage_block_id is not None:
|
545
|
+
block_document = await client.read_block_document(self.storage_block_id)
|
546
|
+
self._storage_block = Block._from_block_document(block_document)
|
547
|
+
else:
|
548
|
+
self._storage_block = await get_default_result_storage()
|
549
|
+
return self._storage_block
|
550
|
+
|
550
551
|
@sync_compatible
|
551
552
|
@inject_client
|
552
553
|
async def get(self, client: "PrefectClient") -> R:
|
@@ -567,12 +568,8 @@ class PersistedResult(BaseResult):
|
|
567
568
|
|
568
569
|
@inject_client
|
569
570
|
async def _read_blob(self, client: "PrefectClient") -> "PersistedResultBlob":
|
570
|
-
|
571
|
-
|
572
|
-
), "Unexpected storage block ID. Was it persisted?"
|
573
|
-
block_document = await client.read_block_document(self.storage_block_id)
|
574
|
-
storage_block: ReadableFileSystem = Block._from_block_document(block_document)
|
575
|
-
content = await storage_block.read_path(self.storage_key)
|
571
|
+
block = await self._get_storage_block(client=client)
|
572
|
+
content = await block.read_path(self.storage_key)
|
576
573
|
blob = PersistedResultBlob.model_validate_json(content)
|
577
574
|
return blob
|
578
575
|
|
@@ -607,10 +604,7 @@ class PersistedResult(BaseResult):
|
|
607
604
|
obj = obj if obj is not NotSet else self._cache
|
608
605
|
|
609
606
|
# next, the storage block
|
610
|
-
storage_block = self.
|
611
|
-
if storage_block is None:
|
612
|
-
block_document = await client.read_block_document(self.storage_block_id)
|
613
|
-
storage_block = Block._from_block_document(block_document)
|
607
|
+
storage_block = await self._get_storage_block(client=client)
|
614
608
|
|
615
609
|
# finally, the serializer
|
616
610
|
serializer = self._serializer
|
@@ -673,9 +667,9 @@ class PersistedResult(BaseResult):
|
|
673
667
|
cls: "Type[PersistedResult]",
|
674
668
|
obj: R,
|
675
669
|
storage_block: WritableFileSystem,
|
676
|
-
storage_block_id: uuid.UUID,
|
677
670
|
storage_key_fn: Callable[[], str],
|
678
671
|
serializer: Serializer,
|
672
|
+
storage_block_id: Optional[uuid.UUID] = None,
|
679
673
|
cache_object: bool = True,
|
680
674
|
expiration: Optional[DateTime] = None,
|
681
675
|
defer_persistence: bool = False,
|
@@ -686,31 +680,21 @@ class PersistedResult(BaseResult):
|
|
686
680
|
The object will be serialized and written to the storage block under a unique
|
687
681
|
key. It will then be cached on the returned result.
|
688
682
|
"""
|
689
|
-
assert (
|
690
|
-
storage_block_id is not None
|
691
|
-
), "Unexpected storage block ID. Was it saved?"
|
692
|
-
|
693
683
|
key = storage_key_fn()
|
694
684
|
if not isinstance(key, str):
|
695
685
|
raise TypeError(
|
696
686
|
f"Expected type 'str' for result storage key; got value {key!r}"
|
697
687
|
)
|
698
|
-
description = f"Result of type `{type(obj).__name__}`"
|
699
688
|
uri = cls._infer_path(storage_block, key)
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
description += f" persisted to [{uri}]({uri})."
|
705
|
-
else:
|
706
|
-
description += f" persisted with storage block `{storage_block_id}`."
|
689
|
+
|
690
|
+
# in this case we store an absolute path
|
691
|
+
if storage_block_id is None and uri is not None:
|
692
|
+
key = str(uri)
|
707
693
|
|
708
694
|
result = cls(
|
709
695
|
serializer_type=serializer.type,
|
710
696
|
storage_block_id=storage_block_id,
|
711
697
|
storage_key=key,
|
712
|
-
artifact_type="result",
|
713
|
-
artifact_description=description,
|
714
698
|
expiration=expiration,
|
715
699
|
)
|
716
700
|
|
@@ -787,5 +771,4 @@ class UnknownResult(BaseResult):
|
|
787
771
|
"Only None is supported."
|
788
772
|
)
|
789
773
|
|
790
|
-
|
791
|
-
return cls(value=obj, artifact_type="result", artifact_description=description)
|
774
|
+
return cls(value=obj)
|
prefect/serializers.py
CHANGED
@@ -13,7 +13,7 @@ bytes to an object respectively.
|
|
13
13
|
|
14
14
|
import abc
|
15
15
|
import base64
|
16
|
-
from typing import Any, Dict, Generic, Optional, Type
|
16
|
+
from typing import Any, Dict, Generic, Optional, Type
|
17
17
|
|
18
18
|
from pydantic import (
|
19
19
|
BaseModel,
|
@@ -23,7 +23,7 @@ from pydantic import (
|
|
23
23
|
ValidationError,
|
24
24
|
field_validator,
|
25
25
|
)
|
26
|
-
from typing_extensions import Literal, Self
|
26
|
+
from typing_extensions import Literal, Self, TypeVar
|
27
27
|
|
28
28
|
from prefect._internal.schemas.validators import (
|
29
29
|
cast_type_names_to_serializers,
|
@@ -36,7 +36,7 @@ from prefect.utilities.dispatch import get_dispatch_key, lookup_type, register_b
|
|
36
36
|
from prefect.utilities.importtools import from_qualified_name, to_qualified_name
|
37
37
|
from prefect.utilities.pydantic import custom_pydantic_encoder
|
38
38
|
|
39
|
-
D = TypeVar("D")
|
39
|
+
D = TypeVar("D", default=Any)
|
40
40
|
|
41
41
|
|
42
42
|
def prefect_json_object_encoder(obj: Any) -> Any:
|
prefect/settings.py
CHANGED
@@ -42,7 +42,7 @@ dependent on the value of other settings or perform other dynamic effects.
|
|
42
42
|
|
43
43
|
import logging
|
44
44
|
import os
|
45
|
-
import
|
45
|
+
import re
|
46
46
|
import string
|
47
47
|
import warnings
|
48
48
|
from contextlib import contextmanager
|
@@ -85,7 +85,6 @@ from prefect._internal.schemas.validators import validate_settings
|
|
85
85
|
from prefect.exceptions import MissingProfileError
|
86
86
|
from prefect.utilities.names import OBFUSCATED_PREFIX, obfuscate
|
87
87
|
from prefect.utilities.pydantic import add_cloudpickle_reduction
|
88
|
-
from prefect.utilities.slugify import slugify
|
89
88
|
|
90
89
|
T = TypeVar("T")
|
91
90
|
|
@@ -404,18 +403,6 @@ def warn_on_misconfigured_api_url(values):
|
|
404
403
|
return values
|
405
404
|
|
406
405
|
|
407
|
-
def default_result_storage_block_name(
|
408
|
-
settings: Optional["Settings"] = None, value: Optional[str] = None
|
409
|
-
):
|
410
|
-
"""
|
411
|
-
`value_callback` for `PREFECT_DEFAULT_RESULT_STORAGE_BLOCK` that sets the default
|
412
|
-
value to the hostname of the machine.
|
413
|
-
"""
|
414
|
-
if value is None:
|
415
|
-
return f"local-file-system/{slugify(socket.gethostname())}-storage"
|
416
|
-
return value
|
417
|
-
|
418
|
-
|
419
406
|
def default_database_connection_url(settings, value):
|
420
407
|
templater = template_with_settings(PREFECT_HOME, PREFECT_API_DATABASE_PASSWORD)
|
421
408
|
|
@@ -474,10 +461,8 @@ def default_cloud_ui_url(settings, value):
|
|
474
461
|
# Otherwise, infer a value from the API URL
|
475
462
|
ui_url = api_url = PREFECT_CLOUD_API_URL.value_from(settings)
|
476
463
|
|
477
|
-
if
|
478
|
-
ui_url = ui_url.replace(
|
479
|
-
"https://api.prefect.cloud", "https://app.prefect.cloud", 1
|
480
|
-
)
|
464
|
+
if re.match(r"^https://api[\.\w]*.prefect.[^\.]+/", api_url):
|
465
|
+
ui_url = ui_url.replace("https://api", "https://app", 1)
|
481
466
|
|
482
467
|
if ui_url.endswith("/api"):
|
483
468
|
ui_url = ui_url[:-4]
|
@@ -1323,15 +1308,6 @@ PREFECT_API_MAX_FLOW_RUN_GRAPH_ARTIFACTS = Setting(int, default=10000)
|
|
1323
1308
|
The maximum number of artifacts to show on a flow run graph on the v2 API
|
1324
1309
|
"""
|
1325
1310
|
|
1326
|
-
PREFECT_EXPERIMENTAL_ENABLE_WORKERS = Setting(bool, default=True)
|
1327
|
-
"""
|
1328
|
-
Whether or not to enable experimental Prefect workers.
|
1329
|
-
"""
|
1330
|
-
|
1331
|
-
PREFECT_EXPERIMENTAL_WARN_WORKERS = Setting(bool, default=False)
|
1332
|
-
"""
|
1333
|
-
Whether or not to warn when experimental Prefect workers are used.
|
1334
|
-
"""
|
1335
1311
|
|
1336
1312
|
PREFECT_EXPERIMENTAL_ENABLE_ENHANCED_CANCELLATION = Setting(bool, default=True)
|
1337
1313
|
"""
|
@@ -1423,10 +1399,7 @@ PREFECT_API_SERVICES_TASK_SCHEDULING_ENABLED = Setting(bool, default=True)
|
|
1423
1399
|
Whether or not to start the task scheduling service in the server application.
|
1424
1400
|
"""
|
1425
1401
|
|
1426
|
-
PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK = Setting(
|
1427
|
-
str,
|
1428
|
-
default="local-file-system/prefect-task-scheduling",
|
1429
|
-
)
|
1402
|
+
PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK = Setting(Optional[str], default=None)
|
1430
1403
|
"""The `block-type/block-document` slug of a block to use as the default storage
|
1431
1404
|
for autonomous tasks."""
|
1432
1405
|
|
@@ -1479,7 +1452,8 @@ PREFECT_EXPERIMENTAL_ENABLE_SCHEDULE_CONCURRENCY = Setting(bool, default=False)
|
|
1479
1452
|
# Defaults -----------------------------------------------------------------------------
|
1480
1453
|
|
1481
1454
|
PREFECT_DEFAULT_RESULT_STORAGE_BLOCK = Setting(
|
1482
|
-
Optional[str],
|
1455
|
+
Optional[str],
|
1456
|
+
default=None,
|
1483
1457
|
)
|
1484
1458
|
"""The `block-type/block-document` slug of a block to use as the default result storage."""
|
1485
1459
|
|
prefect/task_engine.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
import inspect
|
2
2
|
import logging
|
3
|
+
import threading
|
3
4
|
import time
|
5
|
+
from asyncio import CancelledError
|
4
6
|
from contextlib import ExitStack, contextmanager
|
5
7
|
from dataclasses import dataclass, field
|
6
8
|
from textwrap import dedent
|
@@ -17,6 +19,7 @@ from typing import (
|
|
17
19
|
Optional,
|
18
20
|
Sequence,
|
19
21
|
Set,
|
22
|
+
Type,
|
20
23
|
TypeVar,
|
21
24
|
Union,
|
22
25
|
)
|
@@ -36,17 +39,18 @@ from prefect.context import (
|
|
36
39
|
TaskRunContext,
|
37
40
|
hydrated_context,
|
38
41
|
)
|
39
|
-
from prefect.events.schemas.events import Event
|
42
|
+
from prefect.events.schemas.events import Event as PrefectEvent
|
40
43
|
from prefect.exceptions import (
|
41
44
|
Abort,
|
42
45
|
Pause,
|
43
46
|
PrefectException,
|
47
|
+
TerminationSignal,
|
44
48
|
UpstreamTaskError,
|
45
49
|
)
|
46
50
|
from prefect.futures import PrefectFuture
|
47
51
|
from prefect.logging.loggers import get_logger, patch_print, task_run_logger
|
48
52
|
from prefect.records.result_store import ResultFactoryStore
|
49
|
-
from prefect.results import ResultFactory, _format_user_supplied_storage_key
|
53
|
+
from prefect.results import BaseResult, ResultFactory, _format_user_supplied_storage_key
|
50
54
|
from prefect.settings import (
|
51
55
|
PREFECT_DEBUG_MODE,
|
52
56
|
PREFECT_TASKS_REFRESH_CACHE,
|
@@ -63,6 +67,7 @@ from prefect.states import (
|
|
63
67
|
return_value_to_state,
|
64
68
|
)
|
65
69
|
from prefect.transactions import Transaction, transaction
|
70
|
+
from prefect.utilities.annotations import NotSet
|
66
71
|
from prefect.utilities.asyncutils import run_coro_as_sync
|
67
72
|
from prefect.utilities.callables import call_with_parameters, parameters_to_args_kwargs
|
68
73
|
from prefect.utilities.collections import visit_collection
|
@@ -80,6 +85,10 @@ P = ParamSpec("P")
|
|
80
85
|
R = TypeVar("R")
|
81
86
|
|
82
87
|
|
88
|
+
class TaskRunTimeoutError(TimeoutError):
|
89
|
+
"""Raised when a task run exceeds its timeout."""
|
90
|
+
|
91
|
+
|
83
92
|
@dataclass
|
84
93
|
class TaskRunEngine(Generic[P, R]):
|
85
94
|
task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]
|
@@ -89,11 +98,15 @@ class TaskRunEngine(Generic[P, R]):
|
|
89
98
|
retries: int = 0
|
90
99
|
wait_for: Optional[Iterable[PrefectFuture]] = None
|
91
100
|
context: Optional[Dict[str, Any]] = None
|
101
|
+
# holds the return value from the user code
|
102
|
+
_return_value: Union[R, Type[NotSet]] = NotSet
|
103
|
+
# holds the exception raised by the user code, if any
|
104
|
+
_raised: Union[Exception, Type[NotSet]] = NotSet
|
92
105
|
_initial_run_context: Optional[TaskRunContext] = None
|
93
106
|
_is_started: bool = False
|
94
107
|
_client: Optional[SyncPrefectClient] = None
|
95
108
|
_task_name_set: bool = False
|
96
|
-
_last_event: Optional[
|
109
|
+
_last_event: Optional[PrefectEvent] = None
|
97
110
|
|
98
111
|
def __post_init__(self):
|
99
112
|
if self.parameters is None:
|
@@ -136,7 +149,16 @@ class TaskRunEngine(Generic[P, R]):
|
|
136
149
|
)
|
137
150
|
return False
|
138
151
|
|
139
|
-
def
|
152
|
+
def is_cancelled(self) -> bool:
|
153
|
+
if (
|
154
|
+
self.context
|
155
|
+
and "cancel_event" in self.context
|
156
|
+
and isinstance(self.context["cancel_event"], threading.Event)
|
157
|
+
):
|
158
|
+
return self.context["cancel_event"].is_set()
|
159
|
+
return False
|
160
|
+
|
161
|
+
def call_hooks(self, state: Optional[State] = None):
|
140
162
|
if state is None:
|
141
163
|
state = self.state
|
142
164
|
task = self.task
|
@@ -171,7 +193,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
171
193
|
else:
|
172
194
|
self.logger.info(f"Hook {hook_name!r} finished running successfully")
|
173
195
|
|
174
|
-
def compute_transaction_key(self) -> str:
|
196
|
+
def compute_transaction_key(self) -> Optional[str]:
|
175
197
|
key = None
|
176
198
|
if self.task.cache_policy:
|
177
199
|
flow_run_context = FlowRunContext.get()
|
@@ -304,12 +326,24 @@ class TaskRunEngine(Generic[P, R]):
|
|
304
326
|
return new_state
|
305
327
|
|
306
328
|
def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
329
|
+
if self._return_value is not NotSet:
|
330
|
+
# if the return value is a BaseResult, we need to fetch it
|
331
|
+
if isinstance(self._return_value, BaseResult):
|
332
|
+
_result = self._return_value.get()
|
333
|
+
if inspect.isawaitable(_result):
|
334
|
+
_result = run_coro_as_sync(_result)
|
335
|
+
return _result
|
336
|
+
|
337
|
+
# otherwise, return the value as is
|
338
|
+
return self._return_value
|
339
|
+
|
340
|
+
if self._raised is not NotSet:
|
341
|
+
# if the task raised an exception, raise it
|
342
|
+
if raise_on_failure:
|
343
|
+
raise self._raised
|
344
|
+
|
345
|
+
# otherwise, return the exception
|
346
|
+
return self._raised
|
313
347
|
|
314
348
|
def handle_success(self, result: R, transaction: Transaction) -> R:
|
315
349
|
result_factory = getattr(TaskRunContext.get(), "result_factory", None)
|
@@ -339,6 +373,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
339
373
|
if transaction.is_committed():
|
340
374
|
terminal_state.name = "Cached"
|
341
375
|
self.set_state(terminal_state)
|
376
|
+
self._return_value = result
|
342
377
|
return result
|
343
378
|
|
344
379
|
def handle_retry(self, exc: Exception) -> bool:
|
@@ -365,9 +400,11 @@ class TaskRunEngine(Generic[P, R]):
|
|
365
400
|
new_state = Retrying()
|
366
401
|
|
367
402
|
self.logger.info(
|
368
|
-
|
369
|
-
|
370
|
-
|
403
|
+
"Task run failed with exception: %r - " "Retry %s/%s will start %s",
|
404
|
+
exc,
|
405
|
+
self.retries + 1,
|
406
|
+
self.task.retries,
|
407
|
+
str(delay) + " second(s) from now" if delay else "immediately",
|
371
408
|
)
|
372
409
|
|
373
410
|
self.set_state(new_state, force=True)
|
@@ -375,7 +412,9 @@ class TaskRunEngine(Generic[P, R]):
|
|
375
412
|
return True
|
376
413
|
elif self.retries >= self.task.retries:
|
377
414
|
self.logger.error(
|
378
|
-
|
415
|
+
"Task run failed with exception: %r - Retries are exhausted",
|
416
|
+
exc,
|
417
|
+
exc_info=True,
|
379
418
|
)
|
380
419
|
return False
|
381
420
|
|
@@ -394,12 +433,14 @@ class TaskRunEngine(Generic[P, R]):
|
|
394
433
|
)
|
395
434
|
)
|
396
435
|
self.set_state(state)
|
436
|
+
self._raised = exc
|
397
437
|
|
398
438
|
def handle_timeout(self, exc: TimeoutError) -> None:
|
399
439
|
if not self.handle_retry(exc):
|
400
|
-
|
401
|
-
f"Task run exceeded timeout of {self.task.timeout_seconds}
|
402
|
-
|
440
|
+
if isinstance(exc, TaskRunTimeoutError):
|
441
|
+
message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
|
442
|
+
else:
|
443
|
+
message = f"Task run failed due to timeout: {exc!r}"
|
403
444
|
self.logger.error(message)
|
404
445
|
state = Failed(
|
405
446
|
data=exc,
|
@@ -407,12 +448,14 @@ class TaskRunEngine(Generic[P, R]):
|
|
407
448
|
name="TimedOut",
|
408
449
|
)
|
409
450
|
self.set_state(state)
|
451
|
+
self._raised = exc
|
410
452
|
|
411
453
|
def handle_crash(self, exc: BaseException) -> None:
|
412
454
|
state = run_coro_as_sync(exception_to_crashed_state(exc))
|
413
455
|
self.logger.error(f"Crash detected! {state.message}")
|
414
456
|
self.logger.debug("Crash details:", exc_info=exc)
|
415
457
|
self.set_state(state, force=True)
|
458
|
+
self._raised = exc
|
416
459
|
|
417
460
|
@contextmanager
|
418
461
|
def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
|
@@ -498,6 +541,11 @@ class TaskRunEngine(Generic[P, R]):
|
|
498
541
|
)
|
499
542
|
yield self
|
500
543
|
|
544
|
+
except TerminationSignal as exc:
|
545
|
+
# TerminationSignals are caught and handled as crashes
|
546
|
+
self.handle_crash(exc)
|
547
|
+
raise exc
|
548
|
+
|
501
549
|
except Exception:
|
502
550
|
# regular exceptions are caught and re-raised to the user
|
503
551
|
raise
|
@@ -539,8 +587,8 @@ class TaskRunEngine(Generic[P, R]):
|
|
539
587
|
|
540
588
|
@flow
|
541
589
|
def example_flow():
|
542
|
-
say_hello.submit(name="Marvin)
|
543
|
-
|
590
|
+
future = say_hello.submit(name="Marvin)
|
591
|
+
future.wait()
|
544
592
|
|
545
593
|
example_flow()
|
546
594
|
"""
|
@@ -612,10 +660,16 @@ class TaskRunEngine(Generic[P, R]):
|
|
612
660
|
# reenter the run context to ensure it is up to date for every run
|
613
661
|
with self.setup_run_context():
|
614
662
|
try:
|
615
|
-
with timeout_context(
|
663
|
+
with timeout_context(
|
664
|
+
seconds=self.task.timeout_seconds,
|
665
|
+
timeout_exc_type=TaskRunTimeoutError,
|
666
|
+
):
|
616
667
|
self.logger.debug(
|
617
668
|
f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..."
|
618
669
|
)
|
670
|
+
if self.is_cancelled():
|
671
|
+
raise CancelledError("Task run cancelled by the task runner")
|
672
|
+
|
619
673
|
yield self
|
620
674
|
except TimeoutError as exc:
|
621
675
|
self.handle_timeout(exc)
|
@@ -638,6 +692,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
638
692
|
else:
|
639
693
|
result = await call_with_parameters(self.task.fn, parameters)
|
640
694
|
self.handle_success(result, transaction=transaction)
|
695
|
+
return result
|
641
696
|
|
642
697
|
return _call_task_fn()
|
643
698
|
else:
|
@@ -646,6 +701,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
646
701
|
else:
|
647
702
|
result = call_with_parameters(self.task.fn, parameters)
|
648
703
|
self.handle_success(result, transaction=transaction)
|
704
|
+
return result
|
649
705
|
|
650
706
|
|
651
707
|
def run_task_sync(
|