skypilot-nightly 1.0.0.dev20241023__py3-none-any.whl → 1.0.0.dev20241025__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. sky/__init__.py +4 -2
  2. sky/adaptors/azure.py +11 -0
  3. sky/check.py +11 -4
  4. sky/cli.py +24 -16
  5. sky/clouds/azure.py +86 -50
  6. sky/clouds/cloud.py +4 -0
  7. sky/clouds/cloud_registry.py +55 -10
  8. sky/clouds/kubernetes.py +1 -1
  9. sky/clouds/oci.py +1 -1
  10. sky/clouds/service_catalog/azure_catalog.py +15 -0
  11. sky/clouds/service_catalog/kubernetes_catalog.py +7 -1
  12. sky/clouds/utils/azure_utils.py +91 -0
  13. sky/exceptions.py +4 -4
  14. sky/jobs/recovery_strategy.py +3 -3
  15. sky/provision/azure/azure-config-template.json +7 -1
  16. sky/provision/azure/config.py +24 -8
  17. sky/provision/azure/instance.py +251 -137
  18. sky/provision/kubernetes/instance.py +4 -2
  19. sky/provision/provisioner.py +16 -8
  20. sky/resources.py +1 -0
  21. sky/templates/azure-ray.yml.j2 +2 -0
  22. sky/usage/usage_lib.py +3 -2
  23. sky/utils/common_utils.py +3 -2
  24. sky/utils/controller_utils.py +69 -18
  25. {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/METADATA +1 -1
  26. {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/RECORD +30 -30
  27. sky/provision/azure/azure-vm-template.json +0 -301
  28. {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/LICENSE +0 -0
  29. {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/WHEEL +0 -0
  30. {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/entry_points.txt +0 -0
  31. {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ from sky.utils import common_utils
24
24
  from sky.utils import ux_utils
25
25
 
26
26
  if typing.TYPE_CHECKING:
27
+ from sky import resources
27
28
  from sky import task as task_lib
28
29
 
29
30
  logger = sky_logging.init_logger(__name__)
@@ -327,8 +328,7 @@ class StrategyExecutor:
327
328
  'Failure happened before provisioning. Failover '
328
329
  f'reasons: {reasons_str}')
329
330
  if raise_on_failure:
330
- raise exceptions.ProvisionPrechecksError(
331
- reasons=reasons)
331
+ raise exceptions.ProvisionPrechecksError(reasons)
332
332
  return None
333
333
  logger.info('Failed to launch a cluster with error: '
334
334
  f'{common_utils.format_exception(e)})')
@@ -382,7 +382,7 @@ class FailoverStrategyExecutor(StrategyExecutor, name='FAILOVER',
382
382
  # first retry in the same cloud/region. (Inside recover() we may not
383
383
  # rely on cluster handle, as it can be None if the cluster is
384
384
  # preempted.)
385
- self._launched_resources: Optional['sky.resources.Resources'] = None
385
+ self._launched_resources: Optional['resources.Resources'] = None
386
386
 
387
387
  def _launch(self,
388
388
  max_retry: Optional[int] = 3,
@@ -13,6 +13,12 @@
13
13
  "metadata": {
14
14
  "description": "Subnet parameters."
15
15
  }
16
+ },
17
+ "nsgName": {
18
+ "type": "string",
19
+ "metadata": {
20
+ "description": "Name of the Network Security Group associated with the SkyPilot cluster."
21
+ }
16
22
  }
17
23
  },
18
24
  "variables": {
@@ -20,7 +26,7 @@
20
26
  "location": "[resourceGroup().location]",
21
27
  "msiName": "[concat('sky-', parameters('clusterId'), '-msi')]",
22
28
  "roleAssignmentName": "[concat('sky-', parameters('clusterId'), '-ra')]",
23
- "nsgName": "[concat('sky-', parameters('clusterId'), '-nsg')]",
29
+ "nsgName": "[parameters('nsgName')]",
24
30
  "nsg": "[resourceId('Microsoft.Network/networkSecurityGroups', variables('nsgName'))]",
25
31
  "vnetName": "[concat('sky-', parameters('clusterId'), '-vnet')]",
26
32
  "subnetName": "[concat('sky-', parameters('clusterId'), '-subnet')]"
@@ -8,7 +8,7 @@ import json
8
8
  from pathlib import Path
9
9
  import random
10
10
  import time
11
- from typing import Any, Callable
11
+ from typing import Any, Callable, Tuple
12
12
 
13
13
  from sky import exceptions
14
14
  from sky import sky_logging
@@ -22,6 +22,7 @@ UNIQUE_ID_LEN = 4
22
22
  _DEPLOYMENT_NAME = 'skypilot-config'
23
23
  _LEGACY_DEPLOYMENT_NAME = 'ray-config'
24
24
  _RESOURCE_GROUP_WAIT_FOR_DELETION_TIMEOUT = 480 # 8 minutes
25
+ _CLUSTER_ID = '{cluster_name_on_cloud}-{unique_id}'
25
26
 
26
27
 
27
28
  def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
@@ -41,11 +42,25 @@ def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
41
42
  return func
42
43
 
43
44
 
45
+ def get_cluster_id_and_nsg_name(resource_group: str,
46
+ cluster_name_on_cloud: str) -> Tuple[str, str]:
47
+ hasher = hashlib.md5(resource_group.encode('utf-8'))
48
+ unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN]
49
+ # We use the cluster name + resource group hash as the
50
+ # unique ID for the cluster, as we need to make sure that
51
+ # the deployments have unique names during failover.
52
+ cluster_id = _CLUSTER_ID.format(cluster_name_on_cloud=cluster_name_on_cloud,
53
+ unique_id=unique_id)
54
+ nsg_name = f'sky-{cluster_id}-nsg'
55
+ return cluster_id, nsg_name
56
+
57
+
44
58
  @common.log_function_start_end
45
59
  def bootstrap_instances(
46
60
  region: str, cluster_name_on_cloud: str,
47
61
  config: common.ProvisionConfig) -> common.ProvisionConfig:
48
62
  """See sky/provision/__init__.py"""
63
+ # TODO: use new azure sdk instead of ARM deployment.
49
64
  del region # unused
50
65
  provider_config = config.provider_config
51
66
  subscription_id = provider_config.get('subscription_id')
@@ -116,12 +131,13 @@ def bootstrap_instances(
116
131
 
117
132
  logger.info(f'Using cluster name: {cluster_name_on_cloud}')
118
133
 
119
- hasher = hashlib.md5(provider_config['resource_group'].encode('utf-8'))
120
- unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN]
134
+ cluster_id, nsg_name = get_cluster_id_and_nsg_name(
135
+ resource_group=provider_config['resource_group'],
136
+ cluster_name_on_cloud=cluster_name_on_cloud)
121
137
  subnet_mask = provider_config.get('subnet_mask')
122
138
  if subnet_mask is None:
123
139
  # choose a random subnet, skipping most common value of 0
124
- random.seed(unique_id)
140
+ random.seed(cluster_id)
125
141
  subnet_mask = f'10.{random.randint(1, 254)}.0.0/16'
126
142
  logger.info(f'Using subnet mask: {subnet_mask}')
127
143
 
@@ -134,10 +150,10 @@ def bootstrap_instances(
134
150
  'value': subnet_mask
135
151
  },
136
152
  'clusterId': {
137
- # We use the cluster name + resource group hash as the
138
- # unique ID for the cluster, as we need to make sure that
139
- # the deployments have unique names during failover.
140
- 'value': f'{cluster_name_on_cloud}-{unique_id}'
153
+ 'value': cluster_id
154
+ },
155
+ 'nsgName': {
156
+ 'value': nsg_name
141
157
  },
142
158
  },
143
159
  }
@@ -2,10 +2,8 @@
2
2
  import base64
3
3
  import copy
4
4
  import enum
5
- import json
6
5
  import logging
7
6
  from multiprocessing import pool
8
- import pathlib
9
7
  import time
10
8
  import typing
11
9
  from typing import Any, Callable, Dict, List, Optional, Tuple
@@ -17,13 +15,16 @@ from sky import status_lib
17
15
  from sky.adaptors import azure
18
16
  from sky.provision import common
19
17
  from sky.provision import constants
18
+ from sky.provision.azure import config as config_lib
20
19
  from sky.utils import common_utils
21
20
  from sky.utils import subprocess_utils
22
21
  from sky.utils import ux_utils
23
22
 
24
23
  if typing.TYPE_CHECKING:
25
24
  from azure.mgmt import compute as azure_compute
26
- from azure.mgmt import resource as azure_resource
25
+ from azure.mgmt import network as azure_network
26
+ from azure.mgmt.compute import models as azure_compute_models
27
+ from azure.mgmt.network import models as azure_network_models
27
28
 
28
29
  logger = sky_logging.init_logger(__name__)
29
30
 
@@ -31,6 +32,8 @@ logger = sky_logging.init_logger(__name__)
31
32
  # https://github.com/Azure/azure-sdk-for-python/issues/9422
32
33
  azure_logger = logging.getLogger('azure')
33
34
  azure_logger.setLevel(logging.WARNING)
35
+ Client = Any
36
+ NetworkSecurityGroup = Any
34
37
 
35
38
  _RESUME_INSTANCE_TIMEOUT = 480 # 8 minutes
36
39
  _RESUME_PER_INSTANCE_TIMEOUT = 120 # 2 minutes
@@ -40,6 +43,10 @@ _WAIT_CREATION_TIMEOUT_SECONDS = 600
40
43
 
41
44
  _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound'
42
45
  _POLL_INTERVAL = 1
46
+ # TODO(Doyoung): _LEGACY_NSG_NAME can be remove this after 0.8.0 to ignore
47
+ # legacy nsg names.
48
+ _LEGACY_NSG_NAME = 'ray-{cluster_name_on_cloud}-nsg'
49
+ _SECOND_LEGACY_NSG_NAME = 'sky-{cluster_name_on_cloud}-nsg'
43
50
 
44
51
 
45
52
  class AzureInstanceStatus(enum.Enum):
@@ -184,14 +191,150 @@ def _get_head_instance_id(instances: List) -> Optional[str]:
184
191
  return head_instance_id
185
192
 
186
193
 
187
- def _create_instances(
188
- compute_client: 'azure_compute.ComputeManagementClient',
189
- resource_client: 'azure_resource.ResourceManagementClient',
190
- cluster_name_on_cloud: str, resource_group: str,
191
- provider_config: Dict[str, Any], node_config: Dict[str, Any],
192
- tags: Dict[str, str], count: int) -> List:
194
+ def _create_network_interface(
195
+ network_client: 'azure_network.NetworkManagementClient', vm_name: str,
196
+ provider_config: Dict[str,
197
+ Any]) -> 'azure_network_models.NetworkInterface':
198
+ network = azure.azure_mgmt_models('network')
199
+ compute = azure.azure_mgmt_models('compute')
200
+ logger.info(f'Start creating network interface for {vm_name}...')
201
+ if provider_config.get('use_internal_ips', False):
202
+ name = f'{vm_name}-nic-private'
203
+ ip_config = network.IPConfiguration(
204
+ name=f'ip-config-private-{vm_name}',
205
+ subnet=compute.SubResource(id=provider_config['subnet']),
206
+ private_ip_allocation_method=network.IPAllocationMethod.DYNAMIC)
207
+ else:
208
+ name = f'{vm_name}-nic-public'
209
+ public_ip_address = network.PublicIPAddress(
210
+ location=provider_config['location'],
211
+ public_ip_allocation_method='Static',
212
+ public_ip_address_version='IPv4',
213
+ sku=network.PublicIPAddressSku(name='Basic', tier='Regional'))
214
+ ip_poller = network_client.public_ip_addresses.begin_create_or_update(
215
+ resource_group_name=provider_config['resource_group'],
216
+ public_ip_address_name=f'{vm_name}-ip',
217
+ parameters=public_ip_address)
218
+ logger.info(f'Created public IP address {ip_poller.result().name} '
219
+ f'with address {ip_poller.result().ip_address}.')
220
+ ip_config = network.IPConfiguration(
221
+ name=f'ip-config-public-{vm_name}',
222
+ subnet=compute.SubResource(id=provider_config['subnet']),
223
+ private_ip_allocation_method=network.IPAllocationMethod.DYNAMIC,
224
+ public_ip_address=network.PublicIPAddress(id=ip_poller.result().id))
225
+
226
+ ni_poller = network_client.network_interfaces.begin_create_or_update(
227
+ resource_group_name=provider_config['resource_group'],
228
+ network_interface_name=name,
229
+ parameters=network.NetworkInterface(
230
+ location=provider_config['location'],
231
+ ip_configurations=[ip_config],
232
+ network_security_group=network.NetworkSecurityGroup(
233
+ id=provider_config['nsg'])))
234
+ logger.info(f'Created network interface {ni_poller.result().name}.')
235
+ return ni_poller.result()
236
+
237
+
238
+ def _create_vm(
239
+ compute_client: 'azure_compute.ComputeManagementClient', vm_name: str,
240
+ node_tags: Dict[str, str], provider_config: Dict[str, Any],
241
+ node_config: Dict[str, Any],
242
+ network_interface_id: str) -> 'azure_compute_models.VirtualMachine':
243
+ compute = azure.azure_mgmt_models('compute')
244
+ logger.info(f'Start creating VM {vm_name}...')
245
+ hardware_profile = compute.HardwareProfile(
246
+ vm_size=node_config['azure_arm_parameters']['vmSize'])
247
+ network_profile = compute.NetworkProfile(network_interfaces=[
248
+ compute.NetworkInterfaceReference(id=network_interface_id, primary=True)
249
+ ])
250
+ public_key = node_config['azure_arm_parameters']['publicKey']
251
+ username = node_config['azure_arm_parameters']['adminUsername']
252
+ os_linux_custom_data = base64.b64encode(
253
+ node_config['azure_arm_parameters']['cloudInitSetupCommands'].encode(
254
+ 'utf-8')).decode('utf-8')
255
+ os_profile = compute.OSProfile(
256
+ admin_username=username,
257
+ computer_name=vm_name,
258
+ admin_password=public_key,
259
+ linux_configuration=compute.LinuxConfiguration(
260
+ disable_password_authentication=True,
261
+ ssh=compute.SshConfiguration(public_keys=[
262
+ compute.SshPublicKey(
263
+ path=f'/home/{username}/.ssh/authorized_keys',
264
+ key_data=public_key)
265
+ ])),
266
+ custom_data=os_linux_custom_data)
267
+ community_image_id = node_config['azure_arm_parameters'].get(
268
+ 'communityGalleryImageId', None)
269
+ if community_image_id is not None:
270
+ # Prioritize using community gallery image if specified.
271
+ image_reference = compute.ImageReference(
272
+ community_gallery_image_id=community_image_id)
273
+ logger.info(
274
+ f'Used community_image_id: {community_image_id} for VM {vm_name}.')
275
+ else:
276
+ image_reference = compute.ImageReference(
277
+ publisher=node_config['azure_arm_parameters']['imagePublisher'],
278
+ offer=node_config['azure_arm_parameters']['imageOffer'],
279
+ sku=node_config['azure_arm_parameters']['imageSku'],
280
+ version=node_config['azure_arm_parameters']['imageVersion'])
281
+ storage_profile = compute.StorageProfile(
282
+ image_reference=image_reference,
283
+ os_disk=compute.OSDisk(
284
+ create_option=compute.DiskCreateOptionTypes.FROM_IMAGE,
285
+ managed_disk=compute.ManagedDiskParameters(
286
+ storage_account_type=node_config['azure_arm_parameters']
287
+ ['osDiskTier']),
288
+ disk_size_gb=node_config['azure_arm_parameters']['osDiskSizeGB']))
289
+ vm_instance = compute.VirtualMachine(
290
+ location=provider_config['location'],
291
+ tags=node_tags,
292
+ hardware_profile=hardware_profile,
293
+ os_profile=os_profile,
294
+ storage_profile=storage_profile,
295
+ network_profile=network_profile,
296
+ identity=compute.VirtualMachineIdentity(
297
+ type='UserAssigned',
298
+ user_assigned_identities={provider_config['msi']: {}}))
299
+ vm_poller = compute_client.virtual_machines.begin_create_or_update(
300
+ resource_group_name=provider_config['resource_group'],
301
+ vm_name=vm_name,
302
+ parameters=vm_instance,
303
+ )
304
+ # poller.result() will block on async operation until it's done.
305
+ logger.info(f'Created VM {vm_poller.result().name}.')
306
+ # Configure driver extension for A10 GPUs. A10 GPUs requires a
307
+ # special type of drivers which is available at Microsoft HPC
308
+ # extension. Reference:
309
+ # https://forums.developer.nvidia.com/t/ubuntu-22-04-installation-driver-error-nvidia-a10/285195/2
310
+ # This can take more than 20mins for setting up the A10 GPUs
311
+ if node_config.get('need_nvidia_driver_extension', False):
312
+ ext_poller = compute_client.virtual_machine_extensions.\
313
+ begin_create_or_update(
314
+ resource_group_name=provider_config['resource_group'],
315
+ vm_name=vm_name,
316
+ vm_extension_name='NvidiaGpuDriverLinux',
317
+ extension_parameters=compute.VirtualMachineExtension(
318
+ location=provider_config['location'],
319
+ publisher='Microsoft.HpcCompute',
320
+ type_properties_type='NvidiaGpuDriverLinux',
321
+ type_handler_version='1.9',
322
+ auto_upgrade_minor_version=True,
323
+ settings='{}'))
324
+ logger.info(
325
+ f'Created VM extension {ext_poller.result().name} for VM {vm_name}.'
326
+ )
327
+ return vm_poller.result()
328
+
329
+
330
+ def _create_instances(compute_client: 'azure_compute.ComputeManagementClient',
331
+ network_client: 'azure_network.NetworkManagementClient',
332
+ cluster_name_on_cloud: str, resource_group: str,
333
+ provider_config: Dict[str, Any], node_config: Dict[str,
334
+ Any],
335
+ tags: Dict[str, str], count: int) -> List:
193
336
  vm_id = uuid4().hex[:UNIQUE_ID_LEN]
194
- tags = {
337
+ all_tags = {
195
338
  constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
196
339
  constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud,
197
340
  **constants.WORKER_NODE_TAGS,
@@ -199,83 +342,19 @@ def _create_instances(
199
342
  **tags,
200
343
  }
201
344
  node_tags = node_config['tags'].copy()
202
- node_tags.update(tags)
345
+ node_tags.update(all_tags)
203
346
 
204
- # load the template file
205
- current_path = pathlib.Path(__file__).parent
206
- template_path = current_path.joinpath('azure-vm-template.json')
207
- with open(template_path, 'r', encoding='utf-8') as template_fp:
208
- template = json.load(template_fp)
347
+ # Create VM instances in parallel.
348
+ def create_single_instance(vm_i):
349
+ vm_name = f'{cluster_name_on_cloud}-{vm_id}-{vm_i}'
350
+ network_interface = _create_network_interface(network_client, vm_name,
351
+ provider_config)
352
+ _create_vm(compute_client, vm_name, node_tags, provider_config,
353
+ node_config, network_interface.id)
209
354
 
210
- vm_name = f'{cluster_name_on_cloud}-{vm_id}'
211
- use_internal_ips = provider_config.get('use_internal_ips', False)
212
-
213
- template_params = node_config['azure_arm_parameters'].copy()
214
- # We don't include 'head' or 'worker' in the VM name as on Azure the VM
215
- # name is immutable and we may change the node type for existing VM in the
216
- # multi-node cluster, due to manual termination of the head node.
217
- template_params['vmName'] = vm_name
218
- template_params['provisionPublicIp'] = not use_internal_ips
219
- template_params['vmTags'] = node_tags
220
- template_params['vmCount'] = count
221
- template_params['msi'] = provider_config['msi']
222
- template_params['nsg'] = provider_config['nsg']
223
- template_params['subnet'] = provider_config['subnet']
224
- # In Azure, cloud-init script must be encoded in base64. For more
225
- # information, see:
226
- # https://learn.microsoft.com/en-us/azure/virtual-machines/custom-data
227
- template_params['cloudInitSetupCommands'] = (base64.b64encode(
228
- template_params['cloudInitSetupCommands'].encode('utf-8')).decode(
229
- 'utf-8'))
230
-
231
- if node_config.get('need_nvidia_driver_extension', False):
232
- # pylint: disable=line-too-long
233
- # Configure driver extension for A10 GPUs. A10 GPUs requires a
234
- # special type of drivers which is available at Microsoft HPC
235
- # extension. Reference: https://forums.developer.nvidia.com/t/ubuntu-22-04-installation-driver-error-nvidia-a10/285195/2
236
- for r in template['resources']:
237
- if r['type'] == 'Microsoft.Compute/virtualMachines':
238
- # Add a nested extension resource for A10 GPUs
239
- r['resources'] = [
240
- {
241
- 'type': 'extensions',
242
- 'apiVersion': '2015-06-15',
243
- 'location': '[variables(\'location\')]',
244
- 'dependsOn': [
245
- '[concat(\'Microsoft.Compute/virtualMachines/\', parameters(\'vmName\'), copyIndex())]'
246
- ],
247
- 'name': 'NvidiaGpuDriverLinux',
248
- 'properties': {
249
- 'publisher': 'Microsoft.HpcCompute',
250
- 'type': 'NvidiaGpuDriverLinux',
251
- 'typeHandlerVersion': '1.9',
252
- 'autoUpgradeMinorVersion': True,
253
- 'settings': {},
254
- },
255
- },
256
- ]
257
- break
258
-
259
- parameters = {
260
- 'properties': {
261
- 'mode': azure.deployment_mode().incremental,
262
- 'template': template,
263
- 'parameters': {
264
- key: {
265
- 'value': value
266
- } for key, value in template_params.items()
267
- },
268
- }
269
- }
270
-
271
- create_or_update = _get_azure_sdk_function(
272
- client=resource_client.deployments, function_name='create_or_update')
273
- create_or_update(
274
- resource_group_name=resource_group,
275
- deployment_name=vm_name,
276
- parameters=parameters,
277
- ).wait()
355
+ subprocess_utils.run_in_parallel(create_single_instance, range(count))
278
356
 
357
+ # Update disk performance tier
279
358
  performance_tier = node_config.get('disk_performance_tier', None)
280
359
  if performance_tier is not None:
281
360
  disks = compute_client.disks.list_by_resource_group(resource_group)
@@ -286,12 +365,14 @@ def _create_instances(
286
365
  f'az disk update -n {name} -g {resource_group} '
287
366
  f'--set tier={performance_tier}')
288
367
 
368
+ # Validation
289
369
  filters = {
290
370
  constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
291
371
  _TAG_SKYPILOT_VM_ID: vm_id
292
372
  }
293
373
  instances = _filter_instances(compute_client, resource_group, filters)
294
374
  assert len(instances) == count, (len(instances), count)
375
+
295
376
  return instances
296
377
 
297
378
 
@@ -303,7 +384,7 @@ def run_instances(region: str, cluster_name_on_cloud: str,
303
384
  resource_group = provider_config['resource_group']
304
385
  subscription_id = provider_config['subscription_id']
305
386
  compute_client = azure.get_client('compute', subscription_id)
306
-
387
+ network_client = azure.get_client('network', subscription_id)
307
388
  instances_to_resume = []
308
389
  resumed_instance_ids: List[str] = []
309
390
  created_instance_ids: List[str] = []
@@ -439,12 +520,11 @@ def run_instances(region: str, cluster_name_on_cloud: str,
439
520
  to_start_count -= len(resumed_instance_ids)
440
521
 
441
522
  if to_start_count > 0:
442
- resource_client = azure.get_client('resource', subscription_id)
443
523
  logger.debug(f'run_instances: Creating {to_start_count} instances.')
444
524
  try:
445
525
  created_instances = _create_instances(
446
526
  compute_client=compute_client,
447
- resource_client=resource_client,
527
+ network_client=network_client,
448
528
  cluster_name_on_cloud=cluster_name_on_cloud,
449
529
  resource_group=resource_group,
450
530
  provider_config=provider_config,
@@ -722,6 +802,32 @@ def query_instances(
722
802
  return statuses
723
803
 
724
804
 
805
+ # TODO(Doyoung): _get_cluster_nsg can be remove this after 0.8.0 to ignore
806
+ # legacy nsg names.
807
+ def _get_cluster_nsg(network_client: Client, resource_group: str,
808
+ cluster_name_on_cloud: str) -> NetworkSecurityGroup:
809
+ """Retrieve the NSG associated with the given name of the cluster."""
810
+ list_network_security_groups = _get_azure_sdk_function(
811
+ client=network_client.network_security_groups, function_name='list')
812
+ legacy_nsg_name = _LEGACY_NSG_NAME.format(
813
+ cluster_name_on_cloud=cluster_name_on_cloud)
814
+ second_legacy_nsg_name = _SECOND_LEGACY_NSG_NAME.format(
815
+ cluster_name_on_cloud=cluster_name_on_cloud)
816
+ _, nsg_name = config_lib.get_cluster_id_and_nsg_name(
817
+ resource_group=resource_group,
818
+ cluster_name_on_cloud=cluster_name_on_cloud)
819
+ possible_nsg_names = [nsg_name, legacy_nsg_name, second_legacy_nsg_name]
820
+ for nsg in list_network_security_groups(resource_group):
821
+ if nsg.name in possible_nsg_names:
822
+ return nsg
823
+
824
+ # Raise an error if no matching NSG is found
825
+ raise ValueError('Failed to find a matching NSG for cluster '
826
+ f'{cluster_name_on_cloud!r} in resource group '
827
+ f'{resource_group!r}. Expected NSG names were: '
828
+ f'{possible_nsg_names}.')
829
+
830
+
725
831
  def open_ports(
726
832
  cluster_name_on_cloud: str,
727
833
  ports: List[str],
@@ -736,58 +842,66 @@ def open_ports(
736
842
  update_network_security_groups = _get_azure_sdk_function(
737
843
  client=network_client.network_security_groups,
738
844
  function_name='create_or_update')
739
- list_network_security_groups = _get_azure_sdk_function(
740
- client=network_client.network_security_groups, function_name='list')
741
- for nsg in list_network_security_groups(resource_group):
742
- try:
743
- # Wait the NSG creation to be finished before opening a port. The
744
- # cluster provisioning triggers the NSG creation, but it may not be
745
- # finished yet.
746
- backoff = common_utils.Backoff(max_backoff_factor=1)
747
- start_time = time.time()
748
- while True:
749
- if nsg.provisioning_state not in ['Creating', 'Updating']:
750
- break
751
- if time.time() - start_time > _WAIT_CREATION_TIMEOUT_SECONDS:
752
- logger.warning(
753
- f'Fails to wait for the creation of NSG {nsg.name} in '
754
- f'{resource_group} within '
755
- f'{_WAIT_CREATION_TIMEOUT_SECONDS} seconds. '
756
- 'Skip this NSG.')
757
- backoff_time = backoff.current_backoff()
758
- logger.info(f'NSG {nsg.name} is not created yet. Waiting for '
759
- f'{backoff_time} seconds before checking again.')
760
- time.sleep(backoff_time)
761
-
762
- # Azure NSG rules have a priority field that determines the order
763
- # in which they are applied. The priority must be unique across
764
- # all inbound rules in one NSG.
765
- priority = max(rule.priority
766
- for rule in nsg.security_rules
767
- if rule.direction == 'Inbound') + 1
768
- nsg.security_rules.append(
769
- azure.create_security_rule(
770
- name=f'sky-ports-{cluster_name_on_cloud}-{priority}',
771
- priority=priority,
772
- protocol='Tcp',
773
- access='Allow',
774
- direction='Inbound',
775
- source_address_prefix='*',
776
- source_port_range='*',
777
- destination_address_prefix='*',
778
- destination_port_ranges=ports,
779
- ))
780
- poller = update_network_security_groups(resource_group, nsg.name,
781
- nsg)
782
- poller.wait()
783
- if poller.status() != 'Succeeded':
845
+
846
+ try:
847
+ # Wait for the NSG creation to be finished before opening a port. The
848
+ # cluster provisioning triggers the NSG creation, but it may not be
849
+ # finished yet.
850
+ backoff = common_utils.Backoff(max_backoff_factor=1)
851
+ start_time = time.time()
852
+ while True:
853
+ nsg = _get_cluster_nsg(network_client, resource_group,
854
+ cluster_name_on_cloud)
855
+ if nsg.provisioning_state not in ['Creating', 'Updating']:
856
+ break
857
+ if time.time() - start_time > _WAIT_CREATION_TIMEOUT_SECONDS:
784
858
  with ux_utils.print_exception_no_traceback():
785
- raise ValueError(f'Failed to open ports {ports} in NSG '
786
- f'{nsg.name}: {poller.status()}')
787
- except azure.exceptions().HttpResponseError as e:
859
+ raise TimeoutError(
860
+ f'Timed out while waiting for the Network '
861
+ f'Security Group {nsg.name!r} to be ready for '
862
+ f'cluster {cluster_name_on_cloud!r} in '
863
+ f'resource group {resource_group!r}. The NSG '
864
+ f'did not reach a stable state '
865
+ '(Creating/Updating) within the allocated '
866
+ f'{_WAIT_CREATION_TIMEOUT_SECONDS} seconds. '
867
+ 'Consequently, the operation to open ports '
868
+ f'{ports} failed.')
869
+
870
+ backoff_time = backoff.current_backoff()
871
+ logger.info(f'NSG {nsg.name} is not created yet. Waiting for '
872
+ f'{backoff_time} seconds before checking again.')
873
+ time.sleep(backoff_time)
874
+
875
+ # Azure NSG rules have a priority field that determines the order
876
+ # in which they are applied. The priority must be unique across
877
+ # all inbound rules in one NSG.
878
+ priority = max(rule.priority
879
+ for rule in nsg.security_rules
880
+ if rule.direction == 'Inbound') + 1
881
+ nsg.security_rules.append(
882
+ azure.create_security_rule(
883
+ name=f'sky-ports-{cluster_name_on_cloud}-{priority}',
884
+ priority=priority,
885
+ protocol='Tcp',
886
+ access='Allow',
887
+ direction='Inbound',
888
+ source_address_prefix='*',
889
+ source_port_range='*',
890
+ destination_address_prefix='*',
891
+ destination_port_ranges=ports,
892
+ ))
893
+ poller = update_network_security_groups(resource_group, nsg.name, nsg)
894
+ poller.wait()
895
+ if poller.status() != 'Succeeded':
788
896
  with ux_utils.print_exception_no_traceback():
789
- raise ValueError(
790
- f'Failed to open ports {ports} in NSG {nsg.name}.') from e
897
+ raise ValueError(f'Failed to open ports {ports} in NSG '
898
+ f'{nsg.name}: {poller.status()}')
899
+
900
+ except azure.exceptions().HttpResponseError as e:
901
+ with ux_utils.print_exception_no_traceback():
902
+ raise ValueError(f'Failed to open ports {ports} in NSG for cluster '
903
+ f'{cluster_name_on_cloud!r} within resource group '
904
+ f'{resource_group!r}.') from e
791
905
 
792
906
 
793
907
  def cleanup_ports(
@@ -18,6 +18,7 @@ from sky.provision.kubernetes import utils as kubernetes_utils
18
18
  from sky.utils import command_runner
19
19
  from sky.utils import common_utils
20
20
  from sky.utils import kubernetes_enums
21
+ from sky.utils import subprocess_utils
21
22
  from sky.utils import ux_utils
22
23
 
23
24
  POLL_INTERVAL = 2
@@ -398,8 +399,7 @@ def _setup_ssh_in_pods(namespace: str, context: Optional[str],
398
399
  # See https://www.educative.io/answers/error-mesg-ttyname-failed-inappropriate-ioctl-for-device # pylint: disable=line-too-long
399
400
  '$(prefix_cmd) sed -i "s/mesg n/tty -s \\&\\& mesg n/" ~/.profile;')
400
401
 
401
- # TODO(romilb): Parallelize the setup of SSH in pods for multi-node clusters
402
- for new_node in new_nodes:
402
+ def _setup_ssh_thread(new_node):
403
403
  pod_name = new_node.metadata.name
404
404
  runner = command_runner.KubernetesCommandRunner(
405
405
  ((namespace, context), pod_name))
@@ -411,6 +411,8 @@ def _setup_ssh_in_pods(namespace: str, context: Optional[str],
411
411
  stdout)
412
412
  logger.info(f'{"-"*20}End: Set up SSH in pod {pod_name!r} {"-"*20}')
413
413
 
414
+ subprocess_utils.run_in_parallel(_setup_ssh_thread, new_nodes)
415
+
414
416
 
415
417
  def _label_pod(namespace: str, context: Optional[str], pod_name: str,
416
418
  label: Dict[str, str]) -> None: