skypilot-nightly 1.0.0.dev20250502__py3-none-any.whl → 1.0.0.dev20251203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. sky/__init__.py +22 -6
  2. sky/adaptors/aws.py +81 -16
  3. sky/adaptors/common.py +25 -2
  4. sky/adaptors/coreweave.py +278 -0
  5. sky/adaptors/do.py +8 -2
  6. sky/adaptors/gcp.py +11 -0
  7. sky/adaptors/hyperbolic.py +8 -0
  8. sky/adaptors/ibm.py +5 -2
  9. sky/adaptors/kubernetes.py +149 -18
  10. sky/adaptors/nebius.py +173 -30
  11. sky/adaptors/primeintellect.py +1 -0
  12. sky/adaptors/runpod.py +68 -0
  13. sky/adaptors/seeweb.py +183 -0
  14. sky/adaptors/shadeform.py +89 -0
  15. sky/admin_policy.py +187 -4
  16. sky/authentication.py +179 -225
  17. sky/backends/__init__.py +4 -2
  18. sky/backends/backend.py +22 -9
  19. sky/backends/backend_utils.py +1323 -397
  20. sky/backends/cloud_vm_ray_backend.py +1749 -1029
  21. sky/backends/docker_utils.py +1 -1
  22. sky/backends/local_docker_backend.py +11 -6
  23. sky/backends/task_codegen.py +633 -0
  24. sky/backends/wheel_utils.py +55 -9
  25. sky/{clouds/service_catalog → catalog}/__init__.py +21 -19
  26. sky/{clouds/service_catalog → catalog}/aws_catalog.py +27 -8
  27. sky/{clouds/service_catalog → catalog}/azure_catalog.py +10 -7
  28. sky/{clouds/service_catalog → catalog}/common.py +90 -49
  29. sky/{clouds/service_catalog → catalog}/cudo_catalog.py +8 -5
  30. sky/{clouds/service_catalog → catalog}/data_fetchers/analyze.py +1 -1
  31. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_aws.py +116 -80
  32. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_cudo.py +38 -38
  33. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_gcp.py +70 -16
  34. sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
  35. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_lambda_cloud.py +1 -0
  36. sky/catalog/data_fetchers/fetch_nebius.py +338 -0
  37. sky/catalog/data_fetchers/fetch_runpod.py +698 -0
  38. sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
  39. sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
  40. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vast.py +1 -1
  41. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vsphere.py +1 -1
  42. sky/{clouds/service_catalog → catalog}/do_catalog.py +5 -2
  43. sky/{clouds/service_catalog → catalog}/fluidstack_catalog.py +6 -3
  44. sky/{clouds/service_catalog → catalog}/gcp_catalog.py +41 -15
  45. sky/catalog/hyperbolic_catalog.py +136 -0
  46. sky/{clouds/service_catalog → catalog}/ibm_catalog.py +9 -6
  47. sky/{clouds/service_catalog → catalog}/kubernetes_catalog.py +36 -24
  48. sky/{clouds/service_catalog → catalog}/lambda_catalog.py +9 -6
  49. sky/{clouds/service_catalog → catalog}/nebius_catalog.py +9 -7
  50. sky/{clouds/service_catalog → catalog}/oci_catalog.py +9 -6
  51. sky/{clouds/service_catalog → catalog}/paperspace_catalog.py +5 -2
  52. sky/catalog/primeintellect_catalog.py +95 -0
  53. sky/{clouds/service_catalog → catalog}/runpod_catalog.py +11 -4
  54. sky/{clouds/service_catalog → catalog}/scp_catalog.py +9 -6
  55. sky/catalog/seeweb_catalog.py +184 -0
  56. sky/catalog/shadeform_catalog.py +165 -0
  57. sky/catalog/ssh_catalog.py +167 -0
  58. sky/{clouds/service_catalog → catalog}/vast_catalog.py +6 -3
  59. sky/{clouds/service_catalog → catalog}/vsphere_catalog.py +5 -2
  60. sky/check.py +533 -185
  61. sky/cli.py +5 -5975
  62. sky/client/{cli.py → cli/command.py} +2591 -1956
  63. sky/client/cli/deprecation_utils.py +99 -0
  64. sky/client/cli/flags.py +359 -0
  65. sky/client/cli/table_utils.py +322 -0
  66. sky/client/cli/utils.py +79 -0
  67. sky/client/common.py +78 -32
  68. sky/client/oauth.py +82 -0
  69. sky/client/sdk.py +1219 -319
  70. sky/client/sdk_async.py +827 -0
  71. sky/client/service_account_auth.py +47 -0
  72. sky/cloud_stores.py +82 -3
  73. sky/clouds/__init__.py +13 -0
  74. sky/clouds/aws.py +564 -164
  75. sky/clouds/azure.py +105 -83
  76. sky/clouds/cloud.py +140 -40
  77. sky/clouds/cudo.py +68 -50
  78. sky/clouds/do.py +66 -48
  79. sky/clouds/fluidstack.py +63 -44
  80. sky/clouds/gcp.py +339 -110
  81. sky/clouds/hyperbolic.py +293 -0
  82. sky/clouds/ibm.py +70 -49
  83. sky/clouds/kubernetes.py +570 -162
  84. sky/clouds/lambda_cloud.py +74 -54
  85. sky/clouds/nebius.py +210 -81
  86. sky/clouds/oci.py +88 -66
  87. sky/clouds/paperspace.py +61 -44
  88. sky/clouds/primeintellect.py +317 -0
  89. sky/clouds/runpod.py +164 -74
  90. sky/clouds/scp.py +89 -86
  91. sky/clouds/seeweb.py +477 -0
  92. sky/clouds/shadeform.py +400 -0
  93. sky/clouds/ssh.py +263 -0
  94. sky/clouds/utils/aws_utils.py +10 -4
  95. sky/clouds/utils/gcp_utils.py +87 -11
  96. sky/clouds/utils/oci_utils.py +38 -14
  97. sky/clouds/utils/scp_utils.py +231 -167
  98. sky/clouds/vast.py +99 -77
  99. sky/clouds/vsphere.py +51 -40
  100. sky/core.py +375 -173
  101. sky/dag.py +15 -0
  102. sky/dashboard/out/404.html +1 -1
  103. sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +1 -0
  104. sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
  105. sky/dashboard/out/_next/static/chunks/1272-1ef0bf0237faccdb.js +1 -0
  106. sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +6 -0
  107. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
  108. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +1 -0
  109. sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
  110. sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +26 -0
  111. sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +1 -0
  112. sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +1 -0
  113. sky/dashboard/out/_next/static/chunks/3800-7b45f9fbb6308557.js +1 -0
  114. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
  115. sky/dashboard/out/_next/static/chunks/3937.210053269f121201.js +1 -0
  116. sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +1 -0
  117. sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
  118. sky/dashboard/out/_next/static/chunks/5739-d67458fcb1386c92.js +8 -0
  119. sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
  120. sky/dashboard/out/_next/static/chunks/616-3d59f75e2ccf9321.js +39 -0
  121. sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
  122. sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +1 -0
  123. sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +1 -0
  124. sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +1 -0
  125. sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
  126. sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +41 -0
  127. sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +1 -0
  128. sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
  129. sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +1 -0
  130. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
  131. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
  132. sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +31 -0
  133. sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
  134. sky/dashboard/out/_next/static/chunks/fd9d1056-86323a29a8f7e46a.js +1 -0
  135. sky/dashboard/out/_next/static/chunks/framework-cf60a09ccd051a10.js +33 -0
  136. sky/dashboard/out/_next/static/chunks/main-app-587214043926b3cc.js +1 -0
  137. sky/dashboard/out/_next/static/chunks/main-f15ccb73239a3bf1.js +1 -0
  138. sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +34 -0
  139. sky/dashboard/out/_next/static/chunks/pages/_error-c66a4e8afc46f17b.js +1 -0
  140. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +16 -0
  141. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-abfcac9c137aa543.js +1 -0
  142. sky/dashboard/out/_next/static/chunks/pages/clusters-ee39056f9851a3ff.js +1 -0
  143. sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
  144. sky/dashboard/out/_next/static/chunks/pages/index-444f1804401f04ea.js +1 -0
  145. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +1 -0
  146. sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +1 -0
  147. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +16 -0
  148. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-9faf940b253e3e06.js +21 -0
  149. sky/dashboard/out/_next/static/chunks/pages/jobs-2072b48b617989c9.js +1 -0
  150. sky/dashboard/out/_next/static/chunks/pages/users-f42674164aa73423.js +1 -0
  151. sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
  152. sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
  153. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-84a40f8c7c627fe4.js +1 -0
  154. sky/dashboard/out/_next/static/chunks/pages/workspaces-531b2f8c4bf89f82.js +1 -0
  155. sky/dashboard/out/_next/static/chunks/webpack-64e05f17bf2cf8ce.js +1 -0
  156. sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
  157. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  158. sky/dashboard/out/clusters/[cluster].html +1 -1
  159. sky/dashboard/out/clusters.html +1 -1
  160. sky/dashboard/out/config.html +1 -0
  161. sky/dashboard/out/index.html +1 -1
  162. sky/dashboard/out/infra/[context].html +1 -0
  163. sky/dashboard/out/infra.html +1 -0
  164. sky/dashboard/out/jobs/[job].html +1 -1
  165. sky/dashboard/out/jobs/pools/[pool].html +1 -0
  166. sky/dashboard/out/jobs.html +1 -1
  167. sky/dashboard/out/users.html +1 -0
  168. sky/dashboard/out/volumes.html +1 -0
  169. sky/dashboard/out/workspace/new.html +1 -0
  170. sky/dashboard/out/workspaces/[name].html +1 -0
  171. sky/dashboard/out/workspaces.html +1 -0
  172. sky/data/data_utils.py +137 -1
  173. sky/data/mounting_utils.py +269 -84
  174. sky/data/storage.py +1460 -1807
  175. sky/data/storage_utils.py +43 -57
  176. sky/exceptions.py +126 -2
  177. sky/execution.py +216 -63
  178. sky/global_user_state.py +2390 -586
  179. sky/jobs/__init__.py +7 -0
  180. sky/jobs/client/sdk.py +300 -58
  181. sky/jobs/client/sdk_async.py +161 -0
  182. sky/jobs/constants.py +15 -8
  183. sky/jobs/controller.py +848 -275
  184. sky/jobs/file_content_utils.py +128 -0
  185. sky/jobs/log_gc.py +193 -0
  186. sky/jobs/recovery_strategy.py +402 -152
  187. sky/jobs/scheduler.py +314 -189
  188. sky/jobs/server/core.py +836 -255
  189. sky/jobs/server/server.py +156 -115
  190. sky/jobs/server/utils.py +136 -0
  191. sky/jobs/state.py +2109 -706
  192. sky/jobs/utils.py +1306 -215
  193. sky/logs/__init__.py +21 -0
  194. sky/logs/agent.py +108 -0
  195. sky/logs/aws.py +243 -0
  196. sky/logs/gcp.py +91 -0
  197. sky/metrics/__init__.py +0 -0
  198. sky/metrics/utils.py +453 -0
  199. sky/models.py +78 -1
  200. sky/optimizer.py +164 -70
  201. sky/provision/__init__.py +90 -4
  202. sky/provision/aws/config.py +147 -26
  203. sky/provision/aws/instance.py +136 -50
  204. sky/provision/azure/instance.py +11 -6
  205. sky/provision/common.py +13 -1
  206. sky/provision/cudo/cudo_machine_type.py +1 -1
  207. sky/provision/cudo/cudo_utils.py +14 -8
  208. sky/provision/cudo/cudo_wrapper.py +72 -71
  209. sky/provision/cudo/instance.py +10 -6
  210. sky/provision/do/instance.py +10 -6
  211. sky/provision/do/utils.py +4 -3
  212. sky/provision/docker_utils.py +140 -33
  213. sky/provision/fluidstack/instance.py +13 -8
  214. sky/provision/gcp/__init__.py +1 -0
  215. sky/provision/gcp/config.py +301 -19
  216. sky/provision/gcp/constants.py +218 -0
  217. sky/provision/gcp/instance.py +36 -8
  218. sky/provision/gcp/instance_utils.py +18 -4
  219. sky/provision/gcp/volume_utils.py +247 -0
  220. sky/provision/hyperbolic/__init__.py +12 -0
  221. sky/provision/hyperbolic/config.py +10 -0
  222. sky/provision/hyperbolic/instance.py +437 -0
  223. sky/provision/hyperbolic/utils.py +373 -0
  224. sky/provision/instance_setup.py +101 -20
  225. sky/provision/kubernetes/__init__.py +5 -0
  226. sky/provision/kubernetes/config.py +9 -52
  227. sky/provision/kubernetes/constants.py +17 -0
  228. sky/provision/kubernetes/instance.py +919 -280
  229. sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
  230. sky/provision/kubernetes/network.py +27 -17
  231. sky/provision/kubernetes/network_utils.py +44 -43
  232. sky/provision/kubernetes/utils.py +1221 -534
  233. sky/provision/kubernetes/volume.py +343 -0
  234. sky/provision/lambda_cloud/instance.py +22 -16
  235. sky/provision/nebius/constants.py +50 -0
  236. sky/provision/nebius/instance.py +19 -6
  237. sky/provision/nebius/utils.py +237 -137
  238. sky/provision/oci/instance.py +10 -5
  239. sky/provision/paperspace/instance.py +10 -7
  240. sky/provision/paperspace/utils.py +1 -1
  241. sky/provision/primeintellect/__init__.py +10 -0
  242. sky/provision/primeintellect/config.py +11 -0
  243. sky/provision/primeintellect/instance.py +454 -0
  244. sky/provision/primeintellect/utils.py +398 -0
  245. sky/provision/provisioner.py +117 -36
  246. sky/provision/runpod/__init__.py +5 -0
  247. sky/provision/runpod/instance.py +27 -6
  248. sky/provision/runpod/utils.py +51 -18
  249. sky/provision/runpod/volume.py +214 -0
  250. sky/provision/scp/__init__.py +15 -0
  251. sky/provision/scp/config.py +93 -0
  252. sky/provision/scp/instance.py +707 -0
  253. sky/provision/seeweb/__init__.py +11 -0
  254. sky/provision/seeweb/config.py +13 -0
  255. sky/provision/seeweb/instance.py +812 -0
  256. sky/provision/shadeform/__init__.py +11 -0
  257. sky/provision/shadeform/config.py +12 -0
  258. sky/provision/shadeform/instance.py +351 -0
  259. sky/provision/shadeform/shadeform_utils.py +83 -0
  260. sky/provision/ssh/__init__.py +18 -0
  261. sky/provision/vast/instance.py +13 -8
  262. sky/provision/vast/utils.py +10 -7
  263. sky/provision/volume.py +164 -0
  264. sky/provision/vsphere/common/ssl_helper.py +1 -1
  265. sky/provision/vsphere/common/vapiconnect.py +2 -1
  266. sky/provision/vsphere/common/vim_utils.py +4 -4
  267. sky/provision/vsphere/instance.py +15 -10
  268. sky/provision/vsphere/vsphere_utils.py +17 -20
  269. sky/py.typed +0 -0
  270. sky/resources.py +845 -119
  271. sky/schemas/__init__.py +0 -0
  272. sky/schemas/api/__init__.py +0 -0
  273. sky/schemas/api/responses.py +227 -0
  274. sky/schemas/db/README +4 -0
  275. sky/schemas/db/env.py +90 -0
  276. sky/schemas/db/global_user_state/001_initial_schema.py +124 -0
  277. sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
  278. sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
  279. sky/schemas/db/global_user_state/004_is_managed.py +34 -0
  280. sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
  281. sky/schemas/db/global_user_state/006_provision_log.py +41 -0
  282. sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
  283. sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
  284. sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
  285. sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
  286. sky/schemas/db/global_user_state/011_is_ephemeral.py +34 -0
  287. sky/schemas/db/kv_cache/001_initial_schema.py +29 -0
  288. sky/schemas/db/script.py.mako +28 -0
  289. sky/schemas/db/serve_state/001_initial_schema.py +67 -0
  290. sky/schemas/db/serve_state/002_yaml_content.py +34 -0
  291. sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
  292. sky/schemas/db/spot_jobs/001_initial_schema.py +97 -0
  293. sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
  294. sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
  295. sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
  296. sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
  297. sky/schemas/db/spot_jobs/006_controller_pid_started_at.py +34 -0
  298. sky/schemas/db/spot_jobs/007_config_file_content.py +34 -0
  299. sky/schemas/generated/__init__.py +0 -0
  300. sky/schemas/generated/autostopv1_pb2.py +36 -0
  301. sky/schemas/generated/autostopv1_pb2.pyi +43 -0
  302. sky/schemas/generated/autostopv1_pb2_grpc.py +146 -0
  303. sky/schemas/generated/jobsv1_pb2.py +86 -0
  304. sky/schemas/generated/jobsv1_pb2.pyi +254 -0
  305. sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
  306. sky/schemas/generated/managed_jobsv1_pb2.py +76 -0
  307. sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
  308. sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
  309. sky/schemas/generated/servev1_pb2.py +58 -0
  310. sky/schemas/generated/servev1_pb2.pyi +115 -0
  311. sky/schemas/generated/servev1_pb2_grpc.py +322 -0
  312. sky/serve/autoscalers.py +357 -5
  313. sky/serve/client/impl.py +310 -0
  314. sky/serve/client/sdk.py +47 -139
  315. sky/serve/client/sdk_async.py +130 -0
  316. sky/serve/constants.py +12 -9
  317. sky/serve/controller.py +68 -17
  318. sky/serve/load_balancer.py +106 -60
  319. sky/serve/load_balancing_policies.py +116 -2
  320. sky/serve/replica_managers.py +434 -249
  321. sky/serve/serve_rpc_utils.py +179 -0
  322. sky/serve/serve_state.py +569 -257
  323. sky/serve/serve_utils.py +775 -265
  324. sky/serve/server/core.py +66 -711
  325. sky/serve/server/impl.py +1093 -0
  326. sky/serve/server/server.py +21 -18
  327. sky/serve/service.py +192 -89
  328. sky/serve/service_spec.py +144 -20
  329. sky/serve/spot_placer.py +3 -0
  330. sky/server/auth/__init__.py +0 -0
  331. sky/server/auth/authn.py +50 -0
  332. sky/server/auth/loopback.py +38 -0
  333. sky/server/auth/oauth2_proxy.py +202 -0
  334. sky/server/common.py +478 -182
  335. sky/server/config.py +85 -23
  336. sky/server/constants.py +44 -6
  337. sky/server/daemons.py +295 -0
  338. sky/server/html/token_page.html +185 -0
  339. sky/server/metrics.py +160 -0
  340. sky/server/middleware_utils.py +166 -0
  341. sky/server/requests/executor.py +558 -138
  342. sky/server/requests/payloads.py +364 -24
  343. sky/server/requests/preconditions.py +21 -17
  344. sky/server/requests/process.py +112 -29
  345. sky/server/requests/request_names.py +121 -0
  346. sky/server/requests/requests.py +822 -226
  347. sky/server/requests/serializers/decoders.py +82 -31
  348. sky/server/requests/serializers/encoders.py +140 -22
  349. sky/server/requests/threads.py +117 -0
  350. sky/server/rest.py +455 -0
  351. sky/server/server.py +1309 -285
  352. sky/server/state.py +20 -0
  353. sky/server/stream_utils.py +327 -61
  354. sky/server/uvicorn.py +217 -3
  355. sky/server/versions.py +270 -0
  356. sky/setup_files/MANIFEST.in +11 -1
  357. sky/setup_files/alembic.ini +160 -0
  358. sky/setup_files/dependencies.py +139 -31
  359. sky/setup_files/setup.py +44 -42
  360. sky/sky_logging.py +114 -7
  361. sky/skylet/attempt_skylet.py +106 -24
  362. sky/skylet/autostop_lib.py +129 -8
  363. sky/skylet/configs.py +29 -20
  364. sky/skylet/constants.py +216 -25
  365. sky/skylet/events.py +101 -21
  366. sky/skylet/job_lib.py +345 -164
  367. sky/skylet/log_lib.py +297 -18
  368. sky/skylet/log_lib.pyi +44 -1
  369. sky/skylet/providers/ibm/node_provider.py +12 -8
  370. sky/skylet/providers/ibm/vpc_provider.py +13 -12
  371. sky/skylet/ray_patches/__init__.py +17 -3
  372. sky/skylet/ray_patches/autoscaler.py.diff +18 -0
  373. sky/skylet/ray_patches/cli.py.diff +19 -0
  374. sky/skylet/ray_patches/command_runner.py.diff +17 -0
  375. sky/skylet/ray_patches/log_monitor.py.diff +20 -0
  376. sky/skylet/ray_patches/resource_demand_scheduler.py.diff +32 -0
  377. sky/skylet/ray_patches/updater.py.diff +18 -0
  378. sky/skylet/ray_patches/worker.py.diff +41 -0
  379. sky/skylet/runtime_utils.py +21 -0
  380. sky/skylet/services.py +568 -0
  381. sky/skylet/skylet.py +72 -4
  382. sky/skylet/subprocess_daemon.py +104 -29
  383. sky/skypilot_config.py +506 -99
  384. sky/ssh_node_pools/__init__.py +1 -0
  385. sky/ssh_node_pools/core.py +135 -0
  386. sky/ssh_node_pools/server.py +233 -0
  387. sky/task.py +685 -163
  388. sky/templates/aws-ray.yml.j2 +11 -3
  389. sky/templates/azure-ray.yml.j2 +2 -1
  390. sky/templates/cudo-ray.yml.j2 +1 -0
  391. sky/templates/do-ray.yml.j2 +2 -1
  392. sky/templates/fluidstack-ray.yml.j2 +1 -0
  393. sky/templates/gcp-ray.yml.j2 +62 -1
  394. sky/templates/hyperbolic-ray.yml.j2 +68 -0
  395. sky/templates/ibm-ray.yml.j2 +2 -1
  396. sky/templates/jobs-controller.yaml.j2 +27 -24
  397. sky/templates/kubernetes-loadbalancer.yml.j2 +2 -0
  398. sky/templates/kubernetes-ray.yml.j2 +611 -50
  399. sky/templates/lambda-ray.yml.j2 +2 -1
  400. sky/templates/nebius-ray.yml.j2 +34 -12
  401. sky/templates/oci-ray.yml.j2 +1 -0
  402. sky/templates/paperspace-ray.yml.j2 +2 -1
  403. sky/templates/primeintellect-ray.yml.j2 +72 -0
  404. sky/templates/runpod-ray.yml.j2 +10 -1
  405. sky/templates/scp-ray.yml.j2 +4 -50
  406. sky/templates/seeweb-ray.yml.j2 +171 -0
  407. sky/templates/shadeform-ray.yml.j2 +73 -0
  408. sky/templates/sky-serve-controller.yaml.j2 +22 -2
  409. sky/templates/vast-ray.yml.j2 +1 -0
  410. sky/templates/vsphere-ray.yml.j2 +1 -0
  411. sky/templates/websocket_proxy.py +212 -37
  412. sky/usage/usage_lib.py +31 -15
  413. sky/users/__init__.py +0 -0
  414. sky/users/model.conf +15 -0
  415. sky/users/permission.py +397 -0
  416. sky/users/rbac.py +121 -0
  417. sky/users/server.py +720 -0
  418. sky/users/token_service.py +218 -0
  419. sky/utils/accelerator_registry.py +35 -5
  420. sky/utils/admin_policy_utils.py +84 -38
  421. sky/utils/annotations.py +38 -5
  422. sky/utils/asyncio_utils.py +78 -0
  423. sky/utils/atomic.py +1 -1
  424. sky/utils/auth_utils.py +153 -0
  425. sky/utils/benchmark_utils.py +60 -0
  426. sky/utils/cli_utils/status_utils.py +159 -86
  427. sky/utils/cluster_utils.py +31 -9
  428. sky/utils/command_runner.py +354 -68
  429. sky/utils/command_runner.pyi +93 -3
  430. sky/utils/common.py +35 -8
  431. sky/utils/common_utils.py +314 -91
  432. sky/utils/config_utils.py +74 -5
  433. sky/utils/context.py +403 -0
  434. sky/utils/context_utils.py +242 -0
  435. sky/utils/controller_utils.py +383 -89
  436. sky/utils/dag_utils.py +31 -12
  437. sky/utils/db/__init__.py +0 -0
  438. sky/utils/db/db_utils.py +485 -0
  439. sky/utils/db/kv_cache.py +149 -0
  440. sky/utils/db/migration_utils.py +137 -0
  441. sky/utils/directory_utils.py +12 -0
  442. sky/utils/env_options.py +13 -0
  443. sky/utils/git.py +567 -0
  444. sky/utils/git_clone.sh +460 -0
  445. sky/utils/infra_utils.py +195 -0
  446. sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
  447. sky/utils/kubernetes/config_map_utils.py +133 -0
  448. sky/utils/kubernetes/create_cluster.sh +15 -29
  449. sky/utils/kubernetes/delete_cluster.sh +10 -7
  450. sky/utils/kubernetes/deploy_ssh_node_pools.py +1177 -0
  451. sky/utils/kubernetes/exec_kubeconfig_converter.py +22 -31
  452. sky/utils/kubernetes/generate_kind_config.py +6 -66
  453. sky/utils/kubernetes/generate_kubeconfig.sh +4 -1
  454. sky/utils/kubernetes/gpu_labeler.py +18 -8
  455. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +2 -1
  456. sky/utils/kubernetes/k8s_gpu_labeler_setup.yaml +16 -16
  457. sky/utils/kubernetes/kubernetes_deploy_utils.py +284 -114
  458. sky/utils/kubernetes/rsync_helper.sh +11 -3
  459. sky/utils/kubernetes/ssh-tunnel.sh +379 -0
  460. sky/utils/kubernetes/ssh_utils.py +221 -0
  461. sky/utils/kubernetes_enums.py +8 -15
  462. sky/utils/lock_events.py +94 -0
  463. sky/utils/locks.py +416 -0
  464. sky/utils/log_utils.py +82 -107
  465. sky/utils/perf_utils.py +22 -0
  466. sky/utils/resource_checker.py +298 -0
  467. sky/utils/resources_utils.py +249 -32
  468. sky/utils/rich_utils.py +217 -39
  469. sky/utils/schemas.py +955 -160
  470. sky/utils/serialize_utils.py +16 -0
  471. sky/utils/status_lib.py +10 -0
  472. sky/utils/subprocess_utils.py +29 -15
  473. sky/utils/tempstore.py +70 -0
  474. sky/utils/thread_utils.py +91 -0
  475. sky/utils/timeline.py +26 -53
  476. sky/utils/ux_utils.py +84 -15
  477. sky/utils/validator.py +11 -1
  478. sky/utils/volume.py +165 -0
  479. sky/utils/yaml_utils.py +111 -0
  480. sky/volumes/__init__.py +13 -0
  481. sky/volumes/client/__init__.py +0 -0
  482. sky/volumes/client/sdk.py +150 -0
  483. sky/volumes/server/__init__.py +0 -0
  484. sky/volumes/server/core.py +270 -0
  485. sky/volumes/server/server.py +124 -0
  486. sky/volumes/volume.py +215 -0
  487. sky/workspaces/__init__.py +0 -0
  488. sky/workspaces/core.py +655 -0
  489. sky/workspaces/server.py +101 -0
  490. sky/workspaces/utils.py +56 -0
  491. sky_templates/README.md +3 -0
  492. sky_templates/__init__.py +3 -0
  493. sky_templates/ray/__init__.py +0 -0
  494. sky_templates/ray/start_cluster +183 -0
  495. sky_templates/ray/stop_cluster +75 -0
  496. skypilot_nightly-1.0.0.dev20251203.dist-info/METADATA +676 -0
  497. skypilot_nightly-1.0.0.dev20251203.dist-info/RECORD +611 -0
  498. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/WHEEL +1 -1
  499. skypilot_nightly-1.0.0.dev20251203.dist-info/top_level.txt +2 -0
  500. sky/benchmark/benchmark_state.py +0 -256
  501. sky/benchmark/benchmark_utils.py +0 -641
  502. sky/clouds/service_catalog/constants.py +0 -7
  503. sky/dashboard/out/_next/static/GWvVBSCS7FmUiVmjaL1a7/_buildManifest.js +0 -1
  504. sky/dashboard/out/_next/static/chunks/236-2db3ee3fba33dd9e.js +0 -6
  505. sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
  506. sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
  507. sky/dashboard/out/_next/static/chunks/678-206dddca808e6d16.js +0 -59
  508. sky/dashboard/out/_next/static/chunks/845-9e60713e0c441abc.js +0 -1
  509. sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
  510. sky/dashboard/out/_next/static/chunks/fd9d1056-2821b0f0cabcd8bd.js +0 -1
  511. sky/dashboard/out/_next/static/chunks/framework-87d061ee6ed71b28.js +0 -33
  512. sky/dashboard/out/_next/static/chunks/main-app-241eb28595532291.js +0 -1
  513. sky/dashboard/out/_next/static/chunks/main-e0e2335212e72357.js +0 -1
  514. sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
  515. sky/dashboard/out/_next/static/chunks/pages/_error-1be831200e60c5c0.js +0 -1
  516. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-6ac338bc2239cb45.js +0 -1
  517. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
  518. sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
  519. sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
  520. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-1c519e1afc523dc9.js +0 -1
  521. sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
  522. sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
  523. sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
  524. sky/jobs/dashboard/dashboard.py +0 -223
  525. sky/jobs/dashboard/static/favicon.ico +0 -0
  526. sky/jobs/dashboard/templates/index.html +0 -831
  527. sky/jobs/server/dashboard_utils.py +0 -69
  528. sky/skylet/providers/scp/__init__.py +0 -2
  529. sky/skylet/providers/scp/config.py +0 -149
  530. sky/skylet/providers/scp/node_provider.py +0 -578
  531. sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
  532. sky/utils/db_utils.py +0 -100
  533. sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
  534. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
  535. skypilot_nightly-1.0.0.dev20250502.dist-info/METADATA +0 -361
  536. skypilot_nightly-1.0.0.dev20250502.dist-info/RECORD +0 -396
  537. skypilot_nightly-1.0.0.dev20250502.dist-info/top_level.txt +0 -1
  538. /sky/{clouds/service_catalog → catalog}/config.py +0 -0
  539. /sky/{benchmark → catalog/data_fetchers}/__init__.py +0 -0
  540. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_azure.py +0 -0
  541. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_fluidstack.py +0 -0
  542. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_ibm.py +0 -0
  543. /sky/{clouds/service_catalog/data_fetchers → client/cli}/__init__.py +0 -0
  544. /sky/dashboard/out/_next/static/{GWvVBSCS7FmUiVmjaL1a7 → 96_E2yl3QAiIJGOYCkSpB}/_ssgManifest.js +0 -0
  545. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/entry_points.txt +0 -0
  546. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/licenses/LICENSE +0 -0
sky/global_user_state.py CHANGED
@@ -6,22 +6,40 @@ Concepts:
6
6
  - Cluster handle: (non-user facing) an opaque backend handle for us to
7
7
  interact with a cluster.
8
8
  """
9
+ import asyncio
10
+ import enum
11
+ import functools
9
12
  import json
10
13
  import os
11
- import pathlib
12
14
  import pickle
13
- import sqlite3
15
+ import re
16
+ import threading
14
17
  import time
15
18
  import typing
16
19
  from typing import Any, Dict, List, Optional, Set, Tuple
17
20
  import uuid
18
21
 
22
+ import sqlalchemy
23
+ from sqlalchemy import exc as sqlalchemy_exc
24
+ from sqlalchemy import orm
25
+ from sqlalchemy.dialects import postgresql
26
+ from sqlalchemy.dialects import sqlite
27
+ from sqlalchemy.ext import asyncio as sql_async
28
+ from sqlalchemy.ext import declarative
29
+
19
30
  from sky import models
20
31
  from sky import sky_logging
32
+ from sky import skypilot_config
33
+ from sky.metrics import utils as metrics_lib
34
+ from sky.skylet import constants
35
+ from sky.utils import annotations
21
36
  from sky.utils import common_utils
22
- from sky.utils import db_utils
37
+ from sky.utils import context_utils
23
38
  from sky.utils import registry
24
39
  from sky.utils import status_lib
40
+ from sky.utils import yaml_utils
41
+ from sky.utils.db import db_utils
42
+ from sky.utils.db import migration_utils
25
43
 
26
44
  if typing.TYPE_CHECKING:
27
45
  from sky import backends
@@ -32,171 +50,594 @@ if typing.TYPE_CHECKING:
32
50
  logger = sky_logging.init_logger(__name__)
33
51
 
34
52
  _ENABLED_CLOUDS_KEY_PREFIX = 'enabled_clouds_'
35
-
36
- _DB_PATH = os.path.expanduser('~/.sky/state.db')
37
- pathlib.Path(_DB_PATH).parents[0].mkdir(parents=True, exist_ok=True)
38
-
39
-
40
- def create_table(cursor, conn):
53
+ _ALLOWED_CLOUDS_KEY_PREFIX = 'allowed_clouds_'
54
+
55
+ _SQLALCHEMY_ENGINE: Optional[sqlalchemy.engine.Engine] = None
56
+ _SQLALCHEMY_ENGINE_ASYNC: Optional[sql_async.AsyncEngine] = None
57
+ _SQLALCHEMY_ENGINE_LOCK = threading.Lock()
58
+
59
+ DEFAULT_CLUSTER_EVENT_RETENTION_HOURS = 24.0
60
+ DEBUG_CLUSTER_EVENT_RETENTION_HOURS = 30 * 24.0
61
+ MIN_CLUSTER_EVENT_DAEMON_INTERVAL_SECONDS = 3600
62
+
63
+ _UNIQUE_CONSTRAINT_FAILED_ERROR_MSGS = [
64
+ # sqlite
65
+ 'UNIQUE constraint failed',
66
+ # postgres
67
+ 'duplicate key value violates unique constraint',
68
+ ]
69
+
70
+ Base = declarative.declarative_base()
71
+
72
+ config_table = sqlalchemy.Table(
73
+ 'config',
74
+ Base.metadata,
75
+ sqlalchemy.Column('key', sqlalchemy.Text, primary_key=True),
76
+ sqlalchemy.Column('value', sqlalchemy.Text),
77
+ )
78
+
79
+ user_table = sqlalchemy.Table(
80
+ 'users',
81
+ Base.metadata,
82
+ sqlalchemy.Column('id', sqlalchemy.Text, primary_key=True),
83
+ sqlalchemy.Column('name', sqlalchemy.Text),
84
+ sqlalchemy.Column('password', sqlalchemy.Text),
85
+ sqlalchemy.Column('created_at', sqlalchemy.Integer),
86
+ )
87
+
88
+ cluster_table = sqlalchemy.Table(
89
+ 'clusters',
90
+ Base.metadata,
91
+ sqlalchemy.Column('name', sqlalchemy.Text, primary_key=True),
92
+ sqlalchemy.Column('launched_at', sqlalchemy.Integer),
93
+ sqlalchemy.Column('handle', sqlalchemy.LargeBinary),
94
+ sqlalchemy.Column('last_use', sqlalchemy.Text),
95
+ sqlalchemy.Column('status', sqlalchemy.Text),
96
+ sqlalchemy.Column('autostop', sqlalchemy.Integer, server_default='-1'),
97
+ sqlalchemy.Column('to_down', sqlalchemy.Integer, server_default='0'),
98
+ sqlalchemy.Column('metadata', sqlalchemy.Text, server_default='{}'),
99
+ sqlalchemy.Column('owner', sqlalchemy.Text, server_default=None),
100
+ sqlalchemy.Column('cluster_hash', sqlalchemy.Text, server_default=None),
101
+ sqlalchemy.Column('storage_mounts_metadata',
102
+ sqlalchemy.LargeBinary,
103
+ server_default=None),
104
+ sqlalchemy.Column('cluster_ever_up', sqlalchemy.Integer,
105
+ server_default='0'),
106
+ sqlalchemy.Column('status_updated_at',
107
+ sqlalchemy.Integer,
108
+ server_default=None),
109
+ sqlalchemy.Column('config_hash', sqlalchemy.Text, server_default=None),
110
+ sqlalchemy.Column('user_hash', sqlalchemy.Text, server_default=None),
111
+ sqlalchemy.Column('workspace',
112
+ sqlalchemy.Text,
113
+ server_default=constants.SKYPILOT_DEFAULT_WORKSPACE),
114
+ sqlalchemy.Column('last_creation_yaml',
115
+ sqlalchemy.Text,
116
+ server_default=None),
117
+ sqlalchemy.Column('last_creation_command',
118
+ sqlalchemy.Text,
119
+ server_default=None),
120
+ sqlalchemy.Column('is_managed', sqlalchemy.Integer, server_default='0'),
121
+ sqlalchemy.Column('provision_log_path',
122
+ sqlalchemy.Text,
123
+ server_default=None),
124
+ sqlalchemy.Column('skylet_ssh_tunnel_metadata',
125
+ sqlalchemy.LargeBinary,
126
+ server_default=None),
127
+ )
128
+
129
+ storage_table = sqlalchemy.Table(
130
+ 'storage',
131
+ Base.metadata,
132
+ sqlalchemy.Column('name', sqlalchemy.Text, primary_key=True),
133
+ sqlalchemy.Column('launched_at', sqlalchemy.Integer),
134
+ sqlalchemy.Column('handle', sqlalchemy.LargeBinary),
135
+ sqlalchemy.Column('last_use', sqlalchemy.Text),
136
+ sqlalchemy.Column('status', sqlalchemy.Text),
137
+ )
138
+
139
+ volume_table = sqlalchemy.Table(
140
+ 'volumes',
141
+ Base.metadata,
142
+ sqlalchemy.Column('name', sqlalchemy.Text, primary_key=True),
143
+ sqlalchemy.Column('launched_at', sqlalchemy.Integer),
144
+ sqlalchemy.Column('handle', sqlalchemy.LargeBinary),
145
+ sqlalchemy.Column('user_hash', sqlalchemy.Text, server_default=None),
146
+ sqlalchemy.Column('workspace',
147
+ sqlalchemy.Text,
148
+ server_default=constants.SKYPILOT_DEFAULT_WORKSPACE),
149
+ sqlalchemy.Column('last_attached_at',
150
+ sqlalchemy.Integer,
151
+ server_default=None),
152
+ sqlalchemy.Column('last_use', sqlalchemy.Text),
153
+ sqlalchemy.Column('status', sqlalchemy.Text),
154
+ sqlalchemy.Column('is_ephemeral', sqlalchemy.Integer, server_default='0'),
155
+ )
156
+
157
+ # Table for Cluster History
158
+ # usage_intervals: List[Tuple[int, int]]
159
+ # Specifies start and end timestamps of cluster.
160
+ # When the last end time is None, the cluster is still UP.
161
+ # Example: [(start1, end1), (start2, end2), (start3, None)]
162
+
163
+ # requested_resources: Set[resource_lib.Resource]
164
+ # Requested resources fetched from task that user specifies.
165
+
166
+ # launched_resources: Optional[resources_lib.Resources]
167
+ # Actual launched resources fetched from handle for cluster.
168
+
169
+ # num_nodes: Optional[int] number of nodes launched.
170
+ cluster_history_table = sqlalchemy.Table(
171
+ 'cluster_history',
172
+ Base.metadata,
173
+ sqlalchemy.Column('cluster_hash', sqlalchemy.Text, primary_key=True),
174
+ sqlalchemy.Column('name', sqlalchemy.Text),
175
+ sqlalchemy.Column('num_nodes', sqlalchemy.Integer),
176
+ sqlalchemy.Column('requested_resources', sqlalchemy.LargeBinary),
177
+ sqlalchemy.Column('launched_resources', sqlalchemy.LargeBinary),
178
+ sqlalchemy.Column('usage_intervals', sqlalchemy.LargeBinary),
179
+ sqlalchemy.Column('user_hash', sqlalchemy.Text),
180
+ sqlalchemy.Column('last_creation_yaml',
181
+ sqlalchemy.Text,
182
+ server_default=None),
183
+ sqlalchemy.Column('last_creation_command',
184
+ sqlalchemy.Text,
185
+ server_default=None),
186
+ sqlalchemy.Column('workspace', sqlalchemy.Text, server_default=None),
187
+ sqlalchemy.Column('provision_log_path',
188
+ sqlalchemy.Text,
189
+ server_default=None),
190
+ sqlalchemy.Column('last_activity_time',
191
+ sqlalchemy.Integer,
192
+ server_default=None,
193
+ index=True),
194
+ sqlalchemy.Column('launched_at',
195
+ sqlalchemy.Integer,
196
+ server_default=None,
197
+ index=True),
198
+ )
199
+
200
+
201
+ class ClusterEventType(enum.Enum):
202
+ """Type of cluster event."""
203
+ DEBUG = 'DEBUG'
204
+ """Used to denote events that are not related to cluster status."""
205
+
206
+ STATUS_CHANGE = 'STATUS_CHANGE'
207
+ """Used to denote events that modify cluster status."""
208
+
209
+
210
+ # Table for cluster status change events.
211
+ # starting_status: Status of the cluster at the start of the event.
212
+ # ending_status: Status of the cluster at the end of the event.
213
+ # reason: Reason for the transition.
214
+ # transitioned_at: Timestamp of the transition.
215
+ cluster_event_table = sqlalchemy.Table(
216
+ 'cluster_events',
217
+ Base.metadata,
218
+ sqlalchemy.Column('cluster_hash', sqlalchemy.Text, primary_key=True),
219
+ sqlalchemy.Column('name', sqlalchemy.Text),
220
+ sqlalchemy.Column('starting_status', sqlalchemy.Text),
221
+ sqlalchemy.Column('ending_status', sqlalchemy.Text),
222
+ sqlalchemy.Column('reason', sqlalchemy.Text, primary_key=True),
223
+ sqlalchemy.Column('transitioned_at', sqlalchemy.Integer, primary_key=True),
224
+ sqlalchemy.Column('type', sqlalchemy.Text),
225
+ sqlalchemy.Column('request_id', sqlalchemy.Text, server_default=None),
226
+ )
227
+
228
+ ssh_key_table = sqlalchemy.Table(
229
+ 'ssh_key',
230
+ Base.metadata,
231
+ sqlalchemy.Column('user_hash', sqlalchemy.Text, primary_key=True),
232
+ sqlalchemy.Column('ssh_public_key', sqlalchemy.Text),
233
+ sqlalchemy.Column('ssh_private_key', sqlalchemy.Text),
234
+ )
235
+
236
+ service_account_token_table = sqlalchemy.Table(
237
+ 'service_account_tokens',
238
+ Base.metadata,
239
+ sqlalchemy.Column('token_id', sqlalchemy.Text, primary_key=True),
240
+ sqlalchemy.Column('token_name', sqlalchemy.Text),
241
+ sqlalchemy.Column('token_hash', sqlalchemy.Text),
242
+ sqlalchemy.Column('created_at', sqlalchemy.Integer),
243
+ sqlalchemy.Column('last_used_at', sqlalchemy.Integer, server_default=None),
244
+ sqlalchemy.Column('expires_at', sqlalchemy.Integer, server_default=None),
245
+ sqlalchemy.Column('creator_user_hash',
246
+ sqlalchemy.Text), # Who created this token
247
+ sqlalchemy.Column('service_account_user_id',
248
+ sqlalchemy.Text), # Service account's own user ID
249
+ )
250
+
251
+ cluster_yaml_table = sqlalchemy.Table(
252
+ 'cluster_yaml',
253
+ Base.metadata,
254
+ sqlalchemy.Column('cluster_name', sqlalchemy.Text, primary_key=True),
255
+ sqlalchemy.Column('yaml', sqlalchemy.Text),
256
+ )
257
+
258
+ system_config_table = sqlalchemy.Table(
259
+ 'system_config',
260
+ Base.metadata,
261
+ sqlalchemy.Column('config_key', sqlalchemy.Text, primary_key=True),
262
+ sqlalchemy.Column('config_value', sqlalchemy.Text),
263
+ sqlalchemy.Column('created_at', sqlalchemy.Integer),
264
+ sqlalchemy.Column('updated_at', sqlalchemy.Integer),
265
+ )
266
+
267
+
268
+ def _glob_to_similar(glob_pattern):
269
+ """Converts a glob pattern to a PostgreSQL LIKE pattern."""
270
+
271
+ # Escape special LIKE characters that are not special in glob
272
+ glob_pattern = glob_pattern.replace('%', '\\%').replace('_', '\\_')
273
+
274
+ # Convert glob wildcards to LIKE wildcards
275
+ like_pattern = glob_pattern.replace('*', '%').replace('?', '_')
276
+
277
+ # Handle character classes, including negation
278
+ def replace_char_class(match):
279
+ group = match.group(0)
280
+ if group.startswith('[!'):
281
+ return '[^' + group[2:-1] + ']'
282
+ return group
283
+
284
+ like_pattern = re.sub(r'\[(!)?.*?\]', replace_char_class, like_pattern)
285
+ return like_pattern
286
+
287
+
288
+ def create_table(engine: sqlalchemy.engine.Engine):
41
289
  # Enable WAL mode to avoid locking issues.
42
290
  # See: issue #1441 and PR #1509
43
291
  # https://github.com/microsoft/WSL/issues/2395
44
292
  # TODO(romilb): We do not enable WAL for WSL because of known issue in WSL.
45
293
  # This may cause the database locked problem from WSL issue #1441.
46
- if not common_utils.is_wsl():
294
+ if (engine.dialect.name == db_utils.SQLAlchemyDialect.SQLITE.value and
295
+ not common_utils.is_wsl()):
47
296
  try:
48
- cursor.execute('PRAGMA journal_mode=WAL')
49
- except sqlite3.OperationalError as e:
297
+ with orm.Session(engine) as session:
298
+ session.execute(sqlalchemy.text('PRAGMA journal_mode=WAL'))
299
+ session.commit()
300
+ except sqlalchemy_exc.OperationalError as e:
50
301
  if 'database is locked' not in str(e):
51
302
  raise
52
303
  # If the database is locked, it is OK to continue, as the WAL mode
53
304
  # is not critical and is likely to be enabled by other processes.
54
305
 
55
- # Table for Clusters
56
- cursor.execute("""\
57
- CREATE TABLE IF NOT EXISTS clusters (
58
- name TEXT PRIMARY KEY,
59
- launched_at INTEGER,
60
- handle BLOB,
61
- last_use TEXT,
62
- status TEXT,
63
- autostop INTEGER DEFAULT -1,
64
- metadata TEXT DEFAULT '{}',
65
- to_down INTEGER DEFAULT 0,
66
- owner TEXT DEFAULT null,
67
- cluster_hash TEXT DEFAULT null,
68
- storage_mounts_metadata BLOB DEFAULT null,
69
- cluster_ever_up INTEGER DEFAULT 0,
70
- status_updated_at INTEGER DEFAULT null,
71
- config_hash TEXT DEFAULT null,
72
- user_hash TEXT DEFAULT null)""")
73
-
74
- # Table for Cluster History
75
- # usage_intervals: List[Tuple[int, int]]
76
- # Specifies start and end timestamps of cluster.
77
- # When the last end time is None, the cluster is still UP.
78
- # Example: [(start1, end1), (start2, end2), (start3, None)]
79
-
80
- # requested_resources: Set[resource_lib.Resource]
81
- # Requested resources fetched from task that user specifies.
82
-
83
- # launched_resources: Optional[resources_lib.Resources]
84
- # Actual launched resources fetched from handle for cluster.
85
-
86
- # num_nodes: Optional[int] number of nodes launched.
87
-
88
- cursor.execute("""\
89
- CREATE TABLE IF NOT EXISTS cluster_history (
90
- cluster_hash TEXT PRIMARY KEY,
91
- name TEXT,
92
- num_nodes int,
93
- requested_resources BLOB,
94
- launched_resources BLOB,
95
- usage_intervals BLOB,
96
- user_hash TEXT)""")
97
- # Table for configs (e.g. enabled clouds)
98
- cursor.execute("""\
99
- CREATE TABLE IF NOT EXISTS config (
100
- key TEXT PRIMARY KEY, value TEXT)""")
101
- # Table for Storage
102
- cursor.execute("""\
103
- CREATE TABLE IF NOT EXISTS storage (
104
- name TEXT PRIMARY KEY,
105
- launched_at INTEGER,
106
- handle BLOB,
107
- last_use TEXT,
108
- status TEXT)""")
109
- # Table for User
110
- cursor.execute("""\
111
- CREATE TABLE IF NOT EXISTS users (
112
- id TEXT PRIMARY KEY,
113
- name TEXT)""")
114
- # For backward compatibility.
115
- # TODO(zhwu): Remove this function after all users have migrated to
116
- # the latest version of SkyPilot.
117
- # Add autostop column to clusters table
118
- db_utils.add_column_to_table(cursor, conn, 'clusters', 'autostop',
119
- 'INTEGER DEFAULT -1')
120
-
121
- db_utils.add_column_to_table(cursor, conn, 'clusters', 'metadata',
122
- 'TEXT DEFAULT \'{}\'')
123
-
124
- db_utils.add_column_to_table(cursor, conn, 'clusters', 'to_down',
125
- 'INTEGER DEFAULT 0')
126
-
127
- # The cloud identity that created the cluster.
128
- db_utils.add_column_to_table(cursor, conn, 'clusters', 'owner', 'TEXT')
129
-
130
- db_utils.add_column_to_table(cursor, conn, 'clusters', 'cluster_hash',
131
- 'TEXT DEFAULT null')
132
-
133
- db_utils.add_column_to_table(cursor, conn, 'clusters',
134
- 'storage_mounts_metadata', 'BLOB DEFAULT null')
135
- db_utils.add_column_to_table(
136
- cursor,
137
- conn,
138
- 'clusters',
139
- 'cluster_ever_up',
140
- 'INTEGER DEFAULT 0',
141
- # Set the value to 1 so that all the existing clusters before #2977
142
- # are considered as ever up, i.e:
143
- # existing cluster's default (null) -> 1;
144
- # new cluster's default -> 0;
145
- # This is conservative for the existing clusters: even if some INIT
146
- # clusters were never really UP, setting it to 1 means they won't be
147
- # auto-deleted during any failover.
148
- value_to_replace_existing_entries=1)
149
- db_utils.add_column_to_table(cursor, conn, 'clusters', 'status_updated_at',
150
- 'INTEGER DEFAULT null')
151
- db_utils.add_column_to_table(
152
- cursor,
153
- conn,
154
- 'clusters',
155
- 'user_hash',
156
- 'TEXT DEFAULT null',
157
- value_to_replace_existing_entries=common_utils.get_user_hash())
158
- db_utils.add_column_to_table(cursor, conn, 'clusters', 'config_hash',
159
- 'TEXT DEFAULT null')
160
-
161
- db_utils.add_column_to_table(cursor, conn, 'clusters', 'config_hash',
162
- 'TEXT DEFAULT null')
163
-
164
- db_utils.add_column_to_table(cursor, conn, 'cluster_history', 'user_hash',
165
- 'TEXT DEFAULT null')
166
- conn.commit()
167
-
168
-
169
- _DB = db_utils.SQLiteConn(_DB_PATH, create_table)
170
-
171
-
172
- def add_or_update_user(user: models.User):
173
- """Store the mapping from user hash to user name for display purposes."""
174
- if user.name is None:
175
- return
176
- _DB.cursor.execute('INSERT OR REPLACE INTO users (id, name) VALUES (?, ?)',
177
- (user.id, user.name))
178
- _DB.conn.commit()
306
+ migration_utils.safe_alembic_upgrade(
307
+ engine, migration_utils.GLOBAL_USER_STATE_DB_NAME,
308
+ migration_utils.GLOBAL_USER_STATE_VERSION)
179
309
 
180
310
 
181
- def get_user(user_id: str) -> models.User:
182
- row = _DB.cursor.execute('SELECT id, name FROM users WHERE id=?',
183
- (user_id,)).fetchone()
184
- if row is None:
185
- return models.User(id=user_id)
186
- return models.User(id=row[0], name=row[1])
311
+ def initialize_and_get_db_async() -> sql_async.AsyncEngine:
312
+ global _SQLALCHEMY_ENGINE_ASYNC
313
+ if _SQLALCHEMY_ENGINE_ASYNC is not None:
314
+ return _SQLALCHEMY_ENGINE_ASYNC
315
+ with _SQLALCHEMY_ENGINE_LOCK:
316
+ if _SQLALCHEMY_ENGINE_ASYNC is not None:
317
+ return _SQLALCHEMY_ENGINE_ASYNC
187
318
 
319
+ _SQLALCHEMY_ENGINE_ASYNC = db_utils.get_engine('state',
320
+ async_engine=True)
321
+ initialize_and_get_db()
322
+ return _SQLALCHEMY_ENGINE_ASYNC
188
323
 
189
- def get_all_users() -> List[models.User]:
190
- rows = _DB.cursor.execute('SELECT id, name FROM users').fetchall()
191
- return [models.User(id=row[0], name=row[1]) for row in rows]
192
324
 
325
+ # We wrap the sqlalchemy engine initialization in a thread
326
+ # lock to ensure that multiple threads do not initialize the
327
+ # engine which could result in a rare race condition where
328
+ # a session has already been created with _SQLALCHEMY_ENGINE = e1,
329
+ # and then another thread overwrites _SQLALCHEMY_ENGINE = e2
330
+ # which could result in e1 being garbage collected unexpectedly.
331
+ def initialize_and_get_db() -> sqlalchemy.engine.Engine:
332
+ global _SQLALCHEMY_ENGINE
333
+
334
+ if _SQLALCHEMY_ENGINE is not None:
335
+ return _SQLALCHEMY_ENGINE
336
+ with _SQLALCHEMY_ENGINE_LOCK:
337
+ if _SQLALCHEMY_ENGINE is not None:
338
+ return _SQLALCHEMY_ENGINE
339
+ # get an engine to the db
340
+ engine = db_utils.get_engine('state')
341
+
342
+ # run migrations if needed
343
+ create_table(engine)
344
+
345
+ # return engine
346
+ _SQLALCHEMY_ENGINE = engine
347
+ # Cache the result of _sqlite_supports_returning()
348
+ # ahead of time, as it won't change throughout
349
+ # the lifetime of the engine.
350
+ _sqlite_supports_returning()
351
+ return _SQLALCHEMY_ENGINE
352
+
353
+
354
+ def _init_db_async(func):
355
+ """Initialize the async database."""
356
+
357
+ @functools.wraps(func)
358
+ async def wrapper(*args, **kwargs):
359
+ if _SQLALCHEMY_ENGINE_ASYNC is None:
360
+ # this may happen multiple times since there is no locking
361
+ # here but thats fine, this is just a short circuit for the
362
+ # common case.
363
+ await context_utils.to_thread(initialize_and_get_db_async)
193
364
 
365
+ return await func(*args, **kwargs)
366
+
367
+ return wrapper
368
+
369
+
370
+ def _init_db(func):
371
+ """Initialize the database."""
372
+
373
+ @functools.wraps(func)
374
+ def wrapper(*args, **kwargs):
375
+ initialize_and_get_db()
376
+ return func(*args, **kwargs)
377
+
378
+ return wrapper
379
+
380
+
381
+ @annotations.lru_cache(scope='global', maxsize=1)
382
+ def _sqlite_supports_returning() -> bool:
383
+ """Check if SQLite (3.35.0+) and SQLAlchemy (2.0+) support RETURNING.
384
+
385
+ See https://sqlite.org/lang_returning.html and
386
+ https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#insert-update-delete-returning # pylint: disable=line-too-long
387
+ """
388
+ sqlalchemy_version_parts = sqlalchemy.__version__.split('.')
389
+ assert len(sqlalchemy_version_parts) >= 1, \
390
+ f'Invalid SQLAlchemy version: {sqlalchemy.__version__}'
391
+ sqlalchemy_major = int(sqlalchemy_version_parts[0])
392
+ if sqlalchemy_major < 2:
393
+ return False
394
+
395
+ assert _SQLALCHEMY_ENGINE is not None
396
+ if (_SQLALCHEMY_ENGINE.dialect.name !=
397
+ db_utils.SQLAlchemyDialect.SQLITE.value):
398
+ return False
399
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
400
+ result = session.execute(sqlalchemy.text('SELECT sqlite_version()'))
401
+ version_str = result.scalar()
402
+ version_parts = version_str.split('.')
403
+ assert len(version_parts) >= 2, \
404
+ f'Invalid version string: {version_str}'
405
+ major, minor = int(version_parts[0]), int(version_parts[1])
406
+ return (major > 3) or (major == 3 and minor >= 35)
407
+
408
+
409
+ @_init_db
410
+ @metrics_lib.time_me
411
+ def add_or_update_user(
412
+ user: models.User,
413
+ allow_duplicate_name: bool = True,
414
+ return_user: bool = False
415
+ ) -> typing.Union[bool, typing.Tuple[bool, models.User]]:
416
+ """Store the mapping from user hash to user name for display purposes.
417
+
418
+ Returns:
419
+ If return_user=False: bool (whether the user is newly added)
420
+ If return_user=True: Tuple[bool, models.User]
421
+ """
422
+ assert _SQLALCHEMY_ENGINE is not None
423
+
424
+ if user.name is None:
425
+ return (False, user) if return_user else False
426
+
427
+ # Set created_at if not already set
428
+ created_at = user.created_at
429
+ if created_at is None:
430
+ created_at = int(time.time())
431
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
432
+ # Check for duplicate names if not allowed (within the same transaction)
433
+ if not allow_duplicate_name:
434
+ existing_user = session.query(user_table).filter(
435
+ user_table.c.name == user.name).first()
436
+ if existing_user is not None:
437
+ return (False, user) if return_user else False
438
+
439
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
440
+ db_utils.SQLAlchemyDialect.SQLITE.value):
441
+ # For SQLite, use INSERT OR IGNORE followed by UPDATE to detect new
442
+ # vs existing
443
+ insert_func = sqlite.insert
444
+
445
+ # First try INSERT OR IGNORE - this won't fail if user exists
446
+ insert_stmnt = insert_func(user_table).prefix_with(
447
+ 'OR IGNORE').values(id=user.id,
448
+ name=user.name,
449
+ password=user.password,
450
+ created_at=created_at)
451
+ use_returning = return_user and _sqlite_supports_returning()
452
+ if use_returning:
453
+ insert_stmnt = insert_stmnt.returning(
454
+ user_table.c.id,
455
+ user_table.c.name,
456
+ user_table.c.password,
457
+ user_table.c.created_at,
458
+ )
459
+ result = session.execute(insert_stmnt)
460
+
461
+ row = None
462
+ if use_returning:
463
+ # With RETURNING, check if we got a row back.
464
+ row = result.fetchone()
465
+ was_inserted = row is not None
466
+ else:
467
+ # Without RETURNING, use rowcount.
468
+ was_inserted = result.rowcount > 0
469
+
470
+ if not was_inserted:
471
+ # User existed, so update it (but don't update created_at)
472
+ update_values = {user_table.c.name: user.name}
473
+ if user.password:
474
+ update_values[user_table.c.password] = user.password
475
+
476
+ update_stmnt = sqlalchemy.update(user_table).where(
477
+ user_table.c.id == user.id).values(update_values)
478
+ if use_returning:
479
+ update_stmnt = update_stmnt.returning(
480
+ user_table.c.id, user_table.c.name,
481
+ user_table.c.password, user_table.c.created_at)
482
+
483
+ result = session.execute(update_stmnt)
484
+ if use_returning:
485
+ row = result.fetchone()
486
+
487
+ session.commit()
488
+
489
+ if return_user:
490
+ if row is None:
491
+ # row=None means the sqlite used has no RETURNING support,
492
+ # so we need to do a separate query
493
+ row = session.query(user_table).filter_by(
494
+ id=user.id).first()
495
+ updated_user = models.User(id=row.id,
496
+ name=row.name,
497
+ password=row.password,
498
+ created_at=row.created_at)
499
+ return was_inserted, updated_user
500
+ else:
501
+ return was_inserted
502
+
503
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
504
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
505
+ # For PostgreSQL, use INSERT ... ON CONFLICT with RETURNING to
506
+ # detect insert vs update
507
+ insert_func = postgresql.insert
508
+
509
+ insert_stmnt = insert_func(user_table).values(
510
+ id=user.id,
511
+ name=user.name,
512
+ password=user.password,
513
+ created_at=created_at)
514
+
515
+ # Use a sentinel in the RETURNING clause to detect insert vs update
516
+ if user.password:
517
+ set_ = {
518
+ user_table.c.name: user.name,
519
+ user_table.c.password: user.password
520
+ }
521
+ else:
522
+ set_ = {user_table.c.name: user.name}
523
+ upsert_stmnt = insert_stmnt.on_conflict_do_update(
524
+ index_elements=[user_table.c.id], set_=set_).returning(
525
+ user_table.c.id,
526
+ user_table.c.name,
527
+ user_table.c.password,
528
+ user_table.c.created_at,
529
+ # This will be True for INSERT, False for UPDATE
530
+ sqlalchemy.literal_column('(xmax = 0)').label('was_inserted'
531
+ ))
532
+
533
+ result = session.execute(upsert_stmnt)
534
+ row = result.fetchone()
535
+
536
+ was_inserted = bool(row.was_inserted) if row else False
537
+ session.commit()
538
+
539
+ if return_user:
540
+ updated_user = models.User(id=row.id,
541
+ name=row.name,
542
+ password=row.password,
543
+ created_at=row.created_at)
544
+ return was_inserted, updated_user
545
+ else:
546
+ return was_inserted
547
+ else:
548
+ raise ValueError('Unsupported database dialect')
549
+
550
+
551
+ @_init_db
552
+ @metrics_lib.time_me
553
+ def get_user(user_id: str) -> Optional[models.User]:
554
+ assert _SQLALCHEMY_ENGINE is not None
555
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
556
+ row = session.query(user_table).filter_by(id=user_id).first()
557
+ if row is None:
558
+ return None
559
+ return models.User(id=row.id,
560
+ name=row.name,
561
+ password=row.password,
562
+ created_at=row.created_at)
563
+
564
+
565
+ @_init_db
566
+ @metrics_lib.time_me
567
+ def get_users(user_ids: Set[str]) -> Dict[str, models.User]:
568
+ assert _SQLALCHEMY_ENGINE is not None
569
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
570
+ rows = session.query(user_table).filter(
571
+ user_table.c.id.in_(user_ids)).all()
572
+ return {
573
+ row.id: models.User(id=row.id,
574
+ name=row.name,
575
+ password=row.password,
576
+ created_at=row.created_at) for row in rows
577
+ }
578
+
579
+
580
+ @_init_db
581
+ @metrics_lib.time_me
582
+ def get_user_by_name(username: str) -> List[models.User]:
583
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
584
+ rows = session.query(user_table).filter_by(name=username).all()
585
+ if len(rows) == 0:
586
+ return []
587
+ return [
588
+ models.User(id=row.id,
589
+ name=row.name,
590
+ password=row.password,
591
+ created_at=row.created_at) for row in rows
592
+ ]
593
+
594
+
595
+ @_init_db
596
+ @metrics_lib.time_me
597
+ def get_user_by_name_match(username_match: str) -> List[models.User]:
598
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
599
+ rows = session.query(user_table).filter(
600
+ user_table.c.name.like(f'%{username_match}%')).all()
601
+ return [
602
+ models.User(id=row.id, name=row.name, created_at=row.created_at)
603
+ for row in rows
604
+ ]
605
+
606
+
607
+ @_init_db
608
+ @metrics_lib.time_me
609
+ def delete_user(user_id: str) -> None:
610
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
611
+ session.query(user_table).filter_by(id=user_id).delete()
612
+ session.commit()
613
+
614
+
615
+ @_init_db
616
+ @metrics_lib.time_me
617
+ def get_all_users() -> List[models.User]:
618
+ assert _SQLALCHEMY_ENGINE is not None
619
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
620
+ rows = session.query(user_table).all()
621
+ return [
622
+ models.User(id=row.id,
623
+ name=row.name,
624
+ password=row.password,
625
+ created_at=row.created_at) for row in rows
626
+ ]
627
+
628
+
629
+ @_init_db
630
+ @metrics_lib.time_me
194
631
  def add_or_update_cluster(cluster_name: str,
195
632
  cluster_handle: 'backends.ResourceHandle',
196
633
  requested_resources: Optional[Set[Any]],
197
634
  ready: bool,
198
635
  is_launch: bool = True,
199
- config_hash: Optional[str] = None):
636
+ config_hash: Optional[str] = None,
637
+ task_config: Optional[Dict[str, Any]] = None,
638
+ is_managed: bool = False,
639
+ provision_log_path: Optional[str] = None,
640
+ existing_cluster_hash: Optional[str] = None):
200
641
  """Adds or updates cluster_name -> cluster_handle mapping.
201
642
 
202
643
  Args:
@@ -207,7 +648,17 @@ def add_or_update_cluster(cluster_name: str,
207
648
  be marked as INIT, otherwise it will be marked as UP.
208
649
  is_launch: if the cluster is firstly launched. If True, the launched_at
209
650
  and last_use will be updated. Otherwise, use the old value.
651
+ config_hash: Configuration hash for the cluster.
652
+ task_config: The config of the task being launched.
653
+ is_managed: Whether the cluster is launched by the
654
+ controller.
655
+ provision_log_path: Absolute path to provision.log, if available.
656
+ existing_cluster_hash: If specified, the cluster will be updated
657
+ only if the cluster_hash matches. If a cluster does not exist,
658
+ it will not be inserted and an error will be raised.
210
659
  """
660
+ assert _SQLALCHEMY_ENGINE is not None
661
+
211
662
  # FIXME: launched_at will be changed when `sky launch -c` is called.
212
663
  handle = pickle.dumps(cluster_handle)
213
664
  cluster_launched_at = int(time.time()) if is_launch else None
@@ -240,143 +691,362 @@ def add_or_update_cluster(cluster_name: str,
240
691
  cluster_launched_at = int(time.time())
241
692
  usage_intervals.append((cluster_launched_at, None))
242
693
 
243
- user_hash = common_utils.get_user_hash()
244
-
245
- _DB.cursor.execute(
246
- 'INSERT or REPLACE INTO clusters'
247
- # All the fields need to exist here, even if they don't need
248
- # be changed, as the INSERT OR REPLACE statement will replace
249
- # the field of the existing row with the default value if not
250
- # specified.
251
- '(name, launched_at, handle, last_use, status, '
252
- 'autostop, to_down, metadata, owner, cluster_hash, '
253
- 'storage_mounts_metadata, cluster_ever_up, status_updated_at, '
254
- 'config_hash, user_hash) '
255
- 'VALUES ('
256
- # name
257
- '?, '
258
- # launched_at
259
- 'COALESCE('
260
- '?, (SELECT launched_at FROM clusters WHERE name=?)), '
261
- # handle
262
- '?, '
263
- # last_use
264
- 'COALESCE('
265
- '?, (SELECT last_use FROM clusters WHERE name=?)), '
266
- # status
267
- '?, '
268
- # autostop
269
- # Keep the old autostop value if it exists, otherwise set it to
270
- # default -1.
271
- 'COALESCE('
272
- '(SELECT autostop FROM clusters WHERE name=? AND status!=?), -1), '
273
- # Keep the old to_down value if it exists, otherwise set it to
274
- # default 0.
275
- 'COALESCE('
276
- '(SELECT to_down FROM clusters WHERE name=? AND status!=?), 0),'
277
- # Keep the old metadata value if it exists, otherwise set it to
278
- # default {}.
279
- 'COALESCE('
280
- '(SELECT metadata FROM clusters WHERE name=?), \'{}\'),'
281
- # Keep the old owner value if it exists, otherwise set it to
282
- # default null.
283
- 'COALESCE('
284
- '(SELECT owner FROM clusters WHERE name=?), null),'
285
- # cluster_hash
286
- '?,'
287
- # storage_mounts_metadata
288
- 'COALESCE('
289
- '(SELECT storage_mounts_metadata FROM clusters WHERE name=?), null), '
290
- # cluster_ever_up
291
- '((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?), '
292
- # status_updated_at
293
- '?,'
294
- # config_hash
295
- 'COALESCE(?, (SELECT config_hash FROM clusters WHERE name=?)),'
296
- # user_hash: keep original user_hash if it exists
297
- 'COALESCE('
298
- '(SELECT user_hash FROM clusters WHERE name=?), ?)'
299
- ')',
300
- (
301
- # name
302
- cluster_name,
303
- # launched_at
304
- cluster_launched_at,
305
- cluster_name,
306
- # handle
307
- handle,
308
- # last_use
309
- last_use,
310
- cluster_name,
311
- # status
312
- status.value,
313
- # autostop
314
- cluster_name,
315
- status_lib.ClusterStatus.STOPPED.value,
316
- # to_down
317
- cluster_name,
318
- status_lib.ClusterStatus.STOPPED.value,
319
- # metadata
320
- cluster_name,
321
- # owner
322
- cluster_name,
323
- # cluster_hash
324
- cluster_hash,
325
- # storage_mounts_metadata
326
- cluster_name,
327
- # cluster_ever_up
328
- cluster_name,
329
- int(ready),
330
- # status_updated_at
331
- status_updated_at,
332
- # config_hash
333
- config_hash,
334
- cluster_name,
335
- # user_hash
336
- cluster_name,
337
- user_hash,
338
- ))
339
-
340
- launched_nodes = getattr(cluster_handle, 'launched_nodes', None)
341
- launched_resources = getattr(cluster_handle, 'launched_resources', None)
342
- _DB.cursor.execute(
343
- 'INSERT or REPLACE INTO cluster_history'
344
- '(cluster_hash, name, num_nodes, requested_resources, '
345
- 'launched_resources, usage_intervals, user_hash) '
346
- 'VALUES ('
347
- # hash
348
- '?, '
349
- # name
350
- '?, '
351
- # requested resources
352
- '?, '
353
- # launched resources
354
- '?, '
355
- # number of nodes
356
- '?, '
357
- # usage intervals
358
- '?, '
359
- # user_hash
360
- '?'
361
- ')',
362
- (
363
- # hash
364
- cluster_hash,
365
- # name
366
- cluster_name,
367
- # number of nodes
368
- launched_nodes,
369
- # requested resources
370
- pickle.dumps(requested_resources),
371
- # launched resources
372
- pickle.dumps(launched_resources),
373
- # usage intervals
374
- pickle.dumps(usage_intervals),
375
- # user_hash
376
- user_hash,
377
- ))
378
-
379
- _DB.conn.commit()
694
+ user_hash = common_utils.get_current_user().id
695
+ active_workspace = skypilot_config.get_active_workspace()
696
+ history_workspace = active_workspace
697
+ history_hash = user_hash
698
+
699
+ conditional_values = {}
700
+ if is_launch:
701
+ conditional_values.update({
702
+ 'launched_at': cluster_launched_at,
703
+ 'last_use': last_use
704
+ })
705
+
706
+ if int(ready) == 1:
707
+ conditional_values.update({
708
+ 'cluster_ever_up': 1,
709
+ })
710
+
711
+ if config_hash is not None:
712
+ conditional_values.update({
713
+ 'config_hash': config_hash,
714
+ })
715
+
716
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
717
+ # with_for_update() locks the row until commit() or rollback()
718
+ # is called, or until the code escapes the with block.
719
+ cluster_row = session.query(cluster_table).filter_by(
720
+ name=cluster_name).with_for_update().first()
721
+ if (not cluster_row or
722
+ cluster_row.status == status_lib.ClusterStatus.STOPPED.value):
723
+ conditional_values.update({
724
+ 'autostop': -1,
725
+ 'to_down': 0,
726
+ })
727
+ if not cluster_row or not cluster_row.user_hash:
728
+ conditional_values.update({
729
+ 'user_hash': user_hash,
730
+ })
731
+ if not cluster_row or not cluster_row.workspace:
732
+ conditional_values.update({
733
+ 'workspace': active_workspace,
734
+ })
735
+ if (is_launch and not cluster_row or
736
+ cluster_row.status != status_lib.ClusterStatus.UP.value):
737
+ conditional_values.update({
738
+ 'last_creation_yaml': yaml_utils.dump_yaml_str(task_config)
739
+ if task_config else None,
740
+ 'last_creation_command': last_use,
741
+ })
742
+ if provision_log_path is not None:
743
+ conditional_values.update({
744
+ 'provision_log_path': provision_log_path,
745
+ })
746
+
747
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
748
+ db_utils.SQLAlchemyDialect.SQLITE.value):
749
+ insert_func = sqlite.insert
750
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
751
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
752
+ insert_func = postgresql.insert
753
+ else:
754
+ session.rollback()
755
+ raise ValueError('Unsupported database dialect')
756
+
757
+ if existing_cluster_hash is not None:
758
+ count = session.query(cluster_table).filter_by(
759
+ name=cluster_name, cluster_hash=existing_cluster_hash).update({
760
+ **conditional_values, cluster_table.c.handle: handle,
761
+ cluster_table.c.status: status.value,
762
+ cluster_table.c.status_updated_at: status_updated_at
763
+ })
764
+ assert count <= 1
765
+ if count == 0:
766
+ raise ValueError(f'Cluster {cluster_name} with hash '
767
+ f'{existing_cluster_hash} not found.')
768
+ else:
769
+ insert_stmnt = insert_func(cluster_table).values(
770
+ name=cluster_name,
771
+ **conditional_values,
772
+ handle=handle,
773
+ status=status.value,
774
+ # set metadata to server default ('{}')
775
+ # set owner to server default (null)
776
+ cluster_hash=cluster_hash,
777
+ # set storage_mounts_metadata to server default (null)
778
+ status_updated_at=status_updated_at,
779
+ is_managed=int(is_managed),
780
+ )
781
+ insert_or_update_stmt = insert_stmnt.on_conflict_do_update(
782
+ index_elements=[cluster_table.c.name],
783
+ set_={
784
+ **conditional_values,
785
+ cluster_table.c.handle: handle,
786
+ cluster_table.c.status: status.value,
787
+ # do not update metadata value
788
+ # do not update owner value
789
+ cluster_table.c.cluster_hash: cluster_hash,
790
+ # do not update storage_mounts_metadata
791
+ cluster_table.c.status_updated_at: status_updated_at,
792
+ # do not update user_hash
793
+ })
794
+ session.execute(insert_or_update_stmt)
795
+
796
+ # Modify cluster history table
797
+ launched_nodes = getattr(cluster_handle, 'launched_nodes', None)
798
+ launched_resources = getattr(cluster_handle, 'launched_resources', None)
799
+ if cluster_row and cluster_row.workspace:
800
+ history_workspace = cluster_row.workspace
801
+ if cluster_row and cluster_row.user_hash:
802
+ history_hash = cluster_row.user_hash
803
+ creation_info = {}
804
+ if conditional_values.get('last_creation_yaml') is not None:
805
+ creation_info = {
806
+ 'last_creation_yaml':
807
+ conditional_values.get('last_creation_yaml'),
808
+ 'last_creation_command':
809
+ conditional_values.get('last_creation_command'),
810
+ }
811
+
812
+ # Calculate last_activity_time and launched_at from usage_intervals
813
+ last_activity_time = _get_cluster_last_activity_time(usage_intervals)
814
+ launched_at = _get_cluster_launch_time(usage_intervals)
815
+
816
+ insert_stmnt = insert_func(cluster_history_table).values(
817
+ cluster_hash=cluster_hash,
818
+ name=cluster_name,
819
+ num_nodes=launched_nodes,
820
+ requested_resources=pickle.dumps(requested_resources),
821
+ launched_resources=pickle.dumps(launched_resources),
822
+ usage_intervals=pickle.dumps(usage_intervals),
823
+ user_hash=user_hash,
824
+ workspace=history_workspace,
825
+ provision_log_path=provision_log_path,
826
+ last_activity_time=last_activity_time,
827
+ launched_at=launched_at,
828
+ **creation_info,
829
+ )
830
+ do_update_stmt = insert_stmnt.on_conflict_do_update(
831
+ index_elements=[cluster_history_table.c.cluster_hash],
832
+ set_={
833
+ cluster_history_table.c.name: cluster_name,
834
+ cluster_history_table.c.num_nodes: launched_nodes,
835
+ cluster_history_table.c.requested_resources:
836
+ pickle.dumps(requested_resources),
837
+ cluster_history_table.c.launched_resources:
838
+ pickle.dumps(launched_resources),
839
+ cluster_history_table.c.usage_intervals:
840
+ pickle.dumps(usage_intervals),
841
+ cluster_history_table.c.user_hash: history_hash,
842
+ cluster_history_table.c.workspace: history_workspace,
843
+ cluster_history_table.c.provision_log_path: provision_log_path,
844
+ cluster_history_table.c.last_activity_time: last_activity_time,
845
+ cluster_history_table.c.launched_at: launched_at,
846
+ **creation_info,
847
+ })
848
+ session.execute(do_update_stmt)
849
+
850
+ session.commit()
851
+
852
+
853
+ @_init_db
854
+ @metrics_lib.time_me
855
+ def add_cluster_event(cluster_name: str,
856
+ new_status: Optional[status_lib.ClusterStatus],
857
+ reason: str,
858
+ event_type: ClusterEventType,
859
+ nop_if_duplicate: bool = False,
860
+ duplicate_regex: Optional[str] = None,
861
+ expose_duplicate_error: bool = False,
862
+ transitioned_at: Optional[int] = None) -> None:
863
+ """Add a cluster event.
864
+
865
+ Args:
866
+ cluster_name: Name of the cluster.
867
+ new_status: New status of the cluster.
868
+ reason: Reason for the event.
869
+ event_type: Type of the event.
870
+ nop_if_duplicate: If True, do not add the event if it is a duplicate.
871
+ duplicate_regex: If provided, do not add the event if it matches the
872
+ regex. Only used if nop_if_duplicate is True.
873
+ expose_duplicate_error: If True, raise an error if the event is a
874
+ duplicate. Only used if nop_if_duplicate is True.
875
+ transitioned_at: If provided, use this timestamp for the event.
876
+ """
877
+ assert _SQLALCHEMY_ENGINE is not None
878
+ cluster_hash = _get_hash_for_existing_cluster(cluster_name)
879
+ if cluster_hash is None:
880
+ logger.debug(f'Hash for cluster {cluster_name} not found. '
881
+ 'Skipping event.')
882
+ return
883
+ if transitioned_at is None:
884
+ transitioned_at = int(time.time())
885
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
886
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
887
+ db_utils.SQLAlchemyDialect.SQLITE.value):
888
+ insert_func = sqlite.insert
889
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
890
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
891
+ insert_func = postgresql.insert
892
+ else:
893
+ session.rollback()
894
+ raise ValueError('Unsupported database dialect')
895
+
896
+ cluster_row = session.query(cluster_table).filter_by(name=cluster_name)
897
+ last_status = cluster_row.first(
898
+ ).status if cluster_row and cluster_row.first() is not None else None
899
+ if nop_if_duplicate:
900
+ last_event = get_last_cluster_event(cluster_hash,
901
+ event_type=event_type)
902
+ if duplicate_regex is not None and last_event is not None:
903
+ if re.search(duplicate_regex, last_event):
904
+ return
905
+ elif last_event == reason:
906
+ return
907
+ try:
908
+ request_id = common_utils.get_current_request_id()
909
+ session.execute(
910
+ insert_func(cluster_event_table).values(
911
+ cluster_hash=cluster_hash,
912
+ name=cluster_name,
913
+ starting_status=last_status,
914
+ ending_status=new_status.value if new_status else None,
915
+ reason=reason,
916
+ transitioned_at=transitioned_at,
917
+ type=event_type.value,
918
+ request_id=request_id,
919
+ ))
920
+ session.commit()
921
+ except sqlalchemy.exc.IntegrityError as e:
922
+ for msg in _UNIQUE_CONSTRAINT_FAILED_ERROR_MSGS:
923
+ if msg in str(e):
924
+ # This can happen if the cluster event is added twice.
925
+ # We can ignore this error unless the caller requests
926
+ # to expose the error.
927
+ if expose_duplicate_error:
928
+ raise db_utils.UniqueConstraintViolationError(
929
+ value=reason, message=str(e))
930
+ else:
931
+ return
932
+ raise e
933
+
934
+
935
+ def get_last_cluster_event(cluster_hash: str,
936
+ event_type: ClusterEventType) -> Optional[str]:
937
+ assert _SQLALCHEMY_ENGINE is not None
938
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
939
+ row = session.query(cluster_event_table).filter_by(
940
+ cluster_hash=cluster_hash, type=event_type.value).order_by(
941
+ cluster_event_table.c.transitioned_at.desc()).first()
942
+ if row is None:
943
+ return None
944
+ return row.reason
945
+
946
+
947
+ def _get_last_cluster_event_multiple(
948
+ cluster_hashes: Set[str],
949
+ event_type: ClusterEventType) -> Dict[str, str]:
950
+ assert _SQLALCHEMY_ENGINE is not None
951
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
952
+ # Use a subquery to get the latest event for each cluster_hash
953
+ latest_events = session.query(
954
+ cluster_event_table.c.cluster_hash,
955
+ sqlalchemy.func.max(cluster_event_table.c.transitioned_at).label(
956
+ 'max_time')).filter(
957
+ cluster_event_table.c.cluster_hash.in_(cluster_hashes),
958
+ cluster_event_table.c.type == event_type.value).group_by(
959
+ cluster_event_table.c.cluster_hash).subquery()
960
+
961
+ # Join with original table to get the full event details
962
+ rows = session.query(cluster_event_table).join(
963
+ latest_events,
964
+ sqlalchemy.and_(
965
+ cluster_event_table.c.cluster_hash ==
966
+ latest_events.c.cluster_hash,
967
+ cluster_event_table.c.transitioned_at ==
968
+ latest_events.c.max_time)).all()
969
+
970
+ return {row.cluster_hash: row.reason for row in rows}
971
+
972
+
973
+ def cleanup_cluster_events_with_retention(retention_hours: float,
974
+ event_type: ClusterEventType) -> None:
975
+ assert _SQLALCHEMY_ENGINE is not None
976
+ # Once for events with type STATUS_CHANGE.
977
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
978
+ query = session.query(cluster_event_table).filter(
979
+ cluster_event_table.c.transitioned_at <
980
+ time.time() - retention_hours * 3600,
981
+ cluster_event_table.c.type == event_type.value)
982
+ logger.debug(f'Deleting {query.count()} cluster events.')
983
+ query.delete()
984
+ session.commit()
985
+
986
+
987
+ async def cluster_event_retention_daemon():
988
+ """Garbage collect cluster events periodically."""
989
+ while True:
990
+ logger.info('Running cluster event retention daemon...')
991
+ # Use the latest config.
992
+ skypilot_config.reload_config()
993
+ retention_hours = skypilot_config.get_nested(
994
+ ('api_server', 'cluster_event_retention_hours'),
995
+ DEFAULT_CLUSTER_EVENT_RETENTION_HOURS)
996
+ debug_retention_hours = skypilot_config.get_nested(
997
+ ('api_server', 'cluster_debug_event_retention_hours'),
998
+ DEBUG_CLUSTER_EVENT_RETENTION_HOURS)
999
+ try:
1000
+ if retention_hours >= 0:
1001
+ logger.debug('Cleaning up cluster events with retention '
1002
+ f'{retention_hours} hours.')
1003
+ cleanup_cluster_events_with_retention(
1004
+ retention_hours, ClusterEventType.STATUS_CHANGE)
1005
+ if debug_retention_hours >= 0:
1006
+ logger.debug('Cleaning up debug cluster events with retention '
1007
+ f'{debug_retention_hours} hours.')
1008
+ cleanup_cluster_events_with_retention(debug_retention_hours,
1009
+ ClusterEventType.DEBUG)
1010
+ except asyncio.CancelledError:
1011
+ logger.info('Cluster event retention daemon cancelled')
1012
+ break
1013
+ except Exception as e: # pylint: disable=broad-except
1014
+ logger.error(f'Error running cluster event retention daemon: {e}')
1015
+
1016
+ # Run daemon at most once every hour to avoid too frequent cleanup.
1017
+ sleep_amount = max(
1018
+ min(retention_hours * 3600, debug_retention_hours * 3600),
1019
+ MIN_CLUSTER_EVENT_DAEMON_INTERVAL_SECONDS)
1020
+ await asyncio.sleep(sleep_amount)
1021
+
1022
+
1023
+ def get_cluster_events(cluster_name: Optional[str], cluster_hash: Optional[str],
1024
+ event_type: ClusterEventType) -> List[str]:
1025
+ """Returns the cluster events for the cluster.
1026
+
1027
+ Args:
1028
+ cluster_name: Name of the cluster. Cannot be specified if cluster_hash
1029
+ is specified.
1030
+ cluster_hash: Hash of the cluster. Cannot be specified if cluster_name
1031
+ is specified.
1032
+ event_type: Type of the event.
1033
+ """
1034
+ assert _SQLALCHEMY_ENGINE is not None
1035
+
1036
+ if cluster_name is not None and cluster_hash is not None:
1037
+ raise ValueError('Cannot specify both cluster_name and cluster_hash')
1038
+ if cluster_name is None and cluster_hash is None:
1039
+ raise ValueError('Must specify either cluster_name or cluster_hash')
1040
+ if cluster_name is not None:
1041
+ cluster_hash = _get_hash_for_existing_cluster(cluster_name)
1042
+ if cluster_hash is None:
1043
+ raise ValueError(f'Hash for cluster {cluster_name} not found.')
1044
+
1045
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1046
+ rows = session.query(cluster_event_table).filter_by(
1047
+ cluster_hash=cluster_hash, type=event_type.value).order_by(
1048
+ cluster_event_table.c.transitioned_at.asc()).all()
1049
+ return [row.reason for row in rows]
380
1050
 
381
1051
 
382
1052
  def _get_user_hash_or_current_user(user_hash: Optional[str]) -> str:
@@ -391,186 +1061,402 @@ def _get_user_hash_or_current_user(user_hash: Optional[str]) -> str:
391
1061
  return common_utils.get_user_hash()
392
1062
 
393
1063
 
1064
+ @_init_db
1065
+ @metrics_lib.time_me
394
1066
  def update_cluster_handle(cluster_name: str,
395
1067
  cluster_handle: 'backends.ResourceHandle'):
1068
+ assert _SQLALCHEMY_ENGINE is not None
396
1069
  handle = pickle.dumps(cluster_handle)
397
- _DB.cursor.execute('UPDATE clusters SET handle=(?) WHERE name=(?)',
398
- (handle, cluster_name))
399
- _DB.conn.commit()
1070
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1071
+ session.query(cluster_table).filter_by(name=cluster_name).update(
1072
+ {cluster_table.c.handle: handle})
1073
+ session.commit()
400
1074
 
401
1075
 
1076
+ @_init_db
1077
+ @metrics_lib.time_me
402
1078
  def update_last_use(cluster_name: str):
403
1079
  """Updates the last used command for the cluster."""
404
- _DB.cursor.execute('UPDATE clusters SET last_use=(?) WHERE name=(?)',
405
- (common_utils.get_current_command(), cluster_name))
406
- _DB.conn.commit()
1080
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1081
+ session.query(cluster_table).filter_by(name=cluster_name).update(
1082
+ {cluster_table.c.last_use: common_utils.get_current_command()})
1083
+ session.commit()
407
1084
 
408
1085
 
1086
+ @_init_db
1087
+ @metrics_lib.time_me
409
1088
  def remove_cluster(cluster_name: str, terminate: bool) -> None:
410
1089
  """Removes cluster_name mapping."""
1090
+ assert _SQLALCHEMY_ENGINE is not None
411
1091
  cluster_hash = _get_hash_for_existing_cluster(cluster_name)
412
1092
  usage_intervals = _get_cluster_usage_intervals(cluster_hash)
1093
+ provision_log_path = get_cluster_provision_log_path(cluster_name)
413
1094
 
414
- # usage_intervals is not None and not empty
415
- if usage_intervals:
416
- assert cluster_hash is not None, cluster_name
417
- start_time = usage_intervals.pop()[0]
418
- end_time = int(time.time())
419
- usage_intervals.append((start_time, end_time))
420
- _set_cluster_usage_intervals(cluster_hash, usage_intervals)
421
-
422
- if terminate:
423
- _DB.cursor.execute('DELETE FROM clusters WHERE name=(?)',
424
- (cluster_name,))
425
- else:
426
- handle = get_handle_from_cluster_name(cluster_name)
427
- if handle is None:
428
- return
429
- # Must invalidate IP list to avoid directly trying to ssh into a
430
- # stopped VM, which leads to timeout.
431
- if hasattr(handle, 'stable_internal_external_ips'):
432
- handle = typing.cast('backends.CloudVmRayResourceHandle', handle)
433
- handle.stable_internal_external_ips = None
434
- current_time = int(time.time())
435
- _DB.cursor.execute(
436
- 'UPDATE clusters SET handle=(?), status=(?), '
437
- 'status_updated_at=(?) WHERE name=(?)', (
438
- pickle.dumps(handle),
439
- status_lib.ClusterStatus.STOPPED.value,
440
- current_time,
441
- cluster_name,
442
- ))
443
- _DB.conn.commit()
444
-
445
-
1095
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1096
+ # usage_intervals is not None and not empty
1097
+ if usage_intervals:
1098
+ assert cluster_hash is not None, cluster_name
1099
+ start_time = usage_intervals.pop()[0]
1100
+ end_time = int(time.time())
1101
+ usage_intervals.append((start_time, end_time))
1102
+ _set_cluster_usage_intervals(cluster_hash, usage_intervals)
1103
+
1104
+ if provision_log_path:
1105
+ assert cluster_hash is not None, cluster_name
1106
+ session.query(cluster_history_table).filter_by(
1107
+ cluster_hash=cluster_hash
1108
+ ).filter(
1109
+ cluster_history_table.c.provision_log_path.is_(None)
1110
+ ).update({
1111
+ cluster_history_table.c.provision_log_path: provision_log_path
1112
+ })
1113
+
1114
+ if terminate:
1115
+ session.query(cluster_table).filter_by(name=cluster_name).delete()
1116
+ else:
1117
+ handle = get_handle_from_cluster_name(cluster_name)
1118
+ if handle is None:
1119
+ return
1120
+ # Must invalidate IP list to avoid directly trying to ssh into a
1121
+ # stopped VM, which leads to timeout.
1122
+ if hasattr(handle, 'stable_internal_external_ips'):
1123
+ handle = typing.cast('backends.CloudVmRayResourceHandle',
1124
+ handle)
1125
+ handle.stable_internal_external_ips = None
1126
+ current_time = int(time.time())
1127
+ session.query(cluster_table).filter_by(name=cluster_name).update({
1128
+ cluster_table.c.handle: pickle.dumps(handle),
1129
+ cluster_table.c.status: status_lib.ClusterStatus.STOPPED.value,
1130
+ cluster_table.c.status_updated_at: current_time
1131
+ })
1132
+ session.commit()
1133
+
1134
+
1135
+ @_init_db
1136
+ @metrics_lib.time_me
446
1137
  def get_handle_from_cluster_name(
447
1138
  cluster_name: str) -> Optional['backends.ResourceHandle']:
1139
+ assert _SQLALCHEMY_ENGINE is not None
448
1140
  assert cluster_name is not None, 'cluster_name cannot be None'
449
- rows = _DB.cursor.execute('SELECT handle FROM clusters WHERE name=(?)',
450
- (cluster_name,))
451
- for (handle,) in rows:
452
- return pickle.loads(handle)
453
- return None
1141
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1142
+ row = (session.query(
1143
+ cluster_table.c.handle).filter_by(name=cluster_name).first())
1144
+ if row is None:
1145
+ return None
1146
+ return pickle.loads(row.handle)
1147
+
1148
+
1149
+ @_init_db
1150
+ @metrics_lib.time_me
1151
+ def get_handles_from_cluster_names(
1152
+ cluster_names: Set[str]
1153
+ ) -> Dict[str, Optional['backends.ResourceHandle']]:
1154
+ assert _SQLALCHEMY_ENGINE is not None
1155
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1156
+ rows = session.query(cluster_table.c.name,
1157
+ cluster_table.c.handle).filter(
1158
+ cluster_table.c.name.in_(cluster_names)).all()
1159
+ return {
1160
+ row.name: pickle.loads(row.handle) if row is not None else None
1161
+ for row in rows
1162
+ }
454
1163
 
455
1164
 
456
- def get_glob_cluster_names(cluster_name: str) -> List[str]:
1165
+ @_init_db
1166
+ @metrics_lib.time_me
1167
+ def get_cluster_name_to_handle_map(
1168
+ is_managed: Optional[bool] = None,
1169
+ ) -> Dict[str, Optional['backends.ResourceHandle']]:
1170
+ assert _SQLALCHEMY_ENGINE is not None
1171
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1172
+ query = session.query(cluster_table.c.name, cluster_table.c.handle)
1173
+ if is_managed is not None:
1174
+ query = query.filter(cluster_table.c.is_managed == int(is_managed))
1175
+ rows = query.all()
1176
+ name_to_handle = {}
1177
+ for row in rows:
1178
+ if row.handle and len(row.handle) > 0:
1179
+ name_to_handle[row.name] = pickle.loads(row.handle)
1180
+ else:
1181
+ name_to_handle[row.name] = None
1182
+ return name_to_handle
1183
+
1184
+
1185
+ @_init_db_async
1186
+ @metrics_lib.time_me
1187
+ async def get_status_from_cluster_name_async(
1188
+ cluster_name: str) -> Optional[status_lib.ClusterStatus]:
1189
+ """Get the status of a cluster."""
1190
+ assert _SQLALCHEMY_ENGINE_ASYNC is not None
457
1191
  assert cluster_name is not None, 'cluster_name cannot be None'
458
- rows = _DB.cursor.execute('SELECT name FROM clusters WHERE name GLOB (?)',
459
- (cluster_name,))
460
- return [row[0] for row in rows]
1192
+ async with sql_async.AsyncSession(_SQLALCHEMY_ENGINE_ASYNC) as session:
1193
+ result = await session.execute(
1194
+ sqlalchemy.select(cluster_table.c.status).where(
1195
+ cluster_table.c.name == cluster_name))
1196
+ row = result.first()
1197
+
1198
+ if row is None:
1199
+ return None
1200
+ return status_lib.ClusterStatus(row[0])
461
1201
 
462
1202
 
1203
+ @_init_db
1204
+ @metrics_lib.time_me
1205
+ def get_status_from_cluster_name(
1206
+ cluster_name: str) -> Optional[status_lib.ClusterStatus]:
1207
+ assert _SQLALCHEMY_ENGINE is not None
1208
+ assert cluster_name is not None, 'cluster_name cannot be None'
1209
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1210
+ row = session.query(
1211
+ cluster_table.c.status).filter_by(name=cluster_name).first()
1212
+ if row is None:
1213
+ return None
1214
+ return status_lib.ClusterStatus[row.status]
1215
+
1216
+
1217
+ @_init_db
1218
+ @metrics_lib.time_me
1219
+ def get_glob_cluster_names(
1220
+ cluster_name: str,
1221
+ workspaces_filter: Optional[Set[str]] = None) -> List[str]:
1222
+ assert _SQLALCHEMY_ENGINE is not None
1223
+ assert cluster_name is not None, 'cluster_name cannot be None'
1224
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1225
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
1226
+ db_utils.SQLAlchemyDialect.SQLITE.value):
1227
+ query = session.query(cluster_table.c.name).filter(
1228
+ cluster_table.c.name.op('GLOB')(cluster_name))
1229
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
1230
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
1231
+ query = session.query(cluster_table.c.name).filter(
1232
+ cluster_table.c.name.op('SIMILAR TO')(
1233
+ _glob_to_similar(cluster_name)))
1234
+ else:
1235
+ raise ValueError('Unsupported database dialect')
1236
+ if workspaces_filter is not None:
1237
+ query = query.filter(
1238
+ cluster_table.c.workspace.in_(workspaces_filter))
1239
+ rows = query.all()
1240
+ return [row.name for row in rows]
1241
+
1242
+
1243
+ @_init_db
1244
+ @metrics_lib.time_me
463
1245
  def set_cluster_status(cluster_name: str,
464
1246
  status: status_lib.ClusterStatus) -> None:
1247
+ assert _SQLALCHEMY_ENGINE is not None
465
1248
  current_time = int(time.time())
466
- _DB.cursor.execute(
467
- 'UPDATE clusters SET status=(?), status_updated_at=(?) WHERE name=(?)',
468
- (status.value, current_time, cluster_name))
469
- count = _DB.cursor.rowcount
470
- _DB.conn.commit()
1249
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1250
+ count = session.query(cluster_table).filter_by(
1251
+ name=cluster_name).update({
1252
+ cluster_table.c.status: status.value,
1253
+ cluster_table.c.status_updated_at: current_time
1254
+ })
1255
+ session.commit()
471
1256
  assert count <= 1, count
472
1257
  if count == 0:
473
1258
  raise ValueError(f'Cluster {cluster_name} not found.')
474
1259
 
475
1260
 
1261
+ @_init_db
1262
+ @metrics_lib.time_me
476
1263
  def set_cluster_autostop_value(cluster_name: str, idle_minutes: int,
477
1264
  to_down: bool) -> None:
478
- _DB.cursor.execute(
479
- 'UPDATE clusters SET autostop=(?), to_down=(?) WHERE name=(?)', (
480
- idle_minutes,
481
- int(to_down),
482
- cluster_name,
483
- ))
484
- count = _DB.cursor.rowcount
485
- _DB.conn.commit()
1265
+ assert _SQLALCHEMY_ENGINE is not None
1266
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1267
+ count = session.query(cluster_table).filter_by(
1268
+ name=cluster_name).update({
1269
+ cluster_table.c.autostop: idle_minutes,
1270
+ cluster_table.c.to_down: int(to_down)
1271
+ })
1272
+ session.commit()
486
1273
  assert count <= 1, count
487
1274
  if count == 0:
488
1275
  raise ValueError(f'Cluster {cluster_name} not found.')
489
1276
 
490
1277
 
1278
+ @_init_db
1279
+ @metrics_lib.time_me
491
1280
  def get_cluster_launch_time(cluster_name: str) -> Optional[int]:
492
- rows = _DB.cursor.execute('SELECT launched_at FROM clusters WHERE name=(?)',
493
- (cluster_name,))
494
- for (launch_time,) in rows:
495
- if launch_time is None:
496
- return None
497
- return int(launch_time)
498
- return None
1281
+ assert _SQLALCHEMY_ENGINE is not None
1282
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1283
+ row = session.query(
1284
+ cluster_table.c.launched_at).filter_by(name=cluster_name).first()
1285
+ if row is None or row.launched_at is None:
1286
+ return None
1287
+ return int(row.launched_at)
499
1288
 
500
1289
 
1290
+ @_init_db
1291
+ @metrics_lib.time_me
501
1292
  def get_cluster_info(cluster_name: str) -> Optional[Dict[str, Any]]:
502
- rows = _DB.cursor.execute('SELECT metadata FROM clusters WHERE name=(?)',
503
- (cluster_name,))
504
- for (metadata,) in rows:
505
- if metadata is None:
1293
+ assert _SQLALCHEMY_ENGINE is not None
1294
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1295
+ row = session.query(
1296
+ cluster_table.c.metadata).filter_by(name=cluster_name).first()
1297
+ if row is None or row.metadata is None:
1298
+ return None
1299
+ return json.loads(row.metadata)
1300
+
1301
+
1302
+ @_init_db
1303
+ @metrics_lib.time_me
1304
+ def get_cluster_provision_log_path(cluster_name: str) -> Optional[str]:
1305
+ """Returns provision_log_path from clusters table, if recorded."""
1306
+ assert _SQLALCHEMY_ENGINE is not None
1307
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1308
+ row = session.query(cluster_table).filter_by(name=cluster_name).first()
1309
+ if row is None:
1310
+ return None
1311
+ return getattr(row, 'provision_log_path', None)
1312
+
1313
+
1314
+ @_init_db
1315
+ @metrics_lib.time_me
1316
+ def get_cluster_history_provision_log_path(cluster_name: str) -> Optional[str]:
1317
+ """Returns provision_log_path from cluster_history for this name.
1318
+
1319
+ If the cluster currently exists, we use its hash. Otherwise, we look up
1320
+ historical rows by name and choose the most recent one based on
1321
+ usage_intervals.
1322
+ """
1323
+ assert _SQLALCHEMY_ENGINE is not None
1324
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1325
+ # Try current cluster first (fast path)
1326
+ cluster_hash = _get_hash_for_existing_cluster(cluster_name)
1327
+ if cluster_hash is not None:
1328
+ row = session.query(cluster_history_table).filter_by(
1329
+ cluster_hash=cluster_hash).first()
1330
+ if row is not None:
1331
+ return getattr(row, 'provision_log_path', None)
1332
+
1333
+ # Fallback: search history by name and pick the latest by
1334
+ # usage_intervals
1335
+ rows = session.query(cluster_history_table).filter_by(
1336
+ name=cluster_name).all()
1337
+ if not rows:
506
1338
  return None
507
- return json.loads(metadata)
508
- return None
509
1339
 
1340
+ def latest_timestamp(usages_bin) -> int:
1341
+ try:
1342
+ intervals = pickle.loads(usages_bin)
1343
+ # intervals: List[Tuple[int, Optional[int]]]
1344
+ if not intervals:
1345
+ return -1
1346
+ _, end = intervals[-1]
1347
+ return end if end is not None else int(time.time())
1348
+ except Exception: # pylint: disable=broad-except
1349
+ return -1
1350
+
1351
+ latest_row = max(rows,
1352
+ key=lambda r: latest_timestamp(r.usage_intervals))
1353
+ return getattr(latest_row, 'provision_log_path', None)
510
1354
 
1355
+
1356
+ @_init_db
1357
+ @metrics_lib.time_me
511
1358
  def set_cluster_info(cluster_name: str, metadata: Dict[str, Any]) -> None:
512
- _DB.cursor.execute('UPDATE clusters SET metadata=(?) WHERE name=(?)', (
513
- json.dumps(metadata),
514
- cluster_name,
515
- ))
516
- count = _DB.cursor.rowcount
517
- _DB.conn.commit()
1359
+ assert _SQLALCHEMY_ENGINE is not None
1360
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1361
+ count = session.query(cluster_table).filter_by(
1362
+ name=cluster_name).update(
1363
+ {cluster_table.c.metadata: json.dumps(metadata)})
1364
+ session.commit()
518
1365
  assert count <= 1, count
519
1366
  if count == 0:
520
1367
  raise ValueError(f'Cluster {cluster_name} not found.')
521
1368
 
522
1369
 
1370
+ @_init_db
1371
+ @metrics_lib.time_me
523
1372
  def get_cluster_storage_mounts_metadata(
524
1373
  cluster_name: str) -> Optional[Dict[str, Any]]:
525
- rows = _DB.cursor.execute(
526
- 'SELECT storage_mounts_metadata FROM clusters WHERE name=(?)',
527
- (cluster_name,))
528
- for (storage_mounts_metadata,) in rows:
529
- if storage_mounts_metadata is None:
530
- return None
531
- return pickle.loads(storage_mounts_metadata)
532
- return None
1374
+ assert _SQLALCHEMY_ENGINE is not None
1375
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1376
+ row = (session.query(cluster_table.c.storage_mounts_metadata).filter_by(
1377
+ name=cluster_name).first())
1378
+ if row is None or row.storage_mounts_metadata is None:
1379
+ return None
1380
+ return pickle.loads(row.storage_mounts_metadata)
533
1381
 
534
1382
 
1383
+ @_init_db
1384
+ @metrics_lib.time_me
535
1385
  def set_cluster_storage_mounts_metadata(
536
1386
  cluster_name: str, storage_mounts_metadata: Dict[str, Any]) -> None:
537
- _DB.cursor.execute(
538
- 'UPDATE clusters SET storage_mounts_metadata=(?) WHERE name=(?)', (
539
- pickle.dumps(storage_mounts_metadata),
540
- cluster_name,
541
- ))
542
- count = _DB.cursor.rowcount
543
- _DB.conn.commit()
1387
+ assert _SQLALCHEMY_ENGINE is not None
1388
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1389
+ count = session.query(cluster_table).filter_by(
1390
+ name=cluster_name).update({
1391
+ cluster_table.c.storage_mounts_metadata:
1392
+ pickle.dumps(storage_mounts_metadata)
1393
+ })
1394
+ session.commit()
1395
+ assert count <= 1, count
1396
+ if count == 0:
1397
+ raise ValueError(f'Cluster {cluster_name} not found.')
1398
+
1399
+
1400
+ @_init_db
1401
+ @metrics_lib.time_me
1402
+ def get_cluster_skylet_ssh_tunnel_metadata(
1403
+ cluster_name: str) -> Optional[Tuple[int, int]]:
1404
+ assert _SQLALCHEMY_ENGINE is not None
1405
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1406
+ row = session.query(
1407
+ cluster_table.c.skylet_ssh_tunnel_metadata).filter_by(
1408
+ name=cluster_name).first()
1409
+ if row is None or row.skylet_ssh_tunnel_metadata is None:
1410
+ return None
1411
+ return pickle.loads(row.skylet_ssh_tunnel_metadata)
1412
+
1413
+
1414
+ @_init_db
1415
+ @metrics_lib.time_me
1416
+ def set_cluster_skylet_ssh_tunnel_metadata(
1417
+ cluster_name: str,
1418
+ skylet_ssh_tunnel_metadata: Optional[Tuple[int, int]]) -> None:
1419
+ assert _SQLALCHEMY_ENGINE is not None
1420
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1421
+ value = pickle.dumps(
1422
+ skylet_ssh_tunnel_metadata
1423
+ ) if skylet_ssh_tunnel_metadata is not None else None
1424
+ count = session.query(cluster_table).filter_by(
1425
+ name=cluster_name).update(
1426
+ {cluster_table.c.skylet_ssh_tunnel_metadata: value})
1427
+ session.commit()
544
1428
  assert count <= 1, count
545
1429
  if count == 0:
546
1430
  raise ValueError(f'Cluster {cluster_name} not found.')
547
1431
 
548
1432
 
1433
+ @_init_db
1434
+ @metrics_lib.time_me
549
1435
  def _get_cluster_usage_intervals(
550
1436
  cluster_hash: Optional[str]
551
1437
  ) -> Optional[List[Tuple[int, Optional[int]]]]:
1438
+ assert _SQLALCHEMY_ENGINE is not None
552
1439
  if cluster_hash is None:
553
1440
  return None
554
- rows = _DB.cursor.execute(
555
- 'SELECT usage_intervals FROM cluster_history WHERE cluster_hash=(?)',
556
- (cluster_hash,))
557
- for (usage_intervals,) in rows:
558
- if usage_intervals is None:
559
- return None
560
- return pickle.loads(usage_intervals)
561
- return None
1441
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1442
+ row = session.query(cluster_history_table.c.usage_intervals).filter_by(
1443
+ cluster_hash=cluster_hash).first()
1444
+ if row is None or row.usage_intervals is None:
1445
+ return None
1446
+ return pickle.loads(row.usage_intervals)
562
1447
 
563
1448
 
564
- def _get_cluster_launch_time(cluster_hash: str) -> Optional[int]:
565
- usage_intervals = _get_cluster_usage_intervals(cluster_hash)
1449
+ def _get_cluster_launch_time(
1450
+ usage_intervals: Optional[List[Tuple[int,
1451
+ Optional[int]]]]) -> Optional[int]:
566
1452
  if usage_intervals is None:
567
1453
  return None
568
1454
  return usage_intervals[0][0]
569
1455
 
570
1456
 
571
- def _get_cluster_duration(cluster_hash: str) -> int:
1457
+ def _get_cluster_duration(
1458
+ usage_intervals: Optional[List[Tuple[int, Optional[int]]]]) -> int:
572
1459
  total_duration = 0
573
- usage_intervals = _get_cluster_usage_intervals(cluster_hash)
574
1460
 
575
1461
  if usage_intervals is None:
576
1462
  return total_duration
@@ -587,60 +1473,89 @@ def _get_cluster_duration(cluster_hash: str) -> int:
587
1473
  return total_duration
588
1474
 
589
1475
 
1476
+ def _get_cluster_last_activity_time(
1477
+ usage_intervals: Optional[List[Tuple[int,
1478
+ Optional[int]]]]) -> Optional[int]:
1479
+ last_activity_time = None
1480
+ if usage_intervals:
1481
+ last_interval = usage_intervals[-1]
1482
+ last_activity_time = (last_interval[1] if last_interval[1] is not None
1483
+ else last_interval[0])
1484
+ return last_activity_time
1485
+
1486
+
1487
+ @_init_db
1488
+ @metrics_lib.time_me
590
1489
  def _set_cluster_usage_intervals(
591
1490
  cluster_hash: str, usage_intervals: List[Tuple[int,
592
1491
  Optional[int]]]) -> None:
593
- _DB.cursor.execute(
594
- 'UPDATE cluster_history SET usage_intervals=(?) WHERE cluster_hash=(?)',
595
- (
596
- pickle.dumps(usage_intervals),
597
- cluster_hash,
598
- ))
599
-
600
- count = _DB.cursor.rowcount
601
- _DB.conn.commit()
1492
+ assert _SQLALCHEMY_ENGINE is not None
1493
+
1494
+ # Calculate last_activity_time from usage_intervals
1495
+ last_activity_time = _get_cluster_last_activity_time(usage_intervals)
1496
+
1497
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1498
+ count = session.query(cluster_history_table).filter_by(
1499
+ cluster_hash=cluster_hash).update({
1500
+ cluster_history_table.c.usage_intervals:
1501
+ pickle.dumps(usage_intervals),
1502
+ cluster_history_table.c.last_activity_time: last_activity_time,
1503
+ })
1504
+ session.commit()
602
1505
  assert count <= 1, count
603
1506
  if count == 0:
604
1507
  raise ValueError(f'Cluster hash {cluster_hash} not found.')
605
1508
 
606
1509
 
1510
+ @_init_db
1511
+ @metrics_lib.time_me
607
1512
  def set_owner_identity_for_cluster(cluster_name: str,
608
1513
  owner_identity: Optional[List[str]]) -> None:
1514
+ assert _SQLALCHEMY_ENGINE is not None
609
1515
  if owner_identity is None:
610
1516
  return
611
1517
  owner_identity_str = json.dumps(owner_identity)
612
- _DB.cursor.execute('UPDATE clusters SET owner=(?) WHERE name=(?)',
613
- (owner_identity_str, cluster_name))
614
-
615
- count = _DB.cursor.rowcount
616
- _DB.conn.commit()
1518
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1519
+ count = session.query(cluster_table).filter_by(
1520
+ name=cluster_name).update(
1521
+ {cluster_table.c.owner: owner_identity_str})
1522
+ session.commit()
617
1523
  assert count <= 1, count
618
1524
  if count == 0:
619
1525
  raise ValueError(f'Cluster {cluster_name} not found.')
620
1526
 
621
1527
 
1528
+ @_init_db
1529
+ @metrics_lib.time_me
622
1530
  def _get_hash_for_existing_cluster(cluster_name: str) -> Optional[str]:
623
- rows = _DB.cursor.execute(
624
- 'SELECT cluster_hash FROM clusters WHERE name=(?)', (cluster_name,))
625
- for (cluster_hash,) in rows:
626
- if cluster_hash is None:
627
- return None
628
- return cluster_hash
629
- return None
1531
+ assert _SQLALCHEMY_ENGINE is not None
1532
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1533
+ row = (session.query(
1534
+ cluster_table.c.cluster_hash).filter_by(name=cluster_name).first())
1535
+ if row is None or row.cluster_hash is None:
1536
+ return None
1537
+ return row.cluster_hash
630
1538
 
631
1539
 
1540
+ @_init_db
1541
+ @metrics_lib.time_me
632
1542
  def get_launched_resources_from_cluster_hash(
633
1543
  cluster_hash: str) -> Optional[Tuple[int, Any]]:
1544
+ assert _SQLALCHEMY_ENGINE is not None
1545
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1546
+ row = session.query(
1547
+ cluster_history_table.c.num_nodes,
1548
+ cluster_history_table.c.launched_resources).filter_by(
1549
+ cluster_hash=cluster_hash).first()
1550
+ if row is None:
1551
+ return None
1552
+ num_nodes = row.num_nodes
1553
+ launched_resources = row.launched_resources
634
1554
 
635
- rows = _DB.cursor.execute(
636
- 'SELECT num_nodes, launched_resources '
637
- 'FROM cluster_history WHERE cluster_hash=(?)', (cluster_hash,))
638
- for (num_nodes, launched_resources) in rows:
639
- if num_nodes is None or launched_resources is None:
640
- return None
641
- launched_resources = pickle.loads(launched_resources)
642
- return num_nodes, launched_resources
643
- return None
1555
+ if num_nodes is None or launched_resources is None:
1556
+ return None
1557
+ launched_resources = pickle.loads(launched_resources)
1558
+ return num_nodes, launched_resources
644
1559
 
645
1560
 
646
1561
  def _load_owner(record_owner: Optional[str]) -> Optional[List[str]]:
@@ -671,176 +1586,491 @@ def _load_storage_mounts_metadata(
671
1586
  return pickle.loads(record_storage_mounts_metadata)
672
1587
 
673
1588
 
1589
+ @_init_db
1590
+ @metrics_lib.time_me
1591
+ @context_utils.cancellation_guard
674
1592
  def get_cluster_from_name(
675
- cluster_name: Optional[str]) -> Optional[Dict[str, Any]]:
676
- rows = _DB.cursor.execute(
677
- 'SELECT name, launched_at, handle, last_use, status, autostop, '
678
- 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, '
679
- 'cluster_ever_up, status_updated_at, config_hash, user_hash '
680
- 'FROM clusters WHERE name=(?)', (cluster_name,)).fetchall()
681
- for row in rows:
682
- # Explicitly specify the number of fields to unpack, so that
683
- # we can add new fields to the database in the future without
684
- # breaking the previous code.
685
- (name, launched_at, handle, last_use, status, autostop, metadata,
686
- to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up,
687
- status_updated_at, config_hash, user_hash) = row
688
- user_hash = _get_user_hash_or_current_user(user_hash)
689
- # TODO: use namedtuple instead of dict
690
- record = {
691
- 'name': name,
692
- 'launched_at': launched_at,
693
- 'handle': pickle.loads(handle),
694
- 'last_use': last_use,
695
- 'status': status_lib.ClusterStatus[status],
696
- 'autostop': autostop,
697
- 'to_down': bool(to_down),
698
- 'owner': _load_owner(owner),
699
- 'metadata': json.loads(metadata),
700
- 'cluster_hash': cluster_hash,
701
- 'storage_mounts_metadata':
702
- _load_storage_mounts_metadata(storage_mounts_metadata),
703
- 'cluster_ever_up': bool(cluster_ever_up),
704
- 'status_updated_at': status_updated_at,
705
- 'user_hash': user_hash,
706
- 'user_name': get_user(user_hash).name,
707
- 'config_hash': config_hash,
708
- }
709
- return record
710
- return None
711
-
1593
+ cluster_name: Optional[str],
1594
+ *,
1595
+ include_user_info: bool = True,
1596
+ summary_response: bool = False) -> Optional[Dict[str, Any]]:
1597
+ assert _SQLALCHEMY_ENGINE is not None
1598
+ query_fields = [
1599
+ cluster_table.c.name,
1600
+ cluster_table.c.launched_at,
1601
+ cluster_table.c.handle,
1602
+ cluster_table.c.last_use,
1603
+ cluster_table.c.status,
1604
+ cluster_table.c.autostop,
1605
+ cluster_table.c.to_down,
1606
+ cluster_table.c.owner,
1607
+ cluster_table.c.metadata,
1608
+ cluster_table.c.cluster_hash,
1609
+ cluster_table.c.cluster_ever_up,
1610
+ cluster_table.c.status_updated_at,
1611
+ cluster_table.c.user_hash,
1612
+ cluster_table.c.config_hash,
1613
+ cluster_table.c.workspace,
1614
+ cluster_table.c.is_managed,
1615
+ ]
1616
+ if not summary_response:
1617
+ query_fields.extend([
1618
+ cluster_table.c.last_creation_yaml,
1619
+ cluster_table.c.last_creation_command,
1620
+ ])
1621
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1622
+ query = session.query(*query_fields)
1623
+ row = query.filter_by(name=cluster_name).first()
1624
+ if row is None:
1625
+ return None
1626
+ if include_user_info:
1627
+ user_hash = _get_user_hash_or_current_user(row.user_hash)
1628
+ user = get_user(user_hash)
1629
+ user_name = user.name if user is not None else None
1630
+ if not summary_response:
1631
+ last_event = get_last_cluster_event(
1632
+ row.cluster_hash, event_type=ClusterEventType.STATUS_CHANGE)
1633
+ # TODO: use namedtuple instead of dict
1634
+ record = {
1635
+ 'name': row.name,
1636
+ 'launched_at': row.launched_at,
1637
+ 'handle': pickle.loads(row.handle),
1638
+ 'last_use': row.last_use,
1639
+ 'status': status_lib.ClusterStatus[row.status],
1640
+ 'autostop': row.autostop,
1641
+ 'to_down': bool(row.to_down),
1642
+ 'owner': _load_owner(row.owner),
1643
+ 'metadata': json.loads(row.metadata),
1644
+ 'cluster_hash': row.cluster_hash,
1645
+ 'cluster_ever_up': bool(row.cluster_ever_up),
1646
+ 'status_updated_at': row.status_updated_at,
1647
+ 'workspace': row.workspace,
1648
+ 'is_managed': bool(row.is_managed),
1649
+ 'config_hash': row.config_hash,
1650
+ }
1651
+ if not summary_response:
1652
+ record['last_creation_yaml'] = row.last_creation_yaml
1653
+ record['last_creation_command'] = row.last_creation_command
1654
+ record['last_event'] = last_event
1655
+ if include_user_info:
1656
+ record['user_hash'] = user_hash
1657
+ record['user_name'] = user_name
1658
+
1659
+ return record
1660
+
1661
+
1662
+ @_init_db
1663
+ @metrics_lib.time_me
1664
+ @context_utils.cancellation_guard
1665
+ def cluster_with_name_exists(cluster_name: str) -> bool:
1666
+ assert _SQLALCHEMY_ENGINE is not None
1667
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1668
+ row = session.query(
1669
+ cluster_table.c.name).filter_by(name=cluster_name).first()
1670
+ if row is None:
1671
+ return False
1672
+ return True
1673
+
1674
+
1675
+ @_init_db
1676
+ @metrics_lib.time_me
1677
+ def get_clusters(
1678
+ *, # keyword only separator
1679
+ exclude_managed_clusters: bool = False,
1680
+ workspaces_filter: Optional[Dict[str, Any]] = None,
1681
+ user_hashes_filter: Optional[Set[str]] = None,
1682
+ cluster_names: Optional[List[str]] = None,
1683
+ summary_response: bool = False,
1684
+ ) -> List[Dict[str, Any]]:
1685
+ """Get clusters from the database.
712
1686
 
713
- def get_clusters() -> List[Dict[str, Any]]:
714
- rows = _DB.cursor.execute(
715
- 'select name, launched_at, handle, last_use, status, autostop, '
716
- 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, '
717
- 'cluster_ever_up, status_updated_at, config_hash, user_hash '
718
- 'from clusters order by launched_at desc').fetchall()
1687
+ Args:
1688
+ exclude_managed_clusters: If True, exclude clusters that have
1689
+ is_managed field set to True.
1690
+ workspaces_filter: If specified, only include clusters
1691
+ that has workspace field set to one of the values.
1692
+ user_hashes_filter: If specified, only include clusters
1693
+ that has user_hash field set to one of the values.
1694
+ cluster_names: If specified, only include clusters
1695
+ that has name field set to one of the values.
1696
+ """
1697
+ # is a cluster has a null user_hash,
1698
+ # we treat it as belonging to the current user.
1699
+ current_user_hash = common_utils.get_user_hash()
1700
+ assert _SQLALCHEMY_ENGINE is not None
1701
+ query_fields = [
1702
+ cluster_table.c.name,
1703
+ cluster_table.c.launched_at,
1704
+ cluster_table.c.handle,
1705
+ cluster_table.c.status,
1706
+ cluster_table.c.autostop,
1707
+ cluster_table.c.to_down,
1708
+ cluster_table.c.cluster_hash,
1709
+ cluster_table.c.cluster_ever_up,
1710
+ cluster_table.c.user_hash,
1711
+ cluster_table.c.workspace,
1712
+ user_table.c.name.label('user_name'),
1713
+ ]
1714
+ if not summary_response:
1715
+ query_fields.extend([
1716
+ cluster_table.c.last_creation_yaml,
1717
+ cluster_table.c.last_creation_command,
1718
+ cluster_table.c.config_hash,
1719
+ cluster_table.c.owner,
1720
+ cluster_table.c.metadata,
1721
+ cluster_table.c.last_use,
1722
+ cluster_table.c.status_updated_at,
1723
+ ])
1724
+ if not exclude_managed_clusters:
1725
+ query_fields.append(cluster_table.c.is_managed)
1726
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1727
+ query = session.query(*query_fields).outerjoin(
1728
+ user_table, cluster_table.c.user_hash == user_table.c.id)
1729
+ if exclude_managed_clusters:
1730
+ query = query.filter(cluster_table.c.is_managed == int(False))
1731
+ if workspaces_filter is not None:
1732
+ query = query.filter(
1733
+ cluster_table.c.workspace.in_(workspaces_filter))
1734
+ if user_hashes_filter is not None:
1735
+ if current_user_hash in user_hashes_filter:
1736
+ # backwards compatibility for old clusters.
1737
+ # If current_user_hash is in user_hashes_filter, we include
1738
+ # clusters that have a null user_hash.
1739
+ query = query.filter(
1740
+ (cluster_table.c.user_hash.in_(user_hashes_filter) |
1741
+ (cluster_table.c.user_hash is None)))
1742
+ else:
1743
+ query = query.filter(
1744
+ cluster_table.c.user_hash.in_(user_hashes_filter))
1745
+ if cluster_names is not None:
1746
+ query = query.filter(cluster_table.c.name.in_(cluster_names))
1747
+ query = query.order_by(sqlalchemy.desc(cluster_table.c.launched_at))
1748
+ rows = query.all()
719
1749
  records = []
1750
+
1751
+ # Check if we need to fetch the current user's name,
1752
+ # for backwards compatibility, if user_hash is None.
1753
+ current_user_name = None
1754
+ needs_current_user = any(row.user_hash is None for row in rows)
1755
+ if needs_current_user:
1756
+ current_user = get_user(current_user_hash)
1757
+ current_user_name = (current_user.name
1758
+ if current_user is not None else None)
1759
+
1760
+ # get last cluster event for each row
1761
+ if not summary_response:
1762
+ cluster_hashes = {row.cluster_hash for row in rows}
1763
+ last_cluster_event_dict = _get_last_cluster_event_multiple(
1764
+ cluster_hashes, ClusterEventType.STATUS_CHANGE)
1765
+
720
1766
  for row in rows:
721
- (name, launched_at, handle, last_use, status, autostop, metadata,
722
- to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up,
723
- status_updated_at, config_hash, user_hash) = row
724
- user_hash = _get_user_hash_or_current_user(user_hash)
725
1767
  # TODO: use namedtuple instead of dict
726
1768
  record = {
727
- 'name': name,
728
- 'launched_at': launched_at,
729
- 'handle': pickle.loads(handle),
730
- 'last_use': last_use,
731
- 'status': status_lib.ClusterStatus[status],
732
- 'autostop': autostop,
733
- 'to_down': bool(to_down),
734
- 'owner': _load_owner(owner),
735
- 'metadata': json.loads(metadata),
736
- 'cluster_hash': cluster_hash,
737
- 'storage_mounts_metadata':
738
- _load_storage_mounts_metadata(storage_mounts_metadata),
739
- 'cluster_ever_up': bool(cluster_ever_up),
740
- 'status_updated_at': status_updated_at,
741
- 'user_hash': user_hash,
742
- 'user_name': get_user(user_hash).name,
743
- 'config_hash': config_hash,
1769
+ 'name': row.name,
1770
+ 'launched_at': row.launched_at,
1771
+ 'handle': pickle.loads(row.handle),
1772
+ 'status': status_lib.ClusterStatus[row.status],
1773
+ 'autostop': row.autostop,
1774
+ 'to_down': bool(row.to_down),
1775
+ 'cluster_hash': row.cluster_hash,
1776
+ 'cluster_ever_up': bool(row.cluster_ever_up),
1777
+ 'user_hash': (row.user_hash
1778
+ if row.user_hash is not None else current_user_hash),
1779
+ 'user_name': (row.user_name
1780
+ if row.user_name is not None else current_user_name),
1781
+ 'workspace': row.workspace,
1782
+ 'is_managed': False
1783
+ if exclude_managed_clusters else bool(row.is_managed),
744
1784
  }
1785
+ if not summary_response:
1786
+ record['last_creation_yaml'] = row.last_creation_yaml
1787
+ record['last_creation_command'] = row.last_creation_command
1788
+ record['last_event'] = last_cluster_event_dict.get(
1789
+ row.cluster_hash, None)
1790
+ record['config_hash'] = row.config_hash
1791
+ record['owner'] = _load_owner(row.owner)
1792
+ record['metadata'] = json.loads(row.metadata)
1793
+ record['last_use'] = row.last_use
1794
+ record['status_updated_at'] = row.status_updated_at
745
1795
 
746
1796
  records.append(record)
747
1797
  return records
748
1798
 
749
1799
 
750
- def get_clusters_from_history() -> List[Dict[str, Any]]:
751
- rows = _DB.cursor.execute(
752
- 'SELECT ch.cluster_hash, ch.name, ch.num_nodes, '
753
- 'ch.launched_resources, ch.usage_intervals, clusters.status, '
754
- 'ch.user_hash '
755
- 'FROM cluster_history ch '
756
- 'LEFT OUTER JOIN clusters '
757
- 'ON ch.cluster_hash=clusters.cluster_hash ').fetchall()
1800
+ @_init_db
1801
+ @metrics_lib.time_me
1802
+ def get_cluster_names(exclude_managed_clusters: bool = False,) -> List[str]:
1803
+ assert _SQLALCHEMY_ENGINE is not None
1804
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1805
+ query = session.query(cluster_table.c.name)
1806
+ if exclude_managed_clusters:
1807
+ query = query.filter(cluster_table.c.is_managed == int(False))
1808
+ rows = query.all()
1809
+ return [row[0] for row in rows]
758
1810
 
759
- # '(cluster_hash, name, num_nodes, requested_resources, '
760
- # 'launched_resources, usage_intervals) '
761
- records = []
762
1811
 
763
- for row in rows:
764
- # TODO: use namedtuple instead of dict
1812
+ @_init_db
1813
+ @metrics_lib.time_me
1814
+ def get_clusters_from_history(
1815
+ days: Optional[int] = None,
1816
+ abbreviate_response: bool = False,
1817
+ cluster_hashes: Optional[List[str]] = None) -> List[Dict[str, Any]]:
1818
+ """Get cluster reports from history.
765
1819
 
766
- (
767
- cluster_hash,
768
- name,
769
- num_nodes,
770
- launched_resources,
771
- usage_intervals,
772
- status,
773
- user_hash,
774
- ) = row[:7]
775
- user_hash = _get_user_hash_or_current_user(user_hash)
1820
+ Args:
1821
+ days: If specified, only include historical clusters (those not
1822
+ currently active) that were last used within the past 'days'
1823
+ days. Active clusters are always included regardless of this
1824
+ parameter.
776
1825
 
777
- if status is not None:
778
- status = status_lib.ClusterStatus[status]
1826
+ Returns:
1827
+ List of cluster records with history information.
1828
+ """
1829
+ assert _SQLALCHEMY_ENGINE is not None
1830
+
1831
+ current_user_hash = common_utils.get_user_hash()
1832
+
1833
+ # Prepare filtering parameters
1834
+ cutoff_time = 0
1835
+ if days is not None:
1836
+ cutoff_time = int(time.time()) - (days * 24 * 60 * 60)
1837
+
1838
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1839
+ # Explicitly select columns from both tables to avoid ambiguity
1840
+ if abbreviate_response:
1841
+ query = session.query(
1842
+ cluster_history_table.c.cluster_hash,
1843
+ cluster_history_table.c.name, cluster_history_table.c.num_nodes,
1844
+ cluster_history_table.c.launched_resources,
1845
+ cluster_history_table.c.usage_intervals,
1846
+ cluster_history_table.c.user_hash,
1847
+ cluster_history_table.c.workspace.label('history_workspace'),
1848
+ cluster_history_table.c.last_activity_time,
1849
+ cluster_history_table.c.launched_at, cluster_table.c.status,
1850
+ cluster_table.c.workspace)
1851
+ else:
1852
+ query = session.query(
1853
+ cluster_history_table.c.cluster_hash,
1854
+ cluster_history_table.c.name, cluster_history_table.c.num_nodes,
1855
+ cluster_history_table.c.launched_resources,
1856
+ cluster_history_table.c.usage_intervals,
1857
+ cluster_history_table.c.user_hash,
1858
+ cluster_history_table.c.last_creation_yaml,
1859
+ cluster_history_table.c.last_creation_command,
1860
+ cluster_history_table.c.workspace.label('history_workspace'),
1861
+ cluster_history_table.c.last_activity_time,
1862
+ cluster_history_table.c.launched_at, cluster_table.c.status,
1863
+ cluster_table.c.workspace)
1864
+
1865
+ query = query.select_from(
1866
+ cluster_history_table.join(cluster_table,
1867
+ cluster_history_table.c.cluster_hash ==
1868
+ cluster_table.c.cluster_hash,
1869
+ isouter=True))
1870
+
1871
+ # Only include clusters that are either active (status is not None)
1872
+ # or are within the cutoff time (cutoff_time <= last_activity_time).
1873
+ # If days is not specified, we include all clusters by setting
1874
+ # cutoff_time to 0.
1875
+ query = query.filter(
1876
+ (cluster_table.c.status.isnot(None) |
1877
+ (cluster_history_table.c.last_activity_time >= cutoff_time)))
1878
+
1879
+ # Order by launched_at descending (most recent first)
1880
+ query = query.order_by(
1881
+ sqlalchemy.desc(cluster_history_table.c.launched_at))
1882
+
1883
+ if cluster_hashes is not None:
1884
+ query = query.filter(
1885
+ cluster_history_table.c.cluster_hash.in_(cluster_hashes))
1886
+ rows = query.all()
1887
+
1888
+ usage_intervals_dict = {}
1889
+ row_to_user_hash = {}
1890
+ for row in rows:
1891
+ row_usage_intervals: List[Tuple[int, Optional[int]]] = []
1892
+ if row.usage_intervals:
1893
+ try:
1894
+ row_usage_intervals = pickle.loads(row.usage_intervals)
1895
+ except (pickle.PickleError, AttributeError):
1896
+ pass
1897
+ usage_intervals_dict[row.cluster_hash] = row_usage_intervals
1898
+ user_hash = (row.user_hash
1899
+ if row.user_hash is not None else current_user_hash)
1900
+ row_to_user_hash[row.cluster_hash] = user_hash
1901
+
1902
+ user_hashes = set(row_to_user_hash.values())
1903
+ user_hash_to_user = get_users(user_hashes)
1904
+ cluster_hashes = set(row_to_user_hash.keys())
1905
+ if not abbreviate_response:
1906
+ last_cluster_event_dict = _get_last_cluster_event_multiple(
1907
+ cluster_hashes, ClusterEventType.STATUS_CHANGE)
1908
+
1909
+ records = []
1910
+ for row in rows:
1911
+ user_hash = row_to_user_hash[row.cluster_hash]
1912
+ user = user_hash_to_user.get(user_hash, None)
1913
+ user_name = user.name if user is not None else None
1914
+ if not abbreviate_response:
1915
+ last_event = last_cluster_event_dict.get(row.cluster_hash, None)
1916
+ launched_at = row.launched_at
1917
+ usage_intervals: Optional[List[Tuple[
1918
+ int,
1919
+ Optional[int]]]] = usage_intervals_dict.get(row.cluster_hash, None)
1920
+ duration = _get_cluster_duration(usage_intervals)
1921
+
1922
+ # Parse status
1923
+ status = None
1924
+ if row.status:
1925
+ status = status_lib.ClusterStatus[row.status]
1926
+
1927
+ # Parse launched resources safely
1928
+ launched_resources = None
1929
+ if row.launched_resources:
1930
+ try:
1931
+ launched_resources = pickle.loads(row.launched_resources)
1932
+ except (pickle.PickleError, AttributeError):
1933
+ launched_resources = None
1934
+
1935
+ workspace = (row.history_workspace
1936
+ if row.history_workspace else row.workspace)
779
1937
 
780
1938
  record = {
781
- 'name': name,
782
- 'launched_at': _get_cluster_launch_time(cluster_hash),
783
- 'duration': _get_cluster_duration(cluster_hash),
784
- 'num_nodes': num_nodes,
785
- 'resources': pickle.loads(launched_resources),
786
- 'cluster_hash': cluster_hash,
787
- 'usage_intervals': pickle.loads(usage_intervals),
1939
+ 'name': row.name,
1940
+ 'launched_at': launched_at,
1941
+ 'duration': duration,
1942
+ 'num_nodes': row.num_nodes,
1943
+ 'resources': launched_resources,
1944
+ 'cluster_hash': row.cluster_hash,
1945
+ 'usage_intervals': usage_intervals,
788
1946
  'status': status,
789
1947
  'user_hash': user_hash,
1948
+ 'user_name': user_name,
1949
+ 'workspace': workspace,
790
1950
  }
1951
+ if not abbreviate_response:
1952
+ record['last_creation_yaml'] = row.last_creation_yaml
1953
+ record['last_creation_command'] = row.last_creation_command
1954
+ record['last_event'] = last_event
791
1955
 
792
1956
  records.append(record)
793
1957
 
794
1958
  # sort by launch time, descending in recency
795
- records = sorted(records, key=lambda record: -record['launched_at'])
1959
+ records = sorted(records, key=lambda record: -(record['launched_at'] or 0))
796
1960
  return records
797
1961
 
798
1962
 
1963
+ @_init_db
1964
+ @metrics_lib.time_me
799
1965
  def get_cluster_names_start_with(starts_with: str) -> List[str]:
800
- rows = _DB.cursor.execute('SELECT name FROM clusters WHERE name LIKE (?)',
801
- (f'{starts_with}%',))
1966
+ assert _SQLALCHEMY_ENGINE is not None
1967
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1968
+ rows = session.query(cluster_table.c.name).filter(
1969
+ cluster_table.c.name.like(f'{starts_with}%')).all()
802
1970
  return [row[0] for row in rows]
803
1971
 
804
1972
 
805
- def get_cached_enabled_clouds(
806
- cloud_capability: 'cloud.CloudCapability') -> List['clouds.Cloud']:
807
-
808
- rows = _DB.cursor.execute('SELECT value FROM config WHERE key = ?',
809
- (_get_capability_key(cloud_capability),))
1973
+ @_init_db
1974
+ @metrics_lib.time_me
1975
+ def get_cached_enabled_clouds(cloud_capability: 'cloud.CloudCapability',
1976
+ workspace: str) -> List['clouds.Cloud']:
1977
+ assert _SQLALCHEMY_ENGINE is not None
1978
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
1979
+ row = session.query(config_table).filter_by(
1980
+ key=_get_enabled_clouds_key(cloud_capability, workspace)).first()
810
1981
  ret = []
811
- for (value,) in rows:
812
- ret = json.loads(value)
813
- break
1982
+ if row:
1983
+ ret = json.loads(row.value)
814
1984
  enabled_clouds: List['clouds.Cloud'] = []
815
1985
  for c in ret:
816
1986
  try:
817
1987
  cloud = registry.CLOUD_REGISTRY.from_str(c)
818
1988
  except ValueError:
819
- # Handle the case for the clouds whose support has been removed from
820
- # SkyPilot, e.g., 'local' was a cloud in the past and may be stored
821
- # in the database for users before #3037. We should ignore removed
822
- # clouds and continue.
1989
+ # Handle the case for the clouds whose support has been
1990
+ # removed from SkyPilot, e.g., 'local' was a cloud in the past
1991
+ # and may be stored in the database for users before #3037.
1992
+ # We should ignore removed clouds and continue.
823
1993
  continue
824
1994
  if cloud is not None:
825
1995
  enabled_clouds.append(cloud)
826
1996
  return enabled_clouds
827
1997
 
828
1998
 
1999
+ @_init_db
2000
+ @metrics_lib.time_me
829
2001
  def set_enabled_clouds(enabled_clouds: List[str],
830
- cloud_capability: 'cloud.CloudCapability') -> None:
831
- _DB.cursor.execute(
832
- 'INSERT OR REPLACE INTO config VALUES (?, ?)',
833
- (_get_capability_key(cloud_capability), json.dumps(enabled_clouds)))
834
- _DB.conn.commit()
835
-
836
-
837
- def _get_capability_key(cloud_capability: 'cloud.CloudCapability') -> str:
838
- return _ENABLED_CLOUDS_KEY_PREFIX + cloud_capability.value
839
-
840
-
2002
+ cloud_capability: 'cloud.CloudCapability',
2003
+ workspace: str) -> None:
2004
+ assert _SQLALCHEMY_ENGINE is not None
2005
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2006
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
2007
+ db_utils.SQLAlchemyDialect.SQLITE.value):
2008
+ insert_func = sqlite.insert
2009
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
2010
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
2011
+ insert_func = postgresql.insert
2012
+ else:
2013
+ raise ValueError('Unsupported database dialect')
2014
+ insert_stmnt = insert_func(config_table).values(
2015
+ key=_get_enabled_clouds_key(cloud_capability, workspace),
2016
+ value=json.dumps(enabled_clouds))
2017
+ do_update_stmt = insert_stmnt.on_conflict_do_update(
2018
+ index_elements=[config_table.c.key],
2019
+ set_={config_table.c.value: json.dumps(enabled_clouds)})
2020
+ session.execute(do_update_stmt)
2021
+ session.commit()
2022
+
2023
+
2024
+ def _get_enabled_clouds_key(cloud_capability: 'cloud.CloudCapability',
2025
+ workspace: str) -> str:
2026
+ return _ENABLED_CLOUDS_KEY_PREFIX + workspace + '_' + cloud_capability.value
2027
+
2028
+
2029
+ @_init_db
2030
+ @metrics_lib.time_me
2031
+ def get_allowed_clouds(workspace: str) -> List[str]:
2032
+ assert _SQLALCHEMY_ENGINE is not None
2033
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2034
+ row = session.query(config_table).filter_by(
2035
+ key=_get_allowed_clouds_key(workspace)).first()
2036
+ if row:
2037
+ return json.loads(row.value)
2038
+ return []
2039
+
2040
+
2041
+ @_init_db
2042
+ @metrics_lib.time_me
2043
+ def set_allowed_clouds(allowed_clouds: List[str], workspace: str) -> None:
2044
+ assert _SQLALCHEMY_ENGINE is not None
2045
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2046
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
2047
+ db_utils.SQLAlchemyDialect.SQLITE.value):
2048
+ insert_func = sqlite.insert
2049
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
2050
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
2051
+ insert_func = postgresql.insert
2052
+ else:
2053
+ raise ValueError('Unsupported database dialect')
2054
+ insert_stmnt = insert_func(config_table).values(
2055
+ key=_get_allowed_clouds_key(workspace),
2056
+ value=json.dumps(allowed_clouds))
2057
+ do_update_stmt = insert_stmnt.on_conflict_do_update(
2058
+ index_elements=[config_table.c.key],
2059
+ set_={config_table.c.value: json.dumps(allowed_clouds)})
2060
+ session.execute(do_update_stmt)
2061
+ session.commit()
2062
+
2063
+
2064
+ def _get_allowed_clouds_key(workspace: str) -> str:
2065
+ return _ALLOWED_CLOUDS_KEY_PREFIX + workspace
2066
+
2067
+
2068
+ @_init_db
2069
+ @metrics_lib.time_me
841
2070
  def add_or_update_storage(storage_name: str,
842
2071
  storage_handle: 'Storage.StorageMetadata',
843
2072
  storage_status: status_lib.StorageStatus):
2073
+ assert _SQLALCHEMY_ENGINE is not None
844
2074
  storage_launched_at = int(time.time())
845
2075
  handle = pickle.dumps(storage_handle)
846
2076
  last_use = common_utils.get_current_command()
@@ -851,89 +2081,663 @@ def add_or_update_storage(storage_name: str,
851
2081
  if not status_check(storage_status):
852
2082
  raise ValueError(f'Error in updating global state. Storage Status '
853
2083
  f'{storage_status} is passed in incorrectly')
854
- _DB.cursor.execute('INSERT OR REPLACE INTO storage VALUES (?, ?, ?, ?, ?)',
855
- (storage_name, storage_launched_at, handle, last_use,
856
- storage_status.value))
857
- _DB.conn.commit()
858
-
859
-
2084
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2085
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
2086
+ db_utils.SQLAlchemyDialect.SQLITE.value):
2087
+ insert_func = sqlite.insert
2088
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
2089
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
2090
+ insert_func = postgresql.insert
2091
+ else:
2092
+ raise ValueError('Unsupported database dialect')
2093
+ insert_stmnt = insert_func(storage_table).values(
2094
+ name=storage_name,
2095
+ handle=handle,
2096
+ last_use=last_use,
2097
+ launched_at=storage_launched_at,
2098
+ status=storage_status.value)
2099
+ do_update_stmt = insert_stmnt.on_conflict_do_update(
2100
+ index_elements=[storage_table.c.name],
2101
+ set_={
2102
+ storage_table.c.handle: handle,
2103
+ storage_table.c.last_use: last_use,
2104
+ storage_table.c.launched_at: storage_launched_at,
2105
+ storage_table.c.status: storage_status.value
2106
+ })
2107
+ session.execute(do_update_stmt)
2108
+ session.commit()
2109
+
2110
+
2111
+ @_init_db
2112
+ @metrics_lib.time_me
860
2113
  def remove_storage(storage_name: str):
861
2114
  """Removes Storage from Database"""
862
- _DB.cursor.execute('DELETE FROM storage WHERE name=(?)', (storage_name,))
863
- _DB.conn.commit()
2115
+ assert _SQLALCHEMY_ENGINE is not None
2116
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2117
+ session.query(storage_table).filter_by(name=storage_name).delete()
2118
+ session.commit()
864
2119
 
865
2120
 
2121
+ @_init_db
2122
+ @metrics_lib.time_me
866
2123
  def set_storage_status(storage_name: str,
867
2124
  status: status_lib.StorageStatus) -> None:
868
- _DB.cursor.execute('UPDATE storage SET status=(?) WHERE name=(?)', (
869
- status.value,
870
- storage_name,
871
- ))
872
- count = _DB.cursor.rowcount
873
- _DB.conn.commit()
2125
+ assert _SQLALCHEMY_ENGINE is not None
2126
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2127
+ count = session.query(storage_table).filter_by(
2128
+ name=storage_name).update({storage_table.c.status: status.value})
2129
+ session.commit()
874
2130
  assert count <= 1, count
875
2131
  if count == 0:
876
2132
  raise ValueError(f'Storage {storage_name} not found.')
877
2133
 
878
2134
 
2135
+ @_init_db
2136
+ @metrics_lib.time_me
879
2137
  def get_storage_status(storage_name: str) -> Optional[status_lib.StorageStatus]:
2138
+ assert _SQLALCHEMY_ENGINE is not None
880
2139
  assert storage_name is not None, 'storage_name cannot be None'
881
- rows = _DB.cursor.execute('SELECT status FROM storage WHERE name=(?)',
882
- (storage_name,))
883
- for (status,) in rows:
884
- return status_lib.StorageStatus[status]
2140
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2141
+ row = session.query(storage_table).filter_by(name=storage_name).first()
2142
+ if row:
2143
+ return status_lib.StorageStatus[row.status]
885
2144
  return None
886
2145
 
887
2146
 
2147
+ @_init_db
2148
+ @metrics_lib.time_me
888
2149
  def set_storage_handle(storage_name: str,
889
2150
  handle: 'Storage.StorageMetadata') -> None:
890
- _DB.cursor.execute('UPDATE storage SET handle=(?) WHERE name=(?)', (
891
- pickle.dumps(handle),
892
- storage_name,
893
- ))
894
- count = _DB.cursor.rowcount
895
- _DB.conn.commit()
2151
+ assert _SQLALCHEMY_ENGINE is not None
2152
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2153
+ count = session.query(storage_table).filter_by(
2154
+ name=storage_name).update(
2155
+ {storage_table.c.handle: pickle.dumps(handle)})
2156
+ session.commit()
896
2157
  assert count <= 1, count
897
2158
  if count == 0:
898
2159
  raise ValueError(f'Storage{storage_name} not found.')
899
2160
 
900
2161
 
2162
+ @_init_db
2163
+ @metrics_lib.time_me
901
2164
  def get_handle_from_storage_name(
902
2165
  storage_name: Optional[str]) -> Optional['Storage.StorageMetadata']:
2166
+ assert _SQLALCHEMY_ENGINE is not None
903
2167
  if storage_name is None:
904
2168
  return None
905
- rows = _DB.cursor.execute('SELECT handle FROM storage WHERE name=(?)',
906
- (storage_name,))
907
- for (handle,) in rows:
908
- if handle is None:
909
- return None
910
- return pickle.loads(handle)
2169
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2170
+ row = session.query(storage_table).filter_by(name=storage_name).first()
2171
+ if row:
2172
+ return pickle.loads(row.handle)
911
2173
  return None
912
2174
 
913
2175
 
2176
+ @_init_db
2177
+ @metrics_lib.time_me
914
2178
  def get_glob_storage_name(storage_name: str) -> List[str]:
2179
+ assert _SQLALCHEMY_ENGINE is not None
915
2180
  assert storage_name is not None, 'storage_name cannot be None'
916
- rows = _DB.cursor.execute('SELECT name FROM storage WHERE name GLOB (?)',
917
- (storage_name,))
918
- return [row[0] for row in rows]
919
-
920
-
2181
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2182
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
2183
+ db_utils.SQLAlchemyDialect.SQLITE.value):
2184
+ rows = session.query(storage_table).filter(
2185
+ storage_table.c.name.op('GLOB')(storage_name)).all()
2186
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
2187
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
2188
+ rows = session.query(storage_table).filter(
2189
+ storage_table.c.name.op('SIMILAR TO')(
2190
+ _glob_to_similar(storage_name))).all()
2191
+ else:
2192
+ raise ValueError('Unsupported database dialect')
2193
+ return [row.name for row in rows]
2194
+
2195
+
2196
+ @_init_db
2197
+ @metrics_lib.time_me
921
2198
  def get_storage_names_start_with(starts_with: str) -> List[str]:
922
- rows = _DB.cursor.execute('SELECT name FROM storage WHERE name LIKE (?)',
923
- (f'{starts_with}%',))
924
- return [row[0] for row in rows]
2199
+ assert _SQLALCHEMY_ENGINE is not None
2200
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2201
+ rows = session.query(storage_table).filter(
2202
+ storage_table.c.name.like(f'{starts_with}%')).all()
2203
+ return [row.name for row in rows]
925
2204
 
926
2205
 
2206
+ @_init_db
2207
+ @metrics_lib.time_me
927
2208
  def get_storage() -> List[Dict[str, Any]]:
928
- rows = _DB.cursor.execute('SELECT * FROM storage')
2209
+ assert _SQLALCHEMY_ENGINE is not None
2210
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2211
+ rows = session.query(storage_table).all()
929
2212
  records = []
930
- for name, launched_at, handle, last_use, status in rows:
2213
+ for row in rows:
931
2214
  # TODO: use namedtuple instead of dict
932
2215
  records.append({
933
- 'name': name,
934
- 'launched_at': launched_at,
935
- 'handle': pickle.loads(handle),
936
- 'last_use': last_use,
937
- 'status': status_lib.StorageStatus[status],
2216
+ 'name': row.name,
2217
+ 'launched_at': row.launched_at,
2218
+ 'handle': pickle.loads(row.handle),
2219
+ 'last_use': row.last_use,
2220
+ 'status': status_lib.StorageStatus[row.status],
938
2221
  })
939
2222
  return records
2223
+
2224
+
2225
+ @_init_db
2226
+ @metrics_lib.time_me
2227
+ def get_volume_names_start_with(starts_with: str) -> List[str]:
2228
+ assert _SQLALCHEMY_ENGINE is not None
2229
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2230
+ rows = session.query(volume_table).filter(
2231
+ volume_table.c.name.like(f'{starts_with}%')).all()
2232
+ return [row.name for row in rows]
2233
+
2234
+
2235
+ @_init_db
2236
+ @metrics_lib.time_me
2237
+ def get_volumes(is_ephemeral: Optional[bool] = None) -> List[Dict[str, Any]]:
2238
+ assert _SQLALCHEMY_ENGINE is not None
2239
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2240
+ if is_ephemeral is None:
2241
+ rows = session.query(volume_table).all()
2242
+ else:
2243
+ rows = session.query(volume_table).filter_by(
2244
+ is_ephemeral=is_ephemeral).all()
2245
+ records = []
2246
+ for row in rows:
2247
+ records.append({
2248
+ 'name': row.name,
2249
+ 'launched_at': row.launched_at,
2250
+ 'handle': pickle.loads(row.handle),
2251
+ 'user_hash': row.user_hash,
2252
+ 'workspace': row.workspace,
2253
+ 'last_attached_at': row.last_attached_at,
2254
+ 'last_use': row.last_use,
2255
+ 'status': status_lib.VolumeStatus[row.status],
2256
+ 'is_ephemeral': row.is_ephemeral,
2257
+ })
2258
+ return records
2259
+
2260
+
2261
+ @_init_db
2262
+ @metrics_lib.time_me
2263
+ def get_volume_by_name(name: str) -> Optional[Dict[str, Any]]:
2264
+ assert _SQLALCHEMY_ENGINE is not None
2265
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2266
+ row = session.query(volume_table).filter_by(name=name).first()
2267
+ if row:
2268
+ return {
2269
+ 'name': row.name,
2270
+ 'launched_at': row.launched_at,
2271
+ 'handle': pickle.loads(row.handle),
2272
+ 'user_hash': row.user_hash,
2273
+ 'workspace': row.workspace,
2274
+ 'last_attached_at': row.last_attached_at,
2275
+ 'last_use': row.last_use,
2276
+ 'status': status_lib.VolumeStatus[row.status],
2277
+ }
2278
+ return None
2279
+
2280
+
2281
+ @_init_db
2282
+ @metrics_lib.time_me
2283
+ def add_volume(
2284
+ name: str,
2285
+ config: models.VolumeConfig,
2286
+ status: status_lib.VolumeStatus,
2287
+ is_ephemeral: bool = False,
2288
+ ) -> None:
2289
+ assert _SQLALCHEMY_ENGINE is not None
2290
+ volume_launched_at = int(time.time())
2291
+ handle = pickle.dumps(config)
2292
+ last_use = common_utils.get_current_command()
2293
+ user_hash = common_utils.get_current_user().id
2294
+ active_workspace = skypilot_config.get_active_workspace()
2295
+ if is_ephemeral:
2296
+ last_attached_at = int(time.time())
2297
+ status = status_lib.VolumeStatus.IN_USE
2298
+ else:
2299
+ last_attached_at = None
2300
+
2301
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2302
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
2303
+ db_utils.SQLAlchemyDialect.SQLITE.value):
2304
+ insert_func = sqlite.insert
2305
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
2306
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
2307
+ insert_func = postgresql.insert
2308
+ else:
2309
+ raise ValueError('Unsupported database dialect')
2310
+ insert_stmnt = insert_func(volume_table).values(
2311
+ name=name,
2312
+ launched_at=volume_launched_at,
2313
+ handle=handle,
2314
+ user_hash=user_hash,
2315
+ workspace=active_workspace,
2316
+ last_attached_at=last_attached_at,
2317
+ last_use=last_use,
2318
+ status=status.value,
2319
+ is_ephemeral=is_ephemeral,
2320
+ )
2321
+ do_update_stmt = insert_stmnt.on_conflict_do_nothing()
2322
+ session.execute(do_update_stmt)
2323
+ session.commit()
2324
+
2325
+
2326
+ @_init_db
2327
+ @metrics_lib.time_me
2328
+ def update_volume(name: str, last_attached_at: int,
2329
+ status: status_lib.VolumeStatus) -> None:
2330
+ assert _SQLALCHEMY_ENGINE is not None
2331
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2332
+ session.query(volume_table).filter_by(name=name).update({
2333
+ volume_table.c.last_attached_at: last_attached_at,
2334
+ volume_table.c.status: status.value,
2335
+ })
2336
+ session.commit()
2337
+
2338
+
2339
+ @_init_db
2340
+ @metrics_lib.time_me
2341
+ def update_volume_status(name: str, status: status_lib.VolumeStatus) -> None:
2342
+ assert _SQLALCHEMY_ENGINE is not None
2343
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2344
+ session.query(volume_table).filter_by(name=name).update({
2345
+ volume_table.c.status: status.value,
2346
+ })
2347
+ session.commit()
2348
+
2349
+
2350
+ @_init_db
2351
+ @metrics_lib.time_me
2352
+ def delete_volume(name: str) -> None:
2353
+ assert _SQLALCHEMY_ENGINE is not None
2354
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2355
+ session.query(volume_table).filter_by(name=name).delete()
2356
+ session.commit()
2357
+
2358
+
2359
+ @_init_db
2360
+ @metrics_lib.time_me
2361
+ def get_ssh_keys(user_hash: str) -> Tuple[str, str, bool]:
2362
+ assert _SQLALCHEMY_ENGINE is not None
2363
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2364
+ row = session.query(ssh_key_table).filter_by(
2365
+ user_hash=user_hash).first()
2366
+ if row:
2367
+ return row.ssh_public_key, row.ssh_private_key, True
2368
+ return '', '', False
2369
+
2370
+
2371
+ @_init_db
2372
+ @metrics_lib.time_me
2373
+ def set_ssh_keys(user_hash: str, ssh_public_key: str, ssh_private_key: str):
2374
+ assert _SQLALCHEMY_ENGINE is not None
2375
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2376
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
2377
+ db_utils.SQLAlchemyDialect.SQLITE.value):
2378
+ insert_func = sqlite.insert
2379
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
2380
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
2381
+ insert_func = postgresql.insert
2382
+ else:
2383
+ raise ValueError('Unsupported database dialect')
2384
+ insert_stmnt = insert_func(ssh_key_table).values(
2385
+ user_hash=user_hash,
2386
+ ssh_public_key=ssh_public_key,
2387
+ ssh_private_key=ssh_private_key)
2388
+ do_update_stmt = insert_stmnt.on_conflict_do_update(
2389
+ index_elements=[ssh_key_table.c.user_hash],
2390
+ set_={
2391
+ ssh_key_table.c.ssh_public_key: ssh_public_key,
2392
+ ssh_key_table.c.ssh_private_key: ssh_private_key
2393
+ })
2394
+ session.execute(do_update_stmt)
2395
+ session.commit()
2396
+
2397
+
2398
+ @_init_db
2399
+ @metrics_lib.time_me
2400
+ def add_service_account_token(token_id: str,
2401
+ token_name: str,
2402
+ token_hash: str,
2403
+ creator_user_hash: str,
2404
+ service_account_user_id: str,
2405
+ expires_at: Optional[int] = None) -> None:
2406
+ """Add a service account token to the database."""
2407
+ assert _SQLALCHEMY_ENGINE is not None
2408
+ created_at = int(time.time())
2409
+
2410
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2411
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
2412
+ db_utils.SQLAlchemyDialect.SQLITE.value):
2413
+ insert_func = sqlite.insert
2414
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
2415
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
2416
+ insert_func = postgresql.insert
2417
+ else:
2418
+ raise ValueError('Unsupported database dialect')
2419
+
2420
+ insert_stmnt = insert_func(service_account_token_table).values(
2421
+ token_id=token_id,
2422
+ token_name=token_name,
2423
+ token_hash=token_hash,
2424
+ created_at=created_at,
2425
+ expires_at=expires_at,
2426
+ creator_user_hash=creator_user_hash,
2427
+ service_account_user_id=service_account_user_id)
2428
+ session.execute(insert_stmnt)
2429
+ session.commit()
2430
+
2431
+
2432
+ @_init_db
2433
+ @metrics_lib.time_me
2434
+ def get_service_account_token(token_id: str) -> Optional[Dict[str, Any]]:
2435
+ """Get a service account token by token_id."""
2436
+ assert _SQLALCHEMY_ENGINE is not None
2437
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2438
+ row = session.query(service_account_token_table).filter_by(
2439
+ token_id=token_id).first()
2440
+ if row is None:
2441
+ return None
2442
+ return {
2443
+ 'token_id': row.token_id,
2444
+ 'token_name': row.token_name,
2445
+ 'token_hash': row.token_hash,
2446
+ 'created_at': row.created_at,
2447
+ 'last_used_at': row.last_used_at,
2448
+ 'expires_at': row.expires_at,
2449
+ 'creator_user_hash': row.creator_user_hash,
2450
+ 'service_account_user_id': row.service_account_user_id,
2451
+ }
2452
+
2453
+
2454
+ @_init_db
2455
+ @metrics_lib.time_me
2456
+ def get_user_service_account_tokens(user_hash: str) -> List[Dict[str, Any]]:
2457
+ """Get all service account tokens for a user (as creator)."""
2458
+ assert _SQLALCHEMY_ENGINE is not None
2459
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2460
+ rows = session.query(service_account_token_table).filter_by(
2461
+ creator_user_hash=user_hash).all()
2462
+ return [{
2463
+ 'token_id': row.token_id,
2464
+ 'token_name': row.token_name,
2465
+ 'token_hash': row.token_hash,
2466
+ 'created_at': row.created_at,
2467
+ 'last_used_at': row.last_used_at,
2468
+ 'expires_at': row.expires_at,
2469
+ 'creator_user_hash': row.creator_user_hash,
2470
+ 'service_account_user_id': row.service_account_user_id,
2471
+ } for row in rows]
2472
+
2473
+
2474
+ @_init_db
2475
+ @metrics_lib.time_me
2476
+ def update_service_account_token_last_used(token_id: str) -> None:
2477
+ """Update the last_used_at timestamp for a service account token."""
2478
+ assert _SQLALCHEMY_ENGINE is not None
2479
+ last_used_at = int(time.time())
2480
+
2481
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2482
+ session.query(service_account_token_table).filter_by(
2483
+ token_id=token_id).update(
2484
+ {service_account_token_table.c.last_used_at: last_used_at})
2485
+ session.commit()
2486
+
2487
+
2488
+ @_init_db
2489
+ @metrics_lib.time_me
2490
+ def delete_service_account_token(token_id: str) -> bool:
2491
+ """Delete a service account token.
2492
+
2493
+ Returns:
2494
+ True if token was found and deleted.
2495
+ """
2496
+ assert _SQLALCHEMY_ENGINE is not None
2497
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2498
+ result = session.query(service_account_token_table).filter_by(
2499
+ token_id=token_id).delete()
2500
+ session.commit()
2501
+ return result > 0
2502
+
2503
+
2504
+ @_init_db
2505
+ @metrics_lib.time_me
2506
+ def rotate_service_account_token(token_id: str,
2507
+ new_token_hash: str,
2508
+ new_expires_at: Optional[int] = None) -> None:
2509
+ """Rotate a service account token by updating its hash and expiration.
2510
+
2511
+ Args:
2512
+ token_id: The token ID to rotate.
2513
+ new_token_hash: The new hashed token value.
2514
+ new_expires_at: New expiration timestamp, or None for no expiration.
2515
+ """
2516
+ assert _SQLALCHEMY_ENGINE is not None
2517
+ current_time = int(time.time())
2518
+
2519
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2520
+ count = session.query(service_account_token_table).filter_by(
2521
+ token_id=token_id
2522
+ ).update({
2523
+ service_account_token_table.c.token_hash: new_token_hash,
2524
+ service_account_token_table.c.expires_at: new_expires_at,
2525
+ service_account_token_table.c.last_used_at: None, # Reset last used
2526
+ # Update creation time
2527
+ service_account_token_table.c.created_at: current_time,
2528
+ })
2529
+ session.commit()
2530
+
2531
+ if count == 0:
2532
+ raise ValueError(f'Service account token {token_id} not found.')
2533
+
2534
+
2535
+ @_init_db
2536
+ @metrics_lib.time_me
2537
+ def get_cluster_yaml_str(cluster_yaml_path: Optional[str]) -> Optional[str]:
2538
+ """Get the cluster yaml from the database or the local file system.
2539
+ If the cluster yaml is not in the database, check if it exists on the
2540
+ local file system and migrate it to the database.
2541
+
2542
+ It is assumed that the cluster yaml file is named as <cluster_name>.yml.
2543
+ """
2544
+ assert _SQLALCHEMY_ENGINE is not None
2545
+ if cluster_yaml_path is None:
2546
+ raise ValueError('Attempted to read a None YAML.')
2547
+ cluster_file_name = os.path.basename(cluster_yaml_path)
2548
+ cluster_name, _ = os.path.splitext(cluster_file_name)
2549
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2550
+ row = session.query(cluster_yaml_table).filter_by(
2551
+ cluster_name=cluster_name).first()
2552
+ if row is None:
2553
+ return _set_cluster_yaml_from_file(cluster_yaml_path, cluster_name)
2554
+ return row.yaml
2555
+
2556
+
2557
+ def get_cluster_yaml_str_multiple(cluster_yaml_paths: List[str]) -> List[str]:
2558
+ """Get the cluster yaml from the database or the local file system.
2559
+ """
2560
+ assert _SQLALCHEMY_ENGINE is not None
2561
+ cluster_names_to_yaml_paths = {}
2562
+ for cluster_yaml_path in cluster_yaml_paths:
2563
+ cluster_name, _ = os.path.splitext(os.path.basename(cluster_yaml_path))
2564
+ cluster_names_to_yaml_paths[cluster_name] = cluster_yaml_path
2565
+
2566
+ cluster_names = list(cluster_names_to_yaml_paths.keys())
2567
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2568
+ rows = session.query(cluster_yaml_table).filter(
2569
+ cluster_yaml_table.c.cluster_name.in_(cluster_names)).all()
2570
+ row_cluster_names_to_yaml = {row.cluster_name: row.yaml for row in rows}
2571
+
2572
+ yaml_strs = []
2573
+ for cluster_name in cluster_names:
2574
+ if cluster_name in row_cluster_names_to_yaml:
2575
+ yaml_strs.append(row_cluster_names_to_yaml[cluster_name])
2576
+ else:
2577
+ yaml_str = _set_cluster_yaml_from_file(
2578
+ cluster_names_to_yaml_paths[cluster_name], cluster_name)
2579
+ yaml_strs.append(yaml_str)
2580
+ return yaml_strs
2581
+
2582
+
2583
+ def _set_cluster_yaml_from_file(cluster_yaml_path: str,
2584
+ cluster_name: str) -> Optional[str]:
2585
+ """Set the cluster yaml in the database from a file."""
2586
+ # If the cluster yaml is not in the database, check if it exists
2587
+ # on the local file system and migrate it to the database.
2588
+ # TODO(syang): remove this check once we have a way to migrate the
2589
+ # cluster from file to database. Remove on v0.12.0.
2590
+ if cluster_yaml_path is not None:
2591
+ # First try the exact path
2592
+ path_to_read = None
2593
+ if os.path.exists(cluster_yaml_path):
2594
+ path_to_read = cluster_yaml_path
2595
+ # Fallback: try with .debug suffix (when debug logging was enabled)
2596
+ # Debug logging causes YAML files to be saved with .debug suffix
2597
+ # but the path stored in the handle doesn't include it
2598
+ debug_path = cluster_yaml_path + '.debug'
2599
+ if os.path.exists(debug_path):
2600
+ path_to_read = debug_path
2601
+ if path_to_read is not None:
2602
+ with open(path_to_read, 'r', encoding='utf-8') as f:
2603
+ yaml_str = f.read()
2604
+ set_cluster_yaml(cluster_name, yaml_str)
2605
+ return yaml_str
2606
+ return None
2607
+
2608
+
2609
+ def get_cluster_yaml_dict(cluster_yaml_path: Optional[str]) -> Dict[str, Any]:
2610
+ """Get the cluster yaml as a dictionary from the database.
2611
+
2612
+ It is assumed that the cluster yaml file is named as <cluster_name>.yml.
2613
+ """
2614
+ yaml_str = get_cluster_yaml_str(cluster_yaml_path)
2615
+ if yaml_str is None:
2616
+ raise ValueError(f'Cluster yaml {cluster_yaml_path} not found.')
2617
+ return yaml_utils.safe_load(yaml_str)
2618
+
2619
+
2620
+ def get_cluster_yaml_dict_multiple(
2621
+ cluster_yaml_paths: List[str]) -> List[Dict[str, Any]]:
2622
+ """Get the cluster yaml as a dictionary from the database."""
2623
+ yaml_strs = get_cluster_yaml_str_multiple(cluster_yaml_paths)
2624
+ yaml_dicts = []
2625
+ for idx, yaml_str in enumerate(yaml_strs):
2626
+ if yaml_str is None:
2627
+ raise ValueError(
2628
+ f'Cluster yaml {cluster_yaml_paths[idx]} not found.')
2629
+ yaml_dicts.append(yaml_utils.safe_load(yaml_str))
2630
+ return yaml_dicts
2631
+
2632
+
2633
+ @_init_db
2634
+ @metrics_lib.time_me
2635
+ def set_cluster_yaml(cluster_name: str, yaml_str: str) -> None:
2636
+ """Set the cluster yaml in the database."""
2637
+ assert _SQLALCHEMY_ENGINE is not None
2638
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2639
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
2640
+ db_utils.SQLAlchemyDialect.SQLITE.value):
2641
+ insert_func = sqlite.insert
2642
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
2643
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
2644
+ insert_func = postgresql.insert
2645
+ else:
2646
+ raise ValueError('Unsupported database dialect')
2647
+ insert_stmnt = insert_func(cluster_yaml_table).values(
2648
+ cluster_name=cluster_name, yaml=yaml_str)
2649
+ do_update_stmt = insert_stmnt.on_conflict_do_update(
2650
+ index_elements=[cluster_yaml_table.c.cluster_name],
2651
+ set_={cluster_yaml_table.c.yaml: yaml_str})
2652
+ session.execute(do_update_stmt)
2653
+ session.commit()
2654
+
2655
+
2656
+ @_init_db
2657
+ @metrics_lib.time_me
2658
+ def remove_cluster_yaml(cluster_name: str):
2659
+ assert _SQLALCHEMY_ENGINE is not None
2660
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2661
+ session.query(cluster_yaml_table).filter_by(
2662
+ cluster_name=cluster_name).delete()
2663
+ session.commit()
2664
+
2665
+
2666
+ @_init_db
2667
+ @metrics_lib.time_me
2668
+ def get_all_service_account_tokens() -> List[Dict[str, Any]]:
2669
+ """Get all service account tokens across all users (for admin access)."""
2670
+ assert _SQLALCHEMY_ENGINE is not None
2671
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2672
+ rows = session.query(service_account_token_table).all()
2673
+ return [{
2674
+ 'token_id': row.token_id,
2675
+ 'token_name': row.token_name,
2676
+ 'token_hash': row.token_hash,
2677
+ 'created_at': row.created_at,
2678
+ 'last_used_at': row.last_used_at,
2679
+ 'expires_at': row.expires_at,
2680
+ 'creator_user_hash': row.creator_user_hash,
2681
+ 'service_account_user_id': row.service_account_user_id,
2682
+ } for row in rows]
2683
+
2684
+
2685
+ @_init_db
2686
+ @metrics_lib.time_me
2687
+ def get_system_config(config_key: str) -> Optional[str]:
2688
+ """Get a system configuration value by key."""
2689
+ assert _SQLALCHEMY_ENGINE is not None
2690
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2691
+ row = session.query(system_config_table).filter_by(
2692
+ config_key=config_key).first()
2693
+ if row is None:
2694
+ return None
2695
+ return row.config_value
2696
+
2697
+
2698
+ @_init_db
2699
+ @metrics_lib.time_me
2700
+ def set_system_config(config_key: str, config_value: str) -> None:
2701
+ """Set a system configuration value."""
2702
+ assert _SQLALCHEMY_ENGINE is not None
2703
+ current_time = int(time.time())
2704
+
2705
+ with orm.Session(_SQLALCHEMY_ENGINE) as session:
2706
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
2707
+ db_utils.SQLAlchemyDialect.SQLITE.value):
2708
+ insert_func = sqlite.insert
2709
+ elif (_SQLALCHEMY_ENGINE.dialect.name ==
2710
+ db_utils.SQLAlchemyDialect.POSTGRESQL.value):
2711
+ insert_func = postgresql.insert
2712
+ else:
2713
+ raise ValueError('Unsupported database dialect')
2714
+
2715
+ insert_stmnt = insert_func(system_config_table).values(
2716
+ config_key=config_key,
2717
+ config_value=config_value,
2718
+ created_at=current_time,
2719
+ updated_at=current_time)
2720
+
2721
+ upsert_stmnt = insert_stmnt.on_conflict_do_update(
2722
+ index_elements=[system_config_table.c.config_key],
2723
+ set_={
2724
+ system_config_table.c.config_value: config_value,
2725
+ system_config_table.c.updated_at: current_time,
2726
+ })
2727
+ session.execute(upsert_stmnt)
2728
+ session.commit()
2729
+
2730
+
2731
+ @_init_db
2732
+ def get_max_db_connections() -> Optional[int]:
2733
+ """Get the maximum number of connections for the engine."""
2734
+ assert _SQLALCHEMY_ENGINE is not None
2735
+ if (_SQLALCHEMY_ENGINE.dialect.name ==
2736
+ db_utils.SQLAlchemyDialect.SQLITE.value):
2737
+ return None
2738
+ with sqlalchemy.orm.Session(_SQLALCHEMY_ENGINE) as session:
2739
+ max_connections = session.execute(
2740
+ sqlalchemy.text('SHOW max_connections')).scalar()
2741
+ if max_connections is None:
2742
+ return None
2743
+ return int(max_connections)