skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. sky/__init__.py +64 -32
  2. sky/adaptors/aws.py +23 -6
  3. sky/adaptors/azure.py +432 -15
  4. sky/adaptors/cloudflare.py +5 -5
  5. sky/adaptors/common.py +19 -9
  6. sky/adaptors/do.py +20 -0
  7. sky/adaptors/gcp.py +3 -2
  8. sky/adaptors/kubernetes.py +122 -88
  9. sky/adaptors/nebius.py +100 -0
  10. sky/adaptors/oci.py +39 -1
  11. sky/adaptors/vast.py +29 -0
  12. sky/admin_policy.py +101 -0
  13. sky/authentication.py +117 -98
  14. sky/backends/backend.py +52 -20
  15. sky/backends/backend_utils.py +669 -557
  16. sky/backends/cloud_vm_ray_backend.py +1099 -808
  17. sky/backends/local_docker_backend.py +14 -8
  18. sky/backends/wheel_utils.py +38 -20
  19. sky/benchmark/benchmark_utils.py +22 -23
  20. sky/check.py +76 -27
  21. sky/cli.py +1586 -1139
  22. sky/client/__init__.py +1 -0
  23. sky/client/cli.py +5683 -0
  24. sky/client/common.py +345 -0
  25. sky/client/sdk.py +1765 -0
  26. sky/cloud_stores.py +283 -19
  27. sky/clouds/__init__.py +7 -2
  28. sky/clouds/aws.py +303 -112
  29. sky/clouds/azure.py +185 -179
  30. sky/clouds/cloud.py +115 -37
  31. sky/clouds/cudo.py +29 -22
  32. sky/clouds/do.py +313 -0
  33. sky/clouds/fluidstack.py +44 -54
  34. sky/clouds/gcp.py +206 -65
  35. sky/clouds/ibm.py +26 -21
  36. sky/clouds/kubernetes.py +345 -91
  37. sky/clouds/lambda_cloud.py +40 -29
  38. sky/clouds/nebius.py +297 -0
  39. sky/clouds/oci.py +129 -90
  40. sky/clouds/paperspace.py +22 -18
  41. sky/clouds/runpod.py +53 -34
  42. sky/clouds/scp.py +28 -24
  43. sky/clouds/service_catalog/__init__.py +19 -13
  44. sky/clouds/service_catalog/aws_catalog.py +29 -12
  45. sky/clouds/service_catalog/azure_catalog.py +33 -6
  46. sky/clouds/service_catalog/common.py +95 -75
  47. sky/clouds/service_catalog/constants.py +3 -3
  48. sky/clouds/service_catalog/cudo_catalog.py +13 -3
  49. sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
  50. sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
  51. sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
  52. sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
  53. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
  54. sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
  55. sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
  56. sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
  57. sky/clouds/service_catalog/do_catalog.py +111 -0
  58. sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
  59. sky/clouds/service_catalog/gcp_catalog.py +16 -2
  60. sky/clouds/service_catalog/ibm_catalog.py +2 -2
  61. sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
  62. sky/clouds/service_catalog/lambda_catalog.py +8 -3
  63. sky/clouds/service_catalog/nebius_catalog.py +116 -0
  64. sky/clouds/service_catalog/oci_catalog.py +31 -4
  65. sky/clouds/service_catalog/paperspace_catalog.py +2 -2
  66. sky/clouds/service_catalog/runpod_catalog.py +2 -2
  67. sky/clouds/service_catalog/scp_catalog.py +2 -2
  68. sky/clouds/service_catalog/vast_catalog.py +104 -0
  69. sky/clouds/service_catalog/vsphere_catalog.py +2 -2
  70. sky/clouds/utils/aws_utils.py +65 -0
  71. sky/clouds/utils/azure_utils.py +91 -0
  72. sky/clouds/utils/gcp_utils.py +5 -9
  73. sky/clouds/utils/oci_utils.py +47 -5
  74. sky/clouds/utils/scp_utils.py +4 -3
  75. sky/clouds/vast.py +280 -0
  76. sky/clouds/vsphere.py +22 -18
  77. sky/core.py +361 -107
  78. sky/dag.py +41 -28
  79. sky/data/data_transfer.py +37 -0
  80. sky/data/data_utils.py +211 -32
  81. sky/data/mounting_utils.py +182 -30
  82. sky/data/storage.py +2118 -270
  83. sky/data/storage_utils.py +126 -5
  84. sky/exceptions.py +179 -8
  85. sky/execution.py +158 -85
  86. sky/global_user_state.py +150 -34
  87. sky/jobs/__init__.py +12 -10
  88. sky/jobs/client/__init__.py +0 -0
  89. sky/jobs/client/sdk.py +302 -0
  90. sky/jobs/constants.py +49 -11
  91. sky/jobs/controller.py +161 -99
  92. sky/jobs/dashboard/dashboard.py +171 -25
  93. sky/jobs/dashboard/templates/index.html +572 -60
  94. sky/jobs/recovery_strategy.py +157 -156
  95. sky/jobs/scheduler.py +307 -0
  96. sky/jobs/server/__init__.py +1 -0
  97. sky/jobs/server/core.py +598 -0
  98. sky/jobs/server/dashboard_utils.py +69 -0
  99. sky/jobs/server/server.py +190 -0
  100. sky/jobs/state.py +627 -122
  101. sky/jobs/utils.py +615 -206
  102. sky/models.py +27 -0
  103. sky/optimizer.py +142 -83
  104. sky/provision/__init__.py +20 -5
  105. sky/provision/aws/config.py +124 -42
  106. sky/provision/aws/instance.py +130 -53
  107. sky/provision/azure/__init__.py +7 -0
  108. sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
  109. sky/provision/azure/config.py +220 -0
  110. sky/provision/azure/instance.py +1012 -37
  111. sky/provision/common.py +31 -3
  112. sky/provision/constants.py +25 -0
  113. sky/provision/cudo/__init__.py +2 -1
  114. sky/provision/cudo/cudo_utils.py +112 -0
  115. sky/provision/cudo/cudo_wrapper.py +37 -16
  116. sky/provision/cudo/instance.py +28 -12
  117. sky/provision/do/__init__.py +11 -0
  118. sky/provision/do/config.py +14 -0
  119. sky/provision/do/constants.py +10 -0
  120. sky/provision/do/instance.py +287 -0
  121. sky/provision/do/utils.py +301 -0
  122. sky/provision/docker_utils.py +82 -46
  123. sky/provision/fluidstack/fluidstack_utils.py +57 -125
  124. sky/provision/fluidstack/instance.py +15 -43
  125. sky/provision/gcp/config.py +19 -9
  126. sky/provision/gcp/constants.py +7 -1
  127. sky/provision/gcp/instance.py +55 -34
  128. sky/provision/gcp/instance_utils.py +339 -80
  129. sky/provision/gcp/mig_utils.py +210 -0
  130. sky/provision/instance_setup.py +172 -133
  131. sky/provision/kubernetes/__init__.py +1 -0
  132. sky/provision/kubernetes/config.py +104 -90
  133. sky/provision/kubernetes/constants.py +8 -0
  134. sky/provision/kubernetes/instance.py +680 -325
  135. sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
  136. sky/provision/kubernetes/network.py +54 -20
  137. sky/provision/kubernetes/network_utils.py +70 -21
  138. sky/provision/kubernetes/utils.py +1370 -251
  139. sky/provision/lambda_cloud/__init__.py +11 -0
  140. sky/provision/lambda_cloud/config.py +10 -0
  141. sky/provision/lambda_cloud/instance.py +265 -0
  142. sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
  143. sky/provision/logging.py +1 -1
  144. sky/provision/nebius/__init__.py +11 -0
  145. sky/provision/nebius/config.py +11 -0
  146. sky/provision/nebius/instance.py +285 -0
  147. sky/provision/nebius/utils.py +318 -0
  148. sky/provision/oci/__init__.py +15 -0
  149. sky/provision/oci/config.py +51 -0
  150. sky/provision/oci/instance.py +436 -0
  151. sky/provision/oci/query_utils.py +681 -0
  152. sky/provision/paperspace/constants.py +6 -0
  153. sky/provision/paperspace/instance.py +4 -3
  154. sky/provision/paperspace/utils.py +2 -0
  155. sky/provision/provisioner.py +207 -130
  156. sky/provision/runpod/__init__.py +1 -0
  157. sky/provision/runpod/api/__init__.py +3 -0
  158. sky/provision/runpod/api/commands.py +119 -0
  159. sky/provision/runpod/api/pods.py +142 -0
  160. sky/provision/runpod/instance.py +64 -8
  161. sky/provision/runpod/utils.py +239 -23
  162. sky/provision/vast/__init__.py +10 -0
  163. sky/provision/vast/config.py +11 -0
  164. sky/provision/vast/instance.py +247 -0
  165. sky/provision/vast/utils.py +162 -0
  166. sky/provision/vsphere/common/vim_utils.py +1 -1
  167. sky/provision/vsphere/instance.py +8 -18
  168. sky/provision/vsphere/vsphere_utils.py +1 -1
  169. sky/resources.py +247 -102
  170. sky/serve/__init__.py +9 -9
  171. sky/serve/autoscalers.py +361 -299
  172. sky/serve/client/__init__.py +0 -0
  173. sky/serve/client/sdk.py +366 -0
  174. sky/serve/constants.py +12 -3
  175. sky/serve/controller.py +106 -36
  176. sky/serve/load_balancer.py +63 -12
  177. sky/serve/load_balancing_policies.py +84 -2
  178. sky/serve/replica_managers.py +42 -34
  179. sky/serve/serve_state.py +62 -32
  180. sky/serve/serve_utils.py +271 -160
  181. sky/serve/server/__init__.py +0 -0
  182. sky/serve/{core.py → server/core.py} +271 -90
  183. sky/serve/server/server.py +112 -0
  184. sky/serve/service.py +52 -16
  185. sky/serve/service_spec.py +95 -32
  186. sky/server/__init__.py +1 -0
  187. sky/server/common.py +430 -0
  188. sky/server/constants.py +21 -0
  189. sky/server/html/log.html +174 -0
  190. sky/server/requests/__init__.py +0 -0
  191. sky/server/requests/executor.py +472 -0
  192. sky/server/requests/payloads.py +487 -0
  193. sky/server/requests/queues/__init__.py +0 -0
  194. sky/server/requests/queues/mp_queue.py +76 -0
  195. sky/server/requests/requests.py +567 -0
  196. sky/server/requests/serializers/__init__.py +0 -0
  197. sky/server/requests/serializers/decoders.py +192 -0
  198. sky/server/requests/serializers/encoders.py +166 -0
  199. sky/server/server.py +1106 -0
  200. sky/server/stream_utils.py +141 -0
  201. sky/setup_files/MANIFEST.in +2 -5
  202. sky/setup_files/dependencies.py +159 -0
  203. sky/setup_files/setup.py +14 -125
  204. sky/sky_logging.py +59 -14
  205. sky/skylet/autostop_lib.py +2 -2
  206. sky/skylet/constants.py +183 -50
  207. sky/skylet/events.py +22 -10
  208. sky/skylet/job_lib.py +403 -258
  209. sky/skylet/log_lib.py +111 -71
  210. sky/skylet/log_lib.pyi +6 -0
  211. sky/skylet/providers/command_runner.py +6 -8
  212. sky/skylet/providers/ibm/node_provider.py +2 -2
  213. sky/skylet/providers/scp/config.py +11 -3
  214. sky/skylet/providers/scp/node_provider.py +8 -8
  215. sky/skylet/skylet.py +3 -1
  216. sky/skylet/subprocess_daemon.py +69 -17
  217. sky/skypilot_config.py +119 -57
  218. sky/task.py +205 -64
  219. sky/templates/aws-ray.yml.j2 +37 -7
  220. sky/templates/azure-ray.yml.j2 +27 -82
  221. sky/templates/cudo-ray.yml.j2 +7 -3
  222. sky/templates/do-ray.yml.j2 +98 -0
  223. sky/templates/fluidstack-ray.yml.j2 +7 -4
  224. sky/templates/gcp-ray.yml.j2 +26 -6
  225. sky/templates/ibm-ray.yml.j2 +3 -2
  226. sky/templates/jobs-controller.yaml.j2 +46 -11
  227. sky/templates/kubernetes-ingress.yml.j2 +7 -0
  228. sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
  229. sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
  230. sky/templates/kubernetes-ray.yml.j2 +292 -25
  231. sky/templates/lambda-ray.yml.j2 +30 -40
  232. sky/templates/nebius-ray.yml.j2 +79 -0
  233. sky/templates/oci-ray.yml.j2 +18 -57
  234. sky/templates/paperspace-ray.yml.j2 +10 -6
  235. sky/templates/runpod-ray.yml.j2 +26 -4
  236. sky/templates/scp-ray.yml.j2 +3 -2
  237. sky/templates/sky-serve-controller.yaml.j2 +12 -1
  238. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  239. sky/templates/vast-ray.yml.j2 +70 -0
  240. sky/templates/vsphere-ray.yml.j2 +8 -3
  241. sky/templates/websocket_proxy.py +64 -0
  242. sky/usage/constants.py +10 -1
  243. sky/usage/usage_lib.py +130 -37
  244. sky/utils/accelerator_registry.py +35 -51
  245. sky/utils/admin_policy_utils.py +147 -0
  246. sky/utils/annotations.py +51 -0
  247. sky/utils/cli_utils/status_utils.py +81 -23
  248. sky/utils/cluster_utils.py +356 -0
  249. sky/utils/command_runner.py +452 -89
  250. sky/utils/command_runner.pyi +77 -3
  251. sky/utils/common.py +54 -0
  252. sky/utils/common_utils.py +319 -108
  253. sky/utils/config_utils.py +204 -0
  254. sky/utils/control_master_utils.py +48 -0
  255. sky/utils/controller_utils.py +548 -266
  256. sky/utils/dag_utils.py +93 -32
  257. sky/utils/db_utils.py +18 -4
  258. sky/utils/env_options.py +29 -7
  259. sky/utils/kubernetes/create_cluster.sh +8 -60
  260. sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
  261. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  262. sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
  263. sky/utils/kubernetes/gpu_labeler.py +4 -4
  264. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
  265. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  266. sky/utils/kubernetes/rsync_helper.sh +24 -0
  267. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
  268. sky/utils/log_utils.py +240 -33
  269. sky/utils/message_utils.py +81 -0
  270. sky/utils/registry.py +127 -0
  271. sky/utils/resources_utils.py +94 -22
  272. sky/utils/rich_utils.py +247 -18
  273. sky/utils/schemas.py +284 -64
  274. sky/{status_lib.py → utils/status_lib.py} +12 -7
  275. sky/utils/subprocess_utils.py +212 -46
  276. sky/utils/timeline.py +12 -7
  277. sky/utils/ux_utils.py +168 -15
  278. skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
  279. skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
  280. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
  281. sky/clouds/cloud_registry.py +0 -31
  282. sky/jobs/core.py +0 -330
  283. sky/skylet/providers/azure/__init__.py +0 -2
  284. sky/skylet/providers/azure/azure-vm-template.json +0 -301
  285. sky/skylet/providers/azure/config.py +0 -170
  286. sky/skylet/providers/azure/node_provider.py +0 -466
  287. sky/skylet/providers/lambda_cloud/__init__.py +0 -2
  288. sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
  289. sky/skylet/providers/oci/__init__.py +0 -2
  290. sky/skylet/providers/oci/node_provider.py +0 -488
  291. sky/skylet/providers/oci/query_helper.py +0 -383
  292. sky/skylet/providers/oci/utils.py +0 -21
  293. sky/utils/cluster_yaml_utils.py +0 -24
  294. sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
  295. skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
  296. skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
  297. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
  298. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
  299. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
@@ -1,170 +0,0 @@
1
- import json
2
- import logging
3
- import random
4
- from hashlib import sha256
5
- from pathlib import Path
6
- import time
7
- from typing import Any, Callable
8
-
9
- from azure.common.credentials import get_cli_profile
10
- from azure.identity import AzureCliCredential
11
- from azure.mgmt.network import NetworkManagementClient
12
- from azure.mgmt.resource import ResourceManagementClient
13
- from azure.mgmt.resource.resources.models import DeploymentMode
14
-
15
- from sky.utils import common_utils
16
-
17
- UNIQUE_ID_LEN = 4
18
- _WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS = 600
19
-
20
- logger = logging.getLogger(__name__)
21
-
22
-
23
- def get_azure_sdk_function(client: Any, function_name: str) -> Callable:
24
- """Retrieve a callable function from Azure SDK client object.
25
-
26
- Newer versions of the various client SDKs renamed function names to
27
- have a begin_ prefix. This function supports both the old and new
28
- versions of the SDK by first trying the old name and falling back to
29
- the prefixed new name.
30
- """
31
- func = getattr(
32
- client, function_name, getattr(client, f"begin_{function_name}", None)
33
- )
34
- if func is None:
35
- raise AttributeError(
36
- "'{obj}' object has no {func} or begin_{func} attribute".format(
37
- obj={client.__name__}, func=function_name
38
- )
39
- )
40
- return func
41
-
42
-
43
- def bootstrap_azure(config):
44
- config = _configure_key_pair(config)
45
- config = _configure_resource_group(config)
46
- return config
47
-
48
-
49
- def _configure_resource_group(config):
50
- # TODO: look at availability sets
51
- # https://docs.microsoft.com/en-us/azure/virtual-machines/windows/tutorial-availability-sets
52
- subscription_id = config["provider"].get("subscription_id")
53
- if subscription_id is None:
54
- subscription_id = get_cli_profile().get_subscription_id()
55
- # Increase the timeout to fix the Azure get-access-token (used by ray azure
56
- # node_provider) timeout issue.
57
- # Tracked in https://github.com/Azure/azure-cli/issues/20404#issuecomment-1249575110
58
- credentials = AzureCliCredential(process_timeout=30)
59
- resource_client = ResourceManagementClient(credentials, subscription_id)
60
- config["provider"]["subscription_id"] = subscription_id
61
- logger.info("Using subscription id: %s", subscription_id)
62
-
63
- assert (
64
- "resource_group" in config["provider"]
65
- ), "Provider config must include resource_group field"
66
- resource_group = config["provider"]["resource_group"]
67
-
68
- assert (
69
- "location" in config["provider"]
70
- ), "Provider config must include location field"
71
- params = {"location": config["provider"]["location"]}
72
-
73
- if "tags" in config["provider"]:
74
- params["tags"] = config["provider"]["tags"]
75
-
76
- logger.info("Creating/Updating resource group: %s", resource_group)
77
- rg_create_or_update = get_azure_sdk_function(
78
- client=resource_client.resource_groups, function_name="create_or_update"
79
- )
80
- rg_create_or_update(resource_group_name=resource_group, parameters=params)
81
-
82
- # load the template file
83
- current_path = Path(__file__).parent
84
- template_path = current_path.joinpath("azure-config-template.json")
85
- with open(template_path, "r") as template_fp:
86
- template = json.load(template_fp)
87
-
88
- logger.info("Using cluster name: %s", config["cluster_name"])
89
-
90
- # set unique id for resources in this cluster
91
- unique_id = config["provider"].get("unique_id")
92
- if unique_id is None:
93
- hasher = sha256()
94
- hasher.update(config["provider"]["resource_group"].encode("utf-8"))
95
- unique_id = hasher.hexdigest()[:UNIQUE_ID_LEN]
96
- else:
97
- unique_id = str(unique_id)
98
- config["provider"]["unique_id"] = unique_id
99
- logger.info("Using unique id: %s", unique_id)
100
- cluster_id = "{}-{}".format(config["cluster_name"], unique_id)
101
-
102
- subnet_mask = config["provider"].get("subnet_mask")
103
- if subnet_mask is None:
104
- # choose a random subnet, skipping most common value of 0
105
- random.seed(unique_id)
106
- subnet_mask = "10.{}.0.0/16".format(random.randint(1, 254))
107
- logger.info("Using subnet mask: %s", subnet_mask)
108
-
109
- parameters = {
110
- "properties": {
111
- "mode": DeploymentMode.incremental,
112
- "template": template,
113
- "parameters": {
114
- "subnet": {"value": subnet_mask},
115
- "clusterId": {"value": cluster_id},
116
- },
117
- }
118
- }
119
-
120
- create_or_update = get_azure_sdk_function(
121
- client=resource_client.deployments, function_name="create_or_update"
122
- )
123
- # TODO (skypilot): this takes a long time (> 40 seconds) for stopping an
124
- # azure VM, and this can be called twice during ray down.
125
- outputs = (
126
- create_or_update(
127
- resource_group_name=resource_group,
128
- deployment_name="ray-config",
129
- parameters=parameters,
130
- )
131
- .result()
132
- .properties.outputs
133
- )
134
-
135
- # We should wait for the NSG to be created before opening any ports
136
- # to avoid overriding the newly-added NSG rules.
137
- nsg_id = outputs["nsg"]["value"]
138
- nsg_name = nsg_id.split("/")[-1]
139
- network_client = NetworkManagementClient(credentials, subscription_id)
140
- backoff = common_utils.Backoff(max_backoff_factor=1)
141
- start_time = time.time()
142
- while True:
143
- nsg = network_client.network_security_groups.get(resource_group, nsg_name)
144
- if nsg.provisioning_state == "Succeeded":
145
- break
146
- if time.time() - start_time > _WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS:
147
- raise RuntimeError(
148
- f"Fails to create NSG {nsg_name} in {resource_group} within "
149
- f"{_WAIT_NSG_CREATION_NUM_TIMEOUT_SECONDS} seconds."
150
- )
151
- backoff_time = backoff.current_backoff()
152
- logger.info(
153
- f"NSG {nsg_name} is not created yet. Waiting for "
154
- f"{backoff_time} seconds before checking again."
155
- )
156
- time.sleep(backoff_time)
157
-
158
- # append output resource ids to be used with vm creation
159
- config["provider"]["msi"] = outputs["msi"]["value"]
160
- config["provider"]["nsg"] = nsg_id
161
- config["provider"]["subnet"] = outputs["subnet"]["value"]
162
-
163
- return config
164
-
165
-
166
- def _configure_key_pair(config):
167
- # SkyPilot: The original checks and configurations are no longer
168
- # needed, since we have already set them up in the upper level
169
- # SkyPilot codes. See sky/templates/azure-ray.yml.j2
170
- return config
@@ -1,466 +0,0 @@
1
- import copy
2
- import json
3
- import logging
4
- from pathlib import Path
5
- from threading import RLock
6
- from uuid import uuid4
7
-
8
- from azure.identity import AzureCliCredential
9
- from azure.mgmt.compute import ComputeManagementClient
10
- from azure.mgmt.network import NetworkManagementClient
11
- from azure.mgmt.resource import ResourceManagementClient
12
- from azure.mgmt.resource.resources.models import DeploymentMode
13
-
14
- from sky.skylet.providers.azure.config import (
15
- bootstrap_azure,
16
- get_azure_sdk_function,
17
- )
18
- from sky.skylet import autostop_lib
19
- from sky.skylet.providers.command_runner import SkyDockerCommandRunner
20
- from sky.provision import docker_utils
21
-
22
- from ray.autoscaler._private.command_runner import SSHCommandRunner
23
- from ray.autoscaler.node_provider import NodeProvider
24
- from ray.autoscaler.tags import (
25
- TAG_RAY_CLUSTER_NAME,
26
- TAG_RAY_LAUNCH_CONFIG,
27
- TAG_RAY_NODE_KIND,
28
- TAG_RAY_NODE_NAME,
29
- TAG_RAY_USER_NODE_TYPE,
30
- )
31
-
32
- VM_NAME_MAX_LEN = 64
33
- UNIQUE_ID_LEN = 4
34
-
35
- logger = logging.getLogger(__name__)
36
- azure_logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
37
- azure_logger.setLevel(logging.WARNING)
38
-
39
-
40
- def synchronized(f):
41
- def wrapper(self, *args, **kwargs):
42
- self.lock.acquire()
43
- try:
44
- return f(self, *args, **kwargs)
45
- finally:
46
- self.lock.release()
47
-
48
- return wrapper
49
-
50
-
51
- class AzureNodeProvider(NodeProvider):
52
- """Node Provider for Azure
53
-
54
- This provider assumes Azure credentials are set by running ``az login``
55
- and the default subscription is configured through ``az account``
56
- or set in the ``provider`` field of the autoscaler configuration.
57
-
58
- Nodes may be in one of three states: {pending, running, terminated}. Nodes
59
- appear immediately once started by ``create_node``, and transition
60
- immediately to terminated when ``terminate_node`` is called.
61
- """
62
-
63
- def __init__(self, provider_config, cluster_name):
64
- NodeProvider.__init__(self, provider_config, cluster_name)
65
- if not autostop_lib.get_is_autostopping():
66
- # TODO(suquark): This is a temporary patch for resource group.
67
- # By default, Ray autoscaler assumes the resource group is still
68
- # here even after the whole cluster is destroyed. However, now we
69
- # deletes the resource group after tearing down the cluster. To
70
- # comfort the autoscaler, we need to create/update it here, so the
71
- # resource group always exists.
72
- #
73
- # We should not re-configure the resource group again, when it is
74
- # running on the remote VM and the autostopping is in progress,
75
- # because the VM is running which guarantees the resource group
76
- # exists.
77
- from sky.skylet.providers.azure.config import _configure_resource_group
78
-
79
- _configure_resource_group(
80
- {"cluster_name": cluster_name, "provider": provider_config}
81
- )
82
- subscription_id = provider_config["subscription_id"]
83
- self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", True)
84
- # Sky only supports Azure CLI credential for now.
85
- # Increase the timeout to fix the Azure get-access-token (used by ray azure
86
- # node_provider) timeout issue.
87
- # Tracked in https://github.com/Azure/azure-cli/issues/20404#issuecomment-1249575110
88
- credential = AzureCliCredential(process_timeout=30)
89
- self.compute_client = ComputeManagementClient(credential, subscription_id)
90
- self.network_client = NetworkManagementClient(credential, subscription_id)
91
- self.resource_client = ResourceManagementClient(credential, subscription_id)
92
-
93
- self.lock = RLock()
94
-
95
- # cache node objects
96
- self.cached_nodes = {}
97
-
98
- @synchronized
99
- def _get_filtered_nodes(self, tag_filters):
100
- # add cluster name filter to only get nodes from this cluster
101
- cluster_tag_filters = {**tag_filters, TAG_RAY_CLUSTER_NAME: self.cluster_name}
102
-
103
- def match_tags(vm):
104
- for k, v in cluster_tag_filters.items():
105
- if vm.tags.get(k) != v:
106
- return False
107
- return True
108
-
109
- vms = self.compute_client.virtual_machines.list(
110
- resource_group_name=self.provider_config["resource_group"]
111
- )
112
-
113
- nodes = [self._extract_metadata(vm) for vm in filter(match_tags, vms)]
114
- self.cached_nodes = {node["name"]: node for node in nodes}
115
- return self.cached_nodes
116
-
117
- def _extract_metadata(self, vm):
118
- # get tags
119
- metadata = {"name": vm.name, "tags": vm.tags, "status": ""}
120
-
121
- # get status
122
- resource_group = self.provider_config["resource_group"]
123
- instance = self.compute_client.virtual_machines.instance_view(
124
- resource_group_name=resource_group, vm_name=vm.name
125
- ).as_dict()
126
- for status in instance["statuses"]:
127
- code_state = status["code"].split("/")
128
- # It is possible that sometimes the 'code' is empty string, and we
129
- # should skip them.
130
- if len(code_state) != 2:
131
- continue
132
- code, state = code_state
133
- # skip provisioning status
134
- if code == "PowerState":
135
- metadata["status"] = state
136
- break
137
-
138
- # get ip data
139
- nic_id = vm.network_profile.network_interfaces[0].id
140
- metadata["nic_name"] = nic_id.split("/")[-1]
141
- nic = self.network_client.network_interfaces.get(
142
- resource_group_name=resource_group,
143
- network_interface_name=metadata["nic_name"],
144
- )
145
- ip_config = nic.ip_configurations[0]
146
-
147
- if not self.provider_config.get("use_internal_ips", False):
148
- public_ip_id = ip_config.public_ip_address.id
149
- metadata["public_ip_name"] = public_ip_id.split("/")[-1]
150
- public_ip = self.network_client.public_ip_addresses.get(
151
- resource_group_name=resource_group,
152
- public_ip_address_name=metadata["public_ip_name"],
153
- )
154
- metadata["external_ip"] = public_ip.ip_address
155
-
156
- metadata["internal_ip"] = ip_config.private_ip_address
157
-
158
- return metadata
159
-
160
- def stopped_nodes(self, tag_filters):
161
- """Return a list of stopped node ids filtered by the specified tags dict."""
162
- nodes = self._get_filtered_nodes(tag_filters=tag_filters)
163
- return [k for k, v in nodes.items() if v["status"].startswith("deallocat")]
164
-
165
- def non_terminated_nodes(self, tag_filters):
166
- """Return a list of node ids filtered by the specified tags dict.
167
-
168
- This list must not include terminated nodes. For performance reasons,
169
- providers are allowed to cache the result of a call to nodes() to
170
- serve single-node queries (e.g. is_running(node_id)). This means that
171
- nodes() must be called again to refresh results.
172
-
173
- Examples:
174
- >>> from ray.autoscaler.tags import TAG_RAY_NODE_KIND
175
- >>> provider = ... # doctest: +SKIP
176
- >>> provider.non_terminated_nodes( # doctest: +SKIP
177
- ... {TAG_RAY_NODE_KIND: "worker"})
178
- ["node-1", "node-2"]
179
- """
180
- nodes = self._get_filtered_nodes(tag_filters=tag_filters)
181
- return [k for k, v in nodes.items() if not v["status"].startswith("deallocat")]
182
-
183
- def is_running(self, node_id):
184
- """Return whether the specified node is running."""
185
- # always get current status
186
- node = self._get_node(node_id=node_id)
187
- return node["status"] == "running"
188
-
189
- def is_terminated(self, node_id):
190
- """Return whether the specified node is terminated."""
191
- # always get current status
192
- node = self._get_node(node_id=node_id)
193
- return node["status"].startswith("deallocat")
194
-
195
- def node_tags(self, node_id):
196
- """Returns the tags of the given node (string dict)."""
197
- return self._get_cached_node(node_id=node_id)["tags"]
198
-
199
- def external_ip(self, node_id):
200
- """Returns the external ip of the given node."""
201
- ip = (
202
- self._get_cached_node(node_id=node_id)["external_ip"]
203
- or self._get_node(node_id=node_id)["external_ip"]
204
- )
205
- return ip
206
-
207
- def internal_ip(self, node_id):
208
- """Returns the internal ip (Ray ip) of the given node."""
209
- ip = (
210
- self._get_cached_node(node_id=node_id)["internal_ip"]
211
- or self._get_node(node_id=node_id)["internal_ip"]
212
- )
213
- return ip
214
-
215
- def create_node(self, node_config, tags, count):
216
- resource_group = self.provider_config["resource_group"]
217
-
218
- if self.cache_stopped_nodes:
219
- VALIDITY_TAGS = [
220
- TAG_RAY_CLUSTER_NAME,
221
- TAG_RAY_NODE_KIND,
222
- TAG_RAY_USER_NODE_TYPE,
223
- ]
224
- filters = {tag: tags[tag] for tag in VALIDITY_TAGS if tag in tags}
225
- filters_with_launch_config = copy.copy(filters)
226
- if TAG_RAY_LAUNCH_CONFIG in tags:
227
- filters_with_launch_config[TAG_RAY_LAUNCH_CONFIG] = tags[
228
- TAG_RAY_LAUNCH_CONFIG
229
- ]
230
-
231
- # SkyPilot: We try to use the instances with the same matching launch_config first. If
232
- # there is not enough instances with matching launch_config, we then use all the
233
- # instances with the same matching launch_config plus some instances with wrong
234
- # launch_config.
235
- nodes_matching_launch_config = self.stopped_nodes(
236
- filters_with_launch_config
237
- )
238
- nodes_matching_launch_config.sort(reverse=True)
239
- if len(nodes_matching_launch_config) >= count:
240
- reuse_nodes = nodes_matching_launch_config[:count]
241
- else:
242
- nodes_all = self.stopped_nodes(filters)
243
- nodes_non_matching_launch_config = [
244
- n for n in nodes_all if n not in nodes_matching_launch_config
245
- ]
246
- # This sort is for backward compatibility, where the user already has
247
- # leaked stopped nodes with the different launch config before update
248
- # to #1671, and the total number of the leaked nodes is greater than
249
- # the number of nodes to be created. With this, we make sure the nodes
250
- # are reused in a deterministic order (sorting by str IDs; we cannot
251
- # get the launch time info here; otherwise, sort by the launch time
252
- # is more accurate.)
253
- # This can be removed in the future when we are sure all the users
254
- # have updated to #1671.
255
- nodes_non_matching_launch_config.sort(reverse=True)
256
- reuse_nodes = (
257
- nodes_matching_launch_config + nodes_non_matching_launch_config
258
- )
259
- # The total number of reusable nodes can be less than the number of nodes to be created.
260
- # This `[:count]` is fine, as it will get all the reusable nodes, even if there are
261
- # less nodes.
262
- reuse_nodes = reuse_nodes[:count]
263
-
264
- logger.info(
265
- f"Reusing nodes {list(reuse_nodes)}. "
266
- "To disable reuse, set `cache_stopped_nodes: False` "
267
- "under `provider` in the cluster configuration.",
268
- )
269
- start = get_azure_sdk_function(
270
- client=self.compute_client.virtual_machines, function_name="start"
271
- )
272
- for node_id in reuse_nodes:
273
- start(resource_group_name=resource_group, vm_name=node_id).wait()
274
- self.set_node_tags(node_id, tags)
275
- count -= len(reuse_nodes)
276
-
277
- if count:
278
- self._create_node(node_config, tags, count)
279
-
280
- def _create_node(self, node_config, tags, count):
281
- """Creates a number of nodes within the namespace."""
282
- resource_group = self.provider_config["resource_group"]
283
-
284
- # load the template file
285
- current_path = Path(__file__).parent
286
- template_path = current_path.joinpath("azure-vm-template.json")
287
- with open(template_path, "r") as template_fp:
288
- template = json.load(template_fp)
289
-
290
- # get the tags
291
- config_tags = node_config.get("tags", {}).copy()
292
- config_tags.update(tags)
293
- config_tags[TAG_RAY_CLUSTER_NAME] = self.cluster_name
294
-
295
- vm_name = "{node}-{unique_id}-{vm_id}".format(
296
- node=config_tags.get(TAG_RAY_NODE_NAME, "node"),
297
- unique_id=self.provider_config["unique_id"],
298
- vm_id=uuid4().hex[:UNIQUE_ID_LEN],
299
- )[:VM_NAME_MAX_LEN]
300
- use_internal_ips = self.provider_config.get("use_internal_ips", False)
301
-
302
- template_params = node_config["azure_arm_parameters"].copy()
303
- template_params["vmName"] = vm_name
304
- template_params["provisionPublicIp"] = not use_internal_ips
305
- template_params["vmTags"] = config_tags
306
- template_params["vmCount"] = count
307
- template_params["msi"] = self.provider_config["msi"]
308
- template_params["nsg"] = self.provider_config["nsg"]
309
- template_params["subnet"] = self.provider_config["subnet"]
310
-
311
- parameters = {
312
- "properties": {
313
- "mode": DeploymentMode.incremental,
314
- "template": template,
315
- "parameters": {
316
- key: {"value": value} for key, value in template_params.items()
317
- },
318
- }
319
- }
320
-
321
- # TODO: we could get the private/public ips back directly
322
- create_or_update = get_azure_sdk_function(
323
- client=self.resource_client.deployments, function_name="create_or_update"
324
- )
325
- create_or_update(
326
- resource_group_name=resource_group,
327
- deployment_name=vm_name,
328
- parameters=parameters,
329
- ).wait()
330
-
331
- @synchronized
332
- def set_node_tags(self, node_id, tags):
333
- """Sets the tag values (string dict) for the specified node."""
334
- node_tags = self._get_cached_node(node_id)["tags"]
335
- node_tags.update(tags)
336
- update = get_azure_sdk_function(
337
- client=self.compute_client.virtual_machines, function_name="update"
338
- )
339
- update(
340
- resource_group_name=self.provider_config["resource_group"],
341
- vm_name=node_id,
342
- parameters={"tags": node_tags},
343
- )
344
- self.cached_nodes[node_id]["tags"] = node_tags
345
-
346
- def terminate_node(self, node_id):
347
- """Terminates the specified node. This will delete the VM and
348
- associated resources (NIC, IP, Storage) for the specified node."""
349
-
350
- resource_group = self.provider_config["resource_group"]
351
- try:
352
- # get metadata for node
353
- metadata = self._get_node(node_id)
354
- except KeyError:
355
- # node no longer exists
356
- return
357
-
358
- if self.cache_stopped_nodes:
359
- try:
360
- # stop machine and leave all resources
361
- logger.info(
362
- f"Stopping instance {node_id}"
363
- "(to fully terminate instead, "
364
- "set `cache_stopped_nodes: False` "
365
- "under `provider` in the cluster configuration)"
366
- )
367
- stop = get_azure_sdk_function(
368
- client=self.compute_client.virtual_machines,
369
- function_name="deallocate",
370
- )
371
- stop(resource_group_name=resource_group, vm_name=node_id)
372
- except Exception as e:
373
- logger.warning("Failed to stop VM: {}".format(e))
374
- else:
375
- vm = self.compute_client.virtual_machines.get(
376
- resource_group_name=resource_group, vm_name=node_id
377
- )
378
- disks = {d.name for d in vm.storage_profile.data_disks}
379
- disks.add(vm.storage_profile.os_disk.name)
380
-
381
- try:
382
- # delete machine, must wait for this to complete
383
- delete = get_azure_sdk_function(
384
- client=self.compute_client.virtual_machines, function_name="delete"
385
- )
386
- delete(resource_group_name=resource_group, vm_name=node_id).wait()
387
- except Exception as e:
388
- logger.warning("Failed to delete VM: {}".format(e))
389
-
390
- try:
391
- # delete nic
392
- delete = get_azure_sdk_function(
393
- client=self.network_client.network_interfaces,
394
- function_name="delete",
395
- )
396
- delete(
397
- resource_group_name=resource_group,
398
- network_interface_name=metadata["nic_name"],
399
- )
400
- except Exception as e:
401
- logger.warning("Failed to delete nic: {}".format(e))
402
-
403
- # delete ip address
404
- if "public_ip_name" in metadata:
405
- try:
406
- delete = get_azure_sdk_function(
407
- client=self.network_client.public_ip_addresses,
408
- function_name="delete",
409
- )
410
- delete(
411
- resource_group_name=resource_group,
412
- public_ip_address_name=metadata["public_ip_name"],
413
- )
414
- except Exception as e:
415
- logger.warning("Failed to delete public ip: {}".format(e))
416
-
417
- # delete disks
418
- for disk in disks:
419
- try:
420
- delete = get_azure_sdk_function(
421
- client=self.compute_client.disks, function_name="delete"
422
- )
423
- delete(resource_group_name=resource_group, disk_name=disk)
424
- except Exception as e:
425
- logger.warning("Failed to delete disk: {}".format(e))
426
-
427
- def _get_node(self, node_id):
428
- self._get_filtered_nodes({}) # Side effect: updates cache
429
- return self.cached_nodes[node_id]
430
-
431
- def _get_cached_node(self, node_id):
432
- if node_id in self.cached_nodes:
433
- return self.cached_nodes[node_id]
434
- return self._get_node(node_id=node_id)
435
-
436
- @staticmethod
437
- def bootstrap_config(cluster_config):
438
- return bootstrap_azure(cluster_config)
439
-
440
- def get_command_runner(
441
- self,
442
- log_prefix,
443
- node_id,
444
- auth_config,
445
- cluster_name,
446
- process_runner,
447
- use_internal_ip,
448
- docker_config=None,
449
- ):
450
- common_args = {
451
- "log_prefix": log_prefix,
452
- "node_id": node_id,
453
- "provider": self,
454
- "auth_config": auth_config,
455
- "cluster_name": cluster_name,
456
- "process_runner": process_runner,
457
- "use_internal_ip": use_internal_ip,
458
- }
459
- if docker_config and docker_config["container_name"] != "":
460
- if "docker_login_config" in self.provider_config:
461
- docker_config["docker_login_config"] = docker_utils.DockerLoginConfig(
462
- **self.provider_config["docker_login_config"]
463
- )
464
- return SkyDockerCommandRunner(docker_config, **common_args)
465
- else:
466
- return SSHCommandRunner(**common_args)
@@ -1,2 +0,0 @@
1
- """Lambda Cloud node provider"""
2
- from sky.skylet.providers.lambda_cloud.node_provider import LambdaNodeProvider