prefect-client 3.1.10__py3-none-any.whl → 3.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. prefect/_experimental/lineage.py +7 -8
  2. prefect/_experimental/sla/__init__.py +0 -0
  3. prefect/_experimental/sla/client.py +66 -0
  4. prefect/_experimental/sla/objects.py +53 -0
  5. prefect/_internal/_logging.py +15 -3
  6. prefect/_internal/compatibility/async_dispatch.py +22 -16
  7. prefect/_internal/compatibility/deprecated.py +42 -18
  8. prefect/_internal/compatibility/migration.py +2 -2
  9. prefect/_internal/concurrency/inspection.py +12 -14
  10. prefect/_internal/concurrency/primitives.py +2 -2
  11. prefect/_internal/concurrency/services.py +154 -80
  12. prefect/_internal/concurrency/waiters.py +13 -9
  13. prefect/_internal/pydantic/annotations/pendulum.py +7 -7
  14. prefect/_internal/pytz.py +4 -3
  15. prefect/_internal/retries.py +10 -5
  16. prefect/_internal/schemas/bases.py +19 -10
  17. prefect/_internal/schemas/validators.py +227 -388
  18. prefect/_version.py +3 -3
  19. prefect/automations.py +236 -30
  20. prefect/blocks/__init__.py +3 -3
  21. prefect/blocks/abstract.py +53 -30
  22. prefect/blocks/core.py +183 -84
  23. prefect/blocks/notifications.py +133 -73
  24. prefect/blocks/redis.py +13 -9
  25. prefect/blocks/system.py +24 -11
  26. prefect/blocks/webhook.py +7 -5
  27. prefect/cache_policies.py +3 -2
  28. prefect/client/orchestration/__init__.py +1957 -0
  29. prefect/client/orchestration/_artifacts/__init__.py +0 -0
  30. prefect/client/orchestration/_artifacts/client.py +239 -0
  31. prefect/client/orchestration/_automations/__init__.py +0 -0
  32. prefect/client/orchestration/_automations/client.py +329 -0
  33. prefect/client/orchestration/_blocks_documents/__init__.py +0 -0
  34. prefect/client/orchestration/_blocks_documents/client.py +334 -0
  35. prefect/client/orchestration/_blocks_schemas/__init__.py +0 -0
  36. prefect/client/orchestration/_blocks_schemas/client.py +200 -0
  37. prefect/client/orchestration/_blocks_types/__init__.py +0 -0
  38. prefect/client/orchestration/_blocks_types/client.py +380 -0
  39. prefect/client/orchestration/_concurrency_limits/__init__.py +0 -0
  40. prefect/client/orchestration/_concurrency_limits/client.py +762 -0
  41. prefect/client/orchestration/_deployments/__init__.py +0 -0
  42. prefect/client/orchestration/_deployments/client.py +1128 -0
  43. prefect/client/orchestration/_flow_runs/__init__.py +0 -0
  44. prefect/client/orchestration/_flow_runs/client.py +903 -0
  45. prefect/client/orchestration/_flows/__init__.py +0 -0
  46. prefect/client/orchestration/_flows/client.py +343 -0
  47. prefect/client/orchestration/_logs/__init__.py +0 -0
  48. prefect/client/orchestration/_logs/client.py +97 -0
  49. prefect/client/orchestration/_variables/__init__.py +0 -0
  50. prefect/client/orchestration/_variables/client.py +157 -0
  51. prefect/client/orchestration/base.py +46 -0
  52. prefect/client/orchestration/routes.py +145 -0
  53. prefect/client/schemas/__init__.py +68 -28
  54. prefect/client/schemas/actions.py +2 -2
  55. prefect/client/schemas/filters.py +5 -0
  56. prefect/client/schemas/objects.py +8 -15
  57. prefect/client/schemas/schedules.py +22 -10
  58. prefect/concurrency/_asyncio.py +87 -0
  59. prefect/concurrency/{events.py → _events.py} +10 -10
  60. prefect/concurrency/asyncio.py +20 -104
  61. prefect/concurrency/context.py +6 -4
  62. prefect/concurrency/services.py +26 -74
  63. prefect/concurrency/sync.py +23 -44
  64. prefect/concurrency/v1/_asyncio.py +63 -0
  65. prefect/concurrency/v1/{events.py → _events.py} +13 -15
  66. prefect/concurrency/v1/asyncio.py +27 -80
  67. prefect/concurrency/v1/context.py +6 -4
  68. prefect/concurrency/v1/services.py +33 -79
  69. prefect/concurrency/v1/sync.py +18 -37
  70. prefect/context.py +66 -45
  71. prefect/deployments/base.py +10 -144
  72. prefect/deployments/flow_runs.py +12 -2
  73. prefect/deployments/runner.py +53 -4
  74. prefect/deployments/steps/pull.py +13 -0
  75. prefect/engine.py +17 -4
  76. prefect/events/clients.py +7 -1
  77. prefect/events/schemas/events.py +3 -2
  78. prefect/filesystems.py +6 -2
  79. prefect/flow_engine.py +101 -85
  80. prefect/flows.py +10 -1
  81. prefect/input/run_input.py +2 -1
  82. prefect/logging/logging.yml +1 -1
  83. prefect/main.py +1 -3
  84. prefect/results.py +2 -307
  85. prefect/runner/runner.py +4 -2
  86. prefect/runner/storage.py +87 -21
  87. prefect/serializers.py +32 -25
  88. prefect/settings/legacy.py +4 -4
  89. prefect/settings/models/api.py +3 -3
  90. prefect/settings/models/cli.py +3 -3
  91. prefect/settings/models/client.py +5 -3
  92. prefect/settings/models/cloud.py +8 -3
  93. prefect/settings/models/deployments.py +3 -3
  94. prefect/settings/models/experiments.py +4 -7
  95. prefect/settings/models/flows.py +3 -3
  96. prefect/settings/models/internal.py +4 -2
  97. prefect/settings/models/logging.py +4 -3
  98. prefect/settings/models/results.py +3 -3
  99. prefect/settings/models/root.py +3 -2
  100. prefect/settings/models/runner.py +4 -4
  101. prefect/settings/models/server/api.py +3 -3
  102. prefect/settings/models/server/database.py +11 -4
  103. prefect/settings/models/server/deployments.py +6 -2
  104. prefect/settings/models/server/ephemeral.py +4 -2
  105. prefect/settings/models/server/events.py +3 -2
  106. prefect/settings/models/server/flow_run_graph.py +6 -2
  107. prefect/settings/models/server/root.py +3 -3
  108. prefect/settings/models/server/services.py +26 -11
  109. prefect/settings/models/server/tasks.py +6 -3
  110. prefect/settings/models/server/ui.py +3 -3
  111. prefect/settings/models/tasks.py +5 -5
  112. prefect/settings/models/testing.py +3 -3
  113. prefect/settings/models/worker.py +5 -3
  114. prefect/settings/profiles.py +15 -2
  115. prefect/states.py +61 -45
  116. prefect/task_engine.py +54 -75
  117. prefect/task_runners.py +56 -55
  118. prefect/task_worker.py +2 -2
  119. prefect/tasks.py +90 -36
  120. prefect/telemetry/bootstrap.py +10 -9
  121. prefect/telemetry/run_telemetry.py +13 -8
  122. prefect/telemetry/services.py +4 -0
  123. prefect/transactions.py +4 -15
  124. prefect/utilities/_git.py +34 -0
  125. prefect/utilities/asyncutils.py +1 -1
  126. prefect/utilities/engine.py +3 -19
  127. prefect/utilities/generics.py +18 -0
  128. prefect/utilities/templating.py +25 -1
  129. prefect/workers/base.py +6 -3
  130. prefect/workers/process.py +1 -1
  131. {prefect_client-3.1.10.dist-info → prefect_client-3.1.12.dist-info}/METADATA +2 -2
  132. {prefect_client-3.1.10.dist-info → prefect_client-3.1.12.dist-info}/RECORD +135 -109
  133. prefect/client/orchestration.py +0 -4523
  134. prefect/records/__init__.py +0 -1
  135. prefect/records/base.py +0 -235
  136. prefect/records/filesystem.py +0 -213
  137. prefect/records/memory.py +0 -184
  138. prefect/records/result_store.py +0 -70
  139. {prefect_client-3.1.10.dist-info → prefect_client-3.1.12.dist-info}/LICENSE +0 -0
  140. {prefect_client-3.1.10.dist-info → prefect_client-3.1.12.dist-info}/WHEEL +0 -0
  141. {prefect_client-3.1.10.dist-info → prefect_client-3.1.12.dist-info}/top_level.txt +0 -0
@@ -1,42 +1,30 @@
1
- import asyncio
1
+ from collections.abc import AsyncGenerator
2
2
  from contextlib import asynccontextmanager
3
- from typing import AsyncGenerator, List, Optional, Union, cast
3
+ from typing import TYPE_CHECKING, Optional, Union
4
4
  from uuid import UUID
5
5
 
6
6
  import anyio
7
- import httpx
8
7
  import pendulum
9
8
 
10
- from ...client.schemas.responses import MinimalConcurrencyLimitResponse
11
-
12
- try:
13
- from pendulum import Interval
14
- except ImportError:
15
- # pendulum < 3
16
- from pendulum.period import Period as Interval # type: ignore
17
-
18
- from prefect.client.orchestration import get_client
19
- from prefect.utilities.asyncutils import sync_compatible
20
-
21
- from .context import ConcurrencyContext
22
- from .events import (
23
- _emit_concurrency_acquisition_events,
24
- _emit_concurrency_release_events,
9
+ from prefect.concurrency.v1._asyncio import (
10
+ acquire_concurrency_slots,
11
+ release_concurrency_slots,
25
12
  )
26
- from .services import ConcurrencySlotAcquisitionService
27
-
28
-
29
- class ConcurrencySlotAcquisitionError(Exception):
30
- """Raised when an unhandlable occurs while acquiring concurrency slots."""
31
-
13
+ from prefect.concurrency.v1._events import (
14
+ emit_concurrency_acquisition_events,
15
+ emit_concurrency_release_events,
16
+ )
17
+ from prefect.concurrency.v1.context import ConcurrencyContext
32
18
 
33
- class AcquireConcurrencySlotTimeoutError(TimeoutError):
34
- """Raised when acquiring a concurrency slot times out."""
19
+ from ._asyncio import (
20
+ AcquireConcurrencySlotTimeoutError as AcquireConcurrencySlotTimeoutError,
21
+ )
22
+ from ._asyncio import ConcurrencySlotAcquisitionError as ConcurrencySlotAcquisitionError
35
23
 
36
24
 
37
25
  @asynccontextmanager
38
26
  async def concurrency(
39
- names: Union[str, List[str]],
27
+ names: Union[str, list[str]],
40
28
  task_run_id: UUID,
41
29
  timeout_seconds: Optional[float] = None,
42
30
  ) -> AsyncGenerator[None, None]:
@@ -69,24 +57,30 @@ async def concurrency(
69
57
  yield
70
58
  return
71
59
 
72
- names_normalized: List[str] = names if isinstance(names, list) else [names]
60
+ names_normalized: list[str] = names if isinstance(names, list) else [names]
73
61
 
74
- limits = await _acquire_concurrency_slots(
62
+ acquire_slots = acquire_concurrency_slots(
75
63
  names_normalized,
76
64
  task_run_id=task_run_id,
77
65
  timeout_seconds=timeout_seconds,
78
66
  )
67
+ if TYPE_CHECKING:
68
+ assert not isinstance(acquire_slots, list)
69
+ limits = await acquire_slots
79
70
  acquisition_time = pendulum.now("UTC")
80
- emitted_events = _emit_concurrency_acquisition_events(limits, task_run_id)
71
+ emitted_events = emit_concurrency_acquisition_events(limits, task_run_id)
81
72
 
82
73
  try:
83
74
  yield
84
75
  finally:
85
- occupancy_period = cast(Interval, (pendulum.now("UTC") - acquisition_time))
76
+ occupancy_period = pendulum.now("UTC") - acquisition_time
86
77
  try:
87
- await _release_concurrency_slots(
78
+ release_slots = release_concurrency_slots(
88
79
  names_normalized, task_run_id, occupancy_period.total_seconds()
89
80
  )
81
+ if TYPE_CHECKING:
82
+ assert not isinstance(release_slots, list)
83
+ await release_slots
90
84
  except anyio.get_cancelled_exc_class():
91
85
  # The task was cancelled before it could release the slots. Add the
92
86
  # slots to the cleanup list so they can be released when the
@@ -96,51 +90,4 @@ async def concurrency(
96
90
  (names_normalized, occupancy_period.total_seconds(), task_run_id)
97
91
  )
98
92
 
99
- _emit_concurrency_release_events(limits, emitted_events, task_run_id)
100
-
101
-
102
- @sync_compatible
103
- async def _acquire_concurrency_slots(
104
- names: List[str],
105
- task_run_id: UUID,
106
- timeout_seconds: Optional[float] = None,
107
- ) -> List[MinimalConcurrencyLimitResponse]:
108
- service = ConcurrencySlotAcquisitionService.instance(frozenset(names))
109
- future = service.send((task_run_id, timeout_seconds))
110
- response_or_exception = await asyncio.wrap_future(future)
111
-
112
- if isinstance(response_or_exception, Exception):
113
- if isinstance(response_or_exception, TimeoutError):
114
- raise AcquireConcurrencySlotTimeoutError(
115
- f"Attempt to acquire concurrency limits timed out after {timeout_seconds} second(s)"
116
- ) from response_or_exception
117
-
118
- raise ConcurrencySlotAcquisitionError(
119
- f"Unable to acquire concurrency limits {names!r}"
120
- ) from response_or_exception
121
-
122
- return _response_to_concurrency_limit_response(response_or_exception)
123
-
124
-
125
- @sync_compatible
126
- async def _release_concurrency_slots(
127
- names: List[str],
128
- task_run_id: UUID,
129
- occupancy_seconds: float,
130
- ) -> List[MinimalConcurrencyLimitResponse]:
131
- async with get_client() as client:
132
- response = await client.decrement_v1_concurrency_slots(
133
- names=names,
134
- task_run_id=task_run_id,
135
- occupancy_seconds=occupancy_seconds,
136
- )
137
- return _response_to_concurrency_limit_response(response)
138
-
139
-
140
- def _response_to_concurrency_limit_response(
141
- response: httpx.Response,
142
- ) -> List[MinimalConcurrencyLimitResponse]:
143
- data = response.json() or []
144
- return [
145
- MinimalConcurrencyLimitResponse.model_validate(limit) for limit in data if data
146
- ]
93
+ emit_concurrency_release_events(limits, emitted_events, task_run_id)
@@ -1,20 +1,22 @@
1
1
  from contextvars import ContextVar
2
- from typing import List, Tuple
2
+ from typing import Any, ClassVar
3
3
  from uuid import UUID
4
4
 
5
+ from typing_extensions import Self
6
+
5
7
  from prefect.client.orchestration import get_client
6
8
  from prefect.context import ContextModel, Field
7
9
 
8
10
 
9
11
  class ConcurrencyContext(ContextModel):
10
- __var__: ContextVar = ContextVar("concurrency_v1")
12
+ __var__: ClassVar[ContextVar[Self]] = ContextVar("concurrency_v1")
11
13
 
12
14
  # Track the limits that have been acquired but were not able to be released
13
15
  # due to cancellation or some other error. These limits are released when
14
16
  # the context manager exits.
15
- cleanup_slots: List[Tuple[List[str], float, UUID]] = Field(default_factory=list)
17
+ cleanup_slots: list[tuple[list[str], float, UUID]] = Field(default_factory=list)
16
18
 
17
- def __exit__(self, *exc_info):
19
+ def __exit__(self, *exc_info: Any) -> None:
18
20
  if self.cleanup_slots:
19
21
  with get_client(sync_client=True) as client:
20
22
  for names, occupancy_seconds, task_run_id in self.cleanup_slots:
@@ -1,21 +1,16 @@
1
1
  import asyncio
2
- import concurrent.futures
2
+ from collections.abc import AsyncGenerator
3
3
  from contextlib import asynccontextmanager
4
4
  from json import JSONDecodeError
5
- from typing import (
6
- TYPE_CHECKING,
7
- AsyncGenerator,
8
- FrozenSet,
9
- Optional,
10
- Tuple,
11
- )
5
+ from typing import TYPE_CHECKING, Optional
12
6
  from uuid import UUID
13
7
 
14
8
  import httpx
15
9
  from starlette import status
10
+ from typing_extensions import Unpack
16
11
 
17
12
  from prefect._internal.concurrency import logger
18
- from prefect._internal.concurrency.services import QueueService
13
+ from prefect._internal.concurrency.services import FutureQueueService
19
14
  from prefect.client.orchestration import get_client
20
15
  from prefect.utilities.timeout import timeout_async
21
16
 
@@ -27,11 +22,13 @@ class ConcurrencySlotAcquisitionServiceError(Exception):
27
22
  """Raised when an error occurs while acquiring concurrency slots."""
28
23
 
29
24
 
30
- class ConcurrencySlotAcquisitionService(QueueService):
31
- def __init__(self, concurrency_limit_names: FrozenSet[str]):
25
+ class ConcurrencySlotAcquisitionService(
26
+ FutureQueueService[Unpack[tuple[UUID, Optional[float]]], httpx.Response]
27
+ ):
28
+ def __init__(self, concurrency_limit_names: frozenset[str]) -> None:
32
29
  super().__init__(concurrency_limit_names)
33
- self._client: "PrefectClient"
34
- self.concurrency_limit_names = sorted(list(concurrency_limit_names))
30
+ self._client: PrefectClient
31
+ self.concurrency_limit_names: list[str] = sorted(list(concurrency_limit_names))
35
32
 
36
33
  @asynccontextmanager
37
34
  async def _lifespan(self) -> AsyncGenerator[None, None]:
@@ -39,78 +36,35 @@ class ConcurrencySlotAcquisitionService(QueueService):
39
36
  self._client = client
40
37
  yield
41
38
 
42
- async def _handle(
43
- self,
44
- item: Tuple[
45
- UUID,
46
- concurrent.futures.Future,
47
- Optional[float],
48
- ],
49
- ) -> None:
50
- task_run_id, future, timeout_seconds = item
51
- try:
52
- response = await self.acquire_slots(task_run_id, timeout_seconds)
53
- except Exception as exc:
54
- # If the request to the increment endpoint fails in a non-standard
55
- # way, we need to set the future's result so that the caller can
56
- # handle the exception and then re-raise.
57
- future.set_result(exc)
58
- raise exc
59
- else:
60
- future.set_result(response)
61
-
62
- async def acquire_slots(
63
- self,
64
- task_run_id: UUID,
65
- timeout_seconds: Optional[float] = None,
39
+ async def acquire(
40
+ self, task_run_id: UUID, timeout_seconds: Optional[float] = None
66
41
  ) -> httpx.Response:
67
42
  with timeout_async(seconds=timeout_seconds):
68
43
  while True:
69
44
  try:
70
- response = await self._client.increment_v1_concurrency_slots(
45
+ return await self._client.increment_v1_concurrency_slots(
71
46
  task_run_id=task_run_id,
72
47
  names=self.concurrency_limit_names,
73
48
  )
74
- except Exception as exc:
75
- if (
76
- isinstance(exc, httpx.HTTPStatusError)
77
- and exc.response.status_code == status.HTTP_423_LOCKED
78
- ):
79
- retry_after = exc.response.headers.get("Retry-After")
80
- if retry_after:
81
- retry_after = float(retry_after)
82
- await asyncio.sleep(retry_after)
83
- else:
84
- # We received a 423 but no Retry-After header. This
85
- # should indicate that the server told us to abort
86
- # because the concurrency limit is set to 0, i.e.
87
- # effectively disabled.
88
- try:
89
- reason = exc.response.json()["detail"]
90
- except (JSONDecodeError, KeyError):
91
- logger.error(
92
- "Failed to parse response from concurrency limit 423 Locked response: %s",
93
- exc.response.content,
94
- )
95
- reason = "Concurrency limit is locked (server did not specify the reason)"
96
- raise ConcurrencySlotAcquisitionServiceError(
97
- reason
98
- ) from exc
49
+ except httpx.HTTPStatusError as exc:
50
+ if not exc.response.status_code == status.HTTP_423_LOCKED:
51
+ raise
99
52
 
53
+ retry_after = exc.response.headers.get("Retry-After")
54
+ if retry_after:
55
+ retry_after = float(retry_after)
56
+ await asyncio.sleep(retry_after)
100
57
  else:
101
- raise exc # type: ignore
102
- else:
103
- return response
104
-
105
- def send(self, item: Tuple[UUID, Optional[float]]) -> concurrent.futures.Future:
106
- with self._lock:
107
- if self._stopped:
108
- raise RuntimeError("Cannot put items in a stopped service instance.")
109
-
110
- logger.debug("Service %r enqueuing item %r", self, item)
111
- future: concurrent.futures.Future = concurrent.futures.Future()
112
-
113
- task_run_id, timeout_seconds = item
114
- self._queue.put_nowait((task_run_id, future, timeout_seconds))
115
-
116
- return future
58
+ # We received a 423 but no Retry-After header. This
59
+ # should indicate that the server told us to abort
60
+ # because the concurrency limit is set to 0, i.e.
61
+ # effectively disabled.
62
+ try:
63
+ reason = exc.response.json()["detail"]
64
+ except (JSONDecodeError, KeyError):
65
+ logger.error(
66
+ "Failed to parse response from concurrency limit 423 Locked response: %s",
67
+ exc.response.content,
68
+ )
69
+ reason = "Concurrency limit is locked (server did not specify the reason)"
70
+ raise ConcurrencySlotAcquisitionServiceError(reason) from exc
@@ -1,31 +1,15 @@
1
+ import asyncio
2
+ from collections.abc import Generator
1
3
  from contextlib import contextmanager
2
- from typing import (
3
- Generator,
4
- List,
5
- Optional,
6
- TypeVar,
7
- Union,
8
- cast,
9
- )
4
+ from typing import Optional, TypeVar, Union
10
5
  from uuid import UUID
11
6
 
12
7
  import pendulum
13
8
 
14
- from ...client.schemas.responses import MinimalConcurrencyLimitResponse
15
-
16
- try:
17
- from pendulum import Interval
18
- except ImportError:
19
- # pendulum < 3
20
- from pendulum.period import Period as Interval # type: ignore
21
-
22
- from .asyncio import (
23
- _acquire_concurrency_slots,
24
- _release_concurrency_slots,
25
- )
26
- from .events import (
27
- _emit_concurrency_acquisition_events,
28
- _emit_concurrency_release_events,
9
+ from ._asyncio import acquire_concurrency_slots, release_concurrency_slots
10
+ from ._events import (
11
+ emit_concurrency_acquisition_events,
12
+ emit_concurrency_release_events,
29
13
  )
30
14
 
31
15
  T = TypeVar("T")
@@ -33,7 +17,7 @@ T = TypeVar("T")
33
17
 
34
18
  @contextmanager
35
19
  def concurrency(
36
- names: Union[str, List[str]],
20
+ names: Union[str, list[str]],
37
21
  task_run_id: UUID,
38
22
  timeout_seconds: Optional[float] = None,
39
23
  ) -> Generator[None, None, None]:
@@ -69,23 +53,20 @@ def concurrency(
69
53
 
70
54
  names = names if isinstance(names, list) else [names]
71
55
 
72
- limits: List[MinimalConcurrencyLimitResponse] = _acquire_concurrency_slots(
73
- names,
74
- timeout_seconds=timeout_seconds,
75
- task_run_id=task_run_id,
76
- _sync=True,
56
+ force = {"_sync": True}
57
+ result = acquire_concurrency_slots(
58
+ names, timeout_seconds=timeout_seconds, task_run_id=task_run_id, **force
77
59
  )
60
+ assert not asyncio.iscoroutine(result)
61
+ limits = result
78
62
  acquisition_time = pendulum.now("UTC")
79
- emitted_events = _emit_concurrency_acquisition_events(limits, task_run_id)
63
+ emitted_events = emit_concurrency_acquisition_events(limits, task_run_id)
80
64
 
81
65
  try:
82
66
  yield
83
67
  finally:
84
- occupancy_period = cast(Interval, pendulum.now("UTC") - acquisition_time)
85
- _release_concurrency_slots(
86
- names,
87
- task_run_id,
88
- occupancy_period.total_seconds(),
89
- _sync=True,
68
+ occupancy_period = pendulum.now("UTC") - acquisition_time
69
+ release_concurrency_slots(
70
+ names, task_run_id, occupancy_period.total_seconds(), **force
90
71
  )
91
- _emit_concurrency_release_events(limits, emitted_events, task_run_id)
72
+ emit_concurrency_release_events(limits, emitted_events, task_run_id)
prefect/context.py CHANGED
@@ -12,7 +12,7 @@ import warnings
12
12
  from collections.abc import AsyncGenerator, Generator, Mapping
13
13
  from contextlib import ExitStack, asynccontextmanager, contextmanager
14
14
  from contextvars import ContextVar, Token
15
- from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
15
+ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, TypeVar, Union
16
16
 
17
17
  from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
18
18
  from typing_extensions import Self
@@ -47,17 +47,11 @@ if TYPE_CHECKING:
47
47
  from prefect.flows import Flow
48
48
  from prefect.tasks import Task
49
49
 
50
- # Define the global settings context variable
51
- # This will be populated downstream but must be null here to facilitate loading the
52
- # default settings.
53
- GLOBAL_SETTINGS_CONTEXT = None # type: ignore
54
-
55
50
 
56
51
  def serialize_context() -> dict[str, Any]:
57
52
  """
58
53
  Serialize the current context for use in a remote execution environment.
59
54
  """
60
-
61
55
  flow_run_context = EngineContext.get()
62
56
  task_run_context = TaskRunContext.get()
63
57
  tags_context = TagsContext.get()
@@ -75,7 +69,22 @@ def serialize_context() -> dict[str, Any]:
75
69
  def hydrated_context(
76
70
  serialized_context: Optional[dict[str, Any]] = None,
77
71
  client: Union[PrefectClient, SyncPrefectClient, None] = None,
78
- ):
72
+ ) -> Generator[None, Any, None]:
73
+ # We need to rebuild the models because we might be hydrating in a remote
74
+ # environment where the models are not available.
75
+ # TODO: Remove this once we have fixed our circular imports and we don't need to rebuild models any more.
76
+ from prefect.flows import Flow
77
+ from prefect.results import ResultRecordMetadata
78
+ from prefect.tasks import Task
79
+
80
+ _types: dict[str, Any] = dict(
81
+ Flow=Flow,
82
+ Task=Task,
83
+ ResultRecordMetadata=ResultRecordMetadata,
84
+ )
85
+ FlowRunContext.model_rebuild(_types_namespace=_types)
86
+ TaskRunContext.model_rebuild(_types_namespace=_types)
87
+
79
88
  with ExitStack() as stack:
80
89
  if serialized_context:
81
90
  # Set up settings context
@@ -112,10 +121,15 @@ class ContextModel(BaseModel):
112
121
  a context manager
113
122
  """
114
123
 
124
+ if TYPE_CHECKING:
125
+ # subclasses can pass through keyword arguments to the pydantic base model
126
+ def __init__(self, **kwargs: Any) -> None:
127
+ ...
128
+
115
129
  # The context variable for storing data must be defined by the child class
116
- __var__: ContextVar[Self]
130
+ __var__: ClassVar[ContextVar[Self]]
117
131
  _token: Optional[Token[Self]] = PrivateAttr(None)
118
- model_config = ConfigDict(
132
+ model_config: ClassVar[ConfigDict] = ConfigDict(
119
133
  arbitrary_types_allowed=True,
120
134
  extra="forbid",
121
135
  )
@@ -128,7 +142,7 @@ class ContextModel(BaseModel):
128
142
  self._token = self.__var__.set(self)
129
143
  return self
130
144
 
131
- def __exit__(self, *_):
145
+ def __exit__(self, *_: Any) -> None:
132
146
  if not self._token:
133
147
  raise RuntimeError(
134
148
  "Asymmetric use of context. Context exit called without an enter."
@@ -143,7 +157,7 @@ class ContextModel(BaseModel):
143
157
 
144
158
  def model_copy(
145
159
  self: Self, *, update: Optional[Mapping[str, Any]] = None, deep: bool = False
146
- ):
160
+ ) -> Self:
147
161
  """
148
162
  Duplicate the context model, optionally choosing which fields to include, exclude, or change.
149
163
 
@@ -191,19 +205,19 @@ class SyncClientContext(ContextModel):
191
205
  assert c1 is ctx.client
192
206
  """
193
207
 
194
- __var__: ContextVar[Self] = ContextVar("sync-client-context")
208
+ __var__: ClassVar[ContextVar[Self]] = ContextVar("sync-client-context")
195
209
  client: SyncPrefectClient
196
210
  _httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None)
197
211
  _context_stack: int = PrivateAttr(0)
198
212
 
199
- def __init__(self, httpx_settings: Optional[dict[str, Any]] = None):
213
+ def __init__(self, httpx_settings: Optional[dict[str, Any]] = None) -> None:
200
214
  super().__init__(
201
- client=get_client(sync_client=True, httpx_settings=httpx_settings), # type: ignore[reportCallIssue]
215
+ client=get_client(sync_client=True, httpx_settings=httpx_settings),
202
216
  )
203
217
  self._httpx_settings = httpx_settings
204
218
  self._context_stack = 0
205
219
 
206
- def __enter__(self):
220
+ def __enter__(self) -> Self:
207
221
  self._context_stack += 1
208
222
  if self._context_stack == 1:
209
223
  self.client.__enter__()
@@ -212,20 +226,20 @@ class SyncClientContext(ContextModel):
212
226
  else:
213
227
  return self
214
228
 
215
- def __exit__(self, *exc_info: Any):
229
+ def __exit__(self, *exc_info: Any) -> None:
216
230
  self._context_stack -= 1
217
231
  if self._context_stack == 0:
218
- self.client.__exit__(*exc_info) # type: ignore[reportUnknownMemberType]
219
- return super().__exit__(*exc_info) # type: ignore[reportUnknownMemberType]
232
+ self.client.__exit__(*exc_info)
233
+ return super().__exit__(*exc_info)
220
234
 
221
235
  @classmethod
222
236
  @contextmanager
223
- def get_or_create(cls) -> Generator["SyncClientContext", None, None]:
224
- ctx = SyncClientContext.get()
237
+ def get_or_create(cls) -> Generator[Self, None, None]:
238
+ ctx = cls.get()
225
239
  if ctx:
226
240
  yield ctx
227
241
  else:
228
- with SyncClientContext() as ctx:
242
+ with cls() as ctx:
229
243
  yield ctx
230
244
 
231
245
 
@@ -249,14 +263,14 @@ class AsyncClientContext(ContextModel):
249
263
  assert c1 is ctx.client
250
264
  """
251
265
 
252
- __var__ = ContextVar("async-client-context")
266
+ __var__: ClassVar[ContextVar[Self]] = ContextVar("async-client-context")
253
267
  client: PrefectClient
254
268
  _httpx_settings: Optional[dict[str, Any]] = PrivateAttr(None)
255
269
  _context_stack: int = PrivateAttr(0)
256
270
 
257
271
  def __init__(self, httpx_settings: Optional[dict[str, Any]] = None):
258
272
  super().__init__(
259
- client=get_client(sync_client=False, httpx_settings=httpx_settings), # type: ignore[reportCallIssue]
273
+ client=get_client(sync_client=False, httpx_settings=httpx_settings)
260
274
  )
261
275
  self._httpx_settings = httpx_settings
262
276
  self._context_stack = 0
@@ -273,8 +287,8 @@ class AsyncClientContext(ContextModel):
273
287
  async def __aexit__(self: Self, *exc_info: Any) -> None:
274
288
  self._context_stack -= 1
275
289
  if self._context_stack == 0:
276
- await self.client.__aexit__(*exc_info) # type: ignore[reportUnknownMemberType]
277
- return super().__exit__(*exc_info) # type: ignore[reportUnknownMemberType]
290
+ await self.client.__aexit__(*exc_info)
291
+ return super().__exit__(*exc_info)
278
292
 
279
293
  @classmethod
280
294
  @asynccontextmanager
@@ -297,7 +311,7 @@ class RunContext(ContextModel):
297
311
  client: The Prefect client instance being used for API communication
298
312
  """
299
313
 
300
- def __init__(self, *args: Any, **kwargs: Any):
314
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
301
315
  super().__init__(*args, **kwargs)
302
316
 
303
317
  start_client_metrics_server()
@@ -356,7 +370,7 @@ class EngineContext(RunContext):
356
370
  # Events worker to emit events
357
371
  events: Optional[EventsWorker] = None
358
372
 
359
- __var__: ContextVar[Self] = ContextVar("flow_run")
373
+ __var__: ClassVar[ContextVar[Self]] = ContextVar("flow_run")
360
374
 
361
375
  def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]:
362
376
  return self.model_dump(
@@ -398,7 +412,7 @@ class TaskRunContext(RunContext):
398
412
  result_store: ResultStore
399
413
  persist_result: bool = Field(default_factory=get_default_persist_setting_for_tasks)
400
414
 
401
- __var__ = ContextVar("task_run")
415
+ __var__: ClassVar[ContextVar[Self]] = ContextVar("task_run")
402
416
 
403
417
  def serialize(self: Self, include_secrets: bool = True) -> dict[str, Any]:
404
418
  return self.model_dump(
@@ -429,11 +443,11 @@ class TagsContext(ContextModel):
429
443
  current_tags: set[str] = Field(default_factory=set)
430
444
 
431
445
  @classmethod
432
- def get(cls) -> "TagsContext":
446
+ def get(cls) -> Self:
433
447
  # Return an empty `TagsContext` instead of `None` if no context exists
434
- return cls.__var__.get(TagsContext())
448
+ return cls.__var__.get(cls())
435
449
 
436
- __var__: ContextVar[Self] = ContextVar("tags")
450
+ __var__: ClassVar[ContextVar[Self]] = ContextVar("tags")
437
451
 
438
452
 
439
453
  class SettingsContext(ContextModel):
@@ -450,15 +464,21 @@ class SettingsContext(ContextModel):
450
464
  profile: Profile
451
465
  settings: Settings
452
466
 
453
- __var__: ContextVar[Self] = ContextVar("settings")
467
+ __var__: ClassVar[ContextVar[Self]] = ContextVar("settings")
454
468
 
455
469
  def __hash__(self: Self) -> int:
456
470
  return hash(self.settings)
457
471
 
458
472
  @classmethod
459
- def get(cls) -> "SettingsContext":
473
+ def get(cls) -> Optional["SettingsContext"]:
460
474
  # Return the global context instead of `None` if no context exists
461
- return super().get() or GLOBAL_SETTINGS_CONTEXT
475
+ try:
476
+ return super().get() or GLOBAL_SETTINGS_CONTEXT
477
+ except NameError:
478
+ # GLOBAL_SETTINGS_CONTEXT has not yet been set; in order to create
479
+ # it profiles need to be loaded, and that process calls
480
+ # SettingsContext.get().
481
+ return None
462
482
 
463
483
 
464
484
  def get_run_context() -> Union[FlowRunContext, TaskRunContext]:
@@ -559,10 +579,10 @@ def tags(*new_tags: str) -> Generator[set[str], None, None]:
559
579
 
560
580
  @contextmanager
561
581
  def use_profile(
562
- profile: Union[Profile, str, Any],
582
+ profile: Union[Profile, str],
563
583
  override_environment_variables: bool = False,
564
584
  include_current_context: bool = True,
565
- ):
585
+ ) -> Generator[SettingsContext, Any, None]:
566
586
  """
567
587
  Switch to a profile for the duration of this context.
568
588
 
@@ -584,11 +604,12 @@ def use_profile(
584
604
  profiles = prefect.settings.load_profiles()
585
605
  profile = profiles[profile]
586
606
 
587
- if not isinstance(profile, Profile):
588
- raise TypeError(
589
- f"Unexpected type {type(profile).__name__!r} for `profile`. "
590
- "Expected 'str' or 'Profile'."
591
- )
607
+ if not TYPE_CHECKING:
608
+ if not isinstance(profile, Profile):
609
+ raise TypeError(
610
+ f"Unexpected type {type(profile).__name__!r} for `profile`. "
611
+ "Expected 'str' or 'Profile'."
612
+ )
592
613
 
593
614
  # Create a copy of the profiles settings as we will mutate it
594
615
  profile_settings = profile.settings.copy()
@@ -609,7 +630,7 @@ def use_profile(
609
630
  yield ctx
610
631
 
611
632
 
612
- def root_settings_context():
633
+ def root_settings_context() -> SettingsContext:
613
634
  """
614
635
  Return the settings context that will exist as the root context for the module.
615
636
 
@@ -659,9 +680,9 @@ def root_settings_context():
659
680
  # an override in the `SettingsContext.get` method.
660
681
 
661
682
 
662
- GLOBAL_SETTINGS_CONTEXT: SettingsContext = root_settings_context() # type: ignore[reportConstantRedefinition]
683
+ GLOBAL_SETTINGS_CONTEXT: SettingsContext = root_settings_context()
663
684
 
664
685
 
665
686
  # 2024-07-02: This surfaces an actionable error message for removed objects
666
687
  # in Prefect 3.0 upgrade.
667
- __getattr__ = getattr_migration(__name__)
688
+ __getattr__: Callable[[str], Any] = getattr_migration(__name__)