prefect-client 3.3.4.dev2__py3-none-any.whl → 3.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prefect/tasks.py CHANGED
@@ -764,11 +764,6 @@ class Task(Generic[P, R]):
764
764
  def on_rollback(
765
765
  self, fn: Callable[["Transaction"], None]
766
766
  ) -> Callable[["Transaction"], None]:
767
- if asyncio.iscoroutinefunction(fn):
768
- raise ValueError(
769
- "Asynchronous rollback hooks are not yet supported. Rollback hooks must be synchronous functions."
770
- )
771
-
772
767
  self.on_rollback_hooks.append(fn)
773
768
  return fn
774
769
 
@@ -1,6 +1,6 @@
1
1
  import logging
2
- import os
3
2
  import re
3
+ import socket
4
4
  from typing import TYPE_CHECKING
5
5
  from urllib.parse import urljoin
6
6
  from uuid import UUID
@@ -74,7 +74,7 @@ def setup_exporters(
74
74
  resource = Resource.create(
75
75
  {
76
76
  "service.name": "prefect",
77
- "service.instance.id": os.uname().nodename,
77
+ "service.instance.id": socket.gethostname(),
78
78
  "prefect.account": str(account_id),
79
79
  "prefect.workspace": str(workspace_id),
80
80
  }
prefect/transactions.py CHANGED
@@ -1,19 +1,26 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import asyncio
1
5
  import copy
6
+ import inspect
2
7
  import logging
3
- from contextlib import contextmanager
8
+ from contextlib import asynccontextmanager, contextmanager
4
9
  from contextvars import ContextVar, Token
5
10
  from functools import partial
6
11
  from typing import (
7
12
  Any,
13
+ AsyncGenerator,
8
14
  Callable,
9
- Dict,
15
+ ClassVar,
10
16
  Generator,
11
- List,
17
+ NoReturn,
12
18
  Optional,
13
19
  Type,
14
20
  Union,
15
21
  )
16
22
 
23
+ import anyio.to_thread
17
24
  from pydantic import Field, PrivateAttr
18
25
  from typing_extensions import Self
19
26
 
@@ -32,8 +39,11 @@ from prefect.results import (
32
39
  )
33
40
  from prefect.utilities._engine import get_hook_name
34
41
  from prefect.utilities.annotations import NotSet
42
+ from prefect.utilities.asyncutils import run_coro_as_sync
35
43
  from prefect.utilities.collections import AutoEnum
36
44
 
45
+ logger: logging.Logger = get_logger("transactions")
46
+
37
47
 
38
48
  class IsolationLevel(AutoEnum):
39
49
  READ_COMMITTED = AutoEnum.auto()
@@ -54,29 +64,27 @@ class TransactionState(AutoEnum):
54
64
  ROLLED_BACK = AutoEnum.auto()
55
65
 
56
66
 
57
- class Transaction(ContextModel):
67
+ class BaseTransaction(ContextModel, abc.ABC):
58
68
  """
59
69
  A base model for transaction state.
60
70
  """
61
71
 
62
72
  store: Optional[ResultStore] = None
63
73
  key: Optional[str] = None
64
- children: List["Transaction"] = Field(default_factory=list)
74
+ children: list[Self] = Field(default_factory=list)
65
75
  commit_mode: Optional[CommitMode] = None
66
76
  isolation_level: Optional[IsolationLevel] = IsolationLevel.READ_COMMITTED
67
77
  state: TransactionState = TransactionState.PENDING
68
- on_commit_hooks: List[Callable[["Transaction"], None]] = Field(default_factory=list)
69
- on_rollback_hooks: List[Callable[["Transaction"], None]] = Field(
70
- default_factory=list
71
- )
78
+ on_commit_hooks: list[Callable[[Self], None]] = Field(default_factory=list)
79
+ on_rollback_hooks: list[Callable[[Self], None]] = Field(default_factory=list)
72
80
  overwrite: bool = False
73
81
  logger: Union[logging.Logger, LoggingAdapter] = Field(
74
82
  default_factory=partial(get_logger, "transactions")
75
83
  )
76
84
  write_on_commit: bool = True
77
- _stored_values: Dict[str, Any] = PrivateAttr(default_factory=dict)
78
- _staged_value: Any = None
79
- __var__: ContextVar[Self] = ContextVar("transaction")
85
+ _stored_values: dict[str, Any] = PrivateAttr(default_factory=dict)
86
+ _staged_value: ResultRecord[Any] | Any = None
87
+ __var__: ClassVar[ContextVar[Self]] = ContextVar("transaction")
80
88
 
81
89
  def set(self, name: str, value: Any) -> None:
82
90
  """
@@ -174,7 +182,8 @@ class Transaction(ContextModel):
174
182
  def is_active(self) -> bool:
175
183
  return self.state == TransactionState.ACTIVE
176
184
 
177
- def __enter__(self) -> Self:
185
+ def prepare_transaction(self) -> None:
186
+ """Helper method to prepare transaction state and validate configuration."""
178
187
  if self._token is not None:
179
188
  raise RuntimeError(
180
189
  "Context already entered. Context enter calls cannot be nested."
@@ -203,6 +212,51 @@ class Transaction(ContextModel):
203
212
 
204
213
  # this needs to go before begin, which could set the state to committed
205
214
  self.state = TransactionState.ACTIVE
215
+
216
+ def add_child(self, transaction: Self) -> None:
217
+ self.children.append(transaction)
218
+
219
+ def get_parent(self) -> Self | None:
220
+ parent = None
221
+ if self._token:
222
+ prev_var = self._token.old_value
223
+ if prev_var != Token.MISSING:
224
+ parent = prev_var
225
+ else:
226
+ # `_token` has been reset so we need to get the active transaction from the context var
227
+ parent = self.get_active()
228
+ return parent
229
+
230
+ def stage(
231
+ self,
232
+ value: Any,
233
+ on_rollback_hooks: Optional[list[Callable[..., Any]]] = None,
234
+ on_commit_hooks: Optional[list[Callable[..., Any]]] = None,
235
+ ) -> None:
236
+ """
237
+ Stage a value to be committed later.
238
+ """
239
+ on_commit_hooks = on_commit_hooks or []
240
+ on_rollback_hooks = on_rollback_hooks or []
241
+
242
+ if self.state != TransactionState.COMMITTED:
243
+ self._staged_value = value
244
+ self.on_rollback_hooks += on_rollback_hooks
245
+ self.on_commit_hooks += on_commit_hooks
246
+ self.state = TransactionState.STAGED
247
+
248
+ @classmethod
249
+ def get_active(cls: Type[Self]) -> Optional[Self]:
250
+ return cls.__var__.get(None)
251
+
252
+
253
+ class Transaction(BaseTransaction):
254
+ """
255
+ A model representing the state of a transaction.
256
+ """
257
+
258
+ def __enter__(self) -> Self:
259
+ self.prepare_transaction()
206
260
  self.begin()
207
261
  self._token = self.__var__.set(self)
208
262
  return self
@@ -252,11 +306,9 @@ class Transaction(ContextModel):
252
306
  ):
253
307
  self.state = TransactionState.COMMITTED
254
308
 
255
- def read(self) -> Optional[ResultRecord[Any]]:
309
+ def read(self) -> ResultRecord[Any] | None:
256
310
  if self.store and self.key:
257
- record = self.store.read(key=self.key)
258
- if isinstance(record, ResultRecord):
259
- return record
311
+ return self.store.read(key=self.key)
260
312
  return None
261
313
 
262
314
  def reset(self) -> None:
@@ -274,20 +326,6 @@ class Transaction(ContextModel):
274
326
  if parent and self.state == TransactionState.ROLLED_BACK:
275
327
  parent.rollback()
276
328
 
277
- def add_child(self, transaction: "Transaction") -> None:
278
- self.children.append(transaction)
279
-
280
- def get_parent(self) -> Optional["Transaction"]:
281
- parent = None
282
- if self._token:
283
- prev_var = getattr(self._token, "old_value")
284
- if prev_var != Token.MISSING:
285
- parent = prev_var
286
- else:
287
- # `_token` has been reset so we need to get the active transaction from the context var
288
- parent = self.get_active()
289
- return parent
290
-
291
329
  def commit(self) -> bool:
292
330
  if self.state in [TransactionState.ROLLED_BACK, TransactionState.COMMITTED]:
293
331
  if (
@@ -302,21 +340,19 @@ class Transaction(ContextModel):
302
340
 
303
341
  try:
304
342
  for child in self.children:
305
- child.commit()
343
+ if inspect.iscoroutinefunction(child.commit):
344
+ run_coro_as_sync(child.commit())
345
+ else:
346
+ child.commit()
306
347
 
307
348
  for hook in self.on_commit_hooks:
308
349
  self.run_hook(hook, "commit")
309
350
 
310
351
  if self.store and self.key and self.write_on_commit:
311
- if isinstance(self.store, ResultStore):
312
- if isinstance(self._staged_value, ResultRecord):
313
- self.store.persist_result_record(
314
- result_record=self._staged_value
315
- )
316
- else:
317
- self.store.write(key=self.key, obj=self._staged_value)
352
+ if isinstance(self._staged_value, ResultRecord):
353
+ self.store.persist_result_record(result_record=self._staged_value)
318
354
  else:
319
- self.store.write(key=self.key, result=self._staged_value)
355
+ self.store.write(key=self.key, obj=self._staged_value)
320
356
 
321
357
  self.state = TransactionState.COMMITTED
322
358
  if (
@@ -353,7 +389,10 @@ class Transaction(ContextModel):
353
389
  self.logger.info(f"Running {hook_type} hook {hook_name!r}")
354
390
 
355
391
  try:
356
- hook(self)
392
+ if asyncio.iscoroutinefunction(hook):
393
+ run_coro_as_sync(hook(self))
394
+ else:
395
+ hook(self)
357
396
  except Exception as exc:
358
397
  if should_log:
359
398
  self.logger.error(
@@ -366,24 +405,6 @@ class Transaction(ContextModel):
366
405
  f"{hook_type.capitalize()} hook {hook_name!r} finished running successfully"
367
406
  )
368
407
 
369
- def stage(
370
- self,
371
- value: Any,
372
- on_rollback_hooks: Optional[list[Callable[..., Any]]] = None,
373
- on_commit_hooks: Optional[list[Callable[..., Any]]] = None,
374
- ) -> None:
375
- """
376
- Stage a value to be committed later.
377
- """
378
- on_commit_hooks = on_commit_hooks or []
379
- on_rollback_hooks = on_rollback_hooks or []
380
-
381
- if self.state != TransactionState.COMMITTED:
382
- self._staged_value = value
383
- self.on_rollback_hooks += on_rollback_hooks
384
- self.on_commit_hooks += on_commit_hooks
385
- self.state = TransactionState.STAGED
386
-
387
408
  def rollback(self) -> bool:
388
409
  if self.state in [TransactionState.ROLLED_BACK, TransactionState.COMMITTED]:
389
410
  return False
@@ -392,10 +413,13 @@ class Transaction(ContextModel):
392
413
  for hook in reversed(self.on_rollback_hooks):
393
414
  self.run_hook(hook, "rollback")
394
415
 
395
- self.state = TransactionState.ROLLED_BACK
416
+ self.state: TransactionState = TransactionState.ROLLED_BACK
396
417
 
397
418
  for child in reversed(self.children):
398
- child.rollback()
419
+ if inspect.iscoroutinefunction(child.rollback):
420
+ run_coro_as_sync(child.rollback())
421
+ else:
422
+ child.rollback()
399
423
 
400
424
  return True
401
425
  except Exception:
@@ -414,24 +438,221 @@ class Transaction(ContextModel):
414
438
  self.logger.debug(f"Releasing lock for transaction {self.key!r}")
415
439
  self.store.release_lock(self.key)
416
440
 
417
- @classmethod
418
- def get_active(cls: Type[Self]) -> Optional[Self]:
419
- return cls.__var__.get(None)
420
441
 
442
+ class AsyncTransaction(BaseTransaction):
443
+ """
444
+ A model representing the state of an asynchronous transaction.
445
+ """
446
+
447
+ async def begin(self) -> None:
448
+ if (
449
+ self.store
450
+ and self.key
451
+ and self.isolation_level == IsolationLevel.SERIALIZABLE
452
+ ):
453
+ self.logger.debug(f"Acquiring lock for transaction {self.key!r}")
454
+ await self.store.aacquire_lock(self.key)
455
+ if (
456
+ not self.overwrite
457
+ and self.store
458
+ and self.key
459
+ and await self.store.aexists(key=self.key)
460
+ ):
461
+ self.state = TransactionState.COMMITTED
462
+
463
+ async def read(self) -> ResultRecord[Any] | None:
464
+ if self.store and self.key:
465
+ return await self.store.aread(key=self.key)
466
+ return None
421
467
 
422
- def get_transaction() -> Optional[Transaction]:
423
- return Transaction.get_active()
468
+ async def reset(self) -> None:
469
+ parent = self.get_parent()
470
+
471
+ if parent:
472
+ # parent takes responsibility
473
+ parent.add_child(self)
474
+
475
+ if self._token:
476
+ self.__var__.reset(self._token)
477
+ self._token = None
478
+
479
+ # do this below reset so that get_transaction() returns the relevant txn
480
+ if parent and self.state == TransactionState.ROLLED_BACK:
481
+ await parent.rollback()
482
+
483
+ async def commit(self) -> bool:
484
+ if self.state in [TransactionState.ROLLED_BACK, TransactionState.COMMITTED]:
485
+ if (
486
+ self.store
487
+ and self.key
488
+ and self.isolation_level == IsolationLevel.SERIALIZABLE
489
+ ):
490
+ self.logger.debug(f"Releasing lock for transaction {self.key!r}")
491
+ self.store.release_lock(self.key)
492
+
493
+ return False
494
+
495
+ try:
496
+ for child in self.children:
497
+ if isinstance(child, AsyncTransaction):
498
+ await child.commit()
499
+ else:
500
+ child.commit()
501
+
502
+ for hook in self.on_commit_hooks:
503
+ await self.run_hook(hook, "commit")
504
+
505
+ if self.store and self.key and self.write_on_commit:
506
+ if isinstance(self._staged_value, ResultRecord):
507
+ await self.store.apersist_result_record(
508
+ result_record=self._staged_value
509
+ )
510
+ else:
511
+ await self.store.awrite(key=self.key, obj=self._staged_value)
512
+
513
+ self.state = TransactionState.COMMITTED
514
+ if (
515
+ self.store
516
+ and self.key
517
+ and self.isolation_level == IsolationLevel.SERIALIZABLE
518
+ ):
519
+ self.logger.debug(f"Releasing lock for transaction {self.key!r}")
520
+ self.store.release_lock(self.key)
521
+ return True
522
+ except SerializationError as exc:
523
+ if self.logger:
524
+ self.logger.warning(
525
+ f"Encountered an error while serializing result for transaction {self.key!r}: {exc}"
526
+ " Code execution will continue, but the transaction will not be committed.",
527
+ )
528
+ await self.rollback()
529
+ return False
530
+ except Exception:
531
+ if self.logger:
532
+ self.logger.exception(
533
+ f"An error was encountered while committing transaction {self.key!r}",
534
+ exc_info=True,
535
+ )
536
+ await self.rollback()
537
+ return False
538
+
539
+ async def run_hook(self, hook: Callable[..., Any], hook_type: str) -> None:
540
+ hook_name = get_hook_name(hook)
541
+ # Undocumented way to disable logging for a hook. Subject to change.
542
+ should_log = getattr(hook, "log_on_run", True)
543
+
544
+ if should_log:
545
+ self.logger.info(f"Running {hook_type} hook {hook_name!r}")
546
+
547
+ try:
548
+ if asyncio.iscoroutinefunction(hook):
549
+ await hook(self)
550
+ else:
551
+ await anyio.to_thread.run_sync(hook, self)
552
+ except Exception as exc:
553
+ if should_log:
554
+ self.logger.error(
555
+ f"An error was encountered while running {hook_type} hook {hook_name!r}",
556
+ )
557
+ raise exc
558
+ else:
559
+ if should_log:
560
+ self.logger.info(
561
+ f"{hook_type.capitalize()} hook {hook_name!r} finished running successfully"
562
+ )
563
+
564
+ async def rollback(self) -> bool:
565
+ if self.state in [TransactionState.ROLLED_BACK, TransactionState.COMMITTED]:
566
+ return False
567
+
568
+ try:
569
+ for hook in reversed(self.on_rollback_hooks):
570
+ await self.run_hook(hook, "rollback")
571
+
572
+ self.state: TransactionState = TransactionState.ROLLED_BACK
573
+
574
+ for child in reversed(self.children):
575
+ if isinstance(child, AsyncTransaction):
576
+ await child.rollback()
577
+ else:
578
+ child.rollback()
579
+
580
+ return True
581
+ except Exception:
582
+ if self.logger:
583
+ self.logger.exception(
584
+ f"An error was encountered while rolling back transaction {self.key!r}",
585
+ exc_info=True,
586
+ )
587
+ return False
588
+ finally:
589
+ if (
590
+ self.store
591
+ and self.key
592
+ and self.isolation_level == IsolationLevel.SERIALIZABLE
593
+ ):
594
+ self.logger.debug(f"Releasing lock for transaction {self.key!r}")
595
+ self.store.release_lock(self.key)
596
+
597
+ async def __aenter__(self) -> Self:
598
+ self.prepare_transaction()
599
+ await self.begin()
600
+ self._token = self.__var__.set(self)
601
+ return self
602
+
603
+ async def __aexit__(self, *exc_info: Any) -> None:
604
+ exc_type, exc_val, _ = exc_info
605
+ if not self._token:
606
+ raise RuntimeError(
607
+ "Asymmetric use of context. Context exit called without an enter."
608
+ )
609
+ if exc_type:
610
+ await self.rollback()
611
+ await self.reset()
612
+ raise exc_val
613
+
614
+ if self.commit_mode == CommitMode.EAGER:
615
+ await self.commit()
616
+
617
+ # if parent, let them take responsibility
618
+ if self.get_parent():
619
+ await self.reset()
620
+ return
621
+
622
+ if self.commit_mode == CommitMode.OFF:
623
+ # if no one took responsibility to commit, rolling back
624
+ # note that rollback returns if already committed
625
+ await self.rollback()
626
+ elif self.commit_mode == CommitMode.LAZY:
627
+ # no one left to take responsibility for committing
628
+ await self.commit()
629
+
630
+ await self.reset()
631
+
632
+ def __enter__(self) -> NoReturn:
633
+ raise NotImplementedError(
634
+ "AsyncTransaction does not support the `with` statement. Use the `async with` statement instead."
635
+ )
636
+
637
+ def __exit__(self, *exc_info: Any) -> NoReturn:
638
+ raise NotImplementedError(
639
+ "AsyncTransaction does not support the `with` statement. Use the `async with` statement instead."
640
+ )
641
+
642
+
643
+ def get_transaction() -> BaseTransaction | None:
644
+ return BaseTransaction.get_active()
424
645
 
425
646
 
426
647
  @contextmanager
427
648
  def transaction(
428
- key: Optional[str] = None,
429
- store: Optional[ResultStore] = None,
430
- commit_mode: Optional[CommitMode] = None,
431
- isolation_level: Optional[IsolationLevel] = None,
649
+ key: str | None = None,
650
+ store: ResultStore | None = None,
651
+ commit_mode: CommitMode | None = None,
652
+ isolation_level: IsolationLevel | None = None,
432
653
  overwrite: bool = False,
433
654
  write_on_commit: bool = True,
434
- logger: Optional[Union[logging.Logger, LoggingAdapter]] = None,
655
+ logger: logging.Logger | LoggingAdapter | None = None,
435
656
  ) -> Generator[Transaction, None, None]:
436
657
  """
437
658
  A context manager for opening and managing a transaction.
@@ -473,3 +694,56 @@ def transaction(
473
694
  logger=_logger,
474
695
  ) as txn:
475
696
  yield txn
697
+
698
+
699
+ @asynccontextmanager
700
+ async def atransaction(
701
+ key: str | None = None,
702
+ store: ResultStore | None = None,
703
+ commit_mode: CommitMode | None = None,
704
+ isolation_level: IsolationLevel | None = None,
705
+ overwrite: bool = False,
706
+ write_on_commit: bool = True,
707
+ logger: logging.Logger | LoggingAdapter | None = None,
708
+ ) -> AsyncGenerator[AsyncTransaction, None]:
709
+ """
710
+ An asynchronous context manager for opening and managing an asynchronous transaction.
711
+
712
+ Args:
713
+ - key: An identifier to use for the transaction
714
+ - store: The store to use for persisting the transaction result. If not provided,
715
+ a default store will be used based on the current run context.
716
+ - commit_mode: The commit mode controlling when the transaction and
717
+ child transactions are committed
718
+ - overwrite: Whether to overwrite an existing transaction record in the store
719
+ - write_on_commit: Whether to write the result to the store on commit. If not provided,
720
+ the default will be determined by the current run context. If no run context is
721
+ available, the value of `PREFECT_RESULTS_PERSIST_BY_DEFAULT` will be used.
722
+
723
+ Yields:
724
+ - AsyncTransaction: An object representing the transaction state
725
+ """
726
+
727
+ # if there is no key, we won't persist a record
728
+ if key and not store:
729
+ store = get_result_store()
730
+
731
+ # Avoid inheriting a NullFileSystem for metadata_storage from a flow's result store
732
+ if store and isinstance(store.metadata_storage, NullFileSystem):
733
+ store = store.model_copy(update={"metadata_storage": None})
734
+
735
+ try:
736
+ _logger: Union[logging.Logger, LoggingAdapter] = logger or get_run_logger()
737
+ except MissingContextError:
738
+ _logger = get_logger("transactions")
739
+
740
+ async with AsyncTransaction(
741
+ key=key,
742
+ store=store,
743
+ commit_mode=commit_mode,
744
+ isolation_level=isolation_level,
745
+ overwrite=overwrite,
746
+ write_on_commit=write_on_commit,
747
+ logger=_logger,
748
+ ) as txn:
749
+ yield txn