skypilot-nightly 1.0.0.dev20250630__py3-none-any.whl → 1.0.0.dev20250701__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/cloud_vm_ray_backend.py +3 -3
  3. sky/dashboard/out/404.html +1 -1
  4. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  5. sky/dashboard/out/clusters/[cluster].html +1 -1
  6. sky/dashboard/out/clusters.html +1 -1
  7. sky/dashboard/out/config.html +1 -1
  8. sky/dashboard/out/index.html +1 -1
  9. sky/dashboard/out/infra/[context].html +1 -1
  10. sky/dashboard/out/infra.html +1 -1
  11. sky/dashboard/out/jobs/[job].html +1 -1
  12. sky/dashboard/out/jobs.html +1 -1
  13. sky/dashboard/out/users.html +1 -1
  14. sky/dashboard/out/volumes.html +1 -1
  15. sky/dashboard/out/workspace/new.html +1 -1
  16. sky/dashboard/out/workspaces/[name].html +1 -1
  17. sky/dashboard/out/workspaces.html +1 -1
  18. sky/jobs/controller.py +4 -0
  19. sky/jobs/server/core.py +5 -9
  20. sky/jobs/state.py +820 -670
  21. sky/jobs/utils.py +7 -15
  22. sky/server/common.py +1 -0
  23. sky/server/server.py +37 -15
  24. sky/setup_files/dependencies.py +2 -0
  25. sky/task.py +1 -1
  26. sky/utils/dag_utils.py +4 -2
  27. {skypilot_nightly-1.0.0.dev20250630.dist-info → skypilot_nightly-1.0.0.dev20250701.dist-info}/METADATA +4 -1
  28. {skypilot_nightly-1.0.0.dev20250630.dist-info → skypilot_nightly-1.0.0.dev20250701.dist-info}/RECORD +34 -34
  29. /sky/dashboard/out/_next/static/{NdypbqMxaYucRGfopkKXa → Md3rlE87jmL5uv7gSo8mR}/_buildManifest.js +0 -0
  30. /sky/dashboard/out/_next/static/{NdypbqMxaYucRGfopkKXa → Md3rlE87jmL5uv7gSo8mR}/_ssgManifest.js +0 -0
  31. {skypilot_nightly-1.0.0.dev20250630.dist-info → skypilot_nightly-1.0.0.dev20250701.dist-info}/WHEEL +0 -0
  32. {skypilot_nightly-1.0.0.dev20250630.dist-info → skypilot_nightly-1.0.0.dev20250701.dist-info}/entry_points.txt +0 -0
  33. {skypilot_nightly-1.0.0.dev20250630.dist-info → skypilot_nightly-1.0.0.dev20250701.dist-info}/licenses/LICENSE +0 -0
  34. {skypilot_nightly-1.0.0.dev20250630.dist-info → skypilot_nightly-1.0.0.dev20250701.dist-info}/top_level.txt +0 -0
sky/jobs/state.py CHANGED
@@ -4,28 +4,41 @@
4
4
  import enum
5
5
  import functools
6
6
  import json
7
+ import os
7
8
  import pathlib
8
- import sqlite3
9
9
  import threading
10
10
  import time
11
11
  import typing
12
12
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
13
 
14
14
  import colorama
15
+ import sqlalchemy
16
+ from sqlalchemy import exc as sqlalchemy_exc
17
+ from sqlalchemy import orm
18
+ from sqlalchemy.dialects import postgresql
19
+ from sqlalchemy.dialects import sqlite
20
+ from sqlalchemy.ext import declarative
15
21
 
16
22
  from sky import exceptions
17
23
  from sky import sky_logging
24
+ from sky import skypilot_config
18
25
  from sky.skylet import constants
19
26
  from sky.utils import common_utils
20
27
  from sky.utils import db_utils
21
28
 
22
29
  if typing.TYPE_CHECKING:
30
+ from sqlalchemy.engine import row
31
+
23
32
  import sky
24
33
 
25
34
  CallbackType = Callable[[str], None]
26
35
 
27
36
  logger = sky_logging.init_logger(__name__)
28
37
 
38
+ _SQLALCHEMY_ENGINE: Optional[sqlalchemy.engine.Engine] = None
39
+ _DB_INIT_LOCK = threading.Lock()
40
+
41
+ Base = declarative.declarative_base()
29
42
 
30
43
  # === Database schema ===
31
44
  # `spot` table contains all the finest-grained tasks, including all the
@@ -38,144 +51,183 @@ logger = sky_logging.init_logger(__name__)
38
51
  # identifier/primary key for all the tasks. We will use `spot_job_id`
39
52
  # to identify the job.
40
53
  # TODO(zhwu): schema migration may be needed.
41
- def create_table(cursor, conn):
54
+
55
+ spot_table = sqlalchemy.Table(
56
+ 'spot',
57
+ Base.metadata,
58
+ sqlalchemy.Column('job_id',
59
+ sqlalchemy.Integer,
60
+ primary_key=True,
61
+ autoincrement=True),
62
+ sqlalchemy.Column('job_name', sqlalchemy.Text),
63
+ sqlalchemy.Column('resources', sqlalchemy.Text),
64
+ sqlalchemy.Column('submitted_at', sqlalchemy.Float),
65
+ sqlalchemy.Column('status', sqlalchemy.Text),
66
+ sqlalchemy.Column('run_timestamp', sqlalchemy.Text),
67
+ sqlalchemy.Column('start_at', sqlalchemy.Float, server_default=None),
68
+ sqlalchemy.Column('end_at', sqlalchemy.Float, server_default=None),
69
+ sqlalchemy.Column('last_recovered_at',
70
+ sqlalchemy.Float,
71
+ server_default='-1'),
72
+ sqlalchemy.Column('recovery_count', sqlalchemy.Integer, server_default='0'),
73
+ sqlalchemy.Column('job_duration', sqlalchemy.Float, server_default='0'),
74
+ sqlalchemy.Column('failure_reason', sqlalchemy.Text),
75
+ sqlalchemy.Column('spot_job_id', sqlalchemy.Integer),
76
+ sqlalchemy.Column('task_id', sqlalchemy.Integer, server_default='0'),
77
+ sqlalchemy.Column('task_name', sqlalchemy.Text),
78
+ sqlalchemy.Column('specs', sqlalchemy.Text),
79
+ sqlalchemy.Column('local_log_file', sqlalchemy.Text, server_default=None),
80
+ )
81
+
82
+ job_info_table = sqlalchemy.Table(
83
+ 'job_info',
84
+ Base.metadata,
85
+ sqlalchemy.Column('spot_job_id',
86
+ sqlalchemy.Integer,
87
+ primary_key=True,
88
+ autoincrement=True),
89
+ sqlalchemy.Column('name', sqlalchemy.Text),
90
+ sqlalchemy.Column('schedule_state', sqlalchemy.Text),
91
+ sqlalchemy.Column('controller_pid', sqlalchemy.Integer,
92
+ server_default=None),
93
+ sqlalchemy.Column('dag_yaml_path', sqlalchemy.Text),
94
+ sqlalchemy.Column('env_file_path', sqlalchemy.Text),
95
+ sqlalchemy.Column('user_hash', sqlalchemy.Text),
96
+ sqlalchemy.Column('workspace', sqlalchemy.Text, server_default=None),
97
+ sqlalchemy.Column('priority',
98
+ sqlalchemy.Integer,
99
+ server_default=str(constants.DEFAULT_PRIORITY)),
100
+ sqlalchemy.Column('entrypoint', sqlalchemy.Text, server_default=None),
101
+ sqlalchemy.Column('original_user_yaml_path',
102
+ sqlalchemy.Text,
103
+ server_default=None),
104
+ )
105
+
106
+ ha_recovery_script_table = sqlalchemy.Table(
107
+ 'ha_recovery_script',
108
+ Base.metadata,
109
+ sqlalchemy.Column('job_id', sqlalchemy.Integer, primary_key=True),
110
+ sqlalchemy.Column('script', sqlalchemy.Text),
111
+ )
112
+
113
+
114
+ def create_table():
42
115
  # Enable WAL mode to avoid locking issues.
43
116
  # See: issue #3863, #1441 and PR #1509
44
117
  # https://github.com/microsoft/WSL/issues/2395
45
118
  # TODO(romilb): We do not enable WAL for WSL because of known issue in WSL.
46
119
  # This may cause the database locked problem from WSL issue #1441.
47
- if not common_utils.is_wsl():
120
+ if (_SQLALCHEMY_ENGINE.dialect.name
121
+ == db_utils.SQLAlchemyDialect.SQLITE.value and
122
+ not common_utils.is_wsl()):
48
123
  try:
49
- cursor.execute('PRAGMA journal_mode=WAL')
50
- except sqlite3.OperationalError as e:
124
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
125
+ session.execute(sqlalchemy.text('PRAGMA journal_mode=WAL'))
126
+ session.commit()
127
+ except sqlalchemy_exc.OperationalError as e:
51
128
  if 'database is locked' not in str(e):
52
129
  raise
53
130
  # If the database is locked, it is OK to continue, as the WAL mode
54
131
  # is not critical and is likely to be enabled by other processes.
55
132
 
56
- cursor.execute("""\
57
- CREATE TABLE IF NOT EXISTS spot (
58
- job_id INTEGER PRIMARY KEY AUTOINCREMENT,
59
- job_name TEXT,
60
- resources TEXT,
61
- submitted_at FLOAT,
62
- status TEXT,
63
- run_timestamp TEXT CANDIDATE KEY,
64
- start_at FLOAT DEFAULT NULL,
65
- end_at FLOAT DEFAULT NULL,
66
- last_recovered_at FLOAT DEFAULT -1,
67
- recovery_count INTEGER DEFAULT 0,
68
- job_duration FLOAT DEFAULT 0,
69
- failure_reason TEXT,
70
- spot_job_id INTEGER,
71
- task_id INTEGER DEFAULT 0,
72
- task_name TEXT,
73
- specs TEXT,
74
- local_log_file TEXT DEFAULT NULL)""")
75
- conn.commit()
76
-
77
- db_utils.add_column_to_table(cursor, conn, 'spot', 'failure_reason', 'TEXT')
78
- # Create a new column `spot_job_id`, which is the same for tasks of the
79
- # same managed job.
80
- # The original `job_id` no longer has an actual meaning, but only a legacy
81
- # identifier for all tasks in database.
82
- db_utils.add_column_to_table(cursor,
83
- conn,
84
- 'spot',
85
- 'spot_job_id',
86
- 'INTEGER',
87
- copy_from='job_id')
88
- db_utils.add_column_to_table(cursor,
89
- conn,
90
- 'spot',
91
- 'task_id',
92
- 'INTEGER DEFAULT 0',
93
- value_to_replace_existing_entries=0)
94
- db_utils.add_column_to_table(cursor,
95
- conn,
96
- 'spot',
97
- 'task_name',
98
- 'TEXT',
99
- copy_from='job_name')
100
-
101
- # Specs is some useful information about the task, e.g., the
102
- # max_restarts_on_errors value. It is stored in JSON format.
103
- db_utils.add_column_to_table(cursor,
104
- conn,
105
- 'spot',
106
- 'specs',
107
- 'TEXT',
108
- value_to_replace_existing_entries=json.dumps({
109
- 'max_restarts_on_errors': 0,
110
- }))
111
- db_utils.add_column_to_table(cursor, conn, 'spot', 'local_log_file',
112
- 'TEXT DEFAULT NULL')
113
-
114
- # `job_info` contains the mapping from job_id to the job_name, as well as
115
- # information used by the scheduler.
116
- cursor.execute(f"""\
117
- CREATE TABLE IF NOT EXISTS job_info (
118
- spot_job_id INTEGER PRIMARY KEY AUTOINCREMENT,
119
- name TEXT,
120
- schedule_state TEXT,
121
- controller_pid INTEGER DEFAULT NULL,
122
- dag_yaml_path TEXT,
123
- env_file_path TEXT,
124
- user_hash TEXT,
125
- workspace TEXT DEFAULT NULL,
126
- priority INTEGER DEFAULT {constants.DEFAULT_PRIORITY},
127
- entrypoint TEXT DEFAULT NULL,
128
- original_user_yaml_path TEXT DEFAULT NULL)""")
129
-
130
- db_utils.add_column_to_table(cursor, conn, 'job_info', 'schedule_state',
131
- 'TEXT')
132
-
133
- db_utils.add_column_to_table(cursor, conn, 'job_info', 'controller_pid',
134
- 'INTEGER DEFAULT NULL')
135
-
136
- db_utils.add_column_to_table(cursor, conn, 'job_info', 'dag_yaml_path',
137
- 'TEXT')
138
-
139
- db_utils.add_column_to_table(cursor, conn, 'job_info', 'env_file_path',
140
- 'TEXT')
141
-
142
- db_utils.add_column_to_table(cursor, conn, 'job_info', 'user_hash', 'TEXT')
143
-
144
- db_utils.add_column_to_table(cursor,
145
- conn,
146
- 'job_info',
147
- 'workspace',
148
- 'TEXT DEFAULT NULL',
149
- value_to_replace_existing_entries='default')
150
-
151
- db_utils.add_column_to_table(
152
- cursor,
153
- conn,
154
- 'job_info',
155
- 'priority',
156
- 'INTEGER',
157
- value_to_replace_existing_entries=constants.DEFAULT_PRIORITY)
158
-
159
- db_utils.add_column_to_table(cursor, conn, 'job_info', 'entrypoint', 'TEXT')
160
- db_utils.add_column_to_table(cursor, conn, 'job_info',
161
- 'original_user_yaml_path', 'TEXT')
162
- conn.commit()
163
-
164
-
165
- # Module-level connection/cursor; thread-safe as the db is initialized once
166
- # across all threads.
167
- def _get_db_path() -> str:
168
- """Workaround to collapse multi-step Path ops for type checker.
169
- Ensures _DB_PATH is str, avoiding Union[Path, str] inference.
170
- """
171
- path = pathlib.Path('~/.sky/spot_jobs.db')
172
- path = path.expanduser().absolute()
173
- path.parents[0].mkdir(parents=True, exist_ok=True)
174
- return str(path)
175
-
176
-
177
- _DB_PATH = None
178
- _db_init_lock = threading.Lock()
133
+ # Create tables if they don't exist
134
+ Base.metadata.create_all(bind=_SQLALCHEMY_ENGINE)
135
+
136
+ # Backward compatibility: add columns that not exist in older databases
137
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
138
+ db_utils.add_column_to_table_sqlalchemy(session, 'spot',
139
+ 'failure_reason',
140
+ sqlalchemy.Text())
141
+ db_utils.add_column_to_table_sqlalchemy(session,
142
+ 'spot',
143
+ 'spot_job_id',
144
+ sqlalchemy.Integer(),
145
+ copy_from='job_id')
146
+ db_utils.add_column_to_table_sqlalchemy(
147
+ session,
148
+ 'spot',
149
+ 'task_id',
150
+ sqlalchemy.Integer(),
151
+ default_statement='DEFAULT 0',
152
+ value_to_replace_existing_entries=0)
153
+ db_utils.add_column_to_table_sqlalchemy(session,
154
+ 'spot',
155
+ 'task_name',
156
+ sqlalchemy.Text(),
157
+ copy_from='job_name')
158
+ db_utils.add_column_to_table_sqlalchemy(
159
+ session,
160
+ 'spot',
161
+ 'specs',
162
+ sqlalchemy.Text(),
163
+ value_to_replace_existing_entries=json.dumps({
164
+ 'max_restarts_on_errors': 0,
165
+ }))
166
+ db_utils.add_column_to_table_sqlalchemy(
167
+ session,
168
+ 'spot',
169
+ 'local_log_file',
170
+ sqlalchemy.Text(),
171
+ default_statement='DEFAULT NULL')
172
+
173
+ db_utils.add_column_to_table_sqlalchemy(session, 'job_info',
174
+ 'schedule_state',
175
+ sqlalchemy.Text())
176
+ db_utils.add_column_to_table_sqlalchemy(
177
+ session,
178
+ 'job_info',
179
+ 'controller_pid',
180
+ sqlalchemy.Integer(),
181
+ default_statement='DEFAULT NULL')
182
+ db_utils.add_column_to_table_sqlalchemy(session, 'job_info',
183
+ 'dag_yaml_path',
184
+ sqlalchemy.Text())
185
+ db_utils.add_column_to_table_sqlalchemy(session, 'job_info',
186
+ 'env_file_path',
187
+ sqlalchemy.Text())
188
+ db_utils.add_column_to_table_sqlalchemy(session, 'job_info',
189
+ 'user_hash', sqlalchemy.Text())
190
+ db_utils.add_column_to_table_sqlalchemy(
191
+ session,
192
+ 'job_info',
193
+ 'workspace',
194
+ sqlalchemy.Text(),
195
+ default_statement='DEFAULT NULL',
196
+ value_to_replace_existing_entries='default')
197
+ db_utils.add_column_to_table_sqlalchemy(
198
+ session,
199
+ 'job_info',
200
+ 'priority',
201
+ sqlalchemy.Integer(),
202
+ value_to_replace_existing_entries=constants.DEFAULT_PRIORITY)
203
+ db_utils.add_column_to_table_sqlalchemy(session, 'job_info',
204
+ 'entrypoint', sqlalchemy.Text())
205
+ db_utils.add_column_to_table_sqlalchemy(session, 'job_info',
206
+ 'original_user_yaml_path',
207
+ sqlalchemy.Text())
208
+ session.commit()
209
+
210
+
211
+ def initialize_and_get_db() -> sqlalchemy.engine.Engine:
212
+ global _SQLALCHEMY_ENGINE
213
+ if _SQLALCHEMY_ENGINE is not None:
214
+ return _SQLALCHEMY_ENGINE
215
+ with _DB_INIT_LOCK:
216
+ if _SQLALCHEMY_ENGINE is None:
217
+ conn_string = None
218
+ if os.environ.get(constants.ENV_VAR_IS_SKYPILOT_SERVER) is not None:
219
+ conn_string = skypilot_config.get_nested(('db',), None)
220
+ if conn_string:
221
+ logger.debug(f'using db URI from {conn_string}')
222
+ _SQLALCHEMY_ENGINE = sqlalchemy.create_engine(conn_string)
223
+ else:
224
+ db_path = os.path.expanduser('~/.sky/spot_jobs.db')
225
+ pathlib.Path(db_path).parents[0].mkdir(parents=True,
226
+ exist_ok=True)
227
+ _SQLALCHEMY_ENGINE = sqlalchemy.create_engine('sqlite:///' +
228
+ db_path)
229
+ create_table()
230
+ return _SQLALCHEMY_ENGINE
179
231
 
180
232
 
181
233
  def _init_db(func):
@@ -183,13 +235,7 @@ def _init_db(func):
183
235
 
184
236
  @functools.wraps(func)
185
237
  def wrapper(*args, **kwargs):
186
- global _DB_PATH
187
- if _DB_PATH is not None:
188
- return func(*args, **kwargs)
189
- with _db_init_lock:
190
- if _DB_PATH is None:
191
- _DB_PATH = _get_db_path()
192
- db_utils.SQLiteConn(_DB_PATH, create_table)
238
+ initialize_and_get_db()
193
239
  return func(*args, **kwargs)
194
240
 
195
241
  return wrapper
@@ -207,37 +253,39 @@ def _init_db(func):
207
253
  # e.g., via sky jobs queue. These may not correspond to actual
208
254
  # column names in the DB and it corresponds to the combined view
209
255
  # by joining the spot and job_info tables.
210
- columns = [
211
- '_job_id',
212
- '_task_name',
213
- 'resources',
214
- 'submitted_at',
215
- 'status',
216
- 'run_timestamp',
217
- 'start_at',
218
- 'end_at',
219
- 'last_recovered_at',
220
- 'recovery_count',
221
- 'job_duration',
222
- 'failure_reason',
223
- 'job_id',
224
- 'task_id',
225
- 'task_name',
226
- 'specs',
227
- 'local_log_file',
228
- # columns from the job_info table
229
- '_job_info_job_id', # This should be the same as job_id
230
- 'job_name',
231
- 'schedule_state',
232
- 'controller_pid',
233
- 'dag_yaml_path',
234
- 'env_file_path',
235
- 'user_hash',
236
- 'workspace',
237
- 'priority',
238
- 'entrypoint',
239
- 'original_user_yaml_path',
240
- ]
256
+ def _get_jobs_dict(r: 'row.RowMapping') -> Dict[str, Any]:
257
+ return {
258
+ '_job_id': r['job_id'], # from spot table
259
+ '_task_name': r['job_name'], # deprecated, from spot table
260
+ 'resources': r['resources'],
261
+ 'submitted_at': r['submitted_at'],
262
+ 'status': r['status'],
263
+ 'run_timestamp': r['run_timestamp'],
264
+ 'start_at': r['start_at'],
265
+ 'end_at': r['end_at'],
266
+ 'last_recovered_at': r['last_recovered_at'],
267
+ 'recovery_count': r['recovery_count'],
268
+ 'job_duration': r['job_duration'],
269
+ 'failure_reason': r['failure_reason'],
270
+ 'job_id': r[spot_table.c.spot_job_id], # ambiguous, use table.column
271
+ 'task_id': r['task_id'],
272
+ 'task_name': r['task_name'],
273
+ 'specs': r['specs'],
274
+ 'local_log_file': r['local_log_file'],
275
+ # columns from job_info table (some may be None for legacy jobs)
276
+ '_job_info_job_id': r[job_info_table.c.spot_job_id
277
+ ], # ambiguous, use table.column
278
+ 'job_name': r['name'], # from job_info table
279
+ 'schedule_state': r['schedule_state'],
280
+ 'controller_pid': r['controller_pid'],
281
+ 'dag_yaml_path': r['dag_yaml_path'],
282
+ 'env_file_path': r['env_file_path'],
283
+ 'user_hash': r['user_hash'],
284
+ 'workspace': r['workspace'],
285
+ 'priority': r['priority'],
286
+ 'entrypoint': r['entrypoint'],
287
+ 'original_user_yaml_path': r['original_user_yaml_path'],
288
+ }
241
289
 
242
290
 
243
291
  class ManagedJobStatus(enum.Enum):
@@ -452,44 +500,76 @@ class ManagedJobScheduleState(enum.Enum):
452
500
  # === Status transition functions ===
453
501
  @_init_db
454
502
  def set_job_info(job_id: int, name: str, workspace: str, entrypoint: str):
455
- assert _DB_PATH is not None
456
- with db_utils.safe_cursor(_DB_PATH) as cursor:
457
- cursor.execute(
458
- """\
459
- INSERT INTO job_info
460
- (spot_job_id, name, schedule_state, workspace, entrypoint)
461
- VALUES (?, ?, ?, ?, ?)""",
462
- (job_id, name, ManagedJobScheduleState.INACTIVE.value, workspace,
463
- entrypoint))
503
+ assert _SQLALCHEMY_ENGINE is not None
504
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
505
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
506
+ db_utils.SQLAlchemyDialect.SQLITE.value):
507
+ insert_func = sqlite.insert
508
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
509
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
510
+ insert_func = postgresql.insert
511
+ else:
512
+ raise ValueError('Unsupported database dialect')
513
+ insert_stmt = insert_func(job_info_table).values(
514
+ spot_job_id=job_id,
515
+ name=name,
516
+ schedule_state=ManagedJobScheduleState.INACTIVE.value,
517
+ workspace=workspace,
518
+ entrypoint=entrypoint)
519
+ session.execute(insert_stmt)
520
+ session.commit()
464
521
 
465
522
 
466
523
  @_init_db
467
524
  def set_job_info_without_job_id(name: str, workspace: str,
468
525
  entrypoint: str) -> int:
469
- assert _DB_PATH is not None
470
- with db_utils.safe_cursor(_DB_PATH) as cursor:
471
- cursor.execute(
472
- """\
473
- INSERT INTO job_info
474
- (name, schedule_state, workspace, entrypoint)
475
- VALUES (?, ?, ?, ?)""",
476
- (name, ManagedJobScheduleState.INACTIVE.value, workspace,
477
- entrypoint))
478
- return cursor.lastrowid
526
+ assert _SQLALCHEMY_ENGINE is not None
527
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
528
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
529
+ db_utils.SQLAlchemyDialect.SQLITE.value):
530
+ insert_func = sqlite.insert
531
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
532
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
533
+ insert_func = postgresql.insert
534
+ else:
535
+ raise ValueError('Unsupported database dialect')
536
+
537
+ insert_stmt = insert_func(job_info_table).values(
538
+ name=name,
539
+ schedule_state=ManagedJobScheduleState.INACTIVE.value,
540
+ workspace=workspace,
541
+ entrypoint=entrypoint,
542
+ )
543
+
544
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
545
+ db_utils.SQLAlchemyDialect.SQLITE.value):
546
+ result = session.execute(insert_stmt)
547
+ session.commit()
548
+ return result.lastrowid
549
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
550
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
551
+ result = session.execute(
552
+ insert_stmt.returning(job_info_table.c.spot_job_id))
553
+ session.commit()
554
+ return result.scalar()
555
+ else:
556
+ raise ValueError('Unsupported database dialect')
479
557
 
480
558
 
481
559
  @_init_db
482
560
  def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str):
483
561
  """Set the task to pending state."""
484
- assert _DB_PATH is not None
485
- with db_utils.safe_cursor(_DB_PATH) as cursor:
486
- cursor.execute(
487
- """\
488
- INSERT INTO spot
489
- (spot_job_id, task_id, task_name, resources, status)
490
- VALUES (?, ?, ?, ?, ?)""",
491
- (job_id, task_id, task_name, resources_str,
492
- ManagedJobStatus.PENDING.value))
562
+ assert _SQLALCHEMY_ENGINE is not None
563
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
564
+ session.execute(
565
+ sqlalchemy.insert(spot_table).values(
566
+ spot_job_id=job_id,
567
+ task_id=task_id,
568
+ task_name=task_name,
569
+ resources=resources_str,
570
+ status=ManagedJobStatus.PENDING.value,
571
+ ))
572
+ session.commit()
493
573
 
494
574
 
495
575
  @_init_db
@@ -509,33 +589,32 @@ def set_starting(job_id: int, task_id: int, run_timestamp: str,
509
589
  specs: The specs of the managed task.
510
590
  callback_func: The callback function.
511
591
  """
512
- assert _DB_PATH is not None
592
+ assert _SQLALCHEMY_ENGINE is not None
513
593
  # Use the timestamp in the `run_timestamp` ('sky-2022-10...'), to make
514
594
  # the log directory and submission time align with each other, so as to
515
595
  # make it easier to find them based on one of the values.
516
596
  # Also, using the earlier timestamp should be closer to the term
517
597
  # `submit_at`, which represents the time the managed task is submitted.
518
598
  logger.info('Launching the spot cluster...')
519
- with db_utils.safe_cursor(_DB_PATH) as cursor:
520
- cursor.execute(
521
- """\
522
- UPDATE spot SET
523
- resources=(?),
524
- submitted_at=(?),
525
- status=(?),
526
- run_timestamp=(?),
527
- specs=(?)
528
- WHERE spot_job_id=(?) AND
529
- task_id=(?) AND
530
- status=(?) AND
531
- end_at IS null""",
532
- (resources_str, submit_time, ManagedJobStatus.STARTING.value,
533
- run_timestamp, json.dumps(specs), job_id, task_id,
534
- ManagedJobStatus.PENDING.value))
535
- if cursor.rowcount != 1:
599
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
600
+ count = session.query(spot_table).filter(
601
+ sqlalchemy.and_(
602
+ spot_table.c.spot_job_id == job_id,
603
+ spot_table.c.task_id == task_id,
604
+ spot_table.c.status == ManagedJobStatus.PENDING.value,
605
+ spot_table.c.end_at.is_(None),
606
+ )).update({
607
+ spot_table.c.resources: resources_str,
608
+ spot_table.c.submitted_at: submit_time,
609
+ spot_table.c.status: ManagedJobStatus.STARTING.value,
610
+ spot_table.c.run_timestamp: run_timestamp,
611
+ spot_table.c.specs: json.dumps(specs),
612
+ })
613
+ session.commit()
614
+ if count != 1:
536
615
  raise exceptions.ManagedJobStatusError(
537
616
  'Failed to set the task to starting. '
538
- f'({cursor.rowcount} rows updated)')
617
+ f'({count} rows updated)')
539
618
  # SUBMITTED is no longer used, but we keep it for backward compatibility.
540
619
  # TODO(cooperc): remove this in v0.12.0
541
620
  callback_func('SUBMITTED')
@@ -549,22 +628,24 @@ def set_backoff_pending(job_id: int, task_id: int):
549
628
  This should only be used to transition from STARTING or RECOVERING back to
550
629
  PENDING.
551
630
  """
552
- assert _DB_PATH is not None
553
- with db_utils.safe_cursor(_DB_PATH) as cursor:
554
- cursor.execute(
555
- """\
556
- UPDATE spot SET status=(?)
557
- WHERE spot_job_id=(?) AND
558
- task_id=(?) AND
559
- status IN (?, ?) AND
560
- end_at IS null""", (ManagedJobStatus.PENDING.value, job_id, task_id,
561
- ManagedJobStatus.STARTING.value,
562
- ManagedJobStatus.RECOVERING.value))
631
+ assert _SQLALCHEMY_ENGINE is not None
632
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
633
+ count = session.query(spot_table).filter(
634
+ sqlalchemy.and_(
635
+ spot_table.c.spot_job_id == job_id,
636
+ spot_table.c.task_id == task_id,
637
+ spot_table.c.status.in_([
638
+ ManagedJobStatus.STARTING.value,
639
+ ManagedJobStatus.RECOVERING.value
640
+ ]),
641
+ spot_table.c.end_at.is_(None),
642
+ )).update({spot_table.c.status: ManagedJobStatus.PENDING.value})
643
+ session.commit()
563
644
  logger.debug('back to PENDING')
564
- if cursor.rowcount != 1:
645
+ if count != 1:
565
646
  raise exceptions.ManagedJobStatusError(
566
647
  'Failed to set the task back to pending. '
567
- f'({cursor.rowcount} rows updated)')
648
+ f'({count} rows updated)')
568
649
  # Do not call callback_func here, as we don't use the callback for PENDING.
569
650
 
570
651
 
@@ -577,24 +658,24 @@ def set_restarting(job_id: int, task_id: int, recovering: bool):
577
658
  after using set_backoff_pending to transition back to PENDING during
578
659
  launch retry backoff.
579
660
  """
580
- assert _DB_PATH is not None
661
+ assert _SQLALCHEMY_ENGINE is not None
581
662
  target_status = ManagedJobStatus.STARTING.value
582
663
  if recovering:
583
664
  target_status = ManagedJobStatus.RECOVERING.value
584
- with db_utils.safe_cursor(_DB_PATH) as cursor:
585
- cursor.execute(
586
- """\
587
- UPDATE spot SET status=(?)
588
- WHERE spot_job_id=(?) AND
589
- task_id=(?) AND
590
- status=(?) AND
591
- end_at IS null""",
592
- (target_status, job_id, task_id, ManagedJobStatus.PENDING.value))
665
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
666
+ count = session.query(spot_table).filter(
667
+ sqlalchemy.and_(
668
+ spot_table.c.spot_job_id == job_id,
669
+ spot_table.c.task_id == task_id,
670
+ spot_table.c.status == ManagedJobStatus.PENDING.value,
671
+ spot_table.c.end_at.is_(None),
672
+ )).update({spot_table.c.status: target_status})
673
+ session.commit()
593
674
  logger.debug(f'back to {target_status}')
594
- if cursor.rowcount != 1:
675
+ if count != 1:
595
676
  raise exceptions.ManagedJobStatusError(
596
677
  f'Failed to set the task back to {target_status}. '
597
- f'({cursor.rowcount} rows updated)')
678
+ f'({count} rows updated)')
598
679
  # Do not call callback_func here, as it should only be invoked for the
599
680
  # initial (pre-`set_backoff_pending`) transition to STARTING or RECOVERING.
600
681
 
@@ -603,32 +684,30 @@ def set_restarting(job_id: int, task_id: int, recovering: bool):
603
684
  def set_started(job_id: int, task_id: int, start_time: float,
604
685
  callback_func: CallbackType):
605
686
  """Set the task to started state."""
606
- assert _DB_PATH is not None
687
+ assert _SQLALCHEMY_ENGINE is not None
607
688
  logger.info('Job started.')
608
- with db_utils.safe_cursor(_DB_PATH) as cursor:
609
- cursor.execute(
610
- """\
611
- UPDATE spot SET status=(?), start_at=(?), last_recovered_at=(?)
612
- WHERE spot_job_id=(?) AND
613
- task_id=(?) AND
614
- status IN (?, ?) AND
615
- end_at IS null""",
616
- (
617
- ManagedJobStatus.RUNNING.value,
618
- start_time,
619
- start_time,
620
- job_id,
621
- task_id,
622
- ManagedJobStatus.STARTING.value,
623
- # If the task is empty, we will jump straight from PENDING to
624
- # RUNNING
625
- ManagedJobStatus.PENDING.value,
626
- ),
627
- )
628
- if cursor.rowcount != 1:
689
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
690
+ count = session.query(spot_table).filter(
691
+ sqlalchemy.and_(
692
+ spot_table.c.spot_job_id == job_id,
693
+ spot_table.c.task_id == task_id,
694
+ spot_table.c.status.in_([
695
+ ManagedJobStatus.STARTING.value,
696
+ # If the task is empty, we will jump straight
697
+ # from PENDING to RUNNING
698
+ ManagedJobStatus.PENDING.value
699
+ ]),
700
+ spot_table.c.end_at.is_(None),
701
+ )).update({
702
+ spot_table.c.status: ManagedJobStatus.RUNNING.value,
703
+ spot_table.c.start_at: start_time,
704
+ spot_table.c.last_recovered_at: start_time,
705
+ })
706
+ session.commit()
707
+ if count != 1:
629
708
  raise exceptions.ManagedJobStatusError(
630
709
  f'Failed to set the task to started. '
631
- f'({cursor.rowcount} rows updated)')
710
+ f'({count} rows updated)')
632
711
  callback_func('STARTED')
633
712
 
634
713
 
@@ -636,50 +715,48 @@ def set_started(job_id: int, task_id: int, start_time: float,
636
715
  def set_recovering(job_id: int, task_id: int, force_transit_to_recovering: bool,
637
716
  callback_func: CallbackType):
638
717
  """Set the task to recovering state, and update the job duration."""
639
- assert _DB_PATH is not None
718
+ assert _SQLALCHEMY_ENGINE is not None
640
719
  logger.info('=== Recovering... ===')
641
- expected_status: List[str] = [ManagedJobStatus.RUNNING.value]
642
- status_str = 'status=(?)'
643
- if force_transit_to_recovering:
644
- # For the HA job controller, it is possible that the jobs came from any
645
- # processing status to recovering. But it should not be any terminal
646
- # status as such jobs will not be recovered; and it should not be
647
- # CANCELLING as we will directly trigger a cleanup.
648
- expected_status = [
649
- s.value for s in ManagedJobStatus.processing_statuses()
650
- ]
651
- question_mark_str = ', '.join(['?'] * len(expected_status))
652
- status_str = f'status IN ({question_mark_str})'
653
720
  # NOTE: if we are resuming from a controller failure and the previous status
654
721
  # is STARTING, the initial value of `last_recovered_at` might not be set
655
722
  # yet (default value -1). In this case, we should not add current timestamp.
656
723
  # Otherwise, the job duration will be incorrect (~55 years from 1970).
657
724
  current_time = time.time()
658
- with db_utils.safe_cursor(_DB_PATH) as cursor:
659
- cursor.execute(
660
- f"""\
661
- UPDATE spot SET
662
- status=(?),
663
- job_duration=CASE
664
- WHEN last_recovered_at >= 0
665
- THEN job_duration+(?)-last_recovered_at
666
- ELSE job_duration
667
- END,
668
- last_recovered_at=CASE
669
- WHEN last_recovered_at < 0
670
- THEN (?)
671
- ELSE last_recovered_at
672
- END
673
- WHERE spot_job_id=(?) AND
674
- task_id=(?) AND
675
- {status_str} AND
676
- end_at IS null""",
677
- (ManagedJobStatus.RECOVERING.value, current_time, current_time,
678
- job_id, task_id, *expected_status))
679
- if cursor.rowcount != 1:
725
+
726
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
727
+ if force_transit_to_recovering:
728
+ # For the HA job controller, it is possible that the jobs came from
729
+ # any processing status to recovering. But it should not be any
730
+ # terminal status as such jobs will not be recovered; and it should
731
+ # not be CANCELLING as we will directly trigger a cleanup.
732
+ status_condition = spot_table.c.status.in_(
733
+ [s.value for s in ManagedJobStatus.processing_statuses()])
734
+ else:
735
+ status_condition = (
736
+ spot_table.c.status == ManagedJobStatus.RUNNING.value)
737
+
738
+ count = session.query(spot_table).filter(
739
+ sqlalchemy.and_(
740
+ spot_table.c.spot_job_id == job_id,
741
+ spot_table.c.task_id == task_id,
742
+ status_condition,
743
+ spot_table.c.end_at.is_(None),
744
+ )).update({
745
+ spot_table.c.status: ManagedJobStatus.RECOVERING.value,
746
+ spot_table.c.job_duration: sqlalchemy.case(
747
+ (spot_table.c.last_recovered_at >= 0,
748
+ spot_table.c.job_duration + current_time -
749
+ spot_table.c.last_recovered_at),
750
+ else_=spot_table.c.job_duration),
751
+ spot_table.c.last_recovered_at: sqlalchemy.case(
752
+ (spot_table.c.last_recovered_at < 0, current_time),
753
+ else_=spot_table.c.last_recovered_at),
754
+ })
755
+ session.commit()
756
+ if count != 1:
680
757
  raise exceptions.ManagedJobStatusError(
681
758
  f'Failed to set the task to recovering. '
682
- f'({cursor.rowcount} rows updated)')
759
+ f'({count} rows updated)')
683
760
  callback_func('RECOVERING')
684
761
 
685
762
 
@@ -687,22 +764,24 @@ def set_recovering(job_id: int, task_id: int, force_transit_to_recovering: bool,
687
764
  def set_recovered(job_id: int, task_id: int, recovered_time: float,
688
765
  callback_func: CallbackType):
689
766
  """Set the task to recovered."""
690
- assert _DB_PATH is not None
691
- with db_utils.safe_cursor(_DB_PATH) as cursor:
692
- cursor.execute(
693
- """\
694
- UPDATE spot SET
695
- status=(?), last_recovered_at=(?), recovery_count=recovery_count+1
696
- WHERE spot_job_id=(?) AND
697
- task_id=(?) AND
698
- status=(?) AND
699
- end_at IS null""",
700
- (ManagedJobStatus.RUNNING.value, recovered_time, job_id, task_id,
701
- ManagedJobStatus.RECOVERING.value))
702
- if cursor.rowcount != 1:
767
+ assert _SQLALCHEMY_ENGINE is not None
768
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
769
+ count = session.query(spot_table).filter(
770
+ sqlalchemy.and_(
771
+ spot_table.c.spot_job_id == job_id,
772
+ spot_table.c.task_id == task_id,
773
+ spot_table.c.status == ManagedJobStatus.RECOVERING.value,
774
+ spot_table.c.end_at.is_(None),
775
+ )).update({
776
+ spot_table.c.status: ManagedJobStatus.RUNNING.value,
777
+ spot_table.c.last_recovered_at: recovered_time,
778
+ spot_table.c.recovery_count: spot_table.c.recovery_count + 1,
779
+ })
780
+ session.commit()
781
+ if count != 1:
703
782
  raise exceptions.ManagedJobStatusError(
704
783
  f'Failed to set the task to recovered. '
705
- f'({cursor.rowcount} rows updated)')
784
+ f'({count} rows updated)')
706
785
  logger.info('==== Recovered. ====')
707
786
  callback_func('RECOVERED')
708
787
 
@@ -711,22 +790,23 @@ def set_recovered(job_id: int, task_id: int, recovered_time: float,
711
790
  def set_succeeded(job_id: int, task_id: int, end_time: float,
712
791
  callback_func: CallbackType):
713
792
  """Set the task to succeeded, if it is in a non-terminal state."""
714
- assert _DB_PATH is not None
715
- with db_utils.safe_cursor(_DB_PATH) as cursor:
716
- cursor.execute(
717
- """\
718
- UPDATE spot SET
719
- status=(?), end_at=(?)
720
- WHERE spot_job_id=(?) AND
721
- task_id=(?) AND
722
- status=(?) AND
723
- end_at IS null""",
724
- (ManagedJobStatus.SUCCEEDED.value, end_time, job_id, task_id,
725
- ManagedJobStatus.RUNNING.value))
726
- if cursor.rowcount != 1:
793
+ assert _SQLALCHEMY_ENGINE is not None
794
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
795
+ count = session.query(spot_table).filter(
796
+ sqlalchemy.and_(
797
+ spot_table.c.spot_job_id == job_id,
798
+ spot_table.c.task_id == task_id,
799
+ spot_table.c.status == ManagedJobStatus.RUNNING.value,
800
+ spot_table.c.end_at.is_(None),
801
+ )).update({
802
+ spot_table.c.status: ManagedJobStatus.SUCCEEDED.value,
803
+ spot_table.c.end_at: end_time,
804
+ })
805
+ session.commit()
806
+ if count != 1:
727
807
  raise exceptions.ManagedJobStatusError(
728
808
  f'Failed to set the task to succeeded. '
729
- f'({cursor.rowcount} rows updated)')
809
+ f'({count} rows updated)')
730
810
  callback_func('SUCCEEDED')
731
811
  logger.info('Job succeeded.')
732
812
 
@@ -756,52 +836,40 @@ def set_failed(
756
836
  override_terminal: If True, override the current status even if end_at
757
837
  is already set.
758
838
  """
759
- assert _DB_PATH is not None
839
+ assert _SQLALCHEMY_ENGINE is not None
760
840
  assert failure_type.is_failed(), failure_type
761
841
  end_time = time.time() if end_time is None else end_time
762
842
 
763
843
  fields_to_set: Dict[str, Any] = {
764
- 'status': failure_type.value,
765
- 'failure_reason': failure_reason,
844
+ spot_table.c.status: failure_type.value,
845
+ spot_table.c.failure_reason: failure_reason,
766
846
  }
767
- with db_utils.safe_cursor(_DB_PATH) as cursor:
768
- previous_status = cursor.execute(
769
- 'SELECT status FROM spot WHERE spot_job_id=(?)',
770
- (job_id,)).fetchone()[0]
847
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
848
+ # Get previous status
849
+ previous_status = session.execute(
850
+ sqlalchemy.select(spot_table.c.status).where(
851
+ spot_table.c.spot_job_id == job_id)).fetchone()[0]
771
852
  previous_status = ManagedJobStatus(previous_status)
772
853
  if previous_status == ManagedJobStatus.RECOVERING:
773
854
  # If the job is recovering, we should set the last_recovered_at to
774
855
  # the end_time, so that the end_at - last_recovered_at will not be
775
856
  # affect the job duration calculation.
776
- fields_to_set['last_recovered_at'] = end_time
777
- set_str = ', '.join(f'{k}=(?)' for k in fields_to_set)
778
- task_query_str = '' if task_id is None else 'AND task_id=(?)'
779
- task_value = [] if task_id is None else [
780
- task_id,
781
- ]
782
-
857
+ fields_to_set[spot_table.c.last_recovered_at] = end_time
858
+ where_conditions = [spot_table.c.spot_job_id == job_id]
859
+ if task_id is not None:
860
+ where_conditions.append(spot_table.c.task_id == task_id)
783
861
  if override_terminal:
784
862
  # Use COALESCE for end_at to avoid overriding the existing end_at if
785
863
  # it's already set.
786
- cursor.execute(
787
- f"""\
788
- UPDATE spot SET
789
- end_at = COALESCE(end_at, ?),
790
- {set_str}
791
- WHERE spot_job_id=(?) {task_query_str}""",
792
- (end_time, *list(fields_to_set.values()), job_id, *task_value))
864
+ fields_to_set[spot_table.c.end_at] = sqlalchemy.func.coalesce(
865
+ spot_table.c.end_at, end_time)
793
866
  else:
794
- # Only set if end_at is null, i.e. the previous status is not
795
- # terminal.
796
- cursor.execute(
797
- f"""\
798
- UPDATE spot SET
799
- end_at = (?),
800
- {set_str}
801
- WHERE spot_job_id=(?) {task_query_str} AND end_at IS null""",
802
- (end_time, *list(fields_to_set.values()), job_id, *task_value))
803
-
804
- updated = cursor.rowcount > 0
867
+ fields_to_set[spot_table.c.end_at] = end_time
868
+ where_conditions.append(spot_table.c.end_at.is_(None))
869
+ count = session.query(spot_table).filter(
870
+ sqlalchemy.and_(*where_conditions)).update(fields_to_set)
871
+ session.commit()
872
+ updated = count > 0
805
873
  if callback_func and updated:
806
874
  callback_func('FAILED')
807
875
  logger.info(failure_reason)
@@ -814,15 +882,15 @@ def set_cancelling(job_id: int, callback_func: CallbackType):
814
882
  task_id is not needed, because we expect the job should be cancelled
815
883
  as a whole, and we should not cancel a single task.
816
884
  """
817
- assert _DB_PATH is not None
818
- with db_utils.safe_cursor(_DB_PATH) as cursor:
819
- rows = cursor.execute(
820
- """\
821
- UPDATE spot SET
822
- status=(?)
823
- WHERE spot_job_id=(?) AND end_at IS null""",
824
- (ManagedJobStatus.CANCELLING.value, job_id))
825
- updated = rows.rowcount > 0
885
+ assert _SQLALCHEMY_ENGINE is not None
886
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
887
+ count = session.query(spot_table).filter(
888
+ sqlalchemy.and_(
889
+ spot_table.c.spot_job_id == job_id,
890
+ spot_table.c.end_at.is_(None),
891
+ )).update({spot_table.c.status: ManagedJobStatus.CANCELLING.value})
892
+ session.commit()
893
+ updated = count > 0
826
894
  if updated:
827
895
  logger.info('Cancelling the job...')
828
896
  callback_func('CANCELLING')
@@ -836,16 +904,18 @@ def set_cancelled(job_id: int, callback_func: CallbackType):
836
904
 
837
905
  The set_cancelling should be called before this function.
838
906
  """
839
- assert _DB_PATH is not None
840
- with db_utils.safe_cursor(_DB_PATH) as cursor:
841
- rows = cursor.execute(
842
- """\
843
- UPDATE spot SET
844
- status=(?), end_at=(?)
845
- WHERE spot_job_id=(?) AND status=(?)""",
846
- (ManagedJobStatus.CANCELLED.value, time.time(), job_id,
847
- ManagedJobStatus.CANCELLING.value))
848
- updated = rows.rowcount > 0
907
+ assert _SQLALCHEMY_ENGINE is not None
908
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
909
+ count = session.query(spot_table).filter(
910
+ sqlalchemy.and_(
911
+ spot_table.c.spot_job_id == job_id,
912
+ spot_table.c.status == ManagedJobStatus.CANCELLING.value,
913
+ )).update({
914
+ spot_table.c.status: ManagedJobStatus.CANCELLED.value,
915
+ spot_table.c.end_at: time.time(),
916
+ })
917
+ session.commit()
918
+ updated = count > 0
849
919
  if updated:
850
920
  logger.info('Job cancelled.')
851
921
  callback_func('CANCELLED')
@@ -857,17 +927,15 @@ def set_cancelled(job_id: int, callback_func: CallbackType):
857
927
  def set_local_log_file(job_id: int, task_id: Optional[int],
858
928
  local_log_file: str):
859
929
  """Set the local log file for a job."""
860
- assert _DB_PATH is not None
861
- filter_str = 'spot_job_id=(?)'
862
- filter_args = [local_log_file, job_id]
863
-
864
- if task_id is not None:
865
- filter_str += ' AND task_id=(?)'
866
- filter_args.append(task_id)
867
- with db_utils.safe_cursor(_DB_PATH) as cursor:
868
- cursor.execute(
869
- 'UPDATE spot SET local_log_file=(?) '
870
- f'WHERE {filter_str}', filter_args)
930
+ assert _SQLALCHEMY_ENGINE is not None
931
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
932
+ where_conditions = [spot_table.c.spot_job_id == job_id]
933
+ if task_id is not None:
934
+ where_conditions.append(spot_table.c.task_id == task_id)
935
+ session.query(spot_table).filter(
936
+ sqlalchemy.and_(*where_conditions)).update(
937
+ {spot_table.c.local_log_file: local_log_file})
938
+ session.commit()
871
939
 
872
940
 
873
941
  # ======== utility functions ========
@@ -875,37 +943,37 @@ def set_local_log_file(job_id: int, task_id: Optional[int],
875
943
  def get_nonterminal_job_ids_by_name(name: Optional[str],
876
944
  all_users: bool = False) -> List[int]:
877
945
  """Get non-terminal job ids by name."""
878
- assert _DB_PATH is not None
879
- statuses = ', '.join(['?'] * len(ManagedJobStatus.terminal_statuses()))
880
- field_values = [
881
- status.value for status in ManagedJobStatus.terminal_statuses()
882
- ]
883
-
884
- job_filter = ''
885
- if name is None and not all_users:
886
- job_filter += 'AND (job_info.user_hash=(?)) '
887
- field_values.append(common_utils.get_user_hash())
888
- if name is not None:
889
- # We match the job name from `job_info` for the jobs submitted after
890
- # #1982, and from `spot` for the jobs submitted before #1982, whose
891
- # job_info is not available.
892
- job_filter += ('AND (job_info.name=(?) OR '
893
- '(job_info.name IS NULL AND spot.task_name=(?))) ')
894
- field_values.extend([name, name])
895
-
896
- # Left outer join is used here instead of join, because the job_info does
897
- # not contain the managed jobs submitted before #1982.
898
- with db_utils.safe_cursor(_DB_PATH) as cursor:
899
- rows = cursor.execute(
900
- f"""\
901
- SELECT DISTINCT spot.spot_job_id
902
- FROM spot
903
- LEFT OUTER JOIN job_info
904
- ON spot.spot_job_id=job_info.spot_job_id
905
- WHERE status NOT IN
906
- ({statuses})
907
- {job_filter}
908
- ORDER BY spot.spot_job_id DESC""", field_values).fetchall()
946
+ assert _SQLALCHEMY_ENGINE is not None
947
+
948
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
949
+ # Build the query using SQLAlchemy core
950
+ query = sqlalchemy.select(
951
+ spot_table.c.spot_job_id.distinct()).select_from(
952
+ spot_table.outerjoin(
953
+ job_info_table,
954
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id,
955
+ ))
956
+ where_conditions = [
957
+ ~spot_table.c.status.in_([
958
+ status.value for status in ManagedJobStatus.terminal_statuses()
959
+ ])
960
+ ]
961
+ if name is None and not all_users:
962
+ where_conditions.append(
963
+ job_info_table.c.user_hash == common_utils.get_user_hash())
964
+ if name is not None:
965
+ # We match the job name from `job_info` for the jobs submitted after
966
+ # #1982, and from `spot` for the jobs submitted before #1982, whose
967
+ # job_info is not available.
968
+ where_conditions.append(
969
+ sqlalchemy.or_(
970
+ job_info_table.c.name == name,
971
+ sqlalchemy.and_(job_info_table.c.name.is_(None),
972
+ spot_table.c.task_name == name),
973
+ ))
974
+ query = query.where(sqlalchemy.and_(*where_conditions)).order_by(
975
+ spot_table.c.spot_job_id.desc())
976
+ rows = session.execute(query).fetchall()
909
977
  job_ids = [row[0] for row in rows if row[0] is not None]
910
978
  return job_ids
911
979
 
@@ -919,26 +987,25 @@ def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]:
919
987
  exception: the job may have just transitioned from WAITING to LAUNCHING, but
920
988
  the controller process has not yet started.
921
989
  """
922
- assert _DB_PATH is not None
923
- job_filter = '' if job_id is None else 'AND spot_job_id=(?)'
924
- job_value = (job_id,) if job_id is not None else ()
990
+ assert _SQLALCHEMY_ENGINE is not None
925
991
 
926
- # Join spot and job_info tables to get the job name for each task.
927
- # We use LEFT OUTER JOIN mainly for backward compatibility, as for an
928
- # existing controller before #1982, the job_info table may not exist,
929
- # and all the managed jobs created before will not present in the
930
- # job_info.
931
- with db_utils.safe_cursor(_DB_PATH) as cursor:
932
- rows = cursor.execute(
933
- f"""\
934
- SELECT spot_job_id, schedule_state, controller_pid
935
- FROM job_info
936
- WHERE schedule_state not in (?, ?, ?)
937
- {job_filter}
938
- ORDER BY spot_job_id DESC""",
939
- (ManagedJobScheduleState.INACTIVE.value,
940
- ManagedJobScheduleState.WAITING.value,
941
- ManagedJobScheduleState.DONE.value, *job_value)).fetchall()
992
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
993
+ query = sqlalchemy.select(
994
+ job_info_table.c.spot_job_id,
995
+ job_info_table.c.schedule_state,
996
+ job_info_table.c.controller_pid,
997
+ ).where(~job_info_table.c.schedule_state.in_([
998
+ ManagedJobScheduleState.INACTIVE.value,
999
+ ManagedJobScheduleState.WAITING.value,
1000
+ ManagedJobScheduleState.DONE.value,
1001
+ ]))
1002
+
1003
+ if job_id is not None:
1004
+ query = query.where(job_info_table.c.spot_job_id == job_id)
1005
+
1006
+ query = query.order_by(job_info_table.c.spot_job_id.desc())
1007
+
1008
+ rows = session.execute(query).fetchall()
942
1009
  jobs = []
943
1010
  for row in rows:
944
1011
  job_dict = {
@@ -962,77 +1029,76 @@ def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
962
1029
  - Jobs have schedule_state DONE but are in a non-terminal status
963
1030
  - Legacy jobs (that is, no schedule state) that are in non-terminal status
964
1031
  """
965
- assert _DB_PATH is not None
966
- job_filter = '' if job_id is None else 'AND spot.spot_job_id=(?)'
967
- job_value = () if job_id is None else (job_id,)
968
-
969
- status_filter_str = ', '.join(['?'] *
970
- len(ManagedJobStatus.terminal_statuses()))
971
- terminal_status_values = [
972
- status.value for status in ManagedJobStatus.terminal_statuses()
973
- ]
974
-
975
- # Get jobs that are either:
976
- # 1. Have schedule state that is not DONE, or
977
- # 2. Have schedule state DONE AND are in non-terminal status (unexpected
978
- # inconsistent state), or
979
- # 3. Have no schedule state (legacy) AND are in non-terminal status
980
- with db_utils.safe_cursor(_DB_PATH) as cursor:
981
- rows = cursor.execute(
982
- f"""\
983
- SELECT DISTINCT spot.spot_job_id
984
- FROM spot
985
- LEFT OUTER JOIN job_info
986
- ON spot.spot_job_id=job_info.spot_job_id
987
- WHERE (
988
- -- non-legacy jobs that are not DONE
989
- (job_info.schedule_state IS NOT NULL AND
990
- job_info.schedule_state IS NOT ?)
991
- OR
992
- -- legacy or that are in non-terminal status or
993
- -- DONE jobs that are in non-terminal status
994
- ((-- legacy jobs
995
- job_info.schedule_state IS NULL OR
996
- -- non-legacy DONE jobs
997
- job_info.schedule_state IS ?
998
- ) AND
999
- -- non-terminal
1000
- status NOT IN ({status_filter_str}))
1001
- )
1002
- {job_filter}
1003
- ORDER BY spot.spot_job_id DESC""", [
1004
- ManagedJobScheduleState.DONE.value,
1005
- ManagedJobScheduleState.DONE.value, *terminal_status_values,
1006
- *job_value
1007
- ]).fetchall()
1032
+ assert _SQLALCHEMY_ENGINE is not None
1033
+
1034
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1035
+ terminal_status_values = [
1036
+ status.value for status in ManagedJobStatus.terminal_statuses()
1037
+ ]
1038
+
1039
+ query = sqlalchemy.select(
1040
+ spot_table.c.spot_job_id.distinct()).select_from(
1041
+ spot_table.outerjoin(
1042
+ job_info_table,
1043
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1044
+
1045
+ # Get jobs that are either:
1046
+ # 1. Have schedule state that is not DONE, or
1047
+ # 2. Have schedule state DONE AND are in non-terminal status (unexpected
1048
+ # inconsistent state), or
1049
+ # 3. Have no schedule state (legacy) AND are in non-terminal status
1050
+
1051
+ # non-legacy jobs that are not DONE
1052
+ condition1 = sqlalchemy.and_(
1053
+ job_info_table.c.schedule_state.is_not(None),
1054
+ job_info_table.c.schedule_state !=
1055
+ ManagedJobScheduleState.DONE.value)
1056
+ # legacy or that are in non-terminal status or
1057
+ # DONE jobs that are in non-terminal status
1058
+ condition2 = sqlalchemy.and_(
1059
+ sqlalchemy.or_(
1060
+ # legacy jobs
1061
+ job_info_table.c.schedule_state.is_(None),
1062
+ # non-legacy DONE jobs
1063
+ job_info_table.c.schedule_state ==
1064
+ ManagedJobScheduleState.DONE.value),
1065
+ # non-terminal
1066
+ ~spot_table.c.status.in_(terminal_status_values),
1067
+ )
1068
+ where_condition = sqlalchemy.or_(condition1, condition2)
1069
+ if job_id is not None:
1070
+ where_condition = sqlalchemy.and_(
1071
+ where_condition, spot_table.c.spot_job_id == job_id)
1072
+
1073
+ query = query.where(where_condition).order_by(
1074
+ spot_table.c.spot_job_id.desc())
1075
+
1076
+ rows = session.execute(query).fetchall()
1008
1077
  return [row[0] for row in rows if row[0] is not None]
1009
1078
 
1010
1079
 
1011
1080
  @_init_db
1012
1081
  def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
1013
1082
  """Get all job ids by name."""
1014
- assert _DB_PATH is not None
1015
- name_filter = ''
1016
- field_values = []
1017
- if name is not None:
1018
- # We match the job name from `job_info` for the jobs submitted after
1019
- # #1982, and from `spot` for the jobs submitted before #1982, whose
1020
- # job_info is not available.
1021
- name_filter = ('WHERE (job_info.name=(?) OR '
1022
- '(job_info.name IS NULL AND spot.task_name=(?)))')
1023
- field_values = [name, name]
1024
-
1025
- # Left outer join is used here instead of join, because the job_info does
1026
- # not contain the managed jobs submitted before #1982.
1027
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1028
- rows = cursor.execute(
1029
- f"""\
1030
- SELECT DISTINCT spot.spot_job_id
1031
- FROM spot
1032
- LEFT OUTER JOIN job_info
1033
- ON spot.spot_job_id=job_info.spot_job_id
1034
- {name_filter}
1035
- ORDER BY spot.spot_job_id DESC""", field_values).fetchall()
1083
+ assert _SQLALCHEMY_ENGINE is not None
1084
+
1085
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1086
+ query = sqlalchemy.select(
1087
+ spot_table.c.spot_job_id.distinct()).select_from(
1088
+ spot_table.outerjoin(
1089
+ job_info_table,
1090
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1091
+ if name is not None:
1092
+ # We match the job name from `job_info` for the jobs submitted after
1093
+ # #1982, and from `spot` for the jobs submitted before #1982, whose
1094
+ # job_info is not available.
1095
+ name_condition = sqlalchemy.or_(
1096
+ job_info_table.c.name == name,
1097
+ sqlalchemy.and_(job_info_table.c.name.is_(None),
1098
+ spot_table.c.task_name == name))
1099
+ query = query.where(name_condition)
1100
+ query = query.order_by(spot_table.c.spot_job_id.desc())
1101
+ rows = session.execute(query).fetchall()
1036
1102
  job_ids = [row[0] for row in rows if row[0] is not None]
1037
1103
  return job_ids
1038
1104
 
@@ -1040,26 +1106,26 @@ def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
1040
1106
  @_init_db
1041
1107
  def _get_all_task_ids_statuses(
1042
1108
  job_id: int) -> List[Tuple[int, ManagedJobStatus]]:
1043
- assert _DB_PATH is not None
1044
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1045
- id_statuses = cursor.execute(
1046
- """\
1047
- SELECT task_id, status FROM spot
1048
- WHERE spot_job_id=(?)
1049
- ORDER BY task_id ASC""", (job_id,)).fetchall()
1109
+ assert _SQLALCHEMY_ENGINE is not None
1110
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1111
+ id_statuses = session.execute(
1112
+ sqlalchemy.select(
1113
+ spot_table.c.task_id,
1114
+ spot_table.c.status,
1115
+ ).where(spot_table.c.spot_job_id == job_id).order_by(
1116
+ spot_table.c.task_id.asc())).fetchall()
1050
1117
  return [(row[0], ManagedJobStatus(row[1])) for row in id_statuses]
1051
1118
 
1052
1119
 
1053
1120
  @_init_db
1054
1121
  def get_job_status_with_task_id(job_id: int,
1055
1122
  task_id: int) -> Optional[ManagedJobStatus]:
1056
- assert _DB_PATH is not None
1057
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1058
- status = cursor.execute(
1059
- """\
1060
- SELECT status FROM spot
1061
- WHERE spot_job_id=(?) AND task_id=(?)""",
1062
- (job_id, task_id)).fetchone()
1123
+ assert _SQLALCHEMY_ENGINE is not None
1124
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1125
+ status = session.execute(
1126
+ sqlalchemy.select(spot_table.c.status).where(
1127
+ sqlalchemy.and_(spot_table.c.spot_job_id == job_id,
1128
+ spot_table.c.task_id == task_id))).fetchone()
1063
1129
  return ManagedJobStatus(status[0]) if status else None
1064
1130
 
1065
1131
 
@@ -1101,13 +1167,12 @@ def get_failure_reason(job_id: int) -> Optional[str]:
1101
1167
 
1102
1168
  If the job has multiple tasks, we return the first failure reason.
1103
1169
  """
1104
- assert _DB_PATH is not None
1105
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1106
- reason = cursor.execute(
1107
- """\
1108
- SELECT failure_reason FROM spot
1109
- WHERE spot_job_id=(?)
1110
- ORDER BY task_id ASC""", (job_id,)).fetchall()
1170
+ assert _SQLALCHEMY_ENGINE is not None
1171
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1172
+ reason = session.execute(
1173
+ sqlalchemy.select(spot_table.c.failure_reason).where(
1174
+ spot_table.c.spot_job_id == job_id).order_by(
1175
+ spot_table.c.task_id.asc())).fetchall()
1111
1176
  reason = [r[0] for r in reason if r[0] is not None]
1112
1177
  if not reason:
1113
1178
  return None
@@ -1117,8 +1182,7 @@ def get_failure_reason(job_id: int) -> Optional[str]:
1117
1182
  @_init_db
1118
1183
  def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
1119
1184
  """Get managed jobs from the database."""
1120
- assert _DB_PATH is not None
1121
- job_filter = '' if job_id is None else f'WHERE spot.spot_job_id={job_id}'
1185
+ assert _SQLALCHEMY_ENGINE is not None
1122
1186
 
1123
1187
  # Join spot and job_info tables to get the job name for each task.
1124
1188
  # We use LEFT OUTER JOIN mainly for backward compatibility, as for an
@@ -1128,17 +1192,19 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
1128
1192
  # Note: we will get the user_hash here, but don't try to call
1129
1193
  # global_user_state.get_user() on it. This runs on the controller, which may
1130
1194
  # not have the user info. Prefer to do it on the API server side.
1131
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1132
- rows = cursor.execute(f"""\
1133
- SELECT *
1134
- FROM spot
1135
- LEFT OUTER JOIN job_info
1136
- ON spot.spot_job_id=job_info.spot_job_id
1137
- {job_filter}
1138
- ORDER BY spot.spot_job_id DESC, spot.task_id ASC""").fetchall()
1195
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1196
+ query = sqlalchemy.select(spot_table, job_info_table).select_from(
1197
+ spot_table.outerjoin(
1198
+ job_info_table,
1199
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1200
+ if job_id is not None:
1201
+ query = query.where(spot_table.c.spot_job_id == job_id)
1202
+ query = query.order_by(spot_table.c.spot_job_id.desc(),
1203
+ spot_table.c.task_id.asc())
1204
+ rows = session.execute(query).fetchall()
1139
1205
  jobs = []
1140
1206
  for row in rows:
1141
- job_dict = dict(zip(columns, row))
1207
+ job_dict = _get_jobs_dict(row._mapping) # pylint: disable=protected-access
1142
1208
  job_dict['status'] = ManagedJobStatus(job_dict['status'])
1143
1209
  job_dict['schedule_state'] = ManagedJobScheduleState(
1144
1210
  job_dict['schedule_state'])
@@ -1163,55 +1229,54 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
1163
1229
  @_init_db
1164
1230
  def get_task_name(job_id: int, task_id: int) -> str:
1165
1231
  """Get the task name of a job."""
1166
- assert _DB_PATH is not None
1167
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1168
- task_name = cursor.execute(
1169
- """\
1170
- SELECT task_name FROM spot
1171
- WHERE spot_job_id=(?)
1172
- AND task_id=(?)""", (job_id, task_id)).fetchone()
1232
+ assert _SQLALCHEMY_ENGINE is not None
1233
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1234
+ task_name = session.execute(
1235
+ sqlalchemy.select(spot_table.c.task_name).where(
1236
+ sqlalchemy.and_(
1237
+ spot_table.c.spot_job_id == job_id,
1238
+ spot_table.c.task_id == task_id,
1239
+ ))).fetchone()
1173
1240
  return task_name[0]
1174
1241
 
1175
1242
 
1176
1243
  @_init_db
1177
1244
  def get_latest_job_id() -> Optional[int]:
1178
1245
  """Get the latest job id."""
1179
- assert _DB_PATH is not None
1180
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1181
- rows = cursor.execute("""\
1182
- SELECT spot_job_id FROM spot
1183
- WHERE task_id=0
1184
- ORDER BY submitted_at DESC LIMIT 1""")
1185
- for (job_id,) in rows:
1186
- return job_id
1187
- return None
1246
+ assert _SQLALCHEMY_ENGINE is not None
1247
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1248
+ job_id = session.execute(
1249
+ sqlalchemy.select(spot_table.c.spot_job_id).where(
1250
+ spot_table.c.task_id == 0).order_by(
1251
+ spot_table.c.submitted_at.desc()).limit(1)).fetchone()
1252
+ return job_id[0] if job_id else None
1188
1253
 
1189
1254
 
1190
1255
  @_init_db
1191
1256
  def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
1192
- assert _DB_PATH is not None
1193
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1194
- task_specs = cursor.execute(
1195
- """\
1196
- SELECT specs FROM spot
1197
- WHERE spot_job_id=(?) AND task_id=(?)""",
1198
- (job_id, task_id)).fetchone()
1257
+ assert _SQLALCHEMY_ENGINE is not None
1258
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1259
+ task_specs = session.execute(
1260
+ sqlalchemy.select(spot_table.c.specs).where(
1261
+ sqlalchemy.and_(
1262
+ spot_table.c.spot_job_id == job_id,
1263
+ spot_table.c.task_id == task_id,
1264
+ ))).fetchone()
1199
1265
  return json.loads(task_specs[0])
1200
1266
 
1201
1267
 
1202
1268
  @_init_db
1203
1269
  def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
1204
1270
  """Get the local log directory for a job."""
1205
- assert _DB_PATH is not None
1206
- filter_str = 'spot_job_id=(?)'
1207
- filter_args = [job_id]
1208
- if task_id is not None:
1209
- filter_str += ' AND task_id=(?)'
1210
- filter_args.append(task_id)
1211
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1212
- local_log_file = cursor.execute(
1213
- f'SELECT local_log_file FROM spot '
1214
- f'WHERE {filter_str}', filter_args).fetchone()
1271
+ assert _SQLALCHEMY_ENGINE is not None
1272
+
1273
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1274
+ where_conditions = [spot_table.c.spot_job_id == job_id]
1275
+ if task_id is not None:
1276
+ where_conditions.append(spot_table.c.task_id == task_id)
1277
+ local_log_file = session.execute(
1278
+ sqlalchemy.select(spot_table.c.local_log_file).where(
1279
+ sqlalchemy.and_(*where_conditions))).fetchone()
1215
1280
  return local_log_file[-1] if local_log_file else None
1216
1281
 
1217
1282
 
@@ -1232,17 +1297,24 @@ def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
1232
1297
  updated_count will be 0). In this case, we return True.
1233
1298
  Otherwise, we return False.
1234
1299
  """
1235
- assert _DB_PATH is not None
1236
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1237
- updated_count = cursor.execute(
1238
- 'UPDATE job_info SET '
1239
- 'schedule_state = (?), dag_yaml_path = (?), '
1240
- 'original_user_yaml_path = (?), env_file_path = (?), '
1241
- ' user_hash = (?), priority = (?) '
1242
- 'WHERE spot_job_id = (?) AND schedule_state = (?)',
1243
- (ManagedJobScheduleState.WAITING.value, dag_yaml_path,
1244
- original_user_yaml_path, env_file_path, user_hash, priority,
1245
- job_id, ManagedJobScheduleState.INACTIVE.value)).rowcount
1300
+ assert _SQLALCHEMY_ENGINE is not None
1301
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1302
+ updated_count = session.query(job_info_table).filter(
1303
+ sqlalchemy.and_(
1304
+ job_info_table.c.spot_job_id == job_id,
1305
+ job_info_table.c.schedule_state ==
1306
+ ManagedJobScheduleState.INACTIVE.value,
1307
+ )
1308
+ ).update({
1309
+ job_info_table.c.schedule_state:
1310
+ ManagedJobScheduleState.WAITING.value,
1311
+ job_info_table.c.dag_yaml_path: dag_yaml_path,
1312
+ job_info_table.c.original_user_yaml_path: original_user_yaml_path,
1313
+ job_info_table.c.env_file_path: env_file_path,
1314
+ job_info_table.c.user_hash: user_hash,
1315
+ job_info_table.c.priority: priority,
1316
+ })
1317
+ session.commit()
1246
1318
  # For a recovery run, the job may already be in the WAITING state.
1247
1319
  assert updated_count <= 1, (job_id, updated_count)
1248
1320
  return updated_count == 0
@@ -1252,119 +1324,140 @@ def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
1252
1324
  def scheduler_set_launching(job_id: int,
1253
1325
  current_state: ManagedJobScheduleState) -> None:
1254
1326
  """Do not call without holding the scheduler lock."""
1255
- assert _DB_PATH is not None
1256
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1257
- updated_count = cursor.execute(
1258
- 'UPDATE job_info SET '
1259
- 'schedule_state = (?) '
1260
- 'WHERE spot_job_id = (?) AND schedule_state = (?)',
1261
- (ManagedJobScheduleState.LAUNCHING.value, job_id,
1262
- current_state.value)).rowcount
1327
+ assert _SQLALCHEMY_ENGINE is not None
1328
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1329
+ updated_count = session.query(job_info_table).filter(
1330
+ sqlalchemy.and_(
1331
+ job_info_table.c.spot_job_id == job_id,
1332
+ job_info_table.c.schedule_state == current_state.value,
1333
+ )).update({
1334
+ job_info_table.c.schedule_state:
1335
+ ManagedJobScheduleState.LAUNCHING.value
1336
+ })
1337
+ session.commit()
1263
1338
  assert updated_count == 1, (job_id, updated_count)
1264
1339
 
1265
1340
 
1266
1341
  @_init_db
1267
1342
  def scheduler_set_alive(job_id: int) -> None:
1268
1343
  """Do not call without holding the scheduler lock."""
1269
- assert _DB_PATH is not None
1270
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1271
- updated_count = cursor.execute(
1272
- 'UPDATE job_info SET '
1273
- 'schedule_state = (?) '
1274
- 'WHERE spot_job_id = (?) AND schedule_state = (?)',
1275
- (ManagedJobScheduleState.ALIVE.value, job_id,
1276
- ManagedJobScheduleState.LAUNCHING.value)).rowcount
1344
+ assert _SQLALCHEMY_ENGINE is not None
1345
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1346
+ updated_count = session.query(job_info_table).filter(
1347
+ sqlalchemy.and_(
1348
+ job_info_table.c.spot_job_id == job_id,
1349
+ job_info_table.c.schedule_state ==
1350
+ ManagedJobScheduleState.LAUNCHING.value,
1351
+ )).update({
1352
+ job_info_table.c.schedule_state:
1353
+ ManagedJobScheduleState.ALIVE.value
1354
+ })
1355
+ session.commit()
1277
1356
  assert updated_count == 1, (job_id, updated_count)
1278
1357
 
1279
1358
 
1280
1359
  @_init_db
1281
1360
  def scheduler_set_alive_backoff(job_id: int) -> None:
1282
1361
  """Do not call without holding the scheduler lock."""
1283
- assert _DB_PATH is not None
1284
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1285
- updated_count = cursor.execute(
1286
- 'UPDATE job_info SET '
1287
- 'schedule_state = (?) '
1288
- 'WHERE spot_job_id = (?) AND schedule_state = (?)',
1289
- (ManagedJobScheduleState.ALIVE_BACKOFF.value, job_id,
1290
- ManagedJobScheduleState.LAUNCHING.value)).rowcount
1362
+ assert _SQLALCHEMY_ENGINE is not None
1363
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1364
+ updated_count = session.query(job_info_table).filter(
1365
+ sqlalchemy.and_(
1366
+ job_info_table.c.spot_job_id == job_id,
1367
+ job_info_table.c.schedule_state ==
1368
+ ManagedJobScheduleState.LAUNCHING.value,
1369
+ )).update({
1370
+ job_info_table.c.schedule_state:
1371
+ ManagedJobScheduleState.ALIVE_BACKOFF.value
1372
+ })
1373
+ session.commit()
1291
1374
  assert updated_count == 1, (job_id, updated_count)
1292
1375
 
1293
1376
 
1294
1377
  @_init_db
1295
1378
  def scheduler_set_alive_waiting(job_id: int) -> None:
1296
1379
  """Do not call without holding the scheduler lock."""
1297
- assert _DB_PATH is not None
1298
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1299
- updated_count = cursor.execute(
1300
- 'UPDATE job_info SET '
1301
- 'schedule_state = (?) '
1302
- 'WHERE spot_job_id = (?) AND schedule_state IN (?, ?)',
1303
- (ManagedJobScheduleState.ALIVE_WAITING.value, job_id,
1304
- ManagedJobScheduleState.ALIVE.value,
1305
- ManagedJobScheduleState.ALIVE_BACKOFF.value)).rowcount
1380
+ assert _SQLALCHEMY_ENGINE is not None
1381
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1382
+ updated_count = session.query(job_info_table).filter(
1383
+ sqlalchemy.and_(
1384
+ job_info_table.c.spot_job_id == job_id,
1385
+ job_info_table.c.schedule_state.in_([
1386
+ ManagedJobScheduleState.ALIVE.value,
1387
+ ManagedJobScheduleState.ALIVE_BACKOFF.value,
1388
+ ]))).update({
1389
+ job_info_table.c.schedule_state:
1390
+ ManagedJobScheduleState.ALIVE_WAITING.value
1391
+ })
1392
+ session.commit()
1306
1393
  assert updated_count == 1, (job_id, updated_count)
1307
1394
 
1308
1395
 
1309
1396
  @_init_db
1310
1397
  def scheduler_set_done(job_id: int, idempotent: bool = False) -> None:
1311
1398
  """Do not call without holding the scheduler lock."""
1312
- assert _DB_PATH is not None
1313
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1314
- updated_count = cursor.execute(
1315
- 'UPDATE job_info SET '
1316
- 'schedule_state = (?) '
1317
- 'WHERE spot_job_id = (?) AND schedule_state != (?)',
1318
- (ManagedJobScheduleState.DONE.value, job_id,
1319
- ManagedJobScheduleState.DONE.value)).rowcount
1399
+ assert _SQLALCHEMY_ENGINE is not None
1400
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1401
+ updated_count = session.query(job_info_table).filter(
1402
+ sqlalchemy.and_(
1403
+ job_info_table.c.spot_job_id == job_id,
1404
+ job_info_table.c.schedule_state !=
1405
+ ManagedJobScheduleState.DONE.value,
1406
+ )).update({
1407
+ job_info_table.c.schedule_state:
1408
+ ManagedJobScheduleState.DONE.value
1409
+ })
1410
+ session.commit()
1320
1411
  if not idempotent:
1321
1412
  assert updated_count == 1, (job_id, updated_count)
1322
1413
 
1323
1414
 
1324
1415
  @_init_db
1325
1416
  def set_job_controller_pid(job_id: int, pid: int):
1326
- assert _DB_PATH is not None
1327
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1328
- updated_count = cursor.execute(
1329
- 'UPDATE job_info SET '
1330
- 'controller_pid = (?) '
1331
- 'WHERE spot_job_id = (?)', (pid, job_id)).rowcount
1417
+ assert _SQLALCHEMY_ENGINE is not None
1418
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1419
+ updated_count = session.query(job_info_table).filter_by(
1420
+ spot_job_id=job_id).update({job_info_table.c.controller_pid: pid})
1421
+ session.commit()
1332
1422
  assert updated_count == 1, (job_id, updated_count)
1333
1423
 
1334
1424
 
1335
1425
  @_init_db
1336
1426
  def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState:
1337
- assert _DB_PATH is not None
1338
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1339
- state = cursor.execute(
1340
- 'SELECT schedule_state FROM job_info WHERE spot_job_id = (?)',
1341
- (job_id,)).fetchone()[0]
1427
+ assert _SQLALCHEMY_ENGINE is not None
1428
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1429
+ state = session.execute(
1430
+ sqlalchemy.select(job_info_table.c.schedule_state).where(
1431
+ job_info_table.c.spot_job_id == job_id)).fetchone()[0]
1342
1432
  return ManagedJobScheduleState(state)
1343
1433
 
1344
1434
 
1345
1435
  @_init_db
1346
1436
  def get_num_launching_jobs() -> int:
1347
- assert _DB_PATH is not None
1348
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1349
- return cursor.execute(
1350
- 'SELECT COUNT(*) '
1351
- 'FROM job_info '
1352
- 'WHERE schedule_state = (?)',
1353
- (ManagedJobScheduleState.LAUNCHING.value,)).fetchone()[0]
1437
+ assert _SQLALCHEMY_ENGINE is not None
1438
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1439
+ return session.execute(
1440
+ sqlalchemy.select(
1441
+ sqlalchemy.func.count() # pylint: disable=not-callable
1442
+ ).select_from(job_info_table).where(
1443
+ job_info_table.c.schedule_state ==
1444
+ ManagedJobScheduleState.LAUNCHING.value)).fetchone()[0]
1354
1445
 
1355
1446
 
1356
1447
  @_init_db
1357
1448
  def get_num_alive_jobs() -> int:
1358
- assert _DB_PATH is not None
1359
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1360
- return cursor.execute(
1361
- 'SELECT COUNT(*) '
1362
- 'FROM job_info '
1363
- 'WHERE schedule_state IN (?, ?, ?, ?)',
1364
- (ManagedJobScheduleState.ALIVE_WAITING.value,
1365
- ManagedJobScheduleState.LAUNCHING.value,
1366
- ManagedJobScheduleState.ALIVE.value,
1367
- ManagedJobScheduleState.ALIVE_BACKOFF.value)).fetchone()[0]
1449
+ assert _SQLALCHEMY_ENGINE is not None
1450
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1451
+ return session.execute(
1452
+ sqlalchemy.select(
1453
+ sqlalchemy.func.count() # pylint: disable=not-callable
1454
+ ).select_from(job_info_table).where(
1455
+ job_info_table.c.schedule_state.in_([
1456
+ ManagedJobScheduleState.ALIVE_WAITING.value,
1457
+ ManagedJobScheduleState.LAUNCHING.value,
1458
+ ManagedJobScheduleState.ALIVE.value,
1459
+ ManagedJobScheduleState.ALIVE_BACKOFF.value,
1460
+ ]))).fetchone()[0]
1368
1461
 
1369
1462
 
1370
1463
  @_init_db
@@ -1378,27 +1471,37 @@ def get_waiting_job() -> Optional[Dict[str, Any]]:
1378
1471
  Backwards compatibility note: jobs submitted before #4485 will have no
1379
1472
  schedule_state and will be ignored by this SQL query.
1380
1473
  """
1381
- assert _DB_PATH is not None
1382
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1474
+ assert _SQLALCHEMY_ENGINE is not None
1475
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1383
1476
  # Get the highest-priority WAITING or ALIVE_WAITING job whose priority
1384
1477
  # is greater than or equal to the highest priority LAUNCHING or
1385
1478
  # ALIVE_BACKOFF job's priority.
1386
- waiting_job_row = cursor.execute(
1387
- 'SELECT spot_job_id, schedule_state, dag_yaml_path, env_file_path '
1388
- 'FROM job_info '
1389
- 'WHERE schedule_state IN (?, ?) '
1390
- 'AND priority >= COALESCE('
1391
- ' (SELECT MAX(priority) '
1392
- ' FROM job_info '
1393
- ' WHERE schedule_state IN (?, ?)), '
1394
- ' 0'
1395
- ')'
1396
- 'ORDER BY priority DESC, spot_job_id ASC LIMIT 1',
1397
- (ManagedJobScheduleState.WAITING.value,
1398
- ManagedJobScheduleState.ALIVE_WAITING.value,
1399
- ManagedJobScheduleState.LAUNCHING.value,
1400
- ManagedJobScheduleState.ALIVE_BACKOFF.value)).fetchone()
1401
-
1479
+ # First, get the max priority of LAUNCHING or ALIVE_BACKOFF jobs
1480
+ max_priority_subquery = sqlalchemy.select(
1481
+ sqlalchemy.func.max(job_info_table.c.priority)).where(
1482
+ job_info_table.c.schedule_state.in_([
1483
+ ManagedJobScheduleState.LAUNCHING.value,
1484
+ ManagedJobScheduleState.ALIVE_BACKOFF.value,
1485
+ ])).scalar_subquery()
1486
+ # Main query for waiting jobs
1487
+ query = sqlalchemy.select(
1488
+ job_info_table.c.spot_job_id,
1489
+ job_info_table.c.schedule_state,
1490
+ job_info_table.c.dag_yaml_path,
1491
+ job_info_table.c.env_file_path,
1492
+ ).where(
1493
+ sqlalchemy.and_(
1494
+ job_info_table.c.schedule_state.in_([
1495
+ ManagedJobScheduleState.WAITING.value,
1496
+ ManagedJobScheduleState.ALIVE_WAITING.value,
1497
+ ]),
1498
+ job_info_table.c.priority >= sqlalchemy.func.coalesce(
1499
+ max_priority_subquery, 0),
1500
+ )).order_by(
1501
+ job_info_table.c.priority.desc(),
1502
+ job_info_table.c.spot_job_id.asc(),
1503
+ ).limit(1)
1504
+ waiting_job_row = session.execute(query).fetchone()
1402
1505
  if waiting_job_row is None:
1403
1506
  return None
1404
1507
 
@@ -1413,12 +1516,59 @@ def get_waiting_job() -> Optional[Dict[str, Any]]:
1413
1516
  @_init_db
1414
1517
  def get_workspace(job_id: int) -> str:
1415
1518
  """Get the workspace of a job."""
1416
- assert _DB_PATH is not None
1417
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1418
- workspace = cursor.execute(
1419
- 'SELECT workspace FROM job_info WHERE spot_job_id = (?)',
1420
- (job_id,)).fetchone()
1519
+ assert _SQLALCHEMY_ENGINE is not None
1520
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1521
+ workspace = session.execute(
1522
+ sqlalchemy.select(job_info_table.c.workspace).where(
1523
+ job_info_table.c.spot_job_id == job_id)).fetchone()
1421
1524
  job_workspace = workspace[0] if workspace else None
1422
1525
  if job_workspace is None:
1423
1526
  return constants.SKYPILOT_DEFAULT_WORKSPACE
1424
1527
  return job_workspace
1528
+
1529
+
1530
+ # === HA Recovery Script functions ===
1531
+
1532
+
1533
+ @_init_db
1534
+ def get_ha_recovery_script(job_id: int) -> Optional[str]:
1535
+ """Get the HA recovery script for a job."""
1536
+ assert _SQLALCHEMY_ENGINE is not None
1537
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1538
+ row = session.query(ha_recovery_script_table).filter_by(
1539
+ job_id=job_id).first()
1540
+ if row is None:
1541
+ return None
1542
+ return row.script
1543
+
1544
+
1545
+ @_init_db
1546
+ def set_ha_recovery_script(job_id: int, script: str) -> None:
1547
+ """Set the HA recovery script for a job."""
1548
+ assert _SQLALCHEMY_ENGINE is not None
1549
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1550
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
1551
+ db_utils.SQLAlchemyDialect.SQLITE.value):
1552
+ insert_func = sqlite.insert
1553
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
1554
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
1555
+ insert_func = postgresql.insert
1556
+ else:
1557
+ raise ValueError('Unsupported database dialect')
1558
+ insert_stmt = insert_func(ha_recovery_script_table).values(
1559
+ job_id=job_id, script=script)
1560
+ do_update_stmt = insert_stmt.on_conflict_do_update(
1561
+ index_elements=[ha_recovery_script_table.c.job_id],
1562
+ set_={ha_recovery_script_table.c.script: script})
1563
+ session.execute(do_update_stmt)
1564
+ session.commit()
1565
+
1566
+
1567
+ @_init_db
1568
+ def remove_ha_recovery_script(job_id: int) -> None:
1569
+ """Remove the HA recovery script for a job."""
1570
+ assert _SQLALCHEMY_ENGINE is not None
1571
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1572
+ session.query(ha_recovery_script_table).filter_by(
1573
+ job_id=job_id).delete()
1574
+ session.commit()