skypilot-nightly 1.0.0.dev20250611__py3-none-any.whl → 1.0.0.dev20250613__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +2 -2
- sky/adaptors/kubernetes.py +3 -2
- sky/backends/backend_utils.py +8 -2
- sky/benchmark/benchmark_state.py +2 -1
- sky/catalog/data_fetchers/fetch_aws.py +1 -1
- sky/catalog/data_fetchers/fetch_vast.py +1 -1
- sky/check.py +43 -3
- sky/cli.py +1 -1
- sky/client/cli.py +1 -1
- sky/clouds/cloud.py +1 -1
- sky/clouds/gcp.py +1 -1
- sky/clouds/kubernetes.py +9 -3
- sky/clouds/ssh.py +7 -3
- sky/dashboard/out/404.html +1 -1
- sky/dashboard/out/_next/static/chunks/{webpack-208a9812ab4f61c9.js → webpack-5c3e6471d04780c6.js} +1 -1
- sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
- sky/dashboard/out/clusters/[cluster].html +1 -1
- sky/dashboard/out/clusters.html +1 -1
- sky/dashboard/out/config.html +1 -1
- sky/dashboard/out/index.html +1 -1
- sky/dashboard/out/infra/[context].html +1 -1
- sky/dashboard/out/infra.html +1 -1
- sky/dashboard/out/jobs/[job].html +1 -1
- sky/dashboard/out/jobs.html +1 -1
- sky/dashboard/out/users.html +1 -1
- sky/dashboard/out/workspace/new.html +1 -1
- sky/dashboard/out/workspaces/[name].html +1 -1
- sky/dashboard/out/workspaces.html +1 -1
- sky/data/storage.py +2 -2
- sky/global_user_state.py +38 -0
- sky/jobs/server/core.py +1 -68
- sky/jobs/state.py +43 -44
- sky/provision/common.py +1 -1
- sky/provision/gcp/config.py +1 -1
- sky/provision/kubernetes/instance.py +2 -1
- sky/provision/kubernetes/utils.py +60 -13
- sky/resources.py +2 -2
- sky/serve/serve_state.py +81 -15
- sky/server/requests/preconditions.py +1 -1
- sky/server/requests/requests.py +11 -6
- sky/skylet/configs.py +26 -19
- sky/skylet/job_lib.py +3 -5
- sky/task.py +1 -1
- sky/templates/jobs-controller.yaml.j2 +0 -23
- sky/templates/kubernetes-ray.yml.j2 +1 -1
- sky/utils/common_utils.py +6 -0
- sky/utils/context.py +1 -1
- sky/utils/controller_utils.py +10 -0
- sky/utils/infra_utils.py +1 -1
- sky/utils/kubernetes/generate_kubeconfig.sh +1 -1
- {skypilot_nightly-1.0.0.dev20250611.dist-info → skypilot_nightly-1.0.0.dev20250613.dist-info}/METADATA +1 -1
- {skypilot_nightly-1.0.0.dev20250611.dist-info → skypilot_nightly-1.0.0.dev20250613.dist-info}/RECORD +58 -62
- sky/jobs/dashboard/dashboard.py +0 -223
- sky/jobs/dashboard/static/favicon.ico +0 -0
- sky/jobs/dashboard/templates/index.html +0 -831
- sky/jobs/server/dashboard_utils.py +0 -69
- /sky/dashboard/out/_next/static/{zJqasksBQ3HcqMpA2wTUZ → UdgJCk2sZFLJgFJW_qiWG}/_buildManifest.js +0 -0
- /sky/dashboard/out/_next/static/{zJqasksBQ3HcqMpA2wTUZ → UdgJCk2sZFLJgFJW_qiWG}/_ssgManifest.js +0 -0
- {skypilot_nightly-1.0.0.dev20250611.dist-info → skypilot_nightly-1.0.0.dev20250613.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20250611.dist-info → skypilot_nightly-1.0.0.dev20250613.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250611.dist-info → skypilot_nightly-1.0.0.dev20250613.dist-info}/licenses/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250611.dist-info → skypilot_nightly-1.0.0.dev20250613.dist-info}/top_level.txt +0 -0
sky/global_user_state.py
CHANGED
@@ -45,6 +45,7 @@ if typing.TYPE_CHECKING:
|
|
45
45
|
logger = sky_logging.init_logger(__name__)
|
46
46
|
|
47
47
|
_ENABLED_CLOUDS_KEY_PREFIX = 'enabled_clouds_'
|
48
|
+
_ALLOWED_CLOUDS_KEY_PREFIX = 'allowed_clouds_'
|
48
49
|
|
49
50
|
_SQLALCHEMY_ENGINE: Optional[sqlalchemy.engine.Engine] = None
|
50
51
|
_DB_INIT_LOCK = threading.Lock()
|
@@ -1087,6 +1088,43 @@ def _get_enabled_clouds_key(cloud_capability: 'cloud.CloudCapability',
|
|
1087
1088
|
return _ENABLED_CLOUDS_KEY_PREFIX + workspace + '_' + cloud_capability.value
|
1088
1089
|
|
1089
1090
|
|
1091
|
+
@_init_db
|
1092
|
+
def get_allowed_clouds(workspace: str) -> List[str]:
|
1093
|
+
assert _SQLALCHEMY_ENGINE is not None
|
1094
|
+
with orm.Session(_SQLALCHEMY_ENGINE) as session:
|
1095
|
+
row = session.query(config_table).filter_by(
|
1096
|
+
key=_get_allowed_clouds_key(workspace)).first()
|
1097
|
+
if row:
|
1098
|
+
return json.loads(row.value)
|
1099
|
+
return []
|
1100
|
+
|
1101
|
+
|
1102
|
+
@_init_db
|
1103
|
+
def set_allowed_clouds(allowed_clouds: List[str], workspace: str) -> None:
|
1104
|
+
assert _SQLALCHEMY_ENGINE is not None
|
1105
|
+
with orm.Session(_SQLALCHEMY_ENGINE) as session:
|
1106
|
+
if (_SQLALCHEMY_ENGINE.dialect.name ==
|
1107
|
+
db_utils.SQLAlchemyDialect.SQLITE.value):
|
1108
|
+
insert_func = sqlite.insert
|
1109
|
+
elif (_SQLALCHEMY_ENGINE.dialect.name ==
|
1110
|
+
db_utils.SQLAlchemyDialect.POSTGRESQL.value):
|
1111
|
+
insert_func = postgresql.insert
|
1112
|
+
else:
|
1113
|
+
raise ValueError('Unsupported database dialect')
|
1114
|
+
insert_stmnt = insert_func(config_table).values(
|
1115
|
+
key=_get_allowed_clouds_key(workspace),
|
1116
|
+
value=json.dumps(allowed_clouds))
|
1117
|
+
do_update_stmt = insert_stmnt.on_conflict_do_update(
|
1118
|
+
index_elements=[config_table.c.key],
|
1119
|
+
set_={config_table.c.value: json.dumps(allowed_clouds)})
|
1120
|
+
session.execute(do_update_stmt)
|
1121
|
+
session.commit()
|
1122
|
+
|
1123
|
+
|
1124
|
+
def _get_allowed_clouds_key(workspace: str) -> str:
|
1125
|
+
return _ALLOWED_CLOUDS_KEY_PREFIX + workspace
|
1126
|
+
|
1127
|
+
|
1090
1128
|
@_init_db
|
1091
1129
|
def add_or_update_storage(storage_name: str,
|
1092
1130
|
storage_handle: 'Storage.StorageMetadata',
|
sky/jobs/server/core.py
CHANGED
@@ -1,9 +1,6 @@
|
|
1
1
|
"""SDK functions for managed jobs."""
|
2
2
|
import os
|
3
|
-
import signal
|
4
|
-
import subprocess
|
5
3
|
import tempfile
|
6
|
-
import time
|
7
4
|
import typing
|
8
5
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
9
6
|
import uuid
|
@@ -213,8 +210,6 @@ def launch(
|
|
213
210
|
'remote_env_file_path': remote_env_file_path,
|
214
211
|
'modified_catalogs':
|
215
212
|
service_catalog_common.get_modified_catalog_file_mounts(),
|
216
|
-
'dashboard_setup_cmd': managed_job_constants.DASHBOARD_SETUP_CMD,
|
217
|
-
'dashboard_user_id': common.SERVER_ID,
|
218
213
|
'priority': priority,
|
219
214
|
**controller_utils.shared_controller_vars_to_fill(
|
220
215
|
controller,
|
@@ -368,20 +363,7 @@ def _maybe_restart_controller(
|
|
368
363
|
skylet_constants.SKYPILOT_DEFAULT_WORKSPACE):
|
369
364
|
handle = core.start(
|
370
365
|
cluster_name=jobs_controller_type.value.cluster_name)
|
371
|
-
|
372
|
-
# We should not directly use execution.launch() and have the dashboard cmd
|
373
|
-
# in the task setup because since we are using detached_setup, it will
|
374
|
-
# become a job on controller which messes up the job IDs (we assume the
|
375
|
-
# job ID in controller's job queue is consistent with managed job IDs).
|
376
|
-
with rich_utils.safe_status(
|
377
|
-
ux_utils.spinner_message('Starting dashboard...')):
|
378
|
-
runner = handle.get_command_runners()[0]
|
379
|
-
runner.run(
|
380
|
-
f'export '
|
381
|
-
f'{skylet_constants.USER_ID_ENV_VAR}={common.SERVER_ID!r}; '
|
382
|
-
f'{managed_job_constants.DASHBOARD_SETUP_CMD}',
|
383
|
-
stream_logs=True,
|
384
|
-
)
|
366
|
+
|
385
367
|
controller_status = status_lib.ClusterStatus.UP
|
386
368
|
rich_utils.force_update_status(ux_utils.spinner_message(spinner_message))
|
387
369
|
|
@@ -598,55 +580,6 @@ def tail_logs(name: Optional[str],
|
|
598
580
|
tail=tail)
|
599
581
|
|
600
582
|
|
601
|
-
def start_dashboard_forwarding(refresh: bool = False) -> Tuple[int, int]:
|
602
|
-
"""Opens a dashboard for managed jobs (needs controller to be UP)."""
|
603
|
-
# TODO(SKY-1212): ideally, the controller/dashboard server should expose the
|
604
|
-
# API perhaps via REST. Then here we would (1) not have to use SSH to try to
|
605
|
-
# see if the controller is UP first, which is slow; (2) not have to run SSH
|
606
|
-
# port forwarding first (we'd just launch a local dashboard which would make
|
607
|
-
# REST API calls to the controller dashboard server).
|
608
|
-
logger.info('Starting dashboard')
|
609
|
-
hint = ('Dashboard is not available if jobs controller is not up. Run '
|
610
|
-
'a managed job first or run: sky jobs queue --refresh')
|
611
|
-
handle = _maybe_restart_controller(
|
612
|
-
refresh=refresh,
|
613
|
-
stopped_message=hint,
|
614
|
-
spinner_message='Checking jobs controller')
|
615
|
-
|
616
|
-
# SSH forward a free local port to remote's dashboard port.
|
617
|
-
remote_port = skylet_constants.SPOT_DASHBOARD_REMOTE_PORT
|
618
|
-
free_port = common_utils.find_free_port(remote_port)
|
619
|
-
runner = handle.get_command_runners()[0]
|
620
|
-
port_forward_command = ' '.join(
|
621
|
-
runner.port_forward_command(port_forward=[(free_port, remote_port)],
|
622
|
-
connect_timeout=1))
|
623
|
-
port_forward_command = (
|
624
|
-
f'{port_forward_command} '
|
625
|
-
f'> ~/sky_logs/api_server/dashboard-{common_utils.get_user_hash()}.log '
|
626
|
-
'2>&1')
|
627
|
-
logger.info(f'Forwarding port: {colorama.Style.DIM}{port_forward_command}'
|
628
|
-
f'{colorama.Style.RESET_ALL}')
|
629
|
-
|
630
|
-
ssh_process = subprocess.Popen(port_forward_command,
|
631
|
-
shell=True,
|
632
|
-
start_new_session=True)
|
633
|
-
time.sleep(3) # Added delay for ssh_command to initialize.
|
634
|
-
logger.info(f'{colorama.Fore.GREEN}Dashboard is now available at: '
|
635
|
-
f'http://127.0.0.1:{free_port}{colorama.Style.RESET_ALL}')
|
636
|
-
|
637
|
-
return free_port, ssh_process.pid
|
638
|
-
|
639
|
-
|
640
|
-
def stop_dashboard_forwarding(pid: int) -> None:
|
641
|
-
# Exit the ssh command when the context manager is closed.
|
642
|
-
try:
|
643
|
-
os.killpg(os.getpgid(pid), signal.SIGTERM)
|
644
|
-
except ProcessLookupError:
|
645
|
-
# This happens if jobs controller is auto-stopped.
|
646
|
-
pass
|
647
|
-
logger.info('Forwarding port closed. Exiting.')
|
648
|
-
|
649
|
-
|
650
583
|
@usage_lib.entrypoint
|
651
584
|
def download_logs(
|
652
585
|
name: Optional[str],
|
sky/jobs/state.py
CHANGED
@@ -161,8 +161,8 @@ def create_table(cursor, conn):
|
|
161
161
|
conn.commit()
|
162
162
|
|
163
163
|
|
164
|
-
# Module-level connection/cursor; thread-safe as the
|
165
|
-
#
|
164
|
+
# Module-level connection/cursor; thread-safe as the db is initialized once
|
165
|
+
# across all threads.
|
166
166
|
def _get_db_path() -> str:
|
167
167
|
"""Workaround to collapse multi-step Path ops for type checker.
|
168
168
|
Ensures _DB_PATH is str, avoiding Union[Path, str] inference.
|
@@ -173,8 +173,7 @@ def _get_db_path() -> str:
|
|
173
173
|
return str(path)
|
174
174
|
|
175
175
|
|
176
|
-
_DB_PATH =
|
177
|
-
_db_initialized = False
|
176
|
+
_DB_PATH = None
|
178
177
|
_db_init_lock = threading.Lock()
|
179
178
|
|
180
179
|
|
@@ -183,13 +182,13 @@ def _init_db(func):
|
|
183
182
|
|
184
183
|
@functools.wraps(func)
|
185
184
|
def wrapper(*args, **kwargs):
|
186
|
-
global
|
187
|
-
if
|
185
|
+
global _DB_PATH
|
186
|
+
if _DB_PATH is not None:
|
188
187
|
return func(*args, **kwargs)
|
189
188
|
with _db_init_lock:
|
190
|
-
if
|
189
|
+
if _DB_PATH is None:
|
190
|
+
_DB_PATH = _get_db_path()
|
191
191
|
db_utils.SQLiteConn(_DB_PATH, create_table)
|
192
|
-
_db_initialized = True
|
193
192
|
return func(*args, **kwargs)
|
194
193
|
|
195
194
|
return wrapper
|
@@ -442,7 +441,7 @@ class ManagedJobScheduleState(enum.Enum):
|
|
442
441
|
# === Status transition functions ===
|
443
442
|
@_init_db
|
444
443
|
def set_job_info(job_id: int, name: str, workspace: str, entrypoint: str):
|
445
|
-
assert
|
444
|
+
assert _DB_PATH is not None
|
446
445
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
447
446
|
cursor.execute(
|
448
447
|
"""\
|
@@ -456,7 +455,7 @@ def set_job_info(job_id: int, name: str, workspace: str, entrypoint: str):
|
|
456
455
|
@_init_db
|
457
456
|
def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str):
|
458
457
|
"""Set the task to pending state."""
|
459
|
-
assert
|
458
|
+
assert _DB_PATH is not None
|
460
459
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
461
460
|
cursor.execute(
|
462
461
|
"""\
|
@@ -484,7 +483,7 @@ def set_starting(job_id: int, task_id: int, run_timestamp: str,
|
|
484
483
|
specs: The specs of the managed task.
|
485
484
|
callback_func: The callback function.
|
486
485
|
"""
|
487
|
-
assert
|
486
|
+
assert _DB_PATH is not None
|
488
487
|
# Use the timestamp in the `run_timestamp` ('sky-2022-10...'), to make
|
489
488
|
# the log directory and submission time align with each other, so as to
|
490
489
|
# make it easier to find them based on one of the values.
|
@@ -524,7 +523,7 @@ def set_backoff_pending(job_id: int, task_id: int):
|
|
524
523
|
This should only be used to transition from STARTING or RECOVERING back to
|
525
524
|
PENDING.
|
526
525
|
"""
|
527
|
-
assert
|
526
|
+
assert _DB_PATH is not None
|
528
527
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
529
528
|
cursor.execute(
|
530
529
|
"""\
|
@@ -552,7 +551,7 @@ def set_restarting(job_id: int, task_id: int, recovering: bool):
|
|
552
551
|
after using set_backoff_pending to transition back to PENDING during
|
553
552
|
launch retry backoff.
|
554
553
|
"""
|
555
|
-
assert
|
554
|
+
assert _DB_PATH is not None
|
556
555
|
target_status = ManagedJobStatus.STARTING.value
|
557
556
|
if recovering:
|
558
557
|
target_status = ManagedJobStatus.RECOVERING.value
|
@@ -578,7 +577,7 @@ def set_restarting(job_id: int, task_id: int, recovering: bool):
|
|
578
577
|
def set_started(job_id: int, task_id: int, start_time: float,
|
579
578
|
callback_func: CallbackType):
|
580
579
|
"""Set the task to started state."""
|
581
|
-
assert
|
580
|
+
assert _DB_PATH is not None
|
582
581
|
logger.info('Job started.')
|
583
582
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
584
583
|
cursor.execute(
|
@@ -610,7 +609,7 @@ def set_started(job_id: int, task_id: int, start_time: float,
|
|
610
609
|
@_init_db
|
611
610
|
def set_recovering(job_id: int, task_id: int, callback_func: CallbackType):
|
612
611
|
"""Set the task to recovering state, and update the job duration."""
|
613
|
-
assert
|
612
|
+
assert _DB_PATH is not None
|
614
613
|
logger.info('=== Recovering... ===')
|
615
614
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
616
615
|
cursor.execute(
|
@@ -634,7 +633,7 @@ def set_recovering(job_id: int, task_id: int, callback_func: CallbackType):
|
|
634
633
|
def set_recovered(job_id: int, task_id: int, recovered_time: float,
|
635
634
|
callback_func: CallbackType):
|
636
635
|
"""Set the task to recovered."""
|
637
|
-
assert
|
636
|
+
assert _DB_PATH is not None
|
638
637
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
639
638
|
cursor.execute(
|
640
639
|
"""\
|
@@ -658,7 +657,7 @@ def set_recovered(job_id: int, task_id: int, recovered_time: float,
|
|
658
657
|
def set_succeeded(job_id: int, task_id: int, end_time: float,
|
659
658
|
callback_func: CallbackType):
|
660
659
|
"""Set the task to succeeded, if it is in a non-terminal state."""
|
661
|
-
assert
|
660
|
+
assert _DB_PATH is not None
|
662
661
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
663
662
|
cursor.execute(
|
664
663
|
"""\
|
@@ -703,7 +702,7 @@ def set_failed(
|
|
703
702
|
override_terminal: If True, override the current status even if end_at
|
704
703
|
is already set.
|
705
704
|
"""
|
706
|
-
assert
|
705
|
+
assert _DB_PATH is not None
|
707
706
|
assert failure_type.is_failed(), failure_type
|
708
707
|
end_time = time.time() if end_time is None else end_time
|
709
708
|
|
@@ -761,7 +760,7 @@ def set_cancelling(job_id: int, callback_func: CallbackType):
|
|
761
760
|
task_id is not needed, because we expect the job should be cancelled
|
762
761
|
as a whole, and we should not cancel a single task.
|
763
762
|
"""
|
764
|
-
assert
|
763
|
+
assert _DB_PATH is not None
|
765
764
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
766
765
|
rows = cursor.execute(
|
767
766
|
"""\
|
@@ -783,7 +782,7 @@ def set_cancelled(job_id: int, callback_func: CallbackType):
|
|
783
782
|
|
784
783
|
The set_cancelling should be called before this function.
|
785
784
|
"""
|
786
|
-
assert
|
785
|
+
assert _DB_PATH is not None
|
787
786
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
788
787
|
rows = cursor.execute(
|
789
788
|
"""\
|
@@ -804,7 +803,7 @@ def set_cancelled(job_id: int, callback_func: CallbackType):
|
|
804
803
|
def set_local_log_file(job_id: int, task_id: Optional[int],
|
805
804
|
local_log_file: str):
|
806
805
|
"""Set the local log file for a job."""
|
807
|
-
assert
|
806
|
+
assert _DB_PATH is not None
|
808
807
|
filter_str = 'spot_job_id=(?)'
|
809
808
|
filter_args = [local_log_file, job_id]
|
810
809
|
|
@@ -822,7 +821,7 @@ def set_local_log_file(job_id: int, task_id: Optional[int],
|
|
822
821
|
def get_nonterminal_job_ids_by_name(name: Optional[str],
|
823
822
|
all_users: bool = False) -> List[int]:
|
824
823
|
"""Get non-terminal job ids by name."""
|
825
|
-
assert
|
824
|
+
assert _DB_PATH is not None
|
826
825
|
statuses = ', '.join(['?'] * len(ManagedJobStatus.terminal_statuses()))
|
827
826
|
field_values = [
|
828
827
|
status.value for status in ManagedJobStatus.terminal_statuses()
|
@@ -866,7 +865,7 @@ def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]:
|
|
866
865
|
exception: the job may have just transitioned from WAITING to LAUNCHING, but
|
867
866
|
the controller process has not yet started.
|
868
867
|
"""
|
869
|
-
assert
|
868
|
+
assert _DB_PATH is not None
|
870
869
|
job_filter = '' if job_id is None else 'AND spot_job_id=(?)'
|
871
870
|
job_value = (job_id,) if job_id is not None else ()
|
872
871
|
|
@@ -909,7 +908,7 @@ def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
|
|
909
908
|
- Jobs have schedule_state DONE but are in a non-terminal status
|
910
909
|
- Legacy jobs (that is, no schedule state) that are in non-terminal status
|
911
910
|
"""
|
912
|
-
assert
|
911
|
+
assert _DB_PATH is not None
|
913
912
|
job_filter = '' if job_id is None else 'AND spot.spot_job_id=(?)'
|
914
913
|
job_value = () if job_id is None else (job_id,)
|
915
914
|
|
@@ -958,7 +957,7 @@ def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
|
|
958
957
|
@_init_db
|
959
958
|
def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
|
960
959
|
"""Get all job ids by name."""
|
961
|
-
assert
|
960
|
+
assert _DB_PATH is not None
|
962
961
|
name_filter = ''
|
963
962
|
field_values = []
|
964
963
|
if name is not None:
|
@@ -987,7 +986,7 @@ def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
|
|
987
986
|
@_init_db
|
988
987
|
def _get_all_task_ids_statuses(
|
989
988
|
job_id: int) -> List[Tuple[int, ManagedJobStatus]]:
|
990
|
-
assert
|
989
|
+
assert _DB_PATH is not None
|
991
990
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
992
991
|
id_statuses = cursor.execute(
|
993
992
|
"""\
|
@@ -1035,7 +1034,7 @@ def get_failure_reason(job_id: int) -> Optional[str]:
|
|
1035
1034
|
|
1036
1035
|
If the job has multiple tasks, we return the first failure reason.
|
1037
1036
|
"""
|
1038
|
-
assert
|
1037
|
+
assert _DB_PATH is not None
|
1039
1038
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1040
1039
|
reason = cursor.execute(
|
1041
1040
|
"""\
|
@@ -1051,7 +1050,7 @@ def get_failure_reason(job_id: int) -> Optional[str]:
|
|
1051
1050
|
@_init_db
|
1052
1051
|
def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
|
1053
1052
|
"""Get managed jobs from the database."""
|
1054
|
-
assert
|
1053
|
+
assert _DB_PATH is not None
|
1055
1054
|
job_filter = '' if job_id is None else f'WHERE spot.spot_job_id={job_id}'
|
1056
1055
|
|
1057
1056
|
# Join spot and job_info tables to get the job name for each task.
|
@@ -1097,7 +1096,7 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
1097
1096
|
@_init_db
|
1098
1097
|
def get_task_name(job_id: int, task_id: int) -> str:
|
1099
1098
|
"""Get the task name of a job."""
|
1100
|
-
assert
|
1099
|
+
assert _DB_PATH is not None
|
1101
1100
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1102
1101
|
task_name = cursor.execute(
|
1103
1102
|
"""\
|
@@ -1110,7 +1109,7 @@ def get_task_name(job_id: int, task_id: int) -> str:
|
|
1110
1109
|
@_init_db
|
1111
1110
|
def get_latest_job_id() -> Optional[int]:
|
1112
1111
|
"""Get the latest job id."""
|
1113
|
-
assert
|
1112
|
+
assert _DB_PATH is not None
|
1114
1113
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1115
1114
|
rows = cursor.execute("""\
|
1116
1115
|
SELECT spot_job_id FROM spot
|
@@ -1123,7 +1122,7 @@ def get_latest_job_id() -> Optional[int]:
|
|
1123
1122
|
|
1124
1123
|
@_init_db
|
1125
1124
|
def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
|
1126
|
-
assert
|
1125
|
+
assert _DB_PATH is not None
|
1127
1126
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1128
1127
|
task_specs = cursor.execute(
|
1129
1128
|
"""\
|
@@ -1136,7 +1135,7 @@ def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
|
|
1136
1135
|
@_init_db
|
1137
1136
|
def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
|
1138
1137
|
"""Get the local log directory for a job."""
|
1139
|
-
assert
|
1138
|
+
assert _DB_PATH is not None
|
1140
1139
|
filter_str = 'spot_job_id=(?)'
|
1141
1140
|
filter_args = [job_id]
|
1142
1141
|
if task_id is not None:
|
@@ -1159,7 +1158,7 @@ def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
|
|
1159
1158
|
original_user_yaml_path: str, env_file_path: str,
|
1160
1159
|
user_hash: str, priority: int) -> None:
|
1161
1160
|
"""Do not call without holding the scheduler lock."""
|
1162
|
-
assert
|
1161
|
+
assert _DB_PATH is not None
|
1163
1162
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1164
1163
|
updated_count = cursor.execute(
|
1165
1164
|
'UPDATE job_info SET '
|
@@ -1177,7 +1176,7 @@ def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
|
|
1177
1176
|
def scheduler_set_launching(job_id: int,
|
1178
1177
|
current_state: ManagedJobScheduleState) -> None:
|
1179
1178
|
"""Do not call without holding the scheduler lock."""
|
1180
|
-
assert
|
1179
|
+
assert _DB_PATH is not None
|
1181
1180
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1182
1181
|
updated_count = cursor.execute(
|
1183
1182
|
'UPDATE job_info SET '
|
@@ -1191,7 +1190,7 @@ def scheduler_set_launching(job_id: int,
|
|
1191
1190
|
@_init_db
|
1192
1191
|
def scheduler_set_alive(job_id: int) -> None:
|
1193
1192
|
"""Do not call without holding the scheduler lock."""
|
1194
|
-
assert
|
1193
|
+
assert _DB_PATH is not None
|
1195
1194
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1196
1195
|
updated_count = cursor.execute(
|
1197
1196
|
'UPDATE job_info SET '
|
@@ -1205,7 +1204,7 @@ def scheduler_set_alive(job_id: int) -> None:
|
|
1205
1204
|
@_init_db
|
1206
1205
|
def scheduler_set_alive_backoff(job_id: int) -> None:
|
1207
1206
|
"""Do not call without holding the scheduler lock."""
|
1208
|
-
assert
|
1207
|
+
assert _DB_PATH is not None
|
1209
1208
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1210
1209
|
updated_count = cursor.execute(
|
1211
1210
|
'UPDATE job_info SET '
|
@@ -1219,7 +1218,7 @@ def scheduler_set_alive_backoff(job_id: int) -> None:
|
|
1219
1218
|
@_init_db
|
1220
1219
|
def scheduler_set_alive_waiting(job_id: int) -> None:
|
1221
1220
|
"""Do not call without holding the scheduler lock."""
|
1222
|
-
assert
|
1221
|
+
assert _DB_PATH is not None
|
1223
1222
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1224
1223
|
updated_count = cursor.execute(
|
1225
1224
|
'UPDATE job_info SET '
|
@@ -1234,7 +1233,7 @@ def scheduler_set_alive_waiting(job_id: int) -> None:
|
|
1234
1233
|
@_init_db
|
1235
1234
|
def scheduler_set_done(job_id: int, idempotent: bool = False) -> None:
|
1236
1235
|
"""Do not call without holding the scheduler lock."""
|
1237
|
-
assert
|
1236
|
+
assert _DB_PATH is not None
|
1238
1237
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1239
1238
|
updated_count = cursor.execute(
|
1240
1239
|
'UPDATE job_info SET '
|
@@ -1248,7 +1247,7 @@ def scheduler_set_done(job_id: int, idempotent: bool = False) -> None:
|
|
1248
1247
|
|
1249
1248
|
@_init_db
|
1250
1249
|
def set_job_controller_pid(job_id: int, pid: int):
|
1251
|
-
assert
|
1250
|
+
assert _DB_PATH is not None
|
1252
1251
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1253
1252
|
updated_count = cursor.execute(
|
1254
1253
|
'UPDATE job_info SET '
|
@@ -1259,7 +1258,7 @@ def set_job_controller_pid(job_id: int, pid: int):
|
|
1259
1258
|
|
1260
1259
|
@_init_db
|
1261
1260
|
def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState:
|
1262
|
-
assert
|
1261
|
+
assert _DB_PATH is not None
|
1263
1262
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1264
1263
|
state = cursor.execute(
|
1265
1264
|
'SELECT schedule_state FROM job_info WHERE spot_job_id = (?)',
|
@@ -1269,7 +1268,7 @@ def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState:
|
|
1269
1268
|
|
1270
1269
|
@_init_db
|
1271
1270
|
def get_num_launching_jobs() -> int:
|
1272
|
-
assert
|
1271
|
+
assert _DB_PATH is not None
|
1273
1272
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1274
1273
|
return cursor.execute(
|
1275
1274
|
'SELECT COUNT(*) '
|
@@ -1280,7 +1279,7 @@ def get_num_launching_jobs() -> int:
|
|
1280
1279
|
|
1281
1280
|
@_init_db
|
1282
1281
|
def get_num_alive_jobs() -> int:
|
1283
|
-
assert
|
1282
|
+
assert _DB_PATH is not None
|
1284
1283
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1285
1284
|
return cursor.execute(
|
1286
1285
|
'SELECT COUNT(*) '
|
@@ -1303,7 +1302,7 @@ def get_waiting_job() -> Optional[Dict[str, Any]]:
|
|
1303
1302
|
Backwards compatibility note: jobs submitted before #4485 will have no
|
1304
1303
|
schedule_state and will be ignored by this SQL query.
|
1305
1304
|
"""
|
1306
|
-
assert
|
1305
|
+
assert _DB_PATH is not None
|
1307
1306
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1308
1307
|
# Get the highest-priority WAITING or ALIVE_WAITING job whose priority
|
1309
1308
|
# is greater than or equal to the highest priority LAUNCHING or
|
@@ -1338,7 +1337,7 @@ def get_waiting_job() -> Optional[Dict[str, Any]]:
|
|
1338
1337
|
@_init_db
|
1339
1338
|
def get_workspace(job_id: int) -> str:
|
1340
1339
|
"""Get the workspace of a job."""
|
1341
|
-
assert
|
1340
|
+
assert _DB_PATH is not None
|
1342
1341
|
with db_utils.safe_cursor(_DB_PATH) as cursor:
|
1343
1342
|
workspace = cursor.execute(
|
1344
1343
|
'SELECT workspace FROM job_info WHERE spot_job_id = (?)',
|
sky/provision/common.py
CHANGED
sky/provision/gcp/config.py
CHANGED
@@ -274,7 +274,7 @@ def _is_permission_satisfied(service_account, crm, iam, required_permissions,
|
|
274
274
|
# For example, `roles/iam.serviceAccountUser` can be granted at the
|
275
275
|
# skypilot-v1 service account level, which can be checked with
|
276
276
|
# service_account_policy = iam.projects().serviceAccounts().getIamPolicy(
|
277
|
-
# resource=f'projects/{project_id}/
|
277
|
+
# resource=f'projects/{project_id}/serviceAccounts/{email}').execute()
|
278
278
|
# We now skip the check for `iam.serviceAccounts.actAs` permission for
|
279
279
|
# simplicity as it can be granted at the service account level.
|
280
280
|
def check_permissions(policy, required_permissions):
|
@@ -1277,7 +1277,8 @@ def query_instances(
|
|
1277
1277
|
except kubernetes.max_retry_error():
|
1278
1278
|
with ux_utils.print_exception_no_traceback():
|
1279
1279
|
if is_ssh:
|
1280
|
-
node_pool =
|
1280
|
+
node_pool = common_utils.removeprefix(context,
|
1281
|
+
'ssh-') if context else ''
|
1281
1282
|
msg = (
|
1282
1283
|
f'Cannot connect to SSH Node Pool {node_pool}. '
|
1283
1284
|
'Please check if the SSH Node Pool is up and accessible. '
|
@@ -133,6 +133,30 @@ DEFAULT_MAX_RETRIES = 3
|
|
133
133
|
DEFAULT_RETRY_INTERVAL_SECONDS = 1
|
134
134
|
|
135
135
|
|
136
|
+
def normalize_tpu_accelerator_name(accelerator: str) -> Tuple[str, int]:
|
137
|
+
"""Normalize TPU names to the k8s-compatible name and extract count."""
|
138
|
+
# Examples:
|
139
|
+
# 'tpu-v6e-8' -> ('tpu-v6e-slice', 8)
|
140
|
+
# 'tpu-v5litepod-4' -> ('tpu-v5-lite-podslice', 4)
|
141
|
+
|
142
|
+
gcp_to_k8s_patterns = [
|
143
|
+
(r'^tpu-v6e-(\d+)$', 'tpu-v6e-slice'),
|
144
|
+
(r'^tpu-v5p-(\d+)$', 'tpu-v5p-slice'),
|
145
|
+
(r'^tpu-v5litepod-(\d+)$', 'tpu-v5-lite-podslice'),
|
146
|
+
(r'^tpu-v5lite-(\d+)$', 'tpu-v5-lite-device'),
|
147
|
+
(r'^tpu-v4-(\d+)$', 'tpu-v4-podslice'),
|
148
|
+
]
|
149
|
+
|
150
|
+
for pattern, replacement in gcp_to_k8s_patterns:
|
151
|
+
match = re.match(pattern, accelerator)
|
152
|
+
if match:
|
153
|
+
count = int(match.group(1))
|
154
|
+
return replacement, count
|
155
|
+
|
156
|
+
# Default fallback
|
157
|
+
return accelerator, 1
|
158
|
+
|
159
|
+
|
136
160
|
def _retry_on_error(max_retries=DEFAULT_MAX_RETRIES,
|
137
161
|
retry_interval=DEFAULT_RETRY_INTERVAL_SECONDS,
|
138
162
|
resource_type: Optional[str] = None):
|
@@ -427,6 +451,7 @@ class GKELabelFormatter(GPULabelFormatter):
|
|
427
451
|
|
428
452
|
e.g. tpu-v5-lite-podslice:8 -> '2x4'
|
429
453
|
"""
|
454
|
+
acc_type, acc_count = normalize_tpu_accelerator_name(acc_type)
|
430
455
|
count_to_topology = cls.GKE_TPU_TOPOLOGIES.get(acc_type,
|
431
456
|
{}).get(acc_count, None)
|
432
457
|
if count_to_topology is None:
|
@@ -461,6 +486,14 @@ class GKELabelFormatter(GPULabelFormatter):
|
|
461
486
|
raise ValueError(
|
462
487
|
f'Invalid accelerator name in GKE cluster: {value}')
|
463
488
|
|
489
|
+
@classmethod
|
490
|
+
def validate_label_value(cls, value: str) -> Tuple[bool, str]:
|
491
|
+
try:
|
492
|
+
_ = cls.get_accelerator_from_label_value(value)
|
493
|
+
return True, ''
|
494
|
+
except ValueError as e:
|
495
|
+
return False, str(e)
|
496
|
+
|
464
497
|
|
465
498
|
class GFDLabelFormatter(GPULabelFormatter):
|
466
499
|
"""GPU Feature Discovery label formatter
|
@@ -565,17 +598,29 @@ def detect_gpu_label_formatter(
|
|
565
598
|
for label, value in node.metadata.labels.items():
|
566
599
|
node_labels[node.metadata.name].append((label, value))
|
567
600
|
|
568
|
-
label_formatter = None
|
569
|
-
|
570
601
|
# Check if the node labels contain any of the GPU label prefixes
|
571
602
|
for lf in LABEL_FORMATTER_REGISTRY:
|
603
|
+
skip = False
|
572
604
|
for _, label_list in node_labels.items():
|
573
|
-
for label,
|
605
|
+
for label, value in label_list:
|
574
606
|
if lf.match_label_key(label):
|
575
|
-
|
576
|
-
|
607
|
+
valid, reason = lf.validate_label_value(value)
|
608
|
+
if valid:
|
609
|
+
return lf(), node_labels
|
610
|
+
else:
|
611
|
+
logger.warning(f'GPU label {label} matched for label '
|
612
|
+
f'formatter {lf.__class__.__name__}, '
|
613
|
+
f'but has invalid value {value}. '
|
614
|
+
f'Reason: {reason}. '
|
615
|
+
'Skipping...')
|
616
|
+
skip = True
|
617
|
+
break
|
618
|
+
if skip:
|
619
|
+
break
|
620
|
+
if skip:
|
621
|
+
continue
|
577
622
|
|
578
|
-
return
|
623
|
+
return None, node_labels
|
579
624
|
|
580
625
|
|
581
626
|
class Autoscaler:
|
@@ -754,6 +799,8 @@ class GKEAutoscaler(Autoscaler):
|
|
754
799
|
f'checking {node_pool_name} for TPU {requested_acc_type}:'
|
755
800
|
f'{requested_acc_count}')
|
756
801
|
if 'resourceLabels' in node_config:
|
802
|
+
requested_acc_type, requested_acc_count = normalize_tpu_accelerator_name(
|
803
|
+
requested_acc_type)
|
757
804
|
accelerator_exists = cls._node_pool_has_tpu_capacity(
|
758
805
|
node_config['resourceLabels'], machine_type,
|
759
806
|
requested_acc_type, requested_acc_count)
|
@@ -993,7 +1040,7 @@ def check_instance_fits(context: Optional[str],
|
|
993
1040
|
'Maximum resources found on a single node: '
|
994
1041
|
f'{max_cpu} CPUs, {common_utils.format_float(max_mem)}G Memory')
|
995
1042
|
|
996
|
-
def check_tpu_fits(
|
1043
|
+
def check_tpu_fits(acc_type: str, acc_count: int,
|
997
1044
|
node_list: List[Any]) -> Tuple[bool, Optional[str]]:
|
998
1045
|
"""Checks if the instance fits on the cluster based on requested TPU.
|
999
1046
|
|
@@ -1003,8 +1050,6 @@ def check_instance_fits(context: Optional[str],
|
|
1003
1050
|
node (node_tpu_chip_count) and the total TPU chips across the entire
|
1004
1051
|
podslice (topology_chip_count) are correctly handled.
|
1005
1052
|
"""
|
1006
|
-
acc_type = candidate_instance_type.accelerator_type
|
1007
|
-
acc_count = candidate_instance_type.accelerator_count
|
1008
1053
|
tpu_list_in_cluster = []
|
1009
1054
|
for node in node_list:
|
1010
1055
|
if acc_type == node.metadata.labels[
|
@@ -1055,7 +1100,8 @@ def check_instance_fits(context: Optional[str],
|
|
1055
1100
|
if is_tpu_on_gke(acc_type):
|
1056
1101
|
# If requested accelerator is a TPU type, check if the cluster
|
1057
1102
|
# has sufficient TPU resource to meet the requirement.
|
1058
|
-
|
1103
|
+
acc_type, acc_count = normalize_tpu_accelerator_name(acc_type)
|
1104
|
+
fits, reason = check_tpu_fits(acc_type, acc_count, gpu_nodes)
|
1059
1105
|
if reason is not None:
|
1060
1106
|
return fits, reason
|
1061
1107
|
else:
|
@@ -1141,8 +1187,8 @@ def get_accelerator_label_key_values(
|
|
1141
1187
|
|
1142
1188
|
is_ssh_node_pool = context.startswith('ssh-') if context else False
|
1143
1189
|
cloud_name = 'SSH Node Pool' if is_ssh_node_pool else 'Kubernetes cluster'
|
1144
|
-
context_display_name =
|
1145
|
-
context and is_ssh_node_pool) else context
|
1190
|
+
context_display_name = common_utils.removeprefix(
|
1191
|
+
context, 'ssh-') if (context and is_ssh_node_pool) else context
|
1146
1192
|
|
1147
1193
|
autoscaler_type = get_autoscaler_type()
|
1148
1194
|
if autoscaler_type is not None:
|
@@ -2911,7 +2957,8 @@ def get_skypilot_pods(context: Optional[str] = None) -> List[Any]:
|
|
2911
2957
|
|
2912
2958
|
def is_tpu_on_gke(accelerator: str) -> bool:
|
2913
2959
|
"""Determines if the given accelerator is a TPU supported on GKE."""
|
2914
|
-
|
2960
|
+
normalized, _ = normalize_tpu_accelerator_name(accelerator)
|
2961
|
+
return normalized in GKE_TPU_ACCELERATOR_TO_GENERATION
|
2915
2962
|
|
2916
2963
|
|
2917
2964
|
def get_node_accelerator_count(attribute_dict: dict) -> int:
|