prefect-client 3.0.1__py3-none-any.whl → 3.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. prefect/_internal/compatibility/deprecated.py +1 -1
  2. prefect/blocks/notifications.py +21 -0
  3. prefect/blocks/webhook.py +8 -0
  4. prefect/client/orchestration.py +39 -20
  5. prefect/client/schemas/actions.py +2 -2
  6. prefect/client/schemas/objects.py +24 -6
  7. prefect/client/types/flexible_schedule_list.py +1 -1
  8. prefect/concurrency/asyncio.py +45 -6
  9. prefect/concurrency/services.py +1 -1
  10. prefect/concurrency/sync.py +21 -27
  11. prefect/concurrency/v1/asyncio.py +3 -0
  12. prefect/concurrency/v1/sync.py +4 -5
  13. prefect/context.py +5 -1
  14. prefect/deployments/runner.py +1 -0
  15. prefect/events/actions.py +6 -0
  16. prefect/flow_engine.py +12 -4
  17. prefect/locking/filesystem.py +243 -0
  18. prefect/logging/handlers.py +0 -2
  19. prefect/logging/loggers.py +0 -18
  20. prefect/logging/logging.yml +1 -0
  21. prefect/main.py +19 -5
  22. prefect/records/base.py +12 -0
  23. prefect/records/filesystem.py +6 -2
  24. prefect/records/memory.py +6 -0
  25. prefect/records/result_store.py +6 -0
  26. prefect/results.py +169 -25
  27. prefect/runner/runner.py +74 -5
  28. prefect/settings.py +1 -1
  29. prefect/states.py +34 -17
  30. prefect/task_engine.py +31 -37
  31. prefect/transactions.py +105 -50
  32. prefect/utilities/engine.py +16 -8
  33. prefect/utilities/importtools.py +1 -0
  34. prefect/utilities/urls.py +70 -12
  35. prefect/workers/base.py +14 -6
  36. {prefect_client-3.0.1.dist-info → prefect_client-3.0.2.dist-info}/METADATA +1 -1
  37. {prefect_client-3.0.1.dist-info → prefect_client-3.0.2.dist-info}/RECORD +40 -39
  38. {prefect_client-3.0.1.dist-info → prefect_client-3.0.2.dist-info}/LICENSE +0 -0
  39. {prefect_client-3.0.1.dist-info → prefect_client-3.0.2.dist-info}/WHEEL +0 -0
  40. {prefect_client-3.0.1.dist-info → prefect_client-3.0.2.dist-info}/top_level.txt +0 -0
prefect/states.py CHANGED
@@ -25,7 +25,13 @@ from prefect.exceptions import (
25
25
  UnfinishedRun,
26
26
  )
27
27
  from prefect.logging.loggers import get_logger, get_run_logger
28
- from prefect.results import BaseResult, R, ResultStore
28
+ from prefect.results import (
29
+ BaseResult,
30
+ R,
31
+ ResultRecord,
32
+ ResultRecordMetadata,
33
+ ResultStore,
34
+ )
29
35
  from prefect.settings import PREFECT_ASYNC_FETCH_STATE_RESULT
30
36
  from prefect.utilities.annotations import BaseAnnotation
31
37
  from prefect.utilities.asyncutils import in_async_main_thread, sync_compatible
@@ -92,7 +98,11 @@ async def _get_state_result_data_with_retries(
92
98
 
93
99
  for i in range(1, max_attempts + 1):
94
100
  try:
95
- return await state.data.get()
101
+ if isinstance(state.data, ResultRecordMetadata):
102
+ record = await ResultRecord._from_metadata(state.data)
103
+ return record.result
104
+ else:
105
+ return await state.data.get()
96
106
  except Exception as e:
97
107
  if i == max_attempts:
98
108
  raise
@@ -127,10 +137,12 @@ async def _get_state_result(
127
137
  ):
128
138
  raise await get_state_exception(state)
129
139
 
130
- if isinstance(state.data, BaseResult):
140
+ if isinstance(state.data, (BaseResult, ResultRecordMetadata)):
131
141
  result = await _get_state_result_data_with_retries(
132
142
  state, retry_result_failure=retry_result_failure
133
143
  )
144
+ elif isinstance(state.data, ResultRecord):
145
+ result = state.data.result
134
146
 
135
147
  elif state.data is None:
136
148
  if state.is_failed() or state.is_crashed() or state.is_cancelled():
@@ -207,7 +219,7 @@ async def exception_to_crashed_state(
207
219
  )
208
220
 
209
221
  if result_store:
210
- data = await result_store.create_result(exc)
222
+ data = result_store.create_result_record(exc)
211
223
  else:
212
224
  # Attach the exception for local usage, will not be available when retrieved
213
225
  # from the API
@@ -240,10 +252,10 @@ async def exception_to_failed_state(
240
252
  pass
241
253
 
242
254
  if result_store:
243
- data = await result_store.create_result(exc)
255
+ data = result_store.create_result_record(exc)
244
256
  if write_result:
245
257
  try:
246
- await data.write()
258
+ await result_store.apersist_result_record(data)
247
259
  except Exception as exc:
248
260
  local_logger.warning(
249
261
  "Failed to write result: %s Execution will continue, but the result has not been written",
@@ -309,21 +321,21 @@ async def return_value_to_state(
309
321
  state = retval
310
322
  # Unless the user has already constructed a result explicitly, use the store
311
323
  # to update the data to the correct type
312
- if not isinstance(state.data, BaseResult):
313
- result = await result_store.create_result(
324
+ if not isinstance(state.data, (BaseResult, ResultRecord, ResultRecordMetadata)):
325
+ result_record = result_store.create_result_record(
314
326
  state.data,
315
327
  key=key,
316
328
  expiration=expiration,
317
329
  )
318
330
  if write_result:
319
331
  try:
320
- await result.write()
332
+ await result_store.apersist_result_record(result_record)
321
333
  except Exception as exc:
322
334
  local_logger.warning(
323
335
  "Encountered an error while persisting result: %s Execution will continue, but the result has not been persisted",
324
336
  exc,
325
337
  )
326
- state.data = result
338
+ state.data = result_record
327
339
  return state
328
340
 
329
341
  # Determine a new state from the aggregate of contained states
@@ -359,14 +371,14 @@ async def return_value_to_state(
359
371
  # TODO: We may actually want to set the data to a `StateGroup` object and just
360
372
  # allow it to be unpacked into a tuple and such so users can interact with
361
373
  # it
362
- result = await result_store.create_result(
374
+ result_record = result_store.create_result_record(
363
375
  retval,
364
376
  key=key,
365
377
  expiration=expiration,
366
378
  )
367
379
  if write_result:
368
380
  try:
369
- await result.write()
381
+ await result_store.apersist_result_record(result_record)
370
382
  except Exception as exc:
371
383
  local_logger.warning(
372
384
  "Encountered an error while persisting result: %s Execution will continue, but the result has not been persisted",
@@ -375,7 +387,7 @@ async def return_value_to_state(
375
387
  return State(
376
388
  type=new_state_type,
377
389
  message=message,
378
- data=result,
390
+ data=result_record,
379
391
  )
380
392
 
381
393
  # Generators aren't portable, implicitly convert them to a list.
@@ -385,23 +397,23 @@ async def return_value_to_state(
385
397
  data = retval
386
398
 
387
399
  # Otherwise, they just gave data and this is a completed retval
388
- if isinstance(data, BaseResult):
400
+ if isinstance(data, (BaseResult, ResultRecord)):
389
401
  return Completed(data=data)
390
402
  else:
391
- result = await result_store.create_result(
403
+ result_record = result_store.create_result_record(
392
404
  data,
393
405
  key=key,
394
406
  expiration=expiration,
395
407
  )
396
408
  if write_result:
397
409
  try:
398
- await result.write()
410
+ await result_store.apersist_result_record(result_record)
399
411
  except Exception as exc:
400
412
  local_logger.warning(
401
413
  "Encountered an error while persisting result: %s Execution will continue, but the result has not been persisted",
402
414
  exc,
403
415
  )
404
- return Completed(data=result)
416
+ return Completed(data=result_record)
405
417
 
406
418
 
407
419
  @sync_compatible
@@ -442,6 +454,11 @@ async def get_state_exception(state: State) -> BaseException:
442
454
 
443
455
  if isinstance(state.data, BaseResult):
444
456
  result = await _get_state_result_data_with_retries(state)
457
+ elif isinstance(state.data, ResultRecord):
458
+ result = state.data.result
459
+ elif isinstance(state.data, ResultRecordMetadata):
460
+ record = await ResultRecord._from_metadata(state.data)
461
+ result = record.result
445
462
  elif state.data is None:
446
463
  result = None
447
464
  else:
prefect/task_engine.py CHANGED
@@ -55,11 +55,12 @@ from prefect.exceptions import (
55
55
  )
56
56
  from prefect.futures import PrefectFuture
57
57
  from prefect.logging.loggers import get_logger, patch_print, task_run_logger
58
- from prefect.records.result_store import ResultRecordStore
59
58
  from prefect.results import (
60
59
  BaseResult,
60
+ ResultRecord,
61
61
  _format_user_supplied_storage_key,
62
- get_current_result_store,
62
+ get_result_store,
63
+ should_persist_result,
63
64
  )
64
65
  from prefect.settings import (
65
66
  PREFECT_DEBUG_MODE,
@@ -418,6 +419,8 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
418
419
  result = state.result(raise_on_failure=False, fetch=True)
419
420
  if inspect.isawaitable(result):
420
421
  result = run_coro_as_sync(result)
422
+ elif isinstance(state.data, ResultRecord):
423
+ result = state.data.result
421
424
  else:
422
425
  result = state.data
423
426
 
@@ -441,7 +444,8 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
441
444
  if inspect.isawaitable(_result):
442
445
  _result = run_coro_as_sync(_result)
443
446
  return _result
444
-
447
+ elif isinstance(self._return_value, ResultRecord):
448
+ return self._return_value.result
445
449
  # otherwise, return the value as is
446
450
  return self._return_value
447
451
 
@@ -454,10 +458,6 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
454
458
  return self._raised
455
459
 
456
460
  def handle_success(self, result: R, transaction: Transaction) -> R:
457
- result_store = getattr(TaskRunContext.get(), "result_store", None)
458
- if result_store is None:
459
- raise ValueError("Result store is not set")
460
-
461
461
  if self.task.cache_expiration is not None:
462
462
  expiration = pendulum.now("utc") + self.task.cache_expiration
463
463
  else:
@@ -466,7 +466,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
466
466
  terminal_state = run_coro_as_sync(
467
467
  return_value_to_state(
468
468
  result,
469
- result_store=result_store,
469
+ result_store=get_result_store(),
470
470
  key=transaction.key,
471
471
  expiration=expiration,
472
472
  )
@@ -538,12 +538,11 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
538
538
  # If the task fails, and we have retries left, set the task to retrying.
539
539
  if not self.handle_retry(exc):
540
540
  # If the task has no retries left, or the retry condition is not met, set the task to failed.
541
- context = TaskRunContext.get()
542
541
  state = run_coro_as_sync(
543
542
  exception_to_failed_state(
544
543
  exc,
545
544
  message="Task run encountered an exception",
546
- result_store=getattr(context, "result_store", None),
545
+ result_store=get_result_store(),
547
546
  write_result=True,
548
547
  )
549
548
  )
@@ -595,10 +594,13 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
595
594
  log_prints=log_prints,
596
595
  task_run=self.task_run,
597
596
  parameters=self.parameters,
598
- result_store=get_current_result_store().update_for_task(
597
+ result_store=get_result_store().update_for_task(
599
598
  self.task, _sync=True
600
599
  ),
601
600
  client=client,
601
+ persist_result=self.task.persist_result
602
+ if self.task.persist_result is not None
603
+ else should_persist_result(),
602
604
  )
603
605
  )
604
606
  stack.enter_context(ConcurrencyContextV1())
@@ -723,17 +725,12 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
723
725
  else PREFECT_TASKS_REFRESH_CACHE.value()
724
726
  )
725
727
 
726
- result_store = getattr(TaskRunContext.get(), "result_store", None)
727
- if result_store and result_store.persist_result:
728
- store = ResultRecordStore(result_store=result_store)
729
- else:
730
- store = None
731
-
732
728
  with transaction(
733
729
  key=self.compute_transaction_key(),
734
- store=store,
730
+ store=get_result_store(),
735
731
  overwrite=overwrite,
736
732
  logger=self.logger,
733
+ write_on_commit=should_persist_result(),
737
734
  ) as txn:
738
735
  yield txn
739
736
 
@@ -769,10 +766,10 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
769
766
  if transaction.is_committed():
770
767
  result = transaction.read()
771
768
  else:
772
- if self.task.tags:
769
+ if self.task_run.tags:
773
770
  # Acquire a concurrency slot for each tag, but only if a limit
774
771
  # matching the tag already exists.
775
- with concurrency(list(self.task.tags), self.task_run.id):
772
+ with concurrency(list(self.task_run.tags), self.task_run.id):
776
773
  result = call_with_parameters(self.task.fn, parameters)
777
774
  else:
778
775
  result = call_with_parameters(self.task.fn, parameters)
@@ -933,6 +930,8 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
933
930
  # Avoid fetching the result unless it is cached, otherwise we defeat
934
931
  # the purpose of disabling `cache_result_in_memory`
935
932
  result = await new_state.result(raise_on_failure=False, fetch=True)
933
+ elif isinstance(new_state.data, ResultRecord):
934
+ result = new_state.data.result
936
935
  else:
937
936
  result = new_state.data
938
937
 
@@ -953,7 +952,8 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
953
952
  # if the return value is a BaseResult, we need to fetch it
954
953
  if isinstance(self._return_value, BaseResult):
955
954
  return await self._return_value.get()
956
-
955
+ elif isinstance(self._return_value, ResultRecord):
956
+ return self._return_value.result
957
957
  # otherwise, return the value as is
958
958
  return self._return_value
959
959
 
@@ -966,10 +966,6 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
966
966
  return self._raised
967
967
 
968
968
  async def handle_success(self, result: R, transaction: Transaction) -> R:
969
- result_store = getattr(TaskRunContext.get(), "result_store", None)
970
- if result_store is None:
971
- raise ValueError("Result store is not set")
972
-
973
969
  if self.task.cache_expiration is not None:
974
970
  expiration = pendulum.now("utc") + self.task.cache_expiration
975
971
  else:
@@ -977,7 +973,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
977
973
 
978
974
  terminal_state = await return_value_to_state(
979
975
  result,
980
- result_store=result_store,
976
+ result_store=get_result_store(),
981
977
  key=transaction.key,
982
978
  expiration=expiration,
983
979
  )
@@ -1048,11 +1044,10 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1048
1044
  # If the task fails, and we have retries left, set the task to retrying.
1049
1045
  if not await self.handle_retry(exc):
1050
1046
  # If the task has no retries left, or the retry condition is not met, set the task to failed.
1051
- context = TaskRunContext.get()
1052
1047
  state = await exception_to_failed_state(
1053
1048
  exc,
1054
1049
  message="Task run encountered an exception",
1055
- result_store=getattr(context, "result_store", None),
1050
+ result_store=get_result_store(),
1056
1051
  )
1057
1052
  self.record_terminal_state_timing(state)
1058
1053
  await self.set_state(state)
@@ -1102,10 +1097,13 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1102
1097
  log_prints=log_prints,
1103
1098
  task_run=self.task_run,
1104
1099
  parameters=self.parameters,
1105
- result_store=await get_current_result_store().update_for_task(
1100
+ result_store=await get_result_store().update_for_task(
1106
1101
  self.task, _sync=False
1107
1102
  ),
1108
1103
  client=client,
1104
+ persist_result=self.task.persist_result
1105
+ if self.task.persist_result is not None
1106
+ else should_persist_result(),
1109
1107
  )
1110
1108
  )
1111
1109
  stack.enter_context(ConcurrencyContext())
@@ -1226,17 +1224,13 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1226
1224
  if self.task.refresh_cache is not None
1227
1225
  else PREFECT_TASKS_REFRESH_CACHE.value()
1228
1226
  )
1229
- result_store = getattr(TaskRunContext.get(), "result_store", None)
1230
- if result_store and result_store.persist_result:
1231
- store = ResultRecordStore(result_store=result_store)
1232
- else:
1233
- store = None
1234
1227
 
1235
1228
  with transaction(
1236
1229
  key=self.compute_transaction_key(),
1237
- store=store,
1230
+ store=get_result_store(),
1238
1231
  overwrite=overwrite,
1239
1232
  logger=self.logger,
1233
+ write_on_commit=should_persist_result(),
1240
1234
  ) as txn:
1241
1235
  yield txn
1242
1236
 
@@ -1272,10 +1266,10 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1272
1266
  if transaction.is_committed():
1273
1267
  result = transaction.read()
1274
1268
  else:
1275
- if self.task.tags:
1269
+ if self.task_run.tags:
1276
1270
  # Acquire a concurrency slot for each tag, but only if a limit
1277
1271
  # matching the tag already exists.
1278
- async with aconcurrency(list(self.task.tags), self.task_run.id):
1272
+ async with aconcurrency(list(self.task_run.tags), self.task_run.id):
1279
1273
  result = await call_with_parameters(self.task.fn, parameters)
1280
1274
  else:
1281
1275
  result = await call_with_parameters(self.task.fn, parameters)
prefect/transactions.py CHANGED
@@ -22,7 +22,13 @@ from prefect.exceptions import MissingContextError, SerializationError
22
22
  from prefect.logging.loggers import get_logger, get_run_logger
23
23
  from prefect.records import RecordStore
24
24
  from prefect.records.base import TransactionRecord
25
- from prefect.results import BaseResult, ResultRecord, ResultStore
25
+ from prefect.results import (
26
+ BaseResult,
27
+ ResultRecord,
28
+ ResultStore,
29
+ get_result_store,
30
+ should_persist_result,
31
+ )
26
32
  from prefect.utilities.annotations import NotSet
27
33
  from prefect.utilities.collections import AutoEnum
28
34
  from prefect.utilities.engine import _get_hook_name
@@ -66,19 +72,91 @@ class Transaction(ContextModel):
66
72
  logger: Union[logging.Logger, logging.LoggerAdapter] = Field(
67
73
  default_factory=partial(get_logger, "transactions")
68
74
  )
75
+ write_on_commit: bool = True
69
76
  _stored_values: Dict[str, Any] = PrivateAttr(default_factory=dict)
70
77
  _staged_value: Any = None
71
78
  __var__: ContextVar = ContextVar("transaction")
72
79
 
73
80
  def set(self, name: str, value: Any) -> None:
81
+ """
82
+ Set a stored value in the transaction.
83
+
84
+ Args:
85
+ name: The name of the value to set
86
+ value: The value to set
87
+
88
+ Examples:
89
+ Set a value for use later in the transaction:
90
+ ```python
91
+ with transaction() as txn:
92
+ txn.set("key", "value")
93
+ ...
94
+ assert txn.get("key") == "value"
95
+ ```
96
+ """
74
97
  self._stored_values[name] = value
75
98
 
76
99
  def get(self, name: str, default: Any = NotSet) -> Any:
77
- if name not in self._stored_values:
78
- if default is not NotSet:
79
- return default
80
- raise ValueError(f"Could not retrieve value for unknown key: {name}")
81
- return self._stored_values.get(name)
100
+ """
101
+ Get a stored value from the transaction.
102
+
103
+ Child transactions will return values from their parents unless a value with
104
+ the same name is set in the child transaction.
105
+
106
+ Direct changes to returned values will not update the stored value. To update the
107
+ stored value, use the `set` method.
108
+
109
+ Args:
110
+ name: The name of the value to get
111
+ default: The default value to return if the value is not found
112
+
113
+ Returns:
114
+ The value from the transaction
115
+
116
+ Examples:
117
+ Get a value from the transaction:
118
+ ```python
119
+ with transaction() as txn:
120
+ txn.set("key", "value")
121
+ ...
122
+ assert txn.get("key") == "value"
123
+ ```
124
+
125
+ Get a value from a parent transaction:
126
+ ```python
127
+ with transaction() as parent:
128
+ parent.set("key", "parent_value")
129
+ with transaction() as child:
130
+ assert child.get("key") == "parent_value"
131
+ ```
132
+
133
+ Update a stored value:
134
+ ```python
135
+ with transaction() as txn:
136
+ txn.set("key", [1, 2, 3])
137
+ value = txn.get("key")
138
+ value.append(4)
139
+ # Stored value is not updated until `.set` is called
140
+ assert value == [1, 2, 3, 4]
141
+ assert txn.get("key") == [1, 2, 3]
142
+
143
+ txn.set("key", value)
144
+ assert txn.get("key") == [1, 2, 3, 4]
145
+ ```
146
+ """
147
+ # deepcopy to prevent mutation of stored values
148
+ value = copy.deepcopy(self._stored_values.get(name, NotSet))
149
+ if value is NotSet:
150
+ # if there's a parent transaction, get the value from the parent
151
+ parent = self.get_parent()
152
+ if parent is not None:
153
+ value = parent.get(name, default)
154
+ # if there's no parent transaction, use the default
155
+ elif default is not NotSet:
156
+ value = default
157
+ else:
158
+ raise ValueError(f"Could not retrieve value for unknown key: {name}")
159
+ return value
82
160
 
83
161
  def is_committed(self) -> bool:
84
162
  return self.state == TransactionState.COMMITTED
@@ -101,8 +179,6 @@ class Transaction(ContextModel):
101
179
  "Context already entered. Context enter calls cannot be nested."
102
180
  )
103
181
  parent = get_transaction()
104
- if parent:
105
- self._stored_values = copy.deepcopy(parent._stored_values)
106
182
  # set default commit behavior; either inherit from parent or set a default of eager
107
183
  if self.commit_mode is None:
108
184
  self.commit_mode = parent.commit_mode if parent else CommitMode.LAZY
@@ -119,7 +195,7 @@ class Transaction(ContextModel):
119
195
  and not self.store.supports_isolation_level(self.isolation_level)
120
196
  ):
121
197
  raise ValueError(
122
- f"Isolation level {self.isolation_level.name} is not supported by record store type {self.store.__class__.__name__}"
198
+ f"Isolation level {self.isolation_level.name} is not supported by provided result store."
123
199
  )
124
200
 
125
201
  # this needs to go before begin, which could set the state to committed
@@ -229,14 +305,21 @@ class Transaction(ContextModel):
229
305
  for hook in self.on_commit_hooks:
230
306
  self.run_hook(hook, "commit")
231
307
 
232
- if self.store and self.key:
308
+ if self.store and self.key and self.write_on_commit:
233
309
  if isinstance(self.store, ResultStore):
234
310
  if isinstance(self._staged_value, BaseResult):
235
- self.store.write(self.key, self._staged_value.get(_sync=True))
311
+ self.store.write(
312
+ key=self.key, obj=self._staged_value.get(_sync=True)
313
+ )
314
+ elif isinstance(self._staged_value, ResultRecord):
315
+ self.store.persist_result_record(
316
+ result_record=self._staged_value
317
+ )
236
318
  else:
237
- self.store.write(self.key, self._staged_value)
319
+ self.store.write(key=self.key, obj=self._staged_value)
238
320
  else:
239
- self.store.write(self.key, self._staged_value)
321
+ self.store.write(key=self.key, result=self._staged_value)
322
+
240
323
  self.state = TransactionState.COMMITTED
241
324
  if (
242
325
  self.store
@@ -287,7 +370,7 @@ class Transaction(ContextModel):
287
370
 
288
371
  def stage(
289
372
  self,
290
- value: Union["BaseResult", Any],
373
+ value: Any,
291
374
  on_rollback_hooks: Optional[List] = None,
292
375
  on_commit_hooks: Optional[List] = None,
293
376
  ) -> None:
@@ -349,6 +432,7 @@ def transaction(
349
432
  commit_mode: Optional[CommitMode] = None,
350
433
  isolation_level: Optional[IsolationLevel] = None,
351
434
  overwrite: bool = False,
435
+ write_on_commit: Optional[bool] = None,
352
436
  logger: Union[logging.Logger, logging.LoggerAdapter, None] = None,
353
437
  ) -> Generator[Transaction, None, None]:
354
438
  """
@@ -361,48 +445,16 @@ def transaction(
361
445
  - commit_mode: The commit mode controlling when the transaction and
362
446
  child transactions are committed
363
447
  - overwrite: Whether to overwrite an existing transaction record in the store
448
+ - write_on_commit: Whether to write the result to the store on commit. If not provided,
449
+ will default will be determined by the current run context. If no run context is
450
+ available, the value of `PREFECT_RESULTS_PERSIST_BY_DEFAULT` will be used.
364
451
 
365
452
  Yields:
366
453
  - Transaction: An object representing the transaction state
367
454
  """
368
455
  # if there is no key, we won't persist a record
369
456
  if key and not store:
370
- from prefect.context import FlowRunContext, TaskRunContext
371
- from prefect.results import ResultStore, get_default_result_storage
372
-
373
- flow_run_context = FlowRunContext.get()
374
- task_run_context = TaskRunContext.get()
375
- existing_store = getattr(task_run_context, "result_store", None) or getattr(
376
- flow_run_context, "result_store", None
377
- )
378
-
379
- new_store: ResultStore
380
- if existing_store and existing_store.result_storage_block_id:
381
- new_store = existing_store.model_copy(
382
- update={
383
- "persist_result": True,
384
- }
385
- )
386
- else:
387
- default_storage = get_default_result_storage(_sync=True)
388
- if existing_store:
389
- new_store = existing_store.model_copy(
390
- update={
391
- "persist_result": True,
392
- "storage_block": default_storage,
393
- "storage_block_id": default_storage._block_document_id,
394
- }
395
- )
396
- else:
397
- new_store = ResultStore(
398
- persist_result=True,
399
- result_storage=default_storage,
400
- )
401
- from prefect.records.result_store import ResultRecordStore
402
-
403
- store = ResultRecordStore(
404
- result_store=new_store,
405
- )
457
+ store = get_result_store()
406
458
 
407
459
  try:
408
460
  logger = logger or get_run_logger()
@@ -415,6 +467,9 @@ def transaction(
415
467
  commit_mode=commit_mode,
416
468
  isolation_level=isolation_level,
417
469
  overwrite=overwrite,
470
+ write_on_commit=write_on_commit
471
+ if write_on_commit is not None
472
+ else should_persist_result(),
418
473
  logger=logger,
419
474
  ) as txn:
420
475
  yield txn
@@ -44,12 +44,11 @@ from prefect.exceptions import (
44
44
  )
45
45
  from prefect.flows import Flow
46
46
  from prefect.futures import PrefectFuture
47
- from prefect.futures import PrefectFuture as NewPrefectFuture
48
47
  from prefect.logging.loggers import (
49
48
  get_logger,
50
49
  task_run_logger,
51
50
  )
52
- from prefect.results import BaseResult
51
+ from prefect.results import BaseResult, ResultRecord, should_persist_result
53
52
  from prefect.settings import (
54
53
  PREFECT_LOGGING_LOG_PRINTS,
55
54
  )
@@ -122,7 +121,7 @@ async def collect_task_run_inputs(expr: Any, max_depth: int = -1) -> Set[TaskRun
122
121
 
123
122
 
124
123
  def collect_task_run_inputs_sync(
125
- expr: Any, future_cls: Any = NewPrefectFuture, max_depth: int = -1
124
+ expr: Any, future_cls: Any = PrefectFuture, max_depth: int = -1
126
125
  ) -> Set[TaskRunInput]:
127
126
  """
128
127
  This function recurses through an expression to generate a set of any discernible
@@ -131,7 +130,7 @@ def collect_task_run_inputs_sync(
131
130
 
132
131
  Examples:
133
132
  >>> task_inputs = {
134
- >>> k: collect_task_run_inputs(v) for k, v in parameters.items()
133
+ >>> k: collect_task_run_inputs_sync(v) for k, v in parameters.items()
135
134
  >>> }
136
135
  """
137
136
  # TODO: This function needs to be updated to detect parameters and constants
@@ -401,6 +400,8 @@ async def propose_state(
401
400
  # Avoid fetching the result unless it is cached, otherwise we defeat
402
401
  # the purpose of disabling `cache_result_in_memory`
403
402
  result = await state.result(raise_on_failure=False, fetch=True)
403
+ elif isinstance(state.data, ResultRecord):
404
+ result = state.data.result
404
405
  else:
405
406
  result = state.data
406
407
 
@@ -504,6 +505,8 @@ def propose_state_sync(
504
505
  result = state.result(raise_on_failure=False, fetch=True)
505
506
  if inspect.isawaitable(result):
506
507
  result = run_coro_as_sync(result)
508
+ elif isinstance(state.data, ResultRecord):
509
+ result = state.data.result
507
510
  else:
508
511
  result = state.data
509
512
 
@@ -732,6 +735,13 @@ def emit_task_run_state_change_event(
732
735
  ) -> Event:
733
736
  state_message_truncation_length = 100_000
734
737
 
738
+ if isinstance(validated_state.data, ResultRecord) and should_persist_result():
739
+ data = validated_state.data.metadata.model_dump(mode="json")
740
+ elif isinstance(validated_state.data, BaseResult):
741
+ data = validated_state.data.model_dump(mode="json")
742
+ else:
743
+ data = None
744
+
735
745
  return emit_event(
736
746
  id=validated_state.id,
737
747
  occurred=validated_state.timestamp,
@@ -770,9 +780,7 @@ def emit_task_run_state_change_event(
770
780
  exclude_unset=True,
771
781
  exclude={"flow_run_id", "task_run_id"},
772
782
  ),
773
- "data": validated_state.data.model_dump(mode="json")
774
- if isinstance(validated_state.data, BaseResult)
775
- else None,
783
+ "data": data,
776
784
  },
777
785
  "task_run": task_run.model_dump(
778
786
  mode="json",
@@ -822,7 +830,7 @@ def resolve_to_final_result(expr, context):
822
830
  if isinstance(context.get("annotation"), quote):
823
831
  raise StopVisiting()
824
832
 
825
- if isinstance(expr, NewPrefectFuture):
833
+ if isinstance(expr, PrefectFuture):
826
834
  upstream_task_run = context.get("current_task_run")
827
835
  upstream_task = context.get("current_task")
828
836
  if (