skypilot-nightly 1.0.0.dev20250509__py3-none-any.whl → 1.0.0.dev20251107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of skypilot-nightly might be problematic. Click here for more details.

Files changed (512) hide show
  1. sky/__init__.py +22 -6
  2. sky/adaptors/aws.py +25 -7
  3. sky/adaptors/common.py +24 -1
  4. sky/adaptors/coreweave.py +278 -0
  5. sky/adaptors/do.py +8 -2
  6. sky/adaptors/hyperbolic.py +8 -0
  7. sky/adaptors/kubernetes.py +149 -18
  8. sky/adaptors/nebius.py +170 -17
  9. sky/adaptors/primeintellect.py +1 -0
  10. sky/adaptors/runpod.py +68 -0
  11. sky/adaptors/seeweb.py +167 -0
  12. sky/adaptors/shadeform.py +89 -0
  13. sky/admin_policy.py +187 -4
  14. sky/authentication.py +179 -225
  15. sky/backends/__init__.py +4 -2
  16. sky/backends/backend.py +22 -9
  17. sky/backends/backend_utils.py +1299 -380
  18. sky/backends/cloud_vm_ray_backend.py +1715 -518
  19. sky/backends/docker_utils.py +1 -1
  20. sky/backends/local_docker_backend.py +11 -6
  21. sky/backends/wheel_utils.py +37 -9
  22. sky/{clouds/service_catalog → catalog}/__init__.py +21 -19
  23. sky/{clouds/service_catalog → catalog}/aws_catalog.py +27 -8
  24. sky/{clouds/service_catalog → catalog}/azure_catalog.py +10 -7
  25. sky/{clouds/service_catalog → catalog}/common.py +89 -48
  26. sky/{clouds/service_catalog → catalog}/cudo_catalog.py +8 -5
  27. sky/{clouds/service_catalog → catalog}/data_fetchers/analyze.py +1 -1
  28. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_aws.py +30 -40
  29. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_cudo.py +38 -38
  30. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_gcp.py +42 -15
  31. sky/catalog/data_fetchers/fetch_hyperbolic.py +136 -0
  32. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_lambda_cloud.py +1 -0
  33. sky/catalog/data_fetchers/fetch_nebius.py +335 -0
  34. sky/catalog/data_fetchers/fetch_runpod.py +698 -0
  35. sky/catalog/data_fetchers/fetch_seeweb.py +329 -0
  36. sky/catalog/data_fetchers/fetch_shadeform.py +142 -0
  37. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vast.py +1 -1
  38. sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_vsphere.py +1 -1
  39. sky/{clouds/service_catalog → catalog}/do_catalog.py +5 -2
  40. sky/{clouds/service_catalog → catalog}/fluidstack_catalog.py +6 -3
  41. sky/{clouds/service_catalog → catalog}/gcp_catalog.py +41 -15
  42. sky/catalog/hyperbolic_catalog.py +136 -0
  43. sky/{clouds/service_catalog → catalog}/ibm_catalog.py +9 -6
  44. sky/{clouds/service_catalog → catalog}/kubernetes_catalog.py +36 -24
  45. sky/{clouds/service_catalog → catalog}/lambda_catalog.py +9 -6
  46. sky/{clouds/service_catalog → catalog}/nebius_catalog.py +9 -7
  47. sky/{clouds/service_catalog → catalog}/oci_catalog.py +9 -6
  48. sky/{clouds/service_catalog → catalog}/paperspace_catalog.py +5 -2
  49. sky/catalog/primeintellect_catalog.py +95 -0
  50. sky/{clouds/service_catalog → catalog}/runpod_catalog.py +11 -4
  51. sky/{clouds/service_catalog → catalog}/scp_catalog.py +9 -6
  52. sky/catalog/seeweb_catalog.py +184 -0
  53. sky/catalog/shadeform_catalog.py +165 -0
  54. sky/catalog/ssh_catalog.py +167 -0
  55. sky/{clouds/service_catalog → catalog}/vast_catalog.py +6 -3
  56. sky/{clouds/service_catalog → catalog}/vsphere_catalog.py +5 -2
  57. sky/check.py +491 -203
  58. sky/cli.py +5 -6005
  59. sky/client/{cli.py → cli/command.py} +2477 -1885
  60. sky/client/cli/deprecation_utils.py +99 -0
  61. sky/client/cli/flags.py +359 -0
  62. sky/client/cli/table_utils.py +320 -0
  63. sky/client/common.py +70 -32
  64. sky/client/oauth.py +82 -0
  65. sky/client/sdk.py +1203 -297
  66. sky/client/sdk_async.py +833 -0
  67. sky/client/service_account_auth.py +47 -0
  68. sky/cloud_stores.py +73 -0
  69. sky/clouds/__init__.py +13 -0
  70. sky/clouds/aws.py +358 -93
  71. sky/clouds/azure.py +105 -83
  72. sky/clouds/cloud.py +127 -36
  73. sky/clouds/cudo.py +68 -50
  74. sky/clouds/do.py +66 -48
  75. sky/clouds/fluidstack.py +63 -44
  76. sky/clouds/gcp.py +339 -110
  77. sky/clouds/hyperbolic.py +293 -0
  78. sky/clouds/ibm.py +70 -49
  79. sky/clouds/kubernetes.py +563 -162
  80. sky/clouds/lambda_cloud.py +74 -54
  81. sky/clouds/nebius.py +206 -80
  82. sky/clouds/oci.py +88 -66
  83. sky/clouds/paperspace.py +61 -44
  84. sky/clouds/primeintellect.py +317 -0
  85. sky/clouds/runpod.py +164 -74
  86. sky/clouds/scp.py +89 -83
  87. sky/clouds/seeweb.py +466 -0
  88. sky/clouds/shadeform.py +400 -0
  89. sky/clouds/ssh.py +263 -0
  90. sky/clouds/utils/aws_utils.py +10 -4
  91. sky/clouds/utils/gcp_utils.py +87 -11
  92. sky/clouds/utils/oci_utils.py +38 -14
  93. sky/clouds/utils/scp_utils.py +177 -124
  94. sky/clouds/vast.py +99 -77
  95. sky/clouds/vsphere.py +51 -40
  96. sky/core.py +349 -139
  97. sky/dag.py +15 -0
  98. sky/dashboard/out/404.html +1 -1
  99. sky/dashboard/out/_next/static/chunks/1141-e6aa9ab418717c59.js +11 -0
  100. sky/dashboard/out/_next/static/chunks/1272-1ef0bf0237faccdb.js +1 -0
  101. sky/dashboard/out/_next/static/chunks/1871-74503c8e80fd253b.js +6 -0
  102. sky/dashboard/out/_next/static/chunks/2260-7703229c33c5ebd5.js +1 -0
  103. sky/dashboard/out/_next/static/chunks/2350.fab69e61bac57b23.js +1 -0
  104. sky/dashboard/out/_next/static/chunks/2369.fc20f0c2c8ed9fe7.js +15 -0
  105. sky/dashboard/out/_next/static/chunks/2755.fff53c4a3fcae910.js +26 -0
  106. sky/dashboard/out/_next/static/chunks/3294.72362fa129305b19.js +1 -0
  107. sky/dashboard/out/_next/static/chunks/3785.ad6adaa2a0fa9768.js +1 -0
  108. sky/dashboard/out/_next/static/chunks/3850-ff4a9a69d978632b.js +1 -0
  109. sky/dashboard/out/_next/static/chunks/3937.210053269f121201.js +1 -0
  110. sky/dashboard/out/_next/static/chunks/4725.a830b5c9e7867c92.js +1 -0
  111. sky/dashboard/out/_next/static/chunks/4937.a2baa2df5572a276.js +15 -0
  112. sky/dashboard/out/_next/static/chunks/5739-d67458fcb1386c92.js +8 -0
  113. sky/dashboard/out/_next/static/chunks/6130-2be46d70a38f1e82.js +1 -0
  114. sky/dashboard/out/_next/static/chunks/616-3d59f75e2ccf9321.js +39 -0
  115. sky/dashboard/out/_next/static/chunks/6212-7bd06f60ba693125.js +13 -0
  116. sky/dashboard/out/_next/static/chunks/6601-06114c982db410b6.js +1 -0
  117. sky/dashboard/out/_next/static/chunks/6856-ef8ba11f96d8c4a3.js +1 -0
  118. sky/dashboard/out/_next/static/chunks/6989-01359c57e018caa4.js +1 -0
  119. sky/dashboard/out/_next/static/chunks/6990-32b6e2d3822301fa.js +1 -0
  120. sky/dashboard/out/_next/static/chunks/7359-c8d04e06886000b3.js +30 -0
  121. sky/dashboard/out/_next/static/chunks/7411-b15471acd2cba716.js +41 -0
  122. sky/dashboard/out/_next/static/chunks/7615-3301e838e5f25772.js +1 -0
  123. sky/dashboard/out/_next/static/chunks/8640.5b9475a2d18c5416.js +16 -0
  124. sky/dashboard/out/_next/static/chunks/8969-1e4613c651bf4051.js +1 -0
  125. sky/dashboard/out/_next/static/chunks/9025.fa408f3242e9028d.js +6 -0
  126. sky/dashboard/out/_next/static/chunks/9353-cff34f7e773b2e2b.js +1 -0
  127. sky/dashboard/out/_next/static/chunks/9360.7310982cf5a0dc79.js +31 -0
  128. sky/dashboard/out/_next/static/chunks/9847.3aaca6bb33455140.js +30 -0
  129. sky/dashboard/out/_next/static/chunks/fd9d1056-86323a29a8f7e46a.js +1 -0
  130. sky/dashboard/out/_next/static/chunks/framework-cf60a09ccd051a10.js +33 -0
  131. sky/dashboard/out/_next/static/chunks/main-app-587214043926b3cc.js +1 -0
  132. sky/dashboard/out/_next/static/chunks/main-f15ccb73239a3bf1.js +1 -0
  133. sky/dashboard/out/_next/static/chunks/pages/_app-bde01e4a2beec258.js +34 -0
  134. sky/dashboard/out/_next/static/chunks/pages/_error-c66a4e8afc46f17b.js +1 -0
  135. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-c736ead69c2d86ec.js +16 -0
  136. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-a37d2063af475a1c.js +1 -0
  137. sky/dashboard/out/_next/static/chunks/pages/clusters-d44859594e6f8064.js +1 -0
  138. sky/dashboard/out/_next/static/chunks/pages/config-dfb9bf07b13045f4.js +1 -0
  139. sky/dashboard/out/_next/static/chunks/pages/index-444f1804401f04ea.js +1 -0
  140. sky/dashboard/out/_next/static/chunks/pages/infra/[context]-c0b5935149902e6f.js +1 -0
  141. sky/dashboard/out/_next/static/chunks/pages/infra-aed0ea19df7cf961.js +1 -0
  142. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-5796e8d6aea291a0.js +16 -0
  143. sky/dashboard/out/_next/static/chunks/pages/jobs/pools/[pool]-6edeb7d06032adfc.js +21 -0
  144. sky/dashboard/out/_next/static/chunks/pages/jobs-479dde13399cf270.js +1 -0
  145. sky/dashboard/out/_next/static/chunks/pages/users-5ab3b907622cf0fe.js +1 -0
  146. sky/dashboard/out/_next/static/chunks/pages/volumes-b84b948ff357c43e.js +1 -0
  147. sky/dashboard/out/_next/static/chunks/pages/workspace/new-3f88a1c7e86a3f86.js +1 -0
  148. sky/dashboard/out/_next/static/chunks/pages/workspaces/[name]-c5a3eeee1c218af1.js +1 -0
  149. sky/dashboard/out/_next/static/chunks/pages/workspaces-22b23febb3e89ce1.js +1 -0
  150. sky/dashboard/out/_next/static/chunks/webpack-2679be77fc08a2f8.js +1 -0
  151. sky/dashboard/out/_next/static/css/0748ce22df867032.css +3 -0
  152. sky/dashboard/out/_next/static/zB0ed6ge_W1MDszVHhijS/_buildManifest.js +1 -0
  153. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  154. sky/dashboard/out/clusters/[cluster].html +1 -1
  155. sky/dashboard/out/clusters.html +1 -1
  156. sky/dashboard/out/config.html +1 -0
  157. sky/dashboard/out/index.html +1 -1
  158. sky/dashboard/out/infra/[context].html +1 -0
  159. sky/dashboard/out/infra.html +1 -0
  160. sky/dashboard/out/jobs/[job].html +1 -1
  161. sky/dashboard/out/jobs/pools/[pool].html +1 -0
  162. sky/dashboard/out/jobs.html +1 -1
  163. sky/dashboard/out/users.html +1 -0
  164. sky/dashboard/out/volumes.html +1 -0
  165. sky/dashboard/out/workspace/new.html +1 -0
  166. sky/dashboard/out/workspaces/[name].html +1 -0
  167. sky/dashboard/out/workspaces.html +1 -0
  168. sky/data/data_utils.py +137 -1
  169. sky/data/mounting_utils.py +269 -84
  170. sky/data/storage.py +1451 -1807
  171. sky/data/storage_utils.py +43 -57
  172. sky/exceptions.py +132 -2
  173. sky/execution.py +206 -63
  174. sky/global_user_state.py +2374 -586
  175. sky/jobs/__init__.py +5 -0
  176. sky/jobs/client/sdk.py +242 -65
  177. sky/jobs/client/sdk_async.py +143 -0
  178. sky/jobs/constants.py +9 -8
  179. sky/jobs/controller.py +839 -277
  180. sky/jobs/file_content_utils.py +80 -0
  181. sky/jobs/log_gc.py +201 -0
  182. sky/jobs/recovery_strategy.py +398 -152
  183. sky/jobs/scheduler.py +315 -189
  184. sky/jobs/server/core.py +829 -255
  185. sky/jobs/server/server.py +156 -115
  186. sky/jobs/server/utils.py +136 -0
  187. sky/jobs/state.py +2092 -701
  188. sky/jobs/utils.py +1242 -160
  189. sky/logs/__init__.py +21 -0
  190. sky/logs/agent.py +108 -0
  191. sky/logs/aws.py +243 -0
  192. sky/logs/gcp.py +91 -0
  193. sky/metrics/__init__.py +0 -0
  194. sky/metrics/utils.py +443 -0
  195. sky/models.py +78 -1
  196. sky/optimizer.py +164 -70
  197. sky/provision/__init__.py +90 -4
  198. sky/provision/aws/config.py +147 -26
  199. sky/provision/aws/instance.py +135 -50
  200. sky/provision/azure/instance.py +10 -5
  201. sky/provision/common.py +13 -1
  202. sky/provision/cudo/cudo_machine_type.py +1 -1
  203. sky/provision/cudo/cudo_utils.py +14 -8
  204. sky/provision/cudo/cudo_wrapper.py +72 -71
  205. sky/provision/cudo/instance.py +10 -6
  206. sky/provision/do/instance.py +10 -6
  207. sky/provision/do/utils.py +4 -3
  208. sky/provision/docker_utils.py +114 -23
  209. sky/provision/fluidstack/instance.py +13 -8
  210. sky/provision/gcp/__init__.py +1 -0
  211. sky/provision/gcp/config.py +301 -19
  212. sky/provision/gcp/constants.py +218 -0
  213. sky/provision/gcp/instance.py +36 -8
  214. sky/provision/gcp/instance_utils.py +18 -4
  215. sky/provision/gcp/volume_utils.py +247 -0
  216. sky/provision/hyperbolic/__init__.py +12 -0
  217. sky/provision/hyperbolic/config.py +10 -0
  218. sky/provision/hyperbolic/instance.py +437 -0
  219. sky/provision/hyperbolic/utils.py +373 -0
  220. sky/provision/instance_setup.py +93 -14
  221. sky/provision/kubernetes/__init__.py +5 -0
  222. sky/provision/kubernetes/config.py +9 -52
  223. sky/provision/kubernetes/constants.py +17 -0
  224. sky/provision/kubernetes/instance.py +789 -247
  225. sky/provision/kubernetes/manifests/fusermount-server-daemonset.yaml +1 -2
  226. sky/provision/kubernetes/network.py +27 -17
  227. sky/provision/kubernetes/network_utils.py +40 -43
  228. sky/provision/kubernetes/utils.py +1192 -531
  229. sky/provision/kubernetes/volume.py +282 -0
  230. sky/provision/lambda_cloud/instance.py +22 -16
  231. sky/provision/nebius/constants.py +50 -0
  232. sky/provision/nebius/instance.py +19 -6
  233. sky/provision/nebius/utils.py +196 -91
  234. sky/provision/oci/instance.py +10 -5
  235. sky/provision/paperspace/instance.py +10 -7
  236. sky/provision/paperspace/utils.py +1 -1
  237. sky/provision/primeintellect/__init__.py +10 -0
  238. sky/provision/primeintellect/config.py +11 -0
  239. sky/provision/primeintellect/instance.py +454 -0
  240. sky/provision/primeintellect/utils.py +398 -0
  241. sky/provision/provisioner.py +110 -36
  242. sky/provision/runpod/__init__.py +5 -0
  243. sky/provision/runpod/instance.py +27 -6
  244. sky/provision/runpod/utils.py +51 -18
  245. sky/provision/runpod/volume.py +180 -0
  246. sky/provision/scp/__init__.py +15 -0
  247. sky/provision/scp/config.py +93 -0
  248. sky/provision/scp/instance.py +531 -0
  249. sky/provision/seeweb/__init__.py +11 -0
  250. sky/provision/seeweb/config.py +13 -0
  251. sky/provision/seeweb/instance.py +807 -0
  252. sky/provision/shadeform/__init__.py +11 -0
  253. sky/provision/shadeform/config.py +12 -0
  254. sky/provision/shadeform/instance.py +351 -0
  255. sky/provision/shadeform/shadeform_utils.py +83 -0
  256. sky/provision/ssh/__init__.py +18 -0
  257. sky/provision/vast/instance.py +13 -8
  258. sky/provision/vast/utils.py +10 -7
  259. sky/provision/vsphere/common/vim_utils.py +1 -2
  260. sky/provision/vsphere/instance.py +15 -10
  261. sky/provision/vsphere/vsphere_utils.py +9 -19
  262. sky/py.typed +0 -0
  263. sky/resources.py +844 -118
  264. sky/schemas/__init__.py +0 -0
  265. sky/schemas/api/__init__.py +0 -0
  266. sky/schemas/api/responses.py +225 -0
  267. sky/schemas/db/README +4 -0
  268. sky/schemas/db/env.py +90 -0
  269. sky/schemas/db/global_user_state/001_initial_schema.py +124 -0
  270. sky/schemas/db/global_user_state/002_add_workspace_to_cluster_history.py +35 -0
  271. sky/schemas/db/global_user_state/003_fix_initial_revision.py +61 -0
  272. sky/schemas/db/global_user_state/004_is_managed.py +34 -0
  273. sky/schemas/db/global_user_state/005_cluster_event.py +32 -0
  274. sky/schemas/db/global_user_state/006_provision_log.py +41 -0
  275. sky/schemas/db/global_user_state/007_cluster_event_request_id.py +34 -0
  276. sky/schemas/db/global_user_state/008_skylet_ssh_tunnel_metadata.py +34 -0
  277. sky/schemas/db/global_user_state/009_last_activity_and_launched_at.py +89 -0
  278. sky/schemas/db/global_user_state/010_save_ssh_key.py +66 -0
  279. sky/schemas/db/script.py.mako +28 -0
  280. sky/schemas/db/serve_state/001_initial_schema.py +67 -0
  281. sky/schemas/db/skypilot_config/001_initial_schema.py +30 -0
  282. sky/schemas/db/spot_jobs/001_initial_schema.py +97 -0
  283. sky/schemas/db/spot_jobs/002_cluster_pool.py +42 -0
  284. sky/schemas/db/spot_jobs/003_pool_hash.py +34 -0
  285. sky/schemas/db/spot_jobs/004_job_file_contents.py +42 -0
  286. sky/schemas/db/spot_jobs/005_logs_gc.py +38 -0
  287. sky/schemas/generated/__init__.py +0 -0
  288. sky/schemas/generated/autostopv1_pb2.py +36 -0
  289. sky/schemas/generated/autostopv1_pb2.pyi +43 -0
  290. sky/schemas/generated/autostopv1_pb2_grpc.py +146 -0
  291. sky/schemas/generated/jobsv1_pb2.py +86 -0
  292. sky/schemas/generated/jobsv1_pb2.pyi +254 -0
  293. sky/schemas/generated/jobsv1_pb2_grpc.py +542 -0
  294. sky/schemas/generated/managed_jobsv1_pb2.py +74 -0
  295. sky/schemas/generated/managed_jobsv1_pb2.pyi +278 -0
  296. sky/schemas/generated/managed_jobsv1_pb2_grpc.py +278 -0
  297. sky/schemas/generated/servev1_pb2.py +58 -0
  298. sky/schemas/generated/servev1_pb2.pyi +115 -0
  299. sky/schemas/generated/servev1_pb2_grpc.py +322 -0
  300. sky/serve/autoscalers.py +357 -5
  301. sky/serve/client/impl.py +310 -0
  302. sky/serve/client/sdk.py +47 -139
  303. sky/serve/client/sdk_async.py +130 -0
  304. sky/serve/constants.py +10 -8
  305. sky/serve/controller.py +64 -19
  306. sky/serve/load_balancer.py +106 -60
  307. sky/serve/load_balancing_policies.py +115 -1
  308. sky/serve/replica_managers.py +273 -162
  309. sky/serve/serve_rpc_utils.py +179 -0
  310. sky/serve/serve_state.py +554 -251
  311. sky/serve/serve_utils.py +733 -220
  312. sky/serve/server/core.py +66 -711
  313. sky/serve/server/impl.py +1093 -0
  314. sky/serve/server/server.py +21 -18
  315. sky/serve/service.py +133 -48
  316. sky/serve/service_spec.py +135 -16
  317. sky/serve/spot_placer.py +3 -0
  318. sky/server/auth/__init__.py +0 -0
  319. sky/server/auth/authn.py +50 -0
  320. sky/server/auth/loopback.py +38 -0
  321. sky/server/auth/oauth2_proxy.py +200 -0
  322. sky/server/common.py +475 -181
  323. sky/server/config.py +81 -23
  324. sky/server/constants.py +44 -6
  325. sky/server/daemons.py +229 -0
  326. sky/server/html/token_page.html +185 -0
  327. sky/server/metrics.py +160 -0
  328. sky/server/requests/executor.py +528 -138
  329. sky/server/requests/payloads.py +351 -17
  330. sky/server/requests/preconditions.py +21 -17
  331. sky/server/requests/process.py +112 -29
  332. sky/server/requests/request_names.py +120 -0
  333. sky/server/requests/requests.py +817 -224
  334. sky/server/requests/serializers/decoders.py +82 -31
  335. sky/server/requests/serializers/encoders.py +140 -22
  336. sky/server/requests/threads.py +106 -0
  337. sky/server/rest.py +417 -0
  338. sky/server/server.py +1290 -284
  339. sky/server/state.py +20 -0
  340. sky/server/stream_utils.py +345 -57
  341. sky/server/uvicorn.py +217 -3
  342. sky/server/versions.py +270 -0
  343. sky/setup_files/MANIFEST.in +5 -0
  344. sky/setup_files/alembic.ini +156 -0
  345. sky/setup_files/dependencies.py +136 -31
  346. sky/setup_files/setup.py +44 -42
  347. sky/sky_logging.py +102 -5
  348. sky/skylet/attempt_skylet.py +1 -0
  349. sky/skylet/autostop_lib.py +129 -8
  350. sky/skylet/configs.py +27 -20
  351. sky/skylet/constants.py +171 -19
  352. sky/skylet/events.py +105 -21
  353. sky/skylet/job_lib.py +335 -104
  354. sky/skylet/log_lib.py +297 -18
  355. sky/skylet/log_lib.pyi +44 -1
  356. sky/skylet/ray_patches/__init__.py +17 -3
  357. sky/skylet/ray_patches/autoscaler.py.diff +18 -0
  358. sky/skylet/ray_patches/cli.py.diff +19 -0
  359. sky/skylet/ray_patches/command_runner.py.diff +17 -0
  360. sky/skylet/ray_patches/log_monitor.py.diff +20 -0
  361. sky/skylet/ray_patches/resource_demand_scheduler.py.diff +32 -0
  362. sky/skylet/ray_patches/updater.py.diff +18 -0
  363. sky/skylet/ray_patches/worker.py.diff +41 -0
  364. sky/skylet/services.py +564 -0
  365. sky/skylet/skylet.py +63 -4
  366. sky/skylet/subprocess_daemon.py +103 -29
  367. sky/skypilot_config.py +506 -99
  368. sky/ssh_node_pools/__init__.py +1 -0
  369. sky/ssh_node_pools/core.py +135 -0
  370. sky/ssh_node_pools/server.py +233 -0
  371. sky/task.py +621 -137
  372. sky/templates/aws-ray.yml.j2 +10 -3
  373. sky/templates/azure-ray.yml.j2 +1 -1
  374. sky/templates/do-ray.yml.j2 +1 -1
  375. sky/templates/gcp-ray.yml.j2 +57 -0
  376. sky/templates/hyperbolic-ray.yml.j2 +67 -0
  377. sky/templates/jobs-controller.yaml.j2 +27 -24
  378. sky/templates/kubernetes-loadbalancer.yml.j2 +2 -0
  379. sky/templates/kubernetes-ray.yml.j2 +607 -51
  380. sky/templates/lambda-ray.yml.j2 +1 -1
  381. sky/templates/nebius-ray.yml.j2 +33 -12
  382. sky/templates/paperspace-ray.yml.j2 +1 -1
  383. sky/templates/primeintellect-ray.yml.j2 +71 -0
  384. sky/templates/runpod-ray.yml.j2 +9 -1
  385. sky/templates/scp-ray.yml.j2 +3 -50
  386. sky/templates/seeweb-ray.yml.j2 +108 -0
  387. sky/templates/shadeform-ray.yml.j2 +72 -0
  388. sky/templates/sky-serve-controller.yaml.j2 +22 -2
  389. sky/templates/websocket_proxy.py +178 -18
  390. sky/usage/usage_lib.py +18 -11
  391. sky/users/__init__.py +0 -0
  392. sky/users/model.conf +15 -0
  393. sky/users/permission.py +387 -0
  394. sky/users/rbac.py +121 -0
  395. sky/users/server.py +720 -0
  396. sky/users/token_service.py +218 -0
  397. sky/utils/accelerator_registry.py +34 -5
  398. sky/utils/admin_policy_utils.py +84 -38
  399. sky/utils/annotations.py +16 -5
  400. sky/utils/asyncio_utils.py +78 -0
  401. sky/utils/auth_utils.py +153 -0
  402. sky/utils/benchmark_utils.py +60 -0
  403. sky/utils/cli_utils/status_utils.py +159 -86
  404. sky/utils/cluster_utils.py +31 -9
  405. sky/utils/command_runner.py +354 -68
  406. sky/utils/command_runner.pyi +93 -3
  407. sky/utils/common.py +35 -8
  408. sky/utils/common_utils.py +310 -87
  409. sky/utils/config_utils.py +87 -5
  410. sky/utils/context.py +402 -0
  411. sky/utils/context_utils.py +222 -0
  412. sky/utils/controller_utils.py +264 -89
  413. sky/utils/dag_utils.py +31 -12
  414. sky/utils/db/__init__.py +0 -0
  415. sky/utils/db/db_utils.py +470 -0
  416. sky/utils/db/migration_utils.py +133 -0
  417. sky/utils/directory_utils.py +12 -0
  418. sky/utils/env_options.py +13 -0
  419. sky/utils/git.py +567 -0
  420. sky/utils/git_clone.sh +460 -0
  421. sky/utils/infra_utils.py +195 -0
  422. sky/utils/kubernetes/cleanup-tunnel.sh +62 -0
  423. sky/utils/kubernetes/config_map_utils.py +133 -0
  424. sky/utils/kubernetes/create_cluster.sh +13 -27
  425. sky/utils/kubernetes/delete_cluster.sh +10 -7
  426. sky/utils/kubernetes/deploy_remote_cluster.py +1299 -0
  427. sky/utils/kubernetes/exec_kubeconfig_converter.py +22 -31
  428. sky/utils/kubernetes/generate_kind_config.py +6 -66
  429. sky/utils/kubernetes/generate_kubeconfig.sh +4 -1
  430. sky/utils/kubernetes/gpu_labeler.py +5 -5
  431. sky/utils/kubernetes/kubernetes_deploy_utils.py +354 -47
  432. sky/utils/kubernetes/ssh-tunnel.sh +379 -0
  433. sky/utils/kubernetes/ssh_utils.py +221 -0
  434. sky/utils/kubernetes_enums.py +8 -15
  435. sky/utils/lock_events.py +94 -0
  436. sky/utils/locks.py +368 -0
  437. sky/utils/log_utils.py +300 -6
  438. sky/utils/perf_utils.py +22 -0
  439. sky/utils/resource_checker.py +298 -0
  440. sky/utils/resources_utils.py +249 -32
  441. sky/utils/rich_utils.py +213 -37
  442. sky/utils/schemas.py +905 -147
  443. sky/utils/serialize_utils.py +16 -0
  444. sky/utils/status_lib.py +10 -0
  445. sky/utils/subprocess_utils.py +38 -15
  446. sky/utils/tempstore.py +70 -0
  447. sky/utils/timeline.py +24 -52
  448. sky/utils/ux_utils.py +84 -15
  449. sky/utils/validator.py +11 -1
  450. sky/utils/volume.py +86 -0
  451. sky/utils/yaml_utils.py +111 -0
  452. sky/volumes/__init__.py +13 -0
  453. sky/volumes/client/__init__.py +0 -0
  454. sky/volumes/client/sdk.py +149 -0
  455. sky/volumes/server/__init__.py +0 -0
  456. sky/volumes/server/core.py +258 -0
  457. sky/volumes/server/server.py +122 -0
  458. sky/volumes/volume.py +212 -0
  459. sky/workspaces/__init__.py +0 -0
  460. sky/workspaces/core.py +655 -0
  461. sky/workspaces/server.py +101 -0
  462. sky/workspaces/utils.py +56 -0
  463. skypilot_nightly-1.0.0.dev20251107.dist-info/METADATA +675 -0
  464. skypilot_nightly-1.0.0.dev20251107.dist-info/RECORD +594 -0
  465. {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/WHEEL +1 -1
  466. sky/benchmark/benchmark_state.py +0 -256
  467. sky/benchmark/benchmark_utils.py +0 -641
  468. sky/clouds/service_catalog/constants.py +0 -7
  469. sky/dashboard/out/_next/static/LksQgChY5izXjokL3LcEu/_buildManifest.js +0 -1
  470. sky/dashboard/out/_next/static/chunks/236-f49500b82ad5392d.js +0 -6
  471. sky/dashboard/out/_next/static/chunks/312-c3c8845990db8ffc.js +0 -15
  472. sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
  473. sky/dashboard/out/_next/static/chunks/678-206dddca808e6d16.js +0 -59
  474. sky/dashboard/out/_next/static/chunks/845-0f8017370869e269.js +0 -1
  475. sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
  476. sky/dashboard/out/_next/static/chunks/fd9d1056-2821b0f0cabcd8bd.js +0 -1
  477. sky/dashboard/out/_next/static/chunks/framework-87d061ee6ed71b28.js +0 -33
  478. sky/dashboard/out/_next/static/chunks/main-app-241eb28595532291.js +0 -1
  479. sky/dashboard/out/_next/static/chunks/main-e0e2335212e72357.js +0 -1
  480. sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
  481. sky/dashboard/out/_next/static/chunks/pages/_error-1be831200e60c5c0.js +0 -1
  482. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-e15db85d0ea1fbe1.js +0 -1
  483. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
  484. sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
  485. sky/dashboard/out/_next/static/chunks/pages/index-f9f039532ca8cbc4.js +0 -1
  486. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-03f279c6741fb48b.js +0 -1
  487. sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
  488. sky/dashboard/out/_next/static/chunks/webpack-830f59b8404e96b8.js +0 -1
  489. sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
  490. sky/jobs/dashboard/dashboard.py +0 -223
  491. sky/jobs/dashboard/static/favicon.ico +0 -0
  492. sky/jobs/dashboard/templates/index.html +0 -831
  493. sky/jobs/server/dashboard_utils.py +0 -69
  494. sky/skylet/providers/scp/__init__.py +0 -2
  495. sky/skylet/providers/scp/config.py +0 -149
  496. sky/skylet/providers/scp/node_provider.py +0 -578
  497. sky/templates/kubernetes-ssh-jump.yml.j2 +0 -94
  498. sky/utils/db_utils.py +0 -100
  499. sky/utils/kubernetes/deploy_remote_cluster.sh +0 -308
  500. sky/utils/kubernetes/ssh_jump_lifecycle_manager.py +0 -191
  501. skypilot_nightly-1.0.0.dev20250509.dist-info/METADATA +0 -361
  502. skypilot_nightly-1.0.0.dev20250509.dist-info/RECORD +0 -396
  503. /sky/{clouds/service_catalog → catalog}/config.py +0 -0
  504. /sky/{benchmark → catalog/data_fetchers}/__init__.py +0 -0
  505. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_azure.py +0 -0
  506. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_fluidstack.py +0 -0
  507. /sky/{clouds/service_catalog → catalog}/data_fetchers/fetch_ibm.py +0 -0
  508. /sky/{clouds/service_catalog/data_fetchers → client/cli}/__init__.py +0 -0
  509. /sky/dashboard/out/_next/static/{LksQgChY5izXjokL3LcEu → zB0ed6ge_W1MDszVHhijS}/_ssgManifest.js +0 -0
  510. {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/entry_points.txt +0 -0
  511. {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/licenses/LICENSE +0 -0
  512. {skypilot_nightly-1.0.0.dev20250509.dist-info → skypilot_nightly-1.0.0.dev20251107.dist-info}/top_level.txt +0 -0
sky/client/sdk.py CHANGED
@@ -10,75 +10,164 @@ Usage example:
10
10
  statuses = sky.get(request_id)
11
11
 
12
12
  """
13
- import getpass
13
+ from http import cookiejar
14
14
  import json
15
15
  import logging
16
16
  import os
17
- import pathlib
18
17
  import subprocess
19
18
  import typing
20
- from typing import Any, Dict, List, Optional, Tuple, Union
21
- import webbrowser
19
+ from typing import (Any, Dict, Iterator, List, Literal, Optional, Tuple,
20
+ TypeVar, Union)
21
+ from urllib import parse as urlparse
22
22
 
23
23
  import click
24
24
  import colorama
25
25
  import filelock
26
26
 
27
27
  from sky import admin_policy
28
- from sky import backends
29
28
  from sky import exceptions
30
29
  from sky import sky_logging
31
30
  from sky import skypilot_config
32
31
  from sky.adaptors import common as adaptors_common
33
32
  from sky.client import common as client_common
33
+ from sky.client import oauth as oauth_lib
34
+ from sky.jobs import scheduler
35
+ from sky.schemas.api import responses
34
36
  from sky.server import common as server_common
37
+ from sky.server import rest
38
+ from sky.server import versions
35
39
  from sky.server.requests import payloads
40
+ from sky.server.requests import request_names
36
41
  from sky.server.requests import requests as requests_lib
42
+ from sky.skylet import autostop_lib
37
43
  from sky.skylet import constants
38
44
  from sky.usage import usage_lib
45
+ from sky.utils import admin_policy_utils
39
46
  from sky.utils import annotations
40
47
  from sky.utils import cluster_utils
41
48
  from sky.utils import common
42
49
  from sky.utils import common_utils
50
+ from sky.utils import context as sky_context
43
51
  from sky.utils import dag_utils
44
52
  from sky.utils import env_options
53
+ from sky.utils import infra_utils
45
54
  from sky.utils import rich_utils
46
55
  from sky.utils import status_lib
47
56
  from sky.utils import subprocess_utils
48
57
  from sky.utils import ux_utils
58
+ from sky.utils import yaml_utils
59
+ from sky.utils.kubernetes import ssh_utils
49
60
 
50
61
  if typing.TYPE_CHECKING:
62
+ import base64
63
+ import binascii
51
64
  import io
65
+ import pathlib
66
+ import time
67
+ import webbrowser
52
68
 
53
69
  import psutil
54
70
  import requests
55
71
 
56
72
  import sky
73
+ from sky import backends
74
+ from sky import catalog
75
+ from sky import models
76
+ from sky.provision.kubernetes import utils as kubernetes_utils
77
+ from sky.skylet import job_lib
57
78
  else:
79
+ # only used in api_login()
80
+ base64 = adaptors_common.LazyImport('base64')
81
+ binascii = adaptors_common.LazyImport('binascii')
82
+ pathlib = adaptors_common.LazyImport('pathlib')
83
+ time = adaptors_common.LazyImport('time')
84
+ # only used in dashboard() and api_login()
85
+ webbrowser = adaptors_common.LazyImport('webbrowser')
86
+ # only used in api_stop()
58
87
  psutil = adaptors_common.LazyImport('psutil')
59
- requests = adaptors_common.LazyImport('requests')
60
88
 
61
89
  logger = sky_logging.init_logger(__name__)
62
90
  logging.getLogger('httpx').setLevel(logging.CRITICAL)
63
91
 
92
+ _LINE_PROCESSED_KEY = 'line_processed'
64
93
 
65
- def stream_response(request_id: Optional[str],
94
+ T = TypeVar('T')
95
+
96
+
97
+ def reload_config() -> None:
98
+ """Reloads the client-side config."""
99
+ skypilot_config.safe_reload_config()
100
+
101
+
102
+ # The overloads are not comprehensive - e.g. get_result Literal[False] could be
103
+ # specified to return None. We can add more overloads if needed. To do that see
104
+ # https://github.com/python/mypy/issues/8634#issuecomment-609411104
105
+ @typing.overload
106
+ def stream_response(request_id: None,
66
107
  response: 'requests.Response',
67
- output_stream: Optional['io.TextIOBase'] = None) -> Any:
108
+ output_stream: Optional['io.TextIOBase'] = None,
109
+ resumable: bool = False,
110
+ get_result: bool = True) -> None:
111
+ ...
112
+
113
+
114
+ @typing.overload
115
+ def stream_response(request_id: server_common.RequestId[T],
116
+ response: 'requests.Response',
117
+ output_stream: Optional['io.TextIOBase'] = None,
118
+ resumable: bool = False,
119
+ get_result: Literal[True] = True) -> T:
120
+ ...
121
+
122
+
123
+ @typing.overload
124
+ def stream_response(request_id: server_common.RequestId[T],
125
+ response: 'requests.Response',
126
+ output_stream: Optional['io.TextIOBase'] = None,
127
+ resumable: bool = False,
128
+ get_result: bool = True) -> Optional[T]:
129
+ ...
130
+
131
+
132
+ def stream_response(request_id: Optional[server_common.RequestId[T]],
133
+ response: 'requests.Response',
134
+ output_stream: Optional['io.TextIOBase'] = None,
135
+ resumable: bool = False,
136
+ get_result: bool = True) -> Optional[T]:
68
137
  """Streams the response to the console.
69
138
 
70
139
  Args:
71
- request_id: The request ID.
140
+ request_id: The request ID of the request to stream. May be a full
141
+ request ID or a prefix.
142
+ If None, the latest request submitted to the API server is streamed.
143
+ Using None request_id is not recommended in multi-user environments.
72
144
  response: The HTTP response.
73
145
  output_stream: The output stream to write to. If None, print to the
74
146
  console.
147
+ resumable: Whether the response is resumable on retry. If True, the
148
+ streaming will start from the previous failure point on retry.
149
+ get_result: Whether to get the result of the request. This will
150
+ typically be set to False for `--no-follow` flags as requests may
151
+ continue to run for long periods of time without further streaming.
75
152
  """
76
153
 
154
+ retry_context: Optional[rest.RetryContext] = None
155
+ if resumable:
156
+ retry_context = rest.get_retry_context()
77
157
  try:
158
+ line_count = 0
78
159
  for line in rich_utils.decode_rich_status(response):
79
160
  if line is not None:
80
- print(line, flush=True, end='', file=output_stream)
81
- return get(request_id)
161
+ line_count += 1
162
+ if retry_context is None:
163
+ print(line, flush=True, end='', file=output_stream)
164
+ elif line_count > retry_context.line_processed:
165
+ print(line, flush=True, end='', file=output_stream)
166
+ retry_context.line_processed = line_count
167
+ if request_id is not None and get_result:
168
+ return get(request_id)
169
+ else:
170
+ return None
82
171
  except Exception: # pylint: disable=broad-except
83
172
  logger.debug(f'To stream request logs: sky api logs {request_id}')
84
173
  raise
@@ -87,13 +176,18 @@ def stream_response(request_id: Optional[str],
87
176
  @usage_lib.entrypoint
88
177
  @server_common.check_server_healthy_or_start
89
178
  @annotations.client_api
90
- def check(clouds: Optional[Tuple[str]],
91
- verbose: bool) -> server_common.RequestId:
179
+ def check(
180
+ infra_list: Optional[Tuple[str, ...]],
181
+ verbose: bool,
182
+ workspace: Optional[str] = None
183
+ ) -> server_common.RequestId[Dict[str, List[str]]]:
92
184
  """Checks the credentials to enable clouds.
93
185
 
94
186
  Args:
95
- clouds: The clouds to check.
187
+ infra: The infra to check.
96
188
  verbose: Whether to show verbose output.
189
+ workspace: The workspace to check. If None, all workspaces will be
190
+ checked.
97
191
 
98
192
  Returns:
99
193
  The request ID of the check request.
@@ -101,41 +195,69 @@ def check(clouds: Optional[Tuple[str]],
101
195
  Request Returns:
102
196
  None
103
197
  """
104
- body = payloads.CheckBody(clouds=clouds, verbose=verbose)
105
- response = requests.post(f'{server_common.get_server_url()}/check',
106
- json=json.loads(body.model_dump_json()),
107
- cookies=server_common.get_api_cookie_jar())
198
+ if infra_list is None:
199
+ clouds = None
200
+ else:
201
+ specified_clouds = []
202
+ for infra_str in infra_list:
203
+ infra = infra_utils.InfraInfo.from_str(infra_str)
204
+ if infra.cloud is None:
205
+ with ux_utils.print_exception_no_traceback():
206
+ raise ValueError(f'Invalid infra to check: {infra_str}')
207
+ if infra.region is not None or infra.zone is not None:
208
+ region_zone = infra_str.partition('/')[-1]
209
+ logger.warning(f'Infra {infra_str} is specified, but `check` '
210
+ f'only supports checking {infra.cloud}, '
211
+ f'ignoring {region_zone}')
212
+ specified_clouds.append(infra.cloud)
213
+ clouds = tuple(specified_clouds)
214
+ body = payloads.CheckBody(clouds=clouds,
215
+ verbose=verbose,
216
+ workspace=workspace)
217
+ response = server_common.make_authenticated_request(
218
+ 'POST', '/check', json=json.loads(body.model_dump_json()))
108
219
  return server_common.get_request_id(response)
109
220
 
110
221
 
111
222
  @usage_lib.entrypoint
112
223
  @server_common.check_server_healthy_or_start
113
224
  @annotations.client_api
114
- def enabled_clouds() -> server_common.RequestId:
225
+ def enabled_clouds(workspace: Optional[str] = None,
226
+ expand: bool = False) -> server_common.RequestId[List[str]]:
115
227
  """Gets the enabled clouds.
116
228
 
229
+ Args:
230
+ workspace: The workspace to get the enabled clouds for. If None, the
231
+ active workspace will be used.
232
+ expand: Whether to expand Kubernetes and SSH to list of resource pools.
233
+
117
234
  Returns:
118
235
  The request ID of the enabled clouds request.
119
236
 
120
237
  Request Returns:
121
238
  A list of enabled clouds in string format.
122
239
  """
123
- response = requests.get(f'{server_common.get_server_url()}/enabled_clouds',
124
- cookies=server_common.get_api_cookie_jar())
240
+ if workspace is None:
241
+ workspace = skypilot_config.get_active_workspace()
242
+ response = server_common.make_authenticated_request(
243
+ 'GET', f'/enabled_clouds?workspace={workspace}&expand={expand}')
125
244
  return server_common.get_request_id(response)
126
245
 
127
246
 
128
247
  @usage_lib.entrypoint
129
248
  @server_common.check_server_healthy_or_start
130
249
  @annotations.client_api
131
- def list_accelerators(gpus_only: bool = True,
132
- name_filter: Optional[str] = None,
133
- region_filter: Optional[str] = None,
134
- quantity_filter: Optional[int] = None,
135
- clouds: Optional[Union[List[str], str]] = None,
136
- all_regions: bool = False,
137
- require_price: bool = True,
138
- case_sensitive: bool = True) -> server_common.RequestId:
250
+ def list_accelerators(
251
+ gpus_only: bool = True,
252
+ name_filter: Optional[str] = None,
253
+ region_filter: Optional[str] = None,
254
+ quantity_filter: Optional[int] = None,
255
+ clouds: Optional[Union[List[str], str]] = None,
256
+ all_regions: bool = False,
257
+ require_price: bool = True,
258
+ case_sensitive: bool = True
259
+ ) -> server_common.RequestId[Dict[str,
260
+ List['catalog.common.InstanceTypeInfo']]]:
139
261
  """Lists the names of all accelerators offered by Sky.
140
262
 
141
263
  This will include all accelerators offered by Sky, including those
@@ -169,10 +291,8 @@ def list_accelerators(gpus_only: bool = True,
169
291
  require_price=require_price,
170
292
  case_sensitive=case_sensitive,
171
293
  )
172
- response = requests.post(
173
- f'{server_common.get_server_url()}/list_accelerators',
174
- json=json.loads(body.model_dump_json()),
175
- cookies=server_common.get_api_cookie_jar())
294
+ response = server_common.make_authenticated_request(
295
+ 'POST', '/list_accelerators', json=json.loads(body.model_dump_json()))
176
296
  return server_common.get_request_id(response)
177
297
 
178
298
 
@@ -180,12 +300,12 @@ def list_accelerators(gpus_only: bool = True,
180
300
  @server_common.check_server_healthy_or_start
181
301
  @annotations.client_api
182
302
  def list_accelerator_counts(
183
- gpus_only: bool = True,
184
- name_filter: Optional[str] = None,
185
- region_filter: Optional[str] = None,
186
- quantity_filter: Optional[int] = None,
187
- clouds: Optional[Union[List[str],
188
- str]] = None) -> server_common.RequestId:
303
+ gpus_only: bool = True,
304
+ name_filter: Optional[str] = None,
305
+ region_filter: Optional[str] = None,
306
+ quantity_filter: Optional[int] = None,
307
+ clouds: Optional[Union[List[str], str]] = None
308
+ ) -> server_common.RequestId[Dict[str, List[float]]]:
189
309
  """Lists all accelerators offered by Sky and available counts.
190
310
 
191
311
  Args:
@@ -203,17 +323,17 @@ def list_accelerator_counts(
203
323
  accelerator names mapped to a list of available counts. See usage
204
324
  in cli.py.
205
325
  """
206
- body = payloads.ListAcceleratorsBody(
326
+ body = payloads.ListAcceleratorCountsBody(
207
327
  gpus_only=gpus_only,
208
328
  name_filter=name_filter,
209
329
  region_filter=region_filter,
210
330
  quantity_filter=quantity_filter,
211
331
  clouds=clouds,
212
332
  )
213
- response = requests.post(
214
- f'{server_common.get_server_url()}/list_accelerator_counts',
215
- json=json.loads(body.model_dump_json()),
216
- cookies=server_common.get_api_cookie_jar())
333
+ response = server_common.make_authenticated_request(
334
+ 'POST',
335
+ '/list_accelerator_counts',
336
+ json=json.loads(body.model_dump_json()))
217
337
  return server_common.get_request_id(response)
218
338
 
219
339
 
@@ -224,7 +344,7 @@ def optimize(
224
344
  dag: 'sky.Dag',
225
345
  minimize: common.OptimizeTarget = common.OptimizeTarget.COST,
226
346
  admin_policy_request_options: Optional[admin_policy.RequestOptions] = None
227
- ) -> server_common.RequestId:
347
+ ) -> server_common.RequestId['sky.Dag']:
228
348
  """Finds the best execution plan for the given DAG.
229
349
 
230
350
  Args:
@@ -250,9 +370,14 @@ def optimize(
250
370
  body = payloads.OptimizeBody(dag=dag_str,
251
371
  minimize=minimize,
252
372
  request_options=admin_policy_request_options)
253
- response = requests.post(f'{server_common.get_server_url()}/optimize',
254
- json=json.loads(body.model_dump_json()),
255
- cookies=server_common.get_api_cookie_jar())
373
+ response = server_common.make_authenticated_request(
374
+ 'POST', '/optimize', json=json.loads(body.model_dump_json()))
375
+ return server_common.get_request_id(response)
376
+
377
+
378
+ def workspaces() -> server_common.RequestId[Dict[str, Any]]:
379
+ """Gets the workspaces."""
380
+ response = server_common.make_authenticated_request('GET', '/workspaces')
256
381
  return server_common.get_request_id(response)
257
382
 
258
383
 
@@ -279,16 +404,22 @@ def validate(
279
404
  validation. This is only required when a admin policy is in use,
280
405
  see: https://docs.skypilot.co/en/latest/cloud-setup/policy.html
281
406
  """
407
+ remote_api_version = versions.get_remote_api_version()
408
+ # TODO(kevin): remove this in v0.13.0
409
+ omit_user_specified_yaml = (remote_api_version is None or
410
+ remote_api_version < 15)
282
411
  for task in dag.tasks:
412
+ if omit_user_specified_yaml:
413
+ # pylint: disable=protected-access
414
+ task._user_specified_yaml = None
283
415
  task.expand_and_validate_workdir()
284
416
  if not workdir_only:
285
417
  task.expand_and_validate_file_mounts()
286
418
  dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)
287
419
  body = payloads.ValidateBody(dag=dag_str,
288
420
  request_options=admin_policy_request_options)
289
- response = requests.post(f'{server_common.get_server_url()}/validate',
290
- json=json.loads(body.model_dump_json()),
291
- cookies=server_common.get_api_cookie_jar())
421
+ response = server_common.make_authenticated_request(
422
+ 'POST', '/validate', json=json.loads(body.model_dump_json()))
292
423
  if response.status_code == 400:
293
424
  with ux_utils.print_exception_no_traceback():
294
425
  raise exceptions.deserialize_exception(
@@ -298,10 +429,11 @@ def validate(
298
429
  @usage_lib.entrypoint
299
430
  @server_common.check_server_healthy_or_start
300
431
  @annotations.client_api
301
- def dashboard() -> None:
432
+ def dashboard(starting_page: Optional[str] = None) -> None:
302
433
  """Starts the dashboard for SkyPilot."""
303
434
  api_server_url = server_common.get_server_url()
304
- url = server_common.get_dashboard_url(api_server_url)
435
+ url = server_common.get_dashboard_url(api_server_url,
436
+ starting_page=starting_page)
305
437
  logger.info(f'Opening dashboard in browser: {url}')
306
438
  webbrowser.open(url)
307
439
 
@@ -309,14 +441,16 @@ def dashboard() -> None:
309
441
  @usage_lib.entrypoint
310
442
  @server_common.check_server_healthy_or_start
311
443
  @annotations.client_api
444
+ @sky_context.contextual
312
445
  def launch(
313
446
  task: Union['sky.Task', 'sky.Dag'],
314
447
  cluster_name: Optional[str] = None,
315
448
  retry_until_up: bool = False,
316
449
  idle_minutes_to_autostop: Optional[int] = None,
450
+ wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
317
451
  dryrun: bool = False,
318
452
  down: bool = False, # pylint: disable=redefined-outer-name
319
- backend: Optional[backends.Backend] = None,
453
+ backend: Optional['backends.Backend'] = None,
320
454
  optimize_target: common.OptimizeTarget = common.OptimizeTarget.COST,
321
455
  no_setup: bool = False,
322
456
  clone_disk_from: Optional[str] = None,
@@ -327,7 +461,8 @@ def launch(
327
461
  _is_launched_by_jobs_controller: bool = False,
328
462
  _is_launched_by_sky_serve_controller: bool = False,
329
463
  _disable_controller_check: bool = False,
330
- ) -> server_common.RequestId:
464
+ ) -> server_common.RequestId[Tuple[Optional[int],
465
+ Optional['backends.ResourceHandle']]]:
331
466
  """Launches a cluster or task.
332
467
 
333
468
  The task's setup and run commands are executed under the task's workdir
@@ -344,7 +479,7 @@ def launch(
344
479
  import sky
345
480
  task = sky.Task(run='echo hello SkyPilot')
346
481
  task.set_resources(
347
- sky.Resources(cloud=sky.AWS(), accelerators='V100:4'))
482
+ sky.Resources(infra='aws', accelerators='V100:4'))
348
483
  sky.launch(task, cluster_name='my-cluster')
349
484
 
350
485
 
@@ -355,18 +490,31 @@ def launch(
355
490
  retry_until_up: whether to retry launching the cluster until it is
356
491
  up.
357
492
  idle_minutes_to_autostop: automatically stop the cluster after this
358
- many minute of idleness, i.e., no running or pending jobs in the
359
- cluster's job queue. Idleness gets reset whenever setting-up/
360
- running/pending jobs are found in the job queue. Setting this
361
- flag is equivalent to running ``sky.launch()`` and then
362
- ``sky.autostop(idle_minutes=<minutes>)``. If not set, the cluster
363
- will not be autostopped.
493
+ many minute of idleness, i.e., no running or pending jobs in the
494
+ cluster's job queue. Idleness gets reset whenever setting-up/
495
+ running/pending jobs are found in the job queue. Setting this
496
+ flag is equivalent to running
497
+ ``sky.launch(...)`` and then
498
+ ``sky.autostop(idle_minutes=<minutes>)``. If set, the autostop
499
+ config specified in the task' resources will be overridden by
500
+ this parameter.
501
+ wait_for: determines the condition for resetting the idleness timer.
502
+ This option works in conjunction with ``idle_minutes_to_autostop``.
503
+ Choices:
504
+
505
+ 1. "jobs_and_ssh" (default) - Wait for in-progress jobs and SSH
506
+ connections to finish.
507
+ 2. "jobs" - Only wait for in-progress jobs.
508
+ 3. "none" - Wait for nothing; autostop right after
509
+ ``idle_minutes_to_autostop``.
364
510
  dryrun: if True, do not actually launch the cluster.
365
511
  down: Tear down the cluster after all jobs finish (successfully or
366
- abnormally). If --idle-minutes-to-autostop is also set, the
367
- cluster will be torn down after the specified idle time.
368
- Note that if errors occur during provisioning/data syncing/setting
369
- up, the cluster will not be torn down for debugging purposes.
512
+ abnormally). If --idle-minutes-to-autostop is also set, the
513
+ cluster will be torn down after the specified idle time.
514
+ Note that if errors occur during provisioning/data syncing/setting
515
+ up, the cluster will not be torn down for debugging purposes. If
516
+ set, the autostop config specified in the task' resources will be
517
+ overridden by this parameter.
370
518
  backend: backend to use. If None, use the default backend
371
519
  (CloudVMRayBackend).
372
520
  optimize_target: target to optimize for. Choices: OptimizeTarget.COST,
@@ -422,35 +570,115 @@ def launch(
422
570
  raise NotImplementedError('clone_disk_from is not implemented yet. '
423
571
  'Please contact the SkyPilot team if you '
424
572
  'need this feature at slack.skypilot.co.')
573
+
574
+ remote_api_version = versions.get_remote_api_version()
575
+ if wait_for is not None and (remote_api_version is None or
576
+ remote_api_version < 13):
577
+ logger.warning('wait_for is not supported in your API server. '
578
+ 'Please upgrade to a newer API server to use it.')
579
+
425
580
  dag = dag_utils.convert_entrypoint_to_dag(task)
581
+ # Override the autostop config from command line flags to task YAML.
582
+ for task in dag.tasks:
583
+ for resource in task.resources:
584
+ if remote_api_version is None or remote_api_version < 13:
585
+ # An older server would not recognize the wait_for field
586
+ # in the schema, so we need to omit it.
587
+ resource.override_autostop_config(
588
+ down=down, idle_minutes=idle_minutes_to_autostop)
589
+ else:
590
+ resource.override_autostop_config(
591
+ down=down,
592
+ idle_minutes=idle_minutes_to_autostop,
593
+ wait_for=wait_for)
594
+ if resource.autostop_config is not None:
595
+ # For backward-compatibility, get the final autostop config for
596
+ # admin policy.
597
+ # TODO(aylei): remove this after 0.12.0
598
+ down = resource.autostop_config.down
599
+ idle_minutes_to_autostop = resource.autostop_config.idle_minutes
600
+
426
601
  request_options = admin_policy.RequestOptions(
427
602
  cluster_name=cluster_name,
428
603
  idle_minutes_to_autostop=idle_minutes_to_autostop,
429
604
  down=down,
430
605
  dryrun=dryrun)
606
+ with admin_policy_utils.apply_and_use_config_in_current_request(
607
+ dag,
608
+ request_name=request_names.AdminPolicyRequestName.CLUSTER_LAUNCH,
609
+ request_options=request_options,
610
+ at_client_side=True) as dag:
611
+ return _launch(
612
+ dag,
613
+ cluster_name,
614
+ request_options,
615
+ retry_until_up,
616
+ idle_minutes_to_autostop,
617
+ dryrun,
618
+ down,
619
+ backend,
620
+ optimize_target,
621
+ no_setup,
622
+ clone_disk_from,
623
+ fast,
624
+ _need_confirmation,
625
+ _is_launched_by_jobs_controller,
626
+ _is_launched_by_sky_serve_controller,
627
+ _disable_controller_check,
628
+ )
629
+
630
+
631
+ def _launch(
632
+ dag: 'sky.Dag',
633
+ cluster_name: str,
634
+ request_options: admin_policy.RequestOptions,
635
+ retry_until_up: bool = False,
636
+ idle_minutes_to_autostop: Optional[int] = None,
637
+ dryrun: bool = False,
638
+ down: bool = False, # pylint: disable=redefined-outer-name
639
+ backend: Optional['backends.Backend'] = None,
640
+ optimize_target: common.OptimizeTarget = common.OptimizeTarget.COST,
641
+ no_setup: bool = False,
642
+ clone_disk_from: Optional[str] = None,
643
+ fast: bool = False,
644
+ # Internal only:
645
+ # pylint: disable=invalid-name
646
+ _need_confirmation: bool = False,
647
+ _is_launched_by_jobs_controller: bool = False,
648
+ _is_launched_by_sky_serve_controller: bool = False,
649
+ _disable_controller_check: bool = False,
650
+ ) -> server_common.RequestId[Tuple[Optional[int],
651
+ Optional['backends.ResourceHandle']]]:
652
+ """Auxiliary function for launch(), refer to launch() for details."""
653
+
431
654
  validate(dag, admin_policy_request_options=request_options)
655
+ # The flags have been applied to the task YAML and the backward
656
+ # compatibility of admin policy has been handled. We should no longer use
657
+ # these flags.
658
+ del down, idle_minutes_to_autostop
432
659
 
433
660
  confirm_shown = False
434
661
  if _need_confirmation:
435
662
  cluster_status = None
436
663
  # TODO(SKY-998): we should reduce RTTs before launching the cluster.
437
- request_id = status([cluster_name], all_users=True)
438
- clusters = get(request_id)
664
+ status_request_id = status([cluster_name], all_users=True)
665
+ clusters = get(status_request_id)
439
666
  cluster_user_hash = common_utils.get_user_hash()
440
667
  cluster_user_hash_str = ''
441
- cluster_user_name = getpass.getuser()
668
+ current_user = common_utils.get_current_user_name()
669
+ cluster_user_name = current_user
442
670
  if not clusters:
443
671
  # Show the optimize log before the prompt if the cluster does not
444
672
  # exist.
445
- request_id = optimize(dag,
446
- admin_policy_request_options=request_options)
447
- stream_and_get(request_id)
673
+ optimize_request_id = optimize(
674
+ dag, admin_policy_request_options=request_options)
675
+ stream_and_get(optimize_request_id)
448
676
  else:
449
677
  cluster_record = clusters[0]
450
678
  cluster_status = cluster_record['status']
451
679
  cluster_user_hash = cluster_record['user_hash']
452
680
  cluster_user_name = cluster_record['user_name']
453
- if cluster_user_name == getpass.getuser():
681
+ if cluster_user_name == current_user:
454
682
  # Only show the hash if the username is the same as the local
455
683
  # username, to avoid confusion.
456
684
  cluster_user_hash_str = f' (hash: {cluster_user_hash})'
@@ -492,9 +720,7 @@ def launch(
492
720
  task=dag_str,
493
721
  cluster_name=cluster_name,
494
722
  retry_until_up=retry_until_up,
495
- idle_minutes_to_autostop=idle_minutes_to_autostop,
496
723
  dryrun=dryrun,
497
- down=down,
498
724
  backend=backend.NAME if backend else None,
499
725
  optimize_target=optimize_target,
500
726
  no_setup=no_setup,
@@ -507,12 +733,8 @@ def launch(
507
733
  _is_launched_by_sky_serve_controller),
508
734
  disable_controller_check=_disable_controller_check,
509
735
  )
510
- response = requests.post(
511
- f'{server_common.get_server_url()}/launch',
512
- json=json.loads(body.model_dump_json()),
513
- timeout=5,
514
- cookies=server_common.get_api_cookie_jar(),
515
- )
736
+ response = server_common.make_authenticated_request(
737
+ 'POST', '/launch', json=json.loads(body.model_dump_json()), timeout=5)
516
738
  return server_common.get_request_id(response)
517
739
 
518
740
 
@@ -524,8 +746,9 @@ def exec( # pylint: disable=redefined-builtin
524
746
  cluster_name: Optional[str] = None,
525
747
  dryrun: bool = False,
526
748
  down: bool = False, # pylint: disable=redefined-outer-name
527
- backend: Optional[backends.Backend] = None,
528
- ) -> server_common.RequestId:
749
+ backend: Optional['backends.Backend'] = None,
750
+ ) -> server_common.RequestId[Tuple[Optional[int],
751
+ Optional['backends.ResourceHandle']]]:
529
752
  """Executes a task on an existing cluster.
530
753
 
531
754
  This function performs two actions:
@@ -591,23 +814,49 @@ def exec( # pylint: disable=redefined-builtin
591
814
  backend=backend.NAME if backend else None,
592
815
  )
593
816
 
594
- response = requests.post(
595
- f'{server_common.get_server_url()}/exec',
596
- json=json.loads(body.model_dump_json()),
597
- timeout=5,
598
- cookies=server_common.get_api_cookie_jar(),
599
- )
817
+ response = server_common.make_authenticated_request(
818
+ 'POST', '/exec', json=json.loads(body.model_dump_json()), timeout=5)
600
819
  return server_common.get_request_id(response)
601
820
 
602
821
 
603
- @usage_lib.entrypoint
604
- @server_common.check_server_healthy_or_start
605
- @annotations.client_api
822
+ @typing.overload
823
+ def tail_logs(
824
+ cluster_name: str,
825
+ job_id: Optional[int],
826
+ follow: bool,
827
+ tail: int = 0,
828
+ output_stream: Optional['io.TextIOBase'] = None,
829
+ *, # keyword only separator
830
+ preload_content: Literal[True] = True) -> int:
831
+ ...
832
+
833
+
834
+ @typing.overload
606
835
  def tail_logs(cluster_name: str,
607
836
  job_id: Optional[int],
608
837
  follow: bool,
609
838
  tail: int = 0,
610
- output_stream: Optional['io.TextIOBase'] = None) -> int:
839
+ output_stream: None = None,
840
+ *,
841
+ preload_content: Literal[False]) -> Iterator[Optional[str]]:
842
+ ...
843
+
844
+
845
+ # TODO(aylei): when retry logs request, there will be duplicated log entries.
846
+ # We should fix this.
847
+ @usage_lib.entrypoint
848
+ @server_common.check_server_healthy_or_start
849
+ @annotations.client_api
850
+ @rest.retry_transient_errors()
851
+ def tail_logs(
852
+ cluster_name: str,
853
+ job_id: Optional[int],
854
+ follow: bool,
855
+ tail: int = 0,
856
+ output_stream: Optional['io.TextIOBase'] = None,
857
+ *, # keyword only separator
858
+ preload_content: bool = True
859
+ ) -> Union[int, Iterator[Optional[str]]]:
611
860
  """Tails the logs of a job.
612
861
 
613
862
  Args:
@@ -617,12 +866,21 @@ def tail_logs(cluster_name: str,
617
866
  immediately.
618
867
  tail: if > 0, tail the last N lines of the logs.
619
868
  output_stream: the stream to write the logs to. If None, print to the
620
- console.
869
+ console. Cannot be used with preload_content=False.
870
+ preload_content: if False, returns an Iterator[str | None] containing
871
+ the logs without the function blocking on the retrieval of entire
872
+ log. Iterator returns None when the log has been completely
873
+ streamed. Default True. Cannot be used with output_stream.
621
874
 
622
875
  Returns:
623
- Exit code based on success or failure of the job. 0 if success,
624
- 100 if the job failed. See exceptions.JobExitCode for possible exit
625
- codes.
876
+ If preload_content is True:
877
+ Exit code based on success or failure of the job. 0 if success,
878
+ 100 if the job failed. See exceptions.JobExitCode for possible exit
879
+ codes.
880
+ If preload_content is False:
881
+ Iterator[str | None] containing the logs without the function
882
+ blocking on the retrieval of entire log. Iterator returns None
883
+ when the log has been completely streamed.
626
884
 
627
885
  Request Raises:
628
886
  ValueError: if arguments are invalid or the cluster is not supported.
@@ -635,21 +893,110 @@ def tail_logs(cluster_name: str,
635
893
  sky.exceptions.CloudUserIdentityError: if we fail to get the current
636
894
  user identity.
637
895
  """
896
+ if output_stream is not None and not preload_content:
897
+ raise ValueError(
898
+ 'output_stream cannot be specified when preload_content is False')
899
+
638
900
  body = payloads.ClusterJobBody(
639
901
  cluster_name=cluster_name,
640
902
  job_id=job_id,
641
903
  follow=follow,
642
904
  tail=tail,
643
905
  )
644
- response = requests.post(
645
- f'{server_common.get_server_url()}/logs',
906
+ response = server_common.make_authenticated_request(
907
+ 'POST',
908
+ '/logs',
646
909
  json=json.loads(body.model_dump_json()),
647
910
  stream=True,
648
911
  timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
649
- None),
650
- cookies=server_common.get_api_cookie_jar())
651
- request_id = server_common.get_request_id(response)
652
- return stream_response(request_id, response, output_stream)
912
+ None))
913
+ request_id: server_common.RequestId[int] = server_common.get_request_id(
914
+ response)
915
+ if preload_content:
916
+ # Log request is idempotent when tail is 0, thus can resume previous
917
+ # streaming point on retry.
918
+ return stream_response(request_id=request_id,
919
+ response=response,
920
+ output_stream=output_stream,
921
+ resumable=(tail == 0))
922
+ else:
923
+ return rich_utils.decode_rich_status(response)
924
+
925
+
926
+ @usage_lib.entrypoint
927
+ @server_common.check_server_healthy_or_start
928
+ @versions.minimal_api_version(17)
929
+ @annotations.client_api
930
+ @rest.retry_transient_errors()
931
+ def tail_provision_logs(cluster_name: str,
932
+ worker: Optional[int] = None,
933
+ follow: bool = True,
934
+ tail: int = 0,
935
+ output_stream: Optional['io.TextIOBase'] = None) -> int:
936
+ """Tails the provisioning logs (provision.log) for a cluster.
937
+
938
+ Args:
939
+ cluster_name: name of the cluster.
940
+ worker: worker id in multi-node cluster.
941
+ If None, stream the logs of the head node.
942
+ follow: follow the logs.
943
+ tail: lines from end to tail.
944
+ output_stream: optional stream to write logs.
945
+ Returns:
946
+ Exit code 0 on streaming success; raises on HTTP error.
947
+ """
948
+ body = payloads.ProvisionLogsBody(cluster_name=cluster_name)
949
+
950
+ if worker is not None:
951
+ remote_api_version = versions.get_remote_api_version()
952
+ if remote_api_version is not None and remote_api_version >= 21:
953
+ if worker < 1:
954
+ raise ValueError('Worker must be a positive integer.')
955
+ body.worker = worker
956
+ else:
957
+ raise exceptions.APINotSupportedError(
958
+ 'Worker node provision logs are not supported in your API '
959
+ 'server. Please upgrade to a newer API server to use it.')
960
+ params = {
961
+ 'follow': str(follow).lower(),
962
+ 'tail': tail,
963
+ }
964
+
965
+ response = server_common.make_authenticated_request(
966
+ 'POST',
967
+ '/provision_logs',
968
+ json=json.loads(body.model_dump_json()),
969
+ params=params,
970
+ stream=True,
971
+ timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
972
+ None))
973
+ # Check for HTTP errors before streaming the response
974
+ if response.status_code != 200:
975
+ with ux_utils.print_exception_no_traceback():
976
+ raise exceptions.CommandError(response.status_code,
977
+ 'tail_provision_logs',
978
+ 'Failed to stream provision logs',
979
+ response.text)
980
+
981
+ # Log request is idempotent when tail is 0, thus can resume previous
982
+ # streaming point on retry.
983
+ # request_id=None here because /provision_logs does not create an async
984
+ # request. Instead, it streams a plain file from the server. This does NOT
985
+ # violate the stream_response doc warning about None in multi-user
986
+ # environments: we are not asking stream_response to select "the latest
987
+ # request". We already have the HTTP response to stream; request_id=None
988
+ # merely disables the follow-up GET. It is also necessary for --no-follow
989
+ # to return cleanly after printing the tailed lines. If we provided a
990
+ # non-None request_id here, the get(request_id) in stream_response(
991
+ # would fail since /provision_logs does not create a request record.
992
+ # By virtue of this, we set get_result to False to block get() from
993
+ # running.
994
+ stream_response(request_id=None,
995
+ response=response,
996
+ output_stream=output_stream,
997
+ resumable=(tail == 0),
998
+ get_result=False)
999
+ return 0
653
1000
 
654
1001
 
655
1002
  @usage_lib.entrypoint
@@ -683,11 +1030,11 @@ def download_logs(cluster_name: str,
683
1030
  cluster_name=cluster_name,
684
1031
  job_ids=job_ids,
685
1032
  )
686
- response = requests.post(f'{server_common.get_server_url()}/download_logs',
687
- json=json.loads(body.model_dump_json()),
688
- cookies=server_common.get_api_cookie_jar())
689
- job_id_remote_path_dict = stream_and_get(
690
- server_common.get_request_id(response))
1033
+ response = server_common.make_authenticated_request(
1034
+ 'POST', '/download_logs', json=json.loads(body.model_dump_json()))
1035
+ request_id: server_common.RequestId[Dict[
1036
+ str, str]] = server_common.get_request_id(response)
1037
+ job_id_remote_path_dict = stream_and_get(request_id)
691
1038
  remote2local_path_dict = client_common.download_logs_from_api_server(
692
1039
  job_id_remote_path_dict.values())
693
1040
  return {
@@ -702,10 +1049,11 @@ def download_logs(cluster_name: str,
702
1049
  def start(
703
1050
  cluster_name: str,
704
1051
  idle_minutes_to_autostop: Optional[int] = None,
1052
+ wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
705
1053
  retry_until_up: bool = False,
706
1054
  down: bool = False, # pylint: disable=redefined-outer-name
707
1055
  force: bool = False,
708
- ) -> server_common.RequestId:
1056
+ ) -> server_common.RequestId['backends.CloudVmRayResourceHandle']:
709
1057
  """Restart a cluster.
710
1058
 
711
1059
  If a cluster is previously stopped (status is STOPPED) or failed in
@@ -728,6 +1076,15 @@ def start(
728
1076
  flag is equivalent to running ``sky.launch()`` and then
729
1077
  ``sky.autostop(idle_minutes=<minutes>)``. If not set, the
730
1078
  cluster will not be autostopped.
1079
+ wait_for: determines the condition for resetting the idleness timer.
1080
+ This option works in conjunction with ``idle_minutes_to_autostop``.
1081
+ Choices:
1082
+
1083
+ 1. "jobs_and_ssh" (default) - Wait for in-progress jobs and SSH
1084
+ connections to finish.
1085
+ 2. "jobs" - Only wait for in-progress jobs.
1086
+ 3. "none" - Wait for nothing; autostop right after
1087
+ ``idle_minutes_to_autostop``.
731
1088
  retry_until_up: whether to retry launching the cluster until it is
732
1089
  up.
733
1090
  down: Autodown the cluster: tear down the cluster after specified
@@ -756,26 +1113,30 @@ def start(
756
1113
  sky.exceptions.ClusterOwnerIdentitiesMismatchError: if the cluster to
757
1114
  restart was launched by a different user.
758
1115
  """
1116
+ remote_api_version = versions.get_remote_api_version()
1117
+ if wait_for is not None and (remote_api_version is None or
1118
+ remote_api_version < 13):
1119
+ logger.warning('wait_for is not supported in your API server. '
1120
+ 'Please upgrade to a newer API server to use it.')
1121
+
759
1122
  body = payloads.StartBody(
760
1123
  cluster_name=cluster_name,
761
1124
  idle_minutes_to_autostop=idle_minutes_to_autostop,
1125
+ wait_for=wait_for,
762
1126
  retry_until_up=retry_until_up,
763
1127
  down=down,
764
1128
  force=force,
765
1129
  )
766
- response = requests.post(
767
- f'{server_common.get_server_url()}/start',
768
- json=json.loads(body.model_dump_json()),
769
- timeout=5,
770
- cookies=server_common.get_api_cookie_jar(),
771
- )
1130
+ response = server_common.make_authenticated_request(
1131
+ 'POST', '/start', json=json.loads(body.model_dump_json()), timeout=5)
772
1132
  return server_common.get_request_id(response)
773
1133
 
774
1134
 
775
1135
  @usage_lib.entrypoint
776
1136
  @server_common.check_server_healthy_or_start
777
1137
  @annotations.client_api
778
- def down(cluster_name: str, purge: bool = False) -> server_common.RequestId:
1138
+ def down(cluster_name: str,
1139
+ purge: bool = False) -> server_common.RequestId[None]:
779
1140
  """Tears down a cluster.
780
1141
 
781
1142
  Tearing down a cluster will delete all associated resources (all billing
@@ -809,19 +1170,16 @@ def down(cluster_name: str, purge: bool = False) -> server_common.RequestId:
809
1170
  cluster_name=cluster_name,
810
1171
  purge=purge,
811
1172
  )
812
- response = requests.post(
813
- f'{server_common.get_server_url()}/down',
814
- json=json.loads(body.model_dump_json()),
815
- timeout=5,
816
- cookies=server_common.get_api_cookie_jar(),
817
- )
1173
+ response = server_common.make_authenticated_request(
1174
+ 'POST', '/down', json=json.loads(body.model_dump_json()), timeout=5)
818
1175
  return server_common.get_request_id(response)
819
1176
 
820
1177
 
821
1178
  @usage_lib.entrypoint
822
1179
  @server_common.check_server_healthy_or_start
823
1180
  @annotations.client_api
824
- def stop(cluster_name: str, purge: bool = False) -> server_common.RequestId:
1181
+ def stop(cluster_name: str,
1182
+ purge: bool = False) -> server_common.RequestId[None]:
825
1183
  """Stops a cluster.
826
1184
 
827
1185
  Data on attached disks is not lost when a cluster is stopped. Billing for
@@ -858,12 +1216,8 @@ def stop(cluster_name: str, purge: bool = False) -> server_common.RequestId:
858
1216
  cluster_name=cluster_name,
859
1217
  purge=purge,
860
1218
  )
861
- response = requests.post(
862
- f'{server_common.get_server_url()}/stop',
863
- json=json.loads(body.model_dump_json()),
864
- timeout=5,
865
- cookies=server_common.get_api_cookie_jar(),
866
- )
1219
+ response = server_common.make_authenticated_request(
1220
+ 'POST', '/stop', json=json.loads(body.model_dump_json()), timeout=5)
867
1221
  return server_common.get_request_id(response)
868
1222
 
869
1223
 
@@ -871,10 +1225,11 @@ def stop(cluster_name: str, purge: bool = False) -> server_common.RequestId:
871
1225
  @server_common.check_server_healthy_or_start
872
1226
  @annotations.client_api
873
1227
  def autostop(
874
- cluster_name: str,
875
- idle_minutes: int,
876
- down: bool = False # pylint: disable=redefined-outer-name
877
- ) -> server_common.RequestId:
1228
+ cluster_name: str,
1229
+ idle_minutes: int,
1230
+ wait_for: Optional[autostop_lib.AutostopWaitFor] = None,
1231
+ down: bool = False, # pylint: disable=redefined-outer-name
1232
+ ) -> server_common.RequestId[None]:
878
1233
  """Schedules an autostop/autodown for a cluster.
879
1234
 
880
1235
  Autostop/autodown will automatically stop or teardown a cluster when it
@@ -904,6 +1259,14 @@ def autostop(
904
1259
  idle_minutes: the number of minutes of idleness (no pending/running
905
1260
  jobs) after which the cluster will be stopped automatically. Setting
906
1261
  to a negative number cancels any autostop/autodown setting.
1262
+ wait_for: determines the condition for resetting the idleness timer.
1263
+ This option works in conjunction with ``idle_minutes``.
1264
+ Choices:
1265
+
1266
+ 1. "jobs_and_ssh" (default) - Wait for in-progress jobs and SSH
1267
+ connections to finish.
1268
+ 2. "jobs" - Only wait for in-progress jobs.
1269
+ 3. "none" - Wait for nothing; autostop right after ``idle_minutes``.
907
1270
  down: if true, use autodown (tear down the cluster; non-restartable),
908
1271
  rather than autostop (restartable).
909
1272
 
@@ -923,26 +1286,31 @@ def autostop(
923
1286
  sky.exceptions.CloudUserIdentityError: if we fail to get the current
924
1287
  user identity.
925
1288
  """
1289
+ remote_api_version = versions.get_remote_api_version()
1290
+ if wait_for is not None and (remote_api_version is None or
1291
+ remote_api_version < 13):
1292
+ logger.warning('wait_for is not supported in your API server. '
1293
+ 'Please upgrade to a newer API server to use it.')
1294
+
926
1295
  body = payloads.AutostopBody(
927
1296
  cluster_name=cluster_name,
928
1297
  idle_minutes=idle_minutes,
1298
+ wait_for=wait_for,
929
1299
  down=down,
930
1300
  )
931
- response = requests.post(
932
- f'{server_common.get_server_url()}/autostop',
933
- json=json.loads(body.model_dump_json()),
934
- timeout=5,
935
- cookies=server_common.get_api_cookie_jar(),
936
- )
1301
+ response = server_common.make_authenticated_request(
1302
+ 'POST', '/autostop', json=json.loads(body.model_dump_json()), timeout=5)
937
1303
  return server_common.get_request_id(response)
938
1304
 
939
1305
 
940
1306
  @usage_lib.entrypoint
941
1307
  @server_common.check_server_healthy_or_start
942
1308
  @annotations.client_api
943
- def queue(cluster_name: str,
944
- skip_finished: bool = False,
945
- all_users: bool = False) -> server_common.RequestId:
1309
+ def queue(
1310
+ cluster_name: str,
1311
+ skip_finished: bool = False,
1312
+ all_users: bool = False
1313
+ ) -> server_common.RequestId[List[responses.ClusterJobRecord]]:
946
1314
  """Gets the job queue of a cluster.
947
1315
 
948
1316
  Args:
@@ -955,8 +1323,8 @@ def queue(cluster_name: str,
955
1323
  The request ID of the queue request.
956
1324
 
957
1325
  Request Returns:
958
- job_records (List[Dict[str, Any]]): A list of dicts for each job in the
959
- queue.
1326
+ job_records (List[responses.ClusterJobRecord]): A list of job records
1327
+ for each job in the queue.
960
1328
 
961
1329
  .. code-block:: python
962
1330
 
@@ -991,17 +1359,19 @@ def queue(cluster_name: str,
991
1359
  skip_finished=skip_finished,
992
1360
  all_users=all_users,
993
1361
  )
994
- response = requests.post(f'{server_common.get_server_url()}/queue',
995
- json=json.loads(body.model_dump_json()),
996
- cookies=server_common.get_api_cookie_jar())
1362
+ response = server_common.make_authenticated_request(
1363
+ 'POST', '/queue', json=json.loads(body.model_dump_json()))
997
1364
  return server_common.get_request_id(response)
998
1365
 
999
1366
 
1000
1367
  @usage_lib.entrypoint
1001
1368
  @server_common.check_server_healthy_or_start
1002
1369
  @annotations.client_api
1003
- def job_status(cluster_name: str,
1004
- job_ids: Optional[List[int]] = None) -> server_common.RequestId:
1370
+ def job_status(
1371
+ cluster_name: str,
1372
+ job_ids: Optional[List[int]] = None
1373
+ ) -> server_common.RequestId[Dict[Optional[int],
1374
+ Optional['job_lib.JobStatus']]]:
1005
1375
  """Gets the status of jobs on a cluster.
1006
1376
 
1007
1377
  Args:
@@ -1033,9 +1403,8 @@ def job_status(cluster_name: str,
1033
1403
  cluster_name=cluster_name,
1034
1404
  job_ids=job_ids,
1035
1405
  )
1036
- response = requests.post(f'{server_common.get_server_url()}/job_status',
1037
- json=json.loads(body.model_dump_json()),
1038
- cookies=server_common.get_api_cookie_jar())
1406
+ response = server_common.make_authenticated_request(
1407
+ 'POST', '/job_status', json=json.loads(body.model_dump_json()))
1039
1408
  return server_common.get_request_id(response)
1040
1409
 
1041
1410
 
@@ -1049,7 +1418,7 @@ def cancel(
1049
1418
  job_ids: Optional[List[int]] = None,
1050
1419
  # pylint: disable=invalid-name
1051
1420
  _try_cancel_if_cluster_is_init: bool = False
1052
- ) -> server_common.RequestId:
1421
+ ) -> server_common.RequestId[None]:
1053
1422
  """Cancels jobs on a cluster.
1054
1423
 
1055
1424
  Args:
@@ -1087,9 +1456,8 @@ def cancel(
1087
1456
  job_ids=job_ids,
1088
1457
  try_cancel_if_cluster_is_init=_try_cancel_if_cluster_is_init,
1089
1458
  )
1090
- response = requests.post(f'{server_common.get_server_url()}/cancel',
1091
- json=json.loads(body.model_dump_json()),
1092
- cookies=server_common.get_api_cookie_jar())
1459
+ response = server_common.make_authenticated_request(
1460
+ 'POST', '/cancel', json=json.loads(body.model_dump_json()))
1093
1461
  return server_common.get_request_id(response)
1094
1462
 
1095
1463
 
@@ -1100,7 +1468,10 @@ def status(
1100
1468
  cluster_names: Optional[List[str]] = None,
1101
1469
  refresh: common.StatusRefreshMode = common.StatusRefreshMode.NONE,
1102
1470
  all_users: bool = False,
1103
- ) -> server_common.RequestId:
1471
+ *,
1472
+ _include_credentials: bool = False,
1473
+ _summary_response: bool = False,
1474
+ ) -> server_common.RequestId[List[responses.StatusResponse]]:
1104
1475
  """Gets cluster statuses.
1105
1476
 
1106
1477
  If cluster_names is given, return those clusters. Otherwise, return all
@@ -1148,6 +1519,8 @@ def status(
1148
1519
  provider(s).
1149
1520
  all_users: whether to include all users' clusters. By default, only
1150
1521
  the current user's clusters are included.
1522
+ _include_credentials: (internal) whether to include cluster ssh
1523
+ credentials in the response (default: False).
1151
1524
 
1152
1525
  Returns:
1153
1526
  The request ID of the status request.
@@ -1182,10 +1555,11 @@ def status(
1182
1555
  cluster_names=cluster_names,
1183
1556
  refresh=refresh,
1184
1557
  all_users=all_users,
1558
+ include_credentials=_include_credentials,
1559
+ summary_response=_summary_response,
1185
1560
  )
1186
- response = requests.post(f'{server_common.get_server_url()}/status',
1187
- json=json.loads(body.model_dump_json()),
1188
- cookies=server_common.get_api_cookie_jar())
1561
+ response = server_common.make_authenticated_request(
1562
+ 'POST', '/status', json=json.loads(body.model_dump_json()))
1189
1563
  return server_common.get_request_id(response)
1190
1564
 
1191
1565
 
@@ -1193,10 +1567,19 @@ def status(
1193
1567
  @server_common.check_server_healthy_or_start
1194
1568
  @annotations.client_api
1195
1569
  def endpoints(
1196
- cluster: str,
1197
- port: Optional[Union[int, str]] = None) -> server_common.RequestId:
1570
+ cluster: str,
1571
+ port: Optional[Union[int, str]] = None
1572
+ ) -> server_common.RequestId[Dict[int, str]]:
1198
1573
  """Gets the endpoint for a given cluster and port number (endpoint).
1199
1574
 
1575
+ Example:
1576
+ .. code-block:: python
1577
+
1578
+ import sky
1579
+ request_id = sky.endpoints('test-cluster')
1580
+ sky.get(request_id)
1581
+
1582
+
1200
1583
  Args:
1201
1584
  cluster: The name of the cluster.
1202
1585
  port: The port number to get the endpoint for. If None, endpoints
@@ -1206,8 +1589,9 @@ def endpoints(
1206
1589
  The request ID of the endpoints request.
1207
1590
 
1208
1591
  Request Returns:
1209
- A dictionary of port numbers to endpoints. If port is None,
1210
- the dictionary will contain all ports:endpoints exposed on the cluster.
1592
+ A dictionary of port numbers to endpoints.
1593
+ If port is None, the dictionary contains all
1594
+ ports:endpoints exposed on the cluster.
1211
1595
 
1212
1596
  Request Raises:
1213
1597
  ValueError: if the cluster is not UP or the endpoint is not exposed.
@@ -1218,16 +1602,17 @@ def endpoints(
1218
1602
  cluster=cluster,
1219
1603
  port=port,
1220
1604
  )
1221
- response = requests.post(f'{server_common.get_server_url()}/endpoints',
1222
- json=json.loads(body.model_dump_json()),
1223
- cookies=server_common.get_api_cookie_jar())
1605
+ response = server_common.make_authenticated_request(
1606
+ 'POST', '/endpoints', json=json.loads(body.model_dump_json()))
1224
1607
  return server_common.get_request_id(response)
1225
1608
 
1226
1609
 
1227
1610
  @usage_lib.entrypoint
1228
1611
  @server_common.check_server_healthy_or_start
1229
1612
  @annotations.client_api
1230
- def cost_report() -> server_common.RequestId: # pylint: disable=redefined-builtin
1613
+ def cost_report(
1614
+ days: Optional[int] = None
1615
+ ) -> server_common.RequestId[List[Dict[str, Any]]]: # pylint: disable=redefined-builtin
1231
1616
  """Gets all cluster cost reports, including those that have been downed.
1232
1617
 
1233
1618
  The estimated cost column indicates price for the cluster based on the type
@@ -1237,6 +1622,10 @@ def cost_report() -> server_common.RequestId: # pylint: disable=redefined-built
1237
1622
  cache of the cluster status, and may not be accurate for the cluster with
1238
1623
  autostop/use_spot set or terminated/stopped on the cloud console.
1239
1624
 
1625
+ Args:
1626
+ days: The number of days to get the cost report for. If not provided,
1627
+ the default is 30 days.
1628
+
1240
1629
  Returns:
1241
1630
  The request ID of the cost report request.
1242
1631
 
@@ -1258,8 +1647,9 @@ def cost_report() -> server_common.RequestId: # pylint: disable=redefined-built
1258
1647
  'total_cost': (float) cost given resources and usage intervals,
1259
1648
  }
1260
1649
  """
1261
- response = requests.get(f'{server_common.get_server_url()}/cost_report',
1262
- cookies=server_common.get_api_cookie_jar())
1650
+ body = payloads.CostReportBody(days=days)
1651
+ response = server_common.make_authenticated_request(
1652
+ 'POST', '/cost_report', json=json.loads(body.model_dump_json()))
1263
1653
  return server_common.get_request_id(response)
1264
1654
 
1265
1655
 
@@ -1267,36 +1657,24 @@ def cost_report() -> server_common.RequestId: # pylint: disable=redefined-built
1267
1657
  @usage_lib.entrypoint
1268
1658
  @server_common.check_server_healthy_or_start
1269
1659
  @annotations.client_api
1270
- def storage_ls() -> server_common.RequestId:
1660
+ def storage_ls() -> server_common.RequestId[List[responses.StorageRecord]]:
1271
1661
  """Gets the storages.
1272
1662
 
1273
1663
  Returns:
1274
1664
  The request ID of the storage list request.
1275
1665
 
1276
1666
  Request Returns:
1277
- storage_records (List[Dict[str, Any]]): A list of dicts, with each dict
1278
- containing the information of a storage.
1279
-
1280
- .. code-block:: python
1281
-
1282
- {
1283
- 'name': (str) storage name,
1284
- 'launched_at': (int) timestamp of creation,
1285
- 'store': (List[sky.StoreType]) storage type,
1286
- 'last_use': (int) timestamp of last use,
1287
- 'status': (sky.StorageStatus) storage status,
1288
- }
1289
- ]
1667
+ storage_records (List[responses.StorageRecord]):
1668
+ A list of storage records.
1290
1669
  """
1291
- response = requests.get(f'{server_common.get_server_url()}/storage/ls',
1292
- cookies=server_common.get_api_cookie_jar())
1670
+ response = server_common.make_authenticated_request('GET', '/storage/ls')
1293
1671
  return server_common.get_request_id(response)
1294
1672
 
1295
1673
 
1296
1674
  @usage_lib.entrypoint
1297
1675
  @server_common.check_server_healthy_or_start
1298
1676
  @annotations.client_api
1299
- def storage_delete(name: str) -> server_common.RequestId:
1677
+ def storage_delete(name: str) -> server_common.RequestId[None]:
1300
1678
  """Deletes a storage.
1301
1679
 
1302
1680
  Args:
@@ -1312,9 +1690,8 @@ def storage_delete(name: str) -> server_common.RequestId:
1312
1690
  ValueError: If the storage does not exist.
1313
1691
  """
1314
1692
  body = payloads.StorageBody(name=name)
1315
- response = requests.post(f'{server_common.get_server_url()}/storage/delete',
1316
- json=json.loads(body.model_dump_json()),
1317
- cookies=server_common.get_api_cookie_jar())
1693
+ response = server_common.make_authenticated_request(
1694
+ 'POST', '/storage/delete', json=json.loads(body.model_dump_json()))
1318
1695
  return server_common.get_request_id(response)
1319
1696
 
1320
1697
 
@@ -1330,7 +1707,9 @@ def local_up(gpus: bool,
1330
1707
  ssh_key: Optional[str],
1331
1708
  cleanup: bool,
1332
1709
  context_name: Optional[str] = None,
1333
- password: Optional[str] = None) -> server_common.RequestId:
1710
+ password: Optional[str] = None,
1711
+ name: Optional[str] = None,
1712
+ port_start: Optional[int] = None) -> server_common.RequestId[None]:
1334
1713
  """Launches a Kubernetes cluster on local machines.
1335
1714
 
1336
1715
  Returns:
@@ -1341,8 +1720,8 @@ def local_up(gpus: bool,
1341
1720
  # TODO: move this check to server.
1342
1721
  if not server_common.is_api_server_local():
1343
1722
  with ux_utils.print_exception_no_traceback():
1344
- raise ValueError(
1345
- 'sky local up is only supported when running SkyPilot locally.')
1723
+ raise ValueError('`sky local up` is only supported when '
1724
+ 'running SkyPilot locally.')
1346
1725
 
1347
1726
  body = payloads.LocalUpBody(gpus=gpus,
1348
1727
  ips=ips,
@@ -1350,27 +1729,150 @@ def local_up(gpus: bool,
1350
1729
  ssh_key=ssh_key,
1351
1730
  cleanup=cleanup,
1352
1731
  context_name=context_name,
1353
- password=password)
1354
- response = requests.post(f'{server_common.get_server_url()}/local_up',
1355
- json=json.loads(body.model_dump_json()),
1356
- cookies=server_common.get_api_cookie_jar())
1732
+ password=password,
1733
+ name=name,
1734
+ port_start=port_start)
1735
+ response = server_common.make_authenticated_request(
1736
+ 'POST', '/local_up', json=json.loads(body.model_dump_json()))
1357
1737
  return server_common.get_request_id(response)
1358
1738
 
1359
1739
 
1360
1740
  @usage_lib.entrypoint
1361
1741
  @server_common.check_server_healthy_or_start
1362
1742
  @annotations.client_api
1363
- def local_down() -> server_common.RequestId:
1743
+ def local_down(name: Optional[str]) -> server_common.RequestId[None]:
1364
1744
  """Tears down the Kubernetes cluster started by local_up."""
1365
1745
  # We do not allow local up when the API server is running remotely since it
1366
1746
  # will modify the kubeconfig.
1367
1747
  # TODO: move this check to remote server.
1368
1748
  if not server_common.is_api_server_local():
1369
1749
  with ux_utils.print_exception_no_traceback():
1370
- raise ValueError('sky local down is only supported when running '
1750
+ raise ValueError('`sky local down` is only supported when running '
1371
1751
  'SkyPilot locally.')
1372
- response = requests.post(f'{server_common.get_server_url()}/local_down',
1373
- cookies=server_common.get_api_cookie_jar())
1752
+
1753
+ body = payloads.LocalDownBody(name=name)
1754
+ response = server_common.make_authenticated_request(
1755
+ 'POST', '/local_down', json=json.loads(body.model_dump_json()))
1756
+ return server_common.get_request_id(response)
1757
+
1758
+
1759
+ def _update_remote_ssh_node_pools(file: str,
1760
+ infra: Optional[str] = None) -> None:
1761
+ """Update the SSH node pools on the remote server.
1762
+
1763
+ This function will also upload the local SSH key to the remote server, and
1764
+ replace the file path to the remote SSH key file path.
1765
+
1766
+ Args:
1767
+ file: The path to the local SSH node pools config file.
1768
+ infra: The name of the cluster configuration in the local SSH node
1769
+ pools config file. If None, all clusters in the file are updated.
1770
+ """
1771
+ file = os.path.expanduser(file)
1772
+ if not os.path.exists(file):
1773
+ with ux_utils.print_exception_no_traceback():
1774
+ raise ValueError(
1775
+ f'SSH Node Pool config file {file} does not exist. '
1776
+ 'Please check if the file exists and the path is correct.')
1777
+ config = ssh_utils.load_ssh_targets(file)
1778
+ config = ssh_utils.get_cluster_config(config, infra)
1779
+ pools_config = {}
1780
+ for name, pool_config in config.items():
1781
+ hosts_info = ssh_utils.prepare_hosts_info(
1782
+ name, pool_config, upload_ssh_key_func=_upload_ssh_key_and_wait)
1783
+ pools_config[name] = {'hosts': hosts_info}
1784
+ server_common.make_authenticated_request('POST',
1785
+ '/ssh_node_pools',
1786
+ json=pools_config)
1787
+
1788
+
1789
+ def _upload_ssh_key_and_wait(key_name: str, key_file_path: str) -> str:
1790
+ """Upload the SSH key to the remote server and wait for the key to be
1791
+ uploaded.
1792
+
1793
+ Args:
1794
+ key_name: The name of the SSH key.
1795
+ key_file_path: The path to the local SSH key file.
1796
+
1797
+ Returns:
1798
+ The path for the remote SSH key file on the API server.
1799
+ """
1800
+ if not os.path.exists(os.path.expanduser(key_file_path)):
1801
+ with ux_utils.print_exception_no_traceback():
1802
+ raise ValueError(f'SSH key file not found: {key_file_path}')
1803
+
1804
+ with open(os.path.expanduser(key_file_path), 'rb') as key_file:
1805
+ response = server_common.make_authenticated_request(
1806
+ 'POST',
1807
+ '/ssh_node_pools/keys',
1808
+ files={
1809
+ 'key_file': (key_name, key_file, 'application/octet-stream')
1810
+ },
1811
+ data={'key_name': key_name},
1812
+ cookies=server_common.get_api_cookie_jar())
1813
+
1814
+ return response.json()['key_path']
1815
+
1816
+
1817
+ @usage_lib.entrypoint
1818
+ @server_common.check_server_healthy_or_start
1819
+ @annotations.client_api
1820
+ def ssh_up(infra: Optional[str] = None,
1821
+ file: Optional[str] = None) -> server_common.RequestId[None]:
1822
+ """Deploys the SSH Node Pools defined in ~/.sky/ssh_targets.yaml.
1823
+
1824
+ Args:
1825
+ infra: Name of the cluster configuration in ssh_targets.yaml.
1826
+ If None, the first cluster in the file is used.
1827
+ file: Name of the ssh node pool configuration file to use. If
1828
+ None, the default path, ~/.sky/ssh_node_pools.yaml is used.
1829
+
1830
+ Returns:
1831
+ request_id: The request ID of the SSH cluster deployment request.
1832
+ """
1833
+ if file is not None:
1834
+ _update_remote_ssh_node_pools(file, infra)
1835
+
1836
+ # Use SSH node pools router endpoint
1837
+ body = payloads.SSHUpBody(infra=infra, cleanup=False)
1838
+ if infra is not None:
1839
+ # Call the specific pool deployment endpoint
1840
+ response = server_common.make_authenticated_request(
1841
+ 'POST', f'/ssh_node_pools/{infra}/deploy')
1842
+ else:
1843
+ # Call the general deployment endpoint
1844
+ response = server_common.make_authenticated_request(
1845
+ 'POST',
1846
+ '/ssh_node_pools/deploy',
1847
+ json=json.loads(body.model_dump_json()))
1848
+ return server_common.get_request_id(response)
1849
+
1850
+
1851
+ @usage_lib.entrypoint
1852
+ @server_common.check_server_healthy_or_start
1853
+ @annotations.client_api
1854
+ def ssh_down(infra: Optional[str] = None) -> server_common.RequestId[None]:
1855
+ """Tears down a Kubernetes cluster on SSH targets.
1856
+
1857
+ Args:
1858
+ infra: Name of the cluster configuration in ssh_targets.yaml.
1859
+ If None, the first cluster in the file is used.
1860
+
1861
+ Returns:
1862
+ request_id: The request ID of the SSH cluster teardown request.
1863
+ """
1864
+ # Use SSH node pools router endpoint
1865
+ body = payloads.SSHUpBody(infra=infra, cleanup=True)
1866
+ if infra is not None:
1867
+ # Call the specific pool down endpoint
1868
+ response = server_common.make_authenticated_request(
1869
+ 'POST', f'/ssh_node_pools/{infra}/down')
1870
+ else:
1871
+ # Call the general down endpoint
1872
+ response = server_common.make_authenticated_request(
1873
+ 'POST',
1874
+ '/ssh_node_pools/down',
1875
+ json=json.loads(body.model_dump_json()))
1374
1876
  return server_common.get_request_id(response)
1375
1877
 
1376
1878
 
@@ -1378,9 +1880,12 @@ def local_down() -> server_common.RequestId:
1378
1880
  @server_common.check_server_healthy_or_start
1379
1881
  @annotations.client_api
1380
1882
  def realtime_kubernetes_gpu_availability(
1381
- context: Optional[str] = None,
1382
- name_filter: Optional[str] = None,
1383
- quantity_filter: Optional[int] = None) -> server_common.RequestId:
1883
+ context: Optional[str] = None,
1884
+ name_filter: Optional[str] = None,
1885
+ quantity_filter: Optional[int] = None,
1886
+ is_ssh: Optional[bool] = None
1887
+ ) -> server_common.RequestId[List[Tuple[
1888
+ str, List['models.RealtimeGpuAvailability']]]]:
1384
1889
  """Gets the real-time Kubernetes GPU availability.
1385
1890
 
1386
1891
  Returns:
@@ -1390,12 +1895,12 @@ def realtime_kubernetes_gpu_availability(
1390
1895
  context=context,
1391
1896
  name_filter=name_filter,
1392
1897
  quantity_filter=quantity_filter,
1898
+ is_ssh=is_ssh,
1393
1899
  )
1394
- response = requests.post(
1395
- f'{server_common.get_server_url()}/'
1396
- 'realtime_kubernetes_gpu_availability',
1397
- json=json.loads(body.model_dump_json()),
1398
- cookies=server_common.get_api_cookie_jar())
1900
+ response = server_common.make_authenticated_request(
1901
+ 'POST',
1902
+ '/realtime_kubernetes_gpu_availability',
1903
+ json=json.loads(body.model_dump_json()))
1399
1904
  return server_common.get_request_id(response)
1400
1905
 
1401
1906
 
@@ -1403,7 +1908,8 @@ def realtime_kubernetes_gpu_availability(
1403
1908
  @server_common.check_server_healthy_or_start
1404
1909
  @annotations.client_api
1405
1910
  def kubernetes_node_info(
1406
- context: Optional[str] = None) -> server_common.RequestId:
1911
+ context: Optional[str] = None
1912
+ ) -> server_common.RequestId['models.KubernetesNodesInfo']:
1407
1913
  """Gets the resource information for all the nodes in the cluster.
1408
1914
 
1409
1915
  Currently only GPU resources are supported. The function returns the total
@@ -1424,17 +1930,20 @@ def kubernetes_node_info(
1424
1930
  information.
1425
1931
  """
1426
1932
  body = payloads.KubernetesNodeInfoRequestBody(context=context)
1427
- response = requests.post(
1428
- f'{server_common.get_server_url()}/kubernetes_node_info',
1429
- json=json.loads(body.model_dump_json()),
1430
- cookies=server_common.get_api_cookie_jar())
1933
+ response = server_common.make_authenticated_request(
1934
+ 'POST',
1935
+ '/kubernetes_node_info',
1936
+ json=json.loads(body.model_dump_json()))
1431
1937
  return server_common.get_request_id(response)
1432
1938
 
1433
1939
 
1434
1940
  @usage_lib.entrypoint
1435
1941
  @server_common.check_server_healthy_or_start
1436
1942
  @annotations.client_api
1437
- def status_kubernetes() -> server_common.RequestId:
1943
+ def status_kubernetes() -> server_common.RequestId[
1944
+ Tuple[List['kubernetes_utils.KubernetesSkyPilotClusterInfoPayload'],
1945
+ List['kubernetes_utils.KubernetesSkyPilotClusterInfoPayload'],
1946
+ List[responses.ManagedJobRecord], Optional[str]]]:
1438
1947
  """Gets all SkyPilot clusters and jobs in the Kubernetes cluster.
1439
1948
 
1440
1949
  Managed jobs and services are also included in the clusters returned.
@@ -1455,21 +1964,24 @@ def status_kubernetes() -> server_common.RequestId:
1455
1964
  dictionary job info, see jobs.queue_from_kubernetes_pod for details.
1456
1965
  - context: Kubernetes context used to fetch the cluster information.
1457
1966
  """
1458
- response = requests.get(
1459
- f'{server_common.get_server_url()}/status_kubernetes',
1460
- cookies=server_common.get_api_cookie_jar())
1967
+ response = server_common.make_authenticated_request('GET',
1968
+ '/status_kubernetes')
1461
1969
  return server_common.get_request_id(response)
1462
1970
 
1463
1971
 
1464
1972
  # === API request APIs ===
1465
1973
  @usage_lib.entrypoint
1466
- @server_common.check_server_healthy_or_start
1467
1974
  @annotations.client_api
1468
- def get(request_id: str) -> Any:
1975
+ def get(request_id: server_common.RequestId[T]) -> T:
1469
1976
  """Waits for and gets the result of a request.
1470
1977
 
1978
+ This function will not check the server health since /api/get is typically
1979
+ not the first API call in an SDK session and checking the server health
1980
+ may cause GET /api/get being sent to a restarted API server.
1981
+
1471
1982
  Args:
1472
- request_id: The request ID of the request to get.
1983
+ request_id: The request ID of the request to get. May be a full request
1984
+ ID or a prefix.
1473
1985
 
1474
1986
  Returns:
1475
1987
  The ``Request Returns`` of the specified request. See the documentation
@@ -1480,19 +1992,20 @@ def get(request_id: str) -> Any:
1480
1992
  see ``Request Raises`` in the documentation of the specific requests
1481
1993
  above.
1482
1994
  """
1483
- response = requests.get(
1484
- f'{server_common.get_server_url()}/api/get?request_id={request_id}',
1995
+ response = server_common.make_authenticated_request(
1996
+ 'GET',
1997
+ f'/api/get?request_id={request_id}',
1998
+ retry=False,
1485
1999
  timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
1486
- None),
1487
- cookies=server_common.get_api_cookie_jar())
2000
+ None))
1488
2001
  request_task = None
1489
2002
  if response.status_code == 200:
1490
2003
  request_task = requests_lib.Request.decode(
1491
- requests_lib.RequestPayload(**response.json()))
2004
+ payloads.RequestPayload(**response.json()))
1492
2005
  elif response.status_code == 500:
1493
2006
  try:
1494
2007
  request_task = requests_lib.Request.decode(
1495
- requests_lib.RequestPayload(**response.json().get('detail')))
2008
+ payloads.RequestPayload(**response.json().get('detail')))
1496
2009
  logger.debug(f'Got request with error: {request_task.name}')
1497
2010
  except Exception: # pylint: disable=broad-except
1498
2011
  request_task = None
@@ -1518,23 +2031,45 @@ def get(request_id: str) -> Any:
1518
2031
  return request_task.get_return_value()
1519
2032
 
1520
2033
 
2034
+ @typing.overload
2035
+ def stream_and_get(request_id: server_common.RequestId[T],
2036
+ log_path: Optional[str] = None,
2037
+ tail: Optional[int] = None,
2038
+ follow: bool = True,
2039
+ output_stream: Optional['io.TextIOBase'] = None) -> T:
2040
+ ...
2041
+
2042
+
2043
+ @typing.overload
2044
+ def stream_and_get(request_id: None = None,
2045
+ log_path: Optional[str] = None,
2046
+ tail: Optional[int] = None,
2047
+ follow: bool = True,
2048
+ output_stream: Optional['io.TextIOBase'] = None) -> None:
2049
+ ...
2050
+
2051
+
1521
2052
  @usage_lib.entrypoint
1522
2053
  @server_common.check_server_healthy_or_start
1523
2054
  @annotations.client_api
2055
+ @rest.retry_transient_errors()
1524
2056
  def stream_and_get(
1525
- request_id: Optional[str] = None,
2057
+ request_id: Optional[server_common.RequestId[T]] = None,
1526
2058
  log_path: Optional[str] = None,
1527
2059
  tail: Optional[int] = None,
1528
2060
  follow: bool = True,
1529
2061
  output_stream: Optional['io.TextIOBase'] = None,
1530
- ) -> Any:
2062
+ ) -> Optional[T]:
1531
2063
  """Streams the logs of a request or a log file and gets the final result.
1532
2064
 
1533
2065
  This will block until the request is finished. The request id can be a
1534
2066
  prefix of the full request id.
1535
2067
 
1536
2068
  Args:
1537
- request_id: The prefix of the request ID of the request to stream.
2069
+ request_id: The request ID of the request to stream. May be a full
2070
+ request ID or a prefix.
2071
+ If None, the latest request submitted to the API server is streamed.
2072
+ Using None request_id is not recommended in multi-user environments.
1538
2073
  log_path: The path to the log file to stream.
1539
2074
  tail: The number of lines to show from the end of the logs.
1540
2075
  If None, show all logs.
@@ -1545,6 +2080,8 @@ def stream_and_get(
1545
2080
  Returns:
1546
2081
  The ``Request Returns`` of the specified request. See the documentation
1547
2082
  of the specific requests above for more details.
2083
+ If follow is False, will always return None. See note on
2084
+ stream_response.
1548
2085
 
1549
2086
  Raises:
1550
2087
  Exception: It raises the same exceptions as the specific requests,
@@ -1558,27 +2095,44 @@ def stream_and_get(
1558
2095
  'follow': follow,
1559
2096
  'format': 'console',
1560
2097
  }
1561
- response = requests.get(
1562
- f'{server_common.get_server_url()}/api/stream',
2098
+ response = server_common.make_authenticated_request(
2099
+ 'GET',
2100
+ '/api/stream',
1563
2101
  params=params,
2102
+ retry=False,
1564
2103
  timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
1565
2104
  None),
1566
- stream=True,
1567
- cookies=server_common.get_api_cookie_jar())
2105
+ stream=True)
1568
2106
  if response.status_code in [404, 400]:
1569
2107
  detail = response.json().get('detail')
1570
2108
  with ux_utils.print_exception_no_traceback():
1571
- raise RuntimeError(f'Failed to stream logs: {detail}')
2109
+ raise exceptions.ClientError(f'Failed to stream logs: {detail}')
2110
+ stream_request_id: Optional[server_common.RequestId[
2111
+ T]] = server_common.get_stream_request_id(response)
2112
+ if request_id is not None and stream_request_id is not None:
2113
+ assert request_id == stream_request_id
2114
+ if request_id is None:
2115
+ request_id = stream_request_id
1572
2116
  elif response.status_code != 200:
2117
+ # TODO(syang): handle the case where the requestID is not provided
2118
+ # see https://github.com/skypilot-org/skypilot/issues/6549
2119
+ if request_id is None:
2120
+ return None
1573
2121
  return get(request_id)
1574
- return stream_response(request_id, response, output_stream)
2122
+ return stream_response(request_id,
2123
+ response,
2124
+ output_stream,
2125
+ resumable=True,
2126
+ get_result=follow)
1575
2127
 
1576
2128
 
1577
2129
  @usage_lib.entrypoint
1578
2130
  @annotations.client_api
1579
- def api_cancel(request_ids: Optional[Union[str, List[str]]] = None,
2131
+ def api_cancel(request_ids: Optional[Union[server_common.RequestId[T],
2132
+ List[server_common.RequestId[T]],
2133
+ str, List[str]]] = None,
1580
2134
  all_users: bool = False,
1581
- silent: bool = False) -> server_common.RequestId:
2135
+ silent: bool = False) -> server_common.RequestId[List[str]]:
1582
2136
  """Aborts a request or all requests.
1583
2137
 
1584
2138
  Args:
@@ -1618,20 +2172,35 @@ def api_cancel(request_ids: Optional[Union[str, List[str]]] = None,
1618
2172
  echo(f'Cancelling {len(request_ids)} request{plural}: '
1619
2173
  f'{request_id_str}...')
1620
2174
 
1621
- response = requests.post(f'{server_common.get_server_url()}/api/cancel',
1622
- json=json.loads(body.model_dump_json()),
1623
- timeout=5,
1624
- cookies=server_common.get_api_cookie_jar())
2175
+ response = server_common.make_authenticated_request(
2176
+ 'POST',
2177
+ '/api/cancel',
2178
+ json=json.loads(body.model_dump_json()),
2179
+ timeout=5)
1625
2180
  return server_common.get_request_id(response)
1626
2181
 
1627
2182
 
2183
+ def _local_api_server_running(kill: bool = False) -> bool:
2184
+ """Checks if the local api server is running."""
2185
+ for process in psutil.process_iter(attrs=['pid', 'cmdline']):
2186
+ cmdline = process.info['cmdline']
2187
+ if cmdline and server_common.API_SERVER_CMD in ' '.join(cmdline):
2188
+ if kill:
2189
+ subprocess_utils.kill_children_processes(
2190
+ parent_pids=[process.pid], force=True)
2191
+ return True
2192
+ return False
2193
+
2194
+
1628
2195
  @usage_lib.entrypoint
1629
2196
  @annotations.client_api
1630
2197
  def api_status(
1631
- request_ids: Optional[List[str]] = None,
2198
+ request_ids: Optional[List[Union[server_common.RequestId[T], str]]] = None,
1632
2199
  # pylint: disable=redefined-builtin
1633
- all_status: bool = False
1634
- ) -> List[requests_lib.RequestPayload]:
2200
+ all_status: bool = False,
2201
+ limit: Optional[int] = None,
2202
+ fields: Optional[List[str]] = None,
2203
+ ) -> List[payloads.RequestPayload]:
1635
2204
  """Lists all requests.
1636
2205
 
1637
2206
  Args:
@@ -1639,29 +2208,37 @@ def api_status(
1639
2208
  If None, all requests are queried.
1640
2209
  all_status: Whether to list all finished requests as well. This argument
1641
2210
  is ignored if request_ids is not None.
2211
+ limit: The number of requests to show. If None, show all requests.
2212
+ fields: The fields to get. If None, get all fields.
1642
2213
 
1643
2214
  Returns:
1644
2215
  A list of request payloads.
1645
2216
  """
1646
- body = payloads.RequestStatusBody(request_ids=request_ids,
1647
- all_status=all_status)
1648
- response = requests.get(
1649
- f'{server_common.get_server_url()}/api/status',
2217
+ if server_common.is_api_server_local() and not _local_api_server_running():
2218
+ logger.info('SkyPilot API server is not running.')
2219
+ return []
2220
+
2221
+ body = payloads.RequestStatusBody(
2222
+ request_ids=request_ids,
2223
+ all_status=all_status,
2224
+ limit=limit,
2225
+ fields=fields,
2226
+ )
2227
+ response = server_common.make_authenticated_request(
2228
+ 'GET',
2229
+ '/api/status',
1650
2230
  params=server_common.request_body_to_params(body),
1651
2231
  timeout=(client_common.API_SERVER_REQUEST_CONNECTION_TIMEOUT_SECONDS,
1652
- None),
1653
- cookies=server_common.get_api_cookie_jar())
2232
+ None))
1654
2233
  server_common.handle_request_error(response)
1655
- return [
1656
- requests_lib.RequestPayload(**request) for request in response.json()
1657
- ]
2234
+ return [payloads.RequestPayload(**request) for request in response.json()]
1658
2235
 
1659
2236
 
1660
2237
  # === API server management APIs ===
1661
2238
  @usage_lib.entrypoint
1662
2239
  @server_common.check_server_healthy_or_start
1663
2240
  @annotations.client_api
1664
- def api_info() -> Dict[str, str]:
2241
+ def api_info() -> responses.APIHealthResponse:
1665
2242
  """Gets the server's status, commit and version.
1666
2243
 
1667
2244
  Returns:
@@ -1674,13 +2251,19 @@ def api_info() -> Dict[str, str]:
1674
2251
  'api_version': '1',
1675
2252
  'commit': 'abc1234567890',
1676
2253
  'version': '1.0.0',
2254
+ 'version_on_disk': '1.0.0',
2255
+ 'user': {
2256
+ 'name': 'test@example.com',
2257
+ 'id': '12345abcd',
2258
+ },
1677
2259
  }
1678
2260
 
2261
+ Note that user may be None if we are not using an auth proxy.
2262
+
1679
2263
  """
1680
- response = requests.get(f'{server_common.get_server_url()}/api/health',
1681
- cookies=server_common.get_api_cookie_jar())
2264
+ response = server_common.make_authenticated_request('GET', '/api/health')
1682
2265
  response.raise_for_status()
1683
- return response.json()
2266
+ return responses.APIHealthResponse(**response.json())
1684
2267
 
1685
2268
 
1686
2269
  @usage_lib.entrypoint
@@ -1690,6 +2273,9 @@ def api_start(
1690
2273
  deploy: bool = False,
1691
2274
  host: str = '127.0.0.1',
1692
2275
  foreground: bool = False,
2276
+ metrics: bool = False,
2277
+ metrics_port: Optional[int] = None,
2278
+ enable_basic_auth: bool = False,
1693
2279
  ) -> None:
1694
2280
  """Starts the API server.
1695
2281
 
@@ -1703,6 +2289,10 @@ def api_start(
1703
2289
  if deploy is True, to allow remote access.
1704
2290
  foreground: Whether to run the API server in the foreground (run in
1705
2291
  the current process).
2292
+ metrics: Whether to export metrics of the API server.
2293
+ metrics_port: The port to export metrics of the API server.
2294
+ enable_basic_auth: Whether to enable basic authentication
2295
+ in the API server.
1706
2296
  Returns:
1707
2297
  None
1708
2298
  """
@@ -1721,15 +2311,15 @@ def api_start(
1721
2311
  'from the config file and/or unset the '
1722
2312
  'SKYPILOT_API_SERVER_ENDPOINT environment '
1723
2313
  'variable.')
1724
- server_common.check_server_healthy_or_start_fn(deploy, host, foreground)
2314
+ server_common.check_server_healthy_or_start_fn(deploy, host, foreground,
2315
+ metrics, metrics_port,
2316
+ enable_basic_auth)
1725
2317
  if foreground:
1726
2318
  # Explain why current process exited
1727
2319
  logger.info('API server is already running:')
1728
2320
  api_server_url = server_common.get_server_url(host)
1729
- dashboard_url = server_common.get_dashboard_url(api_server_url)
1730
- dashboard_msg = f'Dashboard: {dashboard_url}'
1731
- logger.info(f'{ux_utils.INDENT_SYMBOL}SkyPilot API server: '
1732
- f'{api_server_url} {dashboard_msg}\n'
2321
+ logger.info(f'{ux_utils.INDENT_SYMBOL}SkyPilot API server and dashboard: '
2322
+ f'{api_server_url}\n'
1733
2323
  f'{ux_utils.INDENT_LAST_SYMBOL}'
1734
2324
  f'View API server logs at: {constants.API_SERVER_LOGS}')
1735
2325
 
@@ -1752,16 +2342,30 @@ def api_stop() -> None:
1752
2342
  f'Cannot kill the API server at {server_url} because it is not '
1753
2343
  f'the default SkyPilot API server started locally.')
1754
2344
 
1755
- found = False
1756
- for process in psutil.process_iter(attrs=['pid', 'cmdline']):
1757
- cmdline = process.info['cmdline']
1758
- if cmdline and server_common.API_SERVER_CMD in ' '.join(cmdline):
1759
- subprocess_utils.kill_children_processes(parent_pids=[process.pid],
1760
- force=True)
1761
- found = True
1762
-
1763
- # Remove the database for requests.
1764
- server_common.clear_local_api_server_database()
2345
+ # Acquire the api server creation lock to prevent multiple processes from
2346
+ # stopping and starting the API server at the same time.
2347
+ with filelock.FileLock(
2348
+ os.path.expanduser(constants.API_SERVER_CREATION_LOCK_PATH)):
2349
+ try:
2350
+ with open(os.path.expanduser(scheduler.JOB_CONTROLLER_PID_PATH),
2351
+ 'r',
2352
+ encoding='utf-8') as f:
2353
+ pids = f.read().split('\n')[:-1]
2354
+ for pid in pids:
2355
+ if subprocess_utils.is_process_alive(int(pid.strip())):
2356
+ subprocess_utils.kill_children_processes(
2357
+ parent_pids=[int(pid.strip())], force=True)
2358
+ os.remove(os.path.expanduser(scheduler.JOB_CONTROLLER_PID_PATH))
2359
+ except FileNotFoundError:
2360
+ # its fine we will create it
2361
+ pass
2362
+ except Exception as e: # pylint: disable=broad-except
2363
+ # in case we get perm issues or something is messed up, just ignore
2364
+ # it and assume the process is dead
2365
+ logger.error(f'Error looking at job controller pid file: {e}')
2366
+ pass
2367
+
2368
+ found = _local_api_server_running(kill=True)
1765
2369
 
1766
2370
  if found:
1767
2371
  logger.info(f'{colorama.Fore.GREEN}SkyPilot API server stopped.'
@@ -1796,9 +2400,86 @@ def api_server_logs(follow: bool = True, tail: Optional[int] = None) -> None:
1796
2400
  stream_and_get(log_path=constants.API_SERVER_LOGS, tail=tail)
1797
2401
 
1798
2402
 
2403
+ def _save_config_updates(endpoint: Optional[str] = None,
2404
+ service_account_token: Optional[str] = None) -> None:
2405
+ """Save endpoint and/or service account token to config file."""
2406
+ config_path = pathlib.Path(
2407
+ skypilot_config.get_user_config_path()).expanduser()
2408
+ with filelock.FileLock(config_path.with_suffix('.lock')):
2409
+ if not config_path.exists():
2410
+ config_path.touch()
2411
+ config: Dict[str, Any] = {}
2412
+ else:
2413
+ config = skypilot_config.get_user_config()
2414
+ config = dict(config)
2415
+
2416
+ # Update endpoint if provided
2417
+ if endpoint is not None:
2418
+ # We should always reset the api_server config to avoid legacy
2419
+ # service account token.
2420
+ config['api_server'] = {}
2421
+ config['api_server']['endpoint'] = endpoint
2422
+
2423
+ # Update service account token if provided
2424
+ if service_account_token is not None:
2425
+ if 'api_server' not in config:
2426
+ config['api_server'] = {}
2427
+ config['api_server'][
2428
+ 'service_account_token'] = service_account_token
2429
+
2430
+ yaml_utils.dump_yaml(str(config_path), config)
2431
+ skypilot_config.reload_config()
2432
+
2433
+
2434
+ def _clear_api_server_config() -> None:
2435
+ """Clear endpoint and service account token from config file."""
2436
+ config_path = pathlib.Path(
2437
+ skypilot_config.get_user_config_path()).expanduser()
2438
+ with filelock.FileLock(config_path.with_suffix('.lock')):
2439
+ if not config_path.exists():
2440
+ return
2441
+
2442
+ config = skypilot_config.get_user_config()
2443
+ config = dict(config)
2444
+ if 'api_server' in config:
2445
+ # We might not have set the endpoint in the config file, so we
2446
+ # need to check before deleting.
2447
+ del config['api_server']
2448
+
2449
+ yaml_utils.dump_yaml(str(config_path), config, blank=True)
2450
+ skypilot_config.reload_config()
2451
+
2452
+
2453
+ def _validate_endpoint(endpoint: Optional[str]) -> str:
2454
+ """Validate and normalize the endpoint URL."""
2455
+ if endpoint is None:
2456
+ endpoint = click.prompt('Enter your SkyPilot API server endpoint')
2457
+ # Check endpoint is a valid URL
2458
+ if (endpoint is not None and not endpoint.startswith('http://') and
2459
+ not endpoint.startswith('https://')):
2460
+ raise click.BadParameter('Endpoint must be a valid URL.')
2461
+ return endpoint.rstrip('/')
2462
+
2463
+
2464
+ def _check_endpoint_in_env_var(is_login: bool) -> None:
2465
+ # If the user has set the endpoint via the environment variable, we should
2466
+ # not do anything as we can't disambiguate between the env var and the
2467
+ # config file.
2468
+ """Check if the endpoint is set in the environment variable."""
2469
+ if constants.SKY_API_SERVER_URL_ENV_VAR in os.environ:
2470
+ with ux_utils.print_exception_no_traceback():
2471
+ action = 'login to' if is_login else 'logout of'
2472
+ raise RuntimeError(f'Cannot {action} API server when the endpoint '
2473
+ 'is set via the environment variable. Run unset '
2474
+ f'{constants.SKY_API_SERVER_URL_ENV_VAR} to '
2475
+ 'clear the environment variable.')
2476
+
2477
+
1799
2478
  @usage_lib.entrypoint
1800
2479
  @annotations.client_api
1801
- def api_login(endpoint: Optional[str] = None) -> None:
2480
+ def api_login(endpoint: Optional[str] = None,
2481
+ relogin: bool = False,
2482
+ service_account_token: Optional[str] = None) -> None:
1802
2483
  """Logs into a SkyPilot API server.
1803
2484
 
1804
2485
  This sets the endpoint globally, i.e., all SkyPilot CLI and SDK calls will
@@ -1810,37 +2491,262 @@ def api_login(endpoint: Optional[str] = None) -> None:
1810
2491
  Args:
1811
2492
  endpoint: The endpoint of the SkyPilot API server, e.g.,
1812
2493
  http://1.2.3.4:46580 or https://skypilot.mydomain.com.
2494
+ relogin: Whether to force relogin with OAuth2 when enabled.
2495
+ service_account_token: Service account token for authentication.
1813
2496
 
1814
2497
  Returns:
1815
2498
  None
1816
2499
  """
1817
- # TODO(zhwu): this SDK sets global endpoint, which may not be the best
1818
- # design as a user may expect this is only effective for the current
1819
- # session. We should consider using env var for specifying endpoint.
1820
- if endpoint is None:
1821
- endpoint = click.prompt('Enter your SkyPilot API server endpoint')
1822
- # Check endpoint is a valid URL
1823
- if (endpoint is not None and not endpoint.startswith('http://') and
1824
- not endpoint.startswith('https://')):
1825
- raise click.BadParameter('Endpoint must be a valid URL.')
1826
-
1827
- server_common.check_server_healthy(endpoint)
1828
-
1829
- # Set the endpoint in the config file
1830
- config_path = pathlib.Path(
1831
- skypilot_config.get_user_config_path()).expanduser()
1832
- with filelock.FileLock(config_path.with_suffix('.lock')):
1833
- if not config_path.exists():
1834
- config_path.touch()
1835
- config = {'api_server': {'endpoint': endpoint}}
2500
+ _check_endpoint_in_env_var(is_login=True)
2501
+
2502
+ # Validate and normalize endpoint
2503
+ endpoint = _validate_endpoint(endpoint)
2504
+
2505
+ def _show_logged_in_message(
2506
+ endpoint: str, dashboard_url: str, user: Optional[Dict[str, Any]],
2507
+ server_status: server_common.ApiServerStatus) -> None:
2508
+ """Show the logged in message."""
2509
+ if server_status != server_common.ApiServerStatus.HEALTHY:
2510
+ with ux_utils.print_exception_no_traceback():
2511
+ raise ValueError(f'Cannot log in API server at '
2512
+ f'{endpoint} (status: {server_status.value})')
2513
+
2514
+ identity_info = f'\n{ux_utils.INDENT_SYMBOL}{colorama.Fore.GREEN}User: '
2515
+ if user:
2516
+ user_name = user.get('name')
2517
+ user_id = user.get('id')
2518
+ if user_name and user_id:
2519
+ identity_info += f'{user_name} ({user_id})'
2520
+ elif user_id:
2521
+ identity_info += user_id
1836
2522
  else:
1837
- config = skypilot_config.get_user_config()
1838
- config.set_nested(('api_server', 'endpoint'), endpoint)
1839
- common_utils.dump_yaml(str(config_path), dict(config))
1840
- dashboard_url = server_common.get_dashboard_url(endpoint)
2523
+ identity_info = ''
1841
2524
  dashboard_msg = f'Dashboard: {dashboard_url}'
1842
2525
  click.secho(
1843
2526
  f'Logged into SkyPilot API server at: {endpoint}'
2527
+ f'{identity_info}'
1844
2528
  f'\n{ux_utils.INDENT_LAST_SYMBOL}{colorama.Fore.GREEN}'
1845
2529
  f'{dashboard_msg}',
1846
2530
  fg='green')
2531
+
2532
+ def _set_user_hash(user_hash: Optional[str]) -> None:
2533
+ if user_hash is not None:
2534
+ if not common_utils.is_valid_user_hash(user_hash):
2535
+ raise ValueError(f'Invalid user hash: {user_hash}')
2536
+ common_utils.set_user_hash_locally(user_hash)
2537
+
2538
+ # Handle service account token authentication
2539
+ if service_account_token:
2540
+ if not service_account_token.startswith('sky_'):
2541
+ raise ValueError('Invalid service account token format. '
2542
+ 'Token must start with "sky_"')
2543
+
2544
+ # Save both endpoint and token to config in a single operation
2545
+ _save_config_updates(endpoint=endpoint,
2546
+ service_account_token=service_account_token)
2547
+
2548
+ # Test the authentication by checking server health
2549
+ try:
2550
+ server_status, api_server_info = server_common.check_server_healthy(
2551
+ endpoint)
2552
+ dashboard_url = server_common.get_dashboard_url(endpoint)
2553
+ if api_server_info.user is not None:
2554
+ _set_user_hash(api_server_info.user.get('id'))
2555
+ _show_logged_in_message(endpoint, dashboard_url,
2556
+ api_server_info.user, server_status)
2557
+
2558
+ return
2559
+ except exceptions.ApiServerConnectionError as e:
2560
+ with ux_utils.print_exception_no_traceback():
2561
+ raise RuntimeError(
2562
+ f'Failed to connect to API server at {endpoint}: {e}'
2563
+ ) from e
2564
+ except Exception as e: # pylint: disable=broad-except
2565
+ with ux_utils.print_exception_no_traceback():
2566
+ raise RuntimeError(
2567
+ f'{colorama.Fore.RED}Service account token authentication '
2568
+ f'failed:{colorama.Style.RESET_ALL} {e}') from None
2569
+
2570
+ # OAuth2/cookie-based authentication flow
2571
+ # TODO(zhwu): this SDK sets global endpoint, which may not be the best
2572
+ # design as a user may expect this is only effective for the current
2573
+ # session. We should consider using env var for specifying endpoint.
2574
+
2575
+ server_status, api_server_info = server_common.check_server_healthy(
2576
+ endpoint)
2577
+ if server_status == server_common.ApiServerStatus.NEEDS_AUTH or relogin:
2578
+ # We detected an auth proxy, so go through the auth proxy cookie flow.
2579
+ token: Optional[str] = None
2580
+ server: Optional[oauth_lib.HTTPServer] = None
2581
+ try:
2582
+ callback_port = common_utils.find_free_port(8000)
2583
+
2584
+ token_container: Dict[str, Optional[str]] = {'token': None}
2585
+ logger.debug('Starting local authentication server...')
2586
+ server = oauth_lib.start_local_auth_server(callback_port,
2587
+ token_container,
2588
+ endpoint)
2589
+
2590
+ token_url = (f'{endpoint}/token?local_port={callback_port}')
2591
+ if webbrowser.open(token_url):
2592
+ click.echo(f'{colorama.Fore.GREEN}A web browser has been '
2593
+ f'opened at {token_url}. Please continue the login '
2594
+ f'in the web browser.{colorama.Style.RESET_ALL}\n'
2595
+ f'{colorama.Style.DIM}To manually copy the token, '
2596
+ f'press ctrl+c.{colorama.Style.RESET_ALL}')
2597
+ else:
2598
+ raise ValueError('Failed to open browser.')
2599
+
2600
+ start_time = time.time()
2601
+
2602
+ while (token_container['token'] is None and
2603
+ time.time() - start_time < oauth_lib.AUTH_TIMEOUT):
2604
+ time.sleep(1)
2605
+
2606
+ if token_container['token'] is None:
2607
+ click.echo(f'{colorama.Fore.YELLOW}Authentication timed out '
2608
+ f'after {oauth_lib.AUTH_TIMEOUT} seconds.')
2609
+ else:
2610
+ token = token_container['token']
2611
+
2612
+ except (Exception, KeyboardInterrupt) as e: # pylint: disable=broad-except
2613
+ logger.debug(f'Automatic authentication failed: {e}, '
2614
+ 'falling back to manual token entry.')
2615
+ if isinstance(e, KeyboardInterrupt):
2616
+ click.echo(f'\n{colorama.Style.DIM}Interrupted. Press ctrl+c '
2617
+ f'again to exit.{colorama.Style.RESET_ALL}')
2618
+ # Fall back to manual token entry
2619
+ token_url = f'{endpoint}/token'
2620
+ click.echo('Authentication is needed. Please visit this URL '
2621
+ f'to set up the token:{colorama.Style.BRIGHT}\n\n'
2622
+ f'{token_url}\n{colorama.Style.RESET_ALL}')
2623
+ token = click.prompt('Paste the token')
2624
+ finally:
2625
+ if server is not None:
2626
+ try:
2627
+ server.server_close()
2628
+ except Exception: # pylint: disable=broad-except
2629
+ pass
2630
+ if not token:
2631
+ with ux_utils.print_exception_no_traceback():
2632
+ raise ValueError('Authentication failed.')
2633
+
2634
+ # Parse the token.
2635
+ # b64decode will ignore invalid characters, but does some length and
2636
+ # padding checks.
2637
+ try:
2638
+ data = base64.b64decode(token)
2639
+ except binascii.Error as e:
2640
+ raise ValueError(f'Malformed token: {token}') from e
2641
+ logger.debug(f'Token data: {data!r}')
2642
+ try:
2643
+ json_data = json.loads(data)
2644
+ except (json.JSONDecodeError, UnicodeDecodeError) as e:
2645
+ raise ValueError(f'Malformed token data: {data!r}') from e
2646
+ if not isinstance(json_data, dict):
2647
+ raise ValueError(f'Malformed token JSON: {json_data}')
2648
+
2649
+ if json_data.get('v') == 1:
2650
+ user_hash = json_data.get('user')
2651
+ cookie_dict = json_data['cookies']
2652
+ elif 'v' not in json_data:
2653
+ user_hash = None
2654
+ cookie_dict = json_data
2655
+ else:
2656
+ raise ValueError(f'Unsupported token version: {json_data.get("v")}')
2657
+
2658
+ parsed_url = urlparse.urlparse(endpoint)
2659
+ cookie_jar = cookiejar.MozillaCookieJar()
2660
+ for (name, value) in cookie_dict.items():
2661
+ # dict keys in JSON must be strings
2662
+ assert isinstance(name, str)
2663
+ if not isinstance(value, str):
2664
+ raise ValueError('Malformed token - bad key/value: '
2665
+ f'{name}: {value}')
2666
+
2667
+ # See CookieJar._cookie_from_cookie_tuple
2668
+ # oauth2proxy default is Max-Age 604800
2669
+ expires = int(time.time()) + 604800
2670
+ domain = str(parsed_url.hostname)
2671
+ domain_initial_dot = domain.startswith('.')
2672
+ secure = parsed_url.scheme == 'https'
2673
+ if not domain_initial_dot:
2674
+ domain = '.' + domain
2675
+
2676
+ cookie_jar.set_cookie(
2677
+ cookiejar.Cookie(
2678
+ version=0,
2679
+ name=name,
2680
+ value=value,
2681
+ port=None,
2682
+ port_specified=False,
2683
+ domain=domain,
2684
+ domain_specified=True,
2685
+ domain_initial_dot=domain_initial_dot,
2686
+ path='',
2687
+ path_specified=False,
2688
+ secure=secure,
2689
+ expires=expires,
2690
+ discard=False,
2691
+ comment=None,
2692
+ comment_url=None,
2693
+ rest=dict(),
2694
+ ))
2695
+
2696
+ # Now that the cookies are parsed, save them to the cookie jar.
2697
+ server_common.set_api_cookie_jar(cookie_jar)
2698
+
2699
+ # Set the user hash in the local file.
2700
+ # If the server already has a token for this user set it to the local
2701
+ # file, otherwise use the new user hash.
2702
+ if (api_server_info.user is not None and
2703
+ api_server_info.user.get('id') is not None):
2704
+ _set_user_hash(api_server_info.user.get('id'))
2705
+ else:
2706
+ _set_user_hash(user_hash)
2707
+ else:
2708
+ # Check if basic auth is enabled
2709
+ if api_server_info.basic_auth_enabled:
2710
+ if api_server_info.user is None:
2711
+ with ux_utils.print_exception_no_traceback():
2712
+ raise ValueError(
2713
+ 'Basic auth is enabled but no valid user is found')
2714
+
2715
+ # Set the user hash in the local file.
2716
+ if api_server_info.user is not None:
2717
+ _set_user_hash(api_server_info.user.get('id'))
2718
+
2719
+ # Set the endpoint in the config file
2720
+ _save_config_updates(endpoint=endpoint)
2721
+ dashboard_url = server_common.get_dashboard_url(endpoint)
2722
+
2723
+ # see https://github.com/python/mypy/issues/5107 on why
2724
+ # typing is disabled on this line
2725
+ server_common.get_api_server_status.cache_clear() # type: ignore
2726
+ # After successful authentication, check server health again to get user
2727
+ # identity
2728
+ server_status, final_api_server_info = server_common.check_server_healthy(
2729
+ endpoint)
2730
+ _show_logged_in_message(endpoint, dashboard_url, final_api_server_info.user,
2731
+ server_status)
2732
+
2733
+
2734
+ @usage_lib.entrypoint
2735
+ @annotations.client_api
2736
+ def api_logout() -> None:
2737
+ """Logout of the API server.
2738
+
2739
+ Clears all cookies and settings stored in ~/.sky/config.yaml"""
2740
+ _check_endpoint_in_env_var(is_login=False)
2741
+
2742
+ if server_common.is_api_server_local():
2743
+ with ux_utils.print_exception_no_traceback():
2744
+ raise RuntimeError('Local api server cannot be logged out. '
2745
+ 'Use `sky api stop` instead.')
2746
+
2747
+ # no need to clear cookies if it doesn't exist.
2748
+ server_common.set_api_cookie_jar(cookiejar.MozillaCookieJar(),
2749
+ create_if_not_exists=False)
2750
+ _clear_api_server_config()
2751
+ logger.info(f'{colorama.Fore.GREEN}Logged out of SkyPilot API server.'
2752
+ f'{colorama.Style.RESET_ALL}')