skypilot-nightly 1.0.0.dev20250502__py3-none-any.whl → 1.0.0.dev20251203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. sky/__init__.py +22 -6
  2. sky/adaptors/aws.py +81 -16
  3. sky/adaptors/common.py +25 -2
  4. sky/adaptors/coreweave.py +278 -0
  5. sky/adaptors/do.py +8 -2
  6. sky/adaptors/gcp.py +11 -0
  7. sky/adaptors/hyperbolic.py +8 -0
  8. sky/adaptors/ibm.py +5 -2
  9. sky/adaptors/kubernetes.py +149 -18
  10. sky/adaptors/nebius.py +173 -30
  11. sky/adaptors/primeintellect.py +1 -0
  12. sky/adaptors/runpod.py +68 -0
  13. sky/adaptors/seeweb.py +183 -0
  14. sky/adaptors/shadeform.py +89 -0
  15. sky/admin_policy.py +187 -4
  16. sky/authentication.py +179 -225
  17. sky/backends/__init__.py +4 -2
  18. sky/backends/backend.py +22 -9
  19. sky/backends/backend_utils.py +1323 -397
  20. sky/backends/cloud_vm_ray_backend.py +1749 -1029
  21. sky/backends/docker_utils.py +1 -1
  22. sky/backends/local_docker_backend.py +11 -6
  23. sky/backends/task_codegen.py +633 -0
  24. sky/backends/wheel_utils.py +55 -9
  25. sky/{clouds/service_catalog → catalog}/__init__.py +21 -19
  26. sky/{clouds/service_catalog → catalog}/aws_catalog.py +27 -8
  27. sky/{clouds/service_catalog → catalog}/azure_catalog.py +10 -7
  28. sky/{clouds/service_catalog → catalog}/common.py +90 -49
  29. sky/{clouds/service_catalog → catalog}/cudo_catalog.py +8 -5
  30. sky/{clouds/service_catalog → catalog}/data_fetchers/analyze.py +1 -1
  31. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_aws.py +116 -80
  32. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_cudo.py +38 -38
  33. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_gcp.py +70 -16
  34. sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
  35. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_lambda_cloud.py +1 -0
  36. sky/catalog/data_fetchers/fetch_nebius.py +338 -0
  37. sky/catalog/data_fetchers/fetch_runpod.py +698 -0
  38. sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
  39. sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
  40. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vast.py +1 -1
  41. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vsphere.py +1 -1
  42. sky/{clouds/service_catalog → catalog}/do_catalog.py +5 -2
  43. sky/{clouds/service_catalog → catalog}/fluidstack_catalog.py +6 -3
  44. sky/{clouds/service_catalog → catalog}/gcp_catalog.py +41 -15
  45. sky/catalog/hyperbolic_catalog.py +136 -0
  46. sky/{clouds/service_catalog → catalog}/ibm_catalog.py +9 -6
  47. sky/{clouds/service_catalog → catalog}/kubernetes_catalog.py +36 -24
  48. sky/{clouds/service_catalog → catalog}/lambda_catalog.py +9 -6
  49. sky/{clouds/service_catalog → catalog}/nebius_catalog.py +9 -7
  50. sky/{clouds/service_catalog → catalog}/oci_catalog.py +9 -6
  51. sky/{clouds/service_catalog → catalog}/paperspace_catalog.py +5 -2
  52. sky/catalog/primeintellect_catalog.py +95 -0
  53. sky/{clouds/service_catalog → catalog}/runpod_catalog.py +11 -4
  54. sky/{clouds/service_catalog → catalog}/scp_catalog.py +9 -6
  55. sky/catalog/seeweb_catalog.py +184 -0
  56. sky/catalog/shadeform_catalog.py +165 -0
  57. sky/catalog/ssh_catalog.py +167 -0
  58. sky/{clouds/service_catalog → catalog}/vast_catalog.py +6 -3
  59. sky/{clouds/service_catalog → catalog}/vsphere_catalog.py +5 -2
  60. sky/check.py +533 -185
  61. sky/cli.py +5 -5975
  62. sky/client/{cli.py → cli/command.py} +2591 -1956
  63. sky/client/cli/deprecation_utils.py +99 -0
  64. sky/client/cli/flags.py +359 -0
  65. sky/client/cli/table_utils.py +322 -0
  66. sky/client/cli/utils.py +79 -0
  67. sky/client/common.py +78 -32
  68. sky/client/oauth.py +82 -0
  69. sky/client/sdk.py +1219 -319
  70. sky/client/sdk_async.py +827 -0
  71. sky/client/service_account_auth.py +47 -0
  72. sky/cloud_stores.py +82 -3
  73. sky/clouds/__init__.py +13 -0
  74. sky/clouds/aws.py +564 -164
  75. sky/clouds/azure.py +105 -83
  76. sky/clouds/cloud.py +140 -40
  77. sky/clouds/cudo.py +68 -50
  78. sky/clouds/do.py +66 -48
  79. sky/clouds/fluidstack.py +63 -44
  80. sky/clouds/gcp.py +339 -110
  81. sky/clouds/hyperbolic.py +293 -0
  82. sky/clouds/ibm.py +70 -49
  83. sky/clouds/kubernetes.py +570 -162
  84. sky/clouds/lambda_cloud.py +74 -54
  85. sky/clouds/nebius.py +210 -81
  86. sky/clouds/oci.py +88 -66
  87. sky/clouds/paperspace.py +61 -44
  88. sky/clouds/primeintellect.py +317 -0
  89. sky/clouds/runpod.py +164 -74
  90. sky/clouds/scp.py +89 -86
  91. sky/clouds/seeweb.py +477 -0
  92. sky/clouds/shadeform.py +400 -0
  93. sky/clouds/ssh.py +263 -0
  94. sky/clouds/utils/aws_utils.py +10 -4
  95. sky/clouds/utils/gcp_utils.py +87 -11
  96. sky/clouds/utils/oci_utils.py +38 -14
  97. sky/clouds/utils/scp_utils.py +231 -167
  98. sky/clouds/vast.py +99 -77
  99. sky/clouds/vsphere.py +51 -40
  100. sky/core.py +375 -173
  101. sky/dag.py +15 -0
  102. sky/dashboard/out/404.html +1 -1
  103. sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +1 -0
  104. sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
  105. sky/dashboard/out/_next/static/chunks/1272-1ef0bf0237faccdb.js +1 -0
  106. sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +6 -0
  107. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
  108. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +1 -0
  109. sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
  110. sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +26 -0
  111. sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +1 -0
  112. sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +1 -0
  113. sky/dashboard/out/_next/static/chunks/3800-7b45f9fbb6308557.js +1 -0
  114. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
  115. sky/dashboard/out/_next/static/chunks/3937.210053269f121201.js +1 -0
  116. sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +1 -0
  117. sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
  118. sky/dashboard/out/_next/static/chunks/5739-d67458fcb1386c92.js +8 -0
  119. sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
  120. sky/dashboard/out/_next/static/chunks/616-3d59f75e2ccf9321.js +39 -0
  121. sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
  122. sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +1 -0
  123. sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +1 -0
  124. sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +1 -0
  125. sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
  126. sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +41 -0
  127. sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +1 -0
  128. sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
  129. sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +1 -0
  130. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
  131. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
  132. sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +31 -0
  133. sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
  134. sky/dashboard/out/_next/static/chunks/fd9d1056-86323a29a8f7e46a.js +1 -0
  135. sky/dashboard/out/_next/static/chunks/framework-cf60a09ccd051a10.js +33 -0
  136. sky/dashboard/out/_next/static/chunks/main-app-587214043926b3cc.js +1 -0
  137. sky/dashboard/out/_next/static/chunks/main-f15ccb73239a3bf1.js +1 -0
  138. sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +34 -0
  139. sky/dashboard/out/_next/static/chunks/pages/_error-c66a4e8afc46f17b.js +1 -0
  140. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +16 -0
  141. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-abfcac9c137aa543.js +1 -0
  142. sky/dashboard/out/_next/static/chunks/pages/clusters-ee39056f9851a3ff.js +1 -0
  143. sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
  144. sky/dashboard/out/_next/static/chunks/pages/index-444f1804401f04ea.js +1 -0
  145. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +1 -0
  146. sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +1 -0
  147. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +16 -0
  148. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-9faf940b253e3e06.js +21 -0
  149. sky/dashboard/out/_next/static/chunks/pages/jobs-2072b48b617989c9.js +1 -0
  150. sky/dashboard/out/_next/static/chunks/pages/users-f42674164aa73423.js +1 -0
  151. sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
  152. sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
  153. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-84a40f8c7c627fe4.js +1 -0
  154. sky/dashboard/out/_next/static/chunks/pages/workspaces-531b2f8c4bf89f82.js +1 -0
  155. sky/dashboard/out/_next/static/chunks/webpack-64e05f17bf2cf8ce.js +1 -0
  156. sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
  157. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  158. sky/dashboard/out/clusters/[cluster].html +1 -1
  159. sky/dashboard/out/clusters.html +1 -1
  160. sky/dashboard/out/config.html +1 -0
  161. sky/dashboard/out/index.html +1 -1
  162. sky/dashboard/out/infra/[context].html +1 -0
  163. sky/dashboard/out/infra.html +1 -0
  164. sky/dashboard/out/jobs/[job].html +1 -1
  165. sky/dashboard/out/jobs/pools/[pool].html +1 -0
  166. sky/dashboard/out/jobs.html +1 -1
  167. sky/dashboard/out/users.html +1 -0
  168. sky/dashboard/out/volumes.html +1 -0
  169. sky/dashboard/out/workspace/new.html +1 -0
  170. sky/dashboard/out/workspaces/[name].html +1 -0
  171. sky/dashboard/out/workspaces.html +1 -0
  172. sky/data/data_utils.py +137 -1
  173. sky/data/mounting_utils.py +269 -84
  174. sky/data/storage.py +1460 -1807
  175. sky/data/storage_utils.py +43 -57
  176. sky/exceptions.py +126 -2
  177. sky/execution.py +216 -63
  178. sky/global_user_state.py +2390 -586
  179. sky/jobs/__init__.py +7 -0
  180. sky/jobs/client/sdk.py +300 -58
  181. sky/jobs/client/sdk_async.py +161 -0
  182. sky/jobs/constants.py +15 -8
  183. sky/jobs/controller.py +848 -275
  184. sky/jobs/file_content_utils.py +128 -0
  185. sky/jobs/log_gc.py +193 -0
  186. sky/jobs/recovery_strategy.py +402 -152
  187. sky/jobs/scheduler.py +314 -189
  188. sky/jobs/server/core.py +836 -255
  189. sky/jobs/server/server.py +156 -115
  190. sky/jobs/server/utils.py +136 -0
  191. sky/jobs/state.py +2109 -706
  192. sky/jobs/utils.py +1306 -215
  193. sky/logs/__init__.py +21 -0
  194. sky/logs/agent.py +108 -0
  195. sky/logs/aws.py +243 -0
  196. sky/logs/gcp.py +91 -0
  197. sky/metrics/__init__.py +0 -0
  198. sky/metrics/utils.py +453 -0
  199. sky/models.py +78 -1
  200. sky/optimizer.py +164 -70
  201. sky/provision/__init__.py +90 -4
  202. sky/provision/aws/config.py +147 -26
  203. sky/provision/aws/instance.py +136 -50
  204. sky/provision/azure/instance.py +11 -6
  205. sky/provision/common.py +13 -1
  206. sky/provision/cudo/cudo_machine_type.py +1 -1
  207. sky/provision/cudo/cudo_utils.py +14 -8
  208. sky/provision/cudo/cudo_wrapper.py +72 -71
  209. sky/provision/cudo/instance.py +10 -6
  210. sky/provision/do/instance.py +10 -6
  211. sky/provision/do/utils.py +4 -3
  212. sky/provision/docker_utils.py +140 -33
  213. sky/provision/fluidstack/instance.py +13 -8
  214. sky/provision/gcp/__init__.py +1 -0
  215. sky/provision/gcp/config.py +301 -19
  216. sky/provision/gcp/constants.py +218 -0
  217. sky/provision/gcp/instance.py +36 -8
  218. sky/provision/gcp/instance_utils.py +18 -4
  219. sky/provision/gcp/volume_utils.py +247 -0
  220. sky/provision/hyperbolic/__init__.py +12 -0
  221. sky/provision/hyperbolic/config.py +10 -0
  222. sky/provision/hyperbolic/instance.py +437 -0
  223. sky/provision/hyperbolic/utils.py +373 -0
  224. sky/provision/instance_setup.py +101 -20
  225. sky/provision/kubernetes/__init__.py +5 -0
  226. sky/provision/kubernetes/config.py +9 -52
  227. sky/provision/kubernetes/constants.py +17 -0
  228. sky/provision/kubernetes/instance.py +919 -280
  229. sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
  230. sky/provision/kubernetes/network.py +27 -17
  231. sky/provision/kubernetes/network_utils.py +44 -43
  232. sky/provision/kubernetes/utils.py +1221 -534
  233. sky/provision/kubernetes/volume.py +343 -0
  234. sky/provision/lambda_cloud/instance.py +22 -16
  235. sky/provision/nebius/constants.py +50 -0
  236. sky/provision/nebius/instance.py +19 -6
  237. sky/provision/nebius/utils.py +237 -137
  238. sky/provision/oci/instance.py +10 -5
  239. sky/provision/paperspace/instance.py +10 -7
  240. sky/provision/paperspace/utils.py +1 -1
  241. sky/provision/primeintellect/__init__.py +10 -0
  242. sky/provision/primeintellect/config.py +11 -0
  243. sky/provision/primeintellect/instance.py +454 -0
  244. sky/provision/primeintellect/utils.py +398 -0
  245. sky/provision/provisioner.py +117 -36
  246. sky/provision/runpod/__init__.py +5 -0
  247. sky/provision/runpod/instance.py +27 -6
  248. sky/provision/runpod/utils.py +51 -18
  249. sky/provision/runpod/volume.py +214 -0
  250. sky/provision/scp/__init__.py +15 -0
  251. sky/provision/scp/config.py +93 -0
  252. sky/provision/scp/instance.py +707 -0
  253. sky/provision/seeweb/__init__.py +11 -0
  254. sky/provision/seeweb/config.py +13 -0
  255. sky/provision/seeweb/instance.py +812 -0
  256. sky/provision/shadeform/__init__.py +11 -0
  257. sky/provision/shadeform/config.py +12 -0
  258. sky/provision/shadeform/instance.py +351 -0
  259. sky/provision/shadeform/shadeform_utils.py +83 -0
  260. sky/provision/ssh/__init__.py +18 -0
  261. sky/provision/vast/instance.py +13 -8
  262. sky/provision/vast/utils.py +10 -7
  263. sky/provision/volume.py +164 -0
  264. sky/provision/vsphere/common/ssl_helper.py +1 -1
  265. sky/provision/vsphere/common/vapiconnect.py +2 -1
  266. sky/provision/vsphere/common/vim_utils.py +4 -4
  267. sky/provision/vsphere/instance.py +15 -10
  268. sky/provision/vsphere/vsphere_utils.py +17 -20
  269. sky/py.typed +0 -0
  270. sky/resources.py +845 -119
  271. sky/schemas/__init__.py +0 -0
  272. sky/schemas/api/__init__.py +0 -0
  273. sky/schemas/api/responses.py +227 -0
  274. sky/schemas/db/README +4 -0
  275. sky/schemas/db/env.py +90 -0
  276. sky/schemas/db/global_user_state/001_initial_schema.py +124 -0
  277. sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
  278. sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
  279. sky/schemas/db/global_user_state/004_is_managed.py +34 -0
  280. sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
  281. sky/schemas/db/global_user_state/006_provision_log.py +41 -0
  282. sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
  283. sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
  284. sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
  285. sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
  286. sky/schemas/db/global_user_state/011_is_ephemeral.py +34 -0
  287. sky/schemas/db/kv_cache/001_initial_schema.py +29 -0
  288. sky/schemas/db/script.py.mako +28 -0
  289. sky/schemas/db/serve_state/001_initial_schema.py +67 -0
  290. sky/schemas/db/serve_state/002_yaml_content.py +34 -0
  291. sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
  292. sky/schemas/db/spot_jobs/001_initial_schema.py +97 -0
  293. sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
  294. sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
  295. sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
  296. sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
  297. sky/schemas/db/spot_jobs/006_controller_pid_started_at.py +34 -0
  298. sky/schemas/db/spot_jobs/007_config_file_content.py +34 -0
  299. sky/schemas/generated/__init__.py +0 -0
  300. sky/schemas/generated/autostopv1_pb2.py +36 -0
  301. sky/schemas/generated/autostopv1_pb2.pyi +43 -0
  302. sky/schemas/generated/autostopv1_pb2_grpc.py +146 -0
  303. sky/schemas/generated/jobsv1_pb2.py +86 -0
  304. sky/schemas/generated/jobsv1_pb2.pyi +254 -0
  305. sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
  306. sky/schemas/generated/managed_jobsv1_pb2.py +76 -0
  307. sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
  308. sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
  309. sky/schemas/generated/servev1_pb2.py +58 -0
  310. sky/schemas/generated/servev1_pb2.pyi +115 -0
  311. sky/schemas/generated/servev1_pb2_grpc.py +322 -0
  312. sky/serve/autoscalers.py +357 -5
  313. sky/serve/client/impl.py +310 -0
  314. sky/serve/client/sdk.py +47 -139
  315. sky/serve/client/sdk_async.py +130 -0
  316. sky/serve/constants.py +12 -9
  317. sky/serve/controller.py +68 -17
  318. sky/serve/load_balancer.py +106 -60
  319. sky/serve/load_balancing_policies.py +116 -2
  320. sky/serve/replica_managers.py +434 -249
  321. sky/serve/serve_rpc_utils.py +179 -0
  322. sky/serve/serve_state.py +569 -257
  323. sky/serve/serve_utils.py +775 -265
  324. sky/serve/server/core.py +66 -711
  325. sky/serve/server/impl.py +1093 -0
  326. sky/serve/server/server.py +21 -18
  327. sky/serve/service.py +192 -89
  328. sky/serve/service_spec.py +144 -20
  329. sky/serve/spot_placer.py +3 -0
  330. sky/server/auth/__init__.py +0 -0
  331. sky/server/auth/authn.py +50 -0
  332. sky/server/auth/loopback.py +38 -0
  333. sky/server/auth/oauth2_proxy.py +202 -0
  334. sky/server/common.py +478 -182
  335. sky/server/config.py +85 -23
  336. sky/server/constants.py +44 -6
  337. sky/server/daemons.py +295 -0
  338. sky/server/html/token_page.html +185 -0
  339. sky/server/metrics.py +160 -0
  340. sky/server/middleware_utils.py +166 -0
  341. sky/server/requests/executor.py +558 -138
  342. sky/server/requests/payloads.py +364 -24
  343. sky/server/requests/preconditions.py +21 -17
  344. sky/server/requests/process.py +112 -29
  345. sky/server/requests/request_names.py +121 -0
  346. sky/server/requests/requests.py +822 -226
  347. sky/server/requests/serializers/decoders.py +82 -31
  348. sky/server/requests/serializers/encoders.py +140 -22
  349. sky/server/requests/threads.py +117 -0
  350. sky/server/rest.py +455 -0
  351. sky/server/server.py +1309 -285
  352. sky/server/state.py +20 -0
  353. sky/server/stream_utils.py +327 -61
  354. sky/server/uvicorn.py +217 -3
  355. sky/server/versions.py +270 -0
  356. sky/setup_files/MANIFEST.in +11 -1
  357. sky/setup_files/alembic.ini +160 -0
  358. sky/setup_files/dependencies.py +139 -31
  359. sky/setup_files/setup.py +44 -42
  360. sky/sky_logging.py +114 -7
  361. sky/skylet/attempt_skylet.py +106 -24
  362. sky/skylet/autostop_lib.py +129 -8
  363. sky/skylet/configs.py +29 -20
  364. sky/skylet/constants.py +216 -25
  365. sky/skylet/events.py +101 -21
  366. sky/skylet/job_lib.py +345 -164
  367. sky/skylet/log_lib.py +297 -18
  368. sky/skylet/log_lib.pyi +44 -1
  369. sky/skylet/providers/ibm/node_provider.py +12 -8
  370. sky/skylet/providers/ibm/vpc_provider.py +13 -12
  371. sky/skylet/ray_patches/__init__.py +17 -3
  372. sky/skylet/ray_patches/autoscaler.py.diff +18 -0
  373. sky/skylet/ray_patches/cli.py.diff +19 -0
  374. sky/skylet/ray_patches/command_runner.py.diff +17 -0
  375. sky/skylet/ray_patches/log_monitor.py.diff +20 -0
  376. sky/skylet/ray_patches/resource_demand_scheduler.py.diff +32 -0
  377. sky/skylet/ray_patches/updater.py.diff +18 -0
  378. sky/skylet/ray_patches/worker.py.diff +41 -0
  379. sky/skylet/runtime_utils.py +21 -0
  380. sky/skylet/services.py +568 -0
  381. sky/skylet/skylet.py +72 -4
  382. sky/skylet/subprocess_daemon.py +104 -29
  383. sky/skypilot_config.py +506 -99
  384. sky/ssh_node_pools/__init__.py +1 -0
  385. sky/ssh_node_pools/core.py +135 -0
  386. sky/ssh_node_pools/server.py +233 -0
  387. sky/task.py +685 -163
  388. sky/templates/aws-ray.yml.j2 +11 -3
  389. sky/templates/azure-ray.yml.j2 +2 -1
  390. sky/templates/cudo-ray.yml.j2 +1 -0
  391. sky/templates/do-ray.yml.j2 +2 -1
  392. sky/templates/fluidstack-ray.yml.j2 +1 -0
  393. sky/templates/gcp-ray.yml.j2 +62 -1
  394. sky/templates/hyperbolic-ray.yml.j2 +68 -0
  395. sky/templates/ibm-ray.yml.j2 +2 -1
  396. sky/templates/jobs-controller.yaml.j2 +27 -24
  397. sky/templates/kubernetes-loadbalancer.yml.j2 +2 -0
  398. sky/templates/kubernetes-ray.yml.j2 +611 -50
  399. sky/templates/lambda-ray.yml.j2 +2 -1
  400. sky/templates/nebius-ray.yml.j2 +34 -12
  401. sky/templates/oci-ray.yml.j2 +1 -0
  402. sky/templates/paperspace-ray.yml.j2 +2 -1
  403. sky/templates/primeintellect-ray.yml.j2 +72 -0
  404. sky/templates/runpod-ray.yml.j2 +10 -1
  405. sky/templates/scp-ray.yml.j2 +4 -50
  406. sky/templates/seeweb-ray.yml.j2 +171 -0
  407. sky/templates/shadeform-ray.yml.j2 +73 -0
  408. sky/templates/sky-serve-controller.yaml.j2 +22 -2
  409. sky/templates/vast-ray.yml.j2 +1 -0
  410. sky/templates/vsphere-ray.yml.j2 +1 -0
  411. sky/templates/websocket_proxy.py +212 -37
  412. sky/usage/usage_lib.py +31 -15
  413. sky/users/__init__.py +0 -0
  414. sky/users/model.conf +15 -0
  415. sky/users/permission.py +397 -0
  416. sky/users/rbac.py +121 -0
  417. sky/users/server.py +720 -0
  418. sky/users/token_service.py +218 -0
  419. sky/utils/accelerator_registry.py +35 -5
  420. sky/utils/admin_policy_utils.py +84 -38
  421. sky/utils/annotations.py +38 -5
  422. sky/utils/asyncio_utils.py +78 -0
  423. sky/utils/atomic.py +1 -1
  424. sky/utils/auth_utils.py +153 -0
  425. sky/utils/benchmark_utils.py +60 -0
  426. sky/utils/cli_utils/status_utils.py +159 -86
  427. sky/utils/cluster_utils.py +31 -9
  428. sky/utils/command_runner.py +354 -68
  429. sky/utils/command_runner.pyi +93 -3
  430. sky/utils/common.py +35 -8
  431. sky/utils/common_utils.py +314 -91
  432. sky/utils/config_utils.py +74 -5
  433. sky/utils/context.py +403 -0
  434. sky/utils/context_utils.py +242 -0
  435. sky/utils/controller_utils.py +383 -89
  436. sky/utils/dag_utils.py +31 -12
  437. sky/utils/db/__init__.py +0 -0
  438. sky/utils/db/db_utils.py +485 -0
  439. sky/utils/db/kv_cache.py +149 -0
  440. sky/utils/db/migration_utils.py +137 -0
  441. sky/utils/directory_utils.py +12 -0
  442. sky/utils/env_options.py +13 -0
  443. sky/utils/git.py +567 -0
  444. sky/utils/git_clone.sh +460 -0
  445. sky/utils/infra_utils.py +195 -0
  446. sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
  447. sky/utils/kubernetes/config_map_utils.py +133 -0
  448. sky/utils/kubernetes/create_cluster.sh +15 -29
  449. sky/utils/kubernetes/delete_cluster.sh +10 -7
  450. sky/utils/kubernetes/deploy_ssh_node_pools.py +1177 -0
  451. sky/utils/kubernetes/exec_kubeconfig_converter.py +22 -31
  452. sky/utils/kubernetes/generate_kind_config.py +6 -66
  453. sky/utils/kubernetes/generate_kubeconfig.sh +4 -1
  454. sky/utils/kubernetes/gpu_labeler.py +18 -8
  455. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +2 -1
  456. sky/utils/kubernetes/k8s_gpu_labeler_setup.yaml +16 -16
  457. sky/utils/kubernetes/kubernetes_deploy_utils.py +284 -114
  458. sky/utils/kubernetes/rsync_helper.sh +11 -3
  459. sky/utils/kubernetes/ssh-tunnel.sh +379 -0
  460. sky/utils/kubernetes/ssh_utils.py +221 -0
  461. sky/utils/kubernetes_enums.py +8 -15
  462. sky/utils/lock_events.py +94 -0
  463. sky/utils/locks.py +416 -0
  464. sky/utils/log_utils.py +82 -107
  465. sky/utils/perf_utils.py +22 -0
  466. sky/utils/resource_checker.py +298 -0
  467. sky/utils/resources_utils.py +249 -32
  468. sky/utils/rich_utils.py +217 -39
  469. sky/utils/schemas.py +955 -160
  470. sky/utils/serialize_utils.py +16 -0
  471. sky/utils/status_lib.py +10 -0
  472. sky/utils/subprocess_utils.py +29 -15
  473. sky/utils/tempstore.py +70 -0
  474. sky/utils/thread_utils.py +91 -0
  475. sky/utils/timeline.py +26 -53
  476. sky/utils/ux_utils.py +84 -15
  477. sky/utils/validator.py +11 -1
  478. sky/utils/volume.py +165 -0
  479. sky/utils/yaml_utils.py +111 -0
  480. sky/volumes/__init__.py +13 -0
  481. sky/volumes/client/__init__.py +0 -0
  482. sky/volumes/client/sdk.py +150 -0
  483. sky/volumes/server/__init__.py +0 -0
  484. sky/volumes/server/core.py +270 -0
  485. sky/volumes/server/server.py +124 -0
  486. sky/volumes/volume.py +215 -0
  487. sky/workspaces/__init__.py +0 -0
  488. sky/workspaces/core.py +655 -0
  489. sky/workspaces/server.py +101 -0
  490. sky/workspaces/utils.py +56 -0
  491. sky_templates/README.md +3 -0
  492. sky_templates/__init__.py +3 -0
  493. sky_templates/ray/__init__.py +0 -0
  494. sky_templates/ray/start_cluster +183 -0
  495. sky_templates/ray/stop_cluster +75 -0
  496. skypilot_nightly-1.0.0.dev20251203.dist-info/METADATA +676 -0
  497. skypilot_nightly-1.0.0.dev20251203.dist-info/RECORD +611 -0
  498. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/WHEEL +1 -1
  499. skypilot_nightly-1.0.0.dev20251203.dist-info/top_level.txt +2 -0
  500. sky/benchmark/benchmark_state.py +0 -256
  501. sky/benchmark/benchmark_utils.py +0 -641
  502. sky/clouds/service_catalog/constants.py +0 -7
  503. sky/dashboard/out/_next/static/GWvVBSCS7FmUiVmjaL1a7/_buildManifest.js +0 -1
  504. sky/dashboard/out/_next/static/chunks/236-2db3ee3fba33dd9e.js +0 -6
  505. sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
  506. sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
  507. sky/dashboard/out/_next/static/chunks/678-206dddca808e6d16.js +0 -59
  508. sky/dashboard/out/_next/static/chunks/845-9e60713e0c441abc.js +0 -1
  509. sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
  510. sky/dashboard/out/_next/static/chunks/fd9d1056-2821b0f0cabcd8bd.js +0 -1
  511. sky/dashboard/out/_next/static/chunks/framework-87d061ee6ed71b28.js +0 -33
  512. sky/dashboard/out/_next/static/chunks/main-app-241eb28595532291.js +0 -1
  513. sky/dashboard/out/_next/static/chunks/main-e0e2335212e72357.js +0 -1
  514. sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
  515. sky/dashboard/out/_next/static/chunks/pages/_error-1be831200e60c5c0.js +0 -1
  516. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-6ac338bc2239cb45.js +0 -1
  517. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
  518. sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
  519. sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
  520. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-1c519e1afc523dc9.js +0 -1
  521. sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
  522. sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
  523. sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
  524. sky/jobs/dashboard/dashboard.py +0 -223
  525. sky/jobs/dashboard/static/favicon.ico +0 -0
  526. sky/jobs/dashboard/templates/index.html +0 -831
  527. sky/jobs/server/dashboard_utils.py +0 -69
  528. sky/skylet/providers/scp/__init__.py +0 -2
  529. sky/skylet/providers/scp/config.py +0 -149
  530. sky/skylet/providers/scp/node_provider.py +0 -578
  531. sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
  532. sky/utils/db_utils.py +0 -100
  533. sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
  534. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
  535. skypilot_nightly-1.0.0.dev20250502.dist-info/METADATA +0 -361
  536. skypilot_nightly-1.0.0.dev20250502.dist-info/RECORD +0 -396
  537. skypilot_nightly-1.0.0.dev20250502.dist-info/top_level.txt +0 -1
  538. /sky/{clouds/service_catalog → catalog}/config.py +0 -0
  539. /sky/{benchmark → catalog/data_fetchers}/__init__.py +0 -0
  540. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_azure.py +0 -0
  541. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_fluidstack.py +0 -0
  542. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_ibm.py +0 -0
  543. /sky/{clouds/service_catalog/data_fetchers → client/cli}/__init__.py +0 -0
  544. /sky/dashboard/out/_next/static/{GWvVBSCS7FmUiVmjaL1a7 → 96_E2yl3QAiIJGOYCkSpB}/_ssgManifest.js +0 -0
  545. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/entry_points.txt +0 -0
  546. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/licenses/LICENSE +0 -0
sky/server/server.py CHANGED
@@ -2,55 +2,91 @@
2
2
 
3
3
  import argparse
4
4
  import asyncio
5
+ import base64
6
+ from concurrent.futures import ThreadPoolExecutor
5
7
  import contextlib
6
- import dataclasses
7
8
  import datetime
8
- import logging
9
+ from enum import IntEnum
10
+ import hashlib
11
+ import json
9
12
  import multiprocessing
10
13
  import os
11
14
  import pathlib
15
+ import posixpath
12
16
  import re
17
+ import resource
13
18
  import shutil
19
+ import struct
14
20
  import sys
15
- from typing import Any, Dict, List, Literal, Optional, Set, Tuple
21
+ import threading
22
+ import traceback
23
+ from typing import Dict, List, Literal, Optional, Set, Tuple
16
24
  import uuid
17
25
  import zipfile
18
26
 
19
27
  import aiofiles
28
+ import anyio
20
29
  import fastapi
30
+ from fastapi import responses as fastapi_responses
21
31
  from fastapi.middleware import cors
22
32
  import starlette.middleware.base
33
+ import uvloop
23
34
 
24
35
  import sky
36
+ from sky import catalog
25
37
  from sky import check as sky_check
26
38
  from sky import clouds
27
39
  from sky import core
28
40
  from sky import exceptions
29
41
  from sky import execution
30
42
  from sky import global_user_state
43
+ from sky import models
31
44
  from sky import sky_logging
32
- from sky.clouds import service_catalog
33
45
  from sky.data import storage_utils
46
+ from sky.jobs import utils as managed_job_utils
34
47
  from sky.jobs.server import server as jobs_rest
48
+ from sky.metrics import utils as metrics_utils
49
+ from sky.provision import metadata_utils
35
50
  from sky.provision.kubernetes import utils as kubernetes_utils
51
+ from sky.schemas.api import responses
36
52
  from sky.serve.server import server as serve_rest
37
53
  from sky.server import common
38
54
  from sky.server import config as server_config
39
55
  from sky.server import constants as server_constants
56
+ from sky.server import daemons
57
+ from sky.server import metrics
58
+ from sky.server import middleware_utils
59
+ from sky.server import state
40
60
  from sky.server import stream_utils
61
+ from sky.server import versions
62
+ from sky.server.auth import authn
63
+ from sky.server.auth import loopback
64
+ from sky.server.auth import oauth2_proxy
41
65
  from sky.server.requests import executor
42
66
  from sky.server.requests import payloads
43
67
  from sky.server.requests import preconditions
68
+ from sky.server.requests import request_names
44
69
  from sky.server.requests import requests as requests_lib
45
70
  from sky.skylet import constants
71
+ from sky.ssh_node_pools import server as ssh_node_pools_rest
46
72
  from sky.usage import usage_lib
73
+ from sky.users import permission
74
+ from sky.users import server as users_rest
47
75
  from sky.utils import admin_policy_utils
48
76
  from sky.utils import common as common_lib
49
77
  from sky.utils import common_utils
78
+ from sky.utils import context
79
+ from sky.utils import context_utils
80
+ from sky.utils import controller_utils
50
81
  from sky.utils import dag_utils
51
82
  from sky.utils import env_options
83
+ from sky.utils import perf_utils
52
84
  from sky.utils import status_lib
53
85
  from sky.utils import subprocess_utils
86
+ from sky.utils import ux_utils
87
+ from sky.utils.db import db_utils
88
+ from sky.volumes.server import server as volumes_rest
89
+ from sky.workspaces import server as workspaces_rest
54
90
 
55
91
  # pylint: disable=ungrouped-imports
56
92
  if sys.version_info >= (3, 10):
@@ -60,31 +96,8 @@ else:
60
96
 
61
97
  P = ParamSpec('P')
62
98
 
99
+ _SERVER_USER_HASH_KEY = 'server_user_hash'
63
100
 
64
- def _add_timestamp_prefix_for_server_logs() -> None:
65
- server_logger = sky_logging.init_logger('sky.server')
66
- # Clear existing handlers first to prevent duplicates
67
- server_logger.handlers.clear()
68
- # Disable propagation to avoid the root logger of SkyPilot being affected
69
- server_logger.propagate = False
70
- # Add date prefix to the log message printed by loggers under
71
- # server.
72
- stream_handler = logging.StreamHandler(sys.stdout)
73
- if env_options.Options.SHOW_DEBUG_INFO.get():
74
- stream_handler.setLevel(logging.DEBUG)
75
- else:
76
- stream_handler.setLevel(logging.INFO)
77
- stream_handler.flush = sys.stdout.flush # type: ignore
78
- stream_handler.setFormatter(sky_logging.FORMATTER)
79
- server_logger.addHandler(stream_handler)
80
- # Add date prefix to the log message printed by uvicorn.
81
- for name in ['uvicorn', 'uvicorn.access']:
82
- uvicorn_logger = logging.getLogger(name)
83
- uvicorn_logger.handlers.clear()
84
- uvicorn_logger.addHandler(stream_handler)
85
-
86
-
87
- _add_timestamp_prefix_for_server_logs()
88
101
  logger = sky_logging.init_logger(__name__)
89
102
 
90
103
  # TODO(zhwu): Streaming requests, such log tailing after sky launch or sky logs,
@@ -92,17 +105,315 @@ logger = sky_logging.init_logger(__name__)
92
105
  # response will block other requests from being processed.
93
106
 
94
107
 
108
+ def _basic_auth_401_response(content: str):
109
+ """Return a 401 response with basic auth realm."""
110
+ return fastapi.responses.JSONResponse(
111
+ status_code=401,
112
+ headers={'WWW-Authenticate': 'Basic realm=\"SkyPilot\"'},
113
+ content=content)
114
+
115
+
116
+ def _try_set_basic_auth_user(request: fastapi.Request):
117
+ auth_header = request.headers.get('authorization')
118
+ if not auth_header or not auth_header.lower().startswith('basic '):
119
+ return
120
+
121
+ # Check username and password
122
+ encoded = auth_header.split(' ', 1)[1]
123
+ try:
124
+ decoded = base64.b64decode(encoded).decode()
125
+ username, password = decoded.split(':', 1)
126
+ except Exception: # pylint: disable=broad-except
127
+ return
128
+
129
+ users = global_user_state.get_user_by_name(username)
130
+ if not users:
131
+ return
132
+
133
+ for user in users:
134
+ if not user.name or not user.password:
135
+ continue
136
+ username_encoded = username.encode('utf8')
137
+ db_username_encoded = user.name.encode('utf8')
138
+ if (username_encoded == db_username_encoded and
139
+ common.crypt_ctx.verify(password, user.password)):
140
+ request.state.auth_user = user
141
+ break
142
+
143
+
144
+ @middleware_utils.websocket_aware
145
+ class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
146
+ """Middleware to handle RBAC."""
147
+
148
+ async def dispatch(self, request: fastapi.Request, call_next):
149
+ # TODO(hailong): should have a list of paths
150
+ # that are not checked for RBAC
151
+ if (request.url.path.startswith('/dashboard/') or
152
+ request.url.path.startswith('/api/')):
153
+ return await call_next(request)
154
+
155
+ auth_user = request.state.auth_user
156
+ if auth_user is None:
157
+ return await call_next(request)
158
+
159
+ permission_service = permission.permission_service
160
+ # Check the role permission
161
+ if permission_service.check_endpoint_permission(auth_user.id,
162
+ request.url.path,
163
+ request.method):
164
+ return fastapi.responses.JSONResponse(
165
+ status_code=403, content={'detail': 'Forbidden'})
166
+
167
+ return await call_next(request)
168
+
169
+
95
170
  class RequestIDMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
96
171
  """Middleware to add a request ID to each request."""
97
172
 
98
173
  async def dispatch(self, request: fastapi.Request, call_next):
99
- request_id = str(uuid.uuid4())
174
+ request_id = requests_lib.get_new_request_id()
100
175
  request.state.request_id = request_id
101
176
  response = await call_next(request)
102
- response.headers['X-Request-ID'] = request_id
177
+ response.headers['X-Skypilot-Request-ID'] = request_id
103
178
  return response
104
179
 
105
180
 
181
+ def _get_auth_user_header(request: fastapi.Request) -> Optional[models.User]:
182
+ header_name = os.environ.get(constants.ENV_VAR_SERVER_AUTH_USER_HEADER,
183
+ 'X-Auth-Request-Email')
184
+ if header_name not in request.headers:
185
+ return None
186
+ user_name = request.headers[header_name]
187
+ user_hash = hashlib.md5(
188
+ user_name.encode()).hexdigest()[:common_utils.USER_HASH_LENGTH]
189
+ return models.User(id=user_hash, name=user_name)
190
+
191
+
192
+ @middleware_utils.websocket_aware
193
+ class InitializeRequestAuthUserMiddleware(
194
+ starlette.middleware.base.BaseHTTPMiddleware):
195
+
196
+ async def dispatch(self, request: fastapi.Request, call_next):
197
+ # Make sure that request.state.auth_user is set. Otherwise, we may get a
198
+ # KeyError while trying to read it.
199
+ request.state.auth_user = None
200
+ return await call_next(request)
201
+
202
+
203
+ @middleware_utils.websocket_aware
204
+ class BasicAuthMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
205
+ """Middleware to handle HTTP Basic Auth."""
206
+
207
+ async def dispatch(self, request: fastapi.Request, call_next):
208
+ if managed_job_utils.is_consolidation_mode(
209
+ ) and loopback.is_loopback_request(request):
210
+ return await call_next(request)
211
+
212
+ if request.url.path.startswith('/api/health'):
213
+ # Try to set the auth user from basic auth
214
+ _try_set_basic_auth_user(request)
215
+ return await call_next(request)
216
+
217
+ auth_header = request.headers.get('authorization')
218
+ if not auth_header:
219
+ return _basic_auth_401_response('Authentication required')
220
+
221
+ # Only handle basic auth
222
+ if not auth_header.lower().startswith('basic '):
223
+ return _basic_auth_401_response('Invalid authentication method')
224
+
225
+ # Check username and password
226
+ encoded = auth_header.split(' ', 1)[1]
227
+ try:
228
+ decoded = base64.b64decode(encoded).decode()
229
+ username, password = decoded.split(':', 1)
230
+ except Exception: # pylint: disable=broad-except
231
+ return _basic_auth_401_response('Invalid basic auth')
232
+
233
+ users = global_user_state.get_user_by_name(username)
234
+ if not users:
235
+ return _basic_auth_401_response('Invalid credentials')
236
+
237
+ valid_user = False
238
+ for user in users:
239
+ if not user.name or not user.password:
240
+ continue
241
+ username_encoded = username.encode('utf8')
242
+ db_username_encoded = user.name.encode('utf8')
243
+ if (username_encoded == db_username_encoded and
244
+ common.crypt_ctx.verify(password, user.password)):
245
+ valid_user = True
246
+ request.state.auth_user = user
247
+ await authn.override_user_info_in_request_body(request, user)
248
+ break
249
+ if not valid_user:
250
+ return _basic_auth_401_response('Invalid credentials')
251
+
252
+ return await call_next(request)
253
+
254
+
255
+ @middleware_utils.websocket_aware
256
+ class BearerTokenMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
257
+ """Middleware to handle Bearer Token Auth (Service Accounts)."""
258
+
259
+ async def dispatch(self, request: fastapi.Request, call_next):
260
+ """Make sure correct bearer token auth is present.
261
+
262
+ 1. If the request has the X-Skypilot-Auth-Mode: token header, it must
263
+ have a valid bearer token.
264
+ 2. For backwards compatibility, if the request has a Bearer token
265
+ beginning with "sky_" (even if X-Skypilot-Auth-Mode is not present),
266
+ it must be a valid token.
267
+ 3. If X-Skypilot-Auth-Mode is not set to "token", and there is no Bearer
268
+ token beginning with "sky_", allow the request to continue.
269
+
270
+ In conjunction with an auth proxy, the idea is to make the auth proxy
271
+ bypass requests with bearer tokens, instead setting the
272
+ X-Skypilot-Auth-Mode header. The auth proxy should either validate the
273
+ auth or set the header X-Skypilot-Auth-Mode: token.
274
+ """
275
+ has_skypilot_auth_header = (
276
+ request.headers.get('X-Skypilot-Auth-Mode') == 'token')
277
+ auth_header = request.headers.get('authorization')
278
+ has_bearer_token_starting_with_sky = (
279
+ auth_header and auth_header.lower().startswith('bearer ') and
280
+ auth_header.split(' ', 1)[1].startswith('sky_'))
281
+
282
+ if (not has_skypilot_auth_header and
283
+ not has_bearer_token_starting_with_sky):
284
+ # This is case #3 above. We do not need to validate the request.
285
+ # No Bearer token, continue with normal processing (OAuth2 cookies,
286
+ # etc.)
287
+ return await call_next(request)
288
+ # After this point, all requests must be validated.
289
+
290
+ if auth_header is None:
291
+ return fastapi.responses.JSONResponse(
292
+ status_code=401, content={'detail': 'Authentication required'})
293
+
294
+ # Extract token
295
+ split_header = auth_header.split(' ', 1)
296
+ if split_header[0].lower() != 'bearer':
297
+ return fastapi.responses.JSONResponse(
298
+ status_code=401,
299
+ content={'detail': 'Invalid authentication method'})
300
+ sa_token = split_header[1]
301
+
302
+ # Handle SkyPilot service account tokens
303
+ return await self._handle_service_account_token(request, sa_token,
304
+ call_next)
305
+
306
+ async def _handle_service_account_token(self, request: fastapi.Request,
307
+ sa_token: str, call_next):
308
+ """Handle SkyPilot service account tokens."""
309
+ # Check if service account tokens are enabled
310
+ sa_enabled = os.environ.get(constants.ENV_VAR_ENABLE_SERVICE_ACCOUNTS,
311
+ 'false').lower()
312
+ if sa_enabled != 'true':
313
+ return fastapi.responses.JSONResponse(
314
+ status_code=401,
315
+ content={'detail': 'Service account authentication disabled'})
316
+
317
+ try:
318
+ # Import here to avoid circular imports
319
+ # pylint: disable=import-outside-toplevel
320
+ from sky.users.token_service import token_service
321
+
322
+ # Verify and decode JWT token
323
+ payload = token_service.verify_token(sa_token)
324
+
325
+ if payload is None:
326
+ logger.warning('Service account token verification failed')
327
+ return fastapi.responses.JSONResponse(
328
+ status_code=401,
329
+ content={
330
+ 'detail': 'Invalid or expired service account token'
331
+ })
332
+
333
+ # Extract user information from JWT payload
334
+ user_id = payload.get('sub')
335
+ user_name = payload.get('name')
336
+ token_id = payload.get('token_id')
337
+
338
+ if not user_id or not token_id:
339
+ logger.warning(
340
+ 'Invalid token payload: missing user_id or token_id')
341
+ return fastapi.responses.JSONResponse(
342
+ status_code=401,
343
+ content={'detail': 'Invalid token payload'})
344
+
345
+ # Verify user still exists in database
346
+ user_info = global_user_state.get_user(user_id)
347
+ if user_info is None:
348
+ logger.warning(
349
+ f'Service account user {user_id} no longer exists')
350
+ return fastapi.responses.JSONResponse(
351
+ status_code=401,
352
+ content={'detail': 'Service account user no longer exists'})
353
+
354
+ # Update last used timestamp for token tracking
355
+ try:
356
+ global_user_state.update_service_account_token_last_used(
357
+ token_id)
358
+ except Exception as e: # pylint: disable=broad-except
359
+ logger.debug(f'Failed to update token last used time: {e}')
360
+
361
+ # Set the authenticated user
362
+ auth_user = models.User(id=user_id,
363
+ name=user_name or user_info.name)
364
+ request.state.auth_user = auth_user
365
+
366
+ # Override user info in request body for service account requests
367
+ await authn.override_user_info_in_request_body(request, auth_user)
368
+
369
+ logger.debug(f'Authenticated service account: {user_id}')
370
+
371
+ except Exception as e: # pylint: disable=broad-except
372
+ logger.error(f'Service account authentication failed: {e}',
373
+ exc_info=True)
374
+ return fastapi.responses.JSONResponse(
375
+ status_code=401,
376
+ content={
377
+ 'detail': f'Service account authentication failed: {str(e)}'
378
+ })
379
+
380
+ return await call_next(request)
381
+
382
+
383
+ @middleware_utils.websocket_aware
384
+ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
385
+ """Middleware to handle auth proxy."""
386
+
387
+ async def dispatch(self, request: fastapi.Request, call_next):
388
+ auth_user = _get_auth_user_header(request)
389
+
390
+ if request.state.auth_user is not None:
391
+ # Previous middleware is trusted more than this middleware. For
392
+ # instance, a client could set the Authorization and the
393
+ # X-Auth-Request-Email header. In that case, the auth proxy will be
394
+ # skipped and we should rely on the Bearer token to authenticate the
395
+ # user - but that means the user could set X-Auth-Request-Email to
396
+ # whatever the user wants. We should thus ignore it.
397
+ if auth_user is not None:
398
+ logger.debug('Warning: ignoring auth proxy header since the '
399
+ 'auth user was already set.')
400
+ return await call_next(request)
401
+
402
+ # Add user to database if auth_user is present
403
+ if auth_user is not None:
404
+ newly_added = global_user_state.add_or_update_user(auth_user)
405
+ if newly_added:
406
+ permission.permission_service.add_user_if_not_exists(
407
+ auth_user.id)
408
+
409
+ # Store user info in request.state for access by GET endpoints
410
+ if auth_user is not None:
411
+ request.state.auth_user = auth_user
412
+
413
+ await authn.override_user_info_in_request_body(request, auth_user)
414
+ return await call_next(request)
415
+
416
+
106
417
  # Default expiration time for upload ids before cleanup.
107
418
  _DEFAULT_UPLOAD_EXPIRATION_TIME = datetime.timedelta(hours=1)
108
419
  # Key: (upload_id, user_hash), Value: the time when the upload id needs to be
@@ -132,21 +443,74 @@ async def cleanup_upload_ids():
132
443
  upload_ids_to_cleanup.pop((upload_id, user_hash))
133
444
 
134
445
 
446
+ async def loop_lag_monitor(loop: asyncio.AbstractEventLoop,
447
+ interval: float = 0.1) -> None:
448
+ target = loop.time() + interval
449
+
450
+ pid = str(os.getpid())
451
+ lag_threshold = perf_utils.get_loop_lag_threshold()
452
+
453
+ def tick():
454
+ nonlocal target
455
+ now = loop.time()
456
+ lag = max(0.0, now - target)
457
+ if lag_threshold is not None and lag > lag_threshold:
458
+ logger.warning(f'Event loop lag {lag} seconds exceeds threshold '
459
+ f'{lag_threshold} seconds.')
460
+ metrics_utils.SKY_APISERVER_EVENT_LOOP_LAG_SECONDS.labels(
461
+ pid=pid).observe(lag)
462
+ target = now + interval
463
+ loop.call_at(target, tick)
464
+
465
+ loop.call_at(target, tick)
466
+
467
+
468
+ async def schedule_on_boot_check_async():
469
+ try:
470
+ await executor.schedule_request_async(
471
+ request_id='skypilot-server-on-boot-check',
472
+ request_name=request_names.RequestName.CHECK,
473
+ request_body=payloads.CheckBody(),
474
+ func=sky_check.check,
475
+ schedule_type=requests_lib.ScheduleType.SHORT,
476
+ is_skypilot_system=True,
477
+ )
478
+ except exceptions.RequestAlreadyExistsError:
479
+ # Lifespan will be executed in each uvicorn worker process, we
480
+ # can safely ignore the error if the task is already scheduled.
481
+ logger.debug('Request skypilot-server-on-boot-check already exists.')
482
+
483
+
135
484
  @contextlib.asynccontextmanager
136
485
  async def lifespan(app: fastapi.FastAPI): # pylint: disable=redefined-outer-name
137
486
  """FastAPI lifespan context manager."""
138
487
  del app # unused
139
488
  # Startup: Run background tasks
140
- for event in requests_lib.INTERNAL_REQUEST_DAEMONS:
141
- executor.schedule_request(
142
- request_id=event.id,
143
- request_name=event.name,
144
- request_body=payloads.RequestBody(),
145
- func=event.event_fn,
146
- schedule_type=requests_lib.ScheduleType.SHORT,
147
- is_skypilot_system=True,
148
- )
489
+ for event in daemons.INTERNAL_REQUEST_DAEMONS:
490
+ if event.should_skip():
491
+ continue
492
+ try:
493
+ await executor.schedule_request_async(
494
+ request_id=event.id,
495
+ request_name=event.name,
496
+ request_body=payloads.RequestBody(),
497
+ func=event.run_event,
498
+ schedule_type=requests_lib.ScheduleType.SHORT,
499
+ is_skypilot_system=True,
500
+ # Request deamon should be retried if the process pool is
501
+ # broken.
502
+ retryable=True,
503
+ )
504
+ except exceptions.RequestAlreadyExistsError:
505
+ # Lifespan will be executed in each uvicorn worker process, we
506
+ # can safely ignore the error if the task is already scheduled.
507
+ logger.debug(f'Request {event.id} already exists.')
508
+ await schedule_on_boot_check_async()
149
509
  asyncio.create_task(cleanup_upload_ids())
510
+ if metrics_utils.METRICS_ENABLED:
511
+ # Start monitoring the event loop lag in each server worker
512
+ # event loop (process).
513
+ asyncio.create_task(loop_lag_monitor(asyncio.get_event_loop()))
150
514
  yield
151
515
  # Shutdown: Add any cleanup code here if needed
152
516
 
@@ -164,8 +528,104 @@ class InternalDashboardPrefixMiddleware(
164
528
  return await call_next(request)
165
529
 
166
530
 
531
+ class CacheControlStaticMiddleware(starlette.middleware.base.BaseHTTPMiddleware
532
+ ):
533
+ """Middleware to add cache control headers to static files."""
534
+
535
+ async def dispatch(self, request: fastapi.Request, call_next):
536
+ if request.url.path.startswith('/dashboard/_next'):
537
+ response = await call_next(request)
538
+ response.headers['Cache-Control'] = 'max-age=3600'
539
+ return response
540
+ return await call_next(request)
541
+
542
+
543
+ class PathCleanMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
544
+ """Middleware to check the path of requests."""
545
+
546
+ async def dispatch(self, request: fastapi.Request, call_next):
547
+ if request.url.path.startswith('/dashboard/'):
548
+ # If the requested path is not relative to the expected directory,
549
+ # then the user is attempting path traversal, so deny the request.
550
+ parent = pathlib.Path('/dashboard')
551
+ request_path = pathlib.Path(posixpath.normpath(request.url.path))
552
+ if not _is_relative_to(request_path, parent):
553
+ return fastapi.responses.JSONResponse(
554
+ status_code=403, content={'detail': 'Forbidden'})
555
+ return await call_next(request)
556
+
557
+
558
+ @middleware_utils.websocket_aware
559
+ class GracefulShutdownMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
560
+ """Middleware to control requests when server is shutting down."""
561
+
562
+ async def dispatch(self, request: fastapi.Request, call_next):
563
+ if state.get_block_requests():
564
+ # Allow /api/ paths to continue, which are critical to operate
565
+ # on-going requests but will not submit new requests.
566
+ if not request.url.path.startswith('/api/'):
567
+ # Client will retry on 503 error.
568
+ return fastapi.responses.JSONResponse(
569
+ status_code=503,
570
+ content={
571
+ 'detail': 'Server is shutting down, '
572
+ 'please try again later.'
573
+ })
574
+
575
+ return await call_next(request)
576
+
577
+
578
+ @middleware_utils.websocket_aware
579
+ class APIVersionMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
580
+ """Middleware to add API version to the request."""
581
+
582
+ async def dispatch(self, request: fastapi.Request, call_next):
583
+ version_info = versions.check_compatibility_at_server(request.headers)
584
+ # Bypass version handling for backward compatibility with clients prior
585
+ # to v0.11.0, the client will check the version in the body of
586
+ # /api/health response and hint an upgrade.
587
+ # TODO(aylei): remove this after v0.13.0 is released.
588
+ if version_info is None:
589
+ return await call_next(request)
590
+ if version_info.error is None:
591
+ versions.set_remote_api_version(version_info.api_version)
592
+ versions.set_remote_version(version_info.version)
593
+ response = await call_next(request)
594
+ else:
595
+ response = fastapi.responses.JSONResponse(
596
+ status_code=400,
597
+ content={
598
+ 'error': common.ApiServerStatus.VERSION_MISMATCH.value,
599
+ 'message': version_info.error,
600
+ })
601
+ response.headers[server_constants.API_VERSION_HEADER] = str(
602
+ server_constants.API_VERSION)
603
+ response.headers[server_constants.VERSION_HEADER] = \
604
+ versions.get_local_readable_version()
605
+ return response
606
+
607
+
167
608
  app = fastapi.FastAPI(prefix='/api/v1', debug=True, lifespan=lifespan)
609
+ # Middleware wraps in the order defined here. E.g., given
610
+ # app.add_middleware(Middleware1)
611
+ # app.add_middleware(Middleware2)
612
+ # app.add_middleware(Middleware3)
613
+ # The effect will be like:
614
+ # Middleware3(Middleware2(Middleware1(request)))
615
+ # If MiddlewareN does something like print(n); call_next(); print(n), you'll get
616
+ # 3; 2; 1; <request>; 1; 2; 3
617
+ # Use environment variable to make the metrics middleware optional.
618
+ if os.environ.get(constants.ENV_VAR_SERVER_METRICS_ENABLED):
619
+ app.add_middleware(metrics.PrometheusMiddleware)
620
+ app.add_middleware(APIVersionMiddleware)
621
+ # The order of all the authentication-related middleware is important.
622
+ # RBACMiddleware must precede all the auth middleware, so it can access
623
+ # request.state.auth_user.
624
+ app.add_middleware(RBACMiddleware)
168
625
  app.add_middleware(InternalDashboardPrefixMiddleware)
626
+ app.add_middleware(GracefulShutdownMiddleware)
627
+ app.add_middleware(PathCleanMiddleware)
628
+ app.add_middleware(CacheControlStaticMiddleware)
169
629
  app.add_middleware(
170
630
  cors.CORSMiddleware,
171
631
  # TODO(zhwu): in production deployment, we should restrict the allowed
@@ -174,19 +634,104 @@ app.add_middleware(
174
634
  allow_credentials=True,
175
635
  allow_methods=['*'],
176
636
  allow_headers=['*'],
177
- expose_headers=['X-Request-ID'])
637
+ expose_headers=['X-Skypilot-Request-ID'])
638
+ # Authentication based on oauth2-proxy.
639
+ app.add_middleware(oauth2_proxy.OAuth2ProxyMiddleware)
640
+ # AuthProxyMiddleware should precede BasicAuthMiddleware and
641
+ # BearerTokenMiddleware, since it should be skipped if either of those set the
642
+ # auth user.
643
+ app.add_middleware(AuthProxyMiddleware)
644
+ enable_basic_auth = os.environ.get(constants.ENV_VAR_ENABLE_BASIC_AUTH, 'false')
645
+ if str(enable_basic_auth).lower() == 'true':
646
+ app.add_middleware(BasicAuthMiddleware)
647
+ # Bearer token middleware should always be present to handle service account
648
+ # authentication
649
+ app.add_middleware(BearerTokenMiddleware)
650
+ # InitializeRequestAuthUserMiddleware must be the last added middleware so that
651
+ # request.state.auth_user is always set, but can be overridden by the auth
652
+ # middleware above.
653
+ app.add_middleware(InitializeRequestAuthUserMiddleware)
178
654
  app.add_middleware(RequestIDMiddleware)
179
655
  app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
180
656
  app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
657
+ app.include_router(users_rest.router, prefix='/users', tags=['users'])
658
+ app.include_router(workspaces_rest.router,
659
+ prefix='/workspaces',
660
+ tags=['workspaces'])
661
+ app.include_router(volumes_rest.router, prefix='/volumes', tags=['volumes'])
662
+ app.include_router(ssh_node_pools_rest.router,
663
+ prefix='/ssh_node_pools',
664
+ tags=['ssh_node_pools'])
665
+ # increase the resource limit for the server
666
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
667
+ resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
668
+
669
+
670
+ @app.exception_handler(exceptions.ConcurrentWorkerExhaustedError)
671
+ def handle_concurrent_worker_exhausted_error(
672
+ request: fastapi.Request, e: exceptions.ConcurrentWorkerExhaustedError):
673
+ del request # request is not used
674
+ # Print detailed error message to server log
675
+ logger.error('Concurrent worker exhausted: '
676
+ f'{common_utils.format_exception(e)}')
677
+ with ux_utils.enable_traceback():
678
+ logger.error(f' Traceback: {traceback.format_exc()}')
679
+ # Return human readable error message to client
680
+ return fastapi.responses.JSONResponse(
681
+ status_code=503,
682
+ content={
683
+ 'detail':
684
+ ('The server has exhausted its concurrent worker limit. '
685
+ 'Please try again or scale the server if the load persists.')
686
+ })
687
+
688
+
689
+ @app.get('/token')
690
+ async def token(request: fastapi.Request,
691
+ local_port: Optional[int] = None) -> fastapi.responses.Response:
692
+ del local_port # local_port is used by the served js, but ignored by server
693
+ user = _get_auth_user_header(request)
694
+
695
+ token_data = {
696
+ 'v': 1, # Token version number, bump for backwards incompatible.
697
+ 'user': user.id if user is not None else None,
698
+ 'cookies': request.cookies,
699
+ }
700
+ # Use base64 encoding to avoid having to escape anything in the HTML.
701
+ json_bytes = json.dumps(token_data).encode('utf-8')
702
+ base64_str = base64.b64encode(json_bytes).decode('utf-8')
703
+
704
+ html_dir = pathlib.Path(__file__).parent / 'html'
705
+ token_page_path = html_dir / 'token_page.html'
706
+ try:
707
+ with open(token_page_path, 'r', encoding='utf-8') as f:
708
+ html_content = f.read()
709
+ except FileNotFoundError as e:
710
+ raise fastapi.HTTPException(
711
+ status_code=500, detail='Token page template not found.') from e
712
+
713
+ user_info_string = f'Logged in as {user.name}' if user is not None else ''
714
+ html_content = html_content.replace(
715
+ 'SKYPILOT_API_SERVER_USER_TOKEN_PLACEHOLDER',
716
+ base64_str).replace('USER_PLACEHOLDER', user_info_string)
717
+
718
+ return fastapi.responses.HTMLResponse(
719
+ content=html_content,
720
+ headers={
721
+ 'Cache-Control': 'no-cache, no-transform',
722
+ # X-Accel-Buffering: no is useful for preventing buffering issues
723
+ # with some reverse proxies.
724
+ 'X-Accel-Buffering': 'no'
725
+ })
181
726
 
182
727
 
183
728
  @app.post('/check')
184
729
  async def check(request: fastapi.Request,
185
730
  check_body: payloads.CheckBody) -> None:
186
731
  """Checks enabled clouds."""
187
- executor.schedule_request(
732
+ await executor.schedule_request_async(
188
733
  request_id=request.state.request_id,
189
- request_name='check',
734
+ request_name=request_names.RequestName.CHECK,
190
735
  request_body=check_body,
191
736
  func=sky_check.check,
192
737
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -194,12 +739,15 @@ async def check(request: fastapi.Request,
194
739
 
195
740
 
196
741
  @app.get('/enabled_clouds')
197
- async def enabled_clouds(request: fastapi.Request) -> None:
742
+ async def enabled_clouds(request: fastapi.Request,
743
+ workspace: Optional[str] = None,
744
+ expand: bool = False) -> None:
198
745
  """Gets enabled clouds on the server."""
199
- executor.schedule_request(
746
+ await executor.schedule_request_async(
200
747
  request_id=request.state.request_id,
201
- request_name='enabled_clouds',
202
- request_body=payloads.RequestBody(),
748
+ request_name=request_names.RequestName.ENABLED_CLOUDS,
749
+ request_body=payloads.EnabledCloudsBody(workspace=workspace,
750
+ expand=expand),
203
751
  func=core.enabled_clouds,
204
752
  schedule_type=requests_lib.ScheduleType.SHORT,
205
753
  )
@@ -211,9 +759,10 @@ async def realtime_kubernetes_gpu_availability(
211
759
  realtime_gpu_availability_body: payloads.RealtimeGpuAvailabilityRequestBody
212
760
  ) -> None:
213
761
  """Gets real-time Kubernetes GPU availability."""
214
- executor.schedule_request(
762
+ await executor.schedule_request_async(
215
763
  request_id=request.state.request_id,
216
- request_name='realtime_kubernetes_gpu_availability',
764
+ request_name=request_names.RequestName.
765
+ REALTIME_KUBERNETES_GPU_AVAILABILITY,
217
766
  request_body=realtime_gpu_availability_body,
218
767
  func=core.realtime_kubernetes_gpu_availability,
219
768
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -226,9 +775,9 @@ async def kubernetes_node_info(
226
775
  kubernetes_node_info_body: payloads.KubernetesNodeInfoRequestBody
227
776
  ) -> None:
228
777
  """Gets Kubernetes nodes information and hints."""
229
- executor.schedule_request(
778
+ await executor.schedule_request_async(
230
779
  request_id=request.state.request_id,
231
- request_name='kubernetes_node_info',
780
+ request_name=request_names.RequestName.KUBERNETES_NODE_INFO,
232
781
  request_body=kubernetes_node_info_body,
233
782
  func=kubernetes_utils.get_kubernetes_node_info,
234
783
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -237,10 +786,11 @@ async def kubernetes_node_info(
237
786
 
238
787
  @app.get('/status_kubernetes')
239
788
  async def status_kubernetes(request: fastapi.Request) -> None:
240
- """Gets Kubernetes status."""
241
- executor.schedule_request(
789
+ """[Experimental] Get all SkyPilot resources (including from other '
790
+ 'users) in the current Kubernetes context."""
791
+ await executor.schedule_request_async(
242
792
  request_id=request.state.request_id,
243
- request_name='status_kubernetes',
793
+ request_name=request_names.RequestName.STATUS_KUBERNETES,
244
794
  request_body=payloads.RequestBody(),
245
795
  func=core.status_kubernetes,
246
796
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -252,11 +802,11 @@ async def list_accelerators(
252
802
  request: fastapi.Request,
253
803
  list_accelerator_counts_body: payloads.ListAcceleratorsBody) -> None:
254
804
  """Gets list of accelerators from cloud catalog."""
255
- executor.schedule_request(
805
+ await executor.schedule_request_async(
256
806
  request_id=request.state.request_id,
257
- request_name='list_accelerators',
807
+ request_name=request_names.RequestName.LIST_ACCELERATORS,
258
808
  request_body=list_accelerator_counts_body,
259
- func=service_catalog.list_accelerators,
809
+ func=catalog.list_accelerators,
260
810
  schedule_type=requests_lib.ScheduleType.SHORT,
261
811
  )
262
812
 
@@ -267,11 +817,11 @@ async def list_accelerator_counts(
267
817
  list_accelerator_counts_body: payloads.ListAcceleratorCountsBody
268
818
  ) -> None:
269
819
  """Gets list of accelerator counts from cloud catalog."""
270
- executor.schedule_request(
820
+ await executor.schedule_request_async(
271
821
  request_id=request.state.request_id,
272
- request_name='list_accelerator_counts',
822
+ request_name=request_names.RequestName.LIST_ACCELERATOR_COUNTS,
273
823
  request_body=list_accelerator_counts_body,
274
- func=service_catalog.list_accelerator_counts,
824
+ func=catalog.list_accelerator_counts,
275
825
  schedule_type=requests_lib.ScheduleType.SHORT,
276
826
  )
277
827
 
@@ -289,26 +839,39 @@ async def validate(validate_body: payloads.ValidateBody) -> None:
289
839
  # pairs.
290
840
  logger.debug(f'Validating tasks: {validate_body.dag}')
291
841
 
842
+ context.initialize()
843
+ ctx = context.get()
844
+ assert ctx is not None
845
+ # TODO(aylei): generalize this to all requests without a db record.
846
+ ctx.override_envs(validate_body.env_vars)
847
+
292
848
  def validate_dag(dag: dag_utils.dag_lib.Dag):
293
849
  # TODO: Admin policy may contain arbitrary code, which may be expensive
294
850
  # to run and may block the server thread. However, moving it into the
295
851
  # executor adds a ~150ms penalty on the local API server because of
296
852
  # added RTTs. For now, we stick to doing the validation inline in the
297
853
  # server thread.
298
- dag, _ = admin_policy_utils.apply(
299
- dag, request_options=validate_body.request_options)
300
- # Skip validating workdir and file_mounts, as those need to be
301
- # validated after the files are uploaded to the SkyPilot API server
302
- # with `upload_mounts_to_api_server`.
303
- dag.validate(skip_file_mounts=True, skip_workdir=True)
854
+ with admin_policy_utils.apply_and_use_config_in_current_request(
855
+ dag,
856
+ request_name=request_names.AdminPolicyRequestName.VALIDATE,
857
+ request_options=validate_body.get_request_options()) as dag:
858
+ dag.resolve_and_validate_volumes()
859
+ # Skip validating workdir and file_mounts, as those need to be
860
+ # validated after the files are uploaded to the SkyPilot API server
861
+ # with `upload_mounts_to_api_server`.
862
+ dag.validate(skip_file_mounts=True, skip_workdir=True)
304
863
 
305
864
  try:
306
865
  dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag)
307
- loop = asyncio.get_running_loop()
308
866
  # Apply admin policy and validate DAG is blocking, run it in a separate
309
867
  # thread executor to avoid blocking the uvicorn event loop.
310
- await loop.run_in_executor(None, validate_dag, dag)
868
+ await context_utils.to_thread(validate_dag, dag)
311
869
  except Exception as e: # pylint: disable=broad-except
870
+ # Print the exception to the API server log.
871
+ if env_options.Options.SHOW_DEBUG_INFO.get():
872
+ logger.info('/validate exception:', exc_info=True)
873
+ # Set the exception stacktrace for the serialized exception.
874
+ requests_lib.set_exception_stacktrace(e)
312
875
  raise fastapi.HTTPException(
313
876
  status_code=400, detail=exceptions.serialize_exception(e)) from e
314
877
 
@@ -317,9 +880,9 @@ async def validate(validate_body: payloads.ValidateBody) -> None:
317
880
  async def optimize(optimize_body: payloads.OptimizeBody,
318
881
  request: fastapi.Request) -> None:
319
882
  """Optimizes the user's DAG."""
320
- executor.schedule_request(
883
+ await executor.schedule_request_async(
321
884
  request_id=request.state.request_id,
322
- request_name='optimize',
885
+ request_name=request_names.RequestName.OPTIMIZE,
323
886
  request_body=optimize_body,
324
887
  ignore_return_value=True,
325
888
  func=core.optimize,
@@ -347,16 +910,30 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
347
910
  chunk_index: The chunk index, starting from 0.
348
911
  total_chunks: The total number of chunks.
349
912
  """
913
+ # Field _body would be set if the request body has been received, fail fast
914
+ # to surface potential memory issues, i.e. catch the issue in our smoke
915
+ # test.
916
+ # pylint: disable=protected-access
917
+ if hasattr(request, '_body'):
918
+ raise fastapi.HTTPException(
919
+ status_code=500,
920
+ detail='Upload request body should not be received before streaming'
921
+ )
350
922
  # Add the upload id to the cleanup list.
351
923
  upload_ids_to_cleanup[(upload_id,
352
924
  user_hash)] = (datetime.datetime.now() +
353
925
  _DEFAULT_UPLOAD_EXPIRATION_TIME)
926
+ # For anonymous access, use the user hash from client
927
+ user_id = user_hash
928
+ if request.state.auth_user is not None:
929
+ # Otherwise, the authenticated identity should be used.
930
+ user_id = request.state.auth_user.id
354
931
 
355
932
  # TODO(SKY-1271): We need to double check security of uploading zip file.
356
933
  client_file_mounts_dir = (
357
- common.API_SERVER_CLIENT_DIR.expanduser().resolve() / user_hash /
934
+ common.API_SERVER_CLIENT_DIR.expanduser().resolve() / user_id /
358
935
  'file_mounts')
359
- client_file_mounts_dir.mkdir(parents=True, exist_ok=True)
936
+ await anyio.Path(client_file_mounts_dir).mkdir(parents=True, exist_ok=True)
360
937
 
361
938
  # Check upload_id to be a valid SkyPilot run_timestamp appended with 8 hex
362
939
  # characters, e.g. 'sky-2025-01-17-09-10-13-933602-35d31c22'.
@@ -379,7 +956,7 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
379
956
  zip_file_path = client_file_mounts_dir / f'{upload_id}.zip'
380
957
  else:
381
958
  chunk_dir = client_file_mounts_dir / upload_id
382
- chunk_dir.mkdir(parents=True, exist_ok=True)
959
+ await anyio.Path(chunk_dir).mkdir(parents=True, exist_ok=True)
383
960
  zip_file_path = chunk_dir / f'part{chunk_index}.incomplete'
384
961
 
385
962
  try:
@@ -409,8 +986,9 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
409
986
  zip_file_path.rename(zip_file_path.with_suffix(''))
410
987
  missing_chunks = get_missing_chunks(total_chunks)
411
988
  if missing_chunks:
412
- return payloads.UploadZipFileResponse(status='uploading',
413
- missing_chunks=missing_chunks)
989
+ return payloads.UploadZipFileResponse(
990
+ status=responses.UploadStatus.UPLOADING.value,
991
+ missing_chunks=missing_chunks)
414
992
  zip_file_path = client_file_mounts_dir / f'{upload_id}.zip'
415
993
  async with aiofiles.open(zip_file_path, 'wb') as zip_file:
416
994
  for chunk in range(total_chunks):
@@ -424,10 +1002,11 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
424
1002
  await zip_file.write(data)
425
1003
 
426
1004
  logger.info(f'Uploaded zip file: {zip_file_path}')
427
- unzip_file(zip_file_path, client_file_mounts_dir)
1005
+ await unzip_file(zip_file_path, client_file_mounts_dir)
428
1006
  if total_chunks > 1:
429
- shutil.rmtree(chunk_dir)
430
- return payloads.UploadZipFileResponse(status='completed')
1007
+ await context_utils.to_thread(shutil.rmtree, chunk_dir)
1008
+ return payloads.UploadZipFileResponse(
1009
+ status=responses.UploadStatus.COMPLETED.value)
431
1010
 
432
1011
 
433
1012
  def _is_relative_to(path: pathlib.Path, parent: pathlib.Path) -> bool:
@@ -440,61 +1019,69 @@ def _is_relative_to(path: pathlib.Path, parent: pathlib.Path) -> bool:
440
1019
  return False
441
1020
 
442
1021
 
443
- def unzip_file(zip_file_path: pathlib.Path,
444
- client_file_mounts_dir: pathlib.Path) -> None:
445
- """Unzips a zip file."""
446
- try:
447
- with zipfile.ZipFile(zip_file_path, 'r') as zipf:
448
- for member in zipf.infolist():
449
- # Determine the new path
450
- original_path = os.path.normpath(member.filename)
451
- new_path = client_file_mounts_dir / original_path.lstrip('/')
452
-
453
- if (member.external_attr >> 28) == 0xA:
454
- # Symlink. Read the target path and create a symlink.
1022
+ async def unzip_file(zip_file_path: pathlib.Path,
1023
+ client_file_mounts_dir: pathlib.Path) -> None:
1024
+ """Unzips a zip file without blocking the event loop."""
1025
+
1026
+ def _do_unzip() -> None:
1027
+ try:
1028
+ with zipfile.ZipFile(zip_file_path, 'r') as zipf:
1029
+ for member in zipf.infolist():
1030
+ # Determine the new path
1031
+ original_path = os.path.normpath(member.filename)
1032
+ new_path = client_file_mounts_dir / original_path.lstrip(
1033
+ '/')
1034
+
1035
+ if (member.external_attr >> 28) == 0xA:
1036
+ # Symlink. Read the target path and create a symlink.
1037
+ new_path.parent.mkdir(parents=True, exist_ok=True)
1038
+ target = zipf.read(member).decode()
1039
+ assert not os.path.isabs(target), target
1040
+ # Since target is a relative path, we need to check that
1041
+ # it is under `client_file_mounts_dir` for security.
1042
+ full_target_path = (new_path.parent / target).resolve()
1043
+ if not _is_relative_to(full_target_path,
1044
+ client_file_mounts_dir):
1045
+ raise ValueError(
1046
+ f'Symlink target {target} leads to a '
1047
+ 'file not in userspace. Aborted.')
1048
+
1049
+ if new_path.exists() or new_path.is_symlink():
1050
+ new_path.unlink(missing_ok=True)
1051
+ new_path.symlink_to(
1052
+ target,
1053
+ target_is_directory=member.filename.endswith('/'))
1054
+ continue
1055
+
1056
+ # Handle directories
1057
+ if member.filename.endswith('/'):
1058
+ new_path.mkdir(parents=True, exist_ok=True)
1059
+ continue
1060
+
1061
+ # Handle files
455
1062
  new_path.parent.mkdir(parents=True, exist_ok=True)
456
- target = zipf.read(member).decode()
457
- assert not os.path.isabs(target), target
458
- # Since target is a relative path, we need to check that it
459
- # is under `client_file_mounts_dir` for security.
460
- full_target_path = (new_path.parent / target).resolve()
461
- if not _is_relative_to(full_target_path,
462
- client_file_mounts_dir):
463
- raise ValueError(f'Symlink target {target} leads to a '
464
- 'file not in userspace. Aborted.')
465
-
466
- if new_path.exists() or new_path.is_symlink():
467
- new_path.unlink(missing_ok=True)
468
- new_path.symlink_to(
469
- target,
470
- target_is_directory=member.filename.endswith('/'))
471
- continue
472
-
473
- # Handle directories
474
- if member.filename.endswith('/'):
475
- new_path.mkdir(parents=True, exist_ok=True)
476
- continue
477
-
478
- # Handle files
479
- new_path.parent.mkdir(parents=True, exist_ok=True)
480
- with zipf.open(member) as member_file, new_path.open('wb') as f:
481
- # Use shutil.copyfileobj to copy files in chunks, so it does
482
- # not load the entire file into memory.
483
- shutil.copyfileobj(member_file, f)
484
- except zipfile.BadZipFile as e:
485
- logger.error(f'Bad zip file: {zip_file_path}')
486
- raise fastapi.HTTPException(
487
- status_code=400,
488
- detail=f'Invalid zip file: {common_utils.format_exception(e)}')
489
- except Exception as e:
490
- logger.error(f'Error unzipping file: {zip_file_path}')
491
- raise fastapi.HTTPException(
492
- status_code=500,
493
- detail=(f'Error unzipping file: '
494
- f'{common_utils.format_exception(e)}'))
1063
+ with zipf.open(member) as member_file, new_path.open(
1064
+ 'wb') as f:
1065
+ # Use shutil.copyfileobj to copy files in chunks,
1066
+ # so it does not load the entire file into memory.
1067
+ shutil.copyfileobj(member_file, f)
1068
+ except zipfile.BadZipFile as e:
1069
+ logger.error(f'Bad zip file: {zip_file_path}')
1070
+ raise fastapi.HTTPException(
1071
+ status_code=400,
1072
+ detail=f'Invalid zip file: {common_utils.format_exception(e)}')
1073
+ except Exception as e:
1074
+ logger.error(f'Error unzipping file: {zip_file_path}')
1075
+ raise fastapi.HTTPException(
1076
+ status_code=500,
1077
+ detail=(f'Error unzipping file: '
1078
+ f'{common_utils.format_exception(e)}'))
1079
+ finally:
1080
+ # Cleanup the temporary file regardless of
1081
+ # success/failure handling above
1082
+ zip_file_path.unlink(missing_ok=True)
495
1083
 
496
- # Cleanup the temporary file
497
- zip_file_path.unlink()
1084
+ await context_utils.to_thread(_do_unzip)
498
1085
 
499
1086
 
500
1087
  @app.post('/launch')
@@ -503,13 +1090,14 @@ async def launch(launch_body: payloads.LaunchBody,
503
1090
  """Launches a cluster or task."""
504
1091
  request_id = request.state.request_id
505
1092
  logger.info(f'Launching request: {request_id}')
506
- executor.schedule_request(
1093
+ await executor.schedule_request_async(
507
1094
  request_id,
508
- request_name='launch',
1095
+ request_name=request_names.RequestName.CLUSTER_LAUNCH,
509
1096
  request_body=launch_body,
510
1097
  func=execution.launch,
511
1098
  schedule_type=requests_lib.ScheduleType.LONG,
512
1099
  request_cluster_name=launch_body.cluster_name,
1100
+ retryable=launch_body.retry_until_up,
513
1101
  )
514
1102
 
515
1103
 
@@ -518,9 +1106,9 @@ async def launch(launch_body: payloads.LaunchBody,
518
1106
  async def exec(request: fastapi.Request, exec_body: payloads.ExecBody) -> None:
519
1107
  """Executes a task on an existing cluster."""
520
1108
  cluster_name = exec_body.cluster_name
521
- executor.schedule_request(
1109
+ await executor.schedule_request_async(
522
1110
  request_id=request.state.request_id,
523
- request_name='exec',
1111
+ request_name=request_names.RequestName.CLUSTER_EXEC,
524
1112
  request_body=exec_body,
525
1113
  func=execution.exec,
526
1114
  precondition=preconditions.ClusterStartCompletePrecondition(
@@ -536,9 +1124,9 @@ async def exec(request: fastapi.Request, exec_body: payloads.ExecBody) -> None:
536
1124
  async def stop(request: fastapi.Request,
537
1125
  stop_body: payloads.StopOrDownBody) -> None:
538
1126
  """Stops a cluster."""
539
- executor.schedule_request(
1127
+ await executor.schedule_request_async(
540
1128
  request_id=request.state.request_id,
541
- request_name='stop',
1129
+ request_name=request_names.RequestName.CLUSTER_STOP,
542
1130
  request_body=stop_body,
543
1131
  func=core.stop,
544
1132
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -552,9 +1140,13 @@ async def status(
552
1140
  status_body: payloads.StatusBody = payloads.StatusBody()
553
1141
  ) -> None:
554
1142
  """Gets cluster statuses."""
555
- executor.schedule_request(
1143
+ if state.get_block_requests():
1144
+ raise fastapi.HTTPException(
1145
+ status_code=503,
1146
+ detail='Server is shutting down, please try again later.')
1147
+ await executor.schedule_request_async(
556
1148
  request_id=request.state.request_id,
557
- request_name='status',
1149
+ request_name=request_names.RequestName.CLUSTER_STATUS,
558
1150
  request_body=status_body,
559
1151
  func=core.status,
560
1152
  schedule_type=(requests_lib.ScheduleType.LONG if
@@ -567,9 +1159,9 @@ async def status(
567
1159
  async def endpoints(request: fastapi.Request,
568
1160
  endpoint_body: payloads.EndpointsBody) -> None:
569
1161
  """Gets the endpoint for a given cluster and port number (endpoint)."""
570
- executor.schedule_request(
1162
+ await executor.schedule_request_async(
571
1163
  request_id=request.state.request_id,
572
- request_name='endpoints',
1164
+ request_name=request_names.RequestName.CLUSTER_ENDPOINTS,
573
1165
  request_body=endpoint_body,
574
1166
  func=core.endpoints,
575
1167
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -581,9 +1173,9 @@ async def endpoints(request: fastapi.Request,
581
1173
  async def down(request: fastapi.Request,
582
1174
  down_body: payloads.StopOrDownBody) -> None:
583
1175
  """Tears down a cluster."""
584
- executor.schedule_request(
1176
+ await executor.schedule_request_async(
585
1177
  request_id=request.state.request_id,
586
- request_name='down',
1178
+ request_name=request_names.RequestName.CLUSTER_DOWN,
587
1179
  request_body=down_body,
588
1180
  func=core.down,
589
1181
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -595,9 +1187,9 @@ async def down(request: fastapi.Request,
595
1187
  async def start(request: fastapi.Request,
596
1188
  start_body: payloads.StartBody) -> None:
597
1189
  """Restarts a cluster."""
598
- executor.schedule_request(
1190
+ await executor.schedule_request_async(
599
1191
  request_id=request.state.request_id,
600
- request_name='start',
1192
+ request_name=request_names.RequestName.CLUSTER_START,
601
1193
  request_body=start_body,
602
1194
  func=core.start,
603
1195
  schedule_type=requests_lib.ScheduleType.LONG,
@@ -609,9 +1201,9 @@ async def start(request: fastapi.Request,
609
1201
  async def autostop(request: fastapi.Request,
610
1202
  autostop_body: payloads.AutostopBody) -> None:
611
1203
  """Schedules an autostop/autodown for a cluster."""
612
- executor.schedule_request(
1204
+ await executor.schedule_request_async(
613
1205
  request_id=request.state.request_id,
614
- request_name='autostop',
1206
+ request_name=request_names.RequestName.CLUSTER_AUTOSTOP,
615
1207
  request_body=autostop_body,
616
1208
  func=core.autostop,
617
1209
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -623,9 +1215,9 @@ async def autostop(request: fastapi.Request,
623
1215
  async def queue(request: fastapi.Request,
624
1216
  queue_body: payloads.QueueBody) -> None:
625
1217
  """Gets the job queue of a cluster."""
626
- executor.schedule_request(
1218
+ await executor.schedule_request_async(
627
1219
  request_id=request.state.request_id,
628
- request_name='queue',
1220
+ request_name=request_names.RequestName.CLUSTER_QUEUE,
629
1221
  request_body=queue_body,
630
1222
  func=core.queue,
631
1223
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -637,9 +1229,9 @@ async def queue(request: fastapi.Request,
637
1229
  async def job_status(request: fastapi.Request,
638
1230
  job_status_body: payloads.JobStatusBody) -> None:
639
1231
  """Gets the status of a job."""
640
- executor.schedule_request(
1232
+ await executor.schedule_request_async(
641
1233
  request_id=request.state.request_id,
642
- request_name='job_status',
1234
+ request_name=request_names.RequestName.CLUSTER_JOB_STATUS,
643
1235
  request_body=job_status_body,
644
1236
  func=core.job_status,
645
1237
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -651,9 +1243,9 @@ async def job_status(request: fastapi.Request,
651
1243
  async def cancel(request: fastapi.Request,
652
1244
  cancel_body: payloads.CancelBody) -> None:
653
1245
  """Cancels jobs on a cluster."""
654
- executor.schedule_request(
1246
+ await executor.schedule_request_async(
655
1247
  request_id=request.state.request_id,
656
- request_name='cancel',
1248
+ request_name=request_names.RequestName.CLUSTER_JOB_CANCEL,
657
1249
  request_body=cancel_body,
658
1250
  func=core.cancel,
659
1251
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -670,36 +1262,27 @@ async def logs(
670
1262
  # TODO(zhwu): This should wait for the request on the cluster, e.g., async
671
1263
  # launch, to finish, so that a user does not need to manually pull the
672
1264
  # request status.
673
- executor.schedule_request(
1265
+ executor.check_request_thread_executor_available()
1266
+ request_task = await executor.prepare_request_async(
674
1267
  request_id=request.state.request_id,
675
- request_name='logs',
1268
+ request_name=request_names.RequestName.CLUSTER_JOB_LOGS,
676
1269
  request_body=cluster_job_body,
677
1270
  func=core.tail_logs,
678
- # TODO(aylei): We have tail logs scheduled as SHORT request, because it
679
- # should be responsive. However, it can be long running if the user's
680
- # job keeps running, and we should avoid it taking the SHORT worker.
681
1271
  schedule_type=requests_lib.ScheduleType.SHORT,
682
1272
  request_cluster_name=cluster_job_body.cluster_name,
683
1273
  )
684
-
685
- request_task = requests_lib.get_request(request.state.request_id)
686
-
1274
+ task = executor.execute_request_in_coroutine(request_task)
1275
+ background_tasks.add_task(task.cancel)
687
1276
  # TODO(zhwu): This makes viewing logs in browser impossible. We should adopt
688
1277
  # the same approach as /stream.
689
- return stream_utils.stream_response(
690
- request_id=request_task.request_id,
1278
+ return stream_utils.stream_response_for_long_request(
1279
+ request_id=request.state.request_id,
691
1280
  logs_path=request_task.log_path,
692
1281
  background_tasks=background_tasks,
1282
+ kill_request_on_disconnect=False,
693
1283
  )
694
1284
 
695
1285
 
696
- @app.get('/users')
697
- async def users() -> List[Dict[str, Any]]:
698
- """Gets all users."""
699
- user_list = global_user_state.get_all_users()
700
- return [user.to_dict() for user in user_list]
701
-
702
-
703
1286
  @app.post('/download_logs')
704
1287
  async def download_logs(
705
1288
  request: fastapi.Request,
@@ -711,9 +1294,9 @@ async def download_logs(
711
1294
  # We should reuse the original request body, so that the env vars, such as
712
1295
  # user hash, are kept the same.
713
1296
  cluster_jobs_body.local_dir = str(logs_dir_on_api_server)
714
- executor.schedule_request(
1297
+ await executor.schedule_request_async(
715
1298
  request_id=request.state.request_id,
716
- request_name='download_logs',
1299
+ request_name=request_names.RequestName.CLUSTER_JOB_DOWNLOAD_LOGS,
717
1300
  request_body=cluster_jobs_body,
718
1301
  func=core.download_logs,
719
1302
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -722,7 +1305,8 @@ async def download_logs(
722
1305
 
723
1306
 
724
1307
  @app.post('/download')
725
- async def download(download_body: payloads.DownloadBody) -> None:
1308
+ async def download(download_body: payloads.DownloadBody,
1309
+ request: fastapi.Request) -> None:
726
1310
  """Downloads a folder from the cluster to the local machine."""
727
1311
  folder_paths = [
728
1312
  pathlib.Path(folder_path) for folder_path in download_body.folder_paths
@@ -747,11 +1331,25 @@ async def download(download_body: payloads.DownloadBody) -> None:
747
1331
  logs_dir_on_api_server).expanduser().resolve() / zip_filename
748
1332
 
749
1333
  try:
750
- folders = [
751
- str(folder_path.expanduser().resolve())
752
- for folder_path in folder_paths
753
- ]
754
- storage_utils.zip_files_and_folders(folders, zip_path)
1334
+
1335
+ def _zip_files_and_folders(folder_paths, zip_path):
1336
+ folders = [
1337
+ str(folder_path.expanduser().resolve())
1338
+ for folder_path in folder_paths
1339
+ ]
1340
+ # Check for optional query parameter to control zip entry structure
1341
+ relative = request.query_params.get('relative', 'home')
1342
+ if relative == 'items':
1343
+ # Dashboard-friendly: entries relative to selected folders
1344
+ storage_utils.zip_files_and_folders(folders,
1345
+ zip_path,
1346
+ relative_to_items=True)
1347
+ else:
1348
+ # CLI-friendly (default): entries with full paths for mapping
1349
+ storage_utils.zip_files_and_folders(folders, zip_path)
1350
+
1351
+ await context_utils.to_thread(_zip_files_and_folders, folder_paths,
1352
+ zip_path)
755
1353
 
756
1354
  # Add home path to the response headers, so that the client can replace
757
1355
  # the remote path in the zip file to the local path.
@@ -773,13 +1371,84 @@ async def download(download_body: payloads.DownloadBody) -> None:
773
1371
  detail=f'Error creating zip file: {str(e)}')
774
1372
 
775
1373
 
776
- @app.get('/cost_report')
777
- async def cost_report(request: fastapi.Request) -> None:
1374
+ # TODO(aylei): run it asynchronously after global_user_state support async op
1375
+ @app.post('/provision_logs')
1376
+ def provision_logs(provision_logs_body: payloads.ProvisionLogsBody,
1377
+ follow: bool = True,
1378
+ tail: int = 0) -> fastapi.responses.StreamingResponse:
1379
+ """Streams the provision.log for the latest launch request of a cluster."""
1380
+ log_path = None
1381
+ cluster_name = provision_logs_body.cluster_name
1382
+ worker = provision_logs_body.worker
1383
+ # stream head node logs
1384
+ if worker is None:
1385
+ # Prefer clusters table first, then cluster_history as fallback.
1386
+ log_path_str = global_user_state.get_cluster_provision_log_path(
1387
+ cluster_name)
1388
+ if not log_path_str:
1389
+ log_path_str = (
1390
+ global_user_state.get_cluster_history_provision_log_path(
1391
+ cluster_name))
1392
+ if not log_path_str:
1393
+ raise fastapi.HTTPException(
1394
+ status_code=404,
1395
+ detail=('Provision log path is not recorded for this cluster. '
1396
+ 'Please relaunch to generate provisioning logs.'))
1397
+ log_path = pathlib.Path(log_path_str).expanduser().resolve()
1398
+ if not log_path.exists():
1399
+ raise fastapi.HTTPException(
1400
+ status_code=404,
1401
+ detail=f'Provision log path does not exist: {str(log_path)}')
1402
+
1403
+ # stream worker node logs
1404
+ else:
1405
+ handle = global_user_state.get_handle_from_cluster_name(cluster_name)
1406
+ if handle is None:
1407
+ raise fastapi.HTTPException(
1408
+ status_code=404,
1409
+ detail=('Cluster handle is not recorded for this cluster. '
1410
+ 'Please relaunch to generate provisioning logs.'))
1411
+ # instance_ids includes head node
1412
+ instance_ids = handle.instance_ids
1413
+ if instance_ids is None:
1414
+ raise fastapi.HTTPException(
1415
+ status_code=400,
1416
+ detail='Instance IDs are not recorded for this cluster. '
1417
+ 'Please relaunch to generate provisioning logs.')
1418
+ if worker > len(instance_ids) - 1:
1419
+ raise fastapi.HTTPException(
1420
+ status_code=400,
1421
+ detail=f'Worker {worker} is out of range. '
1422
+ f'The cluster has {len(instance_ids)} nodes.')
1423
+ log_path = metadata_utils.get_instance_log_dir(
1424
+ handle.get_cluster_name_on_cloud(), instance_ids[worker])
1425
+
1426
+ # Tail semantics: 0 means print all lines. Convert 0 -> None for streamer.
1427
+ effective_tail = None if tail is None or tail <= 0 else tail
1428
+
1429
+ return fastapi.responses.StreamingResponse(
1430
+ content=stream_utils.log_streamer(None,
1431
+ log_path,
1432
+ tail=effective_tail,
1433
+ follow=follow,
1434
+ cluster_name=cluster_name),
1435
+ media_type='text/plain',
1436
+ headers={
1437
+ 'Cache-Control': 'no-cache, no-transform',
1438
+ 'X-Accel-Buffering': 'no',
1439
+ 'Transfer-Encoding': 'chunked',
1440
+ },
1441
+ )
1442
+
1443
+
1444
+ @app.post('/cost_report')
1445
+ async def cost_report(request: fastapi.Request,
1446
+ cost_report_body: payloads.CostReportBody) -> None:
778
1447
  """Gets the cost report of a cluster."""
779
- executor.schedule_request(
1448
+ await executor.schedule_request_async(
780
1449
  request_id=request.state.request_id,
781
- request_name='cost_report',
782
- request_body=payloads.RequestBody(),
1450
+ request_name=request_names.RequestName.CLUSTER_COST_REPORT,
1451
+ request_body=cost_report_body,
783
1452
  func=core.cost_report,
784
1453
  schedule_type=requests_lib.ScheduleType.SHORT,
785
1454
  )
@@ -788,9 +1457,9 @@ async def cost_report(request: fastapi.Request) -> None:
788
1457
  @app.get('/storage/ls')
789
1458
  async def storage_ls(request: fastapi.Request) -> None:
790
1459
  """Gets the storages."""
791
- executor.schedule_request(
1460
+ await executor.schedule_request_async(
792
1461
  request_id=request.state.request_id,
793
- request_name='storage_ls',
1462
+ request_name=request_names.RequestName.STORAGE_LS,
794
1463
  request_body=payloads.RequestBody(),
795
1464
  func=core.storage_ls,
796
1465
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -801,9 +1470,9 @@ async def storage_ls(request: fastapi.Request) -> None:
801
1470
  async def storage_delete(request: fastapi.Request,
802
1471
  storage_body: payloads.StorageBody) -> None:
803
1472
  """Deletes a storage."""
804
- executor.schedule_request(
1473
+ await executor.schedule_request_async(
805
1474
  request_id=request.state.request_id,
806
- request_name='storage_delete',
1475
+ request_name=request_names.RequestName.STORAGE_DELETE,
807
1476
  request_body=storage_body,
808
1477
  func=core.storage_delete,
809
1478
  schedule_type=requests_lib.ScheduleType.LONG,
@@ -814,9 +1483,9 @@ async def storage_delete(request: fastapi.Request,
814
1483
  async def local_up(request: fastapi.Request,
815
1484
  local_up_body: payloads.LocalUpBody) -> None:
816
1485
  """Launches a Kubernetes cluster on API server."""
817
- executor.schedule_request(
1486
+ await executor.schedule_request_async(
818
1487
  request_id=request.state.request_id,
819
- request_name='local_up',
1488
+ request_name=request_names.RequestName.LOCAL_UP,
820
1489
  request_body=local_up_body,
821
1490
  func=core.local_up,
822
1491
  schedule_type=requests_lib.ScheduleType.LONG,
@@ -824,37 +1493,65 @@ async def local_up(request: fastapi.Request,
824
1493
 
825
1494
 
826
1495
  @app.post('/local_down')
827
- async def local_down(request: fastapi.Request) -> None:
1496
+ async def local_down(request: fastapi.Request,
1497
+ local_down_body: payloads.LocalDownBody) -> None:
828
1498
  """Tears down the Kubernetes cluster started by local_up."""
829
- executor.schedule_request(
1499
+ await executor.schedule_request_async(
830
1500
  request_id=request.state.request_id,
831
- request_name='local_down',
832
- request_body=payloads.RequestBody(),
1501
+ request_name=request_names.RequestName.LOCAL_DOWN,
1502
+ request_body=local_down_body,
833
1503
  func=core.local_down,
834
1504
  schedule_type=requests_lib.ScheduleType.LONG,
835
1505
  )
836
1506
 
837
1507
 
1508
+ async def get_expanded_request_id(request_id: str) -> str:
1509
+ """Gets the expanded request ID for a given request ID prefix."""
1510
+ request_tasks = await requests_lib.get_requests_async_with_prefix(
1511
+ request_id, fields=['request_id'])
1512
+ if request_tasks is None:
1513
+ raise fastapi.HTTPException(status_code=404,
1514
+ detail=f'Request {request_id!r} not found')
1515
+ if len(request_tasks) > 1:
1516
+ raise fastapi.HTTPException(status_code=400,
1517
+ detail=('Multiple requests found for '
1518
+ f'request ID prefix: {request_id}'))
1519
+ return request_tasks[0].request_id
1520
+
1521
+
838
1522
  # === API server related APIs ===
839
- @app.get('/api/get')
840
- async def api_get(request_id: str) -> requests_lib.RequestPayload:
1523
+ @app.get('/api/get', response_class=fastapi_responses.ORJSONResponse)
1524
+ async def api_get(request_id: str) -> payloads.RequestPayload:
841
1525
  """Gets a request with a given request ID prefix."""
1526
+ # Validate request_id prefix matches a single request.
1527
+ request_id = await get_expanded_request_id(request_id)
1528
+
842
1529
  while True:
843
- request_task = requests_lib.get_request(request_id)
844
- if request_task is None:
1530
+ req_status = await requests_lib.get_request_status_async(request_id)
1531
+ if req_status is None:
845
1532
  print(f'No task with request ID {request_id}', flush=True)
846
1533
  raise fastapi.HTTPException(
847
1534
  status_code=404, detail=f'Request {request_id!r} not found')
848
- if request_task.status > requests_lib.RequestStatus.RUNNING:
849
- request_error = request_task.get_error()
850
- if request_error is not None:
851
- raise fastapi.HTTPException(status_code=500,
852
- detail=dataclasses.asdict(
853
- request_task.encode()))
854
- return request_task.encode()
1535
+ if (req_status.status == requests_lib.RequestStatus.RUNNING and
1536
+ daemons.is_daemon_request_id(request_id)):
1537
+ # Daemon requests run forever, break without waiting for complete.
1538
+ break
1539
+ if req_status.status > requests_lib.RequestStatus.RUNNING:
1540
+ break
855
1541
  # yield control to allow other coroutines to run, sleep shortly
856
1542
  # to avoid storming the DB and CPU in the meantime
857
1543
  await asyncio.sleep(0.1)
1544
+ request_task = await requests_lib.get_request_async(request_id)
1545
+ # TODO(aylei): refine this, /api/get will not be retried and this is
1546
+ # meaningless to retry. It is the original request that should be retried.
1547
+ if request_task.should_retry:
1548
+ raise fastapi.HTTPException(
1549
+ status_code=503, detail=f'Request {request_id!r} should be retried')
1550
+ request_error = request_task.get_error()
1551
+ if request_error is not None:
1552
+ raise fastapi.HTTPException(status_code=500,
1553
+ detail=request_task.encode().model_dump())
1554
+ return request_task.encode()
858
1555
 
859
1556
 
860
1557
  @app.get('/api/stream')
@@ -888,13 +1585,18 @@ async def stream(
888
1585
  clients, console for CLI/API clients), 'plain' (force plain text),
889
1586
  'html' (force HTML), or 'console' (force console)
890
1587
  """
1588
+ # We need to save the user-supplied request ID for the response header.
1589
+ user_supplied_request_id = request_id
891
1590
  if request_id is not None and log_path is not None:
892
1591
  raise fastapi.HTTPException(
893
1592
  status_code=400,
894
1593
  detail='Only one of request_id and log_path can be provided')
895
1594
 
1595
+ if request_id is not None:
1596
+ request_id = await get_expanded_request_id(request_id)
1597
+
896
1598
  if request_id is None and log_path is None:
897
- request_id = requests_lib.get_latest_request_id()
1599
+ request_id = await requests_lib.get_latest_request_id_async()
898
1600
  if request_id is None:
899
1601
  raise fastapi.HTTPException(status_code=404,
900
1602
  detail='No request found')
@@ -921,19 +1623,40 @@ async def stream(
921
1623
  'X-Accel-Buffering': 'no'
922
1624
  })
923
1625
 
1626
+ polling_interval = stream_utils.DEFAULT_POLL_INTERVAL
924
1627
  # Original plain text streaming logic
925
1628
  if request_id is not None:
926
- request_task = requests_lib.get_request(request_id)
1629
+ request_task = await requests_lib.get_request_async(
1630
+ request_id, fields=['request_id', 'schedule_type'])
927
1631
  if request_task is None:
928
1632
  print(f'No task with request ID {request_id}')
929
1633
  raise fastapi.HTTPException(
930
1634
  status_code=404, detail=f'Request {request_id!r} not found')
1635
+ # req.log_path is derived from request_id,
1636
+ # so it's ok to just grab the request_id in the above query.
931
1637
  log_path_to_stream = request_task.log_path
1638
+ if not log_path_to_stream.exists():
1639
+ # The log file might be deleted by the request GC daemon but the
1640
+ # request task is still in the database.
1641
+ raise fastapi.HTTPException(
1642
+ status_code=404,
1643
+ detail=f'Log of request {request_id!r} has been deleted')
1644
+ if request_task.schedule_type == requests_lib.ScheduleType.LONG:
1645
+ polling_interval = stream_utils.LONG_REQUEST_POLL_INTERVAL
1646
+ del request_task
932
1647
  else:
933
1648
  assert log_path is not None, (request_id, log_path)
934
1649
  if log_path == constants.API_SERVER_LOGS:
935
1650
  resolved_log_path = pathlib.Path(
936
1651
  constants.API_SERVER_LOGS).expanduser()
1652
+ if not resolved_log_path.exists():
1653
+ raise fastapi.HTTPException(
1654
+ status_code=404,
1655
+ detail='Server log file does not exist. The API server may '
1656
+ 'have been started with `--foreground` - check the '
1657
+ 'stdout of API server process, such as: '
1658
+ '`kubectl logs -n api-server-namespace '
1659
+ 'api-server-pod-name`')
937
1660
  else:
938
1661
  # This should be a log path under ~/sky_logs.
939
1662
  resolved_logs_directory = pathlib.Path(
@@ -954,18 +1677,26 @@ async def stream(
954
1677
  detail=f'Log path {log_path!r} does not exist')
955
1678
 
956
1679
  log_path_to_stream = resolved_log_path
1680
+
1681
+ headers = {
1682
+ 'Cache-Control': 'no-cache, no-transform',
1683
+ 'X-Accel-Buffering': 'no',
1684
+ 'Transfer-Encoding': 'chunked'
1685
+ }
1686
+ if request_id is not None:
1687
+ headers[server_constants.STREAM_REQUEST_HEADER] = (
1688
+ user_supplied_request_id
1689
+ if user_supplied_request_id else request_id)
1690
+
957
1691
  return fastapi.responses.StreamingResponse(
958
1692
  content=stream_utils.log_streamer(request_id,
959
1693
  log_path_to_stream,
960
1694
  plain_logs=format == 'plain',
961
1695
  tail=tail,
962
- follow=follow),
1696
+ follow=follow,
1697
+ polling_interval=polling_interval),
963
1698
  media_type='text/plain',
964
- headers={
965
- 'Cache-Control': 'no-cache, no-transform',
966
- 'X-Accel-Buffering': 'no',
967
- 'Transfer-Encoding': 'chunked'
968
- },
1699
+ headers=headers,
969
1700
  )
970
1701
 
971
1702
 
@@ -973,11 +1704,11 @@ async def stream(
973
1704
  async def api_cancel(request: fastapi.Request,
974
1705
  request_cancel_body: payloads.RequestCancelBody) -> None:
975
1706
  """Cancels requests."""
976
- executor.schedule_request(
1707
+ await executor.schedule_request_async(
977
1708
  request_id=request.state.request_id,
978
- request_name='api_cancel',
1709
+ request_name=request_names.RequestName.API_CANCEL,
979
1710
  request_body=request_cancel_body,
980
- func=requests_lib.kill_requests,
1711
+ func=requests_lib.kill_requests_with_prefix,
981
1712
  schedule_type=requests_lib.ScheduleType.SHORT,
982
1713
  )
983
1714
 
@@ -985,10 +1716,14 @@ async def api_cancel(request: fastapi.Request,
985
1716
  @app.get('/api/status')
986
1717
  async def api_status(
987
1718
  request_ids: Optional[List[str]] = fastapi.Query(
988
- None, description='Request IDs to get status for.'),
1719
+ None, description='Request ID prefixes to get status for.'),
989
1720
  all_status: bool = fastapi.Query(
990
1721
  False, description='Get finished requests as well.'),
991
- ) -> List[requests_lib.RequestPayload]:
1722
+ limit: Optional[int] = fastapi.Query(
1723
+ None, description='Number of requests to show.'),
1724
+ fields: Optional[List[str]] = fastapi.Query(
1725
+ None, description='Fields to get. If None, get all fields.'),
1726
+ ) -> List[payloads.RequestPayload]:
992
1727
  """Gets the list of requests."""
993
1728
  if request_ids is None:
994
1729
  statuses = None
@@ -997,53 +1732,120 @@ async def api_status(
997
1732
  requests_lib.RequestStatus.PENDING,
998
1733
  requests_lib.RequestStatus.RUNNING,
999
1734
  ]
1000
- return [
1001
- request_task.readable_encode()
1002
- for request_task in requests_lib.get_request_tasks(status=statuses)
1003
- ]
1735
+ request_tasks = await requests_lib.get_request_tasks_async(
1736
+ req_filter=requests_lib.RequestTaskFilter(
1737
+ status=statuses,
1738
+ limit=limit,
1739
+ fields=fields,
1740
+ sort=True,
1741
+ ))
1742
+ return requests_lib.encode_requests(request_tasks)
1004
1743
  else:
1005
1744
  encoded_request_tasks = []
1006
1745
  for request_id in request_ids:
1007
- request_task = requests_lib.get_request(request_id)
1008
- if request_task is None:
1746
+ request_tasks = await requests_lib.get_requests_async_with_prefix(
1747
+ request_id)
1748
+ if request_tasks is None:
1009
1749
  continue
1010
- encoded_request_tasks.append(request_task.readable_encode())
1750
+ for request_task in request_tasks:
1751
+ encoded_request_tasks.append(request_task.readable_encode())
1011
1752
  return encoded_request_tasks
1012
1753
 
1013
1754
 
1014
- @app.get('/api/health')
1015
- async def health() -> Dict[str, str]:
1755
+ @app.get(
1756
+ '/api/health',
1757
+ # response_model_exclude_unset omits unset fields
1758
+ # in the response JSON.
1759
+ response_model_exclude_unset=True)
1760
+ async def health(request: fastapi.Request) -> responses.APIHealthResponse:
1016
1761
  """Checks the health of the API server.
1017
1762
 
1018
1763
  Returns:
1019
- A dictionary with the following keys:
1020
- - status: str; The status of the API server.
1021
- - api_version: str; The API version of the API server.
1022
- - version: str; The version of SkyPilot used for API server.
1023
- - version_on_disk: str; The version of the SkyPilot installation on
1024
- disk, which can be used to warn about restarting the API server
1025
- - commit: str; The commit hash of SkyPilot used for API server.
1764
+ responses.APIHealthResponse: The health response.
1026
1765
  """
1027
- return {
1028
- 'status': common.ApiServerStatus.HEALTHY.value,
1029
- 'api_version': server_constants.API_VERSION,
1030
- 'version': sky.__version__,
1031
- 'version_on_disk': common.get_skypilot_version_on_disk(),
1032
- 'commit': sky.__commit__,
1033
- }
1766
+ user = request.state.auth_user
1767
+ server_status = common.ApiServerStatus.HEALTHY
1768
+ if getattr(request.state, 'anonymous_user', False):
1769
+ # API server authentication is enabled, but the request is not
1770
+ # authenticated. We still have to serve the request because the
1771
+ # /api/health endpoint has two different usage:
1772
+ # 1. For health check from `api start` and external ochestration
1773
+ # tools (k8s), which does not require authentication and user info.
1774
+ # 2. Return server info to client and hint client to login if required.
1775
+ # Separating these two usage to different APIs will break backward
1776
+ # compatibility for existing ochestration solutions (e.g. helm chart).
1777
+ # So we serve these two usages in a backward compatible manner below.
1778
+ client_version = versions.get_remote_api_version()
1779
+ # - For Client with API version >= 14, we return 200 response with
1780
+ # status=NEEDS_AUTH, new client will handle the login process.
1781
+ # - For health check from `sky api start`, the client code always uses
1782
+ # the same API version with the server, thus there is no compatibility
1783
+ # issue.
1784
+ server_status = common.ApiServerStatus.NEEDS_AUTH
1785
+ if client_version is None:
1786
+ # - For health check from ochestration tools (e.g. k8s), we also
1787
+ # return 200 with status=NEEDS_AUTH, which passes HTTP probe
1788
+ # check.
1789
+ # - There is no harm when an malicious client calls /api/health
1790
+ # without authentication since no sensitive information is
1791
+ # returned.
1792
+ return responses.APIHealthResponse(
1793
+ status=common.ApiServerStatus.HEALTHY,)
1794
+ # TODO(aylei): remove this after min_compatible_api_version >= 14.
1795
+ if client_version < 14:
1796
+ # For Client with API version < 14, the NEEDS_AUTH status is not
1797
+ # honored. Return 401 to trigger the login process.
1798
+ raise fastapi.HTTPException(status_code=401,
1799
+ detail='Authentication required')
1800
+
1801
+ logger.debug(f'Health endpoint: request.state.auth_user = {user}')
1802
+ return responses.APIHealthResponse(
1803
+ status=server_status,
1804
+ # Kept for backward compatibility, clients before 0.11.0 will read this
1805
+ # field to check compatibility and hint the user to upgrade the CLI.
1806
+ # TODO(aylei): remove this field after 0.13.0
1807
+ api_version=str(server_constants.API_VERSION),
1808
+ version=sky.__version__,
1809
+ version_on_disk=common.get_skypilot_version_on_disk(),
1810
+ commit=sky.__commit__,
1811
+ # Whether basic auth on api server is enabled
1812
+ basic_auth_enabled=os.environ.get(constants.ENV_VAR_ENABLE_BASIC_AUTH,
1813
+ 'false').lower() == 'true',
1814
+ user=user if user is not None else None,
1815
+ # Whether service account token is enabled
1816
+ service_account_token_enabled=(os.environ.get(
1817
+ constants.ENV_VAR_ENABLE_SERVICE_ACCOUNTS,
1818
+ 'false').lower() == 'true'),
1819
+ # Whether basic auth on ingress is enabled
1820
+ ingress_basic_auth_enabled=os.environ.get(
1821
+ constants.SKYPILOT_INGRESS_BASIC_AUTH_ENABLED,
1822
+ 'false').lower() == 'true',
1823
+ )
1824
+
1825
+
1826
+ class KubernetesSSHMessageType(IntEnum):
1827
+ REGULAR_DATA = 0
1828
+ PINGPONG = 1
1829
+ LATENCY_MEASUREMENT = 2
1034
1830
 
1035
1831
 
1036
1832
  @app.websocket('/kubernetes-pod-ssh-proxy')
1037
1833
  async def kubernetes_pod_ssh_proxy(
1038
- websocket: fastapi.WebSocket,
1039
- cluster_name_body: payloads.ClusterNameBody = fastapi.Depends()
1040
- ) -> None:
1834
+ websocket: fastapi.WebSocket,
1835
+ cluster_name: str,
1836
+ client_version: Optional[int] = None) -> None:
1041
1837
  """Proxies SSH to the Kubernetes pod with websocket."""
1042
1838
  await websocket.accept()
1043
- cluster_name = cluster_name_body.cluster_name
1044
1839
  logger.info(f'WebSocket connection accepted for cluster: {cluster_name}')
1045
1840
 
1046
- cluster_records = core.status(cluster_name, all_users=True)
1841
+ timestamps_supported = client_version is not None and client_version > 21
1842
+ logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
1843
+ client_version = {client_version}')
1844
+
1845
+ # Run core.status in another thread to avoid blocking the event loop.
1846
+ with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
1847
+ cluster_records = await context_utils.to_thread_with_executor(
1848
+ thread_pool_executor, core.status, cluster_name, all_users=True)
1047
1849
  cluster_record = cluster_records[0]
1048
1850
  if cluster_record['status'] != status_lib.ClusterStatus.UP:
1049
1851
  raise fastapi.HTTPException(
@@ -1082,17 +1884,70 @@ async def kubernetes_pod_ssh_proxy(
1082
1884
  return
1083
1885
 
1084
1886
  logger.info(f'Starting port-forward to local port: {local_port}')
1887
+ conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
1888
+ pid=os.getpid())
1889
+ ssh_failed = False
1890
+ websocket_closed = False
1085
1891
  try:
1892
+ conn_gauge.inc()
1086
1893
  # Connect to the local port
1087
1894
  reader, writer = await asyncio.open_connection('127.0.0.1', local_port)
1088
1895
 
1089
1896
  async def websocket_to_ssh():
1090
1897
  try:
1091
1898
  async for message in websocket.iter_bytes():
1899
+ if timestamps_supported:
1900
+ type_size = struct.calcsize('!B')
1901
+ message_type = struct.unpack('!B',
1902
+ message[:type_size])[0]
1903
+ if (message_type ==
1904
+ KubernetesSSHMessageType.REGULAR_DATA):
1905
+ # Regular data - strip type byte and forward to SSH
1906
+ message = message[type_size:]
1907
+ elif message_type == KubernetesSSHMessageType.PINGPONG:
1908
+ # PING message - respond with PONG (type 1)
1909
+ ping_id_size = struct.calcsize('!I')
1910
+ if len(message) != type_size + ping_id_size:
1911
+ raise ValueError('Invalid PING message '
1912
+ f'length: {len(message)}')
1913
+ # Return the same PING message, so that the client
1914
+ # can measure the latency.
1915
+ await websocket.send_bytes(message)
1916
+ continue
1917
+ elif (message_type ==
1918
+ KubernetesSSHMessageType.LATENCY_MEASUREMENT):
1919
+ # Latency measurement from client
1920
+ latency_size = struct.calcsize('!Q')
1921
+ if len(message) != type_size + latency_size:
1922
+ raise ValueError(
1923
+ 'Invalid latency measurement '
1924
+ f'message length: {len(message)}')
1925
+ avg_latency_ms = struct.unpack(
1926
+ '!Q',
1927
+ message[type_size:type_size + latency_size])[0]
1928
+ latency_seconds = avg_latency_ms / 1000
1929
+ metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels(pid=os.getpid()).observe(latency_seconds) # pylint: disable=line-too-long
1930
+ continue
1931
+ else:
1932
+ # Unknown message type.
1933
+ raise ValueError(
1934
+ f'Unknown message type: {message_type}')
1092
1935
  writer.write(message)
1093
- await writer.drain()
1936
+ try:
1937
+ await writer.drain()
1938
+ except Exception as e: # pylint: disable=broad-except
1939
+ # Typically we will not reach here, if the ssh to pod
1940
+ # is disconnected, ssh_to_websocket will exit first.
1941
+ # But just in case.
1942
+ logger.error('Failed to write to pod through '
1943
+ f'port-forward connection: {e}')
1944
+ nonlocal ssh_failed
1945
+ ssh_failed = True
1946
+ break
1094
1947
  except fastapi.WebSocketDisconnect:
1095
1948
  pass
1949
+ nonlocal websocket_closed
1950
+ websocket_closed = True
1096
1951
  writer.close()
1097
1952
 
1098
1953
  async def ssh_to_websocket():
@@ -1100,87 +1955,262 @@ async def kubernetes_pod_ssh_proxy(
1100
1955
  while True:
1101
1956
  data = await reader.read(1024)
1102
1957
  if not data:
1958
+ if not websocket_closed:
1959
+ logger.warning('SSH connection to pod is '
1960
+ 'disconnected before websocket '
1961
+ 'connection is closed')
1962
+ nonlocal ssh_failed
1963
+ ssh_failed = True
1103
1964
  break
1965
+ if timestamps_supported:
1966
+ # Prepend message type byte (0 = regular data)
1967
+ message_type_bytes = struct.pack(
1968
+ '!B', KubernetesSSHMessageType.REGULAR_DATA.value)
1969
+ data = message_type_bytes + data
1104
1970
  await websocket.send_bytes(data)
1105
1971
  except Exception: # pylint: disable=broad-except
1106
1972
  pass
1107
- await websocket.close()
1973
+ try:
1974
+ await websocket.close()
1975
+ except Exception: # pylint: disable=broad-except
1976
+ # The websocket might has been closed by the client.
1977
+ pass
1108
1978
 
1109
1979
  await asyncio.gather(websocket_to_ssh(), ssh_to_websocket())
1110
1980
  finally:
1111
- proc.terminate()
1981
+ conn_gauge.dec()
1982
+ reason = ''
1983
+ try:
1984
+ logger.info('Terminating kubectl port-forward process')
1985
+ proc.terminate()
1986
+ except ProcessLookupError:
1987
+ stdout = await proc.stdout.read()
1988
+ logger.error('kubectl port-forward was terminated before the '
1989
+ 'ssh websocket connection was closed. Remaining '
1990
+ f'output: {str(stdout)}')
1991
+ reason = 'KubectlPortForwardExit'
1992
+ metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
1993
+ pid=os.getpid(), reason='KubectlPortForwardExit').inc()
1994
+ else:
1995
+ if ssh_failed:
1996
+ reason = 'SSHToPodDisconnected'
1997
+ else:
1998
+ reason = 'ClientClosed'
1999
+ metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
2000
+ pid=os.getpid(), reason=reason).inc()
2001
+
2002
+
2003
+ @app.get('/all_contexts')
2004
+ async def all_contexts(request: fastapi.Request) -> None:
2005
+ """Gets all Kubernetes and SSH node pool contexts."""
2006
+
2007
+ await executor.schedule_request_async(
2008
+ request_id=request.state.request_id,
2009
+ request_name=request_names.RequestName.ALL_CONTEXTS,
2010
+ request_body=payloads.RequestBody(),
2011
+ func=core.get_all_contexts,
2012
+ schedule_type=requests_lib.ScheduleType.SHORT,
2013
+ )
1112
2014
 
1113
2015
 
1114
2016
  # === Internal APIs ===
1115
2017
  @app.get('/api/completion/cluster_name')
1116
2018
  async def complete_cluster_name(incomplete: str,) -> List[str]:
1117
- return global_user_state.get_cluster_names_start_with(incomplete)
2019
+ return await context_utils.to_thread(
2020
+ global_user_state.get_cluster_names_start_with, incomplete)
1118
2021
 
1119
2022
 
1120
2023
  @app.get('/api/completion/storage_name')
1121
2024
  async def complete_storage_name(incomplete: str,) -> List[str]:
1122
- return global_user_state.get_storage_names_start_with(incomplete)
2025
+ return await context_utils.to_thread(
2026
+ global_user_state.get_storage_names_start_with, incomplete)
1123
2027
 
1124
2028
 
1125
- # Add a route to serve static files
1126
- @app.get('/{full_path:path}')
1127
- async def serve_static_or_dashboard(full_path: str):
1128
- """Serves static files for any unmatched routes.
2029
+ @app.get('/api/completion/volume_name')
2030
+ async def complete_volume_name(incomplete: str,) -> List[str]:
2031
+ return await context_utils.to_thread(
2032
+ global_user_state.get_volume_names_start_with, incomplete)
1129
2033
 
1130
- Handles the /dashboard prefix from Next.js configuration.
1131
- """
1132
- # Check if the path starts with 'dashboard/' and remove it if it does
1133
- if full_path.startswith('dashboard/'):
1134
- full_path = full_path[len('dashboard/'):]
1135
2034
 
1136
- # Try to serve the file directly from the out directory first
2035
+ @app.get('/api/completion/api_request')
2036
+ async def complete_api_request(incomplete: str,) -> List[str]:
2037
+ return await requests_lib.get_api_request_ids_start_with(incomplete)
2038
+
2039
+
2040
+ @app.get('/dashboard/{full_path:path}')
2041
+ async def serve_dashboard(full_path: str):
2042
+ """Serves the Next.js dashboard application.
2043
+
2044
+ Args:
2045
+ full_path: The path requested by the client.
2046
+ e.g. /clusters, /jobs
2047
+
2048
+ Returns:
2049
+ FileResponse for static files or index.html for client-side routing.
2050
+
2051
+ Raises:
2052
+ HTTPException: If the path is invalid or file not found.
2053
+ """
2054
+ # Try to serve the staticfile directly e.g. /skypilot.svg,
2055
+ # /favicon.ico, and /_next/, etc.
1137
2056
  file_path = os.path.join(server_constants.DASHBOARD_DIR, full_path)
1138
2057
  if os.path.isfile(file_path):
1139
2058
  return fastapi.responses.FileResponse(file_path)
1140
2059
 
1141
- # If file not found, serve the index.html for client-side routing.
1142
- # For example, the non-matched arbitrary route (/ or /test) from
1143
- # client will be redirected to the index.html.
2060
+ # Serve index.html for client-side routing
2061
+ # e.g. /clusters, /jobs
1144
2062
  index_path = os.path.join(server_constants.DASHBOARD_DIR, 'index.html')
1145
2063
  try:
1146
2064
  with open(index_path, 'r', encoding='utf-8') as f:
1147
2065
  content = f.read()
2066
+
1148
2067
  return fastapi.responses.HTMLResponse(content=content)
1149
2068
  except Exception as e:
1150
2069
  logger.error(f'Error serving dashboard: {e}')
1151
2070
  raise fastapi.HTTPException(status_code=500, detail=str(e))
1152
2071
 
1153
2072
 
2073
+ # Redirect the root path to dashboard
2074
+ @app.get('/')
2075
+ async def root():
2076
+ return fastapi.responses.RedirectResponse(url='/dashboard/')
2077
+
2078
+
2079
+ def _init_or_restore_server_user_hash():
2080
+ """Restores the server user hash from the global user state db.
2081
+
2082
+ The API server must have a stable user hash across restarts and potential
2083
+ multiple replicas. Thus we persist the user hash in db and restore it on
2084
+ startup. When upgrading from old version, the user hash will be read from
2085
+ the local file (if any) to keep the user hash consistent.
2086
+ """
2087
+
2088
+ def apply_user_hash(user_hash: str) -> None:
2089
+ # For local API server, the user hash in db and local file should be
2090
+ # same so there is no harm to override here.
2091
+ common_utils.set_user_hash_locally(user_hash)
2092
+ # Refresh the server user hash for current process after restore or
2093
+ # initialize the user hash in db, child processes will get the correct
2094
+ # server id from the local cache file.
2095
+ common_lib.refresh_server_id()
2096
+
2097
+ user_hash = global_user_state.get_system_config(_SERVER_USER_HASH_KEY)
2098
+ if user_hash is not None:
2099
+ apply_user_hash(user_hash)
2100
+ return
2101
+
2102
+ # Initial deployment, generate a user hash and save it to the db.
2103
+ user_hash = common_utils.get_user_hash()
2104
+ global_user_state.set_system_config(_SERVER_USER_HASH_KEY, user_hash)
2105
+ apply_user_hash(user_hash)
2106
+
2107
+
1154
2108
  if __name__ == '__main__':
1155
2109
  import uvicorn
1156
2110
 
1157
2111
  from sky.server import uvicorn as skyuvicorn
1158
2112
 
1159
- requests_lib.reset_db_and_logs()
2113
+ logger.info('Initializing SkyPilot API server')
2114
+ skyuvicorn.add_timestamp_prefix_for_server_logs()
1160
2115
 
1161
2116
  parser = argparse.ArgumentParser()
1162
2117
  parser.add_argument('--host', default='127.0.0.1')
1163
2118
  parser.add_argument('--port', default=46580, type=int)
1164
2119
  parser.add_argument('--deploy', action='store_true')
2120
+ # Serve metrics on a separate port to isolate it from the application APIs:
2121
+ # metrics port will not be exposed to the public network typically.
2122
+ parser.add_argument('--metrics-port', default=9090, type=int)
1165
2123
  cmd_args = parser.parse_args()
2124
+ if cmd_args.port == cmd_args.metrics_port:
2125
+ logger.error('port and metrics-port cannot be the same, exiting.')
2126
+ raise ValueError('port and metrics-port cannot be the same')
2127
+
2128
+ # Fail fast if the port is not available to avoid corrupt the state
2129
+ # of potential running server instance.
2130
+ # We might reach here because the running server is currently not
2131
+ # responding, thus the healthz check fails and `sky api start` think
2132
+ # we should start a new server instance.
2133
+ if not common_utils.is_port_available(cmd_args.port):
2134
+ logger.error(f'Port {cmd_args.port} is not available, exiting.')
2135
+ raise RuntimeError(f'Port {cmd_args.port} is not available')
2136
+
2137
+ # Maybe touch the signal file on API server startup. Do it again here even
2138
+ # if we already touched it in the sky/server/common.py::_start_api_server.
2139
+ # This is because the sky/server/common.py::_start_api_server function call
2140
+ # is running outside the skypilot API server process tree. The process tree
2141
+ # starts within that function (see the `subprocess.Popen` call in
2142
+ # sky/server/common.py::_start_api_server). When pg is used, the
2143
+ # _start_api_server function will not load the config file from db, which
2144
+ # will ignore the consolidation mode config. Here, inside the process tree,
2145
+ # we already reload the config as a server (with env var _start_api_server),
2146
+ # so we will respect the consolidation mode config.
2147
+ # Refers to #7717 for more details.
2148
+ managed_job_utils.is_consolidation_mode(on_api_restart=True)
2149
+
1166
2150
  # Show the privacy policy if it is not already shown. We place it here so
1167
2151
  # that it is shown only when the API server is started.
1168
2152
  usage_lib.maybe_show_privacy_policy()
1169
2153
 
1170
- config = server_config.compute_server_config(cmd_args.deploy)
2154
+ # Initialize global user state db
2155
+ db_utils.set_max_connections(1)
2156
+ logger.info('Initializing database engine')
2157
+ global_user_state.initialize_and_get_db()
2158
+ logger.info('Database engine initialized')
2159
+ # Initialize request db
2160
+ requests_lib.reset_db_and_logs()
2161
+ # Restore the server user hash
2162
+ logger.info('Initializing server user hash')
2163
+ _init_or_restore_server_user_hash()
2164
+
2165
+ max_db_connections = global_user_state.get_max_db_connections()
2166
+ logger.info(f'Max db connections: {max_db_connections}')
2167
+
2168
+ # Reserve memory for jobs and serve/pool controller in consolidation mode.
2169
+ reserved_memory_mb = (
2170
+ controller_utils.compute_memory_reserved_for_controllers(
2171
+ reserve_for_controllers=os.environ.get(
2172
+ constants.OVERRIDE_CONSOLIDATION_MODE) is not None,
2173
+ # For jobs controller, we need to reserve for both jobs and
2174
+ # pool controller.
2175
+ reserve_extra_for_pool=not os.environ.get(
2176
+ constants.IS_SKYPILOT_SERVE_CONTROLLER)))
2177
+
2178
+ config = server_config.compute_server_config(
2179
+ cmd_args.deploy,
2180
+ max_db_connections,
2181
+ reserved_memory_mb=reserved_memory_mb)
2182
+
1171
2183
  num_workers = config.num_server_workers
1172
2184
 
1173
- sub_procs = []
2185
+ queue_server: Optional[multiprocessing.Process] = None
2186
+ workers: List[executor.RequestWorker] = []
2187
+ # Global background tasks that will be scheduled in a separate event loop.
2188
+ global_tasks: List[asyncio.Task] = []
1174
2189
  try:
1175
- sub_procs = executor.start(config)
2190
+ background = uvloop.new_event_loop()
2191
+ if os.environ.get(constants.ENV_VAR_SERVER_METRICS_ENABLED):
2192
+ metrics_server = metrics.build_metrics_server(
2193
+ cmd_args.host, cmd_args.metrics_port)
2194
+ global_tasks.append(background.create_task(metrics_server.serve()))
2195
+ global_tasks.append(
2196
+ background.create_task(requests_lib.requests_gc_daemon()))
2197
+ global_tasks.append(
2198
+ background.create_task(
2199
+ global_user_state.cluster_event_retention_daemon()))
2200
+ threading.Thread(target=background.run_forever, daemon=True).start()
2201
+
2202
+ queue_server, workers = executor.start(config)
2203
+
1176
2204
  logger.info(f'Starting SkyPilot API server, workers={num_workers}')
1177
2205
  # We don't support reload for now, since it may cause leakage of request
1178
2206
  # workers or interrupt running requests.
1179
- config = uvicorn.Config('sky.server.server:app',
1180
- host=cmd_args.host,
1181
- port=cmd_args.port,
1182
- workers=num_workers)
1183
- skyuvicorn.run(config)
2207
+ uvicorn_config = uvicorn.Config('sky.server.server:app',
2208
+ host=cmd_args.host,
2209
+ port=cmd_args.port,
2210
+ workers=num_workers,
2211
+ ws_per_message_deflate=False)
2212
+ skyuvicorn.run(uvicorn_config,
2213
+ max_db_connections=config.num_db_connections_per_worker)
1184
2214
  except Exception as exc: # pylint: disable=broad-except
1185
2215
  logger.error(f'Failed to start SkyPilot API server: '
1186
2216
  f'{common_utils.format_exception(exc, use_bracket=True)}')
@@ -1188,17 +2218,11 @@ if __name__ == '__main__':
1188
2218
  finally:
1189
2219
  logger.info('Shutting down SkyPilot API server...')
1190
2220
 
1191
- def cleanup(proc: multiprocessing.Process) -> None:
1192
- try:
1193
- proc.terminate()
1194
- proc.join()
1195
- finally:
1196
- # The process may not be started yet, close it anyway.
1197
- proc.close()
1198
-
1199
- # Terminate processes in reverse order in case dependency, especially
1200
- # queue server. Terminate queue server first does not affect the
1201
- # correctness of cleanup but introduce redundant error messages.
1202
- subprocess_utils.run_in_parallel(cleanup,
1203
- list(reversed(sub_procs)),
1204
- num_threads=len(sub_procs))
2221
+ for gt in global_tasks:
2222
+ gt.cancel()
2223
+ subprocess_utils.run_in_parallel(lambda worker: worker.cancel(),
2224
+ workers,
2225
+ num_threads=len(workers))
2226
+ if queue_server is not None:
2227
+ queue_server.kill()
2228
+ queue_server.join()