skypilot-nightly 1.0.0.dev20250502__py3-none-any.whl → 1.0.0.dev20251203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. sky/__init__.py +22 -6
  2. sky/adaptors/aws.py +81 -16
  3. sky/adaptors/common.py +25 -2
  4. sky/adaptors/coreweave.py +278 -0
  5. sky/adaptors/do.py +8 -2
  6. sky/adaptors/gcp.py +11 -0
  7. sky/adaptors/hyperbolic.py +8 -0
  8. sky/adaptors/ibm.py +5 -2
  9. sky/adaptors/kubernetes.py +149 -18
  10. sky/adaptors/nebius.py +173 -30
  11. sky/adaptors/primeintellect.py +1 -0
  12. sky/adaptors/runpod.py +68 -0
  13. sky/adaptors/seeweb.py +183 -0
  14. sky/adaptors/shadeform.py +89 -0
  15. sky/admin_policy.py +187 -4
  16. sky/authentication.py +179 -225
  17. sky/backends/__init__.py +4 -2
  18. sky/backends/backend.py +22 -9
  19. sky/backends/backend_utils.py +1323 -397
  20. sky/backends/cloud_vm_ray_backend.py +1749 -1029
  21. sky/backends/docker_utils.py +1 -1
  22. sky/backends/local_docker_backend.py +11 -6
  23. sky/backends/task_codegen.py +633 -0
  24. sky/backends/wheel_utils.py +55 -9
  25. sky/{clouds/service_catalog → catalog}/__init__.py +21 -19
  26. sky/{clouds/service_catalog → catalog}/aws_catalog.py +27 -8
  27. sky/{clouds/service_catalog → catalog}/azure_catalog.py +10 -7
  28. sky/{clouds/service_catalog → catalog}/common.py +90 -49
  29. sky/{clouds/service_catalog → catalog}/cudo_catalog.py +8 -5
  30. sky/{clouds/service_catalog → catalog}/data_fetchers/analyze.py +1 -1
  31. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_aws.py +116 -80
  32. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_cudo.py +38 -38
  33. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_gcp.py +70 -16
  34. sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
  35. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_lambda_cloud.py +1 -0
  36. sky/catalog/data_fetchers/fetch_nebius.py +338 -0
  37. sky/catalog/data_fetchers/fetch_runpod.py +698 -0
  38. sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
  39. sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
  40. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vast.py +1 -1
  41. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vsphere.py +1 -1
  42. sky/{clouds/service_catalog → catalog}/do_catalog.py +5 -2
  43. sky/{clouds/service_catalog → catalog}/fluidstack_catalog.py +6 -3
  44. sky/{clouds/service_catalog → catalog}/gcp_catalog.py +41 -15
  45. sky/catalog/hyperbolic_catalog.py +136 -0
  46. sky/{clouds/service_catalog → catalog}/ibm_catalog.py +9 -6
  47. sky/{clouds/service_catalog → catalog}/kubernetes_catalog.py +36 -24
  48. sky/{clouds/service_catalog → catalog}/lambda_catalog.py +9 -6
  49. sky/{clouds/service_catalog → catalog}/nebius_catalog.py +9 -7
  50. sky/{clouds/service_catalog → catalog}/oci_catalog.py +9 -6
  51. sky/{clouds/service_catalog → catalog}/paperspace_catalog.py +5 -2
  52. sky/catalog/primeintellect_catalog.py +95 -0
  53. sky/{clouds/service_catalog → catalog}/runpod_catalog.py +11 -4
  54. sky/{clouds/service_catalog → catalog}/scp_catalog.py +9 -6
  55. sky/catalog/seeweb_catalog.py +184 -0
  56. sky/catalog/shadeform_catalog.py +165 -0
  57. sky/catalog/ssh_catalog.py +167 -0
  58. sky/{clouds/service_catalog → catalog}/vast_catalog.py +6 -3
  59. sky/{clouds/service_catalog → catalog}/vsphere_catalog.py +5 -2
  60. sky/check.py +533 -185
  61. sky/cli.py +5 -5975
  62. sky/client/{cli.py → cli/command.py} +2591 -1956
  63. sky/client/cli/deprecation_utils.py +99 -0
  64. sky/client/cli/flags.py +359 -0
  65. sky/client/cli/table_utils.py +322 -0
  66. sky/client/cli/utils.py +79 -0
  67. sky/client/common.py +78 -32
  68. sky/client/oauth.py +82 -0
  69. sky/client/sdk.py +1219 -319
  70. sky/client/sdk_async.py +827 -0
  71. sky/client/service_account_auth.py +47 -0
  72. sky/cloud_stores.py +82 -3
  73. sky/clouds/__init__.py +13 -0
  74. sky/clouds/aws.py +564 -164
  75. sky/clouds/azure.py +105 -83
  76. sky/clouds/cloud.py +140 -40
  77. sky/clouds/cudo.py +68 -50
  78. sky/clouds/do.py +66 -48
  79. sky/clouds/fluidstack.py +63 -44
  80. sky/clouds/gcp.py +339 -110
  81. sky/clouds/hyperbolic.py +293 -0
  82. sky/clouds/ibm.py +70 -49
  83. sky/clouds/kubernetes.py +570 -162
  84. sky/clouds/lambda_cloud.py +74 -54
  85. sky/clouds/nebius.py +210 -81
  86. sky/clouds/oci.py +88 -66
  87. sky/clouds/paperspace.py +61 -44
  88. sky/clouds/primeintellect.py +317 -0
  89. sky/clouds/runpod.py +164 -74
  90. sky/clouds/scp.py +89 -86
  91. sky/clouds/seeweb.py +477 -0
  92. sky/clouds/shadeform.py +400 -0
  93. sky/clouds/ssh.py +263 -0
  94. sky/clouds/utils/aws_utils.py +10 -4
  95. sky/clouds/utils/gcp_utils.py +87 -11
  96. sky/clouds/utils/oci_utils.py +38 -14
  97. sky/clouds/utils/scp_utils.py +231 -167
  98. sky/clouds/vast.py +99 -77
  99. sky/clouds/vsphere.py +51 -40
  100. sky/core.py +375 -173
  101. sky/dag.py +15 -0
  102. sky/dashboard/out/404.html +1 -1
  103. sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +1 -0
  104. sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
  105. sky/dashboard/out/_next/static/chunks/1272-1ef0bf0237faccdb.js +1 -0
  106. sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +6 -0
  107. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
  108. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +1 -0
  109. sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
  110. sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +26 -0
  111. sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +1 -0
  112. sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +1 -0
  113. sky/dashboard/out/_next/static/chunks/3800-7b45f9fbb6308557.js +1 -0
  114. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
  115. sky/dashboard/out/_next/static/chunks/3937.210053269f121201.js +1 -0
  116. sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +1 -0
  117. sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
  118. sky/dashboard/out/_next/static/chunks/5739-d67458fcb1386c92.js +8 -0
  119. sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
  120. sky/dashboard/out/_next/static/chunks/616-3d59f75e2ccf9321.js +39 -0
  121. sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
  122. sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +1 -0
  123. sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +1 -0
  124. sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +1 -0
  125. sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
  126. sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +41 -0
  127. sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +1 -0
  128. sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
  129. sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +1 -0
  130. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
  131. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
  132. sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +31 -0
  133. sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
  134. sky/dashboard/out/_next/static/chunks/fd9d1056-86323a29a8f7e46a.js +1 -0
  135. sky/dashboard/out/_next/static/chunks/framework-cf60a09ccd051a10.js +33 -0
  136. sky/dashboard/out/_next/static/chunks/main-app-587214043926b3cc.js +1 -0
  137. sky/dashboard/out/_next/static/chunks/main-f15ccb73239a3bf1.js +1 -0
  138. sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +34 -0
  139. sky/dashboard/out/_next/static/chunks/pages/_error-c66a4e8afc46f17b.js +1 -0
  140. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +16 -0
  141. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-abfcac9c137aa543.js +1 -0
  142. sky/dashboard/out/_next/static/chunks/pages/clusters-ee39056f9851a3ff.js +1 -0
  143. sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
  144. sky/dashboard/out/_next/static/chunks/pages/index-444f1804401f04ea.js +1 -0
  145. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +1 -0
  146. sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +1 -0
  147. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +16 -0
  148. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-9faf940b253e3e06.js +21 -0
  149. sky/dashboard/out/_next/static/chunks/pages/jobs-2072b48b617989c9.js +1 -0
  150. sky/dashboard/out/_next/static/chunks/pages/users-f42674164aa73423.js +1 -0
  151. sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
  152. sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
  153. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-84a40f8c7c627fe4.js +1 -0
  154. sky/dashboard/out/_next/static/chunks/pages/workspaces-531b2f8c4bf89f82.js +1 -0
  155. sky/dashboard/out/_next/static/chunks/webpack-64e05f17bf2cf8ce.js +1 -0
  156. sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
  157. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  158. sky/dashboard/out/clusters/[cluster].html +1 -1
  159. sky/dashboard/out/clusters.html +1 -1
  160. sky/dashboard/out/config.html +1 -0
  161. sky/dashboard/out/index.html +1 -1
  162. sky/dashboard/out/infra/[context].html +1 -0
  163. sky/dashboard/out/infra.html +1 -0
  164. sky/dashboard/out/jobs/[job].html +1 -1
  165. sky/dashboard/out/jobs/pools/[pool].html +1 -0
  166. sky/dashboard/out/jobs.html +1 -1
  167. sky/dashboard/out/users.html +1 -0
  168. sky/dashboard/out/volumes.html +1 -0
  169. sky/dashboard/out/workspace/new.html +1 -0
  170. sky/dashboard/out/workspaces/[name].html +1 -0
  171. sky/dashboard/out/workspaces.html +1 -0
  172. sky/data/data_utils.py +137 -1
  173. sky/data/mounting_utils.py +269 -84
  174. sky/data/storage.py +1460 -1807
  175. sky/data/storage_utils.py +43 -57
  176. sky/exceptions.py +126 -2
  177. sky/execution.py +216 -63
  178. sky/global_user_state.py +2390 -586
  179. sky/jobs/__init__.py +7 -0
  180. sky/jobs/client/sdk.py +300 -58
  181. sky/jobs/client/sdk_async.py +161 -0
  182. sky/jobs/constants.py +15 -8
  183. sky/jobs/controller.py +848 -275
  184. sky/jobs/file_content_utils.py +128 -0
  185. sky/jobs/log_gc.py +193 -0
  186. sky/jobs/recovery_strategy.py +402 -152
  187. sky/jobs/scheduler.py +314 -189
  188. sky/jobs/server/core.py +836 -255
  189. sky/jobs/server/server.py +156 -115
  190. sky/jobs/server/utils.py +136 -0
  191. sky/jobs/state.py +2109 -706
  192. sky/jobs/utils.py +1306 -215
  193. sky/logs/__init__.py +21 -0
  194. sky/logs/agent.py +108 -0
  195. sky/logs/aws.py +243 -0
  196. sky/logs/gcp.py +91 -0
  197. sky/metrics/__init__.py +0 -0
  198. sky/metrics/utils.py +453 -0
  199. sky/models.py +78 -1
  200. sky/optimizer.py +164 -70
  201. sky/provision/__init__.py +90 -4
  202. sky/provision/aws/config.py +147 -26
  203. sky/provision/aws/instance.py +136 -50
  204. sky/provision/azure/instance.py +11 -6
  205. sky/provision/common.py +13 -1
  206. sky/provision/cudo/cudo_machine_type.py +1 -1
  207. sky/provision/cudo/cudo_utils.py +14 -8
  208. sky/provision/cudo/cudo_wrapper.py +72 -71
  209. sky/provision/cudo/instance.py +10 -6
  210. sky/provision/do/instance.py +10 -6
  211. sky/provision/do/utils.py +4 -3
  212. sky/provision/docker_utils.py +140 -33
  213. sky/provision/fluidstack/instance.py +13 -8
  214. sky/provision/gcp/__init__.py +1 -0
  215. sky/provision/gcp/config.py +301 -19
  216. sky/provision/gcp/constants.py +218 -0
  217. sky/provision/gcp/instance.py +36 -8
  218. sky/provision/gcp/instance_utils.py +18 -4
  219. sky/provision/gcp/volume_utils.py +247 -0
  220. sky/provision/hyperbolic/__init__.py +12 -0
  221. sky/provision/hyperbolic/config.py +10 -0
  222. sky/provision/hyperbolic/instance.py +437 -0
  223. sky/provision/hyperbolic/utils.py +373 -0
  224. sky/provision/instance_setup.py +101 -20
  225. sky/provision/kubernetes/__init__.py +5 -0
  226. sky/provision/kubernetes/config.py +9 -52
  227. sky/provision/kubernetes/constants.py +17 -0
  228. sky/provision/kubernetes/instance.py +919 -280
  229. sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
  230. sky/provision/kubernetes/network.py +27 -17
  231. sky/provision/kubernetes/network_utils.py +44 -43
  232. sky/provision/kubernetes/utils.py +1221 -534
  233. sky/provision/kubernetes/volume.py +343 -0
  234. sky/provision/lambda_cloud/instance.py +22 -16
  235. sky/provision/nebius/constants.py +50 -0
  236. sky/provision/nebius/instance.py +19 -6
  237. sky/provision/nebius/utils.py +237 -137
  238. sky/provision/oci/instance.py +10 -5
  239. sky/provision/paperspace/instance.py +10 -7
  240. sky/provision/paperspace/utils.py +1 -1
  241. sky/provision/primeintellect/__init__.py +10 -0
  242. sky/provision/primeintellect/config.py +11 -0
  243. sky/provision/primeintellect/instance.py +454 -0
  244. sky/provision/primeintellect/utils.py +398 -0
  245. sky/provision/provisioner.py +117 -36
  246. sky/provision/runpod/__init__.py +5 -0
  247. sky/provision/runpod/instance.py +27 -6
  248. sky/provision/runpod/utils.py +51 -18
  249. sky/provision/runpod/volume.py +214 -0
  250. sky/provision/scp/__init__.py +15 -0
  251. sky/provision/scp/config.py +93 -0
  252. sky/provision/scp/instance.py +707 -0
  253. sky/provision/seeweb/__init__.py +11 -0
  254. sky/provision/seeweb/config.py +13 -0
  255. sky/provision/seeweb/instance.py +812 -0
  256. sky/provision/shadeform/__init__.py +11 -0
  257. sky/provision/shadeform/config.py +12 -0
  258. sky/provision/shadeform/instance.py +351 -0
  259. sky/provision/shadeform/shadeform_utils.py +83 -0
  260. sky/provision/ssh/__init__.py +18 -0
  261. sky/provision/vast/instance.py +13 -8
  262. sky/provision/vast/utils.py +10 -7
  263. sky/provision/volume.py +164 -0
  264. sky/provision/vsphere/common/ssl_helper.py +1 -1
  265. sky/provision/vsphere/common/vapiconnect.py +2 -1
  266. sky/provision/vsphere/common/vim_utils.py +4 -4
  267. sky/provision/vsphere/instance.py +15 -10
  268. sky/provision/vsphere/vsphere_utils.py +17 -20
  269. sky/py.typed +0 -0
  270. sky/resources.py +845 -119
  271. sky/schemas/__init__.py +0 -0
  272. sky/schemas/api/__init__.py +0 -0
  273. sky/schemas/api/responses.py +227 -0
  274. sky/schemas/db/README +4 -0
  275. sky/schemas/db/env.py +90 -0
  276. sky/schemas/db/global_user_state/001_initial_schema.py +124 -0
  277. sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
  278. sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
  279. sky/schemas/db/global_user_state/004_is_managed.py +34 -0
  280. sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
  281. sky/schemas/db/global_user_state/006_provision_log.py +41 -0
  282. sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
  283. sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
  284. sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
  285. sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
  286. sky/schemas/db/global_user_state/011_is_ephemeral.py +34 -0
  287. sky/schemas/db/kv_cache/001_initial_schema.py +29 -0
  288. sky/schemas/db/script.py.mako +28 -0
  289. sky/schemas/db/serve_state/001_initial_schema.py +67 -0
  290. sky/schemas/db/serve_state/002_yaml_content.py +34 -0
  291. sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
  292. sky/schemas/db/spot_jobs/001_initial_schema.py +97 -0
  293. sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
  294. sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
  295. sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
  296. sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
  297. sky/schemas/db/spot_jobs/006_controller_pid_started_at.py +34 -0
  298. sky/schemas/db/spot_jobs/007_config_file_content.py +34 -0
  299. sky/schemas/generated/__init__.py +0 -0
  300. sky/schemas/generated/autostopv1_pb2.py +36 -0
  301. sky/schemas/generated/autostopv1_pb2.pyi +43 -0
  302. sky/schemas/generated/autostopv1_pb2_grpc.py +146 -0
  303. sky/schemas/generated/jobsv1_pb2.py +86 -0
  304. sky/schemas/generated/jobsv1_pb2.pyi +254 -0
  305. sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
  306. sky/schemas/generated/managed_jobsv1_pb2.py +76 -0
  307. sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
  308. sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
  309. sky/schemas/generated/servev1_pb2.py +58 -0
  310. sky/schemas/generated/servev1_pb2.pyi +115 -0
  311. sky/schemas/generated/servev1_pb2_grpc.py +322 -0
  312. sky/serve/autoscalers.py +357 -5
  313. sky/serve/client/impl.py +310 -0
  314. sky/serve/client/sdk.py +47 -139
  315. sky/serve/client/sdk_async.py +130 -0
  316. sky/serve/constants.py +12 -9
  317. sky/serve/controller.py +68 -17
  318. sky/serve/load_balancer.py +106 -60
  319. sky/serve/load_balancing_policies.py +116 -2
  320. sky/serve/replica_managers.py +434 -249
  321. sky/serve/serve_rpc_utils.py +179 -0
  322. sky/serve/serve_state.py +569 -257
  323. sky/serve/serve_utils.py +775 -265
  324. sky/serve/server/core.py +66 -711
  325. sky/serve/server/impl.py +1093 -0
  326. sky/serve/server/server.py +21 -18
  327. sky/serve/service.py +192 -89
  328. sky/serve/service_spec.py +144 -20
  329. sky/serve/spot_placer.py +3 -0
  330. sky/server/auth/__init__.py +0 -0
  331. sky/server/auth/authn.py +50 -0
  332. sky/server/auth/loopback.py +38 -0
  333. sky/server/auth/oauth2_proxy.py +202 -0
  334. sky/server/common.py +478 -182
  335. sky/server/config.py +85 -23
  336. sky/server/constants.py +44 -6
  337. sky/server/daemons.py +295 -0
  338. sky/server/html/token_page.html +185 -0
  339. sky/server/metrics.py +160 -0
  340. sky/server/middleware_utils.py +166 -0
  341. sky/server/requests/executor.py +558 -138
  342. sky/server/requests/payloads.py +364 -24
  343. sky/server/requests/preconditions.py +21 -17
  344. sky/server/requests/process.py +112 -29
  345. sky/server/requests/request_names.py +121 -0
  346. sky/server/requests/requests.py +822 -226
  347. sky/server/requests/serializers/decoders.py +82 -31
  348. sky/server/requests/serializers/encoders.py +140 -22
  349. sky/server/requests/threads.py +117 -0
  350. sky/server/rest.py +455 -0
  351. sky/server/server.py +1309 -285
  352. sky/server/state.py +20 -0
  353. sky/server/stream_utils.py +327 -61
  354. sky/server/uvicorn.py +217 -3
  355. sky/server/versions.py +270 -0
  356. sky/setup_files/MANIFEST.in +11 -1
  357. sky/setup_files/alembic.ini +160 -0
  358. sky/setup_files/dependencies.py +139 -31
  359. sky/setup_files/setup.py +44 -42
  360. sky/sky_logging.py +114 -7
  361. sky/skylet/attempt_skylet.py +106 -24
  362. sky/skylet/autostop_lib.py +129 -8
  363. sky/skylet/configs.py +29 -20
  364. sky/skylet/constants.py +216 -25
  365. sky/skylet/events.py +101 -21
  366. sky/skylet/job_lib.py +345 -164
  367. sky/skylet/log_lib.py +297 -18
  368. sky/skylet/log_lib.pyi +44 -1
  369. sky/skylet/providers/ibm/node_provider.py +12 -8
  370. sky/skylet/providers/ibm/vpc_provider.py +13 -12
  371. sky/skylet/ray_patches/__init__.py +17 -3
  372. sky/skylet/ray_patches/autoscaler.py.diff +18 -0
  373. sky/skylet/ray_patches/cli.py.diff +19 -0
  374. sky/skylet/ray_patches/command_runner.py.diff +17 -0
  375. sky/skylet/ray_patches/log_monitor.py.diff +20 -0
  376. sky/skylet/ray_patches/resource_demand_scheduler.py.diff +32 -0
  377. sky/skylet/ray_patches/updater.py.diff +18 -0
  378. sky/skylet/ray_patches/worker.py.diff +41 -0
  379. sky/skylet/runtime_utils.py +21 -0
  380. sky/skylet/services.py +568 -0
  381. sky/skylet/skylet.py +72 -4
  382. sky/skylet/subprocess_daemon.py +104 -29
  383. sky/skypilot_config.py +506 -99
  384. sky/ssh_node_pools/__init__.py +1 -0
  385. sky/ssh_node_pools/core.py +135 -0
  386. sky/ssh_node_pools/server.py +233 -0
  387. sky/task.py +685 -163
  388. sky/templates/aws-ray.yml.j2 +11 -3
  389. sky/templates/azure-ray.yml.j2 +2 -1
  390. sky/templates/cudo-ray.yml.j2 +1 -0
  391. sky/templates/do-ray.yml.j2 +2 -1
  392. sky/templates/fluidstack-ray.yml.j2 +1 -0
  393. sky/templates/gcp-ray.yml.j2 +62 -1
  394. sky/templates/hyperbolic-ray.yml.j2 +68 -0
  395. sky/templates/ibm-ray.yml.j2 +2 -1
  396. sky/templates/jobs-controller.yaml.j2 +27 -24
  397. sky/templates/kubernetes-loadbalancer.yml.j2 +2 -0
  398. sky/templates/kubernetes-ray.yml.j2 +611 -50
  399. sky/templates/lambda-ray.yml.j2 +2 -1
  400. sky/templates/nebius-ray.yml.j2 +34 -12
  401. sky/templates/oci-ray.yml.j2 +1 -0
  402. sky/templates/paperspace-ray.yml.j2 +2 -1
  403. sky/templates/primeintellect-ray.yml.j2 +72 -0
  404. sky/templates/runpod-ray.yml.j2 +10 -1
  405. sky/templates/scp-ray.yml.j2 +4 -50
  406. sky/templates/seeweb-ray.yml.j2 +171 -0
  407. sky/templates/shadeform-ray.yml.j2 +73 -0
  408. sky/templates/sky-serve-controller.yaml.j2 +22 -2
  409. sky/templates/vast-ray.yml.j2 +1 -0
  410. sky/templates/vsphere-ray.yml.j2 +1 -0
  411. sky/templates/websocket_proxy.py +212 -37
  412. sky/usage/usage_lib.py +31 -15
  413. sky/users/__init__.py +0 -0
  414. sky/users/model.conf +15 -0
  415. sky/users/permission.py +397 -0
  416. sky/users/rbac.py +121 -0
  417. sky/users/server.py +720 -0
  418. sky/users/token_service.py +218 -0
  419. sky/utils/accelerator_registry.py +35 -5
  420. sky/utils/admin_policy_utils.py +84 -38
  421. sky/utils/annotations.py +38 -5
  422. sky/utils/asyncio_utils.py +78 -0
  423. sky/utils/atomic.py +1 -1
  424. sky/utils/auth_utils.py +153 -0
  425. sky/utils/benchmark_utils.py +60 -0
  426. sky/utils/cli_utils/status_utils.py +159 -86
  427. sky/utils/cluster_utils.py +31 -9
  428. sky/utils/command_runner.py +354 -68
  429. sky/utils/command_runner.pyi +93 -3
  430. sky/utils/common.py +35 -8
  431. sky/utils/common_utils.py +314 -91
  432. sky/utils/config_utils.py +74 -5
  433. sky/utils/context.py +403 -0
  434. sky/utils/context_utils.py +242 -0
  435. sky/utils/controller_utils.py +383 -89
  436. sky/utils/dag_utils.py +31 -12
  437. sky/utils/db/__init__.py +0 -0
  438. sky/utils/db/db_utils.py +485 -0
  439. sky/utils/db/kv_cache.py +149 -0
  440. sky/utils/db/migration_utils.py +137 -0
  441. sky/utils/directory_utils.py +12 -0
  442. sky/utils/env_options.py +13 -0
  443. sky/utils/git.py +567 -0
  444. sky/utils/git_clone.sh +460 -0
  445. sky/utils/infra_utils.py +195 -0
  446. sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
  447. sky/utils/kubernetes/config_map_utils.py +133 -0
  448. sky/utils/kubernetes/create_cluster.sh +15 -29
  449. sky/utils/kubernetes/delete_cluster.sh +10 -7
  450. sky/utils/kubernetes/deploy_ssh_node_pools.py +1177 -0
  451. sky/utils/kubernetes/exec_kubeconfig_converter.py +22 -31
  452. sky/utils/kubernetes/generate_kind_config.py +6 -66
  453. sky/utils/kubernetes/generate_kubeconfig.sh +4 -1
  454. sky/utils/kubernetes/gpu_labeler.py +18 -8
  455. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +2 -1
  456. sky/utils/kubernetes/k8s_gpu_labeler_setup.yaml +16 -16
  457. sky/utils/kubernetes/kubernetes_deploy_utils.py +284 -114
  458. sky/utils/kubernetes/rsync_helper.sh +11 -3
  459. sky/utils/kubernetes/ssh-tunnel.sh +379 -0
  460. sky/utils/kubernetes/ssh_utils.py +221 -0
  461. sky/utils/kubernetes_enums.py +8 -15
  462. sky/utils/lock_events.py +94 -0
  463. sky/utils/locks.py +416 -0
  464. sky/utils/log_utils.py +82 -107
  465. sky/utils/perf_utils.py +22 -0
  466. sky/utils/resource_checker.py +298 -0
  467. sky/utils/resources_utils.py +249 -32
  468. sky/utils/rich_utils.py +217 -39
  469. sky/utils/schemas.py +955 -160
  470. sky/utils/serialize_utils.py +16 -0
  471. sky/utils/status_lib.py +10 -0
  472. sky/utils/subprocess_utils.py +29 -15
  473. sky/utils/tempstore.py +70 -0
  474. sky/utils/thread_utils.py +91 -0
  475. sky/utils/timeline.py +26 -53
  476. sky/utils/ux_utils.py +84 -15
  477. sky/utils/validator.py +11 -1
  478. sky/utils/volume.py +165 -0
  479. sky/utils/yaml_utils.py +111 -0
  480. sky/volumes/__init__.py +13 -0
  481. sky/volumes/client/__init__.py +0 -0
  482. sky/volumes/client/sdk.py +150 -0
  483. sky/volumes/server/__init__.py +0 -0
  484. sky/volumes/server/core.py +270 -0
  485. sky/volumes/server/server.py +124 -0
  486. sky/volumes/volume.py +215 -0
  487. sky/workspaces/__init__.py +0 -0
  488. sky/workspaces/core.py +655 -0
  489. sky/workspaces/server.py +101 -0
  490. sky/workspaces/utils.py +56 -0
  491. sky_templates/README.md +3 -0
  492. sky_templates/__init__.py +3 -0
  493. sky_templates/ray/__init__.py +0 -0
  494. sky_templates/ray/start_cluster +183 -0
  495. sky_templates/ray/stop_cluster +75 -0
  496. skypilot_nightly-1.0.0.dev20251203.dist-info/METADATA +676 -0
  497. skypilot_nightly-1.0.0.dev20251203.dist-info/RECORD +611 -0
  498. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/WHEEL +1 -1
  499. skypilot_nightly-1.0.0.dev20251203.dist-info/top_level.txt +2 -0
  500. sky/benchmark/benchmark_state.py +0 -256
  501. sky/benchmark/benchmark_utils.py +0 -641
  502. sky/clouds/service_catalog/constants.py +0 -7
  503. sky/dashboard/out/_next/static/GWvVBSCS7FmUiVmjaL1a7/_buildManifest.js +0 -1
  504. sky/dashboard/out/_next/static/chunks/236-2db3ee3fba33dd9e.js +0 -6
  505. sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
  506. sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
  507. sky/dashboard/out/_next/static/chunks/678-206dddca808e6d16.js +0 -59
  508. sky/dashboard/out/_next/static/chunks/845-9e60713e0c441abc.js +0 -1
  509. sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
  510. sky/dashboard/out/_next/static/chunks/fd9d1056-2821b0f0cabcd8bd.js +0 -1
  511. sky/dashboard/out/_next/static/chunks/framework-87d061ee6ed71b28.js +0 -33
  512. sky/dashboard/out/_next/static/chunks/main-app-241eb28595532291.js +0 -1
  513. sky/dashboard/out/_next/static/chunks/main-e0e2335212e72357.js +0 -1
  514. sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
  515. sky/dashboard/out/_next/static/chunks/pages/_error-1be831200e60c5c0.js +0 -1
  516. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-6ac338bc2239cb45.js +0 -1
  517. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
  518. sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
  519. sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
  520. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-1c519e1afc523dc9.js +0 -1
  521. sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
  522. sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
  523. sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
  524. sky/jobs/dashboard/dashboard.py +0 -223
  525. sky/jobs/dashboard/static/favicon.ico +0 -0
  526. sky/jobs/dashboard/templates/index.html +0 -831
  527. sky/jobs/server/dashboard_utils.py +0 -69
  528. sky/skylet/providers/scp/__init__.py +0 -2
  529. sky/skylet/providers/scp/config.py +0 -149
  530. sky/skylet/providers/scp/node_provider.py +0 -578
  531. sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
  532. sky/utils/db_utils.py +0 -100
  533. sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
  534. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
  535. skypilot_nightly-1.0.0.dev20250502.dist-info/METADATA +0 -361
  536. skypilot_nightly-1.0.0.dev20250502.dist-info/RECORD +0 -396
  537. skypilot_nightly-1.0.0.dev20250502.dist-info/top_level.txt +0 -1
  538. /sky/{clouds/service_catalog → catalog}/config.py +0 -0
  539. /sky/{benchmark → catalog/data_fetchers}/__init__.py +0 -0
  540. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_azure.py +0 -0
  541. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_fluidstack.py +0 -0
  542. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_ibm.py +0 -0
  543. /sky/{clouds/service_catalog/data_fetchers → client/cli}/__init__.py +0 -0
  544. /sky/dashboard/out/_next/static/{GWvVBSCS7FmUiVmjaL1a7 → 96_E2yl3QAiIJGOYCkSpB}/_ssgManifest.js +0 -0
  545. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/entry_points.txt +0 -0
  546. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/licenses/LICENSE +0 -0
sky/jobs/state.py CHANGED
@@ -1,28 +1,60 @@
1
1
  """The database for managed jobs status."""
2
2
  # TODO(zhwu): maybe use file based status instead of database, so
3
3
  # that we can easily switch to a s3-based storage.
4
+ import asyncio
5
+ import collections
4
6
  import enum
7
+ import functools
8
+ import ipaddress
5
9
  import json
6
- import pathlib
7
10
  import sqlite3
11
+ import threading
8
12
  import time
9
13
  import typing
10
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
14
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
15
+ import urllib.parse
11
16
 
12
17
  import colorama
18
+ import sqlalchemy
19
+ from sqlalchemy import exc as sqlalchemy_exc
20
+ from sqlalchemy import orm
21
+ from sqlalchemy.dialects import postgresql
22
+ from sqlalchemy.dialects import sqlite
23
+ from sqlalchemy.ext import asyncio as sql_async
24
+ from sqlalchemy.ext import declarative
13
25
 
14
26
  from sky import exceptions
15
27
  from sky import sky_logging
28
+ from sky import skypilot_config
29
+ from sky.adaptors import common as adaptors_common
30
+ from sky.skylet import constants
16
31
  from sky.utils import common_utils
17
- from sky.utils import db_utils
32
+ from sky.utils import context_utils
33
+ from sky.utils.db import db_utils
34
+ from sky.utils.db import migration_utils
18
35
 
19
36
  if typing.TYPE_CHECKING:
20
- import sky
37
+ from sqlalchemy.engine import row
21
38
 
22
- CallbackType = Callable[[str], None]
39
+ from sky.schemas.generated import managed_jobsv1_pb2
40
+ else:
41
+ managed_jobsv1_pb2 = adaptors_common.LazyImport(
42
+ 'sky.schemas.generated.managed_jobsv1_pb2')
43
+
44
+ # Separate callback types for sync and async contexts
45
+ SyncCallbackType = Callable[[str], None]
46
+ AsyncCallbackType = Callable[[str], Awaitable[Any]]
47
+ CallbackType = Union[SyncCallbackType, AsyncCallbackType]
23
48
 
24
49
  logger = sky_logging.init_logger(__name__)
25
50
 
51
+ _SQLALCHEMY_ENGINE: Optional[sqlalchemy.engine.Engine] = None
52
+ _SQLALCHEMY_ENGINE_ASYNC: Optional[sql_async.AsyncEngine] = None
53
+ _SQLALCHEMY_ENGINE_LOCK = threading.Lock()
54
+
55
+ _DB_RETRY_TIMES = 30
56
+
57
+ Base = declarative.declarative_base()
26
58
 
27
59
  # === Database schema ===
28
60
  # `spot` table contains all the finest-grained tasks, including all the
@@ -35,122 +67,255 @@ logger = sky_logging.init_logger(__name__)
35
67
  # identifier/primary key for all the tasks. We will use `spot_job_id`
36
68
  # to identify the job.
37
69
  # TODO(zhwu): schema migration may be needed.
38
- def create_table(cursor, conn):
70
+
71
+ spot_table = sqlalchemy.Table(
72
+ 'spot',
73
+ Base.metadata,
74
+ sqlalchemy.Column('job_id',
75
+ sqlalchemy.Integer,
76
+ primary_key=True,
77
+ autoincrement=True),
78
+ sqlalchemy.Column('job_name', sqlalchemy.Text),
79
+ sqlalchemy.Column('resources', sqlalchemy.Text),
80
+ sqlalchemy.Column('submitted_at', sqlalchemy.Float),
81
+ sqlalchemy.Column('status', sqlalchemy.Text),
82
+ sqlalchemy.Column('run_timestamp', sqlalchemy.Text),
83
+ sqlalchemy.Column('start_at', sqlalchemy.Float, server_default=None),
84
+ sqlalchemy.Column('end_at', sqlalchemy.Float, server_default=None),
85
+ sqlalchemy.Column('last_recovered_at',
86
+ sqlalchemy.Float,
87
+ server_default='-1'),
88
+ sqlalchemy.Column('recovery_count', sqlalchemy.Integer, server_default='0'),
89
+ sqlalchemy.Column('job_duration', sqlalchemy.Float, server_default='0'),
90
+ sqlalchemy.Column('failure_reason', sqlalchemy.Text),
91
+ sqlalchemy.Column('spot_job_id', sqlalchemy.Integer, index=True),
92
+ sqlalchemy.Column('task_id', sqlalchemy.Integer, server_default='0'),
93
+ sqlalchemy.Column('task_name', sqlalchemy.Text),
94
+ sqlalchemy.Column('specs', sqlalchemy.Text),
95
+ sqlalchemy.Column('local_log_file', sqlalchemy.Text, server_default=None),
96
+ sqlalchemy.Column('metadata', sqlalchemy.Text, server_default='{}'),
97
+ sqlalchemy.Column('logs_cleaned_at', sqlalchemy.Float, server_default=None),
98
+ )
99
+
100
+ job_info_table = sqlalchemy.Table(
101
+ 'job_info',
102
+ Base.metadata,
103
+ sqlalchemy.Column('spot_job_id',
104
+ sqlalchemy.Integer,
105
+ primary_key=True,
106
+ autoincrement=True),
107
+ sqlalchemy.Column('name', sqlalchemy.Text),
108
+ sqlalchemy.Column('schedule_state', sqlalchemy.Text),
109
+ sqlalchemy.Column('controller_pid', sqlalchemy.Integer,
110
+ server_default=None),
111
+ sqlalchemy.Column('controller_pid_started_at',
112
+ sqlalchemy.Float,
113
+ server_default=None),
114
+ sqlalchemy.Column('dag_yaml_path', sqlalchemy.Text),
115
+ sqlalchemy.Column('env_file_path', sqlalchemy.Text),
116
+ sqlalchemy.Column('dag_yaml_content', sqlalchemy.Text, server_default=None),
117
+ sqlalchemy.Column('env_file_content', sqlalchemy.Text, server_default=None),
118
+ sqlalchemy.Column('config_file_content',
119
+ sqlalchemy.Text,
120
+ server_default=None),
121
+ sqlalchemy.Column('user_hash', sqlalchemy.Text),
122
+ sqlalchemy.Column('workspace', sqlalchemy.Text, server_default=None),
123
+ sqlalchemy.Column('priority',
124
+ sqlalchemy.Integer,
125
+ server_default=str(constants.DEFAULT_PRIORITY)),
126
+ sqlalchemy.Column('entrypoint', sqlalchemy.Text, server_default=None),
127
+ sqlalchemy.Column('original_user_yaml_path',
128
+ sqlalchemy.Text,
129
+ server_default=None),
130
+ sqlalchemy.Column('original_user_yaml_content',
131
+ sqlalchemy.Text,
132
+ server_default=None),
133
+ sqlalchemy.Column('pool', sqlalchemy.Text, server_default=None),
134
+ sqlalchemy.Column('current_cluster_name',
135
+ sqlalchemy.Text,
136
+ server_default=None),
137
+ sqlalchemy.Column('job_id_on_pool_cluster',
138
+ sqlalchemy.Integer,
139
+ server_default=None),
140
+ sqlalchemy.Column('pool_hash', sqlalchemy.Text, server_default=None),
141
+ sqlalchemy.Column('controller_logs_cleaned_at',
142
+ sqlalchemy.Float,
143
+ server_default=None),
144
+ )
145
+
146
+ # TODO(cooperc): drop the table in a migration
147
+ ha_recovery_script_table = sqlalchemy.Table(
148
+ 'ha_recovery_script',
149
+ Base.metadata,
150
+ sqlalchemy.Column('job_id', sqlalchemy.Integer, primary_key=True),
151
+ sqlalchemy.Column('script', sqlalchemy.Text),
152
+ )
153
+
154
+
155
+ def create_table(engine: sqlalchemy.engine.Engine):
39
156
  # Enable WAL mode to avoid locking issues.
40
157
  # See: issue #3863, #1441 and PR #1509
41
158
  # https://github.com/microsoft/WSL/issues/2395
42
159
  # TODO(romilb): We do not enable WAL for WSL because of known issue in WSL.
43
160
  # This may cause the database locked problem from WSL issue #1441.
44
- if not common_utils.is_wsl():
161
+ if (engine.dialect.name == db_utils.SQLAlchemyDialect.SQLITE.value and
162
+ not common_utils.is_wsl()):
45
163
  try:
46
- cursor.execute('PRAGMA journal_mode=WAL')
47
- except sqlite3.OperationalError as e:
164
+ with orm.Session(engine) as session:
165
+ session.execute(sqlalchemy.text('PRAGMA journal_mode=WAL'))
166
+ session.execute(sqlalchemy.text('PRAGMA synchronous=1'))
167
+ session.commit()
168
+ except sqlalchemy_exc.OperationalError as e:
48
169
  if 'database is locked' not in str(e):
49
170
  raise
50
171
  # If the database is locked, it is OK to continue, as the WAL mode
51
172
  # is not critical and is likely to be enabled by other processes.
52
173
 
53
- cursor.execute("""\
54
- CREATE TABLE IF NOT EXISTS spot (
55
- job_id INTEGER PRIMARY KEY AUTOINCREMENT,
56
- job_name TEXT,
57
- resources TEXT,
58
- submitted_at FLOAT,
59
- status TEXT,
60
- run_timestamp TEXT CANDIDATE KEY,
61
- start_at FLOAT DEFAULT NULL,
62
- end_at FLOAT DEFAULT NULL,
63
- last_recovered_at FLOAT DEFAULT -1,
64
- recovery_count INTEGER DEFAULT 0,
65
- job_duration FLOAT DEFAULT 0,
66
- failure_reason TEXT,
67
- spot_job_id INTEGER,
68
- task_id INTEGER DEFAULT 0,
69
- task_name TEXT,
70
- specs TEXT,
71
- local_log_file TEXT DEFAULT NULL)""")
72
- conn.commit()
73
-
74
- db_utils.add_column_to_table(cursor, conn, 'spot', 'failure_reason', 'TEXT')
75
- # Create a new column `spot_job_id`, which is the same for tasks of the
76
- # same managed job.
77
- # The original `job_id` no longer has an actual meaning, but only a legacy
78
- # identifier for all tasks in database.
79
- db_utils.add_column_to_table(cursor,
80
- conn,
81
- 'spot',
82
- 'spot_job_id',
83
- 'INTEGER',
84
- copy_from='job_id')
85
- db_utils.add_column_to_table(cursor,
86
- conn,
87
- 'spot',
88
- 'task_id',
89
- 'INTEGER DEFAULT 0',
90
- value_to_replace_existing_entries=0)
91
- db_utils.add_column_to_table(cursor,
92
- conn,
93
- 'spot',
94
- 'task_name',
95
- 'TEXT',
96
- copy_from='job_name')
97
-
98
- # Specs is some useful information about the task, e.g., the
99
- # max_restarts_on_errors value. It is stored in JSON format.
100
- db_utils.add_column_to_table(cursor,
101
- conn,
102
- 'spot',
103
- 'specs',
104
- 'TEXT',
105
- value_to_replace_existing_entries=json.dumps({
106
- 'max_restarts_on_errors': 0,
107
- }))
108
- db_utils.add_column_to_table(cursor, conn, 'spot', 'local_log_file',
109
- 'TEXT DEFAULT NULL')
110
-
111
- # `job_info` contains the mapping from job_id to the job_name, as well as
112
- # information used by the scheduler.
113
- cursor.execute("""\
114
- CREATE TABLE IF NOT EXISTS job_info (
115
- spot_job_id INTEGER PRIMARY KEY AUTOINCREMENT,
116
- name TEXT,
117
- schedule_state TEXT,
118
- controller_pid INTEGER DEFAULT NULL,
119
- dag_yaml_path TEXT,
120
- env_file_path TEXT,
121
- user_hash TEXT)""")
122
-
123
- db_utils.add_column_to_table(cursor, conn, 'job_info', 'schedule_state',
124
- 'TEXT')
125
-
126
- db_utils.add_column_to_table(cursor, conn, 'job_info', 'controller_pid',
127
- 'INTEGER DEFAULT NULL')
128
-
129
- db_utils.add_column_to_table(cursor, conn, 'job_info', 'dag_yaml_path',
130
- 'TEXT')
131
-
132
- db_utils.add_column_to_table(cursor, conn, 'job_info', 'env_file_path',
133
- 'TEXT')
134
-
135
- db_utils.add_column_to_table(cursor, conn, 'job_info', 'user_hash', 'TEXT')
136
-
137
- conn.commit()
138
-
139
-
140
- # Module-level connection/cursor; thread-safe as the module is only imported
141
- # once.
142
- def _get_db_path() -> str:
143
- """Workaround to collapse multi-step Path ops for type checker.
144
- Ensures _DB_PATH is str, avoiding Union[Path, str] inference.
145
- """
146
- path = pathlib.Path('~/.sky/spot_jobs.db')
147
- path = path.expanduser().absolute()
148
- path.parents[0].mkdir(parents=True, exist_ok=True)
149
- return str(path)
174
+ migration_utils.safe_alembic_upgrade(engine,
175
+ migration_utils.SPOT_JOBS_DB_NAME,
176
+ migration_utils.SPOT_JOBS_VERSION)
177
+
150
178
 
179
+ def force_no_postgres() -> bool:
180
+ """Force no postgres.
181
+
182
+ If the db is localhost on the api server, and we are not in consolidation
183
+ mode, we must force using sqlite and not using the api server on the jobs
184
+ controller.
185
+ """
186
+ conn_string = skypilot_config.get_nested(('db',), None)
187
+
188
+ if conn_string:
189
+ parsed = urllib.parse.urlparse(conn_string)
190
+ # it freezes if we use the normal get_consolidation_mode function
191
+ consolidation_mode = skypilot_config.get_nested(
192
+ ('jobs', 'controller', 'consolidation_mode'), default_value=False)
193
+ if ((parsed.hostname == 'localhost' or
194
+ ipaddress.ip_address(parsed.hostname).is_loopback) and
195
+ not consolidation_mode):
196
+ return True
197
+ return False
198
+
199
+
200
+ def initialize_and_get_db_async() -> sql_async.AsyncEngine:
201
+ global _SQLALCHEMY_ENGINE_ASYNC
202
+ if _SQLALCHEMY_ENGINE_ASYNC is not None:
203
+ return _SQLALCHEMY_ENGINE_ASYNC
204
+ with _SQLALCHEMY_ENGINE_LOCK:
205
+ if _SQLALCHEMY_ENGINE_ASYNC is not None:
206
+ return _SQLALCHEMY_ENGINE_ASYNC
207
+
208
+ _SQLALCHEMY_ENGINE_ASYNC = db_utils.get_engine('spot_jobs',
209
+ async_engine=True)
210
+
211
+ # to create the table in case an async function gets called first
212
+ initialize_and_get_db()
213
+ return _SQLALCHEMY_ENGINE_ASYNC
214
+
215
+
216
+ # We wrap the sqlalchemy engine initialization in a thread
217
+ # lock to ensure that multiple threads do not initialize the
218
+ # engine which could result in a rare race condition where
219
+ # a session has already been created with _SQLALCHEMY_ENGINE = e1,
220
+ # and then another thread overwrites _SQLALCHEMY_ENGINE = e2
221
+ # which could result in e1 being garbage collected unexpectedly.
222
+ def initialize_and_get_db() -> sqlalchemy.engine.Engine:
223
+ global _SQLALCHEMY_ENGINE
224
+ if _SQLALCHEMY_ENGINE is not None:
225
+ return _SQLALCHEMY_ENGINE
226
+
227
+ with _SQLALCHEMY_ENGINE_LOCK:
228
+ if _SQLALCHEMY_ENGINE is not None:
229
+ return _SQLALCHEMY_ENGINE
230
+ # get an engine to the db
231
+ engine = db_utils.get_engine('spot_jobs')
232
+
233
+ # run migrations if needed
234
+ create_table(engine)
235
+
236
+ # return engine
237
+ _SQLALCHEMY_ENGINE = engine
238
+ return _SQLALCHEMY_ENGINE
239
+
240
+
241
+ def _init_db_async(func):
242
+ """Initialize the async database. Add backoff to the function call."""
243
+
244
+ @functools.wraps(func)
245
+ async def wrapper(*args, **kwargs):
246
+ if _SQLALCHEMY_ENGINE_ASYNC is None:
247
+ # this may happen multiple times since there is no locking
248
+ # here but thats fine, this is just a short circuit for the
249
+ # common case.
250
+ await context_utils.to_thread(initialize_and_get_db_async)
251
+
252
+ backoff = common_utils.Backoff(initial_backoff=1, max_backoff_factor=5)
253
+ last_exc = None
254
+ for _ in range(_DB_RETRY_TIMES):
255
+ try:
256
+ return await func(*args, **kwargs)
257
+ except (sqlalchemy_exc.OperationalError,
258
+ asyncio.exceptions.TimeoutError, OSError,
259
+ sqlalchemy_exc.TimeoutError, sqlite3.OperationalError,
260
+ sqlalchemy_exc.InterfaceError, sqlite3.InterfaceError) as e:
261
+ last_exc = e
262
+ logger.debug(f'DB error: {last_exc}')
263
+ await asyncio.sleep(backoff.current_backoff())
264
+ assert last_exc is not None
265
+ raise last_exc
266
+
267
+ return wrapper
268
+
269
+
270
+ def _init_db(func):
271
+ """Initialize the database. Add backoff to the function call."""
272
+
273
+ @functools.wraps(func)
274
+ def wrapper(*args, **kwargs):
275
+ if _SQLALCHEMY_ENGINE is None:
276
+ # this may happen multiple times since there is no locking
277
+ # here but thats fine, this is just a short circuit for the
278
+ # common case.
279
+ initialize_and_get_db()
280
+
281
+ backoff = common_utils.Backoff(initial_backoff=1, max_backoff_factor=10)
282
+ last_exc = None
283
+ for _ in range(_DB_RETRY_TIMES):
284
+ try:
285
+ return func(*args, **kwargs)
286
+ except (sqlalchemy_exc.OperationalError,
287
+ asyncio.exceptions.TimeoutError, OSError,
288
+ sqlalchemy_exc.TimeoutError, sqlite3.OperationalError,
289
+ sqlalchemy_exc.InterfaceError, sqlite3.InterfaceError) as e:
290
+ last_exc = e
291
+ logger.debug(f'DB error: {last_exc}')
292
+ time.sleep(backoff.current_backoff())
293
+ assert last_exc is not None
294
+ raise last_exc
295
+
296
+ return wrapper
297
+
298
+
299
+ async def _describe_task_transition_failure(session: sql_async.AsyncSession,
300
+ job_id: int, task_id: int) -> str:
301
+ """Return a human-readable description when a task transition fails."""
302
+ details = 'Couldn\'t fetch the task details.'
303
+ try:
304
+ debug_result = await session.execute(
305
+ sqlalchemy.select(spot_table.c.status, spot_table.c.end_at).where(
306
+ sqlalchemy.and_(spot_table.c.spot_job_id == job_id,
307
+ spot_table.c.task_id == task_id)))
308
+ rows = debug_result.mappings().all()
309
+ details = (f'{len(rows)} rows matched job {job_id} and task '
310
+ f'{task_id}.')
311
+ for row in rows:
312
+ status = row['status']
313
+ end_at = row['end_at']
314
+ details += f' Status: {status}, End time: {end_at}.'
315
+ except Exception as exc: # pylint: disable=broad-except
316
+ details += f' Error fetching task details: {exc}'
317
+ return details
151
318
 
152
- _DB_PATH = _get_db_path()
153
- db_utils.SQLiteConn(_DB_PATH, create_table)
154
319
 
155
320
  # job_duration is the time a job actually runs (including the
156
321
  # setup duration) before last_recover, excluding the provision
@@ -164,33 +329,54 @@ db_utils.SQLiteConn(_DB_PATH, create_table)
164
329
  # e.g., via sky jobs queue. These may not correspond to actual
165
330
  # column names in the DB and it corresponds to the combined view
166
331
  # by joining the spot and job_info tables.
167
- columns = [
168
- '_job_id',
169
- '_task_name',
170
- 'resources',
171
- 'submitted_at',
172
- 'status',
173
- 'run_timestamp',
174
- 'start_at',
175
- 'end_at',
176
- 'last_recovered_at',
177
- 'recovery_count',
178
- 'job_duration',
179
- 'failure_reason',
180
- 'job_id',
181
- 'task_id',
182
- 'task_name',
183
- 'specs',
184
- 'local_log_file',
185
- # columns from the job_info table
186
- '_job_info_job_id', # This should be the same as job_id
187
- 'job_name',
188
- 'schedule_state',
189
- 'controller_pid',
190
- 'dag_yaml_path',
191
- 'env_file_path',
192
- 'user_hash',
193
- ]
332
+ def _get_jobs_dict(r: 'row.RowMapping') -> Dict[str, Any]:
333
+ # WARNING: If you update these you may also need to update GetJobTable in
334
+ # the skylet ManagedJobsServiceImpl.
335
+ return {
336
+ '_job_id': r.get('job_id'), # from spot table
337
+ '_task_name': r.get('job_name'), # deprecated, from spot table
338
+ 'resources': r.get('resources'),
339
+ 'submitted_at': r.get('submitted_at'),
340
+ 'status': r.get('status'),
341
+ 'run_timestamp': r.get('run_timestamp'),
342
+ 'start_at': r.get('start_at'),
343
+ 'end_at': r.get('end_at'),
344
+ 'last_recovered_at': r.get('last_recovered_at'),
345
+ 'recovery_count': r.get('recovery_count'),
346
+ 'job_duration': r.get('job_duration'),
347
+ 'failure_reason': r.get('failure_reason'),
348
+ 'job_id': r.get(spot_table.c.spot_job_id
349
+ ), # ambiguous, use table.column
350
+ 'task_id': r.get('task_id'),
351
+ 'task_name': r.get('task_name'),
352
+ 'specs': r.get('specs'),
353
+ 'local_log_file': r.get('local_log_file'),
354
+ 'metadata': r.get('metadata'),
355
+ # columns from job_info table (some may be None for legacy jobs)
356
+ '_job_info_job_id': r.get(job_info_table.c.spot_job_id
357
+ ), # ambiguous, use table.column
358
+ 'job_name': r.get('name'), # from job_info table
359
+ 'schedule_state': r.get('schedule_state'),
360
+ 'controller_pid': r.get('controller_pid'),
361
+ 'controller_pid_started_at': r.get('controller_pid_started_at'),
362
+ # the _path columns are for backwards compatibility, use the _content
363
+ # columns instead
364
+ 'dag_yaml_path': r.get('dag_yaml_path'),
365
+ 'env_file_path': r.get('env_file_path'),
366
+ 'dag_yaml_content': r.get('dag_yaml_content'),
367
+ 'env_file_content': r.get('env_file_content'),
368
+ 'config_file_content': r.get('config_file_content'),
369
+ 'user_hash': r.get('user_hash'),
370
+ 'workspace': r.get('workspace'),
371
+ 'priority': r.get('priority'),
372
+ 'entrypoint': r.get('entrypoint'),
373
+ 'original_user_yaml_path': r.get('original_user_yaml_path'),
374
+ 'original_user_yaml_content': r.get('original_user_yaml_content'),
375
+ 'pool': r.get('pool'),
376
+ 'current_cluster_name': r.get('current_cluster_name'),
377
+ 'job_id_on_pool_cluster': r.get('job_id_on_pool_cluster'),
378
+ 'pool_hash': r.get('pool_hash'),
379
+ }
194
380
 
195
381
 
196
382
  class ManagedJobStatus(enum.Enum):
@@ -206,7 +392,7 @@ class ManagedJobStatus(enum.Enum):
206
392
  reset to INIT or SETTING_UP multiple times (depending on the preemptions).
207
393
 
208
394
  However, a managed job only has one ManagedJobStatus on the jobs controller.
209
- ManagedJobStatus = [PENDING, SUBMITTED, STARTING, RUNNING, ...]
395
+ ManagedJobStatus = [PENDING, STARTING, RUNNING, ...]
210
396
  Mapping from JobStatus to ManagedJobStatus:
211
397
  INIT -> STARTING/RECOVERING
212
398
  SETTING_UP -> RUNNING
@@ -226,10 +412,14 @@ class ManagedJobStatus(enum.Enum):
226
412
  # PENDING: Waiting for the jobs controller to have a slot to run the
227
413
  # controller process.
228
414
  PENDING = 'PENDING'
415
+ # SUBMITTED: This state used to be briefly set before immediately changing
416
+ # to STARTING. Its use was removed in #5682. We keep it for backwards
417
+ # compatibility, so we can still parse old jobs databases that may have jobs
418
+ # in this state.
419
+ # TODO(cooperc): remove this in v0.12.0
420
+ DEPRECATED_SUBMITTED = 'SUBMITTED'
229
421
  # The submitted_at timestamp of the managed job in the 'spot' table will be
230
422
  # set to the time when the job controller begins running.
231
- # SUBMITTED: The jobs controller starts the controller process.
232
- SUBMITTED = 'SUBMITTED'
233
423
  # STARTING: The controller process is launching the cluster for the managed
234
424
  # job.
235
425
  STARTING = 'STARTING'
@@ -302,10 +492,88 @@ class ManagedJobStatus(enum.Enum):
302
492
  cls.FAILED_NO_RESOURCE, cls.FAILED_CONTROLLER
303
493
  ]
304
494
 
495
+ @classmethod
496
+ def processing_statuses(cls) -> List['ManagedJobStatus']:
497
+ # Any status that is not terminal and is not CANCELLING.
498
+ return [
499
+ cls.PENDING,
500
+ cls.STARTING,
501
+ cls.RUNNING,
502
+ cls.RECOVERING,
503
+ ]
504
+
505
+ @classmethod
506
+ def from_protobuf(
507
+ cls, protobuf_value: 'managed_jobsv1_pb2.ManagedJobStatus'
508
+ ) -> Optional['ManagedJobStatus']:
509
+ """Convert protobuf ManagedJobStatus enum to Python enum value."""
510
+ protobuf_to_enum = {
511
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_UNSPECIFIED: None,
512
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_PENDING: cls.PENDING,
513
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_SUBMITTED:
514
+ cls.DEPRECATED_SUBMITTED,
515
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_STARTING: cls.STARTING,
516
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_RUNNING: cls.RUNNING,
517
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_SUCCEEDED: cls.SUCCEEDED,
518
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED: cls.FAILED,
519
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_CONTROLLER:
520
+ cls.FAILED_CONTROLLER,
521
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_SETUP:
522
+ cls.FAILED_SETUP,
523
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_CANCELLED: cls.CANCELLED,
524
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_RECOVERING: cls.RECOVERING,
525
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_CANCELLING: cls.CANCELLING,
526
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_PRECHECKS:
527
+ cls.FAILED_PRECHECKS,
528
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_NO_RESOURCE:
529
+ cls.FAILED_NO_RESOURCE,
530
+ }
531
+
532
+ if protobuf_value not in protobuf_to_enum:
533
+ raise ValueError(
534
+ f'Unknown protobuf ManagedJobStatus value: {protobuf_value}')
535
+
536
+ return protobuf_to_enum[protobuf_value]
537
+
538
+ def to_protobuf(self) -> 'managed_jobsv1_pb2.ManagedJobStatus':
539
+ """Convert this Python enum value to protobuf enum value."""
540
+ enum_to_protobuf = {
541
+ ManagedJobStatus.PENDING:
542
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_PENDING,
543
+ ManagedJobStatus.DEPRECATED_SUBMITTED:
544
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_SUBMITTED,
545
+ ManagedJobStatus.STARTING:
546
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_STARTING,
547
+ ManagedJobStatus.RUNNING:
548
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_RUNNING,
549
+ ManagedJobStatus.SUCCEEDED:
550
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_SUCCEEDED,
551
+ ManagedJobStatus.FAILED:
552
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED,
553
+ ManagedJobStatus.FAILED_CONTROLLER:
554
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_CONTROLLER,
555
+ ManagedJobStatus.FAILED_SETUP:
556
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_SETUP,
557
+ ManagedJobStatus.CANCELLED:
558
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_CANCELLED,
559
+ ManagedJobStatus.RECOVERING:
560
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_RECOVERING,
561
+ ManagedJobStatus.CANCELLING:
562
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_CANCELLING,
563
+ ManagedJobStatus.FAILED_PRECHECKS:
564
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_PRECHECKS,
565
+ ManagedJobStatus.FAILED_NO_RESOURCE:
566
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_NO_RESOURCE,
567
+ }
568
+
569
+ if self not in enum_to_protobuf:
570
+ raise ValueError(f'Unknown ManagedJobStatus value: {self}')
571
+
572
+ return enum_to_protobuf[self]
573
+
305
574
 
306
575
  _SPOT_STATUS_TO_COLOR = {
307
576
  ManagedJobStatus.PENDING: colorama.Fore.BLUE,
308
- ManagedJobStatus.SUBMITTED: colorama.Fore.BLUE,
309
577
  ManagedJobStatus.STARTING: colorama.Fore.BLUE,
310
578
  ManagedJobStatus.RUNNING: colorama.Fore.GREEN,
311
579
  ManagedJobStatus.RECOVERING: colorama.Fore.CYAN,
@@ -317,14 +585,14 @@ _SPOT_STATUS_TO_COLOR = {
317
585
  ManagedJobStatus.FAILED_CONTROLLER: colorama.Fore.RED,
318
586
  ManagedJobStatus.CANCELLING: colorama.Fore.YELLOW,
319
587
  ManagedJobStatus.CANCELLED: colorama.Fore.YELLOW,
588
+ # TODO(cooperc): backwards compatibility, remove this in v0.12.0
589
+ ManagedJobStatus.DEPRECATED_SUBMITTED: colorama.Fore.BLUE,
320
590
  }
321
591
 
322
592
 
323
593
  class ManagedJobScheduleState(enum.Enum):
324
594
  """Captures the state of the job from the scheduler's perspective.
325
595
 
326
- A job that predates the introduction of the scheduler will be INVALID.
327
-
328
596
  A newly created job will be INACTIVE. The following transitions are valid:
329
597
  - INACTIVE -> WAITING: The job is "submitted" to the scheduler, and its job
330
598
  controller can be started.
@@ -333,8 +601,12 @@ class ManagedJobScheduleState(enum.Enum):
333
601
  - LAUNCHING -> ALIVE: The launch attempt was completed. It may have
334
602
  succeeded or failed. The job controller is not allowed to sky.launch again
335
603
  without transitioning to ALIVE_WAITING and then LAUNCHING.
604
+ - LAUNCHING -> ALIVE_BACKOFF: The launch failed to find resources, and is
605
+ in backoff waiting for resources.
336
606
  - ALIVE -> ALIVE_WAITING: The job controller wants to sky.launch again,
337
607
  either for recovery or to launch a subsequent task.
608
+ - ALIVE_BACKOFF -> ALIVE_WAITING: The backoff period has ended, and the job
609
+ controller wants to try to launch again.
338
610
  - ALIVE_WAITING -> LAUNCHING: The scheduler has determined that the job
339
611
  controller may launch again.
340
612
  - LAUNCHING, ALIVE, or ALIVE_WAITING -> DONE: The job controller is exiting
@@ -348,6 +620,7 @@ class ManagedJobScheduleState(enum.Enum):
348
620
  state or vice versa. (In fact, schedule state is defined on the job and
349
621
  status on the task.)
350
622
  - INACTIVE or WAITING should only be seen when a job is PENDING.
623
+ - ALIVE_BACKOFF should only be seen when a job is STARTING.
351
624
  - ALIVE_WAITING should only be seen when a job is RECOVERING, has multiple
352
625
  tasks, or needs to retry launching.
353
626
  - LAUNCHING and ALIVE can be seen in many different statuses.
@@ -356,10 +629,10 @@ class ManagedJobScheduleState(enum.Enum):
356
629
  briefly observe inconsistent states, like a job that just finished but
357
630
  hasn't yet transitioned to DONE.
358
631
  """
359
- # This job may have been created before scheduler was introduced in #4458.
360
- # This state is not used by scheduler but just for backward compatibility.
361
- # TODO(cooperc): remove this in v0.11.0
362
- INVALID = None
632
+ # TODO(luca): the only states we need are INACTIVE, WAITING, ALIVE, and
633
+ # DONE. ALIVE = old LAUNCHING + ALIVE + ALIVE_BACKOFF + ALIVE_WAITING and
634
+ # will represent jobs that are claimed by a controller. Delete the rest
635
+ # in v0.13.0
363
636
  # The job should be ignored by the scheduler.
364
637
  INACTIVE = 'INACTIVE'
365
638
  # The job is waiting to transition to LAUNCHING for the first time. The
@@ -373,194 +646,211 @@ class ManagedJobScheduleState(enum.Enum):
373
646
  # The job is running sky.launch, or soon will, using a limited number of
374
647
  # allowed launch slots.
375
648
  LAUNCHING = 'LAUNCHING'
649
+ # The job is alive, but is in backoff waiting for resources - a special case
650
+ # of ALIVE.
651
+ ALIVE_BACKOFF = 'ALIVE_BACKOFF'
376
652
  # The controller for the job is running, but it's not currently launching.
377
653
  ALIVE = 'ALIVE'
378
654
  # The job is in a terminal state. (Not necessarily SUCCEEDED.)
379
655
  DONE = 'DONE'
380
656
 
657
+ @classmethod
658
+ def from_protobuf(
659
+ cls, protobuf_value: 'managed_jobsv1_pb2.ManagedJobScheduleState'
660
+ ) -> Optional['ManagedJobScheduleState']:
661
+ """Convert protobuf ManagedJobScheduleState enum to Python enum value.
662
+ """
663
+ protobuf_to_enum = {
664
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_UNSPECIFIED: None,
665
+ # TODO(cooperc): remove this in v0.13.0. See #8105.
666
+ managed_jobsv1_pb2.DEPRECATED_MANAGED_JOB_SCHEDULE_STATE_INVALID:
667
+ (None),
668
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_INACTIVE:
669
+ cls.INACTIVE,
670
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_WAITING: cls.WAITING,
671
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_ALIVE_WAITING:
672
+ cls.ALIVE_WAITING,
673
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_LAUNCHING:
674
+ cls.LAUNCHING,
675
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_ALIVE_BACKOFF:
676
+ cls.ALIVE_BACKOFF,
677
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_ALIVE: cls.ALIVE,
678
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_DONE: cls.DONE,
679
+ }
680
+
681
+ if protobuf_value not in protobuf_to_enum:
682
+ raise ValueError('Unknown protobuf ManagedJobScheduleState value: '
683
+ f'{protobuf_value}')
684
+
685
+ return protobuf_to_enum[protobuf_value]
686
+
687
+ def to_protobuf(self) -> 'managed_jobsv1_pb2.ManagedJobScheduleState':
688
+ """Convert this Python enum value to protobuf enum value."""
689
+ enum_to_protobuf = {
690
+ ManagedJobScheduleState.INACTIVE:
691
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_INACTIVE,
692
+ ManagedJobScheduleState.WAITING:
693
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_WAITING,
694
+ ManagedJobScheduleState.ALIVE_WAITING:
695
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_ALIVE_WAITING,
696
+ ManagedJobScheduleState.LAUNCHING:
697
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_LAUNCHING,
698
+ ManagedJobScheduleState.ALIVE_BACKOFF:
699
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_ALIVE_BACKOFF,
700
+ ManagedJobScheduleState.ALIVE:
701
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_ALIVE,
702
+ ManagedJobScheduleState.DONE:
703
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_DONE,
704
+ }
705
+
706
+ if self not in enum_to_protobuf:
707
+ raise ValueError(f'Unknown ManagedJobScheduleState value: {self}')
708
+
709
+ return enum_to_protobuf[self]
710
+
711
+
712
+ ControllerPidRecord = collections.namedtuple('ControllerPidRecord', [
713
+ 'pid',
714
+ 'started_at',
715
+ ])
716
+
381
717
 
382
718
  # === Status transition functions ===
383
- def set_job_info(job_id: int, name: str):
384
- with db_utils.safe_cursor(_DB_PATH) as cursor:
385
- cursor.execute(
386
- """\
387
- INSERT INTO job_info
388
- (spot_job_id, name, schedule_state)
389
- VALUES (?, ?, ?)""",
390
- (job_id, name, ManagedJobScheduleState.INACTIVE.value))
719
+ @_init_db
720
+ def set_job_info_without_job_id(name: str, workspace: str, entrypoint: str,
721
+ pool: Optional[str], pool_hash: Optional[str],
722
+ user_hash: Optional[str]) -> int:
723
+ assert _SQLALCHEMY_ENGINE is not None
724
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
725
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
726
+ db_utils.SQLAlchemyDialect.SQLITE.value):
727
+ insert_func = sqlite.insert
728
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
729
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
730
+ insert_func = postgresql.insert
731
+ else:
732
+ raise ValueError('Unsupported database dialect')
733
+
734
+ insert_stmt = insert_func(job_info_table).values(
735
+ name=name,
736
+ schedule_state=ManagedJobScheduleState.INACTIVE.value,
737
+ workspace=workspace,
738
+ entrypoint=entrypoint,
739
+ pool=pool,
740
+ pool_hash=pool_hash,
741
+ user_hash=user_hash,
742
+ )
743
+
744
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
745
+ db_utils.SQLAlchemyDialect.SQLITE.value):
746
+ result = session.execute(insert_stmt)
747
+ ret = result.lastrowid
748
+ session.commit()
749
+ return ret
750
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
751
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
752
+ result = session.execute(
753
+ insert_stmt.returning(job_info_table.c.spot_job_id))
754
+ ret = result.scalar()
755
+ session.commit()
756
+ return ret
757
+ else:
758
+ raise ValueError('Unsupported database dialect')
391
759
 
392
760
 
393
- def set_pending(job_id: int, task_id: int, task_name: str, resources_str: str):
761
+ @_init_db
762
+ def set_pending(
763
+ job_id: int,
764
+ task_id: int,
765
+ task_name: str,
766
+ resources_str: str,
767
+ metadata: str,
768
+ ):
394
769
  """Set the task to pending state."""
395
- with db_utils.safe_cursor(_DB_PATH) as cursor:
396
- cursor.execute(
397
- """\
398
- INSERT INTO spot
399
- (spot_job_id, task_id, task_name, resources, status)
400
- VALUES (?, ?, ?, ?, ?)""",
401
- (job_id, task_id, task_name, resources_str,
402
- ManagedJobStatus.PENDING.value))
403
-
404
-
405
- def set_submitted(job_id: int, task_id: int, run_timestamp: str,
406
- submit_time: float, resources_str: str,
407
- specs: Dict[str, Union[str,
408
- int]], callback_func: CallbackType):
409
- """Set the task to submitted.
410
-
411
- Args:
412
- job_id: The managed job ID.
413
- task_id: The task ID.
414
- run_timestamp: The run_timestamp of the run. This will be used to
415
- determine the log directory of the managed task.
416
- submit_time: The time when the managed task is submitted.
417
- resources_str: The resources string of the managed task.
418
- specs: The specs of the managed task.
419
- callback_func: The callback function.
770
+ assert _SQLALCHEMY_ENGINE is not None
771
+
772
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
773
+ session.execute(
774
+ sqlalchemy.insert(spot_table).values(
775
+ spot_job_id=job_id,
776
+ task_id=task_id,
777
+ task_name=task_name,
778
+ resources=resources_str,
779
+ metadata=metadata,
780
+ status=ManagedJobStatus.PENDING.value,
781
+ ))
782
+ session.commit()
783
+
784
+
785
+ @_init_db_async
786
+ async def set_backoff_pending_async(job_id: int, task_id: int):
787
+ """Set the task to PENDING state if it is in backoff.
788
+
789
+ This should only be used to transition from STARTING or RECOVERING back to
790
+ PENDING.
420
791
  """
421
- # Use the timestamp in the `run_timestamp` ('sky-2022-10...'), to make
422
- # the log directory and submission time align with each other, so as to
423
- # make it easier to find them based on one of the values.
424
- # Also, using the earlier timestamp should be closer to the term
425
- # `submit_at`, which represents the time the managed task is submitted.
426
- with db_utils.safe_cursor(_DB_PATH) as cursor:
427
- cursor.execute(
428
- """\
429
- UPDATE spot SET
430
- resources=(?),
431
- submitted_at=(?),
432
- status=(?),
433
- run_timestamp=(?),
434
- specs=(?)
435
- WHERE spot_job_id=(?) AND
436
- task_id=(?) AND
437
- status=(?) AND
438
- end_at IS null""",
439
- (resources_str, submit_time, ManagedJobStatus.SUBMITTED.value,
440
- run_timestamp, json.dumps(specs), job_id, task_id,
441
- ManagedJobStatus.PENDING.value))
442
- if cursor.rowcount != 1:
443
- raise exceptions.ManagedJobStatusError(
444
- f'Failed to set the task to submitted. '
445
- f'({cursor.rowcount} rows updated)')
446
- callback_func('SUBMITTED')
447
-
448
-
449
- def set_starting(job_id: int, task_id: int, callback_func: CallbackType):
450
- """Set the task to starting state."""
451
- logger.info('Launching the spot cluster...')
452
- with db_utils.safe_cursor(_DB_PATH) as cursor:
453
- cursor.execute(
454
- """\
455
- UPDATE spot SET status=(?)
456
- WHERE spot_job_id=(?) AND
457
- task_id=(?) AND
458
- status=(?) AND
459
- end_at IS null""", (ManagedJobStatus.STARTING.value, job_id,
460
- task_id, ManagedJobStatus.SUBMITTED.value))
461
- if cursor.rowcount != 1:
462
- raise exceptions.ManagedJobStatusError(
463
- f'Failed to set the task to starting. '
464
- f'({cursor.rowcount} rows updated)')
465
- callback_func('STARTING')
466
-
467
-
468
- def set_started(job_id: int, task_id: int, start_time: float,
469
- callback_func: CallbackType):
470
- """Set the task to started state."""
471
- logger.info('Job started.')
472
- with db_utils.safe_cursor(_DB_PATH) as cursor:
473
- cursor.execute(
474
- """\
475
- UPDATE spot SET status=(?), start_at=(?), last_recovered_at=(?)
476
- WHERE spot_job_id=(?) AND
477
- task_id=(?) AND
478
- status IN (?, ?) AND
479
- end_at IS null""",
480
- (
481
- ManagedJobStatus.RUNNING.value,
482
- start_time,
483
- start_time,
484
- job_id,
485
- task_id,
486
- ManagedJobStatus.STARTING.value,
487
- # If the task is empty, we will jump straight from PENDING to
488
- # RUNNING
489
- ManagedJobStatus.PENDING.value,
490
- ),
792
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
793
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
794
+ result = await session.execute(
795
+ sqlalchemy.update(spot_table).where(
796
+ sqlalchemy.and_(
797
+ spot_table.c.spot_job_id == job_id,
798
+ spot_table.c.task_id == task_id,
799
+ spot_table.c.status.in_([
800
+ ManagedJobStatus.STARTING.value,
801
+ ManagedJobStatus.RECOVERING.value
802
+ ]),
803
+ spot_table.c.end_at.is_(None),
804
+ )).values({spot_table.c.status: ManagedJobStatus.PENDING.value})
491
805
  )
492
- if cursor.rowcount != 1:
493
- raise exceptions.ManagedJobStatusError(
494
- f'Failed to set the task to started. '
495
- f'({cursor.rowcount} rows updated)')
496
- callback_func('STARTED')
497
-
498
-
499
- def set_recovering(job_id: int, task_id: int, callback_func: CallbackType):
500
- """Set the task to recovering state, and update the job duration."""
501
- logger.info('=== Recovering... ===')
502
- with db_utils.safe_cursor(_DB_PATH) as cursor:
503
- cursor.execute(
504
- """\
505
- UPDATE spot SET
506
- status=(?), job_duration=job_duration+(?)-last_recovered_at
507
- WHERE spot_job_id=(?) AND
508
- task_id=(?) AND
509
- status=(?) AND
510
- end_at IS null""",
511
- (ManagedJobStatus.RECOVERING.value, time.time(), job_id, task_id,
512
- ManagedJobStatus.RUNNING.value))
513
- if cursor.rowcount != 1:
514
- raise exceptions.ManagedJobStatusError(
515
- f'Failed to set the task to recovering. '
516
- f'({cursor.rowcount} rows updated)')
517
- callback_func('RECOVERING')
518
-
519
-
520
- def set_recovered(job_id: int, task_id: int, recovered_time: float,
521
- callback_func: CallbackType):
522
- """Set the task to recovered."""
523
- with db_utils.safe_cursor(_DB_PATH) as cursor:
524
- cursor.execute(
525
- """\
526
- UPDATE spot SET
527
- status=(?), last_recovered_at=(?), recovery_count=recovery_count+1
528
- WHERE spot_job_id=(?) AND
529
- task_id=(?) AND
530
- status=(?) AND
531
- end_at IS null""",
532
- (ManagedJobStatus.RUNNING.value, recovered_time, job_id, task_id,
533
- ManagedJobStatus.RECOVERING.value))
534
- if cursor.rowcount != 1:
535
- raise exceptions.ManagedJobStatusError(
536
- f'Failed to set the task to recovered. '
537
- f'({cursor.rowcount} rows updated)')
538
- logger.info('==== Recovered. ====')
539
- callback_func('RECOVERED')
540
-
541
-
542
- def set_succeeded(job_id: int, task_id: int, end_time: float,
543
- callback_func: CallbackType):
544
- """Set the task to succeeded, if it is in a non-terminal state."""
545
- with db_utils.safe_cursor(_DB_PATH) as cursor:
546
- cursor.execute(
547
- """\
548
- UPDATE spot SET
549
- status=(?), end_at=(?)
550
- WHERE spot_job_id=(?) AND
551
- task_id=(?) AND
552
- status=(?) AND
553
- end_at IS null""",
554
- (ManagedJobStatus.SUCCEEDED.value, end_time, job_id, task_id,
555
- ManagedJobStatus.RUNNING.value))
556
- if cursor.rowcount != 1:
557
- raise exceptions.ManagedJobStatusError(
558
- f'Failed to set the task to succeeded. '
559
- f'({cursor.rowcount} rows updated)')
560
- callback_func('SUCCEEDED')
561
- logger.info('Job succeeded.')
562
-
563
-
806
+ count = result.rowcount
807
+ await session.commit()
808
+ if count != 1:
809
+ details = await _describe_task_transition_failure(
810
+ session, job_id, task_id)
811
+ message = ('Failed to set the task back to pending. '
812
+ f'({count} rows updated. {details})')
813
+ logger.error(message)
814
+ raise exceptions.ManagedJobStatusError(message)
815
+ # Do not call callback_func here, as we don't use the callback for PENDING.
816
+
817
+
818
+ @_init_db
819
+ async def set_restarting_async(job_id: int, task_id: int, recovering: bool):
820
+ """Set the task back to STARTING or RECOVERING from PENDING.
821
+
822
+ This should not be used for the initial transition from PENDING to STARTING.
823
+ In that case, use set_starting instead. This function should only be used
824
+ after using set_backoff_pending to transition back to PENDING during
825
+ launch retry backoff.
826
+ """
827
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
828
+ target_status = ManagedJobStatus.STARTING.value
829
+ if recovering:
830
+ target_status = ManagedJobStatus.RECOVERING.value
831
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
832
+ result = await session.execute(
833
+ sqlalchemy.update(spot_table).where(
834
+ sqlalchemy.and_(
835
+ spot_table.c.spot_job_id == job_id,
836
+ spot_table.c.task_id == task_id,
837
+ spot_table.c.end_at.is_(None),
838
+ )).values({spot_table.c.status: target_status}))
839
+ count = result.rowcount
840
+ await session.commit()
841
+ logger.debug(f'back to {target_status}')
842
+ if count != 1:
843
+ details = await _describe_task_transition_failure(
844
+ session, job_id, task_id)
845
+ message = (f'Failed to set the task back to {target_status}. '
846
+ f'({count} rows updated. {details})')
847
+ logger.error(message)
848
+ raise exceptions.ManagedJobStatusError(message)
849
+ # Do not call callback_func here, as it should only be invoked for the
850
+ # initial (pre-`set_backoff_pending`) transition to STARTING or RECOVERING.
851
+
852
+
853
+ @_init_db
564
854
  def set_failed(
565
855
  job_id: int,
566
856
  task_id: Optional[int],
@@ -585,188 +875,165 @@ def set_failed(
585
875
  override_terminal: If True, override the current status even if end_at
586
876
  is already set.
587
877
  """
878
+ assert _SQLALCHEMY_ENGINE is not None
588
879
  assert failure_type.is_failed(), failure_type
589
880
  end_time = time.time() if end_time is None else end_time
590
881
 
591
882
  fields_to_set: Dict[str, Any] = {
592
- 'status': failure_type.value,
593
- 'failure_reason': failure_reason,
883
+ spot_table.c.status: failure_type.value,
884
+ spot_table.c.failure_reason: failure_reason,
594
885
  }
595
- with db_utils.safe_cursor(_DB_PATH) as cursor:
596
- previous_status = cursor.execute(
597
- 'SELECT status FROM spot WHERE spot_job_id=(?)',
598
- (job_id,)).fetchone()[0]
886
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
887
+ # Get previous status
888
+ previous_status = session.execute(
889
+ sqlalchemy.select(spot_table.c.status).where(
890
+ spot_table.c.spot_job_id == job_id)).fetchone()[0]
599
891
  previous_status = ManagedJobStatus(previous_status)
600
892
  if previous_status == ManagedJobStatus.RECOVERING:
601
893
  # If the job is recovering, we should set the last_recovered_at to
602
894
  # the end_time, so that the end_at - last_recovered_at will not be
603
895
  # affect the job duration calculation.
604
- fields_to_set['last_recovered_at'] = end_time
605
- set_str = ', '.join(f'{k}=(?)' for k in fields_to_set)
606
- task_query_str = '' if task_id is None else 'AND task_id=(?)'
607
- task_value = [] if task_id is None else [
608
- task_id,
609
- ]
896
+ fields_to_set[spot_table.c.last_recovered_at] = end_time
897
+ where_conditions = [spot_table.c.spot_job_id == job_id]
898
+ if task_id is not None:
899
+ where_conditions.append(spot_table.c.task_id == task_id)
610
900
 
901
+ # Handle failure_reason prepending when override_terminal is True
611
902
  if override_terminal:
903
+ # Get existing failure_reason with row lock to prevent race
904
+ # conditions
905
+ existing_reason_result = session.execute(
906
+ sqlalchemy.select(spot_table.c.failure_reason).where(
907
+ sqlalchemy.and_(*where_conditions)).with_for_update())
908
+ existing_reason_row = existing_reason_result.fetchone()
909
+ if existing_reason_row and existing_reason_row[0]:
910
+ # Prepend new failure reason to existing one
911
+ fields_to_set[spot_table.c.failure_reason] = (
912
+ failure_reason + '. Previously: ' + existing_reason_row[0])
612
913
  # Use COALESCE for end_at to avoid overriding the existing end_at if
613
914
  # it's already set.
614
- cursor.execute(
615
- f"""\
616
- UPDATE spot SET
617
- end_at = COALESCE(end_at, ?),
618
- {set_str}
619
- WHERE spot_job_id=(?) {task_query_str}""",
620
- (end_time, *list(fields_to_set.values()), job_id, *task_value))
915
+ fields_to_set[spot_table.c.end_at] = sqlalchemy.func.coalesce(
916
+ spot_table.c.end_at, end_time)
621
917
  else:
622
- # Only set if end_at is null, i.e. the previous status is not
623
- # terminal.
624
- cursor.execute(
625
- f"""\
626
- UPDATE spot SET
627
- end_at = (?),
628
- {set_str}
629
- WHERE spot_job_id=(?) {task_query_str} AND end_at IS null""",
630
- (end_time, *list(fields_to_set.values()), job_id, *task_value))
631
-
632
- updated = cursor.rowcount > 0
918
+ fields_to_set[spot_table.c.end_at] = end_time
919
+ where_conditions.append(spot_table.c.end_at.is_(None))
920
+ count = session.query(spot_table).filter(
921
+ sqlalchemy.and_(*where_conditions)).update(fields_to_set)
922
+ session.commit()
923
+ updated = count > 0
633
924
  if callback_func and updated:
634
925
  callback_func('FAILED')
635
926
  logger.info(failure_reason)
636
927
 
637
928
 
638
- def set_cancelling(job_id: int, callback_func: CallbackType):
639
- """Set tasks in the job as cancelling, if they are in non-terminal states.
640
-
641
- task_id is not needed, because we expect the job should be cancelled
642
- as a whole, and we should not cancel a single task.
643
- """
644
- with db_utils.safe_cursor(_DB_PATH) as cursor:
645
- rows = cursor.execute(
646
- """\
647
- UPDATE spot SET
648
- status=(?)
649
- WHERE spot_job_id=(?) AND end_at IS null""",
650
- (ManagedJobStatus.CANCELLING.value, job_id))
651
- updated = rows.rowcount > 0
652
- if updated:
653
- logger.info('Cancelling the job...')
654
- callback_func('CANCELLING')
655
- else:
656
- logger.info('Cancellation skipped, job is already terminal')
657
-
929
+ @_init_db
930
+ def set_pending_cancelled(job_id: int):
931
+ """Set the job as cancelled, if it is PENDING and WAITING/INACTIVE.
658
932
 
659
- def set_cancelled(job_id: int, callback_func: CallbackType):
660
- """Set tasks in the job as cancelled, if they are in CANCELLING state.
933
+ This may fail if the job is not PENDING, e.g. another process has changed
934
+ its state in the meantime.
661
935
 
662
- The set_cancelling should be called before this function.
936
+ Returns:
937
+ True if the job was cancelled, False otherwise.
663
938
  """
664
- with db_utils.safe_cursor(_DB_PATH) as cursor:
665
- rows = cursor.execute(
666
- """\
667
- UPDATE spot SET
668
- status=(?), end_at=(?)
669
- WHERE spot_job_id=(?) AND status=(?)""",
670
- (ManagedJobStatus.CANCELLED.value, time.time(), job_id,
671
- ManagedJobStatus.CANCELLING.value))
672
- updated = rows.rowcount > 0
673
- if updated:
674
- logger.info('Job cancelled.')
675
- callback_func('CANCELLED')
676
- else:
677
- logger.info('Cancellation skipped, job is not CANCELLING')
939
+ assert _SQLALCHEMY_ENGINE is not None
940
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
941
+ # Subquery to get the spot_job_ids that match the joined condition
942
+ subquery = session.query(spot_table.c.job_id).join(
943
+ job_info_table,
944
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id
945
+ ).filter(
946
+ spot_table.c.spot_job_id == job_id,
947
+ spot_table.c.status == ManagedJobStatus.PENDING.value,
948
+ # Note: it's possible that a WAITING job actually needs to be
949
+ # cleaned up, if we are in the middle of an upgrade/recovery and
950
+ # the job is waiting to be reclaimed by a new controller. But,
951
+ # in this case the status will not be PENDING.
952
+ sqlalchemy.or_(
953
+ job_info_table.c.schedule_state ==
954
+ ManagedJobScheduleState.WAITING.value,
955
+ job_info_table.c.schedule_state ==
956
+ ManagedJobScheduleState.INACTIVE.value,
957
+ ),
958
+ ).subquery()
959
+
960
+ count = session.query(spot_table).filter(
961
+ spot_table.c.job_id.in_(subquery)).update(
962
+ {spot_table.c.status: ManagedJobStatus.CANCELLED.value},
963
+ synchronize_session=False)
964
+ session.commit()
965
+ return count > 0
678
966
 
679
967
 
968
+ @_init_db
680
969
  def set_local_log_file(job_id: int, task_id: Optional[int],
681
970
  local_log_file: str):
682
971
  """Set the local log file for a job."""
683
- filter_str = 'spot_job_id=(?)'
684
- filter_args = [local_log_file, job_id]
685
- if task_id is not None:
686
- filter_str += ' AND task_id=(?)'
687
- filter_args.append(task_id)
688
- with db_utils.safe_cursor(_DB_PATH) as cursor:
689
- cursor.execute(
690
- 'UPDATE spot SET local_log_file=(?) '
691
- f'WHERE {filter_str}', filter_args)
972
+ assert _SQLALCHEMY_ENGINE is not None
973
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
974
+ where_conditions = [spot_table.c.spot_job_id == job_id]
975
+ if task_id is not None:
976
+ where_conditions.append(spot_table.c.task_id == task_id)
977
+ session.query(spot_table).filter(
978
+ sqlalchemy.and_(*where_conditions)).update(
979
+ {spot_table.c.local_log_file: local_log_file})
980
+ session.commit()
692
981
 
693
982
 
694
983
  # ======== utility functions ========
984
+ @_init_db
695
985
  def get_nonterminal_job_ids_by_name(name: Optional[str],
986
+ user_hash: Optional[str] = None,
696
987
  all_users: bool = False) -> List[int]:
697
- """Get non-terminal job ids by name."""
698
- statuses = ', '.join(['?'] * len(ManagedJobStatus.terminal_statuses()))
699
- field_values = [
700
- status.value for status in ManagedJobStatus.terminal_statuses()
701
- ]
702
-
703
- job_filter = ''
704
- if name is None and not all_users:
705
- job_filter += 'AND (job_info.user_hash=(?)) '
706
- field_values.append(common_utils.get_user_hash())
707
- if name is not None:
708
- # We match the job name from `job_info` for the jobs submitted after
709
- # #1982, and from `spot` for the jobs submitted before #1982, whose
710
- # job_info is not available.
711
- job_filter += ('AND (job_info.name=(?) OR '
712
- '(job_info.name IS NULL AND spot.task_name=(?))) ')
713
- field_values.extend([name, name])
714
-
715
- # Left outer join is used here instead of join, because the job_info does
716
- # not contain the managed jobs submitted before #1982.
717
- with db_utils.safe_cursor(_DB_PATH) as cursor:
718
- rows = cursor.execute(
719
- f"""\
720
- SELECT DISTINCT spot.spot_job_id
721
- FROM spot
722
- LEFT OUTER JOIN job_info
723
- ON spot.spot_job_id=job_info.spot_job_id
724
- WHERE status NOT IN
725
- ({statuses})
726
- {job_filter}
727
- ORDER BY spot.spot_job_id DESC""", field_values).fetchall()
728
- job_ids = [row[0] for row in rows if row[0] is not None]
729
- return job_ids
988
+ """Get non-terminal job ids by name.
730
989
 
731
-
732
- def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]:
733
- """Get jobs from the database that have a live schedule_state.
734
-
735
- This should return job(s) that are not INACTIVE, WAITING, or DONE. So a
736
- returned job should correspond to a live job controller process, with one
737
- exception: the job may have just transitioned from WAITING to LAUNCHING, but
738
- the controller process has not yet started.
990
+ If name is None:
991
+ 1. if all_users is False, get for the given user_hash
992
+ 2. otherwise, get for all users
739
993
  """
740
- job_filter = '' if job_id is None else 'AND spot_job_id=(?)'
741
- job_value = (job_id,) if job_id is not None else ()
742
-
743
- # Join spot and job_info tables to get the job name for each task.
744
- # We use LEFT OUTER JOIN mainly for backward compatibility, as for an
745
- # existing controller before #1982, the job_info table may not exist,
746
- # and all the managed jobs created before will not present in the
747
- # job_info.
748
- with db_utils.safe_cursor(_DB_PATH) as cursor:
749
- rows = cursor.execute(
750
- f"""\
751
- SELECT spot_job_id, schedule_state, controller_pid
752
- FROM job_info
753
- WHERE schedule_state not in (?, ?, ?)
754
- {job_filter}
755
- ORDER BY spot_job_id DESC""",
756
- (ManagedJobScheduleState.INACTIVE.value,
757
- ManagedJobScheduleState.WAITING.value,
758
- ManagedJobScheduleState.DONE.value, *job_value)).fetchall()
759
- jobs = []
760
- for row in rows:
761
- job_dict = {
762
- 'job_id': row[0],
763
- 'schedule_state': ManagedJobScheduleState(row[1]),
764
- 'controller_pid': row[2],
765
- }
766
- jobs.append(job_dict)
767
- return jobs
994
+ assert _SQLALCHEMY_ENGINE is not None
995
+
996
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
997
+ # Build the query using SQLAlchemy core
998
+ query = sqlalchemy.select(
999
+ spot_table.c.spot_job_id.distinct()).select_from(
1000
+ spot_table.outerjoin(
1001
+ job_info_table,
1002
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id,
1003
+ ))
1004
+ where_conditions = [
1005
+ ~spot_table.c.status.in_([
1006
+ status.value for status in ManagedJobStatus.terminal_statuses()
1007
+ ])
1008
+ ]
1009
+ if name is None and not all_users:
1010
+ if user_hash is None:
1011
+ # For backwards compatibility. With codegen, USER_ID_ENV_VAR
1012
+ # was set to the correct value by the jobs controller, as
1013
+ # part of ManagedJobCodeGen._build(). This is no longer the
1014
+ # case for the Skylet gRPC server, which is why we need to
1015
+ # pass it explicitly through the request body.
1016
+ logger.debug('user_hash is None, using current user hash')
1017
+ user_hash = common_utils.get_user_hash()
1018
+ where_conditions.append(job_info_table.c.user_hash == user_hash)
1019
+ if name is not None:
1020
+ # We match the job name from `job_info` for the jobs submitted after
1021
+ # #1982, and from `spot` for the jobs submitted before #1982, whose
1022
+ # job_info is not available.
1023
+ where_conditions.append(
1024
+ sqlalchemy.or_(
1025
+ job_info_table.c.name == name,
1026
+ sqlalchemy.and_(job_info_table.c.name.is_(None),
1027
+ spot_table.c.task_name == name),
1028
+ ))
1029
+ query = query.where(sqlalchemy.and_(*where_conditions)).order_by(
1030
+ spot_table.c.spot_job_id.desc())
1031
+ rows = session.execute(query).fetchall()
1032
+ job_ids = [row[0] for row in rows if row[0] is not None]
1033
+ return job_ids
768
1034
 
769
1035
 
1036
+ @_init_db
770
1037
  def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
771
1038
  """Get jobs that need controller process checking.
772
1039
 
@@ -778,89 +1045,87 @@ def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
778
1045
  - Jobs have schedule_state DONE but are in a non-terminal status
779
1046
  - Legacy jobs (that is, no schedule state) that are in non-terminal status
780
1047
  """
781
- job_filter = '' if job_id is None else 'AND spot.spot_job_id=(?)'
782
- job_value = () if job_id is None else (job_id,)
783
-
784
- status_filter_str = ', '.join(['?'] *
785
- len(ManagedJobStatus.terminal_statuses()))
786
- terminal_status_values = [
787
- status.value for status in ManagedJobStatus.terminal_statuses()
788
- ]
789
-
790
- # Get jobs that are either:
791
- # 1. Have schedule state that is not DONE, or
792
- # 2. Have schedule state DONE AND are in non-terminal status (unexpected
793
- # inconsistent state), or
794
- # 3. Have no schedule state (legacy) AND are in non-terminal status
795
- with db_utils.safe_cursor(_DB_PATH) as cursor:
796
- rows = cursor.execute(
797
- f"""\
798
- SELECT DISTINCT spot.spot_job_id
799
- FROM spot
800
- LEFT OUTER JOIN job_info
801
- ON spot.spot_job_id=job_info.spot_job_id
802
- WHERE (
803
- -- non-legacy jobs that are not DONE
804
- (job_info.schedule_state IS NOT NULL AND
805
- job_info.schedule_state IS NOT ?)
806
- OR
807
- -- legacy or that are in non-terminal status or
808
- -- DONE jobs that are in non-terminal status
809
- ((-- legacy jobs
810
- job_info.schedule_state IS NULL OR
811
- -- non-legacy DONE jobs
812
- job_info.schedule_state IS ?
813
- ) AND
814
- -- non-terminal
815
- status NOT IN ({status_filter_str}))
816
- )
817
- {job_filter}
818
- ORDER BY spot.spot_job_id DESC""", [
819
- ManagedJobScheduleState.DONE.value,
820
- ManagedJobScheduleState.DONE.value, *terminal_status_values,
821
- *job_value
822
- ]).fetchall()
823
- return [row[0] for row in rows if row[0] is not None]
1048
+ assert _SQLALCHEMY_ENGINE is not None
824
1049
 
1050
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1051
+ terminal_status_values = [
1052
+ status.value for status in ManagedJobStatus.terminal_statuses()
1053
+ ]
825
1054
 
826
- def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
827
- """Get all job ids by name."""
828
- name_filter = ''
829
- field_values = []
830
- if name is not None:
831
- # We match the job name from `job_info` for the jobs submitted after
832
- # #1982, and from `spot` for the jobs submitted before #1982, whose
833
- # job_info is not available.
834
- name_filter = ('WHERE (job_info.name=(?) OR '
835
- '(job_info.name IS NULL AND spot.task_name=(?)))')
836
- field_values = [name, name]
837
-
838
- # Left outer join is used here instead of join, because the job_info does
839
- # not contain the managed jobs submitted before #1982.
840
- with db_utils.safe_cursor(_DB_PATH) as cursor:
841
- rows = cursor.execute(
842
- f"""\
843
- SELECT DISTINCT spot.spot_job_id
844
- FROM spot
845
- LEFT OUTER JOIN job_info
846
- ON spot.spot_job_id=job_info.spot_job_id
847
- {name_filter}
848
- ORDER BY spot.spot_job_id DESC""", field_values).fetchall()
849
- job_ids = [row[0] for row in rows if row[0] is not None]
850
- return job_ids
1055
+ query = sqlalchemy.select(
1056
+ spot_table.c.spot_job_id.distinct()).select_from(
1057
+ spot_table.outerjoin(
1058
+ job_info_table,
1059
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1060
+
1061
+ # Get jobs that are either:
1062
+ # 1. Have schedule state that is not DONE, or
1063
+ # 2. Have schedule state DONE AND are in non-terminal status (unexpected
1064
+ # inconsistent state), or
1065
+ # 3. Have no schedule state (legacy) AND are in non-terminal status
1066
+
1067
+ # non-legacy jobs that are not DONE
1068
+ condition1 = sqlalchemy.and_(
1069
+ job_info_table.c.schedule_state.is_not(None),
1070
+ job_info_table.c.schedule_state !=
1071
+ ManagedJobScheduleState.DONE.value)
1072
+ # legacy or that are in non-terminal status or
1073
+ # DONE jobs that are in non-terminal status
1074
+ condition2 = sqlalchemy.and_(
1075
+ sqlalchemy.or_(
1076
+ # legacy jobs
1077
+ job_info_table.c.schedule_state.is_(None),
1078
+ # non-legacy DONE jobs
1079
+ job_info_table.c.schedule_state ==
1080
+ ManagedJobScheduleState.DONE.value),
1081
+ # non-terminal
1082
+ ~spot_table.c.status.in_(terminal_status_values),
1083
+ )
1084
+ where_condition = sqlalchemy.or_(condition1, condition2)
1085
+ if job_id is not None:
1086
+ where_condition = sqlalchemy.and_(
1087
+ where_condition, spot_table.c.spot_job_id == job_id)
851
1088
 
1089
+ query = query.where(where_condition).order_by(
1090
+ spot_table.c.spot_job_id.desc())
1091
+
1092
+ rows = session.execute(query).fetchall()
1093
+ return [row[0] for row in rows if row[0] is not None]
852
1094
 
1095
+
1096
+ @_init_db
853
1097
  def _get_all_task_ids_statuses(
854
1098
  job_id: int) -> List[Tuple[int, ManagedJobStatus]]:
855
- with db_utils.safe_cursor(_DB_PATH) as cursor:
856
- id_statuses = cursor.execute(
857
- """\
858
- SELECT task_id, status FROM spot
859
- WHERE spot_job_id=(?)
860
- ORDER BY task_id ASC""", (job_id,)).fetchall()
1099
+ assert _SQLALCHEMY_ENGINE is not None
1100
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1101
+ id_statuses = session.execute(
1102
+ sqlalchemy.select(
1103
+ spot_table.c.task_id,
1104
+ spot_table.c.status,
1105
+ ).where(spot_table.c.spot_job_id == job_id).order_by(
1106
+ spot_table.c.task_id.asc())).fetchall()
861
1107
  return [(row[0], ManagedJobStatus(row[1])) for row in id_statuses]
862
1108
 
863
1109
 
1110
+ @_init_db
1111
+ def get_all_task_ids_names_statuses_logs(
1112
+ job_id: int
1113
+ ) -> List[Tuple[int, str, ManagedJobStatus, str, Optional[float]]]:
1114
+ assert _SQLALCHEMY_ENGINE is not None
1115
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1116
+ id_names = session.execute(
1117
+ sqlalchemy.select(
1118
+ spot_table.c.task_id,
1119
+ spot_table.c.task_name,
1120
+ spot_table.c.status,
1121
+ spot_table.c.local_log_file,
1122
+ spot_table.c.logs_cleaned_at,
1123
+ ).where(spot_table.c.spot_job_id == job_id).order_by(
1124
+ spot_table.c.task_id.asc())).fetchall()
1125
+ return [(row[0], row[1], ManagedJobStatus(row[2]), row[3], row[4])
1126
+ for row in id_names]
1127
+
1128
+
864
1129
  def get_num_tasks(job_id: int) -> int:
865
1130
  return len(_get_all_task_ids_statuses(job_id))
866
1131
 
@@ -888,31 +1153,88 @@ def get_latest_task_id_status(
888
1153
  return task_id, status
889
1154
 
890
1155
 
1156
+ @_init_db
1157
+ def get_job_controller_process(job_id: int) -> Optional[ControllerPidRecord]:
1158
+ assert _SQLALCHEMY_ENGINE is not None
1159
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1160
+ row = session.execute(
1161
+ sqlalchemy.select(
1162
+ job_info_table.c.controller_pid,
1163
+ job_info_table.c.controller_pid_started_at).where(
1164
+ job_info_table.c.spot_job_id == job_id)).fetchone()
1165
+ if row is None or row[0] is None:
1166
+ return None
1167
+ pid = row[0]
1168
+ if pid < 0:
1169
+ # Between #7051 and #7847, the controller pid was negative to
1170
+ # indicate a controller process that can handle multiple jobs.
1171
+ pid = -pid
1172
+ return ControllerPidRecord(pid=pid, started_at=row[1])
1173
+
1174
+
1175
+ @_init_db
1176
+ def is_legacy_controller_process(job_id: int) -> bool:
1177
+ """Check if the controller process is a legacy single-job controller process
1178
+
1179
+ After #7051, the controller process pid is negative to indicate a new
1180
+ multi-job controller process.
1181
+ After #7847, the controller process pid is changed back to positive, but
1182
+ controller_pid_started_at will also be set.
1183
+ """
1184
+ # TODO(cooperc): Remove this function for 0.13.0
1185
+ assert _SQLALCHEMY_ENGINE is not None
1186
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1187
+ row = session.execute(
1188
+ sqlalchemy.select(
1189
+ job_info_table.c.controller_pid,
1190
+ job_info_table.c.controller_pid_started_at).where(
1191
+ job_info_table.c.spot_job_id == job_id)).fetchone()
1192
+ if row is None:
1193
+ raise ValueError(f'Job {job_id} not found')
1194
+ if row[0] is None:
1195
+ # Job is from before #4485, so controller_pid is not set
1196
+ # This is a legacy single-job controller process (running in ray!)
1197
+ return True
1198
+ started_at = row[1]
1199
+ if started_at is not None:
1200
+ # controller_pid_started_at is only set after #7847, so we know this
1201
+ # must be a non-legacy multi-job controller process.
1202
+ return False
1203
+ pid = row[0]
1204
+ if pid < 0:
1205
+ # Between #7051 and #7847, the controller pid was negative to
1206
+ # indicate a non-legacy multi-job controller process.
1207
+ return False
1208
+ return True
1209
+
1210
+
891
1211
  def get_status(job_id: int) -> Optional[ManagedJobStatus]:
892
1212
  _, status = get_latest_task_id_status(job_id)
893
1213
  return status
894
1214
 
895
1215
 
1216
+ @_init_db
896
1217
  def get_failure_reason(job_id: int) -> Optional[str]:
897
1218
  """Get the failure reason of a job.
898
1219
 
899
1220
  If the job has multiple tasks, we return the first failure reason.
900
1221
  """
901
- with db_utils.safe_cursor(_DB_PATH) as cursor:
902
- reason = cursor.execute(
903
- """\
904
- SELECT failure_reason FROM spot
905
- WHERE spot_job_id=(?)
906
- ORDER BY task_id ASC""", (job_id,)).fetchall()
1222
+ assert _SQLALCHEMY_ENGINE is not None
1223
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1224
+ reason = session.execute(
1225
+ sqlalchemy.select(spot_table.c.failure_reason).where(
1226
+ spot_table.c.spot_job_id == job_id).order_by(
1227
+ spot_table.c.task_id.asc())).fetchall()
907
1228
  reason = [r[0] for r in reason if r[0] is not None]
908
1229
  if not reason:
909
1230
  return None
910
1231
  return reason[0]
911
1232
 
912
1233
 
913
- def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
914
- """Get managed jobs from the database."""
915
- job_filter = '' if job_id is None else f'WHERE spot.spot_job_id={job_id}'
1234
+ @_init_db
1235
+ def get_managed_job_tasks(job_id: int) -> List[Dict[str, Any]]:
1236
+ """Get managed job tasks for a specific managed job id from the database."""
1237
+ assert _SQLALCHEMY_ENGINE is not None
916
1238
 
917
1239
  # Join spot and job_info tables to get the job name for each task.
918
1240
  # We use LEFT OUTER JOIN mainly for backward compatibility, as for an
@@ -922,197 +1244,1278 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
922
1244
  # Note: we will get the user_hash here, but don't try to call
923
1245
  # global_user_state.get_user() on it. This runs on the controller, which may
924
1246
  # not have the user info. Prefer to do it on the API server side.
925
- with db_utils.safe_cursor(_DB_PATH) as cursor:
926
- rows = cursor.execute(f"""\
927
- SELECT *
928
- FROM spot
929
- LEFT OUTER JOIN job_info
930
- ON spot.spot_job_id=job_info.spot_job_id
931
- {job_filter}
932
- ORDER BY spot.spot_job_id DESC, spot.task_id ASC""").fetchall()
933
- jobs = []
934
- for row in rows:
935
- job_dict = dict(zip(columns, row))
1247
+ query = sqlalchemy.select(spot_table, job_info_table).select_from(
1248
+ spot_table.outerjoin(
1249
+ job_info_table,
1250
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1251
+ query = query.where(spot_table.c.spot_job_id == job_id)
1252
+ query = query.order_by(spot_table.c.task_id.asc())
1253
+ rows = None
1254
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1255
+ rows = session.execute(query).fetchall()
1256
+ jobs = []
1257
+ for row in rows:
1258
+ job_dict = _get_jobs_dict(row._mapping) # pylint: disable=protected-access
1259
+ job_dict['status'] = ManagedJobStatus(job_dict['status'])
1260
+ job_dict['schedule_state'] = ManagedJobScheduleState(
1261
+ job_dict['schedule_state'])
1262
+ if job_dict['job_name'] is None:
1263
+ job_dict['job_name'] = job_dict['task_name']
1264
+ job_dict['metadata'] = json.loads(job_dict['metadata'])
1265
+
1266
+ # Add user YAML content for managed jobs.
1267
+ job_dict['user_yaml'] = job_dict.get('original_user_yaml_content')
1268
+ if job_dict['user_yaml'] is None:
1269
+ # Backwards compatibility - try to read from file path
1270
+ yaml_path = job_dict.get('original_user_yaml_path')
1271
+ if yaml_path:
1272
+ try:
1273
+ with open(yaml_path, 'r', encoding='utf-8') as f:
1274
+ job_dict['user_yaml'] = f.read()
1275
+ except (FileNotFoundError, IOError, OSError) as e:
1276
+ logger.debug('Failed to read original user YAML for job '
1277
+ f'{job_id} from {yaml_path}: {e}')
1278
+
1279
+ jobs.append(job_dict)
1280
+ return jobs
1281
+
1282
+
1283
+ def _map_response_field_to_db_column(field: str):
1284
+ """Map the response field name to an actual SQLAlchemy ColumnElement.
1285
+
1286
+ This ensures we never pass plain strings to SQLAlchemy 2.0 APIs like
1287
+ Select.with_only_columns().
1288
+ """
1289
+ # Explicit aliases differing from actual DB column names
1290
+ alias_mapping = {
1291
+ '_job_id': spot_table.c.job_id, # spot.job_id
1292
+ '_task_name': spot_table.c.job_name, # deprecated, from spot table
1293
+ 'job_id': spot_table.c.spot_job_id, # public job id -> spot.spot_job_id
1294
+ '_job_info_job_id': job_info_table.c.spot_job_id,
1295
+ 'job_name': job_info_table.c.name, # public job name -> job_info.name
1296
+ }
1297
+ if field in alias_mapping:
1298
+ return alias_mapping[field]
1299
+
1300
+ # Try direct match on the `spot` table columns
1301
+ if field in spot_table.c:
1302
+ return spot_table.c[field]
1303
+
1304
+ # Try direct match on the `job_info` table columns
1305
+ if field in job_info_table.c:
1306
+ return job_info_table.c[field]
1307
+
1308
+ raise ValueError(f'Unknown field: {field}')
1309
+
1310
+
1311
+ @_init_db
1312
+ def get_managed_jobs_total() -> int:
1313
+ """Get the total number of managed jobs."""
1314
+ assert _SQLALCHEMY_ENGINE is not None
1315
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1316
+ result = session.execute(
1317
+ sqlalchemy.select(sqlalchemy.func.count() # pylint: disable=not-callable
1318
+ ).select_from(spot_table)).fetchone()
1319
+ return result[0] if result else 0
1320
+
1321
+
1322
+ @_init_db
1323
+ def get_managed_jobs_highest_priority() -> int:
1324
+ """Get the highest priority of the managed jobs."""
1325
+ assert _SQLALCHEMY_ENGINE is not None
1326
+ query = sqlalchemy.select(sqlalchemy.func.max(
1327
+ job_info_table.c.priority)).where(
1328
+ sqlalchemy.and_(
1329
+ job_info_table.c.schedule_state.in_([
1330
+ ManagedJobScheduleState.LAUNCHING.value,
1331
+ ManagedJobScheduleState.ALIVE_BACKOFF.value,
1332
+ ManagedJobScheduleState.WAITING.value,
1333
+ ManagedJobScheduleState.ALIVE_WAITING.value,
1334
+ ]),
1335
+ job_info_table.c.priority.is_not(None),
1336
+ ))
1337
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1338
+ priority = session.execute(query).fetchone()
1339
+ return priority[0] if priority and priority[
1340
+ 0] is not None else constants.MIN_PRIORITY
1341
+
1342
+
1343
+ def build_managed_jobs_with_filters_no_status_query(
1344
+ fields: Optional[List[str]] = None,
1345
+ job_ids: Optional[List[int]] = None,
1346
+ accessible_workspaces: Optional[List[str]] = None,
1347
+ workspace_match: Optional[str] = None,
1348
+ name_match: Optional[str] = None,
1349
+ pool_match: Optional[str] = None,
1350
+ user_hashes: Optional[List[Optional[str]]] = None,
1351
+ skip_finished: bool = False,
1352
+ count_only: bool = False,
1353
+ status_count: bool = False,
1354
+ ) -> sqlalchemy.Select:
1355
+ """Build a query to get managed jobs from the database with filters."""
1356
+ # Join spot and job_info tables to get the job name for each task.
1357
+ # We use LEFT OUTER JOIN mainly for backward compatibility, as for an
1358
+ # existing controller before #1982, the job_info table may not exist,
1359
+ # and all the managed jobs created before will not present in the
1360
+ # job_info.
1361
+ # Note: we will get the user_hash here, but don't try to call
1362
+ # global_user_state.get_user() on it. This runs on the controller, which may
1363
+ # not have the user info. Prefer to do it on the API server side.
1364
+ if count_only:
1365
+ query = sqlalchemy.select(sqlalchemy.func.count().label('count')) # pylint: disable=not-callable
1366
+ elif status_count:
1367
+ query = sqlalchemy.select(spot_table.c.status,
1368
+ sqlalchemy.func.count().label('count')) # pylint: disable=not-callable
1369
+ else:
1370
+ query = sqlalchemy.select(spot_table, job_info_table)
1371
+ query = query.select_from(
1372
+ spot_table.outerjoin(
1373
+ job_info_table,
1374
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1375
+ if skip_finished:
1376
+ # Filter out finished jobs at the DB level. If a multi-task job is
1377
+ # partially finished, include all its tasks. We do this by first
1378
+ # selecting job_ids that have at least one non-terminal task, then
1379
+ # restricting the main query to those job_ids.
1380
+ terminal_status_values = [
1381
+ s.value for s in ManagedJobStatus.terminal_statuses()
1382
+ ]
1383
+ non_terminal_job_ids_subquery = (sqlalchemy.select(
1384
+ spot_table.c.spot_job_id).where(
1385
+ sqlalchemy.or_(
1386
+ spot_table.c.status.is_(None),
1387
+ sqlalchemy.not_(
1388
+ spot_table.c.status.in_(terminal_status_values)),
1389
+ )).distinct())
1390
+ query = query.where(
1391
+ spot_table.c.spot_job_id.in_(non_terminal_job_ids_subquery))
1392
+ if not count_only and not status_count and fields:
1393
+ # Resolve requested field names to explicit ColumnElements from
1394
+ # the joined tables.
1395
+ selected_columns = [_map_response_field_to_db_column(f) for f in fields]
1396
+ query = query.with_only_columns(*selected_columns)
1397
+ if job_ids is not None:
1398
+ query = query.where(spot_table.c.spot_job_id.in_(job_ids))
1399
+ if accessible_workspaces is not None:
1400
+ query = query.where(
1401
+ job_info_table.c.workspace.in_(accessible_workspaces))
1402
+ if workspace_match is not None:
1403
+ query = query.where(
1404
+ job_info_table.c.workspace.like(f'%{workspace_match}%'))
1405
+ if name_match is not None:
1406
+ query = query.where(job_info_table.c.name.like(f'%{name_match}%'))
1407
+ if pool_match is not None:
1408
+ query = query.where(job_info_table.c.pool.like(f'%{pool_match}%'))
1409
+ if user_hashes is not None:
1410
+ query = query.where(job_info_table.c.user_hash.in_(user_hashes))
1411
+ return query
1412
+
1413
+
1414
+ def build_managed_jobs_with_filters_query(
1415
+ fields: Optional[List[str]] = None,
1416
+ job_ids: Optional[List[int]] = None,
1417
+ accessible_workspaces: Optional[List[str]] = None,
1418
+ workspace_match: Optional[str] = None,
1419
+ name_match: Optional[str] = None,
1420
+ pool_match: Optional[str] = None,
1421
+ user_hashes: Optional[List[Optional[str]]] = None,
1422
+ statuses: Optional[List[str]] = None,
1423
+ skip_finished: bool = False,
1424
+ count_only: bool = False,
1425
+ ) -> sqlalchemy.Select:
1426
+ """Build a query to get managed jobs from the database with filters."""
1427
+ query = build_managed_jobs_with_filters_no_status_query(
1428
+ fields=fields,
1429
+ job_ids=job_ids,
1430
+ accessible_workspaces=accessible_workspaces,
1431
+ workspace_match=workspace_match,
1432
+ name_match=name_match,
1433
+ pool_match=pool_match,
1434
+ user_hashes=user_hashes,
1435
+ skip_finished=skip_finished,
1436
+ count_only=count_only,
1437
+ )
1438
+ if statuses is not None:
1439
+ query = query.where(spot_table.c.status.in_(statuses))
1440
+ return query
1441
+
1442
+
1443
+ @_init_db
1444
+ def get_status_count_with_filters(
1445
+ fields: Optional[List[str]] = None,
1446
+ job_ids: Optional[List[int]] = None,
1447
+ accessible_workspaces: Optional[List[str]] = None,
1448
+ workspace_match: Optional[str] = None,
1449
+ name_match: Optional[str] = None,
1450
+ pool_match: Optional[str] = None,
1451
+ user_hashes: Optional[List[Optional[str]]] = None,
1452
+ skip_finished: bool = False,
1453
+ ) -> Dict[str, int]:
1454
+ """Get the status count of the managed jobs with filters."""
1455
+ query = build_managed_jobs_with_filters_no_status_query(
1456
+ fields=fields,
1457
+ job_ids=job_ids,
1458
+ accessible_workspaces=accessible_workspaces,
1459
+ workspace_match=workspace_match,
1460
+ name_match=name_match,
1461
+ pool_match=pool_match,
1462
+ user_hashes=user_hashes,
1463
+ skip_finished=skip_finished,
1464
+ status_count=True,
1465
+ )
1466
+ query = query.group_by(spot_table.c.status)
1467
+ results: Dict[str, int] = {}
1468
+ assert _SQLALCHEMY_ENGINE is not None
1469
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1470
+ rows = session.execute(query).fetchall()
1471
+ for status_value, count in rows:
1472
+ # status_value is already a string (enum value)
1473
+ results[str(status_value)] = int(count)
1474
+ return results
1475
+
1476
+
1477
+ @_init_db
1478
+ def get_managed_jobs_with_filters(
1479
+ fields: Optional[List[str]] = None,
1480
+ job_ids: Optional[List[int]] = None,
1481
+ accessible_workspaces: Optional[List[str]] = None,
1482
+ workspace_match: Optional[str] = None,
1483
+ name_match: Optional[str] = None,
1484
+ pool_match: Optional[str] = None,
1485
+ user_hashes: Optional[List[Optional[str]]] = None,
1486
+ statuses: Optional[List[str]] = None,
1487
+ skip_finished: bool = False,
1488
+ page: Optional[int] = None,
1489
+ limit: Optional[int] = None,
1490
+ ) -> Tuple[List[Dict[str, Any]], int]:
1491
+ """Get managed jobs from the database with filters.
1492
+
1493
+ Returns:
1494
+ A tuple containing
1495
+ - the list of managed jobs
1496
+ - the total number of managed jobs
1497
+ """
1498
+ assert _SQLALCHEMY_ENGINE is not None
1499
+
1500
+ count_query = build_managed_jobs_with_filters_query(
1501
+ fields=None,
1502
+ job_ids=job_ids,
1503
+ accessible_workspaces=accessible_workspaces,
1504
+ workspace_match=workspace_match,
1505
+ name_match=name_match,
1506
+ pool_match=pool_match,
1507
+ user_hashes=user_hashes,
1508
+ statuses=statuses,
1509
+ skip_finished=skip_finished,
1510
+ count_only=True,
1511
+ )
1512
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1513
+ total = session.execute(count_query).fetchone()[0]
1514
+
1515
+ query = build_managed_jobs_with_filters_query(
1516
+ fields=fields,
1517
+ job_ids=job_ids,
1518
+ accessible_workspaces=accessible_workspaces,
1519
+ workspace_match=workspace_match,
1520
+ name_match=name_match,
1521
+ pool_match=pool_match,
1522
+ user_hashes=user_hashes,
1523
+ statuses=statuses,
1524
+ skip_finished=skip_finished,
1525
+ )
1526
+ query = query.order_by(spot_table.c.spot_job_id.desc(),
1527
+ spot_table.c.task_id.asc())
1528
+ if page is not None and limit is not None:
1529
+ query = query.offset((page - 1) * limit).limit(limit)
1530
+ rows = None
1531
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1532
+ rows = session.execute(query).fetchall()
1533
+ jobs = []
1534
+ for row in rows:
1535
+ job_dict = _get_jobs_dict(row._mapping) # pylint: disable=protected-access
1536
+ if job_dict.get('status') is not None:
936
1537
  job_dict['status'] = ManagedJobStatus(job_dict['status'])
1538
+ if job_dict.get('schedule_state') is not None:
937
1539
  job_dict['schedule_state'] = ManagedJobScheduleState(
938
1540
  job_dict['schedule_state'])
939
- if job_dict['job_name'] is None:
940
- job_dict['job_name'] = job_dict['task_name']
941
- jobs.append(job_dict)
942
- return jobs
943
-
944
-
1541
+ if job_dict.get('job_name') is None:
1542
+ job_dict['job_name'] = job_dict.get('task_name')
1543
+ if job_dict.get('metadata') is not None:
1544
+ job_dict['metadata'] = json.loads(job_dict['metadata'])
1545
+
1546
+ # Add user YAML content for managed jobs.
1547
+ job_dict['user_yaml'] = job_dict.get('original_user_yaml_content')
1548
+ if job_dict['user_yaml'] is None:
1549
+ # Backwards compatibility - try to read from file path
1550
+ yaml_path = job_dict.get('original_user_yaml_path')
1551
+ if yaml_path:
1552
+ try:
1553
+ with open(yaml_path, 'r', encoding='utf-8') as f:
1554
+ job_dict['user_yaml'] = f.read()
1555
+ except (FileNotFoundError, IOError, OSError) as e:
1556
+ job_id = job_dict.get('job_id')
1557
+ if job_id is not None:
1558
+ logger.debug('Failed to read original user YAML for '
1559
+ f'job {job_id} from {yaml_path}: {e}')
1560
+ else:
1561
+ logger.debug('Failed to read original user YAML from '
1562
+ f'{yaml_path}: {e}')
1563
+
1564
+ jobs.append(job_dict)
1565
+ return jobs, total
1566
+
1567
+
1568
+ @_init_db
945
1569
  def get_task_name(job_id: int, task_id: int) -> str:
946
1570
  """Get the task name of a job."""
947
- with db_utils.safe_cursor(_DB_PATH) as cursor:
948
- task_name = cursor.execute(
949
- """\
950
- SELECT task_name FROM spot
951
- WHERE spot_job_id=(?)
952
- AND task_id=(?)""", (job_id, task_id)).fetchone()
1571
+ assert _SQLALCHEMY_ENGINE is not None
1572
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1573
+ task_name = session.execute(
1574
+ sqlalchemy.select(spot_table.c.task_name).where(
1575
+ sqlalchemy.and_(
1576
+ spot_table.c.spot_job_id == job_id,
1577
+ spot_table.c.task_id == task_id,
1578
+ ))).fetchone()
953
1579
  return task_name[0]
954
1580
 
955
1581
 
1582
+ @_init_db
956
1583
  def get_latest_job_id() -> Optional[int]:
957
1584
  """Get the latest job id."""
958
- with db_utils.safe_cursor(_DB_PATH) as cursor:
959
- rows = cursor.execute("""\
960
- SELECT spot_job_id FROM spot
961
- WHERE task_id=0
962
- ORDER BY submitted_at DESC LIMIT 1""")
963
- for (job_id,) in rows:
964
- return job_id
965
- return None
1585
+ assert _SQLALCHEMY_ENGINE is not None
1586
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1587
+ job_id = session.execute(
1588
+ sqlalchemy.select(spot_table.c.spot_job_id).where(
1589
+ spot_table.c.task_id == 0).order_by(
1590
+ spot_table.c.submitted_at.desc()).limit(1)).fetchone()
1591
+ return job_id[0] if job_id else None
966
1592
 
967
1593
 
1594
+ @_init_db
968
1595
  def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
969
- with db_utils.safe_cursor(_DB_PATH) as cursor:
970
- task_specs = cursor.execute(
971
- """\
972
- SELECT specs FROM spot
973
- WHERE spot_job_id=(?) AND task_id=(?)""",
974
- (job_id, task_id)).fetchone()
1596
+ assert _SQLALCHEMY_ENGINE is not None
1597
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1598
+ task_specs = session.execute(
1599
+ sqlalchemy.select(spot_table.c.specs).where(
1600
+ sqlalchemy.and_(
1601
+ spot_table.c.spot_job_id == job_id,
1602
+ spot_table.c.task_id == task_id,
1603
+ ))).fetchone()
975
1604
  return json.loads(task_specs[0])
976
1605
 
977
1606
 
978
- def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
979
- """Get the local log directory for a job."""
980
- filter_str = 'spot_job_id=(?)'
981
- filter_args = [job_id]
982
- if task_id is not None:
983
- filter_str += ' AND task_id=(?)'
984
- filter_args.append(task_id)
985
- with db_utils.safe_cursor(_DB_PATH) as cursor:
986
- local_log_file = cursor.execute(
987
- f'SELECT local_log_file FROM spot '
988
- f'WHERE {filter_str}', filter_args).fetchone()
989
- return local_log_file[-1] if local_log_file else None
990
-
991
-
992
- # === Scheduler state functions ===
993
- # Only the scheduler should call these functions. They may require holding the
994
- # scheduler lock to work correctly.
995
-
996
-
997
- def scheduler_set_waiting(job_id: int, dag_yaml_path: str, env_file_path: str,
998
- user_hash: str) -> None:
999
- """Do not call without holding the scheduler lock."""
1000
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1001
- updated_count = cursor.execute(
1002
- 'UPDATE job_info SET '
1003
- 'schedule_state = (?), dag_yaml_path = (?), env_file_path = (?), '
1004
- ' user_hash = (?) '
1005
- 'WHERE spot_job_id = (?) AND schedule_state = (?)',
1006
- (ManagedJobScheduleState.WAITING.value, dag_yaml_path,
1007
- env_file_path, user_hash, job_id,
1008
- ManagedJobScheduleState.INACTIVE.value)).rowcount
1009
- assert updated_count == 1, (job_id, updated_count)
1010
-
1011
-
1012
- def scheduler_set_launching(job_id: int,
1013
- current_state: ManagedJobScheduleState) -> None:
1014
- """Do not call without holding the scheduler lock."""
1015
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1016
- updated_count = cursor.execute(
1017
- 'UPDATE job_info SET '
1018
- 'schedule_state = (?) '
1019
- 'WHERE spot_job_id = (?) AND schedule_state = (?)',
1020
- (ManagedJobScheduleState.LAUNCHING.value, job_id,
1021
- current_state.value)).rowcount
1022
- assert updated_count == 1, (job_id, updated_count)
1023
-
1024
-
1025
- def scheduler_set_alive(job_id: int) -> None:
1026
- """Do not call without holding the scheduler lock."""
1027
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1028
- updated_count = cursor.execute(
1029
- 'UPDATE job_info SET '
1030
- 'schedule_state = (?) '
1031
- 'WHERE spot_job_id = (?) AND schedule_state = (?)',
1032
- (ManagedJobScheduleState.ALIVE.value, job_id,
1033
- ManagedJobScheduleState.LAUNCHING.value)).rowcount
1034
- assert updated_count == 1, (job_id, updated_count)
1607
+ @_init_db
1608
+ def scheduler_set_waiting(job_id: int, dag_yaml_content: str,
1609
+ original_user_yaml_content: str,
1610
+ env_file_content: str,
1611
+ config_file_content: Optional[str],
1612
+ priority: int) -> None:
1613
+ assert _SQLALCHEMY_ENGINE is not None
1614
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1615
+ updated_count = session.query(job_info_table).filter(
1616
+ sqlalchemy.and_(job_info_table.c.spot_job_id == job_id,)).update({
1617
+ job_info_table.c.schedule_state:
1618
+ ManagedJobScheduleState.WAITING.value,
1619
+ job_info_table.c.dag_yaml_content: dag_yaml_content,
1620
+ job_info_table.c.original_user_yaml_content:
1621
+ (original_user_yaml_content),
1622
+ job_info_table.c.env_file_content: env_file_content,
1623
+ job_info_table.c.config_file_content: config_file_content,
1624
+ job_info_table.c.priority: priority,
1625
+ })
1626
+ session.commit()
1627
+ assert updated_count <= 1, (job_id, updated_count)
1628
+
1629
+
1630
+ @_init_db
1631
+ def get_job_file_contents(job_id: int) -> Dict[str, Optional[str]]:
1632
+ """Return file information and stored contents for a managed job."""
1633
+ assert _SQLALCHEMY_ENGINE is not None
1634
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1635
+ row = session.execute(
1636
+ sqlalchemy.select(
1637
+ job_info_table.c.dag_yaml_path,
1638
+ job_info_table.c.env_file_path,
1639
+ job_info_table.c.dag_yaml_content,
1640
+ job_info_table.c.env_file_content,
1641
+ job_info_table.c.config_file_content,
1642
+ ).where(job_info_table.c.spot_job_id == job_id)).fetchone()
1643
+
1644
+ if row is None:
1645
+ return {
1646
+ 'dag_yaml_path': None,
1647
+ 'env_file_path': None,
1648
+ 'dag_yaml_content': None,
1649
+ 'env_file_content': None,
1650
+ 'config_file_content': None,
1651
+ }
1652
+
1653
+ return {
1654
+ 'dag_yaml_path': row[0],
1655
+ 'env_file_path': row[1],
1656
+ 'dag_yaml_content': row[2],
1657
+ 'env_file_content': row[3],
1658
+ 'config_file_content': row[4],
1659
+ }
1035
1660
 
1036
1661
 
1037
- def scheduler_set_alive_waiting(job_id: int) -> None:
1662
+ @_init_db
1663
+ def get_pool_from_job_id(job_id: int) -> Optional[str]:
1664
+ """Get the pool from the job id."""
1665
+ assert _SQLALCHEMY_ENGINE is not None
1666
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1667
+ pool = session.execute(
1668
+ sqlalchemy.select(job_info_table.c.pool).where(
1669
+ job_info_table.c.spot_job_id == job_id)).fetchone()
1670
+ return pool[0] if pool else None
1671
+
1672
+
1673
+ @_init_db
1674
+ def set_current_cluster_name(job_id: int, current_cluster_name: str) -> None:
1675
+ """Set the current cluster name for a job."""
1676
+ assert _SQLALCHEMY_ENGINE is not None
1677
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1678
+ session.query(job_info_table).filter(
1679
+ job_info_table.c.spot_job_id == job_id).update(
1680
+ {job_info_table.c.current_cluster_name: current_cluster_name})
1681
+ session.commit()
1682
+
1683
+
1684
+ @_init_db_async
1685
+ async def set_job_id_on_pool_cluster_async(job_id: int,
1686
+ job_id_on_pool_cluster: int) -> None:
1687
+ """Set the job id on the pool cluster for a job."""
1688
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1689
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1690
+ await session.execute(
1691
+ sqlalchemy.update(job_info_table).
1692
+ where(job_info_table.c.spot_job_id == job_id).values({
1693
+ job_info_table.c.job_id_on_pool_cluster: job_id_on_pool_cluster
1694
+ }))
1695
+ await session.commit()
1696
+
1697
+
1698
+ @_init_db
1699
+ def get_pool_submit_info(job_id: int) -> Tuple[Optional[str], Optional[int]]:
1700
+ """Get the cluster name and job id on the pool from the managed job id."""
1701
+ assert _SQLALCHEMY_ENGINE is not None
1702
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1703
+ info = session.execute(
1704
+ sqlalchemy.select(
1705
+ job_info_table.c.current_cluster_name,
1706
+ job_info_table.c.job_id_on_pool_cluster).where(
1707
+ job_info_table.c.spot_job_id == job_id)).fetchone()
1708
+ if info is None:
1709
+ return None, None
1710
+ return info[0], info[1]
1711
+
1712
+
1713
+ @_init_db_async
1714
+ async def get_pool_submit_info_async(
1715
+ job_id: int) -> Tuple[Optional[str], Optional[int]]:
1716
+ """Get the cluster name and job id on the pool from the managed job id."""
1717
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1718
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1719
+ result = await session.execute(
1720
+ sqlalchemy.select(job_info_table.c.current_cluster_name,
1721
+ job_info_table.c.job_id_on_pool_cluster).where(
1722
+ job_info_table.c.spot_job_id == job_id))
1723
+ info = result.fetchone()
1724
+ if info is None:
1725
+ return None, None
1726
+ return info[0], info[1]
1727
+
1728
+
1729
+ @_init_db_async
1730
+ async def scheduler_set_launching_async(job_id: int):
1731
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1732
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1733
+ await session.execute(
1734
+ sqlalchemy.update(job_info_table).where(
1735
+ sqlalchemy.and_(job_info_table.c.spot_job_id == job_id)).values(
1736
+ {
1737
+ job_info_table.c.schedule_state:
1738
+ ManagedJobScheduleState.LAUNCHING.value
1739
+ }))
1740
+ await session.commit()
1741
+
1742
+
1743
+ @_init_db_async
1744
+ async def scheduler_set_alive_async(job_id: int) -> None:
1038
1745
  """Do not call without holding the scheduler lock."""
1039
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1040
- updated_count = cursor.execute(
1041
- 'UPDATE job_info SET '
1042
- 'schedule_state = (?) '
1043
- 'WHERE spot_job_id = (?) AND schedule_state = (?)',
1044
- (ManagedJobScheduleState.ALIVE_WAITING.value, job_id,
1045
- ManagedJobScheduleState.ALIVE.value)).rowcount
1046
- assert updated_count == 1, (job_id, updated_count)
1047
-
1048
-
1746
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1747
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1748
+ result = await session.execute(
1749
+ sqlalchemy.update(job_info_table).where(
1750
+ sqlalchemy.and_(
1751
+ job_info_table.c.spot_job_id == job_id,
1752
+ job_info_table.c.schedule_state ==
1753
+ ManagedJobScheduleState.LAUNCHING.value,
1754
+ )).values({
1755
+ job_info_table.c.schedule_state:
1756
+ ManagedJobScheduleState.ALIVE.value
1757
+ }))
1758
+ changes = result.rowcount
1759
+ await session.commit()
1760
+ assert changes == 1, (job_id, changes)
1761
+
1762
+
1763
+ @_init_db
1049
1764
  def scheduler_set_done(job_id: int, idempotent: bool = False) -> None:
1050
1765
  """Do not call without holding the scheduler lock."""
1051
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1052
- updated_count = cursor.execute(
1053
- 'UPDATE job_info SET '
1054
- 'schedule_state = (?) '
1055
- 'WHERE spot_job_id = (?) AND schedule_state != (?)',
1056
- (ManagedJobScheduleState.DONE.value, job_id,
1057
- ManagedJobScheduleState.DONE.value)).rowcount
1766
+ assert _SQLALCHEMY_ENGINE is not None
1767
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1768
+ updated_count = session.query(job_info_table).filter(
1769
+ sqlalchemy.and_(
1770
+ job_info_table.c.spot_job_id == job_id,
1771
+ job_info_table.c.schedule_state !=
1772
+ ManagedJobScheduleState.DONE.value,
1773
+ )).update({
1774
+ job_info_table.c.schedule_state:
1775
+ ManagedJobScheduleState.DONE.value
1776
+ })
1777
+ session.commit()
1058
1778
  if not idempotent:
1059
1779
  assert updated_count == 1, (job_id, updated_count)
1060
1780
 
1061
1781
 
1062
- def set_job_controller_pid(job_id: int, pid: int):
1063
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1064
- updated_count = cursor.execute(
1065
- 'UPDATE job_info SET '
1066
- 'controller_pid = (?) '
1067
- 'WHERE spot_job_id = (?)', (pid, job_id)).rowcount
1068
- assert updated_count == 1, (job_id, updated_count)
1069
-
1070
-
1782
+ @_init_db
1071
1783
  def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState:
1072
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1073
- state = cursor.execute(
1074
- 'SELECT schedule_state FROM job_info WHERE spot_job_id = (?)',
1075
- (job_id,)).fetchone()[0]
1784
+ assert _SQLALCHEMY_ENGINE is not None
1785
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1786
+ state = session.execute(
1787
+ sqlalchemy.select(job_info_table.c.schedule_state).where(
1788
+ job_info_table.c.spot_job_id == job_id)).fetchone()[0]
1076
1789
  return ManagedJobScheduleState(state)
1077
1790
 
1078
1791
 
1792
+ @_init_db
1079
1793
  def get_num_launching_jobs() -> int:
1080
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1081
- return cursor.execute(
1082
- 'SELECT COUNT(*) '
1083
- 'FROM job_info '
1084
- 'WHERE schedule_state = (?)',
1085
- (ManagedJobScheduleState.LAUNCHING.value,)).fetchone()[0]
1086
-
1087
-
1088
- def get_num_alive_jobs() -> int:
1089
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1090
- return cursor.execute(
1091
- 'SELECT COUNT(*) '
1092
- 'FROM job_info '
1093
- 'WHERE schedule_state IN (?, ?, ?)',
1094
- (ManagedJobScheduleState.ALIVE_WAITING.value,
1095
- ManagedJobScheduleState.LAUNCHING.value,
1096
- ManagedJobScheduleState.ALIVE.value)).fetchone()[0]
1097
-
1098
-
1099
- def get_waiting_job() -> Optional[Dict[str, Any]]:
1794
+ assert _SQLALCHEMY_ENGINE is not None
1795
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1796
+ return session.execute(
1797
+ sqlalchemy.select(
1798
+ sqlalchemy.func.count() # pylint: disable=not-callable
1799
+ ).select_from(job_info_table).where(
1800
+ sqlalchemy.and_(
1801
+ job_info_table.c.schedule_state ==
1802
+ ManagedJobScheduleState.LAUNCHING.value,
1803
+ # We only count jobs that are not in the pool, because the
1804
+ # job in the pool does not actually calling the sky.launch.
1805
+ job_info_table.c.pool.is_(None)))).fetchone()[0]
1806
+
1807
+
1808
+ @_init_db
1809
+ def get_num_alive_jobs(pool: Optional[str] = None) -> int:
1810
+ assert _SQLALCHEMY_ENGINE is not None
1811
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1812
+ where_conditions = [
1813
+ job_info_table.c.schedule_state.in_([
1814
+ ManagedJobScheduleState.ALIVE_WAITING.value,
1815
+ ManagedJobScheduleState.LAUNCHING.value,
1816
+ ManagedJobScheduleState.ALIVE.value,
1817
+ ManagedJobScheduleState.ALIVE_BACKOFF.value,
1818
+ ])
1819
+ ]
1820
+
1821
+ if pool is not None:
1822
+ where_conditions.append(job_info_table.c.pool == pool)
1823
+
1824
+ return session.execute(
1825
+ sqlalchemy.select(
1826
+ sqlalchemy.func.count() # pylint: disable=not-callable
1827
+ ).select_from(job_info_table).where(
1828
+ sqlalchemy.and_(*where_conditions))).fetchone()[0]
1829
+
1830
+
1831
+ @_init_db
1832
+ def get_nonterminal_job_ids_by_pool(pool: str,
1833
+ cluster_name: Optional[str] = None
1834
+ ) -> List[int]:
1835
+ """Get nonterminal job ids in a pool."""
1836
+ assert _SQLALCHEMY_ENGINE is not None
1837
+
1838
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1839
+ query = sqlalchemy.select(
1840
+ spot_table.c.spot_job_id.distinct()).select_from(
1841
+ spot_table.outerjoin(
1842
+ job_info_table,
1843
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1844
+ and_conditions = [
1845
+ ~spot_table.c.status.in_([
1846
+ status.value for status in ManagedJobStatus.terminal_statuses()
1847
+ ]),
1848
+ job_info_table.c.pool == pool,
1849
+ ]
1850
+ if cluster_name is not None:
1851
+ and_conditions.append(
1852
+ job_info_table.c.current_cluster_name == cluster_name)
1853
+ query = query.where(sqlalchemy.and_(*and_conditions)).order_by(
1854
+ spot_table.c.spot_job_id.asc())
1855
+ rows = session.execute(query).fetchall()
1856
+ job_ids = [row[0] for row in rows if row[0] is not None]
1857
+ return job_ids
1858
+
1859
+
1860
+ @_init_db_async
1861
+ async def get_waiting_job_async(
1862
+ pid: int, pid_started_at: float) -> Optional[Dict[str, Any]]:
1100
1863
  """Get the next job that should transition to LAUNCHING.
1101
1864
 
1865
+ Selects the highest-priority WAITING or ALIVE_WAITING job and atomically
1866
+ transitions it to LAUNCHING state to prevent race conditions.
1867
+
1868
+ Returns the job information if a job was successfully transitioned to
1869
+ LAUNCHING, or None if no suitable job was found.
1870
+
1102
1871
  Backwards compatibility note: jobs submitted before #4485 will have no
1103
1872
  schedule_state and will be ignored by this SQL query.
1104
1873
  """
1105
- with db_utils.safe_cursor(_DB_PATH) as cursor:
1106
- row = cursor.execute(
1107
- 'SELECT spot_job_id, schedule_state, dag_yaml_path, env_file_path '
1108
- 'FROM job_info '
1109
- 'WHERE schedule_state in (?, ?) '
1110
- 'ORDER BY spot_job_id LIMIT 1',
1111
- (ManagedJobScheduleState.WAITING.value,
1112
- ManagedJobScheduleState.ALIVE_WAITING.value)).fetchone()
1874
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1875
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1876
+ # Select the highest priority waiting job for update (locks the row)
1877
+ select_query = sqlalchemy.select(
1878
+ job_info_table.c.spot_job_id,
1879
+ job_info_table.c.schedule_state,
1880
+ job_info_table.c.pool,
1881
+ ).where(
1882
+ job_info_table.c.schedule_state.in_([
1883
+ ManagedJobScheduleState.WAITING.value,
1884
+ ])).order_by(
1885
+ job_info_table.c.priority.desc(),
1886
+ job_info_table.c.spot_job_id.asc(),
1887
+ ).limit(1).with_for_update()
1888
+
1889
+ # Execute the select with row locking
1890
+ result = await session.execute(select_query)
1891
+ waiting_job_row = result.fetchone()
1892
+
1893
+ if waiting_job_row is None:
1894
+ return None
1895
+
1896
+ job_id = waiting_job_row[0]
1897
+ current_state = ManagedJobScheduleState(waiting_job_row[1])
1898
+ pool = waiting_job_row[2]
1899
+
1900
+ # Update the job state to LAUNCHING
1901
+ update_result = await session.execute(
1902
+ sqlalchemy.update(job_info_table).where(
1903
+ sqlalchemy.and_(
1904
+ job_info_table.c.spot_job_id == job_id,
1905
+ job_info_table.c.schedule_state == current_state.value,
1906
+ )).values({
1907
+ job_info_table.c.schedule_state:
1908
+ ManagedJobScheduleState.LAUNCHING.value,
1909
+ job_info_table.c.controller_pid: pid,
1910
+ job_info_table.c.controller_pid_started_at: pid_started_at,
1911
+ }))
1912
+
1913
+ if update_result.rowcount != 1:
1914
+ # Update failed, rollback and return None
1915
+ await session.rollback()
1916
+ return None
1917
+
1918
+ # Commit the transaction
1919
+ await session.commit()
1920
+
1113
1921
  return {
1922
+ 'job_id': job_id,
1923
+ 'pool': pool,
1924
+ }
1925
+
1926
+
1927
+ @_init_db
1928
+ def get_workspace(job_id: int) -> str:
1929
+ """Get the workspace of a job."""
1930
+ assert _SQLALCHEMY_ENGINE is not None
1931
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1932
+ workspace = session.execute(
1933
+ sqlalchemy.select(job_info_table.c.workspace).where(
1934
+ job_info_table.c.spot_job_id == job_id)).fetchone()
1935
+ job_workspace = workspace[0] if workspace else None
1936
+ if job_workspace is None:
1937
+ return constants.SKYPILOT_DEFAULT_WORKSPACE
1938
+ return job_workspace
1939
+
1940
+
1941
+ @_init_db_async
1942
+ async def get_latest_task_id_status_async(
1943
+ job_id: int) -> Union[Tuple[int, ManagedJobStatus], Tuple[None, None]]:
1944
+ """Returns the (task id, status) of the latest task of a job."""
1945
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1946
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1947
+ result = await session.execute(
1948
+ sqlalchemy.select(
1949
+ spot_table.c.task_id,
1950
+ spot_table.c.status,
1951
+ ).where(spot_table.c.spot_job_id == job_id).order_by(
1952
+ spot_table.c.task_id.asc()))
1953
+ id_statuses = [
1954
+ (row[0], ManagedJobStatus(row[1])) for row in result.fetchall()
1955
+ ]
1956
+
1957
+ if not id_statuses:
1958
+ return None, None
1959
+ task_id, status = next(
1960
+ ((tid, st) for tid, st in id_statuses if not st.is_terminal()),
1961
+ id_statuses[-1],
1962
+ )
1963
+ return task_id, status
1964
+
1965
+
1966
+ @_init_db_async
1967
+ async def set_starting_async(job_id: int, task_id: int, run_timestamp: str,
1968
+ submit_time: float, resources_str: str,
1969
+ specs: Dict[str, Union[str, int]],
1970
+ callback_func: AsyncCallbackType):
1971
+ """Set the task to starting state."""
1972
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1973
+ logger.info('Launching the spot cluster...')
1974
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1975
+ result = await session.execute(
1976
+ sqlalchemy.update(spot_table).where(
1977
+ sqlalchemy.and_(
1978
+ spot_table.c.spot_job_id == job_id,
1979
+ spot_table.c.task_id == task_id,
1980
+ spot_table.c.status == ManagedJobStatus.PENDING.value,
1981
+ spot_table.c.end_at.is_(None),
1982
+ )).values({
1983
+ spot_table.c.resources: resources_str,
1984
+ spot_table.c.submitted_at: submit_time,
1985
+ spot_table.c.status: ManagedJobStatus.STARTING.value,
1986
+ spot_table.c.run_timestamp: run_timestamp,
1987
+ spot_table.c.specs: json.dumps(specs),
1988
+ }))
1989
+ count = result.rowcount
1990
+ await session.commit()
1991
+ if count != 1:
1992
+ details = await _describe_task_transition_failure(
1993
+ session, job_id, task_id)
1994
+ message = ('Failed to set the task to starting. '
1995
+ f'({count} rows updated. {details})')
1996
+ logger.error(message)
1997
+ raise exceptions.ManagedJobStatusError(message)
1998
+ await callback_func('SUBMITTED')
1999
+ await callback_func('STARTING')
2000
+
2001
+
2002
+ @_init_db_async
2003
+ async def set_started_async(job_id: int, task_id: int, start_time: float,
2004
+ callback_func: AsyncCallbackType):
2005
+ """Set the task to started state."""
2006
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2007
+ logger.info('Job started.')
2008
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2009
+ result = await session.execute(
2010
+ sqlalchemy.update(spot_table).where(
2011
+ sqlalchemy.and_(
2012
+ spot_table.c.spot_job_id == job_id,
2013
+ spot_table.c.task_id == task_id,
2014
+ spot_table.c.status.in_([
2015
+ ManagedJobStatus.STARTING.value,
2016
+ ManagedJobStatus.PENDING.value
2017
+ ]),
2018
+ spot_table.c.end_at.is_(None),
2019
+ )).values({
2020
+ spot_table.c.status: ManagedJobStatus.RUNNING.value,
2021
+ spot_table.c.start_at: start_time,
2022
+ spot_table.c.last_recovered_at: start_time,
2023
+ }))
2024
+ count = result.rowcount
2025
+ await session.commit()
2026
+ if count != 1:
2027
+ details = await _describe_task_transition_failure(
2028
+ session, job_id, task_id)
2029
+ message = (f'Failed to set the task to started. '
2030
+ f'({count} rows updated. {details})')
2031
+ logger.error(message)
2032
+ raise exceptions.ManagedJobStatusError(message)
2033
+ await callback_func('STARTED')
2034
+
2035
+
2036
+ @_init_db_async
2037
+ async def get_job_status_with_task_id_async(
2038
+ job_id: int, task_id: int) -> Optional[ManagedJobStatus]:
2039
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2040
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2041
+ result = await session.execute(
2042
+ sqlalchemy.select(spot_table.c.status).where(
2043
+ sqlalchemy.and_(spot_table.c.spot_job_id == job_id,
2044
+ spot_table.c.task_id == task_id)))
2045
+ status = result.fetchone()
2046
+ return ManagedJobStatus(status[0]) if status else None
2047
+
2048
+
2049
+ @_init_db_async
2050
+ async def set_recovering_async(job_id: int, task_id: int,
2051
+ force_transit_to_recovering: bool,
2052
+ callback_func: AsyncCallbackType):
2053
+ """Set the task to recovering state, and update the job duration."""
2054
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2055
+ logger.info('=== Recovering... ===')
2056
+ current_time = time.time()
2057
+
2058
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2059
+ if force_transit_to_recovering:
2060
+ status_condition = spot_table.c.status.in_(
2061
+ [s.value for s in ManagedJobStatus.processing_statuses()])
2062
+ else:
2063
+ status_condition = (
2064
+ spot_table.c.status == ManagedJobStatus.RUNNING.value)
2065
+
2066
+ result = await session.execute(
2067
+ sqlalchemy.update(spot_table).where(
2068
+ sqlalchemy.and_(
2069
+ spot_table.c.spot_job_id == job_id,
2070
+ spot_table.c.task_id == task_id,
2071
+ status_condition,
2072
+ spot_table.c.end_at.is_(None),
2073
+ )).values({
2074
+ spot_table.c.status: ManagedJobStatus.RECOVERING.value,
2075
+ spot_table.c.job_duration: sqlalchemy.case(
2076
+ (spot_table.c.last_recovered_at >= 0,
2077
+ spot_table.c.job_duration + current_time -
2078
+ spot_table.c.last_recovered_at),
2079
+ else_=spot_table.c.job_duration),
2080
+ spot_table.c.last_recovered_at: sqlalchemy.case(
2081
+ (spot_table.c.last_recovered_at < 0, current_time),
2082
+ else_=spot_table.c.last_recovered_at),
2083
+ }))
2084
+ count = result.rowcount
2085
+ await session.commit()
2086
+ if count != 1:
2087
+ details = await _describe_task_transition_failure(
2088
+ session, job_id, task_id)
2089
+ message = ('Failed to set the task to recovering with '
2090
+ 'force_transit_to_recovering='
2091
+ f'{force_transit_to_recovering}. '
2092
+ f'({count} rows updated. {details})')
2093
+ logger.error(message)
2094
+ raise exceptions.ManagedJobStatusError(message)
2095
+ await callback_func('RECOVERING')
2096
+
2097
+
2098
+ @_init_db_async
2099
+ async def set_recovered_async(job_id: int, task_id: int, recovered_time: float,
2100
+ callback_func: AsyncCallbackType):
2101
+ """Set the task to recovered."""
2102
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2103
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2104
+ result = await session.execute(
2105
+ sqlalchemy.update(spot_table).where(
2106
+ sqlalchemy.and_(
2107
+ spot_table.c.spot_job_id == job_id,
2108
+ spot_table.c.task_id == task_id,
2109
+ spot_table.c.status == ManagedJobStatus.RECOVERING.value,
2110
+ spot_table.c.end_at.is_(None),
2111
+ )).values({
2112
+ spot_table.c.status: ManagedJobStatus.RUNNING.value,
2113
+ spot_table.c.last_recovered_at: recovered_time,
2114
+ spot_table.c.recovery_count: spot_table.c.recovery_count +
2115
+ 1,
2116
+ }))
2117
+ count = result.rowcount
2118
+ await session.commit()
2119
+ if count != 1:
2120
+ details = await _describe_task_transition_failure(
2121
+ session, job_id, task_id)
2122
+ message = (f'Failed to set the task to recovered. '
2123
+ f'({count} rows updated. {details})')
2124
+ logger.error(message)
2125
+ raise exceptions.ManagedJobStatusError(message)
2126
+ logger.info('==== Recovered. ====')
2127
+ await callback_func('RECOVERED')
2128
+
2129
+
2130
+ @_init_db_async
2131
+ async def set_succeeded_async(job_id: int, task_id: int, end_time: float,
2132
+ callback_func: AsyncCallbackType):
2133
+ """Set the task to succeeded, if it is in a non-terminal state."""
2134
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2135
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2136
+ result = await session.execute(
2137
+ sqlalchemy.update(spot_table).where(
2138
+ sqlalchemy.and_(
2139
+ spot_table.c.spot_job_id == job_id,
2140
+ spot_table.c.task_id == task_id,
2141
+ spot_table.c.status == ManagedJobStatus.RUNNING.value,
2142
+ spot_table.c.end_at.is_(None),
2143
+ )).values({
2144
+ spot_table.c.status: ManagedJobStatus.SUCCEEDED.value,
2145
+ spot_table.c.end_at: end_time,
2146
+ }))
2147
+ count = result.rowcount
2148
+ await session.commit()
2149
+ if count != 1:
2150
+ details = await _describe_task_transition_failure(
2151
+ session, job_id, task_id)
2152
+ message = (f'Failed to set the task to succeeded. '
2153
+ f'({count} rows updated. {details})')
2154
+ logger.error(message)
2155
+ raise exceptions.ManagedJobStatusError(message)
2156
+ await callback_func('SUCCEEDED')
2157
+ logger.info('Job succeeded.')
2158
+
2159
+
2160
+ @_init_db_async
2161
+ async def set_failed_async(
2162
+ job_id: int,
2163
+ task_id: Optional[int],
2164
+ failure_type: ManagedJobStatus,
2165
+ failure_reason: str,
2166
+ callback_func: Optional[AsyncCallbackType] = None,
2167
+ end_time: Optional[float] = None,
2168
+ override_terminal: bool = False,
2169
+ ):
2170
+ """Set an entire job or task to failed."""
2171
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2172
+ assert failure_type.is_failed(), failure_type
2173
+ end_time = time.time() if end_time is None else end_time
2174
+
2175
+ fields_to_set: Dict[str, Any] = {
2176
+ spot_table.c.status: failure_type.value,
2177
+ spot_table.c.failure_reason: failure_reason,
2178
+ }
2179
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2180
+ # Get previous status
2181
+ result = await session.execute(
2182
+ sqlalchemy.select(
2183
+ spot_table.c.status).where(spot_table.c.spot_job_id == job_id))
2184
+ previous_status_row = result.fetchone()
2185
+ previous_status = ManagedJobStatus(previous_status_row[0])
2186
+ if previous_status == ManagedJobStatus.RECOVERING:
2187
+ fields_to_set[spot_table.c.last_recovered_at] = end_time
2188
+ where_conditions = [spot_table.c.spot_job_id == job_id]
2189
+ if task_id is not None:
2190
+ where_conditions.append(spot_table.c.task_id == task_id)
2191
+
2192
+ # Handle failure_reason prepending when override_terminal is True
2193
+ if override_terminal:
2194
+ # Get existing failure_reason with row lock to prevent race
2195
+ # conditions
2196
+ existing_reason_result = await session.execute(
2197
+ sqlalchemy.select(spot_table.c.failure_reason).where(
2198
+ sqlalchemy.and_(*where_conditions)).with_for_update())
2199
+ existing_reason_row = existing_reason_result.fetchone()
2200
+ if existing_reason_row and existing_reason_row[0]:
2201
+ # Prepend new failure reason to existing one
2202
+ fields_to_set[spot_table.c.failure_reason] = (
2203
+ failure_reason + '. Previously: ' + existing_reason_row[0])
2204
+ fields_to_set[spot_table.c.end_at] = sqlalchemy.func.coalesce(
2205
+ spot_table.c.end_at, end_time)
2206
+ else:
2207
+ fields_to_set[spot_table.c.end_at] = end_time
2208
+ where_conditions.append(spot_table.c.end_at.is_(None))
2209
+ result = await session.execute(
2210
+ sqlalchemy.update(spot_table).where(
2211
+ sqlalchemy.and_(*where_conditions)).values(fields_to_set))
2212
+ count = result.rowcount
2213
+ await session.commit()
2214
+ updated = count > 0
2215
+ if callback_func and updated:
2216
+ await callback_func('FAILED')
2217
+ logger.info(failure_reason)
2218
+
2219
+
2220
+ @_init_db_async
2221
+ async def set_cancelling_async(job_id: int, callback_func: AsyncCallbackType):
2222
+ """Set tasks in the job as cancelling, if they are in non-terminal
2223
+ states."""
2224
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2225
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2226
+ result = await session.execute(
2227
+ sqlalchemy.update(spot_table).where(
2228
+ sqlalchemy.and_(
2229
+ spot_table.c.spot_job_id == job_id,
2230
+ spot_table.c.end_at.is_(None),
2231
+ )).values(
2232
+ {spot_table.c.status: ManagedJobStatus.CANCELLING.value}))
2233
+ count = result.rowcount
2234
+ await session.commit()
2235
+ updated = count > 0
2236
+ if updated:
2237
+ logger.info('Cancelling the job...')
2238
+ await callback_func('CANCELLING')
2239
+ else:
2240
+ logger.info('Cancellation skipped, job is already terminal')
2241
+
2242
+
2243
+ @_init_db_async
2244
+ async def set_cancelled_async(job_id: int, callback_func: AsyncCallbackType):
2245
+ """Set tasks in the job as cancelled, if they are in CANCELLING state."""
2246
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2247
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2248
+ result = await session.execute(
2249
+ sqlalchemy.update(spot_table).where(
2250
+ sqlalchemy.and_(
2251
+ spot_table.c.spot_job_id == job_id,
2252
+ spot_table.c.status == ManagedJobStatus.CANCELLING.value,
2253
+ )).values({
2254
+ spot_table.c.status: ManagedJobStatus.CANCELLED.value,
2255
+ spot_table.c.end_at: time.time(),
2256
+ }))
2257
+ count = result.rowcount
2258
+ await session.commit()
2259
+ updated = count > 0
2260
+ if updated:
2261
+ logger.info('Job cancelled.')
2262
+ await callback_func('CANCELLED')
2263
+ else:
2264
+ logger.info('Cancellation skipped, job is not CANCELLING')
2265
+
2266
+
2267
+ @_init_db_async
2268
+ async def remove_ha_recovery_script_async(job_id: int) -> None:
2269
+ """Remove the HA recovery script for a job."""
2270
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2271
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2272
+ await session.execute(
2273
+ sqlalchemy.delete(ha_recovery_script_table).where(
2274
+ ha_recovery_script_table.c.job_id == job_id))
2275
+ await session.commit()
2276
+
2277
+
2278
+ async def get_status_async(job_id: int) -> Optional[ManagedJobStatus]:
2279
+ _, status = await get_latest_task_id_status_async(job_id)
2280
+ return status
2281
+
2282
+
2283
+ @_init_db_async
2284
+ async def get_job_schedule_state_async(job_id: int) -> ManagedJobScheduleState:
2285
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2286
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2287
+ result = await session.execute(
2288
+ sqlalchemy.select(job_info_table.c.schedule_state).where(
2289
+ job_info_table.c.spot_job_id == job_id))
2290
+ state = result.fetchone()[0]
2291
+ return ManagedJobScheduleState(state)
2292
+
2293
+
2294
+ @_init_db_async
2295
+ async def scheduler_set_done_async(job_id: int,
2296
+ idempotent: bool = False) -> None:
2297
+ """Do not call without holding the scheduler lock."""
2298
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2299
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2300
+ result = await session.execute(
2301
+ sqlalchemy.update(job_info_table).where(
2302
+ sqlalchemy.and_(
2303
+ job_info_table.c.spot_job_id == job_id,
2304
+ job_info_table.c.schedule_state !=
2305
+ ManagedJobScheduleState.DONE.value,
2306
+ )).values({
2307
+ job_info_table.c.schedule_state:
2308
+ ManagedJobScheduleState.DONE.value
2309
+ }))
2310
+ updated_count = result.rowcount
2311
+ await session.commit()
2312
+ if not idempotent:
2313
+ assert updated_count == 1, (job_id, updated_count)
2314
+
2315
+
2316
+ # ==== needed for codegen ====
2317
+ # functions have no use outside of codegen, remove at your own peril
2318
+
2319
+
2320
+ @_init_db
2321
+ def set_job_info(job_id: int,
2322
+ name: str,
2323
+ workspace: str,
2324
+ entrypoint: str,
2325
+ pool: Optional[str],
2326
+ pool_hash: Optional[str],
2327
+ user_hash: Optional[str] = None):
2328
+ assert _SQLALCHEMY_ENGINE is not None
2329
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2330
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
2331
+ db_utils.SQLAlchemyDialect.SQLITE.value):
2332
+ insert_func = sqlite.insert
2333
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
2334
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
2335
+ insert_func = postgresql.insert
2336
+ else:
2337
+ raise ValueError('Unsupported database dialect')
2338
+ insert_stmt = insert_func(job_info_table).values(
2339
+ spot_job_id=job_id,
2340
+ name=name,
2341
+ schedule_state=ManagedJobScheduleState.INACTIVE.value,
2342
+ workspace=workspace,
2343
+ entrypoint=entrypoint,
2344
+ pool=pool,
2345
+ pool_hash=pool_hash,
2346
+ user_hash=user_hash,
2347
+ )
2348
+ session.execute(insert_stmt)
2349
+ session.commit()
2350
+
2351
+
2352
+ @_init_db
2353
+ def reset_jobs_for_recovery() -> None:
2354
+ """Remove controller PIDs for live jobs, allowing them to be recovered."""
2355
+ assert _SQLALCHEMY_ENGINE is not None
2356
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2357
+ session.query(job_info_table).filter(
2358
+ # PID should be set.
2359
+ job_info_table.c.controller_pid.isnot(None),
2360
+ # Schedule state should be alive.
2361
+ job_info_table.c.schedule_state.isnot(None),
2362
+ (job_info_table.c.schedule_state !=
2363
+ ManagedJobScheduleState.WAITING.value),
2364
+ (job_info_table.c.schedule_state !=
2365
+ ManagedJobScheduleState.DONE.value),
2366
+ ).update({
2367
+ job_info_table.c.controller_pid: None,
2368
+ job_info_table.c.controller_pid_started_at: None,
2369
+ job_info_table.c.schedule_state:
2370
+ (ManagedJobScheduleState.WAITING.value)
2371
+ })
2372
+ session.commit()
2373
+
2374
+
2375
+ @_init_db
2376
+ def reset_job_for_recovery(job_id: int) -> None:
2377
+ """Set a job to WAITING and remove PID, allowing it to be recovered."""
2378
+ assert _SQLALCHEMY_ENGINE is not None
2379
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2380
+ session.query(job_info_table).filter(
2381
+ job_info_table.c.spot_job_id == job_id).update({
2382
+ job_info_table.c.controller_pid: None,
2383
+ job_info_table.c.controller_pid_started_at: None,
2384
+ job_info_table.c.schedule_state:
2385
+ ManagedJobScheduleState.WAITING.value,
2386
+ })
2387
+ session.commit()
2388
+
2389
+
2390
+ @_init_db
2391
+ def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
2392
+ """Get all job ids by name."""
2393
+ assert _SQLALCHEMY_ENGINE is not None
2394
+
2395
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2396
+ query = sqlalchemy.select(
2397
+ spot_table.c.spot_job_id.distinct()).select_from(
2398
+ spot_table.outerjoin(
2399
+ job_info_table,
2400
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
2401
+ if name is not None:
2402
+ # We match the job name from `job_info` for the jobs submitted after
2403
+ # #1982, and from `spot` for the jobs submitted before #1982, whose
2404
+ # job_info is not available.
2405
+ name_condition = sqlalchemy.or_(
2406
+ job_info_table.c.name == name,
2407
+ sqlalchemy.and_(job_info_table.c.name.is_(None),
2408
+ spot_table.c.task_name == name))
2409
+ query = query.where(name_condition)
2410
+ query = query.order_by(spot_table.c.spot_job_id.desc())
2411
+ rows = session.execute(query).fetchall()
2412
+ job_ids = [row[0] for row in rows if row[0] is not None]
2413
+ return job_ids
2414
+
2415
+
2416
+ @_init_db
2417
+ def get_task_logs_to_clean(retention_seconds: int,
2418
+ batch_size: int) -> List[Dict[str, Any]]:
2419
+ """Get the logs of job tasks to clean.
2420
+
2421
+ The logs of a task will only cleaned when:
2422
+ - the job schedule state is DONE
2423
+ - AND the end time of the task is older than the retention period
2424
+ """
2425
+ assert _SQLALCHEMY_ENGINE is not None
2426
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2427
+ now = time.time()
2428
+ result = session.execute(
2429
+ sqlalchemy.select(
2430
+ spot_table.c.spot_job_id,
2431
+ spot_table.c.task_id,
2432
+ spot_table.c.local_log_file,
2433
+ ).select_from(
2434
+ spot_table.join(
2435
+ job_info_table,
2436
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id,
2437
+ )).
2438
+ where(
2439
+ sqlalchemy.and_(
2440
+ job_info_table.c.schedule_state.is_(
2441
+ ManagedJobScheduleState.DONE.value),
2442
+ spot_table.c.end_at.isnot(None),
2443
+ spot_table.c.end_at < (now - retention_seconds),
2444
+ spot_table.c.logs_cleaned_at.is_(None),
2445
+ # The local log file is set AFTER the task is finished,
2446
+ # add this condition to ensure the entire log file has
2447
+ # been written.
2448
+ spot_table.c.local_log_file.isnot(None),
2449
+ )).limit(batch_size))
2450
+ rows = result.fetchall()
2451
+ return [{
1114
2452
  'job_id': row[0],
1115
- 'schedule_state': ManagedJobScheduleState(row[1]),
1116
- 'dag_yaml_path': row[2],
1117
- 'env_file_path': row[3],
1118
- } if row is not None else None
2453
+ 'task_id': row[1],
2454
+ 'local_log_file': row[2]
2455
+ } for row in rows]
2456
+
2457
+
2458
+ @_init_db
2459
+ def get_controller_logs_to_clean(retention_seconds: int,
2460
+ batch_size: int) -> List[Dict[str, Any]]:
2461
+ """Get the controller logs to clean.
2462
+
2463
+ The controller logs will only cleaned when:
2464
+ - the job schedule state is DONE
2465
+ - AND the end time of the latest task is older than the retention period
2466
+ """
2467
+ assert _SQLALCHEMY_ENGINE is not None
2468
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2469
+ now = time.time()
2470
+ result = session.execute(
2471
+ sqlalchemy.select(job_info_table.c.spot_job_id,).select_from(
2472
+ job_info_table.join(
2473
+ spot_table,
2474
+ job_info_table.c.spot_job_id == spot_table.c.spot_job_id,
2475
+ )).where(
2476
+ sqlalchemy.and_(
2477
+ job_info_table.c.schedule_state.is_(
2478
+ ManagedJobScheduleState.DONE.value),
2479
+ spot_table.c.local_log_file.isnot(None),
2480
+ job_info_table.c.controller_logs_cleaned_at.is_(None),
2481
+ )).group_by(
2482
+ job_info_table.c.spot_job_id,
2483
+ job_info_table.c.current_cluster_name,
2484
+ ).having(
2485
+ sqlalchemy.func.max(
2486
+ spot_table.c.end_at).isnot(None),).having(
2487
+ sqlalchemy.func.max(spot_table.c.end_at) < (
2488
+ now - retention_seconds)).limit(batch_size))
2489
+ rows = result.fetchall()
2490
+ return [{'job_id': row[0]} for row in rows]
2491
+
2492
+
2493
+ @_init_db
2494
+ def set_task_logs_cleaned(tasks: List[Tuple[int, int]], logs_cleaned_at: float):
2495
+ """Set the task logs cleaned at."""
2496
+ if not tasks:
2497
+ return
2498
+ task_keys = list(dict.fromkeys(tasks))
2499
+ assert _SQLALCHEMY_ENGINE is not None
2500
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2501
+ session.execute(
2502
+ sqlalchemy.update(spot_table).where(
2503
+ sqlalchemy.tuple_(spot_table.c.spot_job_id,
2504
+ spot_table.c.task_id).in_(task_keys)).values(
2505
+ logs_cleaned_at=logs_cleaned_at))
2506
+ session.commit()
2507
+
2508
+
2509
+ @_init_db
2510
+ def set_controller_logs_cleaned(job_ids: List[int], logs_cleaned_at: float):
2511
+ """Set the controller logs cleaned at."""
2512
+ if not job_ids:
2513
+ return
2514
+ job_ids = list(dict.fromkeys(job_ids))
2515
+ assert _SQLALCHEMY_ENGINE is not None
2516
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2517
+ session.execute(
2518
+ sqlalchemy.update(job_info_table).where(
2519
+ job_info_table.c.spot_job_id.in_(job_ids)).values(
2520
+ controller_logs_cleaned_at=logs_cleaned_at))
2521
+ session.commit()