skypilot-nightly 1.0.0.dev20250502__py3-none-any.whl → 1.0.0.dev20251203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (546) hide show
  1. sky/__init__.py +22 -6
  2. sky/adaptors/aws.py +81 -16
  3. sky/adaptors/common.py +25 -2
  4. sky/adaptors/coreweave.py +278 -0
  5. sky/adaptors/do.py +8 -2
  6. sky/adaptors/gcp.py +11 -0
  7. sky/adaptors/hyperbolic.py +8 -0
  8. sky/adaptors/ibm.py +5 -2
  9. sky/adaptors/kubernetes.py +149 -18
  10. sky/adaptors/nebius.py +173 -30
  11. sky/adaptors/primeintellect.py +1 -0
  12. sky/adaptors/runpod.py +68 -0
  13. sky/adaptors/seeweb.py +183 -0
  14. sky/adaptors/shadeform.py +89 -0
  15. sky/admin_policy.py +187 -4
  16. sky/authentication.py +179 -225
  17. sky/backends/__init__.py +4 -2
  18. sky/backends/backend.py +22 -9
  19. sky/backends/backend_utils.py +1323 -397
  20. sky/backends/cloud_vm_ray_backend.py +1749 -1029
  21. sky/backends/docker_utils.py +1 -1
  22. sky/backends/local_docker_backend.py +11 -6
  23. sky/backends/task_codegen.py +633 -0
  24. sky/backends/wheel_utils.py +55 -9
  25. sky/{clouds/service_catalog → catalog}/__init__.py +21 -19
  26. sky/{clouds/service_catalog → catalog}/aws_catalog.py +27 -8
  27. sky/{clouds/service_catalog → catalog}/azure_catalog.py +10 -7
  28. sky/{clouds/service_catalog → catalog}/common.py +90 -49
  29. sky/{clouds/service_catalog → catalog}/cudo_catalog.py +8 -5
  30. sky/{clouds/service_catalog → catalog}/data_fetchers/analyze.py +1 -1
  31. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_aws.py +116 -80
  32. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_cudo.py +38 -38
  33. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_gcp.py +70 -16
  34. sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
  35. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_lambda_cloud.py +1 -0
  36. sky/catalog/data_fetchers/fetch_nebius.py +338 -0
  37. sky/catalog/data_fetchers/fetch_runpod.py +698 -0
  38. sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
  39. sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
  40. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vast.py +1 -1
  41. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vsphere.py +1 -1
  42. sky/{clouds/service_catalog → catalog}/do_catalog.py +5 -2
  43. sky/{clouds/service_catalog → catalog}/fluidstack_catalog.py +6 -3
  44. sky/{clouds/service_catalog → catalog}/gcp_catalog.py +41 -15
  45. sky/catalog/hyperbolic_catalog.py +136 -0
  46. sky/{clouds/service_catalog → catalog}/ibm_catalog.py +9 -6
  47. sky/{clouds/service_catalog → catalog}/kubernetes_catalog.py +36 -24
  48. sky/{clouds/service_catalog → catalog}/lambda_catalog.py +9 -6
  49. sky/{clouds/service_catalog → catalog}/nebius_catalog.py +9 -7
  50. sky/{clouds/service_catalog → catalog}/oci_catalog.py +9 -6
  51. sky/{clouds/service_catalog → catalog}/paperspace_catalog.py +5 -2
  52. sky/catalog/primeintellect_catalog.py +95 -0
  53. sky/{clouds/service_catalog → catalog}/runpod_catalog.py +11 -4
  54. sky/{clouds/service_catalog → catalog}/scp_catalog.py +9 -6
  55. sky/catalog/seeweb_catalog.py +184 -0
  56. sky/catalog/shadeform_catalog.py +165 -0
  57. sky/catalog/ssh_catalog.py +167 -0
  58. sky/{clouds/service_catalog → catalog}/vast_catalog.py +6 -3
  59. sky/{clouds/service_catalog → catalog}/vsphere_catalog.py +5 -2
  60. sky/check.py +533 -185
  61. sky/cli.py +5 -5975
  62. sky/client/{cli.py → cli/command.py} +2591 -1956
  63. sky/client/cli/deprecation_utils.py +99 -0
  64. sky/client/cli/flags.py +359 -0
  65. sky/client/cli/table_utils.py +322 -0
  66. sky/client/cli/utils.py +79 -0
  67. sky/client/common.py +78 -32
  68. sky/client/oauth.py +82 -0
  69. sky/client/sdk.py +1219 -319
  70. sky/client/sdk_async.py +827 -0
  71. sky/client/service_account_auth.py +47 -0
  72. sky/cloud_stores.py +82 -3
  73. sky/clouds/__init__.py +13 -0
  74. sky/clouds/aws.py +564 -164
  75. sky/clouds/azure.py +105 -83
  76. sky/clouds/cloud.py +140 -40
  77. sky/clouds/cudo.py +68 -50
  78. sky/clouds/do.py +66 -48
  79. sky/clouds/fluidstack.py +63 -44
  80. sky/clouds/gcp.py +339 -110
  81. sky/clouds/hyperbolic.py +293 -0
  82. sky/clouds/ibm.py +70 -49
  83. sky/clouds/kubernetes.py +570 -162
  84. sky/clouds/lambda_cloud.py +74 -54
  85. sky/clouds/nebius.py +210 -81
  86. sky/clouds/oci.py +88 -66
  87. sky/clouds/paperspace.py +61 -44
  88. sky/clouds/primeintellect.py +317 -0
  89. sky/clouds/runpod.py +164 -74
  90. sky/clouds/scp.py +89 -86
  91. sky/clouds/seeweb.py +477 -0
  92. sky/clouds/shadeform.py +400 -0
  93. sky/clouds/ssh.py +263 -0
  94. sky/clouds/utils/aws_utils.py +10 -4
  95. sky/clouds/utils/gcp_utils.py +87 -11
  96. sky/clouds/utils/oci_utils.py +38 -14
  97. sky/clouds/utils/scp_utils.py +231 -167
  98. sky/clouds/vast.py +99 -77
  99. sky/clouds/vsphere.py +51 -40
  100. sky/core.py +375 -173
  101. sky/dag.py +15 -0
  102. sky/dashboard/out/404.html +1 -1
  103. sky/dashboard/out/_next/static/96_E2yl3QAiIJGOYCkSpB/_buildManifest.js +1 -0
  104. sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
  105. sky/dashboard/out/_next/static/chunks/1272-1ef0bf0237faccdb.js +1 -0
  106. sky/dashboard/out/_next/static/chunks/1871-7e202677c42f43fe.js +6 -0
  107. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
  108. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +1 -0
  109. sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
  110. sky/dashboard/out/_next/static/chunks/2755.edd818326d489a1d.js +26 -0
  111. sky/dashboard/out/_next/static/chunks/3294.20a8540fe697d5ee.js +1 -0
  112. sky/dashboard/out/_next/static/chunks/3785.7e245f318f9d1121.js +1 -0
  113. sky/dashboard/out/_next/static/chunks/3800-7b45f9fbb6308557.js +1 -0
  114. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
  115. sky/dashboard/out/_next/static/chunks/3937.210053269f121201.js +1 -0
  116. sky/dashboard/out/_next/static/chunks/4725.172ede95d1b21022.js +1 -0
  117. sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
  118. sky/dashboard/out/_next/static/chunks/5739-d67458fcb1386c92.js +8 -0
  119. sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
  120. sky/dashboard/out/_next/static/chunks/616-3d59f75e2ccf9321.js +39 -0
  121. sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
  122. sky/dashboard/out/_next/static/chunks/6856-8f27d1c10c98def8.js +1 -0
  123. sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +1 -0
  124. sky/dashboard/out/_next/static/chunks/6990-9146207c4567fdfd.js +1 -0
  125. sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
  126. sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +41 -0
  127. sky/dashboard/out/_next/static/chunks/7615-019513abc55b3b47.js +1 -0
  128. sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
  129. sky/dashboard/out/_next/static/chunks/8969-452f9d5cbdd2dc73.js +1 -0
  130. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
  131. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
  132. sky/dashboard/out/_next/static/chunks/9360.a536cf6b1fa42355.js +31 -0
  133. sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
  134. sky/dashboard/out/_next/static/chunks/fd9d1056-86323a29a8f7e46a.js +1 -0
  135. sky/dashboard/out/_next/static/chunks/framework-cf60a09ccd051a10.js +33 -0
  136. sky/dashboard/out/_next/static/chunks/main-app-587214043926b3cc.js +1 -0
  137. sky/dashboard/out/_next/static/chunks/main-f15ccb73239a3bf1.js +1 -0
  138. sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +34 -0
  139. sky/dashboard/out/_next/static/chunks/pages/_error-c66a4e8afc46f17b.js +1 -0
  140. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-792db96d918c98c9.js +16 -0
  141. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-abfcac9c137aa543.js +1 -0
  142. sky/dashboard/out/_next/static/chunks/pages/clusters-ee39056f9851a3ff.js +1 -0
  143. sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
  144. sky/dashboard/out/_next/static/chunks/pages/index-444f1804401f04ea.js +1 -0
  145. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +1 -0
  146. sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +1 -0
  147. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-d66997e2bfc837cf.js +16 -0
  148. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-9faf940b253e3e06.js +21 -0
  149. sky/dashboard/out/_next/static/chunks/pages/jobs-2072b48b617989c9.js +1 -0
  150. sky/dashboard/out/_next/static/chunks/pages/users-f42674164aa73423.js +1 -0
  151. sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
  152. sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
  153. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-84a40f8c7c627fe4.js +1 -0
  154. sky/dashboard/out/_next/static/chunks/pages/workspaces-531b2f8c4bf89f82.js +1 -0
  155. sky/dashboard/out/_next/static/chunks/webpack-64e05f17bf2cf8ce.js +1 -0
  156. sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
  157. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  158. sky/dashboard/out/clusters/[cluster].html +1 -1
  159. sky/dashboard/out/clusters.html +1 -1
  160. sky/dashboard/out/config.html +1 -0
  161. sky/dashboard/out/index.html +1 -1
  162. sky/dashboard/out/infra/[context].html +1 -0
  163. sky/dashboard/out/infra.html +1 -0
  164. sky/dashboard/out/jobs/[job].html +1 -1
  165. sky/dashboard/out/jobs/pools/[pool].html +1 -0
  166. sky/dashboard/out/jobs.html +1 -1
  167. sky/dashboard/out/users.html +1 -0
  168. sky/dashboard/out/volumes.html +1 -0
  169. sky/dashboard/out/workspace/new.html +1 -0
  170. sky/dashboard/out/workspaces/[name].html +1 -0
  171. sky/dashboard/out/workspaces.html +1 -0
  172. sky/data/data_utils.py +137 -1
  173. sky/data/mounting_utils.py +269 -84
  174. sky/data/storage.py +1460 -1807
  175. sky/data/storage_utils.py +43 -57
  176. sky/exceptions.py +126 -2
  177. sky/execution.py +216 -63
  178. sky/global_user_state.py +2390 -586
  179. sky/jobs/__init__.py +7 -0
  180. sky/jobs/client/sdk.py +300 -58
  181. sky/jobs/client/sdk_async.py +161 -0
  182. sky/jobs/constants.py +15 -8
  183. sky/jobs/controller.py +848 -275
  184. sky/jobs/file_content_utils.py +128 -0
  185. sky/jobs/log_gc.py +193 -0
  186. sky/jobs/recovery_strategy.py +402 -152
  187. sky/jobs/scheduler.py +314 -189
  188. sky/jobs/server/core.py +836 -255
  189. sky/jobs/server/server.py +156 -115
  190. sky/jobs/server/utils.py +136 -0
  191. sky/jobs/state.py +2109 -706
  192. sky/jobs/utils.py +1306 -215
  193. sky/logs/__init__.py +21 -0
  194. sky/logs/agent.py +108 -0
  195. sky/logs/aws.py +243 -0
  196. sky/logs/gcp.py +91 -0
  197. sky/metrics/__init__.py +0 -0
  198. sky/metrics/utils.py +453 -0
  199. sky/models.py +78 -1
  200. sky/optimizer.py +164 -70
  201. sky/provision/__init__.py +90 -4
  202. sky/provision/aws/config.py +147 -26
  203. sky/provision/aws/instance.py +136 -50
  204. sky/provision/azure/instance.py +11 -6
  205. sky/provision/common.py +13 -1
  206. sky/provision/cudo/cudo_machine_type.py +1 -1
  207. sky/provision/cudo/cudo_utils.py +14 -8
  208. sky/provision/cudo/cudo_wrapper.py +72 -71
  209. sky/provision/cudo/instance.py +10 -6
  210. sky/provision/do/instance.py +10 -6
  211. sky/provision/do/utils.py +4 -3
  212. sky/provision/docker_utils.py +140 -33
  213. sky/provision/fluidstack/instance.py +13 -8
  214. sky/provision/gcp/__init__.py +1 -0
  215. sky/provision/gcp/config.py +301 -19
  216. sky/provision/gcp/constants.py +218 -0
  217. sky/provision/gcp/instance.py +36 -8
  218. sky/provision/gcp/instance_utils.py +18 -4
  219. sky/provision/gcp/volume_utils.py +247 -0
  220. sky/provision/hyperbolic/__init__.py +12 -0
  221. sky/provision/hyperbolic/config.py +10 -0
  222. sky/provision/hyperbolic/instance.py +437 -0
  223. sky/provision/hyperbolic/utils.py +373 -0
  224. sky/provision/instance_setup.py +101 -20
  225. sky/provision/kubernetes/__init__.py +5 -0
  226. sky/provision/kubernetes/config.py +9 -52
  227. sky/provision/kubernetes/constants.py +17 -0
  228. sky/provision/kubernetes/instance.py +919 -280
  229. sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
  230. sky/provision/kubernetes/network.py +27 -17
  231. sky/provision/kubernetes/network_utils.py +44 -43
  232. sky/provision/kubernetes/utils.py +1221 -534
  233. sky/provision/kubernetes/volume.py +343 -0
  234. sky/provision/lambda_cloud/instance.py +22 -16
  235. sky/provision/nebius/constants.py +50 -0
  236. sky/provision/nebius/instance.py +19 -6
  237. sky/provision/nebius/utils.py +237 -137
  238. sky/provision/oci/instance.py +10 -5
  239. sky/provision/paperspace/instance.py +10 -7
  240. sky/provision/paperspace/utils.py +1 -1
  241. sky/provision/primeintellect/__init__.py +10 -0
  242. sky/provision/primeintellect/config.py +11 -0
  243. sky/provision/primeintellect/instance.py +454 -0
  244. sky/provision/primeintellect/utils.py +398 -0
  245. sky/provision/provisioner.py +117 -36
  246. sky/provision/runpod/__init__.py +5 -0
  247. sky/provision/runpod/instance.py +27 -6
  248. sky/provision/runpod/utils.py +51 -18
  249. sky/provision/runpod/volume.py +214 -0
  250. sky/provision/scp/__init__.py +15 -0
  251. sky/provision/scp/config.py +93 -0
  252. sky/provision/scp/instance.py +707 -0
  253. sky/provision/seeweb/__init__.py +11 -0
  254. sky/provision/seeweb/config.py +13 -0
  255. sky/provision/seeweb/instance.py +812 -0
  256. sky/provision/shadeform/__init__.py +11 -0
  257. sky/provision/shadeform/config.py +12 -0
  258. sky/provision/shadeform/instance.py +351 -0
  259. sky/provision/shadeform/shadeform_utils.py +83 -0
  260. sky/provision/ssh/__init__.py +18 -0
  261. sky/provision/vast/instance.py +13 -8
  262. sky/provision/vast/utils.py +10 -7
  263. sky/provision/volume.py +164 -0
  264. sky/provision/vsphere/common/ssl_helper.py +1 -1
  265. sky/provision/vsphere/common/vapiconnect.py +2 -1
  266. sky/provision/vsphere/common/vim_utils.py +4 -4
  267. sky/provision/vsphere/instance.py +15 -10
  268. sky/provision/vsphere/vsphere_utils.py +17 -20
  269. sky/py.typed +0 -0
  270. sky/resources.py +845 -119
  271. sky/schemas/__init__.py +0 -0
  272. sky/schemas/api/__init__.py +0 -0
  273. sky/schemas/api/responses.py +227 -0
  274. sky/schemas/db/README +4 -0
  275. sky/schemas/db/env.py +90 -0
  276. sky/schemas/db/global_user_state/001_initial_schema.py +124 -0
  277. sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
  278. sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
  279. sky/schemas/db/global_user_state/004_is_managed.py +34 -0
  280. sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
  281. sky/schemas/db/global_user_state/006_provision_log.py +41 -0
  282. sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
  283. sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
  284. sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
  285. sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
  286. sky/schemas/db/global_user_state/011_is_ephemeral.py +34 -0
  287. sky/schemas/db/kv_cache/001_initial_schema.py +29 -0
  288. sky/schemas/db/script.py.mako +28 -0
  289. sky/schemas/db/serve_state/001_initial_schema.py +67 -0
  290. sky/schemas/db/serve_state/002_yaml_content.py +34 -0
  291. sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
  292. sky/schemas/db/spot_jobs/001_initial_schema.py +97 -0
  293. sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
  294. sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
  295. sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
  296. sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
  297. sky/schemas/db/spot_jobs/006_controller_pid_started_at.py +34 -0
  298. sky/schemas/db/spot_jobs/007_config_file_content.py +34 -0
  299. sky/schemas/generated/__init__.py +0 -0
  300. sky/schemas/generated/autostopv1_pb2.py +36 -0
  301. sky/schemas/generated/autostopv1_pb2.pyi +43 -0
  302. sky/schemas/generated/autostopv1_pb2_grpc.py +146 -0
  303. sky/schemas/generated/jobsv1_pb2.py +86 -0
  304. sky/schemas/generated/jobsv1_pb2.pyi +254 -0
  305. sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
  306. sky/schemas/generated/managed_jobsv1_pb2.py +76 -0
  307. sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
  308. sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
  309. sky/schemas/generated/servev1_pb2.py +58 -0
  310. sky/schemas/generated/servev1_pb2.pyi +115 -0
  311. sky/schemas/generated/servev1_pb2_grpc.py +322 -0
  312. sky/serve/autoscalers.py +357 -5
  313. sky/serve/client/impl.py +310 -0
  314. sky/serve/client/sdk.py +47 -139
  315. sky/serve/client/sdk_async.py +130 -0
  316. sky/serve/constants.py +12 -9
  317. sky/serve/controller.py +68 -17
  318. sky/serve/load_balancer.py +106 -60
  319. sky/serve/load_balancing_policies.py +116 -2
  320. sky/serve/replica_managers.py +434 -249
  321. sky/serve/serve_rpc_utils.py +179 -0
  322. sky/serve/serve_state.py +569 -257
  323. sky/serve/serve_utils.py +775 -265
  324. sky/serve/server/core.py +66 -711
  325. sky/serve/server/impl.py +1093 -0
  326. sky/serve/server/server.py +21 -18
  327. sky/serve/service.py +192 -89
  328. sky/serve/service_spec.py +144 -20
  329. sky/serve/spot_placer.py +3 -0
  330. sky/server/auth/__init__.py +0 -0
  331. sky/server/auth/authn.py +50 -0
  332. sky/server/auth/loopback.py +38 -0
  333. sky/server/auth/oauth2_proxy.py +202 -0
  334. sky/server/common.py +478 -182
  335. sky/server/config.py +85 -23
  336. sky/server/constants.py +44 -6
  337. sky/server/daemons.py +295 -0
  338. sky/server/html/token_page.html +185 -0
  339. sky/server/metrics.py +160 -0
  340. sky/server/middleware_utils.py +166 -0
  341. sky/server/requests/executor.py +558 -138
  342. sky/server/requests/payloads.py +364 -24
  343. sky/server/requests/preconditions.py +21 -17
  344. sky/server/requests/process.py +112 -29
  345. sky/server/requests/request_names.py +121 -0
  346. sky/server/requests/requests.py +822 -226
  347. sky/server/requests/serializers/decoders.py +82 -31
  348. sky/server/requests/serializers/encoders.py +140 -22
  349. sky/server/requests/threads.py +117 -0
  350. sky/server/rest.py +455 -0
  351. sky/server/server.py +1309 -285
  352. sky/server/state.py +20 -0
  353. sky/server/stream_utils.py +327 -61
  354. sky/server/uvicorn.py +217 -3
  355. sky/server/versions.py +270 -0
  356. sky/setup_files/MANIFEST.in +11 -1
  357. sky/setup_files/alembic.ini +160 -0
  358. sky/setup_files/dependencies.py +139 -31
  359. sky/setup_files/setup.py +44 -42
  360. sky/sky_logging.py +114 -7
  361. sky/skylet/attempt_skylet.py +106 -24
  362. sky/skylet/autostop_lib.py +129 -8
  363. sky/skylet/configs.py +29 -20
  364. sky/skylet/constants.py +216 -25
  365. sky/skylet/events.py +101 -21
  366. sky/skylet/job_lib.py +345 -164
  367. sky/skylet/log_lib.py +297 -18
  368. sky/skylet/log_lib.pyi +44 -1
  369. sky/skylet/providers/ibm/node_provider.py +12 -8
  370. sky/skylet/providers/ibm/vpc_provider.py +13 -12
  371. sky/skylet/ray_patches/__init__.py +17 -3
  372. sky/skylet/ray_patches/autoscaler.py.diff +18 -0
  373. sky/skylet/ray_patches/cli.py.diff +19 -0
  374. sky/skylet/ray_patches/command_runner.py.diff +17 -0
  375. sky/skylet/ray_patches/log_monitor.py.diff +20 -0
  376. sky/skylet/ray_patches/resource_demand_scheduler.py.diff +32 -0
  377. sky/skylet/ray_patches/updater.py.diff +18 -0
  378. sky/skylet/ray_patches/worker.py.diff +41 -0
  379. sky/skylet/runtime_utils.py +21 -0
  380. sky/skylet/services.py +568 -0
  381. sky/skylet/skylet.py +72 -4
  382. sky/skylet/subprocess_daemon.py +104 -29
  383. sky/skypilot_config.py +506 -99
  384. sky/ssh_node_pools/__init__.py +1 -0
  385. sky/ssh_node_pools/core.py +135 -0
  386. sky/ssh_node_pools/server.py +233 -0
  387. sky/task.py +685 -163
  388. sky/templates/aws-ray.yml.j2 +11 -3
  389. sky/templates/azure-ray.yml.j2 +2 -1
  390. sky/templates/cudo-ray.yml.j2 +1 -0
  391. sky/templates/do-ray.yml.j2 +2 -1
  392. sky/templates/fluidstack-ray.yml.j2 +1 -0
  393. sky/templates/gcp-ray.yml.j2 +62 -1
  394. sky/templates/hyperbolic-ray.yml.j2 +68 -0
  395. sky/templates/ibm-ray.yml.j2 +2 -1
  396. sky/templates/jobs-controller.yaml.j2 +27 -24
  397. sky/templates/kubernetes-loadbalancer.yml.j2 +2 -0
  398. sky/templates/kubernetes-ray.yml.j2 +611 -50
  399. sky/templates/lambda-ray.yml.j2 +2 -1
  400. sky/templates/nebius-ray.yml.j2 +34 -12
  401. sky/templates/oci-ray.yml.j2 +1 -0
  402. sky/templates/paperspace-ray.yml.j2 +2 -1
  403. sky/templates/primeintellect-ray.yml.j2 +72 -0
  404. sky/templates/runpod-ray.yml.j2 +10 -1
  405. sky/templates/scp-ray.yml.j2 +4 -50
  406. sky/templates/seeweb-ray.yml.j2 +171 -0
  407. sky/templates/shadeform-ray.yml.j2 +73 -0
  408. sky/templates/sky-serve-controller.yaml.j2 +22 -2
  409. sky/templates/vast-ray.yml.j2 +1 -0
  410. sky/templates/vsphere-ray.yml.j2 +1 -0
  411. sky/templates/websocket_proxy.py +212 -37
  412. sky/usage/usage_lib.py +31 -15
  413. sky/users/__init__.py +0 -0
  414. sky/users/model.conf +15 -0
  415. sky/users/permission.py +397 -0
  416. sky/users/rbac.py +121 -0
  417. sky/users/server.py +720 -0
  418. sky/users/token_service.py +218 -0
  419. sky/utils/accelerator_registry.py +35 -5
  420. sky/utils/admin_policy_utils.py +84 -38
  421. sky/utils/annotations.py +38 -5
  422. sky/utils/asyncio_utils.py +78 -0
  423. sky/utils/atomic.py +1 -1
  424. sky/utils/auth_utils.py +153 -0
  425. sky/utils/benchmark_utils.py +60 -0
  426. sky/utils/cli_utils/status_utils.py +159 -86
  427. sky/utils/cluster_utils.py +31 -9
  428. sky/utils/command_runner.py +354 -68
  429. sky/utils/command_runner.pyi +93 -3
  430. sky/utils/common.py +35 -8
  431. sky/utils/common_utils.py +314 -91
  432. sky/utils/config_utils.py +74 -5
  433. sky/utils/context.py +403 -0
  434. sky/utils/context_utils.py +242 -0
  435. sky/utils/controller_utils.py +383 -89
  436. sky/utils/dag_utils.py +31 -12
  437. sky/utils/db/__init__.py +0 -0
  438. sky/utils/db/db_utils.py +485 -0
  439. sky/utils/db/kv_cache.py +149 -0
  440. sky/utils/db/migration_utils.py +137 -0
  441. sky/utils/directory_utils.py +12 -0
  442. sky/utils/env_options.py +13 -0
  443. sky/utils/git.py +567 -0
  444. sky/utils/git_clone.sh +460 -0
  445. sky/utils/infra_utils.py +195 -0
  446. sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
  447. sky/utils/kubernetes/config_map_utils.py +133 -0
  448. sky/utils/kubernetes/create_cluster.sh +15 -29
  449. sky/utils/kubernetes/delete_cluster.sh +10 -7
  450. sky/utils/kubernetes/deploy_ssh_node_pools.py +1177 -0
  451. sky/utils/kubernetes/exec_kubeconfig_converter.py +22 -31
  452. sky/utils/kubernetes/generate_kind_config.py +6 -66
  453. sky/utils/kubernetes/generate_kubeconfig.sh +4 -1
  454. sky/utils/kubernetes/gpu_labeler.py +18 -8
  455. sky/utils/kubernetes/k8s_gpu_labeler_job.yaml +2 -1
  456. sky/utils/kubernetes/k8s_gpu_labeler_setup.yaml +16 -16
  457. sky/utils/kubernetes/kubernetes_deploy_utils.py +284 -114
  458. sky/utils/kubernetes/rsync_helper.sh +11 -3
  459. sky/utils/kubernetes/ssh-tunnel.sh +379 -0
  460. sky/utils/kubernetes/ssh_utils.py +221 -0
  461. sky/utils/kubernetes_enums.py +8 -15
  462. sky/utils/lock_events.py +94 -0
  463. sky/utils/locks.py +416 -0
  464. sky/utils/log_utils.py +82 -107
  465. sky/utils/perf_utils.py +22 -0
  466. sky/utils/resource_checker.py +298 -0
  467. sky/utils/resources_utils.py +249 -32
  468. sky/utils/rich_utils.py +217 -39
  469. sky/utils/schemas.py +955 -160
  470. sky/utils/serialize_utils.py +16 -0
  471. sky/utils/status_lib.py +10 -0
  472. sky/utils/subprocess_utils.py +29 -15
  473. sky/utils/tempstore.py +70 -0
  474. sky/utils/thread_utils.py +91 -0
  475. sky/utils/timeline.py +26 -53
  476. sky/utils/ux_utils.py +84 -15
  477. sky/utils/validator.py +11 -1
  478. sky/utils/volume.py +165 -0
  479. sky/utils/yaml_utils.py +111 -0
  480. sky/volumes/__init__.py +13 -0
  481. sky/volumes/client/__init__.py +0 -0
  482. sky/volumes/client/sdk.py +150 -0
  483. sky/volumes/server/__init__.py +0 -0
  484. sky/volumes/server/core.py +270 -0
  485. sky/volumes/server/server.py +124 -0
  486. sky/volumes/volume.py +215 -0
  487. sky/workspaces/__init__.py +0 -0
  488. sky/workspaces/core.py +655 -0
  489. sky/workspaces/server.py +101 -0
  490. sky/workspaces/utils.py +56 -0
  491. sky_templates/README.md +3 -0
  492. sky_templates/__init__.py +3 -0
  493. sky_templates/ray/__init__.py +0 -0
  494. sky_templates/ray/start_cluster +183 -0
  495. sky_templates/ray/stop_cluster +75 -0
  496. skypilot_nightly-1.0.0.dev20251203.dist-info/METADATA +676 -0
  497. skypilot_nightly-1.0.0.dev20251203.dist-info/RECORD +611 -0
  498. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/WHEEL +1 -1
  499. skypilot_nightly-1.0.0.dev20251203.dist-info/top_level.txt +2 -0
  500. sky/benchmark/benchmark_state.py +0 -256
  501. sky/benchmark/benchmark_utils.py +0 -641
  502. sky/clouds/service_catalog/constants.py +0 -7
  503. sky/dashboard/out/_next/static/GWvVBSCS7FmUiVmjaL1a7/_buildManifest.js +0 -1
  504. sky/dashboard/out/_next/static/chunks/236-2db3ee3fba33dd9e.js +0 -6
  505. sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
  506. sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
  507. sky/dashboard/out/_next/static/chunks/678-206dddca808e6d16.js +0 -59
  508. sky/dashboard/out/_next/static/chunks/845-9e60713e0c441abc.js +0 -1
  509. sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
  510. sky/dashboard/out/_next/static/chunks/fd9d1056-2821b0f0cabcd8bd.js +0 -1
  511. sky/dashboard/out/_next/static/chunks/framework-87d061ee6ed71b28.js +0 -33
  512. sky/dashboard/out/_next/static/chunks/main-app-241eb28595532291.js +0 -1
  513. sky/dashboard/out/_next/static/chunks/main-e0e2335212e72357.js +0 -1
  514. sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
  515. sky/dashboard/out/_next/static/chunks/pages/_error-1be831200e60c5c0.js +0 -1
  516. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-6ac338bc2239cb45.js +0 -1
  517. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
  518. sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
  519. sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
  520. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-1c519e1afc523dc9.js +0 -1
  521. sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
  522. sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
  523. sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
  524. sky/jobs/dashboard/dashboard.py +0 -223
  525. sky/jobs/dashboard/static/favicon.ico +0 -0
  526. sky/jobs/dashboard/templates/index.html +0 -831
  527. sky/jobs/server/dashboard_utils.py +0 -69
  528. sky/skylet/providers/scp/__init__.py +0 -2
  529. sky/skylet/providers/scp/config.py +0 -149
  530. sky/skylet/providers/scp/node_provider.py +0 -578
  531. sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
  532. sky/utils/db_utils.py +0 -100
  533. sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
  534. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
  535. skypilot_nightly-1.0.0.dev20250502.dist-info/METADATA +0 -361
  536. skypilot_nightly-1.0.0.dev20250502.dist-info/RECORD +0 -396
  537. skypilot_nightly-1.0.0.dev20250502.dist-info/top_level.txt +0 -1
  538. /sky/{clouds/service_catalog → catalog}/config.py +0 -0
  539. /sky/{benchmark → catalog/data_fetchers}/__init__.py +0 -0
  540. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_azure.py +0 -0
  541. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_fluidstack.py +0 -0
  542. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_ibm.py +0 -0
  543. /sky/{clouds/service_catalog/data_fetchers → client/cli}/__init__.py +0 -0
  544. /sky/dashboard/out/_next/static/{GWvVBSCS7FmUiVmjaL1a7 → 96_E2yl3QAiIJGOYCkSpB}/_ssgManifest.js +0 -0
  545. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/entry_points.txt +0 -0
  546. {skypilot_nightly-1.0.0.dev20250502.dist-info → skypilot_nightly-1.0.0.dev20251203.dist-info}/licenses/LICENSE +0 -0
sky/client/sdk.py CHANGED
@@ -10,75 +10,165 @@ Usage example:
10
10
  statuses = sky.get(request_id)
11
11
 
12
12
  """
13
- import getpass
13
+ from http import cookiejar
14
14
  import json
15
15
  import logging
16
16
  import os
17
- import pathlib
18
17
  import subprocess
19
18
  import typing
20
- from typing import Any, Dict, List, Optional, Tuple, Union
21
- import webbrowser
19
+ from typing import (Any, Dict, Iterator, List, Literal, Optional, Tuple,
20
+ TypeVar, Union)
21
+ from urllib import parse as urlparse
22
22
 
23
23
  import click
24
24
  import colorama
25
25
  import filelock
26
26
 
27
27
  from sky import admin_policy
28
- from sky import backends
29
28
  from sky import exceptions
30
29
  from sky import sky_logging
31
30
  from sky import skypilot_config
32
31
  from sky.adaptors import common as adaptors_common
33
32
  from sky.client import common as client_common
33
+ from sky.client import oauth as oauth_lib
34
+ from sky.jobs import scheduler
35
+ from sky.jobs import utils as managed_job_utils
36
+ from sky.schemas.api import responses
34
37
  from sky.server import common as server_common
38
+ from sky.server import rest
39
+ from sky.server import versions
35
40
  from sky.server.requests import payloads
41
+ from sky.server.requests import request_names
36
42
  from sky.server.requests import requests as requests_lib
43
+ from sky.skylet import autostop_lib
37
44
  from sky.skylet import constants
38
45
  from sky.usage import usage_lib
46
+ from sky.utils import admin_policy_utils
39
47
  from sky.utils import annotations
40
48
  from sky.utils import cluster_utils
41
49
  from sky.utils import common
42
50
  from sky.utils import common_utils
51
+ from sky.utils import context as sky_context
43
52
  from sky.utils import dag_utils
44
53
  from sky.utils import env_options
54
+ from sky.utils import infra_utils
45
55
  from sky.utils import rich_utils
46
56
  from sky.utils import status_lib
47
57
  from sky.utils import subprocess_utils
48
58
  from sky.utils import ux_utils
59
+ from sky.utils import yaml_utils
60
+ from sky.utils.kubernetes import ssh_utils
49
61
 
50
62
  if typing.TYPE_CHECKING:
63
+ import base64
64
+ import binascii
51
65
  import io
66
+ import pathlib
67
+ import time
68
+ import webbrowser
52
69
 
53
70
  import psutil
54
71
  import requests
55
72
 
56
73
  import sky
74
+ from sky import backends
75
+ from sky import catalog
76
+ from sky import models
77
+ from sky.provision.kubernetes import utils as kubernetes_utils
78
+ from sky.skylet import job_lib
57
79
  else:
80
+ # only used in api_login()
81
+ base64 = adaptors_common.LazyImport('base64')
82
+ binascii = adaptors_common.LazyImport('binascii')
83
+ pathlib = adaptors_common.LazyImport('pathlib')
84
+ time = adaptors_common.LazyImport('time')
85
+ # only used in dashboard() and api_login()
86
+ webbrowser = adaptors_common.LazyImport('webbrowser')
87
+ # only used in api_stop()
58
88
  psutil = adaptors_common.LazyImport('psutil')
59
- requests = adaptors_common.LazyImport('requests')
60
89
 
61
90
  logger = sky_logging.init_logger(__name__)
62
91
  logging.getLogger('httpx').setLevel(logging.CRITICAL)
63
92
 
93
+ _LINE_PROCESSED_KEY = 'line_processed'
64
94
 
65
- def stream_response(request_id: Optional[str],
95
+ T = TypeVar('T')
96
+
97
+
98
+ def reload_config() -> None:
99
+ """Reloads the client-side config."""
100
+ skypilot_config.safe_reload_config()
101
+
102
+
103
+ # The overloads are not comprehensive - e.g. get_result Literal[False] could be
104
+ # specified to return None. We can add more overloads if needed. To do that see
105
+ # https://github.com/python/mypy/issues/8634#issuecomment-609411104
106
+ @typing.overload
107
+ def stream_response(request_id: None,
66
108
  response: 'requests.Response',
67
- output_stream: Optional['io.TextIOBase'] = None) -> Any:
109
+ output_stream: Optional['io.TextIOBase'] = None,
110
+ resumable: bool = False,
111
+ get_result: bool = True) -> None:
112
+ ...
113
+
114
+
115
+ @typing.overload
116
+ def stream_response(request_id: server_common.RequestId[T],
117
+ response: 'requests.Response',
118
+ output_stream: Optional['io.TextIOBase'] = None,
119
+ resumable: bool = False,
120
+ get_result: Literal[True] = True) -> T:
121
+ ...
122
+
123
+
124
+ @typing.overload
125
+ def stream_response(request_id: server_common.RequestId[T],
126
+ response: 'requests.Response',
127
+ output_stream: Optional['io.TextIOBase'] = None,
128
+ resumable: bool = False,
129
+ get_result: bool = True) -> Optional[T]:
130
+ ...
131
+
132
+
133
+ def stream_response(request_id: Optional[server_common.RequestId[T]],
134
+ response: 'requests.Response',
135
+ output_stream: Optional['io.TextIOBase'] = None,
136
+ resumable: bool = False,
137
+ get_result: bool = True) -> Optional[T]:
68
138
  """Streams the response to the console.
69
139
 
70
140
  Args:
71
- request_id: The request ID.
141
+ request_id: The request ID of the request to stream. May be a full
142
+ request ID or a prefix.
143
+ If None, the latest request submitted to the API server is streamed.
144
+ Using None request_id is not recommended in multi-user environments.
72
145
  response: The HTTP response.
73
146
  output_stream: The output stream to write to. If None, print to the
74
147
  console.
148
+ resumable: Whether the response is resumable on retry. If True, the
149
+ streaming will start from the previous failure point on retry.
150
+ get_result: Whether to get the result of the request. This will
151
+ typically be set to False for `--no-follow` flags as requests may
152
+ continue to run for long periods of time without further streaming.
75
153
  """
76
154
 
155
+ retry_context: Optional[rest.RetryContext] = None
156
+ if resumable:
157
+ retry_context = rest.get_retry_context()
77
158
  try:
159
+ line_count = 0
78
160
  for line in rich_utils.decode_rich_status(response):
79
161
  if line is not None:
80
- print(line, flush=True, end='', file=output_stream)
81
- return get(request_id)
162
+ line_count += 1
163
+ if retry_context is None:
164
+ print(line, flush=True, end='', file=output_stream)
165
+ elif line_count > retry_context.line_processed:
166
+ print(line, flush=True, end='', file=output_stream)
167
+ retry_context.line_processed = line_count
168
+ if request_id is not None and get_result:
169
+ return get(request_id)
170
+ else:
171
+ return None
82
172
  except Exception: # pylint: disable=broad-except
83
173
  logger.debug(f'To stream request logs: sky api logs {request_id}')
84
174
  raise
@@ -87,13 +177,18 @@ def stream_response(request_id: Optional[str],
87
177
  @usage_lib.entrypoint
88
178
  @server_common.check_server_healthy_or_start
89
179
  @annotations.client_api
90
- def check(clouds: Optional[Tuple[str]],
91
- verbose: bool) -> server_common.RequestId:
180
+ def check(
181
+ infra_list: Optional[Tuple[str, ...]],
182
+ verbose: bool,
183
+ workspace: Optional[str] = None
184
+ ) -> server_common.RequestId[Dict[str, List[str]]]:
92
185
  """Checks the credentials to enable clouds.
93
186
 
94
187
  Args:
95
- clouds: The clouds to check.
188
+ infra: The infra to check.
96
189
  verbose: Whether to show verbose output.
190
+ workspace: The workspace to check. If None, all workspaces will be
191
+ checked.
97
192
 
98
193
  Returns:
99
194
  The request ID of the check request.
@@ -101,41 +196,69 @@ def check(clouds: Optional[Tuple[str]],
101
196
  Request Returns:
102
197
  None
103
198
  """
104
- body = payloads.CheckBody(clouds=clouds, verbose=verbose)
105
- response = requests.post(f'{server_common.get_server_url()}/check',
106
- json=json.loads(body.model_dump_json()),
107
- cookies=server_common.get_api_cookie_jar())
199
+ if infra_list is None:
200
+ clouds = None
201
+ else:
202
+ specified_clouds = []
203
+ for infra_str in infra_list:
204
+ infra = infra_utils.InfraInfo.from_str(infra_str)
205
+ if infra.cloud is None:
206
+ with ux_utils.print_exception_no_traceback():
207
+ raise ValueError(f'Invalid infra to check: {infra_str}')
208
+ if infra.region is not None or infra.zone is not None:
209
+ region_zone = infra_str.partition('/')[-1]
210
+ logger.warning(f'Infra {infra_str} is specified, but `check` '
211
+ f'only supports checking {infra.cloud}, '
212
+ f'ignoring {region_zone}')
213
+ specified_clouds.append(infra.cloud)
214
+ clouds = tuple(specified_clouds)
215
+ body = payloads.CheckBody(clouds=clouds,
216
+ verbose=verbose,
217
+ workspace=workspace)
218
+ response = server_common.make_authenticated_request(
219
+ 'POST', '/check', json=json.loads(body.model_dump_json()))
108
220
  return server_common.get_request_id(response)
109
221
 
110
222
 
111
223
  @usage_lib.entrypoint
112
224
  @server_common.check_server_healthy_or_start
113
225
  @annotations.client_api
114
- def enabled_clouds() -> server_common.RequestId:
226
+ def enabled_clouds(workspace: Optional[str] = None,
227
+ expand: bool = False) -> server_common.RequestId[List[str]]:
115
228
  """Gets the enabled clouds.
116
229
 
230
+ Args:
231
+ workspace: The workspace to get the enabled clouds for. If None, the
232
+ active workspace will be used.
233
+ expand: Whether to expand Kubernetes and SSH to list of resource pools.
234
+
117
235
  Returns:
118
236
  The request ID of the enabled clouds request.
119
237
 
120
238
  Request Returns:
121
239
  A list of enabled clouds in string format.
122
240
  """
123
- response = requests.get(f'{server_common.get_server_url()}/enabled_clouds',
124
- cookies=server_common.get_api_cookie_jar())
241
+ if workspace is None:
242
+ workspace = skypilot_config.get_active_workspace()
243
+ response = server_common.make_authenticated_request(
244
+ 'GET', f'/enabled_clouds?workspace={workspace}&expand={expand}')
125
245
  return server_common.get_request_id(response)
126
246
 
127
247
 
128
248
  @usage_lib.entrypoint
129
249
  @server_common.check_server_healthy_or_start
130
250
  @annotations.client_api
131
- def list_accelerators(gpus_only: bool = True,
132
- name_filter: Optional[str] = None,
133
- region_filter: Optional[str] = None,
134
- quantity_filter: Optional[int] = None,
135
- clouds: Optional[Union[List[str], str]] = None,
136
- all_regions: bool = False,
137
- require_price: bool = True,
138
- case_sensitive: bool = True) -> server_common.RequestId:
251
+ def list_accelerators(
252
+ gpus_only: bool = True,
253
+ name_filter: Optional[str] = None,
254
+ region_filter: Optional[str] = None,
255
+ quantity_filter: Optional[int] = None,
256
+ clouds: Optional[Union[List[str], str]] = None,
257
+ all_regions: bool = False,
258
+ require_price: bool = True,
259
+ case_sensitive: bool = True
260
+ ) -> server_common.RequestId[Dict[str,
261
+ List['catalog.common.InstanceTypeInfo']]]:
139
262
  """Lists the names of all accelerators offered by Sky.
140
263
 
141
264
  This will include all accelerators offered by Sky, including those
@@ -169,10 +292,8 @@ def list_accelerators(gpus_only: bool = True,
169
292
  require_price=require_price,
170
293
  case_sensitive=case_sensitive,
171
294
  )
172
- response = requests.post(
173
- f'{server_common.get_server_url()}/list_accelerators',
174
- json=json.loads(body.model_dump_json()),
175
- cookies=server_common.get_api_cookie_jar())
295
+ response = server_common.make_authenticated_request(
296
+ 'POST', '/list_accelerators', json=json.loads(body.model_dump_json()))
176
297
  return server_common.get_request_id(response)
177
298
 
178
299
 
@@ -180,12 +301,12 @@ def list_accelerators(gpus_only: bool = True,
180
301
  @server_common.check_server_healthy_or_start
181
302
  @annotations.client_api
182
303
  def list_accelerator_counts(
183
- gpus_only: bool = True,
184
- name_filter: Optional[str] = None,
185
- region_filter: Optional[str] = None,
186
- quantity_filter: Optional[int] = None,
187
- clouds: Optional[Union[List[str],
188
- str]] = None) -> server_common.RequestId:
304
+ gpus_only: bool = True,
305
+ name_filter: Optional[str] = None,
306
+ region_filter: Optional[str] = None,
307
+ quantity_filter: Optional[int] = None,
308
+ clouds: Optional[Union[List[str], str]] = None
309
+ ) -> server_common.RequestId[Dict[str, List[float]]]:
189
310
  """Lists all accelerators offered by Sky and available counts.
190
311
 
191
312
  Args:
@@ -203,17 +324,17 @@ def list_accelerator_counts(
203
324
  accelerator names mapped to a list of available counts. See usage
204
325
  in cli.py.
205
326
  """
206
- body = payloads.ListAcceleratorsBody(
327
+ body = payloads.ListAcceleratorCountsBody(
207
328
  gpus_only=gpus_only,
208
329
  name_filter=name_filter,
209
330
  region_filter=region_filter,
210
331
  quantity_filter=quantity_filter,
211
332
  clouds=clouds,
212
333
  )
213
- response = requests.post(
214
- f'{server_common.get_server_url()}/list_accelerator_counts',
215
- json=json.loads(body.model_dump_json()),
216
- cookies=server_common.get_api_cookie_jar())
334
+ response = server_common.make_authenticated_request(
335
+ 'POST',
336
+ '/list_accelerator_counts',
337
+ json=json.loads(body.model_dump_json()))
217
338
  return server_common.get_request_id(response)
218
339
 
219
340
 
@@ -224,7 +345,7 @@ def optimize(
224
345
  dag: 'sky.Dag',
225
346
  minimize: common.OptimizeTarget = common.OptimizeTarget.COST,
226
347
  admin_policy_request_options: Optional[admin_policy.RequestOptions] = None
227
- ) -> server_common.RequestId:
348
+ ) -> server_common.RequestId['sky.Dag']:
228
349
  """Finds the best execution plan for the given DAG.
229
350
 
230
351
  Args:
@@ -250,12 +371,27 @@ def optimize(
250
371
  body = payloads.OptimizeBody(dag=dag_str,
251
372
  minimize=minimize,
252
373
  request_options=admin_policy_request_options)
253
- response = requests.post(f'{server_common.get_server_url()}/optimize',
254
- json=json.loads(body.model_dump_json()),
255
- cookies=server_common.get_api_cookie_jar())
374
+ response = server_common.make_authenticated_request(
375
+ 'POST', '/optimize', json=json.loads(body.model_dump_json()))
256
376
  return server_common.get_request_id(response)
257
377
 
258
378
 
379
+ def workspaces() -> server_common.RequestId[Dict[str, Any]]:
380
+ """Gets the workspaces."""
381
+ response = server_common.make_authenticated_request('GET', '/workspaces')
382
+ return server_common.get_request_id(response)
383
+
384
+
385
+ def _raise_exception_object_on_client(e: BaseException) -> None:
386
+ """Raise the exception object on the client."""
387
+ if env_options.Options.SHOW_DEBUG_INFO.get():
388
+ stacktrace = getattr(e, 'stacktrace', str(e))
389
+ logger.error('=== Traceback on SkyPilot API Server ===\n'
390
+ f'{stacktrace}')
391
+ with ux_utils.print_exception_no_traceback():
392
+ raise e
393
+
394
+
259
395
  @usage_lib.entrypoint
260
396
  @server_common.check_server_healthy_or_start
261
397
  @annotations.client_api
@@ -279,29 +415,35 @@ def validate(
279
415
  validation. This is only required when a admin policy is in use,
280
416
  see: https://docs.skypilot.co/en/latest/cloud-setup/policy.html
281
417
  """
418
+ remote_api_version = versions.get_remote_api_version()
419
+ # TODO(kevin): remove this in v0.13.0
420
+ omit_user_specified_yaml = (remote_api_version is None or
421
+ remote_api_version < 15)
282
422
  for task in dag.tasks:
423
+ if omit_user_specified_yaml:
424
+ # pylint: disable=protected-access
425
+ task._user_specified_yaml = None
283
426
  task.expand_and_validate_workdir()
284
427
  if not workdir_only:
285
428
  task.expand_and_validate_file_mounts()
286
429
  dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)
287
430
  body = payloads.ValidateBody(dag=dag_str,
288
431
  request_options=admin_policy_request_options)
289
- response = requests.post(f'{server_common.get_server_url()}/validate',
290
- json=json.loads(body.model_dump_json()),
291
- cookies=server_common.get_api_cookie_jar())
432
+ response = server_common.make_authenticated_request(
433
+ 'POST', '/validate', json=json.loads(body.model_dump_json()))
292
434
  if response.status_code == 400:
293
- with ux_utils.print_exception_no_traceback():
294
- raise exceptions.deserialize_exception(
295
- response.json().get('detail'))
435
+ _raise_exception_object_on_client(
436
+ exceptions.deserialize_exception(response.json().get('detail')))
296
437
 
297
438
 
298
439
  @usage_lib.entrypoint
299
440
  @server_common.check_server_healthy_or_start
300
441
  @annotations.client_api
301
- def dashboard() -> None:
442
+ def dashboard(starting_page: Optional[str] = None) -> None:
302
443
  """Starts the dashboard for SkyPilot."""
303
444
  api_server_url = server_common.get_server_url()
304
- url = server_common.get_dashboard_url(api_server_url)
445
+ url = server_common.get_dashboard_url(api_server_url,
446
+ starting_page=starting_page)
305
447
  logger.info(f'Opening dashboard in browser: {url}')
306
448
  webbrowser.open(url)
307
449
 
@@ -309,14 +451,16 @@ def dashboard() -> None:
309
451
  @usage_lib.entrypoint
310
452
  @server_common.check_server_healthy_or_start
311
453
  @annotations.client_api
454
+ @sky_context.contextual
312
455
  def launch(
313
456
  task: Union['sky.Task', 'sky.Dag'],
314
457
  cluster_name: Optional[str] = None,
315
458
  retry_until_up: bool = False,
316
459
  idle_minutes_to_autostop: Optional[int] = None,
460
+ wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
317
461
  dryrun: bool = False,
318
462
  down: bool = False, # pylint: disable=redefined-outer-name
319
- backend: Optional[backends.Backend] = None,
463
+ backend: Optional['backends.Backend'] = None,
320
464
  optimize_target: common.OptimizeTarget = common.OptimizeTarget.COST,
321
465
  no_setup: bool = False,
322
466
  clone_disk_from: Optional[str] = None,
@@ -327,7 +471,8 @@ def launch(
327
471
  _is_launched_by_jobs_controller: bool = False,
328
472
  _is_launched_by_sky_serve_controller: bool = False,
329
473
  _disable_controller_check: bool = False,
330
- ) -> server_common.RequestId:
474
+ ) -> server_common.RequestId[Tuple[Optional[int],
475
+ Optional['backends.ResourceHandle']]]:
331
476
  """Launches a cluster or task.
332
477
 
333
478
  The task's setup and run commands are executed under the task's workdir
@@ -344,7 +489,7 @@ def launch(
344
489
  import sky
345
490
  task = sky.Task(run='echo hello SkyPilot')
346
491
  task.set_resources(
347
- sky.Resources(cloud=sky.AWS(), accelerators='V100:4'))
492
+ sky.Resources(infra='aws', accelerators='V100:4'))
348
493
  sky.launch(task, cluster_name='my-cluster')
349
494
 
350
495
 
@@ -355,18 +500,31 @@ def launch(
355
500
  retry_until_up: whether to retry launching the cluster until it is
356
501
  up.
357
502
  idle_minutes_to_autostop: automatically stop the cluster after this
358
- many minute of idleness, i.e., no running or pending jobs in the
359
- cluster's job queue. Idleness gets reset whenever setting-up/
360
- running/pending jobs are found in the job queue. Setting this
361
- flag is equivalent to running ``sky.launch()`` and then
362
- ``sky.autostop(idle_minutes=<minutes>)``. If not set, the cluster
363
- will not be autostopped.
503
+ many minute of idleness, i.e., no running or pending jobs in the
504
+ cluster's job queue. Idleness gets reset whenever setting-up/
505
+ running/pending jobs are found in the job queue. Setting this
506
+ flag is equivalent to running
507
+ ``sky.launch(...)`` and then
508
+ ``sky.autostop(idle_minutes=<minutes>)``. If set, the autostop
509
+ config specified in the task' resources will be overridden by
510
+ this parameter.
511
+ wait_for: determines the condition for resetting the idleness timer.
512
+ This option works in conjunction with ``idle_minutes_to_autostop``.
513
+ Choices:
514
+
515
+ 1. "jobs_and_ssh" (default) - Wait for in-progress jobs and SSH
516
+ connections to finish.
517
+ 2. "jobs" - Only wait for in-progress jobs.
518
+ 3. "none" - Wait for nothing; autostop right after
519
+ ``idle_minutes_to_autostop``.
364
520
  dryrun: if True, do not actually launch the cluster.
365
521
  down: Tear down the cluster after all jobs finish (successfully or
366
- abnormally). If --idle-minutes-to-autostop is also set, the
367
- cluster will be torn down after the specified idle time.
368
- Note that if errors occur during provisioning/data syncing/setting
369
- up, the cluster will not be torn down for debugging purposes.
522
+ abnormally). If --idle-minutes-to-autostop is also set, the
523
+ cluster will be torn down after the specified idle time.
524
+ Note that if errors occur during provisioning/data syncing/setting
525
+ up, the cluster will not be torn down for debugging purposes. If
526
+ set, the autostop config specified in the task' resources will be
527
+ overridden by this parameter.
370
528
  backend: backend to use. If None, use the default backend
371
529
  (CloudVMRayBackend).
372
530
  optimize_target: target to optimize for. Choices: OptimizeTarget.COST,
@@ -422,35 +580,115 @@ def launch(
422
580
  raise NotImplementedError('clone_disk_from is not implemented yet. '
423
581
  'Please contact the SkyPilot team if you '
424
582
  'need this feature at slack.skypilot.co.')
583
+
584
+ remote_api_version = versions.get_remote_api_version()
585
+ if wait_for is not None and (remote_api_version is None or
586
+ remote_api_version < 13):
587
+ logger.warning('wait_for is not supported in your API server. '
588
+ 'Please upgrade to a newer API server to use it.')
589
+
425
590
  dag = dag_utils.convert_entrypoint_to_dag(task)
591
+ # Override the autostop config from command line flags to task YAML.
592
+ for task in dag.tasks:
593
+ for resource in task.resources:
594
+ if remote_api_version is None or remote_api_version < 13:
595
+ # An older server would not recognize the wait_for field
596
+ # in the schema, so we need to omit it.
597
+ resource.override_autostop_config(
598
+ down=down, idle_minutes=idle_minutes_to_autostop)
599
+ else:
600
+ resource.override_autostop_config(
601
+ down=down,
602
+ idle_minutes=idle_minutes_to_autostop,
603
+ wait_for=wait_for)
604
+ if resource.autostop_config is not None:
605
+ # For backward-compatibility, get the final autostop config for
606
+ # admin policy.
607
+ # TODO(aylei): remove this after 0.12.0
608
+ down = resource.autostop_config.down
609
+ idle_minutes_to_autostop = resource.autostop_config.idle_minutes
610
+
426
611
  request_options = admin_policy.RequestOptions(
427
612
  cluster_name=cluster_name,
428
613
  idle_minutes_to_autostop=idle_minutes_to_autostop,
429
614
  down=down,
430
615
  dryrun=dryrun)
616
+ with admin_policy_utils.apply_and_use_config_in_current_request(
617
+ dag,
618
+ request_name=request_names.AdminPolicyRequestName.CLUSTER_LAUNCH,
619
+ request_options=request_options,
620
+ at_client_side=True) as dag:
621
+ return _launch(
622
+ dag,
623
+ cluster_name,
624
+ request_options,
625
+ retry_until_up,
626
+ idle_minutes_to_autostop,
627
+ dryrun,
628
+ down,
629
+ backend,
630
+ optimize_target,
631
+ no_setup,
632
+ clone_disk_from,
633
+ fast,
634
+ _need_confirmation,
635
+ _is_launched_by_jobs_controller,
636
+ _is_launched_by_sky_serve_controller,
637
+ _disable_controller_check,
638
+ )
639
+
640
+
641
+ def _launch(
642
+ dag: 'sky.Dag',
643
+ cluster_name: str,
644
+ request_options: admin_policy.RequestOptions,
645
+ retry_until_up: bool = False,
646
+ idle_minutes_to_autostop: Optional[int] = None,
647
+ dryrun: bool = False,
648
+ down: bool = False, # pylint: disable=redefined-outer-name
649
+ backend: Optional['backends.Backend'] = None,
650
+ optimize_target: common.OptimizeTarget = common.OptimizeTarget.COST,
651
+ no_setup: bool = False,
652
+ clone_disk_from: Optional[str] = None,
653
+ fast: bool = False,
654
+ # Internal only:
655
+ # pylint: disable=invalid-name
656
+ _need_confirmation: bool = False,
657
+ _is_launched_by_jobs_controller: bool = False,
658
+ _is_launched_by_sky_serve_controller: bool = False,
659
+ _disable_controller_check: bool = False,
660
+ ) -> server_common.RequestId[Tuple[Optional[int],
661
+ Optional['backends.ResourceHandle']]]:
662
+ """Auxiliary function for launch(), refer to launch() for details."""
663
+
431
664
  validate(dag, admin_policy_request_options=request_options)
665
+ # The flags have been applied to the task YAML and the backward
666
+ # compatibility of admin policy has been handled. We should no longer use
667
+ # these flags.
668
+ del down, idle_minutes_to_autostop
432
669
 
433
670
  confirm_shown = False
434
671
  if _need_confirmation:
435
672
  cluster_status = None
436
673
  # TODO(SKY-998): we should reduce RTTs before launching the cluster.
437
- request_id = status([cluster_name], all_users=True)
438
- clusters = get(request_id)
674
+ status_request_id = status([cluster_name], all_users=True)
675
+ clusters = get(status_request_id)
439
676
  cluster_user_hash = common_utils.get_user_hash()
440
677
  cluster_user_hash_str = ''
441
- cluster_user_name = getpass.getuser()
678
+ current_user = common_utils.get_current_user_name()
679
+ cluster_user_name = current_user
442
680
  if not clusters:
443
681
  # Show the optimize log before the prompt if the cluster does not
444
682
  # exist.
445
- request_id = optimize(dag,
446
- admin_policy_request_options=request_options)
447
- stream_and_get(request_id)
683
+ optimize_request_id = optimize(
684
+ dag, admin_policy_request_options=request_options)
685
+ stream_and_get(optimize_request_id)
448
686
  else:
449
687
  cluster_record = clusters[0]
450
688
  cluster_status = cluster_record['status']
451
689
  cluster_user_hash = cluster_record['user_hash']
452
690
  cluster_user_name = cluster_record['user_name']
453
- if cluster_user_name == getpass.getuser():
691
+ if cluster_user_name == current_user:
454
692
  # Only show the hash if the username is the same as the local
455
693
  # username, to avoid confusion.
456
694
  cluster_user_hash_str = f' (hash: {cluster_user_hash})'
@@ -492,9 +730,7 @@ def launch(
492
730
  task=dag_str,
493
731
  cluster_name=cluster_name,
494
732
  retry_until_up=retry_until_up,
495
- idle_minutes_to_autostop=idle_minutes_to_autostop,
496
733
  dryrun=dryrun,
497
- down=down,
498
734
  backend=backend.NAME if backend else None,
499
735
  optimize_target=optimize_target,
500
736
  no_setup=no_setup,
@@ -507,12 +743,8 @@ def launch(
507
743
  _is_launched_by_sky_serve_controller),
508
744
  disable_controller_check=_disable_controller_check,
509
745
  )
510
- response = requests.post(
511
- f'{server_common.get_server_url()}/launch',
512
- json=json.loads(body.model_dump_json()),
513
- timeout=5,
514
- cookies=server_common.get_api_cookie_jar(),
515
- )
746
+ response = server_common.make_authenticated_request(
747
+ 'POST', '/launch', json=json.loads(body.model_dump_json()), timeout=5)
516
748
  return server_common.get_request_id(response)
517
749
 
518
750
 
@@ -524,8 +756,9 @@ def exec( # pylint: disable=redefined-builtin
524
756
  cluster_name: Optional[str] = None,
525
757
  dryrun: bool = False,
526
758
  down: bool = False, # pylint: disable=redefined-outer-name
527
- backend: Optional[backends.Backend] = None,
528
- ) -> server_common.RequestId:
759
+ backend: Optional['backends.Backend'] = None,
760
+ ) -> server_common.RequestId[Tuple[Optional[int],
761
+ Optional['backends.ResourceHandle']]]:
529
762
  """Executes a task on an existing cluster.
530
763
 
531
764
  This function performs two actions:
@@ -591,23 +824,49 @@ def exec( # pylint: disable=redefined-builtin
591
824
  backend=backend.NAME if backend else None,
592
825
  )
593
826
 
594
- response = requests.post(
595
- f'{server_common.get_server_url()}/exec',
596
- json=json.loads(body.model_dump_json()),
597
- timeout=5,
598
- cookies=server_common.get_api_cookie_jar(),
599
- )
827
+ response = server_common.make_authenticated_request(
828
+ 'POST', '/exec', json=json.loads(body.model_dump_json()), timeout=5)
600
829
  return server_common.get_request_id(response)
601
830
 
602
831
 
603
- @usage_lib.entrypoint
604
- @server_common.check_server_healthy_or_start
605
- @annotations.client_api
832
+ @typing.overload
833
+ def tail_logs(
834
+ cluster_name: str,
835
+ job_id: Optional[int],
836
+ follow: bool,
837
+ tail: int = 0,
838
+ output_stream: Optional['io.TextIOBase'] = None,
839
+ *, # keyword only separator
840
+ preload_content: Literal[True] = True) -> int:
841
+ ...
842
+
843
+
844
+ @typing.overload
606
845
  def tail_logs(cluster_name: str,
607
846
  job_id: Optional[int],
608
847
  follow: bool,
609
848
  tail: int = 0,
610
- output_stream: Optional['io.TextIOBase'] = None) -> int:
849
+ output_stream: None = None,
850
+ *,
851
+ preload_content: Literal[False]) -> Iterator[Optional[str]]:
852
+ ...
853
+
854
+
855
+ # TODO(aylei): when retry logs request, there will be duplicated log entries.
856
+ # We should fix this.
857
+ @usage_lib.entrypoint
858
+ @server_common.check_server_healthy_or_start
859
+ @annotations.client_api
860
+ @rest.retry_transient_errors()
861
+ def tail_logs(
862
+ cluster_name: str,
863
+ job_id: Optional[int],
864
+ follow: bool,
865
+ tail: int = 0,
866
+ output_stream: Optional['io.TextIOBase'] = None,
867
+ *, # keyword only separator
868
+ preload_content: bool = True
869
+ ) -> Union[int, Iterator[Optional[str]]]:
611
870
  """Tails the logs of a job.
612
871
 
613
872
  Args:
@@ -617,12 +876,21 @@ def tail_logs(cluster_name: str,
617
876
  immediately.
618
877
  tail: if > 0, tail the last N lines of the logs.
619
878
  output_stream: the stream to write the logs to. If None, print to the
620
- console.
879
+ console. Cannot be used with preload_content=False.
880
+ preload_content: if False, returns an Iterator[str | None] containing
881
+ the logs without the function blocking on the retrieval of entire
882
+ log. Iterator returns None when the log has been completely
883
+ streamed. Default True. Cannot be used with output_stream.
621
884
 
622
885
  Returns:
623
- Exit code based on success or failure of the job. 0 if success,
624
- 100 if the job failed. See exceptions.JobExitCode for possible exit
625
- codes.
886
+ If preload_content is True:
887
+ Exit code based on success or failure of the job. 0 if success,
888
+ 100 if the job failed. See exceptions.JobExitCode for possible exit
889
+ codes.
890
+ If preload_content is False:
891
+ Iterator[str | None] containing the logs without the function
892
+ blocking on the retrieval of entire log. Iterator returns None
893
+ when the log has been completely streamed.
626
894
 
627
895
  Request Raises:
628
896
  ValueError: if arguments are invalid or the cluster is not supported.
@@ -635,21 +903,110 @@ def tail_logs(cluster_name: str,
635
903
  sky.exceptions.CloudUserIdentityError: if we fail to get the current
636
904
  user identity.
637
905
  """
906
+ if output_stream is not None and not preload_content:
907
+ raise ValueError(
908
+ 'output_stream cannot be specified when preload_content is False')
909
+
638
910
  body = payloads.ClusterJobBody(
639
911
  cluster_name=cluster_name,
640
912
  job_id=job_id,
641
913
  follow=follow,
642
914
  tail=tail,
643
915
  )
644
- response = requests.post(
645
- f'{server_common.get_server_url()}/logs',
916
+ response = server_common.make_authenticated_request(
917
+ 'POST',
918
+ '/logs',
646
919
  json=json.loads(body.model_dump_json()),
647
920
  stream=True,
648
921
  timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
649
- None),
650
- cookies=server_common.get_api_cookie_jar())
651
- request_id = server_common.get_request_id(response)
652
- return stream_response(request_id, response, output_stream)
922
+ None))
923
+ request_id: server_common.RequestId[int] = server_common.get_request_id(
924
+ response)
925
+ if preload_content:
926
+ # Log request is idempotent when tail is 0, thus can resume previous
927
+ # streaming point on retry.
928
+ return stream_response(request_id=request_id,
929
+ response=response,
930
+ output_stream=output_stream,
931
+ resumable=(tail == 0))
932
+ else:
933
+ return rich_utils.decode_rich_status(response)
934
+
935
+
936
+ @usage_lib.entrypoint
937
+ @server_common.check_server_healthy_or_start
938
+ @versions.minimal_api_version(17)
939
+ @annotations.client_api
940
+ @rest.retry_transient_errors()
941
+ def tail_provision_logs(cluster_name: str,
942
+ worker: Optional[int] = None,
943
+ follow: bool = True,
944
+ tail: int = 0,
945
+ output_stream: Optional['io.TextIOBase'] = None) -> int:
946
+ """Tails the provisioning logs (provision.log) for a cluster.
947
+
948
+ Args:
949
+ cluster_name: name of the cluster.
950
+ worker: worker id in multi-node cluster.
951
+ If None, stream the logs of the head node.
952
+ follow: follow the logs.
953
+ tail: lines from end to tail.
954
+ output_stream: optional stream to write logs.
955
+ Returns:
956
+ Exit code 0 on streaming success; raises on HTTP error.
957
+ """
958
+ body = payloads.ProvisionLogsBody(cluster_name=cluster_name)
959
+
960
+ if worker is not None:
961
+ remote_api_version = versions.get_remote_api_version()
962
+ if remote_api_version is not None and remote_api_version >= 21:
963
+ if worker < 1:
964
+ raise ValueError('Worker must be a positive integer.')
965
+ body.worker = worker
966
+ else:
967
+ raise exceptions.APINotSupportedError(
968
+ 'Worker node provision logs are not supported in your API '
969
+ 'server. Please upgrade to a newer API server to use it.')
970
+ params = {
971
+ 'follow': str(follow).lower(),
972
+ 'tail': tail,
973
+ }
974
+
975
+ response = server_common.make_authenticated_request(
976
+ 'POST',
977
+ '/provision_logs',
978
+ json=json.loads(body.model_dump_json()),
979
+ params=params,
980
+ stream=True,
981
+ timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
982
+ None))
983
+ # Check for HTTP errors before streaming the response
984
+ if response.status_code != 200:
985
+ with ux_utils.print_exception_no_traceback():
986
+ raise exceptions.CommandError(response.status_code,
987
+ 'tail_provision_logs',
988
+ 'Failed to stream provision logs',
989
+ response.text)
990
+
991
+ # Log request is idempotent when tail is 0, thus can resume previous
992
+ # streaming point on retry.
993
+ # request_id=None here because /provision_logs does not create an async
994
+ # request. Instead, it streams a plain file from the server. This does NOT
995
+ # violate the stream_response doc warning about None in multi-user
996
+ # environments: we are not asking stream_response to select "the latest
997
+ # request". We already have the HTTP response to stream; request_id=None
998
+ # merely disables the follow-up GET. It is also necessary for --no-follow
999
+ # to return cleanly after printing the tailed lines. If we provided a
1000
+ # non-None request_id here, the get(request_id) in stream_response(
1001
+ # would fail since /provision_logs does not create a request record.
1002
+ # By virtue of this, we set get_result to False to block get() from
1003
+ # running.
1004
+ stream_response(request_id=None,
1005
+ response=response,
1006
+ output_stream=output_stream,
1007
+ resumable=(tail == 0),
1008
+ get_result=False)
1009
+ return 0
653
1010
 
654
1011
 
655
1012
  @usage_lib.entrypoint
@@ -683,11 +1040,11 @@ def download_logs(cluster_name: str,
683
1040
  cluster_name=cluster_name,
684
1041
  job_ids=job_ids,
685
1042
  )
686
- response = requests.post(f'{server_common.get_server_url()}/download_logs',
687
- json=json.loads(body.model_dump_json()),
688
- cookies=server_common.get_api_cookie_jar())
689
- job_id_remote_path_dict = stream_and_get(
690
- server_common.get_request_id(response))
1043
+ response = server_common.make_authenticated_request(
1044
+ 'POST', '/download_logs', json=json.loads(body.model_dump_json()))
1045
+ request_id: server_common.RequestId[Dict[
1046
+ str, str]] = server_common.get_request_id(response)
1047
+ job_id_remote_path_dict = stream_and_get(request_id)
691
1048
  remote2local_path_dict = client_common.download_logs_from_api_server(
692
1049
  job_id_remote_path_dict.values())
693
1050
  return {
@@ -702,10 +1059,11 @@ def download_logs(cluster_name: str,
702
1059
  def start(
703
1060
  cluster_name: str,
704
1061
  idle_minutes_to_autostop: Optional[int] = None,
1062
+ wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
705
1063
  retry_until_up: bool = False,
706
1064
  down: bool = False, # pylint: disable=redefined-outer-name
707
1065
  force: bool = False,
708
- ) -> server_common.RequestId:
1066
+ ) -> server_common.RequestId['backends.CloudVmRayResourceHandle']:
709
1067
  """Restart a cluster.
710
1068
 
711
1069
  If a cluster is previously stopped (status is STOPPED) or failed in
@@ -728,6 +1086,15 @@ def start(
728
1086
  flag is equivalent to running ``sky.launch()`` and then
729
1087
  ``sky.autostop(idle_minutes=<minutes>)``. If not set, the
730
1088
  cluster will not be autostopped.
1089
+ wait_for: determines the condition for resetting the idleness timer.
1090
+ This option works in conjunction with ``idle_minutes_to_autostop``.
1091
+ Choices:
1092
+
1093
+ 1. "jobs_and_ssh" (default) - Wait for in-progress jobs and SSH
1094
+ connections to finish.
1095
+ 2. "jobs" - Only wait for in-progress jobs.
1096
+ 3. "none" - Wait for nothing; autostop right after
1097
+ ``idle_minutes_to_autostop``.
731
1098
  retry_until_up: whether to retry launching the cluster until it is
732
1099
  up.
733
1100
  down: Autodown the cluster: tear down the cluster after specified
@@ -756,26 +1123,30 @@ def start(
756
1123
  sky.exceptions.ClusterOwnerIdentitiesMismatchError: if the cluster to
757
1124
  restart was launched by a different user.
758
1125
  """
1126
+ remote_api_version = versions.get_remote_api_version()
1127
+ if wait_for is not None and (remote_api_version is None or
1128
+ remote_api_version < 13):
1129
+ logger.warning('wait_for is not supported in your API server. '
1130
+ 'Please upgrade to a newer API server to use it.')
1131
+
759
1132
  body = payloads.StartBody(
760
1133
  cluster_name=cluster_name,
761
1134
  idle_minutes_to_autostop=idle_minutes_to_autostop,
1135
+ wait_for=wait_for,
762
1136
  retry_until_up=retry_until_up,
763
1137
  down=down,
764
1138
  force=force,
765
1139
  )
766
- response = requests.post(
767
- f'{server_common.get_server_url()}/start',
768
- json=json.loads(body.model_dump_json()),
769
- timeout=5,
770
- cookies=server_common.get_api_cookie_jar(),
771
- )
1140
+ response = server_common.make_authenticated_request(
1141
+ 'POST', '/start', json=json.loads(body.model_dump_json()), timeout=5)
772
1142
  return server_common.get_request_id(response)
773
1143
 
774
1144
 
775
1145
  @usage_lib.entrypoint
776
1146
  @server_common.check_server_healthy_or_start
777
1147
  @annotations.client_api
778
- def down(cluster_name: str, purge: bool = False) -> server_common.RequestId:
1148
+ def down(cluster_name: str,
1149
+ purge: bool = False) -> server_common.RequestId[None]:
779
1150
  """Tears down a cluster.
780
1151
 
781
1152
  Tearing down a cluster will delete all associated resources (all billing
@@ -809,19 +1180,16 @@ def down(cluster_name: str, purge: bool = False) -> server_common.RequestId:
809
1180
  cluster_name=cluster_name,
810
1181
  purge=purge,
811
1182
  )
812
- response = requests.post(
813
- f'{server_common.get_server_url()}/down',
814
- json=json.loads(body.model_dump_json()),
815
- timeout=5,
816
- cookies=server_common.get_api_cookie_jar(),
817
- )
1183
+ response = server_common.make_authenticated_request(
1184
+ 'POST', '/down', json=json.loads(body.model_dump_json()), timeout=5)
818
1185
  return server_common.get_request_id(response)
819
1186
 
820
1187
 
821
1188
  @usage_lib.entrypoint
822
1189
  @server_common.check_server_healthy_or_start
823
1190
  @annotations.client_api
824
- def stop(cluster_name: str, purge: bool = False) -> server_common.RequestId:
1191
+ def stop(cluster_name: str,
1192
+ purge: bool = False) -> server_common.RequestId[None]:
825
1193
  """Stops a cluster.
826
1194
 
827
1195
  Data on attached disks is not lost when a cluster is stopped. Billing for
@@ -858,12 +1226,8 @@ def stop(cluster_name: str, purge: bool = False) -> server_common.RequestId:
858
1226
  cluster_name=cluster_name,
859
1227
  purge=purge,
860
1228
  )
861
- response = requests.post(
862
- f'{server_common.get_server_url()}/stop',
863
- json=json.loads(body.model_dump_json()),
864
- timeout=5,
865
- cookies=server_common.get_api_cookie_jar(),
866
- )
1229
+ response = server_common.make_authenticated_request(
1230
+ 'POST', '/stop', json=json.loads(body.model_dump_json()), timeout=5)
867
1231
  return server_common.get_request_id(response)
868
1232
 
869
1233
 
@@ -871,10 +1235,11 @@ def stop(cluster_name: str, purge: bool = False) -> server_common.RequestId:
871
1235
  @server_common.check_server_healthy_or_start
872
1236
  @annotations.client_api
873
1237
  def autostop(
874
- cluster_name: str,
875
- idle_minutes: int,
876
- down: bool = False # pylint: disable=redefined-outer-name
877
- ) -> server_common.RequestId:
1238
+ cluster_name: str,
1239
+ idle_minutes: int,
1240
+ wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
1241
+ down: bool = False, # pylint: disable=redefined-outer-name
1242
+ ) -> server_common.RequestId[None]:
878
1243
  """Schedules an autostop/autodown for a cluster.
879
1244
 
880
1245
  Autostop/autodown will automatically stop or teardown a cluster when it
@@ -904,6 +1269,14 @@ def autostop(
904
1269
  idle_minutes: the number of minutes of idleness (no pending/running
905
1270
  jobs) after which the cluster will be stopped automatically. Setting
906
1271
  to a negative number cancels any autostop/autodown setting.
1272
+ wait_for: determines the condition for resetting the idleness timer.
1273
+ This option works in conjunction with ``idle_minutes``.
1274
+ Choices:
1275
+
1276
+ 1. "jobs_and_ssh" (default) - Wait for in-progress jobs and SSH
1277
+ connections to finish.
1278
+ 2. "jobs" - Only wait for in-progress jobs.
1279
+ 3. "none" - Wait for nothing; autostop right after ``idle_minutes``.
907
1280
  down: if true, use autodown (tear down the cluster; non-restartable),
908
1281
  rather than autostop (restartable).
909
1282
 
@@ -923,26 +1296,31 @@ def autostop(
923
1296
  sky.exceptions.CloudUserIdentityError: if we fail to get the current
924
1297
  user identity.
925
1298
  """
1299
+ remote_api_version = versions.get_remote_api_version()
1300
+ if wait_for is not None and (remote_api_version is None or
1301
+ remote_api_version < 13):
1302
+ logger.warning('wait_for is not supported in your API server. '
1303
+ 'Please upgrade to a newer API server to use it.')
1304
+
926
1305
  body = payloads.AutostopBody(
927
1306
  cluster_name=cluster_name,
928
1307
  idle_minutes=idle_minutes,
1308
+ wait_for=wait_for,
929
1309
  down=down,
930
1310
  )
931
- response = requests.post(
932
- f'{server_common.get_server_url()}/autostop',
933
- json=json.loads(body.model_dump_json()),
934
- timeout=5,
935
- cookies=server_common.get_api_cookie_jar(),
936
- )
1311
+ response = server_common.make_authenticated_request(
1312
+ 'POST', '/autostop', json=json.loads(body.model_dump_json()), timeout=5)
937
1313
  return server_common.get_request_id(response)
938
1314
 
939
1315
 
940
1316
  @usage_lib.entrypoint
941
1317
  @server_common.check_server_healthy_or_start
942
1318
  @annotations.client_api
943
- def queue(cluster_name: str,
944
- skip_finished: bool = False,
945
- all_users: bool = False) -> server_common.RequestId:
1319
+ def queue(
1320
+ cluster_name: str,
1321
+ skip_finished: bool = False,
1322
+ all_users: bool = False
1323
+ ) -> server_common.RequestId[List[responses.ClusterJobRecord]]:
946
1324
  """Gets the job queue of a cluster.
947
1325
 
948
1326
  Args:
@@ -955,8 +1333,8 @@ def queue(cluster_name: str,
955
1333
  The request ID of the queue request.
956
1334
 
957
1335
  Request Returns:
958
- job_records (List[Dict[str, Any]]): A list of dicts for each job in the
959
- queue.
1336
+ job_records (List[responses.ClusterJobRecord]): A list of job records
1337
+ for each job in the queue.
960
1338
 
961
1339
  .. code-block:: python
962
1340
 
@@ -991,17 +1369,19 @@ def queue(cluster_name: str,
991
1369
  skip_finished=skip_finished,
992
1370
  all_users=all_users,
993
1371
  )
994
- response = requests.post(f'{server_common.get_server_url()}/queue',
995
- json=json.loads(body.model_dump_json()),
996
- cookies=server_common.get_api_cookie_jar())
1372
+ response = server_common.make_authenticated_request(
1373
+ 'POST', '/queue', json=json.loads(body.model_dump_json()))
997
1374
  return server_common.get_request_id(response)
998
1375
 
999
1376
 
1000
1377
  @usage_lib.entrypoint
1001
1378
  @server_common.check_server_healthy_or_start
1002
1379
  @annotations.client_api
1003
- def job_status(cluster_name: str,
1004
- job_ids: Optional[List[int]] = None) -> server_common.RequestId:
1380
+ def job_status(
1381
+ cluster_name: str,
1382
+ job_ids: Optional[List[int]] = None
1383
+ ) -> server_common.RequestId[Dict[Optional[int],
1384
+ Optional['job_lib.JobStatus']]]:
1005
1385
  """Gets the status of jobs on a cluster.
1006
1386
 
1007
1387
  Args:
@@ -1033,9 +1413,8 @@ def job_status(cluster_name: str,
1033
1413
  cluster_name=cluster_name,
1034
1414
  job_ids=job_ids,
1035
1415
  )
1036
- response = requests.post(f'{server_common.get_server_url()}/job_status',
1037
- json=json.loads(body.model_dump_json()),
1038
- cookies=server_common.get_api_cookie_jar())
1416
+ response = server_common.make_authenticated_request(
1417
+ 'POST', '/job_status', json=json.loads(body.model_dump_json()))
1039
1418
  return server_common.get_request_id(response)
1040
1419
 
1041
1420
 
@@ -1049,7 +1428,7 @@ def cancel(
1049
1428
  job_ids: Optional[List[int]] = None,
1050
1429
  # pylint: disable=invalid-name
1051
1430
  _try_cancel_if_cluster_is_init: bool = False
1052
- ) -> server_common.RequestId:
1431
+ ) -> server_common.RequestId[None]:
1053
1432
  """Cancels jobs on a cluster.
1054
1433
 
1055
1434
  Args:
@@ -1087,9 +1466,8 @@ def cancel(
1087
1466
  job_ids=job_ids,
1088
1467
  try_cancel_if_cluster_is_init=_try_cancel_if_cluster_is_init,
1089
1468
  )
1090
- response = requests.post(f'{server_common.get_server_url()}/cancel',
1091
- json=json.loads(body.model_dump_json()),
1092
- cookies=server_common.get_api_cookie_jar())
1469
+ response = server_common.make_authenticated_request(
1470
+ 'POST', '/cancel', json=json.loads(body.model_dump_json()))
1093
1471
  return server_common.get_request_id(response)
1094
1472
 
1095
1473
 
@@ -1100,7 +1478,10 @@ def status(
1100
1478
  cluster_names: Optional[List[str]] = None,
1101
1479
  refresh: common.StatusRefreshMode = common.StatusRefreshMode.NONE,
1102
1480
  all_users: bool = False,
1103
- ) -> server_common.RequestId:
1481
+ *,
1482
+ _include_credentials: bool = False,
1483
+ _summary_response: bool = False,
1484
+ ) -> server_common.RequestId[List[responses.StatusResponse]]:
1104
1485
  """Gets cluster statuses.
1105
1486
 
1106
1487
  If cluster_names is given, return those clusters. Otherwise, return all
@@ -1148,6 +1529,8 @@ def status(
1148
1529
  provider(s).
1149
1530
  all_users: whether to include all users' clusters. By default, only
1150
1531
  the current user's clusters are included.
1532
+ _include_credentials: (internal) whether to include cluster ssh
1533
+ credentials in the response (default: False).
1151
1534
 
1152
1535
  Returns:
1153
1536
  The request ID of the status request.
@@ -1182,10 +1565,11 @@ def status(
1182
1565
  cluster_names=cluster_names,
1183
1566
  refresh=refresh,
1184
1567
  all_users=all_users,
1568
+ include_credentials=_include_credentials,
1569
+ summary_response=_summary_response,
1185
1570
  )
1186
- response = requests.post(f'{server_common.get_server_url()}/status',
1187
- json=json.loads(body.model_dump_json()),
1188
- cookies=server_common.get_api_cookie_jar())
1571
+ response = server_common.make_authenticated_request(
1572
+ 'POST', '/status', json=json.loads(body.model_dump_json()))
1189
1573
  return server_common.get_request_id(response)
1190
1574
 
1191
1575
 
@@ -1193,10 +1577,19 @@ def status(
1193
1577
  @server_common.check_server_healthy_or_start
1194
1578
  @annotations.client_api
1195
1579
  def endpoints(
1196
- cluster: str,
1197
- port: Optional[Union[int, str]] = None) -> server_common.RequestId:
1580
+ cluster: str,
1581
+ port: Optional[Union[int, str]] = None
1582
+ ) -> server_common.RequestId[Dict[int, str]]:
1198
1583
  """Gets the endpoint for a given cluster and port number (endpoint).
1199
1584
 
1585
+ Example:
1586
+ .. code-block:: python
1587
+
1588
+ import sky
1589
+ request_id = sky.endpoints('test-cluster')
1590
+ sky.get(request_id)
1591
+
1592
+
1200
1593
  Args:
1201
1594
  cluster: The name of the cluster.
1202
1595
  port: The port number to get the endpoint for. If None, endpoints
@@ -1206,8 +1599,9 @@ def endpoints(
1206
1599
  The request ID of the endpoints request.
1207
1600
 
1208
1601
  Request Returns:
1209
- A dictionary of port numbers to endpoints. If port is None,
1210
- the dictionary will contain all ports:endpoints exposed on the cluster.
1602
+ A dictionary of port numbers to endpoints.
1603
+ If port is None, the dictionary contains all
1604
+ ports:endpoints exposed on the cluster.
1211
1605
 
1212
1606
  Request Raises:
1213
1607
  ValueError: if the cluster is not UP or the endpoint is not exposed.
@@ -1218,16 +1612,17 @@ def endpoints(
1218
1612
  cluster=cluster,
1219
1613
  port=port,
1220
1614
  )
1221
- response = requests.post(f'{server_common.get_server_url()}/endpoints',
1222
- json=json.loads(body.model_dump_json()),
1223
- cookies=server_common.get_api_cookie_jar())
1615
+ response = server_common.make_authenticated_request(
1616
+ 'POST', '/endpoints', json=json.loads(body.model_dump_json()))
1224
1617
  return server_common.get_request_id(response)
1225
1618
 
1226
1619
 
1227
1620
  @usage_lib.entrypoint
1228
1621
  @server_common.check_server_healthy_or_start
1229
1622
  @annotations.client_api
1230
- def cost_report() -> server_common.RequestId: # pylint: disable=redefined-builtin
1623
+ def cost_report(
1624
+ days: Optional[int] = None
1625
+ ) -> server_common.RequestId[List[Dict[str, Any]]]: # pylint: disable=redefined-builtin
1231
1626
  """Gets all cluster cost reports, including those that have been downed.
1232
1627
 
1233
1628
  The estimated cost column indicates price for the cluster based on the type
@@ -1237,6 +1632,10 @@ def cost_report() -> server_common.RequestId: # pylint: disable=redefined-built
1237
1632
  cache of the cluster status, and may not be accurate for the cluster with
1238
1633
  autostop/use_spot set or terminated/stopped on the cloud console.
1239
1634
 
1635
+ Args:
1636
+ days: The number of days to get the cost report for. If not provided,
1637
+ the default is 30 days.
1638
+
1240
1639
  Returns:
1241
1640
  The request ID of the cost report request.
1242
1641
 
@@ -1258,8 +1657,9 @@ def cost_report() -> server_common.RequestId: # pylint: disable=redefined-built
1258
1657
  'total_cost': (float) cost given resources and usage intervals,
1259
1658
  }
1260
1659
  """
1261
- response = requests.get(f'{server_common.get_server_url()}/cost_report',
1262
- cookies=server_common.get_api_cookie_jar())
1660
+ body = payloads.CostReportBody(days=days)
1661
+ response = server_common.make_authenticated_request(
1662
+ 'POST', '/cost_report', json=json.loads(body.model_dump_json()))
1263
1663
  return server_common.get_request_id(response)
1264
1664
 
1265
1665
 
@@ -1267,36 +1667,24 @@ def cost_report() -> server_common.RequestId: # pylint: disable=redefined-built
1267
1667
  @usage_lib.entrypoint
1268
1668
  @server_common.check_server_healthy_or_start
1269
1669
  @annotations.client_api
1270
- def storage_ls() -> server_common.RequestId:
1670
+ def storage_ls() -> server_common.RequestId[List[responses.StorageRecord]]:
1271
1671
  """Gets the storages.
1272
1672
 
1273
1673
  Returns:
1274
1674
  The request ID of the storage list request.
1275
1675
 
1276
1676
  Request Returns:
1277
- storage_records (List[Dict[str, Any]]): A list of dicts, with each dict
1278
- containing the information of a storage.
1279
-
1280
- .. code-block:: python
1281
-
1282
- {
1283
- 'name': (str) storage name,
1284
- 'launched_at': (int) timestamp of creation,
1285
- 'store': (List[sky.StoreType]) storage type,
1286
- 'last_use': (int) timestamp of last use,
1287
- 'status': (sky.StorageStatus) storage status,
1288
- }
1289
- ]
1677
+ storage_records (List[responses.StorageRecord]):
1678
+ A list of storage records.
1290
1679
  """
1291
- response = requests.get(f'{server_common.get_server_url()}/storage/ls',
1292
- cookies=server_common.get_api_cookie_jar())
1680
+ response = server_common.make_authenticated_request('GET', '/storage/ls')
1293
1681
  return server_common.get_request_id(response)
1294
1682
 
1295
1683
 
1296
1684
  @usage_lib.entrypoint
1297
1685
  @server_common.check_server_healthy_or_start
1298
1686
  @annotations.client_api
1299
- def storage_delete(name: str) -> server_common.RequestId:
1687
+ def storage_delete(name: str) -> server_common.RequestId[None]:
1300
1688
  """Deletes a storage.
1301
1689
 
1302
1690
  Args:
@@ -1312,9 +1700,8 @@ def storage_delete(name: str) -> server_common.RequestId:
1312
1700
  ValueError: If the storage does not exist.
1313
1701
  """
1314
1702
  body = payloads.StorageBody(name=name)
1315
- response = requests.post(f'{server_common.get_server_url()}/storage/delete',
1316
- json=json.loads(body.model_dump_json()),
1317
- cookies=server_common.get_api_cookie_jar())
1703
+ response = server_common.make_authenticated_request(
1704
+ 'POST', '/storage/delete', json=json.loads(body.model_dump_json()))
1318
1705
  return server_common.get_request_id(response)
1319
1706
 
1320
1707
 
@@ -1325,12 +1712,8 @@ def storage_delete(name: str) -> server_common.RequestId:
1325
1712
  @server_common.check_server_healthy_or_start
1326
1713
  @annotations.client_api
1327
1714
  def local_up(gpus: bool,
1328
- ips: Optional[List[str]],
1329
- ssh_user: Optional[str],
1330
- ssh_key: Optional[str],
1331
- cleanup: bool,
1332
- context_name: Optional[str] = None,
1333
- password: Optional[str] = None) -> server_common.RequestId:
1715
+ name: Optional[str] = None,
1716
+ port_start: Optional[int] = None) -> server_common.RequestId[None]:
1334
1717
  """Launches a Kubernetes cluster on local machines.
1335
1718
 
1336
1719
  Returns:
@@ -1341,36 +1724,151 @@ def local_up(gpus: bool,
1341
1724
  # TODO: move this check to server.
1342
1725
  if not server_common.is_api_server_local():
1343
1726
  with ux_utils.print_exception_no_traceback():
1344
- raise ValueError(
1345
- 'sky local up is only supported when running SkyPilot locally.')
1346
-
1347
- body = payloads.LocalUpBody(gpus=gpus,
1348
- ips=ips,
1349
- ssh_user=ssh_user,
1350
- ssh_key=ssh_key,
1351
- cleanup=cleanup,
1352
- context_name=context_name,
1353
- password=password)
1354
- response = requests.post(f'{server_common.get_server_url()}/local_up',
1355
- json=json.loads(body.model_dump_json()),
1356
- cookies=server_common.get_api_cookie_jar())
1727
+ raise ValueError('`sky local up` is only supported when '
1728
+ 'running SkyPilot locally.')
1729
+
1730
+ body = payloads.LocalUpBody(gpus=gpus, name=name, port_start=port_start)
1731
+ response = server_common.make_authenticated_request(
1732
+ 'POST', '/local_up', json=json.loads(body.model_dump_json()))
1357
1733
  return server_common.get_request_id(response)
1358
1734
 
1359
1735
 
1360
1736
  @usage_lib.entrypoint
1361
1737
  @server_common.check_server_healthy_or_start
1362
1738
  @annotations.client_api
1363
- def local_down() -> server_common.RequestId:
1739
+ def local_down(name: Optional[str]) -> server_common.RequestId[None]:
1364
1740
  """Tears down the Kubernetes cluster started by local_up."""
1365
1741
  # We do not allow local up when the API server is running remotely since it
1366
1742
  # will modify the kubeconfig.
1367
1743
  # TODO: move this check to remote server.
1368
1744
  if not server_common.is_api_server_local():
1369
1745
  with ux_utils.print_exception_no_traceback():
1370
- raise ValueError('sky local down is only supported when running '
1746
+ raise ValueError('`sky local down` is only supported when running '
1371
1747
  'SkyPilot locally.')
1372
- response = requests.post(f'{server_common.get_server_url()}/local_down',
1373
- cookies=server_common.get_api_cookie_jar())
1748
+
1749
+ body = payloads.LocalDownBody(name=name)
1750
+ response = server_common.make_authenticated_request(
1751
+ 'POST', '/local_down', json=json.loads(body.model_dump_json()))
1752
+ return server_common.get_request_id(response)
1753
+
1754
+
1755
+ def _update_remote_ssh_node_pools(file: str,
1756
+ infra: Optional[str] = None) -> None:
1757
+ """Update the SSH node pools on the remote server.
1758
+
1759
+ This function will also upload the local SSH key to the remote server, and
1760
+ replace the file path to the remote SSH key file path.
1761
+
1762
+ Args:
1763
+ file: The path to the local SSH node pools config file.
1764
+ infra: The name of the cluster configuration in the local SSH node
1765
+ pools config file. If None, all clusters in the file are updated.
1766
+ """
1767
+ file = os.path.expanduser(file)
1768
+ if not os.path.exists(file):
1769
+ with ux_utils.print_exception_no_traceback():
1770
+ raise ValueError(
1771
+ f'SSH Node Pool config file {file} does not exist. '
1772
+ 'Please check if the file exists and the path is correct.')
1773
+ config = ssh_utils.load_ssh_targets(file)
1774
+ config = ssh_utils.get_cluster_config(config, infra)
1775
+ pools_config = {}
1776
+ for name, pool_config in config.items():
1777
+ hosts_info = ssh_utils.prepare_hosts_info(
1778
+ name, pool_config, upload_ssh_key_func=_upload_ssh_key_and_wait)
1779
+ pools_config[name] = {'hosts': hosts_info}
1780
+ server_common.make_authenticated_request('POST',
1781
+ '/ssh_node_pools',
1782
+ json=pools_config)
1783
+
1784
+
1785
+ def _upload_ssh_key_and_wait(key_name: str, key_file_path: str) -> str:
1786
+ """Upload the SSH key to the remote server and wait for the key to be
1787
+ uploaded.
1788
+
1789
+ Args:
1790
+ key_name: The name of the SSH key.
1791
+ key_file_path: The path to the local SSH key file.
1792
+
1793
+ Returns:
1794
+ The path for the remote SSH key file on the API server.
1795
+ """
1796
+ if not os.path.exists(os.path.expanduser(key_file_path)):
1797
+ with ux_utils.print_exception_no_traceback():
1798
+ raise ValueError(f'SSH key file not found: {key_file_path}')
1799
+
1800
+ with open(os.path.expanduser(key_file_path), 'rb') as key_file:
1801
+ response = server_common.make_authenticated_request(
1802
+ 'POST',
1803
+ '/ssh_node_pools/keys',
1804
+ files={
1805
+ 'key_file': (key_name, key_file, 'application/octet-stream')
1806
+ },
1807
+ data={'key_name': key_name},
1808
+ cookies=server_common.get_api_cookie_jar())
1809
+
1810
+ return response.json()['key_path']
1811
+
1812
+
1813
+ @usage_lib.entrypoint
1814
+ @server_common.check_server_healthy_or_start
1815
+ @annotations.client_api
1816
+ def ssh_up(infra: Optional[str] = None,
1817
+ file: Optional[str] = None) -> server_common.RequestId[None]:
1818
+ """Deploys the SSH Node Pools defined in ~/.sky/ssh_targets.yaml.
1819
+
1820
+ Args:
1821
+ infra: Name of the cluster configuration in ssh_targets.yaml.
1822
+ If None, the first cluster in the file is used.
1823
+ file: Name of the ssh node pool configuration file to use. If
1824
+ None, the default path, ~/.sky/ssh_node_pools.yaml is used.
1825
+
1826
+ Returns:
1827
+ request_id: The request ID of the SSH cluster deployment request.
1828
+ """
1829
+ if file is not None:
1830
+ _update_remote_ssh_node_pools(file, infra)
1831
+
1832
+ # Use SSH node pools router endpoint
1833
+ body = payloads.SSHUpBody(infra=infra, cleanup=False)
1834
+ if infra is not None:
1835
+ # Call the specific pool deployment endpoint
1836
+ response = server_common.make_authenticated_request(
1837
+ 'POST', f'/ssh_node_pools/{infra}/deploy')
1838
+ else:
1839
+ # Call the general deployment endpoint
1840
+ response = server_common.make_authenticated_request(
1841
+ 'POST',
1842
+ '/ssh_node_pools/deploy',
1843
+ json=json.loads(body.model_dump_json()))
1844
+ return server_common.get_request_id(response)
1845
+
1846
+
1847
+ @usage_lib.entrypoint
1848
+ @server_common.check_server_healthy_or_start
1849
+ @annotations.client_api
1850
+ def ssh_down(infra: Optional[str] = None) -> server_common.RequestId[None]:
1851
+ """Tears down a Kubernetes cluster on SSH targets.
1852
+
1853
+ Args:
1854
+ infra: Name of the cluster configuration in ssh_targets.yaml.
1855
+ If None, the first cluster in the file is used.
1856
+
1857
+ Returns:
1858
+ request_id: The request ID of the SSH cluster teardown request.
1859
+ """
1860
+ # Use SSH node pools router endpoint
1861
+ body = payloads.SSHUpBody(infra=infra, cleanup=True)
1862
+ if infra is not None:
1863
+ # Call the specific pool down endpoint
1864
+ response = server_common.make_authenticated_request(
1865
+ 'POST', f'/ssh_node_pools/{infra}/down')
1866
+ else:
1867
+ # Call the general down endpoint
1868
+ response = server_common.make_authenticated_request(
1869
+ 'POST',
1870
+ '/ssh_node_pools/down',
1871
+ json=json.loads(body.model_dump_json()))
1374
1872
  return server_common.get_request_id(response)
1375
1873
 
1376
1874
 
@@ -1378,9 +1876,12 @@ def local_down() -> server_common.RequestId:
1378
1876
  @server_common.check_server_healthy_or_start
1379
1877
  @annotations.client_api
1380
1878
  def realtime_kubernetes_gpu_availability(
1381
- context: Optional[str] = None,
1382
- name_filter: Optional[str] = None,
1383
- quantity_filter: Optional[int] = None) -> server_common.RequestId:
1879
+ context: Optional[str] = None,
1880
+ name_filter: Optional[str] = None,
1881
+ quantity_filter: Optional[int] = None,
1882
+ is_ssh: Optional[bool] = None
1883
+ ) -> server_common.RequestId[List[Tuple[
1884
+ str, List['models.RealtimeGpuAvailability']]]]:
1384
1885
  """Gets the real-time Kubernetes GPU availability.
1385
1886
 
1386
1887
  Returns:
@@ -1390,12 +1891,12 @@ def realtime_kubernetes_gpu_availability(
1390
1891
  context=context,
1391
1892
  name_filter=name_filter,
1392
1893
  quantity_filter=quantity_filter,
1894
+ is_ssh=is_ssh,
1393
1895
  )
1394
- response = requests.post(
1395
- f'{server_common.get_server_url()}/'
1396
- 'realtime_kubernetes_gpu_availability',
1397
- json=json.loads(body.model_dump_json()),
1398
- cookies=server_common.get_api_cookie_jar())
1896
+ response = server_common.make_authenticated_request(
1897
+ 'POST',
1898
+ '/realtime_kubernetes_gpu_availability',
1899
+ json=json.loads(body.model_dump_json()))
1399
1900
  return server_common.get_request_id(response)
1400
1901
 
1401
1902
 
@@ -1403,7 +1904,8 @@ def realtime_kubernetes_gpu_availability(
1403
1904
  @server_common.check_server_healthy_or_start
1404
1905
  @annotations.client_api
1405
1906
  def kubernetes_node_info(
1406
- context: Optional[str] = None) -> server_common.RequestId:
1907
+ context: Optional[str] = None
1908
+ ) -> server_common.RequestId['models.KubernetesNodesInfo']:
1407
1909
  """Gets the resource information for all the nodes in the cluster.
1408
1910
 
1409
1911
  Currently only GPU resources are supported. The function returns the total
@@ -1424,18 +1926,22 @@ def kubernetes_node_info(
1424
1926
  information.
1425
1927
  """
1426
1928
  body = payloads.KubernetesNodeInfoRequestBody(context=context)
1427
- response = requests.post(
1428
- f'{server_common.get_server_url()}/kubernetes_node_info',
1429
- json=json.loads(body.model_dump_json()),
1430
- cookies=server_common.get_api_cookie_jar())
1929
+ response = server_common.make_authenticated_request(
1930
+ 'POST',
1931
+ '/kubernetes_node_info',
1932
+ json=json.loads(body.model_dump_json()))
1431
1933
  return server_common.get_request_id(response)
1432
1934
 
1433
1935
 
1434
1936
  @usage_lib.entrypoint
1435
1937
  @server_common.check_server_healthy_or_start
1436
1938
  @annotations.client_api
1437
- def status_kubernetes() -> server_common.RequestId:
1438
- """Gets all SkyPilot clusters and jobs in the Kubernetes cluster.
1939
+ def status_kubernetes() -> server_common.RequestId[
1940
+ Tuple[List['kubernetes_utils.KubernetesSkyPilotClusterInfoPayload'],
1941
+ List['kubernetes_utils.KubernetesSkyPilotClusterInfoPayload'],
1942
+ List[responses.ManagedJobRecord], Optional[str]]]:
1943
+ """[Experimental] Gets all SkyPilot clusters and jobs
1944
+ in the Kubernetes cluster.
1439
1945
 
1440
1946
  Managed jobs and services are also included in the clusters returned.
1441
1947
  The caller must parse the controllers to identify which clusters are run
@@ -1455,21 +1961,24 @@ def status_kubernetes() -> server_common.RequestId:
1455
1961
  dictionary job info, see jobs.queue_from_kubernetes_pod for details.
1456
1962
  - context: Kubernetes context used to fetch the cluster information.
1457
1963
  """
1458
- response = requests.get(
1459
- f'{server_common.get_server_url()}/status_kubernetes',
1460
- cookies=server_common.get_api_cookie_jar())
1964
+ response = server_common.make_authenticated_request('GET',
1965
+ '/status_kubernetes')
1461
1966
  return server_common.get_request_id(response)
1462
1967
 
1463
1968
 
1464
1969
  # === API request APIs ===
1465
1970
  @usage_lib.entrypoint
1466
- @server_common.check_server_healthy_or_start
1467
1971
  @annotations.client_api
1468
- def get(request_id: str) -> Any:
1972
+ def get(request_id: server_common.RequestId[T]) -> T:
1469
1973
  """Waits for and gets the result of a request.
1470
1974
 
1975
+ This function will not check the server health since /api/get is typically
1976
+ not the first API call in an SDK session and checking the server health
1977
+ may cause GET /api/get being sent to a restarted API server.
1978
+
1471
1979
  Args:
1472
- request_id: The request ID of the request to get.
1980
+ request_id: The request ID of the request to get. May be a full request
1981
+ ID or a prefix.
1473
1982
 
1474
1983
  Returns:
1475
1984
  The ``Request Returns`` of the specified request. See the documentation
@@ -1480,19 +1989,20 @@ def get(request_id: str) -> Any:
1480
1989
  see ``Request Raises`` in the documentation of the specific requests
1481
1990
  above.
1482
1991
  """
1483
- response = requests.get(
1484
- f'{server_common.get_server_url()}/api/get?request_id={request_id}',
1992
+ response = server_common.make_authenticated_request(
1993
+ 'GET',
1994
+ f'/api/get?request_id={request_id}',
1995
+ retry=False,
1485
1996
  timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
1486
- None),
1487
- cookies=server_common.get_api_cookie_jar())
1997
+ None))
1488
1998
  request_task = None
1489
1999
  if response.status_code == 200:
1490
2000
  request_task = requests_lib.Request.decode(
1491
- requests_lib.RequestPayload(**response.json()))
2001
+ payloads.RequestPayload(**response.json()))
1492
2002
  elif response.status_code == 500:
1493
2003
  try:
1494
2004
  request_task = requests_lib.Request.decode(
1495
- requests_lib.RequestPayload(**response.json().get('detail')))
2005
+ payloads.RequestPayload(**response.json().get('detail')))
1496
2006
  logger.debug(f'Got request with error: {request_task.name}')
1497
2007
  except Exception: # pylint: disable=broad-except
1498
2008
  request_task = None
@@ -1503,12 +2013,7 @@ def get(request_id: str) -> Any:
1503
2013
  error = request_task.get_error()
1504
2014
  if error is not None:
1505
2015
  error_obj = error['object']
1506
- if env_options.Options.SHOW_DEBUG_INFO.get():
1507
- stacktrace = getattr(error_obj, 'stacktrace', str(error_obj))
1508
- logger.error('=== Traceback on SkyPilot API Server ===\n'
1509
- f'{stacktrace}')
1510
- with ux_utils.print_exception_no_traceback():
1511
- raise error_obj
2016
+ _raise_exception_object_on_client(error_obj)
1512
2017
  if request_task.status == requests_lib.RequestStatus.CANCELLED:
1513
2018
  with ux_utils.print_exception_no_traceback():
1514
2019
  raise exceptions.RequestCancelled(
@@ -1518,23 +2023,45 @@ def get(request_id: str) -> Any:
1518
2023
  return request_task.get_return_value()
1519
2024
 
1520
2025
 
2026
+ @typing.overload
2027
+ def stream_and_get(request_id: server_common.RequestId[T],
2028
+ log_path: Optional[str] = None,
2029
+ tail: Optional[int] = None,
2030
+ follow: bool = True,
2031
+ output_stream: Optional['io.TextIOBase'] = None) -> T:
2032
+ ...
2033
+
2034
+
2035
+ @typing.overload
2036
+ def stream_and_get(request_id: None = None,
2037
+ log_path: Optional[str] = None,
2038
+ tail: Optional[int] = None,
2039
+ follow: bool = True,
2040
+ output_stream: Optional['io.TextIOBase'] = None) -> None:
2041
+ ...
2042
+
2043
+
1521
2044
  @usage_lib.entrypoint
1522
2045
  @server_common.check_server_healthy_or_start
1523
2046
  @annotations.client_api
2047
+ @rest.retry_transient_errors()
1524
2048
  def stream_and_get(
1525
- request_id: Optional[str] = None,
2049
+ request_id: Optional[server_common.RequestId[T]] = None,
1526
2050
  log_path: Optional[str] = None,
1527
2051
  tail: Optional[int] = None,
1528
2052
  follow: bool = True,
1529
2053
  output_stream: Optional['io.TextIOBase'] = None,
1530
- ) -> Any:
2054
+ ) -> Optional[T]:
1531
2055
  """Streams the logs of a request or a log file and gets the final result.
1532
2056
 
1533
2057
  This will block until the request is finished. The request id can be a
1534
2058
  prefix of the full request id.
1535
2059
 
1536
2060
  Args:
1537
- request_id: The prefix of the request ID of the request to stream.
2061
+ request_id: The request ID of the request to stream. May be a full
2062
+ request ID or a prefix.
2063
+ If None, the latest request submitted to the API server is streamed.
2064
+ Using None request_id is not recommended in multi-user environments.
1538
2065
  log_path: The path to the log file to stream.
1539
2066
  tail: The number of lines to show from the end of the logs.
1540
2067
  If None, show all logs.
@@ -1545,6 +2072,8 @@ def stream_and_get(
1545
2072
  Returns:
1546
2073
  The ``Request Returns`` of the specified request. See the documentation
1547
2074
  of the specific requests above for more details.
2075
+ If follow is False, will always return None. See note on
2076
+ stream_response.
1548
2077
 
1549
2078
  Raises:
1550
2079
  Exception: It raises the same exceptions as the specific requests,
@@ -1558,27 +2087,44 @@ def stream_and_get(
1558
2087
  'follow': follow,
1559
2088
  'format': 'console',
1560
2089
  }
1561
- response = requests.get(
1562
- f'{server_common.get_server_url()}/api/stream',
2090
+ response = server_common.make_authenticated_request(
2091
+ 'GET',
2092
+ '/api/stream',
1563
2093
  params=params,
2094
+ retry=False,
1564
2095
  timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
1565
2096
  None),
1566
- stream=True,
1567
- cookies=server_common.get_api_cookie_jar())
2097
+ stream=True)
1568
2098
  if response.status_code in [404, 400]:
1569
2099
  detail = response.json().get('detail')
1570
2100
  with ux_utils.print_exception_no_traceback():
1571
- raise RuntimeError(f'Failed to stream logs: {detail}')
2101
+ raise exceptions.ClientError(f'Failed to stream logs: {detail}')
2102
+ stream_request_id: Optional[server_common.RequestId[
2103
+ T]] = server_common.get_stream_request_id(response)
2104
+ if request_id is not None and stream_request_id is not None:
2105
+ assert request_id == stream_request_id
2106
+ if request_id is None:
2107
+ request_id = stream_request_id
1572
2108
  elif response.status_code != 200:
2109
+ # TODO(syang): handle the case where the requestID is not provided
2110
+ # see https://github.com/skypilot-org/skypilot/issues/6549
2111
+ if request_id is None:
2112
+ return None
1573
2113
  return get(request_id)
1574
- return stream_response(request_id, response, output_stream)
2114
+ return stream_response(request_id,
2115
+ response,
2116
+ output_stream,
2117
+ resumable=True,
2118
+ get_result=follow)
1575
2119
 
1576
2120
 
1577
2121
  @usage_lib.entrypoint
1578
2122
  @annotations.client_api
1579
- def api_cancel(request_ids: Optional[Union[str, List[str]]] = None,
2123
+ def api_cancel(request_ids: Optional[Union[server_common.RequestId[T],
2124
+ List[server_common.RequestId[T]],
2125
+ str, List[str]]] = None,
1580
2126
  all_users: bool = False,
1581
- silent: bool = False) -> server_common.RequestId:
2127
+ silent: bool = False) -> server_common.RequestId[List[str]]:
1582
2128
  """Aborts a request or all requests.
1583
2129
 
1584
2130
  Args:
@@ -1618,20 +2164,35 @@ def api_cancel(request_ids: Optional[Union[str, List[str]]] = None,
1618
2164
  echo(f'Cancelling {len(request_ids)} request{plural}: '
1619
2165
  f'{request_id_str}...')
1620
2166
 
1621
- response = requests.post(f'{server_common.get_server_url()}/api/cancel',
1622
- json=json.loads(body.model_dump_json()),
1623
- timeout=5,
1624
- cookies=server_common.get_api_cookie_jar())
2167
+ response = server_common.make_authenticated_request(
2168
+ 'POST',
2169
+ '/api/cancel',
2170
+ json=json.loads(body.model_dump_json()),
2171
+ timeout=5)
1625
2172
  return server_common.get_request_id(response)
1626
2173
 
1627
2174
 
2175
+ def _local_api_server_running(kill: bool = False) -> bool:
2176
+ """Checks if the local api server is running."""
2177
+ for process in psutil.process_iter(attrs=['pid', 'cmdline']):
2178
+ cmdline = process.info['cmdline']
2179
+ if cmdline and server_common.API_SERVER_CMD in ' '.join(cmdline):
2180
+ if kill:
2181
+ subprocess_utils.kill_children_processes(
2182
+ parent_pids=[process.pid], force=True)
2183
+ return True
2184
+ return False
2185
+
2186
+
1628
2187
  @usage_lib.entrypoint
1629
2188
  @annotations.client_api
1630
2189
  def api_status(
1631
- request_ids: Optional[List[str]] = None,
2190
+ request_ids: Optional[List[Union[server_common.RequestId[T], str]]] = None,
1632
2191
  # pylint: disable=redefined-builtin
1633
- all_status: bool = False
1634
- ) -> List[requests_lib.RequestPayload]:
2192
+ all_status: bool = False,
2193
+ limit: Optional[int] = None,
2194
+ fields: Optional[List[str]] = None,
2195
+ ) -> List[payloads.RequestPayload]:
1635
2196
  """Lists all requests.
1636
2197
 
1637
2198
  Args:
@@ -1639,29 +2200,37 @@ def api_status(
1639
2200
  If None, all requests are queried.
1640
2201
  all_status: Whether to list all finished requests as well. This argument
1641
2202
  is ignored if request_ids is not None.
2203
+ limit: The number of requests to show. If None, show all requests.
2204
+ fields: The fields to get. If None, get all fields.
1642
2205
 
1643
2206
  Returns:
1644
2207
  A list of request payloads.
1645
2208
  """
1646
- body = payloads.RequestStatusBody(request_ids=request_ids,
1647
- all_status=all_status)
1648
- response = requests.get(
1649
- f'{server_common.get_server_url()}/api/status',
2209
+ if server_common.is_api_server_local() and not _local_api_server_running():
2210
+ logger.info('SkyPilot API server is not running.')
2211
+ return []
2212
+
2213
+ body = payloads.RequestStatusBody(
2214
+ request_ids=request_ids,
2215
+ all_status=all_status,
2216
+ limit=limit,
2217
+ fields=fields,
2218
+ )
2219
+ response = server_common.make_authenticated_request(
2220
+ 'GET',
2221
+ '/api/status',
1650
2222
  params=server_common.request_body_to_params(body),
1651
2223
  timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
1652
- None),
1653
- cookies=server_common.get_api_cookie_jar())
2224
+ None))
1654
2225
  server_common.handle_request_error(response)
1655
- return [
1656
- requests_lib.RequestPayload(**request) for request in response.json()
1657
- ]
2226
+ return [payloads.RequestPayload(**request) for request in response.json()]
1658
2227
 
1659
2228
 
1660
2229
  # === API server management APIs ===
1661
2230
  @usage_lib.entrypoint
1662
2231
  @server_common.check_server_healthy_or_start
1663
2232
  @annotations.client_api
1664
- def api_info() -> Dict[str, str]:
2233
+ def api_info() -> responses.APIHealthResponse:
1665
2234
  """Gets the server's status, commit and version.
1666
2235
 
1667
2236
  Returns:
@@ -1674,13 +2243,19 @@ def api_info() -> Dict[str, str]:
1674
2243
  'api_version': '1',
1675
2244
  'commit': 'abc1234567890',
1676
2245
  'version': '1.0.0',
2246
+ 'version_on_disk': '1.0.0',
2247
+ 'user': {
2248
+ 'name': 'test@example.com',
2249
+ 'id': '12345abcd',
2250
+ },
1677
2251
  }
1678
2252
 
2253
+ Note that user may be None if we are not using an auth proxy.
2254
+
1679
2255
  """
1680
- response = requests.get(f'{server_common.get_server_url()}/api/health',
1681
- cookies=server_common.get_api_cookie_jar())
2256
+ response = server_common.make_authenticated_request('GET', '/api/health')
1682
2257
  response.raise_for_status()
1683
- return response.json()
2258
+ return responses.APIHealthResponse(**response.json())
1684
2259
 
1685
2260
 
1686
2261
  @usage_lib.entrypoint
@@ -1690,6 +2265,9 @@ def api_start(
1690
2265
  deploy: bool = False,
1691
2266
  host: str = '127.0.0.1',
1692
2267
  foreground: bool = False,
2268
+ metrics: bool = False,
2269
+ metrics_port: Optional[int] = None,
2270
+ enable_basic_auth: bool = False,
1693
2271
  ) -> None:
1694
2272
  """Starts the API server.
1695
2273
 
@@ -1703,6 +2281,10 @@ def api_start(
1703
2281
  if deploy is True, to allow remote access.
1704
2282
  foreground: Whether to run the API server in the foreground (run in
1705
2283
  the current process).
2284
+ metrics: Whether to export metrics of the API server.
2285
+ metrics_port: The port to export metrics of the API server.
2286
+ enable_basic_auth: Whether to enable basic authentication
2287
+ in the API server.
1706
2288
  Returns:
1707
2289
  None
1708
2290
  """
@@ -1721,15 +2303,15 @@ def api_start(
1721
2303
  'from the config file and/or unset the '
1722
2304
  'SKYPILOT_API_SERVER_ENDPOINT environment '
1723
2305
  'variable.')
1724
- server_common.check_server_healthy_or_start_fn(deploy, host, foreground)
2306
+ server_common.check_server_healthy_or_start_fn(deploy, host, foreground,
2307
+ metrics, metrics_port,
2308
+ enable_basic_auth)
1725
2309
  if foreground:
1726
2310
  # Explain why current process exited
1727
2311
  logger.info('API server is already running:')
1728
2312
  api_server_url = server_common.get_server_url(host)
1729
- dashboard_url = server_common.get_dashboard_url(api_server_url)
1730
- dashboard_msg = f'Dashboard: {dashboard_url}'
1731
- logger.info(f'{ux_utils.INDENT_SYMBOL}SkyPilot API server: '
1732
- f'{api_server_url} {dashboard_msg}\n'
2313
+ logger.info(f'{ux_utils.INDENT_SYMBOL}SkyPilot API server and dashboard: '
2314
+ f'{api_server_url}\n'
1733
2315
  f'{ux_utils.INDENT_LAST_SYMBOL}'
1734
2316
  f'View API server logs at: {constants.API_SERVER_LOGS}')
1735
2317
 
@@ -1752,16 +2334,32 @@ def api_stop() -> None:
1752
2334
  f'Cannot kill the API server at {server_url} because it is not '
1753
2335
  f'the default SkyPilot API server started locally.')
1754
2336
 
1755
- found = False
1756
- for process in psutil.process_iter(attrs=['pid', 'cmdline']):
1757
- cmdline = process.info['cmdline']
1758
- if cmdline and server_common.API_SERVER_CMD in ' '.join(cmdline):
1759
- subprocess_utils.kill_children_processes(parent_pids=[process.pid],
1760
- force=True)
1761
- found = True
1762
-
1763
- # Remove the database for requests.
1764
- server_common.clear_local_api_server_database()
2337
+ # Acquire the api server creation lock to prevent multiple processes from
2338
+ # stopping and starting the API server at the same time.
2339
+ with filelock.FileLock(
2340
+ os.path.expanduser(constants.API_SERVER_CREATION_LOCK_PATH)):
2341
+ try:
2342
+ records = scheduler.get_controller_process_records()
2343
+ if records is not None:
2344
+ for record in records:
2345
+ try:
2346
+ if managed_job_utils.controller_process_alive(
2347
+ record, quiet=False):
2348
+ subprocess_utils.kill_children_processes(
2349
+ parent_pids=[record.pid], force=True)
2350
+ except (psutil.NoSuchProcess, psutil.ZombieProcess):
2351
+ continue
2352
+ os.remove(os.path.expanduser(scheduler.JOB_CONTROLLER_PID_PATH))
2353
+ except FileNotFoundError:
2354
+ # its fine we will create it
2355
+ pass
2356
+ except Exception as e: # pylint: disable=broad-except
2357
+ # in case we get perm issues or something is messed up, just ignore
2358
+ # it and assume the process is dead
2359
+ logger.error(f'Error looking at job controller pid file: {e}')
2360
+ pass
2361
+
2362
+ found = _local_api_server_running(kill=True)
1765
2363
 
1766
2364
  if found:
1767
2365
  logger.info(f'{colorama.Fore.GREEN}SkyPilot API server stopped.'
@@ -1796,9 +2394,86 @@ def api_server_logs(follow: bool = True, tail: Optional[int] = None) -> None:
1796
2394
  stream_and_get(log_path=constants.API_SERVER_LOGS, tail=tail)
1797
2395
 
1798
2396
 
2397
+ def _save_config_updates(endpoint: Optional[str] = None,
2398
+ service_account_token: Optional[str] = None) -> None:
2399
+ """Save endpoint and/or service account token to config file."""
2400
+ config_path = pathlib.Path(
2401
+ skypilot_config.get_user_config_path()).expanduser()
2402
+ with filelock.FileLock(config_path.with_suffix('.lock')):
2403
+ if not config_path.exists():
2404
+ config_path.touch()
2405
+ config: Dict[str, Any] = {}
2406
+ else:
2407
+ config = skypilot_config.get_user_config()
2408
+ config = dict(config)
2409
+
2410
+ # Update endpoint if provided
2411
+ if endpoint is not None:
2412
+ # We should always reset the api_server config to avoid legacy
2413
+ # service account token.
2414
+ config['api_server'] = {}
2415
+ config['api_server']['endpoint'] = endpoint
2416
+
2417
+ # Update service account token if provided
2418
+ if service_account_token is not None:
2419
+ if 'api_server' not in config:
2420
+ config['api_server'] = {}
2421
+ config['api_server'][
2422
+ 'service_account_token'] = service_account_token
2423
+
2424
+ yaml_utils.dump_yaml(str(config_path), config)
2425
+ skypilot_config.reload_config()
2426
+
2427
+
2428
+ def _clear_api_server_config() -> None:
2429
+ """Clear endpoint and service account token from config file."""
2430
+ config_path = pathlib.Path(
2431
+ skypilot_config.get_user_config_path()).expanduser()
2432
+ with filelock.FileLock(config_path.with_suffix('.lock')):
2433
+ if not config_path.exists():
2434
+ return
2435
+
2436
+ config = skypilot_config.get_user_config()
2437
+ config = dict(config)
2438
+ if 'api_server' in config:
2439
+ # We might not have set the endpoint in the config file, so we
2440
+ # need to check before deleting.
2441
+ del config['api_server']
2442
+
2443
+ yaml_utils.dump_yaml(str(config_path), config, blank=True)
2444
+ skypilot_config.reload_config()
2445
+
2446
+
2447
+ def _validate_endpoint(endpoint: Optional[str]) -> str:
2448
+ """Validate and normalize the endpoint URL."""
2449
+ if endpoint is None:
2450
+ endpoint = click.prompt('Enter your SkyPilot API server endpoint')
2451
+ # Check endpoint is a valid URL
2452
+ if (endpoint is not None and not endpoint.startswith('http://') and
2453
+ not endpoint.startswith('https://')):
2454
+ raise click.BadParameter('Endpoint must be a valid URL.')
2455
+ return endpoint.rstrip('/')
2456
+
2457
+
2458
+ def _check_endpoint_in_env_var(is_login: bool) -> None:
2459
+ # If the user has set the endpoint via the environment variable, we should
2460
+ # not do anything as we can't disambiguate between the env var and the
2461
+ # config file.
2462
+ """Check if the endpoint is set in the environment variable."""
2463
+ if constants.SKY_API_SERVER_URL_ENV_VAR in os.environ:
2464
+ with ux_utils.print_exception_no_traceback():
2465
+ action = 'login to' if is_login else 'logout of'
2466
+ raise RuntimeError(f'Cannot {action} API server when the endpoint '
2467
+ 'is set via the environment variable. Run unset '
2468
+ f'{constants.SKY_API_SERVER_URL_ENV_VAR} to '
2469
+ 'clear the environment variable.')
2470
+
2471
+
1799
2472
  @usage_lib.entrypoint
1800
2473
  @annotations.client_api
1801
- def api_login(endpoint: Optional[str] = None) -> None:
2474
+ def api_login(endpoint: Optional[str] = None,
2475
+ relogin: bool = False,
2476
+ service_account_token: Optional[str] = None) -> None:
1802
2477
  """Logs into a SkyPilot API server.
1803
2478
 
1804
2479
  This sets the endpoint globally, i.e., all SkyPilot CLI and SDK calls will
@@ -1810,37 +2485,262 @@ def api_login(endpoint: Optional[str] = None) -> None:
1810
2485
  Args:
1811
2486
  endpoint: The endpoint of the SkyPilot API server, e.g.,
1812
2487
  http://1.2.3.4:46580 or https://skypilot.mydomain.com.
2488
+ relogin: Whether to force relogin with OAuth2 when enabled.
2489
+ service_account_token: Service account token for authentication.
1813
2490
 
1814
2491
  Returns:
1815
2492
  None
1816
2493
  """
1817
- # TODO(zhwu): this SDK sets global endpoint, which may not be the best
1818
- # design as a user may expect this is only effective for the current
1819
- # session. We should consider using env var for specifying endpoint.
1820
- if endpoint is None:
1821
- endpoint = click.prompt('Enter your SkyPilot API server endpoint')
1822
- # Check endpoint is a valid URL
1823
- if (endpoint is not None and not endpoint.startswith('http://') and
1824
- not endpoint.startswith('https://')):
1825
- raise click.BadParameter('Endpoint must be a valid URL.')
1826
-
1827
- server_common.check_server_healthy(endpoint)
1828
-
1829
- # Set the endpoint in the config file
1830
- config_path = pathlib.Path(
1831
- skypilot_config.get_user_config_path()).expanduser()
1832
- with filelock.FileLock(config_path.with_suffix('.lock')):
1833
- if not config_path.exists():
1834
- config_path.touch()
1835
- config = {'api_server': {'endpoint': endpoint}}
2494
+ _check_endpoint_in_env_var(is_login=True)
2495
+
2496
+ # Validate and normalize endpoint
2497
+ endpoint = _validate_endpoint(endpoint)
2498
+
2499
+ def _show_logged_in_message(
2500
+ endpoint: str, dashboard_url: str, user: Optional[Dict[str, Any]],
2501
+ server_status: server_common.ApiServerStatus) -> None:
2502
+ """Show the logged in message."""
2503
+ if server_status != server_common.ApiServerStatus.HEALTHY:
2504
+ with ux_utils.print_exception_no_traceback():
2505
+ raise ValueError(f'Cannot log in API server at '
2506
+ f'{endpoint} (status: {server_status.value})')
2507
+
2508
+ identity_info = f'\n{ux_utils.INDENT_SYMBOL}{colorama.Fore.GREEN}User: '
2509
+ if user:
2510
+ user_name = user.get('name')
2511
+ user_id = user.get('id')
2512
+ if user_name and user_id:
2513
+ identity_info += f'{user_name} ({user_id})'
2514
+ elif user_id:
2515
+ identity_info += user_id
1836
2516
  else:
1837
- config = skypilot_config.get_user_config()
1838
- config.set_nested(('api_server', 'endpoint'), endpoint)
1839
- common_utils.dump_yaml(str(config_path), dict(config))
1840
- dashboard_url = server_common.get_dashboard_url(endpoint)
2517
+ identity_info = ''
1841
2518
  dashboard_msg = f'Dashboard: {dashboard_url}'
1842
2519
  click.secho(
1843
2520
  f'Logged into SkyPilot API server at: {endpoint}'
2521
+ f'{identity_info}'
1844
2522
  f'\n{ux_utils.INDENT_LAST_SYMBOL}{colorama.Fore.GREEN}'
1845
2523
  f'{dashboard_msg}',
1846
2524
  fg='green')
2525
+
2526
+ def _set_user_hash(user_hash: Optional[str]) -> None:
2527
+ if user_hash is not None:
2528
+ if not common_utils.is_valid_user_hash(user_hash):
2529
+ raise ValueError(f'Invalid user hash: {user_hash}')
2530
+ common_utils.set_user_hash_locally(user_hash)
2531
+
2532
+ # Handle service account token authentication
2533
+ if service_account_token:
2534
+ if not service_account_token.startswith('sky_'):
2535
+ raise ValueError('Invalid service account token format. '
2536
+ 'Token must start with "sky_"')
2537
+
2538
+ # Save both endpoint and token to config in a single operation
2539
+ _save_config_updates(endpoint=endpoint,
2540
+ service_account_token=service_account_token)
2541
+
2542
+ # Test the authentication by checking server health
2543
+ try:
2544
+ server_status, api_server_info = server_common.check_server_healthy(
2545
+ endpoint)
2546
+ dashboard_url = server_common.get_dashboard_url(endpoint)
2547
+ if api_server_info.user is not None:
2548
+ _set_user_hash(api_server_info.user.get('id'))
2549
+ _show_logged_in_message(endpoint, dashboard_url,
2550
+ api_server_info.user, server_status)
2551
+
2552
+ return
2553
+ except exceptions.ApiServerConnectionError as e:
2554
+ with ux_utils.print_exception_no_traceback():
2555
+ raise RuntimeError(
2556
+ f'Failed to connect to API server at {endpoint}: {e}'
2557
+ ) from e
2558
+ except Exception as e: # pylint: disable=broad-except
2559
+ with ux_utils.print_exception_no_traceback():
2560
+ raise RuntimeError(
2561
+ f'{colorama.Fore.RED}Service account token authentication '
2562
+ f'failed:{colorama.Style.RESET_ALL} {e}') from None
2563
+
2564
+ # OAuth2/cookie-based authentication flow
2565
+ # TODO(zhwu): this SDK sets global endpoint, which may not be the best
2566
+ # design as a user may expect this is only effective for the current
2567
+ # session. We should consider using env var for specifying endpoint.
2568
+
2569
+ server_status, api_server_info = server_common.check_server_healthy(
2570
+ endpoint)
2571
+ if server_status == server_common.ApiServerStatus.NEEDS_AUTH or relogin:
2572
+ # We detected an auth proxy, so go through the auth proxy cookie flow.
2573
+ token: Optional[str] = None
2574
+ server: Optional[oauth_lib.HTTPServer] = None
2575
+ try:
2576
+ callback_port = common_utils.find_free_port(8000)
2577
+
2578
+ token_container: Dict[str, Optional[str]] = {'token': None}
2579
+ logger.debug('Starting local authentication server...')
2580
+ server = oauth_lib.start_local_auth_server(callback_port,
2581
+ token_container,
2582
+ endpoint)
2583
+
2584
+ token_url = (f'{endpoint}/token?local_port={callback_port}')
2585
+ if webbrowser.open(token_url):
2586
+ click.echo(f'{colorama.Fore.GREEN}A web browser has been '
2587
+ f'opened at {token_url}. Please continue the login '
2588
+ f'in the web browser.{colorama.Style.RESET_ALL}\n'
2589
+ f'{colorama.Style.DIM}To manually copy the token, '
2590
+ f'press ctrl+c.{colorama.Style.RESET_ALL}')
2591
+ else:
2592
+ raise ValueError('Failed to open browser.')
2593
+
2594
+ start_time = time.time()
2595
+
2596
+ while (token_container['token'] is None and
2597
+ time.time() - start_time < oauth_lib.AUTH_TIMEOUT):
2598
+ time.sleep(1)
2599
+
2600
+ if token_container['token'] is None:
2601
+ click.echo(f'{colorama.Fore.YELLOW}Authentication timed out '
2602
+ f'after {oauth_lib.AUTH_TIMEOUT} seconds.')
2603
+ else:
2604
+ token = token_container['token']
2605
+
2606
+ except (Exception, KeyboardInterrupt) as e: # pylint: disable=broad-except
2607
+ logger.debug(f'Automatic authentication failed: {e}, '
2608
+ 'falling back to manual token entry.')
2609
+ if isinstance(e, KeyboardInterrupt):
2610
+ click.echo(f'\n{colorama.Style.DIM}Interrupted. Press ctrl+c '
2611
+ f'again to exit.{colorama.Style.RESET_ALL}')
2612
+ # Fall back to manual token entry
2613
+ token_url = f'{endpoint}/token'
2614
+ click.echo('Authentication is needed. Please visit this URL '
2615
+ f'to set up the token:{colorama.Style.BRIGHT}\n\n'
2616
+ f'{token_url}\n{colorama.Style.RESET_ALL}')
2617
+ token = click.prompt('Paste the token')
2618
+ finally:
2619
+ if server is not None:
2620
+ try:
2621
+ server.server_close()
2622
+ except Exception: # pylint: disable=broad-except
2623
+ pass
2624
+ if not token:
2625
+ with ux_utils.print_exception_no_traceback():
2626
+ raise ValueError('Authentication failed.')
2627
+
2628
+ # Parse the token.
2629
+ # b64decode will ignore invalid characters, but does some length and
2630
+ # padding checks.
2631
+ try:
2632
+ data = base64.b64decode(token)
2633
+ except binascii.Error as e:
2634
+ raise ValueError(f'Malformed token: {token}') from e
2635
+ logger.debug(f'Token data: {data!r}')
2636
+ try:
2637
+ json_data = json.loads(data)
2638
+ except (json.JSONDecodeError, UnicodeDecodeError) as e:
2639
+ raise ValueError(f'Malformed token data: {data!r}') from e
2640
+ if not isinstance(json_data, dict):
2641
+ raise ValueError(f'Malformed token JSON: {json_data}')
2642
+
2643
+ if json_data.get('v') == 1:
2644
+ user_hash = json_data.get('user')
2645
+ cookie_dict = json_data['cookies']
2646
+ elif 'v' not in json_data:
2647
+ user_hash = None
2648
+ cookie_dict = json_data
2649
+ else:
2650
+ raise ValueError(f'Unsupported token version: {json_data.get("v")}')
2651
+
2652
+ parsed_url = urlparse.urlparse(endpoint)
2653
+ cookie_jar = cookiejar.MozillaCookieJar()
2654
+ for (name, value) in cookie_dict.items():
2655
+ # dict keys in JSON must be strings
2656
+ assert isinstance(name, str)
2657
+ if not isinstance(value, str):
2658
+ raise ValueError('Malformed token - bad key/value: '
2659
+ f'{name}: {value}')
2660
+
2661
+ # See CookieJar._cookie_from_cookie_tuple
2662
+ # oauth2proxy default is Max-Age 604800
2663
+ expires = int(time.time()) + 604800
2664
+ domain = str(parsed_url.hostname)
2665
+ domain_initial_dot = domain.startswith('.')
2666
+ secure = parsed_url.scheme == 'https'
2667
+ if not domain_initial_dot:
2668
+ domain = '.' + domain
2669
+
2670
+ cookie_jar.set_cookie(
2671
+ cookiejar.Cookie(
2672
+ version=0,
2673
+ name=name,
2674
+ value=value,
2675
+ port=None,
2676
+ port_specified=False,
2677
+ domain=domain,
2678
+ domain_specified=True,
2679
+ domain_initial_dot=domain_initial_dot,
2680
+ path='',
2681
+ path_specified=False,
2682
+ secure=secure,
2683
+ expires=expires,
2684
+ discard=False,
2685
+ comment=None,
2686
+ comment_url=None,
2687
+ rest=dict(),
2688
+ ))
2689
+
2690
+ # Now that the cookies are parsed, save them to the cookie jar.
2691
+ server_common.set_api_cookie_jar(cookie_jar)
2692
+
2693
+ # Set the user hash in the local file.
2694
+ # If the server already has a token for this user set it to the local
2695
+ # file, otherwise use the new user hash.
2696
+ if (api_server_info.user is not None and
2697
+ api_server_info.user.get('id') is not None):
2698
+ _set_user_hash(api_server_info.user.get('id'))
2699
+ else:
2700
+ _set_user_hash(user_hash)
2701
+ else:
2702
+ # Check if basic auth is enabled
2703
+ if api_server_info.basic_auth_enabled:
2704
+ if api_server_info.user is None:
2705
+ with ux_utils.print_exception_no_traceback():
2706
+ raise ValueError(
2707
+ 'Basic auth is enabled but no valid user is found')
2708
+
2709
+ # Set the user hash in the local file.
2710
+ if api_server_info.user is not None:
2711
+ _set_user_hash(api_server_info.user.get('id'))
2712
+
2713
+ # Set the endpoint in the config file
2714
+ _save_config_updates(endpoint=endpoint)
2715
+ dashboard_url = server_common.get_dashboard_url(endpoint)
2716
+
2717
+ # see https://github.com/python/mypy/issues/5107 on why
2718
+ # typing is disabled on this line
2719
+ server_common.get_api_server_status.cache_clear() # type: ignore
2720
+ # After successful authentication, check server health again to get user
2721
+ # identity
2722
+ server_status, final_api_server_info = server_common.check_server_healthy(
2723
+ endpoint)
2724
+ _show_logged_in_message(endpoint, dashboard_url, final_api_server_info.user,
2725
+ server_status)
2726
+
2727
+
2728
+ @usage_lib.entrypoint
2729
+ @annotations.client_api
2730
+ def api_logout() -> None:
2731
+ """Logout of the API server.
2732
+
2733
+ Clears all cookies and settings stored in ~/.sky/config.yaml"""
2734
+ _check_endpoint_in_env_var(is_login=False)
2735
+
2736
+ if server_common.is_api_server_local():
2737
+ with ux_utils.print_exception_no_traceback():
2738
+ raise RuntimeError('Local api server cannot be logged out. '
2739
+ 'Use `sky api stop` instead.')
2740
+
2741
+ # no need to clear cookies if it doesn't exist.
2742
+ server_common.set_api_cookie_jar(cookiejar.MozillaCookieJar(),
2743
+ create_if_not_exists=False)
2744
+ _clear_api_server_config()
2745
+ logger.info(f'{colorama.Fore.GREEN}Logged out of SkyPilot API server.'
2746
+ f'{colorama.Style.RESET_ALL}')