skypilot-nightly 1.0.0.dev20250909__py3-none-any.whl → 1.0.0.dev20250910__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of skypilot-nightly might be problematic. Click here for more details.
- sky/__init__.py +2 -2
- sky/authentication.py +19 -4
- sky/backends/backend_utils.py +35 -1
- sky/backends/cloud_vm_ray_backend.py +2 -2
- sky/client/sdk.py +20 -0
- sky/client/sdk_async.py +18 -16
- sky/clouds/aws.py +3 -1
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/chunks/{webpack-d4fabc08788e14af.js → webpack-1d7e11230da3ca89.js} +1 -1
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs/pools/[pool].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/volumes.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/data/storage.py +5 -1
- sky/execution.py +21 -14
- sky/jobs/constants.py +3 -0
- sky/jobs/controller.py +732 -310
- sky/jobs/recovery_strategy.py +251 -129
- sky/jobs/scheduler.py +247 -174
- sky/jobs/server/core.py +20 -4
- sky/jobs/server/utils.py +2 -2
- sky/jobs/state.py +702 -511
- sky/jobs/utils.py +94 -39
- sky/provision/aws/config.py +4 -1
- sky/provision/gcp/config.py +6 -1
- sky/provision/kubernetes/utils.py +17 -8
- sky/provision/provisioner.py +1 -0
- sky/serve/replica_managers.py +0 -7
- sky/serve/serve_utils.py +5 -0
- sky/serve/server/impl.py +1 -2
- sky/serve/service.py +0 -2
- sky/server/common.py +8 -3
- sky/server/config.py +43 -24
- sky/server/constants.py +1 -0
- sky/server/daemons.py +7 -11
- sky/server/requests/serializers/encoders.py +1 -1
- sky/server/server.py +8 -1
- sky/setup_files/dependencies.py +4 -2
- sky/skylet/attempt_skylet.py +1 -0
- sky/skylet/constants.py +3 -1
- sky/skylet/events.py +2 -10
- sky/utils/command_runner.pyi +3 -3
- sky/utils/common_utils.py +11 -1
- sky/utils/controller_utils.py +5 -0
- sky/utils/db/db_utils.py +31 -2
- sky/utils/rich_utils.py +3 -1
- sky/utils/subprocess_utils.py +9 -0
- sky/volumes/volume.py +2 -0
- {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/METADATA +39 -37
- {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/RECORD +67 -67
- /sky/dashboard/out/_next/static/{eWytLgin5zvayQw3Xk46m → 3SYxqNGnvvPS8h3gdD2T7}/_buildManifest.js +0 -0
- /sky/dashboard/out/_next/static/{eWytLgin5zvayQw3Xk46m → 3SYxqNGnvvPS8h3gdD2T7}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250909.dist-info → skypilot_nightly-1.0.0.dev20250910.dist-info}/top_level.txt +0 -0
sky/jobs/utils.py
CHANGED
|
@@ -4,9 +4,11 @@ NOTE: whenever an API change is made in this file, we need to bump the
|
|
|
4
4
|
jobs.constants.MANAGED_JOBS_VERSION and handle the API change in the
|
|
5
5
|
ManagedJobCodeGen.
|
|
6
6
|
"""
|
|
7
|
+
import asyncio
|
|
7
8
|
import collections
|
|
8
9
|
import datetime
|
|
9
10
|
import enum
|
|
11
|
+
import logging
|
|
10
12
|
import os
|
|
11
13
|
import pathlib
|
|
12
14
|
import shlex
|
|
@@ -14,11 +16,11 @@ import textwrap
|
|
|
14
16
|
import time
|
|
15
17
|
import traceback
|
|
16
18
|
import typing
|
|
17
|
-
from typing import Any, Deque, Dict, List, Optional, Set, TextIO,
|
|
19
|
+
from typing import (Any, Deque, Dict, List, Literal, Optional, Set, TextIO,
|
|
20
|
+
Tuple, Union)
|
|
18
21
|
|
|
19
22
|
import colorama
|
|
20
23
|
import filelock
|
|
21
|
-
from typing_extensions import Literal
|
|
22
24
|
|
|
23
25
|
from sky import backends
|
|
24
26
|
from sky import exceptions
|
|
@@ -37,6 +39,7 @@ from sky.usage import usage_lib
|
|
|
37
39
|
from sky.utils import annotations
|
|
38
40
|
from sky.utils import command_runner
|
|
39
41
|
from sky.utils import common_utils
|
|
42
|
+
from sky.utils import context_utils
|
|
40
43
|
from sky.utils import controller_utils
|
|
41
44
|
from sky.utils import infra_utils
|
|
42
45
|
from sky.utils import log_utils
|
|
@@ -56,9 +59,9 @@ else:
|
|
|
56
59
|
|
|
57
60
|
logger = sky_logging.init_logger(__name__)
|
|
58
61
|
|
|
59
|
-
SIGNAL_FILE_PREFIX = '/tmp/sky_jobs_controller_signal_{}'
|
|
60
62
|
# Controller checks its job's status every this many seconds.
|
|
61
|
-
|
|
63
|
+
# This is a tradeoff between the latency and the resource usage.
|
|
64
|
+
JOB_STATUS_CHECK_GAP_SECONDS = 15
|
|
62
65
|
|
|
63
66
|
# Controller checks if its job has started every this many seconds.
|
|
64
67
|
JOB_STARTED_STATUS_CHECK_GAP_SECONDS = 5
|
|
@@ -82,7 +85,7 @@ _JOB_CANCELLED_MESSAGE = (
|
|
|
82
85
|
# blocking for a long time. This should be significantly longer than the
|
|
83
86
|
# JOB_STATUS_CHECK_GAP_SECONDS to avoid timing out before the controller can
|
|
84
87
|
# update the state.
|
|
85
|
-
_FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS =
|
|
88
|
+
_FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS = 120
|
|
86
89
|
|
|
87
90
|
|
|
88
91
|
class ManagedJobQueueResultType(enum.Enum):
|
|
@@ -99,7 +102,11 @@ class UserSignal(enum.Enum):
|
|
|
99
102
|
|
|
100
103
|
|
|
101
104
|
# ====== internal functions ======
|
|
102
|
-
def terminate_cluster(
|
|
105
|
+
def terminate_cluster(
|
|
106
|
+
cluster_name: str,
|
|
107
|
+
max_retry: int = 6,
|
|
108
|
+
_logger: logging.Logger = logger, # pylint: disable=invalid-name
|
|
109
|
+
) -> None:
|
|
103
110
|
"""Terminate the cluster."""
|
|
104
111
|
from sky import core # pylint: disable=import-outside-toplevel
|
|
105
112
|
retry_cnt = 0
|
|
@@ -122,18 +129,18 @@ def terminate_cluster(cluster_name: str, max_retry: int = 6) -> None:
|
|
|
122
129
|
return
|
|
123
130
|
except exceptions.ClusterDoesNotExist:
|
|
124
131
|
# The cluster is already down.
|
|
125
|
-
|
|
132
|
+
_logger.debug(f'The cluster {cluster_name} is already down.')
|
|
126
133
|
return
|
|
127
134
|
except Exception as e: # pylint: disable=broad-except
|
|
128
135
|
retry_cnt += 1
|
|
129
136
|
if retry_cnt >= max_retry:
|
|
130
137
|
raise RuntimeError(
|
|
131
138
|
f'Failed to terminate the cluster {cluster_name}.') from e
|
|
132
|
-
|
|
139
|
+
_logger.error(
|
|
133
140
|
f'Failed to terminate the cluster {cluster_name}. Retrying.'
|
|
134
141
|
f'Details: {common_utils.format_exception(e)}')
|
|
135
142
|
with ux_utils.enable_traceback():
|
|
136
|
-
|
|
143
|
+
_logger.error(f' Traceback: {traceback.format_exc()}')
|
|
137
144
|
time.sleep(backoff.current_backoff())
|
|
138
145
|
|
|
139
146
|
|
|
@@ -183,6 +190,9 @@ def _validate_consolidation_mode_config(
|
|
|
183
190
|
# Use LRU Cache so that the check is only done once.
|
|
184
191
|
@annotations.lru_cache(scope='request', maxsize=1)
|
|
185
192
|
def is_consolidation_mode() -> bool:
|
|
193
|
+
if os.environ.get(constants.OVERRIDE_CONSOLIDATION_MODE) is not None:
|
|
194
|
+
return True
|
|
195
|
+
|
|
186
196
|
consolidation_mode = skypilot_config.get_nested(
|
|
187
197
|
('jobs', 'controller', 'consolidation_mode'), default_value=False)
|
|
188
198
|
# We should only do this check on API server, as the controller will not
|
|
@@ -199,6 +209,7 @@ def ha_recovery_for_consolidation_mode():
|
|
|
199
209
|
# already has all runtime installed. Directly start jobs recovery here.
|
|
200
210
|
# Refers to sky/templates/kubernetes-ray.yml.j2 for more details.
|
|
201
211
|
runner = command_runner.LocalProcessCommandRunner()
|
|
212
|
+
scheduler.maybe_start_controllers()
|
|
202
213
|
with open(constants.HA_PERSISTENT_RECOVERY_LOG_PATH.format('jobs_'),
|
|
203
214
|
'w',
|
|
204
215
|
encoding='utf-8') as f:
|
|
@@ -214,7 +225,7 @@ def ha_recovery_for_consolidation_mode():
|
|
|
214
225
|
# just keep running.
|
|
215
226
|
if controller_pid is not None:
|
|
216
227
|
try:
|
|
217
|
-
if
|
|
228
|
+
if controller_process_alive(controller_pid, job_id):
|
|
218
229
|
f.write(f'Controller pid {controller_pid} for '
|
|
219
230
|
f'job {job_id} is still running. '
|
|
220
231
|
'Skipping recovery.\n')
|
|
@@ -227,7 +238,7 @@ def ha_recovery_for_consolidation_mode():
|
|
|
227
238
|
|
|
228
239
|
if job['schedule_state'] not in [
|
|
229
240
|
managed_job_state.ManagedJobScheduleState.DONE,
|
|
230
|
-
managed_job_state.ManagedJobScheduleState.WAITING
|
|
241
|
+
managed_job_state.ManagedJobScheduleState.WAITING,
|
|
231
242
|
]:
|
|
232
243
|
script = managed_job_state.get_ha_recovery_script(job_id)
|
|
233
244
|
if script is None:
|
|
@@ -242,56 +253,66 @@ def ha_recovery_for_consolidation_mode():
|
|
|
242
253
|
f.write(f'Total recovery time: {time.time() - start} seconds\n')
|
|
243
254
|
|
|
244
255
|
|
|
245
|
-
def get_job_status(
|
|
246
|
-
|
|
256
|
+
async def get_job_status(
|
|
257
|
+
backend: 'backends.CloudVmRayBackend', cluster_name: str,
|
|
258
|
+
job_id: Optional[int],
|
|
259
|
+
job_logger: logging.Logger) -> Optional['job_lib.JobStatus']:
|
|
247
260
|
"""Check the status of the job running on a managed job cluster.
|
|
248
261
|
|
|
249
262
|
It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_DRIVER,
|
|
250
263
|
FAILED_SETUP or CANCELLED.
|
|
251
264
|
"""
|
|
252
|
-
|
|
265
|
+
# TODO(luca) make this async
|
|
266
|
+
handle = await context_utils.to_thread(
|
|
267
|
+
global_user_state.get_handle_from_cluster_name, cluster_name)
|
|
253
268
|
if handle is None:
|
|
254
269
|
# This can happen if the cluster was preempted and background status
|
|
255
270
|
# refresh already noticed and cleaned it up.
|
|
256
|
-
|
|
271
|
+
job_logger.info(f'Cluster {cluster_name} not found.')
|
|
257
272
|
return None
|
|
258
273
|
assert isinstance(handle, backends.CloudVmRayResourceHandle), handle
|
|
259
274
|
job_ids = None if job_id is None else [job_id]
|
|
260
275
|
for i in range(_JOB_STATUS_FETCH_MAX_RETRIES):
|
|
261
276
|
try:
|
|
262
|
-
|
|
263
|
-
statuses = backend.get_job_status
|
|
264
|
-
|
|
265
|
-
|
|
277
|
+
job_logger.info('=== Checking the job status... ===')
|
|
278
|
+
statuses = await context_utils.to_thread(backend.get_job_status,
|
|
279
|
+
handle,
|
|
280
|
+
job_ids=job_ids,
|
|
281
|
+
stream_logs=False)
|
|
266
282
|
status = list(statuses.values())[0]
|
|
267
283
|
if status is None:
|
|
268
|
-
|
|
284
|
+
job_logger.info('No job found.')
|
|
269
285
|
else:
|
|
270
|
-
|
|
271
|
-
|
|
286
|
+
job_logger.info(f'Job status: {status}')
|
|
287
|
+
job_logger.info('=' * 34)
|
|
272
288
|
return status
|
|
273
289
|
except exceptions.CommandError as e:
|
|
274
290
|
# Retry on k8s transient network errors. This is useful when using
|
|
275
291
|
# coreweave which may have transient network issue sometimes.
|
|
276
292
|
if (e.detailed_reason is not None and
|
|
277
293
|
_JOB_K8S_TRANSIENT_NW_MSG in e.detailed_reason):
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
294
|
+
job_logger.info('Failed to connect to the cluster. Retrying '
|
|
295
|
+
f'({i + 1}/{_JOB_STATUS_FETCH_MAX_RETRIES})...')
|
|
296
|
+
job_logger.info('=' * 34)
|
|
297
|
+
await asyncio.sleep(1)
|
|
282
298
|
else:
|
|
283
|
-
|
|
284
|
-
|
|
299
|
+
job_logger.info(
|
|
300
|
+
f'Failed to get job status: {e.detailed_reason}')
|
|
301
|
+
job_logger.info('=' * 34)
|
|
285
302
|
return None
|
|
286
303
|
return None
|
|
287
304
|
|
|
288
305
|
|
|
289
|
-
def
|
|
306
|
+
def controller_process_alive(pid: int, job_id: int) -> bool:
|
|
290
307
|
"""Check if the controller process is alive."""
|
|
291
308
|
try:
|
|
309
|
+
if pid < 0:
|
|
310
|
+
# new job controller process will always be negative
|
|
311
|
+
pid = -pid
|
|
292
312
|
process = psutil.Process(pid)
|
|
293
313
|
cmd_str = ' '.join(process.cmdline())
|
|
294
|
-
return process.is_running() and f'--job-id {job_id}' in cmd_str
|
|
314
|
+
return process.is_running() and ((f'--job-id {job_id}' in cmd_str) or
|
|
315
|
+
('controller' in cmd_str))
|
|
295
316
|
except psutil.NoSuchProcess:
|
|
296
317
|
return False
|
|
297
318
|
|
|
@@ -466,7 +487,7 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
|
|
|
466
487
|
failure_reason = f'No controller pid set for {schedule_state.value}'
|
|
467
488
|
else:
|
|
468
489
|
logger.debug(f'Checking controller pid {pid}')
|
|
469
|
-
if
|
|
490
|
+
if controller_process_alive(pid, job_id):
|
|
470
491
|
# The controller is still running, so this job is fine.
|
|
471
492
|
continue
|
|
472
493
|
|
|
@@ -604,7 +625,17 @@ def event_callback_func(job_id: int, task_id: int, task: 'sky.Task'):
|
|
|
604
625
|
f'Bash:{event_callback},log_path:{log_path},result:{result}')
|
|
605
626
|
logger.info(f'=== END: event callback for {status!r} ===')
|
|
606
627
|
|
|
607
|
-
|
|
628
|
+
try:
|
|
629
|
+
asyncio.get_running_loop()
|
|
630
|
+
|
|
631
|
+
# In async context
|
|
632
|
+
async def async_callback_func(status: str):
|
|
633
|
+
return await context_utils.to_thread(callback_func, status)
|
|
634
|
+
|
|
635
|
+
return async_callback_func
|
|
636
|
+
except RuntimeError:
|
|
637
|
+
# Not in async context
|
|
638
|
+
return callback_func
|
|
608
639
|
|
|
609
640
|
|
|
610
641
|
# ======== user functions ========
|
|
@@ -651,16 +682,41 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]],
|
|
|
651
682
|
logger.info(f'Job {job_id} is already in terminal state '
|
|
652
683
|
f'{job_status.value}. Skipped.')
|
|
653
684
|
continue
|
|
685
|
+
elif job_status == managed_job_state.ManagedJobStatus.PENDING:
|
|
686
|
+
# the if is a short circuit, this will be atomic.
|
|
687
|
+
cancelled = managed_job_state.set_pending_cancelled(job_id)
|
|
688
|
+
if cancelled:
|
|
689
|
+
cancelled_job_ids.append(job_id)
|
|
690
|
+
continue
|
|
654
691
|
|
|
655
692
|
update_managed_jobs_statuses(job_id)
|
|
656
693
|
|
|
694
|
+
job_controller_pid = managed_job_state.get_job_controller_pid(job_id)
|
|
695
|
+
if job_controller_pid is not None and job_controller_pid < 0:
|
|
696
|
+
# This is a consolidated job controller, so we need to cancel the
|
|
697
|
+
# with the controller server API
|
|
698
|
+
try:
|
|
699
|
+
# we create a file as a signal to the controller server
|
|
700
|
+
signal_file = pathlib.Path(
|
|
701
|
+
managed_job_constants.CONSOLIDATED_SIGNAL_PATH, f'{job_id}')
|
|
702
|
+
signal_file.touch()
|
|
703
|
+
cancelled_job_ids.append(job_id)
|
|
704
|
+
except OSError as e:
|
|
705
|
+
logger.error(f'Failed to cancel job {job_id} '
|
|
706
|
+
f'with controller server: {e}')
|
|
707
|
+
# don't add it to the to be cancelled job ids, since we don't
|
|
708
|
+
# know for sure yet.
|
|
709
|
+
continue
|
|
710
|
+
continue
|
|
711
|
+
|
|
657
712
|
job_workspace = managed_job_state.get_workspace(job_id)
|
|
658
713
|
if current_workspace is not None and job_workspace != current_workspace:
|
|
659
714
|
wrong_workspace_job_ids.append(job_id)
|
|
660
715
|
continue
|
|
661
716
|
|
|
662
717
|
# Send the signal to the jobs controller.
|
|
663
|
-
signal_file = pathlib.Path(
|
|
718
|
+
signal_file = (pathlib.Path(
|
|
719
|
+
managed_job_constants.SIGNAL_FILE_PREFIX.format(job_id)))
|
|
664
720
|
# Filelock is needed to prevent race condition between signal
|
|
665
721
|
# check/removal and signal writing.
|
|
666
722
|
with filelock.FileLock(str(signal_file) + '.lock'):
|
|
@@ -1159,8 +1215,7 @@ def dump_managed_job_queue(
|
|
|
1159
1215
|
# It's possible for a WAITING/ALIVE_WAITING job to be ready to
|
|
1160
1216
|
# launch, but the scheduler just hasn't run yet.
|
|
1161
1217
|
managed_job_state.ManagedJobScheduleState.WAITING,
|
|
1162
|
-
managed_job_state.ManagedJobScheduleState.ALIVE_WAITING
|
|
1163
|
-
):
|
|
1218
|
+
managed_job_state.ManagedJobScheduleState.ALIVE_WAITING):
|
|
1164
1219
|
# This job will not block others.
|
|
1165
1220
|
continue
|
|
1166
1221
|
|
|
@@ -1370,12 +1425,12 @@ def load_managed_job_queue(
|
|
|
1370
1425
|
"""Load job queue from json string."""
|
|
1371
1426
|
result = message_utils.decode_payload(payload)
|
|
1372
1427
|
result_type = ManagedJobQueueResultType.DICT
|
|
1373
|
-
status_counts = {}
|
|
1428
|
+
status_counts: Dict[str, int] = {}
|
|
1374
1429
|
if isinstance(result, dict):
|
|
1375
|
-
jobs = result['jobs']
|
|
1376
|
-
total = result['total']
|
|
1430
|
+
jobs: List[Dict[str, Any]] = result['jobs']
|
|
1431
|
+
total: int = result['total']
|
|
1377
1432
|
status_counts = result.get('status_counts', {})
|
|
1378
|
-
total_no_filter = result.get('total_no_filter', total)
|
|
1433
|
+
total_no_filter: int = result.get('total_no_filter', total)
|
|
1379
1434
|
else:
|
|
1380
1435
|
jobs = result
|
|
1381
1436
|
total = len(jobs)
|
sky/provision/aws/config.py
CHANGED
|
@@ -305,7 +305,10 @@ def _get_route_tables(ec2: 'mypy_boto3_ec2.ServiceResource',
|
|
|
305
305
|
Returns:
|
|
306
306
|
A list of route tables associated with the options VPC and region
|
|
307
307
|
"""
|
|
308
|
-
filters
|
|
308
|
+
filters: List['ec2_type_defs.FilterTypeDef'] = [{
|
|
309
|
+
'Name': 'association.main',
|
|
310
|
+
'Values': [str(main).lower()],
|
|
311
|
+
}]
|
|
309
312
|
if vpc_id is not None:
|
|
310
313
|
filters.append({'Name': 'vpc-id', 'Values': [vpc_id]})
|
|
311
314
|
logger.debug(
|
sky/provision/gcp/config.py
CHANGED
|
@@ -5,6 +5,8 @@ import time
|
|
|
5
5
|
import typing
|
|
6
6
|
from typing import Any, Dict, List, Set, Tuple
|
|
7
7
|
|
|
8
|
+
from typing_extensions import TypedDict
|
|
9
|
+
|
|
8
10
|
from sky.adaptors import gcp
|
|
9
11
|
from sky.clouds.utils import gcp_utils
|
|
10
12
|
from sky.provision import common
|
|
@@ -415,6 +417,9 @@ def _configure_iam_role(config: common.ProvisionConfig, crm, iam) -> dict:
|
|
|
415
417
|
return iam_role
|
|
416
418
|
|
|
417
419
|
|
|
420
|
+
AllowedList = TypedDict('AllowedList', {'IPProtocol': str, 'ports': List[str]})
|
|
421
|
+
|
|
422
|
+
|
|
418
423
|
def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str,
|
|
419
424
|
compute):
|
|
420
425
|
"""Check if the firewall rules in the VPC are sufficient."""
|
|
@@ -466,7 +471,7 @@ def _check_firewall_rules(cluster_name: str, vpc_name: str, project_id: str,
|
|
|
466
471
|
}
|
|
467
472
|
"""
|
|
468
473
|
source2rules: Dict[Tuple[str, str], Dict[str, Set[int]]] = {}
|
|
469
|
-
source2allowed_list: Dict[Tuple[str, str], List[
|
|
474
|
+
source2allowed_list: Dict[Tuple[str, str], List[AllowedList]] = {}
|
|
470
475
|
for rule in rules:
|
|
471
476
|
# Rules applied to specific VM (targetTags) may not work for the
|
|
472
477
|
# current VM, so should be skipped.
|
|
@@ -451,6 +451,9 @@ class CoreWeaveLabelFormatter(GPULabelFormatter):
|
|
|
451
451
|
|
|
452
452
|
LABEL_KEY = 'gpu.nvidia.com/class'
|
|
453
453
|
|
|
454
|
+
# TODO (kyuds): fill in more label values for different accelerators.
|
|
455
|
+
ACC_VALUE_MAPPINGS = {'H100_NVLINK_80GB': 'H100'}
|
|
456
|
+
|
|
454
457
|
@classmethod
|
|
455
458
|
def get_label_key(cls, accelerator: Optional[str] = None) -> str:
|
|
456
459
|
return cls.LABEL_KEY
|
|
@@ -469,7 +472,8 @@ class CoreWeaveLabelFormatter(GPULabelFormatter):
|
|
|
469
472
|
|
|
470
473
|
@classmethod
|
|
471
474
|
def get_accelerator_from_label_value(cls, value: str) -> str:
|
|
472
|
-
return value
|
|
475
|
+
# return original label value if not found in mappings.
|
|
476
|
+
return cls.ACC_VALUE_MAPPINGS.get(value, value)
|
|
473
477
|
|
|
474
478
|
|
|
475
479
|
class GKELabelFormatter(GPULabelFormatter):
|
|
@@ -1012,15 +1016,16 @@ class GKEAutoscaler(Autoscaler):
|
|
|
1012
1016
|
to fit the instance type.
|
|
1013
1017
|
"""
|
|
1014
1018
|
for accelerator in node_pool_accelerators:
|
|
1019
|
+
raw_value = accelerator['acceleratorType']
|
|
1015
1020
|
node_accelerator_type = (
|
|
1016
|
-
GKELabelFormatter.get_accelerator_from_label_value(
|
|
1017
|
-
accelerator['acceleratorType']))
|
|
1021
|
+
GKELabelFormatter.get_accelerator_from_label_value(raw_value))
|
|
1018
1022
|
# handle heterogenous nodes.
|
|
1019
1023
|
if not node_accelerator_type:
|
|
1020
1024
|
continue
|
|
1021
1025
|
node_accelerator_count = accelerator['acceleratorCount']
|
|
1022
|
-
|
|
1023
|
-
|
|
1026
|
+
viable_names = [node_accelerator_type.lower(), raw_value.lower()]
|
|
1027
|
+
if (requested_gpu_type.lower() in viable_names and
|
|
1028
|
+
int(node_accelerator_count) >= requested_gpu_count):
|
|
1024
1029
|
return True
|
|
1025
1030
|
return False
|
|
1026
1031
|
|
|
@@ -1448,9 +1453,13 @@ def get_accelerator_label_key_values(
|
|
|
1448
1453
|
if is_multi_host_tpu(node_metadata_labels):
|
|
1449
1454
|
continue
|
|
1450
1455
|
for label, value in label_list:
|
|
1451
|
-
if
|
|
1452
|
-
|
|
1453
|
-
|
|
1456
|
+
if label_formatter.match_label_key(label):
|
|
1457
|
+
# match either canonicalized name or raw name
|
|
1458
|
+
accelerator = (label_formatter.
|
|
1459
|
+
get_accelerator_from_label_value(value))
|
|
1460
|
+
viable = [value.lower(), accelerator.lower()]
|
|
1461
|
+
if acc_type.lower() not in viable:
|
|
1462
|
+
continue
|
|
1454
1463
|
if is_tpu_on_gke(acc_type):
|
|
1455
1464
|
assert isinstance(label_formatter,
|
|
1456
1465
|
GKELabelFormatter)
|
sky/provision/provisioner.py
CHANGED
sky/serve/replica_managers.py
CHANGED
|
@@ -22,7 +22,6 @@ from sky import global_user_state
|
|
|
22
22
|
from sky import sky_logging
|
|
23
23
|
from sky import task as task_lib
|
|
24
24
|
from sky.backends import backend_utils
|
|
25
|
-
from sky.jobs import scheduler as jobs_scheduler
|
|
26
25
|
from sky.serve import constants as serve_constants
|
|
27
26
|
from sky.serve import serve_state
|
|
28
27
|
from sky.serve import serve_utils
|
|
@@ -1052,7 +1051,6 @@ class SkyPilotReplicaManager(ReplicaManager):
|
|
|
1052
1051
|
self._service_name, replica_id)
|
|
1053
1052
|
assert info is not None, replica_id
|
|
1054
1053
|
error_in_sky_launch = False
|
|
1055
|
-
schedule_next_jobs = False
|
|
1056
1054
|
if info.status == serve_state.ReplicaStatus.PENDING:
|
|
1057
1055
|
# sky.launch not started yet
|
|
1058
1056
|
if controller_utils.can_provision():
|
|
@@ -1080,7 +1078,6 @@ class SkyPilotReplicaManager(ReplicaManager):
|
|
|
1080
1078
|
else:
|
|
1081
1079
|
info.status_property.sky_launch_status = (
|
|
1082
1080
|
common_utils.ProcessStatus.SUCCEEDED)
|
|
1083
|
-
schedule_next_jobs = True
|
|
1084
1081
|
if self._spot_placer is not None and info.is_spot:
|
|
1085
1082
|
# TODO(tian): Currently, we set the location to
|
|
1086
1083
|
# preemptive if the launch process failed. This is
|
|
@@ -1100,16 +1097,12 @@ class SkyPilotReplicaManager(ReplicaManager):
|
|
|
1100
1097
|
self._spot_placer.set_active(location)
|
|
1101
1098
|
serve_state.add_or_update_replica(self._service_name,
|
|
1102
1099
|
replica_id, info)
|
|
1103
|
-
if schedule_next_jobs and self._is_pool:
|
|
1104
|
-
jobs_scheduler.maybe_schedule_next_jobs()
|
|
1105
1100
|
if error_in_sky_launch:
|
|
1106
1101
|
# Teardown after update replica info since
|
|
1107
1102
|
# _terminate_replica will update the replica info too.
|
|
1108
1103
|
self._terminate_replica(replica_id,
|
|
1109
1104
|
sync_down_logs=True,
|
|
1110
1105
|
replica_drain_delay_seconds=0)
|
|
1111
|
-
# Try schedule next job after acquiring the lock.
|
|
1112
|
-
jobs_scheduler.maybe_schedule_next_jobs()
|
|
1113
1106
|
down_process_pool_snapshot = list(self._down_process_pool.items())
|
|
1114
1107
|
for replica_id, p in down_process_pool_snapshot:
|
|
1115
1108
|
if p.is_alive():
|
sky/serve/serve_utils.py
CHANGED
|
@@ -294,6 +294,11 @@ def is_consolidation_mode(pool: bool = False) -> bool:
|
|
|
294
294
|
# We should only do this check on API server, as the controller will not
|
|
295
295
|
# have related config and will always seemingly disabled for consolidation
|
|
296
296
|
# mode. Check #6611 for more details.
|
|
297
|
+
if (os.environ.get(skylet_constants.OVERRIDE_CONSOLIDATION_MODE) is not None
|
|
298
|
+
and controller.controller_type == 'jobs'):
|
|
299
|
+
# if we are in the job controller, we must always be in consolidation
|
|
300
|
+
# mode.
|
|
301
|
+
return True
|
|
297
302
|
if os.environ.get(skylet_constants.ENV_VAR_IS_SKYPILOT_SERVER) is not None:
|
|
298
303
|
_validate_consolidation_mode_config(consolidation_mode, pool)
|
|
299
304
|
return consolidation_mode
|
sky/serve/server/impl.py
CHANGED
|
@@ -280,8 +280,7 @@ def up(
|
|
|
280
280
|
]
|
|
281
281
|
run_script = '\n'.join(env_cmds + [run_script])
|
|
282
282
|
# Dump script for high availability recovery.
|
|
283
|
-
|
|
284
|
-
serve_state.set_ha_recovery_script(service_name, run_script)
|
|
283
|
+
serve_state.set_ha_recovery_script(service_name, run_script)
|
|
285
284
|
backend.run_on_head(controller_handle, run_script)
|
|
286
285
|
|
|
287
286
|
style = colorama.Style
|
sky/serve/service.py
CHANGED
|
@@ -21,7 +21,6 @@ from sky import task as task_lib
|
|
|
21
21
|
from sky.backends import backend_utils
|
|
22
22
|
from sky.backends import cloud_vm_ray_backend
|
|
23
23
|
from sky.data import data_utils
|
|
24
|
-
from sky.jobs import scheduler as jobs_scheduler
|
|
25
24
|
from sky.serve import constants
|
|
26
25
|
from sky.serve import controller
|
|
27
26
|
from sky.serve import load_balancer
|
|
@@ -278,7 +277,6 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int, entrypoint: str):
|
|
|
278
277
|
pool=service_spec.pool,
|
|
279
278
|
controller_pid=os.getpid(),
|
|
280
279
|
entrypoint=entrypoint)
|
|
281
|
-
jobs_scheduler.maybe_schedule_next_jobs()
|
|
282
280
|
# Directly throw an error here. See sky/serve/api.py::up
|
|
283
281
|
# for more details.
|
|
284
282
|
if not success:
|
sky/server/common.py
CHANGED
|
@@ -538,12 +538,17 @@ def _start_api_server(deploy: bool = False,
|
|
|
538
538
|
|
|
539
539
|
# Check available memory before starting the server.
|
|
540
540
|
avail_mem_size_gb: float = common_utils.get_mem_size_gb()
|
|
541
|
-
|
|
541
|
+
# pylint: disable=import-outside-toplevel
|
|
542
|
+
import sky.jobs.utils as job_utils
|
|
543
|
+
max_memory = (server_constants.MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE
|
|
544
|
+
if job_utils.is_consolidation_mode() else
|
|
545
|
+
server_constants.MIN_AVAIL_MEM_GB)
|
|
546
|
+
if avail_mem_size_gb <= max_memory:
|
|
542
547
|
logger.warning(
|
|
543
548
|
f'{colorama.Fore.YELLOW}Your SkyPilot API server machine only '
|
|
544
549
|
f'has {avail_mem_size_gb:.1f}GB memory available. '
|
|
545
|
-
f'At least {
|
|
546
|
-
'
|
|
550
|
+
f'At least {max_memory}GB is recommended to support higher '
|
|
551
|
+
'load with better performance.'
|
|
547
552
|
f'{colorama.Style.RESET_ALL}')
|
|
548
553
|
|
|
549
554
|
args = [sys.executable, *API_SERVER_CMD.split()]
|
sky/server/config.py
CHANGED
|
@@ -19,8 +19,9 @@ from sky.utils import common_utils
|
|
|
19
19
|
# TODO(aylei): maintaining these constants is error-prone, we may need to
|
|
20
20
|
# automatically tune parallelism at runtime according to system usage stats
|
|
21
21
|
# in the future.
|
|
22
|
-
|
|
23
|
-
|
|
22
|
+
# TODO(luca): The future is now! ^^^
|
|
23
|
+
LONG_WORKER_MEM_GB = 0.4
|
|
24
|
+
SHORT_WORKER_MEM_GB = 0.25
|
|
24
25
|
# To control the number of long workers.
|
|
25
26
|
_CPU_MULTIPLIER_FOR_LONG_WORKERS = 2
|
|
26
27
|
# Limit the number of long workers of local API server, since local server is
|
|
@@ -75,8 +76,8 @@ class ServerConfig:
|
|
|
75
76
|
|
|
76
77
|
|
|
77
78
|
def compute_server_config(deploy: bool,
|
|
78
|
-
max_db_connections: Optional[int] = None
|
|
79
|
-
|
|
79
|
+
max_db_connections: Optional[int] = None,
|
|
80
|
+
quiet: bool = False) -> ServerConfig:
|
|
80
81
|
"""Compute the server config based on environment.
|
|
81
82
|
|
|
82
83
|
We have different assumptions for the resources in different deployment
|
|
@@ -140,7 +141,12 @@ def compute_server_config(deploy: bool,
|
|
|
140
141
|
burstable_parallel_for_short = _BURSTABLE_WORKERS_FOR_LOCAL
|
|
141
142
|
# Runs in low resource mode if the available memory is less than
|
|
142
143
|
# server_constants.MIN_AVAIL_MEM_GB.
|
|
143
|
-
|
|
144
|
+
# pylint: disable=import-outside-toplevel
|
|
145
|
+
import sky.jobs.utils as job_utils
|
|
146
|
+
max_memory = (server_constants.MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE
|
|
147
|
+
if job_utils.is_consolidation_mode() else
|
|
148
|
+
server_constants.MIN_AVAIL_MEM_GB)
|
|
149
|
+
if not deploy and mem_size_gb < max_memory:
|
|
144
150
|
# Permanent worker process may have significant memory consumption
|
|
145
151
|
# (~350MB per worker) after running commands like `sky check`, so we
|
|
146
152
|
# don't start any permanent workers in low resource local mode. This
|
|
@@ -151,25 +157,29 @@ def compute_server_config(deploy: bool,
|
|
|
151
157
|
# permanently because it never exits.
|
|
152
158
|
max_parallel_for_long = 0
|
|
153
159
|
max_parallel_for_short = 0
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
160
|
+
if not quiet:
|
|
161
|
+
logger.warning(
|
|
162
|
+
'SkyPilot API server will run in low resource mode because '
|
|
163
|
+
'the available memory is less than '
|
|
164
|
+
f'{server_constants.MIN_AVAIL_MEM_GB}GB.')
|
|
158
165
|
elif max_db_connections is not None:
|
|
159
166
|
if max_parallel_all_workers > max_db_connections:
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
167
|
+
if not quiet:
|
|
168
|
+
logger.warning(
|
|
169
|
+
f'Max parallel all workers ({max_parallel_all_workers}) '
|
|
170
|
+
'is greater than max db connections '
|
|
171
|
+
f'({max_db_connections}). Increase the number of max db '
|
|
172
|
+
f'connections to at least {max_parallel_all_workers} for '
|
|
173
|
+
'optimal performance.')
|
|
165
174
|
else:
|
|
166
175
|
num_db_connections_per_worker = 1
|
|
167
176
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
177
|
+
if not quiet:
|
|
178
|
+
logger.info(
|
|
179
|
+
f'SkyPilot API server will start {num_server_workers} server '
|
|
180
|
+
f'processes with {max_parallel_for_long} background workers for '
|
|
181
|
+
f'long requests and will allow at max {max_parallel_for_short} '
|
|
182
|
+
'short requests in parallel.')
|
|
173
183
|
return ServerConfig(
|
|
174
184
|
num_server_workers=num_server_workers,
|
|
175
185
|
queue_backend=queue_backend,
|
|
@@ -190,10 +200,15 @@ def _max_long_worker_parallism(cpu_count: int,
|
|
|
190
200
|
local=False) -> int:
|
|
191
201
|
"""Max parallelism for long workers."""
|
|
192
202
|
# Reserve min available memory to avoid OOM.
|
|
193
|
-
|
|
203
|
+
# pylint: disable=import-outside-toplevel
|
|
204
|
+
import sky.jobs.utils as job_utils
|
|
205
|
+
max_memory = (server_constants.MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE
|
|
206
|
+
if job_utils.is_consolidation_mode() else
|
|
207
|
+
server_constants.MIN_AVAIL_MEM_GB)
|
|
208
|
+
available_mem = max(0, mem_size_gb - max_memory)
|
|
194
209
|
cpu_based_max_parallel = cpu_count * _CPU_MULTIPLIER_FOR_LONG_WORKERS
|
|
195
210
|
mem_based_max_parallel = int(available_mem * _MAX_MEM_PERCENT_FOR_BLOCKING /
|
|
196
|
-
|
|
211
|
+
LONG_WORKER_MEM_GB)
|
|
197
212
|
n = max(_MIN_LONG_WORKERS,
|
|
198
213
|
min(cpu_based_max_parallel, mem_based_max_parallel))
|
|
199
214
|
if local:
|
|
@@ -205,8 +220,12 @@ def _max_short_worker_parallism(mem_size_gb: float,
|
|
|
205
220
|
long_worker_parallism: int) -> int:
|
|
206
221
|
"""Max parallelism for short workers."""
|
|
207
222
|
# Reserve memory for long workers and min available memory.
|
|
208
|
-
|
|
209
|
-
|
|
223
|
+
# pylint: disable=import-outside-toplevel
|
|
224
|
+
import sky.jobs.utils as job_utils
|
|
225
|
+
max_memory = (server_constants.MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE
|
|
226
|
+
if job_utils.is_consolidation_mode() else
|
|
227
|
+
server_constants.MIN_AVAIL_MEM_GB)
|
|
228
|
+
reserved_mem = max_memory + (long_worker_parallism * LONG_WORKER_MEM_GB)
|
|
210
229
|
available_mem = max(0, mem_size_gb - reserved_mem)
|
|
211
|
-
n = max(_MIN_SHORT_WORKERS, int(available_mem /
|
|
230
|
+
n = max(_MIN_SHORT_WORKERS, int(available_mem / SHORT_WORKER_MEM_GB))
|
|
212
231
|
return n
|
sky/server/constants.py
CHANGED
|
@@ -34,6 +34,7 @@ VERSION_HEADER = 'X-SkyPilot-Version'
|
|
|
34
34
|
REQUEST_NAME_PREFIX = 'sky.'
|
|
35
35
|
# The memory (GB) that SkyPilot tries to not use to prevent OOM.
|
|
36
36
|
MIN_AVAIL_MEM_GB = 2
|
|
37
|
+
MIN_AVAIL_MEM_GB_CONSOLIDATION_MODE = 4
|
|
37
38
|
# Default encoder/decoder handler name.
|
|
38
39
|
DEFAULT_HANDLER_NAME = 'default'
|
|
39
40
|
# The path to the API request database.
|