skypilot-nightly 1.0.0.dev20250905__py3-none-any.whl → 1.0.0.dev20251203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (397) hide show
  1. sky/__init__.py +10 -2
  2. sky/adaptors/aws.py +81 -16
  3. sky/adaptors/common.py +25 -2
  4. sky/adaptors/coreweave.py +278 -0
  5. sky/adaptors/do.py +8 -2
  6. sky/adaptors/gcp.py +11 -0
  7. sky/adaptors/ibm.py +5 -2
  8. sky/adaptors/kubernetes.py +64 -0
  9. sky/adaptors/nebius.py +3 -1
  10. sky/adaptors/primeintellect.py +1 -0
  11. sky/adaptors/seeweb.py +183 -0
  12. sky/adaptors/shadeform.py +89 -0
  13. sky/admin_policy.py +20 -0
  14. sky/authentication.py +157 -263
  15. sky/backends/__init__.py +3 -2
  16. sky/backends/backend.py +11 -3
  17. sky/backends/backend_utils.py +588 -184
  18. sky/backends/cloud_vm_ray_backend.py +1088 -904
  19. sky/backends/local_docker_backend.py +9 -5
  20. sky/backends/task_codegen.py +633 -0
  21. sky/backends/wheel_utils.py +18 -0
  22. sky/catalog/__init__.py +8 -0
  23. sky/catalog/aws_catalog.py +4 -0
  24. sky/catalog/common.py +19 -1
  25. sky/catalog/data_fetchers/fetch_aws.py +102 -80
  26. sky/catalog/data_fetchers/fetch_gcp.py +30 -3
  27. sky/catalog/data_fetchers/fetch_nebius.py +9 -6
  28. sky/catalog/data_fetchers/fetch_runpod.py +698 -0
  29. sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
  30. sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
  31. sky/catalog/kubernetes_catalog.py +24 -28
  32. sky/catalog/primeintellect_catalog.py +95 -0
  33. sky/catalog/runpod_catalog.py +5 -1
  34. sky/catalog/seeweb_catalog.py +184 -0
  35. sky/catalog/shadeform_catalog.py +165 -0
  36. sky/check.py +73 -43
  37. sky/client/cli/command.py +675 -412
  38. sky/client/cli/flags.py +4 -2
  39. sky/{volumes/utils.py → client/cli/table_utils.py} +111 -13
  40. sky/client/cli/utils.py +79 -0
  41. sky/client/common.py +12 -2
  42. sky/client/sdk.py +132 -63
  43. sky/client/sdk_async.py +34 -33
  44. sky/cloud_stores.py +82 -3
  45. sky/clouds/__init__.py +6 -0
  46. sky/clouds/aws.py +337 -129
  47. sky/clouds/azure.py +24 -18
  48. sky/clouds/cloud.py +40 -13
  49. sky/clouds/cudo.py +16 -13
  50. sky/clouds/do.py +9 -7
  51. sky/clouds/fluidstack.py +12 -5
  52. sky/clouds/gcp.py +14 -7
  53. sky/clouds/hyperbolic.py +12 -5
  54. sky/clouds/ibm.py +12 -5
  55. sky/clouds/kubernetes.py +80 -45
  56. sky/clouds/lambda_cloud.py +12 -5
  57. sky/clouds/nebius.py +23 -9
  58. sky/clouds/oci.py +19 -12
  59. sky/clouds/paperspace.py +4 -1
  60. sky/clouds/primeintellect.py +317 -0
  61. sky/clouds/runpod.py +85 -24
  62. sky/clouds/scp.py +12 -8
  63. sky/clouds/seeweb.py +477 -0
  64. sky/clouds/shadeform.py +400 -0
  65. sky/clouds/ssh.py +4 -2
  66. sky/clouds/utils/scp_utils.py +61 -50
  67. sky/clouds/vast.py +33 -27
  68. sky/clouds/vsphere.py +14 -16
  69. sky/core.py +174 -165
  70. sky/dashboard/out/404.html +1 -1
  71. sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +1 -0
  72. sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
  73. sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +6 -0
  74. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
  75. sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
  76. sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +26 -0
  77. sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +1 -0
  78. sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +1 -0
  79. sky/dashboard/out/_next/static/chunks/{6601-06114c982db410b6.js → 3800-7b45f9fbb6308557.js} +1 -1
  80. sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +1 -0
  81. sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
  82. sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
  83. sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +1 -0
  84. sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +1 -0
  85. sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
  86. sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +1 -0
  87. sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
  88. sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +1 -0
  89. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
  90. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
  91. sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +31 -0
  92. sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
  93. sky/dashboard/out/_next/static/chunks/pages/{_app-ce361c6959bc2001.js → _app-bde01e4a2beec258.js} +1 -1
  94. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +16 -0
  95. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-abfcac9c137aa543.js +1 -0
  96. sky/dashboard/out/_next/static/chunks/pages/clusters-ee39056f9851a3ff.js +1 -0
  97. sky/dashboard/out/_next/static/chunks/pages/infra/{[context]-6563820e094f68ca.js → [context]-c0b5935149902e6f.js} +1 -1
  98. sky/dashboard/out/_next/static/chunks/pages/{infra-aabba60d57826e0f.js → infra-aed0ea19df7cf961.js} +1 -1
  99. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +16 -0
  100. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-9faf940b253e3e06.js +21 -0
  101. sky/dashboard/out/_next/static/chunks/pages/jobs-2072b48b617989c9.js +1 -0
  102. sky/dashboard/out/_next/static/chunks/pages/users-f42674164aa73423.js +1 -0
  103. sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
  104. sky/dashboard/out/_next/static/chunks/pages/workspaces/{[name]-af76bb06dbb3954f.js → [name]-84a40f8c7c627fe4.js} +1 -1
  105. sky/dashboard/out/_next/static/chunks/pages/{workspaces-7598c33a746cdc91.js → workspaces-531b2f8c4bf89f82.js} +1 -1
  106. sky/dashboard/out/_next/static/chunks/webpack-64e05f17bf2cf8ce.js +1 -0
  107. sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
  108. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  109. sky/dashboard/out/clusters/[cluster].html +1 -1
  110. sky/dashboard/out/clusters.html +1 -1
  111. sky/dashboard/out/config.html +1 -1
  112. sky/dashboard/out/index.html +1 -1
  113. sky/dashboard/out/infra/[context].html +1 -1
  114. sky/dashboard/out/infra.html +1 -1
  115. sky/dashboard/out/jobs/[job].html +1 -1
  116. sky/dashboard/out/jobs/pools/[pool].html +1 -1
  117. sky/dashboard/out/jobs.html +1 -1
  118. sky/dashboard/out/users.html +1 -1
  119. sky/dashboard/out/volumes.html +1 -1
  120. sky/dashboard/out/workspace/new.html +1 -1
  121. sky/dashboard/out/workspaces/[name].html +1 -1
  122. sky/dashboard/out/workspaces.html +1 -1
  123. sky/data/data_utils.py +92 -1
  124. sky/data/mounting_utils.py +162 -29
  125. sky/data/storage.py +200 -19
  126. sky/data/storage_utils.py +10 -45
  127. sky/exceptions.py +18 -7
  128. sky/execution.py +74 -31
  129. sky/global_user_state.py +605 -191
  130. sky/jobs/__init__.py +2 -0
  131. sky/jobs/client/sdk.py +101 -4
  132. sky/jobs/client/sdk_async.py +31 -5
  133. sky/jobs/constants.py +15 -8
  134. sky/jobs/controller.py +726 -284
  135. sky/jobs/file_content_utils.py +128 -0
  136. sky/jobs/log_gc.py +193 -0
  137. sky/jobs/recovery_strategy.py +250 -100
  138. sky/jobs/scheduler.py +271 -173
  139. sky/jobs/server/core.py +367 -114
  140. sky/jobs/server/server.py +81 -35
  141. sky/jobs/server/utils.py +89 -35
  142. sky/jobs/state.py +1498 -620
  143. sky/jobs/utils.py +771 -306
  144. sky/logs/agent.py +40 -5
  145. sky/logs/aws.py +9 -19
  146. sky/metrics/utils.py +282 -39
  147. sky/optimizer.py +1 -1
  148. sky/provision/__init__.py +37 -1
  149. sky/provision/aws/config.py +34 -13
  150. sky/provision/aws/instance.py +5 -2
  151. sky/provision/azure/instance.py +5 -3
  152. sky/provision/common.py +2 -0
  153. sky/provision/cudo/instance.py +4 -3
  154. sky/provision/do/instance.py +4 -3
  155. sky/provision/docker_utils.py +97 -26
  156. sky/provision/fluidstack/instance.py +6 -5
  157. sky/provision/gcp/config.py +6 -1
  158. sky/provision/gcp/instance.py +4 -2
  159. sky/provision/hyperbolic/instance.py +4 -2
  160. sky/provision/instance_setup.py +66 -20
  161. sky/provision/kubernetes/__init__.py +2 -0
  162. sky/provision/kubernetes/config.py +7 -44
  163. sky/provision/kubernetes/constants.py +0 -1
  164. sky/provision/kubernetes/instance.py +609 -213
  165. sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
  166. sky/provision/kubernetes/network.py +12 -8
  167. sky/provision/kubernetes/network_utils.py +8 -25
  168. sky/provision/kubernetes/utils.py +382 -418
  169. sky/provision/kubernetes/volume.py +150 -18
  170. sky/provision/lambda_cloud/instance.py +16 -13
  171. sky/provision/nebius/instance.py +6 -2
  172. sky/provision/nebius/utils.py +103 -86
  173. sky/provision/oci/instance.py +4 -2
  174. sky/provision/paperspace/instance.py +4 -3
  175. sky/provision/primeintellect/__init__.py +10 -0
  176. sky/provision/primeintellect/config.py +11 -0
  177. sky/provision/primeintellect/instance.py +454 -0
  178. sky/provision/primeintellect/utils.py +398 -0
  179. sky/provision/provisioner.py +30 -9
  180. sky/provision/runpod/__init__.py +2 -0
  181. sky/provision/runpod/instance.py +4 -3
  182. sky/provision/runpod/volume.py +69 -13
  183. sky/provision/scp/instance.py +307 -130
  184. sky/provision/seeweb/__init__.py +11 -0
  185. sky/provision/seeweb/config.py +13 -0
  186. sky/provision/seeweb/instance.py +812 -0
  187. sky/provision/shadeform/__init__.py +11 -0
  188. sky/provision/shadeform/config.py +12 -0
  189. sky/provision/shadeform/instance.py +351 -0
  190. sky/provision/shadeform/shadeform_utils.py +83 -0
  191. sky/provision/vast/instance.py +5 -3
  192. sky/provision/volume.py +164 -0
  193. sky/provision/vsphere/common/ssl_helper.py +1 -1
  194. sky/provision/vsphere/common/vapiconnect.py +2 -1
  195. sky/provision/vsphere/common/vim_utils.py +3 -2
  196. sky/provision/vsphere/instance.py +8 -6
  197. sky/provision/vsphere/vsphere_utils.py +8 -1
  198. sky/resources.py +11 -3
  199. sky/schemas/api/responses.py +107 -6
  200. sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
  201. sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
  202. sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
  203. sky/schemas/db/global_user_state/011_is_ephemeral.py +34 -0
  204. sky/schemas/db/kv_cache/001_initial_schema.py +29 -0
  205. sky/schemas/db/serve_state/002_yaml_content.py +34 -0
  206. sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
  207. sky/schemas/db/spot_jobs/002_cluster_pool.py +3 -3
  208. sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
  209. sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
  210. sky/schemas/db/spot_jobs/006_controller_pid_started_at.py +34 -0
  211. sky/schemas/db/spot_jobs/007_config_file_content.py +34 -0
  212. sky/schemas/generated/jobsv1_pb2.py +86 -0
  213. sky/schemas/generated/jobsv1_pb2.pyi +254 -0
  214. sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
  215. sky/schemas/generated/managed_jobsv1_pb2.py +76 -0
  216. sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
  217. sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
  218. sky/schemas/generated/servev1_pb2.py +58 -0
  219. sky/schemas/generated/servev1_pb2.pyi +115 -0
  220. sky/schemas/generated/servev1_pb2_grpc.py +322 -0
  221. sky/serve/autoscalers.py +2 -0
  222. sky/serve/client/impl.py +55 -21
  223. sky/serve/constants.py +4 -3
  224. sky/serve/controller.py +17 -11
  225. sky/serve/load_balancing_policies.py +1 -1
  226. sky/serve/replica_managers.py +219 -142
  227. sky/serve/serve_rpc_utils.py +179 -0
  228. sky/serve/serve_state.py +63 -54
  229. sky/serve/serve_utils.py +145 -109
  230. sky/serve/server/core.py +46 -25
  231. sky/serve/server/impl.py +311 -162
  232. sky/serve/server/server.py +21 -19
  233. sky/serve/service.py +84 -68
  234. sky/serve/service_spec.py +45 -7
  235. sky/server/auth/loopback.py +38 -0
  236. sky/server/auth/oauth2_proxy.py +12 -7
  237. sky/server/common.py +47 -24
  238. sky/server/config.py +62 -28
  239. sky/server/constants.py +9 -1
  240. sky/server/daemons.py +109 -38
  241. sky/server/metrics.py +76 -96
  242. sky/server/middleware_utils.py +166 -0
  243. sky/server/requests/executor.py +381 -145
  244. sky/server/requests/payloads.py +71 -18
  245. sky/server/requests/preconditions.py +15 -13
  246. sky/server/requests/request_names.py +121 -0
  247. sky/server/requests/requests.py +507 -157
  248. sky/server/requests/serializers/decoders.py +48 -17
  249. sky/server/requests/serializers/encoders.py +85 -20
  250. sky/server/requests/threads.py +117 -0
  251. sky/server/rest.py +116 -24
  252. sky/server/server.py +420 -172
  253. sky/server/stream_utils.py +219 -45
  254. sky/server/uvicorn.py +30 -19
  255. sky/setup_files/MANIFEST.in +6 -1
  256. sky/setup_files/alembic.ini +8 -0
  257. sky/setup_files/dependencies.py +62 -19
  258. sky/setup_files/setup.py +44 -44
  259. sky/sky_logging.py +13 -5
  260. sky/skylet/attempt_skylet.py +106 -24
  261. sky/skylet/configs.py +3 -1
  262. sky/skylet/constants.py +111 -26
  263. sky/skylet/events.py +64 -10
  264. sky/skylet/job_lib.py +141 -104
  265. sky/skylet/log_lib.py +233 -5
  266. sky/skylet/log_lib.pyi +40 -2
  267. sky/skylet/providers/ibm/node_provider.py +12 -8
  268. sky/skylet/providers/ibm/vpc_provider.py +13 -12
  269. sky/skylet/runtime_utils.py +21 -0
  270. sky/skylet/services.py +524 -0
  271. sky/skylet/skylet.py +22 -1
  272. sky/skylet/subprocess_daemon.py +104 -29
  273. sky/skypilot_config.py +99 -79
  274. sky/ssh_node_pools/server.py +9 -8
  275. sky/task.py +221 -104
  276. sky/templates/aws-ray.yml.j2 +1 -0
  277. sky/templates/azure-ray.yml.j2 +1 -0
  278. sky/templates/cudo-ray.yml.j2 +1 -0
  279. sky/templates/do-ray.yml.j2 +1 -0
  280. sky/templates/fluidstack-ray.yml.j2 +1 -0
  281. sky/templates/gcp-ray.yml.j2 +1 -0
  282. sky/templates/hyperbolic-ray.yml.j2 +1 -0
  283. sky/templates/ibm-ray.yml.j2 +2 -1
  284. sky/templates/jobs-controller.yaml.j2 +3 -0
  285. sky/templates/kubernetes-ray.yml.j2 +196 -55
  286. sky/templates/lambda-ray.yml.j2 +1 -0
  287. sky/templates/nebius-ray.yml.j2 +3 -0
  288. sky/templates/oci-ray.yml.j2 +1 -0
  289. sky/templates/paperspace-ray.yml.j2 +1 -0
  290. sky/templates/primeintellect-ray.yml.j2 +72 -0
  291. sky/templates/runpod-ray.yml.j2 +1 -0
  292. sky/templates/scp-ray.yml.j2 +1 -0
  293. sky/templates/seeweb-ray.yml.j2 +171 -0
  294. sky/templates/shadeform-ray.yml.j2 +73 -0
  295. sky/templates/vast-ray.yml.j2 +1 -0
  296. sky/templates/vsphere-ray.yml.j2 +1 -0
  297. sky/templates/websocket_proxy.py +188 -43
  298. sky/usage/usage_lib.py +16 -4
  299. sky/users/permission.py +60 -43
  300. sky/utils/accelerator_registry.py +6 -3
  301. sky/utils/admin_policy_utils.py +18 -5
  302. sky/utils/annotations.py +22 -0
  303. sky/utils/asyncio_utils.py +78 -0
  304. sky/utils/atomic.py +1 -1
  305. sky/utils/auth_utils.py +153 -0
  306. sky/utils/cli_utils/status_utils.py +12 -7
  307. sky/utils/cluster_utils.py +28 -6
  308. sky/utils/command_runner.py +88 -27
  309. sky/utils/command_runner.pyi +36 -3
  310. sky/utils/common.py +3 -1
  311. sky/utils/common_utils.py +37 -4
  312. sky/utils/config_utils.py +1 -14
  313. sky/utils/context.py +127 -40
  314. sky/utils/context_utils.py +73 -18
  315. sky/utils/controller_utils.py +229 -70
  316. sky/utils/db/db_utils.py +95 -18
  317. sky/utils/db/kv_cache.py +149 -0
  318. sky/utils/db/migration_utils.py +24 -7
  319. sky/utils/env_options.py +4 -0
  320. sky/utils/git.py +559 -1
  321. sky/utils/kubernetes/create_cluster.sh +15 -30
  322. sky/utils/kubernetes/delete_cluster.sh +10 -7
  323. sky/utils/kubernetes/{deploy_remote_cluster.py → deploy_ssh_node_pools.py} +258 -380
  324. sky/utils/kubernetes/generate_kind_config.py +6 -66
  325. sky/utils/kubernetes/gpu_labeler.py +13 -3
  326. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +2 -1
  327. sky/utils/kubernetes/k8s_gpu_labeler_setup.yaml +16 -16
  328. sky/utils/kubernetes/kubernetes_deploy_utils.py +213 -194
  329. sky/utils/kubernetes/rsync_helper.sh +11 -3
  330. sky/utils/kubernetes_enums.py +7 -15
  331. sky/utils/lock_events.py +4 -4
  332. sky/utils/locks.py +128 -31
  333. sky/utils/log_utils.py +0 -319
  334. sky/utils/resource_checker.py +13 -10
  335. sky/utils/resources_utils.py +53 -29
  336. sky/utils/rich_utils.py +8 -4
  337. sky/utils/schemas.py +107 -52
  338. sky/utils/subprocess_utils.py +17 -4
  339. sky/utils/thread_utils.py +91 -0
  340. sky/utils/timeline.py +2 -1
  341. sky/utils/ux_utils.py +35 -1
  342. sky/utils/volume.py +88 -4
  343. sky/utils/yaml_utils.py +9 -0
  344. sky/volumes/client/sdk.py +48 -10
  345. sky/volumes/server/core.py +59 -22
  346. sky/volumes/server/server.py +46 -17
  347. sky/volumes/volume.py +54 -42
  348. sky/workspaces/core.py +57 -21
  349. sky/workspaces/server.py +13 -12
  350. sky_templates/README.md +3 -0
  351. sky_templates/__init__.py +3 -0
  352. sky_templates/ray/__init__.py +0 -0
  353. sky_templates/ray/start_cluster +183 -0
  354. sky_templates/ray/stop_cluster +75 -0
  355. {skypilot_nightly-1.0.0.dev20250905.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/METADATA +331 -65
  356. skypilot_nightly-1.0.0.dev20251203.dist-info/RECORD +611 -0
  357. skypilot_nightly-1.0.0.dev20251203.dist-info/top_level.txt +2 -0
  358. sky/client/cli/git.py +0 -549
  359. sky/dashboard/out/_next/static/chunks/1121-408ed10b2f9fce17.js +0 -1
  360. sky/dashboard/out/_next/static/chunks/1141-943efc7aff0f0c06.js +0 -1
  361. sky/dashboard/out/_next/static/chunks/1836-37fede578e2da5f8.js +0 -40
  362. sky/dashboard/out/_next/static/chunks/3015-86cabed5d4669ad0.js +0 -1
  363. sky/dashboard/out/_next/static/chunks/3294.c80326aec9bfed40.js +0 -6
  364. sky/dashboard/out/_next/static/chunks/3785.4872a2f3aa489880.js +0 -1
  365. sky/dashboard/out/_next/static/chunks/4045.b30465273dc5e468.js +0 -21
  366. sky/dashboard/out/_next/static/chunks/4676-9da7fdbde90b5549.js +0 -10
  367. sky/dashboard/out/_next/static/chunks/4725.10f7a9a5d3ea8208.js +0 -1
  368. sky/dashboard/out/_next/static/chunks/5339.3fda4a4010ff4e06.js +0 -51
  369. sky/dashboard/out/_next/static/chunks/6135-4b4d5e824b7f9d3c.js +0 -1
  370. sky/dashboard/out/_next/static/chunks/649.b9d7f7d10c1b8c53.js +0 -45
  371. sky/dashboard/out/_next/static/chunks/6856-dca7962af4814e1b.js +0 -1
  372. sky/dashboard/out/_next/static/chunks/6990-08b2a1cae076a943.js +0 -1
  373. sky/dashboard/out/_next/static/chunks/7325.b4bc99ce0892dcd5.js +0 -6
  374. sky/dashboard/out/_next/static/chunks/754-d0da8ab45f9509e9.js +0 -18
  375. sky/dashboard/out/_next/static/chunks/7669.1f5d9a402bf5cc42.js +0 -36
  376. sky/dashboard/out/_next/static/chunks/8969-0be3036bf86f8256.js +0 -1
  377. sky/dashboard/out/_next/static/chunks/9025.c12318fb6a1a9093.js +0 -6
  378. sky/dashboard/out/_next/static/chunks/9037-fa1737818d0a0969.js +0 -6
  379. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-1cbba24bd1bd35f8.js +0 -16
  380. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-0b4b35dc1dfe046c.js +0 -16
  381. sky/dashboard/out/_next/static/chunks/pages/clusters-469814d711d63b1b.js +0 -1
  382. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-dd64309c3fe67ed2.js +0 -11
  383. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-07349868f7905d37.js +0 -16
  384. sky/dashboard/out/_next/static/chunks/pages/jobs-1f70d9faa564804f.js +0 -1
  385. sky/dashboard/out/_next/static/chunks/pages/users-018bf31cda52e11b.js +0 -1
  386. sky/dashboard/out/_next/static/chunks/pages/volumes-739726d6b823f532.js +0 -1
  387. sky/dashboard/out/_next/static/chunks/webpack-4fe903277b57b523.js +0 -1
  388. sky/dashboard/out/_next/static/css/4614e06482d7309e.css +0 -3
  389. sky/dashboard/out/_next/static/mS-4qZPSkRuA1u-g2wQhg/_buildManifest.js +0 -1
  390. sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
  391. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
  392. skypilot_nightly-1.0.0.dev20250905.dist-info/RECORD +0 -547
  393. skypilot_nightly-1.0.0.dev20250905.dist-info/top_level.txt +0 -1
  394. /sky/dashboard/out/_next/static/{mS-4qZPSkRuA1u-g2wQhg → 96_E2yl3QAiIJGOYCkSpB}/_ssgManifest.js +0 -0
  395. {skypilot_nightly-1.0.0.dev20250905.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/WHEEL +0 -0
  396. {skypilot_nightly-1.0.0.dev20250905.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/entry_points.txt +0 -0
  397. {skypilot_nightly-1.0.0.dev20250905.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/licenses/LICENSE +0 -0
sky/jobs/state.py CHANGED
@@ -1,13 +1,18 @@
1
1
  """The database for managed jobs status."""
2
2
  # TODO(zhwu): maybe use file based status instead of database, so
3
3
  # that we can easily switch to a s3-based storage.
4
+ import asyncio
5
+ import collections
4
6
  import enum
5
7
  import functools
8
+ import ipaddress
6
9
  import json
10
+ import sqlite3
7
11
  import threading
8
12
  import time
9
13
  import typing
10
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
14
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
15
+ import urllib.parse
11
16
 
12
17
  import colorama
13
18
  import sqlalchemy
@@ -15,27 +20,40 @@ from sqlalchemy import exc as sqlalchemy_exc
15
20
  from sqlalchemy import orm
16
21
  from sqlalchemy.dialects import postgresql
17
22
  from sqlalchemy.dialects import sqlite
23
+ from sqlalchemy.ext import asyncio as sql_async
18
24
  from sqlalchemy.ext import declarative
19
25
 
20
26
  from sky import exceptions
21
27
  from sky import sky_logging
28
+ from sky import skypilot_config
29
+ from sky.adaptors import common as adaptors_common
22
30
  from sky.skylet import constants
23
31
  from sky.utils import common_utils
32
+ from sky.utils import context_utils
24
33
  from sky.utils.db import db_utils
25
34
  from sky.utils.db import migration_utils
26
35
 
27
36
  if typing.TYPE_CHECKING:
28
37
  from sqlalchemy.engine import row
29
38
 
30
- import sky
39
+ from sky.schemas.generated import managed_jobsv1_pb2
40
+ else:
41
+ managed_jobsv1_pb2 = adaptors_common.LazyImport(
42
+ 'sky.schemas.generated.managed_jobsv1_pb2')
31
43
 
32
- CallbackType = Callable[[str], None]
44
+ # Separate callback types for sync and async contexts
45
+ SyncCallbackType = Callable[[str], None]
46
+ AsyncCallbackType = Callable[[str], Awaitable[Any]]
47
+ CallbackType = Union[SyncCallbackType, AsyncCallbackType]
33
48
 
34
49
  logger = sky_logging.init_logger(__name__)
35
50
 
36
51
  _SQLALCHEMY_ENGINE: Optional[sqlalchemy.engine.Engine] = None
52
+ _SQLALCHEMY_ENGINE_ASYNC: Optional[sql_async.AsyncEngine] = None
37
53
  _SQLALCHEMY_ENGINE_LOCK = threading.Lock()
38
54
 
55
+ _DB_RETRY_TIMES = 30
56
+
39
57
  Base = declarative.declarative_base()
40
58
 
41
59
  # === Database schema ===
@@ -70,12 +88,13 @@ spot_table = sqlalchemy.Table(
70
88
  sqlalchemy.Column('recovery_count', sqlalchemy.Integer, server_default='0'),
71
89
  sqlalchemy.Column('job_duration', sqlalchemy.Float, server_default='0'),
72
90
  sqlalchemy.Column('failure_reason', sqlalchemy.Text),
73
- sqlalchemy.Column('spot_job_id', sqlalchemy.Integer),
91
+ sqlalchemy.Column('spot_job_id', sqlalchemy.Integer, index=True),
74
92
  sqlalchemy.Column('task_id', sqlalchemy.Integer, server_default='0'),
75
93
  sqlalchemy.Column('task_name', sqlalchemy.Text),
76
94
  sqlalchemy.Column('specs', sqlalchemy.Text),
77
95
  sqlalchemy.Column('local_log_file', sqlalchemy.Text, server_default=None),
78
96
  sqlalchemy.Column('metadata', sqlalchemy.Text, server_default='{}'),
97
+ sqlalchemy.Column('logs_cleaned_at', sqlalchemy.Float, server_default=None),
79
98
  )
80
99
 
81
100
  job_info_table = sqlalchemy.Table(
@@ -89,8 +108,16 @@ job_info_table = sqlalchemy.Table(
89
108
  sqlalchemy.Column('schedule_state', sqlalchemy.Text),
90
109
  sqlalchemy.Column('controller_pid', sqlalchemy.Integer,
91
110
  server_default=None),
111
+ sqlalchemy.Column('controller_pid_started_at',
112
+ sqlalchemy.Float,
113
+ server_default=None),
92
114
  sqlalchemy.Column('dag_yaml_path', sqlalchemy.Text),
93
115
  sqlalchemy.Column('env_file_path', sqlalchemy.Text),
116
+ sqlalchemy.Column('dag_yaml_content', sqlalchemy.Text, server_default=None),
117
+ sqlalchemy.Column('env_file_content', sqlalchemy.Text, server_default=None),
118
+ sqlalchemy.Column('config_file_content',
119
+ sqlalchemy.Text,
120
+ server_default=None),
94
121
  sqlalchemy.Column('user_hash', sqlalchemy.Text),
95
122
  sqlalchemy.Column('workspace', sqlalchemy.Text, server_default=None),
96
123
  sqlalchemy.Column('priority',
@@ -100,6 +127,9 @@ job_info_table = sqlalchemy.Table(
100
127
  sqlalchemy.Column('original_user_yaml_path',
101
128
  sqlalchemy.Text,
102
129
  server_default=None),
130
+ sqlalchemy.Column('original_user_yaml_content',
131
+ sqlalchemy.Text,
132
+ server_default=None),
103
133
  sqlalchemy.Column('pool', sqlalchemy.Text, server_default=None),
104
134
  sqlalchemy.Column('current_cluster_name',
105
135
  sqlalchemy.Text,
@@ -108,8 +138,12 @@ job_info_table = sqlalchemy.Table(
108
138
  sqlalchemy.Integer,
109
139
  server_default=None),
110
140
  sqlalchemy.Column('pool_hash', sqlalchemy.Text, server_default=None),
141
+ sqlalchemy.Column('controller_logs_cleaned_at',
142
+ sqlalchemy.Float,
143
+ server_default=None),
111
144
  )
112
145
 
146
+ # TODO(cooperc): drop the table in a migration
113
147
  ha_recovery_script_table = sqlalchemy.Table(
114
148
  'ha_recovery_script',
115
149
  Base.metadata,
@@ -129,6 +163,7 @@ def create_table(engine: sqlalchemy.engine.Engine):
129
163
  try:
130
164
  with orm.Session(engine) as session:
131
165
  session.execute(sqlalchemy.text('PRAGMA journal_mode=WAL'))
166
+ session.execute(sqlalchemy.text('PRAGMA synchronous=1'))
132
167
  session.commit()
133
168
  except sqlalchemy_exc.OperationalError as e:
134
169
  if 'database is locked' not in str(e):
@@ -141,6 +176,43 @@ def create_table(engine: sqlalchemy.engine.Engine):
141
176
  migration_utils.SPOT_JOBS_VERSION)
142
177
 
143
178
 
179
+ def force_no_postgres() -> bool:
180
+ """Force no postgres.
181
+
182
+ If the db is localhost on the api server, and we are not in consolidation
183
+ mode, we must force using sqlite and not using the api server on the jobs
184
+ controller.
185
+ """
186
+ conn_string = skypilot_config.get_nested(('db',), None)
187
+
188
+ if conn_string:
189
+ parsed = urllib.parse.urlparse(conn_string)
190
+ # it freezes if we use the normal get_consolidation_mode function
191
+ consolidation_mode = skypilot_config.get_nested(
192
+ ('jobs', 'controller', 'consolidation_mode'), default_value=False)
193
+ if ((parsed.hostname == 'localhost' or
194
+ ipaddress.ip_address(parsed.hostname).is_loopback) and
195
+ not consolidation_mode):
196
+ return True
197
+ return False
198
+
199
+
200
+ def initialize_and_get_db_async() -> sql_async.AsyncEngine:
201
+ global _SQLALCHEMY_ENGINE_ASYNC
202
+ if _SQLALCHEMY_ENGINE_ASYNC is not None:
203
+ return _SQLALCHEMY_ENGINE_ASYNC
204
+ with _SQLALCHEMY_ENGINE_LOCK:
205
+ if _SQLALCHEMY_ENGINE_ASYNC is not None:
206
+ return _SQLALCHEMY_ENGINE_ASYNC
207
+
208
+ _SQLALCHEMY_ENGINE_ASYNC = db_utils.get_engine('spot_jobs',
209
+ async_engine=True)
210
+
211
+ # to create the table in case an async function gets called first
212
+ initialize_and_get_db()
213
+ return _SQLALCHEMY_ENGINE_ASYNC
214
+
215
+
144
216
  # We wrap the sqlalchemy engine initialization in a thread
145
217
  # lock to ensure that multiple threads do not initialize the
146
218
  # engine which could result in a rare race condition where
@@ -149,7 +221,6 @@ def create_table(engine: sqlalchemy.engine.Engine):
149
221
  # which could result in e1 being garbage collected unexpectedly.
150
222
  def initialize_and_get_db() -> sqlalchemy.engine.Engine:
151
223
  global _SQLALCHEMY_ENGINE
152
-
153
224
  if _SQLALCHEMY_ENGINE is not None:
154
225
  return _SQLALCHEMY_ENGINE
155
226
 
@@ -167,17 +238,85 @@ def initialize_and_get_db() -> sqlalchemy.engine.Engine:
167
238
  return _SQLALCHEMY_ENGINE
168
239
 
169
240
 
241
+ def _init_db_async(func):
242
+ """Initialize the async database. Add backoff to the function call."""
243
+
244
+ @functools.wraps(func)
245
+ async def wrapper(*args, **kwargs):
246
+ if _SQLALCHEMY_ENGINE_ASYNC is None:
247
+ # this may happen multiple times since there is no locking
248
+ # here but thats fine, this is just a short circuit for the
249
+ # common case.
250
+ await context_utils.to_thread(initialize_and_get_db_async)
251
+
252
+ backoff = common_utils.Backoff(initial_backoff=1, max_backoff_factor=5)
253
+ last_exc = None
254
+ for _ in range(_DB_RETRY_TIMES):
255
+ try:
256
+ return await func(*args, **kwargs)
257
+ except (sqlalchemy_exc.OperationalError,
258
+ asyncio.exceptions.TimeoutError, OSError,
259
+ sqlalchemy_exc.TimeoutError, sqlite3.OperationalError,
260
+ sqlalchemy_exc.InterfaceError, sqlite3.InterfaceError) as e:
261
+ last_exc = e
262
+ logger.debug(f'DB error: {last_exc}')
263
+ await asyncio.sleep(backoff.current_backoff())
264
+ assert last_exc is not None
265
+ raise last_exc
266
+
267
+ return wrapper
268
+
269
+
170
270
  def _init_db(func):
171
- """Initialize the database."""
271
+ """Initialize the database. Add backoff to the function call."""
172
272
 
173
273
  @functools.wraps(func)
174
274
  def wrapper(*args, **kwargs):
175
- initialize_and_get_db()
176
- return func(*args, **kwargs)
275
+ if _SQLALCHEMY_ENGINE is None:
276
+ # this may happen multiple times since there is no locking
277
+ # here but thats fine, this is just a short circuit for the
278
+ # common case.
279
+ initialize_and_get_db()
280
+
281
+ backoff = common_utils.Backoff(initial_backoff=1, max_backoff_factor=10)
282
+ last_exc = None
283
+ for _ in range(_DB_RETRY_TIMES):
284
+ try:
285
+ return func(*args, **kwargs)
286
+ except (sqlalchemy_exc.OperationalError,
287
+ asyncio.exceptions.TimeoutError, OSError,
288
+ sqlalchemy_exc.TimeoutError, sqlite3.OperationalError,
289
+ sqlalchemy_exc.InterfaceError, sqlite3.InterfaceError) as e:
290
+ last_exc = e
291
+ logger.debug(f'DB error: {last_exc}')
292
+ time.sleep(backoff.current_backoff())
293
+ assert last_exc is not None
294
+ raise last_exc
177
295
 
178
296
  return wrapper
179
297
 
180
298
 
299
+ async def _describe_task_transition_failure(session: sql_async.AsyncSession,
300
+ job_id: int, task_id: int) -> str:
301
+ """Return a human-readable description when a task transition fails."""
302
+ details = 'Couldn\'t fetch the task details.'
303
+ try:
304
+ debug_result = await session.execute(
305
+ sqlalchemy.select(spot_table.c.status, spot_table.c.end_at).where(
306
+ sqlalchemy.and_(spot_table.c.spot_job_id == job_id,
307
+ spot_table.c.task_id == task_id)))
308
+ rows = debug_result.mappings().all()
309
+ details = (f'{len(rows)} rows matched job {job_id} and task '
310
+ f'{task_id}.')
311
+ for row in rows:
312
+ status = row['status']
313
+ end_at = row['end_at']
314
+ details += f' Status: {status}, End time: {end_at}.'
315
+ except Exception as exc: # pylint: disable=broad-except
316
+ details += f' Error fetching task details: {exc}'
317
+ return details
318
+
319
+
181
320
  # job_duration is the time a job actually runs (including the
182
321
  # setup duration) before last_recover, excluding the provision
183
322
  # and recovery time.
@@ -191,42 +330,52 @@ def _init_db(func):
191
330
  # column names in the DB and it corresponds to the combined view
192
331
  # by joining the spot and job_info tables.
193
332
  def _get_jobs_dict(r: 'row.RowMapping') -> Dict[str, Any]:
333
+ # WARNING: If you update these you may also need to update GetJobTable in
334
+ # the skylet ManagedJobsServiceImpl.
194
335
  return {
195
- '_job_id': r['job_id'], # from spot table
196
- '_task_name': r['job_name'], # deprecated, from spot table
197
- 'resources': r['resources'],
198
- 'submitted_at': r['submitted_at'],
199
- 'status': r['status'],
200
- 'run_timestamp': r['run_timestamp'],
201
- 'start_at': r['start_at'],
202
- 'end_at': r['end_at'],
203
- 'last_recovered_at': r['last_recovered_at'],
204
- 'recovery_count': r['recovery_count'],
205
- 'job_duration': r['job_duration'],
206
- 'failure_reason': r['failure_reason'],
207
- 'job_id': r[spot_table.c.spot_job_id], # ambiguous, use table.column
208
- 'task_id': r['task_id'],
209
- 'task_name': r['task_name'],
210
- 'specs': r['specs'],
211
- 'local_log_file': r['local_log_file'],
212
- 'metadata': r['metadata'],
336
+ '_job_id': r.get('job_id'), # from spot table
337
+ '_task_name': r.get('job_name'), # deprecated, from spot table
338
+ 'resources': r.get('resources'),
339
+ 'submitted_at': r.get('submitted_at'),
340
+ 'status': r.get('status'),
341
+ 'run_timestamp': r.get('run_timestamp'),
342
+ 'start_at': r.get('start_at'),
343
+ 'end_at': r.get('end_at'),
344
+ 'last_recovered_at': r.get('last_recovered_at'),
345
+ 'recovery_count': r.get('recovery_count'),
346
+ 'job_duration': r.get('job_duration'),
347
+ 'failure_reason': r.get('failure_reason'),
348
+ 'job_id': r.get(spot_table.c.spot_job_id
349
+ ), # ambiguous, use table.column
350
+ 'task_id': r.get('task_id'),
351
+ 'task_name': r.get('task_name'),
352
+ 'specs': r.get('specs'),
353
+ 'local_log_file': r.get('local_log_file'),
354
+ 'metadata': r.get('metadata'),
213
355
  # columns from job_info table (some may be None for legacy jobs)
214
- '_job_info_job_id': r[job_info_table.c.spot_job_id
215
- ], # ambiguous, use table.column
216
- 'job_name': r['name'], # from job_info table
217
- 'schedule_state': r['schedule_state'],
218
- 'controller_pid': r['controller_pid'],
219
- 'dag_yaml_path': r['dag_yaml_path'],
220
- 'env_file_path': r['env_file_path'],
221
- 'user_hash': r['user_hash'],
222
- 'workspace': r['workspace'],
223
- 'priority': r['priority'],
224
- 'entrypoint': r['entrypoint'],
225
- 'original_user_yaml_path': r['original_user_yaml_path'],
226
- 'pool': r['pool'],
227
- 'current_cluster_name': r['current_cluster_name'],
228
- 'job_id_on_pool_cluster': r['job_id_on_pool_cluster'],
229
- 'pool_hash': r['pool_hash'],
356
+ '_job_info_job_id': r.get(job_info_table.c.spot_job_id
357
+ ), # ambiguous, use table.column
358
+ 'job_name': r.get('name'), # from job_info table
359
+ 'schedule_state': r.get('schedule_state'),
360
+ 'controller_pid': r.get('controller_pid'),
361
+ 'controller_pid_started_at': r.get('controller_pid_started_at'),
362
+ # the _path columns are for backwards compatibility, use the _content
363
+ # columns instead
364
+ 'dag_yaml_path': r.get('dag_yaml_path'),
365
+ 'env_file_path': r.get('env_file_path'),
366
+ 'dag_yaml_content': r.get('dag_yaml_content'),
367
+ 'env_file_content': r.get('env_file_content'),
368
+ 'config_file_content': r.get('config_file_content'),
369
+ 'user_hash': r.get('user_hash'),
370
+ 'workspace': r.get('workspace'),
371
+ 'priority': r.get('priority'),
372
+ 'entrypoint': r.get('entrypoint'),
373
+ 'original_user_yaml_path': r.get('original_user_yaml_path'),
374
+ 'original_user_yaml_content': r.get('original_user_yaml_content'),
375
+ 'pool': r.get('pool'),
376
+ 'current_cluster_name': r.get('current_cluster_name'),
377
+ 'job_id_on_pool_cluster': r.get('job_id_on_pool_cluster'),
378
+ 'pool_hash': r.get('pool_hash'),
230
379
  }
231
380
 
232
381
 
@@ -353,6 +502,75 @@ class ManagedJobStatus(enum.Enum):
353
502
  cls.RECOVERING,
354
503
  ]
355
504
 
505
+ @classmethod
506
+ def from_protobuf(
507
+ cls, protobuf_value: 'managed_jobsv1_pb2.ManagedJobStatus'
508
+ ) -> Optional['ManagedJobStatus']:
509
+ """Convert protobuf ManagedJobStatus enum to Python enum value."""
510
+ protobuf_to_enum = {
511
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_UNSPECIFIED: None,
512
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_PENDING: cls.PENDING,
513
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_SUBMITTED:
514
+ cls.DEPRECATED_SUBMITTED,
515
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_STARTING: cls.STARTING,
516
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_RUNNING: cls.RUNNING,
517
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_SUCCEEDED: cls.SUCCEEDED,
518
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED: cls.FAILED,
519
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_CONTROLLER:
520
+ cls.FAILED_CONTROLLER,
521
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_SETUP:
522
+ cls.FAILED_SETUP,
523
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_CANCELLED: cls.CANCELLED,
524
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_RECOVERING: cls.RECOVERING,
525
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_CANCELLING: cls.CANCELLING,
526
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_PRECHECKS:
527
+ cls.FAILED_PRECHECKS,
528
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_NO_RESOURCE:
529
+ cls.FAILED_NO_RESOURCE,
530
+ }
531
+
532
+ if protobuf_value not in protobuf_to_enum:
533
+ raise ValueError(
534
+ f'Unknown protobuf ManagedJobStatus value: {protobuf_value}')
535
+
536
+ return protobuf_to_enum[protobuf_value]
537
+
538
+ def to_protobuf(self) -> 'managed_jobsv1_pb2.ManagedJobStatus':
539
+ """Convert this Python enum value to protobuf enum value."""
540
+ enum_to_protobuf = {
541
+ ManagedJobStatus.PENDING:
542
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_PENDING,
543
+ ManagedJobStatus.DEPRECATED_SUBMITTED:
544
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_SUBMITTED,
545
+ ManagedJobStatus.STARTING:
546
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_STARTING,
547
+ ManagedJobStatus.RUNNING:
548
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_RUNNING,
549
+ ManagedJobStatus.SUCCEEDED:
550
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_SUCCEEDED,
551
+ ManagedJobStatus.FAILED:
552
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED,
553
+ ManagedJobStatus.FAILED_CONTROLLER:
554
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_CONTROLLER,
555
+ ManagedJobStatus.FAILED_SETUP:
556
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_SETUP,
557
+ ManagedJobStatus.CANCELLED:
558
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_CANCELLED,
559
+ ManagedJobStatus.RECOVERING:
560
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_RECOVERING,
561
+ ManagedJobStatus.CANCELLING:
562
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_CANCELLING,
563
+ ManagedJobStatus.FAILED_PRECHECKS:
564
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_PRECHECKS,
565
+ ManagedJobStatus.FAILED_NO_RESOURCE:
566
+ managed_jobsv1_pb2.MANAGED_JOB_STATUS_FAILED_NO_RESOURCE,
567
+ }
568
+
569
+ if self not in enum_to_protobuf:
570
+ raise ValueError(f'Unknown ManagedJobStatus value: {self}')
571
+
572
+ return enum_to_protobuf[self]
573
+
356
574
 
357
575
  _SPOT_STATUS_TO_COLOR = {
358
576
  ManagedJobStatus.PENDING: colorama.Fore.BLUE,
@@ -375,8 +593,6 @@ _SPOT_STATUS_TO_COLOR = {
375
593
  class ManagedJobScheduleState(enum.Enum):
376
594
  """Captures the state of the job from the scheduler's perspective.
377
595
 
378
- A job that predates the introduction of the scheduler will be INVALID.
379
-
380
596
  A newly created job will be INACTIVE. The following transitions are valid:
381
597
  - INACTIVE -> WAITING: The job is "submitted" to the scheduler, and its job
382
598
  controller can be started.
@@ -413,10 +629,10 @@ class ManagedJobScheduleState(enum.Enum):
413
629
  briefly observe inconsistent states, like a job that just finished but
414
630
  hasn't yet transitioned to DONE.
415
631
  """
416
- # This job may have been created before scheduler was introduced in #4458.
417
- # This state is not used by scheduler but just for backward compatibility.
418
- # TODO(cooperc): remove this in v0.11.0
419
- INVALID = None
632
+ # TODO(luca): the only states we need are INACTIVE, WAITING, ALIVE, and
633
+ # DONE. ALIVE = old LAUNCHING + ALIVE + ALIVE_BACKOFF + ALIVE_WAITING and
634
+ # will represent jobs that are claimed by a controller. Delete the rest
635
+ # in v0.13.0
420
636
  # The job should be ignored by the scheduler.
421
637
  INACTIVE = 'INACTIVE'
422
638
  # The job is waiting to transition to LAUNCHING for the first time. The
@@ -438,38 +654,72 @@ class ManagedJobScheduleState(enum.Enum):
438
654
  # The job is in a terminal state. (Not necessarily SUCCEEDED.)
439
655
  DONE = 'DONE'
440
656
 
657
+ @classmethod
658
+ def from_protobuf(
659
+ cls, protobuf_value: 'managed_jobsv1_pb2.ManagedJobScheduleState'
660
+ ) -> Optional['ManagedJobScheduleState']:
661
+ """Convert protobuf ManagedJobScheduleState enum to Python enum value.
662
+ """
663
+ protobuf_to_enum = {
664
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_UNSPECIFIED: None,
665
+ # TODO(cooperc): remove this in v0.13.0. See #8105.
666
+ managed_jobsv1_pb2.DEPRECATED_MANAGED_JOB_SCHEDULE_STATE_INVALID:
667
+ (None),
668
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_INACTIVE:
669
+ cls.INACTIVE,
670
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_WAITING: cls.WAITING,
671
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_ALIVE_WAITING:
672
+ cls.ALIVE_WAITING,
673
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_LAUNCHING:
674
+ cls.LAUNCHING,
675
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_ALIVE_BACKOFF:
676
+ cls.ALIVE_BACKOFF,
677
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_ALIVE: cls.ALIVE,
678
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_DONE: cls.DONE,
679
+ }
441
680
 
442
- # === Status transition functions ===
443
- @_init_db
444
- def set_job_info(job_id: int, name: str, workspace: str, entrypoint: str,
445
- pool: Optional[str], pool_hash: Optional[str]):
446
- assert _SQLALCHEMY_ENGINE is not None
447
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
448
- if (_SQLALCHEMY_ENGINE.dialect.name ==
449
- db_utils.SQLAlchemyDialect.SQLITE.value):
450
- insert_func = sqlite.insert
451
- elif (_SQLALCHEMY_ENGINE.dialect.name ==
452
- db_utils.SQLAlchemyDialect.POSTGRESQL.value):
453
- insert_func = postgresql.insert
454
- else:
455
- raise ValueError('Unsupported database dialect')
456
- insert_stmt = insert_func(job_info_table).values(
457
- spot_job_id=job_id,
458
- name=name,
459
- schedule_state=ManagedJobScheduleState.INACTIVE.value,
460
- workspace=workspace,
461
- entrypoint=entrypoint,
462
- pool=pool,
463
- pool_hash=pool_hash,
464
- )
465
- session.execute(insert_stmt)
466
- session.commit()
681
+ if protobuf_value not in protobuf_to_enum:
682
+ raise ValueError('Unknown protobuf ManagedJobScheduleState value: '
683
+ f'{protobuf_value}')
684
+
685
+ return protobuf_to_enum[protobuf_value]
686
+
687
+ def to_protobuf(self) -> 'managed_jobsv1_pb2.ManagedJobScheduleState':
688
+ """Convert this Python enum value to protobuf enum value."""
689
+ enum_to_protobuf = {
690
+ ManagedJobScheduleState.INACTIVE:
691
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_INACTIVE,
692
+ ManagedJobScheduleState.WAITING:
693
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_WAITING,
694
+ ManagedJobScheduleState.ALIVE_WAITING:
695
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_ALIVE_WAITING,
696
+ ManagedJobScheduleState.LAUNCHING:
697
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_LAUNCHING,
698
+ ManagedJobScheduleState.ALIVE_BACKOFF:
699
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_ALIVE_BACKOFF,
700
+ ManagedJobScheduleState.ALIVE:
701
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_ALIVE,
702
+ ManagedJobScheduleState.DONE:
703
+ managed_jobsv1_pb2.MANAGED_JOB_SCHEDULE_STATE_DONE,
704
+ }
705
+
706
+ if self not in enum_to_protobuf:
707
+ raise ValueError(f'Unknown ManagedJobScheduleState value: {self}')
467
708
 
709
+ return enum_to_protobuf[self]
468
710
 
711
+
712
+ ControllerPidRecord = collections.namedtuple('ControllerPidRecord', [
713
+ 'pid',
714
+ 'started_at',
715
+ ])
716
+
717
+
718
+ # === Status transition functions ===
469
719
  @_init_db
470
720
  def set_job_info_without_job_id(name: str, workspace: str, entrypoint: str,
471
- pool: Optional[str],
472
- pool_hash: Optional[str]) -> int:
721
+ pool: Optional[str], pool_hash: Optional[str],
722
+ user_hash: Optional[str]) -> int:
473
723
  assert _SQLALCHEMY_ENGINE is not None
474
724
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
475
725
  if (_SQLALCHEMY_ENGINE.dialect.name ==
@@ -488,6 +738,7 @@ def set_job_info_without_job_id(name: str, workspace: str, entrypoint: str,
488
738
  entrypoint=entrypoint,
489
739
  pool=pool,
490
740
  pool_hash=pool_hash,
741
+ user_hash=user_hash,
491
742
  )
492
743
 
493
744
  if (_SQLALCHEMY_ENGINE.dialect.name ==
@@ -517,6 +768,7 @@ def set_pending(
517
768
  ):
518
769
  """Set the task to pending state."""
519
770
  assert _SQLALCHEMY_ENGINE is not None
771
+
520
772
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
521
773
  session.execute(
522
774
  sqlalchemy.insert(spot_table).values(
@@ -530,85 +782,41 @@ def set_pending(
530
782
  session.commit()
531
783
 
532
784
 
533
- @_init_db
534
- def set_starting(job_id: int, task_id: int, run_timestamp: str,
535
- submit_time: float, resources_str: str,
536
- specs: Dict[str, Union[str,
537
- int]], callback_func: CallbackType):
538
- """Set the task to starting state.
539
-
540
- Args:
541
- job_id: The managed job ID.
542
- task_id: The task ID.
543
- run_timestamp: The run_timestamp of the run. This will be used to
544
- determine the log directory of the managed task.
545
- submit_time: The time when the managed task is submitted.
546
- resources_str: The resources string of the managed task.
547
- specs: The specs of the managed task.
548
- callback_func: The callback function.
549
- """
550
- assert _SQLALCHEMY_ENGINE is not None
551
- # Use the timestamp in the `run_timestamp` ('sky-2022-10...'), to make
552
- # the log directory and submission time align with each other, so as to
553
- # make it easier to find them based on one of the values.
554
- # Also, using the earlier timestamp should be closer to the term
555
- # `submit_at`, which represents the time the managed task is submitted.
556
- logger.info('Launching the spot cluster...')
557
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
558
- count = session.query(spot_table).filter(
559
- sqlalchemy.and_(
560
- spot_table.c.spot_job_id == job_id,
561
- spot_table.c.task_id == task_id,
562
- spot_table.c.status == ManagedJobStatus.PENDING.value,
563
- spot_table.c.end_at.is_(None),
564
- )).update({
565
- spot_table.c.resources: resources_str,
566
- spot_table.c.submitted_at: submit_time,
567
- spot_table.c.status: ManagedJobStatus.STARTING.value,
568
- spot_table.c.run_timestamp: run_timestamp,
569
- spot_table.c.specs: json.dumps(specs),
570
- })
571
- session.commit()
572
- if count != 1:
573
- raise exceptions.ManagedJobStatusError(
574
- 'Failed to set the task to starting. '
575
- f'({count} rows updated)')
576
- # SUBMITTED is no longer used, but we keep it for backward compatibility.
577
- # TODO(cooperc): remove this in v0.12.0
578
- callback_func('SUBMITTED')
579
- callback_func('STARTING')
580
-
581
-
582
- @_init_db
583
- def set_backoff_pending(job_id: int, task_id: int):
785
+ @_init_db_async
786
+ async def set_backoff_pending_async(job_id: int, task_id: int):
584
787
  """Set the task to PENDING state if it is in backoff.
585
788
 
586
789
  This should only be used to transition from STARTING or RECOVERING back to
587
790
  PENDING.
588
791
  """
589
- assert _SQLALCHEMY_ENGINE is not None
590
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
591
- count = session.query(spot_table).filter(
592
- sqlalchemy.and_(
593
- spot_table.c.spot_job_id == job_id,
594
- spot_table.c.task_id == task_id,
595
- spot_table.c.status.in_([
596
- ManagedJobStatus.STARTING.value,
597
- ManagedJobStatus.RECOVERING.value
598
- ]),
599
- spot_table.c.end_at.is_(None),
600
- )).update({spot_table.c.status: ManagedJobStatus.PENDING.value})
601
- session.commit()
602
- logger.debug('back to PENDING')
792
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
793
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
794
+ result = await session.execute(
795
+ sqlalchemy.update(spot_table).where(
796
+ sqlalchemy.and_(
797
+ spot_table.c.spot_job_id == job_id,
798
+ spot_table.c.task_id == task_id,
799
+ spot_table.c.status.in_([
800
+ ManagedJobStatus.STARTING.value,
801
+ ManagedJobStatus.RECOVERING.value
802
+ ]),
803
+ spot_table.c.end_at.is_(None),
804
+ )).values({spot_table.c.status: ManagedJobStatus.PENDING.value})
805
+ )
806
+ count = result.rowcount
807
+ await session.commit()
603
808
  if count != 1:
604
- raise exceptions.ManagedJobStatusError(
605
- 'Failed to set the task back to pending. '
606
- f'({count} rows updated)')
809
+ details = await _describe_task_transition_failure(
810
+ session, job_id, task_id)
811
+ message = ('Failed to set the task back to pending. '
812
+ f'({count} rows updated. {details})')
813
+ logger.error(message)
814
+ raise exceptions.ManagedJobStatusError(message)
607
815
  # Do not call callback_func here, as we don't use the callback for PENDING.
608
816
 
609
817
 
610
818
  @_init_db
611
- def set_restarting(job_id: int, task_id: int, recovering: bool):
819
+ async def set_restarting_async(job_id: int, task_id: int, recovering: bool):
612
820
  """Set the task back to STARTING or RECOVERING from PENDING.
613
821
 
614
822
  This should not be used for the initial transition from PENDING to STARTING.
@@ -616,159 +824,32 @@ def set_restarting(job_id: int, task_id: int, recovering: bool):
616
824
  after using set_backoff_pending to transition back to PENDING during
617
825
  launch retry backoff.
618
826
  """
619
- assert _SQLALCHEMY_ENGINE is not None
827
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
620
828
  target_status = ManagedJobStatus.STARTING.value
621
829
  if recovering:
622
830
  target_status = ManagedJobStatus.RECOVERING.value
623
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
624
- count = session.query(spot_table).filter(
625
- sqlalchemy.and_(
626
- spot_table.c.spot_job_id == job_id,
627
- spot_table.c.task_id == task_id,
628
- spot_table.c.status == ManagedJobStatus.PENDING.value,
629
- spot_table.c.end_at.is_(None),
630
- )).update({spot_table.c.status: target_status})
631
- session.commit()
831
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
832
+ result = await session.execute(
833
+ sqlalchemy.update(spot_table).where(
834
+ sqlalchemy.and_(
835
+ spot_table.c.spot_job_id == job_id,
836
+ spot_table.c.task_id == task_id,
837
+ spot_table.c.end_at.is_(None),
838
+ )).values({spot_table.c.status: target_status}))
839
+ count = result.rowcount
840
+ await session.commit()
632
841
  logger.debug(f'back to {target_status}')
633
842
  if count != 1:
634
- raise exceptions.ManagedJobStatusError(
635
- f'Failed to set the task back to {target_status}. '
636
- f'({count} rows updated)')
843
+ details = await _describe_task_transition_failure(
844
+ session, job_id, task_id)
845
+ message = (f'Failed to set the task back to {target_status}. '
846
+ f'({count} rows updated. {details})')
847
+ logger.error(message)
848
+ raise exceptions.ManagedJobStatusError(message)
637
849
  # Do not call callback_func here, as it should only be invoked for the
638
850
  # initial (pre-`set_backoff_pending`) transition to STARTING or RECOVERING.
639
851
 
640
852
 
641
- @_init_db
642
- def set_started(job_id: int, task_id: int, start_time: float,
643
- callback_func: CallbackType):
644
- """Set the task to started state."""
645
- assert _SQLALCHEMY_ENGINE is not None
646
- logger.info('Job started.')
647
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
648
- count = session.query(spot_table).filter(
649
- sqlalchemy.and_(
650
- spot_table.c.spot_job_id == job_id,
651
- spot_table.c.task_id == task_id,
652
- spot_table.c.status.in_([
653
- ManagedJobStatus.STARTING.value,
654
- # If the task is empty, we will jump straight
655
- # from PENDING to RUNNING
656
- ManagedJobStatus.PENDING.value
657
- ]),
658
- spot_table.c.end_at.is_(None),
659
- )).update({
660
- spot_table.c.status: ManagedJobStatus.RUNNING.value,
661
- spot_table.c.start_at: start_time,
662
- spot_table.c.last_recovered_at: start_time,
663
- })
664
- session.commit()
665
- if count != 1:
666
- raise exceptions.ManagedJobStatusError(
667
- f'Failed to set the task to started. '
668
- f'({count} rows updated)')
669
- callback_func('STARTED')
670
-
671
-
672
- @_init_db
673
- def set_recovering(job_id: int, task_id: int, force_transit_to_recovering: bool,
674
- callback_func: CallbackType):
675
- """Set the task to recovering state, and update the job duration."""
676
- assert _SQLALCHEMY_ENGINE is not None
677
- logger.info('=== Recovering... ===')
678
- # NOTE: if we are resuming from a controller failure and the previous status
679
- # is STARTING, the initial value of `last_recovered_at` might not be set
680
- # yet (default value -1). In this case, we should not add current timestamp.
681
- # Otherwise, the job duration will be incorrect (~55 years from 1970).
682
- current_time = time.time()
683
-
684
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
685
- if force_transit_to_recovering:
686
- # For the HA job controller, it is possible that the jobs came from
687
- # any processing status to recovering. But it should not be any
688
- # terminal status as such jobs will not be recovered; and it should
689
- # not be CANCELLING as we will directly trigger a cleanup.
690
- status_condition = spot_table.c.status.in_(
691
- [s.value for s in ManagedJobStatus.processing_statuses()])
692
- else:
693
- status_condition = (
694
- spot_table.c.status == ManagedJobStatus.RUNNING.value)
695
-
696
- count = session.query(spot_table).filter(
697
- sqlalchemy.and_(
698
- spot_table.c.spot_job_id == job_id,
699
- spot_table.c.task_id == task_id,
700
- status_condition,
701
- spot_table.c.end_at.is_(None),
702
- )).update({
703
- spot_table.c.status: ManagedJobStatus.RECOVERING.value,
704
- spot_table.c.job_duration: sqlalchemy.case(
705
- (spot_table.c.last_recovered_at >= 0,
706
- spot_table.c.job_duration + current_time -
707
- spot_table.c.last_recovered_at),
708
- else_=spot_table.c.job_duration),
709
- spot_table.c.last_recovered_at: sqlalchemy.case(
710
- (spot_table.c.last_recovered_at < 0, current_time),
711
- else_=spot_table.c.last_recovered_at),
712
- })
713
- session.commit()
714
- if count != 1:
715
- raise exceptions.ManagedJobStatusError(
716
- f'Failed to set the task to recovering. '
717
- f'({count} rows updated)')
718
- callback_func('RECOVERING')
719
-
720
-
721
- @_init_db
722
- def set_recovered(job_id: int, task_id: int, recovered_time: float,
723
- callback_func: CallbackType):
724
- """Set the task to recovered."""
725
- assert _SQLALCHEMY_ENGINE is not None
726
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
727
- count = session.query(spot_table).filter(
728
- sqlalchemy.and_(
729
- spot_table.c.spot_job_id == job_id,
730
- spot_table.c.task_id == task_id,
731
- spot_table.c.status == ManagedJobStatus.RECOVERING.value,
732
- spot_table.c.end_at.is_(None),
733
- )).update({
734
- spot_table.c.status: ManagedJobStatus.RUNNING.value,
735
- spot_table.c.last_recovered_at: recovered_time,
736
- spot_table.c.recovery_count: spot_table.c.recovery_count + 1,
737
- })
738
- session.commit()
739
- if count != 1:
740
- raise exceptions.ManagedJobStatusError(
741
- f'Failed to set the task to recovered. '
742
- f'({count} rows updated)')
743
- logger.info('==== Recovered. ====')
744
- callback_func('RECOVERED')
745
-
746
-
747
- @_init_db
748
- def set_succeeded(job_id: int, task_id: int, end_time: float,
749
- callback_func: CallbackType):
750
- """Set the task to succeeded, if it is in a non-terminal state."""
751
- assert _SQLALCHEMY_ENGINE is not None
752
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
753
- count = session.query(spot_table).filter(
754
- sqlalchemy.and_(
755
- spot_table.c.spot_job_id == job_id,
756
- spot_table.c.task_id == task_id,
757
- spot_table.c.status == ManagedJobStatus.RUNNING.value,
758
- spot_table.c.end_at.is_(None),
759
- )).update({
760
- spot_table.c.status: ManagedJobStatus.SUCCEEDED.value,
761
- spot_table.c.end_at: end_time,
762
- })
763
- session.commit()
764
- if count != 1:
765
- raise exceptions.ManagedJobStatusError(
766
- f'Failed to set the task to succeeded. '
767
- f'({count} rows updated)')
768
- callback_func('SUCCEEDED')
769
- logger.info('Job succeeded.')
770
-
771
-
772
853
  @_init_db
773
854
  def set_failed(
774
855
  job_id: int,
@@ -816,7 +897,19 @@ def set_failed(
816
897
  where_conditions = [spot_table.c.spot_job_id == job_id]
817
898
  if task_id is not None:
818
899
  where_conditions.append(spot_table.c.task_id == task_id)
900
+
901
+ # Handle failure_reason prepending when override_terminal is True
819
902
  if override_terminal:
903
+ # Get existing failure_reason with row lock to prevent race
904
+ # conditions
905
+ existing_reason_result = session.execute(
906
+ sqlalchemy.select(spot_table.c.failure_reason).where(
907
+ sqlalchemy.and_(*where_conditions)).with_for_update())
908
+ existing_reason_row = existing_reason_result.fetchone()
909
+ if existing_reason_row and existing_reason_row[0]:
910
+ # Prepend new failure reason to existing one
911
+ fields_to_set[spot_table.c.failure_reason] = (
912
+ failure_reason + '. Previously: ' + existing_reason_row[0])
820
913
  # Use COALESCE for end_at to avoid overriding the existing end_at if
821
914
  # it's already set.
822
915
  fields_to_set[spot_table.c.end_at] = sqlalchemy.func.coalesce(
@@ -834,51 +927,42 @@ def set_failed(
834
927
 
835
928
 
836
929
  @_init_db
837
- def set_cancelling(job_id: int, callback_func: CallbackType):
838
- """Set tasks in the job as cancelling, if they are in non-terminal states.
839
-
840
- task_id is not needed, because we expect the job should be cancelled
841
- as a whole, and we should not cancel a single task.
842
- """
843
- assert _SQLALCHEMY_ENGINE is not None
844
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
845
- count = session.query(spot_table).filter(
846
- sqlalchemy.and_(
847
- spot_table.c.spot_job_id == job_id,
848
- spot_table.c.end_at.is_(None),
849
- )).update({spot_table.c.status: ManagedJobStatus.CANCELLING.value})
850
- session.commit()
851
- updated = count > 0
852
- if updated:
853
- logger.info('Cancelling the job...')
854
- callback_func('CANCELLING')
855
- else:
856
- logger.info('Cancellation skipped, job is already terminal')
857
-
930
+ def set_pending_cancelled(job_id: int):
931
+ """Set the job as cancelled, if it is PENDING and WAITING/INACTIVE.
858
932
 
859
- @_init_db
860
- def set_cancelled(job_id: int, callback_func: CallbackType):
861
- """Set tasks in the job as cancelled, if they are in CANCELLING state.
933
+ This may fail if the job is not PENDING, e.g. another process has changed
934
+ its state in the meantime.
862
935
 
863
- The set_cancelling should be called before this function.
936
+ Returns:
937
+ True if the job was cancelled, False otherwise.
864
938
  """
865
939
  assert _SQLALCHEMY_ENGINE is not None
866
940
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
941
+ # Subquery to get the spot_job_ids that match the joined condition
942
+ subquery = session.query(spot_table.c.job_id).join(
943
+ job_info_table,
944
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id
945
+ ).filter(
946
+ spot_table.c.spot_job_id == job_id,
947
+ spot_table.c.status == ManagedJobStatus.PENDING.value,
948
+ # Note: it's possible that a WAITING job actually needs to be
949
+ # cleaned up, if we are in the middle of an upgrade/recovery and
950
+ # the job is waiting to be reclaimed by a new controller. But,
951
+ # in this case the status will not be PENDING.
952
+ sqlalchemy.or_(
953
+ job_info_table.c.schedule_state ==
954
+ ManagedJobScheduleState.WAITING.value,
955
+ job_info_table.c.schedule_state ==
956
+ ManagedJobScheduleState.INACTIVE.value,
957
+ ),
958
+ ).subquery()
959
+
867
960
  count = session.query(spot_table).filter(
868
- sqlalchemy.and_(
869
- spot_table.c.spot_job_id == job_id,
870
- spot_table.c.status == ManagedJobStatus.CANCELLING.value,
871
- )).update({
872
- spot_table.c.status: ManagedJobStatus.CANCELLED.value,
873
- spot_table.c.end_at: time.time(),
874
- })
961
+ spot_table.c.job_id.in_(subquery)).update(
962
+ {spot_table.c.status: ManagedJobStatus.CANCELLED.value},
963
+ synchronize_session=False)
875
964
  session.commit()
876
- updated = count > 0
877
- if updated:
878
- logger.info('Job cancelled.')
879
- callback_func('CANCELLED')
880
- else:
881
- logger.info('Cancellation skipped, job is not CANCELLING')
965
+ return count > 0
882
966
 
883
967
 
884
968
  @_init_db
@@ -899,8 +983,14 @@ def set_local_log_file(job_id: int, task_id: Optional[int],
899
983
  # ======== utility functions ========
900
984
  @_init_db
901
985
  def get_nonterminal_job_ids_by_name(name: Optional[str],
986
+ user_hash: Optional[str] = None,
902
987
  all_users: bool = False) -> List[int]:
903
- """Get non-terminal job ids by name."""
988
+ """Get non-terminal job ids by name.
989
+
990
+ If name is None:
991
+ 1. if all_users is False, get for the given user_hash
992
+ 2. otherwise, get for all users
993
+ """
904
994
  assert _SQLALCHEMY_ENGINE is not None
905
995
 
906
996
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
@@ -917,8 +1007,15 @@ def get_nonterminal_job_ids_by_name(name: Optional[str],
917
1007
  ])
918
1008
  ]
919
1009
  if name is None and not all_users:
920
- where_conditions.append(
921
- job_info_table.c.user_hash == common_utils.get_user_hash())
1010
+ if user_hash is None:
1011
+ # For backwards compatibility. With codegen, USER_ID_ENV_VAR
1012
+ # was set to the correct value by the jobs controller, as
1013
+ # part of ManagedJobCodeGen._build(). This is no longer the
1014
+ # case for the Skylet gRPC server, which is why we need to
1015
+ # pass it explicitly through the request body.
1016
+ logger.debug('user_hash is None, using current user hash')
1017
+ user_hash = common_utils.get_user_hash()
1018
+ where_conditions.append(job_info_table.c.user_hash == user_hash)
922
1019
  if name is not None:
923
1020
  # We match the job name from `job_info` for the jobs submitted after
924
1021
  # #1982, and from `spot` for the jobs submitted before #1982, whose
@@ -936,45 +1033,6 @@ def get_nonterminal_job_ids_by_name(name: Optional[str],
936
1033
  return job_ids
937
1034
 
938
1035
 
939
- @_init_db
940
- def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]:
941
- """Get jobs from the database that have a live schedule_state.
942
-
943
- This should return job(s) that are not INACTIVE, WAITING, or DONE. So a
944
- returned job should correspond to a live job controller process, with one
945
- exception: the job may have just transitioned from WAITING to LAUNCHING, but
946
- the controller process has not yet started.
947
- """
948
- assert _SQLALCHEMY_ENGINE is not None
949
-
950
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
951
- query = sqlalchemy.select(
952
- job_info_table.c.spot_job_id,
953
- job_info_table.c.schedule_state,
954
- job_info_table.c.controller_pid,
955
- ).where(~job_info_table.c.schedule_state.in_([
956
- ManagedJobScheduleState.INACTIVE.value,
957
- ManagedJobScheduleState.WAITING.value,
958
- ManagedJobScheduleState.DONE.value,
959
- ]))
960
-
961
- if job_id is not None:
962
- query = query.where(job_info_table.c.spot_job_id == job_id)
963
-
964
- query = query.order_by(job_info_table.c.spot_job_id.desc())
965
-
966
- rows = session.execute(query).fetchall()
967
- jobs = []
968
- for row in rows:
969
- job_dict = {
970
- 'job_id': row[0],
971
- 'schedule_state': ManagedJobScheduleState(row[1]),
972
- 'controller_pid': row[2],
973
- }
974
- jobs.append(job_dict)
975
- return jobs
976
-
977
-
978
1036
  @_init_db
979
1037
  def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
980
1038
  """Get jobs that need controller process checking.
@@ -1035,32 +1093,6 @@ def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]:
1035
1093
  return [row[0] for row in rows if row[0] is not None]
1036
1094
 
1037
1095
 
1038
- @_init_db
1039
- def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
1040
- """Get all job ids by name."""
1041
- assert _SQLALCHEMY_ENGINE is not None
1042
-
1043
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1044
- query = sqlalchemy.select(
1045
- spot_table.c.spot_job_id.distinct()).select_from(
1046
- spot_table.outerjoin(
1047
- job_info_table,
1048
- spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1049
- if name is not None:
1050
- # We match the job name from `job_info` for the jobs submitted after
1051
- # #1982, and from `spot` for the jobs submitted before #1982, whose
1052
- # job_info is not available.
1053
- name_condition = sqlalchemy.or_(
1054
- job_info_table.c.name == name,
1055
- sqlalchemy.and_(job_info_table.c.name.is_(None),
1056
- spot_table.c.task_name == name))
1057
- query = query.where(name_condition)
1058
- query = query.order_by(spot_table.c.spot_job_id.desc())
1059
- rows = session.execute(query).fetchall()
1060
- job_ids = [row[0] for row in rows if row[0] is not None]
1061
- return job_ids
1062
-
1063
-
1064
1096
  @_init_db
1065
1097
  def _get_all_task_ids_statuses(
1066
1098
  job_id: int) -> List[Tuple[int, ManagedJobStatus]]:
@@ -1077,7 +1109,8 @@ def _get_all_task_ids_statuses(
1077
1109
 
1078
1110
  @_init_db
1079
1111
  def get_all_task_ids_names_statuses_logs(
1080
- job_id: int) -> List[Tuple[int, str, ManagedJobStatus, str]]:
1112
+ job_id: int
1113
+ ) -> List[Tuple[int, str, ManagedJobStatus, str, Optional[float]]]:
1081
1114
  assert _SQLALCHEMY_ENGINE is not None
1082
1115
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
1083
1116
  id_names = session.execute(
@@ -1086,24 +1119,13 @@ def get_all_task_ids_names_statuses_logs(
1086
1119
  spot_table.c.task_name,
1087
1120
  spot_table.c.status,
1088
1121
  spot_table.c.local_log_file,
1122
+ spot_table.c.logs_cleaned_at,
1089
1123
  ).where(spot_table.c.spot_job_id == job_id).order_by(
1090
1124
  spot_table.c.task_id.asc())).fetchall()
1091
- return [(row[0], row[1], ManagedJobStatus(row[2]), row[3])
1125
+ return [(row[0], row[1], ManagedJobStatus(row[2]), row[3], row[4])
1092
1126
  for row in id_names]
1093
1127
 
1094
1128
 
1095
- @_init_db
1096
- def get_job_status_with_task_id(job_id: int,
1097
- task_id: int) -> Optional[ManagedJobStatus]:
1098
- assert _SQLALCHEMY_ENGINE is not None
1099
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1100
- status = session.execute(
1101
- sqlalchemy.select(spot_table.c.status).where(
1102
- sqlalchemy.and_(spot_table.c.spot_job_id == job_id,
1103
- spot_table.c.task_id == task_id))).fetchone()
1104
- return ManagedJobStatus(status[0]) if status else None
1105
-
1106
-
1107
1129
  def get_num_tasks(job_id: int) -> int:
1108
1130
  return len(_get_all_task_ids_statuses(job_id))
1109
1131
 
@@ -1131,13 +1153,68 @@ def get_latest_task_id_status(
1131
1153
  return task_id, status
1132
1154
 
1133
1155
 
1134
- def get_status(job_id: int) -> Optional[ManagedJobStatus]:
1135
- _, status = get_latest_task_id_status(job_id)
1136
- return status
1137
-
1138
-
1139
1156
  @_init_db
1140
- def get_failure_reason(job_id: int) -> Optional[str]:
1157
+ def get_job_controller_process(job_id: int) -> Optional[ControllerPidRecord]:
1158
+ assert _SQLALCHEMY_ENGINE is not None
1159
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1160
+ row = session.execute(
1161
+ sqlalchemy.select(
1162
+ job_info_table.c.controller_pid,
1163
+ job_info_table.c.controller_pid_started_at).where(
1164
+ job_info_table.c.spot_job_id == job_id)).fetchone()
1165
+ if row is None or row[0] is None:
1166
+ return None
1167
+ pid = row[0]
1168
+ if pid < 0:
1169
+ # Between #7051 and #7847, the controller pid was negative to
1170
+ # indicate a controller process that can handle multiple jobs.
1171
+ pid = -pid
1172
+ return ControllerPidRecord(pid=pid, started_at=row[1])
1173
+
1174
+
1175
+ @_init_db
1176
+ def is_legacy_controller_process(job_id: int) -> bool:
1177
+ """Check if the controller process is a legacy single-job controller process
1178
+
1179
+ After #7051, the controller process pid is negative to indicate a new
1180
+ multi-job controller process.
1181
+ After #7847, the controller process pid is changed back to positive, but
1182
+ controller_pid_started_at will also be set.
1183
+ """
1184
+ # TODO(cooperc): Remove this function for 0.13.0
1185
+ assert _SQLALCHEMY_ENGINE is not None
1186
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1187
+ row = session.execute(
1188
+ sqlalchemy.select(
1189
+ job_info_table.c.controller_pid,
1190
+ job_info_table.c.controller_pid_started_at).where(
1191
+ job_info_table.c.spot_job_id == job_id)).fetchone()
1192
+ if row is None:
1193
+ raise ValueError(f'Job {job_id} not found')
1194
+ if row[0] is None:
1195
+ # Job is from before #4485, so controller_pid is not set
1196
+ # This is a legacy single-job controller process (running in ray!)
1197
+ return True
1198
+ started_at = row[1]
1199
+ if started_at is not None:
1200
+ # controller_pid_started_at is only set after #7847, so we know this
1201
+ # must be a non-legacy multi-job controller process.
1202
+ return False
1203
+ pid = row[0]
1204
+ if pid < 0:
1205
+ # Between #7051 and #7847, the controller pid was negative to
1206
+ # indicate a non-legacy multi-job controller process.
1207
+ return False
1208
+ return True
1209
+
1210
+
1211
+ def get_status(job_id: int) -> Optional[ManagedJobStatus]:
1212
+ _, status = get_latest_task_id_status(job_id)
1213
+ return status
1214
+
1215
+
1216
+ @_init_db
1217
+ def get_failure_reason(job_id: int) -> Optional[str]:
1141
1218
  """Get the failure reason of a job.
1142
1219
 
1143
1220
  If the job has multiple tasks, we return the first failure reason.
@@ -1155,8 +1232,8 @@ def get_failure_reason(job_id: int) -> Optional[str]:
1155
1232
 
1156
1233
 
1157
1234
  @_init_db
1158
- def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
1159
- """Get managed jobs from the database."""
1235
+ def get_managed_job_tasks(job_id: int) -> List[Dict[str, Any]]:
1236
+ """Get managed job tasks for a specific managed job id from the database."""
1160
1237
  assert _SQLALCHEMY_ENGINE is not None
1161
1238
 
1162
1239
  # Join spot and job_info tables to get the job name for each task.
@@ -1171,10 +1248,8 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
1171
1248
  spot_table.outerjoin(
1172
1249
  job_info_table,
1173
1250
  spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1174
- if job_id is not None:
1175
- query = query.where(spot_table.c.spot_job_id == job_id)
1176
- query = query.order_by(spot_table.c.spot_job_id.desc(),
1177
- spot_table.c.task_id.asc())
1251
+ query = query.where(spot_table.c.spot_job_id == job_id)
1252
+ query = query.order_by(spot_table.c.task_id.asc())
1178
1253
  rows = None
1179
1254
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
1180
1255
  rows = session.execute(query).fetchall()
@@ -1189,20 +1264,307 @@ def get_managed_jobs(job_id: Optional[int] = None) -> List[Dict[str, Any]]:
1189
1264
  job_dict['metadata'] = json.loads(job_dict['metadata'])
1190
1265
 
1191
1266
  # Add user YAML content for managed jobs.
1192
- yaml_path = job_dict.get('original_user_yaml_path')
1193
- if yaml_path:
1194
- try:
1195
- with open(yaml_path, 'r', encoding='utf-8') as f:
1196
- job_dict['user_yaml'] = f.read()
1197
- except (FileNotFoundError, IOError, OSError):
1198
- job_dict['user_yaml'] = None
1199
- else:
1200
- job_dict['user_yaml'] = None
1267
+ job_dict['user_yaml'] = job_dict.get('original_user_yaml_content')
1268
+ if job_dict['user_yaml'] is None:
1269
+ # Backwards compatibility - try to read from file path
1270
+ yaml_path = job_dict.get('original_user_yaml_path')
1271
+ if yaml_path:
1272
+ try:
1273
+ with open(yaml_path, 'r', encoding='utf-8') as f:
1274
+ job_dict['user_yaml'] = f.read()
1275
+ except (FileNotFoundError, IOError, OSError) as e:
1276
+ logger.debug('Failed to read original user YAML for job '
1277
+ f'{job_id} from {yaml_path}: {e}')
1201
1278
 
1202
1279
  jobs.append(job_dict)
1203
1280
  return jobs
1204
1281
 
1205
1282
 
1283
+ def _map_response_field_to_db_column(field: str):
1284
+ """Map the response field name to an actual SQLAlchemy ColumnElement.
1285
+
1286
+ This ensures we never pass plain strings to SQLAlchemy 2.0 APIs like
1287
+ Select.with_only_columns().
1288
+ """
1289
+ # Explicit aliases differing from actual DB column names
1290
+ alias_mapping = {
1291
+ '_job_id': spot_table.c.job_id, # spot.job_id
1292
+ '_task_name': spot_table.c.job_name, # deprecated, from spot table
1293
+ 'job_id': spot_table.c.spot_job_id, # public job id -> spot.spot_job_id
1294
+ '_job_info_job_id': job_info_table.c.spot_job_id,
1295
+ 'job_name': job_info_table.c.name, # public job name -> job_info.name
1296
+ }
1297
+ if field in alias_mapping:
1298
+ return alias_mapping[field]
1299
+
1300
+ # Try direct match on the `spot` table columns
1301
+ if field in spot_table.c:
1302
+ return spot_table.c[field]
1303
+
1304
+ # Try direct match on the `job_info` table columns
1305
+ if field in job_info_table.c:
1306
+ return job_info_table.c[field]
1307
+
1308
+ raise ValueError(f'Unknown field: {field}')
1309
+
1310
+
1311
+ @_init_db
1312
+ def get_managed_jobs_total() -> int:
1313
+ """Get the total number of managed jobs."""
1314
+ assert _SQLALCHEMY_ENGINE is not None
1315
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1316
+ result = session.execute(
1317
+ sqlalchemy.select(sqlalchemy.func.count() # pylint: disable=not-callable
1318
+ ).select_from(spot_table)).fetchone()
1319
+ return result[0] if result else 0
1320
+
1321
+
1322
+ @_init_db
1323
+ def get_managed_jobs_highest_priority() -> int:
1324
+ """Get the highest priority of the managed jobs."""
1325
+ assert _SQLALCHEMY_ENGINE is not None
1326
+ query = sqlalchemy.select(sqlalchemy.func.max(
1327
+ job_info_table.c.priority)).where(
1328
+ sqlalchemy.and_(
1329
+ job_info_table.c.schedule_state.in_([
1330
+ ManagedJobScheduleState.LAUNCHING.value,
1331
+ ManagedJobScheduleState.ALIVE_BACKOFF.value,
1332
+ ManagedJobScheduleState.WAITING.value,
1333
+ ManagedJobScheduleState.ALIVE_WAITING.value,
1334
+ ]),
1335
+ job_info_table.c.priority.is_not(None),
1336
+ ))
1337
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1338
+ priority = session.execute(query).fetchone()
1339
+ return priority[0] if priority and priority[
1340
+ 0] is not None else constants.MIN_PRIORITY
1341
+
1342
+
1343
+ def build_managed_jobs_with_filters_no_status_query(
1344
+ fields: Optional[List[str]] = None,
1345
+ job_ids: Optional[List[int]] = None,
1346
+ accessible_workspaces: Optional[List[str]] = None,
1347
+ workspace_match: Optional[str] = None,
1348
+ name_match: Optional[str] = None,
1349
+ pool_match: Optional[str] = None,
1350
+ user_hashes: Optional[List[Optional[str]]] = None,
1351
+ skip_finished: bool = False,
1352
+ count_only: bool = False,
1353
+ status_count: bool = False,
1354
+ ) -> sqlalchemy.Select:
1355
+ """Build a query to get managed jobs from the database with filters."""
1356
+ # Join spot and job_info tables to get the job name for each task.
1357
+ # We use LEFT OUTER JOIN mainly for backward compatibility, as for an
1358
+ # existing controller before #1982, the job_info table may not exist,
1359
+ # and all the managed jobs created before will not present in the
1360
+ # job_info.
1361
+ # Note: we will get the user_hash here, but don't try to call
1362
+ # global_user_state.get_user() on it. This runs on the controller, which may
1363
+ # not have the user info. Prefer to do it on the API server side.
1364
+ if count_only:
1365
+ query = sqlalchemy.select(sqlalchemy.func.count().label('count')) # pylint: disable=not-callable
1366
+ elif status_count:
1367
+ query = sqlalchemy.select(spot_table.c.status,
1368
+ sqlalchemy.func.count().label('count')) # pylint: disable=not-callable
1369
+ else:
1370
+ query = sqlalchemy.select(spot_table, job_info_table)
1371
+ query = query.select_from(
1372
+ spot_table.outerjoin(
1373
+ job_info_table,
1374
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
1375
+ if skip_finished:
1376
+ # Filter out finished jobs at the DB level. If a multi-task job is
1377
+ # partially finished, include all its tasks. We do this by first
1378
+ # selecting job_ids that have at least one non-terminal task, then
1379
+ # restricting the main query to those job_ids.
1380
+ terminal_status_values = [
1381
+ s.value for s in ManagedJobStatus.terminal_statuses()
1382
+ ]
1383
+ non_terminal_job_ids_subquery = (sqlalchemy.select(
1384
+ spot_table.c.spot_job_id).where(
1385
+ sqlalchemy.or_(
1386
+ spot_table.c.status.is_(None),
1387
+ sqlalchemy.not_(
1388
+ spot_table.c.status.in_(terminal_status_values)),
1389
+ )).distinct())
1390
+ query = query.where(
1391
+ spot_table.c.spot_job_id.in_(non_terminal_job_ids_subquery))
1392
+ if not count_only and not status_count and fields:
1393
+ # Resolve requested field names to explicit ColumnElements from
1394
+ # the joined tables.
1395
+ selected_columns = [_map_response_field_to_db_column(f) for f in fields]
1396
+ query = query.with_only_columns(*selected_columns)
1397
+ if job_ids is not None:
1398
+ query = query.where(spot_table.c.spot_job_id.in_(job_ids))
1399
+ if accessible_workspaces is not None:
1400
+ query = query.where(
1401
+ job_info_table.c.workspace.in_(accessible_workspaces))
1402
+ if workspace_match is not None:
1403
+ query = query.where(
1404
+ job_info_table.c.workspace.like(f'%{workspace_match}%'))
1405
+ if name_match is not None:
1406
+ query = query.where(job_info_table.c.name.like(f'%{name_match}%'))
1407
+ if pool_match is not None:
1408
+ query = query.where(job_info_table.c.pool.like(f'%{pool_match}%'))
1409
+ if user_hashes is not None:
1410
+ query = query.where(job_info_table.c.user_hash.in_(user_hashes))
1411
+ return query
1412
+
1413
+
1414
+ def build_managed_jobs_with_filters_query(
1415
+ fields: Optional[List[str]] = None,
1416
+ job_ids: Optional[List[int]] = None,
1417
+ accessible_workspaces: Optional[List[str]] = None,
1418
+ workspace_match: Optional[str] = None,
1419
+ name_match: Optional[str] = None,
1420
+ pool_match: Optional[str] = None,
1421
+ user_hashes: Optional[List[Optional[str]]] = None,
1422
+ statuses: Optional[List[str]] = None,
1423
+ skip_finished: bool = False,
1424
+ count_only: bool = False,
1425
+ ) -> sqlalchemy.Select:
1426
+ """Build a query to get managed jobs from the database with filters."""
1427
+ query = build_managed_jobs_with_filters_no_status_query(
1428
+ fields=fields,
1429
+ job_ids=job_ids,
1430
+ accessible_workspaces=accessible_workspaces,
1431
+ workspace_match=workspace_match,
1432
+ name_match=name_match,
1433
+ pool_match=pool_match,
1434
+ user_hashes=user_hashes,
1435
+ skip_finished=skip_finished,
1436
+ count_only=count_only,
1437
+ )
1438
+ if statuses is not None:
1439
+ query = query.where(spot_table.c.status.in_(statuses))
1440
+ return query
1441
+
1442
+
1443
+ @_init_db
1444
+ def get_status_count_with_filters(
1445
+ fields: Optional[List[str]] = None,
1446
+ job_ids: Optional[List[int]] = None,
1447
+ accessible_workspaces: Optional[List[str]] = None,
1448
+ workspace_match: Optional[str] = None,
1449
+ name_match: Optional[str] = None,
1450
+ pool_match: Optional[str] = None,
1451
+ user_hashes: Optional[List[Optional[str]]] = None,
1452
+ skip_finished: bool = False,
1453
+ ) -> Dict[str, int]:
1454
+ """Get the status count of the managed jobs with filters."""
1455
+ query = build_managed_jobs_with_filters_no_status_query(
1456
+ fields=fields,
1457
+ job_ids=job_ids,
1458
+ accessible_workspaces=accessible_workspaces,
1459
+ workspace_match=workspace_match,
1460
+ name_match=name_match,
1461
+ pool_match=pool_match,
1462
+ user_hashes=user_hashes,
1463
+ skip_finished=skip_finished,
1464
+ status_count=True,
1465
+ )
1466
+ query = query.group_by(spot_table.c.status)
1467
+ results: Dict[str, int] = {}
1468
+ assert _SQLALCHEMY_ENGINE is not None
1469
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1470
+ rows = session.execute(query).fetchall()
1471
+ for status_value, count in rows:
1472
+ # status_value is already a string (enum value)
1473
+ results[str(status_value)] = int(count)
1474
+ return results
1475
+
1476
+
1477
+ @_init_db
1478
+ def get_managed_jobs_with_filters(
1479
+ fields: Optional[List[str]] = None,
1480
+ job_ids: Optional[List[int]] = None,
1481
+ accessible_workspaces: Optional[List[str]] = None,
1482
+ workspace_match: Optional[str] = None,
1483
+ name_match: Optional[str] = None,
1484
+ pool_match: Optional[str] = None,
1485
+ user_hashes: Optional[List[Optional[str]]] = None,
1486
+ statuses: Optional[List[str]] = None,
1487
+ skip_finished: bool = False,
1488
+ page: Optional[int] = None,
1489
+ limit: Optional[int] = None,
1490
+ ) -> Tuple[List[Dict[str, Any]], int]:
1491
+ """Get managed jobs from the database with filters.
1492
+
1493
+ Returns:
1494
+ A tuple containing
1495
+ - the list of managed jobs
1496
+ - the total number of managed jobs
1497
+ """
1498
+ assert _SQLALCHEMY_ENGINE is not None
1499
+
1500
+ count_query = build_managed_jobs_with_filters_query(
1501
+ fields=None,
1502
+ job_ids=job_ids,
1503
+ accessible_workspaces=accessible_workspaces,
1504
+ workspace_match=workspace_match,
1505
+ name_match=name_match,
1506
+ pool_match=pool_match,
1507
+ user_hashes=user_hashes,
1508
+ statuses=statuses,
1509
+ skip_finished=skip_finished,
1510
+ count_only=True,
1511
+ )
1512
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1513
+ total = session.execute(count_query).fetchone()[0]
1514
+
1515
+ query = build_managed_jobs_with_filters_query(
1516
+ fields=fields,
1517
+ job_ids=job_ids,
1518
+ accessible_workspaces=accessible_workspaces,
1519
+ workspace_match=workspace_match,
1520
+ name_match=name_match,
1521
+ pool_match=pool_match,
1522
+ user_hashes=user_hashes,
1523
+ statuses=statuses,
1524
+ skip_finished=skip_finished,
1525
+ )
1526
+ query = query.order_by(spot_table.c.spot_job_id.desc(),
1527
+ spot_table.c.task_id.asc())
1528
+ if page is not None and limit is not None:
1529
+ query = query.offset((page - 1) * limit).limit(limit)
1530
+ rows = None
1531
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1532
+ rows = session.execute(query).fetchall()
1533
+ jobs = []
1534
+ for row in rows:
1535
+ job_dict = _get_jobs_dict(row._mapping) # pylint: disable=protected-access
1536
+ if job_dict.get('status') is not None:
1537
+ job_dict['status'] = ManagedJobStatus(job_dict['status'])
1538
+ if job_dict.get('schedule_state') is not None:
1539
+ job_dict['schedule_state'] = ManagedJobScheduleState(
1540
+ job_dict['schedule_state'])
1541
+ if job_dict.get('job_name') is None:
1542
+ job_dict['job_name'] = job_dict.get('task_name')
1543
+ if job_dict.get('metadata') is not None:
1544
+ job_dict['metadata'] = json.loads(job_dict['metadata'])
1545
+
1546
+ # Add user YAML content for managed jobs.
1547
+ job_dict['user_yaml'] = job_dict.get('original_user_yaml_content')
1548
+ if job_dict['user_yaml'] is None:
1549
+ # Backwards compatibility - try to read from file path
1550
+ yaml_path = job_dict.get('original_user_yaml_path')
1551
+ if yaml_path:
1552
+ try:
1553
+ with open(yaml_path, 'r', encoding='utf-8') as f:
1554
+ job_dict['user_yaml'] = f.read()
1555
+ except (FileNotFoundError, IOError, OSError) as e:
1556
+ job_id = job_dict.get('job_id')
1557
+ if job_id is not None:
1558
+ logger.debug('Failed to read original user YAML for '
1559
+ f'job {job_id} from {yaml_path}: {e}')
1560
+ else:
1561
+ logger.debug('Failed to read original user YAML from '
1562
+ f'{yaml_path}: {e}')
1563
+
1564
+ jobs.append(job_dict)
1565
+ return jobs, total
1566
+
1567
+
1206
1568
  @_init_db
1207
1569
  def get_task_name(job_id: int, task_id: int) -> str:
1208
1570
  """Get the task name of a job."""
@@ -1243,58 +1605,58 @@ def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]:
1243
1605
 
1244
1606
 
1245
1607
  @_init_db
1246
- def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]:
1247
- """Get the local log directory for a job."""
1608
+ def scheduler_set_waiting(job_id: int, dag_yaml_content: str,
1609
+ original_user_yaml_content: str,
1610
+ env_file_content: str,
1611
+ config_file_content: Optional[str],
1612
+ priority: int) -> None:
1248
1613
  assert _SQLALCHEMY_ENGINE is not None
1249
-
1250
1614
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
1251
- where_conditions = [spot_table.c.spot_job_id == job_id]
1252
- if task_id is not None:
1253
- where_conditions.append(spot_table.c.task_id == task_id)
1254
- local_log_file = session.execute(
1255
- sqlalchemy.select(spot_table.c.local_log_file).where(
1256
- sqlalchemy.and_(*where_conditions))).fetchone()
1257
- return local_log_file[-1] if local_log_file else None
1258
-
1259
-
1260
- # === Scheduler state functions ===
1261
- # Only the scheduler should call these functions. They may require holding the
1262
- # scheduler lock to work correctly.
1615
+ updated_count = session.query(job_info_table).filter(
1616
+ sqlalchemy.and_(job_info_table.c.spot_job_id == job_id,)).update({
1617
+ job_info_table.c.schedule_state:
1618
+ ManagedJobScheduleState.WAITING.value,
1619
+ job_info_table.c.dag_yaml_content: dag_yaml_content,
1620
+ job_info_table.c.original_user_yaml_content:
1621
+ (original_user_yaml_content),
1622
+ job_info_table.c.env_file_content: env_file_content,
1623
+ job_info_table.c.config_file_content: config_file_content,
1624
+ job_info_table.c.priority: priority,
1625
+ })
1626
+ session.commit()
1627
+ assert updated_count <= 1, (job_id, updated_count)
1263
1628
 
1264
1629
 
1265
1630
  @_init_db
1266
- def scheduler_set_waiting(job_id: int, dag_yaml_path: str,
1267
- original_user_yaml_path: str, env_file_path: str,
1268
- user_hash: str, priority: int) -> bool:
1269
- """Do not call without holding the scheduler lock.
1270
-
1271
- Returns: Whether this is a recovery run or not.
1272
- If this is a recovery run, the job may already be in the WAITING
1273
- state and the update will not change the schedule_state (hence the
1274
- updated_count will be 0). In this case, we return True.
1275
- Otherwise, we return False.
1276
- """
1631
+ def get_job_file_contents(job_id: int) -> Dict[str, Optional[str]]:
1632
+ """Return file information and stored contents for a managed job."""
1277
1633
  assert _SQLALCHEMY_ENGINE is not None
1278
1634
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
1279
- updated_count = session.query(job_info_table).filter(
1280
- sqlalchemy.and_(
1281
- job_info_table.c.spot_job_id == job_id,
1282
- job_info_table.c.schedule_state ==
1283
- ManagedJobScheduleState.INACTIVE.value,
1284
- )
1285
- ).update({
1286
- job_info_table.c.schedule_state:
1287
- ManagedJobScheduleState.WAITING.value,
1288
- job_info_table.c.dag_yaml_path: dag_yaml_path,
1289
- job_info_table.c.original_user_yaml_path: original_user_yaml_path,
1290
- job_info_table.c.env_file_path: env_file_path,
1291
- job_info_table.c.user_hash: user_hash,
1292
- job_info_table.c.priority: priority,
1293
- })
1294
- session.commit()
1295
- # For a recovery run, the job may already be in the WAITING state.
1296
- assert updated_count <= 1, (job_id, updated_count)
1297
- return updated_count == 0
1635
+ row = session.execute(
1636
+ sqlalchemy.select(
1637
+ job_info_table.c.dag_yaml_path,
1638
+ job_info_table.c.env_file_path,
1639
+ job_info_table.c.dag_yaml_content,
1640
+ job_info_table.c.env_file_content,
1641
+ job_info_table.c.config_file_content,
1642
+ ).where(job_info_table.c.spot_job_id == job_id)).fetchone()
1643
+
1644
+ if row is None:
1645
+ return {
1646
+ 'dag_yaml_path': None,
1647
+ 'env_file_path': None,
1648
+ 'dag_yaml_content': None,
1649
+ 'env_file_content': None,
1650
+ 'config_file_content': None,
1651
+ }
1652
+
1653
+ return {
1654
+ 'dag_yaml_path': row[0],
1655
+ 'env_file_path': row[1],
1656
+ 'dag_yaml_content': row[2],
1657
+ 'env_file_content': row[3],
1658
+ 'config_file_content': row[4],
1659
+ }
1298
1660
 
1299
1661
 
1300
1662
  @_init_db
@@ -1319,17 +1681,18 @@ def set_current_cluster_name(job_id: int, current_cluster_name: str) -> None:
1319
1681
  session.commit()
1320
1682
 
1321
1683
 
1322
- @_init_db
1323
- def set_job_id_on_pool_cluster(job_id: int,
1324
- job_id_on_pool_cluster: int) -> None:
1684
+ @_init_db_async
1685
+ async def set_job_id_on_pool_cluster_async(job_id: int,
1686
+ job_id_on_pool_cluster: int) -> None:
1325
1687
  """Set the job id on the pool cluster for a job."""
1326
- assert _SQLALCHEMY_ENGINE is not None
1327
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1328
- session.query(job_info_table).filter(
1329
- job_info_table.c.spot_job_id == job_id).update({
1688
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1689
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1690
+ await session.execute(
1691
+ sqlalchemy.update(job_info_table).
1692
+ where(job_info_table.c.spot_job_id == job_id).values({
1330
1693
  job_info_table.c.job_id_on_pool_cluster: job_id_on_pool_cluster
1331
- })
1332
- session.commit()
1694
+ }))
1695
+ await session.commit()
1333
1696
 
1334
1697
 
1335
1698
  @_init_db
@@ -1347,77 +1710,54 @@ def get_pool_submit_info(job_id: int) -> Tuple[Optional[str], Optional[int]]:
1347
1710
  return info[0], info[1]
1348
1711
 
1349
1712
 
1350
- @_init_db
1351
- def scheduler_set_launching(job_id: int,
1352
- current_state: ManagedJobScheduleState) -> None:
1353
- """Do not call without holding the scheduler lock."""
1354
- assert _SQLALCHEMY_ENGINE is not None
1355
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1356
- updated_count = session.query(job_info_table).filter(
1357
- sqlalchemy.and_(
1358
- job_info_table.c.spot_job_id == job_id,
1359
- job_info_table.c.schedule_state == current_state.value,
1360
- )).update({
1361
- job_info_table.c.schedule_state:
1362
- ManagedJobScheduleState.LAUNCHING.value
1363
- })
1364
- session.commit()
1365
- assert updated_count == 1, (job_id, updated_count)
1713
+ @_init_db_async
1714
+ async def get_pool_submit_info_async(
1715
+ job_id: int) -> Tuple[Optional[str], Optional[int]]:
1716
+ """Get the cluster name and job id on the pool from the managed job id."""
1717
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1718
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1719
+ result = await session.execute(
1720
+ sqlalchemy.select(job_info_table.c.current_cluster_name,
1721
+ job_info_table.c.job_id_on_pool_cluster).where(
1722
+ job_info_table.c.spot_job_id == job_id))
1723
+ info = result.fetchone()
1724
+ if info is None:
1725
+ return None, None
1726
+ return info[0], info[1]
1366
1727
 
1367
1728
 
1368
- @_init_db
1369
- def scheduler_set_alive(job_id: int) -> None:
1370
- """Do not call without holding the scheduler lock."""
1371
- assert _SQLALCHEMY_ENGINE is not None
1372
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1373
- updated_count = session.query(job_info_table).filter(
1374
- sqlalchemy.and_(
1375
- job_info_table.c.spot_job_id == job_id,
1376
- job_info_table.c.schedule_state ==
1377
- ManagedJobScheduleState.LAUNCHING.value,
1378
- )).update({
1379
- job_info_table.c.schedule_state:
1380
- ManagedJobScheduleState.ALIVE.value
1381
- })
1382
- session.commit()
1383
- assert updated_count == 1, (job_id, updated_count)
1729
+ @_init_db_async
1730
+ async def scheduler_set_launching_async(job_id: int):
1731
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1732
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1733
+ await session.execute(
1734
+ sqlalchemy.update(job_info_table).where(
1735
+ sqlalchemy.and_(job_info_table.c.spot_job_id == job_id)).values(
1736
+ {
1737
+ job_info_table.c.schedule_state:
1738
+ ManagedJobScheduleState.LAUNCHING.value
1739
+ }))
1740
+ await session.commit()
1384
1741
 
1385
1742
 
1386
- @_init_db
1387
- def scheduler_set_alive_backoff(job_id: int) -> None:
1743
+ @_init_db_async
1744
+ async def scheduler_set_alive_async(job_id: int) -> None:
1388
1745
  """Do not call without holding the scheduler lock."""
1389
- assert _SQLALCHEMY_ENGINE is not None
1390
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1391
- updated_count = session.query(job_info_table).filter(
1392
- sqlalchemy.and_(
1393
- job_info_table.c.spot_job_id == job_id,
1394
- job_info_table.c.schedule_state ==
1395
- ManagedJobScheduleState.LAUNCHING.value,
1396
- )).update({
1397
- job_info_table.c.schedule_state:
1398
- ManagedJobScheduleState.ALIVE_BACKOFF.value
1399
- })
1400
- session.commit()
1401
- assert updated_count == 1, (job_id, updated_count)
1402
-
1403
-
1404
- @_init_db
1405
- def scheduler_set_alive_waiting(job_id: int) -> None:
1406
- """Do not call without holding the scheduler lock."""
1407
- assert _SQLALCHEMY_ENGINE is not None
1408
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1409
- updated_count = session.query(job_info_table).filter(
1410
- sqlalchemy.and_(
1411
- job_info_table.c.spot_job_id == job_id,
1412
- job_info_table.c.schedule_state.in_([
1413
- ManagedJobScheduleState.ALIVE.value,
1414
- ManagedJobScheduleState.ALIVE_BACKOFF.value,
1415
- ]))).update({
1746
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1747
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1748
+ result = await session.execute(
1749
+ sqlalchemy.update(job_info_table).where(
1750
+ sqlalchemy.and_(
1751
+ job_info_table.c.spot_job_id == job_id,
1752
+ job_info_table.c.schedule_state ==
1753
+ ManagedJobScheduleState.LAUNCHING.value,
1754
+ )).values({
1416
1755
  job_info_table.c.schedule_state:
1417
- ManagedJobScheduleState.ALIVE_WAITING.value
1418
- })
1419
- session.commit()
1420
- assert updated_count == 1, (job_id, updated_count)
1756
+ ManagedJobScheduleState.ALIVE.value
1757
+ }))
1758
+ changes = result.rowcount
1759
+ await session.commit()
1760
+ assert changes == 1, (job_id, changes)
1421
1761
 
1422
1762
 
1423
1763
  @_init_db
@@ -1439,16 +1779,6 @@ def scheduler_set_done(job_id: int, idempotent: bool = False) -> None:
1439
1779
  assert updated_count == 1, (job_id, updated_count)
1440
1780
 
1441
1781
 
1442
- @_init_db
1443
- def set_job_controller_pid(job_id: int, pid: int):
1444
- assert _SQLALCHEMY_ENGINE is not None
1445
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1446
- updated_count = session.query(job_info_table).filter_by(
1447
- spot_job_id=job_id).update({job_info_table.c.controller_pid: pid})
1448
- session.commit()
1449
- assert updated_count == 1, (job_id, updated_count)
1450
-
1451
-
1452
1782
  @_init_db
1453
1783
  def get_job_schedule_state(job_id: int) -> ManagedJobScheduleState:
1454
1784
  assert _SQLALCHEMY_ENGINE is not None
@@ -1527,58 +1857,70 @@ def get_nonterminal_job_ids_by_pool(pool: str,
1527
1857
  return job_ids
1528
1858
 
1529
1859
 
1530
- @_init_db
1531
- def get_waiting_job() -> Optional[Dict[str, Any]]:
1860
+ @_init_db_async
1861
+ async def get_waiting_job_async(
1862
+ pid: int, pid_started_at: float) -> Optional[Dict[str, Any]]:
1532
1863
  """Get the next job that should transition to LAUNCHING.
1533
1864
 
1534
- Selects the highest-priority WAITING or ALIVE_WAITING job, provided its
1535
- priority is greater than or equal to any currently LAUNCHING or
1536
- ALIVE_BACKOFF job.
1865
+ Selects the highest-priority WAITING or ALIVE_WAITING job and atomically
1866
+ transitions it to LAUNCHING state to prevent race conditions.
1867
+
1868
+ Returns the job information if a job was successfully transitioned to
1869
+ LAUNCHING, or None if no suitable job was found.
1537
1870
 
1538
1871
  Backwards compatibility note: jobs submitted before #4485 will have no
1539
1872
  schedule_state and will be ignored by this SQL query.
1540
1873
  """
1541
- assert _SQLALCHEMY_ENGINE is not None
1542
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1543
- # Get the highest-priority WAITING or ALIVE_WAITING job whose priority
1544
- # is greater than or equal to the highest priority LAUNCHING or
1545
- # ALIVE_BACKOFF job's priority.
1546
- # First, get the max priority of LAUNCHING or ALIVE_BACKOFF jobs
1547
- max_priority_subquery = sqlalchemy.select(
1548
- sqlalchemy.func.max(job_info_table.c.priority)).where(
1549
- job_info_table.c.schedule_state.in_([
1550
- ManagedJobScheduleState.LAUNCHING.value,
1551
- ManagedJobScheduleState.ALIVE_BACKOFF.value,
1552
- ])).scalar_subquery()
1553
- # Main query for waiting jobs
1554
- select_conds = [
1555
- job_info_table.c.schedule_state.in_([
1556
- ManagedJobScheduleState.WAITING.value,
1557
- ManagedJobScheduleState.ALIVE_WAITING.value,
1558
- ]),
1559
- job_info_table.c.priority >= sqlalchemy.func.coalesce(
1560
- max_priority_subquery, 0),
1561
- ]
1562
- query = sqlalchemy.select(
1874
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1875
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1876
+ # Select the highest priority waiting job for update (locks the row)
1877
+ select_query = sqlalchemy.select(
1563
1878
  job_info_table.c.spot_job_id,
1564
1879
  job_info_table.c.schedule_state,
1565
- job_info_table.c.dag_yaml_path,
1566
- job_info_table.c.env_file_path,
1567
1880
  job_info_table.c.pool,
1568
- ).where(sqlalchemy.and_(*select_conds)).order_by(
1569
- job_info_table.c.priority.desc(),
1570
- job_info_table.c.spot_job_id.asc(),
1571
- ).limit(1)
1572
- waiting_job_row = session.execute(query).fetchone()
1881
+ ).where(
1882
+ job_info_table.c.schedule_state.in_([
1883
+ ManagedJobScheduleState.WAITING.value,
1884
+ ])).order_by(
1885
+ job_info_table.c.priority.desc(),
1886
+ job_info_table.c.spot_job_id.asc(),
1887
+ ).limit(1).with_for_update()
1888
+
1889
+ # Execute the select with row locking
1890
+ result = await session.execute(select_query)
1891
+ waiting_job_row = result.fetchone()
1892
+
1573
1893
  if waiting_job_row is None:
1574
1894
  return None
1575
1895
 
1896
+ job_id = waiting_job_row[0]
1897
+ current_state = ManagedJobScheduleState(waiting_job_row[1])
1898
+ pool = waiting_job_row[2]
1899
+
1900
+ # Update the job state to LAUNCHING
1901
+ update_result = await session.execute(
1902
+ sqlalchemy.update(job_info_table).where(
1903
+ sqlalchemy.and_(
1904
+ job_info_table.c.spot_job_id == job_id,
1905
+ job_info_table.c.schedule_state == current_state.value,
1906
+ )).values({
1907
+ job_info_table.c.schedule_state:
1908
+ ManagedJobScheduleState.LAUNCHING.value,
1909
+ job_info_table.c.controller_pid: pid,
1910
+ job_info_table.c.controller_pid_started_at: pid_started_at,
1911
+ }))
1912
+
1913
+ if update_result.rowcount != 1:
1914
+ # Update failed, rollback and return None
1915
+ await session.rollback()
1916
+ return None
1917
+
1918
+ # Commit the transaction
1919
+ await session.commit()
1920
+
1576
1921
  return {
1577
- 'job_id': waiting_job_row[0],
1578
- 'schedule_state': ManagedJobScheduleState(waiting_job_row[1]),
1579
- 'dag_yaml_path': waiting_job_row[2],
1580
- 'env_file_path': waiting_job_row[3],
1581
- 'pool': waiting_job_row[4],
1922
+ 'job_id': job_id,
1923
+ 'pool': pool,
1582
1924
  }
1583
1925
 
1584
1926
 
@@ -1596,24 +1938,393 @@ def get_workspace(job_id: int) -> str:
1596
1938
  return job_workspace
1597
1939
 
1598
1940
 
1599
- # === HA Recovery Script functions ===
1941
+ @_init_db_async
1942
+ async def get_latest_task_id_status_async(
1943
+ job_id: int) -> Union[Tuple[int, ManagedJobStatus], Tuple[None, None]]:
1944
+ """Returns the (task id, status) of the latest task of a job."""
1945
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1946
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1947
+ result = await session.execute(
1948
+ sqlalchemy.select(
1949
+ spot_table.c.task_id,
1950
+ spot_table.c.status,
1951
+ ).where(spot_table.c.spot_job_id == job_id).order_by(
1952
+ spot_table.c.task_id.asc()))
1953
+ id_statuses = [
1954
+ (row[0], ManagedJobStatus(row[1])) for row in result.fetchall()
1955
+ ]
1600
1956
 
1957
+ if not id_statuses:
1958
+ return None, None
1959
+ task_id, status = next(
1960
+ ((tid, st) for tid, st in id_statuses if not st.is_terminal()),
1961
+ id_statuses[-1],
1962
+ )
1963
+ return task_id, status
1601
1964
 
1602
- @_init_db
1603
- def get_ha_recovery_script(job_id: int) -> Optional[str]:
1604
- """Get the HA recovery script for a job."""
1605
- assert _SQLALCHEMY_ENGINE is not None
1606
- with orm.Session(_SQLALCHEMY_ENGINE) as session:
1607
- row = session.query(ha_recovery_script_table).filter_by(
1608
- job_id=job_id).first()
1609
- if row is None:
1610
- return None
1611
- return row.script
1965
+
1966
+ @_init_db_async
1967
+ async def set_starting_async(job_id: int, task_id: int, run_timestamp: str,
1968
+ submit_time: float, resources_str: str,
1969
+ specs: Dict[str, Union[str, int]],
1970
+ callback_func: AsyncCallbackType):
1971
+ """Set the task to starting state."""
1972
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
1973
+ logger.info('Launching the spot cluster...')
1974
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1975
+ result = await session.execute(
1976
+ sqlalchemy.update(spot_table).where(
1977
+ sqlalchemy.and_(
1978
+ spot_table.c.spot_job_id == job_id,
1979
+ spot_table.c.task_id == task_id,
1980
+ spot_table.c.status == ManagedJobStatus.PENDING.value,
1981
+ spot_table.c.end_at.is_(None),
1982
+ )).values({
1983
+ spot_table.c.resources: resources_str,
1984
+ spot_table.c.submitted_at: submit_time,
1985
+ spot_table.c.status: ManagedJobStatus.STARTING.value,
1986
+ spot_table.c.run_timestamp: run_timestamp,
1987
+ spot_table.c.specs: json.dumps(specs),
1988
+ }))
1989
+ count = result.rowcount
1990
+ await session.commit()
1991
+ if count != 1:
1992
+ details = await _describe_task_transition_failure(
1993
+ session, job_id, task_id)
1994
+ message = ('Failed to set the task to starting. '
1995
+ f'({count} rows updated. {details})')
1996
+ logger.error(message)
1997
+ raise exceptions.ManagedJobStatusError(message)
1998
+ await callback_func('SUBMITTED')
1999
+ await callback_func('STARTING')
2000
+
2001
+
2002
+ @_init_db_async
2003
+ async def set_started_async(job_id: int, task_id: int, start_time: float,
2004
+ callback_func: AsyncCallbackType):
2005
+ """Set the task to started state."""
2006
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2007
+ logger.info('Job started.')
2008
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2009
+ result = await session.execute(
2010
+ sqlalchemy.update(spot_table).where(
2011
+ sqlalchemy.and_(
2012
+ spot_table.c.spot_job_id == job_id,
2013
+ spot_table.c.task_id == task_id,
2014
+ spot_table.c.status.in_([
2015
+ ManagedJobStatus.STARTING.value,
2016
+ ManagedJobStatus.PENDING.value
2017
+ ]),
2018
+ spot_table.c.end_at.is_(None),
2019
+ )).values({
2020
+ spot_table.c.status: ManagedJobStatus.RUNNING.value,
2021
+ spot_table.c.start_at: start_time,
2022
+ spot_table.c.last_recovered_at: start_time,
2023
+ }))
2024
+ count = result.rowcount
2025
+ await session.commit()
2026
+ if count != 1:
2027
+ details = await _describe_task_transition_failure(
2028
+ session, job_id, task_id)
2029
+ message = (f'Failed to set the task to started. '
2030
+ f'({count} rows updated. {details})')
2031
+ logger.error(message)
2032
+ raise exceptions.ManagedJobStatusError(message)
2033
+ await callback_func('STARTED')
2034
+
2035
+
2036
+ @_init_db_async
2037
+ async def get_job_status_with_task_id_async(
2038
+ job_id: int, task_id: int) -> Optional[ManagedJobStatus]:
2039
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2040
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2041
+ result = await session.execute(
2042
+ sqlalchemy.select(spot_table.c.status).where(
2043
+ sqlalchemy.and_(spot_table.c.spot_job_id == job_id,
2044
+ spot_table.c.task_id == task_id)))
2045
+ status = result.fetchone()
2046
+ return ManagedJobStatus(status[0]) if status else None
2047
+
2048
+
2049
+ @_init_db_async
2050
+ async def set_recovering_async(job_id: int, task_id: int,
2051
+ force_transit_to_recovering: bool,
2052
+ callback_func: AsyncCallbackType):
2053
+ """Set the task to recovering state, and update the job duration."""
2054
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2055
+ logger.info('=== Recovering... ===')
2056
+ current_time = time.time()
2057
+
2058
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2059
+ if force_transit_to_recovering:
2060
+ status_condition = spot_table.c.status.in_(
2061
+ [s.value for s in ManagedJobStatus.processing_statuses()])
2062
+ else:
2063
+ status_condition = (
2064
+ spot_table.c.status == ManagedJobStatus.RUNNING.value)
2065
+
2066
+ result = await session.execute(
2067
+ sqlalchemy.update(spot_table).where(
2068
+ sqlalchemy.and_(
2069
+ spot_table.c.spot_job_id == job_id,
2070
+ spot_table.c.task_id == task_id,
2071
+ status_condition,
2072
+ spot_table.c.end_at.is_(None),
2073
+ )).values({
2074
+ spot_table.c.status: ManagedJobStatus.RECOVERING.value,
2075
+ spot_table.c.job_duration: sqlalchemy.case(
2076
+ (spot_table.c.last_recovered_at >= 0,
2077
+ spot_table.c.job_duration + current_time -
2078
+ spot_table.c.last_recovered_at),
2079
+ else_=spot_table.c.job_duration),
2080
+ spot_table.c.last_recovered_at: sqlalchemy.case(
2081
+ (spot_table.c.last_recovered_at < 0, current_time),
2082
+ else_=spot_table.c.last_recovered_at),
2083
+ }))
2084
+ count = result.rowcount
2085
+ await session.commit()
2086
+ if count != 1:
2087
+ details = await _describe_task_transition_failure(
2088
+ session, job_id, task_id)
2089
+ message = ('Failed to set the task to recovering with '
2090
+ 'force_transit_to_recovering='
2091
+ f'{force_transit_to_recovering}. '
2092
+ f'({count} rows updated. {details})')
2093
+ logger.error(message)
2094
+ raise exceptions.ManagedJobStatusError(message)
2095
+ await callback_func('RECOVERING')
2096
+
2097
+
2098
+ @_init_db_async
2099
+ async def set_recovered_async(job_id: int, task_id: int, recovered_time: float,
2100
+ callback_func: AsyncCallbackType):
2101
+ """Set the task to recovered."""
2102
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2103
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2104
+ result = await session.execute(
2105
+ sqlalchemy.update(spot_table).where(
2106
+ sqlalchemy.and_(
2107
+ spot_table.c.spot_job_id == job_id,
2108
+ spot_table.c.task_id == task_id,
2109
+ spot_table.c.status == ManagedJobStatus.RECOVERING.value,
2110
+ spot_table.c.end_at.is_(None),
2111
+ )).values({
2112
+ spot_table.c.status: ManagedJobStatus.RUNNING.value,
2113
+ spot_table.c.last_recovered_at: recovered_time,
2114
+ spot_table.c.recovery_count: spot_table.c.recovery_count +
2115
+ 1,
2116
+ }))
2117
+ count = result.rowcount
2118
+ await session.commit()
2119
+ if count != 1:
2120
+ details = await _describe_task_transition_failure(
2121
+ session, job_id, task_id)
2122
+ message = (f'Failed to set the task to recovered. '
2123
+ f'({count} rows updated. {details})')
2124
+ logger.error(message)
2125
+ raise exceptions.ManagedJobStatusError(message)
2126
+ logger.info('==== Recovered. ====')
2127
+ await callback_func('RECOVERED')
2128
+
2129
+
2130
+ @_init_db_async
2131
+ async def set_succeeded_async(job_id: int, task_id: int, end_time: float,
2132
+ callback_func: AsyncCallbackType):
2133
+ """Set the task to succeeded, if it is in a non-terminal state."""
2134
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2135
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2136
+ result = await session.execute(
2137
+ sqlalchemy.update(spot_table).where(
2138
+ sqlalchemy.and_(
2139
+ spot_table.c.spot_job_id == job_id,
2140
+ spot_table.c.task_id == task_id,
2141
+ spot_table.c.status == ManagedJobStatus.RUNNING.value,
2142
+ spot_table.c.end_at.is_(None),
2143
+ )).values({
2144
+ spot_table.c.status: ManagedJobStatus.SUCCEEDED.value,
2145
+ spot_table.c.end_at: end_time,
2146
+ }))
2147
+ count = result.rowcount
2148
+ await session.commit()
2149
+ if count != 1:
2150
+ details = await _describe_task_transition_failure(
2151
+ session, job_id, task_id)
2152
+ message = (f'Failed to set the task to succeeded. '
2153
+ f'({count} rows updated. {details})')
2154
+ logger.error(message)
2155
+ raise exceptions.ManagedJobStatusError(message)
2156
+ await callback_func('SUCCEEDED')
2157
+ logger.info('Job succeeded.')
2158
+
2159
+
2160
+ @_init_db_async
2161
+ async def set_failed_async(
2162
+ job_id: int,
2163
+ task_id: Optional[int],
2164
+ failure_type: ManagedJobStatus,
2165
+ failure_reason: str,
2166
+ callback_func: Optional[AsyncCallbackType] = None,
2167
+ end_time: Optional[float] = None,
2168
+ override_terminal: bool = False,
2169
+ ):
2170
+ """Set an entire job or task to failed."""
2171
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2172
+ assert failure_type.is_failed(), failure_type
2173
+ end_time = time.time() if end_time is None else end_time
2174
+
2175
+ fields_to_set: Dict[str, Any] = {
2176
+ spot_table.c.status: failure_type.value,
2177
+ spot_table.c.failure_reason: failure_reason,
2178
+ }
2179
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2180
+ # Get previous status
2181
+ result = await session.execute(
2182
+ sqlalchemy.select(
2183
+ spot_table.c.status).where(spot_table.c.spot_job_id == job_id))
2184
+ previous_status_row = result.fetchone()
2185
+ previous_status = ManagedJobStatus(previous_status_row[0])
2186
+ if previous_status == ManagedJobStatus.RECOVERING:
2187
+ fields_to_set[spot_table.c.last_recovered_at] = end_time
2188
+ where_conditions = [spot_table.c.spot_job_id == job_id]
2189
+ if task_id is not None:
2190
+ where_conditions.append(spot_table.c.task_id == task_id)
2191
+
2192
+ # Handle failure_reason prepending when override_terminal is True
2193
+ if override_terminal:
2194
+ # Get existing failure_reason with row lock to prevent race
2195
+ # conditions
2196
+ existing_reason_result = await session.execute(
2197
+ sqlalchemy.select(spot_table.c.failure_reason).where(
2198
+ sqlalchemy.and_(*where_conditions)).with_for_update())
2199
+ existing_reason_row = existing_reason_result.fetchone()
2200
+ if existing_reason_row and existing_reason_row[0]:
2201
+ # Prepend new failure reason to existing one
2202
+ fields_to_set[spot_table.c.failure_reason] = (
2203
+ failure_reason + '. Previously: ' + existing_reason_row[0])
2204
+ fields_to_set[spot_table.c.end_at] = sqlalchemy.func.coalesce(
2205
+ spot_table.c.end_at, end_time)
2206
+ else:
2207
+ fields_to_set[spot_table.c.end_at] = end_time
2208
+ where_conditions.append(spot_table.c.end_at.is_(None))
2209
+ result = await session.execute(
2210
+ sqlalchemy.update(spot_table).where(
2211
+ sqlalchemy.and_(*where_conditions)).values(fields_to_set))
2212
+ count = result.rowcount
2213
+ await session.commit()
2214
+ updated = count > 0
2215
+ if callback_func and updated:
2216
+ await callback_func('FAILED')
2217
+ logger.info(failure_reason)
2218
+
2219
+
2220
+ @_init_db_async
2221
+ async def set_cancelling_async(job_id: int, callback_func: AsyncCallbackType):
2222
+ """Set tasks in the job as cancelling, if they are in non-terminal
2223
+ states."""
2224
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2225
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2226
+ result = await session.execute(
2227
+ sqlalchemy.update(spot_table).where(
2228
+ sqlalchemy.and_(
2229
+ spot_table.c.spot_job_id == job_id,
2230
+ spot_table.c.end_at.is_(None),
2231
+ )).values(
2232
+ {spot_table.c.status: ManagedJobStatus.CANCELLING.value}))
2233
+ count = result.rowcount
2234
+ await session.commit()
2235
+ updated = count > 0
2236
+ if updated:
2237
+ logger.info('Cancelling the job...')
2238
+ await callback_func('CANCELLING')
2239
+ else:
2240
+ logger.info('Cancellation skipped, job is already terminal')
2241
+
2242
+
2243
+ @_init_db_async
2244
+ async def set_cancelled_async(job_id: int, callback_func: AsyncCallbackType):
2245
+ """Set tasks in the job as cancelled, if they are in CANCELLING state."""
2246
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2247
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2248
+ result = await session.execute(
2249
+ sqlalchemy.update(spot_table).where(
2250
+ sqlalchemy.and_(
2251
+ spot_table.c.spot_job_id == job_id,
2252
+ spot_table.c.status == ManagedJobStatus.CANCELLING.value,
2253
+ )).values({
2254
+ spot_table.c.status: ManagedJobStatus.CANCELLED.value,
2255
+ spot_table.c.end_at: time.time(),
2256
+ }))
2257
+ count = result.rowcount
2258
+ await session.commit()
2259
+ updated = count > 0
2260
+ if updated:
2261
+ logger.info('Job cancelled.')
2262
+ await callback_func('CANCELLED')
2263
+ else:
2264
+ logger.info('Cancellation skipped, job is not CANCELLING')
2265
+
2266
+
2267
+ @_init_db_async
2268
+ async def remove_ha_recovery_script_async(job_id: int) -> None:
2269
+ """Remove the HA recovery script for a job."""
2270
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2271
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2272
+ await session.execute(
2273
+ sqlalchemy.delete(ha_recovery_script_table).where(
2274
+ ha_recovery_script_table.c.job_id == job_id))
2275
+ await session.commit()
2276
+
2277
+
2278
+ async def get_status_async(job_id: int) -> Optional[ManagedJobStatus]:
2279
+ _, status = await get_latest_task_id_status_async(job_id)
2280
+ return status
2281
+
2282
+
2283
+ @_init_db_async
2284
+ async def get_job_schedule_state_async(job_id: int) -> ManagedJobScheduleState:
2285
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2286
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2287
+ result = await session.execute(
2288
+ sqlalchemy.select(job_info_table.c.schedule_state).where(
2289
+ job_info_table.c.spot_job_id == job_id))
2290
+ state = result.fetchone()[0]
2291
+ return ManagedJobScheduleState(state)
2292
+
2293
+
2294
+ @_init_db_async
2295
+ async def scheduler_set_done_async(job_id: int,
2296
+ idempotent: bool = False) -> None:
2297
+ """Do not call without holding the scheduler lock."""
2298
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
2299
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
2300
+ result = await session.execute(
2301
+ sqlalchemy.update(job_info_table).where(
2302
+ sqlalchemy.and_(
2303
+ job_info_table.c.spot_job_id == job_id,
2304
+ job_info_table.c.schedule_state !=
2305
+ ManagedJobScheduleState.DONE.value,
2306
+ )).values({
2307
+ job_info_table.c.schedule_state:
2308
+ ManagedJobScheduleState.DONE.value
2309
+ }))
2310
+ updated_count = result.rowcount
2311
+ await session.commit()
2312
+ if not idempotent:
2313
+ assert updated_count == 1, (job_id, updated_count)
2314
+
2315
+
2316
+ # ==== needed for codegen ====
2317
+ # functions have no use outside of codegen, remove at your own peril
1612
2318
 
1613
2319
 
1614
2320
  @_init_db
1615
- def set_ha_recovery_script(job_id: int, script: str) -> None:
1616
- """Set the HA recovery script for a job."""
2321
+ def set_job_info(job_id: int,
2322
+ name: str,
2323
+ workspace: str,
2324
+ entrypoint: str,
2325
+ pool: Optional[str],
2326
+ pool_hash: Optional[str],
2327
+ user_hash: Optional[str] = None):
1617
2328
  assert _SQLALCHEMY_ENGINE is not None
1618
2329
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
1619
2330
  if (_SQLALCHEMY_ENGINE.dialect.name ==
@@ -1624,20 +2335,187 @@ def set_ha_recovery_script(job_id: int, script: str) -> None:
1624
2335
  insert_func = postgresql.insert
1625
2336
  else:
1626
2337
  raise ValueError('Unsupported database dialect')
1627
- insert_stmt = insert_func(ha_recovery_script_table).values(
1628
- job_id=job_id, script=script)
1629
- do_update_stmt = insert_stmt.on_conflict_do_update(
1630
- index_elements=[ha_recovery_script_table.c.job_id],
1631
- set_={ha_recovery_script_table.c.script: script})
1632
- session.execute(do_update_stmt)
2338
+ insert_stmt = insert_func(job_info_table).values(
2339
+ spot_job_id=job_id,
2340
+ name=name,
2341
+ schedule_state=ManagedJobScheduleState.INACTIVE.value,
2342
+ workspace=workspace,
2343
+ entrypoint=entrypoint,
2344
+ pool=pool,
2345
+ pool_hash=pool_hash,
2346
+ user_hash=user_hash,
2347
+ )
2348
+ session.execute(insert_stmt)
1633
2349
  session.commit()
1634
2350
 
1635
2351
 
1636
2352
  @_init_db
1637
- def remove_ha_recovery_script(job_id: int) -> None:
1638
- """Remove the HA recovery script for a job."""
2353
+ def reset_jobs_for_recovery() -> None:
2354
+ """Remove controller PIDs for live jobs, allowing them to be recovered."""
2355
+ assert _SQLALCHEMY_ENGINE is not None
2356
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2357
+ session.query(job_info_table).filter(
2358
+ # PID should be set.
2359
+ job_info_table.c.controller_pid.isnot(None),
2360
+ # Schedule state should be alive.
2361
+ job_info_table.c.schedule_state.isnot(None),
2362
+ (job_info_table.c.schedule_state !=
2363
+ ManagedJobScheduleState.WAITING.value),
2364
+ (job_info_table.c.schedule_state !=
2365
+ ManagedJobScheduleState.DONE.value),
2366
+ ).update({
2367
+ job_info_table.c.controller_pid: None,
2368
+ job_info_table.c.controller_pid_started_at: None,
2369
+ job_info_table.c.schedule_state:
2370
+ (ManagedJobScheduleState.WAITING.value)
2371
+ })
2372
+ session.commit()
2373
+
2374
+
2375
+ @_init_db
2376
+ def reset_job_for_recovery(job_id: int) -> None:
2377
+ """Set a job to WAITING and remove PID, allowing it to be recovered."""
2378
+ assert _SQLALCHEMY_ENGINE is not None
2379
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2380
+ session.query(job_info_table).filter(
2381
+ job_info_table.c.spot_job_id == job_id).update({
2382
+ job_info_table.c.controller_pid: None,
2383
+ job_info_table.c.controller_pid_started_at: None,
2384
+ job_info_table.c.schedule_state:
2385
+ ManagedJobScheduleState.WAITING.value,
2386
+ })
2387
+ session.commit()
2388
+
2389
+
2390
+ @_init_db
2391
+ def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
2392
+ """Get all job ids by name."""
2393
+ assert _SQLALCHEMY_ENGINE is not None
2394
+
2395
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2396
+ query = sqlalchemy.select(
2397
+ spot_table.c.spot_job_id.distinct()).select_from(
2398
+ spot_table.outerjoin(
2399
+ job_info_table,
2400
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id))
2401
+ if name is not None:
2402
+ # We match the job name from `job_info` for the jobs submitted after
2403
+ # #1982, and from `spot` for the jobs submitted before #1982, whose
2404
+ # job_info is not available.
2405
+ name_condition = sqlalchemy.or_(
2406
+ job_info_table.c.name == name,
2407
+ sqlalchemy.and_(job_info_table.c.name.is_(None),
2408
+ spot_table.c.task_name == name))
2409
+ query = query.where(name_condition)
2410
+ query = query.order_by(spot_table.c.spot_job_id.desc())
2411
+ rows = session.execute(query).fetchall()
2412
+ job_ids = [row[0] for row in rows if row[0] is not None]
2413
+ return job_ids
2414
+
2415
+
2416
+ @_init_db
2417
+ def get_task_logs_to_clean(retention_seconds: int,
2418
+ batch_size: int) -> List[Dict[str, Any]]:
2419
+ """Get the logs of job tasks to clean.
2420
+
2421
+ The logs of a task will only cleaned when:
2422
+ - the job schedule state is DONE
2423
+ - AND the end time of the task is older than the retention period
2424
+ """
2425
+ assert _SQLALCHEMY_ENGINE is not None
2426
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2427
+ now = time.time()
2428
+ result = session.execute(
2429
+ sqlalchemy.select(
2430
+ spot_table.c.spot_job_id,
2431
+ spot_table.c.task_id,
2432
+ spot_table.c.local_log_file,
2433
+ ).select_from(
2434
+ spot_table.join(
2435
+ job_info_table,
2436
+ spot_table.c.spot_job_id == job_info_table.c.spot_job_id,
2437
+ )).
2438
+ where(
2439
+ sqlalchemy.and_(
2440
+ job_info_table.c.schedule_state.is_(
2441
+ ManagedJobScheduleState.DONE.value),
2442
+ spot_table.c.end_at.isnot(None),
2443
+ spot_table.c.end_at < (now - retention_seconds),
2444
+ spot_table.c.logs_cleaned_at.is_(None),
2445
+ # The local log file is set AFTER the task is finished,
2446
+ # add this condition to ensure the entire log file has
2447
+ # been written.
2448
+ spot_table.c.local_log_file.isnot(None),
2449
+ )).limit(batch_size))
2450
+ rows = result.fetchall()
2451
+ return [{
2452
+ 'job_id': row[0],
2453
+ 'task_id': row[1],
2454
+ 'local_log_file': row[2]
2455
+ } for row in rows]
2456
+
2457
+
2458
+ @_init_db
2459
+ def get_controller_logs_to_clean(retention_seconds: int,
2460
+ batch_size: int) -> List[Dict[str, Any]]:
2461
+ """Get the controller logs to clean.
2462
+
2463
+ The controller logs will only cleaned when:
2464
+ - the job schedule state is DONE
2465
+ - AND the end time of the latest task is older than the retention period
2466
+ """
1639
2467
  assert _SQLALCHEMY_ENGINE is not None
1640
2468
  with orm.Session(_SQLALCHEMY_ENGINE) as session:
1641
- session.query(ha_recovery_script_table).filter_by(
1642
- job_id=job_id).delete()
2469
+ now = time.time()
2470
+ result = session.execute(
2471
+ sqlalchemy.select(job_info_table.c.spot_job_id,).select_from(
2472
+ job_info_table.join(
2473
+ spot_table,
2474
+ job_info_table.c.spot_job_id == spot_table.c.spot_job_id,
2475
+ )).where(
2476
+ sqlalchemy.and_(
2477
+ job_info_table.c.schedule_state.is_(
2478
+ ManagedJobScheduleState.DONE.value),
2479
+ spot_table.c.local_log_file.isnot(None),
2480
+ job_info_table.c.controller_logs_cleaned_at.is_(None),
2481
+ )).group_by(
2482
+ job_info_table.c.spot_job_id,
2483
+ job_info_table.c.current_cluster_name,
2484
+ ).having(
2485
+ sqlalchemy.func.max(
2486
+ spot_table.c.end_at).isnot(None),).having(
2487
+ sqlalchemy.func.max(spot_table.c.end_at) < (
2488
+ now - retention_seconds)).limit(batch_size))
2489
+ rows = result.fetchall()
2490
+ return [{'job_id': row[0]} for row in rows]
2491
+
2492
+
2493
+ @_init_db
2494
+ def set_task_logs_cleaned(tasks: List[Tuple[int, int]], logs_cleaned_at: float):
2495
+ """Set the task logs cleaned at."""
2496
+ if not tasks:
2497
+ return
2498
+ task_keys = list(dict.fromkeys(tasks))
2499
+ assert _SQLALCHEMY_ENGINE is not None
2500
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2501
+ session.execute(
2502
+ sqlalchemy.update(spot_table).where(
2503
+ sqlalchemy.tuple_(spot_table.c.spot_job_id,
2504
+ spot_table.c.task_id).in_(task_keys)).values(
2505
+ logs_cleaned_at=logs_cleaned_at))
2506
+ session.commit()
2507
+
2508
+
2509
+ @_init_db
2510
+ def set_controller_logs_cleaned(job_ids: List[int], logs_cleaned_at: float):
2511
+ """Set the controller logs cleaned at."""
2512
+ if not job_ids:
2513
+ return
2514
+ job_ids = list(dict.fromkeys(job_ids))
2515
+ assert _SQLALCHEMY_ENGINE is not None
2516
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2517
+ session.execute(
2518
+ sqlalchemy.update(job_info_table).where(
2519
+ job_info_table.c.spot_job_id.in_(job_ids)).values(
2520
+ controller_logs_cleaned_at=logs_cleaned_at))
1643
2521
  session.commit()