skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. sky/__init__.py +64 -32
  2. sky/adaptors/aws.py +23 -6
  3. sky/adaptors/azure.py +432 -15
  4. sky/adaptors/cloudflare.py +5 -5
  5. sky/adaptors/common.py +19 -9
  6. sky/adaptors/do.py +20 -0
  7. sky/adaptors/gcp.py +3 -2
  8. sky/adaptors/kubernetes.py +122 -88
  9. sky/adaptors/nebius.py +100 -0
  10. sky/adaptors/oci.py +39 -1
  11. sky/adaptors/vast.py +29 -0
  12. sky/admin_policy.py +101 -0
  13. sky/authentication.py +117 -98
  14. sky/backends/backend.py +52 -20
  15. sky/backends/backend_utils.py +669 -557
  16. sky/backends/cloud_vm_ray_backend.py +1099 -808
  17. sky/backends/local_docker_backend.py +14 -8
  18. sky/backends/wheel_utils.py +38 -20
  19. sky/benchmark/benchmark_utils.py +22 -23
  20. sky/check.py +76 -27
  21. sky/cli.py +1586 -1139
  22. sky/client/__init__.py +1 -0
  23. sky/client/cli.py +5683 -0
  24. sky/client/common.py +345 -0
  25. sky/client/sdk.py +1765 -0
  26. sky/cloud_stores.py +283 -19
  27. sky/clouds/__init__.py +7 -2
  28. sky/clouds/aws.py +303 -112
  29. sky/clouds/azure.py +185 -179
  30. sky/clouds/cloud.py +115 -37
  31. sky/clouds/cudo.py +29 -22
  32. sky/clouds/do.py +313 -0
  33. sky/clouds/fluidstack.py +44 -54
  34. sky/clouds/gcp.py +206 -65
  35. sky/clouds/ibm.py +26 -21
  36. sky/clouds/kubernetes.py +345 -91
  37. sky/clouds/lambda_cloud.py +40 -29
  38. sky/clouds/nebius.py +297 -0
  39. sky/clouds/oci.py +129 -90
  40. sky/clouds/paperspace.py +22 -18
  41. sky/clouds/runpod.py +53 -34
  42. sky/clouds/scp.py +28 -24
  43. sky/clouds/service_catalog/__init__.py +19 -13
  44. sky/clouds/service_catalog/aws_catalog.py +29 -12
  45. sky/clouds/service_catalog/azure_catalog.py +33 -6
  46. sky/clouds/service_catalog/common.py +95 -75
  47. sky/clouds/service_catalog/constants.py +3 -3
  48. sky/clouds/service_catalog/cudo_catalog.py +13 -3
  49. sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
  50. sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
  51. sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
  52. sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
  53. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
  54. sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
  55. sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
  56. sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
  57. sky/clouds/service_catalog/do_catalog.py +111 -0
  58. sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
  59. sky/clouds/service_catalog/gcp_catalog.py +16 -2
  60. sky/clouds/service_catalog/ibm_catalog.py +2 -2
  61. sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
  62. sky/clouds/service_catalog/lambda_catalog.py +8 -3
  63. sky/clouds/service_catalog/nebius_catalog.py +116 -0
  64. sky/clouds/service_catalog/oci_catalog.py +31 -4
  65. sky/clouds/service_catalog/paperspace_catalog.py +2 -2
  66. sky/clouds/service_catalog/runpod_catalog.py +2 -2
  67. sky/clouds/service_catalog/scp_catalog.py +2 -2
  68. sky/clouds/service_catalog/vast_catalog.py +104 -0
  69. sky/clouds/service_catalog/vsphere_catalog.py +2 -2
  70. sky/clouds/utils/aws_utils.py +65 -0
  71. sky/clouds/utils/azure_utils.py +91 -0
  72. sky/clouds/utils/gcp_utils.py +5 -9
  73. sky/clouds/utils/oci_utils.py +47 -5
  74. sky/clouds/utils/scp_utils.py +4 -3
  75. sky/clouds/vast.py +280 -0
  76. sky/clouds/vsphere.py +22 -18
  77. sky/core.py +361 -107
  78. sky/dag.py +41 -28
  79. sky/data/data_transfer.py +37 -0
  80. sky/data/data_utils.py +211 -32
  81. sky/data/mounting_utils.py +182 -30
  82. sky/data/storage.py +2118 -270
  83. sky/data/storage_utils.py +126 -5
  84. sky/exceptions.py +179 -8
  85. sky/execution.py +158 -85
  86. sky/global_user_state.py +150 -34
  87. sky/jobs/__init__.py +12 -10
  88. sky/jobs/client/__init__.py +0 -0
  89. sky/jobs/client/sdk.py +302 -0
  90. sky/jobs/constants.py +49 -11
  91. sky/jobs/controller.py +161 -99
  92. sky/jobs/dashboard/dashboard.py +171 -25
  93. sky/jobs/dashboard/templates/index.html +572 -60
  94. sky/jobs/recovery_strategy.py +157 -156
  95. sky/jobs/scheduler.py +307 -0
  96. sky/jobs/server/__init__.py +1 -0
  97. sky/jobs/server/core.py +598 -0
  98. sky/jobs/server/dashboard_utils.py +69 -0
  99. sky/jobs/server/server.py +190 -0
  100. sky/jobs/state.py +627 -122
  101. sky/jobs/utils.py +615 -206
  102. sky/models.py +27 -0
  103. sky/optimizer.py +142 -83
  104. sky/provision/__init__.py +20 -5
  105. sky/provision/aws/config.py +124 -42
  106. sky/provision/aws/instance.py +130 -53
  107. sky/provision/azure/__init__.py +7 -0
  108. sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
  109. sky/provision/azure/config.py +220 -0
  110. sky/provision/azure/instance.py +1012 -37
  111. sky/provision/common.py +31 -3
  112. sky/provision/constants.py +25 -0
  113. sky/provision/cudo/__init__.py +2 -1
  114. sky/provision/cudo/cudo_utils.py +112 -0
  115. sky/provision/cudo/cudo_wrapper.py +37 -16
  116. sky/provision/cudo/instance.py +28 -12
  117. sky/provision/do/__init__.py +11 -0
  118. sky/provision/do/config.py +14 -0
  119. sky/provision/do/constants.py +10 -0
  120. sky/provision/do/instance.py +287 -0
  121. sky/provision/do/utils.py +301 -0
  122. sky/provision/docker_utils.py +82 -46
  123. sky/provision/fluidstack/fluidstack_utils.py +57 -125
  124. sky/provision/fluidstack/instance.py +15 -43
  125. sky/provision/gcp/config.py +19 -9
  126. sky/provision/gcp/constants.py +7 -1
  127. sky/provision/gcp/instance.py +55 -34
  128. sky/provision/gcp/instance_utils.py +339 -80
  129. sky/provision/gcp/mig_utils.py +210 -0
  130. sky/provision/instance_setup.py +172 -133
  131. sky/provision/kubernetes/__init__.py +1 -0
  132. sky/provision/kubernetes/config.py +104 -90
  133. sky/provision/kubernetes/constants.py +8 -0
  134. sky/provision/kubernetes/instance.py +680 -325
  135. sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
  136. sky/provision/kubernetes/network.py +54 -20
  137. sky/provision/kubernetes/network_utils.py +70 -21
  138. sky/provision/kubernetes/utils.py +1370 -251
  139. sky/provision/lambda_cloud/__init__.py +11 -0
  140. sky/provision/lambda_cloud/config.py +10 -0
  141. sky/provision/lambda_cloud/instance.py +265 -0
  142. sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
  143. sky/provision/logging.py +1 -1
  144. sky/provision/nebius/__init__.py +11 -0
  145. sky/provision/nebius/config.py +11 -0
  146. sky/provision/nebius/instance.py +285 -0
  147. sky/provision/nebius/utils.py +318 -0
  148. sky/provision/oci/__init__.py +15 -0
  149. sky/provision/oci/config.py +51 -0
  150. sky/provision/oci/instance.py +436 -0
  151. sky/provision/oci/query_utils.py +681 -0
  152. sky/provision/paperspace/constants.py +6 -0
  153. sky/provision/paperspace/instance.py +4 -3
  154. sky/provision/paperspace/utils.py +2 -0
  155. sky/provision/provisioner.py +207 -130
  156. sky/provision/runpod/__init__.py +1 -0
  157. sky/provision/runpod/api/__init__.py +3 -0
  158. sky/provision/runpod/api/commands.py +119 -0
  159. sky/provision/runpod/api/pods.py +142 -0
  160. sky/provision/runpod/instance.py +64 -8
  161. sky/provision/runpod/utils.py +239 -23
  162. sky/provision/vast/__init__.py +10 -0
  163. sky/provision/vast/config.py +11 -0
  164. sky/provision/vast/instance.py +247 -0
  165. sky/provision/vast/utils.py +162 -0
  166. sky/provision/vsphere/common/vim_utils.py +1 -1
  167. sky/provision/vsphere/instance.py +8 -18
  168. sky/provision/vsphere/vsphere_utils.py +1 -1
  169. sky/resources.py +247 -102
  170. sky/serve/__init__.py +9 -9
  171. sky/serve/autoscalers.py +361 -299
  172. sky/serve/client/__init__.py +0 -0
  173. sky/serve/client/sdk.py +366 -0
  174. sky/serve/constants.py +12 -3
  175. sky/serve/controller.py +106 -36
  176. sky/serve/load_balancer.py +63 -12
  177. sky/serve/load_balancing_policies.py +84 -2
  178. sky/serve/replica_managers.py +42 -34
  179. sky/serve/serve_state.py +62 -32
  180. sky/serve/serve_utils.py +271 -160
  181. sky/serve/server/__init__.py +0 -0
  182. sky/serve/{core.py → server/core.py} +271 -90
  183. sky/serve/server/server.py +112 -0
  184. sky/serve/service.py +52 -16
  185. sky/serve/service_spec.py +95 -32
  186. sky/server/__init__.py +1 -0
  187. sky/server/common.py +430 -0
  188. sky/server/constants.py +21 -0
  189. sky/server/html/log.html +174 -0
  190. sky/server/requests/__init__.py +0 -0
  191. sky/server/requests/executor.py +472 -0
  192. sky/server/requests/payloads.py +487 -0
  193. sky/server/requests/queues/__init__.py +0 -0
  194. sky/server/requests/queues/mp_queue.py +76 -0
  195. sky/server/requests/requests.py +567 -0
  196. sky/server/requests/serializers/__init__.py +0 -0
  197. sky/server/requests/serializers/decoders.py +192 -0
  198. sky/server/requests/serializers/encoders.py +166 -0
  199. sky/server/server.py +1106 -0
  200. sky/server/stream_utils.py +141 -0
  201. sky/setup_files/MANIFEST.in +2 -5
  202. sky/setup_files/dependencies.py +159 -0
  203. sky/setup_files/setup.py +14 -125
  204. sky/sky_logging.py +59 -14
  205. sky/skylet/autostop_lib.py +2 -2
  206. sky/skylet/constants.py +183 -50
  207. sky/skylet/events.py +22 -10
  208. sky/skylet/job_lib.py +403 -258
  209. sky/skylet/log_lib.py +111 -71
  210. sky/skylet/log_lib.pyi +6 -0
  211. sky/skylet/providers/command_runner.py +6 -8
  212. sky/skylet/providers/ibm/node_provider.py +2 -2
  213. sky/skylet/providers/scp/config.py +11 -3
  214. sky/skylet/providers/scp/node_provider.py +8 -8
  215. sky/skylet/skylet.py +3 -1
  216. sky/skylet/subprocess_daemon.py +69 -17
  217. sky/skypilot_config.py +119 -57
  218. sky/task.py +205 -64
  219. sky/templates/aws-ray.yml.j2 +37 -7
  220. sky/templates/azure-ray.yml.j2 +27 -82
  221. sky/templates/cudo-ray.yml.j2 +7 -3
  222. sky/templates/do-ray.yml.j2 +98 -0
  223. sky/templates/fluidstack-ray.yml.j2 +7 -4
  224. sky/templates/gcp-ray.yml.j2 +26 -6
  225. sky/templates/ibm-ray.yml.j2 +3 -2
  226. sky/templates/jobs-controller.yaml.j2 +46 -11
  227. sky/templates/kubernetes-ingress.yml.j2 +7 -0
  228. sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
  229. sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
  230. sky/templates/kubernetes-ray.yml.j2 +292 -25
  231. sky/templates/lambda-ray.yml.j2 +30 -40
  232. sky/templates/nebius-ray.yml.j2 +79 -0
  233. sky/templates/oci-ray.yml.j2 +18 -57
  234. sky/templates/paperspace-ray.yml.j2 +10 -6
  235. sky/templates/runpod-ray.yml.j2 +26 -4
  236. sky/templates/scp-ray.yml.j2 +3 -2
  237. sky/templates/sky-serve-controller.yaml.j2 +12 -1
  238. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  239. sky/templates/vast-ray.yml.j2 +70 -0
  240. sky/templates/vsphere-ray.yml.j2 +8 -3
  241. sky/templates/websocket_proxy.py +64 -0
  242. sky/usage/constants.py +10 -1
  243. sky/usage/usage_lib.py +130 -37
  244. sky/utils/accelerator_registry.py +35 -51
  245. sky/utils/admin_policy_utils.py +147 -0
  246. sky/utils/annotations.py +51 -0
  247. sky/utils/cli_utils/status_utils.py +81 -23
  248. sky/utils/cluster_utils.py +356 -0
  249. sky/utils/command_runner.py +452 -89
  250. sky/utils/command_runner.pyi +77 -3
  251. sky/utils/common.py +54 -0
  252. sky/utils/common_utils.py +319 -108
  253. sky/utils/config_utils.py +204 -0
  254. sky/utils/control_master_utils.py +48 -0
  255. sky/utils/controller_utils.py +548 -266
  256. sky/utils/dag_utils.py +93 -32
  257. sky/utils/db_utils.py +18 -4
  258. sky/utils/env_options.py +29 -7
  259. sky/utils/kubernetes/create_cluster.sh +8 -60
  260. sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
  261. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  262. sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
  263. sky/utils/kubernetes/gpu_labeler.py +4 -4
  264. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
  265. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  266. sky/utils/kubernetes/rsync_helper.sh +24 -0
  267. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
  268. sky/utils/log_utils.py +240 -33
  269. sky/utils/message_utils.py +81 -0
  270. sky/utils/registry.py +127 -0
  271. sky/utils/resources_utils.py +94 -22
  272. sky/utils/rich_utils.py +247 -18
  273. sky/utils/schemas.py +284 -64
  274. sky/{status_lib.py → utils/status_lib.py} +12 -7
  275. sky/utils/subprocess_utils.py +212 -46
  276. sky/utils/timeline.py +12 -7
  277. sky/utils/ux_utils.py +168 -15
  278. skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
  279. skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
  280. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
  281. sky/clouds/cloud_registry.py +0 -31
  282. sky/jobs/core.py +0 -330
  283. sky/skylet/providers/azure/__init__.py +0 -2
  284. sky/skylet/providers/azure/azure-vm-template.json +0 -301
  285. sky/skylet/providers/azure/config.py +0 -170
  286. sky/skylet/providers/azure/node_provider.py +0 -466
  287. sky/skylet/providers/lambda_cloud/__init__.py +0 -2
  288. sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
  289. sky/skylet/providers/oci/__init__.py +0 -2
  290. sky/skylet/providers/oci/node_provider.py +0 -488
  291. sky/skylet/providers/oci/query_helper.py +0 -383
  292. sky/skylet/providers/oci/utils.py +0 -21
  293. sky/utils/cluster_yaml_utils.py +0 -24
  294. sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
  295. skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
  296. skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
  297. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
  298. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
  299. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ import hashlib
5
5
  import os
6
6
  import time
7
7
  import typing
8
- from typing import Dict, List, NamedTuple, Optional, Tuple
8
+ from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
9
9
 
10
10
  import filelock
11
11
  import requests
@@ -13,8 +13,9 @@ import requests
13
13
  from sky import sky_logging
14
14
  from sky.adaptors import common as adaptors_common
15
15
  from sky.clouds import cloud as cloud_lib
16
- from sky.clouds import cloud_registry
17
16
  from sky.clouds.service_catalog import constants
17
+ from sky.utils import common_utils
18
+ from sky.utils import registry
18
19
  from sky.utils import rich_utils
19
20
  from sky.utils import ux_utils
20
21
 
@@ -58,7 +59,9 @@ class InstanceTypeInfo(NamedTuple):
58
59
 
59
60
 
60
61
  def get_catalog_path(filename: str) -> str:
61
- return os.path.join(_ABSOLUTE_VERSIONED_CATALOG_DIR, filename)
62
+ catalog_path = os.path.join(_ABSOLUTE_VERSIONED_CATALOG_DIR, filename)
63
+ os.makedirs(os.path.dirname(catalog_path), exist_ok=True)
64
+ return catalog_path
62
65
 
63
66
 
64
67
  def is_catalog_modified(filename: str) -> bool:
@@ -67,8 +70,7 @@ def is_catalog_modified(filename: str) -> bool:
67
70
  meta_path = os.path.join(_ABSOLUTE_VERSIONED_CATALOG_DIR, '.meta', filename)
68
71
  md5_filepath = meta_path + '.md5'
69
72
  if os.path.exists(md5_filepath):
70
- with open(catalog_path, 'rb') as f:
71
- file_md5 = hashlib.md5(f.read()).hexdigest()
73
+ file_md5 = common_utils.hash_file(catalog_path, 'md5').hexdigest()
72
74
  with open(md5_filepath, 'r', encoding='utf-8') as f:
73
75
  last_md5 = f.read()
74
76
  return file_md5 != last_md5
@@ -118,19 +120,21 @@ def get_modified_catalog_file_mounts() -> Dict[str, str]:
118
120
 
119
121
 
120
122
  class LazyDataFrame:
121
- """A lazy data frame that reads the catalog on demand.
123
+ """A lazy data frame that updates and reads the catalog on demand.
122
124
 
123
125
  We don't need to load the catalog for every SkyPilot call, and this class
124
126
  allows us to load the catalog only when needed.
125
127
  """
126
128
 
127
- def __init__(self, filename: str):
129
+ def __init__(self, filename: str, update_func: Callable[[], None]):
128
130
  self._filename = filename
129
131
  self._df: Optional['pd.DataFrame'] = None
132
+ self._update_func = update_func
130
133
 
131
134
  def _load_df(self) -> 'pd.DataFrame':
132
135
  if self._df is None:
133
136
  try:
137
+ self._update_func()
134
138
  self._df = pd.read_csv(self._filename)
135
139
  except Exception as e: # pylint: disable=broad-except
136
140
  # As users can manually modify the catalog, read_csv can fail.
@@ -167,63 +171,69 @@ def read_catalog(filename: str,
167
171
  assert (pull_frequency_hours is None or
168
172
  pull_frequency_hours >= 0), pull_frequency_hours
169
173
  catalog_path = get_catalog_path(filename)
170
- cloud = cloud_registry.CLOUD_REGISTRY.from_str(os.path.dirname(filename))
174
+ cloud = os.path.dirname(filename)
175
+ if cloud != 'common':
176
+ cloud = str(registry.CLOUD_REGISTRY.from_str(cloud))
171
177
 
172
178
  meta_path = os.path.join(_ABSOLUTE_VERSIONED_CATALOG_DIR, '.meta', filename)
173
179
  os.makedirs(os.path.dirname(meta_path), exist_ok=True)
174
180
 
175
- # Atomic check, to avoid conflicts with other processes.
176
- # TODO(mraheja): remove pylint disabling when filelock version updated
177
- # pylint: disable=abstract-class-instantiated
178
- with filelock.FileLock(meta_path + '.lock'):
179
-
180
- def _need_update() -> bool:
181
- if not os.path.exists(catalog_path):
182
- return True
183
- if pull_frequency_hours is None:
184
- return False
185
- if is_catalog_modified(filename):
186
- # If the catalog is modified by a user manually, we should
187
- # avoid overwriting the catalog by fetching from GitHub.
188
- return False
189
-
190
- last_update = os.path.getmtime(catalog_path)
191
- return last_update + pull_frequency_hours * 3600 < time.time()
192
-
193
- if _need_update():
194
- url = f'{constants.HOSTED_CATALOG_DIR_URL}/{constants.CATALOG_SCHEMA_VERSION}/{filename}' # pylint: disable=line-too-long
195
- update_frequency_str = ''
196
- if pull_frequency_hours is not None:
197
- update_frequency_str = f' (every {pull_frequency_hours} hours)'
198
- with rich_utils.safe_status((f'Updating {cloud} catalog: '
199
- f'{filename}'
200
- f'{update_frequency_str}')):
201
- try:
202
- r = requests.get(url)
203
- r.raise_for_status()
204
- except requests.exceptions.RequestException as e:
205
- error_str = (f'Failed to fetch {cloud} catalog '
206
- f'{filename}. ')
207
- if os.path.exists(catalog_path):
208
- logger.warning(
209
- f'{error_str}Using cached catalog files.')
210
- # Update catalog file modification time.
211
- os.utime(catalog_path, None) # Sets to current time
181
+ def _need_update() -> bool:
182
+ if not os.path.exists(catalog_path):
183
+ return True
184
+ if pull_frequency_hours is None:
185
+ return False
186
+ if is_catalog_modified(filename):
187
+ # If the catalog is modified by a user manually, we should
188
+ # avoid overwriting the catalog by fetching from GitHub.
189
+ return False
190
+
191
+ last_update = os.path.getmtime(catalog_path)
192
+ return last_update + pull_frequency_hours * 3600 < time.time()
193
+
194
+ def _update_catalog():
195
+ # Atomic check, to avoid conflicts with other processes.
196
+ with filelock.FileLock(meta_path + '.lock'):
197
+ if _need_update():
198
+ url = f'{constants.HOSTED_CATALOG_DIR_URL}/{constants.CATALOG_SCHEMA_VERSION}/{filename}' # pylint: disable=line-too-long
199
+ update_frequency_str = ''
200
+ if pull_frequency_hours is not None:
201
+ update_frequency_str = (
202
+ f' (every {pull_frequency_hours} hours)')
203
+ with rich_utils.safe_status(
204
+ ux_utils.spinner_message(
205
+ f'Updating {cloud} catalog: {filename}') +
206
+ f'{update_frequency_str}'):
207
+ try:
208
+ r = requests.get(url=url,
209
+ headers={'User-Agent': 'SkyPilot/0.7'})
210
+ r.raise_for_status()
211
+ except requests.exceptions.RequestException as e:
212
+ error_str = (f'Failed to fetch {cloud} catalog '
213
+ f'{filename}. ')
214
+ if os.path.exists(catalog_path):
215
+ logger.warning(
216
+ f'{error_str}Using cached catalog files.')
217
+ # Update catalog file modification time.
218
+ os.utime(catalog_path, None) # Sets to current time
219
+ else:
220
+ logger.error(
221
+ f'{error_str}Please check your internet '
222
+ 'connection.')
223
+ with ux_utils.print_exception_no_traceback():
224
+ raise e
212
225
  else:
213
- logger.error(
214
- f'{error_str}Please check your internet connection.'
215
- )
216
- with ux_utils.print_exception_no_traceback():
217
- raise e
218
- else:
219
- # Download successful, save the catalog to a local file.
220
- os.makedirs(os.path.dirname(catalog_path), exist_ok=True)
221
- with open(catalog_path, 'w', encoding='utf-8') as f:
222
- f.write(r.text)
223
- with open(meta_path + '.md5', 'w', encoding='utf-8') as f:
224
- f.write(hashlib.md5(r.text.encode()).hexdigest())
225
-
226
- return LazyDataFrame(catalog_path)
226
+ # Download successful, save the catalog to a local file.
227
+ os.makedirs(os.path.dirname(catalog_path),
228
+ exist_ok=True)
229
+ with open(catalog_path, 'w', encoding='utf-8') as f:
230
+ f.write(r.text)
231
+ with open(meta_path + '.md5', 'w',
232
+ encoding='utf-8') as f:
233
+ f.write(hashlib.md5(r.text.encode()).hexdigest())
234
+ logger.debug(f'Updated {cloud} catalog {filename}.')
235
+
236
+ return LazyDataFrame(catalog_path, update_func=_update_catalog)
227
237
 
228
238
 
229
239
  def _get_instance_type(
@@ -262,9 +272,10 @@ def validate_region_zone_impl(
262
272
  candidate_loc = difflib.get_close_matches(loc, all_loc, n=5, cutoff=0.9)
263
273
  candidate_loc = sorted(candidate_loc)
264
274
  candidate_strs = ''
265
- if len(candidate_loc) > 0:
275
+ if candidate_loc:
266
276
  candidate_strs = ', '.join(candidate_loc)
267
277
  candidate_strs = f'\nDid you mean one of these: {candidate_strs!r}?'
278
+
268
279
  return candidate_strs
269
280
 
270
281
  def _get_all_supported_regions_str() -> str:
@@ -278,7 +289,7 @@ def validate_region_zone_impl(
278
289
  filter_df = df
279
290
  if region is not None:
280
291
  filter_df = _filter_region_zone(filter_df, region, zone=None)
281
- if len(filter_df) == 0:
292
+ if filter_df.empty:
282
293
  with ux_utils.print_exception_no_traceback():
283
294
  error_msg = (f'Invalid region {region!r}')
284
295
  candidate_strs = _get_candidate_str(
@@ -288,7 +299,7 @@ def validate_region_zone_impl(
288
299
  faq_msg = (
289
300
  '\nIf a region is not included in the following '
290
301
  'list, please check the FAQ docs for how to fetch '
291
- 'its catalog info.\nhttps://skypilot.readthedocs.io'
302
+ 'its catalog info.\nhttps://docs.skypilot.co'
292
303
  '/en/latest/reference/faq.html#advanced-how-to-'
293
304
  'make-skypilot-use-all-global-regions')
294
305
  error_msg += faq_msg + _get_all_supported_regions_str()
@@ -302,7 +313,7 @@ def validate_region_zone_impl(
302
313
  if zone is not None:
303
314
  maybe_region_df = filter_df
304
315
  filter_df = filter_df[filter_df['AvailabilityZone'] == zone]
305
- if len(filter_df) == 0:
316
+ if filter_df.empty:
306
317
  region_str = f' for region {region!r}' if region else ''
307
318
  df = maybe_region_df if region else df
308
319
  with ux_utils.print_exception_no_traceback():
@@ -370,7 +381,7 @@ def get_vcpus_mem_from_instance_type_impl(
370
381
  instance_type: str,
371
382
  ) -> Tuple[Optional[float], Optional[float]]:
372
383
  df = _get_instance_type(df, instance_type, None)
373
- if len(df) == 0:
384
+ if df.empty:
374
385
  with ux_utils.print_exception_no_traceback():
375
386
  raise ValueError(f'No instance type {instance_type} found.')
376
387
  assert len(set(df['vCPUs'])) == 1, ('Cannot determine the number of vCPUs '
@@ -474,22 +485,28 @@ def get_instance_type_for_cpus_mem_impl(
474
485
  def get_accelerators_from_instance_type_impl(
475
486
  df: 'pd.DataFrame',
476
487
  instance_type: str,
477
- ) -> Optional[Dict[str, int]]:
488
+ ) -> Optional[Dict[str, Union[int, float]]]:
478
489
  df = _get_instance_type(df, instance_type, None)
479
- if len(df) == 0:
490
+ if df.empty:
480
491
  with ux_utils.print_exception_no_traceback():
481
492
  raise ValueError(f'No instance type {instance_type} found.')
482
493
  row = df.iloc[0]
483
494
  acc_name, acc_count = row['AcceleratorName'], row['AcceleratorCount']
484
495
  if pd.isnull(acc_name):
485
496
  return None
486
- return {acc_name: int(acc_count)}
497
+
498
+ def _convert(value):
499
+ if int(value) == value:
500
+ return int(value)
501
+ return float(value)
502
+
503
+ return {acc_name: _convert(acc_count)}
487
504
 
488
505
 
489
506
  def get_instance_type_for_accelerator_impl(
490
507
  df: 'pd.DataFrame',
491
508
  acc_name: str,
492
- acc_count: int,
509
+ acc_count: Union[int, float],
493
510
  cpus: Optional[str] = None,
494
511
  memory: Optional[str] = None,
495
512
  use_spot: bool = False,
@@ -502,9 +519,9 @@ def get_instance_type_for_accelerator_impl(
502
519
  accelerators with sorted prices and a list of candidates with fuzzy search.
503
520
  """
504
521
  result = df[(df['AcceleratorName'].str.fullmatch(acc_name, case=False)) &
505
- (df['AcceleratorCount'] == acc_count)]
522
+ (abs(df['AcceleratorCount'] - acc_count) <= 0.01)]
506
523
  result = _filter_region_zone(result, region, zone)
507
- if len(result) == 0:
524
+ if result.empty:
508
525
  fuzzy_result = df[
509
526
  (df['AcceleratorName'].str.contains(acc_name, case=False)) &
510
527
  (df['AcceleratorCount'] >= acc_count)]
@@ -513,16 +530,19 @@ def get_instance_type_for_accelerator_impl(
513
530
  fuzzy_result = fuzzy_result[['AcceleratorName',
514
531
  'AcceleratorCount']].drop_duplicates()
515
532
  fuzzy_candidate_list = []
516
- if len(fuzzy_result) > 0:
533
+ if not fuzzy_result.empty:
517
534
  for _, row in fuzzy_result.iterrows():
535
+ acc_cnt = float(row['AcceleratorCount'])
536
+ acc_count_display = (int(acc_cnt) if acc_cnt.is_integer() else
537
+ f'{acc_cnt:.2f}')
518
538
  fuzzy_candidate_list.append(f'{row["AcceleratorName"]}:'
519
- f'{int(row["AcceleratorCount"])}')
539
+ f'{acc_count_display}')
520
540
  return (None, fuzzy_candidate_list)
521
541
 
522
542
  result = _filter_with_cpus(result, cpus)
523
543
  result = _filter_with_mem(result, memory)
524
544
  result = _filter_region_zone(result, region, zone)
525
- if len(result) == 0:
545
+ if result.empty:
526
546
  return ([], [])
527
547
 
528
548
  # Current strategy: choose the cheapest instance
@@ -663,7 +683,7 @@ def get_image_id_from_tag_impl(df: 'pd.DataFrame', tag: str,
663
683
  df = _filter_region_zone(df, region, zone=None)
664
684
  assert len(df) <= 1, ('Multiple images found for tag '
665
685
  f'{tag} in region {region}')
666
- if len(df) == 0:
686
+ if df.empty:
667
687
  return None
668
688
  image_id = df['ImageId'].iloc[0]
669
689
  if pd.isna(image_id):
@@ -677,4 +697,4 @@ def is_image_tag_valid_impl(df: 'pd.DataFrame', tag: str,
677
697
  df = df[df['Tag'] == tag]
678
698
  df = _filter_region_zone(df, region, zone=None)
679
699
  df = df.dropna(subset=['ImageId'])
680
- return len(df) > 0
700
+ return not df.empty
@@ -1,7 +1,7 @@
1
1
  """Constants used for service catalog."""
2
2
  HOSTED_CATALOG_DIR_URL = 'https://raw.githubusercontent.com/skypilot-org/skypilot-catalog/master/catalogs' # pylint: disable=line-too-long
3
- CATALOG_SCHEMA_VERSION = 'v5'
3
+ CATALOG_SCHEMA_VERSION = 'v6'
4
4
  CATALOG_DIR = '~/.sky/catalogs'
5
5
  ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci',
6
- 'kubernetes', 'runpod', 'vsphere', 'cudo', 'fluidstack',
7
- 'paperspace')
6
+ 'kubernetes', 'runpod', 'vast', 'vsphere', 'cudo', 'fluidstack',
7
+ 'paperspace', 'do', 'nebius')
@@ -1,7 +1,7 @@
1
1
  """Cudo Compute Offerings Catalog."""
2
2
 
3
3
  import typing
4
- from typing import Dict, List, Optional, Tuple
4
+ from typing import Dict, List, Optional, Tuple, Union
5
5
 
6
6
  from sky.clouds.service_catalog import common
7
7
  import sky.provision.cudo.cudo_machine_type as cudo_mt
@@ -14,6 +14,9 @@ _PULL_FREQUENCY_HOURS = 1
14
14
  _df = common.read_catalog(cudo_mt.VMS_CSV,
15
15
  pull_frequency_hours=_PULL_FREQUENCY_HOURS)
16
16
 
17
+ _DEFAULT_NUM_VCPUS = 8
18
+ _DEFAULT_MEMORY_CPU_RATIO = 2
19
+
17
20
 
18
21
  def instance_type_exists(instance_type: str) -> bool:
19
22
  return common.instance_type_exists_impl(_df, instance_type)
@@ -52,11 +55,18 @@ def get_default_instance_type(cpus: Optional[str] = None,
52
55
  del disk_tier
53
56
  # NOTE: After expanding catalog to multiple entries, you may
54
57
  # want to specify a default instance type or family.
55
- return common.get_instance_type_for_cpus_mem_impl(_df, cpus, memory)
58
+ if cpus is None and memory is None:
59
+ cpus = f'{_DEFAULT_NUM_VCPUS}+'
60
+
61
+ memory_gb_or_ratio = memory
62
+ if memory is None:
63
+ memory_gb_or_ratio = f'{_DEFAULT_MEMORY_CPU_RATIO}x'
64
+ return common.get_instance_type_for_cpus_mem_impl(_df, cpus,
65
+ memory_gb_or_ratio)
56
66
 
57
67
 
58
68
  def get_accelerators_from_instance_type(
59
- instance_type: str) -> Optional[Dict[str, int]]:
69
+ instance_type: str) -> Optional[Dict[str, Union[int, float]]]:
60
70
  return common.get_accelerators_from_instance_type_impl(_df, instance_type)
61
71
 
62
72
 
@@ -306,6 +306,12 @@ def _get_instance_types_df(region: str) -> Union[str, 'pd.DataFrame']:
306
306
  assert find_num_in_name is not None, row['InstanceType']
307
307
  num_in_name = find_num_in_name.group(1)
308
308
  acc_count = int(num_in_name) // 2
309
+ if row['InstanceType'] == 'p5en.48xlarge':
310
+ # TODO(andyl): Check if this workaround still needed after
311
+ # v0.10.0 released. Currently, the acc_name returned by the
312
+ # AWS API is 'NVIDIA', which is incorrect. See #4652.
313
+ acc_name = 'H200'
314
+ acc_count = 8
309
315
  return pd.Series({
310
316
  'AcceleratorName': acc_name,
311
317
  'AcceleratorCount': acc_count,
@@ -379,26 +385,33 @@ def get_all_regions_instance_types_df(regions: Set[str]) -> 'pd.DataFrame':
379
385
  #
380
386
  # Deep Learning AMI GPU PyTorch 1.10.0 (Ubuntu 18.04) 20211208
381
387
  # Nvidia driver: 470.57.02, CUDA Version: 11.4
382
- _GPU_UBUNTU_DATE_PYTORCH = [
383
- ('gpu', '20.04', '20231103', '2.1.0'),
384
- ('gpu', '18.04', '20221114', '1.10.0'),
385
- ('k80', '20.04', '20211208', '1.10.0'),
386
- ('k80', '18.04', '20211208', '1.10.0'),
388
+ #
389
+ # Neuron (Inferentia / Trainium):
390
+ # https://aws.amazon.com/releasenotes/aws-deep-learning-ami-base-neuron-ubuntu-20-04/ # pylint: disable=line-too-long
391
+ # Deep Learning Base Neuron AMI (Ubuntu 20.04) 20240923
392
+ # TODO(tian): find out the driver version.
393
+ # Neuron driver:
394
+ _GPU_DESC_UBUNTU_DATE = [
395
+ ('gpu', 'AMI GPU PyTorch 2.1.0', '20.04', '20231103'),
396
+ ('gpu', 'AMI GPU PyTorch 1.10.0', '18.04', '20221114'),
397
+ ('k80', 'AMI GPU PyTorch 1.10.0', '20.04', '20211208'),
398
+ ('k80', 'AMI GPU PyTorch 1.10.0', '18.04', '20211208'),
399
+ ('neuron', 'Base Neuron AMI', '22.04', '20240923'),
387
400
  ]
388
401
 
389
402
 
390
- def _fetch_image_id(region: str, ubuntu_version: str, creation_date: str,
391
- pytorch_version: str) -> Optional[str]:
403
+ def _fetch_image_id(region: str, description: str, ubuntu_version: str,
404
+ creation_date: str) -> Optional[str]:
392
405
  try:
393
406
  image = subprocess.check_output(f"""\
394
407
  aws ec2 describe-images --region {region} --owners amazon \\
395
- --filters 'Name=name,Values="Deep Learning AMI GPU PyTorch {pytorch_version} (Ubuntu {ubuntu_version}) {creation_date}"' \\
408
+ --filters 'Name=name,Values="Deep Learning {description} (Ubuntu {ubuntu_version}) {creation_date}"' \\
396
409
  'Name=state,Values=available' --query 'Images[:1].ImageId' --output text
397
410
  """,
398
411
  shell=True)
399
412
  except subprocess.CalledProcessError as e:
400
- print(f'Failed {region}, {ubuntu_version}, {creation_date}. '
401
- 'Trying next date.')
413
+ print(f'Failed {region}, {description}, {ubuntu_version}, '
414
+ f'{creation_date}. Trying next date.')
402
415
  print(f'{type(e)}: {e}')
403
416
  image_id = None
404
417
  else:
@@ -407,21 +420,21 @@ def _fetch_image_id(region: str, ubuntu_version: str, creation_date: str,
407
420
  return image_id
408
421
 
409
422
 
410
- def _get_image_row(
411
- region: str, gpu: str, ubuntu_version: str, date: str,
412
- pytorch_version) -> Tuple[str, str, str, str, Optional[str], str]:
413
- print(f'Getting image for {region}, {ubuntu_version}, {gpu}')
414
- image_id = _fetch_image_id(region, ubuntu_version, date, pytorch_version)
423
+ def _get_image_row(region: str, gpu: str, description: str, ubuntu_version: str,
424
+ date: str) -> Tuple[str, str, str, str, Optional[str], str]:
425
+ print(f'Getting image for {region}, {description}, {ubuntu_version}, {gpu}')
426
+ image_id = _fetch_image_id(region, description, ubuntu_version, date)
415
427
  if image_id is None:
416
428
  # not found
417
- print(f'Failed to find image for {region}, {ubuntu_version}, {gpu}')
429
+ print(f'Failed to find image for {region}, {description}, '
430
+ f'{ubuntu_version}, {gpu}')
418
431
  tag = f'skypilot:{gpu}-ubuntu-{ubuntu_version.replace(".", "")}'
419
432
  return tag, region, 'ubuntu', ubuntu_version, image_id, date
420
433
 
421
434
 
422
435
  def get_all_regions_images_df(regions: Set[str]) -> 'pd.DataFrame':
423
436
  image_metas = [
424
- (r, *i) for r, i in itertools.product(regions, _GPU_UBUNTU_DATE_PYTORCH)
437
+ (r, *i) for r, i in itertools.product(regions, _GPU_DESC_UBUNTU_DATE)
425
438
  ]
426
439
  with mp_pool.Pool() as pool:
427
440
  results = pool.starmap(_get_image_row, image_metas)
@@ -531,11 +544,13 @@ if __name__ == '__main__':
531
544
  instance_df.to_csv('aws/vms.csv', index=False)
532
545
  print('AWS Service Catalog saved to aws/vms.csv')
533
546
 
534
- image_df = get_all_regions_images_df(user_regions)
535
- _check_regions_integrity(image_df, 'images')
547
+ # Disable refreshing images.csv as we are using skypilot custom AMIs
548
+ # See sky/clouds/service_catalog/images/README.md for more details.
549
+ # image_df = get_all_regions_images_df(user_regions)
550
+ # _check_regions_integrity(image_df, 'images')
536
551
 
537
- image_df.to_csv('aws/images.csv', index=False)
538
- print('AWS Images saved to aws/images.csv')
552
+ # image_df.to_csv('aws/images.csv', index=False)
553
+ # print('AWS Images saved to aws/images.csv')
539
554
 
540
555
  if args.az_mappings:
541
556
  az_mappings_df = fetch_availability_zone_mappings()
@@ -64,7 +64,7 @@ FAMILY_NAME_TO_SKYPILOT_GPU_NAME = {
64
64
  'standardNVSv2Family': 'M60',
65
65
  'standardNVSv3Family': 'M60',
66
66
  'standardNVPromoFamily': 'M60',
67
- 'standardNVSv4Family': 'Radeon MI25',
67
+ 'standardNVSv4Family': 'MI25',
68
68
  'standardNDSFamily': 'P40',
69
69
  'StandardNVADSA10v5Family': 'A10',
70
70
  'StandardNCadsH100v5Family': 'H100',
@@ -93,6 +93,16 @@ def get_regions() -> List[str]:
93
93
  # We have to manually remove it.
94
94
  DEPRECATED_FAMILIES = ['standardNVSv2Family']
95
95
 
96
+ # Azure has those fractional A10 instance types, which still shows has 1 A10 GPU
97
+ # in the API response. We manually changing the number of GPUs to a float here.
98
+ # Ref: https://learn.microsoft.com/en-us/azure/virtual-machines/nva10v5-series
99
+ # TODO(zhwu,tian): Support fractional GPUs on k8s as well.
100
+ # TODO(tian): Maybe we should support literally fractional count, i.e. A10:1/6
101
+ # instead of float point count (A10:0.167).
102
+ AZURE_FRACTIONAL_A10_INS_TYPE_TO_NUM_GPUS = {
103
+ f'Standard_NV{vcpu}ads_A10_v5': round(vcpu / 36, 3) for vcpu in [6, 12, 18]
104
+ }
105
+
96
106
  USEFUL_COLUMNS = [
97
107
  'InstanceType', 'AcceleratorName', 'AcceleratorCount', 'vCPUs', 'MemoryGiB',
98
108
  'GpuInfo', 'Price', 'SpotPrice', 'Region', 'Generation'
@@ -124,15 +134,19 @@ def get_pricing_df(region: Optional[str] = None) -> 'pd.DataFrame':
124
134
  content_str = r.content.decode('ascii')
125
135
  content = json.loads(content_str)
126
136
  items = content.get('Items', [])
127
- if len(items) == 0:
137
+ if not items:
128
138
  break
129
139
  all_items += items
130
140
  url = content.get('NextPageLink')
131
141
  print(f'Done fetching pricing {region}')
132
142
  df = pd.DataFrame(all_items)
133
143
  assert 'productName' in df.columns, (region, df.columns)
134
- return df[(~df['productName'].str.contains(' Windows')) &
135
- (df['unitPrice'] > 0)]
144
+ # Filter out the cloud services and windows products.
145
+ # Some H100 series use ' Win' instead of ' Windows', e.g.
146
+ # Virtual Machines NCCadsv5 Srs Win
147
+ return df[
148
+ (~df['productName'].str.contains(' Win| Cloud Services| CloudServices'))
149
+ & (df['unitPrice'] > 0)]
136
150
 
137
151
 
138
152
  def get_sku_df(region_set: Set[str]) -> 'pd.DataFrame':
@@ -261,6 +275,19 @@ def get_all_regions_instance_types_df(region_set: Set[str]):
261
275
  axis='columns',
262
276
  )
263
277
 
278
+ def _upd_a10_gpu_count(row):
279
+ new_gpu_cnt = AZURE_FRACTIONAL_A10_INS_TYPE_TO_NUM_GPUS.get(
280
+ row['InstanceType'])
281
+ if new_gpu_cnt is not None:
282
+ return new_gpu_cnt
283
+ return row['AcceleratorCount']
284
+
285
+ # Manually update the GPU count for fractional A10 instance types.
286
+ # Those instance types have fractional GPU count, but Azure API returns
287
+ # 1 GPU count for them. We manually update the GPU count here.
288
+ df_ret['AcceleratorCount'] = df_ret.apply(_upd_a10_gpu_count,
289
+ axis='columns')
290
+
264
291
  # As of Dec 2023, a few H100 instance types fetched from Azure APIs do not
265
292
  # have pricing:
266
293
  #