skypilot-nightly 1.0.0.dev20241023__py3-none-any.whl → 1.0.0.dev20241025__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +4 -2
- sky/adaptors/azure.py +11 -0
- sky/check.py +11 -4
- sky/cli.py +24 -16
- sky/clouds/azure.py +86 -50
- sky/clouds/cloud.py +4 -0
- sky/clouds/cloud_registry.py +55 -10
- sky/clouds/kubernetes.py +1 -1
- sky/clouds/oci.py +1 -1
- sky/clouds/service_catalog/azure_catalog.py +15 -0
- sky/clouds/service_catalog/kubernetes_catalog.py +7 -1
- sky/clouds/utils/azure_utils.py +91 -0
- sky/exceptions.py +4 -4
- sky/jobs/recovery_strategy.py +3 -3
- sky/provision/azure/azure-config-template.json +7 -1
- sky/provision/azure/config.py +24 -8
- sky/provision/azure/instance.py +251 -137
- sky/provision/kubernetes/instance.py +4 -2
- sky/provision/provisioner.py +16 -8
- sky/resources.py +1 -0
- sky/templates/azure-ray.yml.j2 +2 -0
- sky/usage/usage_lib.py +3 -2
- sky/utils/common_utils.py +3 -2
- sky/utils/controller_utils.py +69 -18
- {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/METADATA +1 -1
- {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/RECORD +30 -30
- sky/provision/azure/azure-vm-template.json +0 -301
- {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/top_level.txt +0 -0
sky/jobs/recovery_strategy.py
CHANGED
@@ -24,6 +24,7 @@ from sky.utils import common_utils
|
|
24
24
|
from sky.utils import ux_utils
|
25
25
|
|
26
26
|
if typing.TYPE_CHECKING:
|
27
|
+
from sky import resources
|
27
28
|
from sky import task as task_lib
|
28
29
|
|
29
30
|
logger = sky_logging.init_logger(__name__)
|
@@ -327,8 +328,7 @@ class StrategyExecutor:
|
|
327
328
|
'Failure happened before provisioning. Failover '
|
328
329
|
f'reasons: {reasons_str}')
|
329
330
|
if raise_on_failure:
|
330
|
-
raise exceptions.ProvisionPrechecksError(
|
331
|
-
reasons=reasons)
|
331
|
+
raise exceptions.ProvisionPrechecksError(reasons)
|
332
332
|
return None
|
333
333
|
logger.info('Failed to launch a cluster with error: '
|
334
334
|
f'{common_utils.format_exception(e)})')
|
@@ -382,7 +382,7 @@ class FailoverStrategyExecutor(StrategyExecutor, name='FAILOVER',
|
|
382
382
|
# first retry in the same cloud/region. (Inside recover() we may not
|
383
383
|
# rely on cluster handle, as it can be None if the cluster is
|
384
384
|
# preempted.)
|
385
|
-
self._launched_resources: Optional['
|
385
|
+
self._launched_resources: Optional['resources.Resources'] = None
|
386
386
|
|
387
387
|
def _launch(self,
|
388
388
|
max_retry: Optional[int] = 3,
|
@@ -13,6 +13,12 @@
|
|
13
13
|
"metadata": {
|
14
14
|
"description": "Subnet parameters."
|
15
15
|
}
|
16
|
+
},
|
17
|
+
"nsgName": {
|
18
|
+
"type": "string",
|
19
|
+
"metadata": {
|
20
|
+
"description": "Name of the Network Security Group associated with the SkyPilot cluster."
|
21
|
+
}
|
16
22
|
}
|
17
23
|
},
|
18
24
|
"variables": {
|
@@ -20,7 +26,7 @@
|
|
20
26
|
"location": "[resourceGroup().location]",
|
21
27
|
"msiName": "[concat('sky-', parameters('clusterId'), '-msi')]",
|
22
28
|
"roleAssignmentName": "[concat('sky-', parameters('clusterId'), '-ra')]",
|
23
|
-
"nsgName": "[
|
29
|
+
"nsgName": "[parameters('nsgName')]",
|
24
30
|
"nsg": "[resourceId('Microsoft.Network/networkSecurityGroups', variables('nsgName'))]",
|
25
31
|
"vnetName": "[concat('sky-', parameters('clusterId'), '-vnet')]",
|
26
32
|
"subnetName": "[concat('sky-', parameters('clusterId'), '-subnet')]"
|
sky/provision/azure/config.py
CHANGED
@@ -8,7 +8,7 @@ import json
|
|
8
8
|
from pathlib import Path
|
9
9
|
import random
|
10
10
|
import time
|
11
|
-
from typing import Any, Callable
|
11
|
+
from typing import Any, Callable, Tuple
|
12
12
|
|
13
13
|
from sky import exceptions
|
14
14
|
from sky import sky_logging
|
@@ -22,6 +22,7 @@ UNIQUE_ID_LEN = 4
|
|
22
22
|
_DEPLOYMENT_NAME = 'skypilot-config'
|
23
23
|
_LEGACY_DEPLOYMENT_NAME = 'ray-config'
|
24
24
|
_RESOURCE_GROUP_WAIT_FOR_DELETION_TIMEOUT = 480 # 8 minutes
|
25
|
+
_CLUSTER_ID = '{cluster_name_on_cloud}-{unique_id}'
|
25
26
|
|
26
27
|
|
27
28
|
def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
|
@@ -41,11 +42,25 @@ def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
|
|
41
42
|
return func
|
42
43
|
|
43
44
|
|
45
|
+
def get_cluster_id_and_nsg_name(resource_group: str,
|
46
|
+
cluster_name_on_cloud: str) -> Tuple[str, str]:
|
47
|
+
hasher = hashlib.md5(resource_group.encode('utf-8'))
|
48
|
+
unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN]
|
49
|
+
# We use the cluster name + resource group hash as the
|
50
|
+
# unique ID for the cluster, as we need to make sure that
|
51
|
+
# the deployments have unique names during failover.
|
52
|
+
cluster_id = _CLUSTER_ID.format(cluster_name_on_cloud=cluster_name_on_cloud,
|
53
|
+
unique_id=unique_id)
|
54
|
+
nsg_name = f'sky-{cluster_id}-nsg'
|
55
|
+
return cluster_id, nsg_name
|
56
|
+
|
57
|
+
|
44
58
|
@common.log_function_start_end
|
45
59
|
def bootstrap_instances(
|
46
60
|
region: str, cluster_name_on_cloud: str,
|
47
61
|
config: common.ProvisionConfig) -> common.ProvisionConfig:
|
48
62
|
"""See sky/provision/__init__.py"""
|
63
|
+
# TODO: use new azure sdk instead of ARM deployment.
|
49
64
|
del region # unused
|
50
65
|
provider_config = config.provider_config
|
51
66
|
subscription_id = provider_config.get('subscription_id')
|
@@ -116,12 +131,13 @@ def bootstrap_instances(
|
|
116
131
|
|
117
132
|
logger.info(f'Using cluster name: {cluster_name_on_cloud}')
|
118
133
|
|
119
|
-
|
120
|
-
|
134
|
+
cluster_id, nsg_name = get_cluster_id_and_nsg_name(
|
135
|
+
resource_group=provider_config['resource_group'],
|
136
|
+
cluster_name_on_cloud=cluster_name_on_cloud)
|
121
137
|
subnet_mask = provider_config.get('subnet_mask')
|
122
138
|
if subnet_mask is None:
|
123
139
|
# choose a random subnet, skipping most common value of 0
|
124
|
-
random.seed(
|
140
|
+
random.seed(cluster_id)
|
125
141
|
subnet_mask = f'10.{random.randint(1, 254)}.0.0/16'
|
126
142
|
logger.info(f'Using subnet mask: {subnet_mask}')
|
127
143
|
|
@@ -134,10 +150,10 @@ def bootstrap_instances(
|
|
134
150
|
'value': subnet_mask
|
135
151
|
},
|
136
152
|
'clusterId': {
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
'value':
|
153
|
+
'value': cluster_id
|
154
|
+
},
|
155
|
+
'nsgName': {
|
156
|
+
'value': nsg_name
|
141
157
|
},
|
142
158
|
},
|
143
159
|
}
|
sky/provision/azure/instance.py
CHANGED
@@ -2,10 +2,8 @@
|
|
2
2
|
import base64
|
3
3
|
import copy
|
4
4
|
import enum
|
5
|
-
import json
|
6
5
|
import logging
|
7
6
|
from multiprocessing import pool
|
8
|
-
import pathlib
|
9
7
|
import time
|
10
8
|
import typing
|
11
9
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
@@ -17,13 +15,16 @@ from sky import status_lib
|
|
17
15
|
from sky.adaptors import azure
|
18
16
|
from sky.provision import common
|
19
17
|
from sky.provision import constants
|
18
|
+
from sky.provision.azure import config as config_lib
|
20
19
|
from sky.utils import common_utils
|
21
20
|
from sky.utils import subprocess_utils
|
22
21
|
from sky.utils import ux_utils
|
23
22
|
|
24
23
|
if typing.TYPE_CHECKING:
|
25
24
|
from azure.mgmt import compute as azure_compute
|
26
|
-
from azure.mgmt import
|
25
|
+
from azure.mgmt import network as azure_network
|
26
|
+
from azure.mgmt.compute import models as azure_compute_models
|
27
|
+
from azure.mgmt.network import models as azure_network_models
|
27
28
|
|
28
29
|
logger = sky_logging.init_logger(__name__)
|
29
30
|
|
@@ -31,6 +32,8 @@ logger = sky_logging.init_logger(__name__)
|
|
31
32
|
# https://github.com/Azure/azure-sdk-for-python/issues/9422
|
32
33
|
azure_logger = logging.getLogger('azure')
|
33
34
|
azure_logger.setLevel(logging.WARNING)
|
35
|
+
Client = Any
|
36
|
+
NetworkSecurityGroup = Any
|
34
37
|
|
35
38
|
_RESUME_INSTANCE_TIMEOUT = 480 # 8 minutes
|
36
39
|
_RESUME_PER_INSTANCE_TIMEOUT = 120 # 2 minutes
|
@@ -40,6 +43,10 @@ _WAIT_CREATION_TIMEOUT_SECONDS = 600
|
|
40
43
|
|
41
44
|
_RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound'
|
42
45
|
_POLL_INTERVAL = 1
|
46
|
+
# TODO(Doyoung): _LEGACY_NSG_NAME can be remove this after 0.8.0 to ignore
|
47
|
+
# legacy nsg names.
|
48
|
+
_LEGACY_NSG_NAME = 'ray-{cluster_name_on_cloud}-nsg'
|
49
|
+
_SECOND_LEGACY_NSG_NAME = 'sky-{cluster_name_on_cloud}-nsg'
|
43
50
|
|
44
51
|
|
45
52
|
class AzureInstanceStatus(enum.Enum):
|
@@ -184,14 +191,150 @@ def _get_head_instance_id(instances: List) -> Optional[str]:
|
|
184
191
|
return head_instance_id
|
185
192
|
|
186
193
|
|
187
|
-
def
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
194
|
+
def _create_network_interface(
|
195
|
+
network_client: 'azure_network.NetworkManagementClient', vm_name: str,
|
196
|
+
provider_config: Dict[str,
|
197
|
+
Any]) -> 'azure_network_models.NetworkInterface':
|
198
|
+
network = azure.azure_mgmt_models('network')
|
199
|
+
compute = azure.azure_mgmt_models('compute')
|
200
|
+
logger.info(f'Start creating network interface for {vm_name}...')
|
201
|
+
if provider_config.get('use_internal_ips', False):
|
202
|
+
name = f'{vm_name}-nic-private'
|
203
|
+
ip_config = network.IPConfiguration(
|
204
|
+
name=f'ip-config-private-{vm_name}',
|
205
|
+
subnet=compute.SubResource(id=provider_config['subnet']),
|
206
|
+
private_ip_allocation_method=network.IPAllocationMethod.DYNAMIC)
|
207
|
+
else:
|
208
|
+
name = f'{vm_name}-nic-public'
|
209
|
+
public_ip_address = network.PublicIPAddress(
|
210
|
+
location=provider_config['location'],
|
211
|
+
public_ip_allocation_method='Static',
|
212
|
+
public_ip_address_version='IPv4',
|
213
|
+
sku=network.PublicIPAddressSku(name='Basic', tier='Regional'))
|
214
|
+
ip_poller = network_client.public_ip_addresses.begin_create_or_update(
|
215
|
+
resource_group_name=provider_config['resource_group'],
|
216
|
+
public_ip_address_name=f'{vm_name}-ip',
|
217
|
+
parameters=public_ip_address)
|
218
|
+
logger.info(f'Created public IP address {ip_poller.result().name} '
|
219
|
+
f'with address {ip_poller.result().ip_address}.')
|
220
|
+
ip_config = network.IPConfiguration(
|
221
|
+
name=f'ip-config-public-{vm_name}',
|
222
|
+
subnet=compute.SubResource(id=provider_config['subnet']),
|
223
|
+
private_ip_allocation_method=network.IPAllocationMethod.DYNAMIC,
|
224
|
+
public_ip_address=network.PublicIPAddress(id=ip_poller.result().id))
|
225
|
+
|
226
|
+
ni_poller = network_client.network_interfaces.begin_create_or_update(
|
227
|
+
resource_group_name=provider_config['resource_group'],
|
228
|
+
network_interface_name=name,
|
229
|
+
parameters=network.NetworkInterface(
|
230
|
+
location=provider_config['location'],
|
231
|
+
ip_configurations=[ip_config],
|
232
|
+
network_security_group=network.NetworkSecurityGroup(
|
233
|
+
id=provider_config['nsg'])))
|
234
|
+
logger.info(f'Created network interface {ni_poller.result().name}.')
|
235
|
+
return ni_poller.result()
|
236
|
+
|
237
|
+
|
238
|
+
def _create_vm(
|
239
|
+
compute_client: 'azure_compute.ComputeManagementClient', vm_name: str,
|
240
|
+
node_tags: Dict[str, str], provider_config: Dict[str, Any],
|
241
|
+
node_config: Dict[str, Any],
|
242
|
+
network_interface_id: str) -> 'azure_compute_models.VirtualMachine':
|
243
|
+
compute = azure.azure_mgmt_models('compute')
|
244
|
+
logger.info(f'Start creating VM {vm_name}...')
|
245
|
+
hardware_profile = compute.HardwareProfile(
|
246
|
+
vm_size=node_config['azure_arm_parameters']['vmSize'])
|
247
|
+
network_profile = compute.NetworkProfile(network_interfaces=[
|
248
|
+
compute.NetworkInterfaceReference(id=network_interface_id, primary=True)
|
249
|
+
])
|
250
|
+
public_key = node_config['azure_arm_parameters']['publicKey']
|
251
|
+
username = node_config['azure_arm_parameters']['adminUsername']
|
252
|
+
os_linux_custom_data = base64.b64encode(
|
253
|
+
node_config['azure_arm_parameters']['cloudInitSetupCommands'].encode(
|
254
|
+
'utf-8')).decode('utf-8')
|
255
|
+
os_profile = compute.OSProfile(
|
256
|
+
admin_username=username,
|
257
|
+
computer_name=vm_name,
|
258
|
+
admin_password=public_key,
|
259
|
+
linux_configuration=compute.LinuxConfiguration(
|
260
|
+
disable_password_authentication=True,
|
261
|
+
ssh=compute.SshConfiguration(public_keys=[
|
262
|
+
compute.SshPublicKey(
|
263
|
+
path=f'/home/{username}/.ssh/authorized_keys',
|
264
|
+
key_data=public_key)
|
265
|
+
])),
|
266
|
+
custom_data=os_linux_custom_data)
|
267
|
+
community_image_id = node_config['azure_arm_parameters'].get(
|
268
|
+
'communityGalleryImageId', None)
|
269
|
+
if community_image_id is not None:
|
270
|
+
# Prioritize using community gallery image if specified.
|
271
|
+
image_reference = compute.ImageReference(
|
272
|
+
community_gallery_image_id=community_image_id)
|
273
|
+
logger.info(
|
274
|
+
f'Used community_image_id: {community_image_id} for VM {vm_name}.')
|
275
|
+
else:
|
276
|
+
image_reference = compute.ImageReference(
|
277
|
+
publisher=node_config['azure_arm_parameters']['imagePublisher'],
|
278
|
+
offer=node_config['azure_arm_parameters']['imageOffer'],
|
279
|
+
sku=node_config['azure_arm_parameters']['imageSku'],
|
280
|
+
version=node_config['azure_arm_parameters']['imageVersion'])
|
281
|
+
storage_profile = compute.StorageProfile(
|
282
|
+
image_reference=image_reference,
|
283
|
+
os_disk=compute.OSDisk(
|
284
|
+
create_option=compute.DiskCreateOptionTypes.FROM_IMAGE,
|
285
|
+
managed_disk=compute.ManagedDiskParameters(
|
286
|
+
storage_account_type=node_config['azure_arm_parameters']
|
287
|
+
['osDiskTier']),
|
288
|
+
disk_size_gb=node_config['azure_arm_parameters']['osDiskSizeGB']))
|
289
|
+
vm_instance = compute.VirtualMachine(
|
290
|
+
location=provider_config['location'],
|
291
|
+
tags=node_tags,
|
292
|
+
hardware_profile=hardware_profile,
|
293
|
+
os_profile=os_profile,
|
294
|
+
storage_profile=storage_profile,
|
295
|
+
network_profile=network_profile,
|
296
|
+
identity=compute.VirtualMachineIdentity(
|
297
|
+
type='UserAssigned',
|
298
|
+
user_assigned_identities={provider_config['msi']: {}}))
|
299
|
+
vm_poller = compute_client.virtual_machines.begin_create_or_update(
|
300
|
+
resource_group_name=provider_config['resource_group'],
|
301
|
+
vm_name=vm_name,
|
302
|
+
parameters=vm_instance,
|
303
|
+
)
|
304
|
+
# poller.result() will block on async operation until it's done.
|
305
|
+
logger.info(f'Created VM {vm_poller.result().name}.')
|
306
|
+
# Configure driver extension for A10 GPUs. A10 GPUs requires a
|
307
|
+
# special type of drivers which is available at Microsoft HPC
|
308
|
+
# extension. Reference:
|
309
|
+
# https://forums.developer.nvidia.com/t/ubuntu-22-04-installation-driver-error-nvidia-a10/285195/2
|
310
|
+
# This can take more than 20mins for setting up the A10 GPUs
|
311
|
+
if node_config.get('need_nvidia_driver_extension', False):
|
312
|
+
ext_poller = compute_client.virtual_machine_extensions.\
|
313
|
+
begin_create_or_update(
|
314
|
+
resource_group_name=provider_config['resource_group'],
|
315
|
+
vm_name=vm_name,
|
316
|
+
vm_extension_name='NvidiaGpuDriverLinux',
|
317
|
+
extension_parameters=compute.VirtualMachineExtension(
|
318
|
+
location=provider_config['location'],
|
319
|
+
publisher='Microsoft.HpcCompute',
|
320
|
+
type_properties_type='NvidiaGpuDriverLinux',
|
321
|
+
type_handler_version='1.9',
|
322
|
+
auto_upgrade_minor_version=True,
|
323
|
+
settings='{}'))
|
324
|
+
logger.info(
|
325
|
+
f'Created VM extension {ext_poller.result().name} for VM {vm_name}.'
|
326
|
+
)
|
327
|
+
return vm_poller.result()
|
328
|
+
|
329
|
+
|
330
|
+
def _create_instances(compute_client: 'azure_compute.ComputeManagementClient',
|
331
|
+
network_client: 'azure_network.NetworkManagementClient',
|
332
|
+
cluster_name_on_cloud: str, resource_group: str,
|
333
|
+
provider_config: Dict[str, Any], node_config: Dict[str,
|
334
|
+
Any],
|
335
|
+
tags: Dict[str, str], count: int) -> List:
|
193
336
|
vm_id = uuid4().hex[:UNIQUE_ID_LEN]
|
194
|
-
|
337
|
+
all_tags = {
|
195
338
|
constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
|
196
339
|
constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud,
|
197
340
|
**constants.WORKER_NODE_TAGS,
|
@@ -199,83 +342,19 @@ def _create_instances(
|
|
199
342
|
**tags,
|
200
343
|
}
|
201
344
|
node_tags = node_config['tags'].copy()
|
202
|
-
node_tags.update(
|
345
|
+
node_tags.update(all_tags)
|
203
346
|
|
204
|
-
#
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
347
|
+
# Create VM instances in parallel.
|
348
|
+
def create_single_instance(vm_i):
|
349
|
+
vm_name = f'{cluster_name_on_cloud}-{vm_id}-{vm_i}'
|
350
|
+
network_interface = _create_network_interface(network_client, vm_name,
|
351
|
+
provider_config)
|
352
|
+
_create_vm(compute_client, vm_name, node_tags, provider_config,
|
353
|
+
node_config, network_interface.id)
|
209
354
|
|
210
|
-
|
211
|
-
use_internal_ips = provider_config.get('use_internal_ips', False)
|
212
|
-
|
213
|
-
template_params = node_config['azure_arm_parameters'].copy()
|
214
|
-
# We don't include 'head' or 'worker' in the VM name as on Azure the VM
|
215
|
-
# name is immutable and we may change the node type for existing VM in the
|
216
|
-
# multi-node cluster, due to manual termination of the head node.
|
217
|
-
template_params['vmName'] = vm_name
|
218
|
-
template_params['provisionPublicIp'] = not use_internal_ips
|
219
|
-
template_params['vmTags'] = node_tags
|
220
|
-
template_params['vmCount'] = count
|
221
|
-
template_params['msi'] = provider_config['msi']
|
222
|
-
template_params['nsg'] = provider_config['nsg']
|
223
|
-
template_params['subnet'] = provider_config['subnet']
|
224
|
-
# In Azure, cloud-init script must be encoded in base64. For more
|
225
|
-
# information, see:
|
226
|
-
# https://learn.microsoft.com/en-us/azure/virtual-machines/custom-data
|
227
|
-
template_params['cloudInitSetupCommands'] = (base64.b64encode(
|
228
|
-
template_params['cloudInitSetupCommands'].encode('utf-8')).decode(
|
229
|
-
'utf-8'))
|
230
|
-
|
231
|
-
if node_config.get('need_nvidia_driver_extension', False):
|
232
|
-
# pylint: disable=line-too-long
|
233
|
-
# Configure driver extension for A10 GPUs. A10 GPUs requires a
|
234
|
-
# special type of drivers which is available at Microsoft HPC
|
235
|
-
# extension. Reference: https://forums.developer.nvidia.com/t/ubuntu-22-04-installation-driver-error-nvidia-a10/285195/2
|
236
|
-
for r in template['resources']:
|
237
|
-
if r['type'] == 'Microsoft.Compute/virtualMachines':
|
238
|
-
# Add a nested extension resource for A10 GPUs
|
239
|
-
r['resources'] = [
|
240
|
-
{
|
241
|
-
'type': 'extensions',
|
242
|
-
'apiVersion': '2015-06-15',
|
243
|
-
'location': '[variables(\'location\')]',
|
244
|
-
'dependsOn': [
|
245
|
-
'[concat(\'Microsoft.Compute/virtualMachines/\', parameters(\'vmName\'), copyIndex())]'
|
246
|
-
],
|
247
|
-
'name': 'NvidiaGpuDriverLinux',
|
248
|
-
'properties': {
|
249
|
-
'publisher': 'Microsoft.HpcCompute',
|
250
|
-
'type': 'NvidiaGpuDriverLinux',
|
251
|
-
'typeHandlerVersion': '1.9',
|
252
|
-
'autoUpgradeMinorVersion': True,
|
253
|
-
'settings': {},
|
254
|
-
},
|
255
|
-
},
|
256
|
-
]
|
257
|
-
break
|
258
|
-
|
259
|
-
parameters = {
|
260
|
-
'properties': {
|
261
|
-
'mode': azure.deployment_mode().incremental,
|
262
|
-
'template': template,
|
263
|
-
'parameters': {
|
264
|
-
key: {
|
265
|
-
'value': value
|
266
|
-
} for key, value in template_params.items()
|
267
|
-
},
|
268
|
-
}
|
269
|
-
}
|
270
|
-
|
271
|
-
create_or_update = _get_azure_sdk_function(
|
272
|
-
client=resource_client.deployments, function_name='create_or_update')
|
273
|
-
create_or_update(
|
274
|
-
resource_group_name=resource_group,
|
275
|
-
deployment_name=vm_name,
|
276
|
-
parameters=parameters,
|
277
|
-
).wait()
|
355
|
+
subprocess_utils.run_in_parallel(create_single_instance, range(count))
|
278
356
|
|
357
|
+
# Update disk performance tier
|
279
358
|
performance_tier = node_config.get('disk_performance_tier', None)
|
280
359
|
if performance_tier is not None:
|
281
360
|
disks = compute_client.disks.list_by_resource_group(resource_group)
|
@@ -286,12 +365,14 @@ def _create_instances(
|
|
286
365
|
f'az disk update -n {name} -g {resource_group} '
|
287
366
|
f'--set tier={performance_tier}')
|
288
367
|
|
368
|
+
# Validation
|
289
369
|
filters = {
|
290
370
|
constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
|
291
371
|
_TAG_SKYPILOT_VM_ID: vm_id
|
292
372
|
}
|
293
373
|
instances = _filter_instances(compute_client, resource_group, filters)
|
294
374
|
assert len(instances) == count, (len(instances), count)
|
375
|
+
|
295
376
|
return instances
|
296
377
|
|
297
378
|
|
@@ -303,7 +384,7 @@ def run_instances(region: str, cluster_name_on_cloud: str,
|
|
303
384
|
resource_group = provider_config['resource_group']
|
304
385
|
subscription_id = provider_config['subscription_id']
|
305
386
|
compute_client = azure.get_client('compute', subscription_id)
|
306
|
-
|
387
|
+
network_client = azure.get_client('network', subscription_id)
|
307
388
|
instances_to_resume = []
|
308
389
|
resumed_instance_ids: List[str] = []
|
309
390
|
created_instance_ids: List[str] = []
|
@@ -439,12 +520,11 @@ def run_instances(region: str, cluster_name_on_cloud: str,
|
|
439
520
|
to_start_count -= len(resumed_instance_ids)
|
440
521
|
|
441
522
|
if to_start_count > 0:
|
442
|
-
resource_client = azure.get_client('resource', subscription_id)
|
443
523
|
logger.debug(f'run_instances: Creating {to_start_count} instances.')
|
444
524
|
try:
|
445
525
|
created_instances = _create_instances(
|
446
526
|
compute_client=compute_client,
|
447
|
-
|
527
|
+
network_client=network_client,
|
448
528
|
cluster_name_on_cloud=cluster_name_on_cloud,
|
449
529
|
resource_group=resource_group,
|
450
530
|
provider_config=provider_config,
|
@@ -722,6 +802,32 @@ def query_instances(
|
|
722
802
|
return statuses
|
723
803
|
|
724
804
|
|
805
|
+
# TODO(Doyoung): _get_cluster_nsg can be remove this after 0.8.0 to ignore
|
806
|
+
# legacy nsg names.
|
807
|
+
def _get_cluster_nsg(network_client: Client, resource_group: str,
|
808
|
+
cluster_name_on_cloud: str) -> NetworkSecurityGroup:
|
809
|
+
"""Retrieve the NSG associated with the given name of the cluster."""
|
810
|
+
list_network_security_groups = _get_azure_sdk_function(
|
811
|
+
client=network_client.network_security_groups, function_name='list')
|
812
|
+
legacy_nsg_name = _LEGACY_NSG_NAME.format(
|
813
|
+
cluster_name_on_cloud=cluster_name_on_cloud)
|
814
|
+
second_legacy_nsg_name = _SECOND_LEGACY_NSG_NAME.format(
|
815
|
+
cluster_name_on_cloud=cluster_name_on_cloud)
|
816
|
+
_, nsg_name = config_lib.get_cluster_id_and_nsg_name(
|
817
|
+
resource_group=resource_group,
|
818
|
+
cluster_name_on_cloud=cluster_name_on_cloud)
|
819
|
+
possible_nsg_names = [nsg_name, legacy_nsg_name, second_legacy_nsg_name]
|
820
|
+
for nsg in list_network_security_groups(resource_group):
|
821
|
+
if nsg.name in possible_nsg_names:
|
822
|
+
return nsg
|
823
|
+
|
824
|
+
# Raise an error if no matching NSG is found
|
825
|
+
raise ValueError('Failed to find a matching NSG for cluster '
|
826
|
+
f'{cluster_name_on_cloud!r} in resource group '
|
827
|
+
f'{resource_group!r}. Expected NSG names were: '
|
828
|
+
f'{possible_nsg_names}.')
|
829
|
+
|
830
|
+
|
725
831
|
def open_ports(
|
726
832
|
cluster_name_on_cloud: str,
|
727
833
|
ports: List[str],
|
@@ -736,58 +842,66 @@ def open_ports(
|
|
736
842
|
update_network_security_groups = _get_azure_sdk_function(
|
737
843
|
client=network_client.network_security_groups,
|
738
844
|
function_name='create_or_update')
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
logger.warning(
|
753
|
-
f'Fails to wait for the creation of NSG {nsg.name} in '
|
754
|
-
f'{resource_group} within '
|
755
|
-
f'{_WAIT_CREATION_TIMEOUT_SECONDS} seconds. '
|
756
|
-
'Skip this NSG.')
|
757
|
-
backoff_time = backoff.current_backoff()
|
758
|
-
logger.info(f'NSG {nsg.name} is not created yet. Waiting for '
|
759
|
-
f'{backoff_time} seconds before checking again.')
|
760
|
-
time.sleep(backoff_time)
|
761
|
-
|
762
|
-
# Azure NSG rules have a priority field that determines the order
|
763
|
-
# in which they are applied. The priority must be unique across
|
764
|
-
# all inbound rules in one NSG.
|
765
|
-
priority = max(rule.priority
|
766
|
-
for rule in nsg.security_rules
|
767
|
-
if rule.direction == 'Inbound') + 1
|
768
|
-
nsg.security_rules.append(
|
769
|
-
azure.create_security_rule(
|
770
|
-
name=f'sky-ports-{cluster_name_on_cloud}-{priority}',
|
771
|
-
priority=priority,
|
772
|
-
protocol='Tcp',
|
773
|
-
access='Allow',
|
774
|
-
direction='Inbound',
|
775
|
-
source_address_prefix='*',
|
776
|
-
source_port_range='*',
|
777
|
-
destination_address_prefix='*',
|
778
|
-
destination_port_ranges=ports,
|
779
|
-
))
|
780
|
-
poller = update_network_security_groups(resource_group, nsg.name,
|
781
|
-
nsg)
|
782
|
-
poller.wait()
|
783
|
-
if poller.status() != 'Succeeded':
|
845
|
+
|
846
|
+
try:
|
847
|
+
# Wait for the NSG creation to be finished before opening a port. The
|
848
|
+
# cluster provisioning triggers the NSG creation, but it may not be
|
849
|
+
# finished yet.
|
850
|
+
backoff = common_utils.Backoff(max_backoff_factor=1)
|
851
|
+
start_time = time.time()
|
852
|
+
while True:
|
853
|
+
nsg = _get_cluster_nsg(network_client, resource_group,
|
854
|
+
cluster_name_on_cloud)
|
855
|
+
if nsg.provisioning_state not in ['Creating', 'Updating']:
|
856
|
+
break
|
857
|
+
if time.time() - start_time > _WAIT_CREATION_TIMEOUT_SECONDS:
|
784
858
|
with ux_utils.print_exception_no_traceback():
|
785
|
-
raise
|
786
|
-
|
787
|
-
|
859
|
+
raise TimeoutError(
|
860
|
+
f'Timed out while waiting for the Network '
|
861
|
+
f'Security Group {nsg.name!r} to be ready for '
|
862
|
+
f'cluster {cluster_name_on_cloud!r} in '
|
863
|
+
f'resource group {resource_group!r}. The NSG '
|
864
|
+
f'did not reach a stable state '
|
865
|
+
'(Creating/Updating) within the allocated '
|
866
|
+
f'{_WAIT_CREATION_TIMEOUT_SECONDS} seconds. '
|
867
|
+
'Consequently, the operation to open ports '
|
868
|
+
f'{ports} failed.')
|
869
|
+
|
870
|
+
backoff_time = backoff.current_backoff()
|
871
|
+
logger.info(f'NSG {nsg.name} is not created yet. Waiting for '
|
872
|
+
f'{backoff_time} seconds before checking again.')
|
873
|
+
time.sleep(backoff_time)
|
874
|
+
|
875
|
+
# Azure NSG rules have a priority field that determines the order
|
876
|
+
# in which they are applied. The priority must be unique across
|
877
|
+
# all inbound rules in one NSG.
|
878
|
+
priority = max(rule.priority
|
879
|
+
for rule in nsg.security_rules
|
880
|
+
if rule.direction == 'Inbound') + 1
|
881
|
+
nsg.security_rules.append(
|
882
|
+
azure.create_security_rule(
|
883
|
+
name=f'sky-ports-{cluster_name_on_cloud}-{priority}',
|
884
|
+
priority=priority,
|
885
|
+
protocol='Tcp',
|
886
|
+
access='Allow',
|
887
|
+
direction='Inbound',
|
888
|
+
source_address_prefix='*',
|
889
|
+
source_port_range='*',
|
890
|
+
destination_address_prefix='*',
|
891
|
+
destination_port_ranges=ports,
|
892
|
+
))
|
893
|
+
poller = update_network_security_groups(resource_group, nsg.name, nsg)
|
894
|
+
poller.wait()
|
895
|
+
if poller.status() != 'Succeeded':
|
788
896
|
with ux_utils.print_exception_no_traceback():
|
789
|
-
raise ValueError(
|
790
|
-
|
897
|
+
raise ValueError(f'Failed to open ports {ports} in NSG '
|
898
|
+
f'{nsg.name}: {poller.status()}')
|
899
|
+
|
900
|
+
except azure.exceptions().HttpResponseError as e:
|
901
|
+
with ux_utils.print_exception_no_traceback():
|
902
|
+
raise ValueError(f'Failed to open ports {ports} in NSG for cluster '
|
903
|
+
f'{cluster_name_on_cloud!r} within resource group '
|
904
|
+
f'{resource_group!r}.') from e
|
791
905
|
|
792
906
|
|
793
907
|
def cleanup_ports(
|
@@ -18,6 +18,7 @@ from sky.provision.kubernetes import utils as kubernetes_utils
|
|
18
18
|
from sky.utils import command_runner
|
19
19
|
from sky.utils import common_utils
|
20
20
|
from sky.utils import kubernetes_enums
|
21
|
+
from sky.utils import subprocess_utils
|
21
22
|
from sky.utils import ux_utils
|
22
23
|
|
23
24
|
POLL_INTERVAL = 2
|
@@ -398,8 +399,7 @@ def _setup_ssh_in_pods(namespace: str, context: Optional[str],
|
|
398
399
|
# See https://www.educative.io/answers/error-mesg-ttyname-failed-inappropriate-ioctl-for-device # pylint: disable=line-too-long
|
399
400
|
'$(prefix_cmd) sed -i "s/mesg n/tty -s \\&\\& mesg n/" ~/.profile;')
|
400
401
|
|
401
|
-
|
402
|
-
for new_node in new_nodes:
|
402
|
+
def _setup_ssh_thread(new_node):
|
403
403
|
pod_name = new_node.metadata.name
|
404
404
|
runner = command_runner.KubernetesCommandRunner(
|
405
405
|
((namespace, context), pod_name))
|
@@ -411,6 +411,8 @@ def _setup_ssh_in_pods(namespace: str, context: Optional[str],
|
|
411
411
|
stdout)
|
412
412
|
logger.info(f'{"-"*20}End: Set up SSH in pod {pod_name!r} {"-"*20}')
|
413
413
|
|
414
|
+
subprocess_utils.run_in_parallel(_setup_ssh_thread, new_nodes)
|
415
|
+
|
414
416
|
|
415
417
|
def _label_pod(namespace: str, context: Optional[str], pod_name: str,
|
416
418
|
label: Dict[str, str]) -> None:
|