skypilot-nightly 1.0.0.dev20251012__py3-none-any.whl → 1.0.0.dev20251014__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of skypilot-nightly might be problematic. Click here for more details.
- sky/__init__.py +4 -2
- sky/adaptors/shadeform.py +89 -0
- sky/authentication.py +52 -2
- sky/backends/backend_utils.py +35 -25
- sky/backends/cloud_vm_ray_backend.py +5 -5
- sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
- sky/catalog/kubernetes_catalog.py +19 -25
- sky/catalog/shadeform_catalog.py +165 -0
- sky/client/cli/command.py +53 -19
- sky/client/sdk.py +13 -1
- sky/clouds/__init__.py +2 -0
- sky/clouds/shadeform.py +393 -0
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/volumes.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/jobs/controller.py +122 -145
- sky/jobs/recovery_strategy.py +59 -82
- sky/jobs/scheduler.py +5 -5
- sky/jobs/state.py +65 -21
- sky/jobs/utils.py +58 -22
- sky/metrics/utils.py +27 -6
- sky/provision/__init__.py +1 -0
- sky/provision/kubernetes/utils.py +44 -39
- sky/provision/shadeform/__init__.py +11 -0
- sky/provision/shadeform/config.py +12 -0
- sky/provision/shadeform/instance.py +351 -0
- sky/provision/shadeform/shadeform_utils.py +83 -0
- sky/server/common.py +4 -2
- sky/server/requests/executor.py +25 -3
- sky/server/server.py +9 -3
- sky/setup_files/dependencies.py +1 -0
- sky/sky_logging.py +0 -2
- sky/skylet/constants.py +23 -6
- sky/skylet/log_lib.py +0 -1
- sky/skylet/log_lib.pyi +1 -1
- sky/templates/shadeform-ray.yml.j2 +72 -0
- sky/utils/common.py +2 -0
- sky/utils/context.py +57 -51
- sky/utils/context_utils.py +15 -11
- sky/utils/controller_utils.py +35 -8
- sky/utils/locks.py +20 -5
- sky/utils/subprocess_utils.py +4 -3
- {skypilot_nightly-1.0.0.dev20251012.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/METADATA +39 -38
- {skypilot_nightly-1.0.0.dev20251012.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/RECORD +63 -54
- /sky/dashboard/out/_next/static/{yOfMelBaFp8uL5F9atyAK → 9Fek73R28lDp1A5J4N7g7}/_buildManifest.js +0 -0
- /sky/dashboard/out/_next/static/{yOfMelBaFp8uL5F9atyAK → 9Fek73R28lDp1A5J4N7g7}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20251012.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20251012.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20251012.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20251012.dist-info → skypilot_nightly-1.0.0.dev20251014.dist-info}/top_level.txt +0 -0
sky/jobs/utils.py
CHANGED
|
@@ -8,7 +8,6 @@ import asyncio
|
|
|
8
8
|
import collections
|
|
9
9
|
import datetime
|
|
10
10
|
import enum
|
|
11
|
-
import logging
|
|
12
11
|
import os
|
|
13
12
|
import pathlib
|
|
14
13
|
import re
|
|
@@ -84,6 +83,7 @@ _LOG_STREAM_CHECK_CONTROLLER_GAP_SECONDS = 5
|
|
|
84
83
|
|
|
85
84
|
_JOB_STATUS_FETCH_MAX_RETRIES = 3
|
|
86
85
|
_JOB_K8S_TRANSIENT_NW_MSG = 'Unable to connect to the server: dial tcp'
|
|
86
|
+
_JOB_STATUS_FETCH_TIMEOUT_SECONDS = 30
|
|
87
87
|
|
|
88
88
|
_JOB_WAITING_STATUS_MESSAGE = ux_utils.spinner_message(
|
|
89
89
|
'Waiting for task to start[/]'
|
|
@@ -101,6 +101,13 @@ _JOB_CANCELLED_MESSAGE = (
|
|
|
101
101
|
# update the state.
|
|
102
102
|
_FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 120
|
|
103
103
|
|
|
104
|
+
# After enabling consolidation mode, we need to restart the API server to get
|
|
105
|
+
# the jobs refresh deamon and correct number of executors. We use this file to
|
|
106
|
+
# indicate that the API server has been restarted after enabling consolidation
|
|
107
|
+
# mode.
|
|
108
|
+
_JOBS_CONSOLIDATION_RELOADED_SIGNAL_FILE = (
|
|
109
|
+
'~/.sky/.jobs_controller_consolidation_reloaded_signal')
|
|
110
|
+
|
|
104
111
|
|
|
105
112
|
class ManagedJobQueueResultType(enum.Enum):
|
|
106
113
|
"""The type of the managed job queue result."""
|
|
@@ -117,9 +124,8 @@ class UserSignal(enum.Enum):
|
|
|
117
124
|
|
|
118
125
|
# ====== internal functions ======
|
|
119
126
|
def terminate_cluster(
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
_logger: logging.Logger = logger, # pylint: disable=invalid-name
|
|
127
|
+
cluster_name: str,
|
|
128
|
+
max_retry: int = 6,
|
|
123
129
|
) -> None:
|
|
124
130
|
"""Terminate the cluster."""
|
|
125
131
|
from sky import core # pylint: disable=import-outside-toplevel
|
|
@@ -143,18 +149,18 @@ def terminate_cluster(
|
|
|
143
149
|
return
|
|
144
150
|
except exceptions.ClusterDoesNotExist:
|
|
145
151
|
# The cluster is already down.
|
|
146
|
-
|
|
152
|
+
logger.debug(f'The cluster {cluster_name} is already down.')
|
|
147
153
|
return
|
|
148
154
|
except Exception as e: # pylint: disable=broad-except
|
|
149
155
|
retry_cnt += 1
|
|
150
156
|
if retry_cnt >= max_retry:
|
|
151
157
|
raise RuntimeError(
|
|
152
158
|
f'Failed to terminate the cluster {cluster_name}.') from e
|
|
153
|
-
|
|
159
|
+
logger.error(
|
|
154
160
|
f'Failed to terminate the cluster {cluster_name}. Retrying.'
|
|
155
161
|
f'Details: {common_utils.format_exception(e)}')
|
|
156
162
|
with ux_utils.enable_traceback():
|
|
157
|
-
|
|
163
|
+
logger.error(f' Traceback: {traceback.format_exc()}')
|
|
158
164
|
time.sleep(backoff.current_backoff())
|
|
159
165
|
|
|
160
166
|
|
|
@@ -202,13 +208,39 @@ def _validate_consolidation_mode_config(
|
|
|
202
208
|
# API Server. Under the hood, we submit the job monitoring logic as processes
|
|
203
209
|
# directly in the API Server.
|
|
204
210
|
# Use LRU Cache so that the check is only done once.
|
|
205
|
-
@annotations.lru_cache(scope='request', maxsize=
|
|
206
|
-
def is_consolidation_mode() -> bool:
|
|
211
|
+
@annotations.lru_cache(scope='request', maxsize=2)
|
|
212
|
+
def is_consolidation_mode(on_api_restart: bool = False) -> bool:
|
|
207
213
|
if os.environ.get(constants.OVERRIDE_CONSOLIDATION_MODE) is not None:
|
|
208
214
|
return True
|
|
209
215
|
|
|
210
|
-
|
|
216
|
+
config_consolidation_mode = skypilot_config.get_nested(
|
|
211
217
|
('jobs', 'controller', 'consolidation_mode'), default_value=False)
|
|
218
|
+
|
|
219
|
+
signal_file = pathlib.Path(
|
|
220
|
+
_JOBS_CONSOLIDATION_RELOADED_SIGNAL_FILE).expanduser()
|
|
221
|
+
|
|
222
|
+
restart_signal_file_exists = signal_file.exists()
|
|
223
|
+
consolidation_mode = (config_consolidation_mode and
|
|
224
|
+
restart_signal_file_exists)
|
|
225
|
+
|
|
226
|
+
if on_api_restart:
|
|
227
|
+
if config_consolidation_mode:
|
|
228
|
+
signal_file.touch()
|
|
229
|
+
else:
|
|
230
|
+
if not restart_signal_file_exists:
|
|
231
|
+
if config_consolidation_mode:
|
|
232
|
+
logger.warning(f'{colorama.Fore.YELLOW}Consolidation mode for '
|
|
233
|
+
'managed jobs is enabled in the server config, '
|
|
234
|
+
'but the API server has not been restarted yet. '
|
|
235
|
+
'Please restart the API server to enable it.'
|
|
236
|
+
f'{colorama.Style.RESET_ALL}')
|
|
237
|
+
return False
|
|
238
|
+
elif not config_consolidation_mode:
|
|
239
|
+
# Cleanup the signal file if the consolidation mode is disabled in
|
|
240
|
+
# the config. This allow the user to disable the consolidation mode
|
|
241
|
+
# without restarting the API server.
|
|
242
|
+
signal_file.unlink()
|
|
243
|
+
|
|
212
244
|
# We should only do this check on API server, as the controller will not
|
|
213
245
|
# have related config and will always seemingly disabled for consolidation
|
|
214
246
|
# mode. Check #6611 for more details.
|
|
@@ -269,8 +301,7 @@ def ha_recovery_for_consolidation_mode():
|
|
|
269
301
|
|
|
270
302
|
async def get_job_status(
|
|
271
303
|
backend: 'backends.CloudVmRayBackend', cluster_name: str,
|
|
272
|
-
job_id: Optional[int]
|
|
273
|
-
job_logger: logging.Logger) -> Optional['job_lib.JobStatus']:
|
|
304
|
+
job_id: Optional[int]) -> Optional['job_lib.JobStatus']:
|
|
274
305
|
"""Check the status of the job running on a managed job cluster.
|
|
275
306
|
|
|
276
307
|
It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_DRIVER,
|
|
@@ -282,26 +313,28 @@ async def get_job_status(
|
|
|
282
313
|
if handle is None:
|
|
283
314
|
# This can happen if the cluster was preempted and background status
|
|
284
315
|
# refresh already noticed and cleaned it up.
|
|
285
|
-
|
|
316
|
+
logger.info(f'Cluster {cluster_name} not found.')
|
|
286
317
|
return None
|
|
287
318
|
assert isinstance(handle, backends.CloudVmRayResourceHandle), handle
|
|
288
319
|
job_ids = None if job_id is None else [job_id]
|
|
289
320
|
for i in range(_JOB_STATUS_FETCH_MAX_RETRIES):
|
|
290
321
|
try:
|
|
291
|
-
|
|
292
|
-
statuses = await
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
322
|
+
logger.info('=== Checking the job status... ===')
|
|
323
|
+
statuses = await asyncio.wait_for(
|
|
324
|
+
context_utils.to_thread(backend.get_job_status,
|
|
325
|
+
handle,
|
|
326
|
+
job_ids=job_ids,
|
|
327
|
+
stream_logs=False),
|
|
328
|
+
timeout=_JOB_STATUS_FETCH_TIMEOUT_SECONDS)
|
|
296
329
|
status = list(statuses.values())[0]
|
|
297
330
|
if status is None:
|
|
298
|
-
|
|
331
|
+
logger.info('No job found.')
|
|
299
332
|
else:
|
|
300
|
-
|
|
301
|
-
|
|
333
|
+
logger.info(f'Job status: {status}')
|
|
334
|
+
logger.info('=' * 34)
|
|
302
335
|
return status
|
|
303
336
|
except (exceptions.CommandError, grpc.RpcError, grpc.FutureTimeoutError,
|
|
304
|
-
ValueError, TypeError) as e:
|
|
337
|
+
ValueError, TypeError, asyncio.TimeoutError) as e:
|
|
305
338
|
# Note: Each of these exceptions has some additional conditions to
|
|
306
339
|
# limit how we handle it and whether or not we catch it.
|
|
307
340
|
# Retry on k8s transient network errors. This is useful when using
|
|
@@ -322,6 +355,9 @@ async def get_job_status(
|
|
|
322
355
|
is_transient_error = True
|
|
323
356
|
elif isinstance(e, grpc.FutureTimeoutError):
|
|
324
357
|
detailed_reason = 'Timeout'
|
|
358
|
+
elif isinstance(e, asyncio.TimeoutError):
|
|
359
|
+
detailed_reason = ('Job status check timed out after '
|
|
360
|
+
f'{_JOB_STATUS_FETCH_TIMEOUT_SECONDS}s')
|
|
325
361
|
# TODO(cooperc): Gracefully handle these exceptions in the backend.
|
|
326
362
|
elif isinstance(e, ValueError):
|
|
327
363
|
# If the cluster yaml is deleted in the middle of getting the
|
sky/metrics/utils.py
CHANGED
|
@@ -48,8 +48,15 @@ SKY_APISERVER_CODE_DURATION_SECONDS = prom.Histogram(
|
|
|
48
48
|
'sky_apiserver_code_duration_seconds',
|
|
49
49
|
'Time spent processing code',
|
|
50
50
|
['name', 'group'],
|
|
51
|
-
buckets=(0.
|
|
52
|
-
|
|
51
|
+
buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.25,
|
|
52
|
+
0.35, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 2.75, 3, 3.5, 4, 4.5,
|
|
53
|
+
5, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0,
|
|
54
|
+
50.0, 55.0, 60.0, 80.0, 120.0, 140.0, 160.0, 180.0, 200.0, 220.0,
|
|
55
|
+
240.0, 260.0, 280.0, 300.0, 320.0, 340.0, 360.0, 380.0, 400.0,
|
|
56
|
+
420.0, 440.0, 460.0, 480.0, 500.0, 520.0, 540.0, 560.0, 580.0,
|
|
57
|
+
600.0, 620.0, 640.0, 660.0, 680.0, 700.0, 720.0, 740.0, 760.0,
|
|
58
|
+
780.0, 800.0, 820.0, 840.0, 860.0, 880.0, 900.0, 920.0, 940.0,
|
|
59
|
+
960.0, 980.0, 1000.0, float('inf')),
|
|
53
60
|
)
|
|
54
61
|
|
|
55
62
|
# Total number of API server requests, grouped by path, method, and status.
|
|
@@ -65,16 +72,30 @@ SKY_APISERVER_REQUEST_DURATION_SECONDS = prom.Histogram(
|
|
|
65
72
|
'sky_apiserver_request_duration_seconds',
|
|
66
73
|
'Time spent processing API server requests',
|
|
67
74
|
['path', 'method', 'status'],
|
|
68
|
-
buckets=(0.
|
|
69
|
-
|
|
75
|
+
buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.25,
|
|
76
|
+
0.35, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 2.75, 3, 3.5, 4, 4.5,
|
|
77
|
+
5, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0,
|
|
78
|
+
50.0, 55.0, 60.0, 80.0, 120.0, 140.0, 160.0, 180.0, 200.0, 220.0,
|
|
79
|
+
240.0, 260.0, 280.0, 300.0, 320.0, 340.0, 360.0, 380.0, 400.0,
|
|
80
|
+
420.0, 440.0, 460.0, 480.0, 500.0, 520.0, 540.0, 560.0, 580.0,
|
|
81
|
+
600.0, 620.0, 640.0, 660.0, 680.0, 700.0, 720.0, 740.0, 760.0,
|
|
82
|
+
780.0, 800.0, 820.0, 840.0, 860.0, 880.0, 900.0, 920.0, 940.0,
|
|
83
|
+
960.0, 980.0, 1000.0, float('inf')),
|
|
70
84
|
)
|
|
71
85
|
|
|
72
86
|
SKY_APISERVER_EVENT_LOOP_LAG_SECONDS = prom.Histogram(
|
|
73
87
|
'sky_apiserver_event_loop_lag_seconds',
|
|
74
88
|
'Scheduling delay of the server event loop',
|
|
75
89
|
['pid'],
|
|
76
|
-
buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.
|
|
77
|
-
|
|
90
|
+
buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.25,
|
|
91
|
+
0.35, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 2.75, 3, 3.5, 4, 4.5,
|
|
92
|
+
5, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0,
|
|
93
|
+
50.0, 55.0, 60.0, 80.0, 120.0, 140.0, 160.0, 180.0, 200.0, 220.0,
|
|
94
|
+
240.0, 260.0, 280.0, 300.0, 320.0, 340.0, 360.0, 380.0, 400.0,
|
|
95
|
+
420.0, 440.0, 460.0, 480.0, 500.0, 520.0, 540.0, 560.0, 580.0,
|
|
96
|
+
600.0, 620.0, 640.0, 660.0, 680.0, 700.0, 720.0, 740.0, 760.0,
|
|
97
|
+
780.0, 800.0, 820.0, 840.0, 860.0, 880.0, 900.0, 920.0, 940.0,
|
|
98
|
+
960.0, 980.0, 1000.0, float('inf')),
|
|
78
99
|
)
|
|
79
100
|
|
|
80
101
|
SKY_APISERVER_WEBSOCKET_CONNECTIONS = prom.Gauge(
|
sky/provision/__init__.py
CHANGED
|
@@ -28,6 +28,7 @@ from sky.provision import primeintellect
|
|
|
28
28
|
from sky.provision import runpod
|
|
29
29
|
from sky.provision import scp
|
|
30
30
|
from sky.provision import seeweb
|
|
31
|
+
from sky.provision import shadeform
|
|
31
32
|
from sky.provision import ssh
|
|
32
33
|
from sky.provision import vast
|
|
33
34
|
from sky.provision import vsphere
|
|
@@ -1299,30 +1299,52 @@ class V1Pod:
|
|
|
1299
1299
|
|
|
1300
1300
|
|
|
1301
1301
|
@_retry_on_error(resource_type='pod')
|
|
1302
|
-
def
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1302
|
+
def get_allocated_gpu_qty_by_node(
|
|
1303
|
+
*,
|
|
1304
|
+
context: Optional[str] = None,
|
|
1305
|
+
) -> Dict[str, int]:
|
|
1306
|
+
"""Gets allocated GPU quantity by each node by fetching pods in
|
|
1307
|
+
all namespaces in kubernetes cluster indicated by context.
|
|
1308
1308
|
"""
|
|
1309
1309
|
if context is None:
|
|
1310
1310
|
context = get_current_kube_config_context_name()
|
|
1311
|
+
non_included_pod_statuses = POD_STATUSES.copy()
|
|
1312
|
+
status_filters = ['Running', 'Pending']
|
|
1313
|
+
if status_filters is not None:
|
|
1314
|
+
non_included_pod_statuses -= set(status_filters)
|
|
1315
|
+
field_selector = ','.join(
|
|
1316
|
+
[f'status.phase!={status}' for status in non_included_pod_statuses])
|
|
1311
1317
|
|
|
1312
1318
|
# Return raw urllib3.HTTPResponse object so that we can parse the json
|
|
1313
1319
|
# more efficiently.
|
|
1314
1320
|
response = kubernetes.core_api(context).list_pod_for_all_namespaces(
|
|
1315
|
-
_request_timeout=kubernetes.API_TIMEOUT,
|
|
1321
|
+
_request_timeout=kubernetes.API_TIMEOUT,
|
|
1322
|
+
_preload_content=False,
|
|
1323
|
+
field_selector=field_selector)
|
|
1316
1324
|
try:
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1325
|
+
allocated_qty_by_node: Dict[str, int] = collections.defaultdict(int)
|
|
1326
|
+
for item_dict in ijson.items(response,
|
|
1327
|
+
'items.item',
|
|
1328
|
+
buf_size=IJSON_BUFFER_SIZE):
|
|
1329
|
+
pod = V1Pod.from_dict(item_dict)
|
|
1330
|
+
if should_exclude_pod_from_gpu_allocation(pod):
|
|
1331
|
+
logger.debug(
|
|
1332
|
+
f'Excluding pod {pod.metadata.name} from GPU count '
|
|
1333
|
+
f'calculations on node {pod.spec.node_name}')
|
|
1334
|
+
continue
|
|
1335
|
+
# Iterate over all the containers in the pod and sum the
|
|
1336
|
+
# GPU requests
|
|
1337
|
+
pod_allocated_qty = 0
|
|
1338
|
+
for container in pod.spec.containers:
|
|
1339
|
+
if container.resources.requests:
|
|
1340
|
+
pod_allocated_qty += get_node_accelerator_count(
|
|
1341
|
+
context, container.resources.requests)
|
|
1342
|
+
if pod_allocated_qty > 0 and pod.spec.node_name:
|
|
1343
|
+
allocated_qty_by_node[pod.spec.node_name] += pod_allocated_qty
|
|
1344
|
+
return allocated_qty_by_node
|
|
1321
1345
|
finally:
|
|
1322
1346
|
response.release_conn()
|
|
1323
1347
|
|
|
1324
|
-
return pods
|
|
1325
|
-
|
|
1326
1348
|
|
|
1327
1349
|
def check_instance_fits(context: Optional[str],
|
|
1328
1350
|
instance: str) -> Tuple[bool, Optional[str]]:
|
|
@@ -3006,41 +3028,24 @@ def get_kubernetes_node_info(
|
|
|
3006
3028
|
label_keys = lf.get_label_keys()
|
|
3007
3029
|
|
|
3008
3030
|
# Check if all nodes have no accelerators to avoid fetching pods
|
|
3009
|
-
|
|
3031
|
+
has_accelerator_nodes = False
|
|
3010
3032
|
for node in nodes:
|
|
3011
3033
|
accelerator_count = get_node_accelerator_count(context,
|
|
3012
3034
|
node.status.allocatable)
|
|
3013
3035
|
if accelerator_count > 0:
|
|
3014
|
-
|
|
3036
|
+
has_accelerator_nodes = True
|
|
3015
3037
|
break
|
|
3016
3038
|
|
|
3017
|
-
# Get the
|
|
3018
|
-
pods = None
|
|
3039
|
+
# Get the allocated GPU quantity by each node
|
|
3019
3040
|
allocated_qty_by_node: Dict[str, int] = collections.defaultdict(int)
|
|
3020
|
-
|
|
3041
|
+
error_on_get_allocated_gpu_qty_by_node = False
|
|
3042
|
+
if has_accelerator_nodes:
|
|
3021
3043
|
try:
|
|
3022
|
-
|
|
3023
|
-
|
|
3024
|
-
for pod in pods:
|
|
3025
|
-
if pod.status.phase in ['Running', 'Pending']:
|
|
3026
|
-
# Skip pods that should not count against GPU count
|
|
3027
|
-
if should_exclude_pod_from_gpu_allocation(pod):
|
|
3028
|
-
logger.debug(f'Excluding low priority pod '
|
|
3029
|
-
f'{pod.metadata.name} from GPU allocation '
|
|
3030
|
-
f'calculations')
|
|
3031
|
-
continue
|
|
3032
|
-
# Iterate over all the containers in the pod and sum the
|
|
3033
|
-
# GPU requests
|
|
3034
|
-
pod_allocated_qty = 0
|
|
3035
|
-
for container in pod.spec.containers:
|
|
3036
|
-
if container.resources.requests:
|
|
3037
|
-
pod_allocated_qty += get_node_accelerator_count(
|
|
3038
|
-
context, container.resources.requests)
|
|
3039
|
-
if pod_allocated_qty > 0:
|
|
3040
|
-
allocated_qty_by_node[
|
|
3041
|
-
pod.spec.node_name] += pod_allocated_qty
|
|
3044
|
+
allocated_qty_by_node = get_allocated_gpu_qty_by_node(
|
|
3045
|
+
context=context)
|
|
3042
3046
|
except kubernetes.api_exception() as e:
|
|
3043
3047
|
if e.status == 403:
|
|
3048
|
+
error_on_get_allocated_gpu_qty_by_node = True
|
|
3044
3049
|
pass
|
|
3045
3050
|
else:
|
|
3046
3051
|
raise
|
|
@@ -3085,7 +3090,7 @@ def get_kubernetes_node_info(
|
|
|
3085
3090
|
ip_address=node_ip)
|
|
3086
3091
|
continue
|
|
3087
3092
|
|
|
3088
|
-
if
|
|
3093
|
+
if not has_accelerator_nodes or error_on_get_allocated_gpu_qty_by_node:
|
|
3089
3094
|
accelerators_available = -1
|
|
3090
3095
|
else:
|
|
3091
3096
|
allocated_qty = allocated_qty_by_node[node.metadata.name]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Shadeform provisioner."""
|
|
2
|
+
|
|
3
|
+
from sky.provision.shadeform.config import bootstrap_instances
|
|
4
|
+
from sky.provision.shadeform.instance import cleanup_ports
|
|
5
|
+
from sky.provision.shadeform.instance import get_cluster_info
|
|
6
|
+
from sky.provision.shadeform.instance import open_ports
|
|
7
|
+
from sky.provision.shadeform.instance import query_instances
|
|
8
|
+
from sky.provision.shadeform.instance import run_instances
|
|
9
|
+
from sky.provision.shadeform.instance import stop_instances
|
|
10
|
+
from sky.provision.shadeform.instance import terminate_instances
|
|
11
|
+
from sky.provision.shadeform.instance import wait_instances
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Shadeform configuration bootstrapping."""
|
|
2
|
+
|
|
3
|
+
from sky.provision import common
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def bootstrap_instances(
|
|
7
|
+
region: str, cluster_name: str,
|
|
8
|
+
config: common.ProvisionConfig) -> common.ProvisionConfig:
|
|
9
|
+
"""Bootstraps instances for the given cluster."""
|
|
10
|
+
del region, cluster_name # unused
|
|
11
|
+
|
|
12
|
+
return config
|