skypilot-nightly 1.0.0.dev20251012__py3-none-any.whl → 1.0.0.dev20251014__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (63) hide show
  1. sky/__init__.py +4 -2
  2. sky/adaptors/shadeform.py +89 -0
  3. sky/authentication.py +52 -2
  4. sky/backends/backend_utils.py +35 -25
  5. sky/backends/cloud_vm_ray_backend.py +5 -5
  6. sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
  7. sky/catalog/kubernetes_catalog.py +19 -25
  8. sky/catalog/shadeform_catalog.py +165 -0
  9. sky/client/cli/command.py +53 -19
  10. sky/client/sdk.py +13 -1
  11. sky/clouds/__init__.py +2 -0
  12. sky/clouds/shadeform.py +393 -0
  13. sky/dashboard/out/404.html +1 -1
  14. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  15. sky/dashboard/out/clusters/[cluster].html +1 -1
  16. sky/dashboard/out/clusters.html +1 -1
  17. sky/dashboard/out/config.html +1 -1
  18. sky/dashboard/out/index.html +1 -1
  19. sky/dashboard/out/infra/[context].html +1 -1
  20. sky/dashboard/out/infra.html +1 -1
  21. sky/dashboard/out/jobs/[job].html +1 -1
  22. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  23. sky/dashboard/out/jobs.html +1 -1
  24. sky/dashboard/out/users.html +1 -1
  25. sky/dashboard/out/volumes.html +1 -1
  26. sky/dashboard/out/workspace/new.html +1 -1
  27. sky/dashboard/out/workspaces/[name].html +1 -1
  28. sky/dashboard/out/workspaces.html +1 -1
  29. sky/jobs/controller.py +122 -145
  30. sky/jobs/recovery_strategy.py +59 -82
  31. sky/jobs/scheduler.py +5 -5
  32. sky/jobs/state.py +65 -21
  33. sky/jobs/utils.py +58 -22
  34. sky/metrics/utils.py +27 -6
  35. sky/provision/__init__.py +1 -0
  36. sky/provision/kubernetes/utils.py +44 -39
  37. sky/provision/shadeform/__init__.py +11 -0
  38. sky/provision/shadeform/config.py +12 -0
  39. sky/provision/shadeform/instance.py +351 -0
  40. sky/provision/shadeform/shadeform_utils.py +83 -0
  41. sky/server/common.py +4 -2
  42. sky/server/requests/executor.py +25 -3
  43. sky/server/server.py +9 -3
  44. sky/setup_files/dependencies.py +1 -0
  45. sky/sky_logging.py +0 -2
  46. sky/skylet/constants.py +23 -6
  47. sky/skylet/log_lib.py +0 -1
  48. sky/skylet/log_lib.pyi +1 -1
  49. sky/templates/shadeform-ray.yml.j2 +72 -0
  50. sky/utils/common.py +2 -0
  51. sky/utils/context.py +57 -51
  52. sky/utils/context_utils.py +15 -11
  53. sky/utils/controller_utils.py +35 -8
  54. sky/utils/locks.py +20 -5
  55. sky/utils/subprocess_utils.py +4 -3
  56. {skypilot_nightly-1.0.0.dev20251012.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/METADATA +39 -38
  57. {skypilot_nightly-1.0.0.dev20251012.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/RECORD +63 -54
  58. /sky/dashboard/out/_next/static/{yOfMelBaFp8uL5F9atyAK → 9Fek73R28lDp1A5J4N7g7}/_buildManifest.js +0 -0
  59. /sky/dashboard/out/_next/static/{yOfMelBaFp8uL5F9atyAK → 9Fek73R28lDp1A5J4N7g7}/_ssgManifest.js +0 -0
  60. {skypilot_nightly-1.0.0.dev20251012.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/WHEEL +0 -0
  61. {skypilot_nightly-1.0.0.dev20251012.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/entry_points.txt +0 -0
  62. {skypilot_nightly-1.0.0.dev20251012.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/licenses/LICENSE +0 -0
  63. {skypilot_nightly-1.0.0.dev20251012.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/top_level.txt +0 -0
sky/__init__.py CHANGED
@@ -7,7 +7,7 @@ import urllib.request
7
7
  from sky.utils import directory_utils
8
8
 
9
9
  # Replaced with the current commit when building the wheels.
10
- _SKYPILOT_COMMIT_SHA = '7d5d1a2925fc7192af10061ca395d329364e7405'
10
+ _SKYPILOT_COMMIT_SHA = '7bcf0fcb2073aac435139d1d85fc0e66acca26a5'
11
11
 
12
12
 
13
13
  def _get_git_commit():
@@ -37,7 +37,7 @@ def _get_git_commit():
37
37
 
38
38
 
39
39
  __commit__ = _get_git_commit()
40
- __version__ = '1.0.0.dev20251012'
40
+ __version__ = '1.0.0.dev20251014'
41
41
  __root_dir__ = directory_utils.get_sky_dir()
42
42
 
43
43
 
@@ -150,6 +150,7 @@ Vsphere = clouds.Vsphere
150
150
  Fluidstack = clouds.Fluidstack
151
151
  Nebius = clouds.Nebius
152
152
  Hyperbolic = clouds.Hyperbolic
153
+ Shadeform = clouds.Shadeform
153
154
  Seeweb = clouds.Seeweb
154
155
 
155
156
  __all__ = [
@@ -172,6 +173,7 @@ __all__ = [
172
173
  'Fluidstack',
173
174
  'Nebius',
174
175
  'Hyperbolic',
176
+ 'Shadeform',
175
177
  'Seeweb',
176
178
  'Optimizer',
177
179
  'OptimizeTarget',
@@ -0,0 +1,89 @@
1
+ """Shadeform cloud adaptor."""
2
+
3
+ import functools
4
+ import socket
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import requests
8
+
9
+ from sky import sky_logging
10
+ from sky.provision.shadeform import shadeform_utils
11
+ from sky.utils import common_utils
12
+
13
+ logger = sky_logging.init_logger(__name__)
14
+
15
+ _shadeform_sdk = None
16
+
17
+
18
+ def import_package(func):
19
+
20
+ @functools.wraps(func)
21
+ def wrapper(*args, **kwargs):
22
+ global _shadeform_sdk
23
+ if _shadeform_sdk is None:
24
+ try:
25
+ import shadeform as _shadeform # pylint: disable=import-outside-toplevel
26
+ _shadeform_sdk = _shadeform
27
+ except ImportError:
28
+ raise ImportError(
29
+ 'Failed to import dependencies for Shadeform. '
30
+ 'Try pip install "skypilot[shadeform]"') from None
31
+ return func(*args, **kwargs)
32
+
33
+ return wrapper
34
+
35
+
36
+ @import_package
37
+ def shadeform():
38
+ """Return the shadeform package."""
39
+ return _shadeform_sdk
40
+
41
+
42
+ def list_ssh_keys() -> List[Dict[str, Any]]:
43
+ """List all SSH keys in Shadeform account."""
44
+ try:
45
+ response = shadeform_utils.get_ssh_keys()
46
+ return response.get('ssh_keys', [])
47
+ except (ValueError, KeyError, requests.exceptions.RequestException) as e:
48
+ logger.warning(f'Failed to list SSH keys from Shadeform: {e}')
49
+ return []
50
+
51
+
52
+ def add_ssh_key_to_shadeform(public_key: str) -> Optional[str]:
53
+ """Add SSH key to Shadeform if it doesn't already exist.
54
+
55
+ Args:
56
+ public_key: The SSH public key string.
57
+
58
+ Returns:
59
+ The name of the key if added successfully, None otherwise.
60
+ """
61
+ try:
62
+ # Check if key already exists
63
+ existing_keys = list_ssh_keys()
64
+ key_exists = False
65
+ key_id = None
66
+ for key in existing_keys:
67
+ if key.get('public_key') == public_key:
68
+ key_exists = True
69
+ key_id = key.get('id')
70
+ break
71
+
72
+ if key_exists:
73
+ logger.info('SSH key already exists in Shadeform account')
74
+ return key_id
75
+
76
+ # Generate a unique key name
77
+ hostname = socket.gethostname()
78
+ key_name = f'skypilot-{hostname}-{common_utils.get_user_hash()[:8]}'
79
+
80
+ # Add the key
81
+ response = shadeform_utils.add_ssh_key(name=key_name,
82
+ public_key=public_key)
83
+ key_id = response['id']
84
+ logger.info(f'Added SSH key to Shadeform: {key_name, key_id}')
85
+ return key_id
86
+
87
+ except (ValueError, KeyError, requests.exceptions.RequestException) as e:
88
+ logger.warning(f'Failed to add SSH key to Shadeform: {e}')
89
+ return None
sky/authentication.py CHANGED
@@ -39,6 +39,7 @@ from sky.adaptors import gcp
39
39
  from sky.adaptors import ibm
40
40
  from sky.adaptors import runpod
41
41
  from sky.adaptors import seeweb as seeweb_adaptor
42
+ from sky.adaptors import shadeform as shadeform_adaptor
42
43
  from sky.adaptors import vast
43
44
  from sky.provision.fluidstack import fluidstack_utils
44
45
  from sky.provision.kubernetes import utils as kubernetes_utils
@@ -152,7 +153,12 @@ def get_or_generate_keys() -> Tuple[str, str]:
152
153
  return private_key_path, public_key_path
153
154
 
154
155
 
155
- def create_ssh_key_files_from_db(private_key_path: str):
156
+ def create_ssh_key_files_from_db(private_key_path: str) -> bool:
157
+ """Creates the ssh key files from the database.
158
+
159
+ Returns:
160
+ True if the ssh key files are created successfully, False otherwise.
161
+ """
156
162
  # Assume private key path is in the format of
157
163
  # ~/.sky/clients/<user_hash>/ssh/sky-key
158
164
  separated_path = os.path.normpath(private_key_path).split(os.path.sep)
@@ -180,12 +186,14 @@ def create_ssh_key_files_from_db(private_key_path: str):
180
186
  ssh_public_key, ssh_private_key, exists = (
181
187
  global_user_state.get_ssh_keys(user_hash))
182
188
  if not exists:
183
- raise RuntimeError(f'SSH keys not found for user {user_hash}')
189
+ logger.debug(f'SSH keys not found for user {user_hash}')
190
+ return False
184
191
  _save_key_pair(private_key_path, public_key_path, ssh_private_key,
185
192
  ssh_public_key)
186
193
  assert os.path.exists(public_key_path), (
187
194
  'Private key found, but associated public key '
188
195
  f'{public_key_path} does not exist.')
196
+ return True
189
197
 
190
198
 
191
199
  def configure_ssh_info(config: Dict[str, Any]) -> Dict[str, Any]:
@@ -511,6 +519,48 @@ def setup_hyperbolic_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
511
519
  return configure_ssh_info(config)
512
520
 
513
521
 
522
+ def setup_shadeform_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
523
+ """Sets up SSH authentication for Shadeform.
524
+ - Generates a new SSH key pair if one does not exist.
525
+ - Adds the public SSH key to the user's Shadeform account.
526
+
527
+ Note: This assumes there is a Shadeform Python SDK available.
528
+ If no official SDK exists, this function would need to use direct API calls.
529
+ """
530
+
531
+ _, public_key_path = get_or_generate_keys()
532
+ ssh_key_id = None
533
+
534
+ with open(public_key_path, 'r', encoding='utf-8') as f:
535
+ public_key = f.read().strip()
536
+
537
+ try:
538
+ # Add SSH key to Shadeform using our utility functions
539
+ ssh_key_id = shadeform_adaptor.add_ssh_key_to_shadeform(public_key)
540
+
541
+ except ImportError as e:
542
+ # If required dependencies are missing
543
+ logger.warning(
544
+ f'Failed to add Shadeform SSH key due to missing dependencies: '
545
+ f'{e}. Manually configure SSH keys in your Shadeform account.')
546
+
547
+ except Exception as e:
548
+ logger.warning(f'Failed to set up Shadeform authentication: {e}')
549
+ raise exceptions.CloudUserIdentityError(
550
+ 'Failed to set up SSH authentication for Shadeform. '
551
+ f'Please ensure your Shadeform credentials are configured: {e}'
552
+ ) from e
553
+
554
+ if ssh_key_id is None:
555
+ raise Exception('Failed to add SSH key to Shadeform')
556
+
557
+ # Configure SSH info in the config
558
+ config['auth']['ssh_public_key'] = public_key_path
559
+ config['auth']['ssh_key_id'] = ssh_key_id
560
+
561
+ return configure_ssh_info(config)
562
+
563
+
514
564
  def setup_primeintellect_authentication(
515
565
  config: Dict[str, Any]) -> Dict[str, Any]:
516
566
  """Sets up SSH authentication for Prime Intellect.
@@ -1124,6 +1124,8 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, tmp_yaml_path: str):
1124
1124
  config = auth.setup_fluidstack_authentication(config)
1125
1125
  elif isinstance(cloud, clouds.Hyperbolic):
1126
1126
  config = auth.setup_hyperbolic_authentication(config)
1127
+ elif isinstance(cloud, clouds.Shadeform):
1128
+ config = auth.setup_shadeform_authentication(config)
1127
1129
  elif isinstance(cloud, clouds.PrimeIntellect):
1128
1130
  config = auth.setup_primeintellect_authentication(config)
1129
1131
  elif isinstance(cloud, clouds.Seeweb):
@@ -1855,6 +1857,13 @@ def check_owner_identity(cluster_name: str) -> None:
1855
1857
  summary_response=True)
1856
1858
  if record is None:
1857
1859
  return
1860
+ _check_owner_identity_with_record(cluster_name, record)
1861
+
1862
+
1863
+ def _check_owner_identity_with_record(cluster_name: str,
1864
+ record: Dict[str, Any]) -> None:
1865
+ if env_options.Options.SKIP_CLOUD_IDENTITY_CHECK.get():
1866
+ return
1858
1867
  handle = record['handle']
1859
1868
  if not isinstance(handle, backends.CloudVmRayResourceHandle):
1860
1869
  return
@@ -2149,6 +2158,7 @@ def check_can_clone_disk_and_override_task(
2149
2158
 
2150
2159
  def _update_cluster_status(
2151
2160
  cluster_name: str,
2161
+ record: Dict[str, Any],
2152
2162
  include_user_info: bool = True,
2153
2163
  summary_response: bool = False) -> Optional[Dict[str, Any]]:
2154
2164
  """Update the cluster status.
@@ -2177,12 +2187,6 @@ def _update_cluster_status(
2177
2187
  fetched from the cloud provider or there are leaked nodes causing
2178
2188
  the node number larger than expected.
2179
2189
  """
2180
- record = global_user_state.get_cluster_from_name(
2181
- cluster_name,
2182
- include_user_info=include_user_info,
2183
- summary_response=summary_response)
2184
- if record is None:
2185
- return None
2186
2190
  handle = record['handle']
2187
2191
  if handle.cluster_yaml is None:
2188
2192
  # Remove cluster from db since this cluster does not have a config file
@@ -2675,10 +2679,9 @@ def refresh_cluster_record(
2675
2679
  # using the correct cloud credentials.
2676
2680
  workspace = record.get('workspace', constants.SKYPILOT_DEFAULT_WORKSPACE)
2677
2681
  with skypilot_config.local_active_workspace_ctx(workspace):
2678
- check_owner_identity(cluster_name)
2679
-
2680
- if not isinstance(record['handle'], backends.CloudVmRayResourceHandle):
2681
- return record
2682
+ # check_owner_identity returns if the record handle is
2683
+ # not a CloudVmRayResourceHandle
2684
+ _check_owner_identity_with_record(cluster_name, record)
2682
2685
 
2683
2686
  # The loop logic allows us to notice if the status was updated in the
2684
2687
  # global_user_state by another process and stop trying to get the lock.
@@ -2695,7 +2698,8 @@ def refresh_cluster_record(
2695
2698
  return record
2696
2699
 
2697
2700
  if cluster_lock_already_held:
2698
- return _update_cluster_status(cluster_name, include_user_info,
2701
+ return _update_cluster_status(cluster_name, record,
2702
+ include_user_info,
2699
2703
  summary_response)
2700
2704
 
2701
2705
  # Try to acquire the lock so we can fetch the status.
@@ -2711,7 +2715,7 @@ def refresh_cluster_record(
2711
2715
  record, force_refresh_statuses):
2712
2716
  return record
2713
2717
  # Update and return the cluster status.
2714
- return _update_cluster_status(cluster_name,
2718
+ return _update_cluster_status(cluster_name, record,
2715
2719
  include_user_info,
2716
2720
  summary_response)
2717
2721
 
@@ -3115,25 +3119,23 @@ def refresh_cluster_records() -> None:
3115
3119
  exclude_managed_clusters = True
3116
3120
  if env_options.Options.SHOW_DEBUG_INFO.get():
3117
3121
  exclude_managed_clusters = False
3118
- cluster_names = global_user_state.get_cluster_names(
3119
- exclude_managed_clusters=exclude_managed_clusters,)
3122
+ cluster_names = set(
3123
+ global_user_state.get_cluster_names(
3124
+ exclude_managed_clusters=exclude_managed_clusters,))
3120
3125
 
3121
3126
  # TODO(syang): we should try not to leak
3122
3127
  # request info in backend_utils.py.
3123
3128
  # Refactor this to use some other info to
3124
3129
  # determine if a launch is in progress.
3125
- request = requests_lib.get_request_tasks(
3130
+ requests = requests_lib.get_request_tasks(
3126
3131
  req_filter=requests_lib.RequestTaskFilter(
3127
3132
  status=[requests_lib.RequestStatus.RUNNING],
3128
- cluster_names=cluster_names,
3129
3133
  include_request_names=['sky.launch']))
3130
3134
  cluster_names_with_launch_request = {
3131
- request.cluster_name for request in request
3135
+ request.cluster_name for request in requests
3132
3136
  }
3133
- cluster_names_without_launch_request = [
3134
- cluster_name for cluster_name in cluster_names
3135
- if cluster_name not in cluster_names_with_launch_request
3136
- ]
3137
+ cluster_names_without_launch_request = (cluster_names -
3138
+ cluster_names_with_launch_request)
3137
3139
 
3138
3140
  def _refresh_cluster_record(cluster_name):
3139
3141
  return _refresh_cluster(cluster_name,
@@ -3142,7 +3144,7 @@ def refresh_cluster_records() -> None:
3142
3144
  include_user_info=False,
3143
3145
  summary_response=True)
3144
3146
 
3145
- if len(cluster_names) > 0:
3147
+ if len(cluster_names_without_launch_request) > 0:
3146
3148
  # Do not refresh the clusters that have an active launch request.
3147
3149
  subprocess_utils.run_in_parallel(_refresh_cluster_record,
3148
3150
  cluster_names_without_launch_request)
@@ -3268,7 +3270,15 @@ def get_clusters(
3268
3270
  expanded_private_key_path = os.path.expanduser(
3269
3271
  ssh_private_key_path)
3270
3272
  if not os.path.exists(expanded_private_key_path):
3271
- auth.create_ssh_key_files_from_db(ssh_private_key_path)
3273
+ success = auth.create_ssh_key_files_from_db(
3274
+ ssh_private_key_path)
3275
+ if not success:
3276
+ # If the ssh key files are not found, we do not
3277
+ # update the record with credentials.
3278
+ logger.debug(
3279
+ f'SSH keys not found for cluster {record["name"]} '
3280
+ f'at key path {ssh_private_key_path}')
3281
+ continue
3272
3282
  else:
3273
3283
  private_key_path, _ = auth.get_or_generate_keys()
3274
3284
  expanded_private_key_path = os.path.expanduser(private_key_path)
@@ -3342,13 +3352,13 @@ def get_clusters(
3342
3352
  # request info in backend_utils.py.
3343
3353
  # Refactor this to use some other info to
3344
3354
  # determine if a launch is in progress.
3345
- request = requests_lib.get_request_tasks(
3355
+ requests = requests_lib.get_request_tasks(
3346
3356
  req_filter=requests_lib.RequestTaskFilter(
3347
3357
  status=[requests_lib.RequestStatus.RUNNING],
3348
3358
  cluster_names=cluster_names,
3349
3359
  include_request_names=['sky.launch']))
3350
3360
  cluster_names_with_launch_request = {
3351
- request.cluster_name for request in request
3361
+ request.cluster_name for request in requests
3352
3362
  }
3353
3363
  cluster_names_without_launch_request = [
3354
3364
  cluster_name for cluster_name in cluster_names
@@ -141,6 +141,7 @@ _NODES_LAUNCHING_PROGRESS_TIMEOUT = {
141
141
  clouds.OCI: 300,
142
142
  clouds.Paperspace: 600,
143
143
  clouds.Kubernetes: 300,
144
+ clouds.Shadeform: 300,
144
145
  clouds.Vsphere: 240,
145
146
  }
146
147
 
@@ -304,6 +305,7 @@ def _get_cluster_config_template(cloud):
304
305
  clouds.RunPod: 'runpod-ray.yml.j2',
305
306
  clouds.Kubernetes: 'kubernetes-ray.yml.j2',
306
307
  clouds.SSH: 'kubernetes-ray.yml.j2',
308
+ clouds.Shadeform: 'shadeform-ray.yml.j2',
307
309
  clouds.Vsphere: 'vsphere-ray.yml.j2',
308
310
  clouds.Vast: 'vast-ray.yml.j2',
309
311
  clouds.Fluidstack: 'fluidstack-ray.yml.j2',
@@ -3718,7 +3720,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
3718
3720
 
3719
3721
  self._update_after_cluster_provisioned(
3720
3722
  handle, to_provision_config.prev_handle, task,
3721
- prev_cluster_status, lock_id, config_hash)
3723
+ prev_cluster_status, config_hash)
3722
3724
  return handle, False
3723
3725
 
3724
3726
  cluster_config_file = config_dict['ray']
@@ -3790,7 +3792,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
3790
3792
 
3791
3793
  self._update_after_cluster_provisioned(
3792
3794
  handle, to_provision_config.prev_handle, task,
3793
- prev_cluster_status, lock_id, config_hash)
3795
+ prev_cluster_status, config_hash)
3794
3796
  return handle, False
3795
3797
 
3796
3798
  def _open_ports(self, handle: CloudVmRayResourceHandle) -> None:
@@ -3808,7 +3810,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
3808
3810
  prev_handle: Optional[CloudVmRayResourceHandle],
3809
3811
  task: task_lib.Task,
3810
3812
  prev_cluster_status: Optional[status_lib.ClusterStatus],
3811
- lock_id: str, config_hash: str) -> None:
3813
+ config_hash: str) -> None:
3812
3814
  usage_lib.messages.usage.update_cluster_resources(
3813
3815
  handle.launched_nodes, handle.launched_resources)
3814
3816
  usage_lib.messages.usage.update_final_cluster_status(
@@ -3920,8 +3922,6 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
3920
3922
  handle.cached_external_ssh_ports, handle.docker_user,
3921
3923
  handle.ssh_user)
3922
3924
 
3923
- locks.get_lock(lock_id).force_unlock()
3924
-
3925
3925
  def _sync_workdir(self, handle: CloudVmRayResourceHandle,
3926
3926
  workdir: Union[Path, Dict[str, Any]],
3927
3927
  envs_and_secrets: Dict[str, str]) -> None:
@@ -0,0 +1,142 @@
1
+ """A script that generates the Shadeform catalog.
2
+
3
+ Usage:
4
+ python fetch_shadeform.py [-h] [--api-key API_KEY]
5
+ [--api-key-path API_KEY_PATH]
6
+
7
+ If neither --api-key nor --api-key-path are provided, this script will parse
8
+ `~/.shadeform/api_key` to look for Shadeform API key.
9
+ """
10
+ import argparse
11
+ import csv
12
+ import json
13
+ import os
14
+ from typing import Dict
15
+
16
+ import requests
17
+
18
+ ENDPOINT = 'https://api.shadeform.ai/v1/instances/types'
19
+ DEFAULT_SHADEFORM_API_KEY_PATH = os.path.expanduser('~/.shadeform/api_key')
20
+
21
+
22
+ def parse_gpu_info(gpu_type: str, num_gpus: int, ram_per_gpu: int) -> Dict:
23
+ """Parse GPU information for the catalog."""
24
+
25
+ manufacturer = 'NVIDIA'
26
+ if gpu_type == 'MI300X':
27
+ manufacturer = 'AMD'
28
+ elif gpu_type == 'GAUDI2':
29
+ manufacturer = 'Intel'
30
+
31
+ return {
32
+ 'Gpus': [{
33
+ 'Name': gpu_type,
34
+ 'Manufacturer': manufacturer,
35
+ 'Count': float(num_gpus),
36
+ 'MemoryInfo': {
37
+ 'SizeInMiB': ram_per_gpu
38
+ },
39
+ 'TotalGpuMemoryInMiB': ram_per_gpu * num_gpus
40
+ }]
41
+ }
42
+
43
+
44
+ def create_catalog(api_key: str, output_path: str) -> None:
45
+ """Create Shadeform catalog by fetching from API."""
46
+ headers = {'X-API-KEY': api_key}
47
+
48
+ params = {'available': 'true'}
49
+
50
+ response = requests.get(ENDPOINT,
51
+ headers=headers,
52
+ params=params,
53
+ timeout=30)
54
+ response.raise_for_status()
55
+
56
+ data = response.json()
57
+ instance_types = data.get('instance_types', [])
58
+
59
+ with open(output_path, mode='w', encoding='utf-8') as f:
60
+ writer = csv.writer(f, delimiter=',', quotechar='"')
61
+ writer.writerow([
62
+ 'InstanceType', 'AcceleratorName', 'AcceleratorCount', 'vCPUs',
63
+ 'MemoryGiB', 'Price', 'Region', 'GpuInfo', 'SpotPrice'
64
+ ])
65
+
66
+ for instance in instance_types:
67
+ config = instance['configuration']
68
+
69
+ cloud = instance['cloud']
70
+ shade_instance_type = instance['shade_instance_type']
71
+ instance_type = f'{cloud}_{shade_instance_type.replace("_", "-")}'
72
+ gpu_type = config['gpu_type'].replace('_', '-')
73
+ gpu_count = float(config['num_gpus'])
74
+ vcpus = float(config['vcpus'])
75
+ memory_gb = int(config['memory_in_gb'])
76
+
77
+ # Append "B" to instance_type and gpu_type if they end with "G"
78
+ if instance_type.endswith('G'):
79
+ instance_type += 'B'
80
+ if gpu_type.endswith('G'):
81
+ gpu_type += 'B'
82
+
83
+ # Replace "Gx" with "GBx" (case sensitive)
84
+ if 'Gx' in instance_type:
85
+ instance_type = instance_type.replace('Gx', 'GBx')
86
+
87
+ # Price is in cents per hour, convert to dollars
88
+ price = float(instance['hourly_price']) / 100
89
+
90
+ # Create GPU info
91
+ gpuinfo = None
92
+ if gpu_count > 0:
93
+ gpuinfo_dict = parse_gpu_info(gpu_type, int(gpu_count),
94
+ int(config['vram_per_gpu_in_gb']))
95
+ gpuinfo = json.dumps(gpuinfo_dict).replace('"', '\'')
96
+
97
+ # Write entry for each available region
98
+ for availability in instance.get('availability', []):
99
+ if availability['available'] and gpu_count > 0:
100
+ region = availability['region']
101
+ writer.writerow([
102
+ instance_type,
103
+ gpu_type,
104
+ gpu_count,
105
+ vcpus,
106
+ memory_gb,
107
+ price,
108
+ region,
109
+ gpuinfo,
110
+ '' # No spot pricing info available
111
+ ])
112
+
113
+
114
+ def get_api_key(cmdline_args: argparse.Namespace) -> str:
115
+ """Get Shadeform API key from cmdline or default path."""
116
+ api_key = cmdline_args.api_key
117
+ if api_key is None:
118
+ if cmdline_args.api_key_path is not None:
119
+ with open(cmdline_args.api_key_path, mode='r',
120
+ encoding='utf-8') as f:
121
+ api_key = f.read().strip()
122
+ else:
123
+ # Read from ~/.shadeform/api_key
124
+ with open(DEFAULT_SHADEFORM_API_KEY_PATH,
125
+ mode='r',
126
+ encoding='utf-8') as f:
127
+ api_key = f.read().strip()
128
+ assert api_key is not None, (
129
+ f'API key not found. Please provide via --api-key or place in '
130
+ f'{DEFAULT_SHADEFORM_API_KEY_PATH}')
131
+ return api_key
132
+
133
+
134
+ if __name__ == '__main__':
135
+ parser = argparse.ArgumentParser()
136
+ parser.add_argument('--api-key', help='Shadeform API key.')
137
+ parser.add_argument('--api-key-path',
138
+ help='path of file containing Shadeform API key.')
139
+ args = parser.parse_args()
140
+ os.makedirs('shadeform', exist_ok=True)
141
+ create_catalog(get_api_key(args), 'shadeform/vms.csv')
142
+ print('Shadeform catalog saved to shadeform/vms.csv')
@@ -3,6 +3,7 @@
3
3
  Kubernetes does not require a catalog of instances, but we need an image catalog
4
4
  mapping SkyPilot image tags to corresponding container image tags.
5
5
  """
6
+ import collections
6
7
  import re
7
8
  import typing
8
9
  from typing import Dict, List, Optional, Set, Tuple
@@ -167,12 +168,25 @@ def _list_accelerators(
167
168
  accelerators_qtys: Set[Tuple[str, int]] = set()
168
169
  keys = lf.get_label_keys()
169
170
  nodes = kubernetes_utils.get_kubernetes_nodes(context=context)
171
+
172
+ # Check if any nodes have accelerators before fetching pods
173
+ has_accelerator_nodes = False
174
+ for node in nodes:
175
+ for key in keys:
176
+ if key in node.metadata.labels:
177
+ has_accelerator_nodes = True
178
+ break
179
+ if has_accelerator_nodes:
180
+ break
181
+
182
+ # Only fetch pods if we have accelerator nodes and realtime is requested
170
183
  pods = None
171
- if realtime:
172
- # Get the pods to get the real-time GPU usage
184
+ allocated_qty_by_node: Dict[str, int] = collections.defaultdict(int)
185
+ if realtime and has_accelerator_nodes:
186
+ # Get the allocated GPU quantity by each node
173
187
  try:
174
- pods = kubernetes_utils.get_all_pods_in_kubernetes_cluster(
175
- context=context)
188
+ allocated_qty_by_node = (
189
+ kubernetes_utils.get_allocated_gpu_qty_by_node(context=context))
176
190
  except kubernetes.api_exception() as e:
177
191
  if e.status == 403:
178
192
  logger.warning(
@@ -191,7 +205,6 @@ def _list_accelerators(
191
205
  for node in nodes:
192
206
  for key in keys:
193
207
  if key in node.metadata.labels:
194
- allocated_qty = 0
195
208
  accelerator_name = lf.get_accelerator_from_label_value(
196
209
  node.metadata.labels.get(key))
197
210
 
@@ -251,26 +264,7 @@ def _list_accelerators(
251
264
  total_accelerators_available[accelerator_name] = -1
252
265
  continue
253
266
 
254
- for pod in pods:
255
- # Get all the pods running on the node
256
- if (pod.spec.node_name == node.metadata.name and
257
- pod.status.phase in ['Running', 'Pending']):
258
- # Skip pods that should not count against GPU count
259
- if (kubernetes_utils.
260
- should_exclude_pod_from_gpu_allocation(pod)):
261
- logger.debug(
262
- f'Excluding pod '
263
- f'{pod.metadata.name} from GPU count '
264
- f'calculations on node {node.metadata.name}')
265
- continue
266
- # Iterate over all the containers in the pod and sum
267
- # the GPU requests
268
- for container in pod.spec.containers:
269
- if container.resources.requests:
270
- allocated_qty += (
271
- kubernetes_utils.get_node_accelerator_count(
272
- context, container.resources.requests))
273
-
267
+ allocated_qty = allocated_qty_by_node[node.metadata.name]
274
268
  accelerators_available = accelerator_count - allocated_qty
275
269
  # Initialize the total_accelerators_available to make sure the
276
270
  # key exists in the dictionary.