skypilot-nightly 1.0.0.dev20250609__py3-none-any.whl → 1.0.0.dev20250611__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. sky/__init__.py +2 -2
  2. sky/admin_policy.py +134 -5
  3. sky/authentication.py +1 -7
  4. sky/backends/cloud_vm_ray_backend.py +9 -20
  5. sky/benchmark/benchmark_state.py +39 -1
  6. sky/cli.py +3 -5
  7. sky/client/cli.py +3 -5
  8. sky/client/sdk.py +49 -4
  9. sky/clouds/kubernetes.py +15 -24
  10. sky/dashboard/out/404.html +1 -1
  11. sky/dashboard/out/_next/static/chunks/211.692afc57e812ae1a.js +1 -0
  12. sky/dashboard/out/_next/static/chunks/350.9e123a4551f68b0d.js +1 -0
  13. sky/dashboard/out/_next/static/chunks/37-d8aebf1683522a0b.js +6 -0
  14. sky/dashboard/out/_next/static/chunks/42.d39e24467181b06b.js +6 -0
  15. sky/dashboard/out/_next/static/chunks/443.b2242d0efcdf5f47.js +1 -0
  16. sky/dashboard/out/_next/static/chunks/470-4d1a5dbe58a8a2b9.js +1 -0
  17. sky/dashboard/out/_next/static/chunks/{121-865d2bf8a3b84c6a.js → 491.b3d264269613fe09.js} +3 -3
  18. sky/dashboard/out/_next/static/chunks/513.211357a2914a34b2.js +1 -0
  19. sky/dashboard/out/_next/static/chunks/600.15a0009177e86b86.js +16 -0
  20. sky/dashboard/out/_next/static/chunks/616-d6128fa9e7cae6e6.js +39 -0
  21. sky/dashboard/out/_next/static/chunks/664-047bc03493fda379.js +1 -0
  22. sky/dashboard/out/_next/static/chunks/682.4dd5dc116f740b5f.js +6 -0
  23. sky/dashboard/out/_next/static/chunks/760-a89d354797ce7af5.js +1 -0
  24. sky/dashboard/out/_next/static/chunks/799-3625946b2ec2eb30.js +8 -0
  25. sky/dashboard/out/_next/static/chunks/804-4c9fc53aa74bc191.js +21 -0
  26. sky/dashboard/out/_next/static/chunks/843-6fcc4bf91ac45b39.js +11 -0
  27. sky/dashboard/out/_next/static/chunks/856-0776dc6ed6000c39.js +1 -0
  28. sky/dashboard/out/_next/static/chunks/901-b424d293275e1fd7.js +1 -0
  29. sky/dashboard/out/_next/static/chunks/938-ab185187a63f9cdb.js +1 -0
  30. sky/dashboard/out/_next/static/chunks/947-6620842ef80ae879.js +35 -0
  31. sky/dashboard/out/_next/static/chunks/969-20d54a9d998dc102.js +1 -0
  32. sky/dashboard/out/_next/static/chunks/973-c807fc34f09c7df3.js +1 -0
  33. sky/dashboard/out/_next/static/chunks/pages/_app-7bbd9d39d6f9a98a.js +20 -0
  34. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-89216c616dbaa9c5.js +6 -0
  35. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-451a14e7e755ebbc.js +6 -0
  36. sky/dashboard/out/_next/static/chunks/pages/clusters-e56b17fd85d0ba58.js +1 -0
  37. sky/dashboard/out/_next/static/chunks/pages/config-497a35a7ed49734a.js +1 -0
  38. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-d2910be98e9227cb.js +1 -0
  39. sky/dashboard/out/_next/static/chunks/pages/infra-780860bcc1103945.js +1 -0
  40. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-b3dbf38b51cb29be.js +16 -0
  41. sky/dashboard/out/_next/static/chunks/pages/jobs-fe233baf3d073491.js +1 -0
  42. sky/dashboard/out/_next/static/chunks/pages/users-c69ffcab9d6e5269.js +1 -0
  43. sky/dashboard/out/_next/static/chunks/pages/workspace/new-31aa8bdcb7592635.js +1 -0
  44. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-c8c2191328532b7d.js +1 -0
  45. sky/dashboard/out/_next/static/chunks/pages/workspaces-82e6601baa5dd280.js +1 -0
  46. sky/dashboard/out/_next/static/chunks/webpack-208a9812ab4f61c9.js +1 -0
  47. sky/dashboard/out/_next/static/css/{8b1c8321d4c02372.css → 5d71bfc09f184bab.css} +1 -1
  48. sky/dashboard/out/_next/static/zJqasksBQ3HcqMpA2wTUZ/_buildManifest.js +1 -0
  49. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  50. sky/dashboard/out/clusters/[cluster].html +1 -1
  51. sky/dashboard/out/clusters.html +1 -1
  52. sky/dashboard/out/config.html +1 -1
  53. sky/dashboard/out/index.html +1 -1
  54. sky/dashboard/out/infra/[context].html +1 -1
  55. sky/dashboard/out/infra.html +1 -1
  56. sky/dashboard/out/jobs/[job].html +1 -1
  57. sky/dashboard/out/jobs.html +1 -1
  58. sky/dashboard/out/users.html +1 -1
  59. sky/dashboard/out/workspace/new.html +1 -1
  60. sky/dashboard/out/workspaces/[name].html +1 -1
  61. sky/dashboard/out/workspaces.html +1 -1
  62. sky/exceptions.py +18 -0
  63. sky/global_user_state.py +181 -74
  64. sky/jobs/client/sdk.py +29 -21
  65. sky/jobs/scheduler.py +4 -5
  66. sky/jobs/state.py +104 -11
  67. sky/jobs/utils.py +5 -5
  68. sky/provision/kubernetes/constants.py +9 -0
  69. sky/provision/kubernetes/utils.py +106 -7
  70. sky/serve/client/sdk.py +56 -45
  71. sky/server/common.py +1 -5
  72. sky/server/requests/executor.py +50 -20
  73. sky/server/requests/payloads.py +3 -0
  74. sky/server/requests/process.py +69 -29
  75. sky/server/server.py +1 -0
  76. sky/server/stream_utils.py +111 -55
  77. sky/skylet/constants.py +1 -2
  78. sky/skylet/job_lib.py +95 -40
  79. sky/skypilot_config.py +99 -25
  80. sky/users/permission.py +34 -17
  81. sky/utils/admin_policy_utils.py +41 -16
  82. sky/utils/context.py +21 -1
  83. sky/utils/controller_utils.py +16 -1
  84. sky/utils/kubernetes/exec_kubeconfig_converter.py +19 -47
  85. sky/utils/schemas.py +11 -3
  86. {skypilot_nightly-1.0.0.dev20250609.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/METADATA +1 -1
  87. {skypilot_nightly-1.0.0.dev20250609.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/RECORD +92 -81
  88. sky/dashboard/out/_next/static/chunks/236-619ed0248fb6fdd9.js +0 -6
  89. sky/dashboard/out/_next/static/chunks/293-351268365226d251.js +0 -1
  90. sky/dashboard/out/_next/static/chunks/37-600191c5804dcae2.js +0 -6
  91. sky/dashboard/out/_next/static/chunks/470-680c19413b8f808b.js +0 -1
  92. sky/dashboard/out/_next/static/chunks/63-e2d7b1e75e67c713.js +0 -66
  93. sky/dashboard/out/_next/static/chunks/682-b60cfdacc15202e8.js +0 -6
  94. sky/dashboard/out/_next/static/chunks/843-16c7194621b2b512.js +0 -11
  95. sky/dashboard/out/_next/static/chunks/856-affc52adf5403a3a.js +0 -1
  96. sky/dashboard/out/_next/static/chunks/969-2c584e28e6b4b106.js +0 -1
  97. sky/dashboard/out/_next/static/chunks/973-aed916d5b02d2d63.js +0 -1
  98. sky/dashboard/out/_next/static/chunks/pages/_app-5f16aba5794ee8e7.js +0 -1
  99. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-d31688d3e52736dd.js +0 -6
  100. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-e7d8710a9b0491e5.js +0 -6
  101. sky/dashboard/out/_next/static/chunks/pages/clusters-3c674e5d970e05cb.js +0 -1
  102. sky/dashboard/out/_next/static/chunks/pages/config-3aac7a015c6eede1.js +0 -6
  103. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-46d2e4ad6c487260.js +0 -1
  104. sky/dashboard/out/_next/static/chunks/pages/infra-7013d816a2a0e76c.js +0 -1
  105. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-f7f0c9e156d328bc.js +0 -16
  106. sky/dashboard/out/_next/static/chunks/pages/jobs-87e60396c376292f.js +0 -1
  107. sky/dashboard/out/_next/static/chunks/pages/users-9355a0f13d1db61d.js +0 -16
  108. sky/dashboard/out/_next/static/chunks/pages/workspace/new-9a749cca1813bd27.js +0 -1
  109. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-8eeb628e03902f1b.js +0 -1
  110. sky/dashboard/out/_next/static/chunks/pages/workspaces-8fbcc5ab4af316d0.js +0 -1
  111. sky/dashboard/out/_next/static/chunks/webpack-65d465f948974c0d.js +0 -1
  112. sky/dashboard/out/_next/static/xos0euNCptbGAM7_Q3Acl/_buildManifest.js +0 -1
  113. /sky/dashboard/out/_next/static/{xos0euNCptbGAM7_Q3Acl → zJqasksBQ3HcqMpA2wTUZ}/_ssgManifest.js +0 -0
  114. {skypilot_nightly-1.0.0.dev20250609.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/WHEEL +0 -0
  115. {skypilot_nightly-1.0.0.dev20250609.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/entry_points.txt +0 -0
  116. {skypilot_nightly-1.0.0.dev20250609.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/licenses/LICENSE +0 -0
  117. {skypilot_nightly-1.0.0.dev20250609.dist-info → skypilot_nightly-1.0.0.dev20250611.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@
3
3
  import asyncio
4
4
  import collections
5
5
  import pathlib
6
- from typing import AsyncGenerator, Deque, Optional
6
+ from typing import AsyncGenerator, Deque, List, Optional
7
7
 
8
8
  import aiofiles
9
9
  import fastapi
@@ -15,6 +15,12 @@ from sky.utils import rich_utils
15
15
 
16
16
  logger = sky_logging.init_logger(__name__)
17
17
 
18
+ # When streaming log lines, buffer the lines in memory and flush them in chunks
19
+ # to improve log tailing throughput. Buffer size is the max size bytes of each
20
+ # chunk and the timeout threshold for flushing the buffer to ensure
21
+ # responsiveness.
22
+ _BUFFER_SIZE = 8 * 1024 # 8KB
23
+ _BUFFER_TIMEOUT = 0.02 # 20ms
18
24
  _HEARTBEAT_INTERVAL = 30
19
25
 
20
26
 
@@ -36,7 +42,16 @@ async def log_streamer(request_id: Optional[str],
36
42
  plain_logs: bool = False,
37
43
  tail: Optional[int] = None,
38
44
  follow: bool = True) -> AsyncGenerator[str, None]:
39
- """Streams the logs of a request."""
45
+ """Streams the logs of a request.
46
+
47
+ Args:
48
+ request_id: The request ID to check whether the log tailing process
49
+ should be stopped.
50
+ log_path: The path to the log file.
51
+ plain_logs: Whether to show plain logs.
52
+ tail: The number of lines to tail. If None, tail the whole file.
53
+ follow: Whether to follow the log file.
54
+ """
40
55
 
41
56
  if request_id is not None:
42
57
  status_msg = rich_utils.EncodedStatusMessage(
@@ -80,65 +95,106 @@ async def log_streamer(request_id: Optional[str],
80
95
  if show_request_waiting_spinner:
81
96
  yield status_msg.stop()
82
97
 
83
- # Find last n lines of the log file. Do not read the whole file into memory.
84
98
  async with aiofiles.open(log_path, 'rb') as f:
85
- if tail is not None:
86
- # TODO(zhwu): this will include the control lines for rich status,
87
- # which may not lead to exact tail lines when showing on the client
88
- # side.
89
- lines: Deque[str] = collections.deque(maxlen=tail)
90
- async for line_str in _yield_log_file_with_payloads_skipped(f):
91
- lines.append(line_str)
92
- for line_str in lines:
93
- yield line_str
94
-
95
- last_heartbeat_time = asyncio.get_event_loop().time()
99
+ async for chunk in _tail_log_file(f, request_id, plain_logs, tail,
100
+ follow):
101
+ yield chunk
102
+
103
+
104
+ async def _tail_log_file(f: aiofiles.threadpool.binary.AsyncBufferedReader,
105
+ request_id: Optional[str] = None,
106
+ plain_logs: bool = False,
107
+ tail: Optional[int] = None,
108
+ follow: bool = True) -> AsyncGenerator[str, None]:
109
+ """Tail the opened log file, buffer the lines and flush in chunks."""
110
+
111
+ if tail is not None:
112
+ # Find last n lines of the log file. Do not read the whole file into
113
+ # memory.
114
+ # TODO(zhwu): this will include the control lines for rich status,
115
+ # which may not lead to exact tail lines when showing on the client
116
+ # side.
117
+ lines: Deque[str] = collections.deque(maxlen=tail)
118
+ async for line_str in _yield_log_file_with_payloads_skipped(f):
119
+ lines.append(line_str)
120
+ for line_str in lines:
121
+ yield line_str
96
122
 
97
- while True:
98
- # Sleep 0 to yield control to allow other coroutines to run,
99
- # while keeps the loop tight to make log stream responsive.
100
- await asyncio.sleep(0)
101
- line: Optional[bytes] = await f.readline()
102
- if not line:
103
- if request_id is not None:
104
- request_task = requests_lib.get_request(request_id)
105
- if request_task.status > requests_lib.RequestStatus.RUNNING:
106
- if (request_task.status ==
107
- requests_lib.RequestStatus.CANCELLED):
108
- yield (f'{request_task.name!r} request {request_id}'
109
- ' cancelled\n')
110
- break
111
- if not follow:
123
+ last_heartbeat_time = asyncio.get_event_loop().time()
124
+
125
+ # Buffer the lines in memory and flush them in chunks to improve log
126
+ # tailing throughput.
127
+ buffer: List[str] = []
128
+ buffer_bytes = 0
129
+ last_flush_time = asyncio.get_event_loop().time()
130
+
131
+ async def flush_buffer() -> AsyncGenerator[str, None]:
132
+ nonlocal buffer, buffer_bytes, last_flush_time
133
+ if buffer:
134
+ yield ''.join(buffer)
135
+ buffer.clear()
136
+ buffer_bytes = 0
137
+ last_flush_time = asyncio.get_event_loop().time()
138
+
139
+ while True:
140
+ # Sleep 0 to yield control to allow other coroutines to run,
141
+ # while keeps the loop tight to make log stream responsive.
142
+ await asyncio.sleep(0)
143
+ current_time = asyncio.get_event_loop().time()
144
+ # Flush the buffer when it is not empty and the buffer is full or the
145
+ # flush timeout is reached.
146
+ if buffer and (buffer_bytes >= _BUFFER_SIZE or
147
+ (current_time - last_flush_time) >= _BUFFER_TIMEOUT):
148
+ async for chunk in flush_buffer():
149
+ yield chunk
150
+
151
+ line: Optional[bytes] = await f.readline()
152
+ if not line:
153
+ if request_id is not None:
154
+ request_task = requests_lib.get_request(request_id)
155
+ if request_task.status > requests_lib.RequestStatus.RUNNING:
156
+ if (request_task.status ==
157
+ requests_lib.RequestStatus.CANCELLED):
158
+ buffer.append(
159
+ f'{request_task.name!r} request {request_id}'
160
+ ' cancelled\n')
112
161
  break
162
+ if not follow:
163
+ break
164
+
165
+ if current_time - last_heartbeat_time >= _HEARTBEAT_INTERVAL:
166
+ # Currently just used to keep the connection busy, refer to
167
+ # https://github.com/skypilot-org/skypilot/issues/5750 for
168
+ # more details.
169
+ buffer.append(
170
+ message_utils.encode_payload(
171
+ rich_utils.Control.HEARTBEAT.encode('')))
172
+ last_heartbeat_time = current_time
173
+
174
+ # Sleep shortly to avoid storming the DB and CPU, this has
175
+ # little impact on the responsivness here since we are waiting
176
+ # for a new line to come in.
177
+ await asyncio.sleep(0.1)
178
+ continue
113
179
 
114
- current_time = asyncio.get_event_loop().time()
115
- if current_time - last_heartbeat_time >= _HEARTBEAT_INTERVAL:
116
- # Currently just used to keep the connection busy, refer to
117
- # https://github.com/skypilot-org/skypilot/issues/5750 for
118
- # more details.
119
- yield message_utils.encode_payload(
120
- rich_utils.Control.HEARTBEAT.encode(''))
121
- last_heartbeat_time = current_time
122
-
123
- # Sleep shortly to avoid storming the DB and CPU, this has
124
- # little impact on the responsivness here since we are waiting
125
- # for a new line to come in.
126
- await asyncio.sleep(0.1)
180
+ # Refresh the heartbeat time, this is a trivial optimization for
181
+ # performance but it helps avoid unnecessary heartbeat strings
182
+ # being printed when the client runs in an old version.
183
+ last_heartbeat_time = asyncio.get_event_loop().time()
184
+ line_str = line.decode('utf-8')
185
+ if plain_logs:
186
+ is_payload, line_str = message_utils.decode_payload(
187
+ line_str, raise_for_mismatch=False)
188
+ # TODO(aylei): implement heartbeat mechanism for plain logs,
189
+ # sending invisible characters might be okay.
190
+ if is_payload:
127
191
  continue
192
+ buffer.append(line_str)
193
+ buffer_bytes += len(line_str.encode('utf-8'))
128
194
 
129
- # Refresh the heartbeat time, this is a trivial optimization for
130
- # performance but it helps avoid unnecessary heartbeat strings
131
- # being printed when the client runs in an old version.
132
- last_heartbeat_time = asyncio.get_event_loop().time()
133
- line_str = line.decode('utf-8')
134
- if plain_logs:
135
- is_payload, line_str = message_utils.decode_payload(
136
- line_str, raise_for_mismatch=False)
137
- # TODO(aylei): implement heartbeat mechanism for plain logs,
138
- # sending invisible characters might be okay.
139
- if is_payload:
140
- continue
141
- yield line_str
195
+ # Flush remaining lines in the buffer.
196
+ async for chunk in flush_buffer():
197
+ yield chunk
142
198
 
143
199
 
144
200
  def stream_response(
sky/skylet/constants.py CHANGED
@@ -377,8 +377,7 @@ OVERRIDEABLE_CONFIG_KEYS_IN_TASK: List[Tuple[str, ...]] = [
377
377
  ]
378
378
  # When overriding the SkyPilot configs on the API server with the client one,
379
379
  # we skip the following keys because they are meant to be client-side configs.
380
- SKIPPED_CLIENT_OVERRIDE_KEYS: List[Tuple[str, ...]] = [('admin_policy',),
381
- ('api_server',),
380
+ SKIPPED_CLIENT_OVERRIDE_KEYS: List[Tuple[str, ...]] = [('api_server',),
382
381
  ('allowed_clouds',),
383
382
  ('workspaces',), ('db',)]
384
383
 
sky/skylet/job_lib.py CHANGED
@@ -3,6 +3,7 @@
3
3
  This is a remote utility module that provides job queue functionality.
4
4
  """
5
5
  import enum
6
+ import functools
6
7
  import getpass
7
8
  import json
8
9
  import os
@@ -10,6 +11,7 @@ import pathlib
10
11
  import shlex
11
12
  import signal
12
13
  import sqlite3
14
+ import threading
13
15
  import time
14
16
  import typing
15
17
  from typing import Any, Dict, List, Optional, Sequence
@@ -119,9 +121,25 @@ def create_table(cursor, conn):
119
121
  conn.commit()
120
122
 
121
123
 
122
- _DB = db_utils.SQLiteConn(_DB_PATH, create_table)
123
- _CURSOR = _DB.cursor
124
- _CONN = _DB.conn
124
+ _DB = None
125
+ _db_init_lock = threading.Lock()
126
+
127
+
128
+ def init_db(func):
129
+ """Initialize the database."""
130
+
131
+ @functools.wraps(func)
132
+ def wrapper(*args, **kwargs):
133
+ global _DB
134
+ if _DB is not None:
135
+ return func(*args, **kwargs)
136
+
137
+ with _db_init_lock:
138
+ if _DB is None:
139
+ _DB = db_utils.SQLiteConn(_DB_PATH, create_table)
140
+ return func(*args, **kwargs)
141
+
142
+ return wrapper
125
143
 
126
144
 
127
145
  class JobStatus(enum.Enum):
@@ -210,30 +228,37 @@ _PRE_RESOURCE_STATUSES = [JobStatus.PENDING]
210
228
  class JobScheduler:
211
229
  """Base class for job scheduler"""
212
230
 
231
+ @init_db
213
232
  def queue(self, job_id: int, cmd: str) -> None:
214
- _CURSOR.execute('INSERT INTO pending_jobs VALUES (?,?,?,?)',
215
- (job_id, cmd, 0, int(time.time())))
216
- _CONN.commit()
233
+ assert _DB is not None
234
+ _DB.cursor.execute('INSERT INTO pending_jobs VALUES (?,?,?,?)',
235
+ (job_id, cmd, 0, int(time.time())))
236
+ _DB.conn.commit()
217
237
  set_status(job_id, JobStatus.PENDING)
218
238
  self.schedule_step()
219
239
 
240
+ @init_db
220
241
  def remove_job_no_lock(self, job_id: int) -> None:
221
- _CURSOR.execute(f'DELETE FROM pending_jobs WHERE job_id={job_id!r}')
222
- _CONN.commit()
242
+ assert _DB is not None
243
+ _DB.cursor.execute(f'DELETE FROM pending_jobs WHERE job_id={job_id!r}')
244
+ _DB.conn.commit()
223
245
 
246
+ @init_db
224
247
  def _run_job(self, job_id: int, run_cmd: str):
225
- _CURSOR.execute((f'UPDATE pending_jobs SET submit={int(time.time())} '
226
- f'WHERE job_id={job_id!r}'))
227
- _CONN.commit()
248
+ assert _DB is not None
249
+ _DB.cursor.execute(
250
+ (f'UPDATE pending_jobs SET submit={int(time.time())} '
251
+ f'WHERE job_id={job_id!r}'))
252
+ _DB.conn.commit()
228
253
  pid = subprocess_utils.launch_new_process_tree(run_cmd)
229
254
  # TODO(zhwu): Backward compatibility, remove this check after 0.10.0.
230
255
  # This is for the case where the job is submitted with SkyPilot older
231
256
  # than #4318, using ray job submit.
232
257
  if 'job submit' in run_cmd:
233
258
  pid = -1
234
- _CURSOR.execute((f'UPDATE jobs SET pid={pid} '
235
- f'WHERE job_id={job_id!r}'))
236
- _CONN.commit()
259
+ _DB.cursor.execute((f'UPDATE jobs SET pid={pid} '
260
+ f'WHERE job_id={job_id!r}'))
261
+ _DB.conn.commit()
237
262
 
238
263
  def schedule_step(self, force_update_jobs: bool = False) -> None:
239
264
  if force_update_jobs:
@@ -282,8 +307,10 @@ class JobScheduler:
282
307
  class FIFOScheduler(JobScheduler):
283
308
  """First in first out job scheduler"""
284
309
 
310
+ @init_db
285
311
  def _get_pending_job_ids(self) -> List[int]:
286
- rows = _CURSOR.execute(
312
+ assert _DB is not None
313
+ rows = _DB.cursor.execute(
287
314
  'SELECT job_id FROM pending_jobs ORDER BY job_id').fetchall()
288
315
  return [row[0] for row in rows]
289
316
 
@@ -308,26 +335,30 @@ def make_job_command_with_user_switching(username: str,
308
335
  return ['sudo', '-H', 'su', '--login', username, '-c', command]
309
336
 
310
337
 
338
+ @init_db
311
339
  def add_job(job_name: str, username: str, run_timestamp: str,
312
340
  resources_str: str) -> int:
313
341
  """Atomically reserve the next available job id for the user."""
342
+ assert _DB is not None
314
343
  job_submitted_at = time.time()
315
344
  # job_id will autoincrement with the null value
316
- _CURSOR.execute(
345
+ _DB.cursor.execute(
317
346
  'INSERT INTO jobs VALUES (null, ?, ?, ?, ?, ?, ?, null, ?, 0)',
318
347
  (job_name, username, job_submitted_at, JobStatus.INIT.value,
319
348
  run_timestamp, None, resources_str))
320
- _CONN.commit()
321
- rows = _CURSOR.execute('SELECT job_id FROM jobs WHERE run_timestamp=(?)',
322
- (run_timestamp,))
349
+ _DB.conn.commit()
350
+ rows = _DB.cursor.execute('SELECT job_id FROM jobs WHERE run_timestamp=(?)',
351
+ (run_timestamp,))
323
352
  for row in rows:
324
353
  job_id = row[0]
325
354
  assert job_id is not None
326
355
  return job_id
327
356
 
328
357
 
358
+ @init_db
329
359
  def _set_status_no_lock(job_id: int, status: JobStatus) -> None:
330
360
  """Setting the status of the job in the database."""
361
+ assert _DB is not None
331
362
  assert status != JobStatus.RUNNING, (
332
363
  'Please use set_job_started() to set job status to RUNNING')
333
364
  if status.is_terminal():
@@ -339,15 +370,15 @@ def _set_status_no_lock(job_id: int, status: JobStatus) -> None:
339
370
  check_end_at_str = ' AND end_at IS NULL'
340
371
  if status != JobStatus.FAILED_SETUP:
341
372
  check_end_at_str = ''
342
- _CURSOR.execute(
373
+ _DB.cursor.execute(
343
374
  'UPDATE jobs SET status=(?), end_at=(?) '
344
375
  f'WHERE job_id=(?) {check_end_at_str}',
345
376
  (status.value, end_at, job_id))
346
377
  else:
347
- _CURSOR.execute(
378
+ _DB.cursor.execute(
348
379
  'UPDATE jobs SET status=(?), end_at=NULL '
349
380
  'WHERE job_id=(?)', (status.value, job_id))
350
- _CONN.commit()
381
+ _DB.conn.commit()
351
382
 
352
383
 
353
384
  def set_status(job_id: int, status: JobStatus) -> None:
@@ -357,16 +388,19 @@ def set_status(job_id: int, status: JobStatus) -> None:
357
388
  _set_status_no_lock(job_id, status)
358
389
 
359
390
 
391
+ @init_db
360
392
  def set_job_started(job_id: int) -> None:
361
393
  # TODO(mraheja): remove pylint disabling when filelock version updated.
362
394
  # pylint: disable=abstract-class-instantiated
395
+ assert _DB is not None
363
396
  with filelock.FileLock(_get_lock_path(job_id)):
364
- _CURSOR.execute(
397
+ _DB.cursor.execute(
365
398
  'UPDATE jobs SET status=(?), start_at=(?), end_at=NULL '
366
399
  'WHERE job_id=(?)', (JobStatus.RUNNING.value, time.time(), job_id))
367
- _CONN.commit()
400
+ _DB.conn.commit()
368
401
 
369
402
 
403
+ @init_db
370
404
  def get_status_no_lock(job_id: int) -> Optional[JobStatus]:
371
405
  """Get the status of the job with the given id.
372
406
 
@@ -375,8 +409,9 @@ def get_status_no_lock(job_id: int) -> Optional[JobStatus]:
375
409
  the status in a while loop as in `log_lib._follow_job_logs`. Otherwise, use
376
410
  `get_status`.
377
411
  """
378
- rows = _CURSOR.execute('SELECT status FROM jobs WHERE job_id=(?)',
379
- (job_id,))
412
+ assert _DB is not None
413
+ rows = _DB.cursor.execute('SELECT status FROM jobs WHERE job_id=(?)',
414
+ (job_id,))
380
415
  for (status,) in rows:
381
416
  if status is None:
382
417
  return None
@@ -391,11 +426,13 @@ def get_status(job_id: int) -> Optional[JobStatus]:
391
426
  return get_status_no_lock(job_id)
392
427
 
393
428
 
429
+ @init_db
394
430
  def get_statuses_payload(job_ids: List[Optional[int]]) -> str:
431
+ assert _DB is not None
395
432
  # Per-job lock is not required here, since the staled job status will not
396
433
  # affect the caller.
397
434
  query_str = ','.join(['?'] * len(job_ids))
398
- rows = _CURSOR.execute(
435
+ rows = _DB.cursor.execute(
399
436
  f'SELECT job_id, status FROM jobs WHERE job_id IN ({query_str})',
400
437
  job_ids)
401
438
  statuses = {job_id: None for job_id in job_ids}
@@ -419,14 +456,17 @@ def load_statuses_payload(
419
456
  return statuses
420
457
 
421
458
 
459
+ @init_db
422
460
  def get_latest_job_id() -> Optional[int]:
423
- rows = _CURSOR.execute(
461
+ assert _DB is not None
462
+ rows = _DB.cursor.execute(
424
463
  'SELECT job_id FROM jobs ORDER BY job_id DESC LIMIT 1')
425
464
  for (job_id,) in rows:
426
465
  return job_id
427
466
  return None
428
467
 
429
468
 
469
+ @init_db
430
470
  def get_job_submitted_or_ended_timestamp_payload(job_id: int,
431
471
  get_ended_time: bool) -> str:
432
472
  """Get the job submitted/ended timestamp.
@@ -440,9 +480,10 @@ def get_job_submitted_or_ended_timestamp_payload(job_id: int,
440
480
  `format_job_queue()`), because the job may stay in PENDING if the cluster is
441
481
  busy.
442
482
  """
483
+ assert _DB is not None
443
484
  field = 'end_at' if get_ended_time else 'submitted_at'
444
- rows = _CURSOR.execute(f'SELECT {field} FROM jobs WHERE job_id=(?)',
445
- (job_id,))
485
+ rows = _DB.cursor.execute(f'SELECT {field} FROM jobs WHERE job_id=(?)',
486
+ (job_id,))
446
487
  for (timestamp,) in rows:
447
488
  return message_utils.encode_payload(timestamp)
448
489
  return message_utils.encode_payload(None)
@@ -496,10 +537,12 @@ def _get_records_from_rows(rows) -> List[Dict[str, Any]]:
496
537
  return records
497
538
 
498
539
 
540
+ @init_db
499
541
  def _get_jobs(
500
542
  user_hash: Optional[str],
501
543
  status_list: Optional[List[JobStatus]] = None) -> List[Dict[str, Any]]:
502
544
  """Returns jobs with the given fields, sorted by job_id, descending."""
545
+ assert _DB is not None
503
546
  if status_list is None:
504
547
  status_list = list(JobStatus)
505
548
  status_str_list = [repr(status.value) for status in status_list]
@@ -509,14 +552,16 @@ def _get_jobs(
509
552
  # We use the old username field for compatibility.
510
553
  filter_str += ' AND username=(?)'
511
554
  params.append(user_hash)
512
- rows = _CURSOR.execute(
555
+ rows = _DB.cursor.execute(
513
556
  f'SELECT * FROM jobs {filter_str} ORDER BY job_id DESC', params)
514
557
  records = _get_records_from_rows(rows)
515
558
  return records
516
559
 
517
560
 
561
+ @init_db
518
562
  def _get_jobs_by_ids(job_ids: List[int]) -> List[Dict[str, Any]]:
519
- rows = _CURSOR.execute(
563
+ assert _DB is not None
564
+ rows = _DB.cursor.execute(
520
565
  f"""\
521
566
  SELECT * FROM jobs
522
567
  WHERE job_id IN ({','.join(['?'] * len(job_ids))})
@@ -527,8 +572,10 @@ def _get_jobs_by_ids(job_ids: List[int]) -> List[Dict[str, Any]]:
527
572
  return records
528
573
 
529
574
 
575
+ @init_db
530
576
  def _get_pending_job(job_id: int) -> Optional[Dict[str, Any]]:
531
- rows = _CURSOR.execute(
577
+ assert _DB is not None
578
+ rows = _DB.cursor.execute(
532
579
  'SELECT created_time, submit, run_cmd FROM pending_jobs '
533
580
  f'WHERE job_id={job_id!r}')
534
581
  for row in rows:
@@ -698,16 +745,18 @@ def update_job_status(job_ids: List[int],
698
745
  return statuses
699
746
 
700
747
 
748
+ @init_db
701
749
  def fail_all_jobs_in_progress() -> None:
750
+ assert _DB is not None
702
751
  in_progress_status = [
703
752
  status.value for status in JobStatus.nonterminal_statuses()
704
753
  ]
705
- _CURSOR.execute(
754
+ _DB.cursor.execute(
706
755
  f"""\
707
756
  UPDATE jobs SET status=(?)
708
757
  WHERE status IN ({','.join(['?'] * len(in_progress_status))})
709
758
  """, (JobStatus.FAILED_DRIVER.value, *in_progress_status))
710
- _CONN.commit()
759
+ _DB.conn.commit()
711
760
 
712
761
 
713
762
  def update_status() -> None:
@@ -720,12 +769,14 @@ def update_status() -> None:
720
769
  update_job_status(nonterminal_job_ids)
721
770
 
722
771
 
772
+ @init_db
723
773
  def is_cluster_idle() -> bool:
724
774
  """Returns if the cluster is idle (no in-flight jobs)."""
775
+ assert _DB is not None
725
776
  in_progress_status = [
726
777
  status.value for status in JobStatus.nonterminal_statuses()
727
778
  ]
728
- rows = _CURSOR.execute(
779
+ rows = _DB.cursor.execute(
729
780
  f"""\
730
781
  SELECT COUNT(*) FROM jobs
731
782
  WHERE status IN ({','.join(['?'] * len(in_progress_status))})
@@ -905,27 +956,31 @@ def cancel_jobs_encoded_results(jobs: Optional[List[int]],
905
956
  return message_utils.encode_payload(cancelled_ids)
906
957
 
907
958
 
959
+ @init_db
908
960
  def get_run_timestamp(job_id: Optional[int]) -> Optional[str]:
909
961
  """Returns the relative path to the log file for a job."""
910
- _CURSOR.execute(
962
+ assert _DB is not None
963
+ _DB.cursor.execute(
911
964
  """\
912
965
  SELECT * FROM jobs
913
966
  WHERE job_id=(?)""", (job_id,))
914
- row = _CURSOR.fetchone()
967
+ row = _DB.cursor.fetchone()
915
968
  if row is None:
916
969
  return None
917
970
  run_timestamp = row[JobInfoLoc.RUN_TIMESTAMP.value]
918
971
  return run_timestamp
919
972
 
920
973
 
974
+ @init_db
921
975
  def run_timestamp_with_globbing_payload(job_ids: List[Optional[str]]) -> str:
922
976
  """Returns the relative paths to the log files for job with globbing."""
977
+ assert _DB is not None
923
978
  query_str = ' OR '.join(['job_id GLOB (?)'] * len(job_ids))
924
- _CURSOR.execute(
979
+ _DB.cursor.execute(
925
980
  f"""\
926
981
  SELECT * FROM jobs
927
982
  WHERE {query_str}""", job_ids)
928
- rows = _CURSOR.fetchall()
983
+ rows = _DB.cursor.fetchall()
929
984
  run_timestamps = {}
930
985
  for row in rows:
931
986
  job_id = row[JobInfoLoc.JOB_ID.value]