skypilot-nightly 1.0.0.dev2024053101__py3-none-any.whl → 1.0.0.dev2025022801__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. sky/__init__.py +64 -32
  2. sky/adaptors/aws.py +23 -6
  3. sky/adaptors/azure.py +432 -15
  4. sky/adaptors/cloudflare.py +5 -5
  5. sky/adaptors/common.py +19 -9
  6. sky/adaptors/do.py +20 -0
  7. sky/adaptors/gcp.py +3 -2
  8. sky/adaptors/kubernetes.py +122 -88
  9. sky/adaptors/nebius.py +100 -0
  10. sky/adaptors/oci.py +39 -1
  11. sky/adaptors/vast.py +29 -0
  12. sky/admin_policy.py +101 -0
  13. sky/authentication.py +117 -98
  14. sky/backends/backend.py +52 -20
  15. sky/backends/backend_utils.py +669 -557
  16. sky/backends/cloud_vm_ray_backend.py +1099 -808
  17. sky/backends/local_docker_backend.py +14 -8
  18. sky/backends/wheel_utils.py +38 -20
  19. sky/benchmark/benchmark_utils.py +22 -23
  20. sky/check.py +76 -27
  21. sky/cli.py +1586 -1139
  22. sky/client/__init__.py +1 -0
  23. sky/client/cli.py +5683 -0
  24. sky/client/common.py +345 -0
  25. sky/client/sdk.py +1765 -0
  26. sky/cloud_stores.py +283 -19
  27. sky/clouds/__init__.py +7 -2
  28. sky/clouds/aws.py +303 -112
  29. sky/clouds/azure.py +185 -179
  30. sky/clouds/cloud.py +115 -37
  31. sky/clouds/cudo.py +29 -22
  32. sky/clouds/do.py +313 -0
  33. sky/clouds/fluidstack.py +44 -54
  34. sky/clouds/gcp.py +206 -65
  35. sky/clouds/ibm.py +26 -21
  36. sky/clouds/kubernetes.py +345 -91
  37. sky/clouds/lambda_cloud.py +40 -29
  38. sky/clouds/nebius.py +297 -0
  39. sky/clouds/oci.py +129 -90
  40. sky/clouds/paperspace.py +22 -18
  41. sky/clouds/runpod.py +53 -34
  42. sky/clouds/scp.py +28 -24
  43. sky/clouds/service_catalog/__init__.py +19 -13
  44. sky/clouds/service_catalog/aws_catalog.py +29 -12
  45. sky/clouds/service_catalog/azure_catalog.py +33 -6
  46. sky/clouds/service_catalog/common.py +95 -75
  47. sky/clouds/service_catalog/constants.py +3 -3
  48. sky/clouds/service_catalog/cudo_catalog.py +13 -3
  49. sky/clouds/service_catalog/data_fetchers/fetch_aws.py +36 -21
  50. sky/clouds/service_catalog/data_fetchers/fetch_azure.py +31 -4
  51. sky/clouds/service_catalog/data_fetchers/fetch_cudo.py +8 -117
  52. sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +197 -44
  53. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +224 -36
  54. sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +44 -24
  55. sky/clouds/service_catalog/data_fetchers/fetch_vast.py +147 -0
  56. sky/clouds/service_catalog/data_fetchers/fetch_vsphere.py +1 -1
  57. sky/clouds/service_catalog/do_catalog.py +111 -0
  58. sky/clouds/service_catalog/fluidstack_catalog.py +2 -2
  59. sky/clouds/service_catalog/gcp_catalog.py +16 -2
  60. sky/clouds/service_catalog/ibm_catalog.py +2 -2
  61. sky/clouds/service_catalog/kubernetes_catalog.py +192 -70
  62. sky/clouds/service_catalog/lambda_catalog.py +8 -3
  63. sky/clouds/service_catalog/nebius_catalog.py +116 -0
  64. sky/clouds/service_catalog/oci_catalog.py +31 -4
  65. sky/clouds/service_catalog/paperspace_catalog.py +2 -2
  66. sky/clouds/service_catalog/runpod_catalog.py +2 -2
  67. sky/clouds/service_catalog/scp_catalog.py +2 -2
  68. sky/clouds/service_catalog/vast_catalog.py +104 -0
  69. sky/clouds/service_catalog/vsphere_catalog.py +2 -2
  70. sky/clouds/utils/aws_utils.py +65 -0
  71. sky/clouds/utils/azure_utils.py +91 -0
  72. sky/clouds/utils/gcp_utils.py +5 -9
  73. sky/clouds/utils/oci_utils.py +47 -5
  74. sky/clouds/utils/scp_utils.py +4 -3
  75. sky/clouds/vast.py +280 -0
  76. sky/clouds/vsphere.py +22 -18
  77. sky/core.py +361 -107
  78. sky/dag.py +41 -28
  79. sky/data/data_transfer.py +37 -0
  80. sky/data/data_utils.py +211 -32
  81. sky/data/mounting_utils.py +182 -30
  82. sky/data/storage.py +2118 -270
  83. sky/data/storage_utils.py +126 -5
  84. sky/exceptions.py +179 -8
  85. sky/execution.py +158 -85
  86. sky/global_user_state.py +150 -34
  87. sky/jobs/__init__.py +12 -10
  88. sky/jobs/client/__init__.py +0 -0
  89. sky/jobs/client/sdk.py +302 -0
  90. sky/jobs/constants.py +49 -11
  91. sky/jobs/controller.py +161 -99
  92. sky/jobs/dashboard/dashboard.py +171 -25
  93. sky/jobs/dashboard/templates/index.html +572 -60
  94. sky/jobs/recovery_strategy.py +157 -156
  95. sky/jobs/scheduler.py +307 -0
  96. sky/jobs/server/__init__.py +1 -0
  97. sky/jobs/server/core.py +598 -0
  98. sky/jobs/server/dashboard_utils.py +69 -0
  99. sky/jobs/server/server.py +190 -0
  100. sky/jobs/state.py +627 -122
  101. sky/jobs/utils.py +615 -206
  102. sky/models.py +27 -0
  103. sky/optimizer.py +142 -83
  104. sky/provision/__init__.py +20 -5
  105. sky/provision/aws/config.py +124 -42
  106. sky/provision/aws/instance.py +130 -53
  107. sky/provision/azure/__init__.py +7 -0
  108. sky/{skylet/providers → provision}/azure/azure-config-template.json +19 -7
  109. sky/provision/azure/config.py +220 -0
  110. sky/provision/azure/instance.py +1012 -37
  111. sky/provision/common.py +31 -3
  112. sky/provision/constants.py +25 -0
  113. sky/provision/cudo/__init__.py +2 -1
  114. sky/provision/cudo/cudo_utils.py +112 -0
  115. sky/provision/cudo/cudo_wrapper.py +37 -16
  116. sky/provision/cudo/instance.py +28 -12
  117. sky/provision/do/__init__.py +11 -0
  118. sky/provision/do/config.py +14 -0
  119. sky/provision/do/constants.py +10 -0
  120. sky/provision/do/instance.py +287 -0
  121. sky/provision/do/utils.py +301 -0
  122. sky/provision/docker_utils.py +82 -46
  123. sky/provision/fluidstack/fluidstack_utils.py +57 -125
  124. sky/provision/fluidstack/instance.py +15 -43
  125. sky/provision/gcp/config.py +19 -9
  126. sky/provision/gcp/constants.py +7 -1
  127. sky/provision/gcp/instance.py +55 -34
  128. sky/provision/gcp/instance_utils.py +339 -80
  129. sky/provision/gcp/mig_utils.py +210 -0
  130. sky/provision/instance_setup.py +172 -133
  131. sky/provision/kubernetes/__init__.py +1 -0
  132. sky/provision/kubernetes/config.py +104 -90
  133. sky/provision/kubernetes/constants.py +8 -0
  134. sky/provision/kubernetes/instance.py +680 -325
  135. sky/provision/kubernetes/manifests/smarter-device-manager-daemonset.yaml +3 -0
  136. sky/provision/kubernetes/network.py +54 -20
  137. sky/provision/kubernetes/network_utils.py +70 -21
  138. sky/provision/kubernetes/utils.py +1370 -251
  139. sky/provision/lambda_cloud/__init__.py +11 -0
  140. sky/provision/lambda_cloud/config.py +10 -0
  141. sky/provision/lambda_cloud/instance.py +265 -0
  142. sky/{clouds/utils → provision/lambda_cloud}/lambda_utils.py +24 -23
  143. sky/provision/logging.py +1 -1
  144. sky/provision/nebius/__init__.py +11 -0
  145. sky/provision/nebius/config.py +11 -0
  146. sky/provision/nebius/instance.py +285 -0
  147. sky/provision/nebius/utils.py +318 -0
  148. sky/provision/oci/__init__.py +15 -0
  149. sky/provision/oci/config.py +51 -0
  150. sky/provision/oci/instance.py +436 -0
  151. sky/provision/oci/query_utils.py +681 -0
  152. sky/provision/paperspace/constants.py +6 -0
  153. sky/provision/paperspace/instance.py +4 -3
  154. sky/provision/paperspace/utils.py +2 -0
  155. sky/provision/provisioner.py +207 -130
  156. sky/provision/runpod/__init__.py +1 -0
  157. sky/provision/runpod/api/__init__.py +3 -0
  158. sky/provision/runpod/api/commands.py +119 -0
  159. sky/provision/runpod/api/pods.py +142 -0
  160. sky/provision/runpod/instance.py +64 -8
  161. sky/provision/runpod/utils.py +239 -23
  162. sky/provision/vast/__init__.py +10 -0
  163. sky/provision/vast/config.py +11 -0
  164. sky/provision/vast/instance.py +247 -0
  165. sky/provision/vast/utils.py +162 -0
  166. sky/provision/vsphere/common/vim_utils.py +1 -1
  167. sky/provision/vsphere/instance.py +8 -18
  168. sky/provision/vsphere/vsphere_utils.py +1 -1
  169. sky/resources.py +247 -102
  170. sky/serve/__init__.py +9 -9
  171. sky/serve/autoscalers.py +361 -299
  172. sky/serve/client/__init__.py +0 -0
  173. sky/serve/client/sdk.py +366 -0
  174. sky/serve/constants.py +12 -3
  175. sky/serve/controller.py +106 -36
  176. sky/serve/load_balancer.py +63 -12
  177. sky/serve/load_balancing_policies.py +84 -2
  178. sky/serve/replica_managers.py +42 -34
  179. sky/serve/serve_state.py +62 -32
  180. sky/serve/serve_utils.py +271 -160
  181. sky/serve/server/__init__.py +0 -0
  182. sky/serve/{core.py → server/core.py} +271 -90
  183. sky/serve/server/server.py +112 -0
  184. sky/serve/service.py +52 -16
  185. sky/serve/service_spec.py +95 -32
  186. sky/server/__init__.py +1 -0
  187. sky/server/common.py +430 -0
  188. sky/server/constants.py +21 -0
  189. sky/server/html/log.html +174 -0
  190. sky/server/requests/__init__.py +0 -0
  191. sky/server/requests/executor.py +472 -0
  192. sky/server/requests/payloads.py +487 -0
  193. sky/server/requests/queues/__init__.py +0 -0
  194. sky/server/requests/queues/mp_queue.py +76 -0
  195. sky/server/requests/requests.py +567 -0
  196. sky/server/requests/serializers/__init__.py +0 -0
  197. sky/server/requests/serializers/decoders.py +192 -0
  198. sky/server/requests/serializers/encoders.py +166 -0
  199. sky/server/server.py +1106 -0
  200. sky/server/stream_utils.py +141 -0
  201. sky/setup_files/MANIFEST.in +2 -5
  202. sky/setup_files/dependencies.py +159 -0
  203. sky/setup_files/setup.py +14 -125
  204. sky/sky_logging.py +59 -14
  205. sky/skylet/autostop_lib.py +2 -2
  206. sky/skylet/constants.py +183 -50
  207. sky/skylet/events.py +22 -10
  208. sky/skylet/job_lib.py +403 -258
  209. sky/skylet/log_lib.py +111 -71
  210. sky/skylet/log_lib.pyi +6 -0
  211. sky/skylet/providers/command_runner.py +6 -8
  212. sky/skylet/providers/ibm/node_provider.py +2 -2
  213. sky/skylet/providers/scp/config.py +11 -3
  214. sky/skylet/providers/scp/node_provider.py +8 -8
  215. sky/skylet/skylet.py +3 -1
  216. sky/skylet/subprocess_daemon.py +69 -17
  217. sky/skypilot_config.py +119 -57
  218. sky/task.py +205 -64
  219. sky/templates/aws-ray.yml.j2 +37 -7
  220. sky/templates/azure-ray.yml.j2 +27 -82
  221. sky/templates/cudo-ray.yml.j2 +7 -3
  222. sky/templates/do-ray.yml.j2 +98 -0
  223. sky/templates/fluidstack-ray.yml.j2 +7 -4
  224. sky/templates/gcp-ray.yml.j2 +26 -6
  225. sky/templates/ibm-ray.yml.j2 +3 -2
  226. sky/templates/jobs-controller.yaml.j2 +46 -11
  227. sky/templates/kubernetes-ingress.yml.j2 +7 -0
  228. sky/templates/kubernetes-loadbalancer.yml.j2 +7 -0
  229. sky/templates/{kubernetes-port-forward-proxy-command.sh.j2 → kubernetes-port-forward-proxy-command.sh} +51 -7
  230. sky/templates/kubernetes-ray.yml.j2 +292 -25
  231. sky/templates/lambda-ray.yml.j2 +30 -40
  232. sky/templates/nebius-ray.yml.j2 +79 -0
  233. sky/templates/oci-ray.yml.j2 +18 -57
  234. sky/templates/paperspace-ray.yml.j2 +10 -6
  235. sky/templates/runpod-ray.yml.j2 +26 -4
  236. sky/templates/scp-ray.yml.j2 +3 -2
  237. sky/templates/sky-serve-controller.yaml.j2 +12 -1
  238. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  239. sky/templates/vast-ray.yml.j2 +70 -0
  240. sky/templates/vsphere-ray.yml.j2 +8 -3
  241. sky/templates/websocket_proxy.py +64 -0
  242. sky/usage/constants.py +10 -1
  243. sky/usage/usage_lib.py +130 -37
  244. sky/utils/accelerator_registry.py +35 -51
  245. sky/utils/admin_policy_utils.py +147 -0
  246. sky/utils/annotations.py +51 -0
  247. sky/utils/cli_utils/status_utils.py +81 -23
  248. sky/utils/cluster_utils.py +356 -0
  249. sky/utils/command_runner.py +452 -89
  250. sky/utils/command_runner.pyi +77 -3
  251. sky/utils/common.py +54 -0
  252. sky/utils/common_utils.py +319 -108
  253. sky/utils/config_utils.py +204 -0
  254. sky/utils/control_master_utils.py +48 -0
  255. sky/utils/controller_utils.py +548 -266
  256. sky/utils/dag_utils.py +93 -32
  257. sky/utils/db_utils.py +18 -4
  258. sky/utils/env_options.py +29 -7
  259. sky/utils/kubernetes/create_cluster.sh +8 -60
  260. sky/utils/kubernetes/deploy_remote_cluster.sh +243 -0
  261. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  262. sky/utils/kubernetes/generate_kubeconfig.sh +336 -0
  263. sky/utils/kubernetes/gpu_labeler.py +4 -4
  264. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +4 -3
  265. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  266. sky/utils/kubernetes/rsync_helper.sh +24 -0
  267. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +1 -1
  268. sky/utils/log_utils.py +240 -33
  269. sky/utils/message_utils.py +81 -0
  270. sky/utils/registry.py +127 -0
  271. sky/utils/resources_utils.py +94 -22
  272. sky/utils/rich_utils.py +247 -18
  273. sky/utils/schemas.py +284 -64
  274. sky/{status_lib.py → utils/status_lib.py} +12 -7
  275. sky/utils/subprocess_utils.py +212 -46
  276. sky/utils/timeline.py +12 -7
  277. sky/utils/ux_utils.py +168 -15
  278. skypilot_nightly-1.0.0.dev2025022801.dist-info/METADATA +363 -0
  279. skypilot_nightly-1.0.0.dev2025022801.dist-info/RECORD +352 -0
  280. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/WHEEL +1 -1
  281. sky/clouds/cloud_registry.py +0 -31
  282. sky/jobs/core.py +0 -330
  283. sky/skylet/providers/azure/__init__.py +0 -2
  284. sky/skylet/providers/azure/azure-vm-template.json +0 -301
  285. sky/skylet/providers/azure/config.py +0 -170
  286. sky/skylet/providers/azure/node_provider.py +0 -466
  287. sky/skylet/providers/lambda_cloud/__init__.py +0 -2
  288. sky/skylet/providers/lambda_cloud/node_provider.py +0 -320
  289. sky/skylet/providers/oci/__init__.py +0 -2
  290. sky/skylet/providers/oci/node_provider.py +0 -488
  291. sky/skylet/providers/oci/query_helper.py +0 -383
  292. sky/skylet/providers/oci/utils.py +0 -21
  293. sky/utils/cluster_yaml_utils.py +0 -24
  294. sky/utils/kubernetes/generate_static_kubeconfig.sh +0 -137
  295. skypilot_nightly-1.0.0.dev2024053101.dist-info/METADATA +0 -315
  296. skypilot_nightly-1.0.0.dev2024053101.dist-info/RECORD +0 -275
  297. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/LICENSE +0 -0
  298. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/entry_points.txt +0 -0
  299. {skypilot_nightly-1.0.0.dev2024053101.dist-info → skypilot_nightly-1.0.0.dev2025022801.dist-info}/top_level.txt +0 -0
sky/jobs/state.py CHANGED
@@ -2,6 +2,7 @@
2
2
  # TODO(zhwu): maybe use file based status instead of database, so
3
3
  # that we can easily switch to a s3-based storage.
4
4
  import enum
5
+ import json
5
6
  import pathlib
6
7
  import sqlite3
7
8
  import time
@@ -10,7 +11,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
11
 
11
12
  import colorama
12
13
 
14
+ from sky import exceptions
13
15
  from sky import sky_logging
16
+ from sky.utils import common_utils
14
17
  from sky.utils import db_utils
15
18
 
16
19
  if typing.TYPE_CHECKING:
@@ -20,15 +23,6 @@ CallbackType = Callable[[str], None]
20
23
 
21
24
  logger = sky_logging.init_logger(__name__)
22
25
 
23
- _DB_PATH = pathlib.Path('~/.sky/spot_jobs.db')
24
- _DB_PATH = _DB_PATH.expanduser().absolute()
25
- _DB_PATH.parents[0].mkdir(parents=True, exist_ok=True)
26
- _DB_PATH = str(_DB_PATH)
27
-
28
- # Module-level connection/cursor; thread-safe as the module is only imported
29
- # once.
30
- _CONN = sqlite3.connect(_DB_PATH)
31
- _CURSOR = _CONN.cursor()
32
26
 
33
27
  # === Database schema ===
34
28
  # `spot` table contains all the finest-grained tasks, including all the
@@ -39,58 +33,124 @@ _CURSOR = _CONN.cursor()
39
33
  # the same content as the `task_name` column.
40
34
  # The `job_id` is now not really a job id, but a only a unique
41
35
  # identifier/primary key for all the tasks. We will use `spot_job_id`
42
- # to identify the spot job.
36
+ # to identify the job.
43
37
  # TODO(zhwu): schema migration may be needed.
44
- _CURSOR.execute("""\
45
- CREATE TABLE IF NOT EXISTS spot (
46
- job_id INTEGER PRIMARY KEY AUTOINCREMENT,
47
- job_name TEXT,
48
- resources TEXT,
49
- submitted_at FLOAT,
50
- status TEXT,
51
- run_timestamp TEXT CANDIDATE KEY,
52
- start_at FLOAT DEFAULT NULL,
53
- end_at FLOAT DEFAULT NULL,
54
- last_recovered_at FLOAT DEFAULT -1,
55
- recovery_count INTEGER DEFAULT 0,
56
- job_duration FLOAT DEFAULT 0,
57
- failure_reason TEXT,
58
- spot_job_id INTEGER,
59
- task_id INTEGER DEFAULT 0,
60
- task_name TEXT)""")
61
- _CONN.commit()
62
-
63
- db_utils.add_column_to_table(_CURSOR, _CONN, 'spot', 'failure_reason', 'TEXT')
64
- # Create a new column `spot_job_id`, which is the same for tasks of the
65
- # same managed job.
66
- # The original `job_id` no longer has an actual meaning, but only a legacy
67
- # identifier for all tasks in database.
68
- db_utils.add_column_to_table(_CURSOR,
69
- _CONN,
70
- 'spot',
71
- 'spot_job_id',
72
- 'INTEGER',
73
- copy_from='job_id')
74
- db_utils.add_column_to_table(_CURSOR,
75
- _CONN,
76
- 'spot',
77
- 'task_id',
78
- 'INTEGER DEFAULT 0',
79
- value_to_replace_existing_entries=0)
80
- db_utils.add_column_to_table(_CURSOR,
81
- _CONN,
82
- 'spot',
83
- 'task_name',
84
- 'TEXT',
85
- copy_from='job_name')
86
-
87
- # `job_info` contains the mapping from job_id to the job_name.
88
- # In the future, it may contain more information about each job.
89
- _CURSOR.execute("""\
90
- CREATE TABLE IF NOT EXISTS job_info (
91
- spot_job_id INTEGER PRIMARY KEY AUTOINCREMENT,
92
- name TEXT)""")
93
- _CONN.commit()
38
+ def create_table(cursor, conn):
39
+ # Enable WAL mode to avoid locking issues.
40
+ # See: issue #3863, #1441 and PR #1509
41
+ # https://github.com/microsoft/WSL/issues/2395
42
+ # TODO(romilb): We do not enable WAL for WSL because of known issue in WSL.
43
+ # This may cause the database locked problem from WSL issue #1441.
44
+ if not common_utils.is_wsl():
45
+ try:
46
+ cursor.execute('PRAGMA journal_mode=WAL')
47
+ except sqlite3.OperationalError as e:
48
+ if 'database is locked' not in str(e):
49
+ raise
50
+ # If the database is locked, it is OK to continue, as the WAL mode
51
+ # is not critical and is likely to be enabled by other processes.
52
+
53
+ cursor.execute("""\
54
+ CREATE TABLE IF NOT EXISTS spot (
55
+ job_id INTEGER PRIMARY KEY AUTOINCREMENT,
56
+ job_name TEXT,
57
+ resources TEXT,
58
+ submitted_at FLOAT,
59
+ status TEXT,
60
+ run_timestamp TEXT CANDIDATE KEY,
61
+ start_at FLOAT DEFAULT NULL,
62
+ end_at FLOAT DEFAULT NULL,
63
+ last_recovered_at FLOAT DEFAULT -1,
64
+ recovery_count INTEGER DEFAULT 0,
65
+ job_duration FLOAT DEFAULT 0,
66
+ failure_reason TEXT,
67
+ spot_job_id INTEGER,
68
+ task_id INTEGER DEFAULT 0,
69
+ task_name TEXT,
70
+ specs TEXT,
71
+ local_log_file TEXT DEFAULT NULL)""")
72
+ conn.commit()
73
+
74
+ db_utils.add_column_to_table(cursor, conn, 'spot', 'failure_reason', 'TEXT')
75
+ # Create a new column `spot_job_id`, which is the same for tasks of the
76
+ # same managed job.
77
+ # The original `job_id` no longer has an actual meaning, but only a legacy
78
+ # identifier for all tasks in database.
79
+ db_utils.add_column_to_table(cursor,
80
+ conn,
81
+ 'spot',
82
+ 'spot_job_id',
83
+ 'INTEGER',
84
+ copy_from='job_id')
85
+ db_utils.add_column_to_table(cursor,
86
+ conn,
87
+ 'spot',
88
+ 'task_id',
89
+ 'INTEGER DEFAULT 0',
90
+ value_to_replace_existing_entries=0)
91
+ db_utils.add_column_to_table(cursor,
92
+ conn,
93
+ 'spot',
94
+ 'task_name',
95
+ 'TEXT',
96
+ copy_from='job_name')
97
+
98
+ # Specs is some useful information about the task, e.g., the
99
+ # max_restarts_on_errors value. It is stored in JSON format.
100
+ db_utils.add_column_to_table(cursor,
101
+ conn,
102
+ 'spot',
103
+ 'specs',
104
+ 'TEXT',
105
+ value_to_replace_existing_entries=json.dumps({
106
+ 'max_restarts_on_errors': 0,
107
+ }))
108
+ db_utils.add_column_to_table(cursor, conn, 'spot', 'local_log_file',
109
+ 'TEXT DEFAULT NULL')
110
+
111
+ # `job_info` contains the mapping from job_id to the job_name, as well as
112
+ # information used by the scheduler.
113
+ cursor.execute("""\
114
+ CREATE TABLE IF NOT EXISTS job_info (
115
+ spot_job_id INTEGER PRIMARY KEY AUTOINCREMENT,
116
+ name TEXT,
117
+ schedule_state TEXT,
118
+ controller_pid INTEGER DEFAULT NULL,
119
+ dag_yaml_path TEXT,
120
+ env_file_path TEXT,
121
+ user_hash TEXT)""")
122
+
123
+ db_utils.add_column_to_table(cursor, conn, 'job_info', 'schedule_state',
124
+ 'TEXT')
125
+
126
+ db_utils.add_column_to_table(cursor, conn, 'job_info', 'controller_pid',
127
+ 'INTEGER DEFAULT NULL')
128
+
129
+ db_utils.add_column_to_table(cursor, conn, 'job_info', 'dag_yaml_path',
130
+ 'TEXT')
131
+
132
+ db_utils.add_column_to_table(cursor, conn, 'job_info', 'env_file_path',
133
+ 'TEXT')
134
+
135
+ db_utils.add_column_to_table(cursor, conn, 'job_info', 'user_hash', 'TEXT')
136
+
137
+ conn.commit()
138
+
139
+
140
+ # Module-level connection/cursor; thread-safe as the module is only imported
141
+ # once.
142
+ def _get_db_path() -> str:
143
+ """Workaround to collapse multi-step Path ops for type checker.
144
+ Ensures _DB_PATH is str, avoiding Union[Path, str] inference.
145
+ """
146
+ path = pathlib.Path('~/.sky/spot_jobs.db')
147
+ path = path.expanduser().absolute()
148
+ path.parents[0].mkdir(parents=True, exist_ok=True)
149
+ return str(path)
150
+
151
+
152
+ _DB_PATH = _get_db_path()
153
+ db_utils.SQLiteConn(_DB_PATH, create_table)
94
154
 
95
155
  # job_duration is the time a job actually runs (including the
96
156
  # setup duration) before last_recover, excluding the provision
@@ -120,9 +180,16 @@ columns = [
120
180
  'job_id',
121
181
  'task_id',
122
182
  'task_name',
183
+ 'specs',
184
+ 'local_log_file',
123
185
  # columns from the job_info table
124
186
  '_job_info_job_id', # This should be the same as job_id
125
- 'job_name'
187
+ 'job_name',
188
+ 'schedule_state',
189
+ 'controller_pid',
190
+ 'dag_yaml_path',
191
+ 'env_file_path',
192
+ 'user_hash',
126
193
  ]
127
194
 
128
195
 
@@ -148,16 +215,19 @@ class ManagedJobStatus(enum.Enum):
148
215
  SUCCEEDED -> SUCCEEDED
149
216
  FAILED -> FAILED
150
217
  FAILED_SETUP -> FAILED_SETUP
218
+ Not all statuses are in this list, since some ManagedJobStatuses are only
219
+ possible while the cluster is INIT/STOPPED/not yet UP.
151
220
  Note that the JobStatus will not be stuck in PENDING, because each cluster
152
221
  is dedicated to a managed job, i.e. there should always be enough resource
153
222
  to run the job and the job will be immediately transitioned to RUNNING.
223
+
224
+ You can see a state diagram for ManagedJobStatus in sky/jobs/README.md.
154
225
  """
155
226
  # PENDING: Waiting for the jobs controller to have a slot to run the
156
227
  # controller process.
157
- # The submitted_at timestamp of the managed job in the 'spot' table will be
158
- # set to the time when the job is firstly submitted by the user (set to
159
- # PENDING).
160
228
  PENDING = 'PENDING'
229
+ # The submitted_at timestamp of the managed job in the 'spot' table will be
230
+ # set to the time when the job controller begins running.
161
231
  # SUBMITTED: The jobs controller starts the controller process.
162
232
  SUBMITTED = 'SUBMITTED'
163
233
  # STARTING: The controller process is launching the cluster for the managed
@@ -171,12 +241,12 @@ class ManagedJobStatus(enum.Enum):
171
241
  # RECOVERING: The cluster is preempted, and the controller process is
172
242
  # recovering the cluster (relaunching/failover).
173
243
  RECOVERING = 'RECOVERING'
174
- # Terminal statuses
175
- # SUCCEEDED: The job is finished successfully.
176
- SUCCEEDED = 'SUCCEEDED'
177
244
  # CANCELLING: The job is requested to be cancelled by the user, and the
178
245
  # controller is cleaning up the cluster.
179
246
  CANCELLING = 'CANCELLING'
247
+ # Terminal statuses
248
+ # SUCCEEDED: The job is finished successfully.
249
+ SUCCEEDED = 'SUCCEEDED'
180
250
  # CANCELLED: The job is cancelled by the user. When the managed job is in
181
251
  # CANCELLED status, the cluster has been cleaned up.
182
252
  CANCELLED = 'CANCELLED'
@@ -222,7 +292,6 @@ class ManagedJobStatus(enum.Enum):
222
292
  cls.FAILED_PRECHECKS,
223
293
  cls.FAILED_NO_RESOURCE,
224
294
  cls.FAILED_CONTROLLER,
225
- cls.CANCELLING,
226
295
  cls.CANCELLED,
227
296
  ]
228
297
 
@@ -251,14 +320,74 @@ _SPOT_STATUS_TO_COLOR = {
251
320
  }
252
321
 
253
322
 
323
+ class ManagedJobScheduleState(enum.Enum):
324
+ """Captures the state of the job from the scheduler's perspective.
325
+
326
+ A job that predates the introduction of the scheduler will be INVALID.
327
+
328
+ A newly created job will be INACTIVE. The following transitions are valid:
329
+ - INACTIVE -> WAITING: The job is "submitted" to the scheduler, and its job
330
+ controller can be started.
331
+ - WAITING -> LAUNCHING: The job controller is starting by the scheduler and
332
+ may proceed to sky.launch.
333
+ - LAUNCHING -> ALIVE: The launch attempt was completed. It may have
334
+ succeeded or failed. The job controller is not allowed to sky.launch again
335
+ without transitioning to ALIVE_WAITING and then LAUNCHING.
336
+ - ALIVE -> ALIVE_WAITING: The job controller wants to sky.launch again,
337
+ either for recovery or to launch a subsequent task.
338
+ - ALIVE_WAITING -> LAUNCHING: The scheduler has determined that the job
339
+ controller may launch again.
340
+ - LAUNCHING, ALIVE, or ALIVE_WAITING -> DONE: The job controller is exiting
341
+ and the job is in some terminal status. In the future it may be possible
342
+ to transition directly from WAITING or even INACTIVE to DONE if the job is
343
+ cancelled.
344
+
345
+ You can see a state diagram in sky/jobs/README.md.
346
+
347
+ There is no well-defined mapping from the managed job status to schedule
348
+ state or vice versa. (In fact, schedule state is defined on the job and
349
+ status on the task.)
350
+ - INACTIVE or WAITING should only be seen when a job is PENDING.
351
+ - ALIVE_WAITING should only be seen when a job is RECOVERING, has multiple
352
+ tasks, or needs to retry launching.
353
+ - LAUNCHING and ALIVE can be seen in many different statuses.
354
+ - DONE should only be seen when a job is in a terminal status.
355
+ Since state and status transitions are not atomic, it may be possible to
356
+ briefly observe inconsistent states, like a job that just finished but
357
+ hasn't yet transitioned to DONE.
358
+ """
359
+ # This job may have been created before scheduler was introduced in #4458.
360
+ # This state is not used by scheduler but just for backward compatibility.
361
+ # TODO(cooperc): remove this in v0.11.0
362
+ INVALID = None
363
+ # The job should be ignored by the scheduler.
364
+ INACTIVE = 'INACTIVE'
365
+ # The job is waiting to transition to LAUNCHING for the first time. The
366
+ # scheduler should try to transition it, and when it does, it should start
367
+ # the job controller.
368
+ WAITING = 'WAITING'
369
+ # The job is already alive, but wants to transition back to LAUNCHING,
370
+ # e.g. for recovery, or launching later tasks in the DAG. The scheduler
371
+ # should try to transition it to LAUNCHING.
372
+ ALIVE_WAITING = 'ALIVE_WAITING'
373
+ # The job is running sky.launch, or soon will, using a limited number of
374
+ # allowed launch slots.
375
+ LAUNCHING = 'LAUNCHING'
376
+ # The controller for the job is running, but it's not currently launching.
377
+ ALIVE = 'ALIVE'
378
+ # The job is in a terminal state. (Not necessarily SUCCEEDED.)
379
+ DONE = 'DONE'
380
+
381
+
254
382
  # === Status transition functions ===
255
- def set_job_name(job_id: int, name: str):
383
+ def set_job_info(job_id: int, name: str):
256
384
  with db_utils.safe_cursor(_DB_PATH) as cursor:
257
385
  cursor.execute(
258
386
  """\
259
387
  INSERT INTO job_info
260
- (spot_job_id, name)
261
- VALUES (?, ?)""", (job_id, name))
388
+ (spot_job_id, name, schedule_state)
389
+ VALUES (?, ?, ?)""",
390
+ (job_id, name, ManagedJobScheduleState.INACTIVE.value))
262
391
 
263
392
 
264
393
  def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str):
@@ -275,16 +404,19 @@ def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str):
275
404
 
276
405
  def set_submitted(job_id: int, task_id: int, run_timestamp: str,
277
406
  submit_time: float, resources_str: str,
278
- callback_func: CallbackType):
407
+ specs: Dict[str, Union[str,
408
+ int]], callback_func: CallbackType):
279
409
  """Set the task to submitted.
280
410
 
281
411
  Args:
282
412
  job_id: The managed job ID.
283
413
  task_id: The task ID.
284
414
  run_timestamp: The run_timestamp of the run. This will be used to
285
- determine the log directory of the managed task.
415
+ determine the log directory of the managed task.
286
416
  submit_time: The time when the managed task is submitted.
287
417
  resources_str: The resources string of the managed task.
418
+ specs: The specs of the managed task.
419
+ callback_func: The callback function.
288
420
  """
289
421
  # Use the timestamp in the `run_timestamp` ('sky-2022-10...'), to make
290
422
  # the log directory and submission time align with each other, so as to
@@ -298,11 +430,19 @@ def set_submitted(job_id: int, task_id: int, run_timestamp: str,
298
430
  resources=(?),
299
431
  submitted_at=(?),
300
432
  status=(?),
301
- run_timestamp=(?)
433
+ run_timestamp=(?),
434
+ specs=(?)
302
435
  WHERE spot_job_id=(?) AND
303
- task_id=(?)""",
436
+ task_id=(?) AND
437
+ status=(?) AND
438
+ end_at IS null""",
304
439
  (resources_str, submit_time, ManagedJobStatus.SUBMITTED.value,
305
- run_timestamp, job_id, task_id))
440
+ run_timestamp, json.dumps(specs), job_id, task_id,
441
+ ManagedJobStatus.PENDING.value))
442
+ if cursor.rowcount != 1:
443
+ raise exceptions.ManagedJobStatusError(
444
+ f'Failed to set the task to submitted. '
445
+ f'({cursor.rowcount} rows updated)')
306
446
  callback_func('SUBMITTED')
307
447
 
308
448
 
@@ -314,7 +454,14 @@ def set_starting(job_id: int, task_id: int, callback_func: CallbackType):
314
454
  """\
315
455
  UPDATE spot SET status=(?)
316
456
  WHERE spot_job_id=(?) AND
317
- task_id=(?)""", (ManagedJobStatus.STARTING.value, job_id, task_id))
457
+ task_id=(?) AND
458
+ status=(?) AND
459
+ end_at IS null""", (ManagedJobStatus.STARTING.value, job_id,
460
+ task_id, ManagedJobStatus.SUBMITTED.value))
461
+ if cursor.rowcount != 1:
462
+ raise exceptions.ManagedJobStatusError(
463
+ f'Failed to set the task to starting. '
464
+ f'({cursor.rowcount} rows updated)')
318
465
  callback_func('STARTING')
319
466
 
320
467
 
@@ -327,15 +474,25 @@ def set_started(job_id: int, task_id: int, start_time: float,
327
474
  """\
328
475
  UPDATE spot SET status=(?), start_at=(?), last_recovered_at=(?)
329
476
  WHERE spot_job_id=(?) AND
330
- task_id=(?)""",
477
+ task_id=(?) AND
478
+ status IN (?, ?) AND
479
+ end_at IS null""",
331
480
  (
332
481
  ManagedJobStatus.RUNNING.value,
333
482
  start_time,
334
483
  start_time,
335
484
  job_id,
336
485
  task_id,
486
+ ManagedJobStatus.STARTING.value,
487
+ # If the task is empty, we will jump straight from PENDING to
488
+ # RUNNING
489
+ ManagedJobStatus.PENDING.value,
337
490
  ),
338
491
  )
492
+ if cursor.rowcount != 1:
493
+ raise exceptions.ManagedJobStatusError(
494
+ f'Failed to set the task to started. '
495
+ f'({cursor.rowcount} rows updated)')
339
496
  callback_func('STARTED')
340
497
 
341
498
 
@@ -348,8 +505,15 @@ def set_recovering(job_id: int, task_id: int, callback_func: CallbackType):
348
505
  UPDATE spot SET
349
506
  status=(?), job_duration=job_duration+(?)-last_recovered_at
350
507
  WHERE spot_job_id=(?) AND
351
- task_id=(?)""",
352
- (ManagedJobStatus.RECOVERING.value, time.time(), job_id, task_id))
508
+ task_id=(?) AND
509
+ status=(?) AND
510
+ end_at IS null""",
511
+ (ManagedJobStatus.RECOVERING.value, time.time(), job_id, task_id,
512
+ ManagedJobStatus.RUNNING.value))
513
+ if cursor.rowcount != 1:
514
+ raise exceptions.ManagedJobStatusError(
515
+ f'Failed to set the task to recovering. '
516
+ f'({cursor.rowcount} rows updated)')
353
517
  callback_func('RECOVERING')
354
518
 
355
519
 
@@ -362,8 +526,15 @@ def set_recovered(job_id: int, task_id: int, recovered_time: float,
362
526
  UPDATE spot SET
363
527
  status=(?), last_recovered_at=(?), recovery_count=recovery_count+1
364
528
  WHERE spot_job_id=(?) AND
365
- task_id=(?)""",
366
- (ManagedJobStatus.RUNNING.value, recovered_time, job_id, task_id))
529
+ task_id=(?) AND
530
+ status=(?) AND
531
+ end_at IS null""",
532
+ (ManagedJobStatus.RUNNING.value, recovered_time, job_id, task_id,
533
+ ManagedJobStatus.RECOVERING.value))
534
+ if cursor.rowcount != 1:
535
+ raise exceptions.ManagedJobStatusError(
536
+ f'Failed to set the task to recovered. '
537
+ f'({cursor.rowcount} rows updated)')
367
538
  logger.info('==== Recovered. ====')
368
539
  callback_func('RECOVERED')
369
540
 
@@ -376,10 +547,16 @@ def set_succeeded(job_id: int, task_id: int, end_time: float,
376
547
  """\
377
548
  UPDATE spot SET
378
549
  status=(?), end_at=(?)
379
- WHERE spot_job_id=(?) AND task_id=(?)
380
- AND end_at IS null""",
381
- (ManagedJobStatus.SUCCEEDED.value, end_time, job_id, task_id))
382
-
550
+ WHERE spot_job_id=(?) AND
551
+ task_id=(?) AND
552
+ status=(?) AND
553
+ end_at IS null""",
554
+ (ManagedJobStatus.SUCCEEDED.value, end_time, job_id, task_id,
555
+ ManagedJobStatus.RUNNING.value))
556
+ if cursor.rowcount != 1:
557
+ raise exceptions.ManagedJobStatusError(
558
+ f'Failed to set the task to succeeded. '
559
+ f'({cursor.rowcount} rows updated)')
383
560
  callback_func('SUCCEEDED')
384
561
  logger.info('Job succeeded.')
385
562
 
@@ -391,8 +568,12 @@ def set_failed(
391
568
  failure_reason: str,
392
569
  callback_func: Optional[CallbackType] = None,
393
570
  end_time: Optional[float] = None,
571
+ override_terminal: bool = False,
394
572
  ):
395
- """Set an entire job or task to failed, if they are in non-terminal states.
573
+ """Set an entire job or task to failed.
574
+
575
+ By default, don't override tasks that are already terminal (that is, for
576
+ which end_at is already set).
396
577
 
397
578
  Args:
398
579
  job_id: The job id.
@@ -401,36 +582,55 @@ def set_failed(
401
582
  failure_type: The failure type. One of ManagedJobStatus.FAILED_*.
402
583
  failure_reason: The failure reason.
403
584
  end_time: The end time. If None, the current time will be used.
585
+ override_terminal: If True, override the current status even if end_at
586
+ is already set.
404
587
  """
405
588
  assert failure_type.is_failed(), failure_type
406
589
  end_time = time.time() if end_time is None else end_time
407
590
 
408
- fields_to_set = {
409
- 'end_at': end_time,
591
+ fields_to_set: Dict[str, Any] = {
410
592
  'status': failure_type.value,
411
593
  'failure_reason': failure_reason,
412
594
  }
413
595
  with db_utils.safe_cursor(_DB_PATH) as cursor:
414
596
  previous_status = cursor.execute(
415
597
  'SELECT status FROM spot WHERE spot_job_id=(?)',
416
- (job_id,)).fetchone()
417
- previous_status = ManagedJobStatus(previous_status[0])
418
- if previous_status in [ManagedJobStatus.RECOVERING]:
419
- # If the job is recovering, we should set the
420
- # last_recovered_at to the end_time, so that the
421
- # end_at - last_recovered_at will not be affect the job duration
422
- # calculation.
598
+ (job_id,)).fetchone()[0]
599
+ previous_status = ManagedJobStatus(previous_status)
600
+ if previous_status == ManagedJobStatus.RECOVERING:
601
+ # If the job is recovering, we should set the last_recovered_at to
602
+ # the end_time, so that the end_at - last_recovered_at will not be
603
+ # affect the job duration calculation.
423
604
  fields_to_set['last_recovered_at'] = end_time
424
605
  set_str = ', '.join(f'{k}=(?)' for k in fields_to_set)
425
- task_str = '' if task_id is None else f' AND task_id={task_id}'
606
+ task_query_str = '' if task_id is None else 'AND task_id=(?)'
607
+ task_value = [] if task_id is None else [
608
+ task_id,
609
+ ]
426
610
 
427
- cursor.execute(
428
- f"""\
429
- UPDATE spot SET
430
- {set_str}
431
- WHERE spot_job_id=(?){task_str} AND end_at IS null""",
432
- (*list(fields_to_set.values()), job_id))
433
- if callback_func:
611
+ if override_terminal:
612
+ # Use COALESCE for end_at to avoid overriding the existing end_at if
613
+ # it's already set.
614
+ cursor.execute(
615
+ f"""\
616
+ UPDATE spot SET
617
+ end_at = COALESCE(end_at, ?),
618
+ {set_str}
619
+ WHERE spot_job_id=(?) {task_query_str}""",
620
+ (end_time, *list(fields_to_set.values()), job_id, *task_value))
621
+ else:
622
+ # Only set if end_at is null, i.e. the previous status is not
623
+ # terminal.
624
+ cursor.execute(
625
+ f"""\
626
+ UPDATE spot SET
627
+ end_at = (?),
628
+ {set_str}
629
+ WHERE spot_job_id=(?) {task_query_str} AND end_at IS null""",
630
+ (end_time, *list(fields_to_set.values()), job_id, *task_value))
631
+
632
+ updated = cursor.rowcount > 0
633
+ if callback_func and updated:
434
634
  callback_func('FAILED')
435
635
  logger.info(failure_reason)
436
636
 
@@ -445,12 +645,15 @@ def set_cancelling(job_id: int, callback_func: CallbackType):
445
645
  rows = cursor.execute(
446
646
  """\
447
647
  UPDATE spot SET
448
- status=(?), end_at=(?)
648
+ status=(?)
449
649
  WHERE spot_job_id=(?) AND end_at IS null""",
450
- (ManagedJobStatus.CANCELLING.value, time.time(), job_id))
451
- if rows.rowcount > 0:
452
- logger.info('Cancelling the job...')
453
- callback_func('CANCELLING')
650
+ (ManagedJobStatus.CANCELLING.value, job_id))
651
+ updated = rows.rowcount > 0
652
+ if updated:
653
+ logger.info('Cancelling the job...')
654
+ callback_func('CANCELLING')
655
+ else:
656
+ logger.info('Cancellation skipped, job is already terminal')
454
657
 
455
658
 
456
659
  def set_cancelled(job_id: int, callback_func: CallbackType):
@@ -466,26 +669,47 @@ def set_cancelled(job_id: int, callback_func: CallbackType):
466
669
  WHERE spot_job_id=(?) AND status=(?)""",
467
670
  (ManagedJobStatus.CANCELLED.value, time.time(), job_id,
468
671
  ManagedJobStatus.CANCELLING.value))
469
- if rows.rowcount > 0:
470
- logger.info('Job cancelled.')
471
- callback_func('CANCELLED')
672
+ updated = rows.rowcount > 0
673
+ if updated:
674
+ logger.info('Job cancelled.')
675
+ callback_func('CANCELLED')
676
+ else:
677
+ logger.info('Cancellation skipped, job is not CANCELLING')
678
+
679
+
680
+ def set_local_log_file(job_id: int, task_id: Optional[int],
681
+ local_log_file: str):
682
+ """Set the local log file for a job."""
683
+ filter_str = 'spot_job_id=(?)'
684
+ filter_args = [local_log_file, job_id]
685
+ if task_id is not None:
686
+ filter_str += ' AND task_id=(?)'
687
+ filter_args.append(task_id)
688
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
689
+ cursor.execute(
690
+ 'UPDATE spot SET local_log_file=(?) '
691
+ f'WHERE {filter_str}', filter_args)
472
692
 
473
693
 
474
694
  # ======== utility functions ========
475
- def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]:
695
+ def get_nonterminal_job_ids_by_name(name: Optional[str],
696
+ all_users: bool = False) -> List[int]:
476
697
  """Get non-terminal job ids by name."""
477
698
  statuses = ', '.join(['?'] * len(ManagedJobStatus.terminal_statuses()))
478
699
  field_values = [
479
700
  status.value for status in ManagedJobStatus.terminal_statuses()
480
701
  ]
481
702
 
482
- name_filter = ''
703
+ job_filter = ''
704
+ if name is None and not all_users:
705
+ job_filter += 'AND (job_info.user_hash=(?)) '
706
+ field_values.append(common_utils.get_user_hash())
483
707
  if name is not None:
484
708
  # We match the job name from `job_info` for the jobs submitted after
485
709
  # #1982, and from `spot` for the jobs submitted before #1982, whose
486
710
  # job_info is not available.
487
- name_filter = ('AND (job_info.name=(?) OR '
488
- '(job_info.name IS NULL AND spot.task_name=(?)))')
711
+ job_filter += ('AND (job_info.name=(?) OR '
712
+ '(job_info.name IS NULL AND spot.task_name=(?))) ')
489
713
  field_values.extend([name, name])
490
714
 
491
715
  # Left outer join is used here instead of join, because the job_info does
@@ -499,6 +723,127 @@ def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]:
499
723
  ON spot.spot_job_id=job_info.spot_job_id
500
724
  WHERE status NOT IN
501
725
  ({statuses})
726
+ {job_filter}
727
+ ORDER BY spot.spot_job_id DESC""", field_values).fetchall()
728
+ job_ids = [row[0] for row in rows if row[0] is not None]
729
+ return job_ids
730
+
731
+
732
+ def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]:
733
+ """Get jobs from the database that have a live schedule_state.
734
+
735
+ This should return job(s) that are not INACTIVE, WAITING, or DONE. So a
736
+ returned job should correspond to a live job controller process, with one
737
+ exception: the job may have just transitioned from WAITING to LAUNCHING, but
738
+ the controller process has not yet started.
739
+ """
740
+ job_filter = '' if job_id is None else 'AND spot_job_id=(?)'
741
+ job_value = (job_id,) if job_id is not None else ()
742
+
743
+ # Join spot and job_info tables to get the job name for each task.
744
+ # We use LEFT OUTER JOIN mainly for backward compatibility, as for an
745
+ # existing controller before #1982, the job_info table may not exist,
746
+ # and all the managed jobs created before will not present in the
747
+ # job_info.
748
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
749
+ rows = cursor.execute(
750
+ f"""\
751
+ SELECT spot_job_id, schedule_state, controller_pid
752
+ FROM job_info
753
+ WHERE schedule_state not in (?, ?, ?)
754
+ {job_filter}
755
+ ORDER BY spot_job_id DESC""",
756
+ (ManagedJobScheduleState.INACTIVE.value,
757
+ ManagedJobScheduleState.WAITING.value,
758
+ ManagedJobScheduleState.DONE.value, *job_value)).fetchall()
759
+ jobs = []
760
+ for row in rows:
761
+ job_dict = {
762
+ 'job_id': row[0],
763
+ 'schedule_state': ManagedJobScheduleState(row[1]),
764
+ 'controller_pid': row[2],
765
+ }
766
+ jobs.append(job_dict)
767
+ return jobs
768
+
769
+
770
+ def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
771
+ """Get jobs that need controller process checking.
772
+
773
+ Args:
774
+ job_id: Optional job ID to check. If None, checks all jobs.
775
+
776
+ Returns a list of job_ids, including the following:
777
+ - Jobs that have a schedule_state that is not DONE
778
+ - Jobs have schedule_state DONE but are in a non-terminal status
779
+ - Legacy jobs (that is, no schedule state) that are in non-terminal status
780
+ """
781
+ job_filter = '' if job_id is None else 'AND spot.spot_job_id=(?)'
782
+ job_value = () if job_id is None else (job_id,)
783
+
784
+ status_filter_str = ', '.join(['?'] *
785
+ len(ManagedJobStatus.terminal_statuses()))
786
+ terminal_status_values = [
787
+ status.value for status in ManagedJobStatus.terminal_statuses()
788
+ ]
789
+
790
+ # Get jobs that are either:
791
+ # 1. Have schedule state that is not DONE, or
792
+ # 2. Have schedule state DONE AND are in non-terminal status (unexpected
793
+ # inconsistent state), or
794
+ # 3. Have no schedule state (legacy) AND are in non-terminal status
795
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
796
+ rows = cursor.execute(
797
+ f"""\
798
+ SELECT DISTINCT spot.spot_job_id
799
+ FROM spot
800
+ LEFT OUTER JOIN job_info
801
+ ON spot.spot_job_id=job_info.spot_job_id
802
+ WHERE (
803
+ -- non-legacy jobs that are not DONE
804
+ (job_info.schedule_state IS NOT NULL AND
805
+ job_info.schedule_state IS NOT ?)
806
+ OR
807
+ -- legacy or that are in non-terminal status or
808
+ -- DONE jobs that are in non-terminal status
809
+ ((-- legacy jobs
810
+ job_info.schedule_state IS NULL OR
811
+ -- non-legacy DONE jobs
812
+ job_info.schedule_state IS ?
813
+ ) AND
814
+ -- non-terminal
815
+ status NOT IN ({status_filter_str}))
816
+ )
817
+ {job_filter}
818
+ ORDER BY spot.spot_job_id DESC""", [
819
+ ManagedJobScheduleState.DONE.value,
820
+ ManagedJobScheduleState.DONE.value, *terminal_status_values,
821
+ *job_value
822
+ ]).fetchall()
823
+ return [row[0] for row in rows if row[0] is not None]
824
+
825
+
826
+ def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
827
+ """Get all job ids by name."""
828
+ name_filter = ''
829
+ field_values = []
830
+ if name is not None:
831
+ # We match the job name from `job_info` for the jobs submitted after
832
+ # #1982, and from `spot` for the jobs submitted before #1982, whose
833
+ # job_info is not available.
834
+ name_filter = ('WHERE (job_info.name=(?) OR '
835
+ '(job_info.name IS NULL AND spot.task_name=(?)))')
836
+ field_values = [name, name]
837
+
838
+ # Left outer join is used here instead of join, because the job_info does
839
+ # not contain the managed jobs submitted before #1982.
840
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
841
+ rows = cursor.execute(
842
+ f"""\
843
+ SELECT DISTINCT spot.spot_job_id
844
+ FROM spot
845
+ LEFT OUTER JOIN job_info
846
+ ON spot.spot_job_id=job_info.spot_job_id
502
847
  {name_filter}
503
848
  ORDER BY spot.spot_job_id DESC""", field_values).fetchall()
504
849
  job_ids = [row[0] for row in rows if row[0] is not None]
@@ -532,12 +877,14 @@ def get_latest_task_id_status(
532
877
  If the job_id does not exist, (None, None) will be returned.
533
878
  """
534
879
  id_statuses = _get_all_task_ids_statuses(job_id)
535
- if len(id_statuses) == 0:
880
+ if not id_statuses:
536
881
  return None, None
537
- task_id, status = id_statuses[-1]
538
- for task_id, status in id_statuses:
539
- if not status.is_terminal():
540
- break
882
+ task_id, status = next(
883
+ ((tid, st) for tid, st in id_statuses if not st.is_terminal()),
884
+ id_statuses[-1],
885
+ )
886
+ # Unpack the tuple first, or it triggers a Pylint's bug on recognizing
887
+ # the return type.
541
888
  return task_id, status
542
889
 
543
890
 
@@ -558,7 +905,7 @@ def get_failure_reason(job_id: int) -> Optional[str]:
558
905
  WHERE spot_job_id=(?)
559
906
  ORDER BY task_id ASC""", (job_id,)).fetchall()
560
907
  reason = [r[0] for r in reason if r[0] is not None]
561
- if len(reason) == 0:
908
+ if not reason:
562
909
  return None
563
910
  return reason[0]
564
911
 
@@ -572,6 +919,9 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
572
919
  # existing controller before #1982, the job_info table may not exist,
573
920
  # and all the managed jobs created before will not present in the
574
921
  # job_info.
922
+ # Note: we will get the user_hash here, but don't try to call
923
+ # global_user_state.get_user() on it. This runs on the controller, which may
924
+ # not have the user info. Prefer to do it on the API server side.
575
925
  with db_utils.safe_cursor(_DB_PATH) as cursor:
576
926
  rows = cursor.execute(f"""\
577
927
  SELECT *
@@ -584,6 +934,8 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
584
934
  for row in rows:
585
935
  job_dict = dict(zip(columns, row))
586
936
  job_dict['status'] = ManagedJobStatus(job_dict['status'])
937
+ job_dict['schedule_state'] = ManagedJobScheduleState(
938
+ job_dict['schedule_state'])
587
939
  if job_dict['job_name'] is None:
588
940
  job_dict['job_name'] = job_dict['task_name']
589
941
  jobs.append(job_dict)
@@ -611,3 +963,156 @@ def get_latest_job_id() -> Optional[int]:
611
963
  for (job_id,) in rows:
612
964
  return job_id
613
965
  return None
966
+
967
+
968
+ def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
969
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
970
+ task_specs = cursor.execute(
971
+ """\
972
+ SELECT specs FROM spot
973
+ WHERE spot_job_id=(?) AND task_id=(?)""",
974
+ (job_id, task_id)).fetchone()
975
+ return json.loads(task_specs[0])
976
+
977
+
978
+ def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
979
+ """Get the local log directory for a job."""
980
+ filter_str = 'spot_job_id=(?)'
981
+ filter_args = [job_id]
982
+ if task_id is not None:
983
+ filter_str += ' AND task_id=(?)'
984
+ filter_args.append(task_id)
985
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
986
+ local_log_file = cursor.execute(
987
+ f'SELECT local_log_file FROM spot '
988
+ f'WHERE {filter_str}', filter_args).fetchone()
989
+ return local_log_file[-1] if local_log_file else None
990
+
991
+
992
+ # === Scheduler state functions ===
993
+ # Only the scheduler should call these functions. They may require holding the
994
+ # scheduler lock to work correctly.
995
+
996
+
997
+ def scheduler_set_waiting(job_id: int, dag_yaml_path: str, env_file_path: str,
998
+ user_hash: str) -> None:
999
+ """Do not call without holding the scheduler lock."""
1000
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
1001
+ updated_count = cursor.execute(
1002
+ 'UPDATE job_info SET '
1003
+ 'schedule_state = (?), dag_yaml_path = (?), env_file_path = (?), '
1004
+ ' user_hash = (?) '
1005
+ 'WHERE spot_job_id = (?) AND schedule_state = (?)',
1006
+ (ManagedJobScheduleState.WAITING.value, dag_yaml_path,
1007
+ env_file_path, user_hash, job_id,
1008
+ ManagedJobScheduleState.INACTIVE.value)).rowcount
1009
+ assert updated_count == 1, (job_id, updated_count)
1010
+
1011
+
1012
+ def scheduler_set_launching(job_id: int,
1013
+ current_state: ManagedJobScheduleState) -> None:
1014
+ """Do not call without holding the scheduler lock."""
1015
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
1016
+ updated_count = cursor.execute(
1017
+ 'UPDATE job_info SET '
1018
+ 'schedule_state = (?) '
1019
+ 'WHERE spot_job_id = (?) AND schedule_state = (?)',
1020
+ (ManagedJobScheduleState.LAUNCHING.value, job_id,
1021
+ current_state.value)).rowcount
1022
+ assert updated_count == 1, (job_id, updated_count)
1023
+
1024
+
1025
+ def scheduler_set_alive(job_id: int) -> None:
1026
+ """Do not call without holding the scheduler lock."""
1027
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
1028
+ updated_count = cursor.execute(
1029
+ 'UPDATE job_info SET '
1030
+ 'schedule_state = (?) '
1031
+ 'WHERE spot_job_id = (?) AND schedule_state = (?)',
1032
+ (ManagedJobScheduleState.ALIVE.value, job_id,
1033
+ ManagedJobScheduleState.LAUNCHING.value)).rowcount
1034
+ assert updated_count == 1, (job_id, updated_count)
1035
+
1036
+
1037
+ def scheduler_set_alive_waiting(job_id: int) -> None:
1038
+ """Do not call without holding the scheduler lock."""
1039
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
1040
+ updated_count = cursor.execute(
1041
+ 'UPDATE job_info SET '
1042
+ 'schedule_state = (?) '
1043
+ 'WHERE spot_job_id = (?) AND schedule_state = (?)',
1044
+ (ManagedJobScheduleState.ALIVE_WAITING.value, job_id,
1045
+ ManagedJobScheduleState.ALIVE.value)).rowcount
1046
+ assert updated_count == 1, (job_id, updated_count)
1047
+
1048
+
1049
+ def scheduler_set_done(job_id: int, idempotent: bool = False) -> None:
1050
+ """Do not call without holding the scheduler lock."""
1051
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
1052
+ updated_count = cursor.execute(
1053
+ 'UPDATE job_info SET '
1054
+ 'schedule_state = (?) '
1055
+ 'WHERE spot_job_id = (?) AND schedule_state != (?)',
1056
+ (ManagedJobScheduleState.DONE.value, job_id,
1057
+ ManagedJobScheduleState.DONE.value)).rowcount
1058
+ if not idempotent:
1059
+ assert updated_count == 1, (job_id, updated_count)
1060
+
1061
+
1062
+ def set_job_controller_pid(job_id: int, pid: int):
1063
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
1064
+ updated_count = cursor.execute(
1065
+ 'UPDATE job_info SET '
1066
+ 'controller_pid = (?) '
1067
+ 'WHERE spot_job_id = (?)', (pid, job_id)).rowcount
1068
+ assert updated_count == 1, (job_id, updated_count)
1069
+
1070
+
1071
+ def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState:
1072
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
1073
+ state = cursor.execute(
1074
+ 'SELECT schedule_state FROM job_info WHERE spot_job_id = (?)',
1075
+ (job_id,)).fetchone()[0]
1076
+ return ManagedJobScheduleState(state)
1077
+
1078
+
1079
+ def get_num_launching_jobs() -> int:
1080
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
1081
+ return cursor.execute(
1082
+ 'SELECT COUNT(*) '
1083
+ 'FROM job_info '
1084
+ 'WHERE schedule_state = (?)',
1085
+ (ManagedJobScheduleState.LAUNCHING.value,)).fetchone()[0]
1086
+
1087
+
1088
+ def get_num_alive_jobs() -> int:
1089
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
1090
+ return cursor.execute(
1091
+ 'SELECT COUNT(*) '
1092
+ 'FROM job_info '
1093
+ 'WHERE schedule_state IN (?, ?, ?)',
1094
+ (ManagedJobScheduleState.ALIVE_WAITING.value,
1095
+ ManagedJobScheduleState.LAUNCHING.value,
1096
+ ManagedJobScheduleState.ALIVE.value)).fetchone()[0]
1097
+
1098
+
1099
+ def get_waiting_job() -> Optional[Dict[str, Any]]:
1100
+ """Get the next job that should transition to LAUNCHING.
1101
+
1102
+ Backwards compatibility note: jobs submitted before #4485 will have no
1103
+ schedule_state and will be ignored by this SQL query.
1104
+ """
1105
+ with db_utils.safe_cursor(_DB_PATH) as cursor:
1106
+ row = cursor.execute(
1107
+ 'SELECT spot_job_id, schedule_state, dag_yaml_path, env_file_path '
1108
+ 'FROM job_info '
1109
+ 'WHERE schedule_state in (?, ?) '
1110
+ 'ORDER BY spot_job_id LIMIT 1',
1111
+ (ManagedJobScheduleState.WAITING.value,
1112
+ ManagedJobScheduleState.ALIVE_WAITING.value)).fetchone()
1113
+ return {
1114
+ 'job_id': row[0],
1115
+ 'schedule_state': ManagedJobScheduleState(row[1]),
1116
+ 'dag_yaml_path': row[2],
1117
+ 'env_file_path': row[3],
1118
+ } if row is not None else None