skypilot-nightly 1.0.0.dev20250612__py3-none-any.whl → 1.0.0.dev20250614__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. sky/__init__.py +4 -2
  2. sky/adaptors/hyperbolic.py +8 -0
  3. sky/adaptors/kubernetes.py +3 -2
  4. sky/authentication.py +20 -2
  5. sky/backends/backend_utils.py +11 -3
  6. sky/backends/cloud_vm_ray_backend.py +2 -1
  7. sky/benchmark/benchmark_state.py +2 -1
  8. sky/catalog/data_fetchers/fetch_aws.py +1 -1
  9. sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
  10. sky/catalog/data_fetchers/fetch_vast.py +1 -1
  11. sky/catalog/hyperbolic_catalog.py +133 -0
  12. sky/check.py +2 -1
  13. sky/cli.py +1 -1
  14. sky/client/cli.py +1 -1
  15. sky/clouds/__init__.py +2 -0
  16. sky/clouds/cloud.py +1 -1
  17. sky/clouds/gcp.py +1 -1
  18. sky/clouds/hyperbolic.py +276 -0
  19. sky/clouds/kubernetes.py +8 -2
  20. sky/clouds/ssh.py +7 -3
  21. sky/dashboard/out/404.html +1 -1
  22. sky/dashboard/out/_next/static/chunks/37-7754056a4b503e1d.js +6 -0
  23. sky/dashboard/out/_next/static/chunks/600.bd2ed8c076b720ec.js +16 -0
  24. sky/dashboard/out/_next/static/chunks/{856-0776dc6ed6000c39.js → 856-c2c39c0912285e54.js} +1 -1
  25. sky/dashboard/out/_next/static/chunks/938-245c9ac4c9e8bf15.js +1 -0
  26. sky/dashboard/out/_next/static/chunks/{webpack-208a9812ab4f61c9.js → webpack-27de3d9d450d81c6.js} +1 -1
  27. sky/dashboard/out/_next/static/css/{5d71bfc09f184bab.css → 6f84444b8f3c656c.css} +1 -1
  28. sky/dashboard/out/_next/static/{G3DXdMFu2Jzd-Dody9iq1 → nm5jrKpUZh2W0SxzyDKhz}/_buildManifest.js +1 -1
  29. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  30. sky/dashboard/out/clusters/[cluster].html +1 -1
  31. sky/dashboard/out/clusters.html +1 -1
  32. sky/dashboard/out/config.html +1 -1
  33. sky/dashboard/out/index.html +1 -1
  34. sky/dashboard/out/infra/[context].html +1 -1
  35. sky/dashboard/out/infra.html +1 -1
  36. sky/dashboard/out/jobs/[job].html +1 -1
  37. sky/dashboard/out/jobs.html +1 -1
  38. sky/dashboard/out/users.html +1 -1
  39. sky/dashboard/out/workspace/new.html +1 -1
  40. sky/dashboard/out/workspaces/[name].html +1 -1
  41. sky/dashboard/out/workspaces.html +1 -1
  42. sky/data/storage.py +2 -2
  43. sky/jobs/state.py +43 -44
  44. sky/provision/__init__.py +1 -0
  45. sky/provision/common.py +1 -1
  46. sky/provision/gcp/config.py +1 -1
  47. sky/provision/hyperbolic/__init__.py +11 -0
  48. sky/provision/hyperbolic/config.py +10 -0
  49. sky/provision/hyperbolic/instance.py +423 -0
  50. sky/provision/hyperbolic/utils.py +373 -0
  51. sky/provision/kubernetes/instance.py +2 -1
  52. sky/provision/kubernetes/utils.py +60 -13
  53. sky/resources.py +2 -2
  54. sky/serve/serve_state.py +81 -15
  55. sky/server/requests/preconditions.py +1 -1
  56. sky/server/requests/requests.py +11 -6
  57. sky/setup_files/dependencies.py +2 -1
  58. sky/skylet/configs.py +26 -19
  59. sky/skylet/constants.py +1 -1
  60. sky/skylet/job_lib.py +3 -5
  61. sky/task.py +1 -1
  62. sky/templates/hyperbolic-ray.yml.j2 +67 -0
  63. sky/templates/kubernetes-ray.yml.j2 +1 -1
  64. sky/users/permission.py +2 -0
  65. sky/utils/common_utils.py +6 -0
  66. sky/utils/context.py +1 -1
  67. sky/utils/infra_utils.py +1 -1
  68. sky/utils/kubernetes/generate_kubeconfig.sh +1 -1
  69. {skypilot_nightly-1.0.0.dev20250612.dist-info → skypilot_nightly-1.0.0.dev20250614.dist-info}/METADATA +2 -1
  70. {skypilot_nightly-1.0.0.dev20250612.dist-info → skypilot_nightly-1.0.0.dev20250614.dist-info}/RECORD +79 -70
  71. sky/dashboard/out/_next/static/chunks/37-d8aebf1683522a0b.js +0 -6
  72. sky/dashboard/out/_next/static/chunks/600.15a0009177e86b86.js +0 -16
  73. sky/dashboard/out/_next/static/chunks/938-ab185187a63f9cdb.js +0 -1
  74. /sky/dashboard/out/_next/static/chunks/{843-6fcc4bf91ac45b39.js → 843-5011affc9540757f.js} +0 -0
  75. /sky/dashboard/out/_next/static/chunks/pages/{_app-7bbd9d39d6f9a98a.js → _app-664031f6ae737f80.js} +0 -0
  76. /sky/dashboard/out/_next/static/chunks/pages/clusters/{[cluster]-451a14e7e755ebbc.js → [cluster]-20210f8cd809063d.js} +0 -0
  77. /sky/dashboard/out/_next/static/chunks/pages/{jobs-fe233baf3d073491.js → jobs-ae7a5e9fa5a5b5f0.js} +0 -0
  78. /sky/dashboard/out/_next/static/{G3DXdMFu2Jzd-Dody9iq1 → nm5jrKpUZh2W0SxzyDKhz}/_ssgManifest.js +0 -0
  79. {skypilot_nightly-1.0.0.dev20250612.dist-info → skypilot_nightly-1.0.0.dev20250614.dist-info}/WHEEL +0 -0
  80. {skypilot_nightly-1.0.0.dev20250612.dist-info → skypilot_nightly-1.0.0.dev20250614.dist-info}/entry_points.txt +0 -0
  81. {skypilot_nightly-1.0.0.dev20250612.dist-info → skypilot_nightly-1.0.0.dev20250614.dist-info}/licenses/LICENSE +0 -0
  82. {skypilot_nightly-1.0.0.dev20250612.dist-info → skypilot_nightly-1.0.0.dev20250614.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,373 @@
1
+ """Hyperbolic API utilities."""
2
+ import enum
3
+ import json
4
+ import os
5
+ import time
6
+ from typing import Any, Dict, Optional, Tuple
7
+
8
+ import requests
9
+
10
+ from sky import authentication
11
+ from sky import sky_logging
12
+ from sky.utils import status_lib
13
+
14
+ #TODO update to prod endpoint
15
+ BASE_URL = 'https://api.hyperbolic.xyz'
16
+ API_KEY_PATH = '~/.hyperbolic/api_key'
17
+
18
+ MAX_RETRIES = 3
19
+ RETRY_DELAY = 2 # seconds
20
+ TIMEOUT = 120
21
+
22
+ logger = sky_logging.init_logger(__name__)
23
+
24
+
25
+ class HyperbolicError(Exception):
26
+ """Base exception for Hyperbolic API errors."""
27
+ pass
28
+
29
+
30
+ class HyperbolicInstanceStatus(enum.Enum):
31
+ """Statuses enum for Hyperbolic instances."""
32
+ UNKNOWN = 'unknown'
33
+ ONLINE = 'online'
34
+ OFFLINE = 'offline'
35
+ STARTING = 'starting'
36
+ STOPPING = 'stopping'
37
+ BUSY = 'busy'
38
+ RESTARTING = 'restarting'
39
+ CREATING = 'creating'
40
+ FAILED = 'failed'
41
+ ERROR = 'error'
42
+ TERMINATED = 'terminated'
43
+
44
+ @classmethod
45
+ def cluster_status_map(
46
+ cls
47
+ ) -> Dict['HyperbolicInstanceStatus', Optional[status_lib.ClusterStatus]]:
48
+ return {
49
+ cls.CREATING: status_lib.ClusterStatus.INIT,
50
+ cls.STARTING: status_lib.ClusterStatus.INIT,
51
+ cls.ONLINE: status_lib.ClusterStatus.UP,
52
+ cls.FAILED: status_lib.ClusterStatus.INIT,
53
+ cls.ERROR: status_lib.ClusterStatus.INIT,
54
+ cls.RESTARTING: status_lib.ClusterStatus.INIT,
55
+ cls.STOPPING: status_lib.ClusterStatus.INIT,
56
+ cls.UNKNOWN: status_lib.ClusterStatus.INIT,
57
+ cls.BUSY: status_lib.ClusterStatus.INIT,
58
+ cls.OFFLINE: status_lib.ClusterStatus.INIT,
59
+ cls.TERMINATED: None,
60
+ }
61
+
62
+ @classmethod
63
+ def from_raw_status(cls, status: str) -> 'HyperbolicInstanceStatus':
64
+ """Convert raw status string to HyperbolicInstanceStatus enum."""
65
+ try:
66
+ return cls(status.lower())
67
+ except ValueError as exc:
68
+ raise HyperbolicError(f'Unknown instance status: {status}') from exc
69
+
70
+ def to_cluster_status(self) -> Optional[status_lib.ClusterStatus]:
71
+ """Convert to SkyPilot cluster status."""
72
+ return self.cluster_status_map().get(self)
73
+
74
+
75
+ class HyperbolicClient:
76
+ """Client for interacting with the Hyperbolic API."""
77
+
78
+ def __init__(self):
79
+ """Initialize the Hyperbolic client with API credentials."""
80
+ cred_path = os.path.expanduser(API_KEY_PATH)
81
+ if not os.path.exists(cred_path):
82
+ raise RuntimeError(f'API key not found at {cred_path}')
83
+ with open(cred_path, 'r', encoding='utf-8') as f:
84
+ self.api_key = f.read().strip()
85
+ self.headers = {'Authorization': f'Bearer {self.api_key}'}
86
+ self.api_url = BASE_URL
87
+
88
+ def _make_request(
89
+ self,
90
+ method: str,
91
+ endpoint: str,
92
+ payload: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
93
+ """Make an API request to Hyperbolic."""
94
+ url = f'{BASE_URL}{endpoint}'
95
+ headers = {
96
+ 'Authorization': f'Bearer {self.api_key}',
97
+ 'Content-Type': 'application/json'
98
+ }
99
+
100
+ # Debug logging for request
101
+ logger.debug(f'Making {method} request to {url}')
102
+ if payload:
103
+ logger.debug(f'Request payload: {json.dumps(payload, indent=2)}')
104
+
105
+ try:
106
+ if method == 'GET':
107
+ response = requests.get(url, headers=headers, timeout=120)
108
+ elif method == 'POST':
109
+ response = requests.post(url,
110
+ headers=headers,
111
+ json=payload,
112
+ timeout=120)
113
+ else:
114
+ raise HyperbolicError(f'Unsupported HTTP method: {method}')
115
+
116
+ # Debug logging for response
117
+ logger.debug(f'Response status code: {response.status_code}')
118
+ logger.debug(f'Response headers: {dict(response.headers)}')
119
+
120
+ # Try to parse response as JSON
121
+ try:
122
+ response_data = response.json()
123
+ logger.debug(
124
+ f'Response body: {json.dumps(response_data, indent=2)}')
125
+ except json.JSONDecodeError as exc:
126
+ # If response is not JSON, use the raw text
127
+ response_text = response.text
128
+ logger.debug(f'Response body (raw): {response_text}')
129
+ if not response.ok:
130
+ raise HyperbolicError(f'API request failed with status '
131
+ f'{response.status_code}: '
132
+ f'{response_text}') from exc
133
+ # If response is OK but not JSON, return empty dict
134
+ return {}
135
+
136
+ if not response.ok:
137
+ error_msg = response_data.get(
138
+ 'error', response_data.get('message', response.text))
139
+ raise HyperbolicError(
140
+ f'API request failed with status {response.status_code}: '
141
+ f'{error_msg}')
142
+
143
+ return response_data
144
+ except requests.exceptions.RequestException as e:
145
+ raise HyperbolicError(f'Request failed: {str(e)}') from e
146
+ except Exception as e:
147
+ raise HyperbolicError(
148
+ f'Unexpected error during API request: {str(e)}') from e
149
+
150
+ def launch_instance(self, gpu_model: str, gpu_count: int,
151
+ name: str) -> Tuple[str, str]:
152
+ """Launch a new instance with the specified configuration."""
153
+ # Initialize config with basic instance info
154
+ config = {
155
+ 'gpuModel': gpu_model,
156
+ 'gpuCount': str(gpu_count),
157
+ 'userMetadata': {
158
+ 'skypilot': {
159
+ 'cluster_name': name,
160
+ 'launch_time': str(int(time.time()))
161
+ }
162
+ }
163
+ }
164
+
165
+ config = authentication.setup_hyperbolic_authentication(config)
166
+
167
+ endpoint = '/v2/marketplace/instances/create-cheapest'
168
+ try:
169
+ response = self._make_request('POST', endpoint, payload=config)
170
+ logger.debug(f'Launch response: {json.dumps(response, indent=2)}')
171
+
172
+ instance_id = response.get('instanceName')
173
+ if not instance_id:
174
+ logger.error(f'No instance ID in response: {response}')
175
+ raise HyperbolicError('No instance ID returned from API')
176
+
177
+ logger.info(f'Successfully launched instance {instance_id}, '
178
+ f'waiting for it to be ready...')
179
+
180
+ # Wait for instance to be ready
181
+ if not self.wait_for_instance(
182
+ instance_id, HyperbolicInstanceStatus.ONLINE.value):
183
+ raise HyperbolicError(
184
+ f'Instance {instance_id} failed to reach ONLINE state')
185
+
186
+ # Get instance details to get SSH command
187
+ instances = self.list_instances(
188
+ metadata={'skypilot': {
189
+ 'cluster_name': name
190
+ }})
191
+ instance = instances.get(instance_id)
192
+ if not instance:
193
+ raise HyperbolicError(
194
+ f'Instance {instance_id} not found after launch')
195
+
196
+ ssh_command = instance.get('sshCommand')
197
+ if not ssh_command:
198
+ logger.error(
199
+ f'No SSH command available for instance {instance_id}')
200
+ raise HyperbolicError('No SSH command available for instance')
201
+
202
+ logger.info(f'Instance {instance_id} is ready with SSH command')
203
+ return instance_id, ssh_command
204
+
205
+ except Exception as e:
206
+ logger.error(f'Failed to launch instance: {str(e)}')
207
+ raise HyperbolicError(f'Failed to launch instance: {str(e)}') from e
208
+
209
+ def list_instances(
210
+ self,
211
+ status: Optional[str] = None,
212
+ metadata: Optional[Dict[str, Dict[str, str]]] = None
213
+ ) -> Dict[str, Dict[str, Any]]:
214
+ """List all instances, optionally filtered by status and metadata."""
215
+ endpoint = '/v1/marketplace/instances'
216
+ try:
217
+ response = self._make_request('GET', endpoint)
218
+ logger.debug(f'Raw API response: {json.dumps(response, indent=2)}')
219
+ instances = {}
220
+ for instance in response.get('instances', []):
221
+ instance_info = instance.get('instance', {})
222
+ current_status = instance_info.get('status')
223
+ logger.debug(
224
+ f'Instance {instance.get("id")} status: {current_status}')
225
+
226
+ # Convert raw status to enum
227
+ try:
228
+ instance_status = HyperbolicInstanceStatus.from_raw_status(
229
+ current_status)
230
+ except HyperbolicError as e:
231
+ logger.warning(f'Failed to parse status for instance '
232
+ f'{instance.get("id")}: {e}')
233
+ continue
234
+
235
+ if status and instance_status.value != status.lower():
236
+ continue
237
+
238
+ if metadata:
239
+ skypilot_metadata: Dict[str,
240
+ str] = metadata.get('skypilot', {})
241
+ cluster_name = skypilot_metadata.get('cluster_name', '')
242
+ instance_skypilot = instance.get('userMetadata',
243
+ {}).get('skypilot', {})
244
+ if not instance_skypilot.get('cluster_name',
245
+ '').startswith(cluster_name):
246
+ logger.debug(
247
+ f'Skipping instance {instance.get("id")} - '
248
+ f'skypilot metadata {instance_skypilot} '
249
+ f'does not match {skypilot_metadata}')
250
+ continue
251
+ logger.debug(f'Including instance {instance.get("id")} '
252
+ f'- skypilot metadata matches')
253
+
254
+ hardware = instance_info.get('hardware', {})
255
+ instances[instance.get('id')] = {
256
+ 'id': instance.get('id'),
257
+ 'created': instance.get('created'),
258
+ 'sshCommand': instance.get('sshCommand'),
259
+ 'status': instance_status.value,
260
+ 'gpu_count': instance_info.get('gpu_count'),
261
+ 'gpus_total': instance_info.get('gpus_total'),
262
+ 'owner': instance_info.get('owner'),
263
+ 'cpus': hardware.get('cpus'),
264
+ 'gpus': hardware.get('gpus'),
265
+ 'ram': hardware.get('ram'),
266
+ 'storage': hardware.get('storage'),
267
+ 'pricing': instance_info.get('pricing'),
268
+ 'metadata': instance.get('userMetadata', {})
269
+ }
270
+ return instances
271
+ except Exception as e:
272
+ raise HyperbolicError(f'Failed to list instances: {str(e)}') from e
273
+
274
+ def terminate_instance(self, instance_id: str) -> None:
275
+ """Terminate an instance by ID."""
276
+ endpoint = '/v1/marketplace/instances/terminate'
277
+ data = {'id': instance_id}
278
+ try:
279
+ self._make_request('POST', endpoint, payload=data)
280
+ except Exception as e:
281
+ raise HyperbolicError(
282
+ f'Failed to terminate instance {instance_id}: {str(e)}') from e
283
+
284
+ def wait_for_instance(self,
285
+ instance_id: str,
286
+ target_status: str,
287
+ timeout: int = TIMEOUT) -> bool:
288
+ """Wait for an instance to reach a specific status."""
289
+ start_time = time.time()
290
+ target_status_enum = HyperbolicInstanceStatus.from_raw_status(
291
+ target_status)
292
+ logger.info(
293
+ f'Waiting for instance {instance_id} '
294
+ f'to reach status {target_status_enum.value} and have SSH command')
295
+
296
+ while True:
297
+ elapsed = time.time() - start_time
298
+ if elapsed >= timeout:
299
+ logger.error(f'Timeout after {int(elapsed)}s '
300
+ f'waiting for instance {instance_id}')
301
+ return False
302
+
303
+ try:
304
+ instances = self.list_instances()
305
+ instance = instances.get(instance_id)
306
+
307
+ if not instance:
308
+ logger.warning(f'Instance {instance_id} not found')
309
+ time.sleep(5)
310
+ continue
311
+
312
+ current_status = instance.get('status', '').lower()
313
+ ssh_command = instance.get('sshCommand')
314
+ logger.debug(f'Current status: {current_status}, '
315
+ f'Target status: {target_status_enum.value}, '
316
+ f'SSH command: {ssh_command}')
317
+
318
+ if current_status == target_status_enum.value and ssh_command:
319
+ logger.info(f'Instance {instance_id} reached '
320
+ f'target status {target_status_enum.value} '
321
+ f'and has SSH command after {int(elapsed)}s')
322
+ return True
323
+
324
+ if current_status in ['failed', 'error', 'terminated']:
325
+ logger.error(f'Instance {instance_id} reached '
326
+ f'terminal status: {current_status} '
327
+ f'after {int(elapsed)}s')
328
+ return False
329
+
330
+ time.sleep(5)
331
+ except Exception as e: # pylint: disable=broad-except
332
+ logger.warning(
333
+ f'Error while waiting for instance {instance_id}: {str(e)}')
334
+ time.sleep(5)
335
+
336
+
337
+ # Module-level singleton client
338
+ _client = None
339
+
340
+
341
+ def get_client() -> HyperbolicClient:
342
+ """Get or create the Hyperbolic client singleton."""
343
+ global _client
344
+ if _client is None:
345
+ _client = HyperbolicClient()
346
+ return _client
347
+
348
+
349
+ # Backward-compatible wrapper functions
350
+ def launch_instance(gpu_model: str, gpu_count: int,
351
+ name: str) -> Tuple[str, str]:
352
+ """Launch a new instance with the specified configuration."""
353
+ return get_client().launch_instance(gpu_model, gpu_count, name)
354
+
355
+
356
+ def list_instances(
357
+ status: Optional[str] = None,
358
+ metadata: Optional[Dict[str, Dict[str, str]]] = None
359
+ ) -> Dict[str, Dict[str, Any]]:
360
+ """List all instances, optionally filtered by status and metadata."""
361
+ return get_client().list_instances(status=status, metadata=metadata)
362
+
363
+
364
+ def terminate_instance(instance_id: str) -> None:
365
+ """Terminate an instance by ID."""
366
+ return get_client().terminate_instance(instance_id)
367
+
368
+
369
+ def wait_for_instance(instance_id: str,
370
+ target_status: str,
371
+ timeout: int = TIMEOUT) -> bool:
372
+ """Wait for an instance to reach a specific status."""
373
+ return get_client().wait_for_instance(instance_id, target_status, timeout)
@@ -1277,7 +1277,8 @@ def query_instances(
1277
1277
  except kubernetes.max_retry_error():
1278
1278
  with ux_utils.print_exception_no_traceback():
1279
1279
  if is_ssh:
1280
- node_pool = context.lstrip('ssh-') if context else ''
1280
+ node_pool = common_utils.removeprefix(context,
1281
+ 'ssh-') if context else ''
1281
1282
  msg = (
1282
1283
  f'Cannot connect to SSH Node Pool {node_pool}. '
1283
1284
  'Please check if the SSH Node Pool is up and accessible. '
@@ -133,6 +133,30 @@ DEFAULT_MAX_RETRIES = 3
133
133
  DEFAULT_RETRY_INTERVAL_SECONDS = 1
134
134
 
135
135
 
136
+ def normalize_tpu_accelerator_name(accelerator: str) -> Tuple[str, int]:
137
+ """Normalize TPU names to the k8s-compatible name and extract count."""
138
+ # Examples:
139
+ # 'tpu-v6e-8' -> ('tpu-v6e-slice', 8)
140
+ # 'tpu-v5litepod-4' -> ('tpu-v5-lite-podslice', 4)
141
+
142
+ gcp_to_k8s_patterns = [
143
+ (r'^tpu-v6e-(\d+)$', 'tpu-v6e-slice'),
144
+ (r'^tpu-v5p-(\d+)$', 'tpu-v5p-slice'),
145
+ (r'^tpu-v5litepod-(\d+)$', 'tpu-v5-lite-podslice'),
146
+ (r'^tpu-v5lite-(\d+)$', 'tpu-v5-lite-device'),
147
+ (r'^tpu-v4-(\d+)$', 'tpu-v4-podslice'),
148
+ ]
149
+
150
+ for pattern, replacement in gcp_to_k8s_patterns:
151
+ match = re.match(pattern, accelerator)
152
+ if match:
153
+ count = int(match.group(1))
154
+ return replacement, count
155
+
156
+ # Default fallback
157
+ return accelerator, 1
158
+
159
+
136
160
  def _retry_on_error(max_retries=DEFAULT_MAX_RETRIES,
137
161
  retry_interval=DEFAULT_RETRY_INTERVAL_SECONDS,
138
162
  resource_type: Optional[str] = None):
@@ -427,6 +451,7 @@ class GKELabelFormatter(GPULabelFormatter):
427
451
 
428
452
  e.g. tpu-v5-lite-podslice:8 -> '2x4'
429
453
  """
454
+ acc_type, acc_count = normalize_tpu_accelerator_name(acc_type)
430
455
  count_to_topology = cls.GKE_TPU_TOPOLOGIES.get(acc_type,
431
456
  {}).get(acc_count, None)
432
457
  if count_to_topology is None:
@@ -461,6 +486,14 @@ class GKELabelFormatter(GPULabelFormatter):
461
486
  raise ValueError(
462
487
  f'Invalid accelerator name in GKE cluster: {value}')
463
488
 
489
+ @classmethod
490
+ def validate_label_value(cls, value: str) -> Tuple[bool, str]:
491
+ try:
492
+ _ = cls.get_accelerator_from_label_value(value)
493
+ return True, ''
494
+ except ValueError as e:
495
+ return False, str(e)
496
+
464
497
 
465
498
  class GFDLabelFormatter(GPULabelFormatter):
466
499
  """GPU Feature Discovery label formatter
@@ -565,17 +598,29 @@ def detect_gpu_label_formatter(
565
598
  for label, value in node.metadata.labels.items():
566
599
  node_labels[node.metadata.name].append((label, value))
567
600
 
568
- label_formatter = None
569
-
570
601
  # Check if the node labels contain any of the GPU label prefixes
571
602
  for lf in LABEL_FORMATTER_REGISTRY:
603
+ skip = False
572
604
  for _, label_list in node_labels.items():
573
- for label, _ in label_list:
605
+ for label, value in label_list:
574
606
  if lf.match_label_key(label):
575
- label_formatter = lf()
576
- return label_formatter, node_labels
607
+ valid, reason = lf.validate_label_value(value)
608
+ if valid:
609
+ return lf(), node_labels
610
+ else:
611
+ logger.warning(f'GPU label {label} matched for label '
612
+ f'formatter {lf.__class__.__name__}, '
613
+ f'but has invalid value {value}. '
614
+ f'Reason: {reason}. '
615
+ 'Skipping...')
616
+ skip = True
617
+ break
618
+ if skip:
619
+ break
620
+ if skip:
621
+ continue
577
622
 
578
- return label_formatter, node_labels
623
+ return None, node_labels
579
624
 
580
625
 
581
626
  class Autoscaler:
@@ -754,6 +799,8 @@ class GKEAutoscaler(Autoscaler):
754
799
  f'checking {node_pool_name} for TPU {requested_acc_type}:'
755
800
  f'{requested_acc_count}')
756
801
  if 'resourceLabels' in node_config:
802
+ requested_acc_type, requested_acc_count = normalize_tpu_accelerator_name(
803
+ requested_acc_type)
757
804
  accelerator_exists = cls._node_pool_has_tpu_capacity(
758
805
  node_config['resourceLabels'], machine_type,
759
806
  requested_acc_type, requested_acc_count)
@@ -993,7 +1040,7 @@ def check_instance_fits(context: Optional[str],
993
1040
  'Maximum resources found on a single node: '
994
1041
  f'{max_cpu} CPUs, {common_utils.format_float(max_mem)}G Memory')
995
1042
 
996
- def check_tpu_fits(candidate_instance_type: 'KubernetesInstanceType',
1043
+ def check_tpu_fits(acc_type: str, acc_count: int,
997
1044
  node_list: List[Any]) -> Tuple[bool, Optional[str]]:
998
1045
  """Checks if the instance fits on the cluster based on requested TPU.
999
1046
 
@@ -1003,8 +1050,6 @@ def check_instance_fits(context: Optional[str],
1003
1050
  node (node_tpu_chip_count) and the total TPU chips across the entire
1004
1051
  podslice (topology_chip_count) are correctly handled.
1005
1052
  """
1006
- acc_type = candidate_instance_type.accelerator_type
1007
- acc_count = candidate_instance_type.accelerator_count
1008
1053
  tpu_list_in_cluster = []
1009
1054
  for node in node_list:
1010
1055
  if acc_type == node.metadata.labels[
@@ -1055,7 +1100,8 @@ def check_instance_fits(context: Optional[str],
1055
1100
  if is_tpu_on_gke(acc_type):
1056
1101
  # If requested accelerator is a TPU type, check if the cluster
1057
1102
  # has sufficient TPU resource to meet the requirement.
1058
- fits, reason = check_tpu_fits(k8s_instance_type, gpu_nodes)
1103
+ acc_type, acc_count = normalize_tpu_accelerator_name(acc_type)
1104
+ fits, reason = check_tpu_fits(acc_type, acc_count, gpu_nodes)
1059
1105
  if reason is not None:
1060
1106
  return fits, reason
1061
1107
  else:
@@ -1141,8 +1187,8 @@ def get_accelerator_label_key_values(
1141
1187
 
1142
1188
  is_ssh_node_pool = context.startswith('ssh-') if context else False
1143
1189
  cloud_name = 'SSH Node Pool' if is_ssh_node_pool else 'Kubernetes cluster'
1144
- context_display_name = context.lstrip('ssh-') if (
1145
- context and is_ssh_node_pool) else context
1190
+ context_display_name = common_utils.removeprefix(
1191
+ context, 'ssh-') if (context and is_ssh_node_pool) else context
1146
1192
 
1147
1193
  autoscaler_type = get_autoscaler_type()
1148
1194
  if autoscaler_type is not None:
@@ -2911,7 +2957,8 @@ def get_skypilot_pods(context: Optional[str] = None) -> List[Any]:
2911
2957
 
2912
2958
  def is_tpu_on_gke(accelerator: str) -> bool:
2913
2959
  """Determines if the given accelerator is a TPU supported on GKE."""
2914
- return accelerator in GKE_TPU_ACCELERATOR_TO_GENERATION
2960
+ normalized, _ = normalize_tpu_accelerator_name(accelerator)
2961
+ return normalized in GKE_TPU_ACCELERATOR_TO_GENERATION
2915
2962
 
2916
2963
 
2917
2964
  def get_node_accelerator_count(attribute_dict: dict) -> int:
sky/resources.py CHANGED
@@ -480,7 +480,7 @@ class Resources:
480
480
  if self.region is not None:
481
481
  region_name = self.region
482
482
  if self.region.startswith('ssh-'):
483
- region_name = self.region.lstrip('ssh-')
483
+ region_name = common_utils.removeprefix(self.region, 'ssh-')
484
484
  region_str = f', region={region_name}'
485
485
  zone_str = ''
486
486
  if self.zone is not None:
@@ -1868,7 +1868,7 @@ class Resources:
1868
1868
  not isinstance(accelerators, set)):
1869
1869
  with ux_utils.print_exception_no_traceback():
1870
1870
  raise ValueError(
1871
- 'Cannot specify multiple "accelerators" with prefered '
1871
+ 'Cannot specify multiple "accelerators" with preferred '
1872
1872
  'order (i.e., list of accelerators) with "any_of" '
1873
1873
  'in resources.')
1874
1874