skypilot-nightly 1.0.0.dev20241227__py3-none-any.whl → 1.0.0.dev20250124__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sky/__init__.py +2 -2
  2. sky/adaptors/common.py +15 -9
  3. sky/adaptors/do.py +20 -0
  4. sky/adaptors/oci.py +32 -1
  5. sky/authentication.py +20 -8
  6. sky/backends/backend_utils.py +44 -0
  7. sky/backends/cloud_vm_ray_backend.py +202 -41
  8. sky/backends/wheel_utils.py +4 -1
  9. sky/check.py +31 -1
  10. sky/cli.py +39 -43
  11. sky/cloud_stores.py +71 -2
  12. sky/clouds/__init__.py +2 -0
  13. sky/clouds/aws.py +137 -50
  14. sky/clouds/cloud.py +4 -0
  15. sky/clouds/do.py +303 -0
  16. sky/clouds/gcp.py +9 -0
  17. sky/clouds/kubernetes.py +3 -3
  18. sky/clouds/oci.py +20 -9
  19. sky/clouds/service_catalog/__init__.py +7 -3
  20. sky/clouds/service_catalog/constants.py +1 -1
  21. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +10 -51
  22. sky/clouds/service_catalog/do_catalog.py +111 -0
  23. sky/clouds/service_catalog/kubernetes_catalog.py +14 -0
  24. sky/clouds/utils/oci_utils.py +15 -2
  25. sky/core.py +8 -5
  26. sky/data/data_transfer.py +37 -0
  27. sky/data/data_utils.py +19 -4
  28. sky/data/mounting_utils.py +99 -15
  29. sky/data/storage.py +961 -130
  30. sky/global_user_state.py +1 -1
  31. sky/jobs/__init__.py +2 -0
  32. sky/jobs/constants.py +8 -7
  33. sky/jobs/controller.py +19 -22
  34. sky/jobs/core.py +46 -2
  35. sky/jobs/recovery_strategy.py +114 -143
  36. sky/jobs/scheduler.py +283 -0
  37. sky/jobs/state.py +290 -21
  38. sky/jobs/utils.py +346 -95
  39. sky/optimizer.py +6 -3
  40. sky/provision/aws/config.py +59 -29
  41. sky/provision/azure/instance.py +1 -1
  42. sky/provision/do/__init__.py +11 -0
  43. sky/provision/do/config.py +14 -0
  44. sky/provision/do/constants.py +10 -0
  45. sky/provision/do/instance.py +287 -0
  46. sky/provision/do/utils.py +306 -0
  47. sky/provision/docker_utils.py +22 -11
  48. sky/provision/gcp/instance_utils.py +15 -9
  49. sky/provision/kubernetes/instance.py +3 -2
  50. sky/provision/kubernetes/utils.py +125 -20
  51. sky/provision/oci/query_utils.py +17 -14
  52. sky/provision/provisioner.py +0 -1
  53. sky/provision/runpod/instance.py +10 -1
  54. sky/provision/runpod/utils.py +170 -13
  55. sky/resources.py +1 -1
  56. sky/serve/autoscalers.py +359 -301
  57. sky/serve/controller.py +10 -8
  58. sky/serve/core.py +84 -7
  59. sky/serve/load_balancer.py +27 -10
  60. sky/serve/replica_managers.py +1 -3
  61. sky/serve/serve_state.py +10 -5
  62. sky/serve/serve_utils.py +28 -1
  63. sky/serve/service.py +4 -3
  64. sky/serve/service_spec.py +31 -0
  65. sky/setup_files/dependencies.py +4 -1
  66. sky/skylet/constants.py +8 -4
  67. sky/skylet/events.py +7 -3
  68. sky/skylet/job_lib.py +10 -30
  69. sky/skylet/log_lib.py +8 -8
  70. sky/skylet/log_lib.pyi +3 -0
  71. sky/skylet/providers/command_runner.py +5 -7
  72. sky/skylet/skylet.py +1 -1
  73. sky/task.py +28 -1
  74. sky/templates/do-ray.yml.j2 +98 -0
  75. sky/templates/jobs-controller.yaml.j2 +41 -7
  76. sky/templates/runpod-ray.yml.j2 +13 -0
  77. sky/templates/sky-serve-controller.yaml.j2 +4 -0
  78. sky/usage/usage_lib.py +10 -2
  79. sky/utils/accelerator_registry.py +12 -8
  80. sky/utils/controller_utils.py +114 -39
  81. sky/utils/db_utils.py +18 -4
  82. sky/utils/kubernetes/deploy_remote_cluster.sh +5 -5
  83. sky/utils/log_utils.py +2 -0
  84. sky/utils/resources_utils.py +25 -21
  85. sky/utils/schemas.py +27 -0
  86. sky/utils/subprocess_utils.py +54 -10
  87. {skypilot_nightly-1.0.0.dev20241227.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/METADATA +23 -4
  88. {skypilot_nightly-1.0.0.dev20241227.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/RECORD +92 -82
  89. {skypilot_nightly-1.0.0.dev20241227.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/WHEEL +1 -1
  90. {skypilot_nightly-1.0.0.dev20241227.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/LICENSE +0 -0
  91. {skypilot_nightly-1.0.0.dev20241227.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/entry_points.txt +0 -0
  92. {skypilot_nightly-1.0.0.dev20241227.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,306 @@
1
+ """DigitalOcean API client wrapper for SkyPilot.
2
+
3
+ Example usage of `pydo` client library was mostly taken from here:
4
+ https://github.com/digitalocean/pydo/blob/main/examples/poc_droplets_volumes_sshkeys.py
5
+ """
6
+
7
+ import copy
8
+ import os
9
+ import typing
10
+ from typing import Any, Dict, List, Optional
11
+ import urllib
12
+ import uuid
13
+
14
+ from sky import sky_logging
15
+ from sky.adaptors import do
16
+ from sky.provision import common
17
+ from sky.provision import constants as provision_constants
18
+ from sky.provision.do import constants
19
+ from sky.utils import common_utils
20
+
21
+ if typing.TYPE_CHECKING:
22
+ from sky import resources
23
+ from sky import status_lib
24
+
25
+ logger = sky_logging.init_logger(__name__)
26
+
27
+ POSSIBLE_CREDENTIALS_PATHS = [
28
+ os.path.expanduser(
29
+ '~/Library/Application Support/doctl/config.yaml'), # OS X
30
+ os.path.expanduser(
31
+ os.path.join(os.getenv('XDG_CONFIG_HOME', '~/.config/'),
32
+ 'doctl/config.yaml')), # Linux
33
+ ]
34
+ INITIAL_BACKOFF_SECONDS = 10
35
+ MAX_BACKOFF_FACTOR = 10
36
+ MAX_ATTEMPTS = 6
37
+ SSH_KEY_NAME_ON_DO = f'sky-key-{common_utils.get_user_hash()}'
38
+
39
+ CREDENTIALS_PATH = '~/.config/doctl/config.yaml'
40
+ _client = None
41
+ _ssh_key_id = None
42
+
43
+
44
+ class DigitalOceanError(Exception):
45
+ pass
46
+
47
+
48
+ def _init_client():
49
+ global _client, CREDENTIALS_PATH
50
+ assert _client is None
51
+ CREDENTIALS_PATH = None
52
+ credentials_found = 0
53
+ for path in POSSIBLE_CREDENTIALS_PATHS:
54
+ if os.path.exists(path):
55
+ CREDENTIALS_PATH = path
56
+ credentials_found += 1
57
+ logger.debug(f'Digital Ocean credential path found at {path}')
58
+ if not credentials_found > 1:
59
+ logger.debug('more than 1 credential file found')
60
+ if CREDENTIALS_PATH is None:
61
+ raise DigitalOceanError(
62
+ 'no credentials file found from '
63
+ f'the following paths {POSSIBLE_CREDENTIALS_PATHS}')
64
+
65
+ # attempt default context
66
+ credentials = common_utils.read_yaml(CREDENTIALS_PATH)
67
+ default_token = credentials.get('access-token', None)
68
+ if default_token is not None:
69
+ try:
70
+ test_client = do.pydo.Client(token=default_token)
71
+ test_client.droplets.list()
72
+ logger.debug('trying `default` context')
73
+ _client = test_client
74
+ return _client
75
+ except do.exceptions().HttpResponseError:
76
+ pass
77
+
78
+ auth_contexts = credentials.get('auth-contexts', None)
79
+ if auth_contexts is not None:
80
+ for context, api_token in auth_contexts.items():
81
+ try:
82
+ test_client = do.pydo.Client(token=api_token)
83
+ test_client.droplets.list()
84
+ logger.debug(f'using {context} context')
85
+ _client = test_client
86
+ break
87
+ except do.exceptions().HttpResponseError:
88
+ continue
89
+ else:
90
+ raise DigitalOceanError(
91
+ 'no valid api tokens found try '
92
+ 'setting a new API token with `doctl auth init`')
93
+ return _client
94
+
95
+
96
+ def client():
97
+ global _client
98
+ if _client is None:
99
+ _client = _init_client()
100
+ return _client
101
+
102
+
103
+ def ssh_key_id(public_key: str):
104
+ global _ssh_key_id
105
+ if _ssh_key_id is None:
106
+ page = 1
107
+ paginated = True
108
+ while paginated:
109
+ try:
110
+ resp = client().ssh_keys.list(per_page=50, page=page)
111
+ for ssh_key in resp['ssh_keys']:
112
+ if ssh_key['public_key'] == public_key:
113
+ _ssh_key_id = ssh_key
114
+ return _ssh_key_id
115
+ except do.exceptions().HttpResponseError as err:
116
+ raise DigitalOceanError(
117
+ f'Error: {err.status_code} {err.reason}: '
118
+ f'{err.error.message}') from err
119
+
120
+ pages = resp['links']
121
+ if 'pages' in pages and 'next' in pages['pages']:
122
+ pages = pages['pages']
123
+ parsed_url = urllib.parse.urlparse(pages['next'])
124
+ page = int(urllib.parse.parse_qs(parsed_url.query)['page'][0])
125
+ else:
126
+ paginated = False
127
+
128
+ request = {
129
+ 'public_key': public_key,
130
+ 'name': SSH_KEY_NAME_ON_DO,
131
+ }
132
+ _ssh_key_id = client().ssh_keys.create(body=request)['ssh_key']
133
+ return _ssh_key_id
134
+
135
+
136
+ def _create_volume(request: Dict[str, Any]) -> Dict[str, Any]:
137
+ try:
138
+ resp = client().volumes.create(body=request)
139
+ volume = resp['volume']
140
+ except do.exceptions().HttpResponseError as err:
141
+ raise DigitalOceanError(
142
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
143
+ ) from err
144
+ else:
145
+ return volume
146
+
147
+
148
+ def _create_droplet(request: Dict[str, Any]) -> Dict[str, Any]:
149
+ try:
150
+ resp = client().droplets.create(body=request)
151
+ droplet_id = resp['droplet']['id']
152
+
153
+ get_resp = client().droplets.get(droplet_id)
154
+ droplet = get_resp['droplet']
155
+ except do.exceptions().HttpResponseError as err:
156
+ raise DigitalOceanError(
157
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
158
+ ) from err
159
+ return droplet
160
+
161
+
162
+ def create_instance(region: str, cluster_name_on_cloud: str, instance_type: str,
163
+ config: common.ProvisionConfig) -> Dict[str, Any]:
164
+ """Creates a instance and mounts the requested block storage
165
+
166
+ Args:
167
+ region (str): instance region
168
+ instance_name (str): name of instance
169
+ config (common.ProvisionConfig): provisioner configuration
170
+
171
+ Returns:
172
+ Dict[str, Any]: instance metadata
173
+ """
174
+ # sort tags by key to support deterministic unit test stubbing
175
+ tags = dict(sorted(copy.deepcopy(config.tags).items()))
176
+ tags = {
177
+ 'Name': cluster_name_on_cloud,
178
+ provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
179
+ provision_constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud,
180
+ **tags
181
+ }
182
+ tags = [f'{key}:{value}' for key, value in tags.items()]
183
+ default_image = constants.GPU_IMAGES.get(
184
+ config.node_config['InstanceType'],
185
+ 'gpu-h100x1-base',
186
+ )
187
+ image_id = config.node_config['ImageId']
188
+ image_id = image_id if image_id is not None else default_image
189
+ instance_name = (f'{cluster_name_on_cloud}-'
190
+ f'{uuid.uuid4().hex[:4]}-{instance_type}')
191
+ instance_request = {
192
+ 'name': instance_name,
193
+ 'region': region,
194
+ 'size': config.node_config['InstanceType'],
195
+ 'image': image_id,
196
+ 'ssh_keys': [
197
+ ssh_key_id(
198
+ config.authentication_config['ssh_public_key'])['fingerprint']
199
+ ],
200
+ 'tags': tags,
201
+ }
202
+ instance = _create_droplet(instance_request)
203
+
204
+ volume_request = {
205
+ 'size_gigabytes': config.node_config['DiskSize'],
206
+ 'name': instance_name,
207
+ 'region': region,
208
+ 'filesystem_type': 'ext4',
209
+ 'tags': tags
210
+ }
211
+ volume = _create_volume(volume_request)
212
+
213
+ attach_request = {'type': 'attach', 'droplet_id': instance['id']}
214
+ try:
215
+ client().volume_actions.post_by_id(volume['id'], attach_request)
216
+ except do.exceptions().HttpResponseError as err:
217
+ raise DigitalOceanError(
218
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
219
+ ) from err
220
+ logger.debug(f'{instance_name} created')
221
+ return instance
222
+
223
+
224
+ def start_instance(instance: Dict[str, Any]):
225
+ try:
226
+ client().droplet_actions.post(droplet_id=instance['id'],
227
+ body={'type': 'power_on'})
228
+ except do.exceptions().HttpResponseError as err:
229
+ raise DigitalOceanError(
230
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
231
+ ) from err
232
+
233
+
234
+ def stop_instance(instance: Dict[str, Any]):
235
+ try:
236
+ client().droplet_actions.post(
237
+ droplet_id=instance['id'],
238
+ body={'type': 'shutdown'},
239
+ )
240
+ except do.exceptions().HttpResponseError as err:
241
+ raise DigitalOceanError(
242
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
243
+ ) from err
244
+
245
+
246
+ def down_instance(instance: Dict[str, Any]):
247
+ # We use dangerous destroy to atomically delete
248
+ # block storage and instance for autodown
249
+ try:
250
+ client().droplets.destroy_with_associated_resources_dangerous(
251
+ droplet_id=instance['id'], x_dangerous=True)
252
+ except do.exceptions().HttpResponseError as err:
253
+ if 'a destroy is already in progress' in err.error.message:
254
+ return
255
+ raise DigitalOceanError(
256
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
257
+ ) from err
258
+
259
+
260
+ def rename_instance(instance: Dict[str, Any], new_name: str):
261
+ try:
262
+ client().droplet_actions.rename(droplet=instance['id'],
263
+ body={
264
+ 'type': 'rename',
265
+ 'name': new_name
266
+ })
267
+ except do.exceptions().HttpResponseError as err:
268
+ raise DigitalOceanError(
269
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
270
+ ) from err
271
+
272
+
273
+ def filter_instances(
274
+ cluster_name_on_cloud: str,
275
+ status_filters: Optional[List[str]] = None) -> Dict[str, Any]:
276
+ """Returns Dict mapping instance name
277
+ to instance metadata filtered by status
278
+ """
279
+
280
+ filtered_instances: Dict[str, Any] = {}
281
+ page = 1
282
+ paginated = True
283
+ while paginated:
284
+ try:
285
+ resp = client().droplets.list(
286
+ tag_name=f'{provision_constants.TAG_SKYPILOT_CLUSTER_NAME}:'
287
+ f'{cluster_name_on_cloud}',
288
+ per_page=50,
289
+ page=page)
290
+ for instance in resp['droplets']:
291
+ if status_filters is None or instance[
292
+ 'status'] in status_filters:
293
+ filtered_instances[instance['name']] = instance
294
+ except do.exceptions().HttpResponseError as err:
295
+ raise DigitalOceanError(
296
+ f'Error: {err.status_code} {err.reason}: {err.error.message}'
297
+ ) from err
298
+
299
+ pages = resp['links']
300
+ if 'pages' in pages and 'next' in pages['pages']:
301
+ pages = pages['pages']
302
+ parsed_url = urllib.parse.urlparse(pages['next'])
303
+ page = int(urllib.parse.parse_qs(parsed_url.query)['page'][0])
304
+ else:
305
+ paginated = False
306
+ return filtered_instances
@@ -38,6 +38,13 @@ class DockerLoginConfig:
38
38
  password: str
39
39
  server: str
40
40
 
41
+ def format_image(self, image: str) -> str:
42
+ """Format the image name with the server prefix."""
43
+ server_prefix = f'{self.server}/'
44
+ if not image.startswith(server_prefix):
45
+ return f'{server_prefix}{image}'
46
+ return image
47
+
41
48
  @classmethod
42
49
  def from_env_vars(cls, d: Dict[str, str]) -> 'DockerLoginConfig':
43
50
  return cls(
@@ -220,9 +227,7 @@ class DockerInitializer:
220
227
  wait_for_docker_daemon=True)
221
228
  # We automatically add the server prefix to the image name if
222
229
  # the user did not add it.
223
- server_prefix = f'{docker_login_config.server}/'
224
- if not specific_image.startswith(server_prefix):
225
- specific_image = f'{server_prefix}{specific_image}'
230
+ specific_image = docker_login_config.format_image(specific_image)
226
231
 
227
232
  if self.docker_config.get('pull_before_run', True):
228
233
  assert specific_image, ('Image must be included in config if ' +
@@ -338,14 +343,20 @@ class DockerInitializer:
338
343
  no_exist = 'NoExist'
339
344
  # SkyPilot: Add the current user to the docker group first (if needed),
340
345
  # before checking if docker is installed to avoid permission issues.
341
- cleaned_output = self._run(
342
- 'id -nG $USER | grep -qw docker || '
343
- 'sudo usermod -aG docker $USER > /dev/null 2>&1;'
344
- f'command -v {self.docker_cmd} || echo {no_exist!r}')
345
- if no_exist in cleaned_output or 'docker' not in cleaned_output:
346
- logger.error(
347
- f'{self.docker_cmd.capitalize()} not installed. Please use an '
348
- f'image with {self.docker_cmd.capitalize()} installed.')
346
+ docker_cmd = ('id -nG $USER | grep -qw docker || '
347
+ 'sudo usermod -aG docker $USER > /dev/null 2>&1;'
348
+ f'command -v {self.docker_cmd} || echo {no_exist!r}')
349
+ cleaned_output = self._run(docker_cmd)
350
+ timeout = 60 * 10 # 10 minute timeout
351
+ start = time.time()
352
+ while no_exist in cleaned_output or 'docker' not in cleaned_output:
353
+ if time.time() - start > timeout:
354
+ logger.error(
355
+ f'{self.docker_cmd.capitalize()} not installed. Please use '
356
+ f'an image with {self.docker_cmd.capitalize()} installed.')
357
+ return
358
+ time.sleep(5)
359
+ cleaned_output = self._run(docker_cmd)
349
360
 
350
361
  def _check_container_status(self):
351
362
  if self.initialized:
@@ -38,7 +38,7 @@ _FIREWALL_RESOURCE_NOT_FOUND_PATTERN = re.compile(
38
38
  r'The resource \'projects/.*/global/firewalls/.*\' was not found')
39
39
 
40
40
 
41
- def _retry_on_http_exception(
41
+ def _retry_on_gcp_http_exception(
42
42
  regex: Optional[str] = None,
43
43
  max_retries: int = GCP_MAX_RETRIES,
44
44
  retry_interval_s: int = GCP_RETRY_INTERVAL_SECONDS,
@@ -49,17 +49,18 @@ def _retry_on_http_exception(
49
49
 
50
50
  @functools.wraps(func)
51
51
  def wrapper(*args, **kwargs):
52
- exception_type = gcp.http_error_exception()
53
52
 
54
53
  def try_catch_exc():
55
54
  try:
56
55
  value = func(*args, **kwargs)
57
56
  return value
58
57
  except Exception as e: # pylint: disable=broad-except
59
- if not isinstance(e, exception_type) or (
60
- regex and not re.search(regex, str(e))):
61
- raise
62
- return e
58
+ if (isinstance(e, gcp.http_error_exception()) and
59
+ (regex is None or re.search(regex, str(e)))):
60
+ logger.error(
61
+ f'Retrying for gcp.http_error_exception: {e}')
62
+ return e
63
+ raise
63
64
 
64
65
  for _ in range(max_retries):
65
66
  ret = try_catch_exc()
@@ -431,7 +432,7 @@ class GCPComputeInstance(GCPInstance):
431
432
  logger.debug(
432
433
  f'Waiting GCP operation {operation["name"]} to be ready ...')
433
434
 
434
- @_retry_on_http_exception(
435
+ @_retry_on_gcp_http_exception(
435
436
  f'Failed to wait for operation {operation["name"]}')
436
437
  def call_operation(fn, timeout: int):
437
438
  request = fn(
@@ -613,6 +614,11 @@ class GCPComputeInstance(GCPInstance):
613
614
  return operation
614
615
 
615
616
  @classmethod
617
+ # When there is a cloud function running in parallel to set labels for
618
+ # newly created instances, it may fail with the following error:
619
+ # "Labels fingerprint either invalid or resource labels have changed"
620
+ # We should retry until the labels are set successfully.
621
+ @_retry_on_gcp_http_exception('Labels fingerprint either invalid')
616
622
  def set_labels(cls, project_id: str, availability_zone: str, node_id: str,
617
623
  labels: dict) -> None:
618
624
  node = cls.load_resource().instances().get(
@@ -1211,7 +1217,7 @@ class GCPTPUVMInstance(GCPInstance):
1211
1217
  """Poll for TPU operation until finished."""
1212
1218
  del project_id, region, zone # unused
1213
1219
 
1214
- @_retry_on_http_exception(
1220
+ @_retry_on_gcp_http_exception(
1215
1221
  f'Failed to wait for operation {operation["name"]}')
1216
1222
  def call_operation(fn, timeout: int):
1217
1223
  request = fn(name=operation['name'])
@@ -1379,7 +1385,7 @@ class GCPTPUVMInstance(GCPInstance):
1379
1385
  f'Failed to get VPC name for instance {instance}') from e
1380
1386
 
1381
1387
  @classmethod
1382
- @_retry_on_http_exception('unable to queue the operation')
1388
+ @_retry_on_gcp_http_exception('unable to queue the operation')
1383
1389
  def set_labels(cls, project_id: str, availability_zone: str, node_id: str,
1384
1390
  labels: dict) -> None:
1385
1391
  while True:
@@ -804,7 +804,8 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
804
804
 
805
805
  # Create pods in parallel
806
806
  pods = subprocess_utils.run_in_parallel(_create_pod_thread,
807
- range(to_start_count), _NUM_THREADS)
807
+ list(range(to_start_count)),
808
+ _NUM_THREADS)
808
809
 
809
810
  # Process created pods
810
811
  for pod in pods:
@@ -975,7 +976,7 @@ def terminate_instances(
975
976
  _terminate_node(namespace, context, pod_name)
976
977
 
977
978
  # Run pod termination in parallel
978
- subprocess_utils.run_in_parallel(_terminate_pod_thread, pods.items(),
979
+ subprocess_utils.run_in_parallel(_terminate_pod_thread, list(pods.items()),
979
980
  _NUM_THREADS)
980
981
 
981
982
 
@@ -7,6 +7,7 @@ import os
7
7
  import re
8
8
  import shutil
9
9
  import subprocess
10
+ import time
10
11
  import typing
11
12
  from typing import Any, Dict, List, Optional, Set, Tuple, Union
12
13
  from urllib.parse import urlparse
@@ -105,6 +106,75 @@ ANNOTATIONS_POD_NOT_FOUND_ERROR_MSG = ('Pod {pod_name} not found in namespace '
105
106
 
106
107
  logger = sky_logging.init_logger(__name__)
107
108
 
109
+ # Default retry settings for Kubernetes API calls
110
+ DEFAULT_MAX_RETRIES = 3
111
+ DEFAULT_RETRY_INTERVAL_SECONDS = 1
112
+
113
+
114
+ def _retry_on_error(max_retries=DEFAULT_MAX_RETRIES,
115
+ retry_interval=DEFAULT_RETRY_INTERVAL_SECONDS,
116
+ resource_type: Optional[str] = None):
117
+ """Decorator to retry Kubernetes API calls on transient failures.
118
+
119
+ Args:
120
+ max_retries: Maximum number of retry attempts
121
+ retry_interval: Initial seconds to wait between retries
122
+ resource_type: Type of resource being accessed (e.g. 'node', 'pod').
123
+ Used to provide more specific error messages.
124
+ """
125
+
126
+ def decorator(func):
127
+
128
+ @functools.wraps(func)
129
+ def wrapper(*args, **kwargs):
130
+ last_exception = None
131
+ backoff = common_utils.Backoff(initial_backoff=retry_interval,
132
+ max_backoff_factor=3)
133
+
134
+ for attempt in range(max_retries):
135
+ try:
136
+ return func(*args, **kwargs)
137
+ except (kubernetes.max_retry_error(),
138
+ kubernetes.api_exception(),
139
+ kubernetes.config_exception()) as e:
140
+ last_exception = e
141
+ # Don't retry on permanent errors like 401 (Unauthorized)
142
+ # or 403 (Forbidden)
143
+ if (isinstance(e, kubernetes.api_exception()) and
144
+ e.status in (401, 403)):
145
+ raise
146
+ if attempt < max_retries - 1:
147
+ sleep_time = backoff.current_backoff()
148
+ logger.debug(f'Kubernetes API call {func.__name__} '
149
+ f'failed with {str(e)}. Retrying in '
150
+ f'{sleep_time:.1f}s...')
151
+ time.sleep(sleep_time)
152
+ continue
153
+
154
+ # Format error message based on the type of exception
155
+ resource_msg = f' when trying to get {resource_type} info' \
156
+ if resource_type else ''
157
+ debug_cmd = f' To debug, run: kubectl get {resource_type}s' \
158
+ if resource_type else ''
159
+
160
+ if isinstance(last_exception, kubernetes.max_retry_error()):
161
+ error_msg = f'Timed out{resource_msg} from Kubernetes cluster.'
162
+ elif isinstance(last_exception, kubernetes.api_exception()):
163
+ error_msg = (f'Kubernetes API error{resource_msg}: '
164
+ f'{str(last_exception)}')
165
+ else:
166
+ error_msg = (f'Kubernetes configuration error{resource_msg}: '
167
+ f'{str(last_exception)}')
168
+
169
+ raise exceptions.ResourcesUnavailableError(
170
+ f'{error_msg}'
171
+ f' Please check if the cluster is healthy and retry.'
172
+ f'{debug_cmd}') from last_exception
173
+
174
+ return wrapper
175
+
176
+ return decorator
177
+
108
178
 
109
179
  class GPULabelFormatter:
110
180
  """Base class to define a GPU label formatter for a Kubernetes cluster
@@ -340,14 +410,15 @@ class GFDLabelFormatter(GPULabelFormatter):
340
410
  """
341
411
  canonical_gpu_names = [
342
412
  'A100-80GB', 'A100', 'A10G', 'H100', 'K80', 'M60', 'T4g', 'T4',
343
- 'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L4'
413
+ 'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L40', 'L4'
344
414
  ]
345
415
  for canonical_name in canonical_gpu_names:
346
416
  # A100-80G accelerator is A100-SXM-80GB or A100-PCIE-80GB
347
417
  if canonical_name == 'A100-80GB' and re.search(
348
418
  r'A100.*-80GB', value):
349
419
  return canonical_name
350
- elif canonical_name in value:
420
+ # Use word boundary matching to prevent substring matches
421
+ elif re.search(rf'\b{re.escape(canonical_name)}\b', value):
351
422
  return canonical_name
352
423
 
353
424
  # If we didn't find a canonical name:
@@ -445,6 +516,7 @@ def detect_accelerator_resource(
445
516
 
446
517
 
447
518
  @functools.lru_cache(maxsize=10)
519
+ @_retry_on_error(resource_type='node')
448
520
  def get_kubernetes_nodes(context: Optional[str] = None) -> List[Any]:
449
521
  """Gets the kubernetes nodes in the context.
450
522
 
@@ -453,17 +525,12 @@ def get_kubernetes_nodes(context: Optional[str] = None) -> List[Any]:
453
525
  if context is None:
454
526
  context = get_current_kube_config_context_name()
455
527
 
456
- try:
457
- nodes = kubernetes.core_api(context).list_node(
458
- _request_timeout=kubernetes.API_TIMEOUT).items
459
- except kubernetes.max_retry_error():
460
- raise exceptions.ResourcesUnavailableError(
461
- 'Timed out when trying to get node info from Kubernetes cluster. '
462
- 'Please check if the cluster is healthy and retry. To debug, run: '
463
- 'kubectl get nodes') from None
528
+ nodes = kubernetes.core_api(context).list_node(
529
+ _request_timeout=kubernetes.API_TIMEOUT).items
464
530
  return nodes
465
531
 
466
532
 
533
+ @_retry_on_error(resource_type='pod')
467
534
  def get_all_pods_in_kubernetes_cluster(
468
535
  context: Optional[str] = None) -> List[Any]:
469
536
  """Gets pods in all namespaces in kubernetes cluster indicated by context.
@@ -473,14 +540,8 @@ def get_all_pods_in_kubernetes_cluster(
473
540
  if context is None:
474
541
  context = get_current_kube_config_context_name()
475
542
 
476
- try:
477
- pods = kubernetes.core_api(context).list_pod_for_all_namespaces(
478
- _request_timeout=kubernetes.API_TIMEOUT).items
479
- except kubernetes.max_retry_error():
480
- raise exceptions.ResourcesUnavailableError(
481
- 'Timed out when trying to get pod info from Kubernetes cluster. '
482
- 'Please check if the cluster is healthy and retry. To debug, run: '
483
- 'kubectl get pods') from None
543
+ pods = kubernetes.core_api(context).list_pod_for_all_namespaces(
544
+ _request_timeout=kubernetes.API_TIMEOUT).items
484
545
  return pods
485
546
 
486
547
 
@@ -892,6 +953,52 @@ def check_credentials(context: Optional[str],
892
953
  return True, None
893
954
 
894
955
 
956
+ def check_pod_config(pod_config: dict) \
957
+ -> Tuple[bool, Optional[str]]:
958
+ """Check if the pod_config is a valid pod config
959
+
960
+ Using deserialize api to check the pod_config is valid or not.
961
+
962
+ Returns:
963
+ bool: True if pod_config is valid.
964
+ str: Error message about why the pod_config is invalid, None otherwise.
965
+ """
966
+ errors = []
967
+ # This api_client won't be used to send any requests, so there is no need to
968
+ # load kubeconfig
969
+ api_client = kubernetes.kubernetes.client.ApiClient()
970
+
971
+ # Used for kubernetes api_client deserialize function, the function will use
972
+ # data attr, the detail ref:
973
+ # https://github.com/kubernetes-client/python/blob/master/kubernetes/client/api_client.py#L244
974
+ class InnerResponse():
975
+
976
+ def __init__(self, data: dict):
977
+ self.data = json.dumps(data)
978
+
979
+ try:
980
+ # Validate metadata if present
981
+ if 'metadata' in pod_config:
982
+ try:
983
+ value = InnerResponse(pod_config['metadata'])
984
+ api_client.deserialize(
985
+ value, kubernetes.kubernetes.client.V1ObjectMeta)
986
+ except ValueError as e:
987
+ errors.append(f'Invalid metadata: {str(e)}')
988
+ # Validate spec if present
989
+ if 'spec' in pod_config:
990
+ try:
991
+ value = InnerResponse(pod_config['spec'])
992
+ api_client.deserialize(value,
993
+ kubernetes.kubernetes.client.V1PodSpec)
994
+ except ValueError as e:
995
+ errors.append(f'Invalid spec: {str(e)}')
996
+ return len(errors) == 0, '.'.join(errors)
997
+ except Exception as e: # pylint: disable=broad-except
998
+ errors.append(f'Validation error: {str(e)}')
999
+ return False, '.'.join(errors)
1000
+
1001
+
895
1002
  def is_kubeconfig_exec_auth(
896
1003
  context: Optional[str] = None) -> Tuple[bool, Optional[str]]:
897
1004
  """Checks if the kubeconfig file uses exec-based authentication
@@ -1711,8 +1818,6 @@ def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]):
1711
1818
  else:
1712
1819
  destination[key].extend(value)
1713
1820
  else:
1714
- if destination is None:
1715
- destination = {}
1716
1821
  destination[key] = value
1717
1822
 
1718
1823