prefect-client 2.14.8__py3-none-any.whl → 2.14.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefect/__init__.py +4 -1
- prefect/client/orchestration.py +1 -2
- prefect/deployments/runner.py +5 -1
- prefect/deployments/steps/core.py +2 -4
- prefect/engine.py +176 -11
- prefect/events/clients.py +216 -5
- prefect/events/filters.py +214 -0
- prefect/exceptions.py +4 -0
- prefect/infrastructure/base.py +106 -1
- prefect/infrastructure/container.py +52 -0
- prefect/infrastructure/process.py +38 -0
- prefect/infrastructure/provisioners/__init__.py +6 -2
- prefect/infrastructure/provisioners/cloud_run.py +7 -1
- prefect/infrastructure/provisioners/container_instance.py +797 -0
- prefect/infrastructure/provisioners/ecs.py +907 -0
- prefect/runner/runner.py +14 -0
- prefect/runner/storage.py +12 -2
- prefect/states.py +26 -3
- prefect/utilities/services.py +10 -0
- prefect/workers/__init__.py +1 -0
- prefect/workers/block.py +226 -0
- prefect/workers/utilities.py +2 -1
- {prefect_client-2.14.8.dist-info → prefect_client-2.14.10.dist-info}/METADATA +2 -1
- {prefect_client-2.14.8.dist-info → prefect_client-2.14.10.dist-info}/RECORD +27 -23
- {prefect_client-2.14.8.dist-info → prefect_client-2.14.10.dist-info}/LICENSE +0 -0
- {prefect_client-2.14.8.dist-info → prefect_client-2.14.10.dist-info}/WHEEL +0 -0
- {prefect_client-2.14.8.dist-info → prefect_client-2.14.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,907 @@
|
|
1
|
+
import contextlib
|
2
|
+
import contextvars
|
3
|
+
import importlib
|
4
|
+
import ipaddress
|
5
|
+
import json
|
6
|
+
import shlex
|
7
|
+
import sys
|
8
|
+
from copy import deepcopy
|
9
|
+
from functools import partial
|
10
|
+
from textwrap import dedent
|
11
|
+
from typing import Any, Callable, Dict, List, Optional
|
12
|
+
|
13
|
+
import anyio
|
14
|
+
from anyio import run_process
|
15
|
+
from rich.console import Console
|
16
|
+
from rich.panel import Panel
|
17
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn
|
18
|
+
from rich.prompt import Confirm
|
19
|
+
|
20
|
+
from prefect.client.orchestration import PrefectClient
|
21
|
+
from prefect.client.schemas.actions import BlockDocumentCreate
|
22
|
+
from prefect.client.utilities import inject_client
|
23
|
+
from prefect.exceptions import ObjectNotFound
|
24
|
+
from prefect.utilities.collections import get_from_dict
|
25
|
+
from prefect.utilities.importtools import lazy_import
|
26
|
+
|
27
|
+
boto3 = docker = lazy_import("boto3")
|
28
|
+
|
29
|
+
current_console = contextvars.ContextVar("console", default=Console())
|
30
|
+
|
31
|
+
|
32
|
+
@contextlib.contextmanager
|
33
|
+
def console_context(value: Console):
|
34
|
+
token = current_console.set(value)
|
35
|
+
try:
|
36
|
+
yield
|
37
|
+
finally:
|
38
|
+
current_console.reset(token)
|
39
|
+
|
40
|
+
|
41
|
+
class IamPolicyResource:
|
42
|
+
"""
|
43
|
+
Represents an IAM policy resource for managing ECS tasks.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
policy_name: The name of the IAM policy. Defaults to "prefect-ecs-policy".
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
policy_name: str = "prefect-ecs-policy",
|
52
|
+
):
|
53
|
+
self._iam_client = boto3.client("iam")
|
54
|
+
self._policy_name = policy_name
|
55
|
+
self._policy_document = json.dumps(
|
56
|
+
{
|
57
|
+
"Version": "2012-10-17",
|
58
|
+
"Statement": [
|
59
|
+
{
|
60
|
+
"Sid": "PrefectEcsPolicy",
|
61
|
+
"Effect": "Allow",
|
62
|
+
"Action": [
|
63
|
+
"ec2:AuthorizeSecurityGroupIngress",
|
64
|
+
"ec2:CreateSecurityGroup",
|
65
|
+
"ec2:CreateTags",
|
66
|
+
"ec2:DescribeNetworkInterfaces",
|
67
|
+
"ec2:DescribeSecurityGroups",
|
68
|
+
"ec2:DescribeSubnets",
|
69
|
+
"ec2:DescribeVpcs",
|
70
|
+
"ecs:CreateCluster",
|
71
|
+
"ecs:DeregisterTaskDefinition",
|
72
|
+
"ecs:DescribeClusters",
|
73
|
+
"ecs:DescribeTaskDefinition",
|
74
|
+
"ecs:DescribeTasks",
|
75
|
+
"ecs:ListAccountSettings",
|
76
|
+
"ecs:ListClusters",
|
77
|
+
"ecs:ListTaskDefinitions",
|
78
|
+
"ecs:RegisterTaskDefinition",
|
79
|
+
"ecs:RunTask",
|
80
|
+
"ecs:StopTask",
|
81
|
+
"logs:CreateLogStream",
|
82
|
+
"logs:PutLogEvents",
|
83
|
+
"logs:DescribeLogGroups",
|
84
|
+
"logs:GetLogEvents",
|
85
|
+
],
|
86
|
+
"Resource": "*",
|
87
|
+
}
|
88
|
+
],
|
89
|
+
}
|
90
|
+
)
|
91
|
+
|
92
|
+
self._requires_provisioning = None
|
93
|
+
|
94
|
+
async def get_task_count(self) -> int:
|
95
|
+
"""
|
96
|
+
Returns the number of tasks that will be executed to provision this resource.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
int: The number of tasks to be provisioned.
|
100
|
+
"""
|
101
|
+
return 1 if await self.requires_provisioning() else 0
|
102
|
+
|
103
|
+
def _get_policy_by_name(self, name):
|
104
|
+
paginator = self._iam_client.get_paginator("list_policies")
|
105
|
+
page_iterator = paginator.paginate(Scope="Local")
|
106
|
+
|
107
|
+
for page in page_iterator:
|
108
|
+
for policy in page["Policies"]:
|
109
|
+
if policy["PolicyName"] == name:
|
110
|
+
return policy
|
111
|
+
return None
|
112
|
+
|
113
|
+
async def requires_provisioning(self) -> bool:
|
114
|
+
"""
|
115
|
+
Check if this resource requires provisioning.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
bool: True if provisioning is required, False otherwise.
|
119
|
+
"""
|
120
|
+
if self._requires_provisioning is not None:
|
121
|
+
return self._requires_provisioning
|
122
|
+
policy = await anyio.to_thread.run_sync(
|
123
|
+
partial(self._get_policy_by_name, self._policy_name)
|
124
|
+
)
|
125
|
+
if policy is not None:
|
126
|
+
self._requires_provisioning = False
|
127
|
+
return False
|
128
|
+
|
129
|
+
self._requires_provisioning = True
|
130
|
+
return True
|
131
|
+
|
132
|
+
async def get_planned_actions(self) -> List[str]:
|
133
|
+
"""
|
134
|
+
Returns a description of the planned actions for provisioning this resource.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
Optional[str]: A description of the planned actions for provisioning the resource,
|
138
|
+
or None if provisioning is not required.
|
139
|
+
"""
|
140
|
+
if await self.requires_provisioning():
|
141
|
+
return [
|
142
|
+
"Creating and attaching an IAM policy for managing ECS tasks:"
|
143
|
+
f" [blue]{self._policy_name}[/]"
|
144
|
+
]
|
145
|
+
return []
|
146
|
+
|
147
|
+
async def provision(
|
148
|
+
self,
|
149
|
+
advance: Callable[[], None],
|
150
|
+
):
|
151
|
+
"""
|
152
|
+
Provisions an IAM policy.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
advance: A callback function to indicate progress.
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
str: The ARN (Amazon Resource Name) of the created IAM policy.
|
159
|
+
"""
|
160
|
+
if await self.requires_provisioning():
|
161
|
+
console = current_console.get()
|
162
|
+
console.print("Creating IAM policy")
|
163
|
+
policy = await anyio.to_thread.run_sync(
|
164
|
+
partial(
|
165
|
+
self._iam_client.create_policy,
|
166
|
+
PolicyName=self._policy_name,
|
167
|
+
PolicyDocument=self._policy_document,
|
168
|
+
)
|
169
|
+
)
|
170
|
+
policy_arn = policy["Policy"]["Arn"]
|
171
|
+
advance()
|
172
|
+
return policy_arn
|
173
|
+
# TODO: read and return policy arn
|
174
|
+
|
175
|
+
|
176
|
+
class IamUserResource:
|
177
|
+
"""
|
178
|
+
Represents an IAM user resource for managing ECS tasks.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
user_name: The desired name of the IAM user.
|
182
|
+
"""
|
183
|
+
|
184
|
+
def __init__(self, user_name: str):
|
185
|
+
self._iam_client = boto3.client("iam")
|
186
|
+
self._user_name = user_name
|
187
|
+
self._requires_provisioning = None
|
188
|
+
|
189
|
+
async def get_task_count(self) -> int:
|
190
|
+
"""
|
191
|
+
Returns the number of tasks that will be executed to provision this resource.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
int: The number of tasks to be provisioned.
|
195
|
+
"""
|
196
|
+
return 1 if await self.requires_provisioning() else 0
|
197
|
+
|
198
|
+
async def requires_provisioning(self) -> bool:
|
199
|
+
"""
|
200
|
+
Check if this resource requires provisioning.
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
bool: True if provisioning is required, False otherwise.
|
204
|
+
"""
|
205
|
+
if self._requires_provisioning is None:
|
206
|
+
try:
|
207
|
+
await anyio.to_thread.run_sync(
|
208
|
+
partial(self._iam_client.get_user, UserName=self._user_name)
|
209
|
+
)
|
210
|
+
self._requires_provisioning = False
|
211
|
+
except self._iam_client.exceptions.NoSuchEntityException:
|
212
|
+
self._requires_provisioning = True
|
213
|
+
|
214
|
+
return self._requires_provisioning
|
215
|
+
|
216
|
+
async def get_planned_actions(self) -> List[str]:
|
217
|
+
"""
|
218
|
+
Returns a description of the planned actions for provisioning this resource.
|
219
|
+
|
220
|
+
Returns:
|
221
|
+
Optional[str]: A description of the planned actions for provisioning the resource,
|
222
|
+
or None if provisioning is not required.
|
223
|
+
"""
|
224
|
+
if await self.requires_provisioning():
|
225
|
+
return [
|
226
|
+
"Creating an IAM user for managing ECS tasks:"
|
227
|
+
f" [blue]{self._user_name}[/]"
|
228
|
+
]
|
229
|
+
return []
|
230
|
+
|
231
|
+
async def provision(
|
232
|
+
self,
|
233
|
+
advance: Callable[[], None],
|
234
|
+
):
|
235
|
+
"""
|
236
|
+
Provisions an IAM user.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
advance: A callback function to indicate progress.
|
240
|
+
"""
|
241
|
+
console = current_console.get()
|
242
|
+
if await self.requires_provisioning():
|
243
|
+
console.print("Provisioning IAM user")
|
244
|
+
await anyio.to_thread.run_sync(
|
245
|
+
partial(self._iam_client.create_user, UserName=self._user_name)
|
246
|
+
)
|
247
|
+
advance()
|
248
|
+
|
249
|
+
|
250
|
+
class CredentialsBlockResource:
|
251
|
+
def __init__(self, user_name: str, block_document_name: str):
|
252
|
+
self._block_document_name = block_document_name
|
253
|
+
self._user_name = user_name
|
254
|
+
self._requires_provisioning = None
|
255
|
+
|
256
|
+
async def get_task_count(self):
|
257
|
+
"""
|
258
|
+
Returns the number of tasks that will be executed to provision this resource.
|
259
|
+
|
260
|
+
Returns:
|
261
|
+
int: The number of tasks to be provisioned.
|
262
|
+
"""
|
263
|
+
return 2 if await self.requires_provisioning() else 0
|
264
|
+
|
265
|
+
@inject_client
|
266
|
+
async def requires_provisioning(
|
267
|
+
self, client: Optional[PrefectClient] = None
|
268
|
+
) -> bool:
|
269
|
+
if self._requires_provisioning is None:
|
270
|
+
try:
|
271
|
+
assert client is not None
|
272
|
+
await client.read_block_document_by_name(
|
273
|
+
self._block_document_name, "aws-credentials"
|
274
|
+
)
|
275
|
+
self._requires_provisioning = False
|
276
|
+
except ObjectNotFound:
|
277
|
+
self._requires_provisioning = True
|
278
|
+
return self._requires_provisioning
|
279
|
+
|
280
|
+
async def get_planned_actions(self) -> List[str]:
|
281
|
+
"""
|
282
|
+
Returns a description of the planned actions for provisioning this resource.
|
283
|
+
|
284
|
+
Returns:
|
285
|
+
Optional[str]: A description of the planned actions for provisioning the resource,
|
286
|
+
or None if provisioning is not required.
|
287
|
+
"""
|
288
|
+
if await self.requires_provisioning():
|
289
|
+
return ["Storing generated AWS credentials in a block"]
|
290
|
+
return []
|
291
|
+
|
292
|
+
@inject_client
|
293
|
+
async def provision(
|
294
|
+
self,
|
295
|
+
base_job_template: Dict[str, Any],
|
296
|
+
advance: Callable[[], None],
|
297
|
+
client: Optional[PrefectClient] = None,
|
298
|
+
):
|
299
|
+
"""
|
300
|
+
Provisions an AWS credentials block.
|
301
|
+
|
302
|
+
Will generate new credentials if the block does not already exist. Updates
|
303
|
+
the `aws_credentials` variable in the job template to reference the block.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
base_job_template: The base job template.
|
307
|
+
advance: A callback function to indicate progress.
|
308
|
+
client: A Prefect client to use for interacting with the Prefect API.
|
309
|
+
"""
|
310
|
+
assert client is not None, "Client injection failed"
|
311
|
+
if not await self.requires_provisioning():
|
312
|
+
block_doc = await client.read_block_document_by_name(
|
313
|
+
self._block_document_name, "aws-credentials"
|
314
|
+
)
|
315
|
+
else:
|
316
|
+
console = current_console.get()
|
317
|
+
console.print("Generating AWS credentials")
|
318
|
+
iam_client = boto3.client("iam")
|
319
|
+
access_key_data = await anyio.to_thread.run_sync(
|
320
|
+
partial(iam_client.create_access_key, UserName=self._user_name)
|
321
|
+
)
|
322
|
+
access_key = access_key_data["AccessKey"]
|
323
|
+
advance()
|
324
|
+
console.print("Creating AWS credentials block")
|
325
|
+
assert client is not None
|
326
|
+
|
327
|
+
try:
|
328
|
+
credentials_block_type = await client.read_block_type_by_slug(
|
329
|
+
"aws-credentials"
|
330
|
+
)
|
331
|
+
except ObjectNotFound as exc:
|
332
|
+
raise RuntimeError(
|
333
|
+
dedent(
|
334
|
+
"""\
|
335
|
+
Unable to find block type "aws-credentials".
|
336
|
+
To register the `aws-credentials` block type, run:
|
337
|
+
|
338
|
+
pip install prefect-aws
|
339
|
+
prefect blocks register -m prefect_aws
|
340
|
+
|
341
|
+
"""
|
342
|
+
)
|
343
|
+
) from exc
|
344
|
+
|
345
|
+
credentials_block_schema = (
|
346
|
+
await client.get_most_recent_block_schema_for_block_type(
|
347
|
+
block_type_id=credentials_block_type.id
|
348
|
+
)
|
349
|
+
)
|
350
|
+
assert (
|
351
|
+
credentials_block_schema is not None
|
352
|
+
), f"Unable to find schema for block type {credentials_block_type.slug}"
|
353
|
+
|
354
|
+
block_doc = await client.create_block_document(
|
355
|
+
block_document=BlockDocumentCreate(
|
356
|
+
name=self._block_document_name,
|
357
|
+
data={
|
358
|
+
"aws_access_key_id": access_key["AccessKeyId"],
|
359
|
+
"aws_secret_access_key": access_key["SecretAccessKey"],
|
360
|
+
"region_name": boto3.session.Session().region_name,
|
361
|
+
},
|
362
|
+
block_type_id=credentials_block_type.id,
|
363
|
+
block_schema_id=credentials_block_schema.id,
|
364
|
+
)
|
365
|
+
)
|
366
|
+
advance()
|
367
|
+
base_job_template["variables"]["properties"]["aws_credentials"]["default"] = {
|
368
|
+
"$ref": {"block_document_id": str(block_doc.id)}
|
369
|
+
}
|
370
|
+
|
371
|
+
|
372
|
+
class AuthenticationResource:
|
373
|
+
def __init__(
|
374
|
+
self,
|
375
|
+
work_pool_name: str,
|
376
|
+
user_name: str = "prefect-ecs-user",
|
377
|
+
policy_name: str = "prefect-ecs-policy",
|
378
|
+
):
|
379
|
+
self._user_name = user_name
|
380
|
+
self._policy_name = policy_name
|
381
|
+
self._iam_user_resource = IamUserResource(user_name=user_name)
|
382
|
+
self._iam_policy_resource = IamPolicyResource(policy_name=policy_name)
|
383
|
+
self._credentials_block_resource = CredentialsBlockResource(
|
384
|
+
user_name=user_name, block_document_name=f"{work_pool_name}-aws-credentials"
|
385
|
+
)
|
386
|
+
|
387
|
+
@property
|
388
|
+
def resources(self):
|
389
|
+
return [
|
390
|
+
self._iam_user_resource,
|
391
|
+
self._iam_policy_resource,
|
392
|
+
self._credentials_block_resource,
|
393
|
+
]
|
394
|
+
|
395
|
+
async def get_task_count(self):
|
396
|
+
"""
|
397
|
+
Returns the number of tasks that will be executed to provision this resource.
|
398
|
+
|
399
|
+
Returns:
|
400
|
+
int: The number of tasks to be provisioned.
|
401
|
+
"""
|
402
|
+
return sum([await resource.get_task_count() for resource in self.resources])
|
403
|
+
|
404
|
+
async def requires_provisioning(self) -> bool:
|
405
|
+
"""
|
406
|
+
Check if this resource requires provisioning.
|
407
|
+
|
408
|
+
Returns:
|
409
|
+
bool: True if provisioning is required, False otherwise.
|
410
|
+
"""
|
411
|
+
return any(
|
412
|
+
[await resource.requires_provisioning() for resource in self.resources]
|
413
|
+
)
|
414
|
+
|
415
|
+
async def get_planned_actions(self) -> List[str]:
|
416
|
+
"""
|
417
|
+
Returns a description of the planned actions for provisioning this resource.
|
418
|
+
|
419
|
+
Returns:
|
420
|
+
Optional[str]: A description of the planned actions for provisioning the resource,
|
421
|
+
or None if provisioning is not required.
|
422
|
+
"""
|
423
|
+
return [
|
424
|
+
action
|
425
|
+
for resource in self.resources
|
426
|
+
for action in await resource.get_planned_actions()
|
427
|
+
]
|
428
|
+
|
429
|
+
async def provision(
|
430
|
+
self,
|
431
|
+
base_job_template: Dict[str, Any],
|
432
|
+
advance: Callable[[], None],
|
433
|
+
):
|
434
|
+
"""
|
435
|
+
Provisions the authentication resources.
|
436
|
+
|
437
|
+
Args:
|
438
|
+
base_job_template: The base job template of the work pool to provision
|
439
|
+
infrastructure for.
|
440
|
+
advance: A callback function to indicate progress.
|
441
|
+
"""
|
442
|
+
# Provision the IAM user
|
443
|
+
await self._iam_user_resource.provision(advance=advance)
|
444
|
+
# Provision the IAM policy
|
445
|
+
policy_arn = await self._iam_policy_resource.provision(advance=advance)
|
446
|
+
# Attach the policy to the user
|
447
|
+
if policy_arn:
|
448
|
+
iam_client = boto3.client("iam")
|
449
|
+
await anyio.to_thread.run_sync(
|
450
|
+
partial(
|
451
|
+
iam_client.attach_user_policy,
|
452
|
+
UserName=self._user_name,
|
453
|
+
PolicyArn=policy_arn,
|
454
|
+
)
|
455
|
+
)
|
456
|
+
await self._credentials_block_resource.provision(
|
457
|
+
base_job_template=base_job_template,
|
458
|
+
advance=advance,
|
459
|
+
)
|
460
|
+
|
461
|
+
|
462
|
+
class ClusterResource:
|
463
|
+
def __init__(self, cluster_name: str = "prefect-ecs-cluster"):
|
464
|
+
self._ecs_client = boto3.client("ecs")
|
465
|
+
self._cluster_name = cluster_name
|
466
|
+
self._requires_provisioning = None
|
467
|
+
|
468
|
+
async def get_task_count(self):
|
469
|
+
"""
|
470
|
+
Returns the number of tasks that will be executed to provision this resource.
|
471
|
+
|
472
|
+
Returns:
|
473
|
+
int: The number of tasks to be provisioned.
|
474
|
+
"""
|
475
|
+
return 1 if await self.requires_provisioning() else 0
|
476
|
+
|
477
|
+
async def requires_provisioning(self) -> bool:
|
478
|
+
"""
|
479
|
+
Check if this resource requires provisioning.
|
480
|
+
|
481
|
+
Returns:
|
482
|
+
bool: True if provisioning is required, False otherwise.
|
483
|
+
"""
|
484
|
+
if self._requires_provisioning is None:
|
485
|
+
response = await anyio.to_thread.run_sync(
|
486
|
+
partial(
|
487
|
+
self._ecs_client.describe_clusters, clusters=[self._cluster_name]
|
488
|
+
)
|
489
|
+
)
|
490
|
+
if response["clusters"] and response["clusters"][0]["status"] == "ACTIVE":
|
491
|
+
self._requires_provisioning = False
|
492
|
+
else:
|
493
|
+
self._requires_provisioning = True
|
494
|
+
return self._requires_provisioning
|
495
|
+
|
496
|
+
async def get_planned_actions(self) -> List[str]:
|
497
|
+
"""
|
498
|
+
Returns a description of the planned actions for provisioning this resource.
|
499
|
+
|
500
|
+
Returns:
|
501
|
+
Optional[str]: A description of the planned actions for provisioning the resource,
|
502
|
+
or None if provisioning is not required.
|
503
|
+
"""
|
504
|
+
if await self.requires_provisioning():
|
505
|
+
return [
|
506
|
+
"Creating an ECS cluster for running Prefect flows:"
|
507
|
+
f" [blue]{self._cluster_name}[/]"
|
508
|
+
]
|
509
|
+
return []
|
510
|
+
|
511
|
+
async def provision(
|
512
|
+
self,
|
513
|
+
base_job_template: Dict[str, Any],
|
514
|
+
advance: Callable[[], None],
|
515
|
+
):
|
516
|
+
"""
|
517
|
+
Provisions an ECS cluster.
|
518
|
+
|
519
|
+
Will update the `cluster` variable in the job template to reference the cluster.
|
520
|
+
|
521
|
+
Args:
|
522
|
+
base_job_template: The base job template of the work pool to provision
|
523
|
+
infrastructure for.
|
524
|
+
advance: A callback function to indicate progress.
|
525
|
+
"""
|
526
|
+
if await self.requires_provisioning():
|
527
|
+
console = current_console.get()
|
528
|
+
console.print("Provisioning ECS cluster")
|
529
|
+
await anyio.to_thread.run_sync(
|
530
|
+
partial(self._ecs_client.create_cluster, clusterName=self._cluster_name)
|
531
|
+
)
|
532
|
+
advance()
|
533
|
+
|
534
|
+
base_job_template["variables"]["properties"]["cluster"][
|
535
|
+
"default"
|
536
|
+
] = self._cluster_name
|
537
|
+
|
538
|
+
|
539
|
+
class VpcResource:
|
540
|
+
def __init__(self, vpc_name: str = "prefect-ecs-vpc"):
|
541
|
+
self._ec2_client = boto3.client("ec2")
|
542
|
+
self._ec2_resource = boto3.resource("ec2")
|
543
|
+
self._vpc_name = vpc_name
|
544
|
+
self._requires_provisioning = None
|
545
|
+
|
546
|
+
async def get_task_count(self):
|
547
|
+
"""
|
548
|
+
Returns the number of tasks that will be executed to provision this resource.
|
549
|
+
|
550
|
+
Returns:
|
551
|
+
int: The number of tasks to be provisioned.
|
552
|
+
"""
|
553
|
+
return 4 if await self.requires_provisioning() else 0
|
554
|
+
|
555
|
+
async def _default_vpc_exists(self):
|
556
|
+
response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs)
|
557
|
+
default_vpc = next(
|
558
|
+
(
|
559
|
+
vpc
|
560
|
+
for vpc in response["Vpcs"]
|
561
|
+
if vpc["IsDefault"] and vpc["State"] == "available"
|
562
|
+
),
|
563
|
+
None,
|
564
|
+
)
|
565
|
+
return default_vpc is not None
|
566
|
+
|
567
|
+
async def _get_prefect_created_vpc(self):
|
568
|
+
vpcs = await anyio.to_thread.run_sync(
|
569
|
+
partial(
|
570
|
+
self._ec2_resource.vpcs.filter,
|
571
|
+
Filters=[{"Name": "tag:Name", "Values": [self._vpc_name]}],
|
572
|
+
)
|
573
|
+
)
|
574
|
+
return next(iter(vpcs), None)
|
575
|
+
|
576
|
+
async def _get_existing_vpc_cidrs(self):
|
577
|
+
response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs)
|
578
|
+
return [vpc["CidrBlock"] for vpc in response["Vpcs"]]
|
579
|
+
|
580
|
+
async def _find_non_overlapping_cidr(self, default_cidr="172.31.0.0/16"):
|
581
|
+
"""Find a non-overlapping CIDR block"""
|
582
|
+
response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs)
|
583
|
+
existing_cidrs = [vpc["CidrBlock"] for vpc in response["Vpcs"]]
|
584
|
+
|
585
|
+
base_ip = ipaddress.ip_network(default_cidr)
|
586
|
+
new_cidr = base_ip
|
587
|
+
while True:
|
588
|
+
if any(
|
589
|
+
new_cidr.overlaps(ipaddress.ip_network(cidr)) for cidr in existing_cidrs
|
590
|
+
):
|
591
|
+
# Increase the network address by the size of the network
|
592
|
+
new_network_address = int(new_cidr.network_address) + 2 ** (
|
593
|
+
32 - new_cidr.prefixlen
|
594
|
+
)
|
595
|
+
try:
|
596
|
+
new_cidr = ipaddress.ip_network(
|
597
|
+
f"{ipaddress.IPv4Address(new_network_address)}/{new_cidr.prefixlen}"
|
598
|
+
)
|
599
|
+
except ValueError:
|
600
|
+
raise Exception(
|
601
|
+
"Unable to find a non-overlapping CIDR block in the default"
|
602
|
+
" range"
|
603
|
+
)
|
604
|
+
else:
|
605
|
+
return str(new_cidr)
|
606
|
+
|
607
|
+
async def requires_provisioning(self) -> bool:
|
608
|
+
"""
|
609
|
+
Check if this resource requires provisioning.
|
610
|
+
|
611
|
+
Returns:
|
612
|
+
bool: True if provisioning is required, False otherwise.
|
613
|
+
"""
|
614
|
+
if self._requires_provisioning is not None:
|
615
|
+
return self._requires_provisioning
|
616
|
+
|
617
|
+
if await self._default_vpc_exists():
|
618
|
+
self._requires_provisioning = False
|
619
|
+
return False
|
620
|
+
|
621
|
+
if await self._get_prefect_created_vpc() is not None:
|
622
|
+
self._requires_provisioning = False
|
623
|
+
return False
|
624
|
+
|
625
|
+
self._requires_provisioning = True
|
626
|
+
return True
|
627
|
+
|
628
|
+
async def get_planned_actions(self) -> List[str]:
|
629
|
+
"""
|
630
|
+
Returns a description of the planned actions for provisioning this resource.
|
631
|
+
|
632
|
+
Returns:
|
633
|
+
Optional[str]: A description of the planned actions for provisioning the resource,
|
634
|
+
or None if provisioning is not required.
|
635
|
+
"""
|
636
|
+
if await self.requires_provisioning():
|
637
|
+
new_vpc_cidr = await self._find_non_overlapping_cidr()
|
638
|
+
return [
|
639
|
+
f"Creating a VPC with CIDR [blue]{new_vpc_cidr}[/] for running"
|
640
|
+
f" ECS tasks: [blue]{self._vpc_name}[/]"
|
641
|
+
]
|
642
|
+
return []
|
643
|
+
|
644
|
+
async def provision(
|
645
|
+
self,
|
646
|
+
base_job_template: Dict[str, Any],
|
647
|
+
advance: Callable[[], None],
|
648
|
+
):
|
649
|
+
"""
|
650
|
+
Provisions a VPC.
|
651
|
+
|
652
|
+
Chooses a CIDR block to avoid conflicting with any existing VPCs. Will update
|
653
|
+
the `vpc_id` variable in the job template to reference the VPC.
|
654
|
+
|
655
|
+
Args:
|
656
|
+
base_job_template: The base job template of the work pool to provision
|
657
|
+
infrastructure for.
|
658
|
+
advance: A callback function to indicate progress.
|
659
|
+
"""
|
660
|
+
if await self.requires_provisioning():
|
661
|
+
console = current_console.get()
|
662
|
+
console.print("Provisioning VPC")
|
663
|
+
new_vpc_cidr = await self._find_non_overlapping_cidr()
|
664
|
+
vpc = await anyio.to_thread.run_sync(
|
665
|
+
partial(self._ec2_resource.create_vpc, CidrBlock=new_vpc_cidr)
|
666
|
+
)
|
667
|
+
await anyio.to_thread.run_sync(vpc.wait_until_available)
|
668
|
+
await anyio.to_thread.run_sync(
|
669
|
+
partial(
|
670
|
+
vpc.create_tags,
|
671
|
+
Resources=[vpc.id],
|
672
|
+
Tags=[
|
673
|
+
{
|
674
|
+
"Key": "Name",
|
675
|
+
"Value": self._vpc_name,
|
676
|
+
},
|
677
|
+
],
|
678
|
+
)
|
679
|
+
)
|
680
|
+
advance()
|
681
|
+
|
682
|
+
console.print("Creating internet gateway")
|
683
|
+
internet_gateway = await anyio.to_thread.run_sync(
|
684
|
+
self._ec2_resource.create_internet_gateway
|
685
|
+
)
|
686
|
+
await anyio.to_thread.run_sync(
|
687
|
+
partial(
|
688
|
+
vpc.attach_internet_gateway, InternetGatewayId=internet_gateway.id
|
689
|
+
)
|
690
|
+
)
|
691
|
+
advance()
|
692
|
+
|
693
|
+
console.print("Setting up subnets")
|
694
|
+
vpc_network = ipaddress.ip_network(new_vpc_cidr)
|
695
|
+
subnet_cidrs = list(
|
696
|
+
vpc_network.subnets(new_prefix=vpc_network.prefixlen + 2)
|
697
|
+
)
|
698
|
+
|
699
|
+
# Create subnets
|
700
|
+
azs = (
|
701
|
+
await anyio.to_thread.run_sync(
|
702
|
+
self._ec2_client.describe_availability_zones
|
703
|
+
)
|
704
|
+
)["AvailabilityZones"]
|
705
|
+
zones = [az["ZoneName"] for az in azs]
|
706
|
+
subnets = []
|
707
|
+
for i, subnet_cidr in enumerate(subnet_cidrs[0:3]):
|
708
|
+
subnets.append(
|
709
|
+
await anyio.to_thread.run_sync(
|
710
|
+
partial(
|
711
|
+
vpc.create_subnet,
|
712
|
+
CidrBlock=str(subnet_cidr),
|
713
|
+
AvailabilityZone=zones[i],
|
714
|
+
)
|
715
|
+
)
|
716
|
+
)
|
717
|
+
|
718
|
+
# Create a Route Table for the public subnet and add a route to the Internet Gateway
|
719
|
+
public_route_table = await anyio.to_thread.run_sync(vpc.create_route_table)
|
720
|
+
await anyio.to_thread.run_sync(
|
721
|
+
partial(
|
722
|
+
public_route_table.create_route,
|
723
|
+
DestinationCidrBlock="0.0.0.0/0",
|
724
|
+
GatewayId=internet_gateway.id,
|
725
|
+
)
|
726
|
+
)
|
727
|
+
await anyio.to_thread.run_sync(
|
728
|
+
partial(
|
729
|
+
public_route_table.associate_with_subnet, SubnetId=subnets[0].id
|
730
|
+
)
|
731
|
+
)
|
732
|
+
await anyio.to_thread.run_sync(
|
733
|
+
partial(
|
734
|
+
public_route_table.associate_with_subnet, SubnetId=subnets[1].id
|
735
|
+
)
|
736
|
+
)
|
737
|
+
await anyio.to_thread.run_sync(
|
738
|
+
partial(
|
739
|
+
public_route_table.associate_with_subnet, SubnetId=subnets[2].id
|
740
|
+
)
|
741
|
+
)
|
742
|
+
advance()
|
743
|
+
|
744
|
+
console.print("Setting up security group")
|
745
|
+
# Create a security group to block all inbound traffic
|
746
|
+
await anyio.to_thread.run_sync(
|
747
|
+
partial(
|
748
|
+
self._ec2_resource.create_security_group,
|
749
|
+
GroupName="prefect-ecs-security-group",
|
750
|
+
Description=(
|
751
|
+
"Block all inbound traffic and allow all outbound traffic"
|
752
|
+
),
|
753
|
+
VpcId=vpc.id,
|
754
|
+
)
|
755
|
+
)
|
756
|
+
advance()
|
757
|
+
else:
|
758
|
+
vpc = await self._get_prefect_created_vpc()
|
759
|
+
|
760
|
+
if vpc is not None:
|
761
|
+
base_job_template["variables"]["properties"]["vpc_id"]["default"] = str(
|
762
|
+
vpc.id
|
763
|
+
)
|
764
|
+
|
765
|
+
|
766
|
+
class ElasticContainerServicePushProvisioner:
|
767
|
+
"""
|
768
|
+
An infrastructure provisioner for ECS push work pools.
|
769
|
+
"""
|
770
|
+
|
771
|
+
def __init__(self):
|
772
|
+
self._console = Console()
|
773
|
+
|
774
|
+
@property
|
775
|
+
def console(self):
|
776
|
+
return self._console
|
777
|
+
|
778
|
+
@console.setter
|
779
|
+
def console(self, value):
|
780
|
+
self._console = value
|
781
|
+
|
782
|
+
async def _prompt_boto3_installation(self):
|
783
|
+
global boto3
|
784
|
+
await run_process(
|
785
|
+
[shlex.quote(sys.executable), "-m", "pip", "install", "boto3"]
|
786
|
+
)
|
787
|
+
boto3 = importlib.import_module("boto3")
|
788
|
+
|
789
|
+
@staticmethod
|
790
|
+
def is_boto3_installed():
|
791
|
+
"""
|
792
|
+
Check if boto3 is installed.
|
793
|
+
"""
|
794
|
+
try:
|
795
|
+
importlib.import_module("boto3")
|
796
|
+
return True
|
797
|
+
except ModuleNotFoundError:
|
798
|
+
return False
|
799
|
+
|
800
|
+
def _generate_resources(self, work_pool_name: str):
|
801
|
+
return [
|
802
|
+
AuthenticationResource(work_pool_name=work_pool_name),
|
803
|
+
ClusterResource(),
|
804
|
+
VpcResource(),
|
805
|
+
]
|
806
|
+
|
807
|
+
async def provision(
|
808
|
+
self,
|
809
|
+
work_pool_name: str,
|
810
|
+
base_job_template: dict,
|
811
|
+
) -> Dict[str, Any]:
|
812
|
+
"""
|
813
|
+
Provisions the infrastructure for an ECS push work pool.
|
814
|
+
|
815
|
+
Args:
|
816
|
+
work_pool_name: The name of the work pool to provision infrastructure for.
|
817
|
+
base_job_template: The base job template of the work pool to provision
|
818
|
+
infrastructure for.
|
819
|
+
|
820
|
+
Returns:
|
821
|
+
dict: An updated copy base job template.
|
822
|
+
"""
|
823
|
+
if not self.is_boto3_installed():
|
824
|
+
if self.console.is_interactive and Confirm.ask(
|
825
|
+
"boto3 is required to configure your AWS account. Would you like to"
|
826
|
+
" install it?"
|
827
|
+
):
|
828
|
+
await self._prompt_boto3_installation()
|
829
|
+
else:
|
830
|
+
raise RuntimeError(
|
831
|
+
"boto3 is required to configure your AWS account. Please install it"
|
832
|
+
" and try again."
|
833
|
+
)
|
834
|
+
|
835
|
+
try:
|
836
|
+
resources = self._generate_resources(work_pool_name=work_pool_name)
|
837
|
+
|
838
|
+
with Progress(
|
839
|
+
SpinnerColumn(),
|
840
|
+
TextColumn(
|
841
|
+
"Checking your AWS account for infrastructure that needs to be"
|
842
|
+
" provisioned..."
|
843
|
+
),
|
844
|
+
transient=True,
|
845
|
+
console=self.console,
|
846
|
+
) as progress:
|
847
|
+
inspect_aws_account_task = progress.add_task(
|
848
|
+
"inspect_aws_account", total=1
|
849
|
+
)
|
850
|
+
num_tasks = sum(
|
851
|
+
[await resource.get_task_count() for resource in resources]
|
852
|
+
)
|
853
|
+
progress.update(inspect_aws_account_task, completed=1)
|
854
|
+
|
855
|
+
if num_tasks > 0:
|
856
|
+
message = (
|
857
|
+
"Provisioning infrastructure for your work pool"
|
858
|
+
f" [blue]{work_pool_name}[/] will require: \n"
|
859
|
+
)
|
860
|
+
for resource in resources:
|
861
|
+
planned_actions = await resource.get_planned_actions()
|
862
|
+
for action in planned_actions:
|
863
|
+
message += f"\n\t - {action}"
|
864
|
+
|
865
|
+
self.console.print(Panel(message))
|
866
|
+
|
867
|
+
if self._console.is_interactive:
|
868
|
+
if not Confirm.ask(
|
869
|
+
"Proceed with infrastructure provisioning?",
|
870
|
+
console=self._console,
|
871
|
+
):
|
872
|
+
return base_job_template
|
873
|
+
else:
|
874
|
+
self.console.print(
|
875
|
+
"No additional infrastructure required for work pool"
|
876
|
+
f" [blue]{work_pool_name}[/]"
|
877
|
+
)
|
878
|
+
# don't return early, we still need to update the base job template
|
879
|
+
# provision calls will be no-ops, but update the base job template
|
880
|
+
|
881
|
+
base_job_template_copy = deepcopy(base_job_template)
|
882
|
+
with Progress(console=self._console, disable=num_tasks == 0) as progress:
|
883
|
+
task = progress.add_task(
|
884
|
+
"Provisioning Infrastructure",
|
885
|
+
total=num_tasks,
|
886
|
+
)
|
887
|
+
for resource in resources:
|
888
|
+
with console_context(progress.console):
|
889
|
+
await resource.provision(
|
890
|
+
advance=partial(progress.advance, task),
|
891
|
+
base_job_template=base_job_template_copy,
|
892
|
+
)
|
893
|
+
|
894
|
+
if num_tasks > 0:
|
895
|
+
self._console.print(
|
896
|
+
"Infrastructure successfully provisioned!", style="green"
|
897
|
+
)
|
898
|
+
|
899
|
+
return base_job_template_copy
|
900
|
+
except Exception as exc:
|
901
|
+
if hasattr(exc, "response"):
|
902
|
+
# Catching boto3 ClientError
|
903
|
+
response = getattr(exc, "response", {})
|
904
|
+
error_message = get_from_dict(response, "Error.Message") or str(exc)
|
905
|
+
raise RuntimeError(error_message) from exc
|
906
|
+
# Catching any other exception
|
907
|
+
raise RuntimeError(str(exc)) from exc
|