prefect-client 2.14.9__py3-none-any.whl → 2.14.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. prefect/__init__.py +4 -1
  2. prefect/_internal/pydantic/v2_schema.py +9 -2
  3. prefect/client/orchestration.py +51 -4
  4. prefect/client/schemas/objects.py +16 -1
  5. prefect/deployments/runner.py +34 -3
  6. prefect/engine.py +302 -25
  7. prefect/events/clients.py +216 -5
  8. prefect/events/filters.py +214 -0
  9. prefect/exceptions.py +4 -0
  10. prefect/flows.py +16 -0
  11. prefect/infrastructure/base.py +106 -1
  12. prefect/infrastructure/container.py +52 -0
  13. prefect/infrastructure/kubernetes.py +64 -0
  14. prefect/infrastructure/process.py +38 -0
  15. prefect/infrastructure/provisioners/__init__.py +2 -0
  16. prefect/infrastructure/provisioners/cloud_run.py +206 -34
  17. prefect/infrastructure/provisioners/container_instance.py +1080 -0
  18. prefect/infrastructure/provisioners/ecs.py +483 -48
  19. prefect/input/__init__.py +11 -0
  20. prefect/input/actions.py +88 -0
  21. prefect/input/run_input.py +107 -0
  22. prefect/runner/runner.py +5 -0
  23. prefect/runner/server.py +92 -8
  24. prefect/runner/utils.py +92 -0
  25. prefect/settings.py +34 -9
  26. prefect/states.py +26 -3
  27. prefect/utilities/dockerutils.py +31 -0
  28. prefect/utilities/processutils.py +5 -2
  29. prefect/utilities/services.py +10 -0
  30. prefect/utilities/validation.py +63 -0
  31. prefect/workers/__init__.py +1 -0
  32. prefect/workers/block.py +226 -0
  33. prefect/workers/utilities.py +2 -2
  34. {prefect_client-2.14.9.dist-info → prefect_client-2.14.11.dist-info}/METADATA +2 -1
  35. {prefect_client-2.14.9.dist-info → prefect_client-2.14.11.dist-info}/RECORD +38 -30
  36. {prefect_client-2.14.9.dist-info → prefect_client-2.14.11.dist-info}/LICENSE +0 -0
  37. {prefect_client-2.14.9.dist-info → prefect_client-2.14.11.dist-info}/WHEEL +0 -0
  38. {prefect_client-2.14.9.dist-info → prefect_client-2.14.11.dist-info}/top_level.txt +0 -0
prefect/events/clients.py CHANGED
@@ -1,12 +1,42 @@
1
1
  import abc
2
2
  import asyncio
3
3
  from types import TracebackType
4
- from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type
5
-
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ ClassVar,
8
+ Dict,
9
+ List,
10
+ Mapping,
11
+ Optional,
12
+ Tuple,
13
+ Type,
14
+ )
15
+ from uuid import UUID
16
+
17
+ import orjson
18
+ import pendulum
19
+
20
+ try:
21
+ from cachetools import TTLCache
22
+ except ImportError:
23
+ pass
24
+ from starlette.status import WS_1008_POLICY_VIOLATION
6
25
  from websockets.client import WebSocketClientProtocol, connect
7
- from websockets.exceptions import ConnectionClosed
26
+ from websockets.exceptions import (
27
+ ConnectionClosed,
28
+ ConnectionClosedError,
29
+ ConnectionClosedOK,
30
+ )
8
31
 
9
32
  from prefect.events import Event
33
+ from prefect.logging import get_logger
34
+ from prefect.settings import PREFECT_API_KEY, PREFECT_API_URL
35
+
36
+ if TYPE_CHECKING:
37
+ from prefect.events.filters import EventFilter
38
+
39
+ logger = get_logger(__name__)
10
40
 
11
41
 
12
42
  class EventsClient(abc.ABC):
@@ -79,6 +109,20 @@ class AssertingEventsClient(EventsClient):
79
109
  return self
80
110
 
81
111
 
112
+ def _get_api_url_and_key(
113
+ api_url: Optional[str], api_key: Optional[str]
114
+ ) -> Tuple[str, str]:
115
+ api_url = api_url or PREFECT_API_URL.value()
116
+ api_key = api_key or PREFECT_API_KEY.value()
117
+
118
+ if not api_url or not api_key:
119
+ raise ValueError(
120
+ "api_url and api_key must be provided or set in the Prefect configuration"
121
+ )
122
+
123
+ return api_url, api_key
124
+
125
+
82
126
  class PrefectCloudEventsClient(EventsClient):
83
127
  """A Prefect Events client that streams Events to a Prefect Cloud Workspace"""
84
128
 
@@ -87,8 +131,8 @@ class PrefectCloudEventsClient(EventsClient):
87
131
 
88
132
  def __init__(
89
133
  self,
90
- api_url: str,
91
- api_key: str,
134
+ api_url: str = None,
135
+ api_key: str = None,
92
136
  reconnection_attempts: int = 10,
93
137
  checkpoint_every: int = 20,
94
138
  ):
@@ -101,6 +145,8 @@ class PrefectCloudEventsClient(EventsClient):
101
145
  checkpoint_every: How often the client should sync with the server to
102
146
  confirm receipt of all previously sent events
103
147
  """
148
+ api_url, api_key = _get_api_url_and_key(api_url, api_key)
149
+
104
150
  socket_url = (
105
151
  api_url.replace("https://", "wss://")
106
152
  .replace("http://", "ws://")
@@ -195,3 +241,168 @@ class PrefectCloudEventsClient(EventsClient):
195
241
  # a standard load balancer timeout, but after that, just take a
196
242
  # beat to let things come back around.
197
243
  await asyncio.sleep(1)
244
+
245
+
246
+ SEEN_EVENTS_SIZE = 500_000
247
+ SEEN_EVENTS_TTL = 120
248
+
249
+
250
+ class PrefectCloudEventSubscriber:
251
+ """
252
+ Subscribes to a Prefect Cloud event stream, yielding events as they occur.
253
+
254
+ Example:
255
+
256
+ from prefect.events.clients import PrefectCloudEventSubscriber
257
+ from prefect.events.filters import EventFilter, EventNameFilter
258
+
259
+ filter = EventFilter(event=EventNameFilter(prefix=["prefect.flow-run."]))
260
+
261
+ async with PrefectCloudEventSubscriber(api_url, api_key, filter) as subscriber:
262
+ async for event in subscriber:
263
+ print(event.occurred, event.resource.id, event.event)
264
+
265
+ """
266
+
267
+ _websocket: Optional[WebSocketClientProtocol]
268
+ _filter: "EventFilter"
269
+ _seen_events: Mapping[UUID, bool]
270
+
271
+ def __init__(
272
+ self,
273
+ api_url: str = None,
274
+ api_key: str = None,
275
+ filter: "EventFilter" = None,
276
+ reconnection_attempts: int = 10,
277
+ ):
278
+ """
279
+ Args:
280
+ api_url: The base URL for a Prefect Cloud workspace
281
+ api_key: The API of an actor with the manage_events scope
282
+ reconnection_attempts: When the client is disconnected, how many times
283
+ the client should attempt to reconnect
284
+ """
285
+ api_url, api_key = _get_api_url_and_key(api_url, api_key)
286
+
287
+ from prefect.events.filters import EventFilter
288
+
289
+ self._filter = filter or EventFilter()
290
+ self._seen_events = TTLCache(maxsize=SEEN_EVENTS_SIZE, ttl=SEEN_EVENTS_TTL)
291
+
292
+ socket_url = (
293
+ api_url.replace("https://", "wss://")
294
+ .replace("http://", "ws://")
295
+ .rstrip("/")
296
+ ) + "/events/out"
297
+
298
+ logger.debug("Connecting to %s", socket_url)
299
+
300
+ self._api_key = api_key
301
+ self._connect = connect(
302
+ socket_url,
303
+ subprotocols=["prefect"],
304
+ )
305
+ self._websocket = None
306
+ self._reconnection_attempts = reconnection_attempts
307
+
308
+ async def __aenter__(self) -> "PrefectCloudEventSubscriber":
309
+ # Don't handle any errors in the initial connection, because these are most
310
+ # likely a permission or configuration issue that should propagate
311
+ await self._reconnect()
312
+ return self
313
+
314
+ async def _reconnect(self) -> None:
315
+ logger.debug("Reconnecting...")
316
+ if self._websocket:
317
+ self._websocket = None
318
+ await self._connect.__aexit__(None, None, None)
319
+
320
+ self._websocket = await self._connect.__aenter__()
321
+
322
+ # make sure we have actually connected
323
+ logger.debug(" pinging...")
324
+ pong = await self._websocket.ping()
325
+ await pong
326
+
327
+ logger.debug(" authenticating...")
328
+ await self._websocket.send(
329
+ orjson.dumps({"type": "auth", "token": self._api_key}).decode()
330
+ )
331
+
332
+ try:
333
+ message = orjson.loads(await self._websocket.recv())
334
+ logger.debug(" auth result %s", message)
335
+ assert message["type"] == "auth_success"
336
+ except (AssertionError, ConnectionClosedError) as e:
337
+ if isinstance(e, AssertionError) or e.code == WS_1008_POLICY_VIOLATION:
338
+ raise Exception(
339
+ "Unable to authenticate to the event stream. Please ensure the "
340
+ "provided api_key you are using is valid for this environment."
341
+ ) from e
342
+ raise
343
+
344
+ from prefect.events.filters import EventOccurredFilter
345
+
346
+ self._filter.occurred = EventOccurredFilter(
347
+ since=pendulum.now("UTC").subtract(minutes=1),
348
+ until=pendulum.now("UTC").add(years=1),
349
+ )
350
+
351
+ logger.debug(" filtering events since %s...", self._filter.occurred.since)
352
+ filter_message = {
353
+ "type": "filter",
354
+ "filter": self._filter.dict(json_compatible=True),
355
+ }
356
+ await self._websocket.send(orjson.dumps(filter_message).decode())
357
+
358
+ async def __aexit__(
359
+ self,
360
+ exc_type: Optional[Type[Exception]],
361
+ exc_val: Optional[Exception],
362
+ exc_tb: Optional[TracebackType],
363
+ ) -> None:
364
+ self._websocket = None
365
+ await self._connect.__aexit__(exc_type, exc_val, exc_tb)
366
+
367
+ def __aiter__(self) -> "PrefectCloudEventSubscriber":
368
+ return self
369
+
370
+ async def __anext__(self) -> Event:
371
+ for i in range(self._reconnection_attempts + 1):
372
+ try:
373
+ # If we're here and the websocket is None, then we've had a failure in a
374
+ # previous reconnection attempt.
375
+ #
376
+ # Otherwise, after the first time through this loop, we're recovering
377
+ # from a ConnectionClosed, so reconnect now.
378
+ if not self._websocket or i > 0:
379
+ await self._reconnect()
380
+ assert self._websocket
381
+
382
+ while True:
383
+ message = orjson.loads(await self._websocket.recv())
384
+ event = Event.parse_obj(message["event"])
385
+
386
+ if event.id in self._seen_events:
387
+ continue
388
+ self._seen_events[event.id] = True
389
+
390
+ return event
391
+ except ConnectionClosedOK:
392
+ logger.debug('Connection closed with "OK" status')
393
+ raise StopAsyncIteration
394
+ except ConnectionClosed:
395
+ logger.debug(
396
+ "Connection closed with %s/%s attempts",
397
+ i + 1,
398
+ self._reconnection_attempts,
399
+ )
400
+ if i == self._reconnection_attempts:
401
+ # this was our final chance, raise the most recent error
402
+ raise
403
+
404
+ if i > 2:
405
+ # let the first two attempts happen quickly in case this is just
406
+ # a standard load balancer timeout, but after that, just take a
407
+ # beat to let things come back around.
408
+ await asyncio.sleep(1)
@@ -0,0 +1,214 @@
1
+ from typing import List, Optional, Tuple, cast
2
+ from uuid import UUID
3
+
4
+ import pendulum
5
+
6
+ from prefect._internal.pydantic import HAS_PYDANTIC_V2
7
+ from prefect.events.schemas import Event, Resource, ResourceSpecification
8
+ from prefect.server.utilities.schemas import DateTimeTZ, PrefectBaseModel
9
+
10
+ if HAS_PYDANTIC_V2:
11
+ from pydantic.v1 import Field
12
+ else:
13
+ from pydantic import Field
14
+
15
+
16
+ class EventDataFilter(PrefectBaseModel):
17
+ """A base class for filtering event data."""
18
+
19
+ class Config:
20
+ extra = "forbid"
21
+
22
+ def get_filters(self) -> List["EventDataFilter"]:
23
+ return [
24
+ filter
25
+ for filter in [
26
+ getattr(self, name)
27
+ for name, field in self.__fields__.items()
28
+ if issubclass(field.type_, EventDataFilter)
29
+ ]
30
+ if filter
31
+ ]
32
+
33
+ def includes(self, event: Event) -> bool:
34
+ """Does the given event match the criteria of this filter?"""
35
+ return all(filter.includes(event) for filter in self.get_filters())
36
+
37
+ def excludes(self, event: Event) -> bool:
38
+ """Would the given filter exclude this event?"""
39
+ return not self.includes(event)
40
+
41
+
42
+ class EventOccurredFilter(EventDataFilter):
43
+ since: DateTimeTZ = Field(
44
+ default_factory=lambda: cast(
45
+ DateTimeTZ,
46
+ pendulum.now("UTC").start_of("day").subtract(days=180),
47
+ ),
48
+ description="Only include events after this time (inclusive)",
49
+ )
50
+ until: DateTimeTZ = Field(
51
+ default_factory=lambda: cast(DateTimeTZ, pendulum.now("UTC")),
52
+ description="Only include events prior to this time (inclusive)",
53
+ )
54
+
55
+ def includes(self, event: Event) -> bool:
56
+ return self.since <= event.occurred <= self.until
57
+
58
+
59
+ class EventNameFilter(EventDataFilter):
60
+ prefix: Optional[List[str]] = Field(
61
+ None, description="Only include events matching one of these prefixes"
62
+ )
63
+ exclude_prefix: Optional[List[str]] = Field(
64
+ None, description="Exclude events matching one of these prefixes"
65
+ )
66
+
67
+ name: Optional[List[str]] = Field(
68
+ None, description="Only include events matching one of these names exactly"
69
+ )
70
+ exclude_name: Optional[List[str]] = Field(
71
+ None, description="Exclude events matching one of these names exactly"
72
+ )
73
+
74
+ def includes(self, event: Event) -> bool:
75
+ if self.prefix:
76
+ if not any(event.event.startswith(prefix) for prefix in self.prefix):
77
+ return False
78
+
79
+ if self.exclude_prefix:
80
+ if any(event.event.startswith(prefix) for prefix in self.exclude_prefix):
81
+ return False
82
+
83
+ if self.name:
84
+ if not any(event.event == name for name in self.name):
85
+ return False
86
+
87
+ if self.exclude_name:
88
+ if any(event.event == name for name in self.exclude_name):
89
+ return False
90
+
91
+ return True
92
+
93
+
94
+ class EventResourceFilter(EventDataFilter):
95
+ id: Optional[List[str]] = Field(
96
+ None, description="Only include events for resources with these IDs"
97
+ )
98
+ id_prefix: Optional[List[str]] = Field(
99
+ None,
100
+ description=(
101
+ "Only include events for resources with IDs starting with these prefixes."
102
+ ),
103
+ )
104
+ labels: Optional[ResourceSpecification] = Field(
105
+ None, description="Only include events for resources with these labels"
106
+ )
107
+
108
+ def includes(self, event: Event) -> bool:
109
+ if self.id:
110
+ if not any(event.resource.id == resource_id for resource_id in self.id):
111
+ return False
112
+
113
+ if self.id_prefix:
114
+ if not any(
115
+ event.resource.id.startswith(prefix) for prefix in self.id_prefix
116
+ ):
117
+ return False
118
+
119
+ if self.labels:
120
+ if not self.labels.matches(event.resource):
121
+ return False
122
+
123
+ return True
124
+
125
+
126
+ class EventRelatedFilter(EventDataFilter):
127
+ id: Optional[List[str]] = Field(
128
+ None, description="Only include events for related resources with these IDs"
129
+ )
130
+ role: Optional[List[str]] = Field(
131
+ None, description="Only include events for related resources in these roles"
132
+ )
133
+ resources_in_roles: Optional[List[Tuple[str, str]]] = Field(
134
+ None,
135
+ description=(
136
+ "Only include events with specific related resources in specific roles"
137
+ ),
138
+ )
139
+ labels: Optional[ResourceSpecification] = Field(
140
+ None, description="Only include events for related resources with these labels"
141
+ )
142
+
143
+
144
+ class EventAnyResourceFilter(EventDataFilter):
145
+ id: Optional[List[str]] = Field(
146
+ None, description="Only include events for resources with these IDs"
147
+ )
148
+ id_prefix: Optional[List[str]] = Field(
149
+ None,
150
+ description=(
151
+ "Only include events for resources with IDs starting with these prefixes"
152
+ ),
153
+ )
154
+ labels: Optional[ResourceSpecification] = Field(
155
+ None, description="Only include events for related resources with these labels"
156
+ )
157
+
158
+ def includes(self, event: Event) -> bool:
159
+ resources = [event.resource] + event.related
160
+ if not any(self._includes(resource) for resource in resources):
161
+ return False
162
+ return True
163
+
164
+ def _includes(self, resource: Resource) -> bool:
165
+ if self.id:
166
+ if not any(resource.id == resource_id for resource_id in self.id):
167
+ return False
168
+
169
+ if self.id_prefix:
170
+ if not any(resource.id.startswith(prefix) for prefix in self.id_prefix):
171
+ return False
172
+
173
+ if self.labels:
174
+ if not self.labels.matches(resource):
175
+ return False
176
+
177
+ return True
178
+
179
+
180
+ class EventIDFilter(EventDataFilter):
181
+ id: Optional[List[UUID]] = Field(
182
+ None, description="Only include events with one of these IDs"
183
+ )
184
+
185
+ def includes(self, event: Event) -> bool:
186
+ if self.id:
187
+ if not any(event.id == id for id in self.id):
188
+ return False
189
+
190
+ return True
191
+
192
+
193
+ class EventFilter(EventDataFilter):
194
+ occurred: EventOccurredFilter = Field(
195
+ default_factory=EventOccurredFilter,
196
+ description="Filter criteria for when the events occurred",
197
+ )
198
+ event: Optional[EventNameFilter] = Field(
199
+ None,
200
+ description="Filter criteria for the event name",
201
+ )
202
+ any_resource: Optional[EventAnyResourceFilter] = Field(
203
+ None, description="Filter criteria for any resource involved in the event"
204
+ )
205
+ resource: Optional[EventResourceFilter] = Field(
206
+ None, description="Filter criteria for the resource of the event"
207
+ )
208
+ related: Optional[EventRelatedFilter] = Field(
209
+ None, description="Filter criteria for the related resources of the event"
210
+ )
211
+ id: EventIDFilter = Field(
212
+ default_factory=EventIDFilter,
213
+ description="Filter criteria for the events' ID",
214
+ )
prefect/exceptions.py CHANGED
@@ -296,6 +296,10 @@ class Pause(PrefectSignal):
296
296
  Raised when a flow run is PAUSED and needs to exit for resubmission.
297
297
  """
298
298
 
299
+ def __init__(self, *args, state=None, **kwargs):
300
+ super().__init__(*args, **kwargs)
301
+ self.state = state
302
+
299
303
 
300
304
  class ExternalSignal(BaseException):
301
305
  """
prefect/flows.py CHANGED
@@ -556,6 +556,7 @@ class Flow(Generic[P, R]):
556
556
  cron: Optional[str] = None,
557
557
  rrule: Optional[str] = None,
558
558
  schedule: Optional[SCHEDULE_TYPES] = None,
559
+ is_schedule_active: Optional[bool] = None,
559
560
  parameters: Optional[dict] = None,
560
561
  triggers: Optional[List[DeploymentTrigger]] = None,
561
562
  description: Optional[str] = None,
@@ -578,6 +579,9 @@ class Flow(Generic[P, R]):
578
579
  timezone: A timezone to use for the schedule. Defaults to UTC.
579
580
  triggers: A list of triggers that will kick off runs of this deployment.
580
581
  schedule: A schedule object defining when to execute runs of this deployment.
582
+ is_schedule_active: Whether or not to set the schedule for this deployment as active. If
583
+ not provided when creating a deployment, the schedule will be set as active. If not
584
+ provided when updating a deployment, the schedule's activation will not be changed.
581
585
  parameters: A dictionary of default parameter values to pass to runs of this deployment.
582
586
  description: A description for the created deployment. Defaults to the flow's
583
587
  description if not provided.
@@ -623,6 +627,7 @@ class Flow(Generic[P, R]):
623
627
  cron=cron,
624
628
  rrule=rrule,
625
629
  schedule=schedule,
630
+ is_schedule_active=is_schedule_active,
626
631
  tags=tags,
627
632
  triggers=triggers,
628
633
  parameters=parameters or {},
@@ -641,6 +646,7 @@ class Flow(Generic[P, R]):
641
646
  cron=cron,
642
647
  rrule=rrule,
643
648
  schedule=schedule,
649
+ is_schedule_active=is_schedule_active,
644
650
  tags=tags,
645
651
  triggers=triggers,
646
652
  parameters=parameters or {},
@@ -660,6 +666,7 @@ class Flow(Generic[P, R]):
660
666
  cron: Optional[str] = None,
661
667
  rrule: Optional[str] = None,
662
668
  schedule: Optional[SCHEDULE_TYPES] = None,
669
+ is_schedule_active: Optional[bool] = None,
663
670
  triggers: Optional[List[DeploymentTrigger]] = None,
664
671
  parameters: Optional[dict] = None,
665
672
  description: Optional[str] = None,
@@ -682,6 +689,9 @@ class Flow(Generic[P, R]):
682
689
  triggers: A list of triggers that will kick off runs of this deployment.
683
690
  schedule: A schedule object defining when to execute runs of this deployment. Used to
684
691
  define additional scheduling options like `timezone`.
692
+ is_schedule_active: Whether or not to set the schedule for this deployment as active. If
693
+ not provided when creating a deployment, the schedule will be set as active. If not
694
+ provided when updating a deployment, the schedule's activation will not be changed.
685
695
  parameters: A dictionary of default parameter values to pass to runs of this deployment.
686
696
  description: A description for the created deployment. Defaults to the flow's
687
697
  description if not provided.
@@ -738,6 +748,7 @@ class Flow(Generic[P, R]):
738
748
  cron=cron,
739
749
  rrule=rrule,
740
750
  schedule=schedule,
751
+ is_schedule_active=is_schedule_active,
741
752
  parameters=parameters,
742
753
  description=description,
743
754
  tags=tags,
@@ -852,6 +863,7 @@ class Flow(Generic[P, R]):
852
863
  cron: Optional[str] = None,
853
864
  rrule: Optional[str] = None,
854
865
  schedule: Optional[SCHEDULE_TYPES] = None,
866
+ is_schedule_active: Optional[bool] = None,
855
867
  triggers: Optional[List[DeploymentTrigger]] = None,
856
868
  parameters: Optional[dict] = None,
857
869
  description: Optional[str] = None,
@@ -891,6 +903,9 @@ class Flow(Generic[P, R]):
891
903
  triggers: A list of triggers that will kick off runs of this deployment.
892
904
  schedule: A schedule object defining when to execute runs of this deployment. Used to
893
905
  define additional scheduling options like `timezone`.
906
+ is_schedule_active: Whether or not to set the schedule for this deployment as active. If
907
+ not provided when creating a deployment, the schedule will be set as active. If not
908
+ provided when updating a deployment, the schedule's activation will not be changed.
894
909
  parameters: A dictionary of default parameter values to pass to runs of this deployment.
895
910
  description: A description for the created deployment. Defaults to the flow's
896
911
  description if not provided.
@@ -956,6 +971,7 @@ class Flow(Generic[P, R]):
956
971
  cron=cron,
957
972
  rrule=rrule,
958
973
  schedule=schedule,
974
+ is_schedule_active=is_schedule_active,
959
975
  triggers=triggers,
960
976
  parameters=parameters,
961
977
  description=description,
@@ -11,22 +11,27 @@ from prefect._internal.compatibility.experimental import (
11
11
  experiment_enabled,
12
12
  )
13
13
  from prefect._internal.pydantic import HAS_PYDANTIC_V2
14
+ from prefect.client.schemas.actions import WorkPoolCreate
15
+ from prefect.exceptions import ObjectAlreadyExists
14
16
 
15
17
  if HAS_PYDANTIC_V2:
16
18
  import pydantic.v1 as pydantic
17
19
  else:
18
20
  import pydantic
19
21
 
22
+ from rich.console import Console
20
23
  from typing_extensions import Self
21
24
 
22
25
  import prefect
23
- from prefect.blocks.core import Block
26
+ from prefect.blocks.core import Block, BlockNotSavedError
24
27
  from prefect.logging import get_logger
25
28
  from prefect.settings import (
26
29
  PREFECT_EXPERIMENTAL_WARN,
27
30
  PREFECT_EXPERIMENTAL_WARN_ENHANCED_CANCELLATION,
31
+ PREFECT_UI_URL,
28
32
  get_current_settings,
29
33
  )
34
+ from prefect.utilities.asyncutils import sync_compatible
30
35
 
31
36
  MIN_COMPAT_PREFECT_VERSION = "2.0b12"
32
37
 
@@ -66,6 +71,106 @@ class Infrastructure(Block, abc.ABC):
66
71
  description="The command to run in the infrastructure.",
67
72
  )
68
73
 
74
+ async def generate_work_pool_base_job_template(self):
75
+ if self._block_document_id is None:
76
+ raise BlockNotSavedError(
77
+ "Cannot publish as work pool, block has not been saved. Please call"
78
+ " `.save()` on your block before publishing."
79
+ )
80
+
81
+ block_schema = self.__class__.schema()
82
+ return {
83
+ "job_configuration": {"block": "{{ block }}"},
84
+ "variables": {
85
+ "type": "object",
86
+ "properties": {
87
+ "block": {
88
+ "title": "Block",
89
+ "description": (
90
+ "The infrastructure block to use for job creation."
91
+ ),
92
+ "allOf": [{"$ref": f"#/definitions/{self.__class__.__name__}"}],
93
+ "default": {
94
+ "$ref": {"block_document_id": str(self._block_document_id)}
95
+ },
96
+ }
97
+ },
98
+ "required": ["block"],
99
+ "definitions": {self.__class__.__name__: block_schema},
100
+ },
101
+ }
102
+
103
+ def get_corresponding_worker_type(self):
104
+ return "block"
105
+
106
+ @sync_compatible
107
+ async def publish_as_work_pool(self, work_pool_name: Optional[str] = None):
108
+ """
109
+ Creates a work pool configured to use the given block as the job creator.
110
+
111
+ Used to migrate from a agents setup to a worker setup.
112
+
113
+ Args:
114
+ work_pool_name: The name to give to the created work pool. If not provided, the name of the current
115
+ block will be used.
116
+ """
117
+
118
+ base_job_template = await self.generate_work_pool_base_job_template()
119
+ work_pool_name = work_pool_name or self._block_document_name
120
+
121
+ if work_pool_name is None:
122
+ raise ValueError(
123
+ "`work_pool_name` must be provided if the block has not been saved."
124
+ )
125
+
126
+ console = Console()
127
+
128
+ try:
129
+ async with prefect.get_client() as client:
130
+ work_pool = await client.create_work_pool(
131
+ work_pool=WorkPoolCreate(
132
+ name=work_pool_name,
133
+ type=self.get_corresponding_worker_type(),
134
+ base_job_template=base_job_template,
135
+ )
136
+ )
137
+ except ObjectAlreadyExists:
138
+ console.print(
139
+ (
140
+ f"Work pool with name {work_pool_name!r} already exists, please use"
141
+ " a different name."
142
+ ),
143
+ style="red",
144
+ )
145
+ return
146
+
147
+ console.print(
148
+ f"Work pool {work_pool.name} created!",
149
+ style="green",
150
+ )
151
+ if PREFECT_UI_URL:
152
+ console.print(
153
+ "You see your new work pool in the UI at"
154
+ f" {PREFECT_UI_URL.value()}/work-pools/work-pool/{work_pool.name}"
155
+ )
156
+
157
+ deploy_script = (
158
+ "my_flow.deploy(work_pool_name='{work_pool.name}', image='my_image:tag')"
159
+ )
160
+ if not hasattr(self, "image"):
161
+ deploy_script = (
162
+ "my_flow.from_source(source='https://github.com/org/repo.git',"
163
+ f" entrypoint='flow.py:my_flow').deploy(work_pool_name='{work_pool.name}')"
164
+ )
165
+ console.print(
166
+ "\nYou can deploy a flow to this work pool by calling"
167
+ f" [blue].deploy[/]:\n\n\t{deploy_script}\n"
168
+ )
169
+ console.print(
170
+ "\nTo start a worker to execute flow runs in this work pool run:\n"
171
+ )
172
+ console.print(f"\t[blue]prefect worker start --pool {work_pool.name}[/]\n")
173
+
69
174
  @abc.abstractmethod
70
175
  async def run(
71
176
  self,