skypilot-nightly 1.0.0.dev20250509__py3-none-any.whl → 1.0.0.dev20251107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (512) hide show
  1. sky/__init__.py +22 -6
  2. sky/adaptors/aws.py +25 -7
  3. sky/adaptors/common.py +24 -1
  4. sky/adaptors/coreweave.py +278 -0
  5. sky/adaptors/do.py +8 -2
  6. sky/adaptors/hyperbolic.py +8 -0
  7. sky/adaptors/kubernetes.py +149 -18
  8. sky/adaptors/nebius.py +170 -17
  9. sky/adaptors/primeintellect.py +1 -0
  10. sky/adaptors/runpod.py +68 -0
  11. sky/adaptors/seeweb.py +167 -0
  12. sky/adaptors/shadeform.py +89 -0
  13. sky/admin_policy.py +187 -4
  14. sky/authentication.py +179 -225
  15. sky/backends/__init__.py +4 -2
  16. sky/backends/backend.py +22 -9
  17. sky/backends/backend_utils.py +1299 -380
  18. sky/backends/cloud_vm_ray_backend.py +1715 -518
  19. sky/backends/docker_utils.py +1 -1
  20. sky/backends/local_docker_backend.py +11 -6
  21. sky/backends/wheel_utils.py +37 -9
  22. sky/{clouds/service_catalog → catalog}/__init__.py +21 -19
  23. sky/{clouds/service_catalog → catalog}/aws_catalog.py +27 -8
  24. sky/{clouds/service_catalog → catalog}/azure_catalog.py +10 -7
  25. sky/{clouds/service_catalog → catalog}/common.py +89 -48
  26. sky/{clouds/service_catalog → catalog}/cudo_catalog.py +8 -5
  27. sky/{clouds/service_catalog → catalog}/data_fetchers/analyze.py +1 -1
  28. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_aws.py +30 -40
  29. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_cudo.py +38 -38
  30. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_gcp.py +42 -15
  31. sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
  32. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_lambda_cloud.py +1 -0
  33. sky/catalog/data_fetchers/fetch_nebius.py +335 -0
  34. sky/catalog/data_fetchers/fetch_runpod.py +698 -0
  35. sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
  36. sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
  37. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vast.py +1 -1
  38. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vsphere.py +1 -1
  39. sky/{clouds/service_catalog → catalog}/do_catalog.py +5 -2
  40. sky/{clouds/service_catalog → catalog}/fluidstack_catalog.py +6 -3
  41. sky/{clouds/service_catalog → catalog}/gcp_catalog.py +41 -15
  42. sky/catalog/hyperbolic_catalog.py +136 -0
  43. sky/{clouds/service_catalog → catalog}/ibm_catalog.py +9 -6
  44. sky/{clouds/service_catalog → catalog}/kubernetes_catalog.py +36 -24
  45. sky/{clouds/service_catalog → catalog}/lambda_catalog.py +9 -6
  46. sky/{clouds/service_catalog → catalog}/nebius_catalog.py +9 -7
  47. sky/{clouds/service_catalog → catalog}/oci_catalog.py +9 -6
  48. sky/{clouds/service_catalog → catalog}/paperspace_catalog.py +5 -2
  49. sky/catalog/primeintellect_catalog.py +95 -0
  50. sky/{clouds/service_catalog → catalog}/runpod_catalog.py +11 -4
  51. sky/{clouds/service_catalog → catalog}/scp_catalog.py +9 -6
  52. sky/catalog/seeweb_catalog.py +184 -0
  53. sky/catalog/shadeform_catalog.py +165 -0
  54. sky/catalog/ssh_catalog.py +167 -0
  55. sky/{clouds/service_catalog → catalog}/vast_catalog.py +6 -3
  56. sky/{clouds/service_catalog → catalog}/vsphere_catalog.py +5 -2
  57. sky/check.py +491 -203
  58. sky/cli.py +5 -6005
  59. sky/client/{cli.py → cli/command.py} +2477 -1885
  60. sky/client/cli/deprecation_utils.py +99 -0
  61. sky/client/cli/flags.py +359 -0
  62. sky/client/cli/table_utils.py +320 -0
  63. sky/client/common.py +70 -32
  64. sky/client/oauth.py +82 -0
  65. sky/client/sdk.py +1203 -297
  66. sky/client/sdk_async.py +833 -0
  67. sky/client/service_account_auth.py +47 -0
  68. sky/cloud_stores.py +73 -0
  69. sky/clouds/__init__.py +13 -0
  70. sky/clouds/aws.py +358 -93
  71. sky/clouds/azure.py +105 -83
  72. sky/clouds/cloud.py +127 -36
  73. sky/clouds/cudo.py +68 -50
  74. sky/clouds/do.py +66 -48
  75. sky/clouds/fluidstack.py +63 -44
  76. sky/clouds/gcp.py +339 -110
  77. sky/clouds/hyperbolic.py +293 -0
  78. sky/clouds/ibm.py +70 -49
  79. sky/clouds/kubernetes.py +563 -162
  80. sky/clouds/lambda_cloud.py +74 -54
  81. sky/clouds/nebius.py +206 -80
  82. sky/clouds/oci.py +88 -66
  83. sky/clouds/paperspace.py +61 -44
  84. sky/clouds/primeintellect.py +317 -0
  85. sky/clouds/runpod.py +164 -74
  86. sky/clouds/scp.py +89 -83
  87. sky/clouds/seeweb.py +466 -0
  88. sky/clouds/shadeform.py +400 -0
  89. sky/clouds/ssh.py +263 -0
  90. sky/clouds/utils/aws_utils.py +10 -4
  91. sky/clouds/utils/gcp_utils.py +87 -11
  92. sky/clouds/utils/oci_utils.py +38 -14
  93. sky/clouds/utils/scp_utils.py +177 -124
  94. sky/clouds/vast.py +99 -77
  95. sky/clouds/vsphere.py +51 -40
  96. sky/core.py +349 -139
  97. sky/dag.py +15 -0
  98. sky/dashboard/out/404.html +1 -1
  99. sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
  100. sky/dashboard/out/_next/static/chunks/1272-1ef0bf0237faccdb.js +1 -0
  101. sky/dashboard/out/_next/static/chunks/1871-74503c8e80fd253b.js +6 -0
  102. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
  103. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +1 -0
  104. sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
  105. sky/dashboard/out/_next/static/chunks/2755.fff53c4a3fcae910.js +26 -0
  106. sky/dashboard/out/_next/static/chunks/3294.72362fa129305b19.js +1 -0
  107. sky/dashboard/out/_next/static/chunks/3785.ad6adaa2a0fa9768.js +1 -0
  108. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
  109. sky/dashboard/out/_next/static/chunks/3937.210053269f121201.js +1 -0
  110. sky/dashboard/out/_next/static/chunks/4725.a830b5c9e7867c92.js +1 -0
  111. sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
  112. sky/dashboard/out/_next/static/chunks/5739-d67458fcb1386c92.js +8 -0
  113. sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
  114. sky/dashboard/out/_next/static/chunks/616-3d59f75e2ccf9321.js +39 -0
  115. sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
  116. sky/dashboard/out/_next/static/chunks/6601-06114c982db410b6.js +1 -0
  117. sky/dashboard/out/_next/static/chunks/6856-ef8ba11f96d8c4a3.js +1 -0
  118. sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +1 -0
  119. sky/dashboard/out/_next/static/chunks/6990-32b6e2d3822301fa.js +1 -0
  120. sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
  121. sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +41 -0
  122. sky/dashboard/out/_next/static/chunks/7615-3301e838e5f25772.js +1 -0
  123. sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
  124. sky/dashboard/out/_next/static/chunks/8969-1e4613c651bf4051.js +1 -0
  125. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
  126. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
  127. sky/dashboard/out/_next/static/chunks/9360.7310982cf5a0dc79.js +31 -0
  128. sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
  129. sky/dashboard/out/_next/static/chunks/fd9d1056-86323a29a8f7e46a.js +1 -0
  130. sky/dashboard/out/_next/static/chunks/framework-cf60a09ccd051a10.js +33 -0
  131. sky/dashboard/out/_next/static/chunks/main-app-587214043926b3cc.js +1 -0
  132. sky/dashboard/out/_next/static/chunks/main-f15ccb73239a3bf1.js +1 -0
  133. sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +34 -0
  134. sky/dashboard/out/_next/static/chunks/pages/_error-c66a4e8afc46f17b.js +1 -0
  135. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-c736ead69c2d86ec.js +16 -0
  136. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-a37d2063af475a1c.js +1 -0
  137. sky/dashboard/out/_next/static/chunks/pages/clusters-d44859594e6f8064.js +1 -0
  138. sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
  139. sky/dashboard/out/_next/static/chunks/pages/index-444f1804401f04ea.js +1 -0
  140. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +1 -0
  141. sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +1 -0
  142. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-5796e8d6aea291a0.js +16 -0
  143. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-6edeb7d06032adfc.js +21 -0
  144. sky/dashboard/out/_next/static/chunks/pages/jobs-479dde13399cf270.js +1 -0
  145. sky/dashboard/out/_next/static/chunks/pages/users-5ab3b907622cf0fe.js +1 -0
  146. sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
  147. sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
  148. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-c5a3eeee1c218af1.js +1 -0
  149. sky/dashboard/out/_next/static/chunks/pages/workspaces-22b23febb3e89ce1.js +1 -0
  150. sky/dashboard/out/_next/static/chunks/webpack-2679be77fc08a2f8.js +1 -0
  151. sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
  152. sky/dashboard/out/_next/static/zB0ed6ge_W1MDszVHhijS/_buildManifest.js +1 -0
  153. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  154. sky/dashboard/out/clusters/[cluster].html +1 -1
  155. sky/dashboard/out/clusters.html +1 -1
  156. sky/dashboard/out/config.html +1 -0
  157. sky/dashboard/out/index.html +1 -1
  158. sky/dashboard/out/infra/[context].html +1 -0
  159. sky/dashboard/out/infra.html +1 -0
  160. sky/dashboard/out/jobs/[job].html +1 -1
  161. sky/dashboard/out/jobs/pools/[pool].html +1 -0
  162. sky/dashboard/out/jobs.html +1 -1
  163. sky/dashboard/out/users.html +1 -0
  164. sky/dashboard/out/volumes.html +1 -0
  165. sky/dashboard/out/workspace/new.html +1 -0
  166. sky/dashboard/out/workspaces/[name].html +1 -0
  167. sky/dashboard/out/workspaces.html +1 -0
  168. sky/data/data_utils.py +137 -1
  169. sky/data/mounting_utils.py +269 -84
  170. sky/data/storage.py +1451 -1807
  171. sky/data/storage_utils.py +43 -57
  172. sky/exceptions.py +132 -2
  173. sky/execution.py +206 -63
  174. sky/global_user_state.py +2374 -586
  175. sky/jobs/__init__.py +5 -0
  176. sky/jobs/client/sdk.py +242 -65
  177. sky/jobs/client/sdk_async.py +143 -0
  178. sky/jobs/constants.py +9 -8
  179. sky/jobs/controller.py +839 -277
  180. sky/jobs/file_content_utils.py +80 -0
  181. sky/jobs/log_gc.py +201 -0
  182. sky/jobs/recovery_strategy.py +398 -152
  183. sky/jobs/scheduler.py +315 -189
  184. sky/jobs/server/core.py +829 -255
  185. sky/jobs/server/server.py +156 -115
  186. sky/jobs/server/utils.py +136 -0
  187. sky/jobs/state.py +2092 -701
  188. sky/jobs/utils.py +1242 -160
  189. sky/logs/__init__.py +21 -0
  190. sky/logs/agent.py +108 -0
  191. sky/logs/aws.py +243 -0
  192. sky/logs/gcp.py +91 -0
  193. sky/metrics/__init__.py +0 -0
  194. sky/metrics/utils.py +443 -0
  195. sky/models.py +78 -1
  196. sky/optimizer.py +164 -70
  197. sky/provision/__init__.py +90 -4
  198. sky/provision/aws/config.py +147 -26
  199. sky/provision/aws/instance.py +135 -50
  200. sky/provision/azure/instance.py +10 -5
  201. sky/provision/common.py +13 -1
  202. sky/provision/cudo/cudo_machine_type.py +1 -1
  203. sky/provision/cudo/cudo_utils.py +14 -8
  204. sky/provision/cudo/cudo_wrapper.py +72 -71
  205. sky/provision/cudo/instance.py +10 -6
  206. sky/provision/do/instance.py +10 -6
  207. sky/provision/do/utils.py +4 -3
  208. sky/provision/docker_utils.py +114 -23
  209. sky/provision/fluidstack/instance.py +13 -8
  210. sky/provision/gcp/__init__.py +1 -0
  211. sky/provision/gcp/config.py +301 -19
  212. sky/provision/gcp/constants.py +218 -0
  213. sky/provision/gcp/instance.py +36 -8
  214. sky/provision/gcp/instance_utils.py +18 -4
  215. sky/provision/gcp/volume_utils.py +247 -0
  216. sky/provision/hyperbolic/__init__.py +12 -0
  217. sky/provision/hyperbolic/config.py +10 -0
  218. sky/provision/hyperbolic/instance.py +437 -0
  219. sky/provision/hyperbolic/utils.py +373 -0
  220. sky/provision/instance_setup.py +93 -14
  221. sky/provision/kubernetes/__init__.py +5 -0
  222. sky/provision/kubernetes/config.py +9 -52
  223. sky/provision/kubernetes/constants.py +17 -0
  224. sky/provision/kubernetes/instance.py +789 -247
  225. sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
  226. sky/provision/kubernetes/network.py +27 -17
  227. sky/provision/kubernetes/network_utils.py +40 -43
  228. sky/provision/kubernetes/utils.py +1192 -531
  229. sky/provision/kubernetes/volume.py +282 -0
  230. sky/provision/lambda_cloud/instance.py +22 -16
  231. sky/provision/nebius/constants.py +50 -0
  232. sky/provision/nebius/instance.py +19 -6
  233. sky/provision/nebius/utils.py +196 -91
  234. sky/provision/oci/instance.py +10 -5
  235. sky/provision/paperspace/instance.py +10 -7
  236. sky/provision/paperspace/utils.py +1 -1
  237. sky/provision/primeintellect/__init__.py +10 -0
  238. sky/provision/primeintellect/config.py +11 -0
  239. sky/provision/primeintellect/instance.py +454 -0
  240. sky/provision/primeintellect/utils.py +398 -0
  241. sky/provision/provisioner.py +110 -36
  242. sky/provision/runpod/__init__.py +5 -0
  243. sky/provision/runpod/instance.py +27 -6
  244. sky/provision/runpod/utils.py +51 -18
  245. sky/provision/runpod/volume.py +180 -0
  246. sky/provision/scp/__init__.py +15 -0
  247. sky/provision/scp/config.py +93 -0
  248. sky/provision/scp/instance.py +531 -0
  249. sky/provision/seeweb/__init__.py +11 -0
  250. sky/provision/seeweb/config.py +13 -0
  251. sky/provision/seeweb/instance.py +807 -0
  252. sky/provision/shadeform/__init__.py +11 -0
  253. sky/provision/shadeform/config.py +12 -0
  254. sky/provision/shadeform/instance.py +351 -0
  255. sky/provision/shadeform/shadeform_utils.py +83 -0
  256. sky/provision/ssh/__init__.py +18 -0
  257. sky/provision/vast/instance.py +13 -8
  258. sky/provision/vast/utils.py +10 -7
  259. sky/provision/vsphere/common/vim_utils.py +1 -2
  260. sky/provision/vsphere/instance.py +15 -10
  261. sky/provision/vsphere/vsphere_utils.py +9 -19
  262. sky/py.typed +0 -0
  263. sky/resources.py +844 -118
  264. sky/schemas/__init__.py +0 -0
  265. sky/schemas/api/__init__.py +0 -0
  266. sky/schemas/api/responses.py +225 -0
  267. sky/schemas/db/README +4 -0
  268. sky/schemas/db/env.py +90 -0
  269. sky/schemas/db/global_user_state/001_initial_schema.py +124 -0
  270. sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
  271. sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
  272. sky/schemas/db/global_user_state/004_is_managed.py +34 -0
  273. sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
  274. sky/schemas/db/global_user_state/006_provision_log.py +41 -0
  275. sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
  276. sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
  277. sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
  278. sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
  279. sky/schemas/db/script.py.mako +28 -0
  280. sky/schemas/db/serve_state/001_initial_schema.py +67 -0
  281. sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
  282. sky/schemas/db/spot_jobs/001_initial_schema.py +97 -0
  283. sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
  284. sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
  285. sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
  286. sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
  287. sky/schemas/generated/__init__.py +0 -0
  288. sky/schemas/generated/autostopv1_pb2.py +36 -0
  289. sky/schemas/generated/autostopv1_pb2.pyi +43 -0
  290. sky/schemas/generated/autostopv1_pb2_grpc.py +146 -0
  291. sky/schemas/generated/jobsv1_pb2.py +86 -0
  292. sky/schemas/generated/jobsv1_pb2.pyi +254 -0
  293. sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
  294. sky/schemas/generated/managed_jobsv1_pb2.py +74 -0
  295. sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
  296. sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
  297. sky/schemas/generated/servev1_pb2.py +58 -0
  298. sky/schemas/generated/servev1_pb2.pyi +115 -0
  299. sky/schemas/generated/servev1_pb2_grpc.py +322 -0
  300. sky/serve/autoscalers.py +357 -5
  301. sky/serve/client/impl.py +310 -0
  302. sky/serve/client/sdk.py +47 -139
  303. sky/serve/client/sdk_async.py +130 -0
  304. sky/serve/constants.py +10 -8
  305. sky/serve/controller.py +64 -19
  306. sky/serve/load_balancer.py +106 -60
  307. sky/serve/load_balancing_policies.py +115 -1
  308. sky/serve/replica_managers.py +273 -162
  309. sky/serve/serve_rpc_utils.py +179 -0
  310. sky/serve/serve_state.py +554 -251
  311. sky/serve/serve_utils.py +733 -220
  312. sky/serve/server/core.py +66 -711
  313. sky/serve/server/impl.py +1093 -0
  314. sky/serve/server/server.py +21 -18
  315. sky/serve/service.py +133 -48
  316. sky/serve/service_spec.py +135 -16
  317. sky/serve/spot_placer.py +3 -0
  318. sky/server/auth/__init__.py +0 -0
  319. sky/server/auth/authn.py +50 -0
  320. sky/server/auth/loopback.py +38 -0
  321. sky/server/auth/oauth2_proxy.py +200 -0
  322. sky/server/common.py +475 -181
  323. sky/server/config.py +81 -23
  324. sky/server/constants.py +44 -6
  325. sky/server/daemons.py +229 -0
  326. sky/server/html/token_page.html +185 -0
  327. sky/server/metrics.py +160 -0
  328. sky/server/requests/executor.py +528 -138
  329. sky/server/requests/payloads.py +351 -17
  330. sky/server/requests/preconditions.py +21 -17
  331. sky/server/requests/process.py +112 -29
  332. sky/server/requests/request_names.py +120 -0
  333. sky/server/requests/requests.py +817 -224
  334. sky/server/requests/serializers/decoders.py +82 -31
  335. sky/server/requests/serializers/encoders.py +140 -22
  336. sky/server/requests/threads.py +106 -0
  337. sky/server/rest.py +417 -0
  338. sky/server/server.py +1290 -284
  339. sky/server/state.py +20 -0
  340. sky/server/stream_utils.py +345 -57
  341. sky/server/uvicorn.py +217 -3
  342. sky/server/versions.py +270 -0
  343. sky/setup_files/MANIFEST.in +5 -0
  344. sky/setup_files/alembic.ini +156 -0
  345. sky/setup_files/dependencies.py +136 -31
  346. sky/setup_files/setup.py +44 -42
  347. sky/sky_logging.py +102 -5
  348. sky/skylet/attempt_skylet.py +1 -0
  349. sky/skylet/autostop_lib.py +129 -8
  350. sky/skylet/configs.py +27 -20
  351. sky/skylet/constants.py +171 -19
  352. sky/skylet/events.py +105 -21
  353. sky/skylet/job_lib.py +335 -104
  354. sky/skylet/log_lib.py +297 -18
  355. sky/skylet/log_lib.pyi +44 -1
  356. sky/skylet/ray_patches/__init__.py +17 -3
  357. sky/skylet/ray_patches/autoscaler.py.diff +18 -0
  358. sky/skylet/ray_patches/cli.py.diff +19 -0
  359. sky/skylet/ray_patches/command_runner.py.diff +17 -0
  360. sky/skylet/ray_patches/log_monitor.py.diff +20 -0
  361. sky/skylet/ray_patches/resource_demand_scheduler.py.diff +32 -0
  362. sky/skylet/ray_patches/updater.py.diff +18 -0
  363. sky/skylet/ray_patches/worker.py.diff +41 -0
  364. sky/skylet/services.py +564 -0
  365. sky/skylet/skylet.py +63 -4
  366. sky/skylet/subprocess_daemon.py +103 -29
  367. sky/skypilot_config.py +506 -99
  368. sky/ssh_node_pools/__init__.py +1 -0
  369. sky/ssh_node_pools/core.py +135 -0
  370. sky/ssh_node_pools/server.py +233 -0
  371. sky/task.py +621 -137
  372. sky/templates/aws-ray.yml.j2 +10 -3
  373. sky/templates/azure-ray.yml.j2 +1 -1
  374. sky/templates/do-ray.yml.j2 +1 -1
  375. sky/templates/gcp-ray.yml.j2 +57 -0
  376. sky/templates/hyperbolic-ray.yml.j2 +67 -0
  377. sky/templates/jobs-controller.yaml.j2 +27 -24
  378. sky/templates/kubernetes-loadbalancer.yml.j2 +2 -0
  379. sky/templates/kubernetes-ray.yml.j2 +607 -51
  380. sky/templates/lambda-ray.yml.j2 +1 -1
  381. sky/templates/nebius-ray.yml.j2 +33 -12
  382. sky/templates/paperspace-ray.yml.j2 +1 -1
  383. sky/templates/primeintellect-ray.yml.j2 +71 -0
  384. sky/templates/runpod-ray.yml.j2 +9 -1
  385. sky/templates/scp-ray.yml.j2 +3 -50
  386. sky/templates/seeweb-ray.yml.j2 +108 -0
  387. sky/templates/shadeform-ray.yml.j2 +72 -0
  388. sky/templates/sky-serve-controller.yaml.j2 +22 -2
  389. sky/templates/websocket_proxy.py +178 -18
  390. sky/usage/usage_lib.py +18 -11
  391. sky/users/__init__.py +0 -0
  392. sky/users/model.conf +15 -0
  393. sky/users/permission.py +387 -0
  394. sky/users/rbac.py +121 -0
  395. sky/users/server.py +720 -0
  396. sky/users/token_service.py +218 -0
  397. sky/utils/accelerator_registry.py +34 -5
  398. sky/utils/admin_policy_utils.py +84 -38
  399. sky/utils/annotations.py +16 -5
  400. sky/utils/asyncio_utils.py +78 -0
  401. sky/utils/auth_utils.py +153 -0
  402. sky/utils/benchmark_utils.py +60 -0
  403. sky/utils/cli_utils/status_utils.py +159 -86
  404. sky/utils/cluster_utils.py +31 -9
  405. sky/utils/command_runner.py +354 -68
  406. sky/utils/command_runner.pyi +93 -3
  407. sky/utils/common.py +35 -8
  408. sky/utils/common_utils.py +310 -87
  409. sky/utils/config_utils.py +87 -5
  410. sky/utils/context.py +402 -0
  411. sky/utils/context_utils.py +222 -0
  412. sky/utils/controller_utils.py +264 -89
  413. sky/utils/dag_utils.py +31 -12
  414. sky/utils/db/__init__.py +0 -0
  415. sky/utils/db/db_utils.py +470 -0
  416. sky/utils/db/migration_utils.py +133 -0
  417. sky/utils/directory_utils.py +12 -0
  418. sky/utils/env_options.py +13 -0
  419. sky/utils/git.py +567 -0
  420. sky/utils/git_clone.sh +460 -0
  421. sky/utils/infra_utils.py +195 -0
  422. sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
  423. sky/utils/kubernetes/config_map_utils.py +133 -0
  424. sky/utils/kubernetes/create_cluster.sh +13 -27
  425. sky/utils/kubernetes/delete_cluster.sh +10 -7
  426. sky/utils/kubernetes/deploy_remote_cluster.py +1299 -0
  427. sky/utils/kubernetes/exec_kubeconfig_converter.py +22 -31
  428. sky/utils/kubernetes/generate_kind_config.py +6 -66
  429. sky/utils/kubernetes/generate_kubeconfig.sh +4 -1
  430. sky/utils/kubernetes/gpu_labeler.py +5 -5
  431. sky/utils/kubernetes/kubernetes_deploy_utils.py +354 -47
  432. sky/utils/kubernetes/ssh-tunnel.sh +379 -0
  433. sky/utils/kubernetes/ssh_utils.py +221 -0
  434. sky/utils/kubernetes_enums.py +8 -15
  435. sky/utils/lock_events.py +94 -0
  436. sky/utils/locks.py +368 -0
  437. sky/utils/log_utils.py +300 -6
  438. sky/utils/perf_utils.py +22 -0
  439. sky/utils/resource_checker.py +298 -0
  440. sky/utils/resources_utils.py +249 -32
  441. sky/utils/rich_utils.py +213 -37
  442. sky/utils/schemas.py +905 -147
  443. sky/utils/serialize_utils.py +16 -0
  444. sky/utils/status_lib.py +10 -0
  445. sky/utils/subprocess_utils.py +38 -15
  446. sky/utils/tempstore.py +70 -0
  447. sky/utils/timeline.py +24 -52
  448. sky/utils/ux_utils.py +84 -15
  449. sky/utils/validator.py +11 -1
  450. sky/utils/volume.py +86 -0
  451. sky/utils/yaml_utils.py +111 -0
  452. sky/volumes/__init__.py +13 -0
  453. sky/volumes/client/__init__.py +0 -0
  454. sky/volumes/client/sdk.py +149 -0
  455. sky/volumes/server/__init__.py +0 -0
  456. sky/volumes/server/core.py +258 -0
  457. sky/volumes/server/server.py +122 -0
  458. sky/volumes/volume.py +212 -0
  459. sky/workspaces/__init__.py +0 -0
  460. sky/workspaces/core.py +655 -0
  461. sky/workspaces/server.py +101 -0
  462. sky/workspaces/utils.py +56 -0
  463. skypilot_nightly-1.0.0.dev20251107.dist-info/METADATA +675 -0
  464. skypilot_nightly-1.0.0.dev20251107.dist-info/RECORD +594 -0
  465. {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/WHEEL +1 -1
  466. sky/benchmark/benchmark_state.py +0 -256
  467. sky/benchmark/benchmark_utils.py +0 -641
  468. sky/clouds/service_catalog/constants.py +0 -7
  469. sky/dashboard/out/_next/static/LksQgChY5izXjokL3LcEu/_buildManifest.js +0 -1
  470. sky/dashboard/out/_next/static/chunks/236-f49500b82ad5392d.js +0 -6
  471. sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
  472. sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
  473. sky/dashboard/out/_next/static/chunks/678-206dddca808e6d16.js +0 -59
  474. sky/dashboard/out/_next/static/chunks/845-0f8017370869e269.js +0 -1
  475. sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
  476. sky/dashboard/out/_next/static/chunks/fd9d1056-2821b0f0cabcd8bd.js +0 -1
  477. sky/dashboard/out/_next/static/chunks/framework-87d061ee6ed71b28.js +0 -33
  478. sky/dashboard/out/_next/static/chunks/main-app-241eb28595532291.js +0 -1
  479. sky/dashboard/out/_next/static/chunks/main-e0e2335212e72357.js +0 -1
  480. sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
  481. sky/dashboard/out/_next/static/chunks/pages/_error-1be831200e60c5c0.js +0 -1
  482. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-e15db85d0ea1fbe1.js +0 -1
  483. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
  484. sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
  485. sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
  486. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-03f279c6741fb48b.js +0 -1
  487. sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
  488. sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
  489. sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
  490. sky/jobs/dashboard/dashboard.py +0 -223
  491. sky/jobs/dashboard/static/favicon.ico +0 -0
  492. sky/jobs/dashboard/templates/index.html +0 -831
  493. sky/jobs/server/dashboard_utils.py +0 -69
  494. sky/skylet/providers/scp/__init__.py +0 -2
  495. sky/skylet/providers/scp/config.py +0 -149
  496. sky/skylet/providers/scp/node_provider.py +0 -578
  497. sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
  498. sky/utils/db_utils.py +0 -100
  499. sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
  500. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
  501. skypilot_nightly-1.0.0.dev20250509.dist-info/METADATA +0 -361
  502. skypilot_nightly-1.0.0.dev20250509.dist-info/RECORD +0 -396
  503. /sky/{clouds/service_catalog → catalog}/config.py +0 -0
  504. /sky/{benchmark → catalog/data_fetchers}/__init__.py +0 -0
  505. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_azure.py +0 -0
  506. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_fluidstack.py +0 -0
  507. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_ibm.py +0 -0
  508. /sky/{clouds/service_catalog/data_fetchers → client/cli}/__init__.py +0 -0
  509. /sky/dashboard/out/_next/static/{LksQgChY5izXjokL3LcEu → zB0ed6ge_W1MDszVHhijS}/_ssgManifest.js +0 -0
  510. {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/entry_points.txt +0 -0
  511. {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/licenses/LICENSE +0 -0
  512. {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/top_level.txt +0 -0
sky/server/server.py CHANGED
@@ -2,55 +2,88 @@
2
2
 
3
3
  import argparse
4
4
  import asyncio
5
+ import base64
6
+ from concurrent.futures import ThreadPoolExecutor
5
7
  import contextlib
6
- import dataclasses
7
8
  import datetime
8
- import logging
9
+ from enum import IntEnum
10
+ import hashlib
11
+ import json
9
12
  import multiprocessing
10
13
  import os
11
14
  import pathlib
15
+ import posixpath
12
16
  import re
17
+ import resource
13
18
  import shutil
19
+ import struct
14
20
  import sys
15
- from typing import Any, Dict, List, Literal, Optional, Set, Tuple
21
+ import threading
22
+ import traceback
23
+ from typing import Dict, List, Literal, Optional, Set, Tuple
16
24
  import uuid
17
25
  import zipfile
18
26
 
19
27
  import aiofiles
28
+ import anyio
20
29
  import fastapi
30
+ from fastapi import responses as fastapi_responses
21
31
  from fastapi.middleware import cors
22
32
  import starlette.middleware.base
33
+ import uvloop
23
34
 
24
35
  import sky
36
+ from sky import catalog
25
37
  from sky import check as sky_check
26
38
  from sky import clouds
27
39
  from sky import core
28
40
  from sky import exceptions
29
41
  from sky import execution
30
42
  from sky import global_user_state
43
+ from sky import models
31
44
  from sky import sky_logging
32
- from sky.clouds import service_catalog
33
45
  from sky.data import storage_utils
46
+ from sky.jobs import utils as managed_job_utils
34
47
  from sky.jobs.server import server as jobs_rest
48
+ from sky.metrics import utils as metrics_utils
49
+ from sky.provision import metadata_utils
35
50
  from sky.provision.kubernetes import utils as kubernetes_utils
51
+ from sky.schemas.api import responses
36
52
  from sky.serve.server import server as serve_rest
37
53
  from sky.server import common
38
54
  from sky.server import config as server_config
39
55
  from sky.server import constants as server_constants
56
+ from sky.server import daemons
57
+ from sky.server import metrics
58
+ from sky.server import state
40
59
  from sky.server import stream_utils
60
+ from sky.server import versions
61
+ from sky.server.auth import authn
62
+ from sky.server.auth import loopback
63
+ from sky.server.auth import oauth2_proxy
41
64
  from sky.server.requests import executor
42
65
  from sky.server.requests import payloads
43
66
  from sky.server.requests import preconditions
67
+ from sky.server.requests import request_names
44
68
  from sky.server.requests import requests as requests_lib
45
69
  from sky.skylet import constants
70
+ from sky.ssh_node_pools import server as ssh_node_pools_rest
46
71
  from sky.usage import usage_lib
72
+ from sky.users import permission
73
+ from sky.users import server as users_rest
47
74
  from sky.utils import admin_policy_utils
48
75
  from sky.utils import common as common_lib
49
76
  from sky.utils import common_utils
77
+ from sky.utils import context
78
+ from sky.utils import context_utils
50
79
  from sky.utils import dag_utils
51
- from sky.utils import env_options
80
+ from sky.utils import perf_utils
52
81
  from sky.utils import status_lib
53
82
  from sky.utils import subprocess_utils
83
+ from sky.utils import ux_utils
84
+ from sky.utils.db import db_utils
85
+ from sky.volumes.server import server as volumes_rest
86
+ from sky.workspaces import server as workspaces_rest
54
87
 
55
88
  # pylint: disable=ungrouped-imports
56
89
  if sys.version_info >= (3, 10):
@@ -60,31 +93,8 @@ else:
60
93
 
61
94
  P = ParamSpec('P')
62
95
 
96
+ _SERVER_USER_HASH_KEY = 'server_user_hash'
63
97
 
64
- def _add_timestamp_prefix_for_server_logs() -> None:
65
- server_logger = sky_logging.init_logger('sky.server')
66
- # Clear existing handlers first to prevent duplicates
67
- server_logger.handlers.clear()
68
- # Disable propagation to avoid the root logger of SkyPilot being affected
69
- server_logger.propagate = False
70
- # Add date prefix to the log message printed by loggers under
71
- # server.
72
- stream_handler = logging.StreamHandler(sys.stdout)
73
- if env_options.Options.SHOW_DEBUG_INFO.get():
74
- stream_handler.setLevel(logging.DEBUG)
75
- else:
76
- stream_handler.setLevel(logging.INFO)
77
- stream_handler.flush = sys.stdout.flush # type: ignore
78
- stream_handler.setFormatter(sky_logging.FORMATTER)
79
- server_logger.addHandler(stream_handler)
80
- # Add date prefix to the log message printed by uvicorn.
81
- for name in ['uvicorn', 'uvicorn.access']:
82
- uvicorn_logger = logging.getLogger(name)
83
- uvicorn_logger.handlers.clear()
84
- uvicorn_logger.addHandler(stream_handler)
85
-
86
-
87
- _add_timestamp_prefix_for_server_logs()
88
98
  logger = sky_logging.init_logger(__name__)
89
99
 
90
100
  # TODO(zhwu): Streaming requests, such log tailing after sky launch or sky logs,
@@ -92,11 +102,72 @@ logger = sky_logging.init_logger(__name__)
92
102
  # response will block other requests from being processed.
93
103
 
94
104
 
105
+ def _basic_auth_401_response(content: str):
106
+ """Return a 401 response with basic auth realm."""
107
+ return fastapi.responses.JSONResponse(
108
+ status_code=401,
109
+ headers={'WWW-Authenticate': 'Basic realm=\"SkyPilot\"'},
110
+ content=content)
111
+
112
+
113
+ def _try_set_basic_auth_user(request: fastapi.Request):
114
+ auth_header = request.headers.get('authorization')
115
+ if not auth_header or not auth_header.lower().startswith('basic '):
116
+ return
117
+
118
+ # Check username and password
119
+ encoded = auth_header.split(' ', 1)[1]
120
+ try:
121
+ decoded = base64.b64decode(encoded).decode()
122
+ username, password = decoded.split(':', 1)
123
+ except Exception: # pylint: disable=broad-except
124
+ return
125
+
126
+ users = global_user_state.get_user_by_name(username)
127
+ if not users:
128
+ return
129
+
130
+ for user in users:
131
+ if not user.name or not user.password:
132
+ continue
133
+ username_encoded = username.encode('utf8')
134
+ db_username_encoded = user.name.encode('utf8')
135
+ if (username_encoded == db_username_encoded and
136
+ common.crypt_ctx.verify(password, user.password)):
137
+ request.state.auth_user = user
138
+ break
139
+
140
+
141
+ class RBACMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
142
+ """Middleware to handle RBAC."""
143
+
144
+ async def dispatch(self, request: fastapi.Request, call_next):
145
+ # TODO(hailong): should have a list of paths
146
+ # that are not checked for RBAC
147
+ if (request.url.path.startswith('/dashboard/') or
148
+ request.url.path.startswith('/api/')):
149
+ return await call_next(request)
150
+
151
+ auth_user = request.state.auth_user
152
+ if auth_user is None:
153
+ return await call_next(request)
154
+
155
+ permission_service = permission.permission_service
156
+ # Check the role permission
157
+ if permission_service.check_endpoint_permission(auth_user.id,
158
+ request.url.path,
159
+ request.method):
160
+ return fastapi.responses.JSONResponse(
161
+ status_code=403, content={'detail': 'Forbidden'})
162
+
163
+ return await call_next(request)
164
+
165
+
95
166
  class RequestIDMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
96
167
  """Middleware to add a request ID to each request."""
97
168
 
98
169
  async def dispatch(self, request: fastapi.Request, call_next):
99
- request_id = str(uuid.uuid4())
170
+ request_id = requests_lib.get_new_request_id()
100
171
  request.state.request_id = request_id
101
172
  response = await call_next(request)
102
173
  # TODO(syang): remove X-Request-ID when v0.10.0 is released.
@@ -105,6 +176,238 @@ class RequestIDMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
105
176
  return response
106
177
 
107
178
 
179
+ def _get_auth_user_header(request: fastapi.Request) -> Optional[models.User]:
180
+ header_name = os.environ.get(constants.ENV_VAR_SERVER_AUTH_USER_HEADER,
181
+ 'X-Auth-Request-Email')
182
+ if header_name not in request.headers:
183
+ return None
184
+ user_name = request.headers[header_name]
185
+ user_hash = hashlib.md5(
186
+ user_name.encode()).hexdigest()[:common_utils.USER_HASH_LENGTH]
187
+ return models.User(id=user_hash, name=user_name)
188
+
189
+
190
+ class InitializeRequestAuthUserMiddleware(
191
+ starlette.middleware.base.BaseHTTPMiddleware):
192
+
193
+ async def dispatch(self, request: fastapi.Request, call_next):
194
+ # Make sure that request.state.auth_user is set. Otherwise, we may get a
195
+ # KeyError while trying to read it.
196
+ request.state.auth_user = None
197
+ return await call_next(request)
198
+
199
+
200
+ class BasicAuthMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
201
+ """Middleware to handle HTTP Basic Auth."""
202
+
203
+ async def dispatch(self, request: fastapi.Request, call_next):
204
+ if managed_job_utils.is_consolidation_mode(
205
+ ) and loopback.is_loopback_request(request):
206
+ return await call_next(request)
207
+
208
+ if request.url.path.startswith('/api/health'):
209
+ # Try to set the auth user from basic auth
210
+ _try_set_basic_auth_user(request)
211
+ return await call_next(request)
212
+
213
+ auth_header = request.headers.get('authorization')
214
+ if not auth_header:
215
+ return _basic_auth_401_response('Authentication required')
216
+
217
+ # Only handle basic auth
218
+ if not auth_header.lower().startswith('basic '):
219
+ return _basic_auth_401_response('Invalid authentication method')
220
+
221
+ # Check username and password
222
+ encoded = auth_header.split(' ', 1)[1]
223
+ try:
224
+ decoded = base64.b64decode(encoded).decode()
225
+ username, password = decoded.split(':', 1)
226
+ except Exception: # pylint: disable=broad-except
227
+ return _basic_auth_401_response('Invalid basic auth')
228
+
229
+ users = global_user_state.get_user_by_name(username)
230
+ if not users:
231
+ return _basic_auth_401_response('Invalid credentials')
232
+
233
+ valid_user = False
234
+ for user in users:
235
+ if not user.name or not user.password:
236
+ continue
237
+ username_encoded = username.encode('utf8')
238
+ db_username_encoded = user.name.encode('utf8')
239
+ if (username_encoded == db_username_encoded and
240
+ common.crypt_ctx.verify(password, user.password)):
241
+ valid_user = True
242
+ request.state.auth_user = user
243
+ await authn.override_user_info_in_request_body(request, user)
244
+ break
245
+ if not valid_user:
246
+ return _basic_auth_401_response('Invalid credentials')
247
+
248
+ return await call_next(request)
249
+
250
+
251
+ class BearerTokenMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
252
+ """Middleware to handle Bearer Token Auth (Service Accounts)."""
253
+
254
+ async def dispatch(self, request: fastapi.Request, call_next):
255
+ """Make sure correct bearer token auth is present.
256
+
257
+ 1. If the request has the X-Skypilot-Auth-Mode: token header, it must
258
+ have a valid bearer token.
259
+ 2. For backwards compatibility, if the request has a Bearer token
260
+ beginning with "sky_" (even if X-Skypilot-Auth-Mode is not present),
261
+ it must be a valid token.
262
+ 3. If X-Skypilot-Auth-Mode is not set to "token", and there is no Bearer
263
+ token beginning with "sky_", allow the request to continue.
264
+
265
+ In conjunction with an auth proxy, the idea is to make the auth proxy
266
+ bypass requests with bearer tokens, instead setting the
267
+ X-Skypilot-Auth-Mode header. The auth proxy should either validate the
268
+ auth or set the header X-Skypilot-Auth-Mode: token.
269
+ """
270
+ has_skypilot_auth_header = (
271
+ request.headers.get('X-Skypilot-Auth-Mode') == 'token')
272
+ auth_header = request.headers.get('authorization')
273
+ has_bearer_token_starting_with_sky = (
274
+ auth_header and auth_header.lower().startswith('bearer ') and
275
+ auth_header.split(' ', 1)[1].startswith('sky_'))
276
+
277
+ if (not has_skypilot_auth_header and
278
+ not has_bearer_token_starting_with_sky):
279
+ # This is case #3 above. We do not need to validate the request.
280
+ # No Bearer token, continue with normal processing (OAuth2 cookies,
281
+ # etc.)
282
+ return await call_next(request)
283
+ # After this point, all requests must be validated.
284
+
285
+ if auth_header is None:
286
+ return fastapi.responses.JSONResponse(
287
+ status_code=401, content={'detail': 'Authentication required'})
288
+
289
+ # Extract token
290
+ split_header = auth_header.split(' ', 1)
291
+ if split_header[0].lower() != 'bearer':
292
+ return fastapi.responses.JSONResponse(
293
+ status_code=401,
294
+ content={'detail': 'Invalid authentication method'})
295
+ sa_token = split_header[1]
296
+
297
+ # Handle SkyPilot service account tokens
298
+ return await self._handle_service_account_token(request, sa_token,
299
+ call_next)
300
+
301
+ async def _handle_service_account_token(self, request: fastapi.Request,
302
+ sa_token: str, call_next):
303
+ """Handle SkyPilot service account tokens."""
304
+ # Check if service account tokens are enabled
305
+ sa_enabled = os.environ.get(constants.ENV_VAR_ENABLE_SERVICE_ACCOUNTS,
306
+ 'false').lower()
307
+ if sa_enabled != 'true':
308
+ return fastapi.responses.JSONResponse(
309
+ status_code=401,
310
+ content={'detail': 'Service account authentication disabled'})
311
+
312
+ try:
313
+ # Import here to avoid circular imports
314
+ # pylint: disable=import-outside-toplevel
315
+ from sky.users.token_service import token_service
316
+
317
+ # Verify and decode JWT token
318
+ payload = token_service.verify_token(sa_token)
319
+
320
+ if payload is None:
321
+ logger.warning('Service account token verification failed')
322
+ return fastapi.responses.JSONResponse(
323
+ status_code=401,
324
+ content={
325
+ 'detail': 'Invalid or expired service account token'
326
+ })
327
+
328
+ # Extract user information from JWT payload
329
+ user_id = payload.get('sub')
330
+ user_name = payload.get('name')
331
+ token_id = payload.get('token_id')
332
+
333
+ if not user_id or not token_id:
334
+ logger.warning(
335
+ 'Invalid token payload: missing user_id or token_id')
336
+ return fastapi.responses.JSONResponse(
337
+ status_code=401,
338
+ content={'detail': 'Invalid token payload'})
339
+
340
+ # Verify user still exists in database
341
+ user_info = global_user_state.get_user(user_id)
342
+ if user_info is None:
343
+ logger.warning(
344
+ f'Service account user {user_id} no longer exists')
345
+ return fastapi.responses.JSONResponse(
346
+ status_code=401,
347
+ content={'detail': 'Service account user no longer exists'})
348
+
349
+ # Update last used timestamp for token tracking
350
+ try:
351
+ global_user_state.update_service_account_token_last_used(
352
+ token_id)
353
+ except Exception as e: # pylint: disable=broad-except
354
+ logger.debug(f'Failed to update token last used time: {e}')
355
+
356
+ # Set the authenticated user
357
+ auth_user = models.User(id=user_id,
358
+ name=user_name or user_info.name)
359
+ request.state.auth_user = auth_user
360
+
361
+ # Override user info in request body for service account requests
362
+ await authn.override_user_info_in_request_body(request, auth_user)
363
+
364
+ logger.debug(f'Authenticated service account: {user_id}')
365
+
366
+ except Exception as e: # pylint: disable=broad-except
367
+ logger.error(f'Service account authentication failed: {e}',
368
+ exc_info=True)
369
+ return fastapi.responses.JSONResponse(
370
+ status_code=401,
371
+ content={
372
+ 'detail': f'Service account authentication failed: {str(e)}'
373
+ })
374
+
375
+ return await call_next(request)
376
+
377
+
378
+ class AuthProxyMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
379
+ """Middleware to handle auth proxy."""
380
+
381
+ async def dispatch(self, request: fastapi.Request, call_next):
382
+ auth_user = _get_auth_user_header(request)
383
+
384
+ if request.state.auth_user is not None:
385
+ # Previous middleware is trusted more than this middleware. For
386
+ # instance, a client could set the Authorization and the
387
+ # X-Auth-Request-Email header. In that case, the auth proxy will be
388
+ # skipped and we should rely on the Bearer token to authenticate the
389
+ # user - but that means the user could set X-Auth-Request-Email to
390
+ # whatever the user wants. We should thus ignore it.
391
+ if auth_user is not None:
392
+ logger.debug('Warning: ignoring auth proxy header since the '
393
+ 'auth user was already set.')
394
+ return await call_next(request)
395
+
396
+ # Add user to database if auth_user is present
397
+ if auth_user is not None:
398
+ newly_added = global_user_state.add_or_update_user(auth_user)
399
+ if newly_added:
400
+ permission.permission_service.add_user_if_not_exists(
401
+ auth_user.id)
402
+
403
+ # Store user info in request.state for access by GET endpoints
404
+ if auth_user is not None:
405
+ request.state.auth_user = auth_user
406
+
407
+ await authn.override_user_info_in_request_body(request, auth_user)
408
+ return await call_next(request)
409
+
410
+
108
411
  # Default expiration time for upload ids before cleanup.
109
412
  _DEFAULT_UPLOAD_EXPIRATION_TIME = datetime.timedelta(hours=1)
110
413
  # Key: (upload_id, user_hash), Value: the time when the upload id needs to be
@@ -134,21 +437,74 @@ async def cleanup_upload_ids():
134
437
  upload_ids_to_cleanup.pop((upload_id, user_hash))
135
438
 
136
439
 
440
+ async def loop_lag_monitor(loop: asyncio.AbstractEventLoop,
441
+ interval: float = 0.1) -> None:
442
+ target = loop.time() + interval
443
+
444
+ pid = str(os.getpid())
445
+ lag_threshold = perf_utils.get_loop_lag_threshold()
446
+
447
+ def tick():
448
+ nonlocal target
449
+ now = loop.time()
450
+ lag = max(0.0, now - target)
451
+ if lag_threshold is not None and lag > lag_threshold:
452
+ logger.warning(f'Event loop lag {lag} seconds exceeds threshold '
453
+ f'{lag_threshold} seconds.')
454
+ metrics_utils.SKY_APISERVER_EVENT_LOOP_LAG_SECONDS.labels(
455
+ pid=pid).observe(lag)
456
+ target = now + interval
457
+ loop.call_at(target, tick)
458
+
459
+ loop.call_at(target, tick)
460
+
461
+
462
+ async def schedule_on_boot_check_async():
463
+ try:
464
+ await executor.schedule_request_async(
465
+ request_id='skypilot-server-on-boot-check',
466
+ request_name=request_names.RequestName.CHECK,
467
+ request_body=payloads.CheckBody(),
468
+ func=sky_check.check,
469
+ schedule_type=requests_lib.ScheduleType.SHORT,
470
+ is_skypilot_system=True,
471
+ )
472
+ except exceptions.RequestAlreadyExistsError:
473
+ # Lifespan will be executed in each uvicorn worker process, we
474
+ # can safely ignore the error if the task is already scheduled.
475
+ logger.debug('Request skypilot-server-on-boot-check already exists.')
476
+
477
+
137
478
  @contextlib.asynccontextmanager
138
479
  async def lifespan(app: fastapi.FastAPI): # pylint: disable=redefined-outer-name
139
480
  """FastAPI lifespan context manager."""
140
481
  del app # unused
141
482
  # Startup: Run background tasks
142
- for event in requests_lib.INTERNAL_REQUEST_DAEMONS:
143
- executor.schedule_request(
144
- request_id=event.id,
145
- request_name=event.name,
146
- request_body=payloads.RequestBody(),
147
- func=event.event_fn,
148
- schedule_type=requests_lib.ScheduleType.SHORT,
149
- is_skypilot_system=True,
150
- )
483
+ for event in daemons.INTERNAL_REQUEST_DAEMONS:
484
+ if event.should_skip():
485
+ continue
486
+ try:
487
+ await executor.schedule_request_async(
488
+ request_id=event.id,
489
+ request_name=event.name,
490
+ request_body=payloads.RequestBody(),
491
+ func=event.run_event,
492
+ schedule_type=requests_lib.ScheduleType.SHORT,
493
+ is_skypilot_system=True,
494
+ # Request deamon should be retried if the process pool is
495
+ # broken.
496
+ retryable=True,
497
+ )
498
+ except exceptions.RequestAlreadyExistsError:
499
+ # Lifespan will be executed in each uvicorn worker process, we
500
+ # can safely ignore the error if the task is already scheduled.
501
+ logger.debug(f'Request {event.id} already exists.')
502
+ await schedule_on_boot_check_async()
151
503
  asyncio.create_task(cleanup_upload_ids())
504
+ if metrics_utils.METRICS_ENABLED:
505
+ # Start monitoring the event loop lag in each server worker
506
+ # event loop (process).
507
+ asyncio.create_task(loop_lag_monitor(asyncio.get_event_loop()))
152
508
  yield
153
509
  # Shutdown: Add any cleanup code here if needed
154
510
 
@@ -166,8 +522,99 @@ class InternalDashboardPrefixMiddleware(
166
522
  return await call_next(request)
167
523
 
168
524
 
525
+ class CacheControlStaticMiddleware(starlette.middleware.base.BaseHTTPMiddleware
526
+ ):
527
+ """Middleware to add cache control headers to static files."""
528
+
529
+ async def dispatch(self, request: fastapi.Request, call_next):
530
+ if request.url.path.startswith('/dashboard/_next'):
531
+ response = await call_next(request)
532
+ response.headers['Cache-Control'] = 'max-age=3600'
533
+ return response
534
+ return await call_next(request)
535
+
536
+
537
+ class PathCleanMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
538
+ """Middleware to check the path of requests."""
539
+
540
+ async def dispatch(self, request: fastapi.Request, call_next):
541
+ if request.url.path.startswith('/dashboard/'):
542
+ # If the requested path is not relative to the expected directory,
543
+ # then the user is attempting path traversal, so deny the request.
544
+ parent = pathlib.Path('/dashboard')
545
+ request_path = pathlib.Path(posixpath.normpath(request.url.path))
546
+ if not _is_relative_to(request_path, parent):
547
+ return fastapi.responses.JSONResponse(
548
+ status_code=403, content={'detail': 'Forbidden'})
549
+ return await call_next(request)
550
+
551
+
552
+ class GracefulShutdownMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
553
+ """Middleware to control requests when server is shutting down."""
554
+
555
+ async def dispatch(self, request: fastapi.Request, call_next):
556
+ if state.get_block_requests():
557
+ # Allow /api/ paths to continue, which are critical to operate
558
+ # on-going requests but will not submit new requests.
559
+ if not request.url.path.startswith('/api/'):
560
+ # Client will retry on 503 error.
561
+ return fastapi.responses.JSONResponse(
562
+ status_code=503,
563
+ content={
564
+ 'detail': 'Server is shutting down, '
565
+ 'please try again later.'
566
+ })
567
+
568
+ return await call_next(request)
569
+
570
+
571
+ class APIVersionMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
572
+ """Middleware to add API version to the request."""
573
+
574
+ async def dispatch(self, request: fastapi.Request, call_next):
575
+ version_info = versions.check_compatibility_at_server(request.headers)
576
+ # Bypass version handling for backward compatibility with clients prior
577
+ # to v0.11.0, the client will check the version in the body of
578
+ # /api/health response and hint an upgrade.
579
+ # TODO(aylei): remove this after v0.13.0 is released.
580
+ if version_info is None:
581
+ return await call_next(request)
582
+ if version_info.error is None:
583
+ versions.set_remote_api_version(version_info.api_version)
584
+ versions.set_remote_version(version_info.version)
585
+ response = await call_next(request)
586
+ else:
587
+ response = fastapi.responses.JSONResponse(
588
+ status_code=400,
589
+ content={
590
+ 'error': common.ApiServerStatus.VERSION_MISMATCH.value,
591
+ 'message': version_info.error,
592
+ })
593
+ response.headers[server_constants.API_VERSION_HEADER] = str(
594
+ server_constants.API_VERSION)
595
+ response.headers[server_constants.VERSION_HEADER] = \
596
+ versions.get_local_readable_version()
597
+ return response
598
+
599
+
169
600
  app = fastapi.FastAPI(prefix='/api/v1', debug=True, lifespan=lifespan)
601
+ # Middleware wraps in the order defined here. E.g., given
602
+ # app.add_middleware(Middleware1)
603
+ # app.add_middleware(Middleware2)
604
+ # app.add_middleware(Middleware3)
605
+ # The effect will be like:
606
+ # Middleware3(Middleware2(Middleware1(request)))
607
+ # If MiddlewareN does something like print(n); call_next(); print(n), you'll get
608
+ # 3; 2; 1; <request>; 1; 2; 3
609
+ # Use environment variable to make the metrics middleware optional.
610
+ if os.environ.get(constants.ENV_VAR_SERVER_METRICS_ENABLED):
611
+ app.add_middleware(metrics.PrometheusMiddleware)
612
+ app.add_middleware(APIVersionMiddleware)
613
+ app.add_middleware(RBACMiddleware)
170
614
  app.add_middleware(InternalDashboardPrefixMiddleware)
615
+ app.add_middleware(GracefulShutdownMiddleware)
616
+ app.add_middleware(PathCleanMiddleware)
617
+ app.add_middleware(CacheControlStaticMiddleware)
171
618
  app.add_middleware(
172
619
  cors.CORSMiddleware,
173
620
  # TODO(zhwu): in production deployment, we should restrict the allowed
@@ -176,20 +623,119 @@ app.add_middleware(
176
623
  allow_credentials=True,
177
624
  allow_methods=['*'],
178
625
  allow_headers=['*'],
179
- # TODO(syang): remove X-Request-ID when v0.10.0 is released.
626
+ # TODO(syang): remove X-Request-ID \when v0.10.0 is released.
180
627
  expose_headers=['X-Request-ID', 'X-Skypilot-Request-ID'])
628
+ # The order of all the authentication-related middleware is important.
629
+ # RBACMiddleware must precede all the auth middleware, so it can access
630
+ # request.state.auth_user.
631
+ app.add_middleware(RBACMiddleware)
632
+ # Authentication based on oauth2-proxy.
633
+ app.add_middleware(oauth2_proxy.OAuth2ProxyMiddleware)
634
+ # AuthProxyMiddleware should precede BasicAuthMiddleware and
635
+ # BearerTokenMiddleware, since it should be skipped if either of those set the
636
+ # auth user.
637
+ app.add_middleware(AuthProxyMiddleware)
638
+ enable_basic_auth = os.environ.get(constants.ENV_VAR_ENABLE_BASIC_AUTH, 'false')
639
+ if str(enable_basic_auth).lower() == 'true':
640
+ app.add_middleware(BasicAuthMiddleware)
641
+ # Bearer token middleware should always be present to handle service account
642
+ # authentication
643
+ app.add_middleware(BearerTokenMiddleware)
644
+ # InitializeRequestAuthUserMiddleware must be the last added middleware so that
645
+ # request.state.auth_user is always set, but can be overridden by the auth
646
+ # middleware above.
647
+ app.add_middleware(InitializeRequestAuthUserMiddleware)
181
648
  app.add_middleware(RequestIDMiddleware)
182
649
  app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
183
650
  app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
651
+ app.include_router(users_rest.router, prefix='/users', tags=['users'])
652
+ app.include_router(workspaces_rest.router,
653
+ prefix='/workspaces',
654
+ tags=['workspaces'])
655
+ app.include_router(volumes_rest.router, prefix='/volumes', tags=['volumes'])
656
+ app.include_router(ssh_node_pools_rest.router,
657
+ prefix='/ssh_node_pools',
658
+ tags=['ssh_node_pools'])
659
+ # increase the resource limit for the server
660
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
661
+ resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
662
+
663
+ # Increase the limit of files we can open to our hard limit. This fixes bugs
664
+ # where we can not aquire file locks or open enough logs and the API server
665
+ # crashes. On Mac, the hard limit is 9,223,372,036,854,775,807.
666
+ # TODO(luca) figure out what to do if we need to open more than 2^63 files.
667
+ try:
668
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
669
+ resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
670
+ except Exception: # pylint: disable=broad-except
671
+ pass # no issue, we will warn the user later if its too low
672
+
673
+
674
+ @app.exception_handler(exceptions.ConcurrentWorkerExhaustedError)
675
+ def handle_concurrent_worker_exhausted_error(
676
+ request: fastapi.Request, e: exceptions.ConcurrentWorkerExhaustedError):
677
+ del request # request is not used
678
+ # Print detailed error message to server log
679
+ logger.error('Concurrent worker exhausted: '
680
+ f'{common_utils.format_exception(e)}')
681
+ with ux_utils.enable_traceback():
682
+ logger.error(f' Traceback: {traceback.format_exc()}')
683
+ # Return human readable error message to client
684
+ return fastapi.responses.JSONResponse(
685
+ status_code=503,
686
+ content={
687
+ 'detail':
688
+ ('The server has exhausted its concurrent worker limit. '
689
+ 'Please try again or scale the server if the load persists.')
690
+ })
691
+
692
+
693
+ @app.get('/token')
694
+ async def token(request: fastapi.Request,
695
+ local_port: Optional[int] = None) -> fastapi.responses.Response:
696
+ del local_port # local_port is used by the served js, but ignored by server
697
+ user = _get_auth_user_header(request)
698
+
699
+ token_data = {
700
+ 'v': 1, # Token version number, bump for backwards incompatible.
701
+ 'user': user.id if user is not None else None,
702
+ 'cookies': request.cookies,
703
+ }
704
+ # Use base64 encoding to avoid having to escape anything in the HTML.
705
+ json_bytes = json.dumps(token_data).encode('utf-8')
706
+ base64_str = base64.b64encode(json_bytes).decode('utf-8')
707
+
708
+ html_dir = pathlib.Path(__file__).parent / 'html'
709
+ token_page_path = html_dir / 'token_page.html'
710
+ try:
711
+ with open(token_page_path, 'r', encoding='utf-8') as f:
712
+ html_content = f.read()
713
+ except FileNotFoundError as e:
714
+ raise fastapi.HTTPException(
715
+ status_code=500, detail='Token page template not found.') from e
716
+
717
+ user_info_string = f'Logged in as {user.name}' if user is not None else ''
718
+ html_content = html_content.replace(
719
+ 'SKYPILOT_API_SERVER_USER_TOKEN_PLACEHOLDER',
720
+ base64_str).replace('USER_PLACEHOLDER', user_info_string)
721
+
722
+ return fastapi.responses.HTMLResponse(
723
+ content=html_content,
724
+ headers={
725
+ 'Cache-Control': 'no-cache, no-transform',
726
+ # X-Accel-Buffering: no is useful for preventing buffering issues
727
+ # with some reverse proxies.
728
+ 'X-Accel-Buffering': 'no'
729
+ })
184
730
 
185
731
 
186
732
  @app.post('/check')
187
733
  async def check(request: fastapi.Request,
188
734
  check_body: payloads.CheckBody) -> None:
189
735
  """Checks enabled clouds."""
190
- executor.schedule_request(
736
+ await executor.schedule_request_async(
191
737
  request_id=request.state.request_id,
192
- request_name='check',
738
+ request_name=request_names.RequestName.CHECK,
193
739
  request_body=check_body,
194
740
  func=sky_check.check,
195
741
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -197,12 +743,15 @@ async def check(request: fastapi.Request,
197
743
 
198
744
 
199
745
  @app.get('/enabled_clouds')
200
- async def enabled_clouds(request: fastapi.Request) -> None:
746
+ async def enabled_clouds(request: fastapi.Request,
747
+ workspace: Optional[str] = None,
748
+ expand: bool = False) -> None:
201
749
  """Gets enabled clouds on the server."""
202
- executor.schedule_request(
750
+ await executor.schedule_request_async(
203
751
  request_id=request.state.request_id,
204
- request_name='enabled_clouds',
205
- request_body=payloads.RequestBody(),
752
+ request_name=request_names.RequestName.ENABLED_CLOUDS,
753
+ request_body=payloads.EnabledCloudsBody(workspace=workspace,
754
+ expand=expand),
206
755
  func=core.enabled_clouds,
207
756
  schedule_type=requests_lib.ScheduleType.SHORT,
208
757
  )
@@ -214,9 +763,10 @@ async def realtime_kubernetes_gpu_availability(
214
763
  realtime_gpu_availability_body: payloads.RealtimeGpuAvailabilityRequestBody
215
764
  ) -> None:
216
765
  """Gets real-time Kubernetes GPU availability."""
217
- executor.schedule_request(
766
+ await executor.schedule_request_async(
218
767
  request_id=request.state.request_id,
219
- request_name='realtime_kubernetes_gpu_availability',
768
+ request_name=request_names.RequestName.
769
+ REALTIME_KUBERNETES_GPU_AVAILABILITY,
220
770
  request_body=realtime_gpu_availability_body,
221
771
  func=core.realtime_kubernetes_gpu_availability,
222
772
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -229,9 +779,9 @@ async def kubernetes_node_info(
229
779
  kubernetes_node_info_body: payloads.KubernetesNodeInfoRequestBody
230
780
  ) -> None:
231
781
  """Gets Kubernetes nodes information and hints."""
232
- executor.schedule_request(
782
+ await executor.schedule_request_async(
233
783
  request_id=request.state.request_id,
234
- request_name='kubernetes_node_info',
784
+ request_name=request_names.RequestName.KUBERNETES_NODE_INFO,
235
785
  request_body=kubernetes_node_info_body,
236
786
  func=kubernetes_utils.get_kubernetes_node_info,
237
787
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -241,9 +791,9 @@ async def kubernetes_node_info(
241
791
  @app.get('/status_kubernetes')
242
792
  async def status_kubernetes(request: fastapi.Request) -> None:
243
793
  """Gets Kubernetes status."""
244
- executor.schedule_request(
794
+ await executor.schedule_request_async(
245
795
  request_id=request.state.request_id,
246
- request_name='status_kubernetes',
796
+ request_name=request_names.RequestName.STATUS_KUBERNETES,
247
797
  request_body=payloads.RequestBody(),
248
798
  func=core.status_kubernetes,
249
799
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -255,11 +805,11 @@ async def list_accelerators(
255
805
  request: fastapi.Request,
256
806
  list_accelerator_counts_body: payloads.ListAcceleratorsBody) -> None:
257
807
  """Gets list of accelerators from cloud catalog."""
258
- executor.schedule_request(
808
+ await executor.schedule_request_async(
259
809
  request_id=request.state.request_id,
260
- request_name='list_accelerators',
810
+ request_name=request_names.RequestName.LIST_ACCELERATORS,
261
811
  request_body=list_accelerator_counts_body,
262
- func=service_catalog.list_accelerators,
812
+ func=catalog.list_accelerators,
263
813
  schedule_type=requests_lib.ScheduleType.SHORT,
264
814
  )
265
815
 
@@ -270,11 +820,11 @@ async def list_accelerator_counts(
270
820
  list_accelerator_counts_body: payloads.ListAcceleratorCountsBody
271
821
  ) -> None:
272
822
  """Gets list of accelerator counts from cloud catalog."""
273
- executor.schedule_request(
823
+ await executor.schedule_request_async(
274
824
  request_id=request.state.request_id,
275
- request_name='list_accelerator_counts',
825
+ request_name=request_names.RequestName.LIST_ACCELERATOR_COUNTS,
276
826
  request_body=list_accelerator_counts_body,
277
- func=service_catalog.list_accelerator_counts,
827
+ func=catalog.list_accelerator_counts,
278
828
  schedule_type=requests_lib.ScheduleType.SHORT,
279
829
  )
280
830
 
@@ -292,25 +842,33 @@ async def validate(validate_body: payloads.ValidateBody) -> None:
292
842
  # pairs.
293
843
  logger.debug(f'Validating tasks: {validate_body.dag}')
294
844
 
845
+ context.initialize()
846
+ ctx = context.get()
847
+ assert ctx is not None
848
+ # TODO(aylei): generalize this to all requests without a db record.
849
+ ctx.override_envs(validate_body.env_vars)
850
+
295
851
  def validate_dag(dag: dag_utils.dag_lib.Dag):
296
852
  # TODO: Admin policy may contain arbitrary code, which may be expensive
297
853
  # to run and may block the server thread. However, moving it into the
298
854
  # executor adds a ~150ms penalty on the local API server because of
299
855
  # added RTTs. For now, we stick to doing the validation inline in the
300
856
  # server thread.
301
- dag, _ = admin_policy_utils.apply(
302
- dag, request_options=validate_body.request_options)
303
- # Skip validating workdir and file_mounts, as those need to be
304
- # validated after the files are uploaded to the SkyPilot API server
305
- # with `upload_mounts_to_api_server`.
306
- dag.validate(skip_file_mounts=True, skip_workdir=True)
857
+ with admin_policy_utils.apply_and_use_config_in_current_request(
858
+ dag,
859
+ request_name=request_names.AdminPolicyRequestName.VALIDATE,
860
+ request_options=validate_body.get_request_options()) as dag:
861
+ dag.resolve_and_validate_volumes()
862
+ # Skip validating workdir and file_mounts, as those need to be
863
+ # validated after the files are uploaded to the SkyPilot API server
864
+ # with `upload_mounts_to_api_server`.
865
+ dag.validate(skip_file_mounts=True, skip_workdir=True)
307
866
 
308
867
  try:
309
868
  dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag)
310
- loop = asyncio.get_running_loop()
311
869
  # Apply admin policy and validate DAG is blocking, run it in a separate
312
870
  # thread executor to avoid blocking the uvicorn event loop.
313
- await loop.run_in_executor(None, validate_dag, dag)
871
+ await context_utils.to_thread(validate_dag, dag)
314
872
  except Exception as e: # pylint: disable=broad-except
315
873
  raise fastapi.HTTPException(
316
874
  status_code=400, detail=exceptions.serialize_exception(e)) from e
@@ -320,9 +878,9 @@ async def validate(validate_body: payloads.ValidateBody) -> None:
320
878
  async def optimize(optimize_body: payloads.OptimizeBody,
321
879
  request: fastapi.Request) -> None:
322
880
  """Optimizes the user's DAG."""
323
- executor.schedule_request(
881
+ await executor.schedule_request_async(
324
882
  request_id=request.state.request_id,
325
- request_name='optimize',
883
+ request_name=request_names.RequestName.OPTIMIZE,
326
884
  request_body=optimize_body,
327
885
  ignore_return_value=True,
328
886
  func=core.optimize,
@@ -350,16 +908,30 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
350
908
  chunk_index: The chunk index, starting from 0.
351
909
  total_chunks: The total number of chunks.
352
910
  """
911
+ # Field _body would be set if the request body has been received, fail fast
912
+ # to surface potential memory issues, i.e. catch the issue in our smoke
913
+ # test.
914
+ # pylint: disable=protected-access
915
+ if hasattr(request, '_body'):
916
+ raise fastapi.HTTPException(
917
+ status_code=500,
918
+ detail='Upload request body should not be received before streaming'
919
+ )
353
920
  # Add the upload id to the cleanup list.
354
921
  upload_ids_to_cleanup[(upload_id,
355
922
  user_hash)] = (datetime.datetime.now() +
356
923
  _DEFAULT_UPLOAD_EXPIRATION_TIME)
924
+ # For anonymous access, use the user hash from client
925
+ user_id = user_hash
926
+ if request.state.auth_user is not None:
927
+ # Otherwise, the authenticated identity should be used.
928
+ user_id = request.state.auth_user.id
357
929
 
358
930
  # TODO(SKY-1271): We need to double check security of uploading zip file.
359
931
  client_file_mounts_dir = (
360
- common.API_SERVER_CLIENT_DIR.expanduser().resolve() / user_hash /
932
+ common.API_SERVER_CLIENT_DIR.expanduser().resolve() / user_id /
361
933
  'file_mounts')
362
- client_file_mounts_dir.mkdir(parents=True, exist_ok=True)
934
+ await anyio.Path(client_file_mounts_dir).mkdir(parents=True, exist_ok=True)
363
935
 
364
936
  # Check upload_id to be a valid SkyPilot run_timestamp appended with 8 hex
365
937
  # characters, e.g. 'sky-2025-01-17-09-10-13-933602-35d31c22'.
@@ -382,7 +954,7 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
382
954
  zip_file_path = client_file_mounts_dir / f'{upload_id}.zip'
383
955
  else:
384
956
  chunk_dir = client_file_mounts_dir / upload_id
385
- chunk_dir.mkdir(parents=True, exist_ok=True)
957
+ await anyio.Path(chunk_dir).mkdir(parents=True, exist_ok=True)
386
958
  zip_file_path = chunk_dir / f'part{chunk_index}.incomplete'
387
959
 
388
960
  try:
@@ -412,8 +984,9 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
412
984
  zip_file_path.rename(zip_file_path.with_suffix(''))
413
985
  missing_chunks = get_missing_chunks(total_chunks)
414
986
  if missing_chunks:
415
- return payloads.UploadZipFileResponse(status='uploading',
416
- missing_chunks=missing_chunks)
987
+ return payloads.UploadZipFileResponse(
988
+ status=responses.UploadStatus.UPLOADING.value,
989
+ missing_chunks=missing_chunks)
417
990
  zip_file_path = client_file_mounts_dir / f'{upload_id}.zip'
418
991
  async with aiofiles.open(zip_file_path, 'wb') as zip_file:
419
992
  for chunk in range(total_chunks):
@@ -427,10 +1000,11 @@ async def upload_zip_file(request: fastapi.Request, user_hash: str,
427
1000
  await zip_file.write(data)
428
1001
 
429
1002
  logger.info(f'Uploaded zip file: {zip_file_path}')
430
- unzip_file(zip_file_path, client_file_mounts_dir)
1003
+ await unzip_file(zip_file_path, client_file_mounts_dir)
431
1004
  if total_chunks > 1:
432
- shutil.rmtree(chunk_dir)
433
- return payloads.UploadZipFileResponse(status='completed')
1005
+ await context_utils.to_thread(shutil.rmtree, chunk_dir)
1006
+ return payloads.UploadZipFileResponse(
1007
+ status=responses.UploadStatus.COMPLETED.value)
434
1008
 
435
1009
 
436
1010
  def _is_relative_to(path: pathlib.Path, parent: pathlib.Path) -> bool:
@@ -443,61 +1017,69 @@ def _is_relative_to(path: pathlib.Path, parent: pathlib.Path) -> bool:
443
1017
  return False
444
1018
 
445
1019
 
446
- def unzip_file(zip_file_path: pathlib.Path,
447
- client_file_mounts_dir: pathlib.Path) -> None:
448
- """Unzips a zip file."""
449
- try:
450
- with zipfile.ZipFile(zip_file_path, 'r') as zipf:
451
- for member in zipf.infolist():
452
- # Determine the new path
453
- original_path = os.path.normpath(member.filename)
454
- new_path = client_file_mounts_dir / original_path.lstrip('/')
455
-
456
- if (member.external_attr >> 28) == 0xA:
457
- # Symlink. Read the target path and create a symlink.
1020
+ async def unzip_file(zip_file_path: pathlib.Path,
1021
+ client_file_mounts_dir: pathlib.Path) -> None:
1022
+ """Unzips a zip file without blocking the event loop."""
1023
+
1024
+ def _do_unzip() -> None:
1025
+ try:
1026
+ with zipfile.ZipFile(zip_file_path, 'r') as zipf:
1027
+ for member in zipf.infolist():
1028
+ # Determine the new path
1029
+ original_path = os.path.normpath(member.filename)
1030
+ new_path = client_file_mounts_dir / original_path.lstrip(
1031
+ '/')
1032
+
1033
+ if (member.external_attr >> 28) == 0xA:
1034
+ # Symlink. Read the target path and create a symlink.
1035
+ new_path.parent.mkdir(parents=True, exist_ok=True)
1036
+ target = zipf.read(member).decode()
1037
+ assert not os.path.isabs(target), target
1038
+ # Since target is a relative path, we need to check that
1039
+ # it is under `client_file_mounts_dir` for security.
1040
+ full_target_path = (new_path.parent / target).resolve()
1041
+ if not _is_relative_to(full_target_path,
1042
+ client_file_mounts_dir):
1043
+ raise ValueError(
1044
+ f'Symlink target {target} leads to a '
1045
+ 'file not in userspace. Aborted.')
1046
+
1047
+ if new_path.exists() or new_path.is_symlink():
1048
+ new_path.unlink(missing_ok=True)
1049
+ new_path.symlink_to(
1050
+ target,
1051
+ target_is_directory=member.filename.endswith('/'))
1052
+ continue
1053
+
1054
+ # Handle directories
1055
+ if member.filename.endswith('/'):
1056
+ new_path.mkdir(parents=True, exist_ok=True)
1057
+ continue
1058
+
1059
+ # Handle files
458
1060
  new_path.parent.mkdir(parents=True, exist_ok=True)
459
- target = zipf.read(member).decode()
460
- assert not os.path.isabs(target), target
461
- # Since target is a relative path, we need to check that it
462
- # is under `client_file_mounts_dir` for security.
463
- full_target_path = (new_path.parent / target).resolve()
464
- if not _is_relative_to(full_target_path,
465
- client_file_mounts_dir):
466
- raise ValueError(f'Symlink target {target} leads to a '
467
- 'file not in userspace. Aborted.')
468
-
469
- if new_path.exists() or new_path.is_symlink():
470
- new_path.unlink(missing_ok=True)
471
- new_path.symlink_to(
472
- target,
473
- target_is_directory=member.filename.endswith('/'))
474
- continue
475
-
476
- # Handle directories
477
- if member.filename.endswith('/'):
478
- new_path.mkdir(parents=True, exist_ok=True)
479
- continue
480
-
481
- # Handle files
482
- new_path.parent.mkdir(parents=True, exist_ok=True)
483
- with zipf.open(member) as member_file, new_path.open('wb') as f:
484
- # Use shutil.copyfileobj to copy files in chunks, so it does
485
- # not load the entire file into memory.
486
- shutil.copyfileobj(member_file, f)
487
- except zipfile.BadZipFile as e:
488
- logger.error(f'Bad zip file: {zip_file_path}')
489
- raise fastapi.HTTPException(
490
- status_code=400,
491
- detail=f'Invalid zip file: {common_utils.format_exception(e)}')
492
- except Exception as e:
493
- logger.error(f'Error unzipping file: {zip_file_path}')
494
- raise fastapi.HTTPException(
495
- status_code=500,
496
- detail=(f'Error unzipping file: '
497
- f'{common_utils.format_exception(e)}'))
1061
+ with zipf.open(member) as member_file, new_path.open(
1062
+ 'wb') as f:
1063
+ # Use shutil.copyfileobj to copy files in chunks,
1064
+ # so it does not load the entire file into memory.
1065
+ shutil.copyfileobj(member_file, f)
1066
+ except zipfile.BadZipFile as e:
1067
+ logger.error(f'Bad zip file: {zip_file_path}')
1068
+ raise fastapi.HTTPException(
1069
+ status_code=400,
1070
+ detail=f'Invalid zip file: {common_utils.format_exception(e)}')
1071
+ except Exception as e:
1072
+ logger.error(f'Error unzipping file: {zip_file_path}')
1073
+ raise fastapi.HTTPException(
1074
+ status_code=500,
1075
+ detail=(f'Error unzipping file: '
1076
+ f'{common_utils.format_exception(e)}'))
1077
+ finally:
1078
+ # Cleanup the temporary file regardless of
1079
+ # success/failure handling above
1080
+ zip_file_path.unlink(missing_ok=True)
498
1081
 
499
- # Cleanup the temporary file
500
- zip_file_path.unlink()
1082
+ await context_utils.to_thread(_do_unzip)
501
1083
 
502
1084
 
503
1085
  @app.post('/launch')
@@ -506,13 +1088,14 @@ async def launch(launch_body: payloads.LaunchBody,
506
1088
  """Launches a cluster or task."""
507
1089
  request_id = request.state.request_id
508
1090
  logger.info(f'Launching request: {request_id}')
509
- executor.schedule_request(
1091
+ await executor.schedule_request_async(
510
1092
  request_id,
511
- request_name='launch',
1093
+ request_name=request_names.RequestName.CLUSTER_LAUNCH,
512
1094
  request_body=launch_body,
513
1095
  func=execution.launch,
514
1096
  schedule_type=requests_lib.ScheduleType.LONG,
515
1097
  request_cluster_name=launch_body.cluster_name,
1098
+ retryable=launch_body.retry_until_up,
516
1099
  )
517
1100
 
518
1101
 
@@ -521,9 +1104,9 @@ async def launch(launch_body: payloads.LaunchBody,
521
1104
  async def exec(request: fastapi.Request, exec_body: payloads.ExecBody) -> None:
522
1105
  """Executes a task on an existing cluster."""
523
1106
  cluster_name = exec_body.cluster_name
524
- executor.schedule_request(
1107
+ await executor.schedule_request_async(
525
1108
  request_id=request.state.request_id,
526
- request_name='exec',
1109
+ request_name=request_names.RequestName.CLUSTER_EXEC,
527
1110
  request_body=exec_body,
528
1111
  func=execution.exec,
529
1112
  precondition=preconditions.ClusterStartCompletePrecondition(
@@ -539,9 +1122,9 @@ async def exec(request: fastapi.Request, exec_body: payloads.ExecBody) -> None:
539
1122
  async def stop(request: fastapi.Request,
540
1123
  stop_body: payloads.StopOrDownBody) -> None:
541
1124
  """Stops a cluster."""
542
- executor.schedule_request(
1125
+ await executor.schedule_request_async(
543
1126
  request_id=request.state.request_id,
544
- request_name='stop',
1127
+ request_name=request_names.RequestName.CLUSTER_STOP,
545
1128
  request_body=stop_body,
546
1129
  func=core.stop,
547
1130
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -555,9 +1138,13 @@ async def status(
555
1138
  status_body: payloads.StatusBody = payloads.StatusBody()
556
1139
  ) -> None:
557
1140
  """Gets cluster statuses."""
558
- executor.schedule_request(
1141
+ if state.get_block_requests():
1142
+ raise fastapi.HTTPException(
1143
+ status_code=503,
1144
+ detail='Server is shutting down, please try again later.')
1145
+ await executor.schedule_request_async(
559
1146
  request_id=request.state.request_id,
560
- request_name='status',
1147
+ request_name=request_names.RequestName.CLUSTER_STATUS,
561
1148
  request_body=status_body,
562
1149
  func=core.status,
563
1150
  schedule_type=(requests_lib.ScheduleType.LONG if
@@ -570,9 +1157,9 @@ async def status(
570
1157
  async def endpoints(request: fastapi.Request,
571
1158
  endpoint_body: payloads.EndpointsBody) -> None:
572
1159
  """Gets the endpoint for a given cluster and port number (endpoint)."""
573
- executor.schedule_request(
1160
+ await executor.schedule_request_async(
574
1161
  request_id=request.state.request_id,
575
- request_name='endpoints',
1162
+ request_name=request_names.RequestName.CLUSTER_ENDPOINTS,
576
1163
  request_body=endpoint_body,
577
1164
  func=core.endpoints,
578
1165
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -584,9 +1171,9 @@ async def endpoints(request: fastapi.Request,
584
1171
  async def down(request: fastapi.Request,
585
1172
  down_body: payloads.StopOrDownBody) -> None:
586
1173
  """Tears down a cluster."""
587
- executor.schedule_request(
1174
+ await executor.schedule_request_async(
588
1175
  request_id=request.state.request_id,
589
- request_name='down',
1176
+ request_name=request_names.RequestName.CLUSTER_DOWN,
590
1177
  request_body=down_body,
591
1178
  func=core.down,
592
1179
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -598,9 +1185,9 @@ async def down(request: fastapi.Request,
598
1185
  async def start(request: fastapi.Request,
599
1186
  start_body: payloads.StartBody) -> None:
600
1187
  """Restarts a cluster."""
601
- executor.schedule_request(
1188
+ await executor.schedule_request_async(
602
1189
  request_id=request.state.request_id,
603
- request_name='start',
1190
+ request_name=request_names.RequestName.CLUSTER_START,
604
1191
  request_body=start_body,
605
1192
  func=core.start,
606
1193
  schedule_type=requests_lib.ScheduleType.LONG,
@@ -612,9 +1199,9 @@ async def start(request: fastapi.Request,
612
1199
  async def autostop(request: fastapi.Request,
613
1200
  autostop_body: payloads.AutostopBody) -> None:
614
1201
  """Schedules an autostop/autodown for a cluster."""
615
- executor.schedule_request(
1202
+ await executor.schedule_request_async(
616
1203
  request_id=request.state.request_id,
617
- request_name='autostop',
1204
+ request_name=request_names.RequestName.CLUSTER_AUTOSTOP,
618
1205
  request_body=autostop_body,
619
1206
  func=core.autostop,
620
1207
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -626,9 +1213,9 @@ async def autostop(request: fastapi.Request,
626
1213
  async def queue(request: fastapi.Request,
627
1214
  queue_body: payloads.QueueBody) -> None:
628
1215
  """Gets the job queue of a cluster."""
629
- executor.schedule_request(
1216
+ await executor.schedule_request_async(
630
1217
  request_id=request.state.request_id,
631
- request_name='queue',
1218
+ request_name=request_names.RequestName.CLUSTER_QUEUE,
632
1219
  request_body=queue_body,
633
1220
  func=core.queue,
634
1221
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -640,9 +1227,9 @@ async def queue(request: fastapi.Request,
640
1227
  async def job_status(request: fastapi.Request,
641
1228
  job_status_body: payloads.JobStatusBody) -> None:
642
1229
  """Gets the status of a job."""
643
- executor.schedule_request(
1230
+ await executor.schedule_request_async(
644
1231
  request_id=request.state.request_id,
645
- request_name='job_status',
1232
+ request_name=request_names.RequestName.CLUSTER_JOB_STATUS,
646
1233
  request_body=job_status_body,
647
1234
  func=core.job_status,
648
1235
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -654,9 +1241,9 @@ async def job_status(request: fastapi.Request,
654
1241
  async def cancel(request: fastapi.Request,
655
1242
  cancel_body: payloads.CancelBody) -> None:
656
1243
  """Cancels jobs on a cluster."""
657
- executor.schedule_request(
1244
+ await executor.schedule_request_async(
658
1245
  request_id=request.state.request_id,
659
- request_name='cancel',
1246
+ request_name=request_names.RequestName.CLUSTER_JOB_CANCEL,
660
1247
  request_body=cancel_body,
661
1248
  func=core.cancel,
662
1249
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -673,36 +1260,27 @@ async def logs(
673
1260
  # TODO(zhwu): This should wait for the request on the cluster, e.g., async
674
1261
  # launch, to finish, so that a user does not need to manually pull the
675
1262
  # request status.
676
- executor.schedule_request(
1263
+ executor.check_request_thread_executor_available()
1264
+ request_task = await executor.prepare_request_async(
677
1265
  request_id=request.state.request_id,
678
- request_name='logs',
1266
+ request_name=request_names.RequestName.CLUSTER_JOB_LOGS,
679
1267
  request_body=cluster_job_body,
680
1268
  func=core.tail_logs,
681
- # TODO(aylei): We have tail logs scheduled as SHORT request, because it
682
- # should be responsive. However, it can be long running if the user's
683
- # job keeps running, and we should avoid it taking the SHORT worker.
684
1269
  schedule_type=requests_lib.ScheduleType.SHORT,
685
1270
  request_cluster_name=cluster_job_body.cluster_name,
686
1271
  )
687
-
688
- request_task = requests_lib.get_request(request.state.request_id)
689
-
1272
+ task = executor.execute_request_in_coroutine(request_task)
1273
+ background_tasks.add_task(task.cancel)
690
1274
  # TODO(zhwu): This makes viewing logs in browser impossible. We should adopt
691
1275
  # the same approach as /stream.
692
- return stream_utils.stream_response(
693
- request_id=request_task.request_id,
1276
+ return stream_utils.stream_response_for_long_request(
1277
+ request_id=request.state.request_id,
694
1278
  logs_path=request_task.log_path,
695
1279
  background_tasks=background_tasks,
1280
+ kill_request_on_disconnect=False,
696
1281
  )
697
1282
 
698
1283
 
699
- @app.get('/users')
700
- async def users() -> List[Dict[str, Any]]:
701
- """Gets all users."""
702
- user_list = global_user_state.get_all_users()
703
- return [user.to_dict() for user in user_list]
704
-
705
-
706
1284
  @app.post('/download_logs')
707
1285
  async def download_logs(
708
1286
  request: fastapi.Request,
@@ -714,9 +1292,9 @@ async def download_logs(
714
1292
  # We should reuse the original request body, so that the env vars, such as
715
1293
  # user hash, are kept the same.
716
1294
  cluster_jobs_body.local_dir = str(logs_dir_on_api_server)
717
- executor.schedule_request(
1295
+ await executor.schedule_request_async(
718
1296
  request_id=request.state.request_id,
719
- request_name='download_logs',
1297
+ request_name=request_names.RequestName.CLUSTER_JOB_DOWNLOAD_LOGS,
720
1298
  request_body=cluster_jobs_body,
721
1299
  func=core.download_logs,
722
1300
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -725,7 +1303,8 @@ async def download_logs(
725
1303
 
726
1304
 
727
1305
  @app.post('/download')
728
- async def download(download_body: payloads.DownloadBody) -> None:
1306
+ async def download(download_body: payloads.DownloadBody,
1307
+ request: fastapi.Request) -> None:
729
1308
  """Downloads a folder from the cluster to the local machine."""
730
1309
  folder_paths = [
731
1310
  pathlib.Path(folder_path) for folder_path in download_body.folder_paths
@@ -750,11 +1329,25 @@ async def download(download_body: payloads.DownloadBody) -> None:
750
1329
  logs_dir_on_api_server).expanduser().resolve() / zip_filename
751
1330
 
752
1331
  try:
753
- folders = [
754
- str(folder_path.expanduser().resolve())
755
- for folder_path in folder_paths
756
- ]
757
- storage_utils.zip_files_and_folders(folders, zip_path)
1332
+
1333
+ def _zip_files_and_folders(folder_paths, zip_path):
1334
+ folders = [
1335
+ str(folder_path.expanduser().resolve())
1336
+ for folder_path in folder_paths
1337
+ ]
1338
+ # Check for optional query parameter to control zip entry structure
1339
+ relative = request.query_params.get('relative', 'home')
1340
+ if relative == 'items':
1341
+ # Dashboard-friendly: entries relative to selected folders
1342
+ storage_utils.zip_files_and_folders(folders,
1343
+ zip_path,
1344
+ relative_to_items=True)
1345
+ else:
1346
+ # CLI-friendly (default): entries with full paths for mapping
1347
+ storage_utils.zip_files_and_folders(folders, zip_path)
1348
+
1349
+ await context_utils.to_thread(_zip_files_and_folders, folder_paths,
1350
+ zip_path)
758
1351
 
759
1352
  # Add home path to the response headers, so that the client can replace
760
1353
  # the remote path in the zip file to the local path.
@@ -776,13 +1369,84 @@ async def download(download_body: payloads.DownloadBody) -> None:
776
1369
  detail=f'Error creating zip file: {str(e)}')
777
1370
 
778
1371
 
779
- @app.get('/cost_report')
780
- async def cost_report(request: fastapi.Request) -> None:
1372
+ # TODO(aylei): run it asynchronously after global_user_state support async op
1373
+ @app.post('/provision_logs')
1374
+ def provision_logs(provision_logs_body: payloads.ProvisionLogsBody,
1375
+ follow: bool = True,
1376
+ tail: int = 0) -> fastapi.responses.StreamingResponse:
1377
+ """Streams the provision.log for the latest launch request of a cluster."""
1378
+ log_path = None
1379
+ cluster_name = provision_logs_body.cluster_name
1380
+ worker = provision_logs_body.worker
1381
+ # stream head node logs
1382
+ if worker is None:
1383
+ # Prefer clusters table first, then cluster_history as fallback.
1384
+ log_path_str = global_user_state.get_cluster_provision_log_path(
1385
+ cluster_name)
1386
+ if not log_path_str:
1387
+ log_path_str = (
1388
+ global_user_state.get_cluster_history_provision_log_path(
1389
+ cluster_name))
1390
+ if not log_path_str:
1391
+ raise fastapi.HTTPException(
1392
+ status_code=404,
1393
+ detail=('Provision log path is not recorded for this cluster. '
1394
+ 'Please relaunch to generate provisioning logs.'))
1395
+ log_path = pathlib.Path(log_path_str).expanduser().resolve()
1396
+ if not log_path.exists():
1397
+ raise fastapi.HTTPException(
1398
+ status_code=404,
1399
+ detail=f'Provision log path does not exist: {str(log_path)}')
1400
+
1401
+ # stream worker node logs
1402
+ else:
1403
+ handle = global_user_state.get_handle_from_cluster_name(cluster_name)
1404
+ if handle is None:
1405
+ raise fastapi.HTTPException(
1406
+ status_code=404,
1407
+ detail=('Cluster handle is not recorded for this cluster. '
1408
+ 'Please relaunch to generate provisioning logs.'))
1409
+ # instance_ids includes head node
1410
+ instance_ids = handle.instance_ids
1411
+ if instance_ids is None:
1412
+ raise fastapi.HTTPException(
1413
+ status_code=400,
1414
+ detail='Instance IDs are not recorded for this cluster. '
1415
+ 'Please relaunch to generate provisioning logs.')
1416
+ if worker > len(instance_ids) - 1:
1417
+ raise fastapi.HTTPException(
1418
+ status_code=400,
1419
+ detail=f'Worker {worker} is out of range. '
1420
+ f'The cluster has {len(instance_ids)} nodes.')
1421
+ log_path = metadata_utils.get_instance_log_dir(
1422
+ handle.get_cluster_name_on_cloud(), instance_ids[worker])
1423
+
1424
+ # Tail semantics: 0 means print all lines. Convert 0 -> None for streamer.
1425
+ effective_tail = None if tail is None or tail <= 0 else tail
1426
+
1427
+ return fastapi.responses.StreamingResponse(
1428
+ content=stream_utils.log_streamer(None,
1429
+ log_path,
1430
+ tail=effective_tail,
1431
+ follow=follow,
1432
+ cluster_name=cluster_name),
1433
+ media_type='text/plain',
1434
+ headers={
1435
+ 'Cache-Control': 'no-cache, no-transform',
1436
+ 'X-Accel-Buffering': 'no',
1437
+ 'Transfer-Encoding': 'chunked',
1438
+ },
1439
+ )
1440
+
1441
+
1442
+ @app.post('/cost_report')
1443
+ async def cost_report(request: fastapi.Request,
1444
+ cost_report_body: payloads.CostReportBody) -> None:
781
1445
  """Gets the cost report of a cluster."""
782
- executor.schedule_request(
1446
+ await executor.schedule_request_async(
783
1447
  request_id=request.state.request_id,
784
- request_name='cost_report',
785
- request_body=payloads.RequestBody(),
1448
+ request_name=request_names.RequestName.CLUSTER_COST_REPORT,
1449
+ request_body=cost_report_body,
786
1450
  func=core.cost_report,
787
1451
  schedule_type=requests_lib.ScheduleType.SHORT,
788
1452
  )
@@ -791,9 +1455,9 @@ async def cost_report(request: fastapi.Request) -> None:
791
1455
  @app.get('/storage/ls')
792
1456
  async def storage_ls(request: fastapi.Request) -> None:
793
1457
  """Gets the storages."""
794
- executor.schedule_request(
1458
+ await executor.schedule_request_async(
795
1459
  request_id=request.state.request_id,
796
- request_name='storage_ls',
1460
+ request_name=request_names.RequestName.STORAGE_LS,
797
1461
  request_body=payloads.RequestBody(),
798
1462
  func=core.storage_ls,
799
1463
  schedule_type=requests_lib.ScheduleType.SHORT,
@@ -804,9 +1468,9 @@ async def storage_ls(request: fastapi.Request) -> None:
804
1468
  async def storage_delete(request: fastapi.Request,
805
1469
  storage_body: payloads.StorageBody) -> None:
806
1470
  """Deletes a storage."""
807
- executor.schedule_request(
1471
+ await executor.schedule_request_async(
808
1472
  request_id=request.state.request_id,
809
- request_name='storage_delete',
1473
+ request_name=request_names.RequestName.STORAGE_DELETE,
810
1474
  request_body=storage_body,
811
1475
  func=core.storage_delete,
812
1476
  schedule_type=requests_lib.ScheduleType.LONG,
@@ -817,9 +1481,9 @@ async def storage_delete(request: fastapi.Request,
817
1481
  async def local_up(request: fastapi.Request,
818
1482
  local_up_body: payloads.LocalUpBody) -> None:
819
1483
  """Launches a Kubernetes cluster on API server."""
820
- executor.schedule_request(
1484
+ await executor.schedule_request_async(
821
1485
  request_id=request.state.request_id,
822
- request_name='local_up',
1486
+ request_name=request_names.RequestName.LOCAL_UP,
823
1487
  request_body=local_up_body,
824
1488
  func=core.local_up,
825
1489
  schedule_type=requests_lib.ScheduleType.LONG,
@@ -827,37 +1491,65 @@ async def local_up(request: fastapi.Request,
827
1491
 
828
1492
 
829
1493
  @app.post('/local_down')
830
- async def local_down(request: fastapi.Request) -> None:
1494
+ async def local_down(request: fastapi.Request,
1495
+ local_down_body: payloads.LocalDownBody) -> None:
831
1496
  """Tears down the Kubernetes cluster started by local_up."""
832
- executor.schedule_request(
1497
+ await executor.schedule_request_async(
833
1498
  request_id=request.state.request_id,
834
- request_name='local_down',
835
- request_body=payloads.RequestBody(),
1499
+ request_name=request_names.RequestName.LOCAL_DOWN,
1500
+ request_body=local_down_body,
836
1501
  func=core.local_down,
837
1502
  schedule_type=requests_lib.ScheduleType.LONG,
838
1503
  )
839
1504
 
840
1505
 
1506
+ async def get_expanded_request_id(request_id: str) -> str:
1507
+ """Gets the expanded request ID for a given request ID prefix."""
1508
+ request_tasks = await requests_lib.get_requests_async_with_prefix(
1509
+ request_id, fields=['request_id'])
1510
+ if request_tasks is None:
1511
+ raise fastapi.HTTPException(status_code=404,
1512
+ detail=f'Request {request_id!r} not found')
1513
+ if len(request_tasks) > 1:
1514
+ raise fastapi.HTTPException(status_code=400,
1515
+ detail=('Multiple requests found for '
1516
+ f'request ID prefix: {request_id}'))
1517
+ return request_tasks[0].request_id
1518
+
1519
+
841
1520
  # === API server related APIs ===
842
- @app.get('/api/get')
843
- async def api_get(request_id: str) -> requests_lib.RequestPayload:
1521
+ @app.get('/api/get', response_class=fastapi_responses.ORJSONResponse)
1522
+ async def api_get(request_id: str) -> payloads.RequestPayload:
844
1523
  """Gets a request with a given request ID prefix."""
1524
+ # Validate request_id prefix matches a single request.
1525
+ request_id = await get_expanded_request_id(request_id)
1526
+
845
1527
  while True:
846
- request_task = requests_lib.get_request(request_id)
847
- if request_task is None:
1528
+ req_status = await requests_lib.get_request_status_async(request_id)
1529
+ if req_status is None:
848
1530
  print(f'No task with request ID {request_id}', flush=True)
849
1531
  raise fastapi.HTTPException(
850
1532
  status_code=404, detail=f'Request {request_id!r} not found')
851
- if request_task.status > requests_lib.RequestStatus.RUNNING:
852
- request_error = request_task.get_error()
853
- if request_error is not None:
854
- raise fastapi.HTTPException(status_code=500,
855
- detail=dataclasses.asdict(
856
- request_task.encode()))
857
- return request_task.encode()
1533
+ if (req_status.status == requests_lib.RequestStatus.RUNNING and
1534
+ daemons.is_daemon_request_id(request_id)):
1535
+ # Daemon requests run forever, break without waiting for complete.
1536
+ break
1537
+ if req_status.status > requests_lib.RequestStatus.RUNNING:
1538
+ break
858
1539
  # yield control to allow other coroutines to run, sleep shortly
859
1540
  # to avoid storming the DB and CPU in the meantime
860
1541
  await asyncio.sleep(0.1)
1542
+ request_task = await requests_lib.get_request_async(request_id)
1543
+ # TODO(aylei): refine this, /api/get will not be retried and this is
1544
+ # meaningless to retry. It is the original request that should be retried.
1545
+ if request_task.should_retry:
1546
+ raise fastapi.HTTPException(
1547
+ status_code=503, detail=f'Request {request_id!r} should be retried')
1548
+ request_error = request_task.get_error()
1549
+ if request_error is not None:
1550
+ raise fastapi.HTTPException(status_code=500,
1551
+ detail=request_task.encode().model_dump())
1552
+ return request_task.encode()
861
1553
 
862
1554
 
863
1555
  @app.get('/api/stream')
@@ -891,13 +1583,18 @@ async def stream(
891
1583
  clients, console for CLI/API clients), 'plain' (force plain text),
892
1584
  'html' (force HTML), or 'console' (force console)
893
1585
  """
1586
+ # We need to save the user-supplied request ID for the response header.
1587
+ user_supplied_request_id = request_id
894
1588
  if request_id is not None and log_path is not None:
895
1589
  raise fastapi.HTTPException(
896
1590
  status_code=400,
897
1591
  detail='Only one of request_id and log_path can be provided')
898
1592
 
1593
+ if request_id is not None:
1594
+ request_id = await get_expanded_request_id(request_id)
1595
+
899
1596
  if request_id is None and log_path is None:
900
- request_id = requests_lib.get_latest_request_id()
1597
+ request_id = await requests_lib.get_latest_request_id_async()
901
1598
  if request_id is None:
902
1599
  raise fastapi.HTTPException(status_code=404,
903
1600
  detail='No request found')
@@ -924,19 +1621,40 @@ async def stream(
924
1621
  'X-Accel-Buffering': 'no'
925
1622
  })
926
1623
 
1624
+ polling_interval = stream_utils.DEFAULT_POLL_INTERVAL
927
1625
  # Original plain text streaming logic
928
1626
  if request_id is not None:
929
- request_task = requests_lib.get_request(request_id)
1627
+ request_task = await requests_lib.get_request_async(
1628
+ request_id, fields=['request_id', 'schedule_type'])
930
1629
  if request_task is None:
931
1630
  print(f'No task with request ID {request_id}')
932
1631
  raise fastapi.HTTPException(
933
1632
  status_code=404, detail=f'Request {request_id!r} not found')
1633
+ # req.log_path is derived from request_id,
1634
+ # so it's ok to just grab the request_id in the above query.
934
1635
  log_path_to_stream = request_task.log_path
1636
+ if not log_path_to_stream.exists():
1637
+ # The log file might be deleted by the request GC daemon but the
1638
+ # request task is still in the database.
1639
+ raise fastapi.HTTPException(
1640
+ status_code=404,
1641
+ detail=f'Log of request {request_id!r} has been deleted')
1642
+ if request_task.schedule_type == requests_lib.ScheduleType.LONG:
1643
+ polling_interval = stream_utils.LONG_REQUEST_POLL_INTERVAL
1644
+ del request_task
935
1645
  else:
936
1646
  assert log_path is not None, (request_id, log_path)
937
1647
  if log_path == constants.API_SERVER_LOGS:
938
1648
  resolved_log_path = pathlib.Path(
939
1649
  constants.API_SERVER_LOGS).expanduser()
1650
+ if not resolved_log_path.exists():
1651
+ raise fastapi.HTTPException(
1652
+ status_code=404,
1653
+ detail='Server log file does not exist. The API server may '
1654
+ 'have been started with `--foreground` - check the '
1655
+ 'stdout of API server process, such as: '
1656
+ '`kubectl logs -n api-server-namespace '
1657
+ 'api-server-pod-name`')
940
1658
  else:
941
1659
  # This should be a log path under ~/sky_logs.
942
1660
  resolved_logs_directory = pathlib.Path(
@@ -957,18 +1675,26 @@ async def stream(
957
1675
  detail=f'Log path {log_path!r} does not exist')
958
1676
 
959
1677
  log_path_to_stream = resolved_log_path
1678
+
1679
+ headers = {
1680
+ 'Cache-Control': 'no-cache, no-transform',
1681
+ 'X-Accel-Buffering': 'no',
1682
+ 'Transfer-Encoding': 'chunked'
1683
+ }
1684
+ if request_id is not None:
1685
+ headers[server_constants.STREAM_REQUEST_HEADER] = (
1686
+ user_supplied_request_id
1687
+ if user_supplied_request_id else request_id)
1688
+
960
1689
  return fastapi.responses.StreamingResponse(
961
1690
  content=stream_utils.log_streamer(request_id,
962
1691
  log_path_to_stream,
963
1692
  plain_logs=format == 'plain',
964
1693
  tail=tail,
965
- follow=follow),
1694
+ follow=follow,
1695
+ polling_interval=polling_interval),
966
1696
  media_type='text/plain',
967
- headers={
968
- 'Cache-Control': 'no-cache, no-transform',
969
- 'X-Accel-Buffering': 'no',
970
- 'Transfer-Encoding': 'chunked'
971
- },
1697
+ headers=headers,
972
1698
  )
973
1699
 
974
1700
 
@@ -976,11 +1702,11 @@ async def stream(
976
1702
  async def api_cancel(request: fastapi.Request,
977
1703
  request_cancel_body: payloads.RequestCancelBody) -> None:
978
1704
  """Cancels requests."""
979
- executor.schedule_request(
1705
+ await executor.schedule_request_async(
980
1706
  request_id=request.state.request_id,
981
- request_name='api_cancel',
1707
+ request_name=request_names.RequestName.API_CANCEL,
982
1708
  request_body=request_cancel_body,
983
- func=requests_lib.kill_requests,
1709
+ func=requests_lib.kill_requests_with_prefix,
984
1710
  schedule_type=requests_lib.ScheduleType.SHORT,
985
1711
  )
986
1712
 
@@ -988,10 +1714,14 @@ async def api_cancel(request: fastapi.Request,
988
1714
  @app.get('/api/status')
989
1715
  async def api_status(
990
1716
  request_ids: Optional[List[str]] = fastapi.Query(
991
- None, description='Request IDs to get status for.'),
1717
+ None, description='Request ID prefixes to get status for.'),
992
1718
  all_status: bool = fastapi.Query(
993
1719
  False, description='Get finished requests as well.'),
994
- ) -> List[requests_lib.RequestPayload]:
1720
+ limit: Optional[int] = fastapi.Query(
1721
+ None, description='Number of requests to show.'),
1722
+ fields: Optional[List[str]] = fastapi.Query(
1723
+ None, description='Fields to get. If None, get all fields.'),
1724
+ ) -> List[payloads.RequestPayload]:
995
1725
  """Gets the list of requests."""
996
1726
  if request_ids is None:
997
1727
  statuses = None
@@ -1000,53 +1730,120 @@ async def api_status(
1000
1730
  requests_lib.RequestStatus.PENDING,
1001
1731
  requests_lib.RequestStatus.RUNNING,
1002
1732
  ]
1003
- return [
1004
- request_task.readable_encode()
1005
- for request_task in requests_lib.get_request_tasks(status=statuses)
1006
- ]
1733
+ request_tasks = await requests_lib.get_request_tasks_async(
1734
+ req_filter=requests_lib.RequestTaskFilter(
1735
+ status=statuses,
1736
+ limit=limit,
1737
+ fields=fields,
1738
+ sort=True,
1739
+ ))
1740
+ return requests_lib.encode_requests(request_tasks)
1007
1741
  else:
1008
1742
  encoded_request_tasks = []
1009
1743
  for request_id in request_ids:
1010
- request_task = requests_lib.get_request(request_id)
1011
- if request_task is None:
1744
+ request_tasks = await requests_lib.get_requests_async_with_prefix(
1745
+ request_id)
1746
+ if request_tasks is None:
1012
1747
  continue
1013
- encoded_request_tasks.append(request_task.readable_encode())
1748
+ for request_task in request_tasks:
1749
+ encoded_request_tasks.append(request_task.readable_encode())
1014
1750
  return encoded_request_tasks
1015
1751
 
1016
1752
 
1017
- @app.get('/api/health')
1018
- async def health() -> Dict[str, str]:
1753
+ @app.get(
1754
+ '/api/health',
1755
+ # response_model_exclude_unset omits unset fields
1756
+ # in the response JSON.
1757
+ response_model_exclude_unset=True)
1758
+ async def health(request: fastapi.Request) -> responses.APIHealthResponse:
1019
1759
  """Checks the health of the API server.
1020
1760
 
1021
1761
  Returns:
1022
- A dictionary with the following keys:
1023
- - status: str; The status of the API server.
1024
- - api_version: str; The API version of the API server.
1025
- - version: str; The version of SkyPilot used for API server.
1026
- - version_on_disk: str; The version of the SkyPilot installation on
1027
- disk, which can be used to warn about restarting the API server
1028
- - commit: str; The commit hash of SkyPilot used for API server.
1762
+ responses.APIHealthResponse: The health response.
1029
1763
  """
1030
- return {
1031
- 'status': common.ApiServerStatus.HEALTHY.value,
1032
- 'api_version': server_constants.API_VERSION,
1033
- 'version': sky.__version__,
1034
- 'version_on_disk': common.get_skypilot_version_on_disk(),
1035
- 'commit': sky.__commit__,
1036
- }
1764
+ user = request.state.auth_user
1765
+ server_status = common.ApiServerStatus.HEALTHY
1766
+ if getattr(request.state, 'anonymous_user', False):
1767
+ # API server authentication is enabled, but the request is not
1768
+ # authenticated. We still have to serve the request because the
1769
+ # /api/health endpoint has two different usage:
1770
+ # 1. For health check from `api start` and external ochestration
1771
+ # tools (k8s), which does not require authentication and user info.
1772
+ # 2. Return server info to client and hint client to login if required.
1773
+ # Separating these two usage to different APIs will break backward
1774
+ # compatibility for existing ochestration solutions (e.g. helm chart).
1775
+ # So we serve these two usages in a backward compatible manner below.
1776
+ client_version = versions.get_remote_api_version()
1777
+ # - For Client with API version >= 14, we return 200 response with
1778
+ # status=NEEDS_AUTH, new client will handle the login process.
1779
+ # - For health check from `sky api start`, the client code always uses
1780
+ # the same API version with the server, thus there is no compatibility
1781
+ # issue.
1782
+ server_status = common.ApiServerStatus.NEEDS_AUTH
1783
+ if client_version is None:
1784
+ # - For health check from ochestration tools (e.g. k8s), we also
1785
+ # return 200 with status=NEEDS_AUTH, which passes HTTP probe
1786
+ # check.
1787
+ # - There is no harm when an malicious client calls /api/health
1788
+ # without authentication since no sensitive information is
1789
+ # returned.
1790
+ return responses.APIHealthResponse(
1791
+ status=common.ApiServerStatus.HEALTHY,)
1792
+ # TODO(aylei): remove this after min_compatible_api_version >= 14.
1793
+ if client_version < 14:
1794
+ # For Client with API version < 14, the NEEDS_AUTH status is not
1795
+ # honored. Return 401 to trigger the login process.
1796
+ raise fastapi.HTTPException(status_code=401,
1797
+ detail='Authentication required')
1798
+
1799
+ logger.debug(f'Health endpoint: request.state.auth_user = {user}')
1800
+ return responses.APIHealthResponse(
1801
+ status=server_status,
1802
+ # Kept for backward compatibility, clients before 0.11.0 will read this
1803
+ # field to check compatibility and hint the user to upgrade the CLI.
1804
+ # TODO(aylei): remove this field after 0.13.0
1805
+ api_version=str(server_constants.API_VERSION),
1806
+ version=sky.__version__,
1807
+ version_on_disk=common.get_skypilot_version_on_disk(),
1808
+ commit=sky.__commit__,
1809
+ # Whether basic auth on api server is enabled
1810
+ basic_auth_enabled=os.environ.get(constants.ENV_VAR_ENABLE_BASIC_AUTH,
1811
+ 'false').lower() == 'true',
1812
+ user=user if user is not None else None,
1813
+ # Whether service account token is enabled
1814
+ service_account_token_enabled=(os.environ.get(
1815
+ constants.ENV_VAR_ENABLE_SERVICE_ACCOUNTS,
1816
+ 'false').lower() == 'true'),
1817
+ # Whether basic auth on ingress is enabled
1818
+ ingress_basic_auth_enabled=os.environ.get(
1819
+ constants.SKYPILOT_INGRESS_BASIC_AUTH_ENABLED,
1820
+ 'false').lower() == 'true',
1821
+ )
1822
+
1823
+
1824
+ class KubernetesSSHMessageType(IntEnum):
1825
+ REGULAR_DATA = 0
1826
+ PINGPONG = 1
1827
+ LATENCY_MEASUREMENT = 2
1037
1828
 
1038
1829
 
1039
1830
  @app.websocket('/kubernetes-pod-ssh-proxy')
1040
1831
  async def kubernetes_pod_ssh_proxy(
1041
- websocket: fastapi.WebSocket,
1042
- cluster_name_body: payloads.ClusterNameBody = fastapi.Depends()
1043
- ) -> None:
1832
+ websocket: fastapi.WebSocket,
1833
+ cluster_name: str,
1834
+ client_version: Optional[int] = None) -> None:
1044
1835
  """Proxies SSH to the Kubernetes pod with websocket."""
1045
1836
  await websocket.accept()
1046
- cluster_name = cluster_name_body.cluster_name
1047
1837
  logger.info(f'WebSocket connection accepted for cluster: {cluster_name}')
1048
1838
 
1049
- cluster_records = core.status(cluster_name, all_users=True)
1839
+ timestamps_supported = client_version is not None and client_version > 21
1840
+ logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
1841
+ client_version = {client_version}')
1842
+
1843
+ # Run core.status in another thread to avoid blocking the event loop.
1844
+ with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
1845
+ cluster_records = await context_utils.to_thread_with_executor(
1846
+ thread_pool_executor, core.status, cluster_name, all_users=True)
1050
1847
  cluster_record = cluster_records[0]
1051
1848
  if cluster_record['status'] != status_lib.ClusterStatus.UP:
1052
1849
  raise fastapi.HTTPException(
@@ -1085,17 +1882,70 @@ async def kubernetes_pod_ssh_proxy(
1085
1882
  return
1086
1883
 
1087
1884
  logger.info(f'Starting port-forward to local port: {local_port}')
1885
+ conn_gauge = metrics_utils.SKY_APISERVER_WEBSOCKET_CONNECTIONS.labels(
1886
+ pid=os.getpid())
1887
+ ssh_failed = False
1888
+ websocket_closed = False
1088
1889
  try:
1890
+ conn_gauge.inc()
1089
1891
  # Connect to the local port
1090
1892
  reader, writer = await asyncio.open_connection('127.0.0.1', local_port)
1091
1893
 
1092
1894
  async def websocket_to_ssh():
1093
1895
  try:
1094
1896
  async for message in websocket.iter_bytes():
1897
+ if timestamps_supported:
1898
+ type_size = struct.calcsize('!B')
1899
+ message_type = struct.unpack('!B',
1900
+ message[:type_size])[0]
1901
+ if (message_type ==
1902
+ KubernetesSSHMessageType.REGULAR_DATA):
1903
+ # Regular data - strip type byte and forward to SSH
1904
+ message = message[type_size:]
1905
+ elif message_type == KubernetesSSHMessageType.PINGPONG:
1906
+ # PING message - respond with PONG (type 1)
1907
+ ping_id_size = struct.calcsize('!I')
1908
+ if len(message) != type_size + ping_id_size:
1909
+ raise ValueError('Invalid PING message '
1910
+ f'length: {len(message)}')
1911
+ # Return the same PING message, so that the client
1912
+ # can measure the latency.
1913
+ await websocket.send_bytes(message)
1914
+ continue
1915
+ elif (message_type ==
1916
+ KubernetesSSHMessageType.LATENCY_MEASUREMENT):
1917
+ # Latency measurement from client
1918
+ latency_size = struct.calcsize('!Q')
1919
+ if len(message) != type_size + latency_size:
1920
+ raise ValueError(
1921
+ 'Invalid latency measurement '
1922
+ f'message length: {len(message)}')
1923
+ avg_latency_ms = struct.unpack(
1924
+ '!Q',
1925
+ message[type_size:type_size + latency_size])[0]
1926
+ latency_seconds = avg_latency_ms / 1000
1927
+ metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels(pid=os.getpid()).observe(latency_seconds) # pylint: disable=line-too-long
1928
+ continue
1929
+ else:
1930
+ # Unknown message type.
1931
+ raise ValueError(
1932
+ f'Unknown message type: {message_type}')
1095
1933
  writer.write(message)
1096
- await writer.drain()
1934
+ try:
1935
+ await writer.drain()
1936
+ except Exception as e: # pylint: disable=broad-except
1937
+ # Typically we will not reach here, if the ssh to pod
1938
+ # is disconnected, ssh_to_websocket will exit first.
1939
+ # But just in case.
1940
+ logger.error('Failed to write to pod through '
1941
+ f'port-forward connection: {e}')
1942
+ nonlocal ssh_failed
1943
+ ssh_failed = True
1944
+ break
1097
1945
  except fastapi.WebSocketDisconnect:
1098
1946
  pass
1947
+ nonlocal websocket_closed
1948
+ websocket_closed = True
1099
1949
  writer.close()
1100
1950
 
1101
1951
  async def ssh_to_websocket():
@@ -1103,87 +1953,249 @@ async def kubernetes_pod_ssh_proxy(
1103
1953
  while True:
1104
1954
  data = await reader.read(1024)
1105
1955
  if not data:
1956
+ if not websocket_closed:
1957
+ logger.warning('SSH connection to pod is '
1958
+ 'disconnected before websocket '
1959
+ 'connection is closed')
1960
+ nonlocal ssh_failed
1961
+ ssh_failed = True
1106
1962
  break
1963
+ if timestamps_supported:
1964
+ # Prepend message type byte (0 = regular data)
1965
+ message_type_bytes = struct.pack(
1966
+ '!B', KubernetesSSHMessageType.REGULAR_DATA.value)
1967
+ data = message_type_bytes + data
1107
1968
  await websocket.send_bytes(data)
1108
1969
  except Exception: # pylint: disable=broad-except
1109
1970
  pass
1110
- await websocket.close()
1971
+ try:
1972
+ await websocket.close()
1973
+ except Exception: # pylint: disable=broad-except
1974
+ # The websocket might has been closed by the client.
1975
+ pass
1111
1976
 
1112
1977
  await asyncio.gather(websocket_to_ssh(), ssh_to_websocket())
1113
1978
  finally:
1114
- proc.terminate()
1979
+ conn_gauge.dec()
1980
+ reason = ''
1981
+ try:
1982
+ logger.info('Terminating kubectl port-forward process')
1983
+ proc.terminate()
1984
+ except ProcessLookupError:
1985
+ stdout = await proc.stdout.read()
1986
+ logger.error('kubectl port-forward was terminated before the '
1987
+ 'ssh websocket connection was closed. Remaining '
1988
+ f'output: {str(stdout)}')
1989
+ reason = 'KubectlPortForwardExit'
1990
+ metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
1991
+ pid=os.getpid(), reason='KubectlPortForwardExit').inc()
1992
+ else:
1993
+ if ssh_failed:
1994
+ reason = 'SSHToPodDisconnected'
1995
+ else:
1996
+ reason = 'ClientClosed'
1997
+ metrics_utils.SKY_APISERVER_WEBSOCKET_CLOSED_TOTAL.labels(
1998
+ pid=os.getpid(), reason=reason).inc()
1999
+
2000
+
2001
+ @app.get('/all_contexts')
2002
+ async def all_contexts(request: fastapi.Request) -> None:
2003
+ """Gets all Kubernetes and SSH node pool contexts."""
2004
+
2005
+ await executor.schedule_request_async(
2006
+ request_id=request.state.request_id,
2007
+ request_name=request_names.RequestName.ALL_CONTEXTS,
2008
+ request_body=payloads.RequestBody(),
2009
+ func=core.get_all_contexts,
2010
+ schedule_type=requests_lib.ScheduleType.SHORT,
2011
+ )
1115
2012
 
1116
2013
 
1117
2014
  # === Internal APIs ===
1118
2015
  @app.get('/api/completion/cluster_name')
1119
2016
  async def complete_cluster_name(incomplete: str,) -> List[str]:
1120
- return global_user_state.get_cluster_names_start_with(incomplete)
2017
+ return await context_utils.to_thread(
2018
+ global_user_state.get_cluster_names_start_with, incomplete)
1121
2019
 
1122
2020
 
1123
2021
  @app.get('/api/completion/storage_name')
1124
2022
  async def complete_storage_name(incomplete: str,) -> List[str]:
1125
- return global_user_state.get_storage_names_start_with(incomplete)
2023
+ return await context_utils.to_thread(
2024
+ global_user_state.get_storage_names_start_with, incomplete)
1126
2025
 
1127
2026
 
1128
- # Add a route to serve static files
1129
- @app.get('/{full_path:path}')
1130
- async def serve_static_or_dashboard(full_path: str):
1131
- """Serves static files for any unmatched routes.
2027
+ @app.get('/api/completion/volume_name')
2028
+ async def complete_volume_name(incomplete: str,) -> List[str]:
2029
+ return await context_utils.to_thread(
2030
+ global_user_state.get_volume_names_start_with, incomplete)
1132
2031
 
1133
- Handles the /dashboard prefix from Next.js configuration.
1134
- """
1135
- # Check if the path starts with 'dashboard/' and remove it if it does
1136
- if full_path.startswith('dashboard/'):
1137
- full_path = full_path[len('dashboard/'):]
1138
2032
 
1139
- # Try to serve the file directly from the out directory first
2033
+ @app.get('/api/completion/api_request')
2034
+ async def complete_api_request(incomplete: str,) -> List[str]:
2035
+ return await requests_lib.get_api_request_ids_start_with(incomplete)
2036
+
2037
+
2038
+ @app.get('/dashboard/{full_path:path}')
2039
+ async def serve_dashboard(full_path: str):
2040
+ """Serves the Next.js dashboard application.
2041
+
2042
+ Args:
2043
+ full_path: The path requested by the client.
2044
+ e.g. /clusters, /jobs
2045
+
2046
+ Returns:
2047
+ FileResponse for static files or index.html for client-side routing.
2048
+
2049
+ Raises:
2050
+ HTTPException: If the path is invalid or file not found.
2051
+ """
2052
+ # Try to serve the staticfile directly e.g. /skypilot.svg,
2053
+ # /favicon.ico, and /_next/, etc.
1140
2054
  file_path = os.path.join(server_constants.DASHBOARD_DIR, full_path)
1141
2055
  if os.path.isfile(file_path):
1142
2056
  return fastapi.responses.FileResponse(file_path)
1143
2057
 
1144
- # If file not found, serve the index.html for client-side routing.
1145
- # For example, the non-matched arbitrary route (/ or /test) from
1146
- # client will be redirected to the index.html.
2058
+ # Serve index.html for client-side routing
2059
+ # e.g. /clusters, /jobs
1147
2060
  index_path = os.path.join(server_constants.DASHBOARD_DIR, 'index.html')
1148
2061
  try:
1149
2062
  with open(index_path, 'r', encoding='utf-8') as f:
1150
2063
  content = f.read()
2064
+
1151
2065
  return fastapi.responses.HTMLResponse(content=content)
1152
2066
  except Exception as e:
1153
2067
  logger.error(f'Error serving dashboard: {e}')
1154
2068
  raise fastapi.HTTPException(status_code=500, detail=str(e))
1155
2069
 
1156
2070
 
2071
+ # Redirect the root path to dashboard
2072
+ @app.get('/')
2073
+ async def root():
2074
+ return fastapi.responses.RedirectResponse(url='/dashboard/')
2075
+
2076
+
2077
+ def _init_or_restore_server_user_hash():
2078
+ """Restores the server user hash from the global user state db.
2079
+
2080
+ The API server must have a stable user hash across restarts and potential
2081
+ multiple replicas. Thus we persist the user hash in db and restore it on
2082
+ startup. When upgrading from old version, the user hash will be read from
2083
+ the local file (if any) to keep the user hash consistent.
2084
+ """
2085
+
2086
+ def apply_user_hash(user_hash: str) -> None:
2087
+ # For local API server, the user hash in db and local file should be
2088
+ # same so there is no harm to override here.
2089
+ common_utils.set_user_hash_locally(user_hash)
2090
+ # Refresh the server user hash for current process after restore or
2091
+ # initialize the user hash in db, child processes will get the correct
2092
+ # server id from the local cache file.
2093
+ common_lib.refresh_server_id()
2094
+
2095
+ user_hash = global_user_state.get_system_config(_SERVER_USER_HASH_KEY)
2096
+ if user_hash is not None:
2097
+ apply_user_hash(user_hash)
2098
+ return
2099
+
2100
+ # Initial deployment, generate a user hash and save it to the db.
2101
+ user_hash = common_utils.get_user_hash()
2102
+ global_user_state.set_system_config(_SERVER_USER_HASH_KEY, user_hash)
2103
+ apply_user_hash(user_hash)
2104
+
2105
+
1157
2106
  if __name__ == '__main__':
1158
2107
  import uvicorn
1159
2108
 
1160
2109
  from sky.server import uvicorn as skyuvicorn
1161
2110
 
1162
- requests_lib.reset_db_and_logs()
2111
+ logger.info('Initializing SkyPilot API server')
2112
+ skyuvicorn.add_timestamp_prefix_for_server_logs()
1163
2113
 
1164
2114
  parser = argparse.ArgumentParser()
1165
2115
  parser.add_argument('--host', default='127.0.0.1')
1166
2116
  parser.add_argument('--port', default=46580, type=int)
1167
2117
  parser.add_argument('--deploy', action='store_true')
2118
+ # Serve metrics on a separate port to isolate it from the application APIs:
2119
+ # metrics port will not be exposed to the public network typically.
2120
+ parser.add_argument('--metrics-port', default=9090, type=int)
1168
2121
  cmd_args = parser.parse_args()
2122
+ if cmd_args.port == cmd_args.metrics_port:
2123
+ logger.error('port and metrics-port cannot be the same, exiting.')
2124
+ raise ValueError('port and metrics-port cannot be the same')
2125
+
2126
+ # Fail fast if the port is not available to avoid corrupt the state
2127
+ # of potential running server instance.
2128
+ # We might reach here because the running server is currently not
2129
+ # responding, thus the healthz check fails and `sky api start` think
2130
+ # we should start a new server instance.
2131
+ if not common_utils.is_port_available(cmd_args.port):
2132
+ logger.error(f'Port {cmd_args.port} is not available, exiting.')
2133
+ raise RuntimeError(f'Port {cmd_args.port} is not available')
2134
+
2135
+ # Maybe touch the signal file on API server startup. Do it again here even
2136
+ # if we already touched it in the sky/server/common.py::_start_api_server.
2137
+ # This is because the sky/server/common.py::_start_api_server function call
2138
+ # is running outside the skypilot API server process tree. The process tree
2139
+ # starts within that function (see the `subprocess.Popen` call in
2140
+ # sky/server/common.py::_start_api_server). When pg is used, the
2141
+ # _start_api_server function will not load the config file from db, which
2142
+ # will ignore the consolidation mode config. Here, inside the process tree,
2143
+ # we already reload the config as a server (with env var _start_api_server),
2144
+ # so we will respect the consolidation mode config.
2145
+ # Refers to #7717 for more details.
2146
+ managed_job_utils.is_consolidation_mode(on_api_restart=True)
2147
+
1169
2148
  # Show the privacy policy if it is not already shown. We place it here so
1170
2149
  # that it is shown only when the API server is started.
1171
2150
  usage_lib.maybe_show_privacy_policy()
1172
2151
 
1173
- config = server_config.compute_server_config(cmd_args.deploy)
2152
+ # Initialize global user state db
2153
+ db_utils.set_max_connections(1)
2154
+ logger.info('Initializing database engine')
2155
+ global_user_state.initialize_and_get_db()
2156
+ logger.info('Database engine initialized')
2157
+ # Initialize request db
2158
+ requests_lib.reset_db_and_logs()
2159
+ # Restore the server user hash
2160
+ logger.info('Initializing server user hash')
2161
+ _init_or_restore_server_user_hash()
2162
+
2163
+ max_db_connections = global_user_state.get_max_db_connections()
2164
+ logger.info(f'Max db connections: {max_db_connections}')
2165
+ config = server_config.compute_server_config(cmd_args.deploy,
2166
+ max_db_connections)
2167
+
1174
2168
  num_workers = config.num_server_workers
1175
2169
 
1176
- sub_procs = []
2170
+ queue_server: Optional[multiprocessing.Process] = None
2171
+ workers: List[executor.RequestWorker] = []
2172
+ # Global background tasks that will be scheduled in a separate event loop.
2173
+ global_tasks: List[asyncio.Task] = []
1177
2174
  try:
1178
- sub_procs = executor.start(config)
2175
+ background = uvloop.new_event_loop()
2176
+ if os.environ.get(constants.ENV_VAR_SERVER_METRICS_ENABLED):
2177
+ metrics_server = metrics.build_metrics_server(
2178
+ cmd_args.host, cmd_args.metrics_port)
2179
+ global_tasks.append(background.create_task(metrics_server.serve()))
2180
+ global_tasks.append(
2181
+ background.create_task(requests_lib.requests_gc_daemon()))
2182
+ global_tasks.append(
2183
+ background.create_task(
2184
+ global_user_state.cluster_event_retention_daemon()))
2185
+ threading.Thread(target=background.run_forever, daemon=True).start()
2186
+
2187
+ queue_server, workers = executor.start(config)
2188
+
1179
2189
  logger.info(f'Starting SkyPilot API server, workers={num_workers}')
1180
2190
  # We don't support reload for now, since it may cause leakage of request
1181
2191
  # workers or interrupt running requests.
1182
- config = uvicorn.Config('sky.server.server:app',
1183
- host=cmd_args.host,
1184
- port=cmd_args.port,
1185
- workers=num_workers)
1186
- skyuvicorn.run(config)
2192
+ uvicorn_config = uvicorn.Config('sky.server.server:app',
2193
+ host=cmd_args.host,
2194
+ port=cmd_args.port,
2195
+ workers=num_workers,
2196
+ ws_per_message_deflate=False)
2197
+ skyuvicorn.run(uvicorn_config,
2198
+ max_db_connections=config.num_db_connections_per_worker)
1187
2199
  except Exception as exc: # pylint: disable=broad-except
1188
2200
  logger.error(f'Failed to start SkyPilot API server: '
1189
2201
  f'{common_utils.format_exception(exc, use_bracket=True)}')
@@ -1191,17 +2203,11 @@ if __name__ == '__main__':
1191
2203
  finally:
1192
2204
  logger.info('Shutting down SkyPilot API server...')
1193
2205
 
1194
- def cleanup(proc: multiprocessing.Process) -> None:
1195
- try:
1196
- proc.terminate()
1197
- proc.join()
1198
- finally:
1199
- # The process may not be started yet, close it anyway.
1200
- proc.close()
1201
-
1202
- # Terminate processes in reverse order in case dependency, especially
1203
- # queue server. Terminate queue server first does not affect the
1204
- # correctness of cleanup but introduce redundant error messages.
1205
- subprocess_utils.run_in_parallel(cleanup,
1206
- list(reversed(sub_procs)),
1207
- num_threads=len(sub_procs))
2206
+ for gt in global_tasks:
2207
+ gt.cancel()
2208
+ subprocess_utils.run_in_parallel(lambda worker: worker.cancel(),
2209
+ workers,
2210
+ num_threads=len(workers))
2211
+ if queue_server is not None:
2212
+ queue_server.kill()
2213
+ queue_server.join()