prefect-client 2.14.8__py3-none-any.whl → 2.14.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,907 @@
1
+ import contextlib
2
+ import contextvars
3
+ import importlib
4
+ import ipaddress
5
+ import json
6
+ import shlex
7
+ import sys
8
+ from copy import deepcopy
9
+ from functools import partial
10
+ from textwrap import dedent
11
+ from typing import Any, Callable, Dict, List, Optional
12
+
13
+ import anyio
14
+ from anyio import run_process
15
+ from rich.console import Console
16
+ from rich.panel import Panel
17
+ from rich.progress import Progress, SpinnerColumn, TextColumn
18
+ from rich.prompt import Confirm
19
+
20
+ from prefect.client.orchestration import PrefectClient
21
+ from prefect.client.schemas.actions import BlockDocumentCreate
22
+ from prefect.client.utilities import inject_client
23
+ from prefect.exceptions import ObjectNotFound
24
+ from prefect.utilities.collections import get_from_dict
25
+ from prefect.utilities.importtools import lazy_import
26
+
27
+ boto3 = docker = lazy_import("boto3")
28
+
29
+ current_console = contextvars.ContextVar("console", default=Console())
30
+
31
+
32
+ @contextlib.contextmanager
33
+ def console_context(value: Console):
34
+ token = current_console.set(value)
35
+ try:
36
+ yield
37
+ finally:
38
+ current_console.reset(token)
39
+
40
+
41
+ class IamPolicyResource:
42
+ """
43
+ Represents an IAM policy resource for managing ECS tasks.
44
+
45
+ Args:
46
+ policy_name: The name of the IAM policy. Defaults to "prefect-ecs-policy".
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ policy_name: str = "prefect-ecs-policy",
52
+ ):
53
+ self._iam_client = boto3.client("iam")
54
+ self._policy_name = policy_name
55
+ self._policy_document = json.dumps(
56
+ {
57
+ "Version": "2012-10-17",
58
+ "Statement": [
59
+ {
60
+ "Sid": "PrefectEcsPolicy",
61
+ "Effect": "Allow",
62
+ "Action": [
63
+ "ec2:AuthorizeSecurityGroupIngress",
64
+ "ec2:CreateSecurityGroup",
65
+ "ec2:CreateTags",
66
+ "ec2:DescribeNetworkInterfaces",
67
+ "ec2:DescribeSecurityGroups",
68
+ "ec2:DescribeSubnets",
69
+ "ec2:DescribeVpcs",
70
+ "ecs:CreateCluster",
71
+ "ecs:DeregisterTaskDefinition",
72
+ "ecs:DescribeClusters",
73
+ "ecs:DescribeTaskDefinition",
74
+ "ecs:DescribeTasks",
75
+ "ecs:ListAccountSettings",
76
+ "ecs:ListClusters",
77
+ "ecs:ListTaskDefinitions",
78
+ "ecs:RegisterTaskDefinition",
79
+ "ecs:RunTask",
80
+ "ecs:StopTask",
81
+ "logs:CreateLogStream",
82
+ "logs:PutLogEvents",
83
+ "logs:DescribeLogGroups",
84
+ "logs:GetLogEvents",
85
+ ],
86
+ "Resource": "*",
87
+ }
88
+ ],
89
+ }
90
+ )
91
+
92
+ self._requires_provisioning = None
93
+
94
+ async def get_task_count(self) -> int:
95
+ """
96
+ Returns the number of tasks that will be executed to provision this resource.
97
+
98
+ Returns:
99
+ int: The number of tasks to be provisioned.
100
+ """
101
+ return 1 if await self.requires_provisioning() else 0
102
+
103
+ def _get_policy_by_name(self, name):
104
+ paginator = self._iam_client.get_paginator("list_policies")
105
+ page_iterator = paginator.paginate(Scope="Local")
106
+
107
+ for page in page_iterator:
108
+ for policy in page["Policies"]:
109
+ if policy["PolicyName"] == name:
110
+ return policy
111
+ return None
112
+
113
+ async def requires_provisioning(self) -> bool:
114
+ """
115
+ Check if this resource requires provisioning.
116
+
117
+ Returns:
118
+ bool: True if provisioning is required, False otherwise.
119
+ """
120
+ if self._requires_provisioning is not None:
121
+ return self._requires_provisioning
122
+ policy = await anyio.to_thread.run_sync(
123
+ partial(self._get_policy_by_name, self._policy_name)
124
+ )
125
+ if policy is not None:
126
+ self._requires_provisioning = False
127
+ return False
128
+
129
+ self._requires_provisioning = True
130
+ return True
131
+
132
+ async def get_planned_actions(self) -> List[str]:
133
+ """
134
+ Returns a description of the planned actions for provisioning this resource.
135
+
136
+ Returns:
137
+ Optional[str]: A description of the planned actions for provisioning the resource,
138
+ or None if provisioning is not required.
139
+ """
140
+ if await self.requires_provisioning():
141
+ return [
142
+ "Creating and attaching an IAM policy for managing ECS tasks:"
143
+ f" [blue]{self._policy_name}[/]"
144
+ ]
145
+ return []
146
+
147
+ async def provision(
148
+ self,
149
+ advance: Callable[[], None],
150
+ ):
151
+ """
152
+ Provisions an IAM policy.
153
+
154
+ Args:
155
+ advance: A callback function to indicate progress.
156
+
157
+ Returns:
158
+ str: The ARN (Amazon Resource Name) of the created IAM policy.
159
+ """
160
+ if await self.requires_provisioning():
161
+ console = current_console.get()
162
+ console.print("Creating IAM policy")
163
+ policy = await anyio.to_thread.run_sync(
164
+ partial(
165
+ self._iam_client.create_policy,
166
+ PolicyName=self._policy_name,
167
+ PolicyDocument=self._policy_document,
168
+ )
169
+ )
170
+ policy_arn = policy["Policy"]["Arn"]
171
+ advance()
172
+ return policy_arn
173
+ # TODO: read and return policy arn
174
+
175
+
176
+ class IamUserResource:
177
+ """
178
+ Represents an IAM user resource for managing ECS tasks.
179
+
180
+ Args:
181
+ user_name: The desired name of the IAM user.
182
+ """
183
+
184
+ def __init__(self, user_name: str):
185
+ self._iam_client = boto3.client("iam")
186
+ self._user_name = user_name
187
+ self._requires_provisioning = None
188
+
189
+ async def get_task_count(self) -> int:
190
+ """
191
+ Returns the number of tasks that will be executed to provision this resource.
192
+
193
+ Returns:
194
+ int: The number of tasks to be provisioned.
195
+ """
196
+ return 1 if await self.requires_provisioning() else 0
197
+
198
+ async def requires_provisioning(self) -> bool:
199
+ """
200
+ Check if this resource requires provisioning.
201
+
202
+ Returns:
203
+ bool: True if provisioning is required, False otherwise.
204
+ """
205
+ if self._requires_provisioning is None:
206
+ try:
207
+ await anyio.to_thread.run_sync(
208
+ partial(self._iam_client.get_user, UserName=self._user_name)
209
+ )
210
+ self._requires_provisioning = False
211
+ except self._iam_client.exceptions.NoSuchEntityException:
212
+ self._requires_provisioning = True
213
+
214
+ return self._requires_provisioning
215
+
216
+ async def get_planned_actions(self) -> List[str]:
217
+ """
218
+ Returns a description of the planned actions for provisioning this resource.
219
+
220
+ Returns:
221
+ Optional[str]: A description of the planned actions for provisioning the resource,
222
+ or None if provisioning is not required.
223
+ """
224
+ if await self.requires_provisioning():
225
+ return [
226
+ "Creating an IAM user for managing ECS tasks:"
227
+ f" [blue]{self._user_name}[/]"
228
+ ]
229
+ return []
230
+
231
+ async def provision(
232
+ self,
233
+ advance: Callable[[], None],
234
+ ):
235
+ """
236
+ Provisions an IAM user.
237
+
238
+ Args:
239
+ advance: A callback function to indicate progress.
240
+ """
241
+ console = current_console.get()
242
+ if await self.requires_provisioning():
243
+ console.print("Provisioning IAM user")
244
+ await anyio.to_thread.run_sync(
245
+ partial(self._iam_client.create_user, UserName=self._user_name)
246
+ )
247
+ advance()
248
+
249
+
250
+ class CredentialsBlockResource:
251
+ def __init__(self, user_name: str, block_document_name: str):
252
+ self._block_document_name = block_document_name
253
+ self._user_name = user_name
254
+ self._requires_provisioning = None
255
+
256
+ async def get_task_count(self):
257
+ """
258
+ Returns the number of tasks that will be executed to provision this resource.
259
+
260
+ Returns:
261
+ int: The number of tasks to be provisioned.
262
+ """
263
+ return 2 if await self.requires_provisioning() else 0
264
+
265
+ @inject_client
266
+ async def requires_provisioning(
267
+ self, client: Optional[PrefectClient] = None
268
+ ) -> bool:
269
+ if self._requires_provisioning is None:
270
+ try:
271
+ assert client is not None
272
+ await client.read_block_document_by_name(
273
+ self._block_document_name, "aws-credentials"
274
+ )
275
+ self._requires_provisioning = False
276
+ except ObjectNotFound:
277
+ self._requires_provisioning = True
278
+ return self._requires_provisioning
279
+
280
+ async def get_planned_actions(self) -> List[str]:
281
+ """
282
+ Returns a description of the planned actions for provisioning this resource.
283
+
284
+ Returns:
285
+ Optional[str]: A description of the planned actions for provisioning the resource,
286
+ or None if provisioning is not required.
287
+ """
288
+ if await self.requires_provisioning():
289
+ return ["Storing generated AWS credentials in a block"]
290
+ return []
291
+
292
+ @inject_client
293
+ async def provision(
294
+ self,
295
+ base_job_template: Dict[str, Any],
296
+ advance: Callable[[], None],
297
+ client: Optional[PrefectClient] = None,
298
+ ):
299
+ """
300
+ Provisions an AWS credentials block.
301
+
302
+ Will generate new credentials if the block does not already exist. Updates
303
+ the `aws_credentials` variable in the job template to reference the block.
304
+
305
+ Args:
306
+ base_job_template: The base job template.
307
+ advance: A callback function to indicate progress.
308
+ client: A Prefect client to use for interacting with the Prefect API.
309
+ """
310
+ assert client is not None, "Client injection failed"
311
+ if not await self.requires_provisioning():
312
+ block_doc = await client.read_block_document_by_name(
313
+ self._block_document_name, "aws-credentials"
314
+ )
315
+ else:
316
+ console = current_console.get()
317
+ console.print("Generating AWS credentials")
318
+ iam_client = boto3.client("iam")
319
+ access_key_data = await anyio.to_thread.run_sync(
320
+ partial(iam_client.create_access_key, UserName=self._user_name)
321
+ )
322
+ access_key = access_key_data["AccessKey"]
323
+ advance()
324
+ console.print("Creating AWS credentials block")
325
+ assert client is not None
326
+
327
+ try:
328
+ credentials_block_type = await client.read_block_type_by_slug(
329
+ "aws-credentials"
330
+ )
331
+ except ObjectNotFound as exc:
332
+ raise RuntimeError(
333
+ dedent(
334
+ """\
335
+ Unable to find block type "aws-credentials".
336
+ To register the `aws-credentials` block type, run:
337
+
338
+ pip install prefect-aws
339
+ prefect blocks register -m prefect_aws
340
+
341
+ """
342
+ )
343
+ ) from exc
344
+
345
+ credentials_block_schema = (
346
+ await client.get_most_recent_block_schema_for_block_type(
347
+ block_type_id=credentials_block_type.id
348
+ )
349
+ )
350
+ assert (
351
+ credentials_block_schema is not None
352
+ ), f"Unable to find schema for block type {credentials_block_type.slug}"
353
+
354
+ block_doc = await client.create_block_document(
355
+ block_document=BlockDocumentCreate(
356
+ name=self._block_document_name,
357
+ data={
358
+ "aws_access_key_id": access_key["AccessKeyId"],
359
+ "aws_secret_access_key": access_key["SecretAccessKey"],
360
+ "region_name": boto3.session.Session().region_name,
361
+ },
362
+ block_type_id=credentials_block_type.id,
363
+ block_schema_id=credentials_block_schema.id,
364
+ )
365
+ )
366
+ advance()
367
+ base_job_template["variables"]["properties"]["aws_credentials"]["default"] = {
368
+ "$ref": {"block_document_id": str(block_doc.id)}
369
+ }
370
+
371
+
372
+ class AuthenticationResource:
373
+ def __init__(
374
+ self,
375
+ work_pool_name: str,
376
+ user_name: str = "prefect-ecs-user",
377
+ policy_name: str = "prefect-ecs-policy",
378
+ ):
379
+ self._user_name = user_name
380
+ self._policy_name = policy_name
381
+ self._iam_user_resource = IamUserResource(user_name=user_name)
382
+ self._iam_policy_resource = IamPolicyResource(policy_name=policy_name)
383
+ self._credentials_block_resource = CredentialsBlockResource(
384
+ user_name=user_name, block_document_name=f"{work_pool_name}-aws-credentials"
385
+ )
386
+
387
+ @property
388
+ def resources(self):
389
+ return [
390
+ self._iam_user_resource,
391
+ self._iam_policy_resource,
392
+ self._credentials_block_resource,
393
+ ]
394
+
395
+ async def get_task_count(self):
396
+ """
397
+ Returns the number of tasks that will be executed to provision this resource.
398
+
399
+ Returns:
400
+ int: The number of tasks to be provisioned.
401
+ """
402
+ return sum([await resource.get_task_count() for resource in self.resources])
403
+
404
+ async def requires_provisioning(self) -> bool:
405
+ """
406
+ Check if this resource requires provisioning.
407
+
408
+ Returns:
409
+ bool: True if provisioning is required, False otherwise.
410
+ """
411
+ return any(
412
+ [await resource.requires_provisioning() for resource in self.resources]
413
+ )
414
+
415
+ async def get_planned_actions(self) -> List[str]:
416
+ """
417
+ Returns a description of the planned actions for provisioning this resource.
418
+
419
+ Returns:
420
+ Optional[str]: A description of the planned actions for provisioning the resource,
421
+ or None if provisioning is not required.
422
+ """
423
+ return [
424
+ action
425
+ for resource in self.resources
426
+ for action in await resource.get_planned_actions()
427
+ ]
428
+
429
+ async def provision(
430
+ self,
431
+ base_job_template: Dict[str, Any],
432
+ advance: Callable[[], None],
433
+ ):
434
+ """
435
+ Provisions the authentication resources.
436
+
437
+ Args:
438
+ base_job_template: The base job template of the work pool to provision
439
+ infrastructure for.
440
+ advance: A callback function to indicate progress.
441
+ """
442
+ # Provision the IAM user
443
+ await self._iam_user_resource.provision(advance=advance)
444
+ # Provision the IAM policy
445
+ policy_arn = await self._iam_policy_resource.provision(advance=advance)
446
+ # Attach the policy to the user
447
+ if policy_arn:
448
+ iam_client = boto3.client("iam")
449
+ await anyio.to_thread.run_sync(
450
+ partial(
451
+ iam_client.attach_user_policy,
452
+ UserName=self._user_name,
453
+ PolicyArn=policy_arn,
454
+ )
455
+ )
456
+ await self._credentials_block_resource.provision(
457
+ base_job_template=base_job_template,
458
+ advance=advance,
459
+ )
460
+
461
+
462
+ class ClusterResource:
463
+ def __init__(self, cluster_name: str = "prefect-ecs-cluster"):
464
+ self._ecs_client = boto3.client("ecs")
465
+ self._cluster_name = cluster_name
466
+ self._requires_provisioning = None
467
+
468
+ async def get_task_count(self):
469
+ """
470
+ Returns the number of tasks that will be executed to provision this resource.
471
+
472
+ Returns:
473
+ int: The number of tasks to be provisioned.
474
+ """
475
+ return 1 if await self.requires_provisioning() else 0
476
+
477
+ async def requires_provisioning(self) -> bool:
478
+ """
479
+ Check if this resource requires provisioning.
480
+
481
+ Returns:
482
+ bool: True if provisioning is required, False otherwise.
483
+ """
484
+ if self._requires_provisioning is None:
485
+ response = await anyio.to_thread.run_sync(
486
+ partial(
487
+ self._ecs_client.describe_clusters, clusters=[self._cluster_name]
488
+ )
489
+ )
490
+ if response["clusters"] and response["clusters"][0]["status"] == "ACTIVE":
491
+ self._requires_provisioning = False
492
+ else:
493
+ self._requires_provisioning = True
494
+ return self._requires_provisioning
495
+
496
+ async def get_planned_actions(self) -> List[str]:
497
+ """
498
+ Returns a description of the planned actions for provisioning this resource.
499
+
500
+ Returns:
501
+ Optional[str]: A description of the planned actions for provisioning the resource,
502
+ or None if provisioning is not required.
503
+ """
504
+ if await self.requires_provisioning():
505
+ return [
506
+ "Creating an ECS cluster for running Prefect flows:"
507
+ f" [blue]{self._cluster_name}[/]"
508
+ ]
509
+ return []
510
+
511
+ async def provision(
512
+ self,
513
+ base_job_template: Dict[str, Any],
514
+ advance: Callable[[], None],
515
+ ):
516
+ """
517
+ Provisions an ECS cluster.
518
+
519
+ Will update the `cluster` variable in the job template to reference the cluster.
520
+
521
+ Args:
522
+ base_job_template: The base job template of the work pool to provision
523
+ infrastructure for.
524
+ advance: A callback function to indicate progress.
525
+ """
526
+ if await self.requires_provisioning():
527
+ console = current_console.get()
528
+ console.print("Provisioning ECS cluster")
529
+ await anyio.to_thread.run_sync(
530
+ partial(self._ecs_client.create_cluster, clusterName=self._cluster_name)
531
+ )
532
+ advance()
533
+
534
+ base_job_template["variables"]["properties"]["cluster"][
535
+ "default"
536
+ ] = self._cluster_name
537
+
538
+
539
+ class VpcResource:
540
+ def __init__(self, vpc_name: str = "prefect-ecs-vpc"):
541
+ self._ec2_client = boto3.client("ec2")
542
+ self._ec2_resource = boto3.resource("ec2")
543
+ self._vpc_name = vpc_name
544
+ self._requires_provisioning = None
545
+
546
+ async def get_task_count(self):
547
+ """
548
+ Returns the number of tasks that will be executed to provision this resource.
549
+
550
+ Returns:
551
+ int: The number of tasks to be provisioned.
552
+ """
553
+ return 4 if await self.requires_provisioning() else 0
554
+
555
+ async def _default_vpc_exists(self):
556
+ response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs)
557
+ default_vpc = next(
558
+ (
559
+ vpc
560
+ for vpc in response["Vpcs"]
561
+ if vpc["IsDefault"] and vpc["State"] == "available"
562
+ ),
563
+ None,
564
+ )
565
+ return default_vpc is not None
566
+
567
+ async def _get_prefect_created_vpc(self):
568
+ vpcs = await anyio.to_thread.run_sync(
569
+ partial(
570
+ self._ec2_resource.vpcs.filter,
571
+ Filters=[{"Name": "tag:Name", "Values": [self._vpc_name]}],
572
+ )
573
+ )
574
+ return next(iter(vpcs), None)
575
+
576
+ async def _get_existing_vpc_cidrs(self):
577
+ response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs)
578
+ return [vpc["CidrBlock"] for vpc in response["Vpcs"]]
579
+
580
+ async def _find_non_overlapping_cidr(self, default_cidr="172.31.0.0/16"):
581
+ """Find a non-overlapping CIDR block"""
582
+ response = await anyio.to_thread.run_sync(self._ec2_client.describe_vpcs)
583
+ existing_cidrs = [vpc["CidrBlock"] for vpc in response["Vpcs"]]
584
+
585
+ base_ip = ipaddress.ip_network(default_cidr)
586
+ new_cidr = base_ip
587
+ while True:
588
+ if any(
589
+ new_cidr.overlaps(ipaddress.ip_network(cidr)) for cidr in existing_cidrs
590
+ ):
591
+ # Increase the network address by the size of the network
592
+ new_network_address = int(new_cidr.network_address) + 2 ** (
593
+ 32 - new_cidr.prefixlen
594
+ )
595
+ try:
596
+ new_cidr = ipaddress.ip_network(
597
+ f"{ipaddress.IPv4Address(new_network_address)}/{new_cidr.prefixlen}"
598
+ )
599
+ except ValueError:
600
+ raise Exception(
601
+ "Unable to find a non-overlapping CIDR block in the default"
602
+ " range"
603
+ )
604
+ else:
605
+ return str(new_cidr)
606
+
607
+ async def requires_provisioning(self) -> bool:
608
+ """
609
+ Check if this resource requires provisioning.
610
+
611
+ Returns:
612
+ bool: True if provisioning is required, False otherwise.
613
+ """
614
+ if self._requires_provisioning is not None:
615
+ return self._requires_provisioning
616
+
617
+ if await self._default_vpc_exists():
618
+ self._requires_provisioning = False
619
+ return False
620
+
621
+ if await self._get_prefect_created_vpc() is not None:
622
+ self._requires_provisioning = False
623
+ return False
624
+
625
+ self._requires_provisioning = True
626
+ return True
627
+
628
+ async def get_planned_actions(self) -> List[str]:
629
+ """
630
+ Returns a description of the planned actions for provisioning this resource.
631
+
632
+ Returns:
633
+ Optional[str]: A description of the planned actions for provisioning the resource,
634
+ or None if provisioning is not required.
635
+ """
636
+ if await self.requires_provisioning():
637
+ new_vpc_cidr = await self._find_non_overlapping_cidr()
638
+ return [
639
+ f"Creating a VPC with CIDR [blue]{new_vpc_cidr}[/] for running"
640
+ f" ECS tasks: [blue]{self._vpc_name}[/]"
641
+ ]
642
+ return []
643
+
644
+ async def provision(
645
+ self,
646
+ base_job_template: Dict[str, Any],
647
+ advance: Callable[[], None],
648
+ ):
649
+ """
650
+ Provisions a VPC.
651
+
652
+ Chooses a CIDR block to avoid conflicting with any existing VPCs. Will update
653
+ the `vpc_id` variable in the job template to reference the VPC.
654
+
655
+ Args:
656
+ base_job_template: The base job template of the work pool to provision
657
+ infrastructure for.
658
+ advance: A callback function to indicate progress.
659
+ """
660
+ if await self.requires_provisioning():
661
+ console = current_console.get()
662
+ console.print("Provisioning VPC")
663
+ new_vpc_cidr = await self._find_non_overlapping_cidr()
664
+ vpc = await anyio.to_thread.run_sync(
665
+ partial(self._ec2_resource.create_vpc, CidrBlock=new_vpc_cidr)
666
+ )
667
+ await anyio.to_thread.run_sync(vpc.wait_until_available)
668
+ await anyio.to_thread.run_sync(
669
+ partial(
670
+ vpc.create_tags,
671
+ Resources=[vpc.id],
672
+ Tags=[
673
+ {
674
+ "Key": "Name",
675
+ "Value": self._vpc_name,
676
+ },
677
+ ],
678
+ )
679
+ )
680
+ advance()
681
+
682
+ console.print("Creating internet gateway")
683
+ internet_gateway = await anyio.to_thread.run_sync(
684
+ self._ec2_resource.create_internet_gateway
685
+ )
686
+ await anyio.to_thread.run_sync(
687
+ partial(
688
+ vpc.attach_internet_gateway, InternetGatewayId=internet_gateway.id
689
+ )
690
+ )
691
+ advance()
692
+
693
+ console.print("Setting up subnets")
694
+ vpc_network = ipaddress.ip_network(new_vpc_cidr)
695
+ subnet_cidrs = list(
696
+ vpc_network.subnets(new_prefix=vpc_network.prefixlen + 2)
697
+ )
698
+
699
+ # Create subnets
700
+ azs = (
701
+ await anyio.to_thread.run_sync(
702
+ self._ec2_client.describe_availability_zones
703
+ )
704
+ )["AvailabilityZones"]
705
+ zones = [az["ZoneName"] for az in azs]
706
+ subnets = []
707
+ for i, subnet_cidr in enumerate(subnet_cidrs[0:3]):
708
+ subnets.append(
709
+ await anyio.to_thread.run_sync(
710
+ partial(
711
+ vpc.create_subnet,
712
+ CidrBlock=str(subnet_cidr),
713
+ AvailabilityZone=zones[i],
714
+ )
715
+ )
716
+ )
717
+
718
+ # Create a Route Table for the public subnet and add a route to the Internet Gateway
719
+ public_route_table = await anyio.to_thread.run_sync(vpc.create_route_table)
720
+ await anyio.to_thread.run_sync(
721
+ partial(
722
+ public_route_table.create_route,
723
+ DestinationCidrBlock="0.0.0.0/0",
724
+ GatewayId=internet_gateway.id,
725
+ )
726
+ )
727
+ await anyio.to_thread.run_sync(
728
+ partial(
729
+ public_route_table.associate_with_subnet, SubnetId=subnets[0].id
730
+ )
731
+ )
732
+ await anyio.to_thread.run_sync(
733
+ partial(
734
+ public_route_table.associate_with_subnet, SubnetId=subnets[1].id
735
+ )
736
+ )
737
+ await anyio.to_thread.run_sync(
738
+ partial(
739
+ public_route_table.associate_with_subnet, SubnetId=subnets[2].id
740
+ )
741
+ )
742
+ advance()
743
+
744
+ console.print("Setting up security group")
745
+ # Create a security group to block all inbound traffic
746
+ await anyio.to_thread.run_sync(
747
+ partial(
748
+ self._ec2_resource.create_security_group,
749
+ GroupName="prefect-ecs-security-group",
750
+ Description=(
751
+ "Block all inbound traffic and allow all outbound traffic"
752
+ ),
753
+ VpcId=vpc.id,
754
+ )
755
+ )
756
+ advance()
757
+ else:
758
+ vpc = await self._get_prefect_created_vpc()
759
+
760
+ if vpc is not None:
761
+ base_job_template["variables"]["properties"]["vpc_id"]["default"] = str(
762
+ vpc.id
763
+ )
764
+
765
+
766
+ class ElasticContainerServicePushProvisioner:
767
+ """
768
+ An infrastructure provisioner for ECS push work pools.
769
+ """
770
+
771
+ def __init__(self):
772
+ self._console = Console()
773
+
774
+ @property
775
+ def console(self):
776
+ return self._console
777
+
778
+ @console.setter
779
+ def console(self, value):
780
+ self._console = value
781
+
782
+ async def _prompt_boto3_installation(self):
783
+ global boto3
784
+ await run_process(
785
+ [shlex.quote(sys.executable), "-m", "pip", "install", "boto3"]
786
+ )
787
+ boto3 = importlib.import_module("boto3")
788
+
789
+ @staticmethod
790
+ def is_boto3_installed():
791
+ """
792
+ Check if boto3 is installed.
793
+ """
794
+ try:
795
+ importlib.import_module("boto3")
796
+ return True
797
+ except ModuleNotFoundError:
798
+ return False
799
+
800
+ def _generate_resources(self, work_pool_name: str):
801
+ return [
802
+ AuthenticationResource(work_pool_name=work_pool_name),
803
+ ClusterResource(),
804
+ VpcResource(),
805
+ ]
806
+
807
+ async def provision(
808
+ self,
809
+ work_pool_name: str,
810
+ base_job_template: dict,
811
+ ) -> Dict[str, Any]:
812
+ """
813
+ Provisions the infrastructure for an ECS push work pool.
814
+
815
+ Args:
816
+ work_pool_name: The name of the work pool to provision infrastructure for.
817
+ base_job_template: The base job template of the work pool to provision
818
+ infrastructure for.
819
+
820
+ Returns:
821
+ dict: An updated copy base job template.
822
+ """
823
+ if not self.is_boto3_installed():
824
+ if self.console.is_interactive and Confirm.ask(
825
+ "boto3 is required to configure your AWS account. Would you like to"
826
+ " install it?"
827
+ ):
828
+ await self._prompt_boto3_installation()
829
+ else:
830
+ raise RuntimeError(
831
+ "boto3 is required to configure your AWS account. Please install it"
832
+ " and try again."
833
+ )
834
+
835
+ try:
836
+ resources = self._generate_resources(work_pool_name=work_pool_name)
837
+
838
+ with Progress(
839
+ SpinnerColumn(),
840
+ TextColumn(
841
+ "Checking your AWS account for infrastructure that needs to be"
842
+ " provisioned..."
843
+ ),
844
+ transient=True,
845
+ console=self.console,
846
+ ) as progress:
847
+ inspect_aws_account_task = progress.add_task(
848
+ "inspect_aws_account", total=1
849
+ )
850
+ num_tasks = sum(
851
+ [await resource.get_task_count() for resource in resources]
852
+ )
853
+ progress.update(inspect_aws_account_task, completed=1)
854
+
855
+ if num_tasks > 0:
856
+ message = (
857
+ "Provisioning infrastructure for your work pool"
858
+ f" [blue]{work_pool_name}[/] will require: \n"
859
+ )
860
+ for resource in resources:
861
+ planned_actions = await resource.get_planned_actions()
862
+ for action in planned_actions:
863
+ message += f"\n\t - {action}"
864
+
865
+ self.console.print(Panel(message))
866
+
867
+ if self._console.is_interactive:
868
+ if not Confirm.ask(
869
+ "Proceed with infrastructure provisioning?",
870
+ console=self._console,
871
+ ):
872
+ return base_job_template
873
+ else:
874
+ self.console.print(
875
+ "No additional infrastructure required for work pool"
876
+ f" [blue]{work_pool_name}[/]"
877
+ )
878
+ # don't return early, we still need to update the base job template
879
+ # provision calls will be no-ops, but update the base job template
880
+
881
+ base_job_template_copy = deepcopy(base_job_template)
882
+ with Progress(console=self._console, disable=num_tasks == 0) as progress:
883
+ task = progress.add_task(
884
+ "Provisioning Infrastructure",
885
+ total=num_tasks,
886
+ )
887
+ for resource in resources:
888
+ with console_context(progress.console):
889
+ await resource.provision(
890
+ advance=partial(progress.advance, task),
891
+ base_job_template=base_job_template_copy,
892
+ )
893
+
894
+ if num_tasks > 0:
895
+ self._console.print(
896
+ "Infrastructure successfully provisioned!", style="green"
897
+ )
898
+
899
+ return base_job_template_copy
900
+ except Exception as exc:
901
+ if hasattr(exc, "response"):
902
+ # Catching boto3 ClientError
903
+ response = getattr(exc, "response", {})
904
+ error_message = get_from_dict(response, "Error.Message") or str(exc)
905
+ raise RuntimeError(error_message) from exc
906
+ # Catching any other exception
907
+ raise RuntimeError(str(exc)) from exc