skypilot-nightly 1.0.0.dev20241227__py3-none-any.whl → 1.0.0.dev20250124__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +2 -2
- sky/adaptors/common.py +15 -9
- sky/adaptors/do.py +20 -0
- sky/adaptors/oci.py +32 -1
- sky/authentication.py +20 -8
- sky/backends/backend_utils.py +44 -0
- sky/backends/cloud_vm_ray_backend.py +202 -41
- sky/backends/wheel_utils.py +4 -1
- sky/check.py +31 -1
- sky/cli.py +39 -43
- sky/cloud_stores.py +71 -2
- sky/clouds/__init__.py +2 -0
- sky/clouds/aws.py +137 -50
- sky/clouds/cloud.py +4 -0
- sky/clouds/do.py +303 -0
- sky/clouds/gcp.py +9 -0
- sky/clouds/kubernetes.py +3 -3
- sky/clouds/oci.py +20 -9
- sky/clouds/service_catalog/__init__.py +7 -3
- sky/clouds/service_catalog/constants.py +1 -1
- sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +10 -51
- sky/clouds/service_catalog/do_catalog.py +111 -0
- sky/clouds/service_catalog/kubernetes_catalog.py +14 -0
- sky/clouds/utils/oci_utils.py +15 -2
- sky/core.py +8 -5
- sky/data/data_transfer.py +37 -0
- sky/data/data_utils.py +19 -4
- sky/data/mounting_utils.py +99 -15
- sky/data/storage.py +961 -130
- sky/global_user_state.py +1 -1
- sky/jobs/__init__.py +2 -0
- sky/jobs/constants.py +8 -7
- sky/jobs/controller.py +19 -22
- sky/jobs/core.py +46 -2
- sky/jobs/recovery_strategy.py +114 -143
- sky/jobs/scheduler.py +283 -0
- sky/jobs/state.py +290 -21
- sky/jobs/utils.py +346 -95
- sky/optimizer.py +6 -3
- sky/provision/aws/config.py +59 -29
- sky/provision/azure/instance.py +1 -1
- sky/provision/do/__init__.py +11 -0
- sky/provision/do/config.py +14 -0
- sky/provision/do/constants.py +10 -0
- sky/provision/do/instance.py +287 -0
- sky/provision/do/utils.py +306 -0
- sky/provision/docker_utils.py +22 -11
- sky/provision/gcp/instance_utils.py +15 -9
- sky/provision/kubernetes/instance.py +3 -2
- sky/provision/kubernetes/utils.py +125 -20
- sky/provision/oci/query_utils.py +17 -14
- sky/provision/provisioner.py +0 -1
- sky/provision/runpod/instance.py +10 -1
- sky/provision/runpod/utils.py +170 -13
- sky/resources.py +1 -1
- sky/serve/autoscalers.py +359 -301
- sky/serve/controller.py +10 -8
- sky/serve/core.py +84 -7
- sky/serve/load_balancer.py +27 -10
- sky/serve/replica_managers.py +1 -3
- sky/serve/serve_state.py +10 -5
- sky/serve/serve_utils.py +28 -1
- sky/serve/service.py +4 -3
- sky/serve/service_spec.py +31 -0
- sky/setup_files/dependencies.py +4 -1
- sky/skylet/constants.py +8 -4
- sky/skylet/events.py +7 -3
- sky/skylet/job_lib.py +10 -30
- sky/skylet/log_lib.py +8 -8
- sky/skylet/log_lib.pyi +3 -0
- sky/skylet/providers/command_runner.py +5 -7
- sky/skylet/skylet.py +1 -1
- sky/task.py +28 -1
- sky/templates/do-ray.yml.j2 +98 -0
- sky/templates/jobs-controller.yaml.j2 +41 -7
- sky/templates/runpod-ray.yml.j2 +13 -0
- sky/templates/sky-serve-controller.yaml.j2 +4 -0
- sky/usage/usage_lib.py +10 -2
- sky/utils/accelerator_registry.py +12 -8
- sky/utils/controller_utils.py +114 -39
- sky/utils/db_utils.py +18 -4
- sky/utils/kubernetes/deploy_remote_cluster.sh +5 -5
- sky/utils/log_utils.py +2 -0
- sky/utils/resources_utils.py +25 -21
- sky/utils/schemas.py +27 -0
- sky/utils/subprocess_utils.py +54 -10
- {skypilot_nightly-1.0.0.dev20241227.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/METADATA +23 -4
- {skypilot_nightly-1.0.0.dev20241227.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/RECORD +92 -82
- {skypilot_nightly-1.0.0.dev20241227.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/WHEEL +1 -1
- {skypilot_nightly-1.0.0.dev20241227.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20241227.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20241227.dist-info → skypilot_nightly-1.0.0.dev20250124.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,306 @@
|
|
1
|
+
"""DigitalOcean API client wrapper for SkyPilot.
|
2
|
+
|
3
|
+
Example usage of `pydo` client library was mostly taken from here:
|
4
|
+
https://github.com/digitalocean/pydo/blob/main/examples/poc_droplets_volumes_sshkeys.py
|
5
|
+
"""
|
6
|
+
|
7
|
+
import copy
|
8
|
+
import os
|
9
|
+
import typing
|
10
|
+
from typing import Any, Dict, List, Optional
|
11
|
+
import urllib
|
12
|
+
import uuid
|
13
|
+
|
14
|
+
from sky import sky_logging
|
15
|
+
from sky.adaptors import do
|
16
|
+
from sky.provision import common
|
17
|
+
from sky.provision import constants as provision_constants
|
18
|
+
from sky.provision.do import constants
|
19
|
+
from sky.utils import common_utils
|
20
|
+
|
21
|
+
if typing.TYPE_CHECKING:
|
22
|
+
from sky import resources
|
23
|
+
from sky import status_lib
|
24
|
+
|
25
|
+
logger = sky_logging.init_logger(__name__)
|
26
|
+
|
27
|
+
POSSIBLE_CREDENTIALS_PATHS = [
|
28
|
+
os.path.expanduser(
|
29
|
+
'~/Library/Application Support/doctl/config.yaml'), # OS X
|
30
|
+
os.path.expanduser(
|
31
|
+
os.path.join(os.getenv('XDG_CONFIG_HOME', '~/.config/'),
|
32
|
+
'doctl/config.yaml')), # Linux
|
33
|
+
]
|
34
|
+
INITIAL_BACKOFF_SECONDS = 10
|
35
|
+
MAX_BACKOFF_FACTOR = 10
|
36
|
+
MAX_ATTEMPTS = 6
|
37
|
+
SSH_KEY_NAME_ON_DO = f'sky-key-{common_utils.get_user_hash()}'
|
38
|
+
|
39
|
+
CREDENTIALS_PATH = '~/.config/doctl/config.yaml'
|
40
|
+
_client = None
|
41
|
+
_ssh_key_id = None
|
42
|
+
|
43
|
+
|
44
|
+
class DigitalOceanError(Exception):
|
45
|
+
pass
|
46
|
+
|
47
|
+
|
48
|
+
def _init_client():
|
49
|
+
global _client, CREDENTIALS_PATH
|
50
|
+
assert _client is None
|
51
|
+
CREDENTIALS_PATH = None
|
52
|
+
credentials_found = 0
|
53
|
+
for path in POSSIBLE_CREDENTIALS_PATHS:
|
54
|
+
if os.path.exists(path):
|
55
|
+
CREDENTIALS_PATH = path
|
56
|
+
credentials_found += 1
|
57
|
+
logger.debug(f'Digital Ocean credential path found at {path}')
|
58
|
+
if not credentials_found > 1:
|
59
|
+
logger.debug('more than 1 credential file found')
|
60
|
+
if CREDENTIALS_PATH is None:
|
61
|
+
raise DigitalOceanError(
|
62
|
+
'no credentials file found from '
|
63
|
+
f'the following paths {POSSIBLE_CREDENTIALS_PATHS}')
|
64
|
+
|
65
|
+
# attempt default context
|
66
|
+
credentials = common_utils.read_yaml(CREDENTIALS_PATH)
|
67
|
+
default_token = credentials.get('access-token', None)
|
68
|
+
if default_token is not None:
|
69
|
+
try:
|
70
|
+
test_client = do.pydo.Client(token=default_token)
|
71
|
+
test_client.droplets.list()
|
72
|
+
logger.debug('trying `default` context')
|
73
|
+
_client = test_client
|
74
|
+
return _client
|
75
|
+
except do.exceptions().HttpResponseError:
|
76
|
+
pass
|
77
|
+
|
78
|
+
auth_contexts = credentials.get('auth-contexts', None)
|
79
|
+
if auth_contexts is not None:
|
80
|
+
for context, api_token in auth_contexts.items():
|
81
|
+
try:
|
82
|
+
test_client = do.pydo.Client(token=api_token)
|
83
|
+
test_client.droplets.list()
|
84
|
+
logger.debug(f'using {context} context')
|
85
|
+
_client = test_client
|
86
|
+
break
|
87
|
+
except do.exceptions().HttpResponseError:
|
88
|
+
continue
|
89
|
+
else:
|
90
|
+
raise DigitalOceanError(
|
91
|
+
'no valid api tokens found try '
|
92
|
+
'setting a new API token with `doctl auth init`')
|
93
|
+
return _client
|
94
|
+
|
95
|
+
|
96
|
+
def client():
|
97
|
+
global _client
|
98
|
+
if _client is None:
|
99
|
+
_client = _init_client()
|
100
|
+
return _client
|
101
|
+
|
102
|
+
|
103
|
+
def ssh_key_id(public_key: str):
|
104
|
+
global _ssh_key_id
|
105
|
+
if _ssh_key_id is None:
|
106
|
+
page = 1
|
107
|
+
paginated = True
|
108
|
+
while paginated:
|
109
|
+
try:
|
110
|
+
resp = client().ssh_keys.list(per_page=50, page=page)
|
111
|
+
for ssh_key in resp['ssh_keys']:
|
112
|
+
if ssh_key['public_key'] == public_key:
|
113
|
+
_ssh_key_id = ssh_key
|
114
|
+
return _ssh_key_id
|
115
|
+
except do.exceptions().HttpResponseError as err:
|
116
|
+
raise DigitalOceanError(
|
117
|
+
f'Error: {err.status_code} {err.reason}: '
|
118
|
+
f'{err.error.message}') from err
|
119
|
+
|
120
|
+
pages = resp['links']
|
121
|
+
if 'pages' in pages and 'next' in pages['pages']:
|
122
|
+
pages = pages['pages']
|
123
|
+
parsed_url = urllib.parse.urlparse(pages['next'])
|
124
|
+
page = int(urllib.parse.parse_qs(parsed_url.query)['page'][0])
|
125
|
+
else:
|
126
|
+
paginated = False
|
127
|
+
|
128
|
+
request = {
|
129
|
+
'public_key': public_key,
|
130
|
+
'name': SSH_KEY_NAME_ON_DO,
|
131
|
+
}
|
132
|
+
_ssh_key_id = client().ssh_keys.create(body=request)['ssh_key']
|
133
|
+
return _ssh_key_id
|
134
|
+
|
135
|
+
|
136
|
+
def _create_volume(request: Dict[str, Any]) -> Dict[str, Any]:
|
137
|
+
try:
|
138
|
+
resp = client().volumes.create(body=request)
|
139
|
+
volume = resp['volume']
|
140
|
+
except do.exceptions().HttpResponseError as err:
|
141
|
+
raise DigitalOceanError(
|
142
|
+
f'Error: {err.status_code} {err.reason}: {err.error.message}'
|
143
|
+
) from err
|
144
|
+
else:
|
145
|
+
return volume
|
146
|
+
|
147
|
+
|
148
|
+
def _create_droplet(request: Dict[str, Any]) -> Dict[str, Any]:
|
149
|
+
try:
|
150
|
+
resp = client().droplets.create(body=request)
|
151
|
+
droplet_id = resp['droplet']['id']
|
152
|
+
|
153
|
+
get_resp = client().droplets.get(droplet_id)
|
154
|
+
droplet = get_resp['droplet']
|
155
|
+
except do.exceptions().HttpResponseError as err:
|
156
|
+
raise DigitalOceanError(
|
157
|
+
f'Error: {err.status_code} {err.reason}: {err.error.message}'
|
158
|
+
) from err
|
159
|
+
return droplet
|
160
|
+
|
161
|
+
|
162
|
+
def create_instance(region: str, cluster_name_on_cloud: str, instance_type: str,
|
163
|
+
config: common.ProvisionConfig) -> Dict[str, Any]:
|
164
|
+
"""Creates a instance and mounts the requested block storage
|
165
|
+
|
166
|
+
Args:
|
167
|
+
region (str): instance region
|
168
|
+
instance_name (str): name of instance
|
169
|
+
config (common.ProvisionConfig): provisioner configuration
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
Dict[str, Any]: instance metadata
|
173
|
+
"""
|
174
|
+
# sort tags by key to support deterministic unit test stubbing
|
175
|
+
tags = dict(sorted(copy.deepcopy(config.tags).items()))
|
176
|
+
tags = {
|
177
|
+
'Name': cluster_name_on_cloud,
|
178
|
+
provision_constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud,
|
179
|
+
provision_constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud,
|
180
|
+
**tags
|
181
|
+
}
|
182
|
+
tags = [f'{key}:{value}' for key, value in tags.items()]
|
183
|
+
default_image = constants.GPU_IMAGES.get(
|
184
|
+
config.node_config['InstanceType'],
|
185
|
+
'gpu-h100x1-base',
|
186
|
+
)
|
187
|
+
image_id = config.node_config['ImageId']
|
188
|
+
image_id = image_id if image_id is not None else default_image
|
189
|
+
instance_name = (f'{cluster_name_on_cloud}-'
|
190
|
+
f'{uuid.uuid4().hex[:4]}-{instance_type}')
|
191
|
+
instance_request = {
|
192
|
+
'name': instance_name,
|
193
|
+
'region': region,
|
194
|
+
'size': config.node_config['InstanceType'],
|
195
|
+
'image': image_id,
|
196
|
+
'ssh_keys': [
|
197
|
+
ssh_key_id(
|
198
|
+
config.authentication_config['ssh_public_key'])['fingerprint']
|
199
|
+
],
|
200
|
+
'tags': tags,
|
201
|
+
}
|
202
|
+
instance = _create_droplet(instance_request)
|
203
|
+
|
204
|
+
volume_request = {
|
205
|
+
'size_gigabytes': config.node_config['DiskSize'],
|
206
|
+
'name': instance_name,
|
207
|
+
'region': region,
|
208
|
+
'filesystem_type': 'ext4',
|
209
|
+
'tags': tags
|
210
|
+
}
|
211
|
+
volume = _create_volume(volume_request)
|
212
|
+
|
213
|
+
attach_request = {'type': 'attach', 'droplet_id': instance['id']}
|
214
|
+
try:
|
215
|
+
client().volume_actions.post_by_id(volume['id'], attach_request)
|
216
|
+
except do.exceptions().HttpResponseError as err:
|
217
|
+
raise DigitalOceanError(
|
218
|
+
f'Error: {err.status_code} {err.reason}: {err.error.message}'
|
219
|
+
) from err
|
220
|
+
logger.debug(f'{instance_name} created')
|
221
|
+
return instance
|
222
|
+
|
223
|
+
|
224
|
+
def start_instance(instance: Dict[str, Any]):
|
225
|
+
try:
|
226
|
+
client().droplet_actions.post(droplet_id=instance['id'],
|
227
|
+
body={'type': 'power_on'})
|
228
|
+
except do.exceptions().HttpResponseError as err:
|
229
|
+
raise DigitalOceanError(
|
230
|
+
f'Error: {err.status_code} {err.reason}: {err.error.message}'
|
231
|
+
) from err
|
232
|
+
|
233
|
+
|
234
|
+
def stop_instance(instance: Dict[str, Any]):
|
235
|
+
try:
|
236
|
+
client().droplet_actions.post(
|
237
|
+
droplet_id=instance['id'],
|
238
|
+
body={'type': 'shutdown'},
|
239
|
+
)
|
240
|
+
except do.exceptions().HttpResponseError as err:
|
241
|
+
raise DigitalOceanError(
|
242
|
+
f'Error: {err.status_code} {err.reason}: {err.error.message}'
|
243
|
+
) from err
|
244
|
+
|
245
|
+
|
246
|
+
def down_instance(instance: Dict[str, Any]):
|
247
|
+
# We use dangerous destroy to atomically delete
|
248
|
+
# block storage and instance for autodown
|
249
|
+
try:
|
250
|
+
client().droplets.destroy_with_associated_resources_dangerous(
|
251
|
+
droplet_id=instance['id'], x_dangerous=True)
|
252
|
+
except do.exceptions().HttpResponseError as err:
|
253
|
+
if 'a destroy is already in progress' in err.error.message:
|
254
|
+
return
|
255
|
+
raise DigitalOceanError(
|
256
|
+
f'Error: {err.status_code} {err.reason}: {err.error.message}'
|
257
|
+
) from err
|
258
|
+
|
259
|
+
|
260
|
+
def rename_instance(instance: Dict[str, Any], new_name: str):
|
261
|
+
try:
|
262
|
+
client().droplet_actions.rename(droplet=instance['id'],
|
263
|
+
body={
|
264
|
+
'type': 'rename',
|
265
|
+
'name': new_name
|
266
|
+
})
|
267
|
+
except do.exceptions().HttpResponseError as err:
|
268
|
+
raise DigitalOceanError(
|
269
|
+
f'Error: {err.status_code} {err.reason}: {err.error.message}'
|
270
|
+
) from err
|
271
|
+
|
272
|
+
|
273
|
+
def filter_instances(
|
274
|
+
cluster_name_on_cloud: str,
|
275
|
+
status_filters: Optional[List[str]] = None) -> Dict[str, Any]:
|
276
|
+
"""Returns Dict mapping instance name
|
277
|
+
to instance metadata filtered by status
|
278
|
+
"""
|
279
|
+
|
280
|
+
filtered_instances: Dict[str, Any] = {}
|
281
|
+
page = 1
|
282
|
+
paginated = True
|
283
|
+
while paginated:
|
284
|
+
try:
|
285
|
+
resp = client().droplets.list(
|
286
|
+
tag_name=f'{provision_constants.TAG_SKYPILOT_CLUSTER_NAME}:'
|
287
|
+
f'{cluster_name_on_cloud}',
|
288
|
+
per_page=50,
|
289
|
+
page=page)
|
290
|
+
for instance in resp['droplets']:
|
291
|
+
if status_filters is None or instance[
|
292
|
+
'status'] in status_filters:
|
293
|
+
filtered_instances[instance['name']] = instance
|
294
|
+
except do.exceptions().HttpResponseError as err:
|
295
|
+
raise DigitalOceanError(
|
296
|
+
f'Error: {err.status_code} {err.reason}: {err.error.message}'
|
297
|
+
) from err
|
298
|
+
|
299
|
+
pages = resp['links']
|
300
|
+
if 'pages' in pages and 'next' in pages['pages']:
|
301
|
+
pages = pages['pages']
|
302
|
+
parsed_url = urllib.parse.urlparse(pages['next'])
|
303
|
+
page = int(urllib.parse.parse_qs(parsed_url.query)['page'][0])
|
304
|
+
else:
|
305
|
+
paginated = False
|
306
|
+
return filtered_instances
|
sky/provision/docker_utils.py
CHANGED
@@ -38,6 +38,13 @@ class DockerLoginConfig:
|
|
38
38
|
password: str
|
39
39
|
server: str
|
40
40
|
|
41
|
+
def format_image(self, image: str) -> str:
|
42
|
+
"""Format the image name with the server prefix."""
|
43
|
+
server_prefix = f'{self.server}/'
|
44
|
+
if not image.startswith(server_prefix):
|
45
|
+
return f'{server_prefix}{image}'
|
46
|
+
return image
|
47
|
+
|
41
48
|
@classmethod
|
42
49
|
def from_env_vars(cls, d: Dict[str, str]) -> 'DockerLoginConfig':
|
43
50
|
return cls(
|
@@ -220,9 +227,7 @@ class DockerInitializer:
|
|
220
227
|
wait_for_docker_daemon=True)
|
221
228
|
# We automatically add the server prefix to the image name if
|
222
229
|
# the user did not add it.
|
223
|
-
|
224
|
-
if not specific_image.startswith(server_prefix):
|
225
|
-
specific_image = f'{server_prefix}{specific_image}'
|
230
|
+
specific_image = docker_login_config.format_image(specific_image)
|
226
231
|
|
227
232
|
if self.docker_config.get('pull_before_run', True):
|
228
233
|
assert specific_image, ('Image must be included in config if ' +
|
@@ -338,14 +343,20 @@ class DockerInitializer:
|
|
338
343
|
no_exist = 'NoExist'
|
339
344
|
# SkyPilot: Add the current user to the docker group first (if needed),
|
340
345
|
# before checking if docker is installed to avoid permission issues.
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
346
|
+
docker_cmd = ('id -nG $USER | grep -qw docker || '
|
347
|
+
'sudo usermod -aG docker $USER > /dev/null 2>&1;'
|
348
|
+
f'command -v {self.docker_cmd} || echo {no_exist!r}')
|
349
|
+
cleaned_output = self._run(docker_cmd)
|
350
|
+
timeout = 60 * 10 # 10 minute timeout
|
351
|
+
start = time.time()
|
352
|
+
while no_exist in cleaned_output or 'docker' not in cleaned_output:
|
353
|
+
if time.time() - start > timeout:
|
354
|
+
logger.error(
|
355
|
+
f'{self.docker_cmd.capitalize()} not installed. Please use '
|
356
|
+
f'an image with {self.docker_cmd.capitalize()} installed.')
|
357
|
+
return
|
358
|
+
time.sleep(5)
|
359
|
+
cleaned_output = self._run(docker_cmd)
|
349
360
|
|
350
361
|
def _check_container_status(self):
|
351
362
|
if self.initialized:
|
@@ -38,7 +38,7 @@ _FIREWALL_RESOURCE_NOT_FOUND_PATTERN = re.compile(
|
|
38
38
|
r'The resource \'projects/.*/global/firewalls/.*\' was not found')
|
39
39
|
|
40
40
|
|
41
|
-
def
|
41
|
+
def _retry_on_gcp_http_exception(
|
42
42
|
regex: Optional[str] = None,
|
43
43
|
max_retries: int = GCP_MAX_RETRIES,
|
44
44
|
retry_interval_s: int = GCP_RETRY_INTERVAL_SECONDS,
|
@@ -49,17 +49,18 @@ def _retry_on_http_exception(
|
|
49
49
|
|
50
50
|
@functools.wraps(func)
|
51
51
|
def wrapper(*args, **kwargs):
|
52
|
-
exception_type = gcp.http_error_exception()
|
53
52
|
|
54
53
|
def try_catch_exc():
|
55
54
|
try:
|
56
55
|
value = func(*args, **kwargs)
|
57
56
|
return value
|
58
57
|
except Exception as e: # pylint: disable=broad-except
|
59
|
-
if
|
60
|
-
|
61
|
-
|
62
|
-
|
58
|
+
if (isinstance(e, gcp.http_error_exception()) and
|
59
|
+
(regex is None or re.search(regex, str(e)))):
|
60
|
+
logger.error(
|
61
|
+
f'Retrying for gcp.http_error_exception: {e}')
|
62
|
+
return e
|
63
|
+
raise
|
63
64
|
|
64
65
|
for _ in range(max_retries):
|
65
66
|
ret = try_catch_exc()
|
@@ -431,7 +432,7 @@ class GCPComputeInstance(GCPInstance):
|
|
431
432
|
logger.debug(
|
432
433
|
f'Waiting GCP operation {operation["name"]} to be ready ...')
|
433
434
|
|
434
|
-
@
|
435
|
+
@_retry_on_gcp_http_exception(
|
435
436
|
f'Failed to wait for operation {operation["name"]}')
|
436
437
|
def call_operation(fn, timeout: int):
|
437
438
|
request = fn(
|
@@ -613,6 +614,11 @@ class GCPComputeInstance(GCPInstance):
|
|
613
614
|
return operation
|
614
615
|
|
615
616
|
@classmethod
|
617
|
+
# When there is a cloud function running in parallel to set labels for
|
618
|
+
# newly created instances, it may fail with the following error:
|
619
|
+
# "Labels fingerprint either invalid or resource labels have changed"
|
620
|
+
# We should retry until the labels are set successfully.
|
621
|
+
@_retry_on_gcp_http_exception('Labels fingerprint either invalid')
|
616
622
|
def set_labels(cls, project_id: str, availability_zone: str, node_id: str,
|
617
623
|
labels: dict) -> None:
|
618
624
|
node = cls.load_resource().instances().get(
|
@@ -1211,7 +1217,7 @@ class GCPTPUVMInstance(GCPInstance):
|
|
1211
1217
|
"""Poll for TPU operation until finished."""
|
1212
1218
|
del project_id, region, zone # unused
|
1213
1219
|
|
1214
|
-
@
|
1220
|
+
@_retry_on_gcp_http_exception(
|
1215
1221
|
f'Failed to wait for operation {operation["name"]}')
|
1216
1222
|
def call_operation(fn, timeout: int):
|
1217
1223
|
request = fn(name=operation['name'])
|
@@ -1379,7 +1385,7 @@ class GCPTPUVMInstance(GCPInstance):
|
|
1379
1385
|
f'Failed to get VPC name for instance {instance}') from e
|
1380
1386
|
|
1381
1387
|
@classmethod
|
1382
|
-
@
|
1388
|
+
@_retry_on_gcp_http_exception('unable to queue the operation')
|
1383
1389
|
def set_labels(cls, project_id: str, availability_zone: str, node_id: str,
|
1384
1390
|
labels: dict) -> None:
|
1385
1391
|
while True:
|
@@ -804,7 +804,8 @@ def _create_pods(region: str, cluster_name_on_cloud: str,
|
|
804
804
|
|
805
805
|
# Create pods in parallel
|
806
806
|
pods = subprocess_utils.run_in_parallel(_create_pod_thread,
|
807
|
-
range(to_start_count),
|
807
|
+
list(range(to_start_count)),
|
808
|
+
_NUM_THREADS)
|
808
809
|
|
809
810
|
# Process created pods
|
810
811
|
for pod in pods:
|
@@ -975,7 +976,7 @@ def terminate_instances(
|
|
975
976
|
_terminate_node(namespace, context, pod_name)
|
976
977
|
|
977
978
|
# Run pod termination in parallel
|
978
|
-
subprocess_utils.run_in_parallel(_terminate_pod_thread, pods.items(),
|
979
|
+
subprocess_utils.run_in_parallel(_terminate_pod_thread, list(pods.items()),
|
979
980
|
_NUM_THREADS)
|
980
981
|
|
981
982
|
|
@@ -7,6 +7,7 @@ import os
|
|
7
7
|
import re
|
8
8
|
import shutil
|
9
9
|
import subprocess
|
10
|
+
import time
|
10
11
|
import typing
|
11
12
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
12
13
|
from urllib.parse import urlparse
|
@@ -105,6 +106,75 @@ ANNOTATIONS_POD_NOT_FOUND_ERROR_MSG = ('Pod {pod_name} not found in namespace '
|
|
105
106
|
|
106
107
|
logger = sky_logging.init_logger(__name__)
|
107
108
|
|
109
|
+
# Default retry settings for Kubernetes API calls
|
110
|
+
DEFAULT_MAX_RETRIES = 3
|
111
|
+
DEFAULT_RETRY_INTERVAL_SECONDS = 1
|
112
|
+
|
113
|
+
|
114
|
+
def _retry_on_error(max_retries=DEFAULT_MAX_RETRIES,
|
115
|
+
retry_interval=DEFAULT_RETRY_INTERVAL_SECONDS,
|
116
|
+
resource_type: Optional[str] = None):
|
117
|
+
"""Decorator to retry Kubernetes API calls on transient failures.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
max_retries: Maximum number of retry attempts
|
121
|
+
retry_interval: Initial seconds to wait between retries
|
122
|
+
resource_type: Type of resource being accessed (e.g. 'node', 'pod').
|
123
|
+
Used to provide more specific error messages.
|
124
|
+
"""
|
125
|
+
|
126
|
+
def decorator(func):
|
127
|
+
|
128
|
+
@functools.wraps(func)
|
129
|
+
def wrapper(*args, **kwargs):
|
130
|
+
last_exception = None
|
131
|
+
backoff = common_utils.Backoff(initial_backoff=retry_interval,
|
132
|
+
max_backoff_factor=3)
|
133
|
+
|
134
|
+
for attempt in range(max_retries):
|
135
|
+
try:
|
136
|
+
return func(*args, **kwargs)
|
137
|
+
except (kubernetes.max_retry_error(),
|
138
|
+
kubernetes.api_exception(),
|
139
|
+
kubernetes.config_exception()) as e:
|
140
|
+
last_exception = e
|
141
|
+
# Don't retry on permanent errors like 401 (Unauthorized)
|
142
|
+
# or 403 (Forbidden)
|
143
|
+
if (isinstance(e, kubernetes.api_exception()) and
|
144
|
+
e.status in (401, 403)):
|
145
|
+
raise
|
146
|
+
if attempt < max_retries - 1:
|
147
|
+
sleep_time = backoff.current_backoff()
|
148
|
+
logger.debug(f'Kubernetes API call {func.__name__} '
|
149
|
+
f'failed with {str(e)}. Retrying in '
|
150
|
+
f'{sleep_time:.1f}s...')
|
151
|
+
time.sleep(sleep_time)
|
152
|
+
continue
|
153
|
+
|
154
|
+
# Format error message based on the type of exception
|
155
|
+
resource_msg = f' when trying to get {resource_type} info' \
|
156
|
+
if resource_type else ''
|
157
|
+
debug_cmd = f' To debug, run: kubectl get {resource_type}s' \
|
158
|
+
if resource_type else ''
|
159
|
+
|
160
|
+
if isinstance(last_exception, kubernetes.max_retry_error()):
|
161
|
+
error_msg = f'Timed out{resource_msg} from Kubernetes cluster.'
|
162
|
+
elif isinstance(last_exception, kubernetes.api_exception()):
|
163
|
+
error_msg = (f'Kubernetes API error{resource_msg}: '
|
164
|
+
f'{str(last_exception)}')
|
165
|
+
else:
|
166
|
+
error_msg = (f'Kubernetes configuration error{resource_msg}: '
|
167
|
+
f'{str(last_exception)}')
|
168
|
+
|
169
|
+
raise exceptions.ResourcesUnavailableError(
|
170
|
+
f'{error_msg}'
|
171
|
+
f' Please check if the cluster is healthy and retry.'
|
172
|
+
f'{debug_cmd}') from last_exception
|
173
|
+
|
174
|
+
return wrapper
|
175
|
+
|
176
|
+
return decorator
|
177
|
+
|
108
178
|
|
109
179
|
class GPULabelFormatter:
|
110
180
|
"""Base class to define a GPU label formatter for a Kubernetes cluster
|
@@ -340,14 +410,15 @@ class GFDLabelFormatter(GPULabelFormatter):
|
|
340
410
|
"""
|
341
411
|
canonical_gpu_names = [
|
342
412
|
'A100-80GB', 'A100', 'A10G', 'H100', 'K80', 'M60', 'T4g', 'T4',
|
343
|
-
'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L4'
|
413
|
+
'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L40', 'L4'
|
344
414
|
]
|
345
415
|
for canonical_name in canonical_gpu_names:
|
346
416
|
# A100-80G accelerator is A100-SXM-80GB or A100-PCIE-80GB
|
347
417
|
if canonical_name == 'A100-80GB' and re.search(
|
348
418
|
r'A100.*-80GB', value):
|
349
419
|
return canonical_name
|
350
|
-
|
420
|
+
# Use word boundary matching to prevent substring matches
|
421
|
+
elif re.search(rf'\b{re.escape(canonical_name)}\b', value):
|
351
422
|
return canonical_name
|
352
423
|
|
353
424
|
# If we didn't find a canonical name:
|
@@ -445,6 +516,7 @@ def detect_accelerator_resource(
|
|
445
516
|
|
446
517
|
|
447
518
|
@functools.lru_cache(maxsize=10)
|
519
|
+
@_retry_on_error(resource_type='node')
|
448
520
|
def get_kubernetes_nodes(context: Optional[str] = None) -> List[Any]:
|
449
521
|
"""Gets the kubernetes nodes in the context.
|
450
522
|
|
@@ -453,17 +525,12 @@ def get_kubernetes_nodes(context: Optional[str] = None) -> List[Any]:
|
|
453
525
|
if context is None:
|
454
526
|
context = get_current_kube_config_context_name()
|
455
527
|
|
456
|
-
|
457
|
-
|
458
|
-
_request_timeout=kubernetes.API_TIMEOUT).items
|
459
|
-
except kubernetes.max_retry_error():
|
460
|
-
raise exceptions.ResourcesUnavailableError(
|
461
|
-
'Timed out when trying to get node info from Kubernetes cluster. '
|
462
|
-
'Please check if the cluster is healthy and retry. To debug, run: '
|
463
|
-
'kubectl get nodes') from None
|
528
|
+
nodes = kubernetes.core_api(context).list_node(
|
529
|
+
_request_timeout=kubernetes.API_TIMEOUT).items
|
464
530
|
return nodes
|
465
531
|
|
466
532
|
|
533
|
+
@_retry_on_error(resource_type='pod')
|
467
534
|
def get_all_pods_in_kubernetes_cluster(
|
468
535
|
context: Optional[str] = None) -> List[Any]:
|
469
536
|
"""Gets pods in all namespaces in kubernetes cluster indicated by context.
|
@@ -473,14 +540,8 @@ def get_all_pods_in_kubernetes_cluster(
|
|
473
540
|
if context is None:
|
474
541
|
context = get_current_kube_config_context_name()
|
475
542
|
|
476
|
-
|
477
|
-
|
478
|
-
_request_timeout=kubernetes.API_TIMEOUT).items
|
479
|
-
except kubernetes.max_retry_error():
|
480
|
-
raise exceptions.ResourcesUnavailableError(
|
481
|
-
'Timed out when trying to get pod info from Kubernetes cluster. '
|
482
|
-
'Please check if the cluster is healthy and retry. To debug, run: '
|
483
|
-
'kubectl get pods') from None
|
543
|
+
pods = kubernetes.core_api(context).list_pod_for_all_namespaces(
|
544
|
+
_request_timeout=kubernetes.API_TIMEOUT).items
|
484
545
|
return pods
|
485
546
|
|
486
547
|
|
@@ -892,6 +953,52 @@ def check_credentials(context: Optional[str],
|
|
892
953
|
return True, None
|
893
954
|
|
894
955
|
|
956
|
+
def check_pod_config(pod_config: dict) \
|
957
|
+
-> Tuple[bool, Optional[str]]:
|
958
|
+
"""Check if the pod_config is a valid pod config
|
959
|
+
|
960
|
+
Using deserialize api to check the pod_config is valid or not.
|
961
|
+
|
962
|
+
Returns:
|
963
|
+
bool: True if pod_config is valid.
|
964
|
+
str: Error message about why the pod_config is invalid, None otherwise.
|
965
|
+
"""
|
966
|
+
errors = []
|
967
|
+
# This api_client won't be used to send any requests, so there is no need to
|
968
|
+
# load kubeconfig
|
969
|
+
api_client = kubernetes.kubernetes.client.ApiClient()
|
970
|
+
|
971
|
+
# Used for kubernetes api_client deserialize function, the function will use
|
972
|
+
# data attr, the detail ref:
|
973
|
+
# https://github.com/kubernetes-client/python/blob/master/kubernetes/client/api_client.py#L244
|
974
|
+
class InnerResponse():
|
975
|
+
|
976
|
+
def __init__(self, data: dict):
|
977
|
+
self.data = json.dumps(data)
|
978
|
+
|
979
|
+
try:
|
980
|
+
# Validate metadata if present
|
981
|
+
if 'metadata' in pod_config:
|
982
|
+
try:
|
983
|
+
value = InnerResponse(pod_config['metadata'])
|
984
|
+
api_client.deserialize(
|
985
|
+
value, kubernetes.kubernetes.client.V1ObjectMeta)
|
986
|
+
except ValueError as e:
|
987
|
+
errors.append(f'Invalid metadata: {str(e)}')
|
988
|
+
# Validate spec if present
|
989
|
+
if 'spec' in pod_config:
|
990
|
+
try:
|
991
|
+
value = InnerResponse(pod_config['spec'])
|
992
|
+
api_client.deserialize(value,
|
993
|
+
kubernetes.kubernetes.client.V1PodSpec)
|
994
|
+
except ValueError as e:
|
995
|
+
errors.append(f'Invalid spec: {str(e)}')
|
996
|
+
return len(errors) == 0, '.'.join(errors)
|
997
|
+
except Exception as e: # pylint: disable=broad-except
|
998
|
+
errors.append(f'Validation error: {str(e)}')
|
999
|
+
return False, '.'.join(errors)
|
1000
|
+
|
1001
|
+
|
895
1002
|
def is_kubeconfig_exec_auth(
|
896
1003
|
context: Optional[str] = None) -> Tuple[bool, Optional[str]]:
|
897
1004
|
"""Checks if the kubeconfig file uses exec-based authentication
|
@@ -1711,8 +1818,6 @@ def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]):
|
|
1711
1818
|
else:
|
1712
1819
|
destination[key].extend(value)
|
1713
1820
|
else:
|
1714
|
-
if destination is None:
|
1715
|
-
destination = {}
|
1716
1821
|
destination[key] = value
|
1717
1822
|
|
1718
1823
|
|