skypilot-nightly 1.0.0.dev20250610__py3-none-any.whl → 1.0.0.dev20250611__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sky/__init__.py +2 -2
  2. sky/admin_policy.py +132 -6
  3. sky/benchmark/benchmark_state.py +39 -1
  4. sky/cli.py +1 -1
  5. sky/client/cli.py +1 -1
  6. sky/dashboard/out/404.html +1 -1
  7. sky/dashboard/out/_next/static/chunks/600.15a0009177e86b86.js +16 -0
  8. sky/dashboard/out/_next/static/chunks/938-ab185187a63f9cdb.js +1 -0
  9. sky/dashboard/out/_next/static/chunks/{webpack-0574a5a4ba3cf0ac.js → webpack-208a9812ab4f61c9.js} +1 -1
  10. sky/dashboard/out/_next/static/css/{8b1c8321d4c02372.css → 5d71bfc09f184bab.css} +1 -1
  11. sky/dashboard/out/_next/static/{4lwUJxN6KwBqUxqO1VccB → zJqasksBQ3HcqMpA2wTUZ}/_buildManifest.js +1 -1
  12. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  13. sky/dashboard/out/clusters/[cluster].html +1 -1
  14. sky/dashboard/out/clusters.html +1 -1
  15. sky/dashboard/out/config.html +1 -1
  16. sky/dashboard/out/index.html +1 -1
  17. sky/dashboard/out/infra/[context].html +1 -1
  18. sky/dashboard/out/infra.html +1 -1
  19. sky/dashboard/out/jobs/[job].html +1 -1
  20. sky/dashboard/out/jobs.html +1 -1
  21. sky/dashboard/out/users.html +1 -1
  22. sky/dashboard/out/workspace/new.html +1 -1
  23. sky/dashboard/out/workspaces/[name].html +1 -1
  24. sky/dashboard/out/workspaces.html +1 -1
  25. sky/jobs/scheduler.py +4 -5
  26. sky/jobs/state.py +104 -11
  27. sky/jobs/utils.py +5 -5
  28. sky/skylet/job_lib.py +95 -40
  29. sky/users/permission.py +34 -17
  30. sky/utils/admin_policy_utils.py +32 -13
  31. sky/utils/schemas.py +11 -3
  32. {skypilot_nightly-1.0.0.dev20250610.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/METADATA +1 -1
  33. {skypilot_nightly-1.0.0.dev20250610.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/RECORD +39 -39
  34. sky/dashboard/out/_next/static/chunks/600.9cc76ec442b22e10.js +0 -16
  35. sky/dashboard/out/_next/static/chunks/938-a75b7712639298b7.js +0 -1
  36. /sky/dashboard/out/_next/static/chunks/pages/{_app-4768de0aede04dc9.js → _app-7bbd9d39d6f9a98a.js} +0 -0
  37. /sky/dashboard/out/_next/static/{4lwUJxN6KwBqUxqO1VccB → zJqasksBQ3HcqMpA2wTUZ}/_ssgManifest.js +0 -0
  38. {skypilot_nightly-1.0.0.dev20250610.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/WHEEL +0 -0
  39. {skypilot_nightly-1.0.0.dev20250610.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/entry_points.txt +0 -0
  40. {skypilot_nightly-1.0.0.dev20250610.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/licenses/LICENSE +0 -0
  41. {skypilot_nightly-1.0.0.dev20250610.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/top_level.txt +0 -0
sky/jobs/state.py CHANGED
@@ -2,9 +2,11 @@
2
2
  # TODO(zhwu): maybe use file based status instead of database, so
3
3
  # that we can easily switch to a s3-based storage.
4
4
  import enum
5
+ import functools
5
6
  import json
6
7
  import pathlib
7
8
  import sqlite3
9
+ import threading
8
10
  import time
9
11
  import typing
10
12
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -172,7 +174,26 @@ def _get_db_path() -> str:
172
174
 
173
175
 
174
176
  _DB_PATH = _get_db_path()
175
- db_utils.SQLiteConn(_DB_PATH, create_table)
177
+ _db_initialized = False
178
+ _db_init_lock = threading.Lock()
179
+
180
+
181
+ def _init_db(func):
182
+ """Initialize the database."""
183
+
184
+ @functools.wraps(func)
185
+ def wrapper(*args, **kwargs):
186
+ global _db_initialized
187
+ if _db_initialized:
188
+ return func(*args, **kwargs)
189
+ with _db_init_lock:
190
+ if not _db_initialized:
191
+ db_utils.SQLiteConn(_DB_PATH, create_table)
192
+ _db_initialized = True
193
+ return func(*args, **kwargs)
194
+
195
+ return wrapper
196
+
176
197
 
177
198
  # job_duration is the time a job actually runs (including the
178
199
  # setup duration) before last_recover, excluding the provision
@@ -419,7 +440,9 @@ class ManagedJobScheduleState(enum.Enum):
419
440
 
420
441
 
421
442
  # === Status transition functions ===
443
+ @_init_db
422
444
  def set_job_info(job_id: int, name: str, workspace: str, entrypoint: str):
445
+ assert _db_initialized
423
446
  with db_utils.safe_cursor(_DB_PATH) as cursor:
424
447
  cursor.execute(
425
448
  """\
@@ -430,8 +453,10 @@ def set_job_info(job_id: int, name: str, workspace: str, entrypoint: str):
430
453
  entrypoint))
431
454
 
432
455
 
456
+ @_init_db
433
457
  def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str):
434
458
  """Set the task to pending state."""
459
+ assert _db_initialized
435
460
  with db_utils.safe_cursor(_DB_PATH) as cursor:
436
461
  cursor.execute(
437
462
  """\
@@ -442,6 +467,7 @@ def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str):
442
467
  ManagedJobStatus.PENDING.value))
443
468
 
444
469
 
470
+ @_init_db
445
471
  def set_starting(job_id: int, task_id: int, run_timestamp: str,
446
472
  submit_time: float, resources_str: str,
447
473
  specs: Dict[str, Union[str,
@@ -458,6 +484,7 @@ def set_starting(job_id: int, task_id: int, run_timestamp: str,
458
484
  specs: The specs of the managed task.
459
485
  callback_func: The callback function.
460
486
  """
487
+ assert _db_initialized
461
488
  # Use the timestamp in the `run_timestamp` ('sky-2022-10...'), to make
462
489
  # the log directory and submission time align with each other, so as to
463
490
  # make it easier to find them based on one of the values.
@@ -490,12 +517,14 @@ def set_starting(job_id: int, task_id: int, run_timestamp: str,
490
517
  callback_func('STARTING')
491
518
 
492
519
 
520
+ @_init_db
493
521
  def set_backoff_pending(job_id: int, task_id: int):
494
522
  """Set the task to PENDING state if it is in backoff.
495
523
 
496
524
  This should only be used to transition from STARTING or RECOVERING back to
497
525
  PENDING.
498
526
  """
527
+ assert _db_initialized
499
528
  with db_utils.safe_cursor(_DB_PATH) as cursor:
500
529
  cursor.execute(
501
530
  """\
@@ -514,6 +543,7 @@ def set_backoff_pending(job_id: int, task_id: int):
514
543
  # Do not call callback_func here, as we don't use the callback for PENDING.
515
544
 
516
545
 
546
+ @_init_db
517
547
  def set_restarting(job_id: int, task_id: int, recovering: bool):
518
548
  """Set the task back to STARTING or RECOVERING from PENDING.
519
549
 
@@ -522,6 +552,7 @@ def set_restarting(job_id: int, task_id: int, recovering: bool):
522
552
  after using set_backoff_pending to transition back to PENDING during
523
553
  launch retry backoff.
524
554
  """
555
+ assert _db_initialized
525
556
  target_status = ManagedJobStatus.STARTING.value
526
557
  if recovering:
527
558
  target_status = ManagedJobStatus.RECOVERING.value
@@ -543,9 +574,11 @@ def set_restarting(job_id: int, task_id: int, recovering: bool):
543
574
  # initial (pre-`set_backoff_pending`) transition to STARTING or RECOVERING.
544
575
 
545
576
 
577
+ @_init_db
546
578
  def set_started(job_id: int, task_id: int, start_time: float,
547
579
  callback_func: CallbackType):
548
580
  """Set the task to started state."""
581
+ assert _db_initialized
549
582
  logger.info('Job started.')
550
583
  with db_utils.safe_cursor(_DB_PATH) as cursor:
551
584
  cursor.execute(
@@ -574,8 +607,10 @@ def set_started(job_id: int, task_id: int, start_time: float,
574
607
  callback_func('STARTED')
575
608
 
576
609
 
610
+ @_init_db
577
611
  def set_recovering(job_id: int, task_id: int, callback_func: CallbackType):
578
612
  """Set the task to recovering state, and update the job duration."""
613
+ assert _db_initialized
579
614
  logger.info('=== Recovering... ===')
580
615
  with db_utils.safe_cursor(_DB_PATH) as cursor:
581
616
  cursor.execute(
@@ -595,9 +630,11 @@ def set_recovering(job_id: int, task_id: int, callback_func: CallbackType):
595
630
  callback_func('RECOVERING')
596
631
 
597
632
 
633
+ @_init_db
598
634
  def set_recovered(job_id: int, task_id: int, recovered_time: float,
599
635
  callback_func: CallbackType):
600
636
  """Set the task to recovered."""
637
+ assert _db_initialized
601
638
  with db_utils.safe_cursor(_DB_PATH) as cursor:
602
639
  cursor.execute(
603
640
  """\
@@ -617,9 +654,11 @@ def set_recovered(job_id: int, task_id: int, recovered_time: float,
617
654
  callback_func('RECOVERED')
618
655
 
619
656
 
657
+ @_init_db
620
658
  def set_succeeded(job_id: int, task_id: int, end_time: float,
621
659
  callback_func: CallbackType):
622
660
  """Set the task to succeeded, if it is in a non-terminal state."""
661
+ assert _db_initialized
623
662
  with db_utils.safe_cursor(_DB_PATH) as cursor:
624
663
  cursor.execute(
625
664
  """\
@@ -639,6 +678,7 @@ def set_succeeded(job_id: int, task_id: int, end_time: float,
639
678
  logger.info('Job succeeded.')
640
679
 
641
680
 
681
+ @_init_db
642
682
  def set_failed(
643
683
  job_id: int,
644
684
  task_id: Optional[int],
@@ -663,6 +703,7 @@ def set_failed(
663
703
  override_terminal: If True, override the current status even if end_at
664
704
  is already set.
665
705
  """
706
+ assert _db_initialized
666
707
  assert failure_type.is_failed(), failure_type
667
708
  end_time = time.time() if end_time is None else end_time
668
709
 
@@ -713,12 +754,14 @@ def set_failed(
713
754
  logger.info(failure_reason)
714
755
 
715
756
 
757
+ @_init_db
716
758
  def set_cancelling(job_id: int, callback_func: CallbackType):
717
759
  """Set tasks in the job as cancelling, if they are in non-terminal states.
718
760
 
719
761
  task_id is not needed, because we expect the job should be cancelled
720
762
  as a whole, and we should not cancel a single task.
721
763
  """
764
+ assert _db_initialized
722
765
  with db_utils.safe_cursor(_DB_PATH) as cursor:
723
766
  rows = cursor.execute(
724
767
  """\
@@ -734,11 +777,13 @@ def set_cancelling(job_id: int, callback_func: CallbackType):
734
777
  logger.info('Cancellation skipped, job is already terminal')
735
778
 
736
779
 
780
+ @_init_db
737
781
  def set_cancelled(job_id: int, callback_func: CallbackType):
738
782
  """Set tasks in the job as cancelled, if they are in CANCELLING state.
739
783
 
740
784
  The set_cancelling should be called before this function.
741
785
  """
786
+ assert _db_initialized
742
787
  with db_utils.safe_cursor(_DB_PATH) as cursor:
743
788
  rows = cursor.execute(
744
789
  """\
@@ -755,11 +800,14 @@ def set_cancelled(job_id: int, callback_func: CallbackType):
755
800
  logger.info('Cancellation skipped, job is not CANCELLING')
756
801
 
757
802
 
803
+ @_init_db
758
804
  def set_local_log_file(job_id: int, task_id: Optional[int],
759
805
  local_log_file: str):
760
806
  """Set the local log file for a job."""
807
+ assert _db_initialized
761
808
  filter_str = 'spot_job_id=(?)'
762
809
  filter_args = [local_log_file, job_id]
810
+
763
811
  if task_id is not None:
764
812
  filter_str += ' AND task_id=(?)'
765
813
  filter_args.append(task_id)
@@ -770,9 +818,11 @@ def set_local_log_file(job_id: int, task_id: Optional[int],
770
818
 
771
819
 
772
820
  # ======== utility functions ========
821
+ @_init_db
773
822
  def get_nonterminal_job_ids_by_name(name: Optional[str],
774
823
  all_users: bool = False) -> List[int]:
775
824
  """Get non-terminal job ids by name."""
825
+ assert _db_initialized
776
826
  statuses = ', '.join(['?'] * len(ManagedJobStatus.terminal_statuses()))
777
827
  field_values = [
778
828
  status.value for status in ManagedJobStatus.terminal_statuses()
@@ -807,6 +857,7 @@ def get_nonterminal_job_ids_by_name(name: Optional[str],
807
857
  return job_ids
808
858
 
809
859
 
860
+ @_init_db
810
861
  def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]:
811
862
  """Get jobs from the database that have a live schedule_state.
812
863
 
@@ -815,6 +866,7 @@ def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]:
815
866
  exception: the job may have just transitioned from WAITING to LAUNCHING, but
816
867
  the controller process has not yet started.
817
868
  """
869
+ assert _db_initialized
818
870
  job_filter = '' if job_id is None else 'AND spot_job_id=(?)'
819
871
  job_value = (job_id,) if job_id is not None else ()
820
872
 
@@ -845,6 +897,7 @@ def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]:
845
897
  return jobs
846
898
 
847
899
 
900
+ @_init_db
848
901
  def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
849
902
  """Get jobs that need controller process checking.
850
903
 
@@ -856,6 +909,7 @@ def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
856
909
  - Jobs have schedule_state DONE but are in a non-terminal status
857
910
  - Legacy jobs (that is, no schedule state) that are in non-terminal status
858
911
  """
912
+ assert _db_initialized
859
913
  job_filter = '' if job_id is None else 'AND spot.spot_job_id=(?)'
860
914
  job_value = () if job_id is None else (job_id,)
861
915
 
@@ -901,8 +955,10 @@ def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
901
955
  return [row[0] for row in rows if row[0] is not None]
902
956
 
903
957
 
958
+ @_init_db
904
959
  def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
905
960
  """Get all job ids by name."""
961
+ assert _db_initialized
906
962
  name_filter = ''
907
963
  field_values = []
908
964
  if name is not None:
@@ -928,8 +984,10 @@ def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
928
984
  return job_ids
929
985
 
930
986
 
987
+ @_init_db
931
988
  def _get_all_task_ids_statuses(
932
989
  job_id: int) -> List[Tuple[int, ManagedJobStatus]]:
990
+ assert _db_initialized
933
991
  with db_utils.safe_cursor(_DB_PATH) as cursor:
934
992
  id_statuses = cursor.execute(
935
993
  """\
@@ -971,11 +1029,13 @@ def get_status(job_id: int) -> Optional[ManagedJobStatus]:
971
1029
  return status
972
1030
 
973
1031
 
1032
+ @_init_db
974
1033
  def get_failure_reason(job_id: int) -> Optional[str]:
975
1034
  """Get the failure reason of a job.
976
1035
 
977
1036
  If the job has multiple tasks, we return the first failure reason.
978
1037
  """
1038
+ assert _db_initialized
979
1039
  with db_utils.safe_cursor(_DB_PATH) as cursor:
980
1040
  reason = cursor.execute(
981
1041
  """\
@@ -988,8 +1048,10 @@ def get_failure_reason(job_id: int) -> Optional[str]:
988
1048
  return reason[0]
989
1049
 
990
1050
 
1051
+ @_init_db
991
1052
  def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
992
1053
  """Get managed jobs from the database."""
1054
+ assert _db_initialized
993
1055
  job_filter = '' if job_id is None else f'WHERE spot.spot_job_id={job_id}'
994
1056
 
995
1057
  # Join spot and job_info tables to get the job name for each task.
@@ -1032,8 +1094,10 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
1032
1094
  return jobs
1033
1095
 
1034
1096
 
1097
+ @_init_db
1035
1098
  def get_task_name(job_id: int, task_id: int) -> str:
1036
1099
  """Get the task name of a job."""
1100
+ assert _db_initialized
1037
1101
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1038
1102
  task_name = cursor.execute(
1039
1103
  """\
@@ -1043,8 +1107,10 @@ def get_task_name(job_id: int, task_id: int) -> str:
1043
1107
  return task_name[0]
1044
1108
 
1045
1109
 
1110
+ @_init_db
1046
1111
  def get_latest_job_id() -> Optional[int]:
1047
1112
  """Get the latest job id."""
1113
+ assert _db_initialized
1048
1114
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1049
1115
  rows = cursor.execute("""\
1050
1116
  SELECT spot_job_id FROM spot
@@ -1055,7 +1121,9 @@ def get_latest_job_id() -> Optional[int]:
1055
1121
  return None
1056
1122
 
1057
1123
 
1124
+ @_init_db
1058
1125
  def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
1126
+ assert _db_initialized
1059
1127
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1060
1128
  task_specs = cursor.execute(
1061
1129
  """\
@@ -1065,8 +1133,10 @@ def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
1065
1133
  return json.loads(task_specs[0])
1066
1134
 
1067
1135
 
1136
+ @_init_db
1068
1137
  def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
1069
1138
  """Get the local log directory for a job."""
1139
+ assert _db_initialized
1070
1140
  filter_str = 'spot_job_id=(?)'
1071
1141
  filter_args = [job_id]
1072
1142
  if task_id is not None:
@@ -1084,10 +1154,12 @@ def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
1084
1154
  # scheduler lock to work correctly.
1085
1155
 
1086
1156
 
1157
+ @_init_db
1087
1158
  def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
1088
1159
  original_user_yaml_path: str, env_file_path: str,
1089
1160
  user_hash: str, priority: int) -> None:
1090
1161
  """Do not call without holding the scheduler lock."""
1162
+ assert _db_initialized
1091
1163
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1092
1164
  updated_count = cursor.execute(
1093
1165
  'UPDATE job_info SET '
@@ -1101,9 +1173,11 @@ def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
1101
1173
  assert updated_count == 1, (job_id, updated_count)
1102
1174
 
1103
1175
 
1176
+ @_init_db
1104
1177
  def scheduler_set_launching(job_id: int,
1105
1178
  current_state: ManagedJobScheduleState) -> None:
1106
1179
  """Do not call without holding the scheduler lock."""
1180
+ assert _db_initialized
1107
1181
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1108
1182
  updated_count = cursor.execute(
1109
1183
  'UPDATE job_info SET '
@@ -1114,8 +1188,10 @@ def scheduler_set_launching(job_id: int,
1114
1188
  assert updated_count == 1, (job_id, updated_count)
1115
1189
 
1116
1190
 
1191
+ @_init_db
1117
1192
  def scheduler_set_alive(job_id: int) -> None:
1118
1193
  """Do not call without holding the scheduler lock."""
1194
+ assert _db_initialized
1119
1195
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1120
1196
  updated_count = cursor.execute(
1121
1197
  'UPDATE job_info SET '
@@ -1126,8 +1202,10 @@ def scheduler_set_alive(job_id: int) -> None:
1126
1202
  assert updated_count == 1, (job_id, updated_count)
1127
1203
 
1128
1204
 
1205
+ @_init_db
1129
1206
  def scheduler_set_alive_backoff(job_id: int) -> None:
1130
1207
  """Do not call without holding the scheduler lock."""
1208
+ assert _db_initialized
1131
1209
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1132
1210
  updated_count = cursor.execute(
1133
1211
  'UPDATE job_info SET '
@@ -1138,8 +1216,10 @@ def scheduler_set_alive_backoff(job_id: int) -> None:
1138
1216
  assert updated_count == 1, (job_id, updated_count)
1139
1217
 
1140
1218
 
1219
+ @_init_db
1141
1220
  def scheduler_set_alive_waiting(job_id: int) -> None:
1142
1221
  """Do not call without holding the scheduler lock."""
1222
+ assert _db_initialized
1143
1223
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1144
1224
  updated_count = cursor.execute(
1145
1225
  'UPDATE job_info SET '
@@ -1151,8 +1231,10 @@ def scheduler_set_alive_waiting(job_id: int) -> None:
1151
1231
  assert updated_count == 1, (job_id, updated_count)
1152
1232
 
1153
1233
 
1234
+ @_init_db
1154
1235
  def scheduler_set_done(job_id: int, idempotent: bool = False) -> None:
1155
1236
  """Do not call without holding the scheduler lock."""
1237
+ assert _db_initialized
1156
1238
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1157
1239
  updated_count = cursor.execute(
1158
1240
  'UPDATE job_info SET '
@@ -1164,7 +1246,9 @@ def scheduler_set_done(job_id: int, idempotent: bool = False) -> None:
1164
1246
  assert updated_count == 1, (job_id, updated_count)
1165
1247
 
1166
1248
 
1249
+ @_init_db
1167
1250
  def set_job_controller_pid(job_id: int, pid: int):
1251
+ assert _db_initialized
1168
1252
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1169
1253
  updated_count = cursor.execute(
1170
1254
  'UPDATE job_info SET '
@@ -1173,7 +1257,9 @@ def set_job_controller_pid(job_id: int, pid: int):
1173
1257
  assert updated_count == 1, (job_id, updated_count)
1174
1258
 
1175
1259
 
1260
+ @_init_db
1176
1261
  def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState:
1262
+ assert _db_initialized
1177
1263
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1178
1264
  state = cursor.execute(
1179
1265
  'SELECT schedule_state FROM job_info WHERE spot_job_id = (?)',
@@ -1181,7 +1267,9 @@ def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState:
1181
1267
  return ManagedJobScheduleState(state)
1182
1268
 
1183
1269
 
1270
+ @_init_db
1184
1271
  def get_num_launching_jobs() -> int:
1272
+ assert _db_initialized
1185
1273
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1186
1274
  return cursor.execute(
1187
1275
  'SELECT COUNT(*) '
@@ -1190,7 +1278,9 @@ def get_num_launching_jobs() -> int:
1190
1278
  (ManagedJobScheduleState.LAUNCHING.value,)).fetchone()[0]
1191
1279
 
1192
1280
 
1281
+ @_init_db
1193
1282
  def get_num_alive_jobs() -> int:
1283
+ assert _db_initialized
1194
1284
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1195
1285
  return cursor.execute(
1196
1286
  'SELECT COUNT(*) '
@@ -1202,32 +1292,33 @@ def get_num_alive_jobs() -> int:
1202
1292
  ManagedJobScheduleState.ALIVE_BACKOFF.value)).fetchone()[0]
1203
1293
 
1204
1294
 
1295
+ @_init_db
1205
1296
  def get_waiting_job() -> Optional[Dict[str, Any]]:
1206
1297
  """Get the next job that should transition to LAUNCHING.
1207
1298
 
1208
- Selects the highest-priority (lowest numerical value) WAITING or
1209
- ALIVE_WAITING job, provided its priority value is less than or equal to any
1210
- currently LAUNCHING or ALIVE_BACKOFF job.
1299
+ Selects the highest-priority WAITING or ALIVE_WAITING job, provided its
1300
+ priority is greater than or equal to any currently LAUNCHING or
1301
+ ALIVE_BACKOFF job.
1211
1302
 
1212
1303
  Backwards compatibility note: jobs submitted before #4485 will have no
1213
1304
  schedule_state and will be ignored by this SQL query.
1214
1305
  """
1306
+ assert _db_initialized
1215
1307
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1216
- # Get the highest-priority (lowest numerical value) WAITING or
1217
- # ALIVE_WAITING job whose priority value is less than or equal to
1218
- # the highest priority (numerically smallest) LAUNCHING or
1308
+ # Get the highest-priority WAITING or ALIVE_WAITING job whose priority
1309
+ # is greater than or equal to the highest priority LAUNCHING or
1219
1310
  # ALIVE_BACKOFF job's priority.
1220
1311
  waiting_job_row = cursor.execute(
1221
1312
  'SELECT spot_job_id, schedule_state, dag_yaml_path, env_file_path '
1222
1313
  'FROM job_info '
1223
1314
  'WHERE schedule_state IN (?, ?) '
1224
- 'AND priority <= COALESCE('
1225
- ' (SELECT MIN(priority) '
1315
+ 'AND priority >= COALESCE('
1316
+ ' (SELECT MAX(priority) '
1226
1317
  ' FROM job_info '
1227
1318
  ' WHERE schedule_state IN (?, ?)), '
1228
- ' 1000'
1319
+ ' 0'
1229
1320
  ')'
1230
- 'ORDER BY priority ASC, spot_job_id ASC LIMIT 1',
1321
+ 'ORDER BY priority DESC, spot_job_id ASC LIMIT 1',
1231
1322
  (ManagedJobScheduleState.WAITING.value,
1232
1323
  ManagedJobScheduleState.ALIVE_WAITING.value,
1233
1324
  ManagedJobScheduleState.LAUNCHING.value,
@@ -1244,8 +1335,10 @@ def get_waiting_job() -> Optional[Dict[str, Any]]:
1244
1335
  }
1245
1336
 
1246
1337
 
1338
+ @_init_db
1247
1339
  def get_workspace(job_id: int) -> str:
1248
1340
  """Get the workspace of a job."""
1341
+ assert _db_initialized
1249
1342
  with db_utils.safe_cursor(_DB_PATH) as cursor:
1250
1343
  workspace = cursor.execute(
1251
1344
  'SELECT workspace FROM job_info WHERE spot_job_id = (?)',
sky/jobs/utils.py CHANGED
@@ -245,6 +245,7 @@ def update_managed_jobs_statuses(job_id: Optional[int] = None):
245
245
  return
246
246
 
247
247
  for job_id in job_ids:
248
+ assert job_id is not None
248
249
  tasks = managed_job_state.get_managed_jobs(job_id)
249
250
  # Note: controller_pid and schedule_state are in the job_info table
250
251
  # which is joined to the spot table, so all tasks with the same job_id
@@ -933,7 +934,7 @@ def dump_managed_job_queue() -> str:
933
934
  # Figure out what the highest priority blocking job is. We need to know in
934
935
  # order to determine if other jobs are blocked by a higher priority job, or
935
936
  # just by the limited controller resources.
936
- lowest_blocking_priority_value = 1000
937
+ highest_blocking_priority = 0
937
938
  for job in jobs:
938
939
  if job['schedule_state'] not in (
939
940
  # LAUNCHING and ALIVE_BACKOFF jobs will block other jobs with
@@ -949,8 +950,8 @@ def dump_managed_job_queue() -> str:
949
950
  continue
950
951
 
951
952
  priority = job.get('priority')
952
- if priority is not None and priority < lowest_blocking_priority_value:
953
- lowest_blocking_priority_value = priority
953
+ if priority is not None and priority > highest_blocking_priority:
954
+ highest_blocking_priority = priority
954
955
 
955
956
  for job in jobs:
956
957
  end_at = job['end_at']
@@ -998,8 +999,7 @@ def dump_managed_job_queue() -> str:
998
999
  state_details = 'In backoff, waiting for resources'
999
1000
  elif job['schedule_state'] in ('WAITING', 'ALIVE_WAITING'):
1000
1001
  priority = job.get('priority')
1001
- if (priority is not None and
1002
- priority > lowest_blocking_priority_value):
1002
+ if (priority is not None and priority < highest_blocking_priority):
1003
1003
  # Job is lower priority than some other blocking job.
1004
1004
  state_details = 'Waiting for higher priority jobs to launch'
1005
1005
  else: