prefect-client 3.0.0rc9__py3-none-any.whl → 3.0.0rc10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prefect/results.py CHANGED
@@ -28,7 +28,6 @@ from prefect.client.utilities import inject_client
28
28
  from prefect.exceptions import MissingResult, ObjectAlreadyExists
29
29
  from prefect.filesystems import (
30
30
  LocalFileSystem,
31
- ReadableFileSystem,
32
31
  WritableFileSystem,
33
32
  )
34
33
  from prefect.logging import get_logger
@@ -111,22 +110,32 @@ async def _get_or_create_default_storage(block_document_slug: str) -> ResultStor
111
110
 
112
111
 
113
112
  @sync_compatible
114
- async def get_or_create_default_result_storage() -> ResultStorage:
113
+ async def get_default_result_storage() -> ResultStorage:
115
114
  """
116
115
  Generate a default file system for result storage.
117
116
  """
118
- return await _get_or_create_default_storage(
119
- PREFECT_DEFAULT_RESULT_STORAGE_BLOCK.value()
120
- )
117
+ default_block = PREFECT_DEFAULT_RESULT_STORAGE_BLOCK.value()
118
+
119
+ if default_block is not None:
120
+ return await Block.load(default_block)
121
+
122
+ # otherwise, use the local file system
123
+ basepath = PREFECT_LOCAL_STORAGE_PATH.value()
124
+ return LocalFileSystem(basepath=basepath)
121
125
 
122
126
 
123
127
  async def get_or_create_default_task_scheduling_storage() -> ResultStorage:
124
128
  """
125
129
  Generate a default file system for background task parameter/result storage.
126
130
  """
127
- return await _get_or_create_default_storage(
128
- PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK.value()
129
- )
131
+ default_block = PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK.value()
132
+
133
+ if default_block is not None:
134
+ return await Block.load(default_block)
135
+
136
+ # otherwise, use the local file system
137
+ basepath = PREFECT_LOCAL_STORAGE_PATH.value()
138
+ return LocalFileSystem(basepath=basepath)
130
139
 
131
140
 
132
141
  def get_default_result_serializer() -> ResultSerializer:
@@ -177,9 +186,7 @@ class ResultFactory(BaseModel):
177
186
  kwargs.pop(key)
178
187
 
179
188
  # Apply defaults
180
- kwargs.setdefault(
181
- "result_storage", await get_or_create_default_result_storage()
182
- )
189
+ kwargs.setdefault("result_storage", await get_default_result_storage())
183
190
  kwargs.setdefault("result_serializer", get_default_result_serializer())
184
191
  kwargs.setdefault("persist_result", get_default_persist_setting())
185
192
  kwargs.setdefault("cache_result_in_memory", True)
@@ -230,9 +237,7 @@ class ResultFactory(BaseModel):
230
237
  """
231
238
  Create a new result factory for a task.
232
239
  """
233
- return await cls._from_task(
234
- task, get_or_create_default_result_storage, client=client
235
- )
240
+ return await cls._from_task(task, get_default_result_storage, client=client)
236
241
 
237
242
  @classmethod
238
243
  @inject_client
@@ -268,7 +273,14 @@ class ResultFactory(BaseModel):
268
273
  if ctx and ctx.result_factory
269
274
  else get_default_result_serializer()
270
275
  )
271
- persist_result = task.persist_result
276
+ if task.persist_result is None:
277
+ persist_result = (
278
+ ctx.result_factory.persist_result
279
+ if ctx and ctx.result_factory
280
+ else get_default_persist_setting()
281
+ )
282
+ else:
283
+ persist_result = task.persist_result
272
284
 
273
285
  cache_result_in_memory = task.cache_result_in_memory
274
286
 
@@ -330,16 +342,7 @@ class ResultFactory(BaseModel):
330
342
  # Avoid saving the block if it already has an identifier assigned
331
343
  storage_block_id = storage_block._block_document_id
332
344
  else:
333
- if persist_result:
334
- # TODO: Overwrite is true to avoid issues where the save collides with
335
- # a previously saved document with a matching hash
336
- storage_block_id = await storage_block._save(
337
- is_anonymous=True, overwrite=True, client=client
338
- )
339
- else:
340
- # a None-type UUID on unpersisted storage should not matter
341
- # since the ID is generated on the server
342
- storage_block_id = None
345
+ storage_block_id = None
343
346
  elif isinstance(result_storage, str):
344
347
  storage_block = await Block.load(result_storage, client=client)
345
348
  storage_block_id = storage_block._block_document_id
@@ -412,9 +415,6 @@ class ResultFactory(BaseModel):
412
415
 
413
416
  @sync_compatible
414
417
  async def store_parameters(self, identifier: UUID, parameters: Dict[str, Any]):
415
- assert (
416
- self.storage_block_id is not None
417
- ), "Unexpected storage block ID. Was it persisted?"
418
418
  data = self.serializer.dumps(parameters)
419
419
  blob = PersistedResultBlob(serializer=self.serializer, data=data)
420
420
  await self.storage_block.write_path(
@@ -423,9 +423,6 @@ class ResultFactory(BaseModel):
423
423
 
424
424
  @sync_compatible
425
425
  async def read_parameters(self, identifier: UUID) -> Dict[str, Any]:
426
- assert (
427
- self.storage_block_id is not None
428
- ), "Unexpected storage block ID. Was it persisted?"
429
426
  blob = PersistedResultBlob.model_validate_json(
430
427
  await self.storage_block.read_path(f"parameters/{identifier}")
431
428
  )
@@ -435,10 +432,7 @@ class ResultFactory(BaseModel):
435
432
  @register_base_type
436
433
  class BaseResult(BaseModel, abc.ABC, Generic[R]):
437
434
  model_config = ConfigDict(extra="forbid")
438
-
439
435
  type: str
440
- artifact_type: Optional[str] = None
441
- artifact_description: Optional[str] = None
442
436
 
443
437
  def __init__(self, **data: Any) -> None:
444
438
  type_string = get_dispatch_key(self) if type(self) != BaseResult else "__base__"
@@ -504,11 +498,7 @@ class UnpersistedResult(BaseResult):
504
498
  obj: R,
505
499
  cache_object: bool = True,
506
500
  ) -> "UnpersistedResult[R]":
507
- description = f"Unpersisted result of type `{type(obj).__name__}`"
508
- result = cls(
509
- artifact_type="result",
510
- artifact_description=description,
511
- )
501
+ result = cls()
512
502
  # Only store the object in local memory, it will not be sent to the API
513
503
  if cache_object:
514
504
  result._cache_object(obj)
@@ -528,8 +518,8 @@ class PersistedResult(BaseResult):
528
518
  type: str = "reference"
529
519
 
530
520
  serializer_type: str
531
- storage_block_id: uuid.UUID
532
521
  storage_key: str
522
+ storage_block_id: Optional[uuid.UUID] = None
533
523
  expiration: Optional[DateTime] = None
534
524
 
535
525
  _should_cache_object: bool = PrivateAttr(default=True)
@@ -547,6 +537,17 @@ class PersistedResult(BaseResult):
547
537
  self._storage_block = storage_block
548
538
  self._serializer = serializer
549
539
 
540
+ @inject_client
541
+ async def _get_storage_block(self, client: "PrefectClient") -> WritableFileSystem:
542
+ if self._storage_block is not None:
543
+ return self._storage_block
544
+ elif self.storage_block_id is not None:
545
+ block_document = await client.read_block_document(self.storage_block_id)
546
+ self._storage_block = Block._from_block_document(block_document)
547
+ else:
548
+ self._storage_block = await get_default_result_storage()
549
+ return self._storage_block
550
+
550
551
  @sync_compatible
551
552
  @inject_client
552
553
  async def get(self, client: "PrefectClient") -> R:
@@ -567,12 +568,8 @@ class PersistedResult(BaseResult):
567
568
 
568
569
  @inject_client
569
570
  async def _read_blob(self, client: "PrefectClient") -> "PersistedResultBlob":
570
- assert (
571
- self.storage_block_id is not None
572
- ), "Unexpected storage block ID. Was it persisted?"
573
- block_document = await client.read_block_document(self.storage_block_id)
574
- storage_block: ReadableFileSystem = Block._from_block_document(block_document)
575
- content = await storage_block.read_path(self.storage_key)
571
+ block = await self._get_storage_block(client=client)
572
+ content = await block.read_path(self.storage_key)
576
573
  blob = PersistedResultBlob.model_validate_json(content)
577
574
  return blob
578
575
 
@@ -607,10 +604,7 @@ class PersistedResult(BaseResult):
607
604
  obj = obj if obj is not NotSet else self._cache
608
605
 
609
606
  # next, the storage block
610
- storage_block = self._storage_block
611
- if storage_block is None:
612
- block_document = await client.read_block_document(self.storage_block_id)
613
- storage_block = Block._from_block_document(block_document)
607
+ storage_block = await self._get_storage_block(client=client)
614
608
 
615
609
  # finally, the serializer
616
610
  serializer = self._serializer
@@ -673,9 +667,9 @@ class PersistedResult(BaseResult):
673
667
  cls: "Type[PersistedResult]",
674
668
  obj: R,
675
669
  storage_block: WritableFileSystem,
676
- storage_block_id: uuid.UUID,
677
670
  storage_key_fn: Callable[[], str],
678
671
  serializer: Serializer,
672
+ storage_block_id: Optional[uuid.UUID] = None,
679
673
  cache_object: bool = True,
680
674
  expiration: Optional[DateTime] = None,
681
675
  defer_persistence: bool = False,
@@ -686,31 +680,21 @@ class PersistedResult(BaseResult):
686
680
  The object will be serialized and written to the storage block under a unique
687
681
  key. It will then be cached on the returned result.
688
682
  """
689
- assert (
690
- storage_block_id is not None
691
- ), "Unexpected storage block ID. Was it saved?"
692
-
693
683
  key = storage_key_fn()
694
684
  if not isinstance(key, str):
695
685
  raise TypeError(
696
686
  f"Expected type 'str' for result storage key; got value {key!r}"
697
687
  )
698
- description = f"Result of type `{type(obj).__name__}`"
699
688
  uri = cls._infer_path(storage_block, key)
700
- if uri:
701
- if isinstance(storage_block, LocalFileSystem):
702
- description += f" persisted to: `{uri}`"
703
- else:
704
- description += f" persisted to [{uri}]({uri})."
705
- else:
706
- description += f" persisted with storage block `{storage_block_id}`."
689
+
690
+ # in this case we store an absolute path
691
+ if storage_block_id is None and uri is not None:
692
+ key = str(uri)
707
693
 
708
694
  result = cls(
709
695
  serializer_type=serializer.type,
710
696
  storage_block_id=storage_block_id,
711
697
  storage_key=key,
712
- artifact_type="result",
713
- artifact_description=description,
714
698
  expiration=expiration,
715
699
  )
716
700
 
@@ -787,5 +771,4 @@ class UnknownResult(BaseResult):
787
771
  "Only None is supported."
788
772
  )
789
773
 
790
- description = "Unknown result persisted to Prefect."
791
- return cls(value=obj, artifact_type="result", artifact_description=description)
774
+ return cls(value=obj)
prefect/serializers.py CHANGED
@@ -13,7 +13,7 @@ bytes to an object respectively.
13
13
 
14
14
  import abc
15
15
  import base64
16
- from typing import Any, Dict, Generic, Optional, Type, TypeVar
16
+ from typing import Any, Dict, Generic, Optional, Type
17
17
 
18
18
  from pydantic import (
19
19
  BaseModel,
@@ -23,7 +23,7 @@ from pydantic import (
23
23
  ValidationError,
24
24
  field_validator,
25
25
  )
26
- from typing_extensions import Literal, Self
26
+ from typing_extensions import Literal, Self, TypeVar
27
27
 
28
28
  from prefect._internal.schemas.validators import (
29
29
  cast_type_names_to_serializers,
@@ -36,7 +36,7 @@ from prefect.utilities.dispatch import get_dispatch_key, lookup_type, register_b
36
36
  from prefect.utilities.importtools import from_qualified_name, to_qualified_name
37
37
  from prefect.utilities.pydantic import custom_pydantic_encoder
38
38
 
39
- D = TypeVar("D")
39
+ D = TypeVar("D", default=Any)
40
40
 
41
41
 
42
42
  def prefect_json_object_encoder(obj: Any) -> Any:
prefect/settings.py CHANGED
@@ -42,7 +42,7 @@ dependent on the value of other settings or perform other dynamic effects.
42
42
 
43
43
  import logging
44
44
  import os
45
- import socket
45
+ import re
46
46
  import string
47
47
  import warnings
48
48
  from contextlib import contextmanager
@@ -85,7 +85,6 @@ from prefect._internal.schemas.validators import validate_settings
85
85
  from prefect.exceptions import MissingProfileError
86
86
  from prefect.utilities.names import OBFUSCATED_PREFIX, obfuscate
87
87
  from prefect.utilities.pydantic import add_cloudpickle_reduction
88
- from prefect.utilities.slugify import slugify
89
88
 
90
89
  T = TypeVar("T")
91
90
 
@@ -404,18 +403,6 @@ def warn_on_misconfigured_api_url(values):
404
403
  return values
405
404
 
406
405
 
407
- def default_result_storage_block_name(
408
- settings: Optional["Settings"] = None, value: Optional[str] = None
409
- ):
410
- """
411
- `value_callback` for `PREFECT_DEFAULT_RESULT_STORAGE_BLOCK` that sets the default
412
- value to the hostname of the machine.
413
- """
414
- if value is None:
415
- return f"local-file-system/{slugify(socket.gethostname())}-storage"
416
- return value
417
-
418
-
419
406
  def default_database_connection_url(settings, value):
420
407
  templater = template_with_settings(PREFECT_HOME, PREFECT_API_DATABASE_PASSWORD)
421
408
 
@@ -474,10 +461,8 @@ def default_cloud_ui_url(settings, value):
474
461
  # Otherwise, infer a value from the API URL
475
462
  ui_url = api_url = PREFECT_CLOUD_API_URL.value_from(settings)
476
463
 
477
- if api_url.startswith("https://api.prefect.cloud"):
478
- ui_url = ui_url.replace(
479
- "https://api.prefect.cloud", "https://app.prefect.cloud", 1
480
- )
464
+ if re.match(r"^https://api[\.\w]*.prefect.[^\.]+/", api_url):
465
+ ui_url = ui_url.replace("https://api", "https://app", 1)
481
466
 
482
467
  if ui_url.endswith("/api"):
483
468
  ui_url = ui_url[:-4]
@@ -1323,15 +1308,6 @@ PREFECT_API_MAX_FLOW_RUN_GRAPH_ARTIFACTS = Setting(int, default=10000)
1323
1308
  The maximum number of artifacts to show on a flow run graph on the v2 API
1324
1309
  """
1325
1310
 
1326
- PREFECT_EXPERIMENTAL_ENABLE_WORKERS = Setting(bool, default=True)
1327
- """
1328
- Whether or not to enable experimental Prefect workers.
1329
- """
1330
-
1331
- PREFECT_EXPERIMENTAL_WARN_WORKERS = Setting(bool, default=False)
1332
- """
1333
- Whether or not to warn when experimental Prefect workers are used.
1334
- """
1335
1311
 
1336
1312
  PREFECT_EXPERIMENTAL_ENABLE_ENHANCED_CANCELLATION = Setting(bool, default=True)
1337
1313
  """
@@ -1423,10 +1399,7 @@ PREFECT_API_SERVICES_TASK_SCHEDULING_ENABLED = Setting(bool, default=True)
1423
1399
  Whether or not to start the task scheduling service in the server application.
1424
1400
  """
1425
1401
 
1426
- PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK = Setting(
1427
- str,
1428
- default="local-file-system/prefect-task-scheduling",
1429
- )
1402
+ PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK = Setting(Optional[str], default=None)
1430
1403
  """The `block-type/block-document` slug of a block to use as the default storage
1431
1404
  for autonomous tasks."""
1432
1405
 
@@ -1479,7 +1452,8 @@ PREFECT_EXPERIMENTAL_ENABLE_SCHEDULE_CONCURRENCY = Setting(bool, default=False)
1479
1452
  # Defaults -----------------------------------------------------------------------------
1480
1453
 
1481
1454
  PREFECT_DEFAULT_RESULT_STORAGE_BLOCK = Setting(
1482
- Optional[str], default=None, value_callback=default_result_storage_block_name
1455
+ Optional[str],
1456
+ default=None,
1483
1457
  )
1484
1458
  """The `block-type/block-document` slug of a block to use as the default result storage."""
1485
1459
 
prefect/task_engine.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import inspect
2
2
  import logging
3
+ import threading
3
4
  import time
5
+ from asyncio import CancelledError
4
6
  from contextlib import ExitStack, contextmanager
5
7
  from dataclasses import dataclass, field
6
8
  from textwrap import dedent
@@ -17,6 +19,7 @@ from typing import (
17
19
  Optional,
18
20
  Sequence,
19
21
  Set,
22
+ Type,
20
23
  TypeVar,
21
24
  Union,
22
25
  )
@@ -36,17 +39,18 @@ from prefect.context import (
36
39
  TaskRunContext,
37
40
  hydrated_context,
38
41
  )
39
- from prefect.events.schemas.events import Event
42
+ from prefect.events.schemas.events import Event as PrefectEvent
40
43
  from prefect.exceptions import (
41
44
  Abort,
42
45
  Pause,
43
46
  PrefectException,
47
+ TerminationSignal,
44
48
  UpstreamTaskError,
45
49
  )
46
50
  from prefect.futures import PrefectFuture
47
51
  from prefect.logging.loggers import get_logger, patch_print, task_run_logger
48
52
  from prefect.records.result_store import ResultFactoryStore
49
- from prefect.results import ResultFactory, _format_user_supplied_storage_key
53
+ from prefect.results import BaseResult, ResultFactory, _format_user_supplied_storage_key
50
54
  from prefect.settings import (
51
55
  PREFECT_DEBUG_MODE,
52
56
  PREFECT_TASKS_REFRESH_CACHE,
@@ -63,6 +67,7 @@ from prefect.states import (
63
67
  return_value_to_state,
64
68
  )
65
69
  from prefect.transactions import Transaction, transaction
70
+ from prefect.utilities.annotations import NotSet
66
71
  from prefect.utilities.asyncutils import run_coro_as_sync
67
72
  from prefect.utilities.callables import call_with_parameters, parameters_to_args_kwargs
68
73
  from prefect.utilities.collections import visit_collection
@@ -80,6 +85,10 @@ P = ParamSpec("P")
80
85
  R = TypeVar("R")
81
86
 
82
87
 
88
+ class TaskRunTimeoutError(TimeoutError):
89
+ """Raised when a task run exceeds its timeout."""
90
+
91
+
83
92
  @dataclass
84
93
  class TaskRunEngine(Generic[P, R]):
85
94
  task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]
@@ -89,11 +98,15 @@ class TaskRunEngine(Generic[P, R]):
89
98
  retries: int = 0
90
99
  wait_for: Optional[Iterable[PrefectFuture]] = None
91
100
  context: Optional[Dict[str, Any]] = None
101
+ # holds the return value from the user code
102
+ _return_value: Union[R, Type[NotSet]] = NotSet
103
+ # holds the exception raised by the user code, if any
104
+ _raised: Union[Exception, Type[NotSet]] = NotSet
92
105
  _initial_run_context: Optional[TaskRunContext] = None
93
106
  _is_started: bool = False
94
107
  _client: Optional[SyncPrefectClient] = None
95
108
  _task_name_set: bool = False
96
- _last_event: Optional[Event] = None
109
+ _last_event: Optional[PrefectEvent] = None
97
110
 
98
111
  def __post_init__(self):
99
112
  if self.parameters is None:
@@ -136,7 +149,16 @@ class TaskRunEngine(Generic[P, R]):
136
149
  )
137
150
  return False
138
151
 
139
- def call_hooks(self, state: State = None) -> Iterable[Callable]:
152
+ def is_cancelled(self) -> bool:
153
+ if (
154
+ self.context
155
+ and "cancel_event" in self.context
156
+ and isinstance(self.context["cancel_event"], threading.Event)
157
+ ):
158
+ return self.context["cancel_event"].is_set()
159
+ return False
160
+
161
+ def call_hooks(self, state: Optional[State] = None):
140
162
  if state is None:
141
163
  state = self.state
142
164
  task = self.task
@@ -171,7 +193,7 @@ class TaskRunEngine(Generic[P, R]):
171
193
  else:
172
194
  self.logger.info(f"Hook {hook_name!r} finished running successfully")
173
195
 
174
- def compute_transaction_key(self) -> str:
196
+ def compute_transaction_key(self) -> Optional[str]:
175
197
  key = None
176
198
  if self.task.cache_policy:
177
199
  flow_run_context = FlowRunContext.get()
@@ -304,12 +326,24 @@ class TaskRunEngine(Generic[P, R]):
304
326
  return new_state
305
327
 
306
328
  def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
307
- _result = self.state.result(raise_on_failure=raise_on_failure, fetch=True)
308
- # state.result is a `sync_compatible` function that may or may not return an awaitable
309
- # depending on whether the parent frame is sync or not
310
- if inspect.isawaitable(_result):
311
- _result = run_coro_as_sync(_result)
312
- return _result
329
+ if self._return_value is not NotSet:
330
+ # if the return value is a BaseResult, we need to fetch it
331
+ if isinstance(self._return_value, BaseResult):
332
+ _result = self._return_value.get()
333
+ if inspect.isawaitable(_result):
334
+ _result = run_coro_as_sync(_result)
335
+ return _result
336
+
337
+ # otherwise, return the value as is
338
+ return self._return_value
339
+
340
+ if self._raised is not NotSet:
341
+ # if the task raised an exception, raise it
342
+ if raise_on_failure:
343
+ raise self._raised
344
+
345
+ # otherwise, return the exception
346
+ return self._raised
313
347
 
314
348
  def handle_success(self, result: R, transaction: Transaction) -> R:
315
349
  result_factory = getattr(TaskRunContext.get(), "result_factory", None)
@@ -339,6 +373,7 @@ class TaskRunEngine(Generic[P, R]):
339
373
  if transaction.is_committed():
340
374
  terminal_state.name = "Cached"
341
375
  self.set_state(terminal_state)
376
+ self._return_value = result
342
377
  return result
343
378
 
344
379
  def handle_retry(self, exc: Exception) -> bool:
@@ -365,9 +400,11 @@ class TaskRunEngine(Generic[P, R]):
365
400
  new_state = Retrying()
366
401
 
367
402
  self.logger.info(
368
- f"Task run failed with exception {exc!r} - "
369
- f"Retry {self.retries + 1}/{self.task.retries} will start "
370
- f"{str(delay) + ' second(s) from now' if delay else 'immediately'}"
403
+ "Task run failed with exception: %r - " "Retry %s/%s will start %s",
404
+ exc,
405
+ self.retries + 1,
406
+ self.task.retries,
407
+ str(delay) + " second(s) from now" if delay else "immediately",
371
408
  )
372
409
 
373
410
  self.set_state(new_state, force=True)
@@ -375,7 +412,9 @@ class TaskRunEngine(Generic[P, R]):
375
412
  return True
376
413
  elif self.retries >= self.task.retries:
377
414
  self.logger.error(
378
- f"Task run failed with exception {exc!r} - Retries are exhausted"
415
+ "Task run failed with exception: %r - Retries are exhausted",
416
+ exc,
417
+ exc_info=True,
379
418
  )
380
419
  return False
381
420
 
@@ -394,12 +433,14 @@ class TaskRunEngine(Generic[P, R]):
394
433
  )
395
434
  )
396
435
  self.set_state(state)
436
+ self._raised = exc
397
437
 
398
438
  def handle_timeout(self, exc: TimeoutError) -> None:
399
439
  if not self.handle_retry(exc):
400
- message = (
401
- f"Task run exceeded timeout of {self.task.timeout_seconds} seconds"
402
- )
440
+ if isinstance(exc, TaskRunTimeoutError):
441
+ message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
442
+ else:
443
+ message = f"Task run failed due to timeout: {exc!r}"
403
444
  self.logger.error(message)
404
445
  state = Failed(
405
446
  data=exc,
@@ -407,12 +448,14 @@ class TaskRunEngine(Generic[P, R]):
407
448
  name="TimedOut",
408
449
  )
409
450
  self.set_state(state)
451
+ self._raised = exc
410
452
 
411
453
  def handle_crash(self, exc: BaseException) -> None:
412
454
  state = run_coro_as_sync(exception_to_crashed_state(exc))
413
455
  self.logger.error(f"Crash detected! {state.message}")
414
456
  self.logger.debug("Crash details:", exc_info=exc)
415
457
  self.set_state(state, force=True)
458
+ self._raised = exc
416
459
 
417
460
  @contextmanager
418
461
  def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
@@ -498,6 +541,11 @@ class TaskRunEngine(Generic[P, R]):
498
541
  )
499
542
  yield self
500
543
 
544
+ except TerminationSignal as exc:
545
+ # TerminationSignals are caught and handled as crashes
546
+ self.handle_crash(exc)
547
+ raise exc
548
+
501
549
  except Exception:
502
550
  # regular exceptions are caught and re-raised to the user
503
551
  raise
@@ -539,8 +587,8 @@ class TaskRunEngine(Generic[P, R]):
539
587
 
540
588
  @flow
541
589
  def example_flow():
542
- say_hello.submit(name="Marvin)
543
- say_hello.wait()
590
+ future = say_hello.submit(name="Marvin)
591
+ future.wait()
544
592
 
545
593
  example_flow()
546
594
  """
@@ -612,10 +660,16 @@ class TaskRunEngine(Generic[P, R]):
612
660
  # reenter the run context to ensure it is up to date for every run
613
661
  with self.setup_run_context():
614
662
  try:
615
- with timeout_context(seconds=self.task.timeout_seconds):
663
+ with timeout_context(
664
+ seconds=self.task.timeout_seconds,
665
+ timeout_exc_type=TaskRunTimeoutError,
666
+ ):
616
667
  self.logger.debug(
617
668
  f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..."
618
669
  )
670
+ if self.is_cancelled():
671
+ raise CancelledError("Task run cancelled by the task runner")
672
+
619
673
  yield self
620
674
  except TimeoutError as exc:
621
675
  self.handle_timeout(exc)
@@ -638,6 +692,7 @@ class TaskRunEngine(Generic[P, R]):
638
692
  else:
639
693
  result = await call_with_parameters(self.task.fn, parameters)
640
694
  self.handle_success(result, transaction=transaction)
695
+ return result
641
696
 
642
697
  return _call_task_fn()
643
698
  else:
@@ -646,6 +701,7 @@ class TaskRunEngine(Generic[P, R]):
646
701
  else:
647
702
  result = call_with_parameters(self.task.fn, parameters)
648
703
  self.handle_success(result, transaction=transaction)
704
+ return result
649
705
 
650
706
 
651
707
  def run_task_sync(