anyscale 0.26.19__py3-none-any.whl → 0.26.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. anyscale/_private/docgen/models.md +1 -1
  2. anyscale/client/README.md +6 -10
  3. anyscale/client/openapi_client/__init__.py +3 -3
  4. anyscale/client/openapi_client/api/default_api.py +238 -671
  5. anyscale/client/openapi_client/models/__init__.py +3 -3
  6. anyscale/client/openapi_client/models/decorated_production_job_state_transition.py +2 -2
  7. anyscale/client/openapi_client/models/{organizationpublicidentifier_response.py → job_queue_sort_directive.py} +49 -22
  8. anyscale/client/openapi_client/models/{organization_response.py → job_queue_sort_field.py} +20 -34
  9. anyscale/client/openapi_client/models/job_queues_query.py +31 -3
  10. anyscale/client/openapi_client/models/production_job_state_transition.py +2 -2
  11. anyscale/client/openapi_client/models/{organization_public_identifier.py → update_job_queue_request.py} +51 -22
  12. anyscale/commands/cloud_commands.py +15 -4
  13. anyscale/commands/command_examples.py +58 -0
  14. anyscale/commands/job_commands.py +2 -2
  15. anyscale/commands/job_queue_commands.py +172 -0
  16. anyscale/controllers/cloud_controller.py +358 -49
  17. anyscale/controllers/job_controller.py +215 -3
  18. anyscale/scripts.py +3 -0
  19. anyscale/sdk/anyscale_client/models/production_job_state_transition.py +2 -2
  20. anyscale/util.py +3 -1
  21. anyscale/utils/connect_helpers.py +34 -0
  22. anyscale/utils/gcp_utils.py +20 -4
  23. anyscale/version.py +1 -1
  24. anyscale/workspace/_private/workspace_sdk.py +19 -6
  25. {anyscale-0.26.19.dist-info → anyscale-0.26.21.dist-info}/METADATA +1 -1
  26. {anyscale-0.26.19.dist-info → anyscale-0.26.21.dist-info}/RECORD +31 -30
  27. {anyscale-0.26.19.dist-info → anyscale-0.26.21.dist-info}/LICENSE +0 -0
  28. {anyscale-0.26.19.dist-info → anyscale-0.26.21.dist-info}/NOTICE +0 -0
  29. {anyscale-0.26.19.dist-info → anyscale-0.26.21.dist-info}/WHEEL +0 -0
  30. {anyscale-0.26.19.dist-info → anyscale-0.26.21.dist-info}/entry_points.txt +0 -0
  31. {anyscale-0.26.19.dist-info → anyscale-0.26.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,172 @@
1
+ from typing import List
2
+
3
+ import click
4
+
5
+ from anyscale.client.openapi_client.models.job_queue_sort_directive import (
6
+ JobQueueSortDirective,
7
+ )
8
+ from anyscale.client.openapi_client.models.job_queue_sort_field import JobQueueSortField
9
+ from anyscale.client.openapi_client.models.sort_order import SortOrder
10
+ from anyscale.commands import command_examples
11
+ from anyscale.commands.util import AnyscaleCommand
12
+ from anyscale.controllers.job_controller import JobController, JobQueueView
13
+ from anyscale.util import validate_non_negative_arg
14
+
15
+
16
+ @click.group(
17
+ "job-queues", help="Interact with production job queues running on Anyscale."
18
+ )
19
+ def job_queue_cli() -> None:
20
+ pass
21
+
22
+
23
+ def parse_sort_fields(
24
+ param: str, sort_fields: List[str],
25
+ ) -> List[JobQueueSortDirective]:
26
+ sort_directives = []
27
+
28
+ for field_str in sort_fields:
29
+ descending = field_str.startswith("-")
30
+ raw_field = field_str.lstrip("-").upper()
31
+
32
+ if raw_field not in JobQueueSortField.allowable_values:
33
+ raise click.UsageError(
34
+ f"{param} must be one of {', '.join([v.lower() for v in JobQueueSortField.allowable_values])}"
35
+ )
36
+
37
+ sort_directives.append(
38
+ JobQueueSortDirective(
39
+ sort_field=raw_field,
40
+ sort_order=SortOrder.DESC if descending else SortOrder.ASC,
41
+ )
42
+ )
43
+
44
+ return sort_directives
45
+
46
+
47
+ @job_queue_cli.command(
48
+ name="list",
49
+ short_help="List job queues.",
50
+ cls=AnyscaleCommand,
51
+ example=command_examples.JOB_QUEUE_LIST,
52
+ )
53
+ @click.option(
54
+ "--include-all-users",
55
+ is_flag=True,
56
+ default=False,
57
+ help="Include job queues not created by current user.",
58
+ )
59
+ @click.option(
60
+ "--view",
61
+ type=click.Choice([v.name.lower() for v in JobQueueView], case_sensitive=False),
62
+ default=JobQueueView.DEFAULT.name,
63
+ help="Select which view to display.",
64
+ callback=lambda _, __, value: JobQueueView[value.upper()],
65
+ )
66
+ @click.option(
67
+ "--page",
68
+ default=100,
69
+ type=int,
70
+ help="Page size (default 100).",
71
+ callback=validate_non_negative_arg,
72
+ )
73
+ @click.option(
74
+ "--max-items",
75
+ required=False,
76
+ type=int,
77
+ help="Max items to show in list (only valid in interactive mode).",
78
+ callback=lambda ctx, param, value: validate_non_negative_arg(ctx, param, value)
79
+ if value
80
+ else None,
81
+ )
82
+ @click.option(
83
+ "--sort",
84
+ "sorting_directives",
85
+ multiple=True,
86
+ default=[JobQueueSortField.CREATED_AT],
87
+ help=f"""
88
+ Sort by column(s). Prefix column with - to sort in descending order.
89
+ Supported columns: {', '.join([v.lower() for v in JobQueueSortField.allowable_values])}.
90
+ """,
91
+ callback=lambda _, __, value: parse_sort_fields("sort", list(value)),
92
+ )
93
+ @click.option(
94
+ "--interactive/--no-interactive",
95
+ default=True,
96
+ help="--no-interactive disables the default interactive mode.",
97
+ )
98
+ def list_job_queues(
99
+ include_all_users: bool,
100
+ view: JobQueueView,
101
+ page: int,
102
+ max_items: int,
103
+ sorting_directives: List[JobQueueSortDirective],
104
+ interactive: bool,
105
+ ):
106
+ if max_items is not None and interactive:
107
+ raise click.UsageError("--max-items can only be used in non interactive mode.")
108
+ job_controller = JobController()
109
+ job_controller.list_job_queues(
110
+ max_items=max_items,
111
+ page_size=page,
112
+ include_all_users=include_all_users,
113
+ view=view,
114
+ sorting_directives=sorting_directives,
115
+ interactive=interactive,
116
+ )
117
+
118
+
119
+ @job_queue_cli.command(
120
+ name="update",
121
+ short_help="Update job queue.",
122
+ cls=AnyscaleCommand,
123
+ example=command_examples.JOB_QUEUE_UPDATE,
124
+ )
125
+ @click.option(
126
+ "--id", "job_queue_id", required=False, default=None, help="ID of the job queue."
127
+ )
128
+ @click.option(
129
+ "--name",
130
+ "job_queue_name",
131
+ required=False,
132
+ default=None,
133
+ help="Name of the job queue.",
134
+ )
135
+ @click.option(
136
+ "--max-concurrency",
137
+ required=False,
138
+ default=None,
139
+ help="Maximum concurrency of the job queue",
140
+ )
141
+ @click.option(
142
+ "--idle-timeout-s",
143
+ required=False,
144
+ default=None,
145
+ help="Idle timeout of the job queue",
146
+ )
147
+ def update_job_queue(
148
+ job_queue_id: str, job_queue_name: str, max_concurrency: int, idle_timeout_s: int
149
+ ):
150
+ if job_queue_id is None and job_queue_name is None:
151
+ raise click.ClickException("ID or name of job queue is required")
152
+ job_controller = JobController()
153
+ job_controller.update_job_queue(
154
+ job_queue_id=job_queue_id,
155
+ job_queue_name=job_queue_name,
156
+ max_concurrency=max_concurrency,
157
+ idle_timeout_s=idle_timeout_s,
158
+ )
159
+
160
+
161
+ @job_queue_cli.command(
162
+ name="info",
163
+ short_help="Info of a job queue.",
164
+ cls=AnyscaleCommand,
165
+ example=command_examples.JOB_QUEUE_INFO,
166
+ )
167
+ @click.option(
168
+ "--id", "job_queue_id", required=True, default=None, help="ID of the job."
169
+ )
170
+ def get_job_queue(job_queue_id: str):
171
+ job_controller = JobController()
172
+ job_controller.get_job_queue(job_queue_id=job_queue_id)
@@ -4,6 +4,7 @@ Fetches data required and formats output for `anyscale cloud` commands.
4
4
 
5
5
  import copy
6
6
  from datetime import datetime, timedelta
7
+ import difflib
7
8
  import json
8
9
  from os import getenv
9
10
  import pathlib
@@ -14,9 +15,11 @@ from typing import Any, Dict, List, MutableSequence, Optional, Tuple
14
15
  import uuid
15
16
 
16
17
  import boto3
18
+ from boto3.resources.base import ServiceResource as Boto3Resource
17
19
  from botocore.exceptions import ClientError, NoCredentialsError
18
20
  import click
19
21
  from click import Abort, ClickException
22
+ import colorama
20
23
  from rich.progress import Progress, track
21
24
  import yaml
22
25
 
@@ -24,10 +27,12 @@ from anyscale import __version__ as anyscale_version
24
27
  from anyscale.aws_iam_policies import get_anyscale_iam_permissions_ec2_restricted
25
28
  from anyscale.cli_logger import CloudSetupLogger
26
29
  from anyscale.client.openapi_client.models import (
30
+ AWSConfig,
27
31
  AWSMemoryDBClusterConfig,
28
32
  CloudAnalyticsEventCloudResource,
29
33
  CloudAnalyticsEventCommandName,
30
34
  CloudAnalyticsEventName,
35
+ CloudDeployment,
31
36
  CloudDeploymentConfig,
32
37
  CloudProviders,
33
38
  CloudState,
@@ -39,6 +44,8 @@ from anyscale.client.openapi_client.models import (
39
44
  CreateCloudResourceGCP,
40
45
  EditableCloudResource,
41
46
  EditableCloudResourceGCP,
47
+ FileStorage,
48
+ GCPConfig,
42
49
  NFSMountTarget,
43
50
  SubnetIdWithAvailabilityZoneAWS,
44
51
  UpdateCloudWithCloudResource,
@@ -1388,7 +1395,7 @@ class CloudController(BaseController):
1388
1395
  cloud_id, CloudProviders.AWS, functions_to_verify, yes,
1389
1396
  )
1390
1397
 
1391
- def get_cloud_deployments(self, cloud_id: str, cloud_name: str) -> Dict[str, Any]:
1398
+ def get_cloud_deployments(self, cloud_id: str) -> Dict[str, Any]:
1392
1399
  cloud = self.api_client.get_cloud_api_v2_clouds_cloud_id_get(
1393
1400
  cloud_id=cloud_id,
1394
1401
  ).result
@@ -1404,7 +1411,7 @@ class CloudController(BaseController):
1404
1411
  ).results
1405
1412
  except Exception as e: # noqa: BLE001
1406
1413
  raise ClickException(
1407
- f"Failed to get cloud deployments for cloud {cloud_name} ({cloud_id}). Error: {e}"
1414
+ f"Failed to get cloud deployments for cloud {cloud.name} ({cloud_id}). Error: {e}"
1408
1415
  )
1409
1416
 
1410
1417
  # Avoid displaying fields with empty values (since the values for optional fields default to None).
@@ -1419,12 +1426,355 @@ class CloudController(BaseController):
1419
1426
 
1420
1427
  return {
1421
1428
  "id": cloud_id,
1422
- "name": cloud_name,
1429
+ "name": cloud.name,
1423
1430
  "deployments": [
1424
1431
  remove_empty_values(deployment.to_dict()) for deployment in deployments
1425
1432
  ],
1426
1433
  }
1427
1434
 
1435
+ def update_aws_anyscale_iam_role(
1436
+ self,
1437
+ cloud_id: str,
1438
+ region: str,
1439
+ anyscale_iam_role_id: Optional[str],
1440
+ external_id: Optional[str],
1441
+ ) -> Tuple[Optional[Boto3Resource], Optional[str]]:
1442
+ """
1443
+ Updates the Anyscale IAM role's assume policy to include the cloud ID as the external ID.
1444
+ Returns the role and the original policy document.
1445
+ """
1446
+ if not anyscale_iam_role_id:
1447
+ # anyscale_iam_role_id is optional for k8s
1448
+ return None, None
1449
+
1450
+ organization_id = get_organization_id(self.api_client)
1451
+ if external_id and not external_id.startswith(organization_id):
1452
+ raise ClickException(
1453
+ f"Invalid external ID: external ID must start with the organization ID: {organization_id}"
1454
+ )
1455
+
1456
+ # Update anyscale IAM role's assume policy to include the cloud id as the external ID
1457
+ role = _get_role(
1458
+ AwsRoleArn.from_string(anyscale_iam_role_id).to_role_name(), region
1459
+ )
1460
+ if role is None:
1461
+ self.log.log_resource_error(
1462
+ CloudAnalyticsEventCloudResource.AWS_IAM_ROLE,
1463
+ CloudSetupError.RESOURCE_NOT_FOUND,
1464
+ )
1465
+ raise ClickException(f"Failed to access IAM role {anyscale_iam_role_id}.")
1466
+
1467
+ iam_role_original_policy = role.assume_role_policy_document # type: ignore
1468
+ if external_id is None:
1469
+ try:
1470
+ new_policy = _update_external_ids_for_policy(
1471
+ iam_role_original_policy, cloud_id
1472
+ )
1473
+ role.AssumeRolePolicy().update(PolicyDocument=json.dumps(new_policy)) # type: ignore
1474
+ except ClientError as e:
1475
+ self.log.log_resource_exception(
1476
+ CloudAnalyticsEventCloudResource.AWS_IAM_ROLE, e
1477
+ )
1478
+ raise e
1479
+ else:
1480
+ fetched_external_ids = [
1481
+ statement.setdefault("Condition", {})
1482
+ .setdefault("StringEquals", {})
1483
+ .setdefault("sts:ExternalId", [])
1484
+ for statement in iam_role_original_policy.get("Statement", []) # type: ignore
1485
+ ]
1486
+ external_id_in_policy = all(
1487
+ external_id == fetched_external_id
1488
+ if isinstance(fetched_external_id, str)
1489
+ else external_id in fetched_external_id
1490
+ for fetched_external_id in fetched_external_ids
1491
+ )
1492
+ if not external_id_in_policy:
1493
+ raise ClickException(
1494
+ f"External ID {external_id} is not in the assume role policy of {anyscale_iam_role_id}."
1495
+ )
1496
+
1497
+ return role, iam_role_original_policy
1498
+
1499
+ def _generate_diff(self, existing: Dict[str, Any], new: Dict[str, Any]) -> str:
1500
+ """
1501
+ Generates a diff between the existing and new dicts.
1502
+ """
1503
+
1504
+ diff = difflib.unified_diff(
1505
+ yaml.dump(existing).splitlines(keepends=True),
1506
+ yaml.dump(new).splitlines(keepends=True),
1507
+ lineterm="",
1508
+ )
1509
+
1510
+ formatted_diff = ""
1511
+ for d in diff:
1512
+ if d.startswith("+") and not d.startswith("+++"):
1513
+ formatted_diff += "{}{}{}".format(
1514
+ colorama.Fore.GREEN, d, colorama.Style.RESET_ALL
1515
+ )
1516
+ elif d.startswith("-") and not d.startswith("---"):
1517
+ formatted_diff += "{}{}{}".format(
1518
+ colorama.Fore.RED, d, colorama.Style.RESET_ALL
1519
+ )
1520
+ else:
1521
+ formatted_diff += d
1522
+
1523
+ return formatted_diff.strip()
1524
+
1525
+ def _compare_cloud_deployments(
1526
+ self,
1527
+ deployments: List[CloudDeployment],
1528
+ existing_deployments: Dict[str, CloudDeployment],
1529
+ ) -> List[CloudDeployment]:
1530
+ """
1531
+ Compares the new deployments with the existing deployments and returns a list of updated/added deployments.
1532
+ """
1533
+
1534
+ deployment_ids = {
1535
+ deployment.cloud_deployment_id
1536
+ for deployment in deployments
1537
+ if deployment.cloud_deployment_id
1538
+ }
1539
+
1540
+ if existing_deployments.keys() - deployment_ids:
1541
+ raise ClickException("Deleting cloud deployments is not supported.")
1542
+
1543
+ unknown_deployments = deployment_ids - existing_deployments.keys()
1544
+ if unknown_deployments:
1545
+ raise ClickException(
1546
+ f"Cloud deployment(s) {unknown_deployments} do not exist. Do not include a deployment ID when adding a new deployment."
1547
+ )
1548
+
1549
+ updated_deployments: List[CloudDeployment] = []
1550
+ for d in deployments:
1551
+ if d.cloud_deployment_id:
1552
+ if d == existing_deployments[d.cloud_deployment_id]:
1553
+ continue
1554
+ if d.provider == CloudProviders.PCP:
1555
+ raise ClickException(
1556
+ "Updating machine pool deployments is not supported."
1557
+ )
1558
+ else:
1559
+ if d.provider == CloudProviders.PCP:
1560
+ raise ClickException(
1561
+ "Please use `anyscale machine-pool attach` to attach a machine pool to a cloud."
1562
+ )
1563
+ updated_deployments.append(d)
1564
+
1565
+ return updated_deployments
1566
+
1567
+ def _preprocess_aws(self, cloud_id: str, deployment: CloudDeployment,) -> None:
1568
+ if not deployment.aws_config and not deployment.file_storage:
1569
+ return
1570
+
1571
+ if not validate_aws_credentials(self.log):
1572
+ raise ClickException(
1573
+ "Updating cloud deployments requires valid AWS credentials to be set locally. Learn more: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html"
1574
+ )
1575
+
1576
+ # Get EFS mount target IP.
1577
+ if deployment.file_storage:
1578
+ file_storage = FileStorage(**deployment.file_storage)
1579
+ assert file_storage.file_storage_id
1580
+
1581
+ try:
1582
+ boto3_session = boto3.Session(region_name=deployment.region)
1583
+ efs_mount_target_ip = _get_aws_efs_mount_target_ip(
1584
+ boto3_session, file_storage.file_storage_id,
1585
+ )
1586
+ if not efs_mount_target_ip:
1587
+ raise ClickException(
1588
+ f"EFS mount target IP not found for {file_storage.file_storage_id}."
1589
+ )
1590
+ file_storage.mount_targets = [
1591
+ NFSMountTarget(address=efs_mount_target_ip)
1592
+ ]
1593
+ except ClientError as e:
1594
+ self.log.log_resource_exception(
1595
+ CloudAnalyticsEventCloudResource.AWS_EFS, e
1596
+ )
1597
+ raise e
1598
+
1599
+ deployment.file_storage = file_storage
1600
+
1601
+ if deployment.aws_config:
1602
+ aws_config = AWSConfig(**deployment.aws_config)
1603
+ assert deployment.region
1604
+
1605
+ # Update Anyscale IAM role's assume policy to include the cloud ID as the external ID.
1606
+ self.update_aws_anyscale_iam_role(
1607
+ cloud_id,
1608
+ deployment.region,
1609
+ aws_config.anyscale_iam_role_id,
1610
+ aws_config.external_id,
1611
+ )
1612
+ if aws_config.external_id is None:
1613
+ aws_config.external_id = cloud_id
1614
+
1615
+ # Get zones corresponding to subnet IDs.
1616
+ if aws_config.subnet_ids:
1617
+ subnets_with_azs = associate_aws_subnets_with_azs(
1618
+ aws_config.subnet_ids, deployment.region, self.log
1619
+ )
1620
+ aws_config.zones = [s.availability_zone for s in subnets_with_azs]
1621
+
1622
+ # Get memorydb config.
1623
+ if aws_config.memorydb_cluster_name:
1624
+ memorydb_cluster_config = _get_memorydb_cluster_config(
1625
+ aws_config.memorydb_cluster_name, deployment.region, self.log,
1626
+ )
1627
+ assert memorydb_cluster_config
1628
+ aws_config.memorydb_cluster_arn = memorydb_cluster_config.id
1629
+ aws_config.memorydb_cluster_endpoint = memorydb_cluster_config.endpoint
1630
+
1631
+ deployment.aws_config = aws_config
1632
+
1633
+ def _preprocess_gcp(
1634
+ self, deployment: CloudDeployment,
1635
+ ):
1636
+ if not deployment.gcp_config:
1637
+ return
1638
+
1639
+ gcp_config = GCPConfig(**deployment.gcp_config)
1640
+ if not deployment.file_storage and not gcp_config.memorystore_instance_name:
1641
+ return
1642
+
1643
+ if not gcp_config.project_id:
1644
+ raise ClickException(
1645
+ '"project_id" is required to configure filestore or memorystore'
1646
+ )
1647
+
1648
+ gcp_utils = try_import_gcp_utils()
1649
+ factory = gcp_utils.get_google_cloud_client_factory(
1650
+ self.log, gcp_config.project_id
1651
+ )
1652
+
1653
+ # Get Filestore mount target IP and root dir.
1654
+ if deployment.file_storage:
1655
+ fs = FileStorage(**deployment.file_storage)
1656
+ if fs.file_storage_id:
1657
+ if not gcp_config.vpc_name:
1658
+ raise ClickException(
1659
+ '"vpc_name" is required to configure filestore'
1660
+ )
1661
+ filestore_config = gcp_utils.get_gcp_filestore_config_from_full_name(
1662
+ factory, gcp_config.vpc_name, fs.file_storage_id, self.log,
1663
+ )
1664
+ if not filestore_config:
1665
+ raise ClickException(
1666
+ f"Filestore config not found for {fs.file_storage_id}."
1667
+ )
1668
+ fs.mount_path = filestore_config.root_dir
1669
+ fs.mount_targets = [
1670
+ NFSMountTarget(address=filestore_config.mount_target_ip)
1671
+ ]
1672
+
1673
+ deployment.file_storage = fs
1674
+
1675
+ # Get Memorystore config.
1676
+ if gcp_config.memorystore_instance_name:
1677
+ memorystore_config = gcp_utils.get_gcp_memorystore_config(
1678
+ factory, gcp_config.memorystore_instance_name
1679
+ )
1680
+ assert memorystore_config
1681
+ gcp_config.memorystore_endpoint = memorystore_config.endpoint
1682
+
1683
+ deployment.gcp_config = gcp_config
1684
+
1685
+ def update_cloud_deployments( # noqa: PLR0912
1686
+ self, spec_file: str, yes: bool = False,
1687
+ ):
1688
+ # Read the spec file.
1689
+ path = pathlib.Path(spec_file)
1690
+ if not path.exists():
1691
+ raise ClickException(f"{spec_file} does not exist.")
1692
+ if not path.is_file():
1693
+ raise ClickException(f"{spec_file} is not a file.")
1694
+
1695
+ spec = yaml.safe_load(path.read_text())
1696
+ if not all(k in spec for k in ["id", "name", "deployments"]):
1697
+ raise ClickException(
1698
+ "Cloud ID, name, and deployments must be specified in the spec file."
1699
+ )
1700
+
1701
+ # Get the existing spec.
1702
+ existing_spec = self.get_cloud_deployments(cloud_id=spec["id"],)
1703
+ if existing_spec["name"] != spec["name"]:
1704
+ raise ClickException("Changing the name of a cloud is not supported.")
1705
+
1706
+ # Diff the existing and new specs
1707
+ diff = self._generate_diff(existing_spec, spec)
1708
+ if not diff:
1709
+ self.log.info("No changes detected.")
1710
+ return
1711
+
1712
+ # Get updated/new deployments.
1713
+ try:
1714
+ deployments = [CloudDeployment(**d) for d in spec["deployments"]]
1715
+ except Exception as e: # noqa: BLE001
1716
+ raise ClickException(f"Failed to parse deployments: {e}")
1717
+
1718
+ existing_deployments = {
1719
+ deployment["cloud_deployment_id"]: CloudDeployment(**deployment)
1720
+ for deployment in existing_spec["deployments"]
1721
+ }
1722
+
1723
+ # Figure out which deployments have been updated/added.
1724
+ updated_deployments = self._compare_cloud_deployments(
1725
+ deployments, existing_deployments,
1726
+ )
1727
+
1728
+ # Log the diff and confirm.
1729
+ self.log.info(f"Detected the following changes:\n{diff}")
1730
+
1731
+ existing_deployment_ids = {
1732
+ d.cloud_deployment_id for d in updated_deployments if d.cloud_deployment_id
1733
+ }
1734
+ if len(updated_deployments) - len(existing_deployment_ids):
1735
+ self.log.info(
1736
+ f"{len(updated_deployments) - len(existing_deployment_ids)} new deployment(s) will be added."
1737
+ )
1738
+ if existing_deployment_ids:
1739
+ self.log.info(
1740
+ f"{len(existing_deployment_ids)} existing deployment(s) will be updated ({', '.join(existing_deployment_ids)})"
1741
+ )
1742
+
1743
+ # Log an additional warning if a new deployment is being added but a deployment with the same AWS/GCP region already exists.
1744
+ existing_stack_provider_regions = {
1745
+ (d.compute_stack, d.provider, d.region)
1746
+ for d in existing_deployments.values()
1747
+ if d.provider in (CloudProviders.AWS, CloudProviders.GCP)
1748
+ }
1749
+ for d in updated_deployments:
1750
+ if (
1751
+ not d.cloud_deployment_id
1752
+ and (d.compute_stack, d.provider, d.region)
1753
+ in existing_stack_provider_regions
1754
+ ):
1755
+ self.log.warning(
1756
+ f"A {d.provider} {d.compute_stack} deployment in region {d.region} already exists."
1757
+ )
1758
+
1759
+ confirm("Would you like to proceed with updating this cloud?", yes)
1760
+
1761
+ # Preprocess the deployments if necessary.
1762
+ for deployment in updated_deployments:
1763
+ if deployment.provider == CloudProviders.AWS:
1764
+ self._preprocess_aws(cloud_id=spec["id"], deployment=deployment)
1765
+ elif deployment.provider == CloudProviders.GCP:
1766
+ self._preprocess_gcp(deployment=deployment)
1767
+
1768
+ # Update the deployments.
1769
+ try:
1770
+ self.api_client.update_cloud_deployments_api_v2_clouds_cloud_id_deployments_put(
1771
+ cloud_id=spec["id"], cloud_deployment=updated_deployments,
1772
+ )
1773
+ except Exception as e: # noqa: BLE001
1774
+ raise ClickException(f"Failed to update cloud deployments: {e}")
1775
+
1776
+ self.log.info(f"Successfully updated cloud {spec['name']}!")
1777
+
1428
1778
  def get_cloud_config(
1429
1779
  self, cloud_name: Optional[str] = None, cloud_id: Optional[str] = None,
1430
1780
  ) -> CloudDeploymentConfig:
@@ -1982,12 +2332,6 @@ class CloudController(BaseController):
1982
2332
  if not cloud_storage_bucket_name.startswith(S3_STORAGE_PREFIX):
1983
2333
  cloud_storage_bucket_name = S3_STORAGE_PREFIX + cloud_storage_bucket_name
1984
2334
 
1985
- organization_id = get_organization_id(self.api_client)
1986
- if external_id and not external_id.startswith(organization_id):
1987
- raise ClickException(
1988
- f"Cloud registration failed! `--external-id` must start with the organization ID: {organization_id}"
1989
- )
1990
-
1991
2335
  self.cloud_event_producer.init_trace_context(
1992
2336
  CloudAnalyticsEventCommandName.REGISTER, CloudProviders.AWS
1993
2337
  )
@@ -2040,47 +2384,12 @@ class CloudController(BaseController):
2040
2384
  iam_role_original_policy = None
2041
2385
  if has_anyscale_iam_role:
2042
2386
  # Update anyscale IAM role's assume policy to include the cloud id as the external ID
2043
- role = _get_role(
2044
- AwsRoleArn.from_string(anyscale_iam_role_id).to_role_name(), region
2387
+ role, iam_role_original_policy = self.update_aws_anyscale_iam_role(
2388
+ cloud_id=cloud_id,
2389
+ region=region,
2390
+ anyscale_iam_role_id=anyscale_iam_role_id,
2391
+ external_id=external_id,
2045
2392
  )
2046
- if role is None:
2047
- self.log.log_resource_error(
2048
- CloudAnalyticsEventCloudResource.AWS_IAM_ROLE,
2049
- CloudSetupError.RESOURCE_NOT_FOUND,
2050
- )
2051
- raise ClickException(
2052
- f"Failed to access IAM role {anyscale_iam_role_id}."
2053
- )
2054
-
2055
- iam_role_original_policy = role.assume_role_policy_document # type: ignore
2056
- if external_id is None:
2057
- try:
2058
- new_policy = _update_external_ids_for_policy(
2059
- iam_role_original_policy, cloud_id
2060
- )
2061
- role.AssumeRolePolicy().update(PolicyDocument=json.dumps(new_policy)) # type: ignore
2062
- except ClientError as e:
2063
- self.log.log_resource_exception(
2064
- CloudAnalyticsEventCloudResource.AWS_IAM_ROLE, e
2065
- )
2066
- raise e
2067
- else:
2068
- fetched_external_ids = [
2069
- statement.setdefault("Condition", {})
2070
- .setdefault("StringEquals", {})
2071
- .setdefault("sts:ExternalId", [])
2072
- for statement in iam_role_original_policy.get("Statement", []) # type: ignore
2073
- ]
2074
- external_id_in_policy = all(
2075
- external_id == fetched_external_id
2076
- if isinstance(fetched_external_id, str)
2077
- else external_id in fetched_external_id
2078
- for fetched_external_id in fetched_external_ids
2079
- )
2080
- if not external_id_in_policy:
2081
- raise ClickException(
2082
- f"External ID {external_id} is not in the assume role policy of {anyscale_iam_role_id}."
2083
- )
2084
2393
 
2085
2394
  # When running on the VM compute stack, validate and retrieve the EFS mount target IP.
2086
2395
  # When running on the K8S compute stack, EFS is optional; if efs_id is provided, then