skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. sky/__init__.py +64 -32
  2. sky/adaptors/aws.py +23 -6
  3. sky/adaptors/azure.py +432 -15
  4. sky/adaptors/cloudflare.py +5 -5
  5. sky/adaptors/common.py +19 -9
  6. sky/adaptors/do.py +20 -0
  7. sky/adaptors/gcp.py +3 -2
  8. sky/adaptors/kubernetes.py +122 -88
  9. sky/adaptors/nebius.py +100 -0
  10. sky/adaptors/oci.py +39 -1
  11. sky/adaptors/vast.py +29 -0
  12. sky/admin_policy.py +101 -0
  13. sky/authentication.py +117 -98
  14. sky/backends/backend.py +52 -20
  15. sky/backends/backend_utils.py +669 -557
  16. sky/backends/cloud_vm_ray_backend.py +1099 -808
  17. sky/backends/local_docker_backend.py +14 -8
  18. sky/backends/wheel_utils.py +38 -20
  19. sky/benchmark/benchmark_utils.py +22 -23
  20. sky/check.py +76 -27
  21. sky/cli.py +1586 -1139
  22. sky/client/__init__.py +1 -0
  23. sky/client/cli.py +5683 -0
  24. sky/client/common.py +345 -0
  25. sky/client/sdk.py +1765 -0
  26. sky/cloud_stores.py +283 -19
  27. sky/clouds/__init__.py +7 -2
  28. sky/clouds/aws.py +303 -112
  29. sky/clouds/azure.py +185 -179
  30. sky/clouds/cloud.py +115 -37
  31. sky/clouds/cudo.py +29 -22
  32. sky/clouds/do.py +313 -0
  33. sky/clouds/fluidstack.py +44 -54
  34. sky/clouds/gcp.py +206 -65
  35. sky/clouds/ibm.py +26 -21
  36. sky/clouds/kubernetes.py +345 -91
  37. sky/clouds/lambda_cloud.py +40 -29
  38. sky/clouds/nebius.py +297 -0
  39. sky/clouds/oci.py +129 -90
  40. sky/clouds/paperspace.py +22 -18
  41. sky/clouds/runpod.py +53 -34
  42. sky/clouds/scp.py +28 -24
  43. sky/clouds/service_catalog/__init__.py +19 -13
  44. sky/clouds/service_catalog/aws_catalog.py +29 -12
  45. sky/clouds/service_catalog/azure_catalog.py +33 -6
  46. sky/clouds/service_catalog/common.py +95 -75
  47. sky/clouds/service_catalog/constants.py +3 -3
  48. sky/clouds/service_catalog/cudo_catalog.py +13 -3
  49. sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
  50. sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
  51. sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
  52. sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
  53. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
  54. sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
  55. sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
  56. sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
  57. sky/clouds/service_catalog/do_catalog.py +111 -0
  58. sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
  59. sky/clouds/service_catalog/gcp_catalog.py +16 -2
  60. sky/clouds/service_catalog/ibm_catalog.py +2 -2
  61. sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
  62. sky/clouds/service_catalog/lambda_catalog.py +8 -3
  63. sky/clouds/service_catalog/nebius_catalog.py +116 -0
  64. sky/clouds/service_catalog/oci_catalog.py +31 -4
  65. sky/clouds/service_catalog/paperspace_catalog.py +2 -2
  66. sky/clouds/service_catalog/runpod_catalog.py +2 -2
  67. sky/clouds/service_catalog/scp_catalog.py +2 -2
  68. sky/clouds/service_catalog/vast_catalog.py +104 -0
  69. sky/clouds/service_catalog/vsphere_catalog.py +2 -2
  70. sky/clouds/utils/aws_utils.py +65 -0
  71. sky/clouds/utils/azure_utils.py +91 -0
  72. sky/clouds/utils/gcp_utils.py +5 -9
  73. sky/clouds/utils/oci_utils.py +47 -5
  74. sky/clouds/utils/scp_utils.py +4 -3
  75. sky/clouds/vast.py +280 -0
  76. sky/clouds/vsphere.py +22 -18
  77. sky/core.py +361 -107
  78. sky/dag.py +41 -28
  79. sky/data/data_transfer.py +37 -0
  80. sky/data/data_utils.py +211 -32
  81. sky/data/mounting_utils.py +182 -30
  82. sky/data/storage.py +2118 -270
  83. sky/data/storage_utils.py +126 -5
  84. sky/exceptions.py +179 -8
  85. sky/execution.py +158 -85
  86. sky/global_user_state.py +150 -34
  87. sky/jobs/__init__.py +12 -10
  88. sky/jobs/client/__init__.py +0 -0
  89. sky/jobs/client/sdk.py +302 -0
  90. sky/jobs/constants.py +49 -11
  91. sky/jobs/controller.py +161 -99
  92. sky/jobs/dashboard/dashboard.py +171 -25
  93. sky/jobs/dashboard/templates/index.html +572 -60
  94. sky/jobs/recovery_strategy.py +157 -156
  95. sky/jobs/scheduler.py +307 -0
  96. sky/jobs/server/__init__.py +1 -0
  97. sky/jobs/server/core.py +598 -0
  98. sky/jobs/server/dashboard_utils.py +69 -0
  99. sky/jobs/server/server.py +190 -0
  100. sky/jobs/state.py +627 -122
  101. sky/jobs/utils.py +615 -206
  102. sky/models.py +27 -0
  103. sky/optimizer.py +142 -83
  104. sky/provision/__init__.py +20 -5
  105. sky/provision/aws/config.py +124 -42
  106. sky/provision/aws/instance.py +130 -53
  107. sky/provision/azure/__init__.py +7 -0
  108. sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
  109. sky/provision/azure/config.py +220 -0
  110. sky/provision/azure/instance.py +1012 -37
  111. sky/provision/common.py +31 -3
  112. sky/provision/constants.py +25 -0
  113. sky/provision/cudo/__init__.py +2 -1
  114. sky/provision/cudo/cudo_utils.py +112 -0
  115. sky/provision/cudo/cudo_wrapper.py +37 -16
  116. sky/provision/cudo/instance.py +28 -12
  117. sky/provision/do/__init__.py +11 -0
  118. sky/provision/do/config.py +14 -0
  119. sky/provision/do/constants.py +10 -0
  120. sky/provision/do/instance.py +287 -0
  121. sky/provision/do/utils.py +301 -0
  122. sky/provision/docker_utils.py +82 -46
  123. sky/provision/fluidstack/fluidstack_utils.py +57 -125
  124. sky/provision/fluidstack/instance.py +15 -43
  125. sky/provision/gcp/config.py +19 -9
  126. sky/provision/gcp/constants.py +7 -1
  127. sky/provision/gcp/instance.py +55 -34
  128. sky/provision/gcp/instance_utils.py +339 -80
  129. sky/provision/gcp/mig_utils.py +210 -0
  130. sky/provision/instance_setup.py +172 -133
  131. sky/provision/kubernetes/__init__.py +1 -0
  132. sky/provision/kubernetes/config.py +104 -90
  133. sky/provision/kubernetes/constants.py +8 -0
  134. sky/provision/kubernetes/instance.py +680 -325
  135. sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
  136. sky/provision/kubernetes/network.py +54 -20
  137. sky/provision/kubernetes/network_utils.py +70 -21
  138. sky/provision/kubernetes/utils.py +1370 -251
  139. sky/provision/lambda_cloud/__init__.py +11 -0
  140. sky/provision/lambda_cloud/config.py +10 -0
  141. sky/provision/lambda_cloud/instance.py +265 -0
  142. sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
  143. sky/provision/logging.py +1 -1
  144. sky/provision/nebius/__init__.py +11 -0
  145. sky/provision/nebius/config.py +11 -0
  146. sky/provision/nebius/instance.py +285 -0
  147. sky/provision/nebius/utils.py +318 -0
  148. sky/provision/oci/__init__.py +15 -0
  149. sky/provision/oci/config.py +51 -0
  150. sky/provision/oci/instance.py +436 -0
  151. sky/provision/oci/query_utils.py +681 -0
  152. sky/provision/paperspace/constants.py +6 -0
  153. sky/provision/paperspace/instance.py +4 -3
  154. sky/provision/paperspace/utils.py +2 -0
  155. sky/provision/provisioner.py +207 -130
  156. sky/provision/runpod/__init__.py +1 -0
  157. sky/provision/runpod/api/__init__.py +3 -0
  158. sky/provision/runpod/api/commands.py +119 -0
  159. sky/provision/runpod/api/pods.py +142 -0
  160. sky/provision/runpod/instance.py +64 -8
  161. sky/provision/runpod/utils.py +239 -23
  162. sky/provision/vast/__init__.py +10 -0
  163. sky/provision/vast/config.py +11 -0
  164. sky/provision/vast/instance.py +247 -0
  165. sky/provision/vast/utils.py +162 -0
  166. sky/provision/vsphere/common/vim_utils.py +1 -1
  167. sky/provision/vsphere/instance.py +8 -18
  168. sky/provision/vsphere/vsphere_utils.py +1 -1
  169. sky/resources.py +247 -102
  170. sky/serve/__init__.py +9 -9
  171. sky/serve/autoscalers.py +361 -299
  172. sky/serve/client/__init__.py +0 -0
  173. sky/serve/client/sdk.py +366 -0
  174. sky/serve/constants.py +12 -3
  175. sky/serve/controller.py +106 -36
  176. sky/serve/load_balancer.py +63 -12
  177. sky/serve/load_balancing_policies.py +84 -2
  178. sky/serve/replica_managers.py +42 -34
  179. sky/serve/serve_state.py +62 -32
  180. sky/serve/serve_utils.py +271 -160
  181. sky/serve/server/__init__.py +0 -0
  182. sky/serve/{core.py → server/core.py} +271 -90
  183. sky/serve/server/server.py +112 -0
  184. sky/serve/service.py +52 -16
  185. sky/serve/service_spec.py +95 -32
  186. sky/server/__init__.py +1 -0
  187. sky/server/common.py +430 -0
  188. sky/server/constants.py +21 -0
  189. sky/server/html/log.html +174 -0
  190. sky/server/requests/__init__.py +0 -0
  191. sky/server/requests/executor.py +472 -0
  192. sky/server/requests/payloads.py +487 -0
  193. sky/server/requests/queues/__init__.py +0 -0
  194. sky/server/requests/queues/mp_queue.py +76 -0
  195. sky/server/requests/requests.py +567 -0
  196. sky/server/requests/serializers/__init__.py +0 -0
  197. sky/server/requests/serializers/decoders.py +192 -0
  198. sky/server/requests/serializers/encoders.py +166 -0
  199. sky/server/server.py +1106 -0
  200. sky/server/stream_utils.py +141 -0
  201. sky/setup_files/MANIFEST.in +2 -5
  202. sky/setup_files/dependencies.py +159 -0
  203. sky/setup_files/setup.py +14 -125
  204. sky/sky_logging.py +59 -14
  205. sky/skylet/autostop_lib.py +2 -2
  206. sky/skylet/constants.py +183 -50
  207. sky/skylet/events.py +22 -10
  208. sky/skylet/job_lib.py +403 -258
  209. sky/skylet/log_lib.py +111 -71
  210. sky/skylet/log_lib.pyi +6 -0
  211. sky/skylet/providers/command_runner.py +6 -8
  212. sky/skylet/providers/ibm/node_provider.py +2 -2
  213. sky/skylet/providers/scp/config.py +11 -3
  214. sky/skylet/providers/scp/node_provider.py +8 -8
  215. sky/skylet/skylet.py +3 -1
  216. sky/skylet/subprocess_daemon.py +69 -17
  217. sky/skypilot_config.py +119 -57
  218. sky/task.py +205 -64
  219. sky/templates/aws-ray.yml.j2 +37 -7
  220. sky/templates/azure-ray.yml.j2 +27 -82
  221. sky/templates/cudo-ray.yml.j2 +7 -3
  222. sky/templates/do-ray.yml.j2 +98 -0
  223. sky/templates/fluidstack-ray.yml.j2 +7 -4
  224. sky/templates/gcp-ray.yml.j2 +26 -6
  225. sky/templates/ibm-ray.yml.j2 +3 -2
  226. sky/templates/jobs-controller.yaml.j2 +46 -11
  227. sky/templates/kubernetes-ingress.yml.j2 +7 -0
  228. sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
  229. sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
  230. sky/templates/kubernetes-ray.yml.j2 +292 -25
  231. sky/templates/lambda-ray.yml.j2 +30 -40
  232. sky/templates/nebius-ray.yml.j2 +79 -0
  233. sky/templates/oci-ray.yml.j2 +18 -57
  234. sky/templates/paperspace-ray.yml.j2 +10 -6
  235. sky/templates/runpod-ray.yml.j2 +26 -4
  236. sky/templates/scp-ray.yml.j2 +3 -2
  237. sky/templates/sky-serve-controller.yaml.j2 +12 -1
  238. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  239. sky/templates/vast-ray.yml.j2 +70 -0
  240. sky/templates/vsphere-ray.yml.j2 +8 -3
  241. sky/templates/websocket_proxy.py +64 -0
  242. sky/usage/constants.py +10 -1
  243. sky/usage/usage_lib.py +130 -37
  244. sky/utils/accelerator_registry.py +35 -51
  245. sky/utils/admin_policy_utils.py +147 -0
  246. sky/utils/annotations.py +51 -0
  247. sky/utils/cli_utils/status_utils.py +81 -23
  248. sky/utils/cluster_utils.py +356 -0
  249. sky/utils/command_runner.py +452 -89
  250. sky/utils/command_runner.pyi +77 -3
  251. sky/utils/common.py +54 -0
  252. sky/utils/common_utils.py +319 -108
  253. sky/utils/config_utils.py +204 -0
  254. sky/utils/control_master_utils.py +48 -0
  255. sky/utils/controller_utils.py +548 -266
  256. sky/utils/dag_utils.py +93 -32
  257. sky/utils/db_utils.py +18 -4
  258. sky/utils/env_options.py +29 -7
  259. sky/utils/kubernetes/create_cluster.sh +8 -60
  260. sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
  261. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  262. sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
  263. sky/utils/kubernetes/gpu_labeler.py +4 -4
  264. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
  265. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  266. sky/utils/kubernetes/rsync_helper.sh +24 -0
  267. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
  268. sky/utils/log_utils.py +240 -33
  269. sky/utils/message_utils.py +81 -0
  270. sky/utils/registry.py +127 -0
  271. sky/utils/resources_utils.py +94 -22
  272. sky/utils/rich_utils.py +247 -18
  273. sky/utils/schemas.py +284 -64
  274. sky/{status_lib.py → utils/status_lib.py} +12 -7
  275. sky/utils/subprocess_utils.py +212 -46
  276. sky/utils/timeline.py +12 -7
  277. sky/utils/ux_utils.py +168 -15
  278. skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
  279. skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
  280. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
  281. sky/clouds/cloud_registry.py +0 -31
  282. sky/jobs/core.py +0 -330
  283. sky/skylet/providers/azure/__init__.py +0 -2
  284. sky/skylet/providers/azure/azure-vm-template.json +0 -301
  285. sky/skylet/providers/azure/config.py +0 -170
  286. sky/skylet/providers/azure/node_provider.py +0 -466
  287. sky/skylet/providers/lambda_cloud/__init__.py +0 -2
  288. sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
  289. sky/skylet/providers/oci/__init__.py +0 -2
  290. sky/skylet/providers/oci/node_provider.py +0 -488
  291. sky/skylet/providers/oci/query_helper.py +0 -383
  292. sky/skylet/providers/oci/utils.py +0 -21
  293. sky/utils/cluster_yaml_utils.py +0 -24
  294. sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
  295. skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
  296. skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
  297. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
  298. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
  299. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,142 @@
1
1
  """Azure instance provisioning."""
2
+ import base64
3
+ import copy
4
+ import enum
2
5
  import logging
3
- from typing import Any, Callable, Dict, List, Optional
6
+ from multiprocessing import pool
7
+ import time
8
+ import typing
9
+ from typing import Any, Callable, Dict, List, Optional, Tuple
10
+ from uuid import uuid4
4
11
 
12
+ from sky import exceptions
5
13
  from sky import sky_logging
6
14
  from sky.adaptors import azure
15
+ from sky.provision import common
16
+ from sky.provision import constants
17
+ from sky.utils import common_utils
18
+ from sky.utils import status_lib
19
+ from sky.utils import subprocess_utils
7
20
  from sky.utils import ux_utils
8
21
 
22
+ if typing.TYPE_CHECKING:
23
+ from azure.mgmt import compute as azure_compute
24
+ from azure.mgmt import network as azure_network
25
+ from azure.mgmt.compute import models as azure_compute_models
26
+ from azure.mgmt.network import models as azure_network_models
27
+
9
28
  logger = sky_logging.init_logger(__name__)
10
29
 
11
30
  # Suppress noisy logs from Azure SDK. Reference:
12
31
  # https://github.com/Azure/azure-sdk-for-python/issues/9422
13
32
  azure_logger = logging.getLogger('azure')
14
33
  azure_logger.setLevel(logging.WARNING)
34
+ Client = Any
35
+ NetworkSecurityGroup = Any
36
+
37
+ _RESUME_INSTANCE_TIMEOUT = 480 # 8 minutes
38
+ _RESUME_PER_INSTANCE_TIMEOUT = 120 # 2 minutes
39
+ UNIQUE_ID_LEN = 4
40
+ _TAG_SKYPILOT_VM_ID = 'skypilot-vm-id'
41
+ _WAIT_CREATION_TIMEOUT_SECONDS = 600
42
+
43
+ _RESOURCE_MANAGED_IDENTITY_TYPE = (
44
+ 'Microsoft.ManagedIdentity/userAssignedIdentities')
45
+ _RESOURCE_NETWORK_SECURITY_GROUP_TYPE = (
46
+ 'Microsoft.Network/networkSecurityGroups')
47
+ _RESOURCE_VIRTUAL_NETWORK_TYPE = 'Microsoft.Network/virtualNetworks'
48
+ _RESOURCE_PUBLIC_IP_ADDRESS_TYPE = 'Microsoft.Network/publicIPAddresses'
49
+ _RESOURCE_VIRTUAL_MACHINE_TYPE = 'Microsoft.Compute/virtualMachines'
50
+ _RESOURCE_NETWORK_INTERFACE_TYPE = 'Microsoft.Network/networkInterfaces'
51
+
52
+ _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE = 'ResourceGroupNotFound'
53
+ _POLL_INTERVAL = 1
54
+
15
55
 
16
- # Tag uniquely identifying all nodes of a cluster
17
- TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'
18
- TAG_RAY_NODE_KIND = 'ray-node-type'
56
+ class AzureInstanceStatus(enum.Enum):
57
+ """Statuses enum for Azure instances with power and provisioning states."""
58
+ PENDING = 'pending'
59
+ RUNNING = 'running'
60
+ STOPPING = 'stopping'
61
+ STOPPED = 'stopped'
62
+ DELETING = 'deleting'
19
63
 
64
+ @classmethod
65
+ def power_state_map(cls) -> Dict[str, 'AzureInstanceStatus']:
66
+ return {
67
+ 'starting': cls.PENDING,
68
+ 'running': cls.RUNNING,
69
+ # 'stopped' in Azure means Stopped (Allocated), which still bills
70
+ # for the VM.
71
+ 'stopping': cls.STOPPING,
72
+ 'stopped': cls.STOPPED,
73
+ # 'VM deallocated' in Azure means Stopped (Deallocated), which does
74
+ # not bill for the VM.
75
+ 'deallocating': cls.STOPPING,
76
+ 'deallocated': cls.STOPPED,
77
+ }
20
78
 
21
- def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
79
+ @classmethod
80
+ def provisioning_state_map(cls) -> Dict[str, 'AzureInstanceStatus']:
81
+ return {
82
+ 'Creating': cls.PENDING,
83
+ 'Updating': cls.PENDING,
84
+ 'Failed': cls.PENDING,
85
+ 'Migrating': cls.PENDING,
86
+ 'Deleting': cls.DELETING,
87
+ # Succeeded in provisioning state means the VM is provisioned but
88
+ # not necessarily running. The caller should further check the
89
+ # power state to determine the actual VM status.
90
+ 'Succeeded': cls.RUNNING,
91
+ }
92
+
93
+ @classmethod
94
+ def cluster_status_map(
95
+ cls
96
+ ) -> Dict['AzureInstanceStatus', Optional[status_lib.ClusterStatus]]:
97
+ return {
98
+ cls.PENDING: status_lib.ClusterStatus.INIT,
99
+ cls.RUNNING: status_lib.ClusterStatus.UP,
100
+ cls.STOPPING: status_lib.ClusterStatus.STOPPED,
101
+ cls.STOPPED: status_lib.ClusterStatus.STOPPED,
102
+ cls.DELETING: None,
103
+ }
104
+
105
+ @classmethod
106
+ def from_raw_states(cls, provisioning_state: str,
107
+ power_state: Optional[str]) -> 'AzureInstanceStatus':
108
+ provisioning_state_map = cls.provisioning_state_map()
109
+ power_state_map = cls.power_state_map()
110
+ status = None
111
+ if power_state is None:
112
+ if provisioning_state not in provisioning_state_map:
113
+ with ux_utils.print_exception_no_traceback():
114
+ raise exceptions.ClusterStatusFetchingError(
115
+ 'Failed to parse status from Azure response: '
116
+ f'{provisioning_state}')
117
+ status = provisioning_state_map[provisioning_state]
118
+ if status is None or status == cls.RUNNING:
119
+ # We should further check the power state to determine the actual
120
+ # VM status.
121
+ if power_state not in power_state_map:
122
+ with ux_utils.print_exception_no_traceback():
123
+ raise exceptions.ClusterStatusFetchingError(
124
+ 'Failed to parse status from Azure response: '
125
+ f'{power_state}.')
126
+ status = power_state_map[power_state]
127
+ if status is None:
128
+ with ux_utils.print_exception_no_traceback():
129
+ raise exceptions.ClusterStatusFetchingError(
130
+ 'Failed to parse status from Azure response: '
131
+ f'provisioning state ({provisioning_state}), '
132
+ f'power state ({power_state})')
133
+ return status
134
+
135
+ def to_cluster_status(self) -> Optional[status_lib.ClusterStatus]:
136
+ return self.cluster_status_map().get(self)
137
+
138
+
139
+ def _get_azure_sdk_function(client: Any, function_name: str) -> Callable:
22
140
  """Retrieve a callable function from Azure SDK client object.
23
141
 
24
142
  Newer versions of the various client SDKs renamed function names to
@@ -35,6 +153,835 @@ def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
35
153
  return func
36
154
 
37
155
 
156
+ def _get_instance_ips(network_client, vm, resource_group: str,
157
+ use_internal_ips: bool) -> Tuple[str, Optional[str]]:
158
+ nic_id = vm.network_profile.network_interfaces[0].id
159
+ nic_name = nic_id.split('/')[-1]
160
+ nic = network_client.network_interfaces.get(
161
+ resource_group_name=resource_group,
162
+ network_interface_name=nic_name,
163
+ )
164
+ ip_config = nic.ip_configurations[0]
165
+
166
+ external_ip = None
167
+ if not use_internal_ips:
168
+ public_ip_id = ip_config.public_ip_address.id
169
+ public_ip_name = public_ip_id.split('/')[-1]
170
+ public_ip = network_client.public_ip_addresses.get(
171
+ resource_group_name=resource_group,
172
+ public_ip_address_name=public_ip_name,
173
+ )
174
+ external_ip = public_ip.ip_address
175
+
176
+ internal_ip = ip_config.private_ip_address
177
+
178
+ return (internal_ip, external_ip)
179
+
180
+
181
+ def _get_head_instance_id(instances: List) -> Optional[str]:
182
+ head_instance_id = None
183
+ head_node_tags = tuple(constants.HEAD_NODE_TAGS.items())
184
+ for inst in instances:
185
+ for k, v in inst.tags.items():
186
+ if (k, v) in head_node_tags:
187
+ if head_instance_id is not None:
188
+ logger.warning(
189
+ 'There are multiple head nodes in the cluster '
190
+ f'(current head instance id: {head_instance_id}, '
191
+ f'newly discovered id: {inst.name}). It is likely '
192
+ f'that something goes wrong.')
193
+ head_instance_id = inst.name
194
+ break
195
+ return head_instance_id
196
+
197
+
198
+ def _create_network_interface(
199
+ network_client: 'azure_network.NetworkManagementClient', vm_name: str,
200
+ provider_config: Dict[str,
201
+ Any]) -> 'azure_network_models.NetworkInterface':
202
+ network = azure.azure_mgmt_models('network')
203
+ compute = azure.azure_mgmt_models('compute')
204
+ logger.info(f'Start creating network interface for {vm_name}...')
205
+ if provider_config.get('use_internal_ips', False):
206
+ name = f'{vm_name}-nic-private'
207
+ ip_config = network.IPConfiguration(
208
+ name=f'ip-config-private-{vm_name}',
209
+ subnet=compute.SubResource(id=provider_config['subnet']),
210
+ private_ip_allocation_method=network.IPAllocationMethod.DYNAMIC)
211
+ else:
212
+ name = f'{vm_name}-nic-public'
213
+ public_ip_address = network.PublicIPAddress(
214
+ location=provider_config['location'],
215
+ public_ip_allocation_method='Static',
216
+ public_ip_address_version='IPv4',
217
+ sku=network.PublicIPAddressSku(name='Basic', tier='Regional'))
218
+ ip_poller = network_client.public_ip_addresses.begin_create_or_update(
219
+ resource_group_name=provider_config['resource_group'],
220
+ public_ip_address_name=f'{vm_name}-ip',
221
+ parameters=public_ip_address)
222
+ logger.info(f'Created public IP address {ip_poller.result().name} '
223
+ f'with address {ip_poller.result().ip_address}.')
224
+ ip_config = network.IPConfiguration(
225
+ name=f'ip-config-public-{vm_name}',
226
+ subnet=compute.SubResource(id=provider_config['subnet']),
227
+ private_ip_allocation_method=network.IPAllocationMethod.DYNAMIC,
228
+ public_ip_address=network.PublicIPAddress(id=ip_poller.result().id))
229
+
230
+ ni_poller = network_client.network_interfaces.begin_create_or_update(
231
+ resource_group_name=provider_config['resource_group'],
232
+ network_interface_name=name,
233
+ parameters=network.NetworkInterface(
234
+ location=provider_config['location'],
235
+ ip_configurations=[ip_config],
236
+ network_security_group=network.NetworkSecurityGroup(
237
+ id=provider_config['nsg'])))
238
+ logger.info(f'Created network interface {ni_poller.result().name}.')
239
+ return ni_poller.result()
240
+
241
+
242
+ def _create_vm(
243
+ compute_client: 'azure_compute.ComputeManagementClient', vm_name: str,
244
+ node_tags: Dict[str, str], provider_config: Dict[str, Any],
245
+ node_config: Dict[str, Any],
246
+ network_interface_id: str) -> 'azure_compute_models.VirtualMachine':
247
+ compute = azure.azure_mgmt_models('compute')
248
+ logger.info(f'Start creating VM {vm_name}...')
249
+ hardware_profile = compute.HardwareProfile(
250
+ vm_size=node_config['azure_arm_parameters']['vmSize'])
251
+ network_profile = compute.NetworkProfile(network_interfaces=[
252
+ compute.NetworkInterfaceReference(id=network_interface_id, primary=True)
253
+ ])
254
+ public_key = node_config['azure_arm_parameters']['publicKey']
255
+ username = node_config['azure_arm_parameters']['adminUsername']
256
+ os_linux_custom_data = base64.b64encode(
257
+ node_config['azure_arm_parameters']['cloudInitSetupCommands'].encode(
258
+ 'utf-8')).decode('utf-8')
259
+ os_profile = compute.OSProfile(
260
+ admin_username=username,
261
+ computer_name=vm_name,
262
+ admin_password=public_key,
263
+ linux_configuration=compute.LinuxConfiguration(
264
+ disable_password_authentication=True,
265
+ ssh=compute.SshConfiguration(public_keys=[
266
+ compute.SshPublicKey(
267
+ path=f'/home/{username}/.ssh/authorized_keys',
268
+ key_data=public_key)
269
+ ])),
270
+ custom_data=os_linux_custom_data)
271
+ community_image_id = node_config['azure_arm_parameters'].get(
272
+ 'communityGalleryImageId', None)
273
+ if community_image_id is not None:
274
+ # Prioritize using community gallery image if specified.
275
+ image_reference = compute.ImageReference(
276
+ community_gallery_image_id=community_image_id)
277
+ logger.info(
278
+ f'Used community_image_id: {community_image_id} for VM {vm_name}.')
279
+ else:
280
+ image_reference = compute.ImageReference(
281
+ publisher=node_config['azure_arm_parameters']['imagePublisher'],
282
+ offer=node_config['azure_arm_parameters']['imageOffer'],
283
+ sku=node_config['azure_arm_parameters']['imageSku'],
284
+ version=node_config['azure_arm_parameters']['imageVersion'])
285
+ storage_profile = compute.StorageProfile(
286
+ image_reference=image_reference,
287
+ os_disk=compute.OSDisk(
288
+ create_option=compute.DiskCreateOptionTypes.FROM_IMAGE,
289
+ delete_option=compute.DiskDeleteOptionTypes.DELETE,
290
+ managed_disk=compute.ManagedDiskParameters(
291
+ storage_account_type=node_config['azure_arm_parameters']
292
+ ['osDiskTier']),
293
+ disk_size_gb=node_config['azure_arm_parameters']['osDiskSizeGB']))
294
+ vm_instance = compute.VirtualMachine(
295
+ location=provider_config['location'],
296
+ tags=node_tags,
297
+ hardware_profile=hardware_profile,
298
+ os_profile=os_profile,
299
+ storage_profile=storage_profile,
300
+ network_profile=network_profile,
301
+ identity=compute.VirtualMachineIdentity(
302
+ type='UserAssigned',
303
+ user_assigned_identities={provider_config['msi']: {}}),
304
+ priority=node_config['azure_arm_parameters'].get('priority', None))
305
+ vm_poller = compute_client.virtual_machines.begin_create_or_update(
306
+ resource_group_name=provider_config['resource_group'],
307
+ vm_name=vm_name,
308
+ parameters=vm_instance,
309
+ )
310
+ # This line will block until the VM is created or the operation times out.
311
+ vm = vm_poller.result()
312
+ logger.info(f'Created VM {vm.name}.')
313
+ return vm
314
+
315
+
316
+ def _create_instances(compute_client: 'azure_compute.ComputeManagementClient',
317
+ network_client: 'azure_network.NetworkManagementClient',
318
+ cluster_name_on_cloud: str, resource_group: str,
319
+ provider_config: Dict[str, Any], node_config: Dict[str,
320
+ Any],
321
+ tags: Dict[str, str], count: int) -> List:
322
+ vm_id = uuid4().hex[:UNIQUE_ID_LEN]
323
+ all_tags = {
324
+ constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
325
+ constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud,
326
+ **constants.WORKER_NODE_TAGS,
327
+ _TAG_SKYPILOT_VM_ID: vm_id,
328
+ **tags,
329
+ }
330
+ node_tags = node_config['tags'].copy()
331
+ node_tags.update(all_tags)
332
+
333
+ # Create VM instances in parallel.
334
+ def create_single_instance(vm_i):
335
+ vm_name = f'{cluster_name_on_cloud}-{vm_id}-{vm_i}'
336
+ network_interface = _create_network_interface(network_client, vm_name,
337
+ provider_config)
338
+ _create_vm(compute_client, vm_name, node_tags, provider_config,
339
+ node_config, network_interface.id)
340
+
341
+ subprocess_utils.run_in_parallel(create_single_instance, list(range(count)))
342
+
343
+ # Update disk performance tier
344
+ performance_tier = node_config.get('disk_performance_tier', None)
345
+ if performance_tier is not None:
346
+ disks = compute_client.disks.list_by_resource_group(resource_group)
347
+ for disk in disks:
348
+ name = disk.name
349
+ # TODO(tian): Investigate if we can use Python SDK to update this.
350
+ subprocess_utils.run_no_outputs(
351
+ f'az disk update -n {name} -g {resource_group} '
352
+ f'--set tier={performance_tier}')
353
+
354
+ # Validation
355
+ filters = {
356
+ constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
357
+ _TAG_SKYPILOT_VM_ID: vm_id
358
+ }
359
+ instances = _filter_instances(compute_client, resource_group, filters)
360
+ assert len(instances) == count, (len(instances), count)
361
+
362
+ return instances
363
+
364
+
365
+ def run_instances(region: str, cluster_name_on_cloud: str,
366
+ config: common.ProvisionConfig) -> common.ProvisionRecord:
367
+ """See sky/provision/__init__.py"""
368
+ # TODO(zhwu): This function is too long. We should refactor it.
369
+ provider_config = config.provider_config
370
+ resource_group = provider_config['resource_group']
371
+ subscription_id = provider_config['subscription_id']
372
+ compute_client = azure.get_client('compute', subscription_id)
373
+ network_client = azure.get_client('network', subscription_id)
374
+ instances_to_resume = []
375
+ resumed_instance_ids: List[str] = []
376
+ created_instance_ids: List[str] = []
377
+
378
+ # sort tags by key to support deterministic unit test stubbing
379
+ tags = dict(sorted(copy.deepcopy(config.tags).items()))
380
+ filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
381
+
382
+ non_deleting_states = (set(AzureInstanceStatus) -
383
+ {AzureInstanceStatus.DELETING})
384
+ existing_instances = _filter_instances(
385
+ compute_client,
386
+ tag_filters=filters,
387
+ resource_group=resource_group,
388
+ status_filters=list(non_deleting_states),
389
+ )
390
+ logger.debug(
391
+ f'run_instances: Found {[inst.name for inst in existing_instances]} '
392
+ 'existing instances in cluster.')
393
+ existing_instances.sort(key=lambda x: x.name)
394
+
395
+ pending_instances = []
396
+ running_instances = []
397
+ stopping_instances = []
398
+ stopped_instances = []
399
+
400
+ for instance in existing_instances:
401
+ status = _get_instance_status(compute_client, instance, resource_group)
402
+ logger.debug(
403
+ f'run_instances: Instance {instance.name} has status {status}.')
404
+
405
+ if status == AzureInstanceStatus.RUNNING:
406
+ running_instances.append(instance)
407
+ elif status == AzureInstanceStatus.STOPPED:
408
+ stopped_instances.append(instance)
409
+ elif status == AzureInstanceStatus.STOPPING:
410
+ stopping_instances.append(instance)
411
+ elif status == AzureInstanceStatus.PENDING:
412
+ pending_instances.append(instance)
413
+
414
+ def _create_instance_tag(target_instance, is_head: bool = True) -> str:
415
+ new_instance_tags = (constants.HEAD_NODE_TAGS
416
+ if is_head else constants.WORKER_NODE_TAGS)
417
+
418
+ tags = target_instance.tags
419
+ tags.update(new_instance_tags)
420
+
421
+ update = _get_azure_sdk_function(compute_client.virtual_machines,
422
+ 'update')
423
+ update(resource_group, target_instance.name, parameters={'tags': tags})
424
+ return target_instance.name
425
+
426
+ head_instance_id = _get_head_instance_id(existing_instances)
427
+ if head_instance_id is None:
428
+ if running_instances:
429
+ head_instance_id = _create_instance_tag(running_instances[0])
430
+ elif pending_instances:
431
+ head_instance_id = _create_instance_tag(pending_instances[0])
432
+
433
+ if config.resume_stopped_nodes and len(existing_instances) > config.count:
434
+ raise RuntimeError(
435
+ 'The number of pending/running/stopped/stopping '
436
+ f'instances combined ({len(existing_instances)}) in '
437
+ f'cluster "{cluster_name_on_cloud}" is greater than the '
438
+ f'number requested by the user ({config.count}). '
439
+ 'This is likely a resource leak. '
440
+ 'Use "sky down" to terminate the cluster.')
441
+
442
+ to_start_count = config.count - len(pending_instances) - len(
443
+ running_instances)
444
+
445
+ if to_start_count < 0:
446
+ raise RuntimeError(
447
+ 'The number of running+pending instances '
448
+ f'({config.count - to_start_count}) in cluster '
449
+ f'"{cluster_name_on_cloud}" is greater than the number '
450
+ f'requested by the user ({config.count}). '
451
+ 'This is likely a resource leak. '
452
+ 'Use "sky down" to terminate the cluster.')
453
+
454
+ if config.resume_stopped_nodes and to_start_count > 0 and (
455
+ stopping_instances or stopped_instances):
456
+ time_start = time.time()
457
+ if stopping_instances:
458
+ plural = 's' if len(stopping_instances) > 1 else ''
459
+ verb = 'are' if len(stopping_instances) > 1 else 'is'
460
+ # TODO(zhwu): double check the correctness of the following on Azure
461
+ logger.warning(
462
+ f'Instance{plural} {[inst.name for inst in stopping_instances]}'
463
+ f' {verb} still in STOPPING state on Azure. It can only be '
464
+ 'resumed after it is fully STOPPED. Waiting ...')
465
+ while (stopping_instances and
466
+ to_start_count > len(stopped_instances) and
467
+ time.time() - time_start < _RESUME_INSTANCE_TIMEOUT):
468
+ inst = stopping_instances.pop(0)
469
+ per_instance_time_start = time.time()
470
+ while (time.time() - per_instance_time_start <
471
+ _RESUME_PER_INSTANCE_TIMEOUT):
472
+ status = _get_instance_status(compute_client, inst,
473
+ resource_group)
474
+ if status == AzureInstanceStatus.STOPPED:
475
+ break
476
+ time.sleep(1)
477
+ else:
478
+ logger.warning(
479
+ f'Instance {inst.name} is still in stopping state '
480
+ f'(Timeout: {_RESUME_PER_INSTANCE_TIMEOUT}). '
481
+ 'Retrying ...')
482
+ stopping_instances.append(inst)
483
+ time.sleep(5)
484
+ continue
485
+ stopped_instances.append(inst)
486
+ if stopping_instances and to_start_count > len(stopped_instances):
487
+ msg = ('Timeout for waiting for existing instances '
488
+ f'{stopping_instances} in STOPPING state to '
489
+ 'be STOPPED before restarting them. Please try again later.')
490
+ logger.error(msg)
491
+ raise RuntimeError(msg)
492
+
493
+ instances_to_resume = stopped_instances[:to_start_count]
494
+ instances_to_resume.sort(key=lambda x: x.name)
495
+ instances_to_resume_ids = [t.name for t in instances_to_resume]
496
+ logger.debug('run_instances: Resuming stopped instances '
497
+ f'{instances_to_resume_ids}.')
498
+ start_virtual_machine = _get_azure_sdk_function(
499
+ compute_client.virtual_machines, 'start')
500
+ with pool.ThreadPool() as p:
501
+ p.starmap(
502
+ start_virtual_machine,
503
+ [(resource_group, inst.name) for inst in instances_to_resume])
504
+ resumed_instance_ids = instances_to_resume_ids
505
+
506
+ to_start_count -= len(resumed_instance_ids)
507
+
508
+ if to_start_count > 0:
509
+ logger.debug(f'run_instances: Creating {to_start_count} instances.')
510
+ try:
511
+ created_instances = _create_instances(
512
+ compute_client=compute_client,
513
+ network_client=network_client,
514
+ cluster_name_on_cloud=cluster_name_on_cloud,
515
+ resource_group=resource_group,
516
+ provider_config=provider_config,
517
+ node_config=config.node_config,
518
+ tags=tags,
519
+ count=to_start_count)
520
+ except Exception as e:
521
+ err_message = common_utils.format_exception(
522
+ e, use_bracket=True).replace('\n', ' ')
523
+ logger.error(f'Failed to create instances: {err_message}')
524
+ raise
525
+ created_instance_ids = [inst.name for inst in created_instances]
526
+
527
+ non_running_instance_statuses = list(
528
+ set(AzureInstanceStatus) - {AzureInstanceStatus.RUNNING})
529
+ start = time.time()
530
+ while True:
531
+ # Wait for all instances to be in running state
532
+ instances = _filter_instances(
533
+ compute_client,
534
+ resource_group,
535
+ filters,
536
+ status_filters=non_running_instance_statuses,
537
+ included_instances=created_instance_ids + resumed_instance_ids)
538
+ if not instances:
539
+ break
540
+ if time.time() - start > _WAIT_CREATION_TIMEOUT_SECONDS:
541
+ raise TimeoutError(
542
+ 'run_instances: Timed out waiting for Azure instances to be '
543
+ f'running: {instances}')
544
+ logger.debug(f'run_instances: Waiting for {len(instances)} instances '
545
+ 'in PENDING status.')
546
+ time.sleep(_POLL_INTERVAL)
547
+
548
+ running_instances = _filter_instances(
549
+ compute_client,
550
+ resource_group,
551
+ filters,
552
+ status_filters=[AzureInstanceStatus.RUNNING])
553
+ head_instance_id = _get_head_instance_id(running_instances)
554
+ instances_to_tag = copy.copy(running_instances)
555
+ if head_instance_id is None:
556
+ head_instance_id = _create_instance_tag(instances_to_tag[0])
557
+ instances_to_tag = instances_to_tag[1:]
558
+ else:
559
+ instances_to_tag = [
560
+ inst for inst in instances_to_tag if inst.name != head_instance_id
561
+ ]
562
+
563
+ if instances_to_tag:
564
+ # Tag the instances in case the old resumed instances are not correctly
565
+ # tagged.
566
+ with pool.ThreadPool() as p:
567
+ p.starmap(
568
+ _create_instance_tag,
569
+ # is_head=False for all wokers.
570
+ [(inst, False) for inst in instances_to_tag])
571
+
572
+ assert head_instance_id is not None, head_instance_id
573
+ return common.ProvisionRecord(
574
+ provider_name='azure',
575
+ region=region,
576
+ zone=None,
577
+ cluster_name=cluster_name_on_cloud,
578
+ head_instance_id=head_instance_id,
579
+ created_instance_ids=created_instance_ids,
580
+ resumed_instance_ids=resumed_instance_ids,
581
+ )
582
+
583
+
584
+ def wait_instances(region: str, cluster_name_on_cloud: str,
585
+ state: Optional[status_lib.ClusterStatus]) -> None:
586
+ """See sky/provision/__init__.py"""
587
+ del region, cluster_name_on_cloud, state
588
+ # We already wait for the instances to be running in run_instances.
589
+ # So we don't need to wait here.
590
+
591
+
592
+ def get_cluster_info(
593
+ region: str,
594
+ cluster_name_on_cloud: str,
595
+ provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo:
596
+ """See sky/provision/__init__.py"""
597
+ del region
598
+ filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
599
+ assert provider_config is not None, (cluster_name_on_cloud, provider_config)
600
+ resource_group = provider_config['resource_group']
601
+ subscription_id = provider_config.get('subscription_id',
602
+ azure.get_subscription_id())
603
+ compute_client = azure.get_client('compute', subscription_id)
604
+ network_client = azure.get_client('network', subscription_id)
605
+
606
+ running_instances = _filter_instances(
607
+ compute_client,
608
+ resource_group,
609
+ filters,
610
+ status_filters=[AzureInstanceStatus.RUNNING])
611
+ head_instance_id = _get_head_instance_id(running_instances)
612
+
613
+ instances = {}
614
+ use_internal_ips = provider_config.get('use_internal_ips', False)
615
+ for inst in running_instances:
616
+ internal_ip, external_ip = _get_instance_ips(network_client, inst,
617
+ resource_group,
618
+ use_internal_ips)
619
+ instances[inst.name] = [
620
+ common.InstanceInfo(
621
+ instance_id=inst.name,
622
+ internal_ip=internal_ip,
623
+ external_ip=external_ip,
624
+ tags=inst.tags,
625
+ )
626
+ ]
627
+ instances = dict(sorted(instances.items(), key=lambda x: x[0]))
628
+ return common.ClusterInfo(
629
+ provider_name='azure',
630
+ head_instance_id=head_instance_id,
631
+ instances=instances,
632
+ provider_config=provider_config,
633
+ )
634
+
635
+
636
+ def stop_instances(
637
+ cluster_name_on_cloud: str,
638
+ provider_config: Optional[Dict[str, Any]] = None,
639
+ worker_only: bool = False,
640
+ ) -> None:
641
+ """See sky/provision/__init__.py"""
642
+ assert provider_config is not None, (cluster_name_on_cloud, provider_config)
643
+
644
+ subscription_id = provider_config['subscription_id']
645
+ resource_group = provider_config['resource_group']
646
+ compute_client = azure.get_client('compute', subscription_id)
647
+ tag_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
648
+ if worker_only:
649
+ tag_filters[constants.TAG_RAY_NODE_KIND] = 'worker'
650
+
651
+ nodes = _filter_instances(compute_client, resource_group, tag_filters)
652
+ stop_virtual_machine = _get_azure_sdk_function(
653
+ client=compute_client.virtual_machines, function_name='deallocate')
654
+ with pool.ThreadPool() as p:
655
+ p.starmap(stop_virtual_machine,
656
+ [(resource_group, node.name) for node in nodes])
657
+
658
+
659
+ def terminate_instances(
660
+ cluster_name_on_cloud: str,
661
+ provider_config: Optional[Dict[str, Any]] = None,
662
+ worker_only: bool = False,
663
+ ) -> None:
664
+ """See sky/provision/__init__.py"""
665
+ assert provider_config is not None, (cluster_name_on_cloud, provider_config)
666
+ # TODO(zhwu): check the following. Also, seems we can directly force
667
+ # delete a resource group.
668
+ subscription_id = provider_config['subscription_id']
669
+ resource_group = provider_config['resource_group']
670
+ if worker_only:
671
+ compute_client = azure.get_client('compute', subscription_id)
672
+ delete_virtual_machine = _get_azure_sdk_function(
673
+ client=compute_client.virtual_machines, function_name='delete')
674
+ filters = {
675
+ constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
676
+ constants.TAG_RAY_NODE_KIND: 'worker'
677
+ }
678
+ nodes = _filter_instances(compute_client, resource_group, filters)
679
+ with pool.ThreadPool() as p:
680
+ p.starmap(delete_virtual_machine,
681
+ [(resource_group, node.name) for node in nodes])
682
+ return
683
+
684
+ assert provider_config is not None, cluster_name_on_cloud
685
+
686
+ use_external_resource_group = provider_config.get(
687
+ 'use_external_resource_group', False)
688
+ # When user specified resource group through config.yaml to create a VM, we
689
+ # cannot remove the entire resource group as it may contain other resources
690
+ # unrelated to this VM being removed.
691
+ if use_external_resource_group:
692
+ delete_vm_and_attached_resources(subscription_id, resource_group,
693
+ cluster_name_on_cloud)
694
+ else:
695
+ # For SkyPilot default resource groups, delete entire resource group.
696
+ # This automatically terminates all resources within, including VMs
697
+ resource_group_client = azure.get_client('resource', subscription_id)
698
+ delete_resource_group = _get_azure_sdk_function(
699
+ client=resource_group_client.resource_groups,
700
+ function_name='delete')
701
+ try:
702
+ delete_resource_group(resource_group, force_deletion_types=None)
703
+ except azure.exceptions().ResourceNotFoundError as e:
704
+ if 'ResourceGroupNotFound' in str(e):
705
+ logger.warning(
706
+ f'Resource group {resource_group} not found. Skip '
707
+ 'terminating it.')
708
+ return
709
+ raise
710
+
711
+
712
+ def _get_instance_status(
713
+ compute_client: 'azure_compute.ComputeManagementClient', vm,
714
+ resource_group: str) -> Optional[AzureInstanceStatus]:
715
+ try:
716
+ instance = compute_client.virtual_machines.instance_view(
717
+ resource_group_name=resource_group, vm_name=vm.name)
718
+ except azure.exceptions().ResourceNotFoundError as e:
719
+ if 'ResourceNotFound' in str(e):
720
+ return None
721
+ raise
722
+ provisioning_state = vm.provisioning_state
723
+ instance_dict = instance.as_dict()
724
+ for status in instance_dict['statuses']:
725
+ code_state = status['code'].split('/')
726
+ # It is possible that sometimes the 'code' is empty string, and we
727
+ # should skip them.
728
+ if len(code_state) != 2:
729
+ continue
730
+ code, state = code_state
731
+ # skip provisioning status
732
+ if code == 'PowerState':
733
+ return AzureInstanceStatus.from_raw_states(provisioning_state,
734
+ state)
735
+ return AzureInstanceStatus.from_raw_states(provisioning_state, None)
736
+
737
+
738
+ def _filter_instances(
739
+ compute_client: 'azure_compute.ComputeManagementClient',
740
+ resource_group: str,
741
+ tag_filters: Dict[str, str],
742
+ status_filters: Optional[List[AzureInstanceStatus]] = None,
743
+ included_instances: Optional[List[str]] = None,
744
+ ) -> List['azure_compute.models.VirtualMachine']:
745
+
746
+ def match_tags(vm):
747
+ for k, v in tag_filters.items():
748
+ if vm.tags.get(k) != v:
749
+ return False
750
+ return True
751
+
752
+ try:
753
+ list_virtual_machines = _get_azure_sdk_function(
754
+ client=compute_client.virtual_machines, function_name='list')
755
+ vms = list_virtual_machines(resource_group_name=resource_group)
756
+ nodes = list(filter(match_tags, vms))
757
+ except azure.exceptions().ResourceNotFoundError as e:
758
+ if _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE in str(e):
759
+ return []
760
+ raise
761
+ if status_filters is not None:
762
+ nodes = [
763
+ node for node in nodes if _get_instance_status(
764
+ compute_client, node, resource_group) in status_filters
765
+ ]
766
+ if included_instances:
767
+ nodes = [node for node in nodes if node.name in included_instances]
768
+ return nodes
769
+
770
+
771
+ def _delete_nic_with_retries(network_client,
772
+ resource_group,
773
+ nic_name,
774
+ max_retries=15,
775
+ retry_interval=20):
776
+ """Delete a NIC with retries.
777
+
778
+ When a VM is created, its NIC is reserved for 180 seconds, preventing its
779
+ immediate deletion. If the NIC is in this reserved state, we must retry
780
+ deletion with intervals until the reservation expires. This situation
781
+ commonly arises if a VM termination is followed by a failover to another
782
+ region due to provisioning failures.
783
+ """
784
+ delete_network_interfaces = _get_azure_sdk_function(
785
+ client=network_client.network_interfaces, function_name='begin_delete')
786
+ for _ in range(max_retries):
787
+ try:
788
+ delete_network_interfaces(resource_group_name=resource_group,
789
+ network_interface_name=nic_name).result()
790
+ return
791
+ except azure.exceptions().HttpResponseError as e:
792
+ if 'NicReservedForAnotherVm' in str(e):
793
+ # Retry when deletion fails with reserved NIC.
794
+ logger.warning(f'NIC {nic_name} is reserved. '
795
+ f'Retrying in {retry_interval} seconds...')
796
+ time.sleep(retry_interval)
797
+ else:
798
+ raise e
799
+ logger.error(
800
+ f'Failed to delete NIC {nic_name} after {max_retries} attempts.')
801
+
802
+
803
+ def delete_vm_and_attached_resources(subscription_id: str, resource_group: str,
804
+ cluster_name_on_cloud: str) -> None:
805
+ """Removes VM with attached resources and Deployments.
806
+
807
+ This function deletes a virtual machine and its associated resources
808
+ (public IP addresses, virtual networks, managed identities, network
809
+ interface and network security groups) that match cluster_name_on_cloud.
810
+ There is one attached resources that is not removed within this
811
+ method: OS disk. It is configured to be deleted when VM is terminated while
812
+ setting up storage profile from _create_vm.
813
+
814
+ Args:
815
+ subscription_id: The Azure subscription ID.
816
+ resource_group: The name of the resource group.
817
+ cluster_name_on_cloud: The name of the cluster to filter resources.
818
+ """
819
+ resource_client = azure.get_client('resource', subscription_id)
820
+ try:
821
+ list_resources = _get_azure_sdk_function(
822
+ client=resource_client.resources,
823
+ function_name='list_by_resource_group')
824
+ resources = list(list_resources(resource_group))
825
+ except azure.exceptions().ResourceNotFoundError as e:
826
+ if _RESOURCE_GROUP_NOT_FOUND_ERROR_MESSAGE in str(e):
827
+ return
828
+ raise
829
+
830
+ filtered_resources: Dict[str, List[str]] = {
831
+ _RESOURCE_VIRTUAL_MACHINE_TYPE: [],
832
+ _RESOURCE_MANAGED_IDENTITY_TYPE: [],
833
+ _RESOURCE_NETWORK_SECURITY_GROUP_TYPE: [],
834
+ _RESOURCE_VIRTUAL_NETWORK_TYPE: [],
835
+ _RESOURCE_PUBLIC_IP_ADDRESS_TYPE: [],
836
+ _RESOURCE_NETWORK_INTERFACE_TYPE: []
837
+ }
838
+
839
+ for resource in resources:
840
+ if (resource.type in filtered_resources and
841
+ cluster_name_on_cloud in resource.name):
842
+ filtered_resources[resource.type].append(resource.name)
843
+
844
+ network_client = azure.get_client('network', subscription_id)
845
+ msi_client = azure.get_client('msi', subscription_id)
846
+ compute_client = azure.get_client('compute', subscription_id)
847
+ auth_client = azure.get_client('authorization', subscription_id)
848
+
849
+ delete_virtual_machine = _get_azure_sdk_function(
850
+ client=compute_client.virtual_machines, function_name='delete')
851
+ delete_public_ip_addresses = _get_azure_sdk_function(
852
+ client=network_client.public_ip_addresses, function_name='begin_delete')
853
+ delete_virtual_networks = _get_azure_sdk_function(
854
+ client=network_client.virtual_networks, function_name='begin_delete')
855
+ delete_managed_identity = _get_azure_sdk_function(
856
+ client=msi_client.user_assigned_identities, function_name='delete')
857
+ delete_network_security_group = _get_azure_sdk_function(
858
+ client=network_client.network_security_groups,
859
+ function_name='begin_delete')
860
+ delete_role_assignment = _get_azure_sdk_function(
861
+ client=auth_client.role_assignments, function_name='delete')
862
+
863
+ for vm_name in filtered_resources[_RESOURCE_VIRTUAL_MACHINE_TYPE]:
864
+ try:
865
+ # Before removing Network Interface, we need to wait for the VM to
866
+ # be completely removed with .result() so the dependency of VM on
867
+ # Network Interface is disassociated. This takes abour ~30s.
868
+ delete_virtual_machine(resource_group_name=resource_group,
869
+ vm_name=vm_name).result()
870
+ except Exception as e: # pylint: disable=broad-except
871
+ logger.warning('Failed to delete VM: {}'.format(e))
872
+
873
+ for nic_name in filtered_resources[_RESOURCE_NETWORK_INTERFACE_TYPE]:
874
+ try:
875
+ # Before removing Public IP Address, we need to wait for the
876
+ # Network Interface to be completely removed with .result() so the
877
+ # dependency of Network Interface on Public IP Address is
878
+ # disassociated. This takes about ~1s.
879
+ _delete_nic_with_retries(network_client, resource_group, nic_name)
880
+ except Exception as e: # pylint: disable=broad-except
881
+ logger.warning('Failed to delete nic: {}'.format(e))
882
+
883
+ for public_ip_name in filtered_resources[_RESOURCE_PUBLIC_IP_ADDRESS_TYPE]:
884
+ try:
885
+ delete_public_ip_addresses(resource_group_name=resource_group,
886
+ public_ip_address_name=public_ip_name)
887
+ except Exception as e: # pylint: disable=broad-except
888
+ logger.warning('Failed to delete public ip: {}'.format(e))
889
+
890
+ for vnet_name in filtered_resources[_RESOURCE_VIRTUAL_NETWORK_TYPE]:
891
+ try:
892
+ delete_virtual_networks(resource_group_name=resource_group,
893
+ virtual_network_name=vnet_name)
894
+ except Exception as e: # pylint: disable=broad-except
895
+ logger.warning('Failed to delete vnet: {}'.format(e))
896
+
897
+ for msi_name in filtered_resources[_RESOURCE_MANAGED_IDENTITY_TYPE]:
898
+ user_assigned_identities = (
899
+ msi_client.user_assigned_identities.list_by_resource_group(
900
+ resource_group_name=resource_group))
901
+ for identity in user_assigned_identities:
902
+ if msi_name == identity.name:
903
+ # We use the principal_id to find the correct guid converted
904
+ # role assignment name because each managed identity has a
905
+ # unique principal_id, and role assignments are associated
906
+ # with security principals (like managed identities) via this
907
+ # principal_id.
908
+ target_principal_id = identity.principal_id
909
+ scope = (f'/subscriptions/{subscription_id}'
910
+ f'/resourceGroups/{resource_group}')
911
+ role_assignments = auth_client.role_assignments.list_for_scope(
912
+ scope)
913
+ for assignment in role_assignments:
914
+ if target_principal_id == assignment.principal_id:
915
+ guid_role_assignment_name = assignment.name
916
+ try:
917
+ delete_role_assignment(
918
+ scope=scope,
919
+ role_assignment_name=guid_role_assignment_name)
920
+ except Exception as e: # pylint: disable=broad-except
921
+ logger.warning('Failed to delete role '
922
+ 'assignment: {}'.format(e))
923
+ break
924
+ try:
925
+ delete_managed_identity(resource_group_name=resource_group,
926
+ resource_name=msi_name)
927
+ except Exception as e: # pylint: disable=broad-except
928
+ logger.warning('Failed to delete msi: {}'.format(e))
929
+
930
+ for nsg_name in filtered_resources[_RESOURCE_NETWORK_SECURITY_GROUP_TYPE]:
931
+ try:
932
+ delete_network_security_group(resource_group_name=resource_group,
933
+ network_security_group_name=nsg_name)
934
+ except Exception as e: # pylint: disable=broad-except
935
+ logger.warning('Failed to delete nsg: {}'.format(e))
936
+
937
+ delete_deployment = _get_azure_sdk_function(
938
+ client=resource_client.deployments, function_name='begin_delete')
939
+ deployment_names = [
940
+ constants.EXTERNAL_RG_BOOTSTRAP_DEPLOYMENT_NAME.format(
941
+ cluster_name_on_cloud=cluster_name_on_cloud),
942
+ constants.EXTERNAL_RG_VM_DEPLOYMENT_NAME.format(
943
+ cluster_name_on_cloud=cluster_name_on_cloud)
944
+ ]
945
+ for deployment_name in deployment_names:
946
+ try:
947
+ delete_deployment(resource_group_name=resource_group,
948
+ deployment_name=deployment_name)
949
+ except Exception as e: # pylint: disable=broad-except
950
+ logger.warning('Failed to delete deployment: {}'.format(e))
951
+
952
+
953
+ @common_utils.retry
954
+ def query_instances(
955
+ cluster_name_on_cloud: str,
956
+ provider_config: Optional[Dict[str, Any]] = None,
957
+ non_terminated_only: bool = True,
958
+ ) -> Dict[str, Optional[status_lib.ClusterStatus]]:
959
+ """See sky/provision/__init__.py"""
960
+ assert provider_config is not None, cluster_name_on_cloud
961
+
962
+ subscription_id = provider_config['subscription_id']
963
+ resource_group = provider_config['resource_group']
964
+ filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}
965
+ compute_client = azure.get_client('compute', subscription_id)
966
+ nodes = _filter_instances(compute_client, resource_group, filters)
967
+ statuses: Dict[str, Optional[status_lib.ClusterStatus]] = {}
968
+
969
+ def _fetch_and_map_status(node, resource_group: str) -> None:
970
+ compute_client = azure.get_client('compute', subscription_id)
971
+ status = _get_instance_status(compute_client, node, resource_group)
972
+
973
+ if status is None and non_terminated_only:
974
+ return
975
+ statuses[node.name] = (None if status is None else
976
+ status.to_cluster_status())
977
+
978
+ with pool.ThreadPool() as p:
979
+ p.starmap(_fetch_and_map_status,
980
+ [(node, resource_group) for node in nodes])
981
+
982
+ return statuses
983
+
984
+
38
985
  def open_ports(
39
986
  cluster_name_on_cloud: str,
40
987
  ports: List[str],
@@ -45,43 +992,71 @@ def open_ports(
45
992
  subscription_id = provider_config['subscription_id']
46
993
  resource_group = provider_config['resource_group']
47
994
  network_client = azure.get_client('network', subscription_id)
48
- # The NSG should have been created by the cluster provisioning.
49
- update_network_security_groups = get_azure_sdk_function(
995
+
996
+ update_network_security_groups = _get_azure_sdk_function(
50
997
  client=network_client.network_security_groups,
51
998
  function_name='create_or_update')
52
- list_network_security_groups = get_azure_sdk_function(
999
+ list_network_security_groups = _get_azure_sdk_function(
53
1000
  client=network_client.network_security_groups, function_name='list')
1001
+
54
1002
  for nsg in list_network_security_groups(resource_group):
55
- try:
56
- # Azure NSG rules have a priority field that determines the order
57
- # in which they are applied. The priority must be unique across
58
- # all inbound rules in one NSG.
59
- priority = max(rule.priority
60
- for rule in nsg.security_rules
61
- if rule.direction == 'Inbound') + 1
62
- nsg.security_rules.append(
63
- azure.create_security_rule(
64
- name=f'sky-ports-{cluster_name_on_cloud}-{priority}',
65
- priority=priority,
66
- protocol='Tcp',
67
- access='Allow',
68
- direction='Inbound',
69
- source_address_prefix='*',
70
- source_port_range='*',
71
- destination_address_prefix='*',
72
- destination_port_ranges=ports,
73
- ))
74
- poller = update_network_security_groups(resource_group, nsg.name,
75
- nsg)
76
- poller.wait()
77
- if poller.status() != 'Succeeded':
1003
+ # Given resource group can contain network security groups that are
1004
+ # irrelevant to this provisioning especially with user specified
1005
+ # resource group at ~/.sky/config. So we make sure to check for the
1006
+ # completion of nsg relevant to the VM being provisioned.
1007
+ if cluster_name_on_cloud in nsg.name:
1008
+ try:
1009
+ # Wait the NSG creation to be finished before opening a port.
1010
+ # The cluster provisioning triggers the NSG creation, but it
1011
+ # may not be finished yet.
1012
+ backoff = common_utils.Backoff(max_backoff_factor=1)
1013
+ start_time = time.time()
1014
+ while True:
1015
+ if nsg.provisioning_state not in ['Creating', 'Updating']:
1016
+ break
1017
+ if time.time(
1018
+ ) - start_time > _WAIT_CREATION_TIMEOUT_SECONDS:
1019
+ logger.warning(
1020
+ f'Fails to wait for the creation of NSG {nsg.name}'
1021
+ f' in {resource_group} within '
1022
+ f'{_WAIT_CREATION_TIMEOUT_SECONDS} seconds. '
1023
+ 'Skip this NSG.')
1024
+ backoff_time = backoff.current_backoff()
1025
+ logger.info(
1026
+ f'NSG {nsg.name} is not created yet. Waiting for '
1027
+ f'{backoff_time} seconds before checking again.')
1028
+ time.sleep(backoff_time)
1029
+
1030
+ # Azure NSG rules have a priority field that determines the
1031
+ # order in which they are applied. The priority must be unique
1032
+ # across all inbound rules in one NSG.
1033
+ priority = max(rule.priority
1034
+ for rule in nsg.security_rules
1035
+ if rule.direction == 'Inbound') + 1
1036
+ nsg.security_rules.append(
1037
+ azure.create_security_rule(
1038
+ name=f'sky-ports-{cluster_name_on_cloud}-{priority}',
1039
+ priority=priority,
1040
+ protocol='Tcp',
1041
+ access='Allow',
1042
+ direction='Inbound',
1043
+ source_address_prefix='*',
1044
+ source_port_range='*',
1045
+ destination_address_prefix='*',
1046
+ destination_port_ranges=ports,
1047
+ ))
1048
+ poller = update_network_security_groups(resource_group,
1049
+ nsg.name, nsg)
1050
+ poller.wait()
1051
+ if poller.status() != 'Succeeded':
1052
+ with ux_utils.print_exception_no_traceback():
1053
+ raise ValueError(f'Failed to open ports {ports} in NSG '
1054
+ f'{nsg.name}: {poller.status()}')
1055
+ except azure.exceptions().HttpResponseError as e:
78
1056
  with ux_utils.print_exception_no_traceback():
79
- raise ValueError(f'Failed to open ports {ports} in NSG '
80
- f'{nsg.name}: {poller.status()}')
81
- except azure.exceptions().HttpResponseError as e:
82
- with ux_utils.print_exception_no_traceback():
83
- raise ValueError(
84
- f'Failed to open ports {ports} in NSG {nsg.name}.') from e
1057
+ raise ValueError(
1058
+ f'Failed to open ports {ports} in NSG {nsg.name}.'
1059
+ ) from e
85
1060
 
86
1061
 
87
1062
  def cleanup_ports(