skypilot-nightly 1.0.0.dev20250611__py3-none-any.whl → 1.0.0.dev20250613__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sky/__init__.py +2 -2
  2. sky/adaptors/kubernetes.py +3 -2
  3. sky/backends/backend_utils.py +8 -2
  4. sky/benchmark/benchmark_state.py +2 -1
  5. sky/catalog/data_fetchers/fetch_aws.py +1 -1
  6. sky/catalog/data_fetchers/fetch_vast.py +1 -1
  7. sky/check.py +43 -3
  8. sky/cli.py +1 -1
  9. sky/client/cli.py +1 -1
  10. sky/clouds/cloud.py +1 -1
  11. sky/clouds/gcp.py +1 -1
  12. sky/clouds/kubernetes.py +9 -3
  13. sky/clouds/ssh.py +7 -3
  14. sky/dashboard/out/404.html +1 -1
  15. sky/dashboard/out/_next/static/chunks/{webpack-208a9812ab4f61c9.js → webpack-5c3e6471d04780c6.js} +1 -1
  16. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  17. sky/dashboard/out/clusters/[cluster].html +1 -1
  18. sky/dashboard/out/clusters.html +1 -1
  19. sky/dashboard/out/config.html +1 -1
  20. sky/dashboard/out/index.html +1 -1
  21. sky/dashboard/out/infra/[context].html +1 -1
  22. sky/dashboard/out/infra.html +1 -1
  23. sky/dashboard/out/jobs/[job].html +1 -1
  24. sky/dashboard/out/jobs.html +1 -1
  25. sky/dashboard/out/users.html +1 -1
  26. sky/dashboard/out/workspace/new.html +1 -1
  27. sky/dashboard/out/workspaces/[name].html +1 -1
  28. sky/dashboard/out/workspaces.html +1 -1
  29. sky/data/storage.py +2 -2
  30. sky/global_user_state.py +38 -0
  31. sky/jobs/server/core.py +1 -68
  32. sky/jobs/state.py +43 -44
  33. sky/provision/common.py +1 -1
  34. sky/provision/gcp/config.py +1 -1
  35. sky/provision/kubernetes/instance.py +2 -1
  36. sky/provision/kubernetes/utils.py +60 -13
  37. sky/resources.py +2 -2
  38. sky/serve/serve_state.py +81 -15
  39. sky/server/requests/preconditions.py +1 -1
  40. sky/server/requests/requests.py +11 -6
  41. sky/skylet/configs.py +26 -19
  42. sky/skylet/job_lib.py +3 -5
  43. sky/task.py +1 -1
  44. sky/templates/jobs-controller.yaml.j2 +0 -23
  45. sky/templates/kubernetes-ray.yml.j2 +1 -1
  46. sky/utils/common_utils.py +6 -0
  47. sky/utils/context.py +1 -1
  48. sky/utils/controller_utils.py +10 -0
  49. sky/utils/infra_utils.py +1 -1
  50. sky/utils/kubernetes/generate_kubeconfig.sh +1 -1
  51. {skypilot_nightly-1.0.0.dev20250611.dist-info → skypilot_nightly-1.0.0.dev20250613.dist-info}/METADATA +1 -1
  52. {skypilot_nightly-1.0.0.dev20250611.dist-info → skypilot_nightly-1.0.0.dev20250613.dist-info}/RECORD +58 -62
  53. sky/jobs/dashboard/dashboard.py +0 -223
  54. sky/jobs/dashboard/static/favicon.ico +0 -0
  55. sky/jobs/dashboard/templates/index.html +0 -831
  56. sky/jobs/server/dashboard_utils.py +0 -69
  57. /sky/dashboard/out/_next/static/{zJqasksBQ3HcqMpA2wTUZ → UdgJCk2sZFLJgFJW_qiWG}/_buildManifest.js +0 -0
  58. /sky/dashboard/out/_next/static/{zJqasksBQ3HcqMpA2wTUZ → UdgJCk2sZFLJgFJW_qiWG}/_ssgManifest.js +0 -0
  59. {skypilot_nightly-1.0.0.dev20250611.dist-info → skypilot_nightly-1.0.0.dev20250613.dist-info}/WHEEL +0 -0
  60. {skypilot_nightly-1.0.0.dev20250611.dist-info → skypilot_nightly-1.0.0.dev20250613.dist-info}/entry_points.txt +0 -0
  61. {skypilot_nightly-1.0.0.dev20250611.dist-info → skypilot_nightly-1.0.0.dev20250613.dist-info}/licenses/LICENSE +0 -0
  62. {skypilot_nightly-1.0.0.dev20250611.dist-info → skypilot_nightly-1.0.0.dev20250613.dist-info}/top_level.txt +0 -0
sky/global_user_state.py CHANGED
@@ -45,6 +45,7 @@ if typing.TYPE_CHECKING:
45
45
  logger = sky_logging.init_logger(__name__)
46
46
 
47
47
  _ENABLED_CLOUDS_KEY_PREFIX = 'enabled_clouds_'
48
+ _ALLOWED_CLOUDS_KEY_PREFIX = 'allowed_clouds_'
48
49
 
49
50
  _SQLALCHEMY_ENGINE: Optional[sqlalchemy.engine.Engine] = None
50
51
  _DB_INIT_LOCK = threading.Lock()
@@ -1087,6 +1088,43 @@ def _get_enabled_clouds_key(cloud_capability: 'cloud.CloudCapability',
1087
1088
  return _ENABLED_CLOUDS_KEY_PREFIX + workspace + '_' + cloud_capability.value
1088
1089
 
1089
1090
 
1091
+ @_init_db
1092
+ def get_allowed_clouds(workspace: str) -> List[str]:
1093
+ assert _SQLALCHEMY_ENGINE is not None
1094
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1095
+ row = session.query(config_table).filter_by(
1096
+ key=_get_allowed_clouds_key(workspace)).first()
1097
+ if row:
1098
+ return json.loads(row.value)
1099
+ return []
1100
+
1101
+
1102
+ @_init_db
1103
+ def set_allowed_clouds(allowed_clouds: List[str], workspace: str) -> None:
1104
+ assert _SQLALCHEMY_ENGINE is not None
1105
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1106
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
1107
+ db_utils.SQLAlchemyDialect.SQLITE.value):
1108
+ insert_func = sqlite.insert
1109
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
1110
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
1111
+ insert_func = postgresql.insert
1112
+ else:
1113
+ raise ValueError('Unsupported database dialect')
1114
+ insert_stmnt = insert_func(config_table).values(
1115
+ key=_get_allowed_clouds_key(workspace),
1116
+ value=json.dumps(allowed_clouds))
1117
+ do_update_stmt = insert_stmnt.on_conflict_do_update(
1118
+ index_elements=[config_table.c.key],
1119
+ set_={config_table.c.value: json.dumps(allowed_clouds)})
1120
+ session.execute(do_update_stmt)
1121
+ session.commit()
1122
+
1123
+
1124
+ def _get_allowed_clouds_key(workspace: str) -> str:
1125
+ return _ALLOWED_CLOUDS_KEY_PREFIX + workspace
1126
+
1127
+
1090
1128
  @_init_db
1091
1129
  def add_or_update_storage(storage_name: str,
1092
1130
  storage_handle: 'Storage.StorageMetadata',
sky/jobs/server/core.py CHANGED
@@ -1,9 +1,6 @@
1
1
  """SDK functions for managed jobs."""
2
2
  import os
3
- import signal
4
- import subprocess
5
3
  import tempfile
6
- import time
7
4
  import typing
8
5
  from typing import Any, Dict, List, Optional, Tuple, Union
9
6
  import uuid
@@ -213,8 +210,6 @@ def launch(
213
210
  'remote_env_file_path': remote_env_file_path,
214
211
  'modified_catalogs':
215
212
  service_catalog_common.get_modified_catalog_file_mounts(),
216
- 'dashboard_setup_cmd': managed_job_constants.DASHBOARD_SETUP_CMD,
217
- 'dashboard_user_id': common.SERVER_ID,
218
213
  'priority': priority,
219
214
  **controller_utils.shared_controller_vars_to_fill(
220
215
  controller,
@@ -368,20 +363,7 @@ def _maybe_restart_controller(
368
363
  skylet_constants.SKYPILOT_DEFAULT_WORKSPACE):
369
364
  handle = core.start(
370
365
  cluster_name=jobs_controller_type.value.cluster_name)
371
- # Make sure the dashboard is running when the controller is restarted.
372
- # We should not directly use execution.launch() and have the dashboard cmd
373
- # in the task setup because since we are using detached_setup, it will
374
- # become a job on controller which messes up the job IDs (we assume the
375
- # job ID in controller's job queue is consistent with managed job IDs).
376
- with rich_utils.safe_status(
377
- ux_utils.spinner_message('Starting dashboard...')):
378
- runner = handle.get_command_runners()[0]
379
- runner.run(
380
- f'export '
381
- f'{skylet_constants.USER_ID_ENV_VAR}={common.SERVER_ID!r}; '
382
- f'{managed_job_constants.DASHBOARD_SETUP_CMD}',
383
- stream_logs=True,
384
- )
366
+
385
367
  controller_status = status_lib.ClusterStatus.UP
386
368
  rich_utils.force_update_status(ux_utils.spinner_message(spinner_message))
387
369
 
@@ -598,55 +580,6 @@ def tail_logs(name: Optional[str],
598
580
  tail=tail)
599
581
 
600
582
 
601
- def start_dashboard_forwarding(refresh: bool = False) -> Tuple[int, int]:
602
- """Opens a dashboard for managed jobs (needs controller to be UP)."""
603
- # TODO(SKY-1212): ideally, the controller/dashboard server should expose the
604
- # API perhaps via REST. Then here we would (1) not have to use SSH to try to
605
- # see if the controller is UP first, which is slow; (2) not have to run SSH
606
- # port forwarding first (we'd just launch a local dashboard which would make
607
- # REST API calls to the controller dashboard server).
608
- logger.info('Starting dashboard')
609
- hint = ('Dashboard is not available if jobs controller is not up. Run '
610
- 'a managed job first or run: sky jobs queue --refresh')
611
- handle = _maybe_restart_controller(
612
- refresh=refresh,
613
- stopped_message=hint,
614
- spinner_message='Checking jobs controller')
615
-
616
- # SSH forward a free local port to remote's dashboard port.
617
- remote_port = skylet_constants.SPOT_DASHBOARD_REMOTE_PORT
618
- free_port = common_utils.find_free_port(remote_port)
619
- runner = handle.get_command_runners()[0]
620
- port_forward_command = ' '.join(
621
- runner.port_forward_command(port_forward=[(free_port, remote_port)],
622
- connect_timeout=1))
623
- port_forward_command = (
624
- f'{port_forward_command} '
625
- f'> ~/sky_logs/api_server/dashboard-{common_utils.get_user_hash()}.log '
626
- '2>&1')
627
- logger.info(f'Forwarding port: {colorama.Style.DIM}{port_forward_command}'
628
- f'{colorama.Style.RESET_ALL}')
629
-
630
- ssh_process = subprocess.Popen(port_forward_command,
631
- shell=True,
632
- start_new_session=True)
633
- time.sleep(3) # Added delay for ssh_command to initialize.
634
- logger.info(f'{colorama.Fore.GREEN}Dashboard is now available at: '
635
- f'http://127.0.0.1:{free_port}{colorama.Style.RESET_ALL}')
636
-
637
- return free_port, ssh_process.pid
638
-
639
-
640
- def stop_dashboard_forwarding(pid: int) -> None:
641
- # Exit the ssh command when the context manager is closed.
642
- try:
643
- os.killpg(os.getpgid(pid), signal.SIGTERM)
644
- except ProcessLookupError:
645
- # This happens if jobs controller is auto-stopped.
646
- pass
647
- logger.info('Forwarding port closed. Exiting.')
648
-
649
-
650
583
  @usage_lib.entrypoint
651
584
  def download_logs(
652
585
  name: Optional[str],
sky/jobs/state.py CHANGED
@@ -161,8 +161,8 @@ def create_table(cursor, conn):
161
161
  conn.commit()
162
162
 
163
163
 
164
- # Module-level connection/cursor; thread-safe as the module is only imported
165
- # once.
164
+ # Module-level connection/cursor; thread-safe as the db is initialized once
165
+ # across all threads.
166
166
  def _get_db_path() -> str:
167
167
  """Workaround to collapse multi-step Path ops for type checker.
168
168
  Ensures _DB_PATH is str, avoiding Union[Path, str] inference.
@@ -173,8 +173,7 @@ def _get_db_path() -> str:
173
173
  return str(path)
174
174
 
175
175
 
176
- _DB_PATH = _get_db_path()
177
- _db_initialized = False
176
+ _DB_PATH = None
178
177
  _db_init_lock = threading.Lock()
179
178
 
180
179
 
@@ -183,13 +182,13 @@ def _init_db(func):
183
182
 
184
183
  @functools.wraps(func)
185
184
  def wrapper(*args, **kwargs):
186
- global _db_initialized
187
- if _db_initialized:
185
+ global _DB_PATH
186
+ if _DB_PATH is not None:
188
187
  return func(*args, **kwargs)
189
188
  with _db_init_lock:
190
- if not _db_initialized:
189
+ if _DB_PATH is None:
190
+ _DB_PATH = _get_db_path()
191
191
  db_utils.SQLiteConn(_DB_PATH, create_table)
192
- _db_initialized = True
193
192
  return func(*args, **kwargs)
194
193
 
195
194
  return wrapper
@@ -442,7 +441,7 @@ class ManagedJobScheduleState(enum.Enum):
442
441
  # === Status transition functions ===
443
442
  @_init_db
444
443
  def set_job_info(job_id: int, name: str, workspace: str, entrypoint: str):
445
- assert _db_initialized
444
+ assert _DB_PATH is not None
446
445
  with db_utils.safe_cursor(_DB_PATH) as cursor:
447
446
  cursor.execute(
448
447
  """\
@@ -456,7 +455,7 @@ def set_job_info(job_id: int, name: str, workspace: str, entrypoint: str):
456
455
  @_init_db
457
456
  def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str):
458
457
  """Set the task to pending state."""
459
- assert _db_initialized
458
+ assert _DB_PATH is not None
460
459
  with db_utils.safe_cursor(_DB_PATH) as cursor:
461
460
  cursor.execute(
462
461
  """\
@@ -484,7 +483,7 @@ def set_starting(job_id: int, task_id: int, run_timestamp: str,
484
483
  specs: The specs of the managed task.
485
484
  callback_func: The callback function.
486
485
  """
487
- assert _db_initialized
486
+ assert _DB_PATH is not None
488
487
  # Use the timestamp in the `run_timestamp` ('sky-2022-10...'), to make
489
488
  # the log directory and submission time align with each other, so as to
490
489
  # make it easier to find them based on one of the values.
@@ -524,7 +523,7 @@ def set_backoff_pending(job_id: int, task_id: int):
524
523
  This should only be used to transition from STARTING or RECOVERING back to
525
524
  PENDING.
526
525
  """
527
- assert _db_initialized
526
+ assert _DB_PATH is not None
528
527
  with db_utils.safe_cursor(_DB_PATH) as cursor:
529
528
  cursor.execute(
530
529
  """\
@@ -552,7 +551,7 @@ def set_restarting(job_id: int, task_id: int, recovering: bool):
552
551
  after using set_backoff_pending to transition back to PENDING during
553
552
  launch retry backoff.
554
553
  """
555
- assert _db_initialized
554
+ assert _DB_PATH is not None
556
555
  target_status = ManagedJobStatus.STARTING.value
557
556
  if recovering:
558
557
  target_status = ManagedJobStatus.RECOVERING.value
@@ -578,7 +577,7 @@ def set_restarting(job_id: int, task_id: int, recovering: bool):
578
577
  def set_started(job_id: int, task_id: int, start_time: float,
579
578
  callback_func: CallbackType):
580
579
  """Set the task to started state."""
581
- assert _db_initialized
580
+ assert _DB_PATH is not None
582
581
  logger.info('Job started.')
583
582
  with db_utils.safe_cursor(_DB_PATH) as cursor:
584
583
  cursor.execute(
@@ -610,7 +609,7 @@ def set_started(job_id: int, task_id: int, start_time: float,
610
609
  @_init_db
611
610
  def set_recovering(job_id: int, task_id: int, callback_func: CallbackType):
612
611
  """Set the task to recovering state, and update the job duration."""
613
- assert _db_initialized
612
+ assert _DB_PATH is not None
614
613
  logger.info('=== Recovering... ===')
615
614
  with db_utils.safe_cursor(_DB_PATH) as cursor:
616
615
  cursor.execute(
@@ -634,7 +633,7 @@ def set_recovering(job_id: int, task_id: int, callback_func: CallbackType):
634
633
  def set_recovered(job_id: int, task_id: int, recovered_time: float,
635
634
  callback_func: CallbackType):
636
635
  """Set the task to recovered."""
637
- assert _db_initialized
636
+ assert _DB_PATH is not None
638
637
  with db_utils.safe_cursor(_DB_PATH) as cursor:
639
638
  cursor.execute(
640
639
  """\
@@ -658,7 +657,7 @@ def set_recovered(job_id: int, task_id: int, recovered_time: float,
658
657
  def set_succeeded(job_id: int, task_id: int, end_time: float,
659
658
  callback_func: CallbackType):
660
659
  """Set the task to succeeded, if it is in a non-terminal state."""
661
- assert _db_initialized
660
+ assert _DB_PATH is not None
662
661
  with db_utils.safe_cursor(_DB_PATH) as cursor:
663
662
  cursor.execute(
664
663
  """\
@@ -703,7 +702,7 @@ def set_failed(
703
702
  override_terminal: If True, override the current status even if end_at
704
703
  is already set.
705
704
  """
706
- assert _db_initialized
705
+ assert _DB_PATH is not None
707
706
  assert failure_type.is_failed(), failure_type
708
707
  end_time = time.time() if end_time is None else end_time
709
708
 
@@ -761,7 +760,7 @@ def set_cancelling(job_id: int, callback_func: CallbackType):
761
760
  task_id is not needed, because we expect the job should be cancelled
762
761
  as a whole, and we should not cancel a single task.
763
762
  """
764
- assert _db_initialized
763
+ assert _DB_PATH is not None
765
764
  with db_utils.safe_cursor(_DB_PATH) as cursor:
766
765
  rows = cursor.execute(
767
766
  """\
@@ -783,7 +782,7 @@ def set_cancelled(job_id: int, callback_func: CallbackType):
783
782
 
784
783
  The set_cancelling should be called before this function.
785
784
  """
786
- assert _db_initialized
785
+ assert _DB_PATH is not None
787
786
  with db_utils.safe_cursor(_DB_PATH) as cursor:
788
787
  rows = cursor.execute(
789
788
  """\
@@ -804,7 +803,7 @@ def set_cancelled(job_id: int, callback_func: CallbackType):
804
803
  def set_local_log_file(job_id: int, task_id: Optional[int],
805
804
  local_log_file: str):
806
805
  """Set the local log file for a job."""
807
- assert _db_initialized
806
+ assert _DB_PATH is not None
808
807
  filter_str = 'spot_job_id=(?)'
809
808
  filter_args = [local_log_file, job_id]
810
809
 
@@ -822,7 +821,7 @@ def set_local_log_file(job_id: int, task_id: Optional[int],
822
821
  def get_nonterminal_job_ids_by_name(name: Optional[str],
823
822
  all_users: bool = False) -> List[int]:
824
823
  """Get non-terminal job ids by name."""
825
- assert _db_initialized
824
+ assert _DB_PATH is not None
826
825
  statuses = ', '.join(['?'] * len(ManagedJobStatus.terminal_statuses()))
827
826
  field_values = [
828
827
  status.value for status in ManagedJobStatus.terminal_statuses()
@@ -866,7 +865,7 @@ def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]:
866
865
  exception: the job may have just transitioned from WAITING to LAUNCHING, but
867
866
  the controller process has not yet started.
868
867
  """
869
- assert _db_initialized
868
+ assert _DB_PATH is not None
870
869
  job_filter = '' if job_id is None else 'AND spot_job_id=(?)'
871
870
  job_value = (job_id,) if job_id is not None else ()
872
871
 
@@ -909,7 +908,7 @@ def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
909
908
  - Jobs have schedule_state DONE but are in a non-terminal status
910
909
  - Legacy jobs (that is, no schedule state) that are in non-terminal status
911
910
  """
912
- assert _db_initialized
911
+ assert _DB_PATH is not None
913
912
  job_filter = '' if job_id is None else 'AND spot.spot_job_id=(?)'
914
913
  job_value = () if job_id is None else (job_id,)
915
914
 
@@ -958,7 +957,7 @@ def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
958
957
  @_init_db
959
958
  def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
960
959
  """Get all job ids by name."""
961
- assert _db_initialized
960
+ assert _DB_PATH is not None
962
961
  name_filter = ''
963
962
  field_values = []
964
963
  if name is not None:
@@ -987,7 +986,7 @@ def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
987
986
  @_init_db
988
987
  def _get_all_task_ids_statuses(
989
988
  job_id: int) -> List[Tuple[int, ManagedJobStatus]]:
990
- assert _db_initialized
989
+ assert _DB_PATH is not None
991
990
  with db_utils.safe_cursor(_DB_PATH) as cursor:
992
991
  id_statuses = cursor.execute(
993
992
  """\
@@ -1035,7 +1034,7 @@ def get_failure_reason(job_id: int) -> Optional[str]:
1035
1034
 
1036
1035
  If the job has multiple tasks, we return the first failure reason.
1037
1036
  """
1038
- assert _db_initialized
1037
+ assert _DB_PATH is not None
1039
1038
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1040
1039
  reason = cursor.execute(
1041
1040
  """\
@@ -1051,7 +1050,7 @@ def get_failure_reason(job_id: int) -> Optional[str]:
1051
1050
  @_init_db
1052
1051
  def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
1053
1052
  """Get managed jobs from the database."""
1054
- assert _db_initialized
1053
+ assert _DB_PATH is not None
1055
1054
  job_filter = '' if job_id is None else f'WHERE spot.spot_job_id={job_id}'
1056
1055
 
1057
1056
  # Join spot and job_info tables to get the job name for each task.
@@ -1097,7 +1096,7 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
1097
1096
  @_init_db
1098
1097
  def get_task_name(job_id: int, task_id: int) -> str:
1099
1098
  """Get the task name of a job."""
1100
- assert _db_initialized
1099
+ assert _DB_PATH is not None
1101
1100
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1102
1101
  task_name = cursor.execute(
1103
1102
  """\
@@ -1110,7 +1109,7 @@ def get_task_name(job_id: int, task_id: int) -> str:
1110
1109
  @_init_db
1111
1110
  def get_latest_job_id() -> Optional[int]:
1112
1111
  """Get the latest job id."""
1113
- assert _db_initialized
1112
+ assert _DB_PATH is not None
1114
1113
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1115
1114
  rows = cursor.execute("""\
1116
1115
  SELECT spot_job_id FROM spot
@@ -1123,7 +1122,7 @@ def get_latest_job_id() -> Optional[int]:
1123
1122
 
1124
1123
  @_init_db
1125
1124
  def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
1126
- assert _db_initialized
1125
+ assert _DB_PATH is not None
1127
1126
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1128
1127
  task_specs = cursor.execute(
1129
1128
  """\
@@ -1136,7 +1135,7 @@ def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
1136
1135
  @_init_db
1137
1136
  def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
1138
1137
  """Get the local log directory for a job."""
1139
- assert _db_initialized
1138
+ assert _DB_PATH is not None
1140
1139
  filter_str = 'spot_job_id=(?)'
1141
1140
  filter_args = [job_id]
1142
1141
  if task_id is not None:
@@ -1159,7 +1158,7 @@ def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
1159
1158
  original_user_yaml_path: str, env_file_path: str,
1160
1159
  user_hash: str, priority: int) -> None:
1161
1160
  """Do not call without holding the scheduler lock."""
1162
- assert _db_initialized
1161
+ assert _DB_PATH is not None
1163
1162
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1164
1163
  updated_count = cursor.execute(
1165
1164
  'UPDATE job_info SET '
@@ -1177,7 +1176,7 @@ def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
1177
1176
  def scheduler_set_launching(job_id: int,
1178
1177
  current_state: ManagedJobScheduleState) -> None:
1179
1178
  """Do not call without holding the scheduler lock."""
1180
- assert _db_initialized
1179
+ assert _DB_PATH is not None
1181
1180
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1182
1181
  updated_count = cursor.execute(
1183
1182
  'UPDATE job_info SET '
@@ -1191,7 +1190,7 @@ def scheduler_set_launching(job_id: int,
1191
1190
  @_init_db
1192
1191
  def scheduler_set_alive(job_id: int) -> None:
1193
1192
  """Do not call without holding the scheduler lock."""
1194
- assert _db_initialized
1193
+ assert _DB_PATH is not None
1195
1194
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1196
1195
  updated_count = cursor.execute(
1197
1196
  'UPDATE job_info SET '
@@ -1205,7 +1204,7 @@ def scheduler_set_alive(job_id: int) -> None:
1205
1204
  @_init_db
1206
1205
  def scheduler_set_alive_backoff(job_id: int) -> None:
1207
1206
  """Do not call without holding the scheduler lock."""
1208
- assert _db_initialized
1207
+ assert _DB_PATH is not None
1209
1208
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1210
1209
  updated_count = cursor.execute(
1211
1210
  'UPDATE job_info SET '
@@ -1219,7 +1218,7 @@ def scheduler_set_alive_backoff(job_id: int) -> None:
1219
1218
  @_init_db
1220
1219
  def scheduler_set_alive_waiting(job_id: int) -> None:
1221
1220
  """Do not call without holding the scheduler lock."""
1222
- assert _db_initialized
1221
+ assert _DB_PATH is not None
1223
1222
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1224
1223
  updated_count = cursor.execute(
1225
1224
  'UPDATE job_info SET '
@@ -1234,7 +1233,7 @@ def scheduler_set_alive_waiting(job_id: int) -> None:
1234
1233
  @_init_db
1235
1234
  def scheduler_set_done(job_id: int, idempotent: bool = False) -> None:
1236
1235
  """Do not call without holding the scheduler lock."""
1237
- assert _db_initialized
1236
+ assert _DB_PATH is not None
1238
1237
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1239
1238
  updated_count = cursor.execute(
1240
1239
  'UPDATE job_info SET '
@@ -1248,7 +1247,7 @@ def scheduler_set_done(job_id: int, idempotent: bool = False) -> None:
1248
1247
 
1249
1248
  @_init_db
1250
1249
  def set_job_controller_pid(job_id: int, pid: int):
1251
- assert _db_initialized
1250
+ assert _DB_PATH is not None
1252
1251
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1253
1252
  updated_count = cursor.execute(
1254
1253
  'UPDATE job_info SET '
@@ -1259,7 +1258,7 @@ def set_job_controller_pid(job_id: int, pid: int):
1259
1258
 
1260
1259
  @_init_db
1261
1260
  def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState:
1262
- assert _db_initialized
1261
+ assert _DB_PATH is not None
1263
1262
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1264
1263
  state = cursor.execute(
1265
1264
  'SELECT schedule_state FROM job_info WHERE spot_job_id = (?)',
@@ -1269,7 +1268,7 @@ def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState:
1269
1268
 
1270
1269
  @_init_db
1271
1270
  def get_num_launching_jobs() -> int:
1272
- assert _db_initialized
1271
+ assert _DB_PATH is not None
1273
1272
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1274
1273
  return cursor.execute(
1275
1274
  'SELECT COUNT(*) '
@@ -1280,7 +1279,7 @@ def get_num_launching_jobs() -> int:
1280
1279
 
1281
1280
  @_init_db
1282
1281
  def get_num_alive_jobs() -> int:
1283
- assert _db_initialized
1282
+ assert _DB_PATH is not None
1284
1283
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1285
1284
  return cursor.execute(
1286
1285
  'SELECT COUNT(*) '
@@ -1303,7 +1302,7 @@ def get_waiting_job() -> Optional[Dict[str, Any]]:
1303
1302
  Backwards compatibility note: jobs submitted before #4485 will have no
1304
1303
  schedule_state and will be ignored by this SQL query.
1305
1304
  """
1306
- assert _db_initialized
1305
+ assert _DB_PATH is not None
1307
1306
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1308
1307
  # Get the highest-priority WAITING or ALIVE_WAITING job whose priority
1309
1308
  # is greater than or equal to the highest priority LAUNCHING or
@@ -1338,7 +1337,7 @@ def get_waiting_job() -> Optional[Dict[str, Any]]:
1338
1337
  @_init_db
1339
1338
  def get_workspace(job_id: int) -> str:
1340
1339
  """Get the workspace of a job."""
1341
- assert _db_initialized
1340
+ assert _DB_PATH is not None
1342
1341
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1343
1342
  workspace = cursor.execute(
1344
1343
  'SELECT workspace FROM job_info WHERE spot_job_id = (?)',
sky/provision/common.py CHANGED
@@ -238,7 +238,7 @@ class Endpoint:
238
238
 
239
239
  @dataclasses.dataclass
240
240
  class SocketEndpoint(Endpoint):
241
- """Socket endpoint accesible via a host and a port."""
241
+ """Socket endpoint accessible via a host and a port."""
242
242
  port: Optional[int]
243
243
  host: str = ''
244
244
 
@@ -274,7 +274,7 @@ def _is_permission_satisfied(service_account, crm, iam, required_permissions,
274
274
  # For example, `roles/iam.serviceAccountUser` can be granted at the
275
275
  # skypilot-v1 service account level, which can be checked with
276
276
  # service_account_policy = iam.projects().serviceAccounts().getIamPolicy(
277
- # resource=f'projects/{project_id}/serviceAcccounts/{email}').execute()
277
+ # resource=f'projects/{project_id}/serviceAccounts/{email}').execute()
278
278
  # We now skip the check for `iam.serviceAccounts.actAs` permission for
279
279
  # simplicity as it can be granted at the service account level.
280
280
  def check_permissions(policy, required_permissions):
@@ -1277,7 +1277,8 @@ def query_instances(
1277
1277
  except kubernetes.max_retry_error():
1278
1278
  with ux_utils.print_exception_no_traceback():
1279
1279
  if is_ssh:
1280
- node_pool = context.lstrip('ssh-') if context else ''
1280
+ node_pool = common_utils.removeprefix(context,
1281
+ 'ssh-') if context else ''
1281
1282
  msg = (
1282
1283
  f'Cannot connect to SSH Node Pool {node_pool}. '
1283
1284
  'Please check if the SSH Node Pool is up and accessible. '
@@ -133,6 +133,30 @@ DEFAULT_MAX_RETRIES = 3
133
133
  DEFAULT_RETRY_INTERVAL_SECONDS = 1
134
134
 
135
135
 
136
+ def normalize_tpu_accelerator_name(accelerator: str) -> Tuple[str, int]:
137
+ """Normalize TPU names to the k8s-compatible name and extract count."""
138
+ # Examples:
139
+ # 'tpu-v6e-8' -> ('tpu-v6e-slice', 8)
140
+ # 'tpu-v5litepod-4' -> ('tpu-v5-lite-podslice', 4)
141
+
142
+ gcp_to_k8s_patterns = [
143
+ (r'^tpu-v6e-(\d+)$', 'tpu-v6e-slice'),
144
+ (r'^tpu-v5p-(\d+)$', 'tpu-v5p-slice'),
145
+ (r'^tpu-v5litepod-(\d+)$', 'tpu-v5-lite-podslice'),
146
+ (r'^tpu-v5lite-(\d+)$', 'tpu-v5-lite-device'),
147
+ (r'^tpu-v4-(\d+)$', 'tpu-v4-podslice'),
148
+ ]
149
+
150
+ for pattern, replacement in gcp_to_k8s_patterns:
151
+ match = re.match(pattern, accelerator)
152
+ if match:
153
+ count = int(match.group(1))
154
+ return replacement, count
155
+
156
+ # Default fallback
157
+ return accelerator, 1
158
+
159
+
136
160
  def _retry_on_error(max_retries=DEFAULT_MAX_RETRIES,
137
161
  retry_interval=DEFAULT_RETRY_INTERVAL_SECONDS,
138
162
  resource_type: Optional[str] = None):
@@ -427,6 +451,7 @@ class GKELabelFormatter(GPULabelFormatter):
427
451
 
428
452
  e.g. tpu-v5-lite-podslice:8 -> '2x4'
429
453
  """
454
+ acc_type, acc_count = normalize_tpu_accelerator_name(acc_type)
430
455
  count_to_topology = cls.GKE_TPU_TOPOLOGIES.get(acc_type,
431
456
  {}).get(acc_count, None)
432
457
  if count_to_topology is None:
@@ -461,6 +486,14 @@ class GKELabelFormatter(GPULabelFormatter):
461
486
  raise ValueError(
462
487
  f'Invalid accelerator name in GKE cluster: {value}')
463
488
 
489
+ @classmethod
490
+ def validate_label_value(cls, value: str) -> Tuple[bool, str]:
491
+ try:
492
+ _ = cls.get_accelerator_from_label_value(value)
493
+ return True, ''
494
+ except ValueError as e:
495
+ return False, str(e)
496
+
464
497
 
465
498
  class GFDLabelFormatter(GPULabelFormatter):
466
499
  """GPU Feature Discovery label formatter
@@ -565,17 +598,29 @@ def detect_gpu_label_formatter(
565
598
  for label, value in node.metadata.labels.items():
566
599
  node_labels[node.metadata.name].append((label, value))
567
600
 
568
- label_formatter = None
569
-
570
601
  # Check if the node labels contain any of the GPU label prefixes
571
602
  for lf in LABEL_FORMATTER_REGISTRY:
603
+ skip = False
572
604
  for _, label_list in node_labels.items():
573
- for label, _ in label_list:
605
+ for label, value in label_list:
574
606
  if lf.match_label_key(label):
575
- label_formatter = lf()
576
- return label_formatter, node_labels
607
+ valid, reason = lf.validate_label_value(value)
608
+ if valid:
609
+ return lf(), node_labels
610
+ else:
611
+ logger.warning(f'GPU label {label} matched for label '
612
+ f'formatter {lf.__class__.__name__}, '
613
+ f'but has invalid value {value}. '
614
+ f'Reason: {reason}. '
615
+ 'Skipping...')
616
+ skip = True
617
+ break
618
+ if skip:
619
+ break
620
+ if skip:
621
+ continue
577
622
 
578
- return label_formatter, node_labels
623
+ return None, node_labels
579
624
 
580
625
 
581
626
  class Autoscaler:
@@ -754,6 +799,8 @@ class GKEAutoscaler(Autoscaler):
754
799
  f'checking {node_pool_name} for TPU {requested_acc_type}:'
755
800
  f'{requested_acc_count}')
756
801
  if 'resourceLabels' in node_config:
802
+ requested_acc_type, requested_acc_count = normalize_tpu_accelerator_name(
803
+ requested_acc_type)
757
804
  accelerator_exists = cls._node_pool_has_tpu_capacity(
758
805
  node_config['resourceLabels'], machine_type,
759
806
  requested_acc_type, requested_acc_count)
@@ -993,7 +1040,7 @@ def check_instance_fits(context: Optional[str],
993
1040
  'Maximum resources found on a single node: '
994
1041
  f'{max_cpu} CPUs, {common_utils.format_float(max_mem)}G Memory')
995
1042
 
996
- def check_tpu_fits(candidate_instance_type: 'KubernetesInstanceType',
1043
+ def check_tpu_fits(acc_type: str, acc_count: int,
997
1044
  node_list: List[Any]) -> Tuple[bool, Optional[str]]:
998
1045
  """Checks if the instance fits on the cluster based on requested TPU.
999
1046
 
@@ -1003,8 +1050,6 @@ def check_instance_fits(context: Optional[str],
1003
1050
  node (node_tpu_chip_count) and the total TPU chips across the entire
1004
1051
  podslice (topology_chip_count) are correctly handled.
1005
1052
  """
1006
- acc_type = candidate_instance_type.accelerator_type
1007
- acc_count = candidate_instance_type.accelerator_count
1008
1053
  tpu_list_in_cluster = []
1009
1054
  for node in node_list:
1010
1055
  if acc_type == node.metadata.labels[
@@ -1055,7 +1100,8 @@ def check_instance_fits(context: Optional[str],
1055
1100
  if is_tpu_on_gke(acc_type):
1056
1101
  # If requested accelerator is a TPU type, check if the cluster
1057
1102
  # has sufficient TPU resource to meet the requirement.
1058
- fits, reason = check_tpu_fits(k8s_instance_type, gpu_nodes)
1103
+ acc_type, acc_count = normalize_tpu_accelerator_name(acc_type)
1104
+ fits, reason = check_tpu_fits(acc_type, acc_count, gpu_nodes)
1059
1105
  if reason is not None:
1060
1106
  return fits, reason
1061
1107
  else:
@@ -1141,8 +1187,8 @@ def get_accelerator_label_key_values(
1141
1187
 
1142
1188
  is_ssh_node_pool = context.startswith('ssh-') if context else False
1143
1189
  cloud_name = 'SSH Node Pool' if is_ssh_node_pool else 'Kubernetes cluster'
1144
- context_display_name = context.lstrip('ssh-') if (
1145
- context and is_ssh_node_pool) else context
1190
+ context_display_name = common_utils.removeprefix(
1191
+ context, 'ssh-') if (context and is_ssh_node_pool) else context
1146
1192
 
1147
1193
  autoscaler_type = get_autoscaler_type()
1148
1194
  if autoscaler_type is not None:
@@ -2911,7 +2957,8 @@ def get_skypilot_pods(context: Optional[str] = None) -> List[Any]:
2911
2957
 
2912
2958
  def is_tpu_on_gke(accelerator: str) -> bool:
2913
2959
  """Determines if the given accelerator is a TPU supported on GKE."""
2914
- return accelerator in GKE_TPU_ACCELERATOR_TO_GENERATION
2960
+ normalized, _ = normalize_tpu_accelerator_name(accelerator)
2961
+ return normalized in GKE_TPU_ACCELERATOR_TO_GENERATION
2915
2962
 
2916
2963
 
2917
2964
  def get_node_accelerator_count(attribute_dict: dict) -> int: