skypilot-nightly 1.0.0.dev20241023__py3-none-any.whl → 1.0.0.dev20241025__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. sky/__init__.py +4 -2
  2. sky/adaptors/azure.py +11 -0
  3. sky/check.py +11 -4
  4. sky/cli.py +24 -16
  5. sky/clouds/azure.py +86 -50
  6. sky/clouds/cloud.py +4 -0
  7. sky/clouds/cloud_registry.py +55 -10
  8. sky/clouds/kubernetes.py +1 -1
  9. sky/clouds/oci.py +1 -1
  10. sky/clouds/service_catalog/azure_catalog.py +15 -0
  11. sky/clouds/service_catalog/kubernetes_catalog.py +7 -1
  12. sky/clouds/utils/azure_utils.py +91 -0
  13. sky/exceptions.py +4 -4
  14. sky/jobs/recovery_strategy.py +3 -3
  15. sky/provision/azure/azure-config-template.json +7 -1
  16. sky/provision/azure/config.py +24 -8
  17. sky/provision/azure/instance.py +251 -137
  18. sky/provision/kubernetes/instance.py +4 -2
  19. sky/provision/provisioner.py +16 -8
  20. sky/resources.py +1 -0
  21. sky/templates/azure-ray.yml.j2 +2 -0
  22. sky/usage/usage_lib.py +3 -2
  23. sky/utils/common_utils.py +3 -2
  24. sky/utils/controller_utils.py +69 -18
  25. {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/METADATA +1 -1
  26. {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/RECORD +30 -30
  27. sky/provision/azure/azure-vm-template.json +0 -301
  28. {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/LICENSE +0 -0
  29. {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/WHEEL +0 -0
  30. {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/entry_points.txt +0 -0
  31. {skypilot_nightly-1.0.0.dev20241023.dist-info → skypilot_nightly-1.0.0.dev20241025.dist-info}/top_level.txt +0 -0
sky/__init__.py CHANGED
@@ -5,7 +5,7 @@ from typing import Optional
5
5
  import urllib.request
6
6
 
7
7
  # Replaced with the current commit when building the wheels.
8
- _SKYPILOT_COMMIT_SHA = 'f2991b144d4b15eac55dd7f759f361b6146033b3'
8
+ _SKYPILOT_COMMIT_SHA = '057bc4b44755ac1e9dadc680e022c369e8ddff52'
9
9
 
10
10
 
11
11
  def _get_git_commit():
@@ -35,7 +35,7 @@ def _get_git_commit():
35
35
 
36
36
 
37
37
  __commit__ = _get_git_commit()
38
- __version__ = '1.0.0.dev20241023'
38
+ __version__ = '1.0.0.dev20241025'
39
39
  __root_dir__ = os.path.dirname(os.path.abspath(__file__))
40
40
 
41
41
 
@@ -128,6 +128,7 @@ GCP = clouds.GCP
128
128
  Lambda = clouds.Lambda
129
129
  SCP = clouds.SCP
130
130
  Kubernetes = clouds.Kubernetes
131
+ K8s = Kubernetes
131
132
  OCI = clouds.OCI
132
133
  Paperspace = clouds.Paperspace
133
134
  RunPod = clouds.RunPod
@@ -143,6 +144,7 @@ __all__ = [
143
144
  'GCP',
144
145
  'IBM',
145
146
  'Kubernetes',
147
+ 'K8s',
146
148
  'Lambda',
147
149
  'OCI',
148
150
  'Paperspace',
sky/adaptors/azure.py CHANGED
@@ -69,6 +69,17 @@ def exceptions():
69
69
  return azure_exceptions
70
70
 
71
71
 
72
+ @functools.lru_cache()
73
+ @common.load_lazy_modules(modules=_LAZY_MODULES)
74
+ def azure_mgmt_models(name: str):
75
+ if name == 'compute':
76
+ from azure.mgmt.compute import models
77
+ return models
78
+ elif name == 'network':
79
+ from azure.mgmt.network import models
80
+ return models
81
+
82
+
72
83
  # We should keep the order of the decorators having 'lru_cache' followed
73
84
  # by 'load_lazy_modules' as we need to make sure a caller can call
74
85
  # 'get_client.cache_clear', which is a function provided by 'lru_cache'
sky/check.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Credential checks: check cloud credentials and enable clouds."""
2
+ import os
2
3
  import traceback
3
4
  from types import ModuleType
4
5
  from typing import Dict, Iterable, List, Optional, Tuple, Union
@@ -194,19 +195,25 @@ def get_cached_enabled_clouds_or_refresh(
194
195
  def get_cloud_credential_file_mounts(
195
196
  excluded_clouds: Optional[Iterable[sky_clouds.Cloud]]
196
197
  ) -> Dict[str, str]:
197
- """Returns the files necessary to access all enabled clouds.
198
+ """Returns the files necessary to access all clouds.
198
199
 
199
200
  Returns a dictionary that will be added to a task's file mounts
200
201
  and a list of patterns that will be excluded (used as rsync_exclude).
201
202
  """
202
- enabled_clouds = get_cached_enabled_clouds_or_refresh()
203
+ # Uploading credentials for all clouds instead of only sky check
204
+ # enabled clouds because users may have partial credentials for some
205
+ # clouds to access their specific resources (e.g. cloud storage) but
206
+ # not have the complete credentials to pass sky check.
207
+ clouds = sky_clouds.CLOUD_REGISTRY.values()
203
208
  file_mounts = {}
204
- for cloud in enabled_clouds:
209
+ for cloud in clouds:
205
210
  if (excluded_clouds is not None and
206
211
  sky_clouds.cloud_in_iterable(cloud, excluded_clouds)):
207
212
  continue
208
213
  cloud_file_mounts = cloud.get_credential_file_mounts()
209
- file_mounts.update(cloud_file_mounts)
214
+ for remote_path, local_path in cloud_file_mounts.items():
215
+ if os.path.exists(os.path.expanduser(local_path)):
216
+ file_mounts[remote_path] = local_path
210
217
  # Currently, get_cached_enabled_clouds_or_refresh() does not support r2 as
211
218
  # only clouds with computing instances are marked as enabled by skypilot.
212
219
  # This will be removed when cloudflare/r2 is added as a 'cloud'.
sky/cli.py CHANGED
@@ -339,7 +339,6 @@ def _get_shell_complete_args(complete_fn):
339
339
 
340
340
 
341
341
  _RELOAD_ZSH_CMD = 'source ~/.zshrc'
342
- _RELOAD_FISH_CMD = 'source ~/.config/fish/config.fish'
343
342
  _RELOAD_BASH_CMD = 'source ~/.bashrc'
344
343
 
345
344
 
@@ -378,7 +377,9 @@ def _install_shell_completion(ctx: click.Context, param: click.Parameter,
378
377
  cmd = '_SKY_COMPLETE=fish_source sky > \
379
378
  ~/.config/fish/completions/sky.fish'
380
379
 
381
- reload_cmd = _RELOAD_FISH_CMD
380
+ # Fish does not need to be reloaded and will automatically pick up
381
+ # completions.
382
+ reload_cmd = None
382
383
 
383
384
  elif value == 'zsh':
384
385
  install_cmd = f'_SKY_COMPLETE=zsh_source sky > \
@@ -398,9 +399,10 @@ def _install_shell_completion(ctx: click.Context, param: click.Parameter,
398
399
  check=True,
399
400
  executable=shutil.which('bash'))
400
401
  click.secho(f'Shell completion installed for {value}', fg='green')
401
- click.echo(
402
- 'Completion will take effect once you restart the terminal: ' +
403
- click.style(f'{reload_cmd}', bold=True))
402
+ if reload_cmd is not None:
403
+ click.echo(
404
+ 'Completion will take effect once you restart the terminal: ' +
405
+ click.style(f'{reload_cmd}', bold=True))
404
406
  except subprocess.CalledProcessError as e:
405
407
  click.secho(f'> Installation failed with code {e.returncode}', fg='red')
406
408
  ctx.exit()
@@ -431,7 +433,9 @@ def _uninstall_shell_completion(ctx: click.Context, param: click.Parameter,
431
433
 
432
434
  elif value == 'fish':
433
435
  cmd = 'rm -f ~/.config/fish/completions/sky.fish'
434
- reload_cmd = _RELOAD_FISH_CMD
436
+ # Fish does not need to be reloaded and will automatically pick up
437
+ # completions.
438
+ reload_cmd = None
435
439
 
436
440
  elif value == 'zsh':
437
441
  cmd = 'sed -i"" -e "/# For SkyPilot shell completion/d" ~/.zshrc && \
@@ -447,8 +451,10 @@ def _uninstall_shell_completion(ctx: click.Context, param: click.Parameter,
447
451
  try:
448
452
  subprocess.run(cmd, shell=True, check=True)
449
453
  click.secho(f'Shell completion uninstalled for {value}', fg='green')
450
- click.echo('Changes will take effect once you restart the terminal: ' +
451
- click.style(f'{reload_cmd}', bold=True))
454
+ if reload_cmd is not None:
455
+ click.echo(
456
+ 'Changes will take effect once you restart the terminal: ' +
457
+ click.style(f'{reload_cmd}', bold=True))
452
458
  except subprocess.CalledProcessError as e:
453
459
  click.secho(f'> Uninstallation failed with code {e.returncode}',
454
460
  fg='red')
@@ -3056,7 +3062,8 @@ def show_gpus(
3056
3062
 
3057
3063
  # This will validate 'cloud' and raise if not found.
3058
3064
  cloud_obj = sky_clouds.CLOUD_REGISTRY.from_str(cloud)
3059
- service_catalog.validate_region_zone(region, None, clouds=cloud)
3065
+ cloud_name = cloud_obj.canonical_name() if cloud_obj is not None else None
3066
+ service_catalog.validate_region_zone(region, None, clouds=cloud_name)
3060
3067
  show_all = all
3061
3068
  if show_all and accelerator_str is not None:
3062
3069
  raise click.UsageError('--all is only allowed without a GPU name.')
@@ -3078,7 +3085,7 @@ def show_gpus(
3078
3085
  qty_header = 'QTY_FILTER'
3079
3086
  free_header = 'FILTERED_FREE_GPUS'
3080
3087
  else:
3081
- qty_header = 'QTY_PER_NODE'
3088
+ qty_header = 'REQUESTABLE_QTY_PER_NODE'
3082
3089
  free_header = 'TOTAL_FREE_GPUS'
3083
3090
  realtime_gpu_table = log_utils.create_table(
3084
3091
  ['GPU', qty_header, 'TOTAL_GPUS', free_header])
@@ -3142,8 +3149,8 @@ def show_gpus(
3142
3149
  # Optimization - do not poll for Kubernetes API for fetching
3143
3150
  # common GPUs because that will be fetched later for the table after
3144
3151
  # common GPUs.
3145
- clouds_to_list = cloud
3146
- if cloud is None:
3152
+ clouds_to_list = cloud_name
3153
+ if cloud_name is None:
3147
3154
  clouds_to_list = [
3148
3155
  c for c in service_catalog.ALL_CLOUDS if c != 'kubernetes'
3149
3156
  ]
@@ -3153,7 +3160,8 @@ def show_gpus(
3153
3160
  # Collect k8s related messages in k8s_messages and print them at end
3154
3161
  print_section_titles = False
3155
3162
  # If cloud is kubernetes, we want to show real-time capacity
3156
- if kubernetes_is_enabled and (cloud is None or cloud_is_kubernetes):
3163
+ if kubernetes_is_enabled and (cloud_name is None or
3164
+ cloud_is_kubernetes):
3157
3165
  if region:
3158
3166
  context = region
3159
3167
  else:
@@ -3263,8 +3271,8 @@ def show_gpus(
3263
3271
  name, quantity = accelerator_str, None
3264
3272
 
3265
3273
  print_section_titles = False
3266
- if (kubernetes_is_enabled and (cloud is None or cloud_is_kubernetes) and
3267
- not show_all):
3274
+ if (kubernetes_is_enabled and
3275
+ (cloud_name is None or cloud_is_kubernetes) and not show_all):
3268
3276
  # Print section title if not showing all and instead a specific
3269
3277
  # accelerator is requested
3270
3278
  print_section_titles = True
@@ -3336,7 +3344,7 @@ def show_gpus(
3336
3344
  if len(result) == 0:
3337
3345
  quantity_str = (f' with requested quantity {quantity}'
3338
3346
  if quantity else '')
3339
- cloud_str = f' on {cloud_obj}.' if cloud else ' in cloud catalogs.'
3347
+ cloud_str = f' on {cloud_obj}.' if cloud_name else ' in cloud catalogs.'
3340
3348
  yield f'Resources \'{name}\'{quantity_str} not found{cloud_str} '
3341
3349
  yield 'To show available accelerators, run: sky show-gpus --all'
3342
3350
  return
sky/clouds/azure.py CHANGED
@@ -15,6 +15,7 @@ from sky import exceptions
15
15
  from sky import sky_logging
16
16
  from sky.adaptors import azure
17
17
  from sky.clouds import service_catalog
18
+ from sky.clouds.utils import azure_utils
18
19
  from sky.utils import common_utils
19
20
  from sky.utils import resources_utils
20
21
  from sky.utils import ux_utils
@@ -36,6 +37,15 @@ _MAX_IDENTITY_FETCH_RETRY = 10
36
37
 
37
38
  _DEFAULT_AZURE_UBUNTU_HPC_IMAGE_GB = 30
38
39
  _DEFAULT_AZURE_UBUNTU_2004_IMAGE_GB = 150
40
+ _DEFAULT_SKYPILOT_IMAGE_GB = 30
41
+
42
+ _DEFAULT_CPU_IMAGE_ID = 'skypilot:gpu-ubuntu-2204'
43
+ _DEFAULT_GPU_IMAGE_ID = 'skypilot:gpu-ubuntu-2204'
44
+ _DEFAULT_V1_IMAGE_ID = 'skypilot:v1-ubuntu-2004'
45
+ _DEFAULT_GPU_K80_IMAGE_ID = 'skypilot:k80-ubuntu-2004'
46
+ _FALLBACK_IMAGE_ID = 'skypilot:gpu-ubuntu-2204'
47
+
48
+ _COMMUNITY_IMAGE_PREFIX = '/CommunityGalleries'
39
49
 
40
50
 
41
51
  def _run_output(cmd):
@@ -132,29 +142,56 @@ class Azure(clouds.Cloud):
132
142
  cost += 0.0
133
143
  return cost
134
144
 
145
+ @classmethod
146
+ def get_default_instance_type(
147
+ cls,
148
+ cpus: Optional[str] = None,
149
+ memory: Optional[str] = None,
150
+ disk_tier: Optional[resources_utils.DiskTier] = None
151
+ ) -> Optional[str]:
152
+ return service_catalog.get_default_instance_type(cpus=cpus,
153
+ memory=memory,
154
+ disk_tier=disk_tier,
155
+ clouds='azure')
156
+
135
157
  @classmethod
136
158
  def get_image_size(cls, image_id: str, region: Optional[str]) -> float:
137
- if region is None:
138
- # The region used here is only for where to send the query,
139
- # not the image location. Azure's image is globally available.
140
- region = 'eastus'
141
- is_skypilot_image_tag = False
159
+ # Process skypilot images.
142
160
  if image_id.startswith('skypilot:'):
143
- is_skypilot_image_tag = True
144
161
  image_id = service_catalog.get_image_id_from_tag(image_id,
145
162
  clouds='azure')
146
- image_id_splitted = image_id.split(':')
147
- if len(image_id_splitted) != 4:
148
- with ux_utils.print_exception_no_traceback():
149
- raise ValueError(f'Invalid image id: {image_id}. Expected '
150
- 'format: <publisher>:<offer>:<sku>:<version>')
151
- publisher, offer, sku, version = image_id_splitted
152
- if is_skypilot_image_tag:
153
- if offer == 'ubuntu-hpc':
154
- return _DEFAULT_AZURE_UBUNTU_HPC_IMAGE_GB
163
+ if image_id.startswith(_COMMUNITY_IMAGE_PREFIX):
164
+ # Avoid querying the image size from Azure as
165
+ # all skypilot custom images have the same size.
166
+ return _DEFAULT_SKYPILOT_IMAGE_GB
155
167
  else:
156
- return _DEFAULT_AZURE_UBUNTU_2004_IMAGE_GB
168
+ publisher, offer, sku, version = image_id.split(':')
169
+ if offer == 'ubuntu-hpc':
170
+ return _DEFAULT_AZURE_UBUNTU_HPC_IMAGE_GB
171
+ else:
172
+ return _DEFAULT_AZURE_UBUNTU_2004_IMAGE_GB
173
+
174
+ # Process user-specified images.
175
+ azure_utils.validate_image_id(image_id)
157
176
  compute_client = azure.get_client('compute', cls.get_project_id())
177
+
178
+ # Community gallery image.
179
+ if image_id.startswith(_COMMUNITY_IMAGE_PREFIX):
180
+ if region is None:
181
+ return 0.0
182
+ _, _, gallery_name, _, image_name = image_id.split('/')
183
+ try:
184
+ return azure_utils.get_community_image_size(
185
+ compute_client, gallery_name, image_name, region)
186
+ except exceptions.ResourcesUnavailableError:
187
+ return 0.0
188
+
189
+ # Marketplace image
190
+ if region is None:
191
+ # The region used here is only for where to send the query,
192
+ # not the image location. Marketplace image is globally available.
193
+ region = 'eastus'
194
+ publisher, offer, sku, version = image_id.split(':')
158
195
  try:
159
196
  image = compute_client.virtual_machine_images.get(
160
197
  region, publisher, offer, sku, version)
@@ -176,40 +213,23 @@ class Azure(clouds.Cloud):
176
213
  size_in_gb = size_in_bytes / (1024**3)
177
214
  return size_in_gb
178
215
 
179
- @classmethod
180
- def get_default_instance_type(
181
- cls,
182
- cpus: Optional[str] = None,
183
- memory: Optional[str] = None,
184
- disk_tier: Optional[resources_utils.DiskTier] = None
185
- ) -> Optional[str]:
186
- return service_catalog.get_default_instance_type(cpus=cpus,
187
- memory=memory,
188
- disk_tier=disk_tier,
189
- clouds='azure')
190
-
191
216
  def _get_default_image_tag(self, gen_version, instance_type) -> str:
192
217
  # ubuntu-2004 v21.08.30, K80 requires image with old NVIDIA driver version
193
218
  acc = self.get_accelerators_from_instance_type(instance_type)
194
219
  if acc is not None:
195
220
  acc_name = list(acc.keys())[0]
196
221
  if acc_name == 'K80':
197
- return 'skypilot:k80-ubuntu-2004'
198
-
199
- # ubuntu-2004 v21.11.04, the previous image we used in the past for
200
- # V1 HyperV instance before we change default image to ubuntu-hpc.
222
+ return _DEFAULT_GPU_K80_IMAGE_ID
223
+ # About Gen V1 vs V2:
201
224
  # In Azure, all instances with K80 (Standard_NC series), some
202
225
  # instances with M60 (Standard_NV series) and some cpu instances
203
- # (Basic_A, Standard_D, ...) are V1 instance. For these instances,
204
- # we use the previous image.
226
+ # (Basic_A, Standard_D, ...) are V1 instance.
227
+ # All A100 instances are V2.
205
228
  if gen_version == 'V1':
206
- return 'skypilot:v1-ubuntu-2004'
207
-
208
- # nvidia-driver: 535.54.03, cuda: 12.2
209
- # see: https://github.com/Azure/azhpc-images/releases/tag/ubuntu-hpc-20230803
210
- # All A100 instances is of gen2, so it will always use
211
- # the latest ubuntu-hpc:2204 image.
212
- return 'skypilot:gpu-ubuntu-2204'
229
+ return _DEFAULT_V1_IMAGE_ID
230
+ if acc is None:
231
+ return _DEFAULT_CPU_IMAGE_ID
232
+ return _DEFAULT_GPU_IMAGE_ID
213
233
 
214
234
  @classmethod
215
235
  def regions_with_offering(cls, instance_type: str,
@@ -302,17 +322,34 @@ class Azure(clouds.Cloud):
302
322
  else:
303
323
  assert region_name in resources.image_id, resources.image_id
304
324
  image_id = resources.image_id[region_name]
325
+
326
+ # Checked basic image syntax in resources.py
305
327
  if image_id.startswith('skypilot:'):
306
328
  image_id = service_catalog.get_image_id_from_tag(image_id,
307
329
  clouds='azure')
308
- # Already checked in resources.py
309
- publisher, offer, sku, version = image_id.split(':')
310
- image_config = {
311
- 'image_publisher': publisher,
312
- 'image_offer': offer,
313
- 'image_sku': sku,
314
- 'image_version': version,
315
- }
330
+ # Fallback if image does not exist in the specified region.
331
+ # Putting fallback here instead of at image validation
332
+ # when creating the resource because community images are
333
+ # regional so we need the correct region when we check whether
334
+ # the image exists.
335
+ if image_id.startswith(
336
+ _COMMUNITY_IMAGE_PREFIX
337
+ ) and region_name not in azure_catalog.COMMUNITY_IMAGE_AVAILABLE_REGIONS:
338
+ logger.info(f'Azure image {image_id} does not exist in region '
339
+ f'{region_name} so use the fallback image instead.')
340
+ image_id = service_catalog.get_image_id_from_tag(
341
+ _FALLBACK_IMAGE_ID, clouds='azure')
342
+
343
+ if image_id.startswith(_COMMUNITY_IMAGE_PREFIX):
344
+ image_config = {'community_gallery_image_id': image_id}
345
+ else:
346
+ publisher, offer, sku, version = image_id.split(':')
347
+ image_config = {
348
+ 'image_publisher': publisher,
349
+ 'image_offer': offer,
350
+ 'image_sku': sku,
351
+ 'image_version': version,
352
+ }
316
353
 
317
354
  # Setup the A10 nvidia driver.
318
355
  need_nvidia_driver_extension = (acc_dict is not None and
@@ -380,7 +417,6 @@ class Azure(clouds.Cloud):
380
417
  # Setting disk performance tier for high disk tier.
381
418
  if disk_tier == resources_utils.DiskTier.HIGH:
382
419
  resources_vars['disk_performance_tier'] = 'P50'
383
-
384
420
  return resources_vars
385
421
 
386
422
  def _get_feasible_launchable_resources(
sky/clouds/cloud.py CHANGED
@@ -819,6 +819,10 @@ class Cloud:
819
819
 
820
820
  # === End of image related methods ===
821
821
 
822
+ @classmethod
823
+ def canonical_name(cls) -> str:
824
+ return cls.__name__.lower()
825
+
822
826
  def __repr__(self):
823
827
  return self._REPR
824
828
 
@@ -1,7 +1,7 @@
1
1
  """Clouds need to be registered in CLOUD_REGISTRY to be discovered"""
2
2
 
3
3
  import typing
4
- from typing import Optional, Type
4
+ from typing import Callable, Dict, List, Optional, overload, Type, Union
5
5
 
6
6
  from sky.utils import ux_utils
7
7
 
@@ -12,20 +12,65 @@ if typing.TYPE_CHECKING:
12
12
  class _CloudRegistry(dict):
13
13
  """Registry of clouds."""
14
14
 
15
+ def __init__(self) -> None:
16
+ super().__init__()
17
+ self.aliases: Dict[str, str] = {}
18
+
15
19
  def from_str(self, name: Optional[str]) -> Optional['cloud.Cloud']:
20
+ """Returns the cloud instance from the canonical name or alias."""
16
21
  if name is None:
17
22
  return None
18
- if name.lower() not in self:
19
- with ux_utils.print_exception_no_traceback():
20
- raise ValueError(f'Cloud {name!r} is not a valid cloud among '
21
- f'{list(self.keys())}')
22
- return self.get(name.lower())
23
23
 
24
+ search_name = name.lower()
25
+
26
+ if search_name in self:
27
+ return self[search_name]
28
+
29
+ if search_name in self.aliases:
30
+ return self[self.aliases[search_name]]
31
+
32
+ with ux_utils.print_exception_no_traceback():
33
+ raise ValueError(f'Cloud {name!r} is not a valid cloud among '
34
+ f'{[*self.keys(), *self.aliases.keys()]}')
35
+
36
+ @overload
24
37
  def register(self, cloud_cls: Type['cloud.Cloud']) -> Type['cloud.Cloud']:
25
- name = cloud_cls.__name__.lower()
26
- assert name not in self, f'{name} already registered'
27
- self[name] = cloud_cls()
28
- return cloud_cls
38
+ ...
39
+
40
+ @overload
41
+ def register(
42
+ self,
43
+ cloud_cls: None = None,
44
+ aliases: Optional[List[str]] = None,
45
+ ) -> Callable[[Type['cloud.Cloud']], Type['cloud.Cloud']]:
46
+ ...
47
+
48
+ def register(
49
+ self,
50
+ cloud_cls: Optional[Type['cloud.Cloud']] = None,
51
+ aliases: Optional[List[str]] = None,
52
+ ) -> Union[Type['cloud.Cloud'], Callable[[Type['cloud.Cloud']],
53
+ Type['cloud.Cloud']]]:
54
+
55
+ def _register(cloud_cls: Type['cloud.Cloud']) -> Type['cloud.Cloud']:
56
+ name = cloud_cls.canonical_name()
57
+ assert name not in self, f'{name} already registered'
58
+ self[name] = cloud_cls()
59
+
60
+ for alias in aliases or []:
61
+ alias = alias.lower()
62
+ assert alias not in self.aliases, (
63
+ f'alias {alias} already registered')
64
+ self.aliases[alias] = name
65
+
66
+ return cloud_cls
67
+
68
+ if cloud_cls is not None:
69
+ # invocation without parens (e.g. just `@register`)
70
+ return _register(cloud_cls)
71
+
72
+ # Invocation with parens (e.g. `@register(aliases=['alias'])`)
73
+ return _register
29
74
 
30
75
 
31
76
  CLOUD_REGISTRY: _CloudRegistry = _CloudRegistry()
sky/clouds/kubernetes.py CHANGED
@@ -33,7 +33,7 @@ CREDENTIAL_PATH = os.environ.get('KUBECONFIG', DEFAULT_KUBECONFIG_PATH)
33
33
  _SKYPILOT_SYSTEM_NAMESPACE = 'skypilot-system'
34
34
 
35
35
 
36
- @clouds.CLOUD_REGISTRY.register
36
+ @clouds.CLOUD_REGISTRY.register(aliases=['k8s'])
37
37
  class Kubernetes(clouds.Cloud):
38
38
  """Kubernetes."""
39
39
 
sky/clouds/oci.py CHANGED
@@ -468,7 +468,7 @@ class OCI(clouds.Cloud):
468
468
  api_key_file = oci_cfg[
469
469
  'key_file'] if 'key_file' in oci_cfg else 'BadConf'
470
470
  sky_cfg_file = oci_utils.oci_config.get_sky_user_config_file()
471
- except ImportError:
471
+ except (ImportError, oci_adaptor.oci.exceptions.ConfigFileNotFound):
472
472
  return {}
473
473
 
474
474
  # OCI config and API key file are mandatory
@@ -12,6 +12,21 @@ from sky.clouds.service_catalog import common
12
12
  from sky.utils import resources_utils
13
13
  from sky.utils import ux_utils
14
14
 
15
+ # This list should match the list of regions in
16
+ # skypilot image generation Packer script's replication_regions
17
+ # sky/clouds/service_catalog/images/skypilot-azure-cpu-ubuntu.pkr.hcl
18
+ COMMUNITY_IMAGE_AVAILABLE_REGIONS = {
19
+ 'centralus',
20
+ 'eastus',
21
+ 'eastus2',
22
+ 'northcentralus',
23
+ 'southcentralus',
24
+ 'westcentralus',
25
+ 'westus',
26
+ 'westus2',
27
+ 'westus3',
28
+ }
29
+
15
30
  # The frequency of pulling the latest catalog from the cloud provider.
16
31
  # Though the catalog update is manual in our skypilot-catalog repo, we
17
32
  # still want to pull the latest catalog periodically to make sure the
@@ -120,8 +120,14 @@ def list_accelerators_realtime(
120
120
 
121
121
  # Generate the GPU quantities for the accelerators
122
122
  if accelerator_name and accelerator_count > 0:
123
- for count in range(1, accelerator_count + 1):
123
+ count = 1
124
+ while count <= accelerator_count:
124
125
  accelerators_qtys.add((accelerator_name, count))
126
+ count *= 2
127
+ # Add the accelerator count if it's not already in the set
128
+ # (e.g., if there's 12 GPUs, we should have qtys 1, 2, 4, 8, 12)
129
+ if accelerator_count not in accelerators_qtys:
130
+ accelerators_qtys.add((accelerator_name, accelerator_count))
125
131
 
126
132
  for pod in pods:
127
133
  # Get all the pods running on the node
@@ -0,0 +1,91 @@
1
+ """Utilies for Azure"""
2
+
3
+ import typing
4
+
5
+ from sky import exceptions
6
+ from sky.adaptors import azure
7
+ from sky.utils import ux_utils
8
+
9
+ if typing.TYPE_CHECKING:
10
+ from azure.mgmt import compute as azure_compute
11
+ from azure.mgmt.compute import models as azure_compute_models
12
+
13
+
14
+ def validate_image_id(image_id: str):
15
+ """Check if the image ID has a valid format.
16
+
17
+ Raises:
18
+ ValueError: If the image ID is invalid.
19
+ """
20
+ image_id_colon_splitted = image_id.split(':')
21
+ image_id_slash_splitted = image_id.split('/')
22
+ if len(image_id_slash_splitted) != 5 and len(image_id_colon_splitted) != 4:
23
+ with ux_utils.print_exception_no_traceback():
24
+ raise ValueError(
25
+ f'Invalid image id for Azure: {image_id}. Expected format: \n'
26
+ '* Marketplace image ID: <publisher>:<offer>:<sku>:<version>\n'
27
+ '* Community image ID: '
28
+ '/CommunityGalleries/<gallery-name>/Images/<image-name>')
29
+ if len(image_id_slash_splitted) == 5:
30
+ _, gallery_type, _, image_type, _ = image_id.split('/')
31
+ if gallery_type != 'CommunityGalleries' or image_type != 'Images':
32
+ with ux_utils.print_exception_no_traceback():
33
+ raise ValueError(
34
+ f'Invalid community image id for Azure: {image_id}.\n'
35
+ 'Expected format: '
36
+ '/CommunityGalleries/<gallery-name>/Images/<image-name>')
37
+
38
+
39
+ def get_community_image(
40
+ compute_client: 'azure_compute.ComputeManagementClient', image_id: str,
41
+ region: str) -> 'azure_compute_models.CommunityGalleryImage':
42
+ """Get community image from cloud.
43
+
44
+ Args:
45
+ image_id: /CommunityGalleries/<gallery-name>/Images/<image-name>
46
+ Raises:
47
+ ResourcesUnavailableError
48
+ """
49
+ try:
50
+ _, _, gallery_name, _, image_name = image_id.split('/')
51
+ return compute_client.community_gallery_images.get(
52
+ location=region,
53
+ public_gallery_name=gallery_name,
54
+ gallery_image_name=image_name)
55
+ except azure.exceptions().AzureError as e:
56
+ raise exceptions.ResourcesUnavailableError(
57
+ f'Community image {image_id} does not exist in region {region}.'
58
+ ) from e
59
+
60
+
61
+ def get_community_image_size(
62
+ compute_client: 'azure_compute.ComputeManagementClient',
63
+ gallery_name: str, image_name: str, region: str) -> float:
64
+ """Get the size of the community image from cloud.
65
+
66
+ Args:
67
+ image_id: /CommunityGalleries/<gallery-name>/Images/<image-name>
68
+ Raises:
69
+ ResourcesUnavailableError
70
+ """
71
+ try:
72
+ image_versions = compute_client.community_gallery_image_versions.list(
73
+ location=region,
74
+ public_gallery_name=gallery_name,
75
+ gallery_image_name=image_name,
76
+ )
77
+ image_versions = list(image_versions)
78
+ if not image_versions:
79
+ raise exceptions.ResourcesUnavailableError(
80
+ f'No versions available for Azure community image {image_name}')
81
+ latest_version = image_versions[-1].name
82
+
83
+ image_details = compute_client.community_gallery_image_versions.get(
84
+ location=region,
85
+ public_gallery_name=gallery_name,
86
+ gallery_image_name=image_name,
87
+ gallery_image_version_name=latest_version)
88
+ return image_details.storage_profile.os_disk_image.disk_size_gb
89
+ except azure.exceptions().AzureError as e:
90
+ raise exceptions.ResourcesUnavailableError(
91
+ f'Failed to get community image size: {e}.') from e
sky/exceptions.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """Exceptions."""
2
2
  import enum
3
3
  import typing
4
- from typing import List, Optional
4
+ from typing import List, Optional, Sequence
5
5
 
6
6
  if typing.TYPE_CHECKING:
7
7
  from sky import status_lib
@@ -61,12 +61,12 @@ class ProvisionPrechecksError(Exception):
61
61
  the error will be raised.
62
62
 
63
63
  Args:
64
- reasons: (List[Exception]) The reasons why the prechecks failed.
64
+ reasons: (Sequence[Exception]) The reasons why the prechecks failed.
65
65
  """
66
66
 
67
- def __init__(self, reasons: List[Exception]) -> None:
67
+ def __init__(self, reasons: Sequence[Exception]) -> None:
68
68
  super().__init__()
69
- self.reasons = list(reasons)
69
+ self.reasons = reasons
70
70
 
71
71
 
72
72
  class ManagedJobReachedMaxRetriesError(Exception):